Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -603,22 +603,31 @@ function _is_gpu_sparse(A)
end

"""
_do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...; kwargs...)
_do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol)

Perform QR fallback after LU failure or residual check failure. Restores `cache.A`
from `A_backup` (since LU may have modified it in-place) and solves with column-pivoted QR
(or NoPivot for GPU arrays which don't support scalar indexing).
`reason` is `:lu_failure` or `:residual_check` for appropriate log messages.
"""
function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...; kwargs...)
function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol)
# Always extract solution data from `cache` rather than `sol`. The QR
# fallback path calls `solve!(cache, QRFactorization(...))` recursively;
# during precompile inference, that inner call's return type gets capped
# to a non-concrete UnionAll (Julia's inference complexity limit). Reading
# `cache.u` (statically typed) and using `cache` for the solution cache
# field keeps the return type of this helper concrete, which propagates
# up through `_default_lu_solve_with_fallback` and the @generated
# `solve!(cache, ::DefaultLinearSolver)` body.
rc = sol.retcode
iters = sol.iters
if is_cusparse(cache.A)
@SciMLMessage(
"LU factorization failed for GPU sparse matrix but QR fallback is not supported for CuSparse. Returning LU failure.",
cache.verbose, :default_lu_fallback
)
return SciMLBase.build_linear_solution(
alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode, iters = sol.iters, stats = sol.stats
alg, cache.u, nothing, cache; retcode = rc, iters = iters, stats = nothing
)
end
if cache.A === cache.cacheval.A_backup
Expand All @@ -627,8 +636,7 @@ function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...;
cache.verbose, :default_lu_fallback
)
return SciMLBase.build_linear_solution(
alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode, iters = sol.iters, stats = sol.stats
alg, cache.u, nothing, cache; retcode = rc, iters = iters, stats = nothing
)
end
if reason === :residual_check
Expand All @@ -645,33 +653,34 @@ function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...;
copyto!(cache.A, cache.cacheval.A_backup)
cache.isfresh = true
pivot = _qr_fallback_pivot(cache.A)
sol = SciMLBase.solve!(cache, QRFactorization(pivot), args...; kwargs...)
qr_sol = SciMLBase.solve!(cache, QRFactorization(pivot))
cache.cacheval.fell_back_to_qr = true
return SciMLBase.build_linear_solution(
alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode, iters = sol.iters, stats = sol.stats
alg, cache.u, nothing, cache;
retcode = qr_sol.retcode, iters = qr_sol.iters, stats = nothing
)
end

"""
_reuse_qr_fallback(cache::LinearCache, alg, args...; kwargs...)
_reuse_qr_fallback(cache::LinearCache, alg)

Reuse the cached QR factorization from a previous QR fallback. Called when
`fell_back_to_qr` is `true` and `isfresh` is `false`, meaning the matrix hasn't
changed since the QR fallback and we should keep using QR instead of the
(potentially corrupted) LU factorization.
"""
function _reuse_qr_fallback(cache::LinearCache, alg, args...; kwargs...)
function _reuse_qr_fallback(cache::LinearCache, alg)
pivot = _qr_fallback_pivot(cache.A)
sol = SciMLBase.solve!(cache, QRFactorization(pivot), args...; kwargs...)
qr_sol = SciMLBase.solve!(cache, QRFactorization(pivot))
# Use cache directly for type-stable inference (see _do_qr_fallback).
return SciMLBase.build_linear_solution(
alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode, iters = sol.iters, stats = sol.stats
alg, cache.u, nothing, cache;
retcode = qr_sol.retcode, iters = qr_sol.iters, stats = nothing
)
end

