|
| 1 | +import torch |
1 | 2 | from ..utils import BaseTestClass |
| 3 | +import pytest |
2 | 4 | import pandas as pd |
3 | 5 | import numpy as np |
4 | 6 |
|
5 | 7 |
|
6 | 8 | class TestClass(BaseTestClass): |
7 | | - def test_draw_samples(self, dist_class): |
8 | | - # Create data for testing |
9 | | - predt_params = pd.DataFrame(np.array([0.5 for _ in range(dist_class.dist.n_dist_param)], dtype="float32")).T |
10 | | - |
| 9 | + @pytest.mark.parametrize("n_obs", [1, 5]) |
| 10 | + def test_draw_samples(self, dist_class, n_obs): |
| 11 | + # Create data for testing with n_obs observations |
| 12 | + predt_params = pd.DataFrame( |
| 13 | + np.array( |
| 14 | + [[0.5 for _ in range(dist_class.dist.n_dist_param)] for _ in range(n_obs)], |
| 15 | + dtype="float32", |
| 16 | + ) |
| 17 | + ) |
11 | 18 | # Call the function |
12 | 19 | dist_samples = dist_class.dist.draw_samples(predt_params) |
13 | 20 |
|
14 | 21 | # Assertions |
15 | 22 | if str(dist_class.dist).split(".")[2] != "Expectile": |
16 | 23 | assert isinstance(dist_samples, (pd.DataFrame, type(None))) |
| 24 | + # row count must match number of observations |
| 25 | + assert dist_samples.shape[0] == predt_params.shape[0] |
17 | 26 | assert not dist_samples.isna().any().any() |
18 | 27 | assert not np.isinf(dist_samples).any().any() |
19 | 28 |
|
20 | | - def test_draw_samples_mixture(self, mixture_class): |
21 | | - # Create data for testing |
22 | | - predt_params = pd.DataFrame(np.array([0.5 for _ in range(mixture_class.dist.n_dist_param)], dtype="float32")).T |
23 | | - |
| 29 | + @pytest.mark.parametrize("n_obs", [1, 5]) |
| 30 | + def test_draw_samples_mixture(self, mixture_class, n_obs): |
| 31 | + # Create data for testing with n_obs observations |
| 32 | + predt_params = pd.DataFrame( |
| 33 | + np.array( |
| 34 | + [[0.5 for _ in range(mixture_class.dist.n_dist_param)] for _ in range(n_obs)], |
| 35 | + dtype="float32", |
| 36 | + ) |
| 37 | + ) |
24 | 38 | # Call the function |
25 | 39 | dist_samples = mixture_class.dist.draw_samples(predt_params) |
26 | 40 |
|
27 | 41 | # Assertions |
28 | 42 | assert isinstance(dist_samples, (pd.DataFrame, type(None))) |
| 43 | + # row count must match number of observations |
| 44 | + assert dist_samples.shape[0] == predt_params.shape[0] |
29 | 45 | assert not dist_samples.isna().any().any() |
30 | 46 | assert not np.isinf(dist_samples).any().any() |
0 commit comments