Skip to content

Commit 0ea4ff2

Browse files
igerberclaude
andcommitted
Address AI review P2 findings: validate psu_period_factor and add DEFF regression test
Add input validation (finite, non-negative) for psu_period_factor. Add behavioral regression test that verifies the tutorial scenario produces survey SE > naive SE and DEFF > 1 for treat_x_post. Add validation test for negative/NaN/inf inputs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b8ff383 commit 0ea4ff2

2 files changed

Lines changed: 68 additions & 0 deletions

File tree

diff_diff/prep_dgp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,12 @@ def generate_survey_did_data(
12721272
f"(g >= 2 ensures at least one pre-treatment period)"
12731273
)
12741274

1275+
if not np.isfinite(psu_period_factor) or psu_period_factor < 0:
1276+
raise ValueError(
1277+
f"psu_period_factor must be finite and non-negative, "
1278+
f"got {psu_period_factor}"
1279+
)
1280+
12751281
valid_wv = ("none", "moderate", "high")
12761282
if weight_variation not in valid_wv:
12771283
raise ValueError(

tests/test_prep.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,3 +1442,65 @@ def test_psu_period_factor(self):
14421442
# Same structure
14431443
assert set(data_low.columns) == set(data_high.columns)
14441444
assert len(data_low) == len(data_high)
1445+
1446+
def test_psu_period_factor_deff_regression(self):
1447+
"""Verify psu_period_factor=1.0 gives DEFF > 1 for the tutorial scenario."""
1448+
import warnings
1449+
1450+
from diff_diff import (
1451+
CallawaySantAnna,
1452+
DifferenceInDifferences,
1453+
SurveyDesign,
1454+
)
1455+
from diff_diff.linalg import LinearRegression
1456+
from diff_diff.prep import generate_survey_did_data
1457+
1458+
warnings.filterwarnings("ignore")
1459+
df = generate_survey_did_data(
1460+
n_units=200, n_periods=8, cohort_periods=[3, 5],
1461+
never_treated_frac=0.3, treatment_effect=2.0,
1462+
n_strata=5, psu_per_stratum=8, fpc_per_stratum=200.0,
1463+
weight_variation="moderate", psu_re_sd=2.0,
1464+
psu_period_factor=1.0, seed=42,
1465+
)
1466+
sd = SurveyDesign(weights="weight", strata="stratum", psu="psu", fpc="fpc")
1467+
1468+
# 2x2 subset: survey SE must exceed naive SE
1469+
c3 = df[(df["first_treat"].isin([0, 3])) & (df["period"].isin([2, 3]))].copy()
1470+
c3["post"] = (c3["period"] == 3).astype(int)
1471+
c3["treat"] = (c3["first_treat"] == 3).astype(int)
1472+
did = DifferenceInDifferences()
1473+
r_naive = did.fit(c3, outcome="outcome", treatment="treat", time="post")
1474+
r_survey = did.fit(
1475+
c3, outcome="outcome", treatment="treat", time="post",
1476+
survey_design=sd,
1477+
)
1478+
assert r_survey.se > r_naive.se, (
1479+
f"Survey SE ({r_survey.se:.4f}) should exceed naive SE ({r_naive.se:.4f})"
1480+
)
1481+
1482+
# DEFF for treat_x_post must be > 1
1483+
c3["treat_x_post"] = c3["treat"] * c3["post"]
1484+
resolved = sd.resolve(c3)
1485+
reg = LinearRegression(include_intercept=True, survey_design=resolved)
1486+
reg.fit(X=c3[["treat", "post", "treat_x_post"]].values, y=c3["outcome"].values)
1487+
deff = reg.compute_deff(
1488+
coefficient_names=["intercept", "treat", "post", "treat_x_post"]
1489+
)
1490+
txp_deff = deff.deff[3] # treat_x_post
1491+
assert txp_deff > 1.0, f"DEFF for treat_x_post ({txp_deff:.2f}) should be > 1"
1492+
1493+
def test_psu_period_factor_validation(self):
1494+
"""Test that invalid psu_period_factor values raise ValueError."""
1495+
import math
1496+
1497+
import pytest
1498+
1499+
from diff_diff.prep import generate_survey_did_data
1500+
1501+
with pytest.raises(ValueError, match="psu_period_factor"):
1502+
generate_survey_did_data(psu_period_factor=-1.0, seed=42)
1503+
with pytest.raises(ValueError, match="psu_period_factor"):
1504+
generate_survey_did_data(psu_period_factor=math.nan, seed=42)
1505+
with pytest.raises(ValueError, match="psu_period_factor"):
1506+
generate_survey_did_data(psu_period_factor=math.inf, seed=42)

0 commit comments

Comments
 (0)