Skip to content

Commit b728ff8

Browse files
authored
Merge pull request #56 from nucflash/20251205-fix-draw_samples
Fixes `draw_samples()` shape when testset consists of exactly one sample
2 parents 0f9db92 + 2ddf4fe commit b728ff8

File tree

4 files changed

+42
-12
lines changed

4 files changed

+42
-12
lines changed

lightgbmlss/distributions/distribution_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,12 @@ def draw_samples(self,
325325
pred_params = torch.tensor(predt_params.values)
326326
dist_kwargs = {arg_name: param for arg_name, param in zip(self.distribution_arg_names, pred_params.T)}
327327
dist_pred = self.distribution(**dist_kwargs)
328-
dist_samples = dist_pred.sample((n_samples,)).squeeze().detach().numpy().T
328+
# Sample: shape is (n_samples, n_obs, *event_shape)
329+
dist_samples = dist_pred.sample((n_samples,)).detach().numpy()
330+
# Flatten any event dimensions but keep (n_samples, n_obs) as outer structure
331+
dist_samples = dist_samples.reshape(n_samples, -1) # (n_samples, n_obs)
332+
dist_samples = dist_samples.T # (n_obs, n_samples)
333+
329334
dist_samples = pd.DataFrame(dist_samples)
330335
dist_samples.columns = [str("y_sample") + str(i) for i in range(dist_samples.shape[1])]
331336
else:

lightgbmlss/distributions/flow_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,13 @@ def draw_samples(self,
397397
# Replace parameters with estimated ones
398398
_, flow_dist_pred = self.replace_parameters(pred_params, flow_dist_pred)
399399

400-
# Draw samples
401-
flow_samples = pd.DataFrame(flow_dist_pred.sample((n_samples,)).squeeze().detach().numpy().T)
400+
# Draw samples (n_samples, n_obs, *event_shape)
401+
flow_samples = flow_dist_pred.sample((n_samples,)).detach().numpy()
402+
# Flatten event dims, keep (n_samples, n_obs) as outer shape
403+
flow_samples = flow_samples.reshape(n_samples, -1) # (n_samples, n_obs)
404+
flow_samples = flow_samples.T # (n_obs, n_samples)
405+
406+
flow_samples = pd.DataFrame(flow_samples)
402407
flow_samples.columns = [str("y_sample") + str(i) for i in range(flow_samples.shape[1])]
403408

404409
if self.discrete:

lightgbmlss/distributions/mixture_distribution_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,11 @@ def draw_samples(self,
404404
pred_params = torch.tensor(predt_params.values).reshape(-1, self.n_dist_param)
405405
pred_params = torch.split(pred_params, self.M, dim=1)
406406
dist_pred = self.create_mixture_distribution(pred_params)
407-
dist_samples = dist_pred.sample((n_samples,)).squeeze().detach().numpy().T
407+
# sample (n_samples, n_obs, *event_shape)
408+
dist_samples = dist_pred.sample((n_samples,)).detach().numpy()
409+
# Flatten event dims, keep (n_samples, n_obs) as outer shape
410+
dist_samples = dist_samples.reshape(n_samples, -1) # (n_samples, n_obs)
411+
dist_samples = dist_samples.T # (n_obs, n_samples)
408412
dist_samples = pd.DataFrame(dist_samples)
409413
dist_samples.columns = [str("y_sample") + str(i) for i in range(dist_samples.shape[1])]
410414

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,46 @@
1+
import torch
12
from ..utils import BaseTestClass
3+
import pytest
24
import pandas as pd
35
import numpy as np
46

57

68
class TestClass(BaseTestClass):
7-
def test_draw_samples(self, dist_class):
8-
# Create data for testing
9-
predt_params = pd.DataFrame(np.array([0.5 for _ in range(dist_class.dist.n_dist_param)], dtype="float32")).T
10-
9+
@pytest.mark.parametrize("n_obs", [1, 5])
10+
def test_draw_samples(self, dist_class, n_obs):
11+
# Create data for testing with n_obs observations
12+
predt_params = pd.DataFrame(
13+
np.array(
14+
[[0.5 for _ in range(dist_class.dist.n_dist_param)] for _ in range(n_obs)],
15+
dtype="float32",
16+
)
17+
)
1118
# Call the function
1219
dist_samples = dist_class.dist.draw_samples(predt_params)
1320

1421
# Assertions
1522
if str(dist_class.dist).split(".")[2] != "Expectile":
1623
assert isinstance(dist_samples, (pd.DataFrame, type(None)))
24+
# row count must match number of observations
25+
assert dist_samples.shape[0] == predt_params.shape[0]
1726
assert not dist_samples.isna().any().any()
1827
assert not np.isinf(dist_samples).any().any()
1928

20-
def test_draw_samples_mixture(self, mixture_class):
21-
# Create data for testing
22-
predt_params = pd.DataFrame(np.array([0.5 for _ in range(mixture_class.dist.n_dist_param)], dtype="float32")).T
23-
29+
@pytest.mark.parametrize("n_obs", [1, 5])
30+
def test_draw_samples_mixture(self, mixture_class, n_obs):
31+
# Create data for testing with n_obs observations
32+
predt_params = pd.DataFrame(
33+
np.array(
34+
[[0.5 for _ in range(mixture_class.dist.n_dist_param)] for _ in range(n_obs)],
35+
dtype="float32",
36+
)
37+
)
2438
# Call the function
2539
dist_samples = mixture_class.dist.draw_samples(predt_params)
2640

2741
# Assertions
2842
assert isinstance(dist_samples, (pd.DataFrame, type(None)))
43+
# row count must match number of observations
44+
assert dist_samples.shape[0] == predt_params.shape[0]
2945
assert not dist_samples.isna().any().any()
3046
assert not np.isinf(dist_samples).any().any()

0 commit comments

Comments
 (0)