Skip to content

Commit 20a126e

Browse files
committed
Vectorize loops and remove unused variables in triple_diff.py
Address code review feedback: - Vectorize 3 for loops in _doubly_robust() using numpy boolean indexing - Vectorize for loop in _compute_ipw_se() using np.where and boolean indexing - Remove unused variables: p_ref (line 857), n_pre, n_post - Fix import ordering via ruff
1 parent 959dbbd commit 20a126e

1 file changed

Lines changed: 24 additions & 39 deletions

File tree

diff_diff/triple_diff.py

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
compute_robust_se,
4545
)
4646

47-
4847
# =============================================================================
4948
# Results Classes
5049
# =============================================================================
@@ -853,9 +852,6 @@ def _ipw_estimation(
853852
p_cell_3 = np.clip(p_cell_3, self.pscore_trim, 1 - self.pscore_trim)
854853
p_cell_4 = np.clip(p_cell_4, self.pscore_trim, 1 - self.pscore_trim)
855854

856-
# Reference probability: P(G=1, P=1)
857-
p_ref = np.mean(cell_1)
858-
859855
# IPW estimator for DDD
860856
# The DDD-IPW estimator reweights each cell to have the same
861857
# covariate distribution as the effectively treated (G=1, P=1)
@@ -1022,29 +1018,22 @@ def _doubly_robust(
10221018
inf_11 += mu_fitted * cell_1.astype(float) / p_ref
10231019

10241020
# Cell 2 (G=1, P=0)
1025-
inf_10 = np.zeros(n)
10261021
w_10 = cell_2.astype(float) * (p_cell_1 / p_cell_2)
10271022
inf_10 = w_10 * (y - mu_fitted) / p_ref
1028-
# Add outcome model contribution for cell 2
1029-
for i in range(n):
1030-
if cell_2[i]:
1031-
inf_10[i] += mu_fitted[i] * (p_cell_1[i] / p_cell_2[i]) / p_ref
1023+
# Add outcome model contribution for cell 2 (vectorized)
1024+
inf_10[cell_2] += mu_fitted[cell_2] * (p_cell_1[cell_2] / p_cell_2[cell_2]) / p_ref
10321025

10331026
# Cell 3 (G=0, P=1)
1034-
inf_01 = np.zeros(n)
10351027
w_01 = cell_3.astype(float) * (p_cell_1 / p_cell_3)
10361028
inf_01 = w_01 * (y - mu_fitted) / p_ref
1037-
for i in range(n):
1038-
if cell_3[i]:
1039-
inf_01[i] += mu_fitted[i] * (p_cell_1[i] / p_cell_3[i]) / p_ref
1029+
# Add outcome model contribution for cell 3 (vectorized)
1030+
inf_01[cell_3] += mu_fitted[cell_3] * (p_cell_1[cell_3] / p_cell_3[cell_3]) / p_ref
10401031

10411032
# Cell 4 (G=0, P=0)
1042-
inf_00 = np.zeros(n)
10431033
w_00 = cell_4.astype(float) * (p_cell_1 / p_cell_4)
10441034
inf_00 = w_00 * (y - mu_fitted) / p_ref
1045-
for i in range(n):
1046-
if cell_4[i]:
1047-
inf_00[i] += mu_fitted[i] * (p_cell_1[i] / p_cell_4[i]) / p_ref
1035+
# Add outcome model contribution for cell 4 (vectorized)
1036+
inf_00[cell_4] += mu_fitted[cell_4] * (p_cell_1[cell_4] / p_cell_4[cell_4]) / p_ref
10481037

10491038
# Compute cell-time means using DR formula
10501039
def dr_mean(inf_vals, t_mask):
@@ -1071,9 +1060,6 @@ def dr_mean(inf_vals, t_mask):
10711060
# Use the simpler variance formula for the DDD estimator
10721061
# Var(DDD) ≈ sum of variances of cell means / cell_sizes
10731062

1074-
n_pre = np.sum(pre_mask)
1075-
n_post = np.sum(post_mask)
1076-
10771063
# Compute variances within each cell-time combination
10781064
def cell_var(cell_mask, t_mask, y_vals):
10791065
mask = cell_mask & t_mask
@@ -1143,32 +1129,31 @@ def _compute_ipw_se(
11431129
) -> float:
11441130
"""Compute standard error for IPW estimator using influence function."""
11451131
n = len(y)
1146-
pre_mask = T == 0
11471132
post_mask = T == 1
11481133

1149-
# Influence function for IPW estimator
1134+
# Influence function for IPW estimator (vectorized)
11501135
inf_func = np.zeros(n)
11511136

11521137
n_ref = np.sum(cell_1)
11531138
p_ref = n_ref / n
11541139

1155-
for i in range(n):
1156-
if post_mask[i]:
1157-
sign = 1.0
1158-
else:
1159-
sign = -1.0
1160-
1161-
if cell_1[i]:
1162-
inf_func[i] = sign * (y[i] - att) / p_ref
1163-
elif cell_2[i]:
1164-
w = p_cell_1[i] / p_cell_2[i]
1165-
inf_func[i] = -sign * y[i] * w / p_ref
1166-
elif cell_3[i]:
1167-
w = p_cell_1[i] / p_cell_3[i]
1168-
inf_func[i] = -sign * y[i] * w / p_ref
1169-
elif cell_4[i]:
1170-
w = p_cell_1[i] / p_cell_4[i]
1171-
inf_func[i] = sign * y[i] * w / p_ref
1140+
# Sign: +1 for post, -1 for pre
1141+
sign = np.where(post_mask, 1.0, -1.0)
1142+
1143+
# Cell 1 (G=1, P=1): sign * (y - att) / p_ref
1144+
inf_func[cell_1] = sign[cell_1] * (y[cell_1] - att) / p_ref
1145+
1146+
# Cell 2 (G=1, P=0): -sign * y * (p_cell_1 / p_cell_2) / p_ref
1147+
w_2 = p_cell_1[cell_2] / p_cell_2[cell_2]
1148+
inf_func[cell_2] = -sign[cell_2] * y[cell_2] * w_2 / p_ref
1149+
1150+
# Cell 3 (G=0, P=1): -sign * y * (p_cell_1 / p_cell_3) / p_ref
1151+
w_3 = p_cell_1[cell_3] / p_cell_3[cell_3]
1152+
inf_func[cell_3] = -sign[cell_3] * y[cell_3] * w_3 / p_ref
1153+
1154+
# Cell 4 (G=0, P=0): sign * y * (p_cell_1 / p_cell_4) / p_ref
1155+
w_4 = p_cell_1[cell_4] / p_cell_4[cell_4]
1156+
inf_func[cell_4] = sign[cell_4] * y[cell_4] * w_4 / p_ref
11721157

11731158
var_inf = np.var(inf_func, ddof=1)
11741159
se = np.sqrt(var_inf / n)

0 commit comments

Comments
 (0)