Skip to content

Commit a4d8e79

Browse files
jeongyoonleeclaude
andcommitted
Add edge-case test for p-value with sparse treatment groups
Test verifies predictions don't contain NaN when tree nodes have zero treatment or control observations (min_samples_treatment=0, heavily imbalanced groups). Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent 86756a8 commit a4d8e79

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/test_uplift_trees.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,27 @@ def test_UpliftTreeClassifier_with_nan_in_categorical_features():
365365
preds = uplift_model.predict(X_test)
366366
assert preds is not None
367367
assert uplift_model.fitted_uplift_tree is not None
368+
369+
370+
def test_uplift_tree_pvalue_no_nan_with_sparse_groups():
371+
"""Test that p-values don't become NaN when tree nodes have zero
372+
treatment or control observations (issue #585)."""
373+
np.random.seed(RANDOM_SEED)
374+
n = 50
375+
X = np.random.randn(n, 3)
376+
# Heavily imbalanced: only 2 samples in treatment1
377+
treatment = np.array([CONTROL_NAME] * 45 + ["treatment1"] * 2 + ["treatment2"] * 3)
378+
y = np.random.randint(0, 2, n)
379+
380+
model = UpliftTreeClassifier(
381+
control_name=CONTROL_NAME,
382+
min_samples_leaf=1,
383+
min_samples_treatment=0,
384+
max_depth=5,
385+
)
386+
model.fit(X, treatment, y)
387+
preds = model.predict(X)
388+
389+
assert not np.any(
390+
np.isnan(preds)
391+
), "Predictions contain NaN (likely from NaN p-values)"

0 commit comments

Comments
 (0)