Skip to content

Commit 21ed981

Browse files
anna-grimanna-grim
andauthored
Feat mlp comparison (#665)
* refactor: feedfoward and normalization * gnn with mlp comp * upd --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 506b025 commit 21ed981

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

src/neuron_proofreader/merge_proofreading/merge_datasets.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,6 @@ def clip_fragments_to_groundtruth(self, brain_id, graph):
586586
graph : SkeletonGraph
587587
Fragment graph to be clipped.
588588
"""
589-
# Compute projection distances
590589
assert brain_id in self.gt_graphs, "Must load GT before fragments!"
591590
d_gt, _ = self.gt_graphs[brain_id].kdtree.query(graph.node_xyz)
592591
nodes = np.where(d_gt > 100)[0]
@@ -621,6 +620,37 @@ def is_nearby_merge_site(self, brain_id, node):
621620
dist, _ = self.merge_site_kdtrees[brain_id].query(xyz)
622621
return dist < 100
623622

623+
def relabel_nodes(self):
624+
"""
625+
Reassigns contiguous node IDs and update all dependent structures.
626+
"""
627+
# Set node ids
628+
old_node_ids = np.array(self.nodes, dtype=int)
629+
new_node_ids = np.arange(len(old_node_ids))
630+
631+
# Set edge ids
632+
old_to_new = dict(zip(old_node_ids, new_node_ids))
633+
old_edge_ids = list(self.edges)
634+
old_irr_edge_ids = self.irreducible.edges
635+
edge_attrs = {(i, j): data for i, j, data in self.edges(data=True)}
636+
637+
# Reset graph
638+
self.clear()
639+
for (i, j) in old_edge_ids:
640+
self.add_edge(old_to_new[i], old_to_new[j], **edge_attrs[(i, j)])
641+
642+
self.irreducible.clear()
643+
for (i, j) in old_irr_edge_ids:
644+
self.irreducible.add_edge(old_to_new[i], old_to_new[j])
645+
646+
# Update attributes
647+
self.node_radius = self.node_radius[old_node_ids]
648+
self.node_xyz = self.node_xyz[old_node_ids]
649+
self.node_component_id = self.node_component_id[old_node_ids]
650+
651+
self.reassign_component_ids()
652+
self.set_kdtree()
653+
624654
def sample_node_nearby_soma(self, brain_id):
625655
subgraph = self.gt_graphs[brain_id].get_rooted_subgraph(0, 600)
626656
gt_node = util.sample_once(subgraph.nodes)

0 commit comments

Comments
 (0)