Skip to content

MTKParameters repack: allocate fresh caches buffers (fixes Enzyme silent-zero on SCC init)#4557

Closed
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:enzyme-fresh-caches-on-repack
Closed

MTKParameters repack: allocate fresh caches buffers (fixes Enzyme silent-zero on SCC init)#4557
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:enzyme-fresh-caches-on-repack

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown

Summary

Fixes a long-standing Enzyme reverse-mode AD bug where gradients through SCC-initialized ODE problems silently return zero (or wrong-but-nonzero values).

The bug

SciMLStructures.canonicalize(::Tunable, ::MTKParameters) returns a repack closure that delegates to SciMLStructures.replace(::Tunable, p, newvals), which is shallow: @set! p.tunable = newvals; return p. Every other field of the returned MTKParameters (including the mutable Vector buffers in p.caches) is therefore aliased — by object identity — to the captured template p.

When a loss closure of the form

let iprob = iprob, irepack = irepack
    p -> begin
        iprob2 = remake(iprob, p = irepack(p))
        sol = solve(iprob2)
        sum(sol.u)
    end
end

is differentiated with Enzyme.gradient(Reverse, Const(loss), tunable), Enzyme sees the same Vector identity on both sides of the irepack call. Because no separate shadow buffer is allocated for the aliased cache, cotangents flowing into p.caches[i] during reverse-mode either silently produce zero (silent-zero) or read back the primal as the cotangent (wrong-but-nonzero, roughly Jᵀ × primal_value). This is the root cause of the long-standing desauty use_scc=true Enzyme zero-gradient bug.

The aliasing was confirmed via a cache_alias_break_test.jl experiment: when the loss explicitly copies caches via ConstructionBase.setproperties(p_initial; caches = ntuple(i -> copy(p_initial.caches[i]), …)), Enzyme matches FiniteDiff to 8 significant figures.

The fix

lib/ModelingToolkitBase/src/systems/parameter_buffer.jl: have the repack closure for the Tunable portion construct a new MTKParameters with freshly allocated caches buffers (ntuple(i -> copy(p.caches[i]), length(p.caches))). Other fields (initials, discrete, constant, nonnumeric) are left shared with the template, matching previous behavior, since they have not been demonstrated to cause AD issues.

function SciMLStructures.canonicalize(::SciMLStructures.Tunable, p::MTKParameters)
    arr = p.tunable
    repack = let p = p
        function (new_val)
            return MTKParameters(
                new_val,
                p.initials,
                p.discrete,
                p.constant,
                p.nonnumeric,
                ntuple(i -> copy(p.caches[i]), length(p.caches)),
            )
        end
    end
    return arr, repack, true
end

Other shallow-replace sites (not changed here)

The symmetric canonicalize(::Initials, ::MTKParameters) and the generic for (Portion, …) loop covering Discrete/Constants/Nonnumeric/Caches use the same shallow-replace pattern (@set! p.field = newvals). If downstream AD users hit analogous problems through those portions, they will need similar treatment — but the demonstrated bug is in Tunable, so this patch is scoped accordingly. Happy to extend if reviewers prefer symmetry.

Verification

  1. Alias-break smoke check (objectid before/after):

    caches[1]: template=7387335378339599272 repacked=5186687618985786704 shared=false
    PASS: all caches are fresh allocations
    
  2. End-to-end desauty MWE (the bug-reproducing case): Enzyme reverse-mode now matches FiniteDiff:

    • FD: [0.0, 0.0, 0.3662747126219529, 0.3100659003339509, 0.0, 0.0, 0.7325494252439058, 0.0, 0.0, 0.0]
    • Enzyme: [0.0, 0.0, 0.3662747126234614, 0.3100659003292303, 0.0, 0.0, 0.7325494252469228, 0.0, 0.0, 0.0]
    • Matches FD to 8 sig figs (was zero before).

Test plan

  • Smoke check: objectid(repack(t).caches[i]) != objectid(p.caches[i]) for all caches
  • End-to-end Enzyme gradient on desauty SCC-initialized ODE matches FD to 8 sig figs
  • MTKBase test suite (currently running locally; will update if regressions surface)
  • CI

