@@ -248,3 +248,334 @@ def compute_trend(group_data):
248248 "p_value" : p_value ,
249249 "parallel_trends_plausible" : p_value > 0.05 if not np .isnan (p_value ) else None ,
250250 }
251+
252+
253+ def check_parallel_trends_robust (
254+ data : pd .DataFrame ,
255+ outcome : str ,
256+ time : str ,
257+ treatment_group : str ,
258+ unit : str = None ,
259+ pre_periods : list = None ,
260+ n_permutations : int = 1000 ,
261+ seed : int = None
262+ ) -> dict :
263+ """
264+ Perform robust parallel trends testing using distributional comparisons.
265+
266+ Uses the Wasserstein (Earth Mover's) distance to compare the full
267+ distribution of outcome changes between treated and control groups,
268+ with permutation-based inference.
269+
270+ Parameters
271+ ----------
272+ data : pd.DataFrame
273+ Panel data with repeated observations over time.
274+ outcome : str
275+ Name of outcome variable column.
276+ time : str
277+ Name of time period column.
278+ treatment_group : str
279+ Name of treatment group indicator column (0/1).
280+ unit : str, optional
281+ Name of unit identifier column. If provided, computes unit-level
282+ changes. Otherwise uses observation-level data.
283+ pre_periods : list, optional
284+ List of pre-treatment time periods. If None, uses first half of periods.
285+ n_permutations : int, default=1000
286+ Number of permutations for computing p-value.
287+ seed : int, optional
288+ Random seed for reproducibility.
289+
290+ Returns
291+ -------
292+ dict
293+ Dictionary containing:
294+ - wasserstein_distance: Wasserstein distance between group distributions
295+ - wasserstein_p_value: Permutation-based p-value
296+ - ks_statistic: Kolmogorov-Smirnov test statistic
297+ - ks_p_value: KS test p-value
298+ - mean_difference: Difference in mean changes
299+ - variance_ratio: Ratio of variances in changes
300+ - treated_changes: Array of outcome changes for treated
301+ - control_changes: Array of outcome changes for control
302+ - parallel_trends_plausible: Boolean assessment
303+
304+ Examples
305+ --------
306+ >>> results = check_parallel_trends_robust(
307+ ... data, outcome='sales', time='year',
308+ ... treatment_group='treated', unit='firm_id'
309+ ... )
310+ >>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
311+ >>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
312+
313+ Notes
314+ -----
315+ The Wasserstein distance (Earth Mover's Distance) measures the minimum
316+ "cost" of transforming one distribution into another. Unlike simple
317+ mean comparisons, it captures differences in the entire distribution
318+ shape, making it more robust to non-normal data and heterogeneous effects.
319+
320+ A small Wasserstein distance and high p-value suggest the distributions
321+ of pre-treatment changes are similar, supporting the parallel trends
322+ assumption.
323+ """
324+ if seed is not None :
325+ np .random .seed (seed )
326+
327+ # Identify pre-treatment periods
328+ if pre_periods is None :
329+ all_periods = sorted (data [time ].unique ())
330+ mid_point = len (all_periods ) // 2
331+ pre_periods = all_periods [:mid_point ]
332+
333+ pre_data = data [data [time ].isin (pre_periods )].copy ()
334+
335+ # Compute outcome changes
336+ treated_changes , control_changes = _compute_outcome_changes (
337+ pre_data , outcome , time , treatment_group , unit
338+ )
339+
340+ if len (treated_changes ) < 2 or len (control_changes ) < 2 :
341+ return {
342+ "wasserstein_distance" : np .nan ,
343+ "wasserstein_p_value" : np .nan ,
344+ "ks_statistic" : np .nan ,
345+ "ks_p_value" : np .nan ,
346+ "mean_difference" : np .nan ,
347+ "variance_ratio" : np .nan ,
348+ "treated_changes" : treated_changes ,
349+ "control_changes" : control_changes ,
350+ "parallel_trends_plausible" : None ,
351+ "error" : "Insufficient data for comparison" ,
352+ }
353+
354+ # Compute Wasserstein distance
355+ wasserstein_dist = stats .wasserstein_distance (treated_changes , control_changes )
356+
357+ # Permutation test for Wasserstein distance
358+ all_changes = np .concatenate ([treated_changes , control_changes ])
359+ n_treated = len (treated_changes )
360+ n_total = len (all_changes )
361+
362+ permuted_distances = np .zeros (n_permutations )
363+ for i in range (n_permutations ):
364+ perm_idx = np .random .permutation (n_total )
365+ perm_treated = all_changes [perm_idx [:n_treated ]]
366+ perm_control = all_changes [perm_idx [n_treated :]]
367+ permuted_distances [i ] = stats .wasserstein_distance (perm_treated , perm_control )
368+
369+ # P-value: proportion of permuted distances >= observed
370+ wasserstein_p = np .mean (permuted_distances >= wasserstein_dist )
371+
372+ # Kolmogorov-Smirnov test
373+ ks_stat , ks_p = stats .ks_2samp (treated_changes , control_changes )
374+
375+ # Additional summary statistics
376+ mean_diff = np .mean (treated_changes ) - np .mean (control_changes )
377+ var_treated = np .var (treated_changes , ddof = 1 )
378+ var_control = np .var (control_changes , ddof = 1 )
379+ var_ratio = var_treated / var_control if var_control > 0 else np .nan
380+
381+ # Normalized Wasserstein (relative to pooled std)
382+ pooled_std = np .std (all_changes , ddof = 1 )
383+ wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np .nan
384+
385+ # Assessment: parallel trends plausible if p-value > 0.05
386+ # and normalized Wasserstein is small (< 0.2 as rule of thumb)
387+ plausible = bool (
388+ wasserstein_p > 0.05 and
389+ (wasserstein_normalized < 0.2 if not np .isnan (wasserstein_normalized ) else True )
390+ )
391+
392+ return {
393+ "wasserstein_distance" : wasserstein_dist ,
394+ "wasserstein_normalized" : wasserstein_normalized ,
395+ "wasserstein_p_value" : wasserstein_p ,
396+ "ks_statistic" : ks_stat ,
397+ "ks_p_value" : ks_p ,
398+ "mean_difference" : mean_diff ,
399+ "variance_ratio" : var_ratio ,
400+ "n_treated" : len (treated_changes ),
401+ "n_control" : len (control_changes ),
402+ "treated_changes" : treated_changes ,
403+ "control_changes" : control_changes ,
404+ "parallel_trends_plausible" : plausible ,
405+ }
406+
407+
408+ def _compute_outcome_changes (
409+ data : pd .DataFrame ,
410+ outcome : str ,
411+ time : str ,
412+ treatment_group : str ,
413+ unit : str = None
414+ ) -> tuple :
415+ """
416+ Compute period-to-period outcome changes for treated and control groups.
417+
418+ Parameters
419+ ----------
420+ data : pd.DataFrame
421+ Panel data.
422+ outcome : str
423+ Outcome variable column.
424+ time : str
425+ Time period column.
426+ treatment_group : str
427+ Treatment group indicator column.
428+ unit : str, optional
429+ Unit identifier column.
430+
431+ Returns
432+ -------
433+ tuple
434+ (treated_changes, control_changes) as numpy arrays.
435+ """
436+ if unit is not None :
437+ # Unit-level changes: compute change for each unit across periods
438+ data_sorted = data .sort_values ([unit , time ])
439+ data_sorted ["_outcome_change" ] = data_sorted .groupby (unit )[outcome ].diff ()
440+
441+ # Remove NaN from first period of each unit
442+ changes_data = data_sorted .dropna (subset = ["_outcome_change" ])
443+
444+ treated_changes = changes_data [
445+ changes_data [treatment_group ] == 1
446+ ]["_outcome_change" ].values
447+
448+ control_changes = changes_data [
449+ changes_data [treatment_group ] == 0
450+ ]["_outcome_change" ].values
451+ else :
452+ # Aggregate changes: compute mean change per period per group
453+ periods = sorted (data [time ].unique ())
454+
455+ treated_data = data [data [treatment_group ] == 1 ]
456+ control_data = data [data [treatment_group ] == 0 ]
457+
458+ # Compute period means
459+ treated_means = treated_data .groupby (time )[outcome ].mean ()
460+ control_means = control_data .groupby (time )[outcome ].mean ()
461+
462+ # Compute changes between consecutive periods
463+ treated_changes = np .diff (treated_means .values )
464+ control_changes = np .diff (control_means .values )
465+
466+ return treated_changes .astype (float ), control_changes .astype (float )
467+
468+
469+ def equivalence_test_trends (
470+ data : pd .DataFrame ,
471+ outcome : str ,
472+ time : str ,
473+ treatment_group : str ,
474+ unit : str = None ,
475+ pre_periods : list = None ,
476+ equivalence_margin : float = None
477+ ) -> dict :
478+ """
479+ Perform equivalence testing (TOST) for parallel trends.
480+
481+ Tests whether the difference in trends is practically equivalent to zero
482+ using Two One-Sided Tests (TOST) procedure.
483+
484+ Parameters
485+ ----------
486+ data : pd.DataFrame
487+ Panel data.
488+ outcome : str
489+ Name of outcome variable column.
490+ time : str
491+ Name of time period column.
492+ treatment_group : str
493+ Name of treatment group indicator column.
494+ unit : str, optional
495+ Name of unit identifier column.
496+ pre_periods : list, optional
497+ List of pre-treatment time periods.
498+ equivalence_margin : float, optional
499+ The margin for equivalence (delta). If None, uses 0.5 * pooled SD
500+ of outcome changes as a default.
501+
502+ Returns
503+ -------
504+ dict
505+ Dictionary containing:
506+ - mean_difference: Difference in mean changes
507+ - equivalence_margin: The margin used
508+ - lower_p_value: P-value for lower bound test
509+ - upper_p_value: P-value for upper bound test
510+ - tost_p_value: Maximum of the two p-values
511+ - equivalent: Boolean indicating equivalence at alpha=0.05
512+ """
513+ # Get pre-treatment periods
514+ if pre_periods is None :
515+ all_periods = sorted (data [time ].unique ())
516+ mid_point = len (all_periods ) // 2
517+ pre_periods = all_periods [:mid_point ]
518+
519+ pre_data = data [data [time ].isin (pre_periods )].copy ()
520+
521+ # Compute outcome changes
522+ treated_changes , control_changes = _compute_outcome_changes (
523+ pre_data , outcome , time , treatment_group , unit
524+ )
525+
526+ if len (treated_changes ) < 2 or len (control_changes ) < 2 :
527+ return {
528+ "mean_difference" : np .nan ,
529+ "equivalence_margin" : np .nan ,
530+ "lower_p_value" : np .nan ,
531+ "upper_p_value" : np .nan ,
532+ "tost_p_value" : np .nan ,
533+ "equivalent" : None ,
534+ "error" : "Insufficient data" ,
535+ }
536+
537+ # Compute statistics
538+ mean_diff = np .mean (treated_changes ) - np .mean (control_changes )
539+ se_diff = np .sqrt (
540+ np .var (treated_changes , ddof = 1 ) / len (treated_changes ) +
541+ np .var (control_changes , ddof = 1 ) / len (control_changes )
542+ )
543+
544+ # Set equivalence margin if not provided
545+ if equivalence_margin is None :
546+ pooled_changes = np .concatenate ([treated_changes , control_changes ])
547+ equivalence_margin = 0.5 * np .std (pooled_changes , ddof = 1 )
548+
549+ # Degrees of freedom (Welch-Satterthwaite approximation)
550+ var_t = np .var (treated_changes , ddof = 1 )
551+ var_c = np .var (control_changes , ddof = 1 )
552+ n_t = len (treated_changes )
553+ n_c = len (control_changes )
554+
555+ df = ((var_t / n_t + var_c / n_c )** 2 /
556+ ((var_t / n_t )** 2 / (n_t - 1 ) + (var_c / n_c )** 2 / (n_c - 1 )))
557+
558+ # TOST: Two one-sided tests
559+ # Test 1: H0: diff <= -margin vs H1: diff > -margin
560+ t_lower = (mean_diff - (- equivalence_margin )) / se_diff
561+ p_lower = stats .t .sf (t_lower , df )
562+
563+ # Test 2: H0: diff >= margin vs H1: diff < margin
564+ t_upper = (mean_diff - equivalence_margin ) / se_diff
565+ p_upper = stats .t .cdf (t_upper , df )
566+
567+ # TOST p-value is the maximum of the two
568+ tost_p = max (p_lower , p_upper )
569+
570+ return {
571+ "mean_difference" : mean_diff ,
572+ "se_difference" : se_diff ,
573+ "equivalence_margin" : equivalence_margin ,
574+ "lower_t_stat" : t_lower ,
575+ "upper_t_stat" : t_upper ,
576+ "lower_p_value" : p_lower ,
577+ "upper_p_value" : p_upper ,
578+ "tost_p_value" : tost_p ,
579+ "degrees_of_freedom" : df ,
580+ "equivalent" : bool (tost_p < 0.05 ),
581+ }
0 commit comments