Skip to content

Commit 0a4290f

Browse files
committed
fix(edits): batch create higher layer ids, reuse future roots from lock, sanity check defaults to true
1 parent d27d653 commit 0a4290f

File tree

7 files changed

+127
-81
lines changed

7 files changed

+127
-81
lines changed

pychunkedgraph/graph/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def parents_multiple(
179179
time_stamp=time_stamp,
180180
fail_to_zero=fail_to_zero,
181181
)
182+
mask = mask | (parents == 0)
182183
update(self.parents_cache, node_ids[~mask], parents[~mask])
183184
return parents
184185

pychunkedgraph/graph/chunkedgraph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def get_parents(
214214
end_time=time_stamp,
215215
end_time_inclusive=True,
216216
)
217-
if not parent_rows:
217+
if not parent_rows and not fail_to_zero:
218218
return types.empty_1d
219219

220220
parents = []
@@ -736,8 +736,8 @@ def get_l2_agglomerations(
736736
else:
737737
all_chunk_edges = all_chunk_edges.get_pairs()
738738
supervoxels = self.get_children(level2_ids, flatten=True)
739-
mask0 = np.in1d(all_chunk_edges[:, 0], supervoxels)
740-
mask1 = np.in1d(all_chunk_edges[:, 1], supervoxels)
739+
mask0 = np.isin(all_chunk_edges[:, 0], supervoxels)
740+
mask1 = np.isin(all_chunk_edges[:, 1], supervoxels)
741741
return all_chunk_edges[mask0 & mask1]
742742

743743
l2id_children_d = self.get_children(level2_ids)
@@ -809,7 +809,7 @@ def add_edges(
809809
source_coords: typing.Sequence[int] = None,
810810
sink_coords: typing.Sequence[int] = None,
811811
allow_same_segment_merge: typing.Optional[bool] = False,
812-
do_sanity_check: typing.Optional[bool] = False,
812+
do_sanity_check: typing.Optional[bool] = True,
813813
) -> operation.GraphEditOperation.Result:
814814
"""
815815
Adds an edge to the chunkedgraph
@@ -842,7 +842,7 @@ def remove_edges(
842842
path_augment: bool = True,
843843
disallow_isolating_cut: bool = True,
844844
bb_offset: typing.Tuple[int, int, int] = (240, 240, 24),
845-
do_sanity_check: typing.Optional[bool] = False,
845+
do_sanity_check: typing.Optional[bool] = True,
846846
) -> operation.GraphEditOperation.Result:
847847
"""
848848
Removes edges - either directly or after applying a mincut

pychunkedgraph/graph/client/bigtable/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def lock_roots(
467467

468468
lock_results = {}
469469
root_ids = np.unique(new_root_ids)
470-
max_workers = max(1, len(root_ids) // 2)
470+
max_workers = max(1, len(root_ids))
471471
with ThreadPoolExecutor(max_workers=max_workers) as executor:
472472
future_to_root = {
473473
executor.submit(self.lock_root, root_id, operation_id): root_id
@@ -518,7 +518,7 @@ def lock_roots_indefinitely(
518518

519519
root_ids = np.unique(new_root_ids)
520520
lock_results = {}
521-
max_workers = max(1, len(root_ids) // 2)
521+
max_workers = max(1, len(root_ids))
522522
failed_to_lock = []
523523
with ThreadPoolExecutor(max_workers=max_workers) as executor:
524524
future_to_root = {
@@ -601,7 +601,7 @@ def renew_lock(self, root_id: np.uint64, operation_id: np.uint64) -> bool:
601601

602602
def renew_locks(self, root_ids: np.uint64, operation_id: np.uint64) -> bool:
603603
"""Renews existing root node locks with operation_id to extend time."""
604-
max_workers = max(1, len(root_ids) // 2)
604+
max_workers = max(1, len(root_ids))
605605
with ThreadPoolExecutor(max_workers=max_workers) as executor:
606606
futures = {
607607
executor.submit(self.renew_lock, root_id, operation_id): root_id
@@ -640,7 +640,7 @@ def get_consolidated_lock_timestamp(
640640
"""Minimum of multiple lock timestamps."""
641641
if len(root_ids) == 0:
642642
return None
643-
max_workers = max(1, len(root_ids) // 2)
643+
max_workers = max(1, len(root_ids))
644644
with ThreadPoolExecutor(max_workers=max_workers) as executor:
645645
futures = {
646646
executor.submit(self.get_lock_timestamp, root_id, op_id): (

pychunkedgraph/graph/edges/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def flip_ids(id_map, node_ids):
207207
return np.concatenate(ids).astype(basetypes.NODE_ID)
208208

209209

210-
def _get_new_nodes(
210+
def get_new_nodes(
211211
cg, nodes: np.ndarray, layer: int, parent_ts: datetime.datetime = None
212212
):
213213
unique_nodes, inverse = np.unique(nodes, return_inverse=True)
@@ -248,7 +248,7 @@ def get_stale_nodes(
248248
for layer in np.unique(node_layers):
249249
_mask = node_layers == layer
250250
layer_nodes = nodes[_mask]
251-
_nodes = _get_new_nodes(cg, supervoxels[_mask], layer, parent_ts)
251+
_nodes = get_new_nodes(cg, supervoxels[_mask], layer, parent_ts)
252252
stale_mask = layer_nodes != _nodes
253253
stale_nodes.append(layer_nodes[stale_mask])
254254
return np.concatenate(stale_nodes)
@@ -447,8 +447,11 @@ def _get_parents_b(edges, parent_ts, layer, fallback: bool = False):
447447
Searches for new partners that may have any edges to `edges[:,0]`.
448448
"""
449449
if PARENTS_CACHE is None:
450+
# this cache is set only during migration
451+
# also, fallback is not applicable if no migration
450452
children_b = cg.get_children(edges[:, 1], flatten=True)
451453
parents_b = np.unique(cg.get_parents(children_b, time_stamp=parent_ts))
454+
fallback = False
452455
else:
453456
children_b = _get_children_from_cache(edges[:, 1])
454457
_populate_parents_cache(children_b)
@@ -564,7 +567,7 @@ def _get_new_edge(edge, edge_layer, parent_ts, padding, fallback: bool = False):
564567
if fallback:
565568
parents_b = _get_parents_b(_edges, parent_ts, edge_layer, True)
566569

567-
parents_b = np.unique(_get_new_nodes(cg, parents_b, mlayer, parent_ts))
570+
parents_b = np.unique(get_new_nodes(cg, parents_b, mlayer, parent_ts))
568571
parents_a = np.array([node_a] * parents_b.size, dtype=basetypes.NODE_ID)
569572
return np.column_stack((parents_a, parents_b))
570573

pychunkedgraph/graph/edits.py

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pylint: disable=invalid-name, missing-docstring, too-many-locals, c-extension-no-member
22

3-
import datetime, random
3+
import datetime, logging, random
44
from typing import Dict
55
from typing import List
66
from typing import Tuple
@@ -16,7 +16,7 @@
1616
from . import types
1717
from . import attributes
1818
from . import cache as cache_utils
19-
from .edges import get_latest_edges_wrapper, flip_ids
19+
from .edges import get_latest_edges_wrapper, flip_ids, get_new_nodes
2020
from .edges.utils import concatenate_cross_edge_dicts
2121
from .edges.utils import merge_cross_edge_dicts
2222
from .utils import basetypes
@@ -25,6 +25,8 @@
2525
from ..utils.general import in2d
2626
from ..debug.utils import sanity_check, sanity_check_single
2727

28+
logger = logging.getLogger(__name__)
29+
2830

2931
def _init_old_hierarchy(cg, l2ids: np.ndarray, parent_ts: datetime.datetime = None):
3032
"""
@@ -684,11 +686,8 @@ def _update_cross_edge_cache_batched(self, new_ids: list):
684686

685687
# Distribute results back to each parent's cache
686688
# Key insight: edges[:, 0] are children, map them to their parent
687-
edge_parents = self.cg.get_roots(
688-
edge_nodes,
689-
stop_layer=parent_layer,
690-
ceil=False,
691-
time_stamp=self._last_ts,
689+
edge_parents = get_new_nodes(
690+
self.cg, edge_nodes, parent_layer, self._last_ts
692691
)
693692
edge_parents_d = dict(zip(edge_nodes, edge_parents))
694693
for new_id in new_ids:
@@ -714,6 +713,48 @@ def _update_cross_edge_cache_batched(self, new_ids: list):
714713
self.cg.cache.cross_chunk_edges_cache[new_id] = parent_cx_edges_d
715714
return updated_entries
716715

716+
def _get_new_ids(self, chunk_id, count, is_root):
717+
batch_size = count
718+
new_ids = []
719+
while len(new_ids) < count:
720+
candidate_ids = self.cg.id_client.create_node_ids(
721+
chunk_id, batch_size, root_chunk=is_root
722+
)
723+
existing = self.cg.client.read_nodes(node_ids=candidate_ids)
724+
non_existing = set(candidate_ids) - existing.keys()
725+
new_ids.extend(non_existing)
726+
batch_size = min(batch_size * 2, 2**16)
727+
return new_ids[:count]
728+
729+
def _get_new_parents(self, layer, ccs, graph_ids) -> tuple[dict, dict]:
730+
cc_layer_chunk_map = {}
731+
size_map = defaultdict(int)
732+
for i, cc_idx in enumerate(ccs):
733+
parent_layer = layer + 1 # must be reset for each connected component
734+
cc_ids = graph_ids[cc_idx]
735+
if len(cc_ids) == 1:
736+
# skip connection
737+
parent_layer = self.cg.meta.layer_count
738+
cx_edges_d = self.cg.get_cross_chunk_edges(
739+
[cc_ids[0]], time_stamp=self._last_ts
740+
)
741+
for l in range(layer + 1, self.cg.meta.layer_count):
742+
if len(cx_edges_d[cc_ids[0]].get(l, types.empty_2d)) > 0:
743+
parent_layer = l
744+
break
745+
chunk_id = self.cg.get_parent_chunk_id(cc_ids[0], parent_layer)
746+
cc_layer_chunk_map[i] = (parent_layer, chunk_id)
747+
size_map[chunk_id] += 1
748+
749+
chunk_ids = list(size_map.keys())
750+
random.shuffle(chunk_ids)
751+
chunk_new_ids_map = {}
752+
layers = self.cg.get_chunk_layers(chunk_ids)
753+
for c, l in zip(chunk_ids, layers):
754+
is_root = l == self.cg.meta.layer_count
755+
chunk_new_ids_map[c] = self._get_new_ids(c, size_map[c], is_root)
756+
return chunk_new_ids_map, cc_layer_chunk_map
757+
717758
def _create_new_parents(self, layer: int):
718759
"""
719760
keep track of old IDs
@@ -726,37 +767,13 @@ def _create_new_parents(self, layer: int):
726767
"""
727768
new_ids = self._new_ids_d[layer]
728769
layer_node_ids = self._get_layer_node_ids(new_ids, layer)
729-
components, graph_ids = self._get_connected_components(layer_node_ids, layer)
730-
for cc_indices in components:
731-
parent_layer = layer + 1 # must be reset for each connected component
732-
cc_ids = graph_ids[cc_indices]
733-
if len(cc_ids) == 1:
734-
# skip connection
735-
parent_layer = self.cg.meta.layer_count
736-
cx_edges_d = self.cg.get_cross_chunk_edges(
737-
[cc_ids[0]], time_stamp=self._last_ts
738-
)
739-
for l in range(layer + 1, self.cg.meta.layer_count):
740-
if len(cx_edges_d[cc_ids[0]].get(l, types.empty_2d)) > 0:
741-
parent_layer = l
742-
break
770+
ccs, _ids = self._get_connected_components(layer_node_ids, layer)
771+
new_parents_map, cc_layer_chunk_map = self._get_new_parents(layer, ccs, _ids)
743772

744-
# TODO: handle skip connected root id creation separately
745-
chunk_id = self.cg.get_parent_chunk_id(cc_ids[0], parent_layer)
746-
is_root = parent_layer == self.cg.meta.layer_count
747-
batch_size = 1
748-
parent = None
749-
while parent is None:
750-
candidate_ids = self.cg.id_client.create_node_ids(
751-
chunk_id, batch_size, root_chunk=is_root
752-
)
753-
existing = self.cg.client.read_nodes(node_ids=candidate_ids)
754-
for cid in candidate_ids:
755-
if cid not in existing:
756-
parent = cid
757-
break
758-
if parent is None:
759-
batch_size = min(batch_size * 2, 2**16)
773+
for i, cc_indices in enumerate(ccs):
774+
cc_ids = _ids[cc_indices]
775+
parent_layer, chunk_id = cc_layer_chunk_map[i]
776+
parent = new_parents_map[chunk_id].pop()
760777

761778
self._new_ids_d[parent_layer].append(parent)
762779
self._update_id_lineage(parent, cc_ids, layer, parent_layer)
@@ -786,19 +803,20 @@ def run(self) -> Iterable:
786803
"""
787804
self._new_ids_d[2] = self._new_l2_ids
788805
for layer in range(2, self.cg.meta.layer_count):
789-
if len(self._new_ids_d[layer]) == 0:
806+
new_nodes = self._new_ids_d[layer]
807+
if len(new_nodes) == 0:
790808
continue
791-
self.cg.cache.new_ids.update(self._new_ids_d[layer])
809+
self.cg.cache.new_ids.update(new_nodes)
792810
# all new IDs in this layer have been created
793811
# update their cross chunk edges and their neighbors'
794812
with self._profiler.profile(f"l{layer}_update_cx_cache"):
795-
entries = self._update_cross_edge_cache_batched(self._new_ids_d[layer])
813+
entries = self._update_cross_edge_cache_batched(new_nodes)
796814
self.new_entries.extend(entries)
797815

798816
with self._profiler.profile(f"l{layer}_update_neighbor_cx"):
799817
entries = _update_neighbor_cx_edges(
800818
self.cg,
801-
self._new_ids_d[layer],
819+
new_nodes,
802820
self._new_old_id_d,
803821
self._old_new_id_d,
804822
time_stamp=self._time_stamp,
@@ -861,10 +879,24 @@ def _get_cross_edges_val_dicts(self):
861879
return val_dicts
862880

863881
def create_new_entries(self) -> List:
882+
max_layer = self.cg.meta.layer_count
864883
val_dicts = self._get_cross_edges_val_dicts()
865-
for layer in range(2, self.cg.meta.layer_count + 1):
884+
for layer in range(2, max_layer + 1):
866885
new_ids = self._new_ids_d[layer]
867886
for id_ in new_ids:
887+
if self.do_sanity_check:
888+
root_layer = self.cg.get_chunk_layer(self.cg.get_root(id_))
889+
assert root_layer == max_layer, (id_, self.cg.get_root(id_))
890+
891+
if layer < max_layer:
892+
try:
893+
_parent = self.cg.get_parent(id_)
894+
_children = self.cg.get_children(_parent)
895+
assert id_ in _children, (layer, id_, _parent, _children)
896+
except TypeError as e:
897+
logger.error(id_, _parent, self.cg.get_root(id_))
898+
raise TypeError from e
899+
868900
val_dict = val_dicts.get(id_, {})
869901
children = self.cg.get_children(id_)
870902
err = f"parent layer less than children; op {self._opid}"

0 commit comments

Comments
 (0)