Skip to content

Commit 1d3f5a5

Browse files
authored
Merge pull request #202 from igerber/csa-review
Fix CallawaySantAnna propensity score estimation (IRLS)
2 parents 667e82f + aeb6ecc commit 1d3f5a5

9 files changed

Lines changed: 2559 additions & 1906 deletions

File tree

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Deferred items from PR reviews that were not addressed before merge.
6161
|-------|----------|----|----------|
6262
| Tutorial notebooks not executed in CI | `docs/tutorials/*.ipynb` | #159 | Low |
6363
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
64+
| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low |
6465

6566
---
6667

diff_diff/linalg.py

Lines changed: 217 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _detect_rank_deficiency(
116116

117117
# Compute pivoted QR decomposition: X @ P = Q @ R
118118
# P is a permutation matrix, represented as pivot indices
119-
Q, R, pivot = qr(X, mode='economic', pivoting=True)
119+
Q, R, pivot = qr(X, mode="economic", pivoting=True)
120120

121121
# Determine rank tolerance
122122
# R's qr() uses tol = 1e-07 by default, which is sqrt(eps) ≈ 1.49e-08
@@ -169,8 +169,7 @@ def _format_dropped_columns(
169169
return ""
170170

171171
if column_names is not None:
172-
names = [column_names[i] if i < len(column_names) else f"column {i}"
173-
for i in dropped_cols]
172+
names = [column_names[i] if i < len(column_names) else f"column {i}" for i in dropped_cols]
174173
if len(names) == 1:
175174
return f"'{names[0]}'"
176175
elif len(names) <= 5:
@@ -251,10 +250,12 @@ def _solve_ols_rust(
251250
cluster_ids: Optional[np.ndarray] = None,
252251
return_vcov: bool = True,
253252
return_fitted: bool = False,
254-
) -> Optional[Union[
255-
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
256-
Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
257-
]]:
253+
) -> Optional[
254+
Union[
255+
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
256+
Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
257+
]
258+
]:
258259
"""
259260
Rust backend implementation of solve_ols for full-rank matrices.
260261
@@ -447,8 +448,7 @@ def solve_ols(
447448
raise ValueError(f"y must be 1-dimensional, got shape {y.shape}")
448449
if X.shape[0] != y.shape[0]:
449450
raise ValueError(
450-
f"X and y must have same number of observations: "
451-
f"{X.shape[0]} vs {y.shape[0]}"
451+
f"X and y must have same number of observations: " f"{X.shape[0]} vs {y.shape[0]}"
452452
)
453453

454454
n, k = X.shape
@@ -484,7 +484,8 @@ def solve_ols(
484484
if skip_rank_check:
485485
if HAS_RUST_BACKEND and _rust_solve_ols is not None:
486486
result = _solve_ols_rust(
487-
X, y,
487+
X,
488+
y,
488489
cluster_ids=cluster_ids,
489490
return_vcov=return_vcov,
490491
return_fitted=return_fitted,
@@ -494,7 +495,8 @@ def solve_ols(
494495
# Fall through to NumPy on numerical instability
495496
# Fall through to Python without rank check (user guarantees full rank)
496497
return _solve_ols_numpy(
497-
X, y,
498+
X,
499+
y,
498500
cluster_ids=cluster_ids,
499501
return_vcov=return_vcov,
500502
return_fitted=return_fitted,
@@ -521,7 +523,8 @@ def solve_ols(
521523
# - No Rust → Python backend (works for all cases)
522524
if HAS_RUST_BACKEND and _rust_solve_ols is not None and not is_rank_deficient:
523525
result = _solve_ols_rust(
524-
X, y,
526+
X,
527+
y,
525528
cluster_ids=cluster_ids,
526529
return_vcov=return_vcov,
527530
return_fitted=return_fitted,
@@ -531,7 +534,8 @@ def solve_ols(
531534
# signaled us to fall back to Python backend
532535
if result is None:
533536
return _solve_ols_numpy(
534-
X, y,
537+
X,
538+
y,
535539
cluster_ids=cluster_ids,
536540
return_vcov=return_vcov,
537541
return_fitted=return_fitted,
@@ -555,7 +559,8 @@ def solve_ols(
555559
# and SVD disagreed about rank. Python's QR will re-detect and
556560
# apply R-style NaN handling for dropped columns.
557561
return _solve_ols_numpy(
558-
X, y,
562+
X,
563+
y,
559564
cluster_ids=cluster_ids,
560565
return_vcov=return_vcov,
561566
return_fitted=return_fitted,
@@ -569,7 +574,8 @@ def solve_ols(
569574
# Use NumPy implementation for rank-deficient cases (R-style NA handling)
570575
# or when Rust backend is not available
571576
return _solve_ols_numpy(
572-
X, y,
577+
X,
578+
y,
573579
cluster_ids=cluster_ids,
574580
return_vcov=return_vcov,
575581
return_fitted=return_fitted,
@@ -834,9 +840,7 @@ def _compute_robust_vcov_numpy(
834840
n_clusters = len(unique_clusters)
835841

836842
if n_clusters < 2:
837-
raise ValueError(
838-
f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}"
839-
)
843+
raise ValueError(f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}")
840844

841845
# Small-sample adjustment
842846
adjustment = (n_clusters / (n_clusters - 1)) * ((n - 1) / (n - k))
@@ -871,6 +875,193 @@ def _compute_robust_vcov_numpy(
871875
return vcov
872876

873877

878+
# Empirical threshold: coefficients above this magnitude suggest near-separation
879+
# in the logistic model (predicted probabilities collapse to 0/1).
880+
_LOGIT_SEPARATION_COEF_THRESHOLD = 10
881+
_LOGIT_SEPARATION_PROB_THRESHOLD = 1e-5
882+
883+
884+
def solve_logit(
885+
X: np.ndarray,
886+
y: np.ndarray,
887+
max_iter: int = 25,
888+
tol: float = 1e-8,
889+
check_separation: bool = True,
890+
rank_deficient_action: str = "warn",
891+
) -> Tuple[np.ndarray, np.ndarray]:
892+
"""
893+
Fit logistic regression via IRLS (Fisher scoring).
894+
895+
Matches R's ``glm(family=binomial)`` algorithm: iteratively reweighted
896+
least squares with working weights ``mu*(1-mu)`` and working response
897+
``eta + (y-mu)/(mu*(1-mu))``.
898+
899+
Parameters
900+
----------
901+
X : np.ndarray
902+
Feature matrix (n_samples, n_features). Intercept added automatically.
903+
y : np.ndarray
904+
Binary outcome (0/1).
905+
max_iter : int, default 25
906+
Maximum IRLS iterations (R's ``glm`` default).
907+
tol : float, default 1e-8
908+
Convergence tolerance on coefficient change (R's ``glm`` default).
909+
check_separation : bool, default True
910+
Whether to check for near-separation and emit warnings.
911+
rank_deficient_action : str, default "warn"
912+
How to handle rank-deficient design matrices:
913+
- "warn": Emit warning and drop columns (default)
914+
- "error": Raise ValueError
915+
- "silent": Drop columns silently
916+
917+
Returns
918+
-------
919+
beta : np.ndarray
920+
Fitted coefficients (including intercept as element 0).
921+
probs : np.ndarray
922+
Predicted probabilities.
923+
"""
924+
n, p = X.shape
925+
X_with_intercept = np.column_stack([np.ones(n), X])
926+
k = p + 1 # number of parameters including intercept
927+
928+
# Validate rank_deficient_action
929+
valid_actions = {"warn", "error", "silent"}
930+
if rank_deficient_action not in valid_actions:
931+
raise ValueError(
932+
f"rank_deficient_action must be one of {valid_actions}, "
933+
f"got '{rank_deficient_action}'"
934+
)
935+
936+
# Check rank deficiency once before iterating
937+
rank_info = _detect_rank_deficiency(X_with_intercept)
938+
rank, dropped_cols, _ = rank_info
939+
if len(dropped_cols) > 0:
940+
col_desc = _format_dropped_columns(dropped_cols)
941+
if rank_deficient_action == "error":
942+
raise ValueError(
943+
f"Rank-deficient design matrix in logistic regression: "
944+
f"dropping {col_desc}. Propensity score estimates may be unreliable."
945+
)
946+
elif rank_deficient_action == "warn":
947+
warnings.warn(
948+
f"Rank-deficient design matrix in logistic regression: "
949+
f"dropping {col_desc}. Propensity score estimates may be unreliable.",
950+
UserWarning,
951+
stacklevel=2,
952+
)
953+
kept_cols = np.array([i for i in range(k) if i not in dropped_cols])
954+
X_solve = X_with_intercept[:, kept_cols]
955+
else:
956+
kept_cols = np.arange(k)
957+
X_solve = X_with_intercept
958+
959+
# IRLS (Fisher scoring)
960+
beta_solve = np.zeros(X_solve.shape[1])
961+
converged = False
962+
963+
for iteration in range(max_iter):
964+
eta = X_solve @ beta_solve
965+
# Clip to prevent overflow in exp
966+
eta = np.clip(eta, -500, 500)
967+
mu = 1.0 / (1.0 + np.exp(-eta))
968+
# Clip mu to prevent zero working weights
969+
mu = np.clip(mu, 1e-10, 1 - 1e-10)
970+
971+
# Working weights and working response
972+
w = mu * (1.0 - mu)
973+
z = eta + (y - mu) / w
974+
975+
# Weighted least squares: solve (X'WX) beta = X'Wz
976+
sqrt_w = np.sqrt(w)
977+
Xw = X_solve * sqrt_w[:, None]
978+
zw = z * sqrt_w
979+
beta_new, _, _, _ = np.linalg.lstsq(Xw, zw, rcond=None)
980+
981+
# Check convergence
982+
if np.max(np.abs(beta_new - beta_solve)) < tol:
983+
beta_solve = beta_new
984+
converged = True
985+
break
986+
beta_solve = beta_new
987+
988+
# Final predicted probabilities
989+
eta_final = X_solve @ beta_solve
990+
eta_final = np.clip(eta_final, -500, 500)
991+
probs = 1.0 / (1.0 + np.exp(-eta_final))
992+
993+
# Warnings
994+
if not converged:
995+
warnings.warn(
996+
f"Logistic regression did not converge in {max_iter} iterations. "
997+
f"Propensity score estimates may be unreliable.",
998+
UserWarning,
999+
stacklevel=2,
1000+
)
1001+
1002+
if check_separation:
1003+
if np.max(np.abs(beta_solve)) > _LOGIT_SEPARATION_COEF_THRESHOLD:
1004+
warnings.warn(
1005+
"Large coefficients detected in propensity score model "
1006+
f"(max|beta| > {_LOGIT_SEPARATION_COEF_THRESHOLD}), "
1007+
"suggesting potential separation.",
1008+
UserWarning,
1009+
stacklevel=2,
1010+
)
1011+
n_extreme = int(
1012+
np.sum(
1013+
(probs < _LOGIT_SEPARATION_PROB_THRESHOLD)
1014+
| (probs > 1 - _LOGIT_SEPARATION_PROB_THRESHOLD)
1015+
)
1016+
)
1017+
if n_extreme > 0:
1018+
warnings.warn(
1019+
f"Near-separation detected in propensity score model: "
1020+
f"{n_extreme} of {n} observations have predicted probabilities "
1021+
f"within {_LOGIT_SEPARATION_PROB_THRESHOLD} of 0 or 1. ATT estimates may be sensitive to "
1022+
f"model specification.",
1023+
UserWarning,
1024+
stacklevel=2,
1025+
)
1026+
1027+
# Expand beta back to full size if columns were dropped
1028+
if len(dropped_cols) > 0:
1029+
beta_full = np.zeros(k)
1030+
beta_full[kept_cols] = beta_solve
1031+
else:
1032+
beta_full = beta_solve
1033+
1034+
return beta_full, probs
1035+
1036+
1037+
def _check_propensity_diagnostics(
1038+
pscore: np.ndarray,
1039+
trim_bound: float = 0.01,
1040+
) -> None:
1041+
"""
1042+
Warn if propensity scores are extreme.
1043+
1044+
Parameters
1045+
----------
1046+
pscore : np.ndarray
1047+
Predicted probabilities.
1048+
trim_bound : float, default 0.01
1049+
Trimming threshold.
1050+
"""
1051+
n_extreme = int(np.sum((pscore < trim_bound) | (pscore > 1 - trim_bound)))
1052+
if n_extreme > 0:
1053+
n_total = len(pscore)
1054+
pct = 100.0 * n_extreme / n_total
1055+
warnings.warn(
1056+
f"Propensity scores for {n_extreme} of {n_total} observations "
1057+
f"({pct:.1f}%) were outside [{trim_bound}, {1 - trim_bound}] "
1058+
f"and will be trimmed. This may indicate near-separation in "
1059+
f"the propensity score model.",
1060+
UserWarning,
1061+
stacklevel=2,
1062+
)
1063+
1064+
8741065
def compute_r_squared(
8751066
y: np.ndarray,
8761067
residuals: np.ndarray,
@@ -1149,7 +1340,8 @@ def fit(
11491340
if self.robust or effective_cluster_ids is not None:
11501341
# Use solve_ols with robust/cluster SEs
11511342
coefficients, residuals, fitted, vcov = solve_ols(
1152-
X, y,
1343+
X,
1344+
y,
11531345
cluster_ids=effective_cluster_ids,
11541346
return_fitted=True,
11551347
return_vcov=compute_vcov,
@@ -1158,7 +1350,8 @@ def fit(
11581350
else:
11591351
# Classical OLS - compute vcov separately
11601352
coefficients, residuals, fitted, _ = solve_ols(
1161-
X, y,
1353+
X,
1354+
y,
11621355
return_fitted=True,
11631356
return_vcov=False,
11641357
rank_deficient_action=self.rank_deficient_action,
@@ -1294,6 +1487,7 @@ def get_inference(
12941487
# Handle zero or negative SE (indicates perfect fit or numerical issues)
12951488
if se <= 0:
12961489
import warnings
1490+
12971491
warnings.warn(
12981492
f"Standard error is zero or negative (se={se}) for coefficient at index {index}. "
12991493
"This may indicate perfect multicollinearity or numerical issues.",
@@ -1319,6 +1513,7 @@ def get_inference(
13191513
# Warn if df is non-positive and fall back to normal distribution
13201514
if effective_df is not None and effective_df <= 0:
13211515
import warnings
1516+
13221517
warnings.warn(
13231518
f"Degrees of freedom is non-positive (df={effective_df}). "
13241519
"Using normal distribution instead of t-distribution for inference.",
@@ -1396,10 +1591,7 @@ def get_all_inference(
13961591
Inference results for each coefficient in order.
13971592
"""
13981593
self._check_fitted()
1399-
return [
1400-
self.get_inference(i, alpha=alpha, df=df)
1401-
for i in range(len(self.coefficients_))
1402-
]
1594+
return [self.get_inference(i, alpha=alpha, df=df) for i in range(len(self.coefficients_))]
14031595

14041596
def r_squared(self, adjusted: bool = False) -> float:
14051597
"""
@@ -1424,9 +1616,7 @@ def r_squared(self, adjusted: bool = False) -> float:
14241616
self._check_fitted()
14251617
# Use effective params for adjusted R² to match df correction
14261618
n_params = self.n_params_effective_ if adjusted else self.n_params_
1427-
return compute_r_squared(
1428-
self._y, self.residuals_, adjusted=adjusted, n_params=n_params
1429-
)
1619+
return compute_r_squared(self._y, self.residuals_, adjusted=adjusted, n_params=n_params)
14301620

14311621
def predict(self, X: np.ndarray) -> np.ndarray:
14321622
"""

0 commit comments

Comments
 (0)