Skip to content

Commit d86b108

Browse files
authored
Merge pull request #3 from igerber/claude/init-did-library-pvNmf
2 parents 860f8c8 + 2889a84 commit d86b108

3 files changed

Lines changed: 614 additions & 1 deletion

File tree

README.md

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ results.r_squared
235235

236236
### Parallel Trends
237237

238+
**Simple slope-based test:**
239+
238240
```python
239241
from diff_diff.utils import check_parallel_trends
240242

@@ -248,7 +250,51 @@ trends = check_parallel_trends(
248250
print(f"Treated trend: {trends['treated_trend']:.4f}")
249251
print(f"Control trend: {trends['control_trend']:.4f}")
250252
print(f"Difference p-value: {trends['p_value']:.4f}")
251-
print(f"Parallel trends plausible: {trends['parallel_trends_plausible']}")
253+
```
254+
255+
**Robust distributional test (Wasserstein distance):**
256+
257+
```python
258+
from diff_diff.utils import check_parallel_trends_robust
259+
260+
results = check_parallel_trends_robust(
261+
data,
262+
outcome='outcome',
263+
time='period',
264+
treatment_group='treated',
265+
unit='firm_id', # Unit identifier for panel data
266+
pre_periods=[2018, 2019], # Pre-treatment periods
267+
n_permutations=1000 # Permutations for p-value
268+
)
269+
270+
print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
271+
print(f"Wasserstein p-value: {results['wasserstein_p_value']:.4f}")
272+
print(f"KS test p-value: {results['ks_p_value']:.4f}")
273+
print(f"Parallel trends plausible: {results['parallel_trends_plausible']}")
274+
```
275+
276+
The Wasserstein (Earth Mover's) distance compares the full distribution of outcome changes, not just means. This is more robust to:
277+
- Non-normal distributions
278+
- Heterogeneous effects across units
279+
- Outliers
280+
281+
**Equivalence testing (TOST):**
282+
283+
```python
284+
from diff_diff.utils import equivalence_test_trends
285+
286+
results = equivalence_test_trends(
287+
data,
288+
outcome='outcome',
289+
time='period',
290+
treatment_group='treated',
291+
unit='firm_id',
292+
equivalence_margin=0.5 # Define "practically equivalent"
293+
)
294+
295+
print(f"Mean difference: {results['mean_difference']:.4f}")
296+
print(f"TOST p-value: {results['tost_p_value']:.4f}")
297+
print(f"Trends equivalent: {results['equivalent']}")
252298
```
253299

254300
## API Reference

diff_diff/utils.py

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,334 @@ def compute_trend(group_data):
248248
"p_value": p_value,
249249
"parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
250250
}
251+
252+
253+
def check_parallel_trends_robust(
254+
data: pd.DataFrame,
255+
outcome: str,
256+
time: str,
257+
treatment_group: str,
258+
unit: str = None,
259+
pre_periods: list = None,
260+
n_permutations: int = 1000,
261+
seed: int = None
262+
) -> dict:
263+
"""
264+
Perform robust parallel trends testing using distributional comparisons.
265+
266+
Uses the Wasserstein (Earth Mover's) distance to compare the full
267+
distribution of outcome changes between treated and control groups,
268+
with permutation-based inference.
269+
270+
Parameters
271+
----------
272+
data : pd.DataFrame
273+
Panel data with repeated observations over time.
274+
outcome : str
275+
Name of outcome variable column.
276+
time : str
277+
Name of time period column.
278+
treatment_group : str
279+
Name of treatment group indicator column (0/1).
280+
unit : str, optional
281+
Name of unit identifier column. If provided, computes unit-level
282+
changes. Otherwise uses observation-level data.
283+
pre_periods : list, optional
284+
List of pre-treatment time periods. If None, uses first half of periods.
285+
n_permutations : int, default=1000
286+
Number of permutations for computing p-value.
287+
seed : int, optional
288+
Random seed for reproducibility.
289+
290+
Returns
291+
-------
292+
dict
293+
Dictionary containing:
294+
- wasserstein_distance: Wasserstein distance between group distributions
295+
- wasserstein_p_value: Permutation-based p-value
296+
- ks_statistic: Kolmogorov-Smirnov test statistic
297+
- ks_p_value: KS test p-value
298+
- mean_difference: Difference in mean changes
299+
- variance_ratio: Ratio of variances in changes
300+
- treated_changes: Array of outcome changes for treated
301+
- control_changes: Array of outcome changes for control
302+
- parallel_trends_plausible: Boolean assessment
303+
304+
Examples
305+
--------
306+
>>> results = check_parallel_trends_robust(
307+
... data, outcome='sales', time='year',
308+
... treatment_group='treated', unit='firm_id'
309+
... )
310+
>>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
311+
>>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
312+
313+
Notes
314+
-----
315+
The Wasserstein distance (Earth Mover's Distance) measures the minimum
316+
"cost" of transforming one distribution into another. Unlike simple
317+
mean comparisons, it captures differences in the entire distribution
318+
shape, making it more robust to non-normal data and heterogeneous effects.
319+
320+
A small Wasserstein distance and high p-value suggest the distributions
321+
of pre-treatment changes are similar, supporting the parallel trends
322+
assumption.
323+
"""
324+
if seed is not None:
325+
np.random.seed(seed)
326+
327+
# Identify pre-treatment periods
328+
if pre_periods is None:
329+
all_periods = sorted(data[time].unique())
330+
mid_point = len(all_periods) // 2
331+
pre_periods = all_periods[:mid_point]
332+
333+
pre_data = data[data[time].isin(pre_periods)].copy()
334+
335+
# Compute outcome changes
336+
treated_changes, control_changes = _compute_outcome_changes(
337+
pre_data, outcome, time, treatment_group, unit
338+
)
339+
340+
if len(treated_changes) < 2 or len(control_changes) < 2:
341+
return {
342+
"wasserstein_distance": np.nan,
343+
"wasserstein_p_value": np.nan,
344+
"ks_statistic": np.nan,
345+
"ks_p_value": np.nan,
346+
"mean_difference": np.nan,
347+
"variance_ratio": np.nan,
348+
"treated_changes": treated_changes,
349+
"control_changes": control_changes,
350+
"parallel_trends_plausible": None,
351+
"error": "Insufficient data for comparison",
352+
}
353+
354+
# Compute Wasserstein distance
355+
wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)
356+
357+
# Permutation test for Wasserstein distance
358+
all_changes = np.concatenate([treated_changes, control_changes])
359+
n_treated = len(treated_changes)
360+
n_total = len(all_changes)
361+
362+
permuted_distances = np.zeros(n_permutations)
363+
for i in range(n_permutations):
364+
perm_idx = np.random.permutation(n_total)
365+
perm_treated = all_changes[perm_idx[:n_treated]]
366+
perm_control = all_changes[perm_idx[n_treated:]]
367+
permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
368+
369+
# P-value: proportion of permuted distances >= observed
370+
wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)
371+
372+
# Kolmogorov-Smirnov test
373+
ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)
374+
375+
# Additional summary statistics
376+
mean_diff = np.mean(treated_changes) - np.mean(control_changes)
377+
var_treated = np.var(treated_changes, ddof=1)
378+
var_control = np.var(control_changes, ddof=1)
379+
var_ratio = var_treated / var_control if var_control > 0 else np.nan
380+
381+
# Normalized Wasserstein (relative to pooled std)
382+
pooled_std = np.std(all_changes, ddof=1)
383+
wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan
384+
385+
# Assessment: parallel trends plausible if p-value > 0.05
386+
# and normalized Wasserstein is small (< 0.2 as rule of thumb)
387+
plausible = bool(
388+
wasserstein_p > 0.05 and
389+
(wasserstein_normalized < 0.2 if not np.isnan(wasserstein_normalized) else True)
390+
)
391+
392+
return {
393+
"wasserstein_distance": wasserstein_dist,
394+
"wasserstein_normalized": wasserstein_normalized,
395+
"wasserstein_p_value": wasserstein_p,
396+
"ks_statistic": ks_stat,
397+
"ks_p_value": ks_p,
398+
"mean_difference": mean_diff,
399+
"variance_ratio": var_ratio,
400+
"n_treated": len(treated_changes),
401+
"n_control": len(control_changes),
402+
"treated_changes": treated_changes,
403+
"control_changes": control_changes,
404+
"parallel_trends_plausible": plausible,
405+
}
406+
407+
408+
def _compute_outcome_changes(
409+
data: pd.DataFrame,
410+
outcome: str,
411+
time: str,
412+
treatment_group: str,
413+
unit: str = None
414+
) -> tuple:
415+
"""
416+
Compute period-to-period outcome changes for treated and control groups.
417+
418+
Parameters
419+
----------
420+
data : pd.DataFrame
421+
Panel data.
422+
outcome : str
423+
Outcome variable column.
424+
time : str
425+
Time period column.
426+
treatment_group : str
427+
Treatment group indicator column.
428+
unit : str, optional
429+
Unit identifier column.
430+
431+
Returns
432+
-------
433+
tuple
434+
(treated_changes, control_changes) as numpy arrays.
435+
"""
436+
if unit is not None:
437+
# Unit-level changes: compute change for each unit across periods
438+
data_sorted = data.sort_values([unit, time])
439+
data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()
440+
441+
# Remove NaN from first period of each unit
442+
changes_data = data_sorted.dropna(subset=["_outcome_change"])
443+
444+
treated_changes = changes_data[
445+
changes_data[treatment_group] == 1
446+
]["_outcome_change"].values
447+
448+
control_changes = changes_data[
449+
changes_data[treatment_group] == 0
450+
]["_outcome_change"].values
451+
else:
452+
# Aggregate changes: compute mean change per period per group
453+
periods = sorted(data[time].unique())
454+
455+
treated_data = data[data[treatment_group] == 1]
456+
control_data = data[data[treatment_group] == 0]
457+
458+
# Compute period means
459+
treated_means = treated_data.groupby(time)[outcome].mean()
460+
control_means = control_data.groupby(time)[outcome].mean()
461+
462+
# Compute changes between consecutive periods
463+
treated_changes = np.diff(treated_means.values)
464+
control_changes = np.diff(control_means.values)
465+
466+
return treated_changes.astype(float), control_changes.astype(float)
467+
468+
469+
def equivalence_test_trends(
470+
data: pd.DataFrame,
471+
outcome: str,
472+
time: str,
473+
treatment_group: str,
474+
unit: str = None,
475+
pre_periods: list = None,
476+
equivalence_margin: float = None
477+
) -> dict:
478+
"""
479+
Perform equivalence testing (TOST) for parallel trends.
480+
481+
Tests whether the difference in trends is practically equivalent to zero
482+
using Two One-Sided Tests (TOST) procedure.
483+
484+
Parameters
485+
----------
486+
data : pd.DataFrame
487+
Panel data.
488+
outcome : str
489+
Name of outcome variable column.
490+
time : str
491+
Name of time period column.
492+
treatment_group : str
493+
Name of treatment group indicator column.
494+
unit : str, optional
495+
Name of unit identifier column.
496+
pre_periods : list, optional
497+
List of pre-treatment time periods.
498+
equivalence_margin : float, optional
499+
The margin for equivalence (delta). If None, uses 0.5 * pooled SD
500+
of outcome changes as a default.
501+
502+
Returns
503+
-------
504+
dict
505+
Dictionary containing:
506+
- mean_difference: Difference in mean changes
507+
- equivalence_margin: The margin used
508+
- lower_p_value: P-value for lower bound test
509+
- upper_p_value: P-value for upper bound test
510+
- tost_p_value: Maximum of the two p-values
511+
- equivalent: Boolean indicating equivalence at alpha=0.05
512+
"""
513+
# Get pre-treatment periods
514+
if pre_periods is None:
515+
all_periods = sorted(data[time].unique())
516+
mid_point = len(all_periods) // 2
517+
pre_periods = all_periods[:mid_point]
518+
519+
pre_data = data[data[time].isin(pre_periods)].copy()
520+
521+
# Compute outcome changes
522+
treated_changes, control_changes = _compute_outcome_changes(
523+
pre_data, outcome, time, treatment_group, unit
524+
)
525+
526+
if len(treated_changes) < 2 or len(control_changes) < 2:
527+
return {
528+
"mean_difference": np.nan,
529+
"equivalence_margin": np.nan,
530+
"lower_p_value": np.nan,
531+
"upper_p_value": np.nan,
532+
"tost_p_value": np.nan,
533+
"equivalent": None,
534+
"error": "Insufficient data",
535+
}
536+
537+
# Compute statistics
538+
mean_diff = np.mean(treated_changes) - np.mean(control_changes)
539+
se_diff = np.sqrt(
540+
np.var(treated_changes, ddof=1) / len(treated_changes) +
541+
np.var(control_changes, ddof=1) / len(control_changes)
542+
)
543+
544+
# Set equivalence margin if not provided
545+
if equivalence_margin is None:
546+
pooled_changes = np.concatenate([treated_changes, control_changes])
547+
equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)
548+
549+
# Degrees of freedom (Welch-Satterthwaite approximation)
550+
var_t = np.var(treated_changes, ddof=1)
551+
var_c = np.var(control_changes, ddof=1)
552+
n_t = len(treated_changes)
553+
n_c = len(control_changes)
554+
555+
df = ((var_t/n_t + var_c/n_c)**2 /
556+
((var_t/n_t)**2/(n_t-1) + (var_c/n_c)**2/(n_c-1)))
557+
558+
# TOST: Two one-sided tests
559+
# Test 1: H0: diff <= -margin vs H1: diff > -margin
560+
t_lower = (mean_diff - (-equivalence_margin)) / se_diff
561+
p_lower = stats.t.sf(t_lower, df)
562+
563+
# Test 2: H0: diff >= margin vs H1: diff < margin
564+
t_upper = (mean_diff - equivalence_margin) / se_diff
565+
p_upper = stats.t.cdf(t_upper, df)
566+
567+
# TOST p-value is the maximum of the two
568+
tost_p = max(p_lower, p_upper)
569+
570+
return {
571+
"mean_difference": mean_diff,
572+
"se_difference": se_diff,
573+
"equivalence_margin": equivalence_margin,
574+
"lower_t_stat": t_lower,
575+
"upper_t_stat": t_upper,
576+
"lower_p_value": p_lower,
577+
"upper_p_value": p_upper,
578+
"tost_p_value": tost_p,
579+
"degrees_of_freedom": df,
580+
"equivalent": bool(tost_p < 0.05),
581+
}

0 commit comments

Comments
 (0)