Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand Down Expand Up @@ -91,6 +92,7 @@ Mooncake = "0.5.24"
Reactant = "0.2.22"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1, 4"
NonlinearSolveBase = "2.27"
SCCNonlinearSolve = "1"
Optimization = "4, 5"
OptimizationOptimisers = "0.3"
Expand Down
75 changes: 34 additions & 41 deletions test/desauty_dae_mwe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,50 +121,43 @@ eqs = [
end

@testset "Enzyme through init" begin
# Annotations follow the documented user-side pattern:
# `Const(loss)` for the closure that captures the mutable
# `NonlinearProblem`/`SCCNonlinearProblem`, and
# `set_runtime_activity(Reverse)` so Enzyme's activity analysis
# tolerates the runtime-activity transitions through MTK's
# `remake` path. The inner `solve` pins `NewtonRaphson()`
# explicitly so Enzyme's type analysis does not trip on the
# polyalgorithm Union NonlinearSolve would otherwise dispatch
# through. The previously-reported `EnzymeMutabilityException`
# on the mutable closure capture is correct upstream behavior
# per EnzymeAD/Enzyme.jl#3117 — annotating with `Const` is the
# fix.
# `Enzyme.gradient(Const(closure), tunables)` does not allocate
# shadows for the closure's captures, so when the closure captures
# a mutable `iprob` (whose `iprob.p.caches` is shared via the
# `SciMLStructures.replace` `@set!` repack and then mutated by
# the inner `solve!`), the derivative info carried by those cache
# writes has nowhere to land and is silently dropped. The
# idiomatic Enzyme pattern is to express the loss as a plain
# function whose captured mutable state is passed as an explicit
# `Duplicated` argument. We also reconstruct `irepack` *inside*
# the loss from the duplicated `iprob_`, so its captured
# parameter template shares the Enzyme shadow.
#
# With these annotations, the plain `NonlinearProblem` case
# (use_scc = false) now passes. The `SCCNonlinearProblem` case
# (use_scc = true) still trips a `MixedDuplicated` /
# `Core.SimpleVector` MethodError further down in Enzyme's
# runtime-activity wrapping for the MTK-System /
# NonlinearSolution types involved in SCC sub-problem
# assembly — tracked in SciMLSensitivity.jl#1359. When that
# lifts, flipping `@test_broken` → `@test` in the `use_scc`
# branch is the only change needed here.
enzyme_init_loss = let iprob = iprob, irepack = irepack
p -> begin
iprob2 = remake(iprob, p = irepack(p))
sol = solve(iprob2, NewtonRaphson())
sum(sol.u)
end
end
if use_scc
@test_broken begin
igs = Enzyme.gradient(
Enzyme.set_runtime_activity(Enzyme.Reverse),
Enzyme.Const(enzyme_init_loss), itunables,
)
!iszero(sum(igs))
end
else
igs = Enzyme.gradient(
Enzyme.set_runtime_activity(Enzyme.Reverse),
Enzyme.Const(enzyme_init_loss), itunables,
# The inner `solve` pins `NewtonRaphson()` explicitly so Enzyme's
# type analysis does not trip on the polyalgorithm Union
# NonlinearSolve would otherwise dispatch through. The
# previously-reported `EnzymeMutabilityException` on the mutable
# closure capture is correct upstream behavior per
# EnzymeAD/Enzyme.jl#3117 — annotating with `Const` is the fix.
function enzyme_init_loss(t, iprob_)
_, irepack_, _ = SS.canonicalize(
SS.Tunable(), parameter_values(iprob_),
)
@test !iszero(sum(igs))
iprob2 = remake(iprob_, p = irepack_(t))
sol = solve(iprob2, NewtonRaphson())
return sum(sol.u)
end
diprob = Enzyme.make_zero(iprob)
dtunables = zero(itunables)
Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.Reverse),
Enzyme.Const(enzyme_init_loss),
Enzyme.Active,
Enzyme.Duplicated(itunables, dtunables),
Enzyme.Duplicated(iprob, diprob),
)
@test !iszero(sum(dtunables))
@test isapprox(dtunables, fd_init_grad, rtol = 0.05)
end

@testset "Mooncake through init" begin
Expand Down
Loading