@@ -307,6 +307,79 @@ cohort_data = aggregate_to_cohorts(
307307# Result: mean outcome by treatment group and period
308308```
309309
310+ ### Rank Control Units
311+
312+ Select the best control units for DiD or Synthetic DiD analysis by ranking them based on pre-treatment outcome similarity:
313+
314+ ``` python
315+ from diff_diff import rank_control_units, generate_did_data
316+
317+ # Generate sample data
318+ data = generate_did_data(n_units = 50 , n_periods = 6 , seed = 42 )
319+
320+ # Rank control units by their similarity to treated units
321+ ranking = rank_control_units(
322+ data,
323+ unit_column = ' unit' ,
324+ time_column = ' period' ,
325+ outcome_column = ' outcome' ,
326+ treatment_column = ' treated' ,
327+ n_top = 10 # Return top 10 controls
328+ )
329+
330+ print (ranking[[' unit' , ' quality_score' , ' pre_trend_rmse' ]])
331+ ```
332+
333+ Output:
334+ ```
335+ unit quality_score pre_trend_rmse
336+ 0 35 1.0000 0.4521
337+ 1 42 0.9234 0.5123
338+ 2 28 0.8876 0.5892
339+ ...
340+ ```
341+
342+ With covariates for matching:
343+
344+ ``` python
345+ # Add covariate-based matching
346+ ranking = rank_control_units(
347+ data,
348+ unit_column = ' unit' ,
349+ time_column = ' period' ,
350+ outcome_column = ' outcome' ,
351+ treatment_column = ' treated' ,
352+ covariates = [' size' , ' age' ], # Match on these too
353+ outcome_weight = 0.7 , # 70% weight on outcome trends
354+ covariate_weight = 0.3 # 30% weight on covariate similarity
355+ )
356+ ```
357+
358+ Filter data for SyntheticDiD using top controls:
359+
360+ ``` python
361+ from diff_diff import SyntheticDiD
362+
363+ # Get top control units
364+ top_controls = ranking[' unit' ].tolist()
365+
366+ # Filter data to treated + top controls
367+ filtered_data = data[
368+ (data[' treated' ] == 1 ) | (data[' unit' ].isin(top_controls))
369+ ]
370+
371+ # Fit SyntheticDiD with selected controls
372+ sdid = SyntheticDiD()
373+ results = sdid.fit(
374+ filtered_data,
375+ outcome = ' outcome' ,
376+ treatment = ' treated' ,
377+ unit = ' unit' ,
378+ time = ' period' ,
379+ post_periods = [3 , 4 , 5 ]
380+ )
381+ ```
382+
310383## Usage
311384
312385### Basic DiD with Column Names
@@ -1026,6 +1099,31 @@ aggregate_to_cohorts(
10261099)
10271100```
10281101
1102+ #### rank_control_units
1103+
1104+ ``` python
1105+ rank_control_units(
1106+ data, # Panel data in long format
1107+ unit_column, # Unit identifier column
1108+ time_column, # Time period column
1109+ outcome_column, # Outcome variable column
1110+ treatment_column = None , # Treatment indicator column (0/1)
1111+ treated_units = None , # Explicit list of treated unit IDs
1112+ pre_periods = None , # Pre-treatment periods (default: first half)
1113+ covariates = None , # Covariate columns for matching
1114+ outcome_weight = 0.7 , # Weight for outcome trend similarity (0-1)
1115+ covariate_weight = 0.3 , # Weight for covariate distance (0-1)
1116+ exclude_units = None , # Units to exclude from control pool
1117+ require_units = None , # Units that must appear in output
1118+ n_top = None , # Return only top N controls
1119+ suggest_treatment_candidates = False , # Identify treatment candidates
1120+ n_treatment_candidates = 5 , # Number of treatment candidates
1121+ lambda_reg = 0.0 # Regularization for synthetic weights
1122+ )
1123+ ```
1124+
1125+ Returns DataFrame with columns: ` unit ` , ` quality_score ` , ` outcome_trend_score ` , ` covariate_score ` , ` synthetic_weight ` , ` pre_trend_rmse ` , ` is_required ` .
1126+
10291127## Requirements
10301128
10311129- Python >= 3.9
0 commit comments