"""
_default_lu_solve_with_fallback(cache::LinearCache, alg::DefaultLinearSolver, sol, args...; kwargs...)
_default_lu_solve_with_fallback(cache::LinearCache, alg::DefaultLinearSolver, sol)

Post-process an LU solve result: if LU explicitly failed, the solution contains NaN/Inf,
or the residual check returned `APosterioriSafetyFailure`, fall back to column-pivoted QR.
Expand All @@ -682,26 +691,27 @@ The NaN/Inf check catches floating-point-near-singular matrices where LU "succee
near-zero pivots. This is O(n) and has zero false positives.
"""
function _default_lu_solve_with_fallback(
cache::LinearCache, alg::DefaultLinearSolver, sol, args...; kwargs...
cache::LinearCache, alg::DefaultLinearSolver, sol
)
if alg.safetyfallback
if sol.retcode === ReturnCode.Failure
return _do_qr_fallback(cache, alg, sol, :lu_failure, args...; kwargs...)
return _do_qr_fallback(cache, alg, sol, :lu_failure)
end
if sol.retcode === ReturnCode.Success && any(!isfinite, sol.u)
@SciMLMessage(
"LU solve produced non-finite values (NaN/Inf), falling back to QR. Matrix is likely near-singular.",
cache.verbose, :default_lu_fallback
)
return _do_qr_fallback(cache, alg, sol, :lu_failure, args...; kwargs...)
return _do_qr_fallback(cache, alg, sol, :lu_failure)
end
if sol.retcode === ReturnCode.APosterioriSafetyFailure
return _do_qr_fallback(cache, alg, sol, :residual_check, args...; kwargs...)
return _do_qr_fallback(cache, alg, sol, :residual_check)
end
end
# Use cache directly for type-stable inference (see _do_qr_fallback).
return SciMLBase.build_linear_solution(
alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode, iters = sol.iters, stats = sol.stats
alg, cache.u, nothing, cache;
retcode = sol.retcode, iters = sol.iters, stats = nothing
)
end

Expand Down Expand Up @@ -762,54 +772,54 @@ end
# its own residual check and returns APosterioriSafetyFailure if needed.
inner_alg_expr = _algchoice_to_alg_with_safety(alg)
newex = quote
sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...)
_default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...)
sol = SciMLBase.solve!(cache, $inner_alg_expr)
_default_lu_solve_with_fallback(cache, alg, sol)
end
elseif alg == Symbol(DefaultAlgorithmChoice.RFLUFactorization)
inner_alg_expr = _algchoice_to_alg_with_safety(alg)
newex = quote
if !userecursivefactorization(nothing)
error("Default algorithm calling solve on RecursiveFactorization without the package being loaded. This shouldn't happen.")
end
sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...)
_default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...)
sol = SciMLBase.solve!(cache, $inner_alg_expr)
_default_lu_solve_with_fallback(cache, alg, sol)
end
elseif alg == Symbol(DefaultAlgorithmChoice.BLISLUFactorization)
inner_alg_expr = _algchoice_to_alg_with_safety(alg)
newex = quote
if !useblis(nothing)
error("Default algorithm calling solve on BLISLUFactorization without the extension being loaded. This shouldn't happen.")
end
sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...)
_default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...)
sol = SciMLBase.solve!(cache, $inner_alg_expr)
_default_lu_solve_with_fallback(cache, alg, sol)
end
elseif alg == Symbol(DefaultAlgorithmChoice.CudaOffloadLUFactorization)
inner_alg_expr = _algchoice_to_alg_with_safety(alg)
newex = quote
if !usecuda(nothing)
error("Default algorithm calling solve on CudaOffloadLUFactorization without CUDA.jl being loaded. This shouldn't happen.")
end
sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...)
_default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...)
sol = SciMLBase.solve!(cache, $inner_alg_expr)
_default_lu_solve_with_fallback(cache, alg, sol)
end
elseif alg == Symbol(DefaultAlgorithmChoice.MetalLUFactorization)
inner_alg_expr = _algchoice_to_alg_with_safety(alg)
newex = quote
if !usemetal(nothing)
error("Default algorithm calling solve on MetalLUFactorization without Metal.jl being loaded. This shouldn't happen.")
end
sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...)
_default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...)
sol = SciMLBase.solve!(cache, $inner_alg_expr)
_default_lu_solve_with_fallback(cache, alg, sol)
end
else
if alg in LinearSolve._SPARSE_ONLY_ALGORITHMS
newex = quote
if !(cache.A isa Array)
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)))
SciMLBase.build_linear_solution(
alg, sol.u, sol.resid, sol.cache;
alg, cache.u, nothing, cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats
iters = sol.iters, stats = nothing
)
else
error(
Expand All @@ -820,11 +830,11 @@ end
end
else
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)))
SciMLBase.build_linear_solution(
alg, sol.u, sol.resid, sol.cache;
alg, cache.u, nothing, cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats
iters = sol.iters, stats = nothing
)
end
end
Expand All @@ -843,7 +853,7 @@ end
return quote
if cache.cacheval isa DefaultLinearSolverInit &&
cache.cacheval.fell_back_to_qr && !cache.isfresh
_reuse_qr_fallback(cache, alg, args...; kwargs...)
_reuse_qr_fallback(cache, alg)
else
$alg_dispatch
end
Expand Down
13 changes: 10 additions & 3 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,15 +456,22 @@ end
function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !issparsematrixcsc(A) && !(A isa GPUArraysCore.AnyGPUArray) &&
!is_cusparse(A)
# Sparse CSC (SPQR) does not accept a pivoting strategy, and CUDA's
# `qr` does not accept extra args either. Use the no-arg `qr(A)`
# form in those cases. For other CPU matrices, always pass
# `alg.pivot` so the return type is determined by the static
# `QRFactorization{P}` parameter (otherwise this branch returns
# `Union{QRCompactWY, QRPivoted}` depending on `alg.inplace`).
if A isa GPUArraysCore.AnyGPUArray || is_cusparse(A) || issparsematrixcsc(A)
fact = qr(A)
elseif alg.inplace
if A isa Symmetric
fact = qr(A, alg.pivot)
else
fact = qr!(A, alg.pivot)
end
else
fact = qr(A) # CUDA.jl does not allow other args!
fact = qr(A, alg.pivot)
end
else
fact = qr(A, alg.pivot)
Expand Down
58 changes: 58 additions & 0 deletions test/nopre/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,61 @@ end
JET.@test_opt init(dual_prob)
end
end

# Concrete-return-type QA for `solve!(cache)`. Guards against the regression
# where `solve!(cache)` through `DefaultLinearSolver` returned
# `LinearSolution{_A, _B, _C, _D, DefaultLinearSolver, _E, _F} where {...}`
# (a UnionAll over 6 free type parameters) instead of a concrete LinearSolution.
_solve_alg(A, b, alg) = solve!(init(LinearProblem(A, b), alg))
_solve_default(A, b) = solve!(init(LinearProblem(A, b)))

@testset "solve!(cache) returns concrete LinearSolution — default solver" begin
# Headline case: `solve!(cache)` after `init(LinearProblem(A, b))` must not
# return a UnionAll-typed LinearSolution. Was broken by the
# `_default_lu_solve_with_fallback`/`_do_qr_fallback` helpers reading
# `sol.u`/`sol.resid`/`sol.cache`/`sol.stats` from an inner `sol` whose
# rettype got capped to `Any` during precompile.
rt = Core.Compiler.return_type(
_solve_default, Tuple{Matrix{Float64}, Vector{Float64}}
)
@test isconcretetype(rt)
@test rt <: LinearSolve.SciMLBase.LinearSolution{Float64, 1, Vector{Float64}}
end

@testset "solve!(cache) is concrete for each algorithm" begin
algs_concrete = (
LUFactorization(),
GenericLUFactorization(),
QRFactorization(LinearAlgebra.ColumnNorm()),
QRFactorization(LinearAlgebra.NoPivot()),
DiagonalFactorization(),
SVDFactorization(),
CholeskyFactorization(),
NormalCholeskyFactorization(),
)
for alg in algs_concrete
@testset "$(nameof(typeof(alg)))" begin
rt = Core.Compiler.return_type(
_solve_alg,
Tuple{Matrix{Float64}, Vector{Float64}, typeof(alg)}
)
@test isconcretetype(rt)
end
end

# Known unrelated inference issues — tracked separately, not what this
# group is guarding against.
algs_broken = (
BunchKaufmanFactorization(),
LDLtFactorization(),
)
for alg in algs_broken
@testset "$(nameof(typeof(alg))) (broken)" begin
rt = Core.Compiler.return_type(
_solve_alg,
Tuple{Matrix{Float64}, Vector{Float64}, typeof(alg)}
)
@test_broken isconcretetype(rt)
end
end
end
Loading