NonlinearSolveBaseEnzymeExt: accumulate non-Tunable cotangent fields (#935)#936
Conversation
When SciMLSensitivity's `steadystatebackpass` runs under `diff_tunables = Val(false)` — which is the deliberate choice when `MTKParameters` carries SCC `caches` so those buffers participate in the adjoint — the cotangent it returns is a structured `MTKParameters` with non-zero contributions in `caches` (and potentially other non-Tunable portions), not just `tunable`. The previous `_accum_tangent!` SciMLStructure→SciMLStructure branch only round-tripped the Tunable slice via `SciMLStructures.canonicalize(Tunable(), ...)` / `replace!`, silently dropping every non-Tunable field of `darg`. That manifested downstream (see SciML#935) as Enzyme gradients of `solve` on `SCCNonlinearProblem`-backed initialization producing zero when the meaningful gradient flowed through `caches`-mediated coupling. Walk the remaining fields of `darg` via the existing `_accum_nested!` helper so caches/initials/discrete/constant/nonnumeric contributions are accumulated as well. Tests added with a minimal `MockMTKParams` SciMLStructure-compliant type that exercises caches accumulation directly through the extension's internal helper. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Replace the unconditional "walk every non-Tunable field whenever both sides are SciMLStructures" with an explicit `diff_tunables::Bool` kwarg on `_accum_tangent!` that mirrors the sensealg field of the same name. `diff_tunables = true` (default) — backpass returned a Tunable-only cotangent; leave non-Tunable fields of any structured `darg` alone. `diff_tunables = false` — backpass returned a structured cotangent whose meaningful contribution may live in `caches`/`initials`/etc.; walk every non-Tunable field in. The `reverse` rule derives this from `sensealg.val.diff_tunables`. Test extended to exercise both directions. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
|
Restructured per review: Tests cover both directions: dbefada. |
The Enzyme `reverse` rule's `sensealg` arg is whatever the user passed to `solve` — typically `Nothing`. The *inner* sensealg that ends up producing the cotangent is resolved later inside `_concrete_solve_adjoint`'s call to `automatic_sensealg_choice`, which picks `SteadyStateAdjoint(diff_tunables = Val(false))` whenever `prob.p` is a SciMLStructure with non-empty `caches` (e.g. an MTKParameters tied to an SCCNonlinearProblem's `explicitfuns!` coupling). The previous wiring only consulted the outer sensealg and therefore defaulted to `diff_tunables = true` in the SCC case, so the non-Tunable walk in `_accum_tangent!` was suppressed exactly when it was supposed to fire. Reproduce the same predicate locally (`isscimlstructure(p) && hasfield(:caches) && !isempty(p.caches)`) when the outer sensealg lacks an explicit `diff_tunables` field, so the rule mirrors what the inner adjoint will actually do. Verified with a temporary `@info` probe that `_accum_tangent!` now receives `diff_tunables=false` and accumulates the structured cotangent through `darg.caches`. (The desauty `use_scc=true` end-to-end gradient remains zero, but for an unrelated upstream reason: the inner sensealg backpass returns `darg.tunable = 0`, which is independent of this code path.) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
|
End-to-end test results — partial success, blockage moved upstream. What worksUnit tests (28/28) confirm
Verified with diagnostic probeA temporary What doesn't work — but isn't this PRThe final gradient is still zero. Root cause is upstream: in the same diagnostic, The sub-problem's This PR is necessary but not sufficient for SCC. It is correctly-scoped to the Tunable-vs-Caches accumulation question, and the caches-walk machinery will be required once the upstream Tests still 28/28 green. |
|
Adding 72aff4d as an independent finding — not a fix for the upstream issue, but it surfaced during the investigation. Separate alias bug in the return-value shadowThe augmented_primal builds the return shadow with Fix: a thin Unit test in Tests: `Pkg.test("NonlinearSolveBase")` 38/38 in 42s. Polyalg
|
72aff4d to
f78b70e
Compare
Note
Draft — please ignore until reviewed by @ChrisRackauckas.
Summary
When
SciMLSensitivity.steadystatebackpassruns underdiff_tunables = Val(false)— which is the deliberate sensealg choice when the parameter object carriescaches(e.g.MTKParametersfrom anSCCNonlinearProbleminitialization, where SCC sub-problem solutions are propagated viaexplicitfuns!) — the cotangent it returns is a structuredMTKParameterswhose meaningful gradient contribution lives in thecachesfield, not intunable.The previous
_accum_tangent!'s SciMLStructure→SciMLStructure branch only round-tripped theTunableportion viaSciMLStructures.canonicalize(Tunable(), …)/replace!, so every non-Tunable field ofdarg(caches/initials/discrete/constant/nonnumeric) was silently dropped on accumulation.This PR walks every non-Tunable field of
dargvia the existing_accum_nested!helper when the inner adjoint produced a structured cotangent. Addresses the accumulation half of #935.Changes (3 commits)
_accum_tangent!SciMLStructure→SciMLStructure branch walks every non-Tunable field ofdargafter the Tunable accumulation, via_accum_nested!.diff_tunables::Boolkwarg gates the walk.true(default) preserves prior behavior;falsewalks the non-Tunable fields.SciMLSensitivity.automatic_sensealg_choicelocally: if the outersensealgisnothing(default) butprob.pis a SciMLStructure with non-emptycaches, the inner adjoint will pickdiff_tunables = Val(false), so this rule passesdiff_tunables = falseinto_accum_tangent!. Otherwise honors the explicitsensealg.diff_tunablesif present.Diagnosis
Trace, with
MTKParametershaving non-emptycaches:SciMLSensitivity.automatic_sensealg_choice(concrete_solve.jl:312-336) detects_has_caches=trueand choosesSteadyStateAdjoint(diff_tunables = Val(false)).steadystatebackpass(concrete_solve.jl:2601-2611) callsadjoint_sensitivitiesand, underEnzymeOriginator + Val(false), returns the rawdp_full::MTKParameters(preserving the structured tangent includingcaches).MTKParametersflows into Enzyme's reverse rule onsolve_upand lands asdargin_accum_tangent!.Tunable. The caches portion of the cotangent — the actual gradient signal in the SCC case — was discarded silently.Verification
Pkg.test("NonlinearSolveBase")passes — 28/28 on Julia 1.12.4._accum_tangent!now receivesdiff_tunables = falseand walksdarg.cachesinto the shadow on the desauty SCC MWE. The walk persists across calls (verified by inspectingdval.caches_beforeon the second call).Not in scope for this PR
The desauty
use_scc=trueend-to-end Enzyme test still produces a zero gradient. The cotangent arriving at_accum_tangent!hasdarg.tunable = 0because in MTK's SCC encoding the parent's tunable parameters enter sub-problems viacaches(theCacheWriter/explicitfuns!copies them in), so the sub-problem's adjoint structurally returns tunable=0. Enzyme's native reverse pass through MTK'sCacheWriter(aRuntimeGeneratedFunction-wrapped explicitfun) doesn't propagate the caches cotangent back to the parent tunable shadow. That's an upstream Enzyme / MTK / SciMLSensitivity issue (needs either a_concrete_solve_adjoint(::SCCNonlinearProblem, …)dispatch or a working Enzyme reverse forCacheWriter) — separate from this fix. Mooncake works on the same MWE because it natively differentiates throughscc_solve_upandCacheWriterrather than relying on the IFT/structured-cotangent path.A separately-discovered alias bug in the return-value shadow (
make_zeroallocating a freshsol.prob.pinstead of preserving alias to the outer shadow) is broken out into #937, independent of this PR.Bumps
NonlinearSolveBase2.26.0→2.26.1(patch — bugfix only, no API change).