55from pathlib import Path
66from typing import TYPE_CHECKING , Any , Literal , TypeVar , overload
77
8+ import geff
89import numpy as np
910import polars as pl
1011import rustworkx as rx
11- from geff import GeffMetadata
12- from geff .metadata_schema import Axis
13- from geff .rustworkx .io import read_rx
14- from geff .write_arrays import write_arrays
12+ from geff .core_io import construct_var_len_props , write_arrays
13+ from geff_spec import Axis , PropMetadata
1514from numpy .typing import ArrayLike
1615from zarr .storage import StoreLike
1716
1817from tracksdata .attrs import AttrComparison , NodeAttr
1918from tracksdata .constants import DEFAULT_ATTR_KEYS
20- from tracksdata .utils ._dtypes import column_to_bytes , column_to_numpy
19+ from tracksdata .utils ._dtypes import column_to_numpy , polars_dtype_to_numpy_dtype
2120from tracksdata .utils ._logging import LOG
2221from tracksdata .utils ._multiprocessing import multiprocessing_apply
2322
@@ -1184,7 +1183,7 @@ def from_geff(
11841183 from tracksdata .graph import IndexedRXGraph
11851184
11861185 # this performs a roundtrip with the rustworkx graph
1187- rx_graph , _ = read_rx (geff_store )
1186+ rx_graph , _ = geff . read (geff_store , backend = "rustworkx" )
11881187
11891188 if not isinstance (rx_graph , rx .PyDiGraph ):
11901189 LOG .warning ("The graph is not a directed graph, converting to directed graph." )
@@ -1196,6 +1195,16 @@ def from_geff(
11961195 ** kwargs ,
11971196 )
11981197
1198+ if DEFAULT_ATTR_KEYS .MASK in indexed_graph .node_attr_keys :
1199+ from tracksdata .nodes ._mask import Mask
1200+
1201+ # unsafe operation, changing graph content inplace
1202+ for node_attr in indexed_graph .rx_graph .nodes ():
1203+ node_attr [DEFAULT_ATTR_KEYS .MASK ] = Mask (
1204+ node_attr [DEFAULT_ATTR_KEYS .MASK ].astype (bool ),
1205+ bbox = node_attr [DEFAULT_ATTR_KEYS .BBOX ],
1206+ )
1207+
11991208 if cls == IndexedRXGraph :
12001209 return indexed_graph
12011210
@@ -1204,7 +1213,7 @@ def from_geff(
12041213 def to_geff (
12051214 self ,
12061215 geff_store : StoreLike ,
1207- geff_metadata : GeffMetadata | None = None ,
1216+ geff_metadata : geff . GeffMetadata | None = None ,
12081217 zarr_format : Literal [2 , 3 ] = 3 ,
12091218 ) -> None :
12101219 """
@@ -1225,9 +1234,12 @@ def to_geff(
12251234 """
12261235
12271236 node_attrs = self .node_attrs ()
1237+ node_ids = node_attrs [DEFAULT_ATTR_KEYS .NODE_ID ].to_numpy ()
1238+ node_attrs = node_attrs .drop (DEFAULT_ATTR_KEYS .NODE_ID )
12281239
1229- if DEFAULT_ATTR_KEYS .MASK in node_attrs .columns :
1230- node_attrs = column_to_bytes (node_attrs , DEFAULT_ATTR_KEYS .MASK )
1240+ edge_attrs = self .edge_attrs ().drop (DEFAULT_ATTR_KEYS .EDGE_ID )
1241+ edge_ids = edge_attrs .select (DEFAULT_ATTR_KEYS .EDGE_SOURCE , DEFAULT_ATTR_KEYS .EDGE_TARGET ).to_numpy ()
1242+ edge_attrs = edge_attrs .drop (DEFAULT_ATTR_KEYS .EDGE_SOURCE , DEFAULT_ATTR_KEYS .EDGE_TARGET )
12311243
12321244 if geff_metadata is None :
12331245 axes = [Axis (name = DEFAULT_ATTR_KEYS .T , type = "time" )]
@@ -1240,37 +1252,47 @@ def to_geff(
12401252 else :
12411253 track_node_props = None
12421254
1243- geff_metadata = GeffMetadata (
1255+ node_props_metadata = {
1256+ k : PropMetadata (
1257+ identifier = k ,
1258+ dtype = polars_dtype_to_numpy_dtype (v .dtype ) if k != DEFAULT_ATTR_KEYS .MASK else np .uint64 ,
1259+ varlength = k == DEFAULT_ATTR_KEYS .MASK ,
1260+ )
1261+ for k , v in node_attrs .to_dict ().items ()
1262+ }
1263+ edge_props_metadata = {
1264+ k : PropMetadata (identifier = k , dtype = polars_dtype_to_numpy_dtype (v .dtype ))
1265+ for k , v in edge_attrs .to_dict ().items ()
1266+ }
1267+
1268+ geff_metadata = geff .GeffMetadata (
12441269 directed = True ,
12451270 axes = axes ,
1271+ node_props_metadata = node_props_metadata ,
1272+ edge_props_metadata = edge_props_metadata ,
12461273 track_node_props = track_node_props ,
12471274 )
12481275
1249- edge_attrs = self .edge_attrs ().drop (DEFAULT_ATTR_KEYS .EDGE_ID )
1250-
12511276 node_dict = {
1252- k : (column_to_numpy (v ), None ) for k , v in node_attrs .drop (DEFAULT_ATTR_KEYS .NODE_ID ).to_dict ().items ()
1277+ k : {"values" : column_to_numpy (v ), "missing" : None }
1278+ for k , v in node_attrs .to_dict ().items ()
1279+ if k != DEFAULT_ATTR_KEYS .MASK
12531280 }
12541281
1255- edge_ids = edge_attrs .select (DEFAULT_ATTR_KEYS .EDGE_SOURCE , DEFAULT_ATTR_KEYS .EDGE_TARGET ).to_numpy ()
1256-
1257- edge_dict = {
1258- k : (column_to_numpy (v ), None )
1259- for k , v in edge_attrs .drop (
1260- DEFAULT_ATTR_KEYS .EDGE_SOURCE ,
1261- DEFAULT_ATTR_KEYS .EDGE_TARGET ,
1282+ if DEFAULT_ATTR_KEYS .MASK in node_attrs .columns :
1283+ node_dict [DEFAULT_ATTR_KEYS .MASK ] = construct_var_len_props (
1284+ [mask .mask .astype (np .uint64 ) for mask in node_attrs [DEFAULT_ATTR_KEYS .MASK ]]
12621285 )
1263- .to_dict ()
1264- .items ()
1265- }
1286+
1287+ edge_dict = {k : {"values" : column_to_numpy (v ), "missing" : None } for k , v in edge_attrs .to_dict ().items ()}
12661288
12671289 write_arrays (
12681290 geff_store ,
1269- metadata = geff_metadata ,
1270- node_ids = node_attrs [DEFAULT_ATTR_KEYS .NODE_ID ].to_numpy (),
1291+ node_ids = node_ids .astype (np .uint64 ),
12711292 node_props = node_dict ,
1272- edge_ids = edge_ids ,
1293+ edge_ids = edge_ids . astype ( np . uint64 ) ,
12731294 edge_props = edge_dict ,
1295+ metadata = geff_metadata ,
12741296 zarr_format = zarr_format ,
12751297 )
12761298
0 commit comments