@@ -2493,3 +2493,31 @@ def enum_loss_fn(params_raw):
24932493
24942494 assert_equal (enum_loss , graph_loss , prec = 1e-3 )
24952495 assert_equal (enum_grads , graph_grads , prec = 2e-2 )
2496+
2497+
2498+ def test_trace_enum_elbo_with_auxiliary_site ():
2499+ """Regression test for https://github.com/pyro-ppl/numpyro/issues/2013.
2500+ TraceEnum_ELBO should handle guide sites marked with is_auxiliary=True
2501+ that have no corresponding site in the model.
2502+ """
2503+
2504+ def model ():
2505+ loc = pyro .sample ("loc" , dist .Normal (0.0 , 1.0 ))
2506+ x = pyro .sample (
2507+ "x" ,
2508+ dist .Categorical (jnp .array ([0.3 , 0.7 ])),
2509+ infer = {"enumerate" : "parallel" },
2510+ )
2511+ pyro .sample ("obs" , dist .Normal (loc + x , 1.0 ), obs = jnp .array (1.5 ))
2512+
2513+ def guide ():
2514+ aux = pyro .sample (
2515+ "aux" ,
2516+ dist .Normal (jnp .zeros (2 ), jnp .ones (2 )).to_event (1 ),
2517+ infer = {"is_auxiliary" : True },
2518+ )
2519+ pyro .sample ("loc" , dist .Delta (aux [0 ]))
2520+
2521+ elbo = infer .TraceEnum_ELBO ()
2522+ loss = elbo .loss (random .key (0 ), {}, model , guide )
2523+ assert jnp .isfinite (loss ), f"Expected finite loss, got { loss } "
0 commit comments