Skip to content

Commit cdc90b3

Browse files
committed
test: analyze cache hits/misses
1 parent 59a1d14 commit cdc90b3

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

pychunkedgraph/__init__.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,56 @@
11
__version__ = "3.1.6"
2+
3+
import sys
4+
import warnings
5+
import logging as stdlib_logging # Use alias to avoid conflict with pychunkedgraph.logging
6+
7+
# Suppress annoying warning from python_jsonschema_objects dependency
8+
warnings.filterwarnings("ignore", message="Schema id not specified", module="python_jsonschema_objects")
9+
10+
# Export logging levels for convenience
11+
DEBUG = stdlib_logging.DEBUG
12+
INFO = stdlib_logging.INFO
13+
WARNING = stdlib_logging.WARNING
14+
ERROR = stdlib_logging.ERROR
15+
16+
# Set up library-level logger with NullHandler (Python logging best practice)
17+
stdlib_logging.getLogger(__name__).addHandler(stdlib_logging.NullHandler())
18+
19+
20+
def configure_logging(level=stdlib_logging.INFO, format_str=None, stream=None):
21+
"""
22+
Configure logging for pychunkedgraph. Call this to enable log output.
23+
24+
Works in Jupyter notebooks and scripts.
25+
26+
Args:
27+
level: Logging level (default: INFO). Use pychunkedgraph.DEBUG, .INFO, .WARNING, .ERROR
28+
format_str: Custom format string (optional)
29+
stream: Output stream (default: sys.stdout for Jupyter compatibility)
30+
31+
Example:
32+
import pychunkedgraph
33+
pychunkedgraph.configure_logging() # Enable INFO level logging
34+
pychunkedgraph.configure_logging(pychunkedgraph.DEBUG) # Enable DEBUG level
35+
"""
36+
if format_str is None:
37+
format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
38+
if stream is None:
39+
stream = sys.stdout
40+
41+
# Get root logger for pychunkedgraph
42+
logger = stdlib_logging.getLogger(__name__)
43+
logger.setLevel(level)
44+
45+
# Remove existing handlers and add fresh StreamHandler
46+
# This allows reconfiguring with different levels/formats
47+
for h in logger.handlers[:]:
48+
if isinstance(h, stdlib_logging.StreamHandler) and not isinstance(h, stdlib_logging.NullHandler):
49+
logger.removeHandler(h)
50+
51+
handler = stdlib_logging.StreamHandler(stream)
52+
handler.setLevel(level)
53+
handler.setFormatter(stdlib_logging.Formatter(format_str))
54+
logger.addHandler(handler)
55+
56+
return logger

