MTKParameters repack: allocate fresh caches buffers (fixes Enzyme silent-zero on SCC init)#4557
Conversation
The `repack` closure returned by `SciMLStructures.canonicalize(::Tunable, ::MTKParameters)` previously delegated to `SciMLStructures.replace`, which only sets the `tunable` field via `@set!` and keeps every other field (including the `caches` tuple of mutable `Vector`s) aliased to the captured template `p`. Each call to `repack(new_tunable)` therefore produced an `MTKParameters` whose `caches` Vector buffers were shared identities with the closure's captured `iprob.p.caches`. This aliasing breaks Enzyme reverse-mode AD when a loss closure (e.g. `p -> sum(solve(remake(iprob, p = irepack(p))).u)`) is differentiated. Because Enzyme sees the same Vector identity on both sides of the call, it does not allocate a separate shadow buffer for the cache. Cotangents flowing into `p.caches[i]` either silently produce zero (silent-zero) or read back the primal value as the cotangent (wrong-but-nonzero, roughly `Jᵀ × primal_value`). This is the root cause of the long-standing desauty `use_scc=true` Enzyme zero-gradient bug. The fix: have `repack` construct a new `MTKParameters` with freshly allocated `caches` buffers (`ntuple(i -> copy(p.caches[i]), …)`), so each repack produces its own independent cache storage. Other fields (`initials`, `discrete`, `constant`, `nonnumeric`) are left shared with the template, matching previous behavior, since they have not been demonstrated to cause AD issues. Note: the symmetric `canonicalize(::Initials, ::MTKParameters)` and the generic `for (Portion, …)` loop covering `Discrete`/`Constants`/ `Nonnumeric`/`Caches` use the same shallow-replace pattern. If downstream AD users hit analogous problems through those portions, they will need similar treatment — but the demonstrated bug is in `Tunable`, so this patch is scoped accordingly. Verification: - `objectid(repack(t).caches[1]) != objectid(p.caches[1])` (alias broken) - End-to-end desauty MWE: Enzyme reverse-mode now matches FiniteDiff to 8 significant figures on the SCC-initialized ODE (was zero before). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
|
Holding this PR — the repack-fresh-caches behavior in this diff is treating the symptom. The semantics of The reproducer Won't merge this PR; closing once the Enzyme issue is filed and acknowledged. |
|
Minimal upstream reproducer filed: EnzymeAD/Enzyme.jl#3124 The bug is purely Enzyme — closure-captured mutable buffer aliased into a new struct → Closing this PR as soon as Enzyme.jl#3124 lands — this MTK-side workaround is the wrong layer to fix it. |
|
Closing — this fix is treating a symptom, not the real issue. The repack-aliasing-caches behavior is intended MTK semantics ( |
Summary
Fixes a long-standing Enzyme reverse-mode AD bug where gradients through SCC-initialized ODE problems silently return zero (or wrong-but-nonzero values).
The bug
SciMLStructures.canonicalize(::Tunable, ::MTKParameters)returns arepackclosure that delegates toSciMLStructures.replace(::Tunable, p, newvals), which is shallow:@set! p.tunable = newvals; return p. Every other field of the returnedMTKParameters(including the mutableVectorbuffers inp.caches) is therefore aliased — by object identity — to the captured templatep.When a loss closure of the form
is differentiated with
Enzyme.gradient(Reverse, Const(loss), tunable), Enzyme sees the sameVectoridentity on both sides of theirepackcall. Because no separate shadow buffer is allocated for the aliased cache, cotangents flowing intop.caches[i]during reverse-mode either silently produce zero (silent-zero) or read back the primal as the cotangent (wrong-but-nonzero, roughlyJᵀ × primal_value). This is the root cause of the long-standing desautyuse_scc=trueEnzyme zero-gradient bug.The aliasing was confirmed via a
cache_alias_break_test.jlexperiment: when the loss explicitly copies caches viaConstructionBase.setproperties(p_initial; caches = ntuple(i -> copy(p_initial.caches[i]), …)), Enzyme matches FiniteDiff to 8 significant figures.The fix
lib/ModelingToolkitBase/src/systems/parameter_buffer.jl: have therepackclosure for theTunableportion construct a newMTKParameterswith freshly allocatedcachesbuffers (ntuple(i -> copy(p.caches[i]), length(p.caches))). Other fields (initials,discrete,constant,nonnumeric) are left shared with the template, matching previous behavior, since they have not been demonstrated to cause AD issues.Other shallow-replace sites (not changed here)
The symmetric
canonicalize(::Initials, ::MTKParameters)and the genericfor (Portion, …)loop coveringDiscrete/Constants/Nonnumeric/Cachesuse the same shallow-replace pattern (@set! p.field = newvals). If downstream AD users hit analogous problems through those portions, they will need similar treatment — but the demonstrated bug is inTunable, so this patch is scoped accordingly. Happy to extend if reviewers prefer symmetry.Verification
Alias-break smoke check (objectid before/after):
End-to-end desauty MWE (the bug-reproducing case): Enzyme reverse-mode now matches FiniteDiff:
[0.0, 0.0, 0.3662747126219529, 0.3100659003339509, 0.0, 0.0, 0.7325494252439058, 0.0, 0.0, 0.0][0.0, 0.0, 0.3662747126234614, 0.3100659003292303, 0.0, 0.0, 0.7325494252469228, 0.0, 0.0, 0.0]Test plan
objectid(repack(t).caches[i]) != objectid(p.caches[i])for all cachesNote
Draft — please ignore until reviewed by @ChrisRackauckas.