@@ -605,10 +605,6 @@ def model(labels):
605605@pytest .mark .skipif (
606606 "CI" in os .environ , reason = "Compiling time the whole sampling process is slow."
607607)
608- @pytest .mark .xfail (
609- os .getenv ("JAX_CHECK_TRACER_LEAKS" ) == "1" ,
610- reason = "Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000" ,
611- )
612608def test_chain_inside_jit (kernel_cls , chain_method ):
613609 # NB: this feature is useful for consensus MC.
614610 # Caution: compiling time will be slow (~ 90s)
@@ -665,10 +661,6 @@ def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
665661@pytest .mark .skipif (
666662 "CI" in os .environ , reason = "Compiling time the whole sampling process is slow."
667663)
668- @pytest .mark .xfail (
669- os .getenv ("JAX_CHECK_TRACER_LEAKS" ) == "1" ,
670- reason = "Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000" ,
671- )
672664def test_chain_jit_args_smoke (chain_method , compile_args ):
673665 def model (data ):
674666 concentration = jnp .array ([1.0 , 1.0 , 1.0 ])
@@ -782,10 +774,6 @@ def potential_fn(z):
782774
783775@pytest .mark .parametrize ("jit_args" , [False , True ])
784776@pytest .mark .parametrize ("shape" , [50 , 100 ])
785- @pytest .mark .xfail (
786- os .getenv ("JAX_CHECK_TRACER_LEAKS" ) == "1" ,
787- reason = "Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000" ,
788- )
789777def test_reuse_mcmc_run (jit_args , shape ):
790778 y1 = np .random .normal (3 , 0.1 , (100 ,))
791779 y2 = np .random .normal (- 3 , 0.1 , (shape ,))
@@ -806,10 +794,6 @@ def model(y_obs):
806794
807795
808796@pytest .mark .parametrize ("jit_args" , [False , True ])
809- @pytest .mark .xfail (
810- os .getenv ("JAX_CHECK_TRACER_LEAKS" ) == "1" ,
811- reason = "Expected tracer leak: https://github.com/pyro-ppl/numpyro/issues/2000" ,
812- )
813797def test_model_with_multiple_exec_paths (jit_args ):
814798 def model (a = None , b = None , z = None ):
815799 int_term = numpyro .sample ("a" , dist .Normal (0.0 , 0.2 ))
@@ -839,6 +823,46 @@ def model(a=None, b=None, z=None):
839823 assert set (mcmc .get_samples ()) == {"a" , "x" , "y" , "sigma" }
840824
841825
826+ def test_mcmc_inside_jit_no_tracer_leak ():
827+ """Regression test for https://github.com/pyro-ppl/numpyro/issues/2000"""
828+ from numpyro .infer .mcmc import _collect_and_postprocess
829+ from numpyro .util import fori_collect
830+
831+ def model (data ):
832+ concentration = jnp .array ([1.0 , 1.0 , 1.0 ])
833+ p_latent = numpyro .sample ("p_latent" , dist .Dirichlet (concentration ))
834+ numpyro .sample ("obs" , dist .Categorical (p_latent ), obs = data )
835+
836+ @jit
837+ def get_samples (rng_key , data ):
838+ kernel = HMC (
839+ model , step_size = 1.0 , trajectory_length = 1.0 , target_accept_prob = 0.8
840+ )
841+ mcmc = MCMC (
842+ kernel ,
843+ num_warmup = 5 ,
844+ num_samples = 10 ,
845+ num_chains = 1 ,
846+ chain_method = "sequential" ,
847+ progress_bar = False ,
848+ )
849+ mcmc .run (rng_key , data )
850+ return mcmc .get_samples ()
851+
852+ data = dist .Categorical (jnp .array ([0.1 , 0.6 , 0.3 ])).sample (random .key (1 ), (100 ,))
853+ samples = get_samples (random .key (2 ), data )
854+ assert "p_latent" in samples
855+
856+ # Verify no traced values leaked into module-level caches
857+ for cached_fn in [_collect_and_postprocess , fori_collect ]:
858+ cache = getattr (cached_fn , "_cache" , {})
859+ leaves = jax .tree .leaves (list (cache .keys ()) + list (cache .values ()))
860+ for leaf in leaves :
861+ assert not isinstance (leaf , jax .core .Tracer ), (
862+ f"Tracer leaked into { cached_fn .__name__ } ._cache"
863+ )
864+
865+
842866@pytest .mark .parametrize ("num_chains" , [1 , 2 ])
843867@pytest .mark .parametrize ("chain_method" , ["parallel" , "sequential" , "vectorized" ])
844868@pytest .mark .parametrize ("progress_bar" , [True , False ])
0 commit comments