Skip to content

Commit 0adb638

Browse files
igerberclaude
andcommitted
synthetic-control: address CI codex R3 — bound _v_starts startup cost (P2)
_v_starts() eagerly computed the inverse-variance and univariate-fit heuristic candidates (the latter = O(k) inner Frank-Wolfe solves) before truncating to n_starts, so n_starts=1 still paid the univariate loop. Generate candidates lazily and stop once `target = max(n_starts, 1)` are collected: n_starts=1 now returns the uniform start without the univariate loop. Candidate ORDER is unchanged, so any given n_starts yields the same set as before (default n_starts=4 is identical — Basque Tier-2 parity preserved) — only unused work is skipped. Regression: test_n_starts_one_runs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 95cd4f9 commit 0adb638

2 files changed

Lines changed: 41 additions & 22 deletions

File tree

diff_diff/synthetic_control.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -944,40 +944,50 @@ def _to_theta(v: np.ndarray) -> Optional[np.ndarray]:
944944
theta = theta - np.mean(theta)
945945
return theta if np.all(np.isfinite(theta)) else None
946946

947+
# Candidates are generated lazily and we stop as soon as n_starts are collected,
948+
# so a small n_starts does not pay for heuristic starts it would only discard. In
949+
# particular n_starts=1 returns the uniform start without running the O(k) univariate
950+
# inner-solve loop below. The candidate ORDER (uniform -> inverse-variance ->
951+
# univariate-fit -> Dirichlet) is unchanged, so any given n_starts yields the same
952+
# set as before — only unused work is skipped.
953+
target = max(n_starts, 1)
947954
candidates: List[np.ndarray] = [np.zeros(k)] # uniform V
948955

949956
# inverse row variance of the standardized predictors over donors+treated
950-
combined = np.column_stack([X0s, X1s.reshape(-1, 1)])
951-
row_var = np.var(combined, axis=1, ddof=1)
952-
inv_var = np.where(row_var > 0, 1.0 / np.maximum(row_var, 1e-12), 0.0)
953-
if np.sum(inv_var) > 0:
954-
t = _to_theta(inv_var / np.sum(inv_var))
955-
if t is not None:
956-
candidates.append(t)
957-
958-
# univariate-fit start: v_i ∝ 1 / (pre-outcome MSPE of W solved with V=e_i)
959-
uni_mspe = np.empty(k)
960-
for i in range(k):
961-
e = np.zeros(k)
962-
e[i] = 1.0
963-
w_i, _ = _inner_solve_W(X1s, X0s, e, inner_max_iter, inner_min_decrease)
964-
uni_mspe[i] = float(np.mean((Z1 - Z0 @ w_i) ** 2))
965-
inv_mspe = np.where(uni_mspe > 0, 1.0 / np.maximum(uni_mspe, 1e-12), 0.0)
966-
if np.sum(inv_mspe) > 0:
967-
t = _to_theta(inv_mspe / np.sum(inv_mspe))
968-
if t is not None:
969-
candidates.append(t)
957+
if len(candidates) < target:
958+
combined = np.column_stack([X0s, X1s.reshape(-1, 1)])
959+
row_var = np.var(combined, axis=1, ddof=1)
960+
inv_var = np.where(row_var > 0, 1.0 / np.maximum(row_var, 1e-12), 0.0)
961+
if np.sum(inv_var) > 0:
962+
t = _to_theta(inv_var / np.sum(inv_var))
963+
if t is not None:
964+
candidates.append(t)
965+
966+
# univariate-fit start: v_i ∝ 1 / (pre-outcome MSPE of W solved with V=e_i).
967+
# Skipped entirely when enough candidates are already collected (saves k inner solves).
968+
if len(candidates) < target:
969+
uni_mspe = np.empty(k)
970+
for i in range(k):
971+
e = np.zeros(k)
972+
e[i] = 1.0
973+
w_i, _ = _inner_solve_W(X1s, X0s, e, inner_max_iter, inner_min_decrease)
974+
uni_mspe[i] = float(np.mean((Z1 - Z0 @ w_i) ** 2))
975+
inv_mspe = np.where(uni_mspe > 0, 1.0 / np.maximum(uni_mspe, 1e-12), 0.0)
976+
if np.sum(inv_mspe) > 0:
977+
t = _to_theta(inv_mspe / np.sum(inv_mspe))
978+
if t is not None:
979+
candidates.append(t)
970980

971981
# random Dirichlet draws to reach n_starts (bounded attempts as a backstop)
972982
attempts = 0
973983
max_attempts = 10 * n_starts + 20
974-
while len(candidates) < n_starts and attempts < max_attempts:
984+
while len(candidates) < target and attempts < max_attempts:
975985
attempts += 1
976986
t = _to_theta(rng.dirichlet(np.ones(k)))
977987
if t is not None:
978988
candidates.append(t)
979989

980-
return candidates[: max(n_starts, 1)]
990+
return candidates[:target]
981991

982992

983993
def _outer_solve_V(

tests/test_methodology_synthetic_control.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,15 @@ def test_outer_v_nonconvergence_warning():
314314
)
315315

316316

317+
def test_n_starts_one_runs():
318+
# n_starts=1 uses only the uniform start (short-circuits the heuristic candidates)
319+
# and still produces a valid nested fit.
320+
df, _, _ = _make_panel()
321+
res = synthetic_control(df, "y", "treated", "unit", "year", seed=0, n_starts=1)
322+
assert np.isfinite(res.att)
323+
assert abs(sum(res.donor_weights.values()) - 1.0) < 1e-6
324+
325+
317326
def test_non_finite_outcome_rejected():
318327
df, years, T0 = _make_panel()
319328
df = df.copy()

0 commit comments

Comments
 (0)