Skip to content

Commit af78dcc

Browse files
committed
edits: batch parents dict and neighbor cx edges updates to reduce io latency, add cache logging and profiler, initial stitch mode
1 parent 59a1d14 commit af78dcc

File tree

10 files changed

+552
-229
lines changed

10 files changed

+552
-229
lines changed

pychunkedgraph/__init__.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,63 @@
11
__version__ = "3.1.6"
2+
3+
import sys
4+
import warnings
5+
import logging as stdlib_logging # Use alias to avoid conflict with pychunkedgraph.logging
6+
7+
# Suppress annoying warning from python_jsonschema_objects dependency
8+
warnings.filterwarnings(
9+
"ignore", message="Schema id not specified", module="python_jsonschema_objects"
10+
)
11+
12+
# Export logging levels for convenience
13+
DEBUG = stdlib_logging.DEBUG
14+
INFO = stdlib_logging.INFO
15+
WARNING = stdlib_logging.WARNING
16+
ERROR = stdlib_logging.ERROR
17+
18+
# Set up library-level logger with NullHandler (Python logging best practice)
19+
stdlib_logging.getLogger(__name__).addHandler(stdlib_logging.NullHandler())
20+
21+
22+
def configure_logging(level=stdlib_logging.INFO, format_str=None, stream=None):
23+
"""
24+
Configure logging for pychunkedgraph. Call this to enable log output.
25+
26+
Works in Jupyter notebooks and scripts.
27+
28+
Args:
29+
level: Logging level (default: INFO). Use pychunkedgraph.DEBUG, .INFO, .WARNING, .ERROR
30+
format_str: Custom format string (optional)
31+
stream: Output stream (default: sys.stdout for Jupyter compatibility)
32+
33+
Example:
34+
import pychunkedgraph
35+
pychunkedgraph.configure_logging() # Enable INFO level logging
36+
pychunkedgraph.configure_logging(pychunkedgraph.DEBUG) # Enable DEBUG level
37+
"""
38+
if format_str is None:
39+
format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
40+
if stream is None:
41+
stream = sys.stdout
42+
43+
# Get root logger for pychunkedgraph
44+
logger = stdlib_logging.getLogger(__name__)
45+
logger.setLevel(level)
46+
47+
# Remove existing handlers and add fresh StreamHandler
48+
# This allows reconfiguring with different levels/formats
49+
for h in logger.handlers[:]:
50+
if isinstance(h, stdlib_logging.StreamHandler) and not isinstance(
51+
h, stdlib_logging.NullHandler
52+
):
53+
logger.removeHandler(h)
54+
55+
handler = stdlib_logging.StreamHandler(stream)
56+
handler.setLevel(level)
57+
handler.setFormatter(stdlib_logging.Formatter(format_str))
58+
logger.addHandler(handler)
59+
60+
return logger
61+
62+
63+
configure_logging()

pychunkedgraph/debug/profiler.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import Dict
2+
from typing import List
3+
from typing import Tuple
4+
5+
import os
6+
import time
7+
from collections import defaultdict
8+
from contextlib import contextmanager
9+
10+
11+
class HierarchicalProfiler:
12+
"""
13+
Hierarchical profiler for detailed timing breakdowns.
14+
Tracks timing at multiple levels and prints a breakdown at the end.
15+
"""
16+
17+
def __init__(self, enabled: bool = True):
18+
self.enabled = enabled
19+
self.timings: Dict[str, List[float]] = defaultdict(list)
20+
self.call_counts: Dict[str, int] = defaultdict(int)
21+
self.stack: List[Tuple[str, float]] = []
22+
self.current_path: List[str] = []
23+
24+
@contextmanager
25+
def profile(self, name: str):
26+
"""Context manager for profiling a code block."""
27+
if not self.enabled:
28+
yield
29+
return
30+
31+
full_path = ".".join(self.current_path + [name])
32+
self.current_path.append(name)
33+
start_time = time.perf_counter()
34+
35+
try:
36+
yield
37+
finally:
38+
elapsed = time.perf_counter() - start_time
39+
self.timings[full_path].append(elapsed)
40+
self.call_counts[full_path] += 1
41+
self.current_path.pop()
42+
43+
def print_report(self, operation_id=None):
44+
"""Print a detailed timing breakdown."""
45+
if not self.enabled or not self.timings:
46+
return
47+
48+
print("\n" + "=" * 80)
49+
print(
50+
f"PROFILER REPORT{f' (operation_id={operation_id})' if operation_id else ''}"
51+
)
52+
print("=" * 80)
53+
54+
# Group by depth level
55+
by_depth: Dict[int, List[Tuple[str, float, int]]] = defaultdict(list)
56+
for path, times in self.timings.items():
57+
depth = path.count(".")
58+
total_time = sum(times)
59+
count = self.call_counts[path]
60+
by_depth[depth].append((path, total_time, count))
61+
62+
# Sort each level by total time
63+
for depth in sorted(by_depth.keys()):
64+
items = sorted(by_depth[depth], key=lambda x: -x[1])
65+
for path, total_time, count in items:
66+
indent = " " * depth
67+
avg_time = total_time / count if count > 0 else 0
68+
if count > 1:
69+
print(
70+
f"{indent}{path}: {total_time*1000:.2f}ms total "
71+
f"({count} calls, {avg_time*1000:.2f}ms avg)"
72+
)
73+
else:
74+
print(f"{indent}{path}: {total_time*1000:.2f}ms")
75+
76+
# Print summary
77+
print("-" * 80)
78+
top_level_total = sum(
79+
sum(times) for path, times in self.timings.items() if "." not in path
80+
)
81+
print(f"Total top-level time: {top_level_total*1000:.2f}ms")
82+
83+
# Print top 10 slowest operations
84+
print("\nTop 10 slowest operations:")
85+
all_ops = [
86+
(path, sum(times), self.call_counts[path])
87+
for path, times in self.timings.items()
88+
]
89+
all_ops.sort(key=lambda x: -x[1])
90+
for i, (path, total_time, count) in enumerate(all_ops[:10]):
91+
pct = (total_time / top_level_total * 100) if top_level_total > 0 else 0
92+
print(f" {i+1}. {path}: {total_time*1000:.2f}ms ({pct:.1f}%)")
93+
94+
print("=" * 80 + "\n")
95+
96+
def reset(self):
97+
"""Reset all timing data."""
98+
self.timings.clear()
99+
self.call_counts.clear()
100+
self.stack.clear()
101+
self.current_path.clear()
102+
103+
104+
# Global profiler instance - enable via environment variable
105+
PROFILER_ENABLED = os.environ.get("PCG_PROFILER_ENABLED", "1") == "1"
106+
_profiler: HierarchicalProfiler = None
107+
108+
109+
def get_profiler() -> HierarchicalProfiler:
110+
"""Get or create the global profiler instance."""
111+
global _profiler
112+
if _profiler is None:
113+
_profiler = HierarchicalProfiler(enabled=PROFILER_ENABLED)
114+
return _profiler
115+
116+
117+
def reset_profiler():
118+
"""Reset the global profiler."""
119+
global _profiler
120+
if _profiler is not None:
121+
_profiler.reset()

