Skip to content

NonlinearSolveBaseEnzymeExt: accumulate non-Tunable cotangent fields (#935)#936

Merged
ChrisRackauckas merged 3 commits into
SciML:masterfrom
ChrisRackauckas-Claude:accum-tangent-non-tunable-portions
May 24, 2026
Merged

NonlinearSolveBaseEnzymeExt: accumulate non-Tunable cotangent fields (#935)#936
ChrisRackauckas merged 3 commits into
SciML:masterfrom
ChrisRackauckas-Claude:accum-tangent-non-tunable-portions

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

@ChrisRackauckas-Claude ChrisRackauckas-Claude commented May 24, 2026

Note

Draft — please ignore until reviewed by @ChrisRackauckas.

Summary

When SciMLSensitivity.steadystatebackpass runs under diff_tunables = Val(false) — which is the deliberate sensealg choice when the parameter object carries caches (e.g. MTKParameters from an SCCNonlinearProblem initialization, where SCC sub-problem solutions are propagated via explicitfuns!) — the cotangent it returns is a structured MTKParameters whose meaningful gradient contribution lives in the caches field, not in tunable.

The previous _accum_tangent!'s SciMLStructure→SciMLStructure branch only round-tripped the Tunable portion via SciMLStructures.canonicalize(Tunable(), …) / replace!, so every non-Tunable field of darg (caches/initials/discrete/constant/nonnumeric) was silently dropped on accumulation.

This PR walks every non-Tunable field of darg via the existing _accum_nested! helper when the inner adjoint produced a structured cotangent. Addresses the accumulation half of #935.

Changes (3 commits)

  1. _accum_tangent! SciMLStructure→SciMLStructure branch walks every non-Tunable field of darg after the Tunable accumulation, via _accum_nested!.
  2. diff_tunables::Bool kwarg gates the walk. true (default) preserves prior behavior; false walks the non-Tunable fields.
  3. Reverse-rule predicate mirrors SciMLSensitivity.automatic_sensealg_choice locally: if the outer sensealg is nothing (default) but prob.p is a SciMLStructure with non-empty caches, the inner adjoint will pick diff_tunables = Val(false), so this rule passes diff_tunables = false into _accum_tangent!. Otherwise honors the explicit sensealg.diff_tunables if present.

Diagnosis

Trace, with MTKParameters having non-empty caches:

  1. SciMLSensitivity.automatic_sensealg_choice (concrete_solve.jl:312-336) detects _has_caches=true and chooses SteadyStateAdjoint(diff_tunables = Val(false)).
  2. steadystatebackpass (concrete_solve.jl:2601-2611) calls adjoint_sensitivities and, under EnzymeOriginator + Val(false), returns the raw dp_full::MTKParameters (preserving the structured tangent including caches).
  3. That MTKParameters flows into Enzyme's reverse rule on solve_up and lands as darg in _accum_tangent!.
  4. Bug: the SciMLStructure→SciMLStructure branch only handled Tunable. The caches portion of the cotangent — the actual gradient signal in the SCC case — was discarded silently.

Verification

  • Pkg.test("NonlinearSolveBase") passes — 28/28 on Julia 1.12.4.
  • Diagnostic instrumentation confirmed _accum_tangent! now receives diff_tunables = false and walks darg.caches into the shadow on the desauty SCC MWE. The walk persists across calls (verified by inspecting dval.caches_before on the second call).
  • CI green on NonlinearSolveBase.

Not in scope for this PR

The desauty use_scc=true end-to-end Enzyme test still produces a zero gradient. The cotangent arriving at _accum_tangent! has darg.tunable = 0 because in MTK's SCC encoding the parent's tunable parameters enter sub-problems via caches (the CacheWriter / explicitfuns! copies them in), so the sub-problem's adjoint structurally returns tunable=0. Enzyme's native reverse pass through MTK's CacheWriter (a RuntimeGeneratedFunction-wrapped explicitfun) doesn't propagate the caches cotangent back to the parent tunable shadow. That's an upstream Enzyme / MTK / SciMLSensitivity issue (needs either a _concrete_solve_adjoint(::SCCNonlinearProblem, …) dispatch or a working Enzyme reverse for CacheWriter) — separate from this fix. Mooncake works on the same MWE because it natively differentiates through scc_solve_up and CacheWriter rather than relying on the IFT/structured-cotangent path.

A separately-discovered alias bug in the return-value shadow (make_zero allocating a fresh sol.prob.p instead of preserving alias to the outer shadow) is broken out into #937, independent of this PR.

Bumps

NonlinearSolveBase 2.26.02.26.1 (patch — bugfix only, no API change).

When SciMLSensitivity's `steadystatebackpass` runs under `diff_tunables = Val(false)` — which is the deliberate choice when `MTKParameters` carries SCC `caches` so those buffers participate in the adjoint — the cotangent it returns is a structured `MTKParameters` with non-zero contributions in `caches` (and potentially other non-Tunable portions), not just `tunable`.

The previous `_accum_tangent!` SciMLStructure→SciMLStructure branch only round-tripped the Tunable slice via `SciMLStructures.canonicalize(Tunable(), ...)` / `replace!`, silently dropping every non-Tunable field of `darg`. That manifested downstream (see SciML#935) as Enzyme gradients of `solve` on `SCCNonlinearProblem`-backed initialization producing zero when the meaningful gradient flowed through `caches`-mediated coupling.

Walk the remaining fields of `darg` via the existing `_accum_nested!` helper so caches/initials/discrete/constant/nonnumeric contributions are accumulated as well. Tests added with a minimal `MockMTKParams` SciMLStructure-compliant type that exercises caches accumulation directly through the extension's internal helper.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Replace the unconditional "walk every non-Tunable field whenever both
sides are SciMLStructures" with an explicit `diff_tunables::Bool` kwarg
on `_accum_tangent!` that mirrors the sensealg field of the same name.

`diff_tunables = true` (default) — backpass returned a Tunable-only
cotangent; leave non-Tunable fields of any structured `darg` alone.

`diff_tunables = false` — backpass returned a structured cotangent
whose meaningful contribution may live in `caches`/`initials`/etc.;
walk every non-Tunable field in.

The `reverse` rule derives this from `sensealg.val.diff_tunables`. Test
extended to exercise both directions.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Restructured per review: _accum_tangent! now takes a diff_tunables::Bool kwarg matching the sensealg field name. true (default) → Tunable-only accumulation. false → walk every non-Tunable field of darg (caches, initials, …). The reverse rule mirrors sensealg.val.diff_tunables into the kwarg.

Tests cover both directions: dbefada. Pkg.test("NonlinearSolveBase") 28/28 pass.

The Enzyme `reverse` rule's `sensealg` arg is whatever the user passed
to `solve` — typically `Nothing`. The *inner* sensealg that ends up
producing the cotangent is resolved later inside
`_concrete_solve_adjoint`'s call to `automatic_sensealg_choice`, which
picks `SteadyStateAdjoint(diff_tunables = Val(false))` whenever
`prob.p` is a SciMLStructure with non-empty `caches` (e.g. an
MTKParameters tied to an SCCNonlinearProblem's `explicitfuns!`
coupling).

The previous wiring only consulted the outer sensealg and therefore
defaulted to `diff_tunables = true` in the SCC case, so the
non-Tunable walk in `_accum_tangent!` was suppressed exactly when it
was supposed to fire. Reproduce the same predicate locally
(`isscimlstructure(p) && hasfield(:caches) && !isempty(p.caches)`)
when the outer sensealg lacks an explicit `diff_tunables` field, so
the rule mirrors what the inner adjoint will actually do.

Verified with a temporary `@info` probe that
`_accum_tangent!` now receives `diff_tunables=false` and accumulates
the structured cotangent through `darg.caches`. (The desauty
`use_scc=true` end-to-end gradient remains zero, but for an unrelated
upstream reason: the inner sensealg backpass returns `darg.tunable = 0`,
which is independent of this code path.)

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

End-to-end test results — partial success, blockage moved upstream.

What works

Unit tests (28/28) confirm _accum_tangent! correctly gates on diff_tunables:

  • diff_tunables=false → caches walked into shadow alongside Tunable
  • diff_tunables=true → only Tunable touched

f78b70e then fixes a wiring bug in the reverse rule: the outer sensealg is whatever the user passed (typically Nothing), but the inner sensealg that produces the cotangent is resolved by automatic_sensealg_choice. To predict it, the rule now reproduces the same predicate locally — checking isscimlstructure(prob.p) && hasfield(:caches) && !isempty(prob.p.caches) — so diff_tunables=false fires correctly under MTKParameters-with-caches.

Verified with diagnostic probe

A temporary @info confirmed end-to-end through the desauty use_scc=true MWE that _accum_tangent! now receives diff_tunables=false and walks the structured cotangent into the shadow:

_accum_tangent! ENTRY:
  darg.caches             = ([-0.31006],)        ← non-zero
  dval.caches_before      = ([-1.5],)             ← Enzyme had accumulated prior
  (after my walk: dval.caches = -1.81)            ← persisted to next call

What doesn't work — but isn't this PR

The final gradient is still zero. Root cause is upstream: in the same diagnostic,

darg.tunable from steadystatebackpass = [-0.0, -0.0, ..., -0.0]   (all zero)

The sub-problem's SteadyStateAdjoint(diff_tunables=Val(false)) backpass — running ZygoteVJP after ReverseDiffVJP failed — produces a structured cotangent whose caches slice is populated but whose tunable slice is silently zero. Physically the sub-problems' f does depend on p.tunable (b, c), so this should be non-zero. My fix has nothing to walk in tunable because the cotangent isn't created upstream.

This PR is necessary but not sufficient for SCC. It is correctly-scoped to the Tunable-vs-Caches accumulation question, and the caches-walk machinery will be required once the upstream darg.tunable = 0 issue is fixed (separate investigation needed in SciMLSensitivity's steadystate_adjoint.jl / adjointdiffcache / ZygoteVJP pullback for MTKParameters under use_full_p=true).

Tests still 28/28 green.

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Adding 72aff4d as an independent finding — not a fix for the upstream issue, but it surfaced during the investigation.

Separate alias bug in the return-value shadow

The augmented_primal builds the return shadow with Enzyme.make_zero(res[1]). That recursively allocates fresh zero buffers for every mutable field of the NonlinearSolution, including sol.prob.p and sol.prob.u0 — which alias the outer caller's active p/u0 shadows. Confirmed with objectid(make_zero(sol).prob.p) != objectid(sol.prob.p). Severing that aliasing means any cotangent a downstream consumer writes into sol.prob.p or .u0 lands in a dangling buffer the outer tape isn't tracking, silently dropping the contribution.

Fix: a thin _make_solution_zero that pre-seeds the make_zero IdDict so prob.p and prob.u0 map to themselves, short-circuiting recursion. sol.u (the derivative-carrying field) still gets a fresh zero buffer. Guards nothing parameters and non-mutable values.

Unit test in lib/NonlinearSolveBase/test/enzyme_make_solution_zero.jl asserts naive make_zero breaks aliasing, the helper preserves it (`===` + `objectid`), sol.u is fresh-zero, and the nothing-p path doesn't crash.

Tests: `Pkg.test("NonlinearSolveBase")` 38/38 in 42s.

Polyalg MixedDuplicated is upstream Enzyme

Tracing the polyalg MethodError: no method matching MixedDuplicated(::NonlinearSolution, ::NonlinearSolution) led into Enzyme's own `create_activity_wrapper` at `Enzyme/.../rules/jitrules.jl:14`, invoked from `runtime_generic_augfwd` on the downstream `SciMLBase.wrap_sol(::NonlinearSolution, …)` call. The wrapper emits `MixedDuplicated(primarg, shadowarg)` with `shadowarg::NonlinearSolution` instead of `Base.RefValue{NonlinearSolution}`. The `Ref` wrap is supposed to be the caller's responsibility but isn't done on the type-unstable generic path that fires when the solution carries `AutoSpecializeCallable{FunctionWrappersWrapper{…}}` in its type parameters. Pinning a concrete algorithm like `NewtonRaphson()` removes those wrappers and the activity analysis no longer classifies the return as `MixedState`, side-stepping the bug. Not fixable from `NonlinearSolveBase` regardless of what we do with `make_zero` — needs an EnzymeAD/Enzyme.jl issue.

@ChrisRackauckas-Claude ChrisRackauckas-Claude force-pushed the accum-tangent-non-tunable-portions branch from 72aff4d to f78b70e Compare May 24, 2026 18:48
@ChrisRackauckas ChrisRackauckas marked this pull request as ready for review May 24, 2026 19:21
@ChrisRackauckas ChrisRackauckas merged commit a32ebd6 into SciML:master May 24, 2026
196 of 242 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants