Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NonlinearSolveBase"
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
version = "2.26.1"
version = "2.26.2"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]

[deps]
Expand Down
31 changes: 30 additions & 1 deletion lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,43 @@ function Enzyme.EnzymeRules.augmented_primal(
kwargs...
)

dres = Enzyme.make_zero(res[1])
dres = _make_solution_zero(res[1])
primal = EnzymeRules.needs_primal(config) ? res[1] : nothing
shadow = EnzymeRules.needs_shadow(config) ? dres : nothing
tup = (dres, res[2])
RetType = Enzyme.EnzymeRules.augmented_rule_return_type(config, RT)
return RetType(primal, shadow, tup::Any)
end

# Build the shadow `NonlinearSolution` for the augmented primal. A plain
# `Enzyme.make_zero(sol)` recursively allocates fresh zero buffers for every
# mutable field of `sol`, including `sol.prob.p` and `sol.prob.u0`, which are
# aliased to the active `p` / `u0` shadows the outer caller is differentiating
# against. Severing that aliasing means a cotangent written into the returned
# `sol.prob.p` field by a downstream consumer (e.g. anything reading the
# solution's parameters) goes into a dangling buffer instead of the buffer
# the outer Enzyme tape is tracking, silently dropping that gradient
# contribution.
#
# Pre-seed the `make_zero` seen-set so `prob.p` and `prob.u0` map to
# themselves: recursion into those fields short-circuits via `haskey(seen, …)`,
# preserving aliasing with the outer shadow. The actual derivative-carrying
# field (`sol.u`) still gets a fresh zero buffer.
@inline function _make_solution_zero(sol)
seen = IdDict()
_preseed_alias!(seen, sol.prob.p)
_preseed_alias!(seen, sol.prob.u0)
return Enzyme.make_zero(Core.Typeof(sol), seen, sol, Val(false))
end

@inline _preseed_alias!(::IdDict, ::Nothing) = nothing
@inline function _preseed_alias!(seen::IdDict, v)
if ismutable(v)
seen[v] = v
end
return nothing
end

function Enzyme.EnzymeRules.reverse(
config::Enzyme.EnzymeRules.RevConfigWidth{1},
func::Const{typeof(NonlinearSolveBase.solve_up)}, ::Type{RT}, tape, prob,
Expand Down
54 changes: 54 additions & 0 deletions lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
module EnzymeMakeSolutionZeroTests

using Test
using NonlinearSolveBase
import ChainRulesCore, Enzyme # triggers NonlinearSolveBaseEnzymeExt
using SciMLBase

const EXT = Base.get_extension(NonlinearSolveBase, :NonlinearSolveBaseEnzymeExt)

@testset "EnzymeExt._make_solution_zero preserves prob.p / prob.u0 aliasing" begin
# The reverse rule builds the return-value shadow via `make_zero(sol)`.
# A plain `Enzyme.make_zero` recursively zeros every mutable field of
# `sol`, including `sol.prob.p` and `sol.prob.u0`, which the outer caller
# has already registered as active shadows for the `p` / `u0` arguments.
# Severing that aliasing means any cotangent written into the returned
# `sol.prob.p` (or `.u0`) by a downstream consumer lands in a dangling
# buffer instead of the one the outer Enzyme tape is tracking, silently
# dropping that gradient contribution.
#
# `_make_solution_zero` pre-seeds the `make_zero` seen-set with identity
# entries for `prob.p` and `prob.u0` so the recursion short-circuits and
# the original buffers are reused verbatim in the shadow.

f(u, p) = u .^ 2 .- p
u0 = [1.0, 1.0]
p = [2.0, 4.0]
prob = NonlinearProblem{false}(f, u0, p)
sol = SciMLBase.build_solution(prob, nothing, [1.5, 2.0], zeros(2))

# Naive `Enzyme.make_zero` allocates fresh buffers for prob.p / prob.u0.
dsol_naive = Enzyme.make_zero(sol)
@test objectid(dsol_naive.prob.p) != objectid(sol.prob.p)
@test objectid(dsol_naive.prob.u0) != objectid(sol.prob.u0)

# The extension helper keeps them aliased to the primal.
dsol = EXT._make_solution_zero(sol)
@test objectid(dsol.prob.p) == objectid(sol.prob.p)
@test objectid(dsol.prob.u0) == objectid(sol.prob.u0)
@test dsol.prob.p === sol.prob.p
@test dsol.prob.u0 === sol.prob.u0
# The actual derivative-carrying field (u) is still a fresh zero buffer.
@test objectid(dsol.u) != objectid(sol.u)
@test all(iszero, dsol.u)

# Guard the `nothing`/non-mutable path: a problem with `nothing` u0 or p
# must not crash the pre-seed helper.
prob_nop = NonlinearProblem{false}(f, u0, nothing)
sol_nop = SciMLBase.build_solution(prob_nop, nothing, [1.5, 2.0], zeros(2))
dsol_nop = EXT._make_solution_zero(sol_nop)
@test dsol_nop.prob.p === nothing
@test dsol_nop.prob.u0 === sol_nop.prob.u0
end

end
4 changes: 4 additions & 0 deletions lib/NonlinearSolveBase/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,8 @@ using InteractiveUtils, Test
@testset "EnzymeExt _accum_tangent! caches accumulation (#935)" begin
include("enzyme_accum_tangent.jl")
end

@testset "EnzymeExt _make_solution_zero preserves prob.p/u0 aliasing" begin
include("enzyme_make_solution_zero.jl")
end
end
Loading