Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions ast_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import posixpath
import sys
import threading
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Iterable

import tree_sitter_java as _ts_java
Expand Down Expand Up @@ -215,10 +215,23 @@
})


@lru_cache(maxsize=1)
# tree-sitter's ``Parser`` mutates internal state during ``parse()`` and is NOT
# thread-safe, so each OS thread gets its own instance. ``parse_java`` is called
# concurrently from worker threads when indexing runs with cocoindex's inflight
# parallelism — both directly (java_index_flow_lancedb.py: process_java_file
# offloads parse+enrich to asyncio.to_thread) and transitively
# (graph_enrich.collect_annotation_meta_chain -> _collect_annotation_decl_index
# -> parse_java, reached from enrich_chunk). The ``Language`` is immutable and
# shared; per-thread ``Parser`` construction is lazy and cheap (once per thread),
# which also preserves parse parallelism instead of serializing it.
_parser_tls = threading.local()


def _parser() -> Parser:
lang = Language(_ts_java.language())
return Parser(lang)
p = getattr(_parser_tls, "parser", None)
if p is None:
_parser_tls.parser = p = Parser(Language(_ts_java.language()))
return p


# ---------- dataclasses ----------
Expand Down
95 changes: 80 additions & 15 deletions java_index_flow_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from __future__ import annotations

import asyncio
import inspect
import os
import sys
Expand Down Expand Up @@ -50,7 +51,7 @@
)
from path_filtering import LayeredIgnore
from ast_java import ONTOLOGY_VERSION, parse_java
from graph_enrich import enrich_chunk
from graph_enrich import collect_annotation_meta_chain, enrich_chunk, load_brownfield_overrides

# Older cocoindex (e.g. 1.0.0a43) uses ``tracked=False``; newer releases renamed
# the flag to ``detect_change`` (default False) and reject ``tracked``.
Expand Down Expand Up @@ -306,6 +307,40 @@ async def _lance_cm() -> AsyncIterator[Any]:
yield


def _parse_and_enrich_java(
content_bytes: bytes,
chunks: list[Any],
rel: str,
project_root: Path,
) -> list[Any]:
"""Parse one Java file and enrich every chunk, off the event loop.

Returns a list of :class:`graph_enrich.ChunkEnrichment` aligned 1:1 with
``chunks``. Intended to run via ``asyncio.to_thread`` from
``process_java_file`` (vectors perf lever #2): while the worker thread
parses + enriches, the event loop is free to drive other files and keep the
embedder's batching queue fed.

Thread-safety: ``parse_java`` uses a per-thread tree-sitter ``Parser``
(see ``ast_java._parser``), so it is safe to call concurrently from these
worker threads — including the transitive ``parse_java`` that ``enrich_chunk``
triggers via ``collect_annotation_meta_chain`` → ``_collect_annotation_decl_index``.
``enrich_chunk`` is otherwise pure-Python over the now-immutable AST; its
``lru_cache`` reads are thread-safe under the GIL.
"""
ast = parse_java(content_bytes)
return [
enrich_chunk(
ast,
chunk_start_byte=ch.start.byte_offset,
chunk_end_byte=ch.end.byte_offset,
file_path=rel,
project_root=project_root,
)
for ch in chunks
]


