Skip to content

Commit 77b7bc2

Browse files
committed
keep edits.py identical to pcgv3
1 parent 266028e commit 77b7bc2

File tree

1 file changed

+69
-222
lines changed

1 file changed

+69
-222
lines changed

pychunkedgraph/graph/edits.py

Lines changed: 69 additions & 222 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
# pylint: disable=invalid-name, missing-docstring, too-many-locals, c-extension-no-member
22

33
import datetime
4-
import time
5-
import os
64
from typing import Dict
75
from typing import List
86
from typing import Tuple
97
from typing import Iterable
108
from typing import Set
119
from collections import defaultdict
12-
from contextlib import contextmanager
1310

1411
import fastremap
1512
import numpy as np
@@ -29,117 +26,6 @@
2926
from ..debug.utils import sanity_check, sanity_check_single
3027

3128

32-
class HierarchicalProfiler:
33-
"""
34-
Hierarchical profiler for detailed timing breakdowns.
35-
Tracks timing at multiple levels and prints a breakdown at the end.
36-
"""
37-
38-
def __init__(self, enabled: bool = True):
39-
self.enabled = enabled
40-
self.timings: Dict[str, List[float]] = defaultdict(list)
41-
self.call_counts: Dict[str, int] = defaultdict(int)
42-
self.stack: List[Tuple[str, float]] = []
43-
self.current_path: List[str] = []
44-
45-
@contextmanager
46-
def profile(self, name: str):
47-
"""Context manager for profiling a code block."""
48-
if not self.enabled:
49-
yield
50-
return
51-
52-
full_path = ".".join(self.current_path + [name])
53-
self.current_path.append(name)
54-
start_time = time.perf_counter()
55-
56-
try:
57-
yield
58-
finally:
59-
elapsed = time.perf_counter() - start_time
60-
self.timings[full_path].append(elapsed)
61-
self.call_counts[full_path] += 1
62-
self.current_path.pop()
63-
64-
def print_report(self, operation_id=None):
65-
"""Print a detailed timing breakdown."""
66-
if not self.enabled or not self.timings:
67-
return
68-
69-
print("\n" + "=" * 80)
70-
print(f"PROFILER REPORT{f' (operation_id={operation_id})' if operation_id else ''}")
71-
print("=" * 80)
72-
73-
# Group by depth level
74-
by_depth: Dict[int, List[Tuple[str, float, int]]] = defaultdict(list)
75-
for path, times in self.timings.items():
76-
depth = path.count(".")
77-
total_time = sum(times)
78-
count = self.call_counts[path]
79-
by_depth[depth].append((path, total_time, count))
80-
81-
# Sort each level by total time
82-
for depth in sorted(by_depth.keys()):
83-
items = sorted(by_depth[depth], key=lambda x: -x[1])
84-
for path, total_time, count in items:
85-
indent = " " * depth
86-
avg_time = total_time / count if count > 0 else 0
87-
if count > 1:
88-
print(
89-
f"{indent}{path}: {total_time*1000:.2f}ms total "
90-
f"({count} calls, {avg_time*1000:.2f}ms avg)"
91-
)
92-
else:
93-
print(f"{indent}{path}: {total_time*1000:.2f}ms")
94-
95-
# Print summary
96-
print("-" * 80)
97-
top_level_total = sum(
98-
sum(times) for path, times in self.timings.items() if "." not in path
99-
)
100-
print(f"Total top-level time: {top_level_total*1000:.2f}ms")
101-
102-
# Print top 10 slowest operations
103-
print("\nTop 10 slowest operations:")
104-
all_ops = [
105-
(path, sum(times), self.call_counts[path])
106-
for path, times in self.timings.items()
107-
]
108-
all_ops.sort(key=lambda x: -x[1])
109-
for i, (path, total_time, count) in enumerate(all_ops[:10]):
110-
pct = (total_time / top_level_total * 100) if top_level_total > 0 else 0
111-
print(f" {i+1}. {path}: {total_time*1000:.2f}ms ({pct:.1f}%)")
112-
113-
print("=" * 80 + "\n")
114-
115-
def reset(self):
116-
"""Reset all timing data."""
117-
self.timings.clear()
118-
self.call_counts.clear()
119-
self.stack.clear()
120-
self.current_path.clear()
121-
122-
123-
# Global profiler instance - enable via environment variable
124-
PROFILER_ENABLED = os.environ.get("PCG_PROFILER_ENABLED", "1") == "1"
125-
_profiler: HierarchicalProfiler = None
126-
127-
128-
def get_profiler() -> HierarchicalProfiler:
129-
"""Get or create the global profiler instance."""
130-
global _profiler
131-
if _profiler is None:
132-
_profiler = HierarchicalProfiler(enabled=PROFILER_ENABLED)
133-
return _profiler
134-
135-
136-
def reset_profiler():
137-
"""Reset the global profiler."""
138-
global _profiler
139-
if _profiler is not None:
140-
_profiler.reset()
141-
142-
14329
def _init_old_hierarchy(cg, l2ids: np.ndarray, parent_ts: datetime.datetime = None):
14430
"""
14531
Populates old hierarcy from child to root and also gets children of intermediate nodes.
@@ -169,11 +55,8 @@ def _analyze_affected_edges(
16955
17056
Also returns new cross edges dicts for nodes crossing chunk boundary.
17157
"""
172-
profiler = get_profiler()
173-
17458
supervoxels = np.unique(atomic_edges)
175-
with profiler.profile("analyze_get_parents"):
176-
parents = cg.get_parents(supervoxels, time_stamp=parent_ts)
59+
parents = cg.get_parents(supervoxels, time_stamp=parent_ts)
17760
sv_parent_d = dict(zip(supervoxels.tolist(), parents))
17861
edge_layers = cg.get_cross_chunk_edges_layer(atomic_edges)
17962
parent_edges = [
@@ -318,48 +201,32 @@ def add_edges(
318201
allow_same_segment_merge=False,
319202
stitch_mode: bool = False,
320203
):
321-
profiler = get_profiler()
322-
profiler.reset() # Reset for fresh profiling
323-
324-
with profiler.profile("add_edges"):
325-
with profiler.profile("analyze_affected_edges"):
326-
edges, l2_cross_edges_d = _analyze_affected_edges(
327-
cg, atomic_edges, parent_ts=parent_ts
328-
)
329-
330-
l2ids = np.unique(edges)
331-
if not allow_same_segment_merge and not stitch_mode:
332-
with profiler.profile("validate_roots"):
333-
roots = cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)
334-
assert np.unique(roots).size >= 2, "L2 IDs must belong to different roots."
335-
336-
new_old_id_d = defaultdict(set)
337-
old_new_id_d = defaultdict(set)
338-
339-
with profiler.profile("init_old_hierarchy"):
340-
old_hierarchy_d = _init_old_hierarchy(cg, l2ids, parent_ts=parent_ts)
341-
342-
with profiler.profile("get_children"):
343-
atomic_children_d = cg.get_children(l2ids)
344-
345-
with profiler.profile("get_cross_chunk_edges"):
346-
cross_edges_d = merge_cross_edge_dicts(
347-
cg.get_cross_chunk_edges(l2ids, time_stamp=parent_ts), l2_cross_edges_d
348-
)
204+
edges, l2_cross_edges_d = _analyze_affected_edges(
205+
cg, atomic_edges, parent_ts=parent_ts
206+
)
207+
l2ids = np.unique(edges)
208+
if not allow_same_segment_merge and not stitch_mode:
209+
roots = cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)
210+
assert np.unique(roots).size >= 2, "L2 IDs must belong to different roots."
349211

