@@ -212,9 +212,9 @@ def fit(
212212 relative times in [-balance_e, max_h].
213213 survey_design : SurveyDesign, optional
214214 Survey design specification for design-based inference. Supports
215- pweight only (aweight/fweight raise ValueError). FPC raises
216- NotImplementedError. PSU is used as cluster variable for Theorem 3
217- variance. Strata enters survey df for t-distribution inference.
215+ pweight only (aweight/fweight raise ValueError). Supports strata,
216+ PSU, and FPC for design-based variance via compute_survey_if_variance().
217+ Strata enters survey df for t-distribution inference.
218218 Both analytical (n_bootstrap=0) and bootstrap inference are supported.
219219
220220 Returns
@@ -276,12 +276,8 @@ def fit(
276276 f"got '{ resolved_survey .weight_type } '. The survey variance math "
277277 f"assumes probability weights (pweight)."
278278 )
279- if resolved_survey .fpc is not None :
280- raise NotImplementedError (
281- "ImputationDiD does not yet support FPC (finite population "
282- "correction) in SurveyDesign. Weights, strata (for survey df), "
283- "and PSU (for cluster-robust variance) are supported."
284- )
279+ # FPC is supported — threaded through compute_survey_if_variance()
280+ # in _compute_conservative_variance().
285281
286282 # Bootstrap + survey supported via PSU-level multiplier bootstrap.
287283
@@ -518,6 +514,7 @@ def fit(
518514 cluster_var = cluster_var ,
519515 kept_cov_mask = kept_cov_mask ,
520516 survey_weights = survey_weights ,
517+ resolved_survey = (resolved_survey if not _uses_replicate_imp else None ),
521518 )
522519
523520 # Survey degrees of freedom for t-distribution inference
@@ -537,26 +534,47 @@ def fit(
537534
538535 if aggregate in ("event_study" , "all" ):
539536 event_study_effects = self ._aggregate_event_study (
540- df = df , outcome = outcome , unit = unit , time = time ,
541- first_treat = first_treat , covariates = covariates ,
542- omega_0_mask = omega_0_mask , omega_1_mask = omega_1_mask ,
543- unit_fe = unit_fe , time_fe = time_fe , grand_mean = grand_mean ,
544- delta_hat = delta_hat , cluster_var = cluster_var ,
545- treatment_groups = treatment_groups , balance_e = balance_e ,
546- kept_cov_mask = kept_cov_mask , survey_weights = survey_weights ,
537+ df = df ,
538+ outcome = outcome ,
539+ unit = unit ,
540+ time = time ,
541+ first_treat = first_treat ,
542+ covariates = covariates ,
543+ omega_0_mask = omega_0_mask ,
544+ omega_1_mask = omega_1_mask ,
545+ unit_fe = unit_fe ,
546+ time_fe = time_fe ,
547+ grand_mean = grand_mean ,
548+ delta_hat = delta_hat ,
549+ cluster_var = cluster_var ,
550+ treatment_groups = treatment_groups ,
551+ balance_e = balance_e ,
552+ kept_cov_mask = kept_cov_mask ,
553+ survey_weights = survey_weights ,
547554 survey_df = _survey_df ,
555+ resolved_survey = (resolved_survey if not _uses_replicate_imp else None ),
548556 )
549557
550558 if aggregate in ("group" , "all" ):
551559 group_effects = self ._aggregate_group (
552- df = df , outcome = outcome , unit = unit , time = time ,
553- first_treat = first_treat , covariates = covariates ,
554- omega_0_mask = omega_0_mask , omega_1_mask = omega_1_mask ,
555- unit_fe = unit_fe , time_fe = time_fe , grand_mean = grand_mean ,
556- delta_hat = delta_hat , cluster_var = cluster_var ,
560+ df = df ,
561+ outcome = outcome ,
562+ unit = unit ,
563+ time = time ,
564+ first_treat = first_treat ,
565+ covariates = covariates ,
566+ omega_0_mask = omega_0_mask ,
567+ omega_1_mask = omega_1_mask ,
568+ unit_fe = unit_fe ,
569+ time_fe = time_fe ,
570+ grand_mean = grand_mean ,
571+ delta_hat = delta_hat ,
572+ cluster_var = cluster_var ,
557573 treatment_groups = treatment_groups ,
558- kept_cov_mask = kept_cov_mask , survey_weights = survey_weights ,
574+ kept_cov_mask = kept_cov_mask ,
575+ survey_weights = survey_weights ,
559576 survey_df = _survey_df ,
577+ resolved_survey = (resolved_survey if not _uses_replicate_imp else None ),
560578 )
561579
562580 # Replicate variance: derive keys from actual outputs (after filtering)
@@ -568,13 +586,13 @@ def fit(
568586
569587 # Derive keys from actual outputs (excludes filtered/Prop5/ref)
570588 _sorted_rel_times = sorted (
571- e for e in (event_study_effects or {}).keys ()
589+ e
590+ for e in (event_study_effects or {}).keys ()
572591 if np .isfinite (event_study_effects [e ]["effect" ])
573592 and event_study_effects [e ].get ("n_obs" , 1 ) > 0
574593 )
575594 _sorted_groups = sorted (
576- g for g in (group_effects or {}).keys ()
577- if np .isfinite (group_effects [g ]["effect" ])
595+ g for g in (group_effects or {}).keys () if np .isfinite (group_effects [g ]["effect" ])
578596 )
579597 _n_es = len (_sorted_rel_times )
580598
@@ -583,13 +601,9 @@ def fit(
583601 if balance_e is not None and _sorted_rel_times :
584602 df_1 = df .loc [omega_1_mask ]
585603 rel_times_all = df_1 ["_rel_time" ].values
586- all_horizons_full = sorted (
587- set (int (h ) for h in rel_times_all if np .isfinite (h ))
588- )
604+ all_horizons_full = sorted (set (int (h ) for h in rel_times_all if np .isfinite (h )))
589605 if self .horizon_max is not None :
590- all_horizons_full = [
591- h for h in all_horizons_full if abs (h ) <= self .horizon_max
592- ]
606+ all_horizons_full = [h for h in all_horizons_full if abs (h ) <= self .horizon_max ]
593607 cohort_rel_times = self ._build_cohort_rel_times (df , first_treat )
594608 _balanced_mask_treated = self ._compute_balanced_cohort_mask (
595609 df_1 , first_treat , all_horizons_full , balance_e , cohort_rel_times
@@ -598,11 +612,25 @@ def fit(
598612 # Single vectorized refit: [overall, es_e0..., grp_g0...]
599613 def _refit_imp (w_r ):
600614 ufe_r , tfe_r , gm_r , delta_r , _ = self ._fit_untreated_model (
601- df , outcome , unit , time , covariates , omega_0_mask , weights = w_r ,
615+ df ,
616+ outcome ,
617+ unit ,
618+ time ,
619+ covariates ,
620+ omega_0_mask ,
621+ weights = w_r ,
602622 )
603623 tau_r , _ = self ._impute_treatment_effects (
604- df , outcome , unit , time , covariates , omega_1_mask ,
605- ufe_r , tfe_r , gm_r , delta_r ,
624+ df ,
625+ outcome ,
626+ unit ,
627+ time ,
628+ covariates ,
629+ omega_1_mask ,
630+ ufe_r ,
631+ tfe_r ,
632+ gm_r ,
633+ delta_r ,
606634 )
607635 fin = np .isfinite (tau_r )
608636 treated_w = w_r [omega_1_mask .values ]
@@ -1314,7 +1342,7 @@ def _compute_cluster_psi_sums(
13141342 ve_series = pd .Series (ve_product , index = df .index )
13151343 cluster_sums = ve_series .groupby (cluster_ids ).sum ()
13161344
1317- return cluster_sums .values , cluster_sums .index .values
1345+ return cluster_sums .values , cluster_sums .index .values , ve_product
13181346
13191347 def _compute_conservative_variance (
13201348 self ,
@@ -1334,6 +1362,7 @@ def _compute_conservative_variance(
13341362 cluster_var : str ,
13351363 kept_cov_mask : Optional [np .ndarray ] = None ,
13361364 survey_weights : Optional [np .ndarray ] = None ,
1365+ resolved_survey = None ,
13371366 ) -> float :
13381367 """
13391368 Compute conservative clustered variance (Theorem 3, Equation 7).
@@ -1346,14 +1375,17 @@ def _compute_conservative_variance(
13461375 survey_weights : np.ndarray, optional
13471376 Full-panel survey weights. When provided, untreated denominators
13481377 in v_it use survey-weighted sums instead of raw counts.
1378+ resolved_survey : ResolvedSurveyDesign, optional
1379+ When provided, uses design-based variance via
1380+ ``compute_survey_if_variance()`` (supports strata, PSU, FPC).
13491381
13501382 Returns
13511383 -------
13521384 float
13531385 Standard error.
13541386 """
13551387 sw_0 = survey_weights [omega_0_mask .values ] if survey_weights is not None else None
1356- cluster_psi_sums , _ = self ._compute_cluster_psi_sums (
1388+ cluster_psi_sums , _ , ve_product = self ._compute_cluster_psi_sums (
13571389 df = df ,
13581390 outcome = outcome ,
13591391 unit = unit ,
@@ -1371,6 +1403,16 @@ def _compute_conservative_variance(
13711403 kept_cov_mask = kept_cov_mask ,
13721404 survey_weights_0 = sw_0 ,
13731405 )
1406+
1407+ if resolved_survey is not None :
1408+ # Design-based variance with strata/PSU/FPC support
1409+ from diff_diff .survey import compute_survey_if_variance
1410+
1411+ variance = compute_survey_if_variance (ve_product , resolved_survey )
1412+ if np .isnan (variance ):
1413+ return np .nan
1414+ return np .sqrt (max (variance , 0.0 ))
1415+
13741416 sigma_sq = float ((cluster_psi_sums ** 2 ).sum ())
13751417 return np .sqrt (max (sigma_sq , 0.0 ))
13761418
@@ -1588,6 +1630,7 @@ def _aggregate_event_study(
15881630 kept_cov_mask : Optional [np .ndarray ] = None ,
15891631 survey_weights : Optional [np .ndarray ] = None ,
15901632 survey_df : Optional [int ] = None ,
1633+ resolved_survey = None ,
15911634 ) -> Dict [int , Dict [str , Any ]]:
15921635 """Aggregate treatment effects by event-study horizon."""
15931636 df_1 = df .loc [omega_1_mask ]
@@ -1679,9 +1722,7 @@ def _aggregate_event_study(
16791722 )
16801723 pre_rel_times = [h for h in pre_rel_times if h != ref_period ]
16811724 if self .horizon_max is not None :
1682- pre_rel_times = [
1683- h for h in pre_rel_times if abs (h ) <= self .horizon_max
1684- ]
1725+ pre_rel_times = [h for h in pre_rel_times if abs (h ) <= self .horizon_max ]
16851726 if pre_rel_times :
16861727 pre_effects , _ , _ = self ._compute_lead_coefficients (
16871728 df_0 ,
@@ -1783,6 +1824,7 @@ def _aggregate_event_study(
17831824 cluster_var = cluster_var ,
17841825 kept_cov_mask = kept_cov_mask ,
17851826 survey_weights = survey_weights ,
1827+ resolved_survey = resolved_survey ,
17861828 )
17871829
17881830 t_stat , p_value , conf_int = safe_inference (effect , se , alpha = self .alpha , df = survey_df )
@@ -1845,6 +1887,7 @@ def _aggregate_group(
18451887 kept_cov_mask : Optional [np .ndarray ] = None ,
18461888 survey_weights : Optional [np .ndarray ] = None ,
18471889 survey_df : Optional [int ] = None ,
1890+ resolved_survey = None ,
18481891 ) -> Dict [Any , Dict [str , Any ]]:
18491892 """Aggregate treatment effects by cohort."""
18501893 df_1 = df .loc [omega_1_mask ]
@@ -1916,6 +1959,7 @@ def _aggregate_group(
19161959 cluster_var = cluster_var ,
19171960 kept_cov_mask = kept_cov_mask ,
19181961 survey_weights = survey_weights ,
1962+ resolved_survey = resolved_survey ,
19191963 )
19201964
19211965 t_stat , p_value , conf_int = safe_inference (effect , se , alpha = self .alpha , df = survey_df )
@@ -2023,14 +2067,19 @@ def _compute_lead_coefficients(
20232067 for h in pre_rel_times :
20242068 n_obs = int (df_0 [f"_lead_{ h } " ].sum ())
20252069 effects [h ] = {
2026- "effect" : np .nan , "se" : np .nan , "t_stat" : np .nan ,
2027- "p_value" : np .nan , "conf_int" : (np .nan , np .nan ),
2070+ "effect" : np .nan ,
2071+ "se" : np .nan ,
2072+ "t_stat" : np .nan ,
2073+ "p_value" : np .nan ,
2074+ "conf_int" : (np .nan , np .nan ),
20282075 "n_obs" : n_obs ,
20292076 }
20302077 for col in lead_cols :
20312078 df_0 .drop (columns = col , inplace = True )
2032- return effects , np .full (len (pre_rel_times ), np .nan ), np .full (
2033- (len (pre_rel_times ), len (pre_rel_times )), np .nan
2079+ return (
2080+ effects ,
2081+ np .full (len (pre_rel_times ), np .nan ),
2082+ np .full ((len (pre_rel_times ), len (pre_rel_times )), np .nan ),
20342083 )
20352084
20362085 coefficients = result [0 ]
@@ -2134,8 +2183,15 @@ def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
21342183
21352184 # Use shared lead coefficient computation
21362185 effects , gamma , V_gamma = self ._compute_lead_coefficients (
2137- df_0 , outcome , unit , time , first_treat , covariates ,
2138- cluster_var , pre_rel_times , alpha = self .alpha ,
2186+ df_0 ,
2187+ outcome ,
2188+ unit ,
2189+ time ,
2190+ first_treat ,
2191+ covariates ,
2192+ cluster_var ,
2193+ pre_rel_times ,
2194+ alpha = self .alpha ,
21392195 )
21402196
21412197 n_leads_actual = len (pre_rel_times )
@@ -2249,9 +2305,9 @@ def imputation_did(
22492305 Balance event study to cohorts observed at all relative times.
22502306 survey_design : SurveyDesign, optional
22512307 Survey design specification for design-based inference. Supports
2252- pweight only (aweight/fweight raise ValueError). FPC raises
2253- NotImplementedError. PSU is used as cluster variable for Theorem 3
2254- variance. Strata enters survey df for t-distribution inference.
2308+ pweight only (aweight/fweight raise ValueError). Supports strata,
2309+ PSU, and FPC for design-based variance. Strata enters survey df
2310+ for t-distribution inference.
22552311 Both analytical (n_bootstrap=0) and bootstrap inference are supported.
22562312 **kwargs
22572313 Additional keyword arguments passed to ImputationDiD constructor.
0 commit comments