pychunkedgraph/debug/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,20 @@ def update_graph_id(cg, new_graph_id:str):
5656
new_gc = GraphConfig(**old_gc)
5757
new_meta = ChunkedGraphMeta(new_gc, cg.meta.data_source, cg.meta.custom_data)
5858
cg.update_meta(new_meta, overwrite=True)
59+
60+
61+
def get_random_l1_ids(cg, n_chunks=100, n_per_chunk=10, seed=None):
62+
"""Generate random layer 1 IDs from different chunks."""
63+
if seed:
64+
np.random.seed(seed)
65+
bounds = cg.meta.layer_chunk_bounds[2]
66+
ids = []
67+
for _ in range(n_chunks):
68+
cx, cy, cz = [np.random.randint(0, b) for b in bounds]
69+
chunk_id = cg.get_chunk_id(layer=2, x=cx, y=cy, z=cz)
70+
max_seg = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id))
71+
if max_seg < 2:
72+
continue
73+
for seg in np.random.randint(1, max_seg + 1, n_per_chunk):
74+
ids.append(cg.get_node_id(np.uint64(seg), np.uint64(chunk_id)))
75+
return np.array(ids, dtype=np.uint64)

pychunkedgraph/graph/cache.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""
33
Cache nodes, parents, children and cross edges.
44
"""
5+
import traceback
6+
from collections import defaultdict
57
from sys import maxsize
68
from datetime import datetime
79

@@ -40,6 +42,30 @@ def __init__(self, cg):
4042
self.children_cache = LRUCache(maxsize=maxsize)
4143
self.cross_chunk_edges_cache = LRUCache(maxsize=maxsize)
4244

45+
# Stats tracking for cache hits/misses
46+
self.stats = {
47+
"parents": {"hits": 0, "misses": 0, "calls": 0},
48+
"children": {"hits": 0, "misses": 0, "calls": 0},
49+
"cross_chunk_edges": {"hits": 0, "misses": 0, "calls": 0},
50+
}
51+
# Track where calls/misses come from
52+
self.call_sources = defaultdict(lambda: defaultdict(lambda: {"calls": 0, "misses": 0}))
53+
54+
def _get_caller(self, skip_frames=2):
55+
"""Get caller info (filename:line:function)."""
56+
stack = traceback.extract_stack()
57+
# Skip frames: _get_caller, the cache method, and go to actual caller
58+
if len(stack) > skip_frames:
59+
frame = stack[-(skip_frames + 1)]
60+
return f"{frame.filename.split('/')[-1]}:{frame.lineno}:{frame.name}"
61+
return "unknown"
62+
63+
def _record_call(self, cache_type, misses=0):
64+
"""Record a call and its source."""
65+
caller = self._get_caller(skip_frames=3)
66+
self.call_sources[cache_type][caller]["calls"] += 1
67+
self.call_sources[cache_type][caller]["misses"] += misses
68+
4369
def __len__(self):
4470
return (
4571
len(self.parents_cache)
@@ -52,14 +78,53 @@ def clear(self):
5278
self.children_cache.clear()
5379
self.cross_chunk_edges_cache.clear()
5480

81+
def get_stats(self):
82+
"""Return stats with hit rates calculated."""
83+
result = {}
84+
for name, s in self.stats.items():
85+
total = s["hits"] + s["misses"]
86+
hit_rate = s["hits"] / total if total > 0 else 0
87+
result[name] = {
88+
**s,
89+
"total": total,
90+
"hit_rate": f"{hit_rate:.1%}",
91+
"sources": dict(self.call_sources[name]),
92+
}
93+
return result
94+
95+
def reset_stats(self):
96+
for s in self.stats.values():
97+
s["hits"] = 0
98+
s["misses"] = 0
99+
s["calls"] = 0
100+
self.call_sources.clear()
101+
55102
def parent(self, node_id: np.uint64, *, time_stamp: datetime = None):
103+
self.stats["parents"]["calls"] += 1
104+
is_cached = node_id in self.parents_cache
105+
miss_count = 0 if is_cached else 1
106+
if is_cached:
107+
self.stats["parents"]["hits"] += 1
108+
else:
109+
self.stats["parents"]["misses"] += 1
110+
self._record_call("parents", misses=miss_count)
111+
56112
@cached(cache=self.parents_cache, key=lambda node_id: node_id)
57113
def parent_decorated(node_id):
58114
return self._cg.get_parent(node_id, raw_only=True, time_stamp=time_stamp)
59115

60116
return parent_decorated(node_id)
61117

62118
def children(self, node_id):
119+
self.stats["children"]["calls"] += 1
120+
is_cached = node_id in self.children_cache
121+
miss_count = 0 if is_cached else 1
122+
if is_cached:
123+
self.stats["children"]["hits"] += 1
124+
else:
125+
self.stats["children"]["misses"] += 1
126+
self._record_call("children", misses=miss_count)
127+
63128
@cached(cache=self.children_cache, key=lambda node_id: node_id)
64129
def children_decorated(node_id):
65130
children = self._cg.get_children(node_id, raw_only=True)
@@ -69,6 +134,15 @@ def children_decorated(node_id):
69134
return children_decorated(node_id)
70135

71136
def cross_chunk_edges(self, node_id, *, time_stamp: datetime = None):
137+
self.stats["cross_chunk_edges"]["calls"] += 1
138+
is_cached = node_id in self.cross_chunk_edges_cache
139+
miss_count = 0 if is_cached else 1
140+
if is_cached:
141+
self.stats["cross_chunk_edges"]["hits"] += 1
142+
else:
143+
self.stats["cross_chunk_edges"]["misses"] += 1
144+
self._record_call("cross_chunk_edges", misses=miss_count)
145+
72146
@cached(cache=self.cross_chunk_edges_cache, key=lambda node_id: node_id)
73147
def cross_edges_decorated(node_id):
74148
edges = self._cg.get_cross_chunk_edges(
@@ -82,7 +156,13 @@ def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None)
82156
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
83157
if not node_ids.size:
84158
return node_ids
159+
self.stats["parents"]["calls"] += 1
85160
mask = np.in1d(node_ids, np.fromiter(self.parents_cache.keys(), dtype=NODE_ID))
161+
hits = int(np.sum(mask))
162+
misses = len(node_ids) - hits
163+
self.stats["parents"]["hits"] += hits
164+
self.stats["parents"]["misses"] += misses
165+
self._record_call("parents", misses=misses)
86166
parents = node_ids.copy()
87167
parents[mask] = self._parent_vec(node_ids[mask])
88168
parents[~mask] = self._cg.get_parents(
@@ -96,7 +176,13 @@ def children_multiple(self, node_ids: np.ndarray, *, flatten=False):
96176
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
97177
if not node_ids.size:
98178
return result
179+
self.stats["children"]["calls"] += 1
99180
mask = np.in1d(node_ids, np.fromiter(self.children_cache.keys(), dtype=NODE_ID))
181+
hits = int(np.sum(mask))
182+
misses = len(node_ids) - hits
183+
self.stats["children"]["hits"] += hits
184+
self.stats["children"]["misses"] += misses
185+
self._record_call("children", misses=misses)
100186
cached_children_ = self._children_vec(node_ids[mask])
101187
result.update({id_: c_ for id_, c_ in zip(node_ids[mask], cached_children_)})
102188
result.update(self._cg.get_children(node_ids[~mask], raw_only=True))
@@ -114,9 +200,15 @@ def cross_chunk_edges_multiple(
114200
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
115201
if not node_ids.size:
116202
return result
203+
self.stats["cross_chunk_edges"]["calls"] += 1
117204
mask = np.in1d(
118205
node_ids, np.fromiter(self.cross_chunk_edges_cache.keys(), dtype=NODE_ID)
119206
)
207+
hits = int(np.sum(mask))
208+
misses = len(node_ids) - hits
209+
self.stats["cross_chunk_edges"]["hits"] += hits
210+
self.stats["cross_chunk_edges"]["misses"] += misses
211+
self._record_call("cross_chunk_edges", misses=misses)
120212
cached_edges_ = self._cross_chunk_edges_vec(node_ids[mask])
121213
result.update(
122214
{id_: edges_ for id_, edges_ in zip(node_ids[mask], cached_edges_)}

0 commit comments

Comments
 (0)