Skip to content

Commit 490e35c

Browse files
jeongyoonleeclaude
andauthored
Fix uplift tree p-value NaN from division by zero (#585) (#882)
* Fix uplift tree p-value NaN from division by zero (#585) When a tree node has zero treatment or control observations (n_t=0 or n_c=0), the p-value variance formula divides by zero, producing NaN. Also guards against zero variance when all observations in a group have the same outcome (p=0 or p=1). Returns p_value=1.0 (maximally non-significant) in these degenerate cases. * Add edge-case test for p-value with sparse treatment groups Test verifies predictions don't contain NaN when tree nodes have zero treatment or control observations (min_samples_treatment=0, heavily imbalanced groups). --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7975c30 commit 490e35c

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

causalml/inference/tree/uplift.pyx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2095,7 +2095,14 @@ class UpliftTreeClassifier:
20952095
else:
20962096
p_t = cur_summary_p[suboptTreatment]
20972097
n_t = cur_summary_n[suboptTreatment]
2098-
p_value = (1. - stats.norm.cdf(fabs(p_c - p_t) / sqrt(p_t * (1 - p_t) / n_t + p_c * (1 - p_c) / n_c))) * 2
2098+
if n_t > 0 and n_c > 0:
2099+
variance = p_t * (1 - p_t) / n_t + p_c * (1 - p_c) / n_c
2100+
if variance > 0:
2101+
p_value = (1. - stats.norm.cdf(fabs(p_c - p_t) / sqrt(variance))) * 2
2102+
else:
2103+
p_value = 1.0
2104+
else:
2105+
p_value = 1.0
20992106
upliftScore = [maxDiff, p_value]
21002107

21012108
bestGain = 0.0

tests/test_uplift_trees.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,27 @@ def test_UpliftTreeClassifier_with_nan_in_categorical_features():
365365
preds = uplift_model.predict(X_test)
366366
assert preds is not None
367367
assert uplift_model.fitted_uplift_tree is not None
368+
369+
370+
def test_uplift_tree_pvalue_no_nan_with_sparse_groups():
371+
"""Test that p-values don't become NaN when tree nodes have zero
372+
treatment or control observations (issue #585)."""
373+
np.random.seed(RANDOM_SEED)
374+
n = 50
375+
X = np.random.randn(n, 3)
376+
# Heavily imbalanced: only 2 samples in treatment1
377+
treatment = np.array([CONTROL_NAME] * 45 + ["treatment1"] * 2 + ["treatment2"] * 3)
378+
y = np.random.randint(0, 2, n)
379+
380+
model = UpliftTreeClassifier(
381+
control_name=CONTROL_NAME,
382+
min_samples_leaf=1,
383+
min_samples_treatment=0,
384+
max_depth=5,
385+
)
386+
model.fit(X, treatment, y)
387+
preds = model.predict(X)
388+
389+
assert not np.any(
390+
np.isnan(preds)
391+
), "Predictions contain NaN (likely from NaN p-values)"

0 commit comments

Comments
 (0)