|
1 | 1 | # Copyright Contributors to the Pyro project. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
| 4 | + |
| 5 | +import numpy as np |
4 | 6 | import pytest |
5 | 7 |
|
6 | 8 | import jax |
@@ -175,3 +177,31 @@ def _test_mixture(mixing_distribution, component_distribution): |
175 | 177 | if mixture.event_shape == (): |
176 | 178 | cdf = mixture.cdf(samples) |
177 | 179 | assert cdf.shape == (*sample_shape, *mixture.shape()) |
| 180 | + |
| 181 | + |
| 182 | +@pytest.mark.parametrize( |
| 183 | + "component_dist", |
| 184 | + [ |
| 185 | + dist.Uniform(low=np.array([0.0, 5.0]), high=np.array([1.0, 10.0])), |
| 186 | + dist.TruncatedNormal(loc=np.zeros(2), scale=np.ones(2), low=0.0, high=1.0), |
| 187 | + ], |
| 188 | +) |
| 189 | +def test_mixture_rejects_parameter_dependent_components(component_dist): |
| 190 | + mixing_dist = dist.Categorical(probs=np.array([0.5, 0.5])) |
| 191 | + with pytest.raises( |
| 192 | + AssertionError, match="expected ParameterFreeConstraint, but found " |
| 193 | + ): |
| 194 | + dist.MixtureSameFamily(mixing_dist, component_dist) |
| 195 | + |
| 196 | + |
| 197 | +@pytest.mark.parametrize( |
| 198 | + "component_dist", |
| 199 | + [ |
| 200 | + dist.Normal(np.zeros(2), np.ones(2)), |
| 201 | + dist.Exponential(np.ones(2)), |
| 202 | + dist.Bernoulli(probs=np.array([0.5, 0.5])), |
| 203 | + ], |
| 204 | +) |
| 205 | +def test_mixture_accepts_parameter_free_components(component_dist): |
| 206 | + mixing_dist = dist.Categorical(probs=np.array([0.3, 0.7])) |
| 207 | + dist.MixtureSameFamily(mixing_dist, component_dist) |
0 commit comments