Note

Draft — please ignore until reviewed by @ChrisRackauckas.

The `repack` closure returned by `SciMLStructures.canonicalize(::Tunable,
::MTKParameters)` previously delegated to `SciMLStructures.replace`, which
only sets the `tunable` field via `@set!` and keeps every other field
(including the `caches` tuple of mutable `Vector`s) aliased to the captured
template `p`. Each call to `repack(new_tunable)` therefore produced an
`MTKParameters` whose `caches` Vector buffers were shared identities with
the closure's captured `iprob.p.caches`.

This aliasing breaks Enzyme reverse-mode AD when a loss closure (e.g.
`p -> sum(solve(remake(iprob, p = irepack(p))).u)`) is differentiated.
Because Enzyme sees the same Vector identity on both sides of the call,
it does not allocate a separate shadow buffer for the cache. Cotangents
flowing into `p.caches[i]` either silently produce zero (silent-zero) or
read back the primal value as the cotangent (wrong-but-nonzero,
roughly `Jᵀ × primal_value`). This is the root cause of the long-standing
desauty `use_scc=true` Enzyme zero-gradient bug.

The fix: have `repack` construct a new `MTKParameters` with freshly
allocated `caches` buffers (`ntuple(i -> copy(p.caches[i]), …)`), so each
repack produces its own independent cache storage. Other fields
(`initials`, `discrete`, `constant`, `nonnumeric`) are left shared with
the template, matching previous behavior, since they have not been
demonstrated to cause AD issues.

Note: the symmetric `canonicalize(::Initials, ::MTKParameters)` and the
generic `for (Portion, …)` loop covering `Discrete`/`Constants`/
`Nonnumeric`/`Caches` use the same shallow-replace pattern. If downstream
AD users hit analogous problems through those portions, they will need
similar treatment — but the demonstrated bug is in `Tunable`, so this
patch is scoped accordingly.

Verification:
- `objectid(repack(t).caches[1]) != objectid(p.caches[1])` (alias broken)
- End-to-end desauty MWE: Enzyme reverse-mode now matches FiniteDiff to
  8 significant figures on the SCC-initialized ODE (was zero before).

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Author

Holding this PR — the repack-fresh-caches behavior in this diff is treating the symptom. The semantics of SciMLStructures.canonicalize(::Tunable, ::MTKParameters)'s repack returning an MTKParameters whose non-tunable fields alias the captured template is intentional (and useful — repack is a lightweight rewrap, not a deep copy). The fact that Enzyme misbehaves on this aliased pattern points to an upstream Enzyme issue, not an MTK design bug.

The reproducer cache_alias_break_test.jl shows that breaking the alias in user code gives the correct gradient to 8 sig figs — but the broken case has aliased mutable buffers captured in a closure, which Enzyme should handle. Will extract a minimal MTK-free reproducer and file it against EnzymeAD/Enzyme.jl.

Won't merge this PR; closing once the Enzyme issue is filed and acknowledged.

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Author

Minimal upstream reproducer filed: EnzymeAD/Enzyme.jl#3124

The bug is purely Enzyme — closure-captured mutable buffer aliased into a new struct → set_runtime_activity(Reverse) silently returns zero gradient (plain Reverse errors). 40-line MWE, no SciML deps. The desauty manifestation is just MTK's repack producing a MTKParameters whose caches field aliases the closure-captured iprob.p.caches — exactly the pattern the upstream MWE captures.

Closing this PR as soon as Enzyme.jl#3124 lands — this MTK-side workaround is the wrong layer to fix it.

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Author

Closing — this fix is treating a symptom, not the real issue. The repack-aliasing-caches behavior is intended MTK semantics (SciMLStructures.canonicalize(::Tunable, p)'s repack is a lightweight rewrap, not a deep copy). The downstream desauty SCC init Enzyme.gradient bug it was working around is the user's misuse of the Enzyme API (Const(closure_with_mutable_captures) is documented to not allocate shadows for captures). Fixed properly downstream in SciML/SciMLSensitivity.jl#1454 by restructuring the test loss to use Enzyme.autodiff with explicit Duplicated(iprob, diprob).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants