@@ -1442,3 +1442,65 @@ def test_psu_period_factor(self):
14421442 # Same structure
14431443 assert set (data_low .columns ) == set (data_high .columns )
14441444 assert len (data_low ) == len (data_high )
1445+
1446+ def test_psu_period_factor_deff_regression (self ):
1447+ """Verify psu_period_factor=1.0 gives DEFF > 1 for the tutorial scenario."""
1448+ import warnings
1449+
1450+ from diff_diff import (
1451+ CallawaySantAnna ,
1452+ DifferenceInDifferences ,
1453+ SurveyDesign ,
1454+ )
1455+ from diff_diff .linalg import LinearRegression
1456+ from diff_diff .prep import generate_survey_did_data
1457+
1458+ warnings .filterwarnings ("ignore" )
1459+ df = generate_survey_did_data (
1460+ n_units = 200 , n_periods = 8 , cohort_periods = [3 , 5 ],
1461+ never_treated_frac = 0.3 , treatment_effect = 2.0 ,
1462+ n_strata = 5 , psu_per_stratum = 8 , fpc_per_stratum = 200.0 ,
1463+ weight_variation = "moderate" , psu_re_sd = 2.0 ,
1464+ psu_period_factor = 1.0 , seed = 42 ,
1465+ )
1466+ sd = SurveyDesign (weights = "weight" , strata = "stratum" , psu = "psu" , fpc = "fpc" )
1467+
1468+ # 2x2 subset: survey SE must exceed naive SE
1469+ c3 = df [(df ["first_treat" ].isin ([0 , 3 ])) & (df ["period" ].isin ([2 , 3 ]))].copy ()
1470+ c3 ["post" ] = (c3 ["period" ] == 3 ).astype (int )
1471+ c3 ["treat" ] = (c3 ["first_treat" ] == 3 ).astype (int )
1472+ did = DifferenceInDifferences ()
1473+ r_naive = did .fit (c3 , outcome = "outcome" , treatment = "treat" , time = "post" )
1474+ r_survey = did .fit (
1475+ c3 , outcome = "outcome" , treatment = "treat" , time = "post" ,
1476+ survey_design = sd ,
1477+ )
1478+ assert r_survey .se > r_naive .se , (
1479+ f"Survey SE ({ r_survey .se :.4f} ) should exceed naive SE ({ r_naive .se :.4f} )"
1480+ )
1481+
1482+ # DEFF for treat_x_post must be > 1
1483+ c3 ["treat_x_post" ] = c3 ["treat" ] * c3 ["post" ]
1484+ resolved = sd .resolve (c3 )
1485+ reg = LinearRegression (include_intercept = True , survey_design = resolved )
1486+ reg .fit (X = c3 [["treat" , "post" , "treat_x_post" ]].values , y = c3 ["outcome" ].values )
1487+ deff = reg .compute_deff (
1488+ coefficient_names = ["intercept" , "treat" , "post" , "treat_x_post" ]
1489+ )
1490+ txp_deff = deff .deff [3 ] # treat_x_post
1491+ assert txp_deff > 1.0 , f"DEFF for treat_x_post ({ txp_deff :.2f} ) should be > 1"
1492+
1493+ def test_psu_period_factor_validation (self ):
1494+ """Test that invalid psu_period_factor values raise ValueError."""
1495+ import math
1496+
1497+ import pytest
1498+
1499+ from diff_diff .prep import generate_survey_did_data
1500+
1501+ with pytest .raises (ValueError , match = "psu_period_factor" ):
1502+ generate_survey_did_data (psu_period_factor = - 1.0 , seed = 42 )
1503+ with pytest .raises (ValueError , match = "psu_period_factor" ):
1504+ generate_survey_did_data (psu_period_factor = math .nan , seed = 42 )
1505+ with pytest .raises (ValueError , match = "psu_period_factor" ):
1506+ generate_survey_did_data (psu_period_factor = math .inf , seed = 42 )
0 commit comments