Skip to content

Commit d096a13

Browse files
authored
fix leak mcmc (#2165)
1 parent d11c8dc commit d096a13

File tree

3 files changed

+60
-18
lines changed

3 files changed

+60
-18
lines changed

numpyro/infer/mcmc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from numpyro.diagnostics import print_summary
1717
from numpyro.util import (
18+
_is_under_jax_transform,
1819
cached_by,
1920
find_stack_level,
2021
fori_collect,
@@ -447,7 +448,7 @@ def _postprocess_fn(state, args, kwargs):
447448
)
448449

449450
fns = sample_fn, postprocess_fn
450-
if key is not None:
451+
if key is not None and not _is_under_jax_transform():
451452
self._cache[key] = fns
452453
return fns
453454

@@ -539,7 +540,8 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
539540
kwargs = jax.tree.map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
540541
key = rng_key + args + kwargs
541542
try:
542-
self._init_state_cache[key] = self._last_state
543+
if not _is_under_jax_transform():
544+
self._init_state_cache[key] = self._last_state
543545
# If unhashable arguments are provided, return None
544546
except TypeError:
545547
pass

numpyro/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,28 @@ def identity(x, *args, **kwargs):
193193
return x
194194

195195

196+
def _is_under_jax_transform():
197+
"""Check if we are currently under a JAX transform (e.g. jit, vmap).
198+
199+
When under a transform, caching functions that close over traced values
200+
can cause tracer leaks (see https://github.com/pyro-ppl/numpyro/issues/2000).
201+
"""
202+
from jax._src.core import trace_state_clean
203+
204+
return not trace_state_clean()
205+
206+
196207
def cached_by(outer_fn, *keys):
197208
# Restrict cache size to prevent ref cycles.
198209
max_size = 8
199210
outer_fn._cache = getattr(outer_fn, "_cache", OrderedDict())
200211

201212
def _wrapped(fn):
213+
# Skip caching when inside a JAX tracing context to avoid
214+
# tracer leaks (https://github.com/pyro-ppl/numpyro/issues/2000).
215+
if _is_under_jax_transform():
216+
return fn
217+
202218
fn_cache = outer_fn._cache
203219
hashkeys = (*keys, fn.__name__)
204220
if hashkeys in fn_cache:

test/infer/test_mcmc.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -605,10 +605,6 @@ def model(labels):
605605
@pytest.mark.skipif(
606606
"CI" in os.environ, reason="Compiling time the whole sampling process is slow."
607607
)
608-
@pytest.mark.xfail(
609-
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
610-
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
611-
)
612608
def test_chain_inside_jit(kernel_cls, chain_method):
613609
# NB: this feature is useful for consensus MC.
614610
# Caution: compiling time will be slow (~ 90s)
@@ -665,10 +661,6 @@ def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
665661
@pytest.mark.skipif(
666662
"CI" in os.environ, reason="Compiling time the whole sampling process is slow."
667663
)
668-
@pytest.mark.xfail(
669-
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
670-
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
671-
)
672664
def test_chain_jit_args_smoke(chain_method, compile_args):
673665
def model(data):
674666
concentration = jnp.array([1.0, 1.0, 1.0])
@@ -782,10 +774,6 @@ def potential_fn(z):
782774

783775
@pytest.mark.parametrize("jit_args", [False, True])
784776
@pytest.mark.parametrize("shape", [50, 100])
785-
@pytest.mark.xfail(
786-
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
787-
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
788-
)
789777
def test_reuse_mcmc_run(jit_args, shape):
790778
y1 = np.random.normal(3, 0.1, (100,))
791779
y2 = np.random.normal(-3, 0.1, (shape,))
@@ -806,10 +794,6 @@ def model(y_obs):
806794

807795

808796
@pytest.mark.parametrize("jit_args", [False, True])
809-
@pytest.mark.xfail(
810-
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
811-
reason="Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000",
812-
)
813797
def test_model_with_multiple_exec_paths(jit_args):
814798
def model(a=None, b=None, z=None):
815799
int_term = numpyro.sample("a", dist.Normal(0.0, 0.2))
@@ -839,6 +823,46 @@ def model(a=None, b=None, z=None):
839823
assert set(mcmc.get_samples()) == {"a", "x", "y", "sigma"}
840824

841825

826+
def test_mcmc_inside_jit_no_tracer_leak():
827+
"""Regression test for https://github.com/pyro-ppl/numpyro/issues/2000"""
828+
from numpyro.infer.mcmc import _collect_and_postprocess
829+
from numpyro.util import fori_collect
830+
831+
def model(data):
832+
concentration = jnp.array([1.0, 1.0, 1.0])
833+
p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
834+
numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
835+
836+
@jit
837+
def get_samples(rng_key, data):
838+
kernel = HMC(
839+
model, step_size=1.0, trajectory_length=1.0, target_accept_prob=0.8
840+
)
841+
mcmc = MCMC(
842+
kernel,
843+
num_warmup=5,
844+
num_samples=10,
845+
num_chains=1,
846+
chain_method="sequential",
847+
progress_bar=False,
848+
)
849+
mcmc.run(rng_key, data)
850+
return mcmc.get_samples()
851+
852+
data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.key(1), (100,))
853+
samples = get_samples(random.key(2), data)
854+
assert "p_latent" in samples
855+
856+
# Verify no traced values leaked into module-level caches
857+
for cached_fn in [_collect_and_postprocess, fori_collect]:
858+
cache = getattr(cached_fn, "_cache", {})
859+
leaves = jax.tree.leaves(list(cache.keys()) + list(cache.values()))
860+
for leaf in leaves:
861+
assert not isinstance(leaf, jax.core.Tracer), (
862+
f"Tracer leaked into {cached_fn.__name__}._cache"
863+
)
864+
865+
842866
@pytest.mark.parametrize("num_chains", [1, 2])
843867
@pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"])
844868
@pytest.mark.parametrize("progress_bar", [True, False])

0 commit comments

Comments
 (0)