Skip to content

Commit 3e57eac

Browse files
authored
Merge pull request #177 from igerber/worktree-continuous-did
Add ContinuousDiD estimator (Callaway, Goodman-Bacon & Sant'Anna 2024)
2 parents 27fcc17 + 999da34 commit 3e57eac

19 files changed

Lines changed: 4505 additions & 301 deletions

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Deferred items from PR reviews that were not addressed before merge.
4444
| Issue | Location | PR | Priority |
4545
|-------|----------|----|----------|
4646
| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) |
47+
| Bootstrap NaN-gating gap: manual SE/CI/p-value without non-finite filtering or SE<=0 guard | `imputation_bootstrap.py`, `two_stage_bootstrap.py` | #177 | Medium — migrate to `compute_effect_bootstrap_stats` from `bootstrap_utils.py` |
4748

4849
#### Performance
4950

diff_diff/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
aggregate_to_cohorts,
7171
balance_panel,
7272
create_event_time,
73+
generate_continuous_did_data,
7374
generate_did_data,
7475
generate_ddd_data,
7576
generate_event_study_data,
@@ -122,6 +123,11 @@
122123
TripleDifferenceResults,
123124
triple_difference,
124125
)
126+
from diff_diff.continuous_did import (
127+
ContinuousDiD,
128+
ContinuousDiDResults,
129+
DoseResponseCurve,
130+
)
125131
from diff_diff.trop import (
126132
TROP,
127133
TROPResults,
@@ -161,6 +167,7 @@
161167
"MultiPeriodDiD",
162168
"SyntheticDiD",
163169
"CallawaySantAnna",
170+
"ContinuousDiD",
164171
"SunAbraham",
165172
"ImputationDiD",
166173
"TwoStageDiD",
@@ -181,6 +188,8 @@
181188
"CallawaySantAnnaResults",
182189
"CSBootstrapResults",
183190
"GroupTimeEffect",
191+
"ContinuousDiDResults",
192+
"DoseResponseCurve",
184193
"SunAbrahamResults",
185194
"SABootstrapResults",
186195
"ImputationDiDResults",
@@ -228,6 +237,7 @@
228237
"generate_ddd_data",
229238
"generate_panel_data",
230239
"generate_event_study_data",
240+
"generate_continuous_did_data",
231241
"create_event_time",
232242
"aggregate_to_cohorts",
233243
"rank_control_units",

diff_diff/bootstrap_utils.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
"""
2+
Shared bootstrap utilities for multiplier bootstrap inference.
3+
4+
Provides weight generation, percentile CI, and p-value helpers used by
5+
both CallawaySantAnna and ContinuousDiD estimators.
6+
"""
7+
8+
import warnings
9+
from typing import Optional, Tuple
10+
11+
import numpy as np
12+
13+
from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights
14+
15+
__all__ = [
16+
"generate_bootstrap_weights",
17+
"generate_bootstrap_weights_batch",
18+
"generate_bootstrap_weights_batch_numpy",
19+
"compute_percentile_ci",
20+
"compute_bootstrap_pvalue",
21+
"compute_effect_bootstrap_stats",
22+
]
23+
24+
25+
def generate_bootstrap_weights(
26+
n_units: int,
27+
weight_type: str,
28+
rng: np.random.Generator,
29+
) -> np.ndarray:
30+
"""
31+
Generate bootstrap weights for multiplier bootstrap.
32+
33+
Parameters
34+
----------
35+
n_units : int
36+
Number of units (clusters) to generate weights for.
37+
weight_type : str
38+
Type of weights: "rademacher", "mammen", or "webb".
39+
rng : np.random.Generator
40+
Random number generator.
41+
42+
Returns
43+
-------
44+
np.ndarray
45+
Array of bootstrap weights with shape (n_units,).
46+
"""
47+
if weight_type == "rademacher":
48+
return rng.choice([-1.0, 1.0], size=n_units)
49+
elif weight_type == "mammen":
50+
sqrt5 = np.sqrt(5)
51+
val1 = -(sqrt5 - 1) / 2
52+
val2 = (sqrt5 + 1) / 2
53+
p1 = (sqrt5 + 1) / (2 * sqrt5)
54+
return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1])
55+
elif weight_type == "webb":
56+
values = np.array([
57+
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
58+
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
59+
])
60+
return rng.choice(values, size=n_units)
61+
else:
62+
raise ValueError(
63+
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
64+
f"got '{weight_type}'"
65+
)
66+
67+
68+
def generate_bootstrap_weights_batch(
69+
n_bootstrap: int,
70+
n_units: int,
71+
weight_type: str,
72+
rng: np.random.Generator,
73+
) -> np.ndarray:
74+
"""
75+
Generate all bootstrap weights at once (vectorized).
76+
77+
Uses Rust backend if available for parallel generation.
78+
79+
Parameters
80+
----------
81+
n_bootstrap : int
82+
Number of bootstrap iterations.
83+
n_units : int
84+
Number of units (clusters) to generate weights for.
85+
weight_type : str
86+
Type of weights: "rademacher", "mammen", or "webb".
87+
rng : np.random.Generator
88+
Random number generator.
89+
90+
Returns
91+
-------
92+
np.ndarray
93+
Array of bootstrap weights with shape (n_bootstrap, n_units).
94+
"""
95+
if HAS_RUST_BACKEND and _rust_bootstrap_weights is not None:
96+
seed = rng.integers(0, 2**63 - 1)
97+
return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed)
98+
return generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng)
99+
100+
101+
def generate_bootstrap_weights_batch_numpy(
102+
n_bootstrap: int,
103+
n_units: int,
104+
weight_type: str,
105+
rng: np.random.Generator,
106+
) -> np.ndarray:
107+
"""
108+
NumPy fallback implementation of :func:`generate_bootstrap_weights_batch`.
109+
110+
Parameters
111+
----------
112+
n_bootstrap : int
113+
Number of bootstrap iterations.
114+
n_units : int
115+
Number of units (clusters) to generate weights for.
116+
weight_type : str
117+
Type of weights: "rademacher", "mammen", or "webb".
118+
rng : np.random.Generator
119+
Random number generator.
120+
121+
Returns
122+
-------
123+
np.ndarray
124+
Array of bootstrap weights with shape (n_bootstrap, n_units).
125+
"""
126+
if weight_type == "rademacher":
127+
return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units))
128+
elif weight_type == "mammen":
129+
sqrt5 = np.sqrt(5)
130+
val1 = -(sqrt5 - 1) / 2
131+
val2 = (sqrt5 + 1) / 2
132+
p1 = (sqrt5 + 1) / (2 * sqrt5)
133+
return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1])
134+
elif weight_type == "webb":
135+
values = np.array([
136+
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
137+
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
138+
])
139+
return rng.choice(values, size=(n_bootstrap, n_units))
140+
else:
141+
raise ValueError(
142+
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
143+
f"got '{weight_type}'"
144+
)
145+
146+
147+
def compute_percentile_ci(
148+
boot_dist: np.ndarray,
149+
alpha: float,
150+
) -> Tuple[float, float]:
151+
"""
152+
Compute percentile confidence interval from bootstrap distribution.
153+
154+
Parameters
155+
----------
156+
boot_dist : np.ndarray
157+
Bootstrap distribution (1-D array).
158+
alpha : float
159+
Significance level (e.g., 0.05 for 95% CI).
160+
161+
Returns
162+
-------
163+
tuple of float
164+
``(lower, upper)`` confidence interval bounds.
165+
"""
166+
lower = float(np.percentile(boot_dist, alpha / 2 * 100))
167+
upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
168+
return (lower, upper)
169+
170+
171+
def compute_bootstrap_pvalue(
172+
original_effect: float,
173+
boot_dist: np.ndarray,
174+
n_valid: Optional[int] = None,
175+
) -> float:
176+
"""
177+
Compute two-sided bootstrap p-value using the percentile method.
178+
179+
Parameters
180+
----------
181+
original_effect : float
182+
Original point estimate.
183+
boot_dist : np.ndarray
184+
Bootstrap distribution of the effect.
185+
n_valid : int, optional
186+
Number of valid bootstrap samples for p-value floor.
187+
If None, uses ``len(boot_dist)``.
188+
189+
Returns
190+
-------
191+
float
192+
Two-sided bootstrap p-value.
193+
"""
194+
if original_effect >= 0:
195+
p_one_sided = np.mean(boot_dist <= 0)
196+
else:
197+
p_one_sided = np.mean(boot_dist >= 0)
198+
199+
p_value = min(2 * p_one_sided, 1.0)
200+
n_for_floor = n_valid if n_valid is not None else len(boot_dist)
201+
p_value = max(p_value, 1 / (n_for_floor + 1))
202+
return float(p_value)
203+
204+
205+
def compute_effect_bootstrap_stats(
206+
original_effect: float,
207+
boot_dist: np.ndarray,
208+
alpha: float = 0.05,
209+
context: str = "bootstrap distribution",
210+
) -> Tuple[float, Tuple[float, float], float]:
211+
"""
212+
Compute bootstrap statistics for a single effect.
213+
214+
Filters non-finite samples, returning NaN for all statistics if
215+
fewer than 50% of samples are valid.
216+
217+
Parameters
218+
----------
219+
original_effect : float
220+
Original point estimate.
221+
boot_dist : np.ndarray
222+
Bootstrap distribution of the effect.
223+
alpha : float, default=0.05
224+
Significance level.
225+
context : str, optional
226+
Description for warning messages.
227+
228+
Returns
229+
-------
230+
se : float
231+
Bootstrap standard error.
232+
ci : tuple of float
233+
Percentile confidence interval.
234+
p_value : float
235+
Bootstrap p-value.
236+
"""
237+
if not np.isfinite(original_effect):
238+
return np.nan, (np.nan, np.nan), np.nan
239+
240+
finite_mask = np.isfinite(boot_dist)
241+
n_valid = np.sum(finite_mask)
242+
n_total = len(boot_dist)
243+
244+
if n_valid < n_total:
245+
n_nonfinite = n_total - n_valid
246+
warnings.warn(
247+
f"Dropping {n_nonfinite}/{n_total} non-finite bootstrap samples "
248+
f"in {context}. Bootstrap estimates based on remaining valid samples.",
249+
RuntimeWarning,
250+
stacklevel=3,
251+
)
252+
253+
if n_valid < n_total * 0.5:
254+
warnings.warn(
255+
f"Too few valid bootstrap samples ({n_valid}/{n_total}) in {context}. "
256+
"Returning NaN for SE/CI/p-value to signal invalid inference.",
257+
RuntimeWarning,
258+
stacklevel=3,
259+
)
260+
return np.nan, (np.nan, np.nan), np.nan
261+
262+
valid_dist = boot_dist[finite_mask]
263+
se = float(np.std(valid_dist, ddof=1))
264+
265+
# Guard: if SE is not finite or zero, all inference fields must be NaN.
266+
if not np.isfinite(se) or se <= 0:
267+
warnings.warn(
268+
f"Bootstrap SE is non-finite or zero (n_valid={n_valid}) in {context}. "
269+
"Returning NaN for SE/CI/p-value.",
270+
RuntimeWarning,
271+
stacklevel=3,
272+
)
273+
return np.nan, (np.nan, np.nan), np.nan
274+
275+
ci = compute_percentile_ci(valid_dist, alpha)
276+
p_value = compute_bootstrap_pvalue(
277+
original_effect, valid_dist, n_valid=len(valid_dist)
278+
)
279+
return se, ci, p_value

0 commit comments

Comments
 (0)