Skip to content

DiffEqBaseEnzymeExt: port NS#936/937 fixes (caches walk + alias preserve)#3671

Closed
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:enzyme-accum-caches-and-alias
Closed

DiffEqBaseEnzymeExt: port NS#936/937 fixes (caches walk + alias preserve)#3671
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:enzyme-accum-caches-and-alias

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Please ignore this PR until reviewed by @ChrisRackauckas.

Summary

The Enzyme reverse rule for DiffEqBase.solve_up in lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl had the same two latent bugs that NonlinearSolveBaseEnzymeExt accumulated over the past year on the steady-state / NLLS adjoint path. Both are silent gradient drops, and both fire on the ODE / DDE adjoint as soon as prob.p is an MTKParameters carrying non-empty caches (anything wired to MTK's explicitfuns! SCC coupling or paraminit-driven init).

This PR is a direct port of the two NonlinearSolveBase fixes:

  • NS#936 (caches walk): replace the bare ptr.dval .+= darg accumulation in the reverse rule with _accum_tangent!(ptr.dval, darg; diff_tunables). The new helper:

    • Goes through SciMLStructures.canonicalize(Tunable(), …) / replace! when the shadow is a SciMLStructure, so the Tunable slice is accumulated correctly instead of falling back to an undefined element-wise iteration.
    • When diff_tunables = false, also walks each non-Tunable field (caches, initials, discrete, constant, nonnumeric, …) of a structured darg. This is exactly what SciMLSensitivity.adjointbackpass returns whenever automatic_sensealg_choice picks diff_tunables = Val(false) (i.e. whenever prob.p is a SciMLStructure with non-empty caches).
    • The diff_tunables value is derived inside the reverse rule the same way NS#936 derives it: honor sensealg.diff_tunables when the user passed an explicit sensealg, else mirror automatic_sensealg_choice's predicate on prob.p.
  • NS#937 (alias preserve): replace Enzyme.make_zero(res[1])::RT in the augmented primal with _make_solution_zero(res[1])::RT. The new helper pre-seeds the make_zero IdDict so sol.prob.p and sol.prob.u0 keep aliasing the outer Enzyme shadows. Without this, recursive make_zero allocates fresh buffers for those fields and any cotangent written into the returned sol.prob.p by a downstream consumer goes into a dangling buffer instead of the buffer the outer tape is tracking.

The helpers are lifted verbatim from NonlinearSolveBaseEnzymeExt (same kwarg name diff_tunables, same predicate, same _accum_nested! dispatch tree, same _preseed_alias! guard). Future fixes to one extension should transplant cleanly to the other.

Reference fixes:

  • NonlinearSolve.jl/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl:25-89 (_accum_tangent!)
  • NonlinearSolve.jl/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl:140-153 (_make_solution_zero)

automatic_sensealg_choice predicate the kwarg derivation mirrors lives in SciMLSensitivity/src/concrete_solve.jl:136-336.

Test plan

New regression tests in lib/DiffEqBase/test/:

  • enzyme_accum_tangent.jlMockMTKParams regression. With the old .+= path the caches stay zero; with the new accumulator under diff_tunables = false the caches accumulate and += repeated-call semantics hold. With diff_tunables = true (default), only Tunable is touched. 6/6 pass.
  • enzyme_make_solution_zero.jl — builds an ODESolution and verifies dsol.prob.p === sol.prob.p, dsol.prob.u0 === sol.prob.u0, while dsol.u is still a fresh zero buffer. Naive Enzyme.make_zero fails the aliasing check (regression guard). Also covers the NullParameters / nothing path of the pre-seed helper. 9/9 pass.
  • Both testsets are wired into the Core group of lib/DiffEqBase/test/runtests.jl (gated on isempty(VERSION.prerelease), matching the extension's own gate).
  • Full Pkg.test("DiffEqBase") on Julia 1.10 — all pre-existing tests still pass alongside the new ones.

Notes

  • End-to-end Enzyme.gradient on a real ODE with MTKParameters + non-empty caches is best verified against the desauty-equivalent reproducer post-merge — that requires a much heavier MTK / SciMLSensitivity setup than is appropriate for this PR's unit-test surface.

🤖 Generated with Claude Code

…rve)

The Enzyme reverse rule for `DiffEqBase.solve_up` had the same two latent
bugs that NonlinearSolveBase's Enzyme extension grew over the past year on
the steady-state / NLLS adjoint path. Both are silent gradient drops; the
ODE / DDE adjoint hits them the moment `prob.p` is an MTKParameters
carrying non-empty `caches` (e.g. anything tied to MTK's `explicitfuns!`
SCC coupling or `paraminit`-driven init).

* NS#936: replace the bare `ptr.dval .+= darg` accumulation with
  `_accum_tangent!`, which honors the SciMLStructures interface for
  structured shadows and walks the non-Tunable fields (`caches`,
  `initials`, `discrete`, `constant`, `nonnumeric`, …) when the
  inner adjoint returns a structured cotangent. The `diff_tunables`
  kwarg mirrors `sensealg.diff_tunables` (and, when sensealg is
  `nothing`, mirrors `automatic_sensealg_choice`'s predicate: a
  SciMLStructure `prob.p` with a non-empty `caches` field forces
  `Val(false)`, so the non-Tunable walk fires).

* NS#937: replace `Enzyme.make_zero(res[1])::RT` with
  `_make_solution_zero`, which pre-seeds the `make_zero` IdDict so
  `sol.prob.p` and `sol.prob.u0` keep aliasing the outer Enzyme
  shadows. Without this, the recursive `make_zero` allocates fresh
  buffers for those fields and any cotangent written into the
  returned `sol.prob.p` by a downstream consumer goes into a
  dangling buffer.

Helpers are lifted verbatim from
`NonlinearSolveBaseEnzymeExt._accum_tangent!` /
`_make_solution_zero` (kwarg name, predicate, dispatch tree all
identical) so future fixes to one extension transplant cleanly to
the other.

Tests:
  * `test/enzyme_accum_tangent.jl` — `MockMTKParams` regression that
    fails on the old `.+=` path (caches stay zero) and passes with
    the new accumulator under `diff_tunables = false`.
  * `test/enzyme_make_solution_zero.jl` — verifies `dsol.prob.p ===
    sol.prob.p` and `dsol.prob.u0 === sol.prob.u0` (and that the
    `NullParameters` path doesn't crash the pre-seed helper).

Both new testsets pass locally on Julia 1.10 alongside the full
`Pkg.test(\"DiffEqBase\")` suite.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Closing as not required. Ablation study with the desauty SCC init test rewritten to use the proper Enzyme API (Enzyme.autodiff + Duplicated(iprob), SciML/SciMLSensitivity.jl#1454) confirms this port doesn't change the outcome — desauty still passes with this PR reverted (DiffEqBase pinned to registry v7.5.1). The bug pattern this PR mirrors (NS#936/#937) is real, but neither half is load-bearing on the ODE adjoint path when downstream callers follow the documented Enzyme contract. Can be reopened if a concrete test case demonstrates the need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants