|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | from collections import namedtuple |
| 6 | +import datetime |
6 | 7 | from os import environ |
7 | | -from typing import Optional |
| 8 | +from copy import copy |
| 9 | +from typing import Iterable, Optional |
8 | 10 |
|
9 | 11 | import numpy as np |
10 | 12 | import tensorstore as ts |
11 | 13 | import zstandard as zstd |
12 | 14 | from graph_tool import Graph |
13 | 15 |
|
| 16 | +from pychunkedgraph.graph import types |
| 17 | +from pychunkedgraph.graph.chunks import utils as chunk_utils |
| 18 | +from pychunkedgraph.graph.utils import basetypes |
| 19 | + |
14 | 20 | from ..utils import basetypes |
15 | 21 |
|
16 | 22 |
|
@@ -189,3 +195,125 @@ def get_edges(source: str, nodes: np.ndarray) -> Edges: |
189 | 195 | affinities=np.concatenate(affinities), |
190 | 196 | areas=np.concatenate(areas), |
191 | 197 | ) |
| 198 | + |
| 199 | + |
| 200 | +def get_stale_nodes( |
| 201 | + cg, edge_nodes: Iterable[basetypes.NODE_ID], parent_ts: datetime.datetime = None |
| 202 | +): |
| 203 | + """ |
| 204 | + Checks to see if partner nodes in edges (edges[:,1]) are stale. |
| 205 | + This is done by getting a supervoxel of the node and check |
| 206 | + if it has a new parent at the same layer as the node. |
| 207 | + """ |
| 208 | + edge_supervoxels = cg.get_supervoxels(edge_nodes) |
| 209 | + # nodes can be at different layers due to skip connections |
| 210 | + edge_nodes_layers = cg.get_chunk_layers(edge_nodes) |
| 211 | + stale_nodes = [types.empty_1d] |
| 212 | + for layer in np.unique(edge_nodes_layers): |
| 213 | + layer_nodes = edge_nodes[edge_nodes_layers == layer] |
| 214 | + _nodes = cg.get_roots( |
| 215 | + edge_supervoxels, |
| 216 | + stop_layer=layer, |
| 217 | + ceil=False, |
| 218 | + time_stamp=parent_ts, |
| 219 | + ) |
| 220 | + stale_mask = layer_nodes != _nodes |
| 221 | + stale_nodes.append(layer_nodes[stale_mask]) |
| 222 | + return np.concatenate(stale_nodes), edge_supervoxels |
| 223 | + |
| 224 | + |
| 225 | +def get_latest_edges(cg, stale_edges: Iterable, edge_layers: Iterable) -> dict: |
| 226 | + """ |
| 227 | + For each of stale_edges [[`node`, `partner`]], get their L2 edge equivalent. |
| 228 | + Then get supervoxels of those L2 IDs and get parent(s) at `node` level. |
| 229 | + These parents would be the new identities for the stale `partner`. |
| 230 | + """ |
| 231 | + |
| 232 | + _nodes = np.unique(stale_edges) |
| 233 | + nodes_ts_map = dict(zip(_nodes, cg.get_node_timestamps(_nodes, return_numpy=False))) |
| 234 | + layers, coords = cg.get_chunk_layers_and_coordinates(_nodes) |
| 235 | + layers_d = dict(zip(_nodes, layers)) |
| 236 | + coords_d = dict(zip(_nodes, coords)) |
| 237 | + |
| 238 | + def _get_normalized_coords(node_a, node_b) -> tuple: |
| 239 | + max_layer = layers_d[node_a] |
| 240 | + coord_a, coord_b = coords_d[node_a], coords_d[node_b] |
| 241 | + if layers_d[node_a] != layers_d[node_b]: |
| 242 | + # normalize if nodes are not from the same layer |
| 243 | + max_layer = max(layers_d[node_a], layers_d[node_b]) |
| 244 | + chunk_a = cg.get_parent_chunk_id(node_a, parent_layer=max_layer) |
| 245 | + chunk_b = cg.get_parent_chunk_id(node_b, parent_layer=max_layer) |
| 246 | + coord_a, coord_b = cg.get_chunk_coordinates_multiple([chunk_a, chunk_b]) |
| 247 | + return max_layer, coord_a, coord_b |
| 248 | + |
| 249 | + def _get_l2chunkids_along_boundary(max_layer, coord_a, coord_b): |
| 250 | + direction = coord_a - coord_b |
| 251 | + axis = np.flatnonzero(direction) |
| 252 | + assert len(axis) == 1, f"{direction}, {coord_a}, {coord_b}" |
| 253 | + axis = axis[0] |
| 254 | + children_a = chunk_utils.get_bounding_children_chunks( |
| 255 | + cg.meta, max_layer, coord_a, children_layer=2 |
| 256 | + ) |
| 257 | + children_b = chunk_utils.get_bounding_children_chunks( |
| 258 | + cg.meta, max_layer, coord_b, children_layer=2 |
| 259 | + ) |
| 260 | + if direction[axis] > 0: |
| 261 | + mid = coord_a[axis] * 2 ** (max_layer - 2) |
| 262 | + l2chunks_a = children_a[children_a[:, axis] == mid] |
| 263 | + l2chunks_b = children_b[children_b[:, axis] == mid - 1] |
| 264 | + else: |
| 265 | + mid = coord_b[axis] * 2 ** (max_layer - 2) |
| 266 | + l2chunks_a = children_a[children_a[:, axis] == mid - 1] |
| 267 | + l2chunks_b = children_b[children_b[:, axis] == mid] |
| 268 | + |
| 269 | + l2chunk_ids_a = chunk_utils.get_chunk_ids_from_coords(cg.meta, 2, l2chunks_a) |
| 270 | + l2chunk_ids_b = chunk_utils.get_chunk_ids_from_coords(cg.meta, 2, l2chunks_b) |
| 271 | + return l2chunk_ids_a, l2chunk_ids_b |
| 272 | + |
| 273 | + def _get_filtered_l2ids(node_a, node_b, chunks_map): |
| 274 | + def _filter(node): |
| 275 | + result = [] |
| 276 | + children = cg.get_children(node) |
| 277 | + while True: |
| 278 | + chunk_ids = cg.get_chunk_ids_from_node_ids(children) |
| 279 | + mask = np.isin(chunk_ids, chunks_map[node]) |
| 280 | + children = children[mask] |
| 281 | + |
| 282 | + mask = cg.get_chunk_layers(children) == 2 |
| 283 | + result.append(children[mask]) |
| 284 | + |
| 285 | + mask = cg.get_chunk_layers(children) > 2 |
| 286 | + if children[mask].size == 0: |
| 287 | + break |
| 288 | + children = cg.get_children(children[mask], flatten=True) |
| 289 | + return np.concatenate(result) |
| 290 | + |
| 291 | + return _filter(node_a), _filter(node_b) |
| 292 | + |
| 293 | + chunks_map = {} |
| 294 | + for edge_layer, _edge in zip(edge_layers, stale_edges): |
| 295 | + node_a, node_b = _edge |
| 296 | + mlayer, coord_a, coord_b = _get_normalized_coords(node_a, node_b) |
| 297 | + chunks_a, chunks_b = _get_l2chunkids_along_boundary(mlayer, coord_a, coord_b) |
| 298 | + |
| 299 | + chunks_map[node_a] = [] |
| 300 | + chunks_map[node_b] = [] |
| 301 | + _layer = 2 |
| 302 | + while _layer < mlayer: |
| 303 | + chunks_map[node_a].append(chunks_a) |
| 304 | + chunks_map[node_b].append(chunks_b) |
| 305 | + chunks_a = np.unique(cg.get_parent_chunk_id_multiple(chunks_a)) |
| 306 | + chunks_b = np.unique(cg.get_parent_chunk_id_multiple(chunks_b)) |
| 307 | + _layer += 1 |
| 308 | + chunks_map[node_a] = np.concatenate(chunks_map[node_a]) |
| 309 | + chunks_map[node_b] = np.concatenate(chunks_map[node_b]) |
| 310 | + |
| 311 | + l2ids_a, l2ids_b = _get_filtered_l2ids(node_a, node_b, chunks_map) |
| 312 | + edges_d = cg.get_cross_chunk_edges(node_ids=l2ids_a, time_stamp=nodes_ts_map[node_b]) |
| 313 | + |
| 314 | + _edges = [] |
| 315 | + for v in edges_d.values(): |
| 316 | + _edges.append(v.get(edge_layer, types.empty_2d)) |
| 317 | + _edges = np.concatenate(_edges) |
| 318 | + mask = np.isin(_edges[:,1], l2ids_b) |
| 319 | + print(_edges[mask]) |
0 commit comments