Skip to content

Commit ba57d79

Browse files
committed
wip: find stale edges and their latest nodes
1 parent ed467ed commit ba57d79

File tree

6 files changed

+210
-25
lines changed

6 files changed

+210
-25
lines changed

pychunkedgraph/graph/chunkedgraph.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,11 @@ def get_parent_chunk_id(
940940
self.meta, node_or_chunk_id, parent_layer
941941
)
942942

943+
def get_parent_chunk_id_multiple(self, node_or_chunk_ids: typing.Sequence):
944+
return chunk_hierarchy.get_parent_chunk_id_multiple(
945+
self.meta, node_or_chunk_ids
946+
)
947+
943948
def get_parent_chunk_ids(self, node_or_chunk_id: basetypes.NODE_ID):
944949
return chunk_hierarchy.get_parent_chunk_ids(self.meta, node_or_chunk_id)
945950

@@ -984,3 +989,38 @@ def get_operation_ids(self, node_ids: typing.Sequence):
984989
except KeyError:
985990
...
986991
return result
992+
993+
def get_supervoxels(self, node_ids):
994+
"""Returns the first supervoxel found for each node_id."""
995+
result = {}
996+
node_ids_copy = np.copy(node_ids)
997+
children = np.copy(node_ids)
998+
children_d = self.get_children(node_ids)
999+
while True:
1000+
children = [children_d[k][0] for k in children]
1001+
children = np.array(children, dtype=basetypes.NODE_ID)
1002+
mask = self.get_chunk_layers(children) == 1
1003+
result.update(
1004+
[(node, sv) for node, sv in zip(node_ids[mask], children[mask])]
1005+
)
1006+
node_ids = node_ids[~mask]
1007+
children = children[~mask]
1008+
if children.size == 0:
1009+
break
1010+
children_d = self.get_children(children)
1011+
return np.array([result[k] for k in node_ids_copy], dtype=basetypes.NODE_ID)
1012+
1013+
def get_chunk_layers_and_coordinates(self, node_or_chunk_ids: typing.Sequence):
1014+
"""
1015+
Helper function that wraps get chunk layer and coordinates for nodes at any layer.
1016+
"""
1017+
node_or_chunk_ids = np.array(node_or_chunk_ids, dtype=basetypes.NODE_ID)
1018+
layers = self.get_chunk_layers(node_or_chunk_ids)
1019+
chunk_coords = np.zeros(shape=(len(node_or_chunk_ids), 3))
1020+
for _layer in np.unique(layers):
1021+
mask = layers == _layer
1022+
_nodes = node_or_chunk_ids[mask]
1023+
chunk_coords[mask] = chunk_utils.get_chunk_coordinates_multiple(
1024+
self.meta, _nodes
1025+
)
1026+
return layers, chunk_coords

