diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 61e56aa65..ce2dc5a83 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,6 +1,6 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" -version = "2.26.1" +version = "2.26.2" authors = ["Avik Pal and contributors"] [deps] diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index 4b82f9259..f007a6d2c 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -115,7 +115,7 @@ function Enzyme.EnzymeRules.augmented_primal( kwargs... ) - dres = Enzyme.make_zero(res[1]) + dres = _make_solution_zero(res[1]) primal = EnzymeRules.needs_primal(config) ? res[1] : nothing shadow = EnzymeRules.needs_shadow(config) ? dres : nothing tup = (dres, res[2]) @@ -123,6 +123,35 @@ function Enzyme.EnzymeRules.augmented_primal( return RetType(primal, shadow, tup::Any) end +# Build the shadow `NonlinearSolution` for the augmented primal. A plain +# `Enzyme.make_zero(sol)` recursively allocates fresh zero buffers for every +# mutable field of `sol`, including `sol.prob.p` and `sol.prob.u0`, which are +# aliased to the active `p` / `u0` shadows the outer caller is differentiating +# against. Severing that aliasing means a cotangent written into the returned +# `sol.prob.p` field by a downstream consumer (e.g. anything reading the +# solution's parameters) goes into a dangling buffer instead of the buffer +# the outer Enzyme tape is tracking, silently dropping that gradient +# contribution. +# +# Pre-seed the `make_zero` seen-set so `prob.p` and `prob.u0` map to +# themselves: recursion into those fields short-circuits via `haskey(seen, …)`, +# preserving aliasing with the outer shadow. The actual derivative-carrying +# field (`sol.u`) still gets a fresh zero buffer. +@inline function _make_solution_zero(sol) + seen = IdDict() + _preseed_alias!(seen, sol.prob.p) + _preseed_alias!(seen, sol.prob.u0) + return Enzyme.make_zero(Core.Typeof(sol), seen, sol, Val(false)) +end + +@inline _preseed_alias!(::IdDict, ::Nothing) = nothing +@inline function _preseed_alias!(seen::IdDict, v) + if ismutable(v) + seen[v] = v + end + return nothing +end + function Enzyme.EnzymeRules.reverse( config::Enzyme.EnzymeRules.RevConfigWidth{1}, func::Const{typeof(NonlinearSolveBase.solve_up)}, ::Type{RT}, tape, prob, diff --git a/lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl b/lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl new file mode 100644 index 000000000..d92c97e74 --- /dev/null +++ b/lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl @@ -0,0 +1,54 @@ +module EnzymeMakeSolutionZeroTests + +using Test +using NonlinearSolveBase +import ChainRulesCore, Enzyme # triggers NonlinearSolveBaseEnzymeExt +using SciMLBase + +const EXT = Base.get_extension(NonlinearSolveBase, :NonlinearSolveBaseEnzymeExt) + +@testset "EnzymeExt._make_solution_zero preserves prob.p / prob.u0 aliasing" begin + # The reverse rule builds the return-value shadow via `make_zero(sol)`. + # A plain `Enzyme.make_zero` recursively zeros every mutable field of + # `sol`, including `sol.prob.p` and `sol.prob.u0`, which the outer caller + # has already registered as active shadows for the `p` / `u0` arguments. + # Severing that aliasing means any cotangent written into the returned + # `sol.prob.p` (or `.u0`) by a downstream consumer lands in a dangling + # buffer instead of the one the outer Enzyme tape is tracking, silently + # dropping that gradient contribution. + # + # `_make_solution_zero` pre-seeds the `make_zero` seen-set with identity + # entries for `prob.p` and `prob.u0` so the recursion short-circuits and + # the original buffers are reused verbatim in the shadow. + + f(u, p) = u .^ 2 .- p + u0 = [1.0, 1.0] + p = [2.0, 4.0] + prob = NonlinearProblem{false}(f, u0, p) + sol = SciMLBase.build_solution(prob, nothing, [1.5, 2.0], zeros(2)) + + # Naive `Enzyme.make_zero` allocates fresh buffers for prob.p / prob.u0. + dsol_naive = Enzyme.make_zero(sol) + @test objectid(dsol_naive.prob.p) != objectid(sol.prob.p) + @test objectid(dsol_naive.prob.u0) != objectid(sol.prob.u0) + + # The extension helper keeps them aliased to the primal. + dsol = EXT._make_solution_zero(sol) + @test objectid(dsol.prob.p) == objectid(sol.prob.p) + @test objectid(dsol.prob.u0) == objectid(sol.prob.u0) + @test dsol.prob.p === sol.prob.p + @test dsol.prob.u0 === sol.prob.u0 + # The actual derivative-carrying field (u) is still a fresh zero buffer. + @test objectid(dsol.u) != objectid(sol.u) + @test all(iszero, dsol.u) + + # Guard the `nothing`/non-mutable path: a problem with `nothing` u0 or p + # must not crash the pre-seed helper. + prob_nop = NonlinearProblem{false}(f, u0, nothing) + sol_nop = SciMLBase.build_solution(prob_nop, nothing, [1.5, 2.0], zeros(2)) + dsol_nop = EXT._make_solution_zero(sol_nop) + @test dsol_nop.prob.p === nothing + @test dsol_nop.prob.u0 === sol_nop.prob.u0 +end + +end diff --git a/lib/NonlinearSolveBase/test/runtests.jl b/lib/NonlinearSolveBase/test/runtests.jl index df91f4d6a..8084e3dc1 100644 --- a/lib/NonlinearSolveBase/test/runtests.jl +++ b/lib/NonlinearSolveBase/test/runtests.jl @@ -134,4 +134,8 @@ using InteractiveUtils, Test @testset "EnzymeExt _accum_tangent! caches accumulation (#935)" begin include("enzyme_accum_tangent.jl") end + + @testset "EnzymeExt _make_solution_zero preserves prob.p/u0 aliasing" begin + include("enzyme_make_solution_zero.jl") + end end