diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index a8a2169c1b..61e56aa65b 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,6 +1,6 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" -version = "2.26.0" +version = "2.26.1" authors = ["Avik Pal and contributors"] [deps] @@ -110,12 +110,15 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "BandedMatrices", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"] +test = ["Aqua", "BandedMatrices", "ChainRulesCore", "Enzyme", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SciMLStructures", "SparseArrays", "Test"] diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index 54c46eebcb..4b82f92593 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -13,7 +13,16 @@ import SciMLStructures # - Another SciMLStructure # - A broadcastable array # In all cases, accumulation goes through the SciMLStructures interface. -function _accum_tangent!(dval, darg) +# +# `diff_tunables` mirrors the sensealg field of the same name and means +# "differentiate only the Tunable portion." When `true` (the default and +# the value carried by `SteadyStateAdjoint`/`Quadrature`/`Gauss` adjoints +# unless the user opted out) only the Tunable slice of a structured +# `darg` is accumulated. When `false`, `SciMLSensitivity.steadystatebackpass` +# returns a structured cotangent whose gradient contribution may live in +# non-Tunable fields such as `caches` (e.g. SCC sub-problem buffers feeding +# `explicitfuns!`), so those fields are walked in as well. +function _accum_tangent!(dval, darg; diff_tunables::Bool = true) if SciMLStructures.isscimlstructure(dval) && !(dval isa AbstractArray) if SciMLStructures.isscimlstructure(darg) shadow_tunables, _, _ = SciMLStructures.canonicalize( @@ -24,6 +33,13 @@ function _accum_tangent!(dval, darg) ) shadow_tunables .+= darg_tunables SciMLStructures.replace!(SciMLStructures.Tunable(), dval, shadow_tunables) + if !diff_tunables + for field in fieldnames(typeof(darg)) + field === :tunable && continue + hasfield(typeof(dval), field) || continue + _accum_nested!(getfield(dval, field), getfield(darg, field)) + end + end elseif darg isa AbstractVector shadow_tunables, _, _ = SciMLStructures.canonicalize( SciMLStructures.Tunable(), dval, @@ -117,6 +133,28 @@ function Enzyme.EnzymeRules.reverse( ) where {RT <: Enzyme.Annotation} dres, clos = tape dargs = clos(dres) + # Mirror the `diff_tunables` choice the inner adjoint will make. When the + # user passes a concrete sensealg, honor its `diff_tunables` field. When + # the outer sensealg is `nothing` (default), `_concrete_solve_adjoint` + # delegates to `automatic_sensealg_choice`, which picks + # `diff_tunables = Val(false)` whenever `prob.p` is a SciMLStructure with + # a non-empty `caches` field (e.g. an MTKParameters tied to an + # SCCNonlinearProblem's `explicitfuns!` coupling). Reproducing that + # predicate here lets the accumulator walk every non-Tunable field of a + # structured `darg` so the meaningful cotangent isn't dropped. + diff_tunables = let s = sensealg.val, pv = p.val + if s isa SciMLBase.AbstractSensitivityAlgorithm && + hasproperty(s, :diff_tunables) + !(getproperty(s, :diff_tunables) isa Val{false}) + else + !( + SciMLStructures.isscimlstructure(pv) && + !(pv isa AbstractArray) && + hasfield(typeof(pv), :caches) && + !isempty(pv.caches) + ) + end + end for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const continue @@ -125,9 +163,9 @@ function Enzyme.EnzymeRules.reverse( continue end if ptr isa MixedDuplicated - _accum_tangent!(ptr.dval[], darg) + _accum_tangent!(ptr.dval[], darg; diff_tunables) else - _accum_tangent!(ptr.dval, darg) + _accum_tangent!(ptr.dval, darg; diff_tunables) end end Enzyme.make_zero!(dres.u) diff --git a/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl b/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl new file mode 100644 index 0000000000..8675ab220c --- /dev/null +++ b/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl @@ -0,0 +1,59 @@ +module EnzymeAccumTangentTests + +using Test +using NonlinearSolveBase +import ChainRulesCore, Enzyme # triggers NonlinearSolveBaseEnzymeExt +import SciMLStructures +import SciMLStructures: Tunable + +mutable struct MockMTKParams + tunable::Vector{Float64} + caches::Tuple{Vector{Float64}} +end + +SciMLStructures.isscimlstructure(::MockMTKParams) = true +SciMLStructures.ismutablescimlstructure(::MockMTKParams) = true +function SciMLStructures.canonicalize(::Tunable, p::MockMTKParams) + return p.tunable, (val) -> MockMTKParams(collect(val), p.caches), true +end +function SciMLStructures.replace!(::Tunable, p::MockMTKParams, val) + p.tunable .= val + return nothing +end + +const EXT = Base.get_extension(NonlinearSolveBase, :NonlinearSolveBaseEnzymeExt) + +@testset "EnzymeExt._accum_tangent! gates non-Tunable walk on diff_tunables (#935)" begin + # Regression for SciML/NonlinearSolve.jl#935. The reverse rule + # mirrors `sensealg.diff_tunables` into this kwarg: + # + # * `diff_tunables = true` (default) — backpass returned a + # Tunable-only cotangent; leave non-Tunable fields of any + # structured `darg` alone. + # * `diff_tunables = false` — backpass returned a *structured* + # cotangent (e.g. caches from SCC `explicitfuns!` coupling). + # Walk every non-Tunable field of `darg` into `dval` so the + # meaningful contribution lands. + + # diff_tunables = false: caches must accumulate. + dval = MockMTKParams([0.0, 0.0], (zeros(3),)) + darg = MockMTKParams([1.0, 2.0], ([10.0, 20.0, 30.0],)) + EXT._accum_tangent!(dval, darg; diff_tunables = false) + @test dval.tunable == [1.0, 2.0] + @test dval.caches[1] == [10.0, 20.0, 30.0] + + # And `+=` semantics on repeated calls. + darg2 = MockMTKParams([0.5, 0.5], ([1.0, 2.0, 3.0],)) + EXT._accum_tangent!(dval, darg2; diff_tunables = false) + @test dval.tunable == [1.5, 2.5] + @test dval.caches[1] == [11.0, 22.0, 33.0] + + # diff_tunables = true (default): only Tunable touched. + dval2 = MockMTKParams([0.0, 0.0], (zeros(3),)) + darg3 = MockMTKParams([1.0, 2.0], ([10.0, 20.0, 30.0],)) + EXT._accum_tangent!(dval2, darg3) + @test dval2.tunable == [1.0, 2.0] + @test dval2.caches[1] == zeros(3) +end + +end diff --git a/lib/NonlinearSolveBase/test/runtests.jl b/lib/NonlinearSolveBase/test/runtests.jl index 3c60fed32f..df91f4d6a1 100644 --- a/lib/NonlinearSolveBase/test/runtests.jl +++ b/lib/NonlinearSolveBase/test/runtests.jl @@ -130,4 +130,8 @@ using InteractiveUtils, Test NonlinearSolveBase.maybe_wrap_nonlinear_f(prob_3d) ) end + + @testset "EnzymeExt _accum_tangent! caches accumulation (#935)" begin + include("enzyme_accum_tangent.jl") + end end