diff --git a/ast_java.py b/ast_java.py index f6bf063..c3dc239 100644 --- a/ast_java.py +++ b/ast_java.py @@ -14,8 +14,8 @@ import posixpath import sys +import threading from dataclasses import dataclass, field -from functools import lru_cache from typing import Iterable import tree_sitter_java as _ts_java @@ -215,10 +215,23 @@ }) -@lru_cache(maxsize=1) +# tree-sitter's ``Parser`` mutates internal state during ``parse()`` and is NOT +# thread-safe, so each OS thread gets its own instance. ``parse_java`` is called +# concurrently from worker threads when indexing runs with cocoindex's inflight +# parallelism — both directly (java_index_flow_lancedb.py: process_java_file +# offloads parse+enrich to asyncio.to_thread) and transitively +# (graph_enrich.collect_annotation_meta_chain -> _collect_annotation_decl_index +# -> parse_java, reached from enrich_chunk). The ``Language`` is immutable and +# shared; per-thread ``Parser`` construction is lazy and cheap (once per thread), +# which also preserves parse parallelism instead of serializing it. +_parser_tls = threading.local() + + def _parser() -> Parser: - lang = Language(_ts_java.language()) - return Parser(lang) + p = getattr(_parser_tls, "parser", None) + if p is None: + _parser_tls.parser = p = Parser(Language(_ts_java.language())) + return p # ---------- dataclasses ---------- diff --git a/java_index_flow_lancedb.py b/java_index_flow_lancedb.py index 0f6edac..3308118 100644 --- a/java_index_flow_lancedb.py +++ b/java_index_flow_lancedb.py @@ -16,6 +16,7 @@ """ from __future__ import annotations +import asyncio import inspect import os import sys @@ -50,7 +51,7 @@ ) from path_filtering import LayeredIgnore from ast_java import ONTOLOGY_VERSION, parse_java -from graph_enrich import enrich_chunk +from graph_enrich import collect_annotation_meta_chain, enrich_chunk, load_brownfield_overrides # Older cocoindex (e.g. 1.0.0a43) uses ``tracked=False``; newer releases renamed # the flag to ``detect_change`` (default False) and reject ``tracked``. @@ -306,6 +307,40 @@ async def _lance_cm() -> AsyncIterator[Any]: yield +def _parse_and_enrich_java( + content_bytes: bytes, + chunks: list[Any], + rel: str, + project_root: Path, +) -> list[Any]: + """Parse one Java file and enrich every chunk, off the event loop. + + Returns a list of :class:`graph_enrich.ChunkEnrichment` aligned 1:1 with + ``chunks``. Intended to run via ``asyncio.to_thread`` from + ``process_java_file`` (vectors perf lever #2): while the worker thread + parses + enriches, the event loop is free to drive other files and keep the + embedder's batching queue fed. + + Thread-safety: ``parse_java`` uses a per-thread tree-sitter ``Parser`` + (see ``ast_java._parser``), so it is safe to call concurrently from these + worker threads — including the transitive ``parse_java`` that ``enrich_chunk`` + triggers via ``collect_annotation_meta_chain`` → ``_collect_annotation_decl_index``. + ``enrich_chunk`` is otherwise pure-Python over the now-immutable AST; its + ``lru_cache`` reads are thread-safe under the GIL. + """ + ast = parse_java(content_bytes) + return [ + enrich_chunk( + ast, + chunk_start_byte=ch.start.byte_offset, + chunk_end_byte=ch.end.byte_offset, + file_path=rel, + project_root=project_root, + ) + for ch in chunks + ] + + @coco.fn(memo=True) async def process_java_file( file: localfs.File, @@ -326,6 +361,9 @@ async def process_java_file( language = detect_code_language(filename=file.file_path.path.name) or "text" cs, mn, ov = JAVA_CHUNK + # ``splitter.split`` stays inline: the module-level ``RecursiveSplitter`` + # shares one Rust object, so keeping split on the event loop preserves its + # existing single-threaded access (no new cross-file concurrency hazard). chunks = splitter.split( content, cs, @@ -335,18 +373,21 @@ async def process_java_file( ) rel = file.file_path.path.as_posix() content_bytes = content.encode("utf-8", errors="replace") - ast = parse_java(content_bytes) - for ch in chunks: + # (vectors perf lever #2) parse + enrich off the event loop so the loop can + # keep the embedder's batching queue fed while this file is being parsed. + # parse_java is thread-safe (per-thread tree-sitter Parser in ast_java). + enrichments = await asyncio.to_thread( + _parse_and_enrich_java, content_bytes, chunks, rel, project_root + ) + # (vectors perf lever #1) embed all chunks concurrently so the batched + # embedder groups them into one ``model.encode(...)`` (max_batch_size=64) + # instead of N serial batch-of-1 calls. Dominant win for ``increment`` + # (few changed files → little cross-file concurrency → otherwise no batching). + embeddings = await asyncio.gather(*(embedder.embed(ch.text) for ch in chunks)) + + for ch, enrich, emb in zip(chunks, enrichments, embeddings): rs, re = chunk_key_range(ch) - enrich = enrich_chunk( - ast, - chunk_start_byte=ch.start.byte_offset, - chunk_end_byte=ch.end.byte_offset, - file_path=rel, - project_root=project_root, - ) - emb = await embedder.embed(ch.text) table.declare_row( row=JavaLanceChunk( id=str(uuid.uuid4()), @@ -401,9 +442,11 @@ async def process_sql_file( ) rel = file.file_path.path.as_posix() - for ch in chunks: + # (vectors perf lever #1) embed chunks concurrently → batched encode. + embeddings = await asyncio.gather(*(embedder.embed(ch.text) for ch in chunks)) + + for ch, emb in zip(chunks, embeddings): rs, re = chunk_key_range(ch) - emb = await embedder.embed(ch.text) table.declare_row( row=SqlLanceChunk( id=str(uuid.uuid4()), @@ -448,9 +491,11 @@ async def process_yaml_file( ) rel = file.file_path.path.as_posix() - for ch in chunks: + # (vectors perf lever #1) embed chunks concurrently → batched encode. + embeddings = await asyncio.gather(*(embedder.embed(ch.text) for ch in chunks)) + + for ch, emb in zip(chunks, embeddings): rs, re = chunk_key_range(ch) - emb = await embedder.embed(ch.text) table.declare_row( row=YamlLanceChunk( id=str(uuid.uuid4()), @@ -501,6 +546,26 @@ async def app_main() -> None: ) project_root = coco.use_context(PROJECT_ROOT) + # Warm per-project enrichment caches ONCE on the event-loop thread, BEFORE + # coco.mount_each fans files into worker threads. collect_annotation_meta_chain + # and load_brownfield_overrides are lru_cached per (resolved) project root; + # without warming, the first wave of concurrent process_java_file worker + # threads each cold-miss and redundantly walk+parse the ENTIRE project (a + # thundering herd that would offset the embedding-batching win on large + # repos — perf lever #2 made enrich concurrent). With warming, every worker + # hits a populated cache (lru_cache reads are thread-safe). Key derivation + # mirrors enrich_chunk exactly so the warmed entries are the ones workers hit. + try: + load_brownfield_overrides(project_root) + try: + prs = str(Path(project_root).resolve()) + except OSError: + prs = str(project_root) + collect_annotation_meta_chain(prs) + except Exception: + # Warm-up must never break indexing — a failure just means workers + # cold-miss lazily (the pre-warming behavior). Swallow and continue. + pass _ignore = LayeredIgnore(project_root) _walk_excludes = _ignore.cocoindex_excluded_patterns() # Emit ONE approximate total so the parent's renderer can show a determinate diff --git a/tests/test_ast_java_thread_safety.py b/tests/test_ast_java_thread_safety.py new file mode 100644 index 0000000..dc5e3f2 --- /dev/null +++ b/tests/test_ast_java_thread_safety.py @@ -0,0 +1,66 @@ +"""Regression test: ``parse_java`` must be safe to call from multiple threads. + +``ast_java._parser()`` returns a **per-thread** tree-sitter ``Parser`` because +``Parser.parse()`` mutates internal parser state and is not thread-safe on a +shared instance. ``parse_java`` is now reached concurrently from worker threads +when indexing runs with cocoindex's inflight parallelism (both directly from +``process_java_file`` and transitively from ``enrich_chunk`` → +``collect_annotation_meta_chain``). This test locks that invariant in: a future +change that reverts to a single shared ``Parser`` would corrupt parses here +(wrong counts / ``parse_error`` / native crash). +""" +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +from ast_java import parse_java + +_SRC_A = b""" +package com.example.alpha; + +import java.util.List; + +public class Alpha { + private final Beta beta; + public Alpha(Beta beta) { this.beta = beta; } + public void run(int n) { + for (int i = 0; i < n; i++) { beta.handle(i); } + } +} +""" + +_SRC_B = b""" +package com.example.beta; + +public class Beta { + public void handle(int x) { System.out.println(x); } + protected int compute(long a, long b) { return (int)(a + b); } +} +""" + + +def _facts(src: bytes) -> tuple[str, int, int]: + """Stable structural fingerprint: (package, #types, #methods).""" + ast = parse_java(src) + methods = sum(len(t.methods) for t in ast.all_types) + return (ast.package, len(ast.all_types), methods) + + +def test_parse_java_concurrent_matches_single_threaded() -> None: + ref_a = _facts(_SRC_A) + ref_b = _facts(_SRC_B) + # Loose sanity: the single-threaded references must be non-trivial and + # distinct, so the equality check below is actually exercising something. + assert ref_a[1] >= 1 and ref_a[2] >= 1 + assert ref_b[1] >= 1 and ref_b[2] >= 1 + assert ref_a != ref_b + + # 16 threads each parse both sources 60×; every result must match the + # single-threaded reference. A shared Parser would corrupt some parses. + def worker() -> bool: + return all(_facts(_SRC_A) == ref_a and _facts(_SRC_B) == ref_b for _ in range(60)) + + with ThreadPoolExecutor(max_workers=16) as ex: + results = list(ex.map(lambda _: worker(), range(16))) + + assert all(results)