diff --git a/Project.toml b/Project.toml index 59576f897..019e79437 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" @@ -91,6 +92,7 @@ Mooncake = "0.5.24" Reactant = "0.2.22" NLsolve = "4.5.1" NonlinearSolve = "3.0.1, 4" +NonlinearSolveBase = "2.27" SCCNonlinearSolve = "1" Optimization = "4, 5" OptimizationOptimisers = "0.3" diff --git a/test/desauty_dae_mwe.jl b/test/desauty_dae_mwe.jl index 688548436..5b7353200 100644 --- a/test/desauty_dae_mwe.jl +++ b/test/desauty_dae_mwe.jl @@ -121,50 +121,43 @@ eqs = [ end @testset "Enzyme through init" begin - # Annotations follow the documented user-side pattern: - # `Const(loss)` for the closure that captures the mutable - # `NonlinearProblem`/`SCCNonlinearProblem`, and - # `set_runtime_activity(Reverse)` so Enzyme's activity analysis - # tolerates the runtime-activity transitions through MTK's - # `remake` path. The inner `solve` pins `NewtonRaphson()` - # explicitly so Enzyme's type analysis does not trip on the - # polyalgorithm Union NonlinearSolve would otherwise dispatch - # through. The previously-reported `EnzymeMutabilityException` - # on the mutable closure capture is correct upstream behavior - # per EnzymeAD/Enzyme.jl#3117 — annotating with `Const` is the - # fix. + # `Enzyme.gradient(Const(closure), tunables)` does not allocate + # shadows for the closure's captures, so when the closure captures + # a mutable `iprob` (whose `iprob.p.caches` is shared via the + # `SciMLStructures.replace` `@set!` repack and then mutated by + # the inner `solve!`), the derivative info carried by those cache + # writes has nowhere to land and is silently dropped. The + # idiomatic Enzyme pattern is to express the loss as a plain + # function whose captured mutable state is passed as an explicit + # `Duplicated` argument. We also reconstruct `irepack` *inside* + # the loss from the duplicated `iprob_`, so its captured + # parameter template shares the Enzyme shadow. # - # With these annotations, the plain `NonlinearProblem` case - # (use_scc = false) now passes. The `SCCNonlinearProblem` case - # (use_scc = true) still trips a `MixedDuplicated` / - # `Core.SimpleVector` MethodError further down in Enzyme's - # runtime-activity wrapping for the MTK-System / - # NonlinearSolution types involved in SCC sub-problem - # assembly — tracked in SciMLSensitivity.jl#1359. When that - # lifts, flipping `@test_broken` → `@test` in the `use_scc` - # branch is the only change needed here. - enzyme_init_loss = let iprob = iprob, irepack = irepack - p -> begin - iprob2 = remake(iprob, p = irepack(p)) - sol = solve(iprob2, NewtonRaphson()) - sum(sol.u) - end - end - if use_scc - @test_broken begin - igs = Enzyme.gradient( - Enzyme.set_runtime_activity(Enzyme.Reverse), - Enzyme.Const(enzyme_init_loss), itunables, - ) - !iszero(sum(igs)) - end - else - igs = Enzyme.gradient( - Enzyme.set_runtime_activity(Enzyme.Reverse), - Enzyme.Const(enzyme_init_loss), itunables, + # The inner `solve` pins `NewtonRaphson()` explicitly so Enzyme's + # type analysis does not trip on the polyalgorithm Union + # NonlinearSolve would otherwise dispatch through. The + # previously-reported `EnzymeMutabilityException` on the mutable + # closure capture is correct upstream behavior per + # EnzymeAD/Enzyme.jl#3117 — annotating with `Const` is the fix. + function enzyme_init_loss(t, iprob_) + _, irepack_, _ = SS.canonicalize( + SS.Tunable(), parameter_values(iprob_), ) - @test !iszero(sum(igs)) + iprob2 = remake(iprob_, p = irepack_(t)) + sol = solve(iprob2, NewtonRaphson()) + return sum(sol.u) end + diprob = Enzyme.make_zero(iprob) + dtunables = zero(itunables) + Enzyme.autodiff( + Enzyme.set_runtime_activity(Enzyme.Reverse), + Enzyme.Const(enzyme_init_loss), + Enzyme.Active, + Enzyme.Duplicated(itunables, dtunables), + Enzyme.Duplicated(iprob, diprob), + ) + @test !iszero(sum(dtunables)) + @test isapprox(dtunables, fd_init_grad, rtol = 0.05) end @testset "Mooncake through init" begin