Skip to content

Commit 7c361ee

Browse files
igerberclaude
andcommitted
synthetic-control: address CI codex R5 — surface inner non-convergence during V search (P1)
During the nested V search, _inner_solve_W's convergence flag was discarded on every intermediate evaluation (univariate starts + objective calls), so the outer optimizer could silently rank truncated W*(V) solves if inner solves hit inner_max_iter; only the final re-solve was surfaced. Now _v_starts returns its inner-solve counts and _outer_solve_V tallies intermediate non-convergence across the univariate starts AND every objective evaluation, emitting one aggregated UserWarning when the rate exceeds 5% (mirrors the synthetic_did.py bootstrap-FW aggregation). Healthy fits (converging inner solves) stay silent — Basque Tier-2 unaffected. Regression: test_inner_v_search_nonconvergence_warning. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 3b98d28 commit 7c361ee

2 files changed

Lines changed: 47 additions & 5 deletions

File tree

diff_diff/synthetic_control.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -925,9 +925,13 @@ def _v_starts(
925925
rng: np.random.Generator,
926926
inner_max_iter: int,
927927
inner_min_decrease: float,
928-
) -> List[np.ndarray]:
928+
) -> Tuple[List[np.ndarray], int, int]:
929929
"""Build a list of DISTINCT starting ``theta`` vectors for the outer V search.
930930
931+
Returns ``(candidates, n_inner_solves, n_inner_nonconverged)`` — the latter two
932+
count the inner Frank-Wolfe solves run by the univariate-fit heuristic so the
933+
caller can surface aggregate intermediate non-convergence.
934+
931935
Heuristic starts: uniform V; inverse-row-variance V (computed from the
932936
UNSTANDARDIZED predictors ``X1``/``X0`` — on the standardized rows every variance
933937
is 1 by construction, so it would collapse to the uniform start); univariate-fit V
@@ -976,12 +980,17 @@ def _add_unique(t: Optional[np.ndarray], pool: List[np.ndarray]) -> None:
976980

977981
# univariate-fit start: v_i ∝ 1 / (pre-outcome MSPE of W solved with V=e_i).
978982
# Skipped entirely when enough candidates are already collected (saves k inner solves).
983+
inner_total = 0
984+
inner_nonconv = 0
979985
if len(candidates) < target:
980986
uni_mspe = np.empty(k)
981987
for i in range(k):
982988
e = np.zeros(k)
983989
e[i] = 1.0
984-
w_i, _ = _inner_solve_W(X1s, X0s, e, inner_max_iter, inner_min_decrease)
990+
w_i, conv_i = _inner_solve_W(X1s, X0s, e, inner_max_iter, inner_min_decrease)
991+
inner_total += 1
992+
if not conv_i:
993+
inner_nonconv += 1
985994
uni_mspe[i] = float(np.mean((Z1 - Z0 @ w_i) ** 2))
986995
inv_mspe = np.where(uni_mspe > 0, 1.0 / np.maximum(uni_mspe, 1e-12), 0.0)
987996
if np.sum(inv_mspe) > 0:
@@ -994,7 +1003,7 @@ def _add_unique(t: Optional[np.ndarray], pool: List[np.ndarray]) -> None:
9941003
attempts += 1
9951004
_add_unique(_to_theta(rng.dirichlet(np.ones(k))), candidates)
9961005

997-
return candidates[:target]
1006+
return candidates[:target], inner_total, inner_nonconv
9981007

9991008

10001009
def _outer_solve_V(
@@ -1021,9 +1030,17 @@ def _outer_solve_V(
10211030
w, converged = _inner_solve_W(X1s, X0s, v, inner_max_iter, inner_min_decrease)
10221031
return v, w, converged, float(np.mean((Z1 - Z0 @ w) ** 2))
10231032

1033+
# Track inner Frank-Wolfe non-convergence across ALL intermediate evaluations so
1034+
# the outer search cannot silently rank truncated W*(V) solves (codex). `_inner_solve_W`
1035+
# suppresses its own per-call warning during the search; we aggregate here.
1036+
_st = {"total": 0, "nonconv": 0}
1037+
10241038
def objective(theta: np.ndarray) -> float:
10251039
v = _softmax(theta)
1026-
w, _ = _inner_solve_W(X1s, X0s, v, inner_max_iter, inner_min_decrease)
1040+
w, conv = _inner_solve_W(X1s, X0s, v, inner_max_iter, inner_min_decrease)
1041+
_st["total"] += 1
1042+
if not conv:
1043+
_st["nonconv"] += 1
10271044
return float(np.mean((Z1 - Z0 @ w) ** 2))
10281045

10291046
nm_options = {"maxiter": 1000, "xatol": 1e-8, "fatol": 1e-8}
@@ -1038,9 +1055,11 @@ def objective(theta: np.ndarray) -> float:
10381055
powell_options["ftol"] = powell_options.pop("fatol")
10391056

10401057
rng = np.random.default_rng(seed)
1041-
starts = _v_starts(
1058+
starts, start_total, start_nonconv = _v_starts(
10421059
k, X1, X0, X1s, X0s, Z1, Z0, n_starts, rng, inner_max_iter, inner_min_decrease
10431060
)
1061+
_st["total"] += start_total
1062+
_st["nonconv"] += start_nonconv
10441063

10451064
best_x: np.ndarray = starts[0]
10461065
best_fun = np.inf
@@ -1071,6 +1090,21 @@ def objective(theta: np.ndarray) -> float:
10711090
stacklevel=3,
10721091
)
10731092

1093+
# Aggregate intermediate inner Frank-Wolfe non-convergence across the whole nested
1094+
# search (univariate starts + every objective evaluation). Per-call FW warnings are
1095+
# suppressed during the search, so without this the outer optimizer could silently
1096+
# rank truncated W*(V) solves. Threshold mirrors synthetic_did.py's 5% rule.
1097+
if _st["nonconv"] > 0.05 * max(_st["total"], 1):
1098+
warnings.warn(
1099+
f"Inner Frank-Wolfe did not converge on {_st['nonconv']} of {_st['total']} "
1100+
f"weight solves during nested V selection (inner_max_iter={inner_max_iter}); "
1101+
"the outer search may have ranked truncated W*(V) solutions, so the selected "
1102+
"V / donor weights / ATT may be sub-optimal. Increase inner_max_iter or relax "
1103+
"inner_min_decrease.",
1104+
UserWarning,
1105+
stacklevel=3,
1106+
)
1107+
10741108
v_star = _softmax(best_x)
10751109
w_star, converged = _inner_solve_W(X1s, X0s, v_star, inner_max_iter, inner_min_decrease)
10761110
mspe = float(np.mean((Z1 - Z0 @ w_star) ** 2))

tests/test_methodology_synthetic_control.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,14 @@ def test_outer_v_nonconvergence_warning():
314314
)
315315

316316

317+
def test_inner_v_search_nonconvergence_warning():
318+
# Intermediate inner solves during the nested V search must not be silent: forcing
319+
# inner_max_iter=1 makes them truncate, and the estimator emits an aggregated warning.
320+
df, _, _ = _make_panel()
321+
with pytest.warns(UserWarning, match="during nested V selection"):
322+
synthetic_control(df, "y", "treated", "unit", "year", seed=0, inner_max_iter=1)
323+
324+
317325
def test_n_starts_one_runs():
318326
# n_starts=1 uses only the uniform start (short-circuits the heuristic candidates)
319327
# and still produces a valid nested fit.

0 commit comments

Comments
 (0)