pychunkedgraph/graph/chunks/hierarchy.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_children_chunk_ids(
4343
else:
4444
children_coords = get_children_chunk_coords(meta, layer, (x, y, z))
4545
children_chunk_ids = []
46-
for (x, y, z) in children_coords:
46+
for x, y, z in children_coords:
4747
children_chunk_ids.append(
4848
utils.get_chunk_id(meta, layer=layer - 1, x=x, y=y, z=z)
4949
)
@@ -62,6 +62,19 @@ def get_parent_chunk_id(
6262
return utils.get_chunk_id(meta, layer=parent_layer, x=x, y=y, z=z)
6363

6464

65+
def get_parent_chunk_id_multiple(
66+
meta: ChunkedGraphMeta, node_or_chunk_ids: np.ndarray
67+
) -> np.ndarray:
68+
"""Parent chunk IDs for multiple nodes. Assumes nodes at same layer."""
69+
70+
node_layers = utils.get_chunk_layers(meta, node_or_chunk_ids)
71+
assert np.unique(node_layers).size == 1, np.unique(node_layers)
72+
parent_layer = node_layers[0] + 1
73+
coords = utils.get_chunk_coordinates_multiple(meta, node_or_chunk_ids)
74+
coords = coords // meta.graph_config.FANOUT
75+
return utils.get_chunk_ids_from_coords(meta, layer=parent_layer, coords=coords)
76+
77+
6578
def get_parent_chunk_ids(
6679
meta: ChunkedGraphMeta, node_or_chunk_id: np.uint64
6780
) -> np.ndarray:

pychunkedgraph/graph/edges/__init__.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
"""
44

55
from collections import namedtuple
6+
import datetime
67
from os import environ
7-
from typing import Optional
8+
from copy import copy
9+
from typing import Iterable, Optional
810

911
import numpy as np
1012
import tensorstore as ts
1113
import zstandard as zstd
1214
from graph_tool import Graph
1315

16+
from pychunkedgraph.graph import types
17+
from pychunkedgraph.graph.chunks import utils as chunk_utils
18+
from pychunkedgraph.graph.utils import basetypes
19+
1420
from ..utils import basetypes
1521

1622

@@ -189,3 +195,125 @@ def get_edges(source: str, nodes: np.ndarray) -> Edges:
189195
affinities=np.concatenate(affinities),
190196
areas=np.concatenate(areas),
191197
)
198+
199+
200+
def get_stale_nodes(
201+
cg, edge_nodes: Iterable[basetypes.NODE_ID], parent_ts: datetime.datetime = None
202+
):
203+
"""
204+
Checks to see if partner nodes in edges (edges[:,1]) are stale.
205+
This is done by getting a supervoxel of the node and check
206+
if it has a new parent at the same layer as the node.
207+
"""
208+
edge_supervoxels = cg.get_supervoxels(edge_nodes)
209+
# nodes can be at different layers due to skip connections
210+
edge_nodes_layers = cg.get_chunk_layers(edge_nodes)
211+
stale_nodes = [types.empty_1d]
212+
for layer in np.unique(edge_nodes_layers):
213+
layer_nodes = edge_nodes[edge_nodes_layers == layer]
214+
_nodes = cg.get_roots(
215+
edge_supervoxels,
216+
stop_layer=layer,
217+
ceil=False,
218+
time_stamp=parent_ts,
219+
)
220+
stale_mask = layer_nodes != _nodes
221+
stale_nodes.append(layer_nodes[stale_mask])
222+
return np.concatenate(stale_nodes), edge_supervoxels
223+
224+
225+
def get_latest_edges(cg, stale_edges: Iterable, edge_layers: Iterable) -> dict:
226+
"""
227+
For each of stale_edges [[`node`, `partner`]], get their L2 edge equivalent.
228+
Then get supervoxels of those L2 IDs and get parent(s) at `node` level.
229+
These parents would be the new identities for the stale `partner`.
230+
"""
231+
232+
_nodes = np.unique(stale_edges)
233+
nodes_ts_map = dict(zip(_nodes, cg.get_node_timestamps(_nodes, return_numpy=False)))
234+
layers, coords = cg.get_chunk_layers_and_coordinates(_nodes)
235+
layers_d = dict(zip(_nodes, layers))
236+
coords_d = dict(zip(_nodes, coords))
237+
238+
def _get_normalized_coords(node_a, node_b) -> tuple:
239+
max_layer = layers_d[node_a]
240+
coord_a, coord_b = coords_d[node_a], coords_d[node_b]
241+
if layers_d[node_a] != layers_d[node_b]:
242+
# normalize if nodes are not from the same layer
243+
max_layer = max(layers_d[node_a], layers_d[node_b])
244+
chunk_a = cg.get_parent_chunk_id(node_a, parent_layer=max_layer)
245+
chunk_b = cg.get_parent_chunk_id(node_b, parent_layer=max_layer)
246+
coord_a, coord_b = cg.get_chunk_coordinates_multiple([chunk_a, chunk_b])
247+
return max_layer, coord_a, coord_b
248+
249+
def _get_l2chunkids_along_boundary(max_layer, coord_a, coord_b):
250+
direction = coord_a - coord_b
251+
axis = np.flatnonzero(direction)
252+
assert len(axis) == 1, f"{direction}, {coord_a}, {coord_b}"
253+
axis = axis[0]
254+
children_a = chunk_utils.get_bounding_children_chunks(
255+
cg.meta, max_layer, coord_a, children_layer=2
256+
)
257+
children_b = chunk_utils.get_bounding_children_chunks(
258+
cg.meta, max_layer, coord_b, children_layer=2
259+
)
260+
if direction[axis] > 0:
261+
mid = coord_a[axis] * 2 ** (max_layer - 2)
262+
l2chunks_a = children_a[children_a[:, axis] == mid]
263+
l2chunks_b = children_b[children_b[:, axis] == mid - 1]
264+
else:
265+
mid = coord_b[axis] * 2 ** (max_layer - 2)
266+
l2chunks_a = children_a[children_a[:, axis] == mid - 1]
267+
l2chunks_b = children_b[children_b[:, axis] == mid]
268+
269+
l2chunk_ids_a = chunk_utils.get_chunk_ids_from_coords(cg.meta, 2, l2chunks_a)
270+
l2chunk_ids_b = chunk_utils.get_chunk_ids_from_coords(cg.meta, 2, l2chunks_b)
271+
return l2chunk_ids_a, l2chunk_ids_b
272+
273+
def _get_filtered_l2ids(node_a, node_b, chunks_map):
274+
def _filter(node):
275+
result = []
276+
children = cg.get_children(node)
277+
while True:
278+
chunk_ids = cg.get_chunk_ids_from_node_ids(children)
279+
mask = np.isin(chunk_ids, chunks_map[node])
280+
children = children[mask]
281+
282+
mask = cg.get_chunk_layers(children) == 2
283+
result.append(children[mask])
284+
285+
mask = cg.get_chunk_layers(children) > 2
286+
if children[mask].size == 0:
287+
break
288+
children = cg.get_children(children[mask], flatten=True)
289+
return np.concatenate(result)
290+
291+
return _filter(node_a), _filter(node_b)
292+
293+
chunks_map = {}
294+
for edge_layer, _edge in zip(edge_layers, stale_edges):
295+
node_a, node_b = _edge
296+
mlayer, coord_a, coord_b = _get_normalized_coords(node_a, node_b)
297+
chunks_a, chunks_b = _get_l2chunkids_along_boundary(mlayer, coord_a, coord_b)
298+
299+
chunks_map[node_a] = []
300+
chunks_map[node_b] = []
301+
_layer = 2
302+
while _layer < mlayer:
303+
chunks_map[node_a].append(chunks_a)
304+
chunks_map[node_b].append(chunks_b)
305+
chunks_a = np.unique(cg.get_parent_chunk_id_multiple(chunks_a))
306+
chunks_b = np.unique(cg.get_parent_chunk_id_multiple(chunks_b))
307+
_layer += 1
308+
chunks_map[node_a] = np.concatenate(chunks_map[node_a])
309+
chunks_map[node_b] = np.concatenate(chunks_map[node_b])
310+
311+
l2ids_a, l2ids_b = _get_filtered_l2ids(node_a, node_b, chunks_map)
312+
edges_d = cg.get_cross_chunk_edges(node_ids=l2ids_a, time_stamp=nodes_ts_map[node_b])
313+
314+
_edges = []
315+
for v in edges_d.values():
316+
_edges.append(v.get(edge_layer, types.empty_2d))
317+
_edges = np.concatenate(_edges)
318+
mask = np.isin(_edges[:,1], l2ids_b)
319+
print(_edges[mask])

pychunkedgraph/graph/edges/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def categorize_edges_v2(
135135

136136

137137
def get_cross_chunk_edges_layer(meta: ChunkedGraphMeta, cross_edges: Iterable):
138-
"""Computes the layer in which a cross chunk edge becomes relevant.
138+
"""Computes the layer in which an atomic cross chunk edge becomes relevant.
139139
I.e. if a cross chunk edge links two nodes in layer 4 this function
140140
returns 3.
141141
:param cross_edges: n x 2 array

pychunkedgraph/graph/edits.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from . import types
1616
from . import attributes
1717
from . import cache as cache_utils
18+
from .edges import get_stale_nodes
1819
from .edges.utils import concatenate_cross_edge_dicts
1920
from .edges.utils import merge_cross_edge_dicts
2021
from .utils import basetypes
@@ -248,6 +249,7 @@ def add_edges(
248249
new_roots = create_parents.run()
249250
sanity_check(cg, new_roots, operation_id)
250251
create_parents.create_new_entries()
252+
raise ValueError("success merge")
251253
return new_roots, new_l2_ids, create_parents.new_entries
252254

253255

@@ -378,6 +380,7 @@ def remove_edges(
378380
new_roots = create_parents.run()
379381
sanity_check(cg, new_roots, operation_id)
380382
create_parents.create_new_entries()
383+
raise ValueError("success split")
381384
return new_roots, new_l2_ids, create_parents.new_entries
382385

383386

@@ -497,25 +500,6 @@ def _update_neighbor_cross_edges(
497500
return updated_entries
498501

499502

500-
def get_supervoxels(cg, node_ids):
501-
"""Returns the first supervoxel found for each node_id."""
502-
result = {}
503-
node_ids_copy = np.copy(node_ids)
504-
children = np.copy(node_ids)
505-
children_d = cg.get_children(node_ids)
506-
while True:
507-
children = [children_d[k][0] for k in children]
508-
children = np.array(children, dtype=basetypes.NODE_ID)
509-
mask = cg.get_chunk_layers(children) == 1
510-
result.update([(node, sv) for node, sv in zip(node_ids[mask], children[mask])])
511-
node_ids = node_ids[~mask]
512-
children = children[~mask]
513-
if children.size == 0:
514-
break
515-
children_d = cg.get_children(children)
516-
return np.array([result[k] for k in node_ids_copy], dtype=basetypes.NODE_ID)
517-
518-
519503
class CreateParentNodes:
520504
def __init__(
521505
self,
@@ -605,8 +589,28 @@ def _update_cross_edge_cache(self, parent, children):
605589
children, time_stamp=self._last_successful_ts
606590
)
607591
cx_edges_d = concatenate_cross_edge_dicts(cx_edges_d.values())
608-
edge_nodes = np.unique(np.concatenate([*cx_edges_d.values(), types.empty_2d]))
609-
edge_supervoxels = get_supervoxels(self.cg, edge_nodes)
592+
593+
_cx_edges = [types.empty_2d]
594+
_edge_layers = [types.empty_1d]
595+
for k, v in cx_edges_d.items():
596+
_cx_edges.append(v)
597+
_edge_layers.append([k] * len(v))
598+
_cx_edges = np.concatenate(_cx_edges)
599+
_edge_layers = np.concatenate(_edge_layers, dtype=int)
600+
601+
edge_nodes = np.unique(_cx_edges)
602+
stale_nodes, edge_supervoxels = get_stale_nodes(
603+
self.cg, edge_nodes, parent_ts=self._last_successful_ts
604+
)
605+
stale_mask = np.isin(edge_nodes, stale_nodes)
606+
607+
if np.any(stale_mask):
608+
mask = _cx_edges[:, 1] == stale_nodes
609+
stale_edges = _cx_edges[mask]
610+
stalte_edge_layers = _edge_layers[mask]
611+
# latest_parents = get_latest_nodes(self.cg, stale_edges)
612+
raise ValueError()
613+
610614
edge_parents = self.cg.get_roots(
611615
edge_supervoxels,
612616
stop_layer=parent_layer,

pychunkedgraph/ingest/upgrade/parent_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> l
110110
if edges.size == 0:
111111
continue
112112
nodes = np.unique(edges[:, 1])
113-
svs = get_supervoxels(cg, nodes)
113+
svs = cg.get_supervoxels(nodes)
114114
parents = cg.get_roots(svs, time_stamp=ts, stop_layer=layer, ceil=False)
115115
edge_parents_d = dict(zip(nodes, parents))
116116
val_dict = {}

0 commit comments

Comments
 (0)