Skip to content

Commit c5fce4e

Browse files
authored
fix(gh-2123): Validate component distributions of MixtureSameFamily have parameter-independent support (#2127)
fixes #2123
1 parent c38f0f4 commit c5fce4e

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

numpyro/distributions/mixtures.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,13 @@ def __init__(
222222
*,
223223
validate_args: Optional[bool] = None,
224224
):
225+
assert isinstance(
226+
component_distribution.support, constraints.ParameterFreeConstraint
227+
), (
228+
f"Invalid component distribution: {type(component_distribution).__name__}. "
229+
"The mixture components must have a support that does not depend on their parameters "
230+
f"(expected ParameterFreeConstraint, but found {component_distribution.support})."
231+
)
225232
_check_mixing_distribution(mixing_distribution)
226233
mixture_size = mixing_distribution.probs.shape[-1]
227234
if not isinstance(component_distribution, Distribution):

test/test_distributions_mixture.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
5+
import numpy as np
46
import pytest
57

68
import jax
@@ -175,3 +177,31 @@ def _test_mixture(mixing_distribution, component_distribution):
175177
if mixture.event_shape == ():
176178
cdf = mixture.cdf(samples)
177179
assert cdf.shape == (*sample_shape, *mixture.shape())
180+
181+
182+
@pytest.mark.parametrize(
183+
"component_dist",
184+
[
185+
dist.Uniform(low=np.array([0.0, 5.0]), high=np.array([1.0, 10.0])),
186+
dist.TruncatedNormal(loc=np.zeros(2), scale=np.ones(2), low=0.0, high=1.0),
187+
],
188+
)
189+
def test_mixture_rejects_parameter_dependent_components(component_dist):
190+
mixing_dist = dist.Categorical(probs=np.array([0.5, 0.5]))
191+
with pytest.raises(
192+
AssertionError, match="expected ParameterFreeConstraint, but found "
193+
):
194+
dist.MixtureSameFamily(mixing_dist, component_dist)
195+
196+
197+
@pytest.mark.parametrize(
198+
"component_dist",
199+
[
200+
dist.Normal(np.zeros(2), np.ones(2)),
201+
dist.Exponential(np.ones(2)),
202+
dist.Bernoulli(probs=np.array([0.5, 0.5])),
203+
],
204+
)
205+
def test_mixture_accepts_parameter_free_components(component_dist):
206+
mixing_dist = dist.Categorical(probs=np.array([0.3, 0.7]))
207+
dist.MixtureSameFamily(mixing_dist, component_dist)

0 commit comments

Comments
 (0)