From dc48690b898f2f38da047760236c4c2059ce6478 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sun, 24 May 2026 04:35:50 -0400 Subject: [PATCH 1/3] NonlinearSolveBaseEnzymeExt: accumulate non-Tunable cotangent fields MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When SciMLSensitivity's `steadystatebackpass` runs under `diff_tunables = Val(false)` — which is the deliberate choice when `MTKParameters` carries SCC `caches` so those buffers participate in the adjoint — the cotangent it returns is a structured `MTKParameters` with non-zero contributions in `caches` (and potentially other non-Tunable portions), not just `tunable`. The previous `_accum_tangent!` SciMLStructure→SciMLStructure branch only round-tripped the Tunable slice via `SciMLStructures.canonicalize(Tunable(), ...)` / `replace!`, silently dropping every non-Tunable field of `darg`. That manifested downstream (see SciML/NonlinearSolve.jl#935) as Enzyme gradients of `solve` on `SCCNonlinearProblem`-backed initialization producing zero when the meaningful gradient flowed through `caches`-mediated coupling. Walk the remaining fields of `darg` via the existing `_accum_nested!` helper so caches/initials/discrete/constant/nonnumeric contributions are accumulated as well. Tests added with a minimal `MockMTKParams` SciMLStructure-compliant type that exercises caches accumulation directly through the extension's internal helper. Co-Authored-By: Chris Rackauckas --- lib/NonlinearSolveBase/Project.toml | 7 ++- .../ext/NonlinearSolveBaseEnzymeExt.jl | 12 +++++ .../test/enzyme_accum_tangent.jl | 49 +++++++++++++++++++ lib/NonlinearSolveBase/test/runtests.jl | 4 ++ 4 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index a8a2169c1b..61e56aa65b 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,6 +1,6 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" -version = "2.26.0" +version = "2.26.1" authors = ["Avik Pal and contributors"] [deps] @@ -110,12 +110,15 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "BandedMatrices", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"] +test = ["Aqua", "BandedMatrices", "ChainRulesCore", "Enzyme", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SciMLStructures", "SparseArrays", "Test"] diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index 54c46eebcb..6266d9bba6 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -24,6 +24,18 @@ function _accum_tangent!(dval, darg) ) shadow_tunables .+= darg_tunables SciMLStructures.replace!(SciMLStructures.Tunable(), dval, shadow_tunables) + # When the upstream rule returns a structured cotangent + # (e.g. SciMLSensitivity's steadystatebackpass under + # diff_tunables = Val(false)), the gradient contribution may + # live in non-Tunable fields such as `caches` (SCC sub-problem + # buffers feeding `explicitfuns!`). Accumulate every non-Tunable + # field that both sides expose so those contributions are not + # silently dropped. + for field in fieldnames(typeof(darg)) + field === :tunable && continue + hasfield(typeof(dval), field) || continue + _accum_nested!(getfield(dval, field), getfield(darg, field)) + end elseif darg isa AbstractVector shadow_tunables, _, _ = SciMLStructures.canonicalize( SciMLStructures.Tunable(), dval, diff --git a/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl b/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl new file mode 100644 index 0000000000..d3437e1a2e --- /dev/null +++ b/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl @@ -0,0 +1,49 @@ +module EnzymeAccumTangentTests + +using Test +using NonlinearSolveBase +import ChainRulesCore, Enzyme # triggers NonlinearSolveBaseEnzymeExt +import SciMLStructures +import SciMLStructures: Tunable + +mutable struct MockMTKParams + tunable::Vector{Float64} + caches::Tuple{Vector{Float64}} +end + +SciMLStructures.isscimlstructure(::MockMTKParams) = true +SciMLStructures.ismutablescimlstructure(::MockMTKParams) = true +function SciMLStructures.canonicalize(::Tunable, p::MockMTKParams) + return p.tunable, (val) -> MockMTKParams(collect(val), p.caches), true +end +function SciMLStructures.replace!(::Tunable, p::MockMTKParams, val) + p.tunable .= val + return nothing +end + +const EXT = Base.get_extension(NonlinearSolveBase, :NonlinearSolveBaseEnzymeExt) + +@testset "EnzymeExt._accum_tangent! walks non-Tunable fields (caches)" begin + # Regression for SciML/NonlinearSolve.jl#935: when SciMLSensitivity's + # `steadystatebackpass` returns a structured cotangent under + # `diff_tunables = Val(false)` (e.g. SCC explicitfuns! coupling), + # the gradient contribution lives in `caches`, not only `tunable`. + # The accumulator must walk those non-Tunable fields too — otherwise + # the meaningful cotangent is silently dropped and the user observes + # a zero gradient. + dval = MockMTKParams([0.0, 0.0], (zeros(3),)) + darg = MockMTKParams([1.0, 2.0], ([10.0, 20.0, 30.0],)) + + EXT._accum_tangent!(dval, darg) + + @test dval.tunable == [1.0, 2.0] + @test dval.caches[1] == [10.0, 20.0, 30.0] + + # Accumulate again — verify it adds, doesn't overwrite. + darg2 = MockMTKParams([0.5, 0.5], ([1.0, 2.0, 3.0],)) + EXT._accum_tangent!(dval, darg2) + @test dval.tunable == [1.5, 2.5] + @test dval.caches[1] == [11.0, 22.0, 33.0] +end + +end diff --git a/lib/NonlinearSolveBase/test/runtests.jl b/lib/NonlinearSolveBase/test/runtests.jl index 3c60fed32f..df91f4d6a1 100644 --- a/lib/NonlinearSolveBase/test/runtests.jl +++ b/lib/NonlinearSolveBase/test/runtests.jl @@ -130,4 +130,8 @@ using InteractiveUtils, Test NonlinearSolveBase.maybe_wrap_nonlinear_f(prob_3d) ) end + + @testset "EnzymeExt _accum_tangent! caches accumulation (#935)" begin + include("enzyme_accum_tangent.jl") + end end From dbefada45c6d5591080661abd73f06abcb9eea17 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sun, 24 May 2026 04:43:03 -0400 Subject: [PATCH 2/3] Gate non-Tunable accumulation on diff_tunables (review feedback) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the unconditional "walk every non-Tunable field whenever both sides are SciMLStructures" with an explicit `diff_tunables::Bool` kwarg on `_accum_tangent!` that mirrors the sensealg field of the same name. `diff_tunables = true` (default) — backpass returned a Tunable-only cotangent; leave non-Tunable fields of any structured `darg` alone. `diff_tunables = false` — backpass returned a structured cotangent whose meaningful contribution may live in `caches`/`initials`/etc.; walk every non-Tunable field in. The `reverse` rule derives this from `sensealg.val.diff_tunables`. Test extended to exercise both directions. Co-Authored-By: Chris Rackauckas --- .../ext/NonlinearSolveBaseEnzymeExt.jl | 43 +++++++++++++------ .../test/enzyme_accum_tangent.jl | 36 ++++++++++------ 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index 6266d9bba6..7442a40ac0 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -13,7 +13,16 @@ import SciMLStructures # - Another SciMLStructure # - A broadcastable array # In all cases, accumulation goes through the SciMLStructures interface. -function _accum_tangent!(dval, darg) +# +# `diff_tunables` mirrors the sensealg field of the same name and means +# "differentiate only the Tunable portion." When `true` (the default and +# the value carried by `SteadyStateAdjoint`/`Quadrature`/`Gauss` adjoints +# unless the user opted out) only the Tunable slice of a structured +# `darg` is accumulated. When `false`, `SciMLSensitivity.steadystatebackpass` +# returns a structured cotangent whose gradient contribution may live in +# non-Tunable fields such as `caches` (e.g. SCC sub-problem buffers feeding +# `explicitfuns!`), so those fields are walked in as well. +function _accum_tangent!(dval, darg; diff_tunables::Bool = true) if SciMLStructures.isscimlstructure(dval) && !(dval isa AbstractArray) if SciMLStructures.isscimlstructure(darg) shadow_tunables, _, _ = SciMLStructures.canonicalize( @@ -24,17 +33,12 @@ function _accum_tangent!(dval, darg) ) shadow_tunables .+= darg_tunables SciMLStructures.replace!(SciMLStructures.Tunable(), dval, shadow_tunables) - # When the upstream rule returns a structured cotangent - # (e.g. SciMLSensitivity's steadystatebackpass under - # diff_tunables = Val(false)), the gradient contribution may - # live in non-Tunable fields such as `caches` (SCC sub-problem - # buffers feeding `explicitfuns!`). Accumulate every non-Tunable - # field that both sides expose so those contributions are not - # silently dropped. - for field in fieldnames(typeof(darg)) - field === :tunable && continue - hasfield(typeof(dval), field) || continue - _accum_nested!(getfield(dval, field), getfield(darg, field)) + if !diff_tunables + for field in fieldnames(typeof(darg)) + field === :tunable && continue + hasfield(typeof(dval), field) || continue + _accum_nested!(getfield(dval, field), getfield(darg, field)) + end end elseif darg isa AbstractVector shadow_tunables, _, _ = SciMLStructures.canonicalize( @@ -129,6 +133,17 @@ function Enzyme.EnzymeRules.reverse( ) where {RT <: Enzyme.Annotation} dres, clos = tape dargs = clos(dres) + # Mirror the sensealg's `diff_tunables` (default `true`). When `false`, + # `SciMLSensitivity.steadystatebackpass` returns a structured cotangent + # whose meaningful contribution may live in non-Tunable fields such as + # `caches`, so the accumulator walks those too. + diff_tunables = let s = sensealg.val + !( + s isa SciMLBase.AbstractSensitivityAlgorithm && + hasproperty(s, :diff_tunables) && + getproperty(s, :diff_tunables) isa Val{false} + ) + end for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const continue @@ -137,9 +152,9 @@ function Enzyme.EnzymeRules.reverse( continue end if ptr isa MixedDuplicated - _accum_tangent!(ptr.dval[], darg) + _accum_tangent!(ptr.dval[], darg; diff_tunables) else - _accum_tangent!(ptr.dval, darg) + _accum_tangent!(ptr.dval, darg; diff_tunables) end end Enzyme.make_zero!(dres.u) diff --git a/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl b/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl index d3437e1a2e..8675ab220c 100644 --- a/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl +++ b/lib/NonlinearSolveBase/test/enzyme_accum_tangent.jl @@ -23,27 +23,37 @@ end const EXT = Base.get_extension(NonlinearSolveBase, :NonlinearSolveBaseEnzymeExt) -@testset "EnzymeExt._accum_tangent! walks non-Tunable fields (caches)" begin - # Regression for SciML/NonlinearSolve.jl#935: when SciMLSensitivity's - # `steadystatebackpass` returns a structured cotangent under - # `diff_tunables = Val(false)` (e.g. SCC explicitfuns! coupling), - # the gradient contribution lives in `caches`, not only `tunable`. - # The accumulator must walk those non-Tunable fields too — otherwise - # the meaningful cotangent is silently dropped and the user observes - # a zero gradient. +@testset "EnzymeExt._accum_tangent! gates non-Tunable walk on diff_tunables (#935)" begin + # Regression for SciML/NonlinearSolve.jl#935. The reverse rule + # mirrors `sensealg.diff_tunables` into this kwarg: + # + # * `diff_tunables = true` (default) — backpass returned a + # Tunable-only cotangent; leave non-Tunable fields of any + # structured `darg` alone. + # * `diff_tunables = false` — backpass returned a *structured* + # cotangent (e.g. caches from SCC `explicitfuns!` coupling). + # Walk every non-Tunable field of `darg` into `dval` so the + # meaningful contribution lands. + + # diff_tunables = false: caches must accumulate. dval = MockMTKParams([0.0, 0.0], (zeros(3),)) darg = MockMTKParams([1.0, 2.0], ([10.0, 20.0, 30.0],)) - - EXT._accum_tangent!(dval, darg) - + EXT._accum_tangent!(dval, darg; diff_tunables = false) @test dval.tunable == [1.0, 2.0] @test dval.caches[1] == [10.0, 20.0, 30.0] - # Accumulate again — verify it adds, doesn't overwrite. + # And `+=` semantics on repeated calls. darg2 = MockMTKParams([0.5, 0.5], ([1.0, 2.0, 3.0],)) - EXT._accum_tangent!(dval, darg2) + EXT._accum_tangent!(dval, darg2; diff_tunables = false) @test dval.tunable == [1.5, 2.5] @test dval.caches[1] == [11.0, 22.0, 33.0] + + # diff_tunables = true (default): only Tunable touched. + dval2 = MockMTKParams([0.0, 0.0], (zeros(3),)) + darg3 = MockMTKParams([1.0, 2.0], ([10.0, 20.0, 30.0],)) + EXT._accum_tangent!(dval2, darg3) + @test dval2.tunable == [1.0, 2.0] + @test dval2.caches[1] == zeros(3) end end From f78b70e80726af1aaf150899ecf707db53bff53c Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sun, 24 May 2026 06:07:21 -0400 Subject: [PATCH 3/3] Reverse rule: derive diff_tunables from prob.p when sensealg is default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Enzyme `reverse` rule's `sensealg` arg is whatever the user passed to `solve` — typically `Nothing`. The *inner* sensealg that ends up producing the cotangent is resolved later inside `_concrete_solve_adjoint`'s call to `automatic_sensealg_choice`, which picks `SteadyStateAdjoint(diff_tunables = Val(false))` whenever `prob.p` is a SciMLStructure with non-empty `caches` (e.g. an MTKParameters tied to an SCCNonlinearProblem's `explicitfuns!` coupling). The previous wiring only consulted the outer sensealg and therefore defaulted to `diff_tunables = true` in the SCC case, so the non-Tunable walk in `_accum_tangent!` was suppressed exactly when it was supposed to fire. Reproduce the same predicate locally (`isscimlstructure(p) && hasfield(:caches) && !isempty(p.caches)`) when the outer sensealg lacks an explicit `diff_tunables` field, so the rule mirrors what the inner adjoint will actually do. Verified with a temporary `@info` probe that `_accum_tangent!` now receives `diff_tunables=false` and accumulates the structured cotangent through `darg.caches`. (The desauty `use_scc=true` end-to-end gradient remains zero, but for an unrelated upstream reason: the inner sensealg backpass returns `darg.tunable = 0`, which is independent of this code path.) Co-Authored-By: Chris Rackauckas --- .../ext/NonlinearSolveBaseEnzymeExt.jl | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index 7442a40ac0..4b82f92593 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -133,16 +133,27 @@ function Enzyme.EnzymeRules.reverse( ) where {RT <: Enzyme.Annotation} dres, clos = tape dargs = clos(dres) - # Mirror the sensealg's `diff_tunables` (default `true`). When `false`, - # `SciMLSensitivity.steadystatebackpass` returns a structured cotangent - # whose meaningful contribution may live in non-Tunable fields such as - # `caches`, so the accumulator walks those too. - diff_tunables = let s = sensealg.val - !( - s isa SciMLBase.AbstractSensitivityAlgorithm && - hasproperty(s, :diff_tunables) && - getproperty(s, :diff_tunables) isa Val{false} - ) + # Mirror the `diff_tunables` choice the inner adjoint will make. When the + # user passes a concrete sensealg, honor its `diff_tunables` field. When + # the outer sensealg is `nothing` (default), `_concrete_solve_adjoint` + # delegates to `automatic_sensealg_choice`, which picks + # `diff_tunables = Val(false)` whenever `prob.p` is a SciMLStructure with + # a non-empty `caches` field (e.g. an MTKParameters tied to an + # SCCNonlinearProblem's `explicitfuns!` coupling). Reproducing that + # predicate here lets the accumulator walk every non-Tunable field of a + # structured `darg` so the meaningful cotangent isn't dropped. + diff_tunables = let s = sensealg.val, pv = p.val + if s isa SciMLBase.AbstractSensitivityAlgorithm && + hasproperty(s, :diff_tunables) + !(getproperty(s, :diff_tunables) isa Val{false}) + else + !( + SciMLStructures.isscimlstructure(pv) && + !(pv isa AbstractArray) && + hasfield(typeof(pv), :caches) && + !isempty(pv.caches) + ) + end end for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const