Skip to content

Commit f4ee204

Browse files
authored
Implementing stable 1.0.0 geff release (#167)
* working version with vlen mask saving * updating inner type computation
1 parent 5007e4a commit f4ee204

File tree

5 files changed

+63
-29
lines changed

5 files changed

+63
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies = [
4848
"ilpy >= 0.5.1",
4949
"pyarrow",
5050
"bidict>=0.23.1",
51-
"geff @ git+https://github.com/live-image-tracking-tools/geff.git@b751718f81d107e1fdda2df2afb62253039c137b",
51+
"geff>=1.0.0",
5252
]
5353

5454
[project.optional-dependencies]

src/tracksdata/graph/_base_graph.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55
from pathlib import Path
66
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
77

8+
import geff
89
import numpy as np
910
import polars as pl
1011
import rustworkx as rx
11-
from geff import GeffMetadata
12-
from geff.metadata_schema import Axis
13-
from geff.rustworkx.io import read_rx
14-
from geff.write_arrays import write_arrays
12+
from geff.core_io import construct_var_len_props, write_arrays
13+
from geff_spec import Axis, PropMetadata
1514
from numpy.typing import ArrayLike
1615
from zarr.storage import StoreLike
1716

1817
from tracksdata.attrs import AttrComparison, NodeAttr
1918
from tracksdata.constants import DEFAULT_ATTR_KEYS
20-
from tracksdata.utils._dtypes import column_to_bytes, column_to_numpy
19+
from tracksdata.utils._dtypes import column_to_numpy, polars_dtype_to_numpy_dtype
2120
from tracksdata.utils._logging import LOG
2221
from tracksdata.utils._multiprocessing import multiprocessing_apply
2322

@@ -1184,7 +1183,7 @@ def from_geff(
11841183
from tracksdata.graph import IndexedRXGraph
11851184

11861185
# this performs a roundtrip with the rustworkx graph
1187-
rx_graph, _ = read_rx(geff_store)
1186+
rx_graph, _ = geff.read(geff_store, backend="rustworkx")
11881187

11891188
if not isinstance(rx_graph, rx.PyDiGraph):
11901189
LOG.warning("The graph is not a directed graph, converting to directed graph.")
@@ -1196,6 +1195,16 @@ def from_geff(
11961195
**kwargs,
11971196
)
11981197

1198+
if DEFAULT_ATTR_KEYS.MASK in indexed_graph.node_attr_keys:
1199+
from tracksdata.nodes._mask import Mask
1200+
1201+
# unsafe operation, changing graph content inplace
1202+
for node_attr in indexed_graph.rx_graph.nodes():
1203+
node_attr[DEFAULT_ATTR_KEYS.MASK] = Mask(
1204+
node_attr[DEFAULT_ATTR_KEYS.MASK].astype(bool),
1205+
bbox=node_attr[DEFAULT_ATTR_KEYS.BBOX],
1206+
)
1207+
11991208
if cls == IndexedRXGraph:
12001209
return indexed_graph
12011210

@@ -1204,7 +1213,7 @@ def from_geff(
12041213
def to_geff(
12051214
self,
12061215
geff_store: StoreLike,
1207-
geff_metadata: GeffMetadata | None = None,
1216+
geff_metadata: geff.GeffMetadata | None = None,
12081217
zarr_format: Literal[2, 3] = 3,
12091218
) -> None:
12101219
"""
@@ -1225,9 +1234,12 @@ def to_geff(
12251234
"""
12261235

12271236
node_attrs = self.node_attrs()
1237+
node_ids = node_attrs[DEFAULT_ATTR_KEYS.NODE_ID].to_numpy()
1238+
node_attrs = node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID)
12281239

1229-
if DEFAULT_ATTR_KEYS.MASK in node_attrs.columns:
1230-
node_attrs = column_to_bytes(node_attrs, DEFAULT_ATTR_KEYS.MASK)
1240+
edge_attrs = self.edge_attrs().drop(DEFAULT_ATTR_KEYS.EDGE_ID)
1241+
edge_ids = edge_attrs.select(DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET).to_numpy()
1242+
edge_attrs = edge_attrs.drop(DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET)
12311243

12321244
if geff_metadata is None:
12331245
axes = [Axis(name=DEFAULT_ATTR_KEYS.T, type="time")]
@@ -1240,37 +1252,47 @@ def to_geff(
12401252
else:
12411253
track_node_props = None
12421254

1243-
geff_metadata = GeffMetadata(
1255+
node_props_metadata = {
1256+
k: PropMetadata(
1257+
identifier=k,
1258+
dtype=polars_dtype_to_numpy_dtype(v.dtype) if k != DEFAULT_ATTR_KEYS.MASK else np.uint64,
1259+
varlength=k == DEFAULT_ATTR_KEYS.MASK,
1260+
)
1261+
for k, v in node_attrs.to_dict().items()
1262+
}
1263+
edge_props_metadata = {
1264+
k: PropMetadata(identifier=k, dtype=polars_dtype_to_numpy_dtype(v.dtype))
1265+
for k, v in edge_attrs.to_dict().items()
1266+
}
1267+
1268+
geff_metadata = geff.GeffMetadata(
12441269
directed=True,
12451270
axes=axes,
1271+
node_props_metadata=node_props_metadata,
1272+
edge_props_metadata=edge_props_metadata,
12461273
track_node_props=track_node_props,
12471274
)
12481275

1249-
edge_attrs = self.edge_attrs().drop(DEFAULT_ATTR_KEYS.EDGE_ID)
1250-
12511276
node_dict = {
1252-
k: (column_to_numpy(v), None) for k, v in node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID).to_dict().items()
1277+
k: {"values": column_to_numpy(v), "missing": None}
1278+
for k, v in node_attrs.to_dict().items()
1279+
if k != DEFAULT_ATTR_KEYS.MASK
12531280
}
12541281

1255-
edge_ids = edge_attrs.select(DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET).to_numpy()
1256-
1257-
edge_dict = {
1258-
k: (column_to_numpy(v), None)
1259-
for k, v in edge_attrs.drop(
1260-
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
1261-
DEFAULT_ATTR_KEYS.EDGE_TARGET,
1282+
if DEFAULT_ATTR_KEYS.MASK in node_attrs.columns:
1283+
node_dict[DEFAULT_ATTR_KEYS.MASK] = construct_var_len_props(
1284+
[mask.mask.astype(np.uint64) for mask in node_attrs[DEFAULT_ATTR_KEYS.MASK]]
12621285
)
1263-
.to_dict()
1264-
.items()
1265-
}
1286+
1287+
edge_dict = {k: {"values": column_to_numpy(v), "missing": None} for k, v in edge_attrs.to_dict().items()}
12661288

12671289
write_arrays(
12681290
geff_store,
1269-
metadata=geff_metadata,
1270-
node_ids=node_attrs[DEFAULT_ATTR_KEYS.NODE_ID].to_numpy(),
1291+
node_ids=node_ids.astype(np.uint64),
12711292
node_props=node_dict,
1272-
edge_ids=edge_ids,
1293+
edge_ids=edge_ids.astype(np.uint64),
12731294
edge_props=edge_dict,
1295+
metadata=geff_metadata,
12741296
zarr_format=zarr_format,
12751297
)
12761298

src/tracksdata/graph/_test/test_graph_backends.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tracksdata.attrs import EdgeAttr, NodeAttr
1010
from tracksdata.constants import DEFAULT_ATTR_KEYS
11-
from tracksdata.graph import BaseGraph, RustWorkXGraph, SQLGraph
11+
from tracksdata.graph import BaseGraph, IndexedRXGraph, RustWorkXGraph, SQLGraph
1212
from tracksdata.io._numpy_array import from_array
1313
from tracksdata.nodes._mask import Mask
1414

@@ -1779,7 +1779,7 @@ def test_geff_roundtrip(graph_backend: BaseGraph) -> None:
17791779

17801780
graph_backend.to_geff(geff_store=output_store)
17811781

1782-
geff_graph = RustWorkXGraph.from_geff(output_store)
1782+
geff_graph = IndexedRXGraph.from_geff(output_store)
17831783

17841784
assert geff_graph.num_nodes == 3
17851785
assert geff_graph.num_edges == 2
@@ -1789,6 +1789,9 @@ def test_geff_roundtrip(graph_backend: BaseGraph) -> None:
17891789
assert set(graph_backend.node_attr_keys) == set(geff_graph.node_attr_keys)
17901790
assert set(graph_backend.edge_attr_keys) == set(geff_graph.edge_attr_keys)
17911791

1792+
for node_id in geff_graph.node_ids():
1793+
assert geff_graph[node_id].to_dict() == graph_backend[node_id].to_dict()
1794+
17921795
assert rx.is_isomorphic(
17931796
rx_graph,
17941797
geff_graph.rx_graph,

src/tracksdata/nodes/_mask.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Sequence
22
from functools import cached_property, lru_cache
3+
from typing import Any
34

45
import blosc2
56
import numpy as np
@@ -278,6 +279,11 @@ def from_coordinates(
278279

279280
return cls(mask, bbox)
280281

282+
def __eq__(self, other: Any) -> bool:
283+
if not isinstance(other, Mask):
284+
return False
285+
return np.array_equal(self.bbox, other.bbox) and np.array_equal(self.mask, other.mask)
286+
281287

282288
class MaskDiskAttrs(GenericFuncNodeAttrs):
283289
"""

src/tracksdata/utils/_dtypes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def polars_dtype_to_numpy_dtype(polars_dtype: DataType) -> np.dtype:
4848
np.dtype
4949
The numpy dtype.
5050
"""
51+
if isinstance(polars_dtype, pl.Array | pl.List):
52+
polars_dtype = polars_dtype.inner
53+
5154
try:
5255
return _POLARS_DTYPE_TO_NUMPY_DTYPE[polars_dtype]
5356
except KeyError as e:

0 commit comments

Comments
 (0)