Skip to content

Commit 3f9cdea

Browse files
authored
fix (#2164)
1 parent 781b0e7 commit 3f9cdea

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

numpyro/infer/elbo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,11 @@ def get_importance_trace_enum(
982982
site["kl"] = to_funsor(
983983
kl_qp, output=funsor.Real, dim_to_name=dim_to_name
984984
)
985-
elif not is_model and (model_trace[name].get("kl") is not None):
985+
elif (
986+
not is_model
987+
and name in model_trace
988+
and (model_trace[name].get("kl") is not None)
989+
):
986990
# skip logq computation if analytic kl was computed
987991
pass
988992
else:

test/contrib/test_enum_elbo.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2493,3 +2493,31 @@ def enum_loss_fn(params_raw):
24932493

24942494
assert_equal(enum_loss, graph_loss, prec=1e-3)
24952495
assert_equal(enum_grads, graph_grads, prec=2e-2)
2496+
2497+
2498+
def test_trace_enum_elbo_with_auxiliary_site():
2499+
"""Regression test for https://github.com/pyro-ppl/numpyro/issues/2013.
2500+
TraceEnum_ELBO should handle guide sites marked with is_auxiliary=True
2501+
that have no corresponding site in the model.
2502+
"""
2503+
2504+
def model():
2505+
loc = pyro.sample("loc", dist.Normal(0.0, 1.0))
2506+
x = pyro.sample(
2507+
"x",
2508+
dist.Categorical(jnp.array([0.3, 0.7])),
2509+
infer={"enumerate": "parallel"},
2510+
)
2511+
pyro.sample("obs", dist.Normal(loc + x, 1.0), obs=jnp.array(1.5))
2512+
2513+
def guide():
2514+
aux = pyro.sample(
2515+
"aux",
2516+
dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1),
2517+
infer={"is_auxiliary": True},
2518+
)
2519+
pyro.sample("loc", dist.Delta(aux[0]))
2520+
2521+
elbo = infer.TraceEnum_ELBO()
2522+
loss = elbo.loss(random.key(0), {}, model, guide)
2523+
assert jnp.isfinite(loss), f"Expected finite loss, got {loss}"

0 commit comments

Comments
 (0)