From d813e910fefdb37e3c97132b4ac23dc799585052 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sun, 24 May 2026 07:05:17 -0400 Subject: [PATCH] EnzymeExt: preserve prob.p/prob.u0 aliasing in return-value shadow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Enzyme.make_zero(sol)` in the augmented_primal return path recursively allocates fresh zero buffers for every mutable field of the `NonlinearSolution`, including `sol.prob.p` and `sol.prob.u0`. Those fields alias the outer caller's active `p` / `u0` shadows, so severing the aliasing means any cotangent a downstream consumer writes back into `sol.prob.p` (or `.u0`) lands in a dangling buffer instead of the buffer the outer Enzyme tape is tracking, silently dropping that contribution. Replace the call with `_make_solution_zero(sol)`, which pre-seeds the `make_zero` IdDict so `prob.p` and `prob.u0` map to themselves and the recursion short-circuits — the original buffers are reused verbatim while `sol.u` (the actual derivative-carrying field) still gets a fresh zero buffer. Guards `nothing` parameters and non-mutable values. Unit test asserts (a) naive `Enzyme.make_zero` breaks aliasing on a `NonlinearProblem`-backed solution, (b) `_make_solution_zero` preserves it (`===` and `objectid`), (c) `sol.u` is a fresh zero buffer, (d) `nothing` `p` doesn't crash the pre-seed helper. Independent of the `_accum_tangent!` caches-walk work in this branch. Does not by itself fix the unrelated polyalg `MixedDuplicated` MethodError, which has been traced to Enzyme's `create_activity_wrapper` emitting `MixedDuplicated(::T, ::T)` for `wrap_sol(::NonlinearSolution)` on the type-unstable generic dispatch path, an upstream Enzyme issue. Co-Authored-By: Chris Rackauckas --- lib/NonlinearSolveBase/Project.toml | 2 +- .../ext/NonlinearSolveBaseEnzymeExt.jl | 31 ++++++++++- .../test/enzyme_make_solution_zero.jl | 54 +++++++++++++++++++ lib/NonlinearSolveBase/test/runtests.jl | 4 ++ 4 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 61e56aa65..ce2dc5a83 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,6 +1,6 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" -version = "2.26.1" +version = "2.26.2" authors = ["Avik Pal and contributors"] [deps] diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index 4b82f9259..f007a6d2c 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -115,7 +115,7 @@ function Enzyme.EnzymeRules.augmented_primal( kwargs... ) - dres = Enzyme.make_zero(res[1]) + dres = _make_solution_zero(res[1]) primal = EnzymeRules.needs_primal(config) ? res[1] : nothing shadow = EnzymeRules.needs_shadow(config) ? dres : nothing tup = (dres, res[2]) @@ -123,6 +123,35 @@ function Enzyme.EnzymeRules.augmented_primal( return RetType(primal, shadow, tup::Any) end +# Build the shadow `NonlinearSolution` for the augmented primal. A plain +# `Enzyme.make_zero(sol)` recursively allocates fresh zero buffers for every +# mutable field of `sol`, including `sol.prob.p` and `sol.prob.u0`, which are +# aliased to the active `p` / `u0` shadows the outer caller is differentiating +# against. Severing that aliasing means a cotangent written into the returned +# `sol.prob.p` field by a downstream consumer (e.g. anything reading the +# solution's parameters) goes into a dangling buffer instead of the buffer +# the outer Enzyme tape is tracking, silently dropping that gradient +# contribution. +# +# Pre-seed the `make_zero` seen-set so `prob.p` and `prob.u0` map to +# themselves: recursion into those fields short-circuits via `haskey(seen, …)`, +# preserving aliasing with the outer shadow. The actual derivative-carrying +# field (`sol.u`) still gets a fresh zero buffer. +@inline function _make_solution_zero(sol) + seen = IdDict() + _preseed_alias!(seen, sol.prob.p) + _preseed_alias!(seen, sol.prob.u0) + return Enzyme.make_zero(Core.Typeof(sol), seen, sol, Val(false)) +end + +@inline _preseed_alias!(::IdDict, ::Nothing) = nothing +@inline function _preseed_alias!(seen::IdDict, v) + if ismutable(v) + seen[v] = v + end + return nothing +end + function Enzyme.EnzymeRules.reverse( config::Enzyme.EnzymeRules.RevConfigWidth{1}, func::Const{typeof(NonlinearSolveBase.solve_up)}, ::Type{RT}, tape, prob, diff --git a/lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl b/lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl new file mode 100644 index 000000000..d92c97e74 --- /dev/null +++ b/lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl @@ -0,0 +1,54 @@ +module EnzymeMakeSolutionZeroTests + +using Test +using NonlinearSolveBase +import ChainRulesCore, Enzyme # triggers NonlinearSolveBaseEnzymeExt +using SciMLBase + +const EXT = Base.get_extension(NonlinearSolveBase, :NonlinearSolveBaseEnzymeExt) + +@testset "EnzymeExt._make_solution_zero preserves prob.p / prob.u0 aliasing" begin + # The reverse rule builds the return-value shadow via `make_zero(sol)`. + # A plain `Enzyme.make_zero` recursively zeros every mutable field of + # `sol`, including `sol.prob.p` and `sol.prob.u0`, which the outer caller + # has already registered as active shadows for the `p` / `u0` arguments. + # Severing that aliasing means any cotangent written into the returned + # `sol.prob.p` (or `.u0`) by a downstream consumer lands in a dangling + # buffer instead of the one the outer Enzyme tape is tracking, silently + # dropping that gradient contribution. + # + # `_make_solution_zero` pre-seeds the `make_zero` seen-set with identity + # entries for `prob.p` and `prob.u0` so the recursion short-circuits and + # the original buffers are reused verbatim in the shadow. + + f(u, p) = u .^ 2 .- p + u0 = [1.0, 1.0] + p = [2.0, 4.0] + prob = NonlinearProblem{false}(f, u0, p) + sol = SciMLBase.build_solution(prob, nothing, [1.5, 2.0], zeros(2)) + + # Naive `Enzyme.make_zero` allocates fresh buffers for prob.p / prob.u0. + dsol_naive = Enzyme.make_zero(sol) + @test objectid(dsol_naive.prob.p) != objectid(sol.prob.p) + @test objectid(dsol_naive.prob.u0) != objectid(sol.prob.u0) + + # The extension helper keeps them aliased to the primal. + dsol = EXT._make_solution_zero(sol) + @test objectid(dsol.prob.p) == objectid(sol.prob.p) + @test objectid(dsol.prob.u0) == objectid(sol.prob.u0) + @test dsol.prob.p === sol.prob.p + @test dsol.prob.u0 === sol.prob.u0 + # The actual derivative-carrying field (u) is still a fresh zero buffer. + @test objectid(dsol.u) != objectid(sol.u) + @test all(iszero, dsol.u) + + # Guard the `nothing`/non-mutable path: a problem with `nothing` u0 or p + # must not crash the pre-seed helper. + prob_nop = NonlinearProblem{false}(f, u0, nothing) + sol_nop = SciMLBase.build_solution(prob_nop, nothing, [1.5, 2.0], zeros(2)) + dsol_nop = EXT._make_solution_zero(sol_nop) + @test dsol_nop.prob.p === nothing + @test dsol_nop.prob.u0 === sol_nop.prob.u0 +end + +end diff --git a/lib/NonlinearSolveBase/test/runtests.jl b/lib/NonlinearSolveBase/test/runtests.jl index df91f4d6a..8084e3dc1 100644 --- a/lib/NonlinearSolveBase/test/runtests.jl +++ b/lib/NonlinearSolveBase/test/runtests.jl @@ -134,4 +134,8 @@ using InteractiveUtils, Test @testset "EnzymeExt _accum_tangent! caches accumulation (#935)" begin include("enzyme_accum_tangent.jl") end + + @testset "EnzymeExt _make_solution_zero preserves prob.p/u0 aliasing" begin + include("enzyme_make_solution_zero.jl") + end end