11# pylint: disable=invalid-name, missing-docstring, too-many-locals, c-extension-no-member
22
3- import datetime , random
3+ import datetime , logging , random
44from typing import Dict
55from typing import List
66from typing import Tuple
1616from . import types
1717from . import attributes
1818from . import cache as cache_utils
19- from .edges import get_latest_edges_wrapper , flip_ids
19+ from .edges import get_latest_edges_wrapper , flip_ids , get_new_nodes
2020from .edges .utils import concatenate_cross_edge_dicts
2121from .edges .utils import merge_cross_edge_dicts
2222from .utils import basetypes
2525from ..utils .general import in2d
2626from ..debug .utils import sanity_check , sanity_check_single
2727
28+ logger = logging .getLogger (__name__ )
29+
2830
2931def _init_old_hierarchy (cg , l2ids : np .ndarray , parent_ts : datetime .datetime = None ):
3032 """
@@ -684,11 +686,8 @@ def _update_cross_edge_cache_batched(self, new_ids: list):
684686
685687 # Distribute results back to each parent's cache
686688 # Key insight: edges[:, 0] are children, map them to their parent
687- edge_parents = self .cg .get_roots (
688- edge_nodes ,
689- stop_layer = parent_layer ,
690- ceil = False ,
691- time_stamp = self ._last_ts ,
689+ edge_parents = get_new_nodes (
690+ self .cg , edge_nodes , parent_layer , self ._last_ts
692691 )
693692 edge_parents_d = dict (zip (edge_nodes , edge_parents ))
694693 for new_id in new_ids :
@@ -714,6 +713,48 @@ def _update_cross_edge_cache_batched(self, new_ids: list):
714713 self .cg .cache .cross_chunk_edges_cache [new_id ] = parent_cx_edges_d
715714 return updated_entries
716715
716+ def _get_new_ids (self , chunk_id , count , is_root ):
717+ batch_size = count
718+ new_ids = []
719+ while len (new_ids ) < count :
720+ candidate_ids = self .cg .id_client .create_node_ids (
721+ chunk_id , batch_size , root_chunk = is_root
722+ )
723+ existing = self .cg .client .read_nodes (node_ids = candidate_ids )
724+ non_existing = set (candidate_ids ) - existing .keys ()
725+ new_ids .extend (non_existing )
726+ batch_size = min (batch_size * 2 , 2 ** 16 )
727+ return new_ids [:count ]
728+
729+ def _get_new_parents (self , layer , ccs , graph_ids ) -> tuple [dict , dict ]:
730+ cc_layer_chunk_map = {}
731+ size_map = defaultdict (int )
732+ for i , cc_idx in enumerate (ccs ):
733+ parent_layer = layer + 1 # must be reset for each connected component
734+ cc_ids = graph_ids [cc_idx ]
735+ if len (cc_ids ) == 1 :
736+ # skip connection
737+ parent_layer = self .cg .meta .layer_count
738+ cx_edges_d = self .cg .get_cross_chunk_edges (
739+ [cc_ids [0 ]], time_stamp = self ._last_ts
740+ )
741+ for l in range (layer + 1 , self .cg .meta .layer_count ):
742+ if len (cx_edges_d [cc_ids [0 ]].get (l , types .empty_2d )) > 0 :
743+ parent_layer = l
744+ break
745+ chunk_id = self .cg .get_parent_chunk_id (cc_ids [0 ], parent_layer )
746+ cc_layer_chunk_map [i ] = (parent_layer , chunk_id )
747+ size_map [chunk_id ] += 1
748+
749+ chunk_ids = list (size_map .keys ())
750+ random .shuffle (chunk_ids )
751+ chunk_new_ids_map = {}
752+ layers = self .cg .get_chunk_layers (chunk_ids )
753+ for c , l in zip (chunk_ids , layers ):
754+ is_root = l == self .cg .meta .layer_count
755+ chunk_new_ids_map [c ] = self ._get_new_ids (c , size_map [c ], is_root )
756+ return chunk_new_ids_map , cc_layer_chunk_map
757+
717758 def _create_new_parents (self , layer : int ):
718759 """
719760 keep track of old IDs
@@ -726,37 +767,13 @@ def _create_new_parents(self, layer: int):
726767 """
727768 new_ids = self ._new_ids_d [layer ]
728769 layer_node_ids = self ._get_layer_node_ids (new_ids , layer )
729- components , graph_ids = self ._get_connected_components (layer_node_ids , layer )
730- for cc_indices in components :
731- parent_layer = layer + 1 # must be reset for each connected component
732- cc_ids = graph_ids [cc_indices ]
733- if len (cc_ids ) == 1 :
734- # skip connection
735- parent_layer = self .cg .meta .layer_count
736- cx_edges_d = self .cg .get_cross_chunk_edges (
737- [cc_ids [0 ]], time_stamp = self ._last_ts
738- )
739- for l in range (layer + 1 , self .cg .meta .layer_count ):
740- if len (cx_edges_d [cc_ids [0 ]].get (l , types .empty_2d )) > 0 :
741- parent_layer = l
742- break
770+ ccs , _ids = self ._get_connected_components (layer_node_ids , layer )
771+ new_parents_map , cc_layer_chunk_map = self ._get_new_parents (layer , ccs , _ids )
743772
744- # TODO: handle skip connected root id creation separately
745- chunk_id = self .cg .get_parent_chunk_id (cc_ids [0 ], parent_layer )
746- is_root = parent_layer == self .cg .meta .layer_count
747- batch_size = 1
748- parent = None
749- while parent is None :
750- candidate_ids = self .cg .id_client .create_node_ids (
751- chunk_id , batch_size , root_chunk = is_root
752- )
753- existing = self .cg .client .read_nodes (node_ids = candidate_ids )
754- for cid in candidate_ids :
755- if cid not in existing :
756- parent = cid
757- break
758- if parent is None :
759- batch_size = min (batch_size * 2 , 2 ** 16 )
773+ for i , cc_indices in enumerate (ccs ):
774+ cc_ids = _ids [cc_indices ]
775+ parent_layer , chunk_id = cc_layer_chunk_map [i ]
776+ parent = new_parents_map [chunk_id ].pop ()
760777
761778 self ._new_ids_d [parent_layer ].append (parent )
762779 self ._update_id_lineage (parent , cc_ids , layer , parent_layer )
@@ -786,19 +803,20 @@ def run(self) -> Iterable:
786803 """
787804 self ._new_ids_d [2 ] = self ._new_l2_ids
788805 for layer in range (2 , self .cg .meta .layer_count ):
789- if len (self ._new_ids_d [layer ]) == 0 :
806+ new_nodes = self ._new_ids_d [layer ]
807+ if len (new_nodes ) == 0 :
790808 continue
791- self .cg .cache .new_ids .update (self . _new_ids_d [ layer ] )
809+ self .cg .cache .new_ids .update (new_nodes )
792810 # all new IDs in this layer have been created
793811 # update their cross chunk edges and their neighbors'
794812 with self ._profiler .profile (f"l{ layer } _update_cx_cache" ):
795- entries = self ._update_cross_edge_cache_batched (self . _new_ids_d [ layer ] )
813+ entries = self ._update_cross_edge_cache_batched (new_nodes )
796814 self .new_entries .extend (entries )
797815
798816 with self ._profiler .profile (f"l{ layer } _update_neighbor_cx" ):
799817 entries = _update_neighbor_cx_edges (
800818 self .cg ,
801- self . _new_ids_d [ layer ] ,
819+ new_nodes ,
802820 self ._new_old_id_d ,
803821 self ._old_new_id_d ,
804822 time_stamp = self ._time_stamp ,
@@ -861,10 +879,24 @@ def _get_cross_edges_val_dicts(self):
861879 return val_dicts
862880
863881 def create_new_entries (self ) -> List :
882+ max_layer = self .cg .meta .layer_count
864883 val_dicts = self ._get_cross_edges_val_dicts ()
865- for layer in range (2 , self . cg . meta . layer_count + 1 ):
884+ for layer in range (2 , max_layer + 1 ):
866885 new_ids = self ._new_ids_d [layer ]
867886 for id_ in new_ids :
887+ if self .do_sanity_check :
888+ root_layer = self .cg .get_chunk_layer (self .cg .get_root (id_ ))
889+ assert root_layer == max_layer , (id_ , self .cg .get_root (id_ ))
890+
891+ if layer < max_layer :
892+ try :
893+ _parent = self .cg .get_parent (id_ )
894+ _children = self .cg .get_children (_parent )
895+ assert id_ in _children , (layer , id_ , _parent , _children )
896+ except TypeError as e :
897+ logger .error (id_ , _parent , self .cg .get_root (id_ ))
898+ raise TypeError from e
899+
868900 val_dict = val_dicts .get (id_ , {})
869901 children = self .cg .get_children (id_ )
870902 err = f"parent layer less than children; op { self ._opid } "
0 commit comments