Skip to content

Commit 999da34

Browse files
igerberclaude
andcommitted
Fix not-yet-treated control mask to respect anticipation parameter
The not-yet-treated control group in both ContinuousDiD and CallawaySantAnna used `G > t` instead of `G > t + anticipation`, incorrectly including cohorts in the anticipation window as controls. This matches R's `did::compute.att_gt()` logic. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5449bbb commit 999da34

5 files changed

Lines changed: 107 additions & 3 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,9 @@ def _compute_dose_response_gt(
690690
control_mask = never_treated_mask
691691
else:
692692
# Not-yet-treated: never-treated + first_treat > t
693-
control_mask = never_treated_mask | ((unit_cohorts > t) & (unit_cohorts != g))
693+
control_mask = never_treated_mask | (
694+
(unit_cohorts > t + self.anticipation) & (unit_cohorts != g)
695+
)
694696
n_control = int(np.sum(control_mask))
695697
if n_control == 0:
696698
warnings.warn(

diff_diff/staggered.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,9 @@ def _compute_att_gt_fast(
491491
else: # not_yet_treated
492492
# Not yet treated at time t: never-treated OR (first_treat > t AND not cohort g)
493493
# Must exclude cohort g since they are the treated group for this ATT(g,t)
494-
control_mask = never_treated_mask | ((unit_cohorts > t) & (unit_cohorts != g))
494+
control_mask = never_treated_mask | (
495+
(unit_cohorts > t + self.anticipation) & (unit_cohorts != g)
496+
)
495497

496498
# Extract outcomes for base and post periods
497499
y_base = outcome_matrix[:, base_col]

docs/methodology/REGISTRY.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1:
341341
- Anticipation: `anticipation` parameter shifts reference period
342342
- Group aggregation includes periods t >= g - anticipation (not just t >= g)
343343
- Both analytical SE and bootstrap SE aggregation respect anticipation
344+
- Not-yet-treated + anticipation: control mask uses `G > t + anticipation`
345+
(not just `G > t`) to exclude cohorts in the anticipation window
344346
- Rank-deficient design matrix (covariate collinearity):
345347
- Detection: Pivoted QR decomposition with tolerance `1e-07` (R's `qr()` default)
346348
- Handling: Warns and drops linearly dependent columns, sets NA for dropped coefficients (R-style, matches `lm()`)
@@ -378,7 +380,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1:
378380
- Always excludes cohort g from controls when computing ATT(g,t)
379381
- This applies to both pre-treatment (t < g) and post-treatment (t >= g) periods
380382
- For pre-treatment periods: even though cohort g hasn't been treated yet at time t, they are the treated group for this ATT(g,t) and cannot serve as their own controls
381-
- Control mask: `never_treated OR (first_treat > t AND first_treat != g)`
383+
- Control mask: `never_treated OR (first_treat > t + anticipation AND first_treat != g)`
382384

383385
**Reference implementation(s):**
384386
- R: `did::att_gt()` (Callaway & Sant'Anna's official package)
@@ -427,6 +429,10 @@ This is stronger than standard PT because it conditions on specific dose values.
427429
- **All-same dose**: B-spline basis collapses; ACRT(d) = 0 everywhere.
428430
- **Rank deficiency**: When n_treated <= n_basis, cell is skipped.
429431
- **Balanced panel required**: Matches R `contdid` v0.1.0.
432+
- **Anticipation + not-yet-treated**: Control mask uses `G > t + anticipation`
433+
(not just `G > t`) to exclude cohorts in the anticipation window from
434+
not-yet-treated controls. When `anticipation=0` (default), behavior is
435+
unchanged.
430436
- **Boundary knots**: Knots are built once from all treated doses (global, not per-cell) to ensure a common basis across (g,t) cells for aggregation. Evaluation grid is clamped to training-dose boundary knots (`range(dose)`). R's `contdid` v0.1.0 has an inconsistency where `splines2::bSpline(dvals)` uses `range(dvals)` instead of `range(dose)`, which can produce extrapolation artifacts at dose grid extremes. Our approach avoids extrapolation and is methodologically sound.
431437

432438
### Implementation Checklist

tests/test_continuous_did.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,66 @@ def test_anticipation_event_study(self):
801801
)
802802
assert np.isfinite(results.event_study_effects[-1]["effect"])
803803

804+
def test_anticipation_not_yet_treated_excludes_anticipation_window(self):
805+
"""Not-yet-treated controls must exclude cohorts in the anticipation window.
806+
807+
With anticipation=1 and cohort g=3, computing ATT(g=3, t=4) should use
808+
threshold t + anticipation = 5, so cohort g=5 (unit_cohorts == 5) fails
809+
> 5 and is correctly excluded. Without the fix, threshold is t=4 and
810+
cohort g=5 passes > 4, contaminating controls with treated units.
811+
"""
812+
rng = np.random.default_rng(42)
813+
n_per_group = 20
814+
periods = [1, 2, 3, 4, 5, 6]
815+
816+
rows = []
817+
# Never-treated group
818+
for i in range(n_per_group):
819+
uid = i
820+
for t in periods:
821+
rows.append({
822+
"unit": uid, "period": t, "first_treat": 0,
823+
"dose": 0.0, "outcome": rng.normal(0, 0.5),
824+
})
825+
826+
# Early cohort: g=3, treatment effect = +5*dose at t>=3
827+
for i in range(n_per_group):
828+
uid = n_per_group + i
829+
d = rng.uniform(1, 3)
830+
for t in periods:
831+
y = rng.normal(0, 0.5) + (5.0 * d if t >= 3 else 0)
832+
rows.append({
833+
"unit": uid, "period": t, "first_treat": 3,
834+
"dose": d, "outcome": y,
835+
})
836+
837+
# Late cohort: g=5, treatment effect = +5*dose at t>=5
838+
for i in range(n_per_group):
839+
uid = 2 * n_per_group + i
840+
d = rng.uniform(1, 3)
841+
for t in periods:
842+
y = rng.normal(0, 0.5) + (5.0 * d if t >= 5 else 0)
843+
rows.append({
844+
"unit": uid, "period": t, "first_treat": 5,
845+
"dose": d, "outcome": y,
846+
})
847+
848+
data = pd.DataFrame(rows)
849+
850+
est = ContinuousDiD(
851+
anticipation=1, control_group="not_yet_treated", n_bootstrap=0,
852+
)
853+
results = est.fit(
854+
data, "outcome", "unit", "period", "first_treat", "dose",
855+
)
856+
857+
assert np.isfinite(results.overall_att), (
858+
"overall_att should be finite with anticipation + not_yet_treated"
859+
)
860+
assert results.dose_response_att is not None, (
861+
"dose-response curve should exist"
862+
)
863+
804864

805865
class TestEmptyPostTreatment:
806866
"""Test guard for empty post-treatment cells."""

tests/test_staggered.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2692,6 +2692,40 @@ def test_group_effects_anticipation_boundary(self):
26922692
post_treatment = [t for (gg, t) in gt_for_group if t >= g]
26932693
assert len(post_treatment) > 0, "Should have post-treatment effects"
26942694

2695+
def test_not_yet_treated_with_anticipation_excludes_anticipation_window(self):
2696+
"""Not-yet-treated controls must exclude cohorts in the anticipation window.
2697+
2698+
With anticipation=1, the control mask should use G > t + anticipation
2699+
(not just G > t). Without the fix, cohorts about to be treated are
2700+
incorrectly included as controls, biasing pre-treatment ATTs toward
2701+
the treatment effect (~3.0) instead of near zero.
2702+
"""
2703+
data = generate_staggered_data(
2704+
n_units=100, n_periods=10, n_cohorts=2,
2705+
treatment_effect=3.0, seed=42,
2706+
)
2707+
2708+
cs = CallawaySantAnna(anticipation=1, control_group="not_yet_treated")
2709+
result = cs.fit(
2710+
data, outcome="outcome", unit="unit",
2711+
time="time", first_treat="first_treat",
2712+
)
2713+
2714+
groups = sorted(
2715+
g for g in data[data["first_treat"] > 0]["first_treat"].unique()
2716+
)
2717+
2718+
for g in groups:
2719+
for (gg, t), eff in result.group_time_effects.items():
2720+
if gg != g:
2721+
continue
2722+
# Pre-treatment: t < g - anticipation
2723+
if t < g - 1:
2724+
assert abs(eff["effect"]) < 1.5, (
2725+
f"Pre-treatment ATT(g={g}, t={t}) = {eff['effect']:.3f} "
2726+
f"should be near zero (< 1.5); contaminated controls?"
2727+
)
2728+
26952729

26962730
class TestCallawaySantAnnaTStatNaN:
26972731
"""Tests for NaN t_stat when SE is invalid."""

0 commit comments

Comments
 (0)