Skip to content

Commit e42d5b1

Browse files
igerberclaude
andcommitted
Validate pscore_trim at fit() to guard against set_params bypass
set_params() can inject invalid pscore_trim values since it bypasses __init__ validation. Add check at fit() entry to catch 0.0, negative, and >=0.5 values before they reach IPW/DR weight formulas. Also update TripleDifference registry fallback note for error mode. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b151613 commit e42d5b1

3 files changed

Lines changed: 32 additions & 1 deletion

File tree

diff_diff/staggered.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,10 @@ def fit(
10821082
ValueError
10831083
If required columns are missing or data validation fails.
10841084
"""
1085+
# Validate pscore_trim (may have been changed via set_params)
1086+
if not (0 < self.pscore_trim < 0.5):
1087+
raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}")
1088+
10851089
# Normalize empty covariates list to None
10861090
if covariates is not None and len(covariates) == 0:
10871091
covariates = None

docs/methodology/REGISTRY.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1191,7 +1191,8 @@ has no additional effect.
11911191
- Cluster IDs: must not contain NaN (raises `ValueError`)
11921192
- Overlap warning: emitted when >5% of observations are trimmed at pscore bounds (IPW/DR only)
11931193
- Propensity score estimation failure: falls back to unconditional probability P(subgroup=4),
1194-
sets hessian=None (skipping PS correction in influence function), emits UserWarning
1194+
sets hessian=None (skipping PS correction in influence function), emits UserWarning.
1195+
Exception: when `rank_deficient_action="error"`, the error is re-raised instead of falling back.
11951196
- Collinear covariates: detected via pivoted QR in `solve_ols()`, action controlled by
11961197
`rank_deficient_action` ("warn", "error", "silent")
11971198
- Non-finite influence function values (e.g., from extreme propensity scores in IPW/DR

tests/test_staggered.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3220,6 +3220,32 @@ def test_set_params_pscore_trim(self):
32203220
cs.set_params(pscore_trim=0.1)
32213221
assert cs.pscore_trim == 0.1
32223222

3223+
def test_set_params_invalid_pscore_trim_rejected_at_fit(self):
3224+
"""Invalid pscore_trim via set_params() raises ValueError at fit()."""
3225+
np.random.seed(42)
3226+
n_units, n_periods = 50, 6
3227+
units = np.repeat(np.arange(n_units), n_periods)
3228+
times = np.tile(np.arange(n_periods), n_units)
3229+
first_treat = np.zeros(n_units)
3230+
first_treat[n_units // 2 :] = 3
3231+
first_treat_expanded = np.repeat(first_treat, n_periods)
3232+
post = (times >= first_treat_expanded) & (first_treat_expanded > 0)
3233+
outcomes = 1.0 + 2.0 * post + np.random.randn(len(units)) * 0.5
3234+
data = pd.DataFrame(
3235+
{
3236+
"unit": units,
3237+
"time": times,
3238+
"outcome": outcomes,
3239+
"first_treat": first_treat_expanded.astype(int),
3240+
}
3241+
)
3242+
3243+
for bad_val in [0.0, -0.1, 0.5]:
3244+
cs = CallawaySantAnna(estimation_method="ipw")
3245+
cs.set_params(pscore_trim=bad_val)
3246+
with pytest.raises(ValueError, match="pscore_trim must be in"):
3247+
cs.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat")
3248+
32233249
def test_default_pscore_trim(self):
32243250
"""Default pscore_trim is 0.01."""
32253251
cs = CallawaySantAnna()

0 commit comments

Comments
 (0)