pychunkedgraph/graph/cache.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""
33
Cache nodes, parents, children and cross edges.
44
"""
5+
import traceback
6+
from collections import defaultdict
57
from sys import maxsize
68
from datetime import datetime
79

@@ -40,6 +42,30 @@ def __init__(self, cg):
4042
self.children_cache = LRUCache(maxsize=maxsize)
4143
self.cross_chunk_edges_cache = LRUCache(maxsize=maxsize)
4244

45+
# Stats tracking for cache hits/misses
46+
self.stats = {
47+
"parents": {"hits": 0, "misses": 0, "calls": 0},
48+
"children": {"hits": 0, "misses": 0, "calls": 0},
49+
"cross_chunk_edges": {"hits": 0, "misses": 0, "calls": 0},
50+
}
51+
# Track where calls/misses come from
52+
self.call_sources = defaultdict(lambda: defaultdict(lambda: {"calls": 0, "misses": 0}))
53+
54+
def _get_caller(self, skip_frames=2):
55+
"""Get caller info (filename:line:function)."""
56+
stack = traceback.extract_stack()
57+
# Skip frames: _get_caller, the cache method, and go to actual caller
58+
if len(stack) > skip_frames:
59+
frame = stack[-(skip_frames + 1)]
60+
return f"{frame.filename.split('/')[-1]}:{frame.lineno}:{frame.name}"
61+
return "unknown"
62+
63+
def _record_call(self, cache_type, misses=0):
64+
"""Record a call and its source."""
65+
caller = self._get_caller(skip_frames=3)
66+
self.call_sources[cache_type][caller]["calls"] += 1
67+
self.call_sources[cache_type][caller]["misses"] += misses
68+
4369
def __len__(self):
4470
return (
4571
len(self.parents_cache)
@@ -52,14 +78,53 @@ def clear(self):
5278
self.children_cache.clear()
5379
self.cross_chunk_edges_cache.clear()
5480

81+
def get_stats(self):
82+
"""Return stats with hit rates calculated."""
83+
result = {}
84+
for name, s in self.stats.items():
85+
total = s["hits"] + s["misses"]
86+
hit_rate = s["hits"] / total if total > 0 else 0
87+
result[name] = {
88+
**s,
89+
"total": total,
90+
"hit_rate": f"{hit_rate:.1%}",
91+
"sources": dict(self.call_sources[name]),
92+
}
93+
return result
94+
95+
def reset_stats(self):
96+
for s in self.stats.values():
97+
s["hits"] = 0
98+
s["misses"] = 0
99+
s["calls"] = 0
100+
self.call_sources.clear()
101+
55102
def parent(self, node_id: np.uint64, *, time_stamp: datetime = None):
103+
self.stats["parents"]["calls"] += 1
104+
is_cached = node_id in self.parents_cache
105+
miss_count = 0 if is_cached else 1
106+
if is_cached:
107+
self.stats["parents"]["hits"] += 1
108+
else:
109+
self.stats["parents"]["misses"] += 1
110+
self._record_call("parents", misses=miss_count)
111+
56112
@cached(cache=self.parents_cache, key=lambda node_id: node_id)
57113
def parent_decorated(node_id):
58114
return self._cg.get_parent(node_id, raw_only=True, time_stamp=time_stamp)
59115

60116
return parent_decorated(node_id)
61117

62118
def children(self, node_id):
119+
self.stats["children"]["calls"] += 1
120+
is_cached = node_id in self.children_cache
121+
miss_count = 0 if is_cached else 1
122+
if is_cached:
123+
self.stats["children"]["hits"] += 1
124+
else:
125+
self.stats["children"]["misses"] += 1
126+
self._record_call("children", misses=miss_count)
127+
63128
@cached(cache=self.children_cache, key=lambda node_id: node_id)
64129
def children_decorated(node_id):
65130
children = self._cg.get_children(node_id, raw_only=True)
@@ -69,6 +134,15 @@ def children_decorated(node_id):
69134
return children_decorated(node_id)
70135

71136
def cross_chunk_edges(self, node_id, *, time_stamp: datetime = None):
137+
self.stats["cross_chunk_edges"]["calls"] += 1
138+
is_cached = node_id in self.cross_chunk_edges_cache
139+
miss_count = 0 if is_cached else 1
140+
if is_cached:
141+
self.stats["cross_chunk_edges"]["hits"] += 1
142+
else:
143+
self.stats["cross_chunk_edges"]["misses"] += 1
144+
self._record_call("cross_chunk_edges", misses=miss_count)
145+
72146
@cached(cache=self.cross_chunk_edges_cache, key=lambda node_id: node_id)
73147
def cross_edges_decorated(node_id):
74148
edges = self._cg.get_cross_chunk_edges(
@@ -82,7 +156,13 @@ def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None)
82156
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
83157
if not node_ids.size:
84158
return node_ids
159+
self.stats["parents"]["calls"] += 1
85160
mask = np.in1d(node_ids, np.fromiter(self.parents_cache.keys(), dtype=NODE_ID))
161+
hits = int(np.sum(mask))
162+
misses = len(node_ids) - hits
163+
self.stats["parents"]["hits"] += hits
164+
self.stats["parents"]["misses"] += misses
165+
self._record_call("parents", misses=misses)
86166
parents = node_ids.copy()
87167
parents[mask] = self._parent_vec(node_ids[mask])
88168
parents[~mask] = self._cg.get_parents(
@@ -96,7 +176,13 @@ def children_multiple(self, node_ids: np.ndarray, *, flatten=False):
96176
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
97177
if not node_ids.size:
98178
return result
179+
self.stats["children"]["calls"] += 1
99180
mask = np.in1d(node_ids, np.fromiter(self.children_cache.keys(), dtype=NODE_ID))
181+
hits = int(np.sum(mask))
182+
misses = len(node_ids) - hits
183+
self.stats["children"]["hits"] += hits
184+
self.stats["children"]["misses"] += misses
185+
self._record_call("children", misses=misses)
100186
cached_children_ = self._children_vec(node_ids[mask])
101187
result.update({id_: c_ for id_, c_ in zip(node_ids[mask], cached_children_)})
102188
result.update(self._cg.get_children(node_ids[~mask], raw_only=True))
@@ -114,9 +200,15 @@ def cross_chunk_edges_multiple(
114200
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
115201
if not node_ids.size:
116202
return result
203+
self.stats["cross_chunk_edges"]["calls"] += 1
117204
mask = np.in1d(
118205
node_ids, np.fromiter(self.cross_chunk_edges_cache.keys(), dtype=NODE_ID)
119206
)
207+
hits = int(np.sum(mask))
208+
misses = len(node_ids) - hits
209+
self.stats["cross_chunk_edges"]["hits"] += hits
210+
self.stats["cross_chunk_edges"]["misses"] += misses
211+
self._record_call("cross_chunk_edges", misses=misses)
120212
cached_edges_ = self._cross_chunk_edges_vec(node_ids[mask])
121213
result.update(
122214
{id_: edges_ for id_, edges_ in zip(node_ids[mask], cached_edges_)}

pychunkedgraph/graph/operation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=invalid-name, missing-docstring, too-many-lines, protected-access, broad-exception-raised
22

3+
import logging
34
from abc import ABC, abstractmethod
45
from collections import namedtuple
56
from datetime import datetime
@@ -16,6 +17,8 @@
1617
import numpy as np
1718
from google.cloud import bigtable
1819

20+
logger = logging.getLogger(__name__)
21+
1922
from . import locks
2023
from . import edits
2124
from . import types
@@ -444,6 +447,19 @@ def execute(
444447
operation_id=lock.operation_id,
445448
timestamp=override_ts if override_ts else timestamp,
446449
)
450+
# Log cache stats
451+
if self.cg.cache:
452+
stats = self.cg.cache.get_stats()
453+
lines = [f"[Op {lock.operation_id}] Cache:"]
454+
for name, s in stats.items():
455+
lines.append(f" {name}: {s['hit_rate']} hit ({s['hits']}/{s['total']}) calls={s['calls']}")
456+
# Show top miss sources if any
457+
if s.get("sources"):
458+
top_sources = sorted(s["sources"].items(), key=lambda x: -x[1]["misses"])[:3]
459+
if top_sources and any(src[1]["misses"] > 0 for src in top_sources):
460+
src_str = ", ".join(f"{k}({v['misses']})" for k, v in top_sources if v["misses"] > 0)
461+
lines.append(f" miss sources: {src_str}")
462+
logger.info("\n".join(lines))
447463
if self.cg.meta.READ_ONLY:
448464
# return without persisting changes
449465
return GraphEditOperation.Result(

0 commit comments

Comments
 (0)