|
19 | 19 |
|
20 | 20 | from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect |
21 | 21 | from diff_diff.utils import ( |
| 22 | + WildBootstrapResults, |
22 | 23 | compute_confidence_interval, |
23 | 24 | compute_p_value, |
24 | 25 | compute_robust_se, |
@@ -279,22 +280,9 @@ def fit( |
279 | 280 | if self.inference == "wild_bootstrap" and self.cluster is not None: |
280 | 281 | # Wild cluster bootstrap for few-cluster inference |
281 | 282 | cluster_ids = data[self.cluster].values |
282 | | - bootstrap_results = wild_bootstrap_se( |
283 | | - X, y, residuals, cluster_ids, |
284 | | - coefficient_index=att_idx, |
285 | | - n_bootstrap=self.n_bootstrap, |
286 | | - weight_type=self.bootstrap_weights, |
287 | | - alpha=self.alpha, |
288 | | - seed=self.seed, |
289 | | - return_distribution=False |
| 283 | + se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference( |
| 284 | + X, y, residuals, cluster_ids, att_idx |
290 | 285 | ) |
291 | | - self._bootstrap_results = bootstrap_results |
292 | | - se = bootstrap_results.se |
293 | | - p_value = bootstrap_results.p_value |
294 | | - conf_int = (bootstrap_results.ci_lower, bootstrap_results.ci_upper) |
295 | | - t_stat = bootstrap_results.t_stat_original |
296 | | - # Also compute vcov for storage (using cluster-robust for consistency) |
297 | | - vcov = compute_robust_se(X, residuals, cluster_ids) |
298 | 286 | elif self.cluster is not None: |
299 | 287 | cluster_ids = data[self.cluster].values |
300 | 288 | vcov = compute_robust_se(X, residuals, cluster_ids) |
@@ -408,6 +396,56 @@ def _fit_ols(self, X: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray |
408 | 396 |
|
409 | 397 | return coefficients, residuals, fitted, r_squared |
410 | 398 |
|
| 399 | + def _run_wild_bootstrap_inference( |
| 400 | + self, |
| 401 | + X: np.ndarray, |
| 402 | + y: np.ndarray, |
| 403 | + residuals: np.ndarray, |
| 404 | + cluster_ids: np.ndarray, |
| 405 | + coefficient_index: int, |
| 406 | + ) -> Tuple[float, float, Tuple[float, float], float, np.ndarray, WildBootstrapResults]: |
| 407 | + """ |
| 408 | + Run wild cluster bootstrap inference. |
| 409 | +
|
| 410 | + Parameters |
| 411 | + ---------- |
| 412 | + X : np.ndarray |
| 413 | + Design matrix. |
| 414 | + y : np.ndarray |
| 415 | + Outcome vector. |
| 416 | + residuals : np.ndarray |
| 417 | + OLS residuals. |
| 418 | + cluster_ids : np.ndarray |
| 419 | + Cluster identifiers for each observation. |
| 420 | + coefficient_index : int |
| 421 | + Index of the coefficient to compute inference for. |
| 422 | +
|
| 423 | + Returns |
| 424 | + ------- |
| 425 | + tuple |
| 426 | + (se, p_value, conf_int, t_stat, vcov, bootstrap_results) |
| 427 | + """ |
| 428 | + bootstrap_results = wild_bootstrap_se( |
| 429 | + X, y, residuals, cluster_ids, |
| 430 | + coefficient_index=coefficient_index, |
| 431 | + n_bootstrap=self.n_bootstrap, |
| 432 | + weight_type=self.bootstrap_weights, |
| 433 | + alpha=self.alpha, |
| 434 | + seed=self.seed, |
| 435 | + return_distribution=False |
| 436 | + ) |
| 437 | + self._bootstrap_results = bootstrap_results |
| 438 | + |
| 439 | + se = bootstrap_results.se |
| 440 | + p_value = bootstrap_results.p_value |
| 441 | + conf_int = (bootstrap_results.ci_lower, bootstrap_results.ci_upper) |
| 442 | + t_stat = bootstrap_results.t_stat_original |
| 443 | + |
| 444 | + # Also compute vcov for storage (using cluster-robust for consistency) |
| 445 | + vcov = compute_robust_se(X, residuals, cluster_ids) |
| 446 | + |
| 447 | + return se, p_value, conf_int, t_stat, vcov, bootstrap_results |
| 448 | + |
411 | 449 | def _parse_formula( |
412 | 450 | self, formula: str, data: pd.DataFrame |
413 | 451 | ) -> Tuple[str, str, str, Optional[List[str]]]: |
|
0 commit comments