350-
with profiler.profile("build_graph"):
351-
graph, _, _, graph_ids = flatgraph.build_gt_graph(edges, make_directed=True)
352-
components = flatgraph.connected_components(graph)
212+
new_old_id_d = defaultdict(set)
213+
old_new_id_d = defaultdict(set)
214+
old_hierarchy_d = _init_old_hierarchy(cg, l2ids, parent_ts=parent_ts)
215+
atomic_children_d = cg.get_children(l2ids)
216+
cross_edges_d = merge_cross_edge_dicts(
217+
cg.get_cross_chunk_edges(l2ids, time_stamp=parent_ts), l2_cross_edges_d
218+
)
353219

354-
with profiler.profile("create_l2_ids"):
355-
new_l2_ids = []
356-
for cc_indices in components:
357-
l2ids_ = graph_ids[cc_indices]
358-
new_id = cg.id_client.create_node_id(cg.get_chunk_id(l2ids_[0]))
359-
new_l2_ids.append(new_id)
360-
new_old_id_d[new_id].update(l2ids_)
361-
for id_ in l2ids_:
362-
old_new_id_d[id_].add(new_id)
220+
graph, _, _, graph_ids = flatgraph.build_gt_graph(edges, make_directed=True)
221+
components = flatgraph.connected_components(graph)
222+
new_l2_ids = []
223+
for cc_indices in components:
224+
l2ids_ = graph_ids[cc_indices]
225+
new_id = cg.id_client.create_node_id(cg.get_chunk_id(l2ids_[0]))
226+
new_l2_ids.append(new_id)
227+
new_old_id_d[new_id].update(l2ids_)
228+
for id_ in l2ids_:
229+
old_new_id_d[id_].add(new_id)
363230

