Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ Lux = "1"
Markdown = "1.10"
ModelingToolkit = "10, 11"
Mooncake = "0.5.24"
Reactant = "0.2.22"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1, 4"
SCCNonlinearSolve = "1"
Optimization = "4, 5"
OptimizationNLopt = "0.3"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.108, 7"
OrdinaryDiffEqCore = "3.26, 4, 5"
Expand All @@ -107,6 +107,7 @@ PreallocationTools = "1.1.1"
QuadGK = "2.9.1"
Random = "1.10"
RandomNumbers = "1.5.3"
Reactant = "0.2.22"
RecursiveArrayTools = "3.27.2, 4"
Reexport = "1.0"
ReverseDiff = "1.15.1"
Expand Down Expand Up @@ -141,10 +142,10 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
Expand All @@ -154,6 +155,7 @@ OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
OrdinaryDiffEqStabilizedRK = "358294b1-0aab-51c3-aafe-ad5ab194a2ad"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -162,4 +164,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "OrdinaryDiffEqFIRK", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqStabilizedRK", "Pkg", "Reactant", "SCCNonlinearSolve", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OptimizationNLopt", "OrdinaryDiffEq", "OrdinaryDiffEqFIRK", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqStabilizedRK", "Pkg", "Reactant", "SCCNonlinearSolve", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pages = [
"Manual and APIs" => Any[
"manual/differential_equation_sensitivities.md",
"manual/nonlinear_solve_sensitivities.md",
"manual/optimization_sensitivities.md",
"manual/direct_forward_sensitivity.md",
"manual/direct_adjoint_sensitivities.md",
],
Expand Down
20 changes: 20 additions & 0 deletions docs/src/manual/optimization_sensitivities.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# [Sensitivity Algorithms for Optimization Problems](@id sensitivity_optimization)

