Skip to content

Commit e7c99ee

Browse files
authored
Merge pull request #259 from igerber/survey-maturity
Add SDR replicate method and FPC support for ImputationDiD/TwoStageDiD
2 parents 33ca749 + b671c89 commit e7c99ee

7 files changed

Lines changed: 1054 additions & 99 deletions

File tree

diff_diff/imputation.py

Lines changed: 104 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ def fit(
212212
relative times in [-balance_e, max_h].
213213
survey_design : SurveyDesign, optional
214214
Survey design specification for design-based inference. Supports
215-
pweight only (aweight/fweight raise ValueError). FPC raises
216-
NotImplementedError. PSU is used as cluster variable for Theorem 3
217-
variance. Strata enters survey df for t-distribution inference.
215+
pweight only (aweight/fweight raise ValueError). Supports strata,
216+
PSU, and FPC for design-based variance via compute_survey_if_variance().
217+
Strata enters survey df for t-distribution inference.
218218
Both analytical (n_bootstrap=0) and bootstrap inference are supported.
219219
220220
Returns
@@ -276,12 +276,8 @@ def fit(
276276
f"got '{resolved_survey.weight_type}'. The survey variance math "
277277
f"assumes probability weights (pweight)."
278278
)
279-
if resolved_survey.fpc is not None:
280-
raise NotImplementedError(
281-
"ImputationDiD does not yet support FPC (finite population "
282-
"correction) in SurveyDesign. Weights, strata (for survey df), "
283-
"and PSU (for cluster-robust variance) are supported."
284-
)
279+
# FPC is supported — threaded through compute_survey_if_variance()
280+
# in _compute_conservative_variance().
285281

286282
# Bootstrap + survey supported via PSU-level multiplier bootstrap.
287283

@@ -518,6 +514,7 @@ def fit(
518514
cluster_var=cluster_var,
519515
kept_cov_mask=kept_cov_mask,
520516
survey_weights=survey_weights,
517+
resolved_survey=(resolved_survey if not _uses_replicate_imp else None),
521518
)
522519

523520
# Survey degrees of freedom for t-distribution inference
@@ -537,26 +534,47 @@ def fit(
537534

538535
if aggregate in ("event_study", "all"):
539536
event_study_effects = self._aggregate_event_study(
540-
df=df, outcome=outcome, unit=unit, time=time,
541-
first_treat=first_treat, covariates=covariates,
542-
omega_0_mask=omega_0_mask, omega_1_mask=omega_1_mask,
543-
unit_fe=unit_fe, time_fe=time_fe, grand_mean=grand_mean,
544-
delta_hat=delta_hat, cluster_var=cluster_var,
545-
treatment_groups=treatment_groups, balance_e=balance_e,
546-
kept_cov_mask=kept_cov_mask, survey_weights=survey_weights,
537+
df=df,
538+
outcome=outcome,
539+
unit=unit,
540+
time=time,
541+
first_treat=first_treat,
542+
covariates=covariates,
543+
omega_0_mask=omega_0_mask,
544+
omega_1_mask=omega_1_mask,
545+
unit_fe=unit_fe,
546+
time_fe=time_fe,
547+
grand_mean=grand_mean,
548+
delta_hat=delta_hat,
549+
cluster_var=cluster_var,
550+
treatment_groups=treatment_groups,
551+
balance_e=balance_e,
552+
kept_cov_mask=kept_cov_mask,
553+
survey_weights=survey_weights,
547554
survey_df=_survey_df,
555+
resolved_survey=(resolved_survey if not _uses_replicate_imp else None),
548556
)
549557

550558
if aggregate in ("group", "all"):
551559
group_effects = self._aggregate_group(
552-
df=df, outcome=outcome, unit=unit, time=time,
553-
first_treat=first_treat, covariates=covariates,
554-
omega_0_mask=omega_0_mask, omega_1_mask=omega_1_mask,
555-
unit_fe=unit_fe, time_fe=time_fe, grand_mean=grand_mean,
556-
delta_hat=delta_hat, cluster_var=cluster_var,
560+
df=df,
561+
outcome=outcome,
562+
unit=unit,
563+
time=time,
564+
first_treat=first_treat,
565+
covariates=covariates,
566+
omega_0_mask=omega_0_mask,
567+
omega_1_mask=omega_1_mask,
568+
unit_fe=unit_fe,
569+
time_fe=time_fe,
570+
grand_mean=grand_mean,
571+
delta_hat=delta_hat,
572+
cluster_var=cluster_var,
557573
treatment_groups=treatment_groups,
558-
kept_cov_mask=kept_cov_mask, survey_weights=survey_weights,
574+
kept_cov_mask=kept_cov_mask,
575+
survey_weights=survey_weights,
559576
survey_df=_survey_df,
577+
resolved_survey=(resolved_survey if not _uses_replicate_imp else None),
560578
)
561579

562580
# Replicate variance: derive keys from actual outputs (after filtering)
@@ -568,13 +586,13 @@ def fit(
568586

569587
# Derive keys from actual outputs (excludes filtered/Prop5/ref)
570588
_sorted_rel_times = sorted(
571-
e for e in (event_study_effects or {}).keys()
589+
e
590+
for e in (event_study_effects or {}).keys()
572591
if np.isfinite(event_study_effects[e]["effect"])
573592
and event_study_effects[e].get("n_obs", 1) > 0
574593
)
575594
_sorted_groups = sorted(
576-
g for g in (group_effects or {}).keys()
577-
if np.isfinite(group_effects[g]["effect"])
595+
g for g in (group_effects or {}).keys() if np.isfinite(group_effects[g]["effect"])
578596
)
579597
_n_es = len(_sorted_rel_times)
580598

@@ -583,13 +601,9 @@ def fit(
583601
if balance_e is not None and _sorted_rel_times:
584602
df_1 = df.loc[omega_1_mask]
585603
rel_times_all = df_1["_rel_time"].values
586-
all_horizons_full = sorted(
587-
set(int(h) for h in rel_times_all if np.isfinite(h))
588-
)
604+
all_horizons_full = sorted(set(int(h) for h in rel_times_all if np.isfinite(h)))
589605
if self.horizon_max is not None:
590-
all_horizons_full = [
591-
h for h in all_horizons_full if abs(h) <= self.horizon_max
592-
]
606+
all_horizons_full = [h for h in all_horizons_full if abs(h) <= self.horizon_max]
593607
cohort_rel_times = self._build_cohort_rel_times(df, first_treat)
594608
_balanced_mask_treated = self._compute_balanced_cohort_mask(
595609
df_1, first_treat, all_horizons_full, balance_e, cohort_rel_times
@@ -598,11 +612,25 @@ def fit(
598612
# Single vectorized refit: [overall, es_e0..., grp_g0...]
599613
def _refit_imp(w_r):
600614
ufe_r, tfe_r, gm_r, delta_r, _ = self._fit_untreated_model(
601-
df, outcome, unit, time, covariates, omega_0_mask, weights=w_r,
615+
df,
616+
outcome,
617+
unit,
618+
time,
619+
covariates,
620+
omega_0_mask,
621+
weights=w_r,
602622
)
603623
tau_r, _ = self._impute_treatment_effects(
604-
df, outcome, unit, time, covariates, omega_1_mask,
605-
ufe_r, tfe_r, gm_r, delta_r,
624+
df,
625+
outcome,
626+
unit,
627+
time,
628+
covariates,
629+
omega_1_mask,
630+
ufe_r,
631+
tfe_r,
632+
gm_r,
633+
delta_r,
606634
)
607635
fin = np.isfinite(tau_r)
608636
treated_w = w_r[omega_1_mask.values]
@@ -1314,7 +1342,7 @@ def _compute_cluster_psi_sums(
13141342
ve_series = pd.Series(ve_product, index=df.index)
13151343
cluster_sums = ve_series.groupby(cluster_ids).sum()
13161344

1317-
return cluster_sums.values, cluster_sums.index.values
1345+
return cluster_sums.values, cluster_sums.index.values, ve_product
13181346

13191347
def _compute_conservative_variance(
13201348
self,
@@ -1334,6 +1362,7 @@ def _compute_conservative_variance(
13341362
cluster_var: str,
13351363
kept_cov_mask: Optional[np.ndarray] = None,
13361364
survey_weights: Optional[np.ndarray] = None,
1365+
resolved_survey=None,
13371366
) -> float:
13381367
"""
13391368
Compute conservative clustered variance (Theorem 3, Equation 7).
@@ -1346,14 +1375,17 @@ def _compute_conservative_variance(
13461375
survey_weights : np.ndarray, optional
13471376
Full-panel survey weights. When provided, untreated denominators
13481377
in v_it use survey-weighted sums instead of raw counts.
1378+
resolved_survey : ResolvedSurveyDesign, optional
1379+
When provided, uses design-based variance via
1380+
``compute_survey_if_variance()`` (supports strata, PSU, FPC).
13491381
13501382
Returns
13511383
-------
13521384
float
13531385
Standard error.
13541386
"""
13551387
sw_0 = survey_weights[omega_0_mask.values] if survey_weights is not None else None
1356-
cluster_psi_sums, _ = self._compute_cluster_psi_sums(
1388+
cluster_psi_sums, _, ve_product = self._compute_cluster_psi_sums(
13571389
df=df,
13581390
outcome=outcome,
13591391
unit=unit,
@@ -1371,6 +1403,16 @@ def _compute_conservative_variance(
13711403
kept_cov_mask=kept_cov_mask,
13721404
survey_weights_0=sw_0,
13731405
)
1406+
1407+
if resolved_survey is not None:
1408+
# Design-based variance with strata/PSU/FPC support
1409+
from diff_diff.survey import compute_survey_if_variance
1410+
1411+
variance = compute_survey_if_variance(ve_product, resolved_survey)
1412+
if np.isnan(variance):
1413+
return np.nan
1414+
return np.sqrt(max(variance, 0.0))
1415+
13741416
sigma_sq = float((cluster_psi_sums**2).sum())
13751417
return np.sqrt(max(sigma_sq, 0.0))
13761418

@@ -1588,6 +1630,7 @@ def _aggregate_event_study(
15881630
kept_cov_mask: Optional[np.ndarray] = None,
15891631
survey_weights: Optional[np.ndarray] = None,
15901632
survey_df: Optional[int] = None,
1633+
resolved_survey=None,
15911634
) -> Dict[int, Dict[str, Any]]:
15921635
"""Aggregate treatment effects by event-study horizon."""
15931636
df_1 = df.loc[omega_1_mask]
@@ -1679,9 +1722,7 @@ def _aggregate_event_study(
16791722
)
16801723
pre_rel_times = [h for h in pre_rel_times if h != ref_period]
16811724
if self.horizon_max is not None:
1682-
pre_rel_times = [
1683-
h for h in pre_rel_times if abs(h) <= self.horizon_max
1684-
]
1725+
pre_rel_times = [h for h in pre_rel_times if abs(h) <= self.horizon_max]
16851726
if pre_rel_times:
16861727
pre_effects, _, _ = self._compute_lead_coefficients(
16871728
df_0,
@@ -1783,6 +1824,7 @@ def _aggregate_event_study(
17831824
cluster_var=cluster_var,
17841825
kept_cov_mask=kept_cov_mask,
17851826
survey_weights=survey_weights,
1827+
resolved_survey=resolved_survey,
17861828
)
17871829

17881830
t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df)
@@ -1845,6 +1887,7 @@ def _aggregate_group(
18451887
kept_cov_mask: Optional[np.ndarray] = None,
18461888
survey_weights: Optional[np.ndarray] = None,
18471889
survey_df: Optional[int] = None,
1890+
resolved_survey=None,
18481891
) -> Dict[Any, Dict[str, Any]]:
18491892
"""Aggregate treatment effects by cohort."""
18501893
df_1 = df.loc[omega_1_mask]
@@ -1916,6 +1959,7 @@ def _aggregate_group(
19161959
cluster_var=cluster_var,
19171960
kept_cov_mask=kept_cov_mask,
19181961
survey_weights=survey_weights,
1962+
resolved_survey=resolved_survey,
19191963
)
19201964

19211965
t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df)
@@ -2023,14 +2067,19 @@ def _compute_lead_coefficients(
20232067
for h in pre_rel_times:
20242068
n_obs = int(df_0[f"_lead_{h}"].sum())
20252069
effects[h] = {
2026-
"effect": np.nan, "se": np.nan, "t_stat": np.nan,
2027-
"p_value": np.nan, "conf_int": (np.nan, np.nan),
2070+
"effect": np.nan,
2071+
"se": np.nan,
2072+
"t_stat": np.nan,
2073+
"p_value": np.nan,
2074+
"conf_int": (np.nan, np.nan),
20282075
"n_obs": n_obs,
20292076
}
20302077
for col in lead_cols:
20312078
df_0.drop(columns=col, inplace=True)
2032-
return effects, np.full(len(pre_rel_times), np.nan), np.full(
2033-
(len(pre_rel_times), len(pre_rel_times)), np.nan
2079+
return (
2080+
effects,
2081+
np.full(len(pre_rel_times), np.nan),
2082+
np.full((len(pre_rel_times), len(pre_rel_times)), np.nan),
20342083
)
20352084

20362085
coefficients = result[0]
@@ -2134,8 +2183,15 @@ def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
21342183

21352184
# Use shared lead coefficient computation
21362185
effects, gamma, V_gamma = self._compute_lead_coefficients(
2137-
df_0, outcome, unit, time, first_treat, covariates,
2138-
cluster_var, pre_rel_times, alpha=self.alpha,
2186+
df_0,
2187+
outcome,
2188+
unit,
2189+
time,
2190+
first_treat,
2191+
covariates,
2192+
cluster_var,
2193+
pre_rel_times,
2194+
alpha=self.alpha,
21392195
)
21402196

21412197
n_leads_actual = len(pre_rel_times)
@@ -2249,9 +2305,9 @@ def imputation_did(
22492305
Balance event study to cohorts observed at all relative times.
22502306
survey_design : SurveyDesign, optional
22512307
Survey design specification for design-based inference. Supports
2252-
pweight only (aweight/fweight raise ValueError). FPC raises
2253-
NotImplementedError. PSU is used as cluster variable for Theorem 3
2254-
variance. Strata enters survey df for t-distribution inference.
2308+
pweight only (aweight/fweight raise ValueError). Supports strata,
2309+
PSU, and FPC for design-based variance. Strata enters survey df
2310+
for t-distribution inference.
22552311
Both analytical (n_bootstrap=0) and bootstrap inference are supported.
22562312
**kwargs
22572313
Additional keyword arguments passed to ImputationDiD constructor.

diff_diff/imputation_bootstrap.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _compute_cluster_psi_sums(
9191
cluster_var: str,
9292
kept_cov_mask: Optional[np.ndarray] = None,
9393
survey_weights_0: Optional[np.ndarray] = None,
94-
) -> Tuple[np.ndarray, np.ndarray]: ...
94+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ...
9595

9696
@staticmethod
9797
def _build_cohort_rel_times(
@@ -165,7 +165,9 @@ def _precompute_bootstrap_psi(
165165
)
166166

167167
# Overall ATT
168-
overall_psi, cluster_ids = self._compute_cluster_psi_sums(**common, weights=overall_weights)
168+
overall_psi, cluster_ids, _ = self._compute_cluster_psi_sums(
169+
**common, weights=overall_weights
170+
)
169171
result["overall"] = (overall_psi, cluster_ids)
170172

171173
# Event study: per-horizon weights
@@ -227,7 +229,7 @@ def _precompute_bootstrap_psi(
227229
if n_valid_h == 0:
228230
continue
229231

230-
psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h)
232+
psi_h, _, _ = self._compute_cluster_psi_sums(**common, weights=weights_h)
231233
result["event_study"][h] = psi_h
232234

233235
# Group effects: per-group weights
@@ -265,7 +267,7 @@ def _precompute_bootstrap_psi(
265267
if n_valid_g == 0:
266268
continue
267269

268-
psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g)
270+
psi_g, _, _ = self._compute_cluster_psi_sums(**common, weights=weights_g)
269271
result["group"][g] = psi_g
270272

271273
return result

0 commit comments

Comments
 (0)