364231
# update cache
365232
# map parent to new merged children and vice versa
@@ -368,20 +235,19 @@ def add_edges(
368235
cg.cache.children_cache[new_id] = merged_children
369236
cache_utils.update(cg.cache.parents_cache, merged_children, new_id)
370237

371-
# update cross chunk edges by replacing old_ids with new
372-
# this can be done only after all new IDs have been created
373-
with profiler.profile("update_cross_edges"):
374-
for new_id, cc_indices in zip(new_l2_ids, components):
375-
l2ids_ = graph_ids[cc_indices]
376-
new_cx_edges_d = {}
377-
cx_edges = [cross_edges_d[l2id] for l2id in l2ids_]
378-
cx_edges_d = concatenate_cross_edge_dicts(cx_edges, unique=True)
379-
temp_map = {k: next(iter(v)) for k, v in old_new_id_d.items()}
380-
for layer, edges in cx_edges_d.items():
381-
edges = fastremap.remap(edges, temp_map, preserve_missing_labels=True)
382-
new_cx_edges_d[layer] = edges
383-
assert np.all(edges[:, 0] == new_id)
384-
cg.cache.cross_chunk_edges_cache[new_id] = new_cx_edges_d
238+
# update cross chunk edges by replacing old_ids with new
239+
# this can be done only after all new IDs have been created
240+
for new_id, cc_indices in zip(new_l2_ids, components):
241+
l2ids_ = graph_ids[cc_indices]
242+
new_cx_edges_d = {}
243+
cx_edges = [cross_edges_d[l2id] for l2id in l2ids_]
244+
cx_edges_d = concatenate_cross_edge_dicts(cx_edges, unique=True)
245+
temp_map = {k: next(iter(v)) for k, v in old_new_id_d.items()}
246+
for layer, edges in cx_edges_d.items():
247+
edges = fastremap.remap(edges, temp_map, preserve_missing_labels=True)
248+
new_cx_edges_d[layer] = edges
249+
assert np.all(edges[:, 0] == new_id)
250+
cg.cache.cross_chunk_edges_cache[new_id] = new_cx_edges_d
385251

386252
profiler = get_profiler()
387253
profiler.reset()
@@ -643,12 +509,8 @@ def _update_neighbor_cx_edges(
643509
and then write to storage to consolidate the mutations.
644510
Returns mutations to updated counterparts/partner nodes.
645511
"""
646-
profiler = get_profiler()
647512
updated_counterparts = {}
648-
649-
with profiler.profile("neighbor_get_cross_chunk_edges"):
650-
newid_cx_edges_d = cg.get_cross_chunk_edges(new_ids, time_stamp=parent_ts)
651-
513+
newid_cx_edges_d = cg.get_cross_chunk_edges(new_ids, time_stamp=parent_ts)
652514
node_map = {}
653515
for k, v in old_new_id.items():
654516
if len(v) == 1:
@@ -670,14 +532,11 @@ def _update_neighbor_cx_edges(
670532
cg, new_id, node_map, cp_layers, all_cx_edges_d
671533
)
672534
updated_counterparts.update(result)
673-
674-
with profiler.profile("neighbor_create_mutations"):
675-
updated_entries = []
676-
for node, val_dict in updated_counterparts.items():
677-
rowkey = serialize_uint64(node)
678-
row = cg.client.mutate_row(rowkey, val_dict, time_stamp=time_stamp)
679-
updated_entries.append(row)
680-
535+
updated_entries = []
536+
for node, val_dict in updated_counterparts.items():
537+
rowkey = serialize_uint64(node)
538+
row = cg.client.mutate_row(rowkey, val_dict, time_stamp=time_stamp)
539+
updated_entries.append(row)
681540
return updated_entries
682541

683542

@@ -748,7 +607,6 @@ def _get_layer_node_ids(
748607
# get their parents, then children of those parents
749608
old_parents = self.cg.get_parents(old_ids, time_stamp=self._last_ts)
750609
siblings = self.cg.get_children(np.unique(old_parents), flatten=True)
751-
752610
# replace old identities with new IDs
753611
mask = np.isin(siblings, old_ids)
754612
node_ids = [_flip_ids(self._old_new_id_d, old_ids), siblings[~mask], new_ids]
@@ -781,7 +639,6 @@ def _update_cross_edge_cache(self, parent, children):
781639
ceil=False,
782640
time_stamp=self._last_ts,
783641
)
784-
785642
edge_parents_d = dict(zip(edge_nodes, edge_parents))
786643
new_cx_edges_d = {}
787644
for layer in range(parent_layer, self.cg.meta.layer_count):
@@ -862,14 +719,9 @@ def _create_new_parents(self, layer: int):
862719
update parent old IDs
863720
"""
864721
new_ids = self._new_ids_d[layer]
865-
866-
with self._profiler.profile("get_layer_node_ids"):
867-
layer_node_ids = self._get_layer_node_ids(new_ids, layer)
868-
869-
with self._profiler.profile("get_connected_components"):
870-
components, graph_ids = self._get_connected_components(layer_node_ids, layer)
871-
872-
for cc_idx, cc_indices in enumerate(components):
722+
layer_node_ids = self._get_layer_node_ids(new_ids, layer)
723+
components, graph_ids = self._get_connected_components(layer_node_ids, layer)
724+
for cc_indices in components:
873725
parent_layer = layer + 1 # must be reset for each connected component
874726
cc_ids = graph_ids[cc_indices]
875727
if len(cc_ids) == 1:
@@ -998,35 +850,30 @@ def _get_cross_edges_val_dicts(self):
998850
return val_dicts
999851

1000852
def create_new_entries(self) -> List:
1001-
with self._profiler.profile("get_cross_edges_val_dicts"):
1002-
val_dicts = self._get_cross_edges_val_dicts()
1003-
1004-
with self._profiler.profile("build_hierarchy_entries"):
1005-
for layer in range(2, self.cg.meta.layer_count + 1):
1006-
new_ids = self._new_ids_d[layer]
1007-
for id_ in new_ids:
1008-
val_dict = val_dicts.get(id_, {})
1009-
children = self.cg.get_children(id_)
1010-
err = f"parent layer less than children; op {self._operation_id}"
1011-
assert np.max(
1012-
self.cg.get_chunk_layers(children)
1013-
) < self.cg.get_chunk_layer(id_), err
1014-
val_dict[attributes.Hierarchy.Child] = children
853+
val_dicts = self._get_cross_edges_val_dicts()
854+
for layer in range(2, self.cg.meta.layer_count + 1):
855+
new_ids = self._new_ids_d[layer]
856+
for id_ in new_ids:
857+
val_dict = val_dicts.get(id_, {})
858+
children = self.cg.get_children(id_)
859+
err = f"parent layer less than children; op {self._operation_id}"
860+
assert np.max(
861+
self.cg.get_chunk_layers(children)
862+
) < self.cg.get_chunk_layer(id_), err
863+
val_dict[attributes.Hierarchy.Child] = children
864+
self.new_entries.append(
865+
self.cg.client.mutate_row(
866+
serialize_uint64(id_),
867+
val_dict,
868+
time_stamp=self._time_stamp,
869+
)
870+
)
871+
for child_id in children:
1015872
self.new_entries.append(
1016873
self.cg.client.mutate_row(
1017-
serialize_uint64(id_),
1018-
val_dict,
874+
serialize_uint64(child_id),
875+
{attributes.Hierarchy.Parent: id_},
1019876
time_stamp=self._time_stamp,
1020877
)
1021878
)
1022-
for child_id in children:
1023-
self.new_entries.append(
1024-
self.cg.client.mutate_row(
1025-
serialize_uint64(child_id),
1026-
{attributes.Hierarchy.Parent: id_},
1027-
time_stamp=self._time_stamp,
1028-
)
1029-
)
1030-
1031-
with self._profiler.profile("update_root_id_lineage"):
1032-
self._update_root_id_lineage()
879+
self._update_root_id_lineage()

0 commit comments

Comments
 (0)