SciMLSensitivity provides adjoint algorithms for differentiating through the optimum
`u*(p)` of a parameterized [`OptimizationProblem`](https://docs.sciml.ai/Optimization/stable/),
giving `dG/dp` for any downstream loss `G(u*(p))` via implicit differentiation rather
than by differentiating through the iterations of the optimizer.

- `UnconstrainedOptimizationAdjoint` handles unconstrained problems by treating the
stationarity condition `∇f(u*, p) = 0` as a steady-state nonlinear system and
reusing the `SteadyStateAdjoint` machinery.
- `OptimizationAdjoint` handles problems with equality, two-sided inequality, and
variable-bound constraints by implicit differentiation of the KKT first-order
optimality conditions. It detects the active inequality set at `u*`, recovers
multipliers from the stationarity equation, and solves a single symmetric KKT
linear system to produce the adjoint.

```@docs
UnconstrainedOptimizationAdjoint
OptimizationAdjoint
```
1 change: 1 addition & 0 deletions ext/SciMLSensitivityMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module SciMLSensitivityMooncakeExt
using SciMLSensitivity: SciMLSensitivity, FakeIntegrator
using Mooncake: Mooncake
import SciMLSensitivity: get_paramjac_config, get_cb_paramjac_config, mooncake_run_ad,

MooncakeVJP, MooncakeLoaded,
DiffEqBase, MooncakeAdjoint, _init_originator_gradient
using SciMLSensitivity: SciMLBase, SciMLStructures, canonicalize, Tunable, isscimlstructure,
Expand Down
3 changes: 2 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ include("interpolating_adjoint.jl")
include("quadrature_adjoint.jl")
include("gauss_adjoint.jl")
include("callback_tracking.jl")
include("optimization_adjoint.jl")
include("concrete_solve.jl")
include("second_order.jl")
include("steadystate_adjoint.jl")
Expand All @@ -103,7 +104,7 @@ export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, GaussKronrodAdjoint,
TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, MooncakeAdjoint,
EnzymeAdjoint, ForwardSensitivity, ForwardDiffSensitivity,
ForwardDiffOverAdjoint,
SteadyStateAdjoint, UnconstrainedOptimizationAdjoint,
SteadyStateAdjoint, UnconstrainedOptimizationAdjoint, OptimizationAdjoint,
ForwardLSS, AdjointLSS, NILSS, NILSAS

export second_order_sensitivities, second_order_sensitivity_product
Expand Down
1 change: 1 addition & 0 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ function mooncake_run_ad(paramjac_config, y, p, t, λ)
error(msg)
end


function get_pf(::ReactantVJP, prob, _f)
isinplace = DiffEqBase.isinplace(prob)
isRODE = isa(prob, RODEProblem)
Expand Down
105 changes: 102 additions & 3 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2687,10 +2687,12 @@ function SciMLBase._concrete_solve_adjoint(
opt_f = _prob.f

if opt_f.grad === nothing
grad_fn = if sensealg.objective_ad isa Bool && !sensealg.objective_ad
(G, u, p) -> FiniteDiff.finite_difference_gradient!(G, Base.Fix2(opt_f, p), u)
else
grad_fn = if alg_autodiff(sensealg)
(G, u, p) -> ForwardDiff.gradient!(G, Base.Fix2(opt_f, p), u)
else
(G, u, p) -> FiniteDiff.finite_difference_gradient!(
G, Base.Fix2(opt_f, p), u, diff_type(sensealg)
)
end
nlprob = NonlinearProblem(grad_fn, opt_sol.u, p)
else
Expand Down Expand Up @@ -2808,6 +2810,103 @@ function SciMLBase._concrete_solve_adjoint(
return out, steadystatebackpass
end

function SciMLBase._concrete_solve_adjoint(
prob::AbstractOptimizationProblem,
alg, sensealg::OptimizationAdjoint{CS, AD, FDT},
u0, p, originator::SciMLBase.ADOriginator,
args...; save_idxs = nothing, kwargs...
) where {CS, AD, FDT}

_prob = remake(prob, u0 = u0, p = p)
opt_sol = solve(_prob, alg, args...; kwargs...)

_save_idxs = save_idxs === nothing ? Colon() : save_idxs
out = if save_idxs === nothing
opt_sol
else
SciMLBase.sensitivity_solution(opt_sol, opt_sol[_save_idxs])
end

_, repack_adjoint = if isscimlstructure(p)
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
elseif isfunctor(p)
ps, re = Functors.functor(p)
ps, x -> (re(x),)
else
nothing, x -> (x,)
end

function optimizationbackpass(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
function df(_out, _u, _p, _t, _i)
return if _save_idxs isa Number
_out[_save_idxs] = Δ isa AbstractArray ? Δ[_save_idxs] : Δ.u[_save_idxs]
elseif Δ isa Number
@. _out[_save_idxs] = Δ
elseif Δ isa AbstractArray
@. _out[_save_idxs] = Δ[_save_idxs]
elseif isnothing(_out)
_out
else
@. _out[_save_idxs] = Δ.u[_save_idxs]
end
end
dp = adjoint_sensitivities(opt_sol, nothing; sensealg = sensealg, dgdu = df)

dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
dp, Δtunables = if isscimlstructure(dp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, nothing
elseif isfunctor(dp)
dp, _ = Functors.functor(dp)
dp, nothing
else
dp, nothing
end
else
dp, Δtunables = if isscimlstructure(p)
if (Δ.prob.p == ZeroTangent() || Δ.prob.p == NoTangent())
dp, _, _ = canonicalize(Tunable(), dp)
dp, nothing
else
Δp = setproperties(dp, to_nt(Δ.prob.p))
Δtunables, _, _ = canonicalize(Tunable(), Δp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, Δtunables
end
elseif isfunctor(p)
dp, _ = Functors.functor(dp)
Δtunables, _ = Functors.functor(Δ.prob.p)
dp, Δtunables
else
dp, Δ.prob.p
end
end

dp = Zygote.accum(
dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : Δtunables
)

return if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(
NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...,
)
else
(
NoTangent(), NoTangent(), NoTangent(),
NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...,
)
end
end

return out, optimizationbackpass
end

function fix_endpoints(sensealg, sol, ts)
@warn "Endpoints do not match. Return code: $(sol.retcode). Likely your time range is not a multiple of `saveat`. sol.t[end]: $(last(current_time(sol))), ts[end]: $(ts[end])"
Expand Down
36 changes: 30 additions & 6 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ function jacobian(
return J
end

function gradient(
f, x::AbstractArray{<:Number},
alg::AbstractOverloadingSensitivityAlgorithm
)
return if alg_autodiff(alg)
ForwardDiff.gradient(unwrapped_f(f), x)
else
FiniteDiff.finite_difference_gradient(f, x, diff_type(alg))
end
end

function hessian(
f, x::AbstractArray{<:Number},
alg::AbstractOverloadingSensitivityAlgorithm
)
return if alg_autodiff(alg)
ForwardDiff.hessian(unwrapped_f(f), x)
else
FiniteDiff.finite_difference_hessian(f, x)
end
end

function jacobian!(
J::Nothing, f, x::AbstractArray{<:Number},
fx::Union{Nothing, AbstractArray{<:Number}},
Expand Down Expand Up @@ -677,10 +699,11 @@ function _vecjacobian!(
(; tunables, repack) = S.diffcache
end

u0 = state_values(prob)
if prob isa AbstractNonlinearProblem ||
if prob isa AbstractNonlinearProblem || prob isa SciMLBase.AbstractOptimizationCache ||
(
eltype(λ) <: eltype(u0) && t isa eltype(u0) &&
let u0 = state_values(prob)
eltype(λ) <: eltype(u0) && t isa eltype(u0)
end &&
compile_tape(sensealg.autojacvec)
)
tape = S.diffcache.paramjac_config
Expand Down Expand Up @@ -731,7 +754,8 @@ function _vecjacobian!(
end
end

if prob isa AbstractNonlinearProblem
_no_time = prob isa AbstractNonlinearProblem || prob isa SciMLBase.AbstractOptimizationCache
if _no_time
tu, tp = ReverseDiff.input_hook(tape)
else
if W === nothing
Expand All @@ -743,13 +767,13 @@ function _vecjacobian!(
output = ReverseDiff.output_hook(tape)
ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls
ReverseDiff.unseed!(tp)
if !(prob isa AbstractNonlinearProblem)
if !_no_time
ReverseDiff.unseed!(tt)
end
W !== nothing && ReverseDiff.unseed!(tW)
ReverseDiff.value!(tu, y)
p isa SciMLBase.NullParameters || ReverseDiff.value!(tp, tunables)
if !(prob isa AbstractNonlinearProblem)
if !_no_time
ReverseDiff.value!(tt, [t])
end
W !== nothing && ReverseDiff.value!(tW, W)
Expand Down
Loading
Loading