Skip to content

FIX: Running MCMC within JIT causes tracer leaks#2165

Merged
fehiepsi merged 1 commit intopyro-ppl:masterfrom
juanitorduz:mcmc-leak
Apr 1, 2026
Merged

FIX: Running MCMC within JIT causes tracer leaks#2165
fehiepsi merged 1 commit intopyro-ppl:masterfrom
juanitorduz:mcmc-leak

Conversation

@juanitorduz
Copy link
Copy Markdown
Collaborator

@juanitorduz juanitorduz commented Mar 30, 2026

Closes #2000

(fixed using Claude code)

This is a path to a solution, @fehiepsi. I am unsure if this is the type of solution we are after.

Root Cause

When MCMC runs inside a @jit context, JAX tracers (abstract values used during tracing) get captured in module-level caches through this chain:

_collect_and_postprocess._cache → cached key tuple → partial(_postprocess_fn) →
  closes over MCMC self → self._last_state → HMCState → JAX DynamicJaxprTracers

Three caching sites were affected:

  1. cached_by decorator (used by _collect_and_postprocess and fori_collect): Module-level ._cache OrderedDict storing functions that transitively reference traced values
  2. MCMC._cache: Instance-level cache of sample_fn/postprocess_fn pairs
  3. MCMC._init_state_cache: Instance-level cache of HMC states

Fix

Added _is_under_jax_transform() which uses jax._src.core.trace_state_clean() to detect when code is executing inside a JAX tracing context. When tracing is active, all three caching mechanisms are bypassed:

  • numpyro/util.py: cached_by skips caching when under JIT; new _is_under_jax_transform() helper
  • numpyro/infer/mcmc.py: _get_cached_fns() and _compile() skip caching when under JIT

Tests

  • Removed 4 xfail markers from existing tests that were expected to fail with JAX_CHECK_TRACER_LEAKS=1
  • Added test_mcmc_inside_jit_no_tracer_leak regression test that runs MCMC inside @jit and verifies no tracers leak into module-level caches

@juanitorduz juanitorduz self-assigned this Mar 30, 2026
@juanitorduz juanitorduz marked this pull request as draft March 30, 2026 19:08
@juanitorduz juanitorduz marked this pull request as ready for review March 30, 2026 20:03
@juanitorduz juanitorduz requested a review from fehiepsi March 30, 2026 20:03
@juanitorduz juanitorduz changed the title Issue #2000: Running MCMC within JIT causes tracer leaks FIX: Running MCMC within JIT causes tracer leaks Mar 30, 2026
@fehiepsi
Copy link
Copy Markdown
Member

This seems to be the right approach. Let me think a bit more. Thanks for addressing this!

after this PR, numpyro seems to be leak-free. :)

@fehiepsi fehiepsi merged commit d096a13 into pyro-ppl:master Apr 1, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Running mcmc within jit might cause tracer leak

2 participants