@@ -116,7 +116,7 @@ def _detect_rank_deficiency(
116116
117117 # Compute pivoted QR decomposition: X @ P = Q @ R
118118 # P is a permutation matrix, represented as pivot indices
119- Q , R , pivot = qr (X , mode = ' economic' , pivoting = True )
119+ Q , R , pivot = qr (X , mode = " economic" , pivoting = True )
120120
121121 # Determine rank tolerance
122122 # R's qr() uses tol = 1e-07 by default, which is sqrt(eps) ≈ 1.49e-08
@@ -169,8 +169,7 @@ def _format_dropped_columns(
169169 return ""
170170
171171 if column_names is not None :
172- names = [column_names [i ] if i < len (column_names ) else f"column { i } "
173- for i in dropped_cols ]
172+ names = [column_names [i ] if i < len (column_names ) else f"column { i } " for i in dropped_cols ]
174173 if len (names ) == 1 :
175174 return f"'{ names [0 ]} '"
176175 elif len (names ) <= 5 :
@@ -251,10 +250,12 @@ def _solve_ols_rust(
251250 cluster_ids : Optional [np .ndarray ] = None ,
252251 return_vcov : bool = True ,
253252 return_fitted : bool = False ,
254- ) -> Optional [Union [
255- Tuple [np .ndarray , np .ndarray , Optional [np .ndarray ]],
256- Tuple [np .ndarray , np .ndarray , np .ndarray , Optional [np .ndarray ]],
257- ]]:
253+ ) -> Optional [
254+ Union [
255+ Tuple [np .ndarray , np .ndarray , Optional [np .ndarray ]],
256+ Tuple [np .ndarray , np .ndarray , np .ndarray , Optional [np .ndarray ]],
257+ ]
258+ ]:
258259 """
259260 Rust backend implementation of solve_ols for full-rank matrices.
260261
@@ -447,8 +448,7 @@ def solve_ols(
447448 raise ValueError (f"y must be 1-dimensional, got shape { y .shape } " )
448449 if X .shape [0 ] != y .shape [0 ]:
449450 raise ValueError (
450- f"X and y must have same number of observations: "
451- f"{ X .shape [0 ]} vs { y .shape [0 ]} "
451+ f"X and y must have same number of observations: " f"{ X .shape [0 ]} vs { y .shape [0 ]} "
452452 )
453453
454454 n , k = X .shape
@@ -484,7 +484,8 @@ def solve_ols(
484484 if skip_rank_check :
485485 if HAS_RUST_BACKEND and _rust_solve_ols is not None :
486486 result = _solve_ols_rust (
487- X , y ,
487+ X ,
488+ y ,
488489 cluster_ids = cluster_ids ,
489490 return_vcov = return_vcov ,
490491 return_fitted = return_fitted ,
@@ -494,7 +495,8 @@ def solve_ols(
494495 # Fall through to NumPy on numerical instability
495496 # Fall through to Python without rank check (user guarantees full rank)
496497 return _solve_ols_numpy (
497- X , y ,
498+ X ,
499+ y ,
498500 cluster_ids = cluster_ids ,
499501 return_vcov = return_vcov ,
500502 return_fitted = return_fitted ,
@@ -521,7 +523,8 @@ def solve_ols(
521523 # - No Rust → Python backend (works for all cases)
522524 if HAS_RUST_BACKEND and _rust_solve_ols is not None and not is_rank_deficient :
523525 result = _solve_ols_rust (
524- X , y ,
526+ X ,
527+ y ,
525528 cluster_ids = cluster_ids ,
526529 return_vcov = return_vcov ,
527530 return_fitted = return_fitted ,
@@ -531,7 +534,8 @@ def solve_ols(
531534 # signaled us to fall back to Python backend
532535 if result is None :
533536 return _solve_ols_numpy (
534- X , y ,
537+ X ,
538+ y ,
535539 cluster_ids = cluster_ids ,
536540 return_vcov = return_vcov ,
537541 return_fitted = return_fitted ,
@@ -555,7 +559,8 @@ def solve_ols(
555559 # and SVD disagreed about rank. Python's QR will re-detect and
556560 # apply R-style NaN handling for dropped columns.
557561 return _solve_ols_numpy (
558- X , y ,
562+ X ,
563+ y ,
559564 cluster_ids = cluster_ids ,
560565 return_vcov = return_vcov ,
561566 return_fitted = return_fitted ,
@@ -569,7 +574,8 @@ def solve_ols(
569574 # Use NumPy implementation for rank-deficient cases (R-style NA handling)
570575 # or when Rust backend is not available
571576 return _solve_ols_numpy (
572- X , y ,
577+ X ,
578+ y ,
573579 cluster_ids = cluster_ids ,
574580 return_vcov = return_vcov ,
575581 return_fitted = return_fitted ,
@@ -834,9 +840,7 @@ def _compute_robust_vcov_numpy(
834840 n_clusters = len (unique_clusters )
835841
836842 if n_clusters < 2 :
837- raise ValueError (
838- f"Need at least 2 clusters for cluster-robust SEs, got { n_clusters } "
839- )
843+ raise ValueError (f"Need at least 2 clusters for cluster-robust SEs, got { n_clusters } " )
840844
841845 # Small-sample adjustment
842846 adjustment = (n_clusters / (n_clusters - 1 )) * ((n - 1 ) / (n - k ))
@@ -871,6 +875,193 @@ def _compute_robust_vcov_numpy(
871875 return vcov
872876
873877
878+ # Empirical threshold: coefficients above this magnitude suggest near-separation
879+ # in the logistic model (predicted probabilities collapse to 0/1).
880+ _LOGIT_SEPARATION_COEF_THRESHOLD = 10
881+ _LOGIT_SEPARATION_PROB_THRESHOLD = 1e-5
882+
883+
884+ def solve_logit (
885+ X : np .ndarray ,
886+ y : np .ndarray ,
887+ max_iter : int = 25 ,
888+ tol : float = 1e-8 ,
889+ check_separation : bool = True ,
890+ rank_deficient_action : str = "warn" ,
891+ ) -> Tuple [np .ndarray , np .ndarray ]:
892+ """
893+ Fit logistic regression via IRLS (Fisher scoring).
894+
895+ Matches R's ``glm(family=binomial)`` algorithm: iteratively reweighted
896+ least squares with working weights ``mu*(1-mu)`` and working response
897+ ``eta + (y-mu)/(mu*(1-mu))``.
898+
899+ Parameters
900+ ----------
901+ X : np.ndarray
902+ Feature matrix (n_samples, n_features). Intercept added automatically.
903+ y : np.ndarray
904+ Binary outcome (0/1).
905+ max_iter : int, default 25
906+ Maximum IRLS iterations (R's ``glm`` default).
907+ tol : float, default 1e-8
908+ Convergence tolerance on coefficient change (R's ``glm`` default).
909+ check_separation : bool, default True
910+ Whether to check for near-separation and emit warnings.
911+ rank_deficient_action : str, default "warn"
912+ How to handle rank-deficient design matrices:
913+ - "warn": Emit warning and drop columns (default)
914+ - "error": Raise ValueError
915+ - "silent": Drop columns silently
916+
917+ Returns
918+ -------
919+ beta : np.ndarray
920+ Fitted coefficients (including intercept as element 0).
921+ probs : np.ndarray
922+ Predicted probabilities.
923+ """
924+ n , p = X .shape
925+ X_with_intercept = np .column_stack ([np .ones (n ), X ])
926+ k = p + 1 # number of parameters including intercept
927+
928+ # Validate rank_deficient_action
929+ valid_actions = {"warn" , "error" , "silent" }
930+ if rank_deficient_action not in valid_actions :
931+ raise ValueError (
932+ f"rank_deficient_action must be one of { valid_actions } , "
933+ f"got '{ rank_deficient_action } '"
934+ )
935+
936+ # Check rank deficiency once before iterating
937+ rank_info = _detect_rank_deficiency (X_with_intercept )
938+ rank , dropped_cols , _ = rank_info
939+ if len (dropped_cols ) > 0 :
940+ col_desc = _format_dropped_columns (dropped_cols )
941+ if rank_deficient_action == "error" :
942+ raise ValueError (
943+ f"Rank-deficient design matrix in logistic regression: "
944+ f"dropping { col_desc } . Propensity score estimates may be unreliable."
945+ )
946+ elif rank_deficient_action == "warn" :
947+ warnings .warn (
948+ f"Rank-deficient design matrix in logistic regression: "
949+ f"dropping { col_desc } . Propensity score estimates may be unreliable." ,
950+ UserWarning ,
951+ stacklevel = 2 ,
952+ )
953+ kept_cols = np .array ([i for i in range (k ) if i not in dropped_cols ])
954+ X_solve = X_with_intercept [:, kept_cols ]
955+ else :
956+ kept_cols = np .arange (k )
957+ X_solve = X_with_intercept
958+
959+ # IRLS (Fisher scoring)
960+ beta_solve = np .zeros (X_solve .shape [1 ])
961+ converged = False
962+
963+ for iteration in range (max_iter ):
964+ eta = X_solve @ beta_solve
965+ # Clip to prevent overflow in exp
966+ eta = np .clip (eta , - 500 , 500 )
967+ mu = 1.0 / (1.0 + np .exp (- eta ))
968+ # Clip mu to prevent zero working weights
969+ mu = np .clip (mu , 1e-10 , 1 - 1e-10 )
970+
971+ # Working weights and working response
972+ w = mu * (1.0 - mu )
973+ z = eta + (y - mu ) / w
974+
975+ # Weighted least squares: solve (X'WX) beta = X'Wz
976+ sqrt_w = np .sqrt (w )
977+ Xw = X_solve * sqrt_w [:, None ]
978+ zw = z * sqrt_w
979+ beta_new , _ , _ , _ = np .linalg .lstsq (Xw , zw , rcond = None )
980+
981+ # Check convergence
982+ if np .max (np .abs (beta_new - beta_solve )) < tol :
983+ beta_solve = beta_new
984+ converged = True
985+ break
986+ beta_solve = beta_new
987+
988+ # Final predicted probabilities
989+ eta_final = X_solve @ beta_solve
990+ eta_final = np .clip (eta_final , - 500 , 500 )
991+ probs = 1.0 / (1.0 + np .exp (- eta_final ))
992+
993+ # Warnings
994+ if not converged :
995+ warnings .warn (
996+ f"Logistic regression did not converge in { max_iter } iterations. "
997+ f"Propensity score estimates may be unreliable." ,
998+ UserWarning ,
999+ stacklevel = 2 ,
1000+ )
1001+
1002+ if check_separation :
1003+ if np .max (np .abs (beta_solve )) > _LOGIT_SEPARATION_COEF_THRESHOLD :
1004+ warnings .warn (
1005+ "Large coefficients detected in propensity score model "
1006+ f"(max|beta| > { _LOGIT_SEPARATION_COEF_THRESHOLD } ), "
1007+ "suggesting potential separation." ,
1008+ UserWarning ,
1009+ stacklevel = 2 ,
1010+ )
1011+ n_extreme = int (
1012+ np .sum (
1013+ (probs < _LOGIT_SEPARATION_PROB_THRESHOLD )
1014+ | (probs > 1 - _LOGIT_SEPARATION_PROB_THRESHOLD )
1015+ )
1016+ )
1017+ if n_extreme > 0 :
1018+ warnings .warn (
1019+ f"Near-separation detected in propensity score model: "
1020+ f"{ n_extreme } of { n } observations have predicted probabilities "
1021+ f"within { _LOGIT_SEPARATION_PROB_THRESHOLD } of 0 or 1. ATT estimates may be sensitive to "
1022+ f"model specification." ,
1023+ UserWarning ,
1024+ stacklevel = 2 ,
1025+ )
1026+
1027+ # Expand beta back to full size if columns were dropped
1028+ if len (dropped_cols ) > 0 :
1029+ beta_full = np .zeros (k )
1030+ beta_full [kept_cols ] = beta_solve
1031+ else :
1032+ beta_full = beta_solve
1033+
1034+ return beta_full , probs
1035+
1036+
1037+ def _check_propensity_diagnostics (
1038+ pscore : np .ndarray ,
1039+ trim_bound : float = 0.01 ,
1040+ ) -> None :
1041+ """
1042+ Warn if propensity scores are extreme.
1043+
1044+ Parameters
1045+ ----------
1046+ pscore : np.ndarray
1047+ Predicted probabilities.
1048+ trim_bound : float, default 0.01
1049+ Trimming threshold.
1050+ """
1051+ n_extreme = int (np .sum ((pscore < trim_bound ) | (pscore > 1 - trim_bound )))
1052+ if n_extreme > 0 :
1053+ n_total = len (pscore )
1054+ pct = 100.0 * n_extreme / n_total
1055+ warnings .warn (
1056+ f"Propensity scores for { n_extreme } of { n_total } observations "
1057+ f"({ pct :.1f} %) were outside [{ trim_bound } , { 1 - trim_bound } ] "
1058+ f"and will be trimmed. This may indicate near-separation in "
1059+ f"the propensity score model." ,
1060+ UserWarning ,
1061+ stacklevel = 2 ,
1062+ )
1063+
1064+
8741065def compute_r_squared (
8751066 y : np .ndarray ,
8761067 residuals : np .ndarray ,
@@ -1149,7 +1340,8 @@ def fit(
11491340 if self .robust or effective_cluster_ids is not None :
11501341 # Use solve_ols with robust/cluster SEs
11511342 coefficients , residuals , fitted , vcov = solve_ols (
1152- X , y ,
1343+ X ,
1344+ y ,
11531345 cluster_ids = effective_cluster_ids ,
11541346 return_fitted = True ,
11551347 return_vcov = compute_vcov ,
@@ -1158,7 +1350,8 @@ def fit(
11581350 else :
11591351 # Classical OLS - compute vcov separately
11601352 coefficients , residuals , fitted , _ = solve_ols (
1161- X , y ,
1353+ X ,
1354+ y ,
11621355 return_fitted = True ,
11631356 return_vcov = False ,
11641357 rank_deficient_action = self .rank_deficient_action ,
@@ -1294,6 +1487,7 @@ def get_inference(
12941487 # Handle zero or negative SE (indicates perfect fit or numerical issues)
12951488 if se <= 0 :
12961489 import warnings
1490+
12971491 warnings .warn (
12981492 f"Standard error is zero or negative (se={ se } ) for coefficient at index { index } . "
12991493 "This may indicate perfect multicollinearity or numerical issues." ,
@@ -1319,6 +1513,7 @@ def get_inference(
13191513 # Warn if df is non-positive and fall back to normal distribution
13201514 if effective_df is not None and effective_df <= 0 :
13211515 import warnings
1516+
13221517 warnings .warn (
13231518 f"Degrees of freedom is non-positive (df={ effective_df } ). "
13241519 "Using normal distribution instead of t-distribution for inference." ,
@@ -1396,10 +1591,7 @@ def get_all_inference(
13961591 Inference results for each coefficient in order.
13971592 """
13981593 self ._check_fitted ()
1399- return [
1400- self .get_inference (i , alpha = alpha , df = df )
1401- for i in range (len (self .coefficients_ ))
1402- ]
1594+ return [self .get_inference (i , alpha = alpha , df = df ) for i in range (len (self .coefficients_ ))]
14031595
14041596 def r_squared (self , adjusted : bool = False ) -> float :
14051597 """
@@ -1424,9 +1616,7 @@ def r_squared(self, adjusted: bool = False) -> float:
14241616 self ._check_fitted ()
14251617 # Use effective params for adjusted R² to match df correction
14261618 n_params = self .n_params_effective_ if adjusted else self .n_params_
1427- return compute_r_squared (
1428- self ._y , self .residuals_ , adjusted = adjusted , n_params = n_params
1429- )
1619+ return compute_r_squared (self ._y , self .residuals_ , adjusted = adjusted , n_params = n_params )
14301620
14311621 def predict (self , X : np .ndarray ) -> np .ndarray :
14321622 """
0 commit comments