@coco.fn(memo=True)
async def process_java_file(
file: localfs.File,
Expand All @@ -326,6 +361,9 @@ async def process_java_file(

language = detect_code_language(filename=file.file_path.path.name) or "text"
cs, mn, ov = JAVA_CHUNK
# ``splitter.split`` stays inline: the module-level ``RecursiveSplitter``
# shares one Rust object, so keeping split on the event loop preserves its
# existing single-threaded access (no new cross-file concurrency hazard).
chunks = splitter.split(
content,
cs,
Expand All @@ -335,18 +373,21 @@ async def process_java_file(
)
rel = file.file_path.path.as_posix()
content_bytes = content.encode("utf-8", errors="replace")
ast = parse_java(content_bytes)

for ch in chunks:
# (vectors perf lever #2) parse + enrich off the event loop so the loop can
# keep the embedder's batching queue fed while this file is being parsed.
# parse_java is thread-safe (per-thread tree-sitter Parser in ast_java).
enrichments = await asyncio.to_thread(
_parse_and_enrich_java, content_bytes, chunks, rel, project_root
)
# (vectors perf lever #1) embed all chunks concurrently so the batched
# embedder groups them into one ``model.encode(...)`` (max_batch_size=64)
# instead of N serial batch-of-1 calls. Dominant win for ``increment``
# (few changed files → little cross-file concurrency → otherwise no batching).
embeddings = await asyncio.gather(*(embedder.embed(ch.text) for ch in chunks))

for ch, enrich, emb in zip(chunks, enrichments, embeddings):
rs, re = chunk_key_range(ch)
enrich = enrich_chunk(
ast,
chunk_start_byte=ch.start.byte_offset,
chunk_end_byte=ch.end.byte_offset,
file_path=rel,
project_root=project_root,
)
emb = await embedder.embed(ch.text)
table.declare_row(
row=JavaLanceChunk(
id=str(uuid.uuid4()),
Expand Down Expand Up @@ -401,9 +442,11 @@ async def process_sql_file(
)
rel = file.file_path.path.as_posix()

for ch in chunks:
# (vectors perf lever #1) embed chunks concurrently → batched encode.
embeddings = await asyncio.gather(*(embedder.embed(ch.text) for ch in chunks))

for ch, emb in zip(chunks, embeddings):
rs, re = chunk_key_range(ch)
emb = await embedder.embed(ch.text)
table.declare_row(
row=SqlLanceChunk(
id=str(uuid.uuid4()),
Expand Down Expand Up @@ -448,9 +491,11 @@ async def process_yaml_file(
)
rel = file.file_path.path.as_posix()

for ch in chunks:
# (vectors perf lever #1) embed chunks concurrently → batched encode.
embeddings = await asyncio.gather(*(embedder.embed(ch.text) for ch in chunks))

for ch, emb in zip(chunks, embeddings):
rs, re = chunk_key_range(ch)
emb = await embedder.embed(ch.text)
table.declare_row(
row=YamlLanceChunk(
id=str(uuid.uuid4()),
Expand Down Expand Up @@ -501,6 +546,26 @@ async def app_main() -> None:
)

project_root = coco.use_context(PROJECT_ROOT)
# Warm per-project enrichment caches ONCE on the event-loop thread, BEFORE
# coco.mount_each fans files into worker threads. collect_annotation_meta_chain
# and load_brownfield_overrides are lru_cached per (resolved) project root;
# without warming, the first wave of concurrent process_java_file worker
# threads each cold-miss and redundantly walk+parse the ENTIRE project (a
# thundering herd that would offset the embedding-batching win on large
# repos — perf lever #2 made enrich concurrent). With warming, every worker
# hits a populated cache (lru_cache reads are thread-safe). Key derivation
# mirrors enrich_chunk exactly so the warmed entries are the ones workers hit.
try:
load_brownfield_overrides(project_root)
try:
prs = str(Path(project_root).resolve())
except OSError:
prs = str(project_root)
collect_annotation_meta_chain(prs)
except Exception:
# Warm-up must never break indexing — a failure just means workers
# cold-miss lazily (the pre-warming behavior). Swallow and continue.
pass
_ignore = LayeredIgnore(project_root)
_walk_excludes = _ignore.cocoindex_excluded_patterns()
# Emit ONE approximate total so the parent's renderer can show a determinate
Expand Down
66 changes: 66 additions & 0 deletions tests/test_ast_java_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Regression test: ``parse_java`` must be safe to call from multiple threads.

``ast_java._parser()`` returns a **per-thread** tree-sitter ``Parser`` because
``Parser.parse()`` mutates internal parser state and is not thread-safe on a
shared instance. ``parse_java`` is now reached concurrently from worker threads
when indexing runs with cocoindex's inflight parallelism (both directly from
``process_java_file`` and transitively from ``enrich_chunk`` →
``collect_annotation_meta_chain``). This test locks that invariant in: a future
change that reverts to a single shared ``Parser`` would corrupt parses here
(wrong counts / ``parse_error`` / native crash).
"""
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor

from ast_java import parse_java

_SRC_A = b"""
package com.example.alpha;

import java.util.List;

public class Alpha {
private final Beta beta;
public Alpha(Beta beta) { this.beta = beta; }
public void run(int n) {
for (int i = 0; i < n; i++) { beta.handle(i); }
}
}
"""

_SRC_B = b"""
package com.example.beta;

public class Beta {
public void handle(int x) { System.out.println(x); }
protected int compute(long a, long b) { return (int)(a + b); }
}
"""


def _facts(src: bytes) -> tuple[str, int, int]:
"""Stable structural fingerprint: (package, #types, #methods)."""
ast = parse_java(src)
methods = sum(len(t.methods) for t in ast.all_types)
return (ast.package, len(ast.all_types), methods)


def test_parse_java_concurrent_matches_single_threaded() -> None:
ref_a = _facts(_SRC_A)
ref_b = _facts(_SRC_B)
# Loose sanity: the single-threaded references must be non-trivial and
# distinct, so the equality check below is actually exercising something.
assert ref_a[1] >= 1 and ref_a[2] >= 1
assert ref_b[1] >= 1 and ref_b[2] >= 1
assert ref_a != ref_b

# 16 threads each parse both sources 60×; every result must match the
# single-threaded reference. A shared Parser would corrupt some parses.
def worker() -> bool:
return all(_facts(_SRC_A) == ref_a and _facts(_SRC_B) == ref_b for _ in range(60))

with ThreadPoolExecutor(max_workers=16) as ex:
results = list(ex.map(lambda _: worker(), range(16)))

assert all(results)
Loading