Skip to content
This repository was archived by the owner on Apr 24, 2024. It is now read-only.
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/equisolve/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _solve(
y: TensorBlock,
alpha: TensorBlock,
sample_weight: TensorBlock,
cond: float,
rcond: Optional[float] = None,
) -> TensorBlock:
"""A regularized solver using ``np.linalg.lstsq``."""
self._used_auto_solver = None
Expand Down Expand Up @@ -197,10 +197,14 @@ def _solve(
# and b is [y*sqrt(w), 0]
X_eff = np.vstack([sqrt_sw_arr * X_arr, np.diag(np.sqrt(alpha_arr))])
y_eff = np.hstack([y_arr * sqrt_sw_arr.flatten(), np.zeros(num_properties)])
w = scipy.linalg.lstsq(X_eff, y_eff, cond=cond, overwrite_a=True)[0].ravel()
if rcond is None:
rcond = max(X_arr.shape) * np.finfo(X_arr.dtype.char.lower()).eps
w = scipy.linalg.lstsq(X_eff, y_eff, cond=rcond, overwrite_a=True)[
0
].ravel()
else:
raise ValueError(
f"Unknown solver {self._solver} only 'auto', 'cholesky',"
f"Unknown solver {self._solver!r} only 'auto', 'cholesky',"
" 'cholesky_dual' and 'lstsq' are supported."
)

Expand All @@ -225,7 +229,7 @@ def fit(
alpha: Union[float, TensorMap] = 1.0,
sample_weight: Union[float, TensorMap] = None,
solver="auto",
cond: float = None,
rcond: Optional[float] = None,
) -> None:
"""Fit a regression model to each block in `X`.

Expand Down Expand Up @@ -261,10 +265,11 @@ def fit(
on the dual problem (X@X.T) w_dual = y,
the primal weights are obtained by w = X.T @ w_dual
- **"lstsq"**: using :func:`scipy.linalg.lstsq` on the linear system X w = y
:param cond:
:param rcond:
Cut-off ratio for small singular values during the fit. For the purposes of
rank determination, singular values are treated as zero if they are smaller
than `cond` times the largest singular value in "weights" matrix.
than `cond` times the largest singular value in "weights" matrix. Only
important when solver "lstsq" is used.
"""
self._solver = solver

Expand Down Expand Up @@ -303,7 +308,7 @@ def fit(
alpha_block = alpha.block(key)
sw_block = sample_weight.block(key)

weight_block = self._solve(X_block, y_block, alpha_block, sw_block, cond)
weight_block = self._solve(X_block, y_block, alpha_block, sw_block, rcond)

weights_blocks.append(weight_block)

Expand Down