diff --git a/src/default.jl b/src/default.jl index fd18c6ef9..4b2f323c4 100644 --- a/src/default.jl +++ b/src/default.jl @@ -603,22 +603,31 @@ function _is_gpu_sparse(A) end """ - _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...; kwargs...) + _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol) Perform QR fallback after LU failure or residual check failure. Restores `cache.A` from `A_backup` (since LU may have modified it in-place) and solves with column-pivoted QR (or NoPivot for GPU arrays which don't support scalar indexing). `reason` is `:lu_failure` or `:residual_check` for appropriate log messages. """ -function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...; kwargs...) +function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol) + # Always extract solution data from `cache` rather than `sol`. The QR + # fallback path calls `solve!(cache, QRFactorization(...))` recursively; + # during precompile inference, that inner call's return type gets capped + # to a non-concrete UnionAll (Julia's inference complexity limit). Reading + # `cache.u` (statically typed) and using `cache` for the solution cache + # field keeps the return type of this helper concrete, which propagates + # up through `_default_lu_solve_with_fallback` and the @generated + # `solve!(cache, ::DefaultLinearSolver)` body. + rc = sol.retcode + iters = sol.iters if is_cusparse(cache.A) @SciMLMessage( "LU factorization failed for GPU sparse matrix but QR fallback is not supported for CuSparse. Returning LU failure.", cache.verbose, :default_lu_fallback ) return SciMLBase.build_linear_solution( - alg, sol.u, sol.resid, sol.cache; - retcode = sol.retcode, iters = sol.iters, stats = sol.stats + alg, cache.u, nothing, cache; retcode = rc, iters = iters, stats = nothing ) end if cache.A === cache.cacheval.A_backup @@ -627,8 +636,7 @@ function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...; cache.verbose, :default_lu_fallback ) return SciMLBase.build_linear_solution( - alg, sol.u, sol.resid, sol.cache; - retcode = sol.retcode, iters = sol.iters, stats = sol.stats + alg, cache.u, nothing, cache; retcode = rc, iters = iters, stats = nothing ) end if reason === :residual_check @@ -645,33 +653,34 @@ function _do_qr_fallback(cache::LinearCache, alg, sol, reason::Symbol, args...; copyto!(cache.A, cache.cacheval.A_backup) cache.isfresh = true pivot = _qr_fallback_pivot(cache.A) - sol = SciMLBase.solve!(cache, QRFactorization(pivot), args...; kwargs...) + qr_sol = SciMLBase.solve!(cache, QRFactorization(pivot)) cache.cacheval.fell_back_to_qr = true return SciMLBase.build_linear_solution( - alg, sol.u, sol.resid, sol.cache; - retcode = sol.retcode, iters = sol.iters, stats = sol.stats + alg, cache.u, nothing, cache; + retcode = qr_sol.retcode, iters = qr_sol.iters, stats = nothing ) end """ - _reuse_qr_fallback(cache::LinearCache, alg, args...; kwargs...) + _reuse_qr_fallback(cache::LinearCache, alg) Reuse the cached QR factorization from a previous QR fallback. Called when `fell_back_to_qr` is `true` and `isfresh` is `false`, meaning the matrix hasn't changed since the QR fallback and we should keep using QR instead of the (potentially corrupted) LU factorization. """ -function _reuse_qr_fallback(cache::LinearCache, alg, args...; kwargs...) +function _reuse_qr_fallback(cache::LinearCache, alg) pivot = _qr_fallback_pivot(cache.A) - sol = SciMLBase.solve!(cache, QRFactorization(pivot), args...; kwargs...) + qr_sol = SciMLBase.solve!(cache, QRFactorization(pivot)) + # Use cache directly for type-stable inference (see _do_qr_fallback). return SciMLBase.build_linear_solution( - alg, sol.u, sol.resid, sol.cache; - retcode = sol.retcode, iters = sol.iters, stats = sol.stats + alg, cache.u, nothing, cache; + retcode = qr_sol.retcode, iters = qr_sol.iters, stats = nothing ) end """ - _default_lu_solve_with_fallback(cache::LinearCache, alg::DefaultLinearSolver, sol, args...; kwargs...) + _default_lu_solve_with_fallback(cache::LinearCache, alg::DefaultLinearSolver, sol) Post-process an LU solve result: if LU explicitly failed, the solution contains NaN/Inf, or the residual check returned `APosterioriSafetyFailure`, fall back to column-pivoted QR. @@ -682,26 +691,27 @@ The NaN/Inf check catches floating-point-near-singular matrices where LU "succee near-zero pivots. This is O(n) and has zero false positives. """ function _default_lu_solve_with_fallback( - cache::LinearCache, alg::DefaultLinearSolver, sol, args...; kwargs... + cache::LinearCache, alg::DefaultLinearSolver, sol ) if alg.safetyfallback if sol.retcode === ReturnCode.Failure - return _do_qr_fallback(cache, alg, sol, :lu_failure, args...; kwargs...) + return _do_qr_fallback(cache, alg, sol, :lu_failure) end if sol.retcode === ReturnCode.Success && any(!isfinite, sol.u) @SciMLMessage( "LU solve produced non-finite values (NaN/Inf), falling back to QR. Matrix is likely near-singular.", cache.verbose, :default_lu_fallback ) - return _do_qr_fallback(cache, alg, sol, :lu_failure, args...; kwargs...) + return _do_qr_fallback(cache, alg, sol, :lu_failure) end if sol.retcode === ReturnCode.APosterioriSafetyFailure - return _do_qr_fallback(cache, alg, sol, :residual_check, args...; kwargs...) + return _do_qr_fallback(cache, alg, sol, :residual_check) end end + # Use cache directly for type-stable inference (see _do_qr_fallback). return SciMLBase.build_linear_solution( - alg, sol.u, sol.resid, sol.cache; - retcode = sol.retcode, iters = sol.iters, stats = sol.stats + alg, cache.u, nothing, cache; + retcode = sol.retcode, iters = sol.iters, stats = nothing ) end @@ -762,8 +772,8 @@ end # its own residual check and returns APosterioriSafetyFailure if needed. inner_alg_expr = _algchoice_to_alg_with_safety(alg) newex = quote - sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...) - _default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...) + sol = SciMLBase.solve!(cache, $inner_alg_expr) + _default_lu_solve_with_fallback(cache, alg, sol) end elseif alg == Symbol(DefaultAlgorithmChoice.RFLUFactorization) inner_alg_expr = _algchoice_to_alg_with_safety(alg) @@ -771,8 +781,8 @@ end if !userecursivefactorization(nothing) error("Default algorithm calling solve on RecursiveFactorization without the package being loaded. This shouldn't happen.") end - sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...) - _default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...) + sol = SciMLBase.solve!(cache, $inner_alg_expr) + _default_lu_solve_with_fallback(cache, alg, sol) end elseif alg == Symbol(DefaultAlgorithmChoice.BLISLUFactorization) inner_alg_expr = _algchoice_to_alg_with_safety(alg) @@ -780,8 +790,8 @@ end if !useblis(nothing) error("Default algorithm calling solve on BLISLUFactorization without the extension being loaded. This shouldn't happen.") end - sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...) - _default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...) + sol = SciMLBase.solve!(cache, $inner_alg_expr) + _default_lu_solve_with_fallback(cache, alg, sol) end elseif alg == Symbol(DefaultAlgorithmChoice.CudaOffloadLUFactorization) inner_alg_expr = _algchoice_to_alg_with_safety(alg) @@ -789,8 +799,8 @@ end if !usecuda(nothing) error("Default algorithm calling solve on CudaOffloadLUFactorization without CUDA.jl being loaded. This shouldn't happen.") end - sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...) - _default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...) + sol = SciMLBase.solve!(cache, $inner_alg_expr) + _default_lu_solve_with_fallback(cache, alg, sol) end elseif alg == Symbol(DefaultAlgorithmChoice.MetalLUFactorization) inner_alg_expr = _algchoice_to_alg_with_safety(alg) @@ -798,18 +808,18 @@ end if !usemetal(nothing) error("Default algorithm calling solve on MetalLUFactorization without Metal.jl being loaded. This shouldn't happen.") end - sol = SciMLBase.solve!(cache, $inner_alg_expr, args...; kwargs...) - _default_lu_solve_with_fallback(cache, alg, sol, args...; kwargs...) + sol = SciMLBase.solve!(cache, $inner_alg_expr) + _default_lu_solve_with_fallback(cache, alg, sol) end else if alg in LinearSolve._SPARSE_ONLY_ALGORITHMS newex = quote if !(cache.A isa Array) - sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) + sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg))) SciMLBase.build_linear_solution( - alg, sol.u, sol.resid, sol.cache; + alg, cache.u, nothing, cache; retcode = sol.retcode, - iters = sol.iters, stats = sol.stats + iters = sol.iters, stats = nothing ) else error( @@ -820,11 +830,11 @@ end end else newex = quote - sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) + sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg))) SciMLBase.build_linear_solution( - alg, sol.u, sol.resid, sol.cache; + alg, cache.u, nothing, cache; retcode = sol.retcode, - iters = sol.iters, stats = sol.stats + iters = sol.iters, stats = nothing ) end end @@ -843,7 +853,7 @@ end return quote if cache.cacheval isa DefaultLinearSolverInit && cache.cacheval.fell_back_to_qr && !cache.isfresh - _reuse_qr_fallback(cache, alg, args...; kwargs...) + _reuse_qr_fallback(cache, alg) else $alg_dispatch end diff --git a/src/factorization.jl b/src/factorization.jl index f345a978d..0b68ef1e5 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -456,15 +456,22 @@ end function do_factorization(alg::QRFactorization, A, b, u) A = convert(AbstractMatrix, A) if ArrayInterface.can_setindex(typeof(A)) - if alg.inplace && !issparsematrixcsc(A) && !(A isa GPUArraysCore.AnyGPUArray) && - !is_cusparse(A) + # Sparse CSC (SPQR) does not accept a pivoting strategy, and CUDA's + # `qr` does not accept extra args either. Use the no-arg `qr(A)` + # form in those cases. For other CPU matrices, always pass + # `alg.pivot` so the return type is determined by the static + # `QRFactorization{P}` parameter (otherwise this branch returns + # `Union{QRCompactWY, QRPivoted}` depending on `alg.inplace`). + if A isa GPUArraysCore.AnyGPUArray || is_cusparse(A) || issparsematrixcsc(A) + fact = qr(A) + elseif alg.inplace if A isa Symmetric fact = qr(A, alg.pivot) else fact = qr!(A, alg.pivot) end else - fact = qr(A) # CUDA.jl does not allow other args! + fact = qr(A, alg.pivot) end else fact = qr(A, alg.pivot) diff --git a/test/nopre/jet.jl b/test/nopre/jet.jl index b890d3e63..660c908d4 100644 --- a/test/nopre/jet.jl +++ b/test/nopre/jet.jl @@ -197,3 +197,61 @@ end JET.@test_opt init(dual_prob) end end + +# Concrete-return-type QA for `solve!(cache)`. Guards against the regression +# where `solve!(cache)` through `DefaultLinearSolver` returned +# `LinearSolution{_A, _B, _C, _D, DefaultLinearSolver, _E, _F} where {...}` +# (a UnionAll over 6 free type parameters) instead of a concrete LinearSolution. +_solve_alg(A, b, alg) = solve!(init(LinearProblem(A, b), alg)) +_solve_default(A, b) = solve!(init(LinearProblem(A, b))) + +@testset "solve!(cache) returns concrete LinearSolution — default solver" begin + # Headline case: `solve!(cache)` after `init(LinearProblem(A, b))` must not + # return a UnionAll-typed LinearSolution. Was broken by the + # `_default_lu_solve_with_fallback`/`_do_qr_fallback` helpers reading + # `sol.u`/`sol.resid`/`sol.cache`/`sol.stats` from an inner `sol` whose + # rettype got capped to `Any` during precompile. + rt = Core.Compiler.return_type( + _solve_default, Tuple{Matrix{Float64}, Vector{Float64}} + ) + @test isconcretetype(rt) + @test rt <: LinearSolve.SciMLBase.LinearSolution{Float64, 1, Vector{Float64}} +end + +@testset "solve!(cache) is concrete for each algorithm" begin + algs_concrete = ( + LUFactorization(), + GenericLUFactorization(), + QRFactorization(LinearAlgebra.ColumnNorm()), + QRFactorization(LinearAlgebra.NoPivot()), + DiagonalFactorization(), + SVDFactorization(), + CholeskyFactorization(), + NormalCholeskyFactorization(), + ) + for alg in algs_concrete + @testset "$(nameof(typeof(alg)))" begin + rt = Core.Compiler.return_type( + _solve_alg, + Tuple{Matrix{Float64}, Vector{Float64}, typeof(alg)} + ) + @test isconcretetype(rt) + end + end + + # Known unrelated inference issues — tracked separately, not what this + # group is guarding against. + algs_broken = ( + BunchKaufmanFactorization(), + LDLtFactorization(), + ) + for alg in algs_broken + @testset "$(nameof(typeof(alg))) (broken)" begin + rt = Core.Compiler.return_type( + _solve_alg, + Tuple{Matrix{Float64}, Vector{Float64}, typeof(alg)} + ) + @test_broken isconcretetype(rt) + end + end +end