Skip to content

Commit 2889a84

Browse files
committed
Add robust parallel trends testing with Wasserstein distance
- Add check_parallel_trends_robust() using Wasserstein (Earth Mover's) distance for distributional comparison of pre-treatment outcome changes - Include permutation-based p-value for statistical inference - Add Kolmogorov-Smirnov test as complementary distributional test - Add equivalence_test_trends() using TOST procedure - Compute normalized Wasserstein and variance ratio diagnostics - Add 9 new tests for robust parallel trends functionality - Update README with usage examples for all three approaches The Wasserstein distance is more robust than simple slope comparisons because it captures differences in the full distribution shape, not just means, making it better suited for heterogeneous effects.
1 parent e81d95b commit 2889a84

3 files changed

Lines changed: 614 additions & 1 deletion

File tree

README.md

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ results.r_squared
235235

236236
### Parallel Trends
237237

238+
**Simple slope-based test:**
239+
238240
```python
239241
from diff_diff.utils import check_parallel_trends
240242

@@ -248,7 +250,51 @@ trends = check_parallel_trends(
248250
print(f"Treated trend: {trends['treated_trend']:.4f}")
249251
print(f"Control trend: {trends['control_trend']:.4f}")
250252
print(f"Difference p-value: {trends['p_value']:.4f}")
251-
print(f"Parallel trends plausible: {trends['parallel_trends_plausible']}")
253+
```
254+
255+
**Robust distributional test (Wasserstein distance):**
256+
257+
```python
258+
from diff_diff.utils import check_parallel_trends_robust
259+
260+
results = check_parallel_trends_robust(
261+
data,
262+
outcome='outcome',
263+
time='period',
264+
treatment_group='treated',
265+
unit='firm_id', # Unit identifier for panel data
266+
pre_periods=[2018, 2019], # Pre-treatment periods
267+
n_permutations=1000 # Permutations for p-value
268+
)
269+
270+
print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
271+
print(f"Wasserstein p-value: {results['wasserstein_p_value']:.4f}")
272+
print(f"KS test p-value: {results['ks_p_value']:.4f}")
273+
print(f"Parallel trends plausible: {results['parallel_trends_plausible']}")
274+
```
275+
276+
The Wasserstein (Earth Mover's) distance compares the full distribution of outcome changes, not just means. This is more robust to:
277+
- Non-normal distributions
278+
- Heterogeneous effects across units
279+
- Outliers
280+
281+
**Equivalence testing (TOST):**
282+
283+
```python
284+
from diff_diff.utils import equivalence_test_trends
285+
286+
results = equivalence_test_trends(
287+
data,
288+
outcome='outcome',
289+
time='period',
290+
treatment_group='treated',
291+
unit='firm_id',
292+
equivalence_margin=0.5 # Define "practically equivalent"
293+
)
294+
295+
print(f"Mean difference: {results['mean_difference']:.4f}")
296+
print(f"TOST p-value: {results['tost_p_value']:.4f}")
297+
print(f"Trends equivalent: {results['equivalent']}")
252298
```
253299

254300
## API Reference

diff_diff/utils.py

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,334 @@ def compute_trend(group_data):
248248
"p_value": p_value,
249249
"parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
250250
}
251+
252+
253+
def check_parallel_trends_robust(
254+
data: pd.DataFrame,
255+
outcome: str,
256+
time: str,
257+
treatment_group: str,
258+
unit: str = None,
259+
pre_periods: list = None,
260+
n_permutations: int = 1000,
261+
seed: int = None
262+
) -> dict:
263+
"""
264+
Perform robust parallel trends testing using distributional comparisons.
265+
266+
Uses the Wasserstein (Earth Mover's) distance to compare the full
267+
distribution of outcome changes between treated and control groups,
268+
with permutation-based inference.
269+
270+
Parameters
271+
----------
272+
data : pd.DataFrame
273+
Panel data with repeated observations over time.
274+
outcome : str
275+
Name of outcome variable column.
276+
time : str
277+
Name of time period column.
278+
treatment_group : str
279+
Name of treatment group indicator column (0/1).
280+
unit : str, optional
281+
Name of unit identifier column. If provided, computes unit-level
282+
changes. Otherwise uses observation-level data.
283+
pre_periods : list, optional
284+
List of pre-treatment time periods. If None, uses first half of periods.
285+
n_permutations : int, default=1000
286+
Number of permutations for computing p-value.
287+
seed : int, optional
288+
Random seed for reproducibility.
289+
290+
Returns
291+
-------
292+
dict
293+
Dictionary containing:
294+
- wasserstein_distance: Wasserstein distance between group distributions
295+
- wasserstein_p_value: Permutation-based p-value
296+
- ks_statistic: Kolmogorov-Smirnov test statistic
297+
- ks_p_value: KS test p-value
298+
- mean_difference: Difference in mean changes
299+
- variance_ratio: Ratio of variances in changes
300+
- treated_changes: Array of outcome changes for treated
301+
- control_changes: Array of outcome changes for control
302+
- parallel_trends_plausible: Boolean assessment
303+
304+
Examples
305+
--------
306+
>>> results = check_parallel_trends_robust(
307+
... data, outcome='sales', time='year',
308+
... treatment_group='treated', unit='firm_id'
309+
... )
310+
>>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
311+
>>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
312+
313+
Notes
314+
-----
315+
The Wasserstein distance (Earth Mover's Distance) measures the minimum
316+
"cost" of transforming one distribution into another. Unlike simple
317+
mean comparisons, it captures differences in the entire distribution
318+
shape, making it more robust to non-normal data and heterogeneous effects.
319+
320+
A small Wasserstein distance and high p-value suggest the distributions
321+
of pre-treatment changes are similar, supporting the parallel trends
322+
assumption.
323+
"""
324+
if seed is not None:
325+
np.random.seed(seed)
326+
327+
# Identify pre-treatment periods
328+
if pre_periods is None:
329+
all_periods = sorted(data[time].unique())
330+
mid_point = len(all_periods) // 2
331+
pre_periods = all_periods[:mid_point]
332+
333+
pre_data = data[data[time].isin(pre_periods)].copy()
334+
335+
# Compute outcome changes
336+
treated_changes, control_changes = _compute_outcome_changes(
337+
pre_data, outcome, time, treatment_group, unit
338+
)
339+
340+
if len(treated_changes) < 2 or len(control_changes) < 2:
341+
return {
342+
"wasserstein_distance": np.nan,
343+
"wasserstein_p_value": np.nan,
344+
"ks_statistic": np.nan,
345+
"ks_p_value": np.nan,
346+
"mean_difference": np.nan,
347+
"variance_ratio": np.nan,
348+
"treated_changes": treated_changes,
349+
"control_changes": control_changes,
350+
"parallel_trends_plausible": None,
351+
"error": "Insufficient data for comparison",
352+
}
353+
354+
# Compute Wasserstein distance
355+
wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)
356+
357+
# Permutation test for Wasserstein distance
358+
all_changes = np.concatenate([treated_changes, control_changes])
359+
n_treated = len(treated_changes)
360+
n_total = len(all_changes)
361+
362+
permuted_distances = np.zeros(n_permutations)
363+
for i in range(n_permutations):
364+
perm_idx = np.random.permutation(n_total)
365+
perm_treated = all_changes[perm_idx[:n_treated]]
366+
perm_control = all_changes[perm_idx[n_treated:]]
367+
permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
368+
369+
# P-value: proportion of permuted distances >= observed
370+
wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)
371+
372+
# Kolmogorov-Smirnov test
373+
ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)
374+
375+
# Additional summary statistics
376+
mean_diff = np.mean(treated_changes) - np.mean(control_changes)
377+
var_treated = np.var(treated_changes, ddof=1)
378+
var_control = np.var(control_changes, ddof=1)
379+
var_ratio = var_treated / var_control if var_control > 0 else np.nan
380+
381+
# Normalized Wasserstein (relative to pooled std)
382+
pooled_std = np.std(all_changes, ddof=1)
383+
wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan
384+
385+
# Assessment: parallel trends plausible if p-value > 0.05
386+
# and normalized Wasserstein is small (< 0.2 as rule of thumb)
387+
plausible = bool(
388+
wasserstein_p > 0.05 and
389+
(wasserstein_normalized < 0.2 if not np.isnan(wasserstein_normalized) else True)
390+
)
391+
392+
return {
393+
"wasserstein_distance": wasserstein_dist,
394+
"wasserstein_normalized": wasserstein_normalized,
395+
"wasserstein_p_value": wasserstein_p,
396+
"ks_statistic": ks_stat,
397+
"ks_p_value": ks_p,
398+
"mean_difference": mean_diff,
399+
"variance_ratio": var_ratio,
400+
"n_treated": len(treated_changes),
401+
"n_control": len(control_changes),
402+
"treated_changes": treated_changes,
403+
"control_changes": control_changes,
404+
"parallel_trends_plausible": plausible,
405+
}
406+
407+
408+
def _compute_outcome_changes(
409+
data: pd.DataFrame,
410+
outcome: str,
411+
time: str,
412+
treatment_group: str,
413+
unit: str = None
414+
) -> tuple:
415+
"""
416+
Compute period-to-period outcome changes for treated and control groups.
417+
418+
Parameters
419+
----------
420+
data : pd.DataFrame
421+
Panel data.
422+
outcome : str
423+
Outcome variable column.
424+
time : str
425+
Time period column.
426+
treatment_group : str
427+
Treatment group indicator column.
428+
unit : str, optional
429+
Unit identifier column.
430+
431+
Returns
432+
-------
433+
tuple
434+
(treated_changes, control_changes) as numpy arrays.
435+
"""
436+
if unit is not None:
437+
# Unit-level changes: compute change for each unit across periods
438+
data_sorted = data.sort_values([unit, time])
439+
data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()
440+
441+
# Remove NaN from first period of each unit
442+
changes_data = data_sorted.dropna(subset=["_outcome_change"])
443+
444+
treated_changes = changes_data[
445+
changes_data[treatment_group] == 1
446+
]["_outcome_change"].values
447+
448+
control_changes = changes_data[
449+
changes_data[treatment_group] == 0
450+
]["_outcome_change"].values
451+
else:
452+
# Aggregate changes: compute mean change per period per group
453+
periods = sorted(data[time].unique())
454+
455+
treated_data = data[data[treatment_group] == 1]
456+
control_data = data[data[treatment_group] == 0]
457+
458+
# Compute period means
459+
treated_means = treated_data.groupby(time)[outcome].mean()
460+
control_means = control_data.groupby(time)[outcome].mean()
461+
462+
# Compute changes between consecutive periods
463+
treated_changes = np.diff(treated_means.values)
464+
control_changes = np.diff(control_means.values)
465+
466+
return treated_changes.astype(float), control_changes.astype(float)
467+
468+
469+
def equivalence_test_trends(
470+
data: pd.DataFrame,
471+
outcome: str,
472+
time: str,
473+
treatment_group: str,
474+
unit: str = None,
475+
pre_periods: list = None,
476+
equivalence_margin: float = None
477+
) -> dict:
478+
"""
479+
Perform equivalence testing (TOST) for parallel trends.
480+
481+
Tests whether the difference in trends is practically equivalent to zero
482+
using Two One-Sided Tests (TOST) procedure.
483+
484+
Parameters
485+
----------
486+
data : pd.DataFrame
487+
Panel data.
488+
outcome : str
489+
Name of outcome variable column.
490+
time : str
491+
Name of time period column.
492+
treatment_group : str
493+
Name of treatment group indicator column.
494+
unit : str, optional
495+
Name of unit identifier column.
496+
pre_periods : list, optional
497+
List of pre-treatment time periods.
498+
equivalence_margin : float, optional
499+
The margin for equivalence (delta). If None, uses 0.5 * pooled SD
500+
of outcome changes as a default.
501+
502+
Returns
503+
-------
504+
dict
505+
Dictionary containing:
506+
- mean_difference: Difference in mean changes
507+
- equivalence_margin: The margin used
508+
- lower_p_value: P-value for lower bound test
509+
- upper_p_value: P-value for upper bound test
510+
- tost_p_value: Maximum of the two p-values
511+
- equivalent: Boolean indicating equivalence at alpha=0.05
512+
"""
513+
# Get pre-treatment periods
514+
if pre_periods is None:
515+
all_periods = sorted(data[time].unique())
516+
mid_point = len(all_periods) // 2
517+
pre_periods = all_periods[:mid_point]
518+
519+
pre_data = data[data[time].isin(pre_periods)].copy()
520+
521+
# Compute outcome changes
522+
treated_changes, control_changes = _compute_outcome_changes(
523+
pre_data, outcome, time, treatment_group, unit
524+
)
525+
526+
if len(treated_changes) < 2 or len(control_changes) < 2:
527+
return {
528+
"mean_difference": np.nan,
529+
"equivalence_margin": np.nan,
530+
"lower_p_value": np.nan,
531+
"upper_p_value": np.nan,
532+
"tost_p_value": np.nan,
533+
"equivalent": None,
534+
"error": "Insufficient data",
535+
}
536+
537+
# Compute statistics
538+
mean_diff = np.mean(treated_changes) - np.mean(control_changes)
539+
se_diff = np.sqrt(
540+
np.var(treated_changes, ddof=1) / len(treated_changes) +
541+
np.var(control_changes, ddof=1) / len(control_changes)
542+
)
543+
544+
# Set equivalence margin if not provided
545+
if equivalence_margin is None:
546+
pooled_changes = np.concatenate([treated_changes, control_changes])
547+
equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)
548+
549+
# Degrees of freedom (Welch-Satterthwaite approximation)
550+
var_t = np.var(treated_changes, ddof=1)
551+
var_c = np.var(control_changes, ddof=1)
552+
n_t = len(treated_changes)
553+
n_c = len(control_changes)
554+
555+
df = ((var_t/n_t + var_c/n_c)**2 /
556+
((var_t/n_t)**2/(n_t-1) + (var_c/n_c)**2/(n_c-1)))
557+
558+
# TOST: Two one-sided tests
559+
# Test 1: H0: diff <= -margin vs H1: diff > -margin
560+
t_lower = (mean_diff - (-equivalence_margin)) / se_diff
561+
p_lower = stats.t.sf(t_lower, df)
562+
563+
# Test 2: H0: diff >= margin vs H1: diff < margin
564+
t_upper = (mean_diff - equivalence_margin) / se_diff
565+
p_upper = stats.t.cdf(t_upper, df)
566+
567+
# TOST p-value is the maximum of the two
568+
tost_p = max(p_lower, p_upper)
569+
570+
return {
571+
"mean_difference": mean_diff,
572+
"se_difference": se_diff,
573+
"equivalence_margin": equivalence_margin,
574+
"lower_t_stat": t_lower,
575+
"upper_t_stat": t_upper,
576+
"lower_p_value": p_lower,
577+
"upper_p_value": p_upper,
578+
"tost_p_value": tost_p,
579+
"degrees_of_freedom": df,
580+
"equivalent": bool(tost_p < 0.05),
581+
}

0 commit comments

Comments
 (0)