DiffEqBaseEnzymeExt: port NS#936/937 fixes (caches walk + alias preserve)#3671
Closed
ChrisRackauckas-Claude wants to merge 1 commit into
Closed
Conversation
…rve)
The Enzyme reverse rule for `DiffEqBase.solve_up` had the same two latent
bugs that NonlinearSolveBase's Enzyme extension grew over the past year on
the steady-state / NLLS adjoint path. Both are silent gradient drops; the
ODE / DDE adjoint hits them the moment `prob.p` is an MTKParameters
carrying non-empty `caches` (e.g. anything tied to MTK's `explicitfuns!`
SCC coupling or `paraminit`-driven init).
* NS#936: replace the bare `ptr.dval .+= darg` accumulation with
`_accum_tangent!`, which honors the SciMLStructures interface for
structured shadows and walks the non-Tunable fields (`caches`,
`initials`, `discrete`, `constant`, `nonnumeric`, …) when the
inner adjoint returns a structured cotangent. The `diff_tunables`
kwarg mirrors `sensealg.diff_tunables` (and, when sensealg is
`nothing`, mirrors `automatic_sensealg_choice`'s predicate: a
SciMLStructure `prob.p` with a non-empty `caches` field forces
`Val(false)`, so the non-Tunable walk fires).
* NS#937: replace `Enzyme.make_zero(res[1])::RT` with
`_make_solution_zero`, which pre-seeds the `make_zero` IdDict so
`sol.prob.p` and `sol.prob.u0` keep aliasing the outer Enzyme
shadows. Without this, the recursive `make_zero` allocates fresh
buffers for those fields and any cotangent written into the
returned `sol.prob.p` by a downstream consumer goes into a
dangling buffer.
Helpers are lifted verbatim from
`NonlinearSolveBaseEnzymeExt._accum_tangent!` /
`_make_solution_zero` (kwarg name, predicate, dispatch tree all
identical) so future fixes to one extension transplant cleanly to
the other.
Tests:
* `test/enzyme_accum_tangent.jl` — `MockMTKParams` regression that
fails on the old `.+=` path (caches stay zero) and passes with
the new accumulator under `diff_tunables = false`.
* `test/enzyme_make_solution_zero.jl` — verifies `dsol.prob.p ===
sol.prob.p` and `dsol.prob.u0 === sol.prob.u0` (and that the
`NullParameters` path doesn't crash the pre-seed helper).
Both new testsets pass locally on Julia 1.10 alongside the full
`Pkg.test(\"DiffEqBase\")` suite.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
2 tasks
Contributor
Author
|
Closing as not required. Ablation study with the desauty SCC init test rewritten to use the proper Enzyme API ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Please ignore this PR until reviewed by @ChrisRackauckas.
Summary
The Enzyme reverse rule for
DiffEqBase.solve_upinlib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jlhad the same two latent bugs thatNonlinearSolveBaseEnzymeExtaccumulated over the past year on the steady-state / NLLS adjoint path. Both are silent gradient drops, and both fire on the ODE / DDE adjoint as soon asprob.pis an MTKParameters carrying non-emptycaches(anything wired to MTK'sexplicitfuns!SCC coupling orparaminit-driven init).This PR is a direct port of the two NonlinearSolveBase fixes:
NS#936 (caches walk): replace the bare
ptr.dval .+= dargaccumulation in the reverse rule with_accum_tangent!(ptr.dval, darg; diff_tunables). The new helper:SciMLStructures.canonicalize(Tunable(), …)/replace!when the shadow is a SciMLStructure, so the Tunable slice is accumulated correctly instead of falling back to an undefined element-wise iteration.diff_tunables = false, also walks each non-Tunable field (caches,initials,discrete,constant,nonnumeric, …) of a structureddarg. This is exactly whatSciMLSensitivity.adjointbackpassreturns wheneverautomatic_sensealg_choicepicksdiff_tunables = Val(false)(i.e. wheneverprob.pis a SciMLStructure with non-emptycaches).diff_tunablesvalue is derived inside the reverse rule the same way NS#936 derives it: honorsensealg.diff_tunableswhen the user passed an explicit sensealg, else mirrorautomatic_sensealg_choice's predicate onprob.p.NS#937 (alias preserve): replace
Enzyme.make_zero(res[1])::RTin the augmented primal with_make_solution_zero(res[1])::RT. The new helper pre-seeds themake_zeroIdDictsosol.prob.pandsol.prob.u0keep aliasing the outer Enzyme shadows. Without this, recursivemake_zeroallocates fresh buffers for those fields and any cotangent written into the returnedsol.prob.pby a downstream consumer goes into a dangling buffer instead of the buffer the outer tape is tracking.The helpers are lifted verbatim from
NonlinearSolveBaseEnzymeExt(same kwarg namediff_tunables, same predicate, same_accum_nested!dispatch tree, same_preseed_alias!guard). Future fixes to one extension should transplant cleanly to the other.Reference fixes:
NonlinearSolve.jl/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl:25-89(_accum_tangent!)NonlinearSolve.jl/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl:140-153(_make_solution_zero)automatic_sensealg_choicepredicate the kwarg derivation mirrors lives inSciMLSensitivity/src/concrete_solve.jl:136-336.Test plan
New regression tests in
lib/DiffEqBase/test/:enzyme_accum_tangent.jl—MockMTKParamsregression. With the old.+=path the caches stay zero; with the new accumulator underdiff_tunables = falsethe caches accumulate and+=repeated-call semantics hold. Withdiff_tunables = true(default), only Tunable is touched. 6/6 pass.enzyme_make_solution_zero.jl— builds anODESolutionand verifiesdsol.prob.p === sol.prob.p,dsol.prob.u0 === sol.prob.u0, whiledsol.uis still a fresh zero buffer. NaiveEnzyme.make_zerofails the aliasing check (regression guard). Also covers theNullParameters/nothingpath of the pre-seed helper. 9/9 pass.Coregroup oflib/DiffEqBase/test/runtests.jl(gated onisempty(VERSION.prerelease), matching the extension's own gate).Pkg.test("DiffEqBase")on Julia 1.10 — all pre-existing tests still pass alongside the new ones.Notes
🤖 Generated with Claude Code