11# pylint: disable=invalid-name, missing-docstring, too-many-locals, c-extension-no-member
22
33import datetime
4- import time
5- import os
64from typing import Dict
75from typing import List
86from typing import Tuple
97from typing import Iterable
108from typing import Set
119from collections import defaultdict
12- from contextlib import contextmanager
1310
1411import fastremap
1512import numpy as np
2926from ..debug .utils import sanity_check , sanity_check_single
3027
3128
32- class HierarchicalProfiler :
33- """
34- Hierarchical profiler for detailed timing breakdowns.
35- Tracks timing at multiple levels and prints a breakdown at the end.
36- """
37-
38- def __init__ (self , enabled : bool = True ):
39- self .enabled = enabled
40- self .timings : Dict [str , List [float ]] = defaultdict (list )
41- self .call_counts : Dict [str , int ] = defaultdict (int )
42- self .stack : List [Tuple [str , float ]] = []
43- self .current_path : List [str ] = []
44-
45- @contextmanager
46- def profile (self , name : str ):
47- """Context manager for profiling a code block."""
48- if not self .enabled :
49- yield
50- return
51-
52- full_path = "." .join (self .current_path + [name ])
53- self .current_path .append (name )
54- start_time = time .perf_counter ()
55-
56- try :
57- yield
58- finally :
59- elapsed = time .perf_counter () - start_time
60- self .timings [full_path ].append (elapsed )
61- self .call_counts [full_path ] += 1
62- self .current_path .pop ()
63-
64- def print_report (self , operation_id = None ):
65- """Print a detailed timing breakdown."""
66- if not self .enabled or not self .timings :
67- return
68-
69- print ("\n " + "=" * 80 )
70- print (f"PROFILER REPORT{ f' (operation_id={ operation_id } )' if operation_id else '' } " )
71- print ("=" * 80 )
72-
73- # Group by depth level
74- by_depth : Dict [int , List [Tuple [str , float , int ]]] = defaultdict (list )
75- for path , times in self .timings .items ():
76- depth = path .count ("." )
77- total_time = sum (times )
78- count = self .call_counts [path ]
79- by_depth [depth ].append ((path , total_time , count ))
80-
81- # Sort each level by total time
82- for depth in sorted (by_depth .keys ()):
83- items = sorted (by_depth [depth ], key = lambda x : - x [1 ])
84- for path , total_time , count in items :
85- indent = " " * depth
86- avg_time = total_time / count if count > 0 else 0
87- if count > 1 :
88- print (
89- f"{ indent } { path } : { total_time * 1000 :.2f} ms total "
90- f"({ count } calls, { avg_time * 1000 :.2f} ms avg)"
91- )
92- else :
93- print (f"{ indent } { path } : { total_time * 1000 :.2f} ms" )
94-
95- # Print summary
96- print ("-" * 80 )
97- top_level_total = sum (
98- sum (times ) for path , times in self .timings .items () if "." not in path
99- )
100- print (f"Total top-level time: { top_level_total * 1000 :.2f} ms" )
101-
102- # Print top 10 slowest operations
103- print ("\n Top 10 slowest operations:" )
104- all_ops = [
105- (path , sum (times ), self .call_counts [path ])
106- for path , times in self .timings .items ()
107- ]
108- all_ops .sort (key = lambda x : - x [1 ])
109- for i , (path , total_time , count ) in enumerate (all_ops [:10 ]):
110- pct = (total_time / top_level_total * 100 ) if top_level_total > 0 else 0
111- print (f" { i + 1 } . { path } : { total_time * 1000 :.2f} ms ({ pct :.1f} %)" )
112-
113- print ("=" * 80 + "\n " )
114-
115- def reset (self ):
116- """Reset all timing data."""
117- self .timings .clear ()
118- self .call_counts .clear ()
119- self .stack .clear ()
120- self .current_path .clear ()
121-
122-
123- # Global profiler instance - enable via environment variable
124- PROFILER_ENABLED = os .environ .get ("PCG_PROFILER_ENABLED" , "1" ) == "1"
125- _profiler : HierarchicalProfiler = None
126-
127-
128- def get_profiler () -> HierarchicalProfiler :
129- """Get or create the global profiler instance."""
130- global _profiler
131- if _profiler is None :
132- _profiler = HierarchicalProfiler (enabled = PROFILER_ENABLED )
133- return _profiler
134-
135-
136- def reset_profiler ():
137- """Reset the global profiler."""
138- global _profiler
139- if _profiler is not None :
140- _profiler .reset ()
141-
142-
14329def _init_old_hierarchy (cg , l2ids : np .ndarray , parent_ts : datetime .datetime = None ):
14430 """
14531 Populates old hierarcy from child to root and also gets children of intermediate nodes.
@@ -169,11 +55,8 @@ def _analyze_affected_edges(
16955
17056 Also returns new cross edges dicts for nodes crossing chunk boundary.
17157 """
172- profiler = get_profiler ()
173-
17458 supervoxels = np .unique (atomic_edges )
175- with profiler .profile ("analyze_get_parents" ):
176- parents = cg .get_parents (supervoxels , time_stamp = parent_ts )
59+ parents = cg .get_parents (supervoxels , time_stamp = parent_ts )
17760 sv_parent_d = dict (zip (supervoxels .tolist (), parents ))
17861 edge_layers = cg .get_cross_chunk_edges_layer (atomic_edges )
17962 parent_edges = [
@@ -318,48 +201,32 @@ def add_edges(
318201 allow_same_segment_merge = False ,
319202 stitch_mode : bool = False ,
320203):
321- profiler = get_profiler ()
322- profiler .reset () # Reset for fresh profiling
323-
324- with profiler .profile ("add_edges" ):
325- with profiler .profile ("analyze_affected_edges" ):
326- edges , l2_cross_edges_d = _analyze_affected_edges (
327- cg , atomic_edges , parent_ts = parent_ts
328- )
329-
330- l2ids = np .unique (edges )
331- if not allow_same_segment_merge and not stitch_mode :
332- with profiler .profile ("validate_roots" ):
333- roots = cg .get_roots (l2ids , assert_roots = True , time_stamp = parent_ts )
334- assert np .unique (roots ).size >= 2 , "L2 IDs must belong to different roots."
335-
336- new_old_id_d = defaultdict (set )
337- old_new_id_d = defaultdict (set )
338-
339- with profiler .profile ("init_old_hierarchy" ):
340- old_hierarchy_d = _init_old_hierarchy (cg , l2ids , parent_ts = parent_ts )
341-
342- with profiler .profile ("get_children" ):
343- atomic_children_d = cg .get_children (l2ids )
344-
345- with profiler .profile ("get_cross_chunk_edges" ):
346- cross_edges_d = merge_cross_edge_dicts (
347- cg .get_cross_chunk_edges (l2ids , time_stamp = parent_ts ), l2_cross_edges_d
348- )
204+ edges , l2_cross_edges_d = _analyze_affected_edges (
205+ cg , atomic_edges , parent_ts = parent_ts
206+ )
207+ l2ids = np .unique (edges )
208+ if not allow_same_segment_merge and not stitch_mode :
209+ roots = cg .get_roots (l2ids , assert_roots = True , time_stamp = parent_ts )
210+ assert np .unique (roots ).size >= 2 , "L2 IDs must belong to different roots."
349211
350- with profiler .profile ("build_graph" ):
351- graph , _ , _ , graph_ids = flatgraph .build_gt_graph (edges , make_directed = True )
352- components = flatgraph .connected_components (graph )
212+ new_old_id_d = defaultdict (set )
213+ old_new_id_d = defaultdict (set )
214+ old_hierarchy_d = _init_old_hierarchy (cg , l2ids , parent_ts = parent_ts )
215+ atomic_children_d = cg .get_children (l2ids )
216+ cross_edges_d = merge_cross_edge_dicts (
217+ cg .get_cross_chunk_edges (l2ids , time_stamp = parent_ts ), l2_cross_edges_d
218+ )
353219
354- with profiler .profile ("create_l2_ids" ):
355- new_l2_ids = []
356- for cc_indices in components :
357- l2ids_ = graph_ids [cc_indices ]
358- new_id = cg .id_client .create_node_id (cg .get_chunk_id (l2ids_ [0 ]))
359- new_l2_ids .append (new_id )
360- new_old_id_d [new_id ].update (l2ids_ )
361- for id_ in l2ids_ :
362- old_new_id_d [id_ ].add (new_id )
220+ graph , _ , _ , graph_ids = flatgraph .build_gt_graph (edges , make_directed = True )
221+ components = flatgraph .connected_components (graph )
222+ new_l2_ids = []
223+ for cc_indices in components :
224+ l2ids_ = graph_ids [cc_indices ]
225+ new_id = cg .id_client .create_node_id (cg .get_chunk_id (l2ids_ [0 ]))
226+ new_l2_ids .append (new_id )
227+ new_old_id_d [new_id ].update (l2ids_ )
228+ for id_ in l2ids_ :
229+ old_new_id_d [id_ ].add (new_id )
363230
364231 # update cache
365232 # map parent to new merged children and vice versa
@@ -368,20 +235,19 @@ def add_edges(
368235 cg .cache .children_cache [new_id ] = merged_children
369236 cache_utils .update (cg .cache .parents_cache , merged_children , new_id )
370237
371- # update cross chunk edges by replacing old_ids with new
372- # this can be done only after all new IDs have been created
373- with profiler .profile ("update_cross_edges" ):
374- for new_id , cc_indices in zip (new_l2_ids , components ):
375- l2ids_ = graph_ids [cc_indices ]
376- new_cx_edges_d = {}
377- cx_edges = [cross_edges_d [l2id ] for l2id in l2ids_ ]
378- cx_edges_d = concatenate_cross_edge_dicts (cx_edges , unique = True )
379- temp_map = {k : next (iter (v )) for k , v in old_new_id_d .items ()}
380- for layer , edges in cx_edges_d .items ():
381- edges = fastremap .remap (edges , temp_map , preserve_missing_labels = True )
382- new_cx_edges_d [layer ] = edges
383- assert np .all (edges [:, 0 ] == new_id )
384- cg .cache .cross_chunk_edges_cache [new_id ] = new_cx_edges_d
238+ # update cross chunk edges by replacing old_ids with new
239+ # this can be done only after all new IDs have been created
240+ for new_id , cc_indices in zip (new_l2_ids , components ):
241+ l2ids_ = graph_ids [cc_indices ]
242+ new_cx_edges_d = {}
243+ cx_edges = [cross_edges_d [l2id ] for l2id in l2ids_ ]
244+ cx_edges_d = concatenate_cross_edge_dicts (cx_edges , unique = True )
245+ temp_map = {k : next (iter (v )) for k , v in old_new_id_d .items ()}
246+ for layer , edges in cx_edges_d .items ():
247+ edges = fastremap .remap (edges , temp_map , preserve_missing_labels = True )
248+ new_cx_edges_d [layer ] = edges
249+ assert np .all (edges [:, 0 ] == new_id )
250+ cg .cache .cross_chunk_edges_cache [new_id ] = new_cx_edges_d
385251
386252 profiler = get_profiler ()
387253 profiler .reset ()
@@ -643,12 +509,8 @@ def _update_neighbor_cx_edges(
643509 and then write to storage to consolidate the mutations.
644510 Returns mutations to updated counterparts/partner nodes.
645511 """
646- profiler = get_profiler ()
647512 updated_counterparts = {}
648-
649- with profiler .profile ("neighbor_get_cross_chunk_edges" ):
650- newid_cx_edges_d = cg .get_cross_chunk_edges (new_ids , time_stamp = parent_ts )
651-
513+ newid_cx_edges_d = cg .get_cross_chunk_edges (new_ids , time_stamp = parent_ts )
652514 node_map = {}
653515 for k , v in old_new_id .items ():
654516 if len (v ) == 1 :
@@ -670,14 +532,11 @@ def _update_neighbor_cx_edges(
670532 cg , new_id , node_map , cp_layers , all_cx_edges_d
671533 )
672534 updated_counterparts .update (result )
673-
674- with profiler .profile ("neighbor_create_mutations" ):
675- updated_entries = []
676- for node , val_dict in updated_counterparts .items ():
677- rowkey = serialize_uint64 (node )
678- row = cg .client .mutate_row (rowkey , val_dict , time_stamp = time_stamp )
679- updated_entries .append (row )
680-
535+ updated_entries = []
536+ for node , val_dict in updated_counterparts .items ():
537+ rowkey = serialize_uint64 (node )
538+ row = cg .client .mutate_row (rowkey , val_dict , time_stamp = time_stamp )
539+ updated_entries .append (row )
681540 return updated_entries
682541
683542
@@ -748,7 +607,6 @@ def _get_layer_node_ids(
748607 # get their parents, then children of those parents
749608 old_parents = self .cg .get_parents (old_ids , time_stamp = self ._last_ts )
750609 siblings = self .cg .get_children (np .unique (old_parents ), flatten = True )
751-
752610 # replace old identities with new IDs
753611 mask = np .isin (siblings , old_ids )
754612 node_ids = [_flip_ids (self ._old_new_id_d , old_ids ), siblings [~ mask ], new_ids ]
@@ -781,7 +639,6 @@ def _update_cross_edge_cache(self, parent, children):
781639 ceil = False ,
782640 time_stamp = self ._last_ts ,
783641 )
784-
785642 edge_parents_d = dict (zip (edge_nodes , edge_parents ))
786643 new_cx_edges_d = {}
787644 for layer in range (parent_layer , self .cg .meta .layer_count ):
@@ -862,14 +719,9 @@ def _create_new_parents(self, layer: int):
862719 update parent old IDs
863720 """
864721 new_ids = self ._new_ids_d [layer ]
865-
866- with self ._profiler .profile ("get_layer_node_ids" ):
867- layer_node_ids = self ._get_layer_node_ids (new_ids , layer )
868-
869- with self ._profiler .profile ("get_connected_components" ):
870- components , graph_ids = self ._get_connected_components (layer_node_ids , layer )
871-
872- for cc_idx , cc_indices in enumerate (components ):
722+ layer_node_ids = self ._get_layer_node_ids (new_ids , layer )
723+ components , graph_ids = self ._get_connected_components (layer_node_ids , layer )
724+ for cc_indices in components :
873725 parent_layer = layer + 1 # must be reset for each connected component
874726 cc_ids = graph_ids [cc_indices ]
875727 if len (cc_ids ) == 1 :
@@ -998,35 +850,30 @@ def _get_cross_edges_val_dicts(self):
998850 return val_dicts
999851
1000852 def create_new_entries (self ) -> List :
1001- with self ._profiler .profile ("get_cross_edges_val_dicts" ):
1002- val_dicts = self ._get_cross_edges_val_dicts ()
1003-
1004- with self ._profiler .profile ("build_hierarchy_entries" ):
1005- for layer in range (2 , self .cg .meta .layer_count + 1 ):
1006- new_ids = self ._new_ids_d [layer ]
1007- for id_ in new_ids :
1008- val_dict = val_dicts .get (id_ , {})
1009- children = self .cg .get_children (id_ )
1010- err = f"parent layer less than children; op { self ._operation_id } "
1011- assert np .max (
1012- self .cg .get_chunk_layers (children )
1013- ) < self .cg .get_chunk_layer (id_ ), err
1014- val_dict [attributes .Hierarchy .Child ] = children
853+ val_dicts = self ._get_cross_edges_val_dicts ()
854+ for layer in range (2 , self .cg .meta .layer_count + 1 ):
855+ new_ids = self ._new_ids_d [layer ]
856+ for id_ in new_ids :
857+ val_dict = val_dicts .get (id_ , {})
858+ children = self .cg .get_children (id_ )
859+ err = f"parent layer less than children; op { self ._operation_id } "
860+ assert np .max (
861+ self .cg .get_chunk_layers (children )
862+ ) < self .cg .get_chunk_layer (id_ ), err
863+ val_dict [attributes .Hierarchy .Child ] = children
864+ self .new_entries .append (
865+ self .cg .client .mutate_row (
866+ serialize_uint64 (id_ ),
867+ val_dict ,
868+ time_stamp = self ._time_stamp ,
869+ )
870+ )
871+ for child_id in children :
1015872 self .new_entries .append (
1016873 self .cg .client .mutate_row (
1017- serialize_uint64 (id_ ),
1018- val_dict ,
874+ serialize_uint64 (child_id ),
875+ { attributes . Hierarchy . Parent : id_ } ,
1019876 time_stamp = self ._time_stamp ,
1020877 )
1021878 )
1022- for child_id in children :
1023- self .new_entries .append (
1024- self .cg .client .mutate_row (
1025- serialize_uint64 (child_id ),
1026- {attributes .Hierarchy .Parent : id_ },
1027- time_stamp = self ._time_stamp ,
1028- )
1029- )
1030-
1031- with self ._profiler .profile ("update_root_id_lineage" ):
1032- self ._update_root_id_lineage ()
879+ self ._update_root_id_lineage ()
0 commit comments