From 715806056200166c2da2eaf95a92373ae76b9d2c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 30 Mar 2026 15:08:21 -0400 Subject: [PATCH 01/19] add OptimizationNLopt to test dependencies --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ff19afeb9..96a5b4f00 100644 --- a/Project.toml +++ b/Project.toml @@ -146,6 +146,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125" OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6" @@ -162,4 +163,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "OrdinaryDiffEqFIRK", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqStabilizedRK", "Pkg", "Reactant", "SCCNonlinearSolve", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] +test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OptimizationNLopt", "OrdinaryDiffEq", "OrdinaryDiffEqFIRK", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqStabilizedRK", "Pkg", "Reactant", "SCCNonlinearSolve", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] From d25635b2d20e6e6fef38f3aa5cb7e78815a69662 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 30 Mar 2026 15:08:41 -0400 Subject: [PATCH 02/19] add optimization_adjoint file --- src/SciMLSensitivity.jl | 3 +- src/optimization_adjoint.jl | 144 ++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 src/optimization_adjoint.jl diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index bd130ae2c..d6173e0ea 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -82,6 +82,7 @@ include("interpolating_adjoint.jl") include("quadrature_adjoint.jl") include("gauss_adjoint.jl") include("callback_tracking.jl") +include("optimization_adjoint.jl") include("concrete_solve.jl") include("second_order.jl") include("steadystate_adjoint.jl") @@ -103,7 +104,7 @@ export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, GaussKronrodAdjoint, TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, MooncakeAdjoint, EnzymeAdjoint, ForwardSensitivity, ForwardDiffSensitivity, ForwardDiffOverAdjoint, - SteadyStateAdjoint, UnconstrainedOptimizationAdjoint, + SteadyStateAdjoint, UnconstrainedOptimizationAdjoint, OptimizationAdjoint, ForwardLSS, AdjointLSS, NILSS, NILSAS export second_order_sensitivities, second_order_sensitivity_product diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl new file mode 100644 index 000000000..d7b0640e1 --- /dev/null +++ b/src/optimization_adjoint.jl @@ -0,0 +1,144 @@ +# Differentiation helpers: dispatch on autodiff type parameter (Val{true} = ForwardDiff, +# Val{false} = FiniteDiff with the given FDT scheme) +_optimization_grad(f, x, ::Val{true}, ::FDT) where {FDT} = ForwardDiff.gradient(f, x) +function _optimization_grad(f, x, ::Val{false}, ::FDT) where {FDT} + FiniteDiff.finite_difference_gradient(f, x, FDT()) +end + +_optimization_jac(f, x, ::Val{true}, ::FDT) where {FDT} = ForwardDiff.jacobian(f, x) +function _optimization_jac(f, x, ::Val{false}, ::FDT) where {FDT} + FiniteDiff.finite_difference_jacobian(f, x, FDT()) +end + +_optimization_hess(f, x, ::Val{true}, ::FDT) where {FDT} = ForwardDiff.hessian(f, x) +function _optimization_hess(f, x, ::Val{false}, ::FDT) where {FDT} + FiniteDiff.finite_difference_jacobian( + y -> FiniteDiff.finite_difference_gradient(f, y, FDT()), x, FDT()) +end + +""" + OptimizationAdjointProblem(prob, opt_sol, sensealg, p) -> Jpx + +Compute the KKT-based parameter Jacobian `Jpx` (n_x × n_p) for a constrained +`OptimizationProblem`, where `Jpx[i,j] = ∂x*[i]/∂p[j]`. + +Uses the implicit function theorem applied to the KKT conditions: + + [∇²_xx L, J_x g^T, J_x h_I^T] [J_p x ] [∇²_xp L] + [J_x g, 0, 0 ] [J_p y ] = -[J_p g ] + [J_x h_I, 0, 0 ] [J_p z_I] [J_p h_I ] + +where g are equality constraints, h_I are active inequality constraints, and +y*, z_I* are the corresponding dual variables. +""" +function OptimizationAdjointProblem( + prob::AbstractOptimizationProblem, + opt_sol, + sensealg::OptimizationAdjoint{CS, AD, FDT}, + p + ) where {CS, AD, FDT} + x_star = opt_sol.u + ad_val = Val{AD}() + fdt_val = FDT() + + lcons = prob.lcons + ucons = prob.ucons + n_cons = length(lcons) + + # Wrap in-place cons!(res, x, p) into an out-of-place helper. + # promote_type handles ForwardDiff Dual propagation when either x or q contains duals. + function eval_cons(x, q) + T = promote_type(eltype(x), eltype(q)) + res = zeros(T, n_cons) + prob.f.cons(res, x, q) + return res + end + + # Classify constraints: equality where lcons[i] == ucons[i] + eq_idx = findall(i -> lcons[i] == ucons[i], eachindex(lcons)) + ineq_idx = findall(i -> lcons[i] != ucons[i], eachindex(lcons)) + + # Evaluate constraints at solution + c_val = eval_cons(x_star, p) + + # Find active inequality constraints + atol = sensealg.active_tol === nothing ? sqrt(eps(eltype(x_star))) : sensealg.active_tol + active_lb = filter(i -> abs(c_val[i] - lcons[i]) <= atol, ineq_idx) + active_ub = filter(i -> abs(c_val[i] - ucons[i]) <= atol, ineq_idx) + + # Constraint residual functions shifted to = 0 at optimum + # Equality: g(x,p) = cons(x,p)[eq_idx] - lcons[eq_idx] + # Active ineq lower bound: h_lb(x,p) = lcons[i] - cons(x,p)[i] (= 0 when active) + # Active ineq upper bound: h_ub(x,p) = cons(x,p)[i] - ucons[i] (= 0 when active) + g(x, q) = eval_cons(x, q)[eq_idx] .- lcons[eq_idx] + h_I(x, q) = vcat( + isempty(active_lb) ? eltype(x_star)[] : lcons[active_lb] .- eval_cons(x, q)[active_lb], + isempty(active_ub) ? eltype(x_star)[] : eval_cons(x, q)[active_ub] .- ucons[active_ub] + ) + + n_eq = length(eq_idx) + n_act = length(active_lb) + length(active_ub) + + # Jacobians of constraints w.r.t. x (needed for dual variables and KKT matrix) + Jxg = isempty(eq_idx) ? zeros(eltype(x_star), 0, length(x_star)) : + _optimization_jac(x -> g(x, p), x_star, ad_val, fdt_val) + Jxhι = n_act == 0 ? zeros(eltype(x_star), 0, length(x_star)) : + _optimization_jac(x -> h_I(x, p), x_star, ad_val, fdt_val) + + # Dual variables from stationarity condition: constraint_jac^T * [y*; z_I*] = -∇f(x*) + ∇f = _optimization_grad(x -> prob.f(x, p), x_star, ad_val, fdt_val) + constraint_jac = vcat(Jxg, Jxhι) # (n_eq + n_act) × n_x + # Solve overdetermined stationarity system via QR (n_x equations, n_eq+n_act unknowns) + dual_vars = if n_eq + n_act == 0 + eltype(x_star)[] + else + dual_prob = LinearProblem(Matrix(constraint_jac'), -∇f) + solve(dual_prob, LinearSolve.QRFactorization(); sensealg.linsolve_kwargs...).u + end + y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] + zI_star = n_act > 0 ? dual_vars[(n_eq + 1):end] : eltype(x_star)[] + + # Lagrangian with fixed multipliers + function L(x, q) + val = prob.f(x, q) + n_eq > 0 && (val += dot(y_star, g(x, q))) + n_act > 0 && (val += dot(zI_star, h_I(x, q))) + return val + end + + # Assemble KKT matrix + Lxx = _optimization_hess(x -> L(x, p), x_star, ad_val, fdt_val) + + n_x = length(x_star) + N = n_x + n_eq + n_act + KKT = zeros(eltype(x_star), N, N) + KKT[1:n_x, 1:n_x] = Lxx + if n_eq > 0 + KKT[1:n_x, (n_x + 1):(n_x + n_eq)] = Jxg' + KKT[(n_x + 1):(n_x + n_eq), 1:n_x] = Jxg + end + if n_act > 0 + KKT[1:n_x, (n_x + n_eq + 1):N] = Jxhι' + KKT[(n_x + n_eq + 1):N, 1:n_x] = Jxhι + end + + # RHS: parameter Jacobians + Lxp = _optimization_jac( + q -> _optimization_grad(x -> L(x, q), x_star, ad_val, fdt_val), p, ad_val, fdt_val) + Jpg = n_eq > 0 ? _optimization_jac(q -> g(x_star, q), p, ad_val, fdt_val) : + zeros(eltype(x_star), 0, length(p)) + Jphι = n_act > 0 ? _optimization_jac(q -> h_I(x_star, q), p, ad_val, fdt_val) : + zeros(eltype(x_star), 0, length(p)) + RHS_p = vcat(Lxp, Jpg, Jphι) # (N × n_p) + + # Solve KKT system column-by-column, reusing the factorization via the cache interface + n_p = size(RHS_p, 2) + Jpx = zeros(eltype(x_star), n_x, n_p) + kkt_cache = LinearSolve.init(LinearProblem(KKT, -RHS_p[:, 1]), sensealg.linsolve; + sensealg.linsolve_kwargs...) + for j in 1:n_p + kkt_cache.b = -RHS_p[:, j] + Jpx[:, j] = LinearSolve.solve!(kkt_cache).u[1:n_x] + end + return Jpx # (n_x × n_p) +end From 96fa26b5eb5921dd74e9e67da467a40b247f43cd Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 30 Mar 2026 15:09:07 -0400 Subject: [PATCH 03/19] add OptimizationAdjoint and implement adjoint --- src/concrete_solve.jl | 95 +++++++++++++++++++++++++++++++++++ src/sensitivity_algorithms.jl | 30 +++++++++++ 2 files changed, 125 insertions(+) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index b4773c0b3..ea77139c0 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -2808,6 +2808,101 @@ function SciMLBase._concrete_solve_adjoint( return out, steadystatebackpass end +function SciMLBase._concrete_solve_adjoint( + prob::AbstractOptimizationProblem, + alg, sensealg::OptimizationAdjoint{CS, AD, FDT}, + u0, p, originator::SciMLBase.ADOriginator, + args...; save_idxs = nothing, kwargs... + ) where {CS, AD, FDT} + if prob.lcons === nothing + error("OptimizationAdjoint requires a constrained OptimizationProblem (lcons/ucons). " * + "For unconstrained problems, use UnconstrainedOptimizationAdjoint instead.") + end + + _prob = remake(prob, u0 = u0, p = p) + opt_sol = solve(_prob, alg, args...; kwargs...) + x_star = opt_sol.u + + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + out = if save_idxs === nothing + opt_sol + else + SciMLBase.sensitivity_solution(opt_sol, opt_sol[_save_idxs]) + end + + Jpx = OptimizationAdjointProblem(_prob, opt_sol, sensealg, p) + + _, repack_adjoint = if isscimlstructure(p) + Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end + elseif isfunctor(p) + ps, re = Functors.functor(p) + ps, x -> (re(x),) + else + nothing, x -> (x,) + end + + function optimizationbackpass(Δ) + Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ + Δu = if Δ isa AbstractArray + Δ + else + Δ.u + end + dp = Jpx' * Δu[_save_idxs] + + dp, Δtunables = if Δ isa AbstractArray || Δ isa Number + dp, Δtunables = if isscimlstructure(dp) + dp, _, _ = canonicalize(Tunable(), dp) + dp, nothing + elseif isfunctor(dp) + dp, _ = Functors.functor(dp) + dp, nothing + else + dp, nothing + end + else + dp, Δtunables = if isscimlstructure(p) + if (Δ.prob.p == ZeroTangent() || Δ.prob.p == NoTangent()) + dp, _, _ = canonicalize(Tunable(), dp) + dp, nothing + else + Δp = setproperties(dp, to_nt(Δ.prob.p)) + Δtunables, _, _ = canonicalize(Tunable(), Δp) + dp, _, _ = canonicalize(Tunable(), dp) + dp, Δtunables + end + elseif isfunctor(p) + dp, _ = Functors.functor(dp) + Δtunables, _ = Functors.functor(Δ.prob.p) + dp, Δtunables + else + dp, Δ.prob.p + end + end + + dp = Zygote.accum( + dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : Δtunables) + + return if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + ( + NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(), + ntuple(_ -> NoTangent(), length(args))..., + ) + else + ( + NoTangent(), NoTangent(), NoTangent(), + NoTangent(), repack_adjoint(dp)[1], NoTangent(), + ntuple(_ -> NoTangent(), length(args))..., + ) + end + end + + return out, optimizationbackpass +end function fix_endpoints(sensealg, sol, ts) @warn "Endpoints do not match. Return code: $(sol.retcode). Likely your time range is not a multiple of `saveat`. sol.t[end]: $(last(current_time(sol))), ts[end]: $(ts[end])" diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index a3d3f7868..a47f79008 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1423,6 +1423,36 @@ function setvjp( ) end +struct OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD, AT} <: + AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} + autojacvec::VJP + linsolve::LS + linsolve_kwargs::LK + objective_ad::OAD + active_tol::AT # tolerance for active inequality constraint detection; nothing = sqrt(eps(eltype(x*))) +end + +function OptimizationAdjoint(; + chunk_size = 0, autodiff = true, + diff_type = Val{:central}, objective_ad = true, autojacvec = nothing, + linsolve = nothing, linsolve_kwargs = (;), active_tol = nothing + ) + return OptimizationAdjoint{ + chunk_size, autodiff, diff_type, typeof(autojacvec), + typeof(linsolve), typeof(linsolve_kwargs), typeof(objective_ad), typeof(active_tol), + }(autojacvec, linsolve, linsolve_kwargs, objective_ad, active_tol) +end + +function setvjp( + sensealg::OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD, AT}, + vjp + ) where {CS, AD, FDT, VJP, LS, LK, OAD, AT} + return OptimizationAdjoint{CS, AD, FDT, typeof(vjp), LS, LK, OAD, AT}( + vjp, sensealg.linsolve, sensealg.linsolve_kwargs, sensealg.objective_ad, + sensealg.active_tol + ) +end + abstract type VJPChoice end """ From 4463185782c6c556e77f098447af8b3b1ca49d1e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 30 Mar 2026 15:09:13 -0400 Subject: [PATCH 04/19] add test --- test/optimization_adjoint.jl | 140 ++++++++++++++++++++++++++++++++++- 1 file changed, 138 insertions(+), 2 deletions(-) diff --git a/test/optimization_adjoint.jl b/test/optimization_adjoint.jl index 5e9942a11..c38101fe7 100644 --- a/test/optimization_adjoint.jl +++ b/test/optimization_adjoint.jl @@ -1,6 +1,6 @@ using Test, LinearAlgebra -using SciMLSensitivity, Optimization, OptimizationOptimisers, SciMLBase -using Mooncake, ForwardDiff +using SciMLSensitivity, Optimization, OptimizationOptimisers, OptimizationNLopt, SciMLBase +using Mooncake, ForwardDiff, Zygote using SciMLSensitivity: MooncakeVJP # Helper: build a NonlinearSolution from an optimization solve using the gradient as the residual, @@ -186,3 +186,139 @@ end @test dp[1] ≈ -0.5 rtol = 1.0e-2 end end + +@testset "OptimizationAdjoint: constrained optimization sensitivities" begin + @testset "Equality constraint" begin + let + # Minimize (u1-1)^2 + (u2-1)^2 s.t. u1 + u2 = p[1] + # Optimal solution: u1* = u2* = p[1]/2 + # du1*/dp[1] = 0.5, du2*/dp[1] = 0.5 + f = (u, p) -> (u[1] - 1)^2 + (u[2] - 1)^2 + # Constraint: u1 + u2 - p[1] = 0 (p flows through cons for correct adjoint) + cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]) + + u0 = [1.5, 1.5] # feasible starting point: u1 + u2 = p[1] = 3 + p = [3.0] + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + + # Verify the forward solve + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ p[1] / 2 rtol = 1e-4 + @test opt_sol.u[2] ≈ p[1] / 2 rtol = 1e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 # constraint satisfied + + # d(u1* + u2*)/dp[1] = d(p[1])/dp[1] = 1 + dp = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[1] + sol.u[2] + end[1] + @test dp[1] ≈ 1.0 rtol = 1e-4 + + # du1*/dp[1] = 0.5 + dp1 = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[1] + end[1] + @test dp1[1] ≈ 0.5 rtol = 1e-4 + end + end + + @testset "Active inequality constraint" begin + let + # Minimize (u - p[1])^2 s.t. u <= p[2] where p[2] < p[1] (constraint active) + # Optimal solution: u* = p[2] + # du*/dp[1] = 0, du*/dp[2] = 1 + f = (u, p) -> (u[1] - p[1])^2 + # Constraint: u[1] - p[2] <= 0 (p[2] flows through cons for correct adjoint) + cons = (res, u, p) -> (res[1] = u[1] - p[2]) + + u0 = [0.0] + p = [3.0, 1.0] # unconstrained min at u=3, constraint forces u<=1 + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [-Inf], ucons = [0.0]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ p[2] rtol = 1e-4 + @test opt_sol.u[1] <= p[2] + 1e-6 # constraint satisfied: u <= p[2] + + dp = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[1] + end[1] + @test dp[1] ≈ 0.0 atol = 1e-4 # du*/dp[1] = 0 (u* doesn't depend on p[1]) + @test dp[2] ≈ 1.0 rtol = 1e-4 # du*/dp[2] = 1 (u* = p[2]) + end + end + + @testset "FiniteDiff vs ForwardDiff consistency" begin + let + # Equality-constrained problem, compare autodiff=true vs autodiff=false + f = (u, p) -> (u[1] - p[1])^2 + (u[2] - p[2])^2 + cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[3]) + + u0 = [0.5, 0.5] + p = [1.0, 2.0, 3.0] + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[3] rtol = 1e-6 # constraint satisfied + + dp_fd = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint(autodiff = false)) + sol.u[1] + end[1] + dp_fwd = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint(autodiff = true)) + sol.u[1] + end[1] + @test dp_fd ≈ dp_fwd rtol = 1e-3 + end + end + + @testset "Lemma 4.2 (Gould et al.): L2 projection onto hyperplane" begin + let + # Minimize (1/2)||u - p||^2 s.t. u1 + u2 + u3 = 1 + # Analytical solution (Lemma 4.2 with A = [1 1 1], H = I): + # g'(p) = I - A^T(AA^T)^{-1}A = I - (1/3)J + # dg_i/dp_j = δ_ij - 1/3 + f = (u, p) -> sum((u .- p) .^ 2) / 2 + cons = (res, u, p) -> (res[1] = u[1] + u[2] + u[3] - 1) + + p = [2.0, 0.0, 0.0] + u0 = [1.0 / 3, 1.0 / 3, 1.0 / 3] # feasible starting point + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + + # Verify forward solve: u* = p - (sum(p)-1)/3 * [1,1,1] = [5/3, -1/3, -1/3] + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ 5.0 / 3 rtol = 1e-4 + @test opt_sol.u[2] ≈ -1.0 / 3 rtol = 1e-4 + @test opt_sol.u[3] ≈ -1.0 / 3 rtol = 1e-4 + @test sum(opt_sol.u) ≈ 1.0 rtol = 1e-6 # constraint satisfied + + # Verify adjoint: dg_i/dp_j = δ_ij - 1/3 + for i in 1:3 + dp = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[i] + end[1] + + expected = [-1.0 / 3, -1.0 / 3, -1.0 / 3] + expected[i] += 1.0 # δ_ij term + @test dp ≈ expected rtol = 1e-3 + end + end + end +end From 6e800755941a2d486ac3dfa8be00c9bcfa6d617c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 30 Mar 2026 15:50:27 -0400 Subject: [PATCH 05/19] add more tests --- test/optimization_adjoint.jl | 173 +++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/test/optimization_adjoint.jl b/test/optimization_adjoint.jl index c38101fe7..fde586a89 100644 --- a/test/optimization_adjoint.jl +++ b/test/optimization_adjoint.jl @@ -285,6 +285,179 @@ end end end + @testset "p only in objective (sensitivity via ∇²_xp L, J_p g = 0)" begin + let + # Minimize p[1]*u[1] + u[1]^2 + u[2]^2 s.t. u[1] + u[2] = 1 (no p in constraint) + # J_p g = 0; sensitivity flows entirely through ∇²_xp L = [1, 0]. + # KKT → u1* = (2 - p[1])/4, u2* = (2 + p[1])/4 + # du1*/dp[1] = -1/4, du2*/dp[1] = 1/4 + f = (u, p) -> p[1] * u[1] + u[1]^2 + u[2]^2 + cons = (res, u, p) -> (res[1] = u[1] + u[2] - 1) + + p = [2.0] + u0 = [0.0, 1.0] # feasible: u1+u2 = 1, equals u* at p[1]=2 + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ (2 - p[1]) / 4 rtol = 1e-4 + @test opt_sol.u[2] ≈ (2 + p[1]) / 4 rtol = 1e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ 1.0 rtol = 1e-6 + + dp1 = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[1] + end[1] + @test dp1[1] ≈ -0.25 rtol = 1e-3 + + dp2 = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[2] + end[1] + @test dp2[1] ≈ 0.25 rtol = 1e-3 + end + end + + @testset "Inactive inequality constraint" begin + let + # Minimize (u - p[1])^2 s.t. u <= p[2] where p[2] > p[1] (constraint NOT active) + # Optimal solution: u* = p[1] (unconstrained min, inequality slack) + # du*/dp[1] = 1, du*/dp[2] = 0 + f = (u, p) -> (u[1] - p[1])^2 + cons = (res, u, p) -> (res[1] = u[1] - p[2]) + + p = [1.0, 5.0] # unconstrained min at u=1, well inside bound u<=5 + u0 = [0.0] + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [-Inf], ucons = [0.0]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ p[1] rtol = 1e-4 + @test opt_sol.u[1] <= p[2] + 1e-6 # constraint satisfied (slack) + + dp = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[1] + end[1] + @test dp[1] ≈ 1.0 rtol = 1e-3 # u* = p[1], so du*/dp[1] = 1 + @test dp[2] ≈ 0.0 atol = 1e-3 # inactive constraint, no dependence on p[2] + end + end + + @testset "Mixed equality + active inequality" begin + let + # Minimize (u1-3)^2 + (u2-3)^2 s.t. u1+u2 = p[1] and u1 <= p[2] + # At p=[4,1]: unconstrained-on-line solution is u1=u2=2, but u1<=1 is active + # → u1* = p[2] = 1, u2* = p[1] - p[2] = 3 + # du1*/dp[1] = 0, du1*/dp[2] = 1 + # du2*/dp[1] = 1, du2*/dp[2] = -1 + f = (u, p) -> (u[1] - 3)^2 + (u[2] - 3)^2 + cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]; res[2] = u[1] - p[2]) + + p = [4.0, 1.0] + u0 = [1.0, 3.0] # feasible: u1+u2=4, u1=1<=1 + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0, -Inf], ucons = [0.0, 0.0]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ p[2] rtol = 1e-4 + @test opt_sol.u[2] ≈ p[1] - p[2] rtol = 1e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 # equality satisfied + @test opt_sol.u[1] <= p[2] + 1e-6 # inequality satisfied + + dp = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[1] + end[1] + @test dp[1] ≈ 0.0 atol = 1e-3 + @test dp[2] ≈ 1.0 rtol = 1e-3 + + dp2 = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[2] + end[1] + @test dp2[1] ≈ 1.0 rtol = 1e-3 + @test dp2[2] ≈ -1.0 rtol = 1e-3 + end + end + + @testset "Multiple equality constraints" begin + let + # Minimize (1/2)||u||^2 s.t. u1+u2 = p[1], u2+u3 = p[2] + # Analytical solution: u* = [(2p[1]-p[2])/3, (p[1]+p[2])/3, (-p[1]+2p[2])/3] + # Jacobian: du_i/dp_j — 3x2 matrix + f = (u, p) -> sum(u .^ 2) / 2 + cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]; res[2] = u[2] + u[3] - p[2]) + + p = [1.0, 1.0] + u0 = [1.0 / 3, 2.0 / 3, 1.0 / 3] # feasible + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0, 0.0], ucons = [0.0, 0.0]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ (2p[1] - p[2]) / 3 rtol = 1e-4 + @test opt_sol.u[2] ≈ (p[1] + p[2]) / 3 rtol = 1e-4 + @test opt_sol.u[3] ≈ (-p[1] + 2p[2]) / 3 rtol = 1e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 + @test opt_sol.u[2] + opt_sol.u[3] ≈ p[2] rtol = 1e-6 + + # du1/dp = [2/3, -1/3], du2/dp = [1/3, 1/3], du3/dp = [-1/3, 2/3] + expected = [[2/3, -1/3], [1/3, 1/3], [-1/3, 2/3]] + for i in 1:3 + dp = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[i] + end[1] + @test dp ≈ expected[i] rtol = 1e-3 + end + end + end + + @testset "p in both objective and constraint (both ∇²_xp L and J_p g nonzero)" begin + let + # Minimize (u1 - p[1])^2 + u2^2 s.t. u1 + u2 = p[2] + # KKT → u1* = (p[1]+p[2])/2, u2* = (p[2]-p[1])/2 + # du1*/dp = [1/2, 1/2], du2*/dp = [-1/2, 1/2] + f = (u, p) -> (u[1] - p[1])^2 + u[2]^2 + cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[2]) + + p = [1.0, 3.0] + u0 = [1.5, 1.5] # feasible: u1+u2 = 3 = p[2] + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ (p[1] + p[2]) / 2 rtol = 1e-4 + @test opt_sol.u[2] ≈ (p[2] - p[1]) / 2 rtol = 1e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[2] rtol = 1e-6 + + dp1 = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[1] + end[1] + @test dp1 ≈ [0.5, 0.5] rtol = 1e-3 + + dp2 = Zygote.gradient(p) do p + _prob = remake(prob; p = p) + sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) + sol.u[2] + end[1] + @test dp2 ≈ [-0.5, 0.5] rtol = 1e-3 + end + end + @testset "Lemma 4.2 (Gould et al.): L2 projection onto hyperplane" begin let # Minimize (1/2)||u - p||^2 s.t. u1 + u2 + u3 = 1 From 6df29b7132249922aa741a944236f72ce4f0d355 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 31 Mar 2026 10:22:51 -0400 Subject: [PATCH 06/19] implement adjoint_sensitivities interface --- src/concrete_solve.jl | 21 ++-- src/optimization_adjoint.jl | 37 +++++-- src/sensitivity_interface.jl | 36 +++++++ test/optimization_adjoint.jl | 193 ++++++++++------------------------- 4 files changed, 131 insertions(+), 156 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index ea77139c0..938ebdb5e 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -2821,7 +2821,6 @@ function SciMLBase._concrete_solve_adjoint( _prob = remake(prob, u0 = u0, p = p) opt_sol = solve(_prob, alg, args...; kwargs...) - x_star = opt_sol.u _save_idxs = save_idxs === nothing ? Colon() : save_idxs out = if save_idxs === nothing @@ -2830,8 +2829,6 @@ function SciMLBase._concrete_solve_adjoint( SciMLBase.sensitivity_solution(opt_sol, opt_sol[_save_idxs]) end - Jpx = OptimizationAdjointProblem(_prob, opt_sol, sensealg, p) - _, repack_adjoint = if isscimlstructure(p) Zygote.pullback(p) do p t, _, _ = canonicalize(Tunable(), p) @@ -2846,12 +2843,20 @@ function SciMLBase._concrete_solve_adjoint( function optimizationbackpass(Δ) Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ - Δu = if Δ isa AbstractArray - Δ - else - Δ.u + function df(_out, _u, _p, _t, _i) + if _save_idxs isa Number + _out[_save_idxs] = Δ isa AbstractArray ? Δ[_save_idxs] : Δ.u[_save_idxs] + elseif Δ isa Number + @. _out[_save_idxs] = Δ + elseif Δ isa AbstractArray + @. _out[_save_idxs] = Δ[_save_idxs] + elseif isnothing(_out) + _out + else + @. _out[_save_idxs] = Δ.u[_save_idxs] + end end - dp = Jpx' * Δu[_save_idxs] + dp = adjoint_sensitivities(opt_sol, nothing; sensealg = sensealg, dgdu = df) dp, Δtunables = if Δ isa AbstractArray || Δ isa Number dp, Δtunables = if isscimlstructure(dp) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index d7b0640e1..1b94093cc 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -32,7 +32,7 @@ where g are equality constraints, h_I are active inequality constraints, and y*, z_I* are the corresponding dual variables. """ function OptimizationAdjointProblem( - prob::AbstractOptimizationProblem, + prob, opt_sol, sensealg::OptimizationAdjoint{CS, AD, FDT}, p @@ -43,20 +43,39 @@ function OptimizationAdjointProblem( lcons = prob.lcons ucons = prob.ucons - n_cons = length(lcons) + has_cons = lcons !== nothing && ucons !== nothing # Wrap in-place cons!(res, x, p) into an out-of-place helper. # promote_type handles ForwardDiff Dual propagation when either x or q contains duals. - function eval_cons(x, q) - T = promote_type(eltype(x), eltype(q)) - res = zeros(T, n_cons) - prob.f.cons(res, x, q) - return res + # When prob is an OptimizationCache, prob.f.cons is a 2-arg closure + # `(res, x) -> f.cons(res, x, captured_p)` from OptimizationBase.instantiate_function. + # The captured field names are mangled (e.g. `#95#f`), so we search by type to find + # the captured OptimizationFunction, regardless of field ordering. + if has_cons + n_cons = length(lcons) + _cons3 = if applicable(prob.f.cons, zeros(n_cons), x_star, p) + prob.f.cons # AbstractOptimizationProblem: already (res, x, p) + else + captured_f = let cl = prob.f.cons + getfield(cl, only(fname for fname in fieldnames(typeof(cl)) + if getfield(cl, fname) isa SciMLBase.AbstractOptimizationFunction)) + end + captured_f.cons + end + eval_cons = function (x, q) + T = promote_type(eltype(x), eltype(q)) + res = zeros(T, n_cons) + _cons3(res, x, q) + return res + end + else + n_cons = 0 + eval_cons = (_, _) -> eltype(x_star)[] end # Classify constraints: equality where lcons[i] == ucons[i] - eq_idx = findall(i -> lcons[i] == ucons[i], eachindex(lcons)) - ineq_idx = findall(i -> lcons[i] != ucons[i], eachindex(lcons)) + eq_idx = has_cons ? findall(i -> lcons[i] == ucons[i], eachindex(lcons)) : Int[] + ineq_idx = has_cons ? findall(i -> lcons[i] != ucons[i], eachindex(lcons)) : Int[] # Evaluate constraints at solution c_val = eval_cons(x_star, p) diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index e9ab9ba5d..b03a5a129 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -423,6 +423,28 @@ function adjoint_sensitivities( end end +function adjoint_sensitivities( + sol::SciMLBase.AbstractOptimizationSolution, + alg::Nothing, args...; + sensealg::OptimizationAdjoint, + verbose = true, kwargs... + ) + return _adjoint_sensitivities(sol, sensealg, alg, args...; verbose, kwargs...) +end + +function _adjoint_sensitivities( + sol, sensealg::OptimizationAdjoint, ::Nothing; + dgdu = nothing, kwargs... + ) + dgdu === nothing && + error("dgdu must be specified for OptimizationAdjoint") + p = SymbolicIndexingInterface.parameter_values(sol) + Jpx = OptimizationAdjointProblem(sol.prob, sol, sensealg, p) + Δu = zero(sol.u) + dgdu(Δu, sol.u, p, nothing, nothing) + return Jpx' * Δu +end + function _adjoint_sensitivities( sol, sensealg, alg; t = nothing, @@ -526,6 +548,20 @@ function _adjoint_sensitivities( return SteadyStateAdjointProblem(sol, sensealg, alg, dgdu, dgdp, g; kwargs...) end +function _adjoint_sensitivities( + sol, sensealg::OptimizationAdjoint, alg; + dgdu = nothing, kwargs... + ) + dgdu === nothing && + error("dgdu must be specified for OptimizationAdjoint") + prob = sol.prob + p = prob.p + Jpx = OptimizationAdjointProblem(prob, sol, sensealg, p) + Δu = zero(sol.u) + dgdu(Δu, sol.u, p, nothing, nothing) + return Jpx' * Δu +end + @doc doc""" ```julia H = second_order_sensitivities(loss,prob,alg,args...; diff --git a/test/optimization_adjoint.jl b/test/optimization_adjoint.jl index fde586a89..4eb533193 100644 --- a/test/optimization_adjoint.jl +++ b/test/optimization_adjoint.jl @@ -1,6 +1,6 @@ using Test, LinearAlgebra using SciMLSensitivity, Optimization, OptimizationOptimisers, OptimizationNLopt, SciMLBase -using Mooncake, ForwardDiff, Zygote +using Mooncake, ForwardDiff using SciMLSensitivity: MooncakeVJP # Helper: build a NonlinearSolution from an optimization solve using the gradient as the residual, @@ -194,36 +194,25 @@ end # Optimal solution: u1* = u2* = p[1]/2 # du1*/dp[1] = 0.5, du2*/dp[1] = 0.5 f = (u, p) -> (u[1] - 1)^2 + (u[2] - 1)^2 - # Constraint: u1 + u2 - p[1] = 0 (p flows through cons for correct adjoint) cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]) - u0 = [1.5, 1.5] # feasible starting point: u1 + u2 = p[1] = 3 + u0 = [1.5, 1.5] # feasible: u1+u2 = p[1] = 3 p = [3.0] opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) - # Verify the forward solve opt_sol = solve(prob, NLopt.LD_SLSQP()) @test opt_sol.u[1] ≈ p[1] / 2 rtol = 1e-4 @test opt_sol.u[2] ≈ p[1] / 2 rtol = 1e-4 @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 # constraint satisfied - # d(u1* + u2*)/dp[1] = d(p[1])/dp[1] = 1 - dp = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[1] + sol.u[2] - end[1] - @test dp[1] ≈ 1.0 rtol = 1e-4 - - # du1*/dp[1] = 0.5 - dp1 = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[1] - end[1] - @test dp1[1] ≈ 0.5 rtol = 1e-4 + dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) + dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) + dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) + dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) + @test dp1[1] ≈ 0.5 rtol = 1e-4 # du1*/dp[1] + @test dp2[1] ≈ 0.5 rtol = 1e-4 # du2*/dp[1] end end @@ -233,7 +222,6 @@ end # Optimal solution: u* = p[2] # du*/dp[1] = 0, du*/dp[2] = 1 f = (u, p) -> (u[1] - p[1])^2 - # Constraint: u[1] - p[2] <= 0 (p[2] flows through cons for correct adjoint) cons = (res, u, p) -> (res[1] = u[1] - p[2]) u0 = [0.0] @@ -246,13 +234,10 @@ end @test opt_sol.u[1] ≈ p[2] rtol = 1e-4 @test opt_sol.u[1] <= p[2] + 1e-6 # constraint satisfied: u <= p[2] - dp = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[1] - end[1] - @test dp[1] ≈ 0.0 atol = 1e-4 # du*/dp[1] = 0 (u* doesn't depend on p[1]) - @test dp[2] ≈ 1.0 rtol = 1e-4 # du*/dp[2] = 1 (u* = p[2]) + dgdu!(out, _, _, _, _) = (out[1] = 1.0) + dp = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu!) + @test dp[1] ≈ 0.0 atol = 1e-4 # du*/dp[1] = 0 + @test dp[2] ≈ 1.0 rtol = 1e-4 # du*/dp[2] = 1 end end @@ -271,16 +256,11 @@ end opt_sol = solve(prob, NLopt.LD_SLSQP()) @test opt_sol.u[1] + opt_sol.u[2] ≈ p[3] rtol = 1e-6 # constraint satisfied - dp_fd = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint(autodiff = false)) - sol.u[1] - end[1] - dp_fwd = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint(autodiff = true)) - sol.u[1] - end[1] + dgdu!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) + dp_fd = adjoint_sensitivities(opt_sol, nothing; + sensealg = OptimizationAdjoint(autodiff = false), dgdu = dgdu!) + dp_fwd = adjoint_sensitivities(opt_sol, nothing; + sensealg = OptimizationAdjoint(autodiff = true), dgdu = dgdu!) @test dp_fd ≈ dp_fwd rtol = 1e-3 end end @@ -295,7 +275,7 @@ end cons = (res, u, p) -> (res[1] = u[1] + u[2] - 1) p = [2.0] - u0 = [0.0, 1.0] # feasible: u1+u2 = 1, equals u* at p[1]=2 + u0 = [0.0, 1.0] # feasible: u1+u2 = 1 opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) @@ -303,21 +283,14 @@ end opt_sol = solve(prob, NLopt.LD_SLSQP()) @test opt_sol.u[1] ≈ (2 - p[1]) / 4 rtol = 1e-4 @test opt_sol.u[2] ≈ (2 + p[1]) / 4 rtol = 1e-4 - @test opt_sol.u[1] + opt_sol.u[2] ≈ 1.0 rtol = 1e-6 - - dp1 = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[1] - end[1] - @test dp1[1] ≈ -0.25 rtol = 1e-3 - - dp2 = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[2] - end[1] - @test dp2[1] ≈ 0.25 rtol = 1e-3 + @test opt_sol.u[1] + opt_sol.u[2] ≈ 1.0 rtol = 1e-6 # constraint satisfied + + dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) + dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) + dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) + dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) + @test dp1[1] ≈ -0.25 rtol = 1e-3 # du1*/dp[1] + @test dp2[1] ≈ 0.25 rtol = 1e-3 # du2*/dp[1] end end @@ -339,23 +312,18 @@ end @test opt_sol.u[1] ≈ p[1] rtol = 1e-4 @test opt_sol.u[1] <= p[2] + 1e-6 # constraint satisfied (slack) - dp = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[1] - end[1] - @test dp[1] ≈ 1.0 rtol = 1e-3 # u* = p[1], so du*/dp[1] = 1 - @test dp[2] ≈ 0.0 atol = 1e-3 # inactive constraint, no dependence on p[2] + dgdu!(out, _, _, _, _) = (out[1] = 1.0) + dp = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu!) + @test dp[1] ≈ 1.0 rtol = 1e-3 # du*/dp[1] = 1 + @test dp[2] ≈ 0.0 atol = 1e-3 # du*/dp[2] = 0 (inactive) end end @testset "Mixed equality + active inequality" begin let # Minimize (u1-3)^2 + (u2-3)^2 s.t. u1+u2 = p[1] and u1 <= p[2] - # At p=[4,1]: unconstrained-on-line solution is u1=u2=2, but u1<=1 is active - # → u1* = p[2] = 1, u2* = p[1] - p[2] = 3 - # du1*/dp[1] = 0, du1*/dp[2] = 1 - # du2*/dp[1] = 1, du2*/dp[2] = -1 + # At p=[4,1]: u1* = p[2] = 1, u2* = p[1] - p[2] = 3 + # du1*/dp = [0, 1], du2*/dp = [1, -1] f = (u, p) -> (u[1] - 3)^2 + (u[2] - 3)^2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]; res[2] = u[1] - p[2]) @@ -371,21 +339,14 @@ end @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 # equality satisfied @test opt_sol.u[1] <= p[2] + 1e-6 # inequality satisfied - dp = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[1] - end[1] - @test dp[1] ≈ 0.0 atol = 1e-3 - @test dp[2] ≈ 1.0 rtol = 1e-3 - - dp2 = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[2] - end[1] - @test dp2[1] ≈ 1.0 rtol = 1e-3 - @test dp2[2] ≈ -1.0 rtol = 1e-3 + dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) + dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) + dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) + dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) + @test dp1[1] ≈ 0.0 atol = 1e-3 # du1*/dp[1] + @test dp1[2] ≈ 1.0 rtol = 1e-3 # du1*/dp[2] + @test dp2[1] ≈ 1.0 rtol = 1e-3 # du2*/dp[1] + @test dp2[2] ≈ -1.0 rtol = 1e-3 # du2*/dp[2] end end @@ -393,7 +354,7 @@ end let # Minimize (1/2)||u||^2 s.t. u1+u2 = p[1], u2+u3 = p[2] # Analytical solution: u* = [(2p[1]-p[2])/3, (p[1]+p[2])/3, (-p[1]+2p[2])/3] - # Jacobian: du_i/dp_j — 3x2 matrix + # du1/dp = [2/3, -1/3], du2/dp = [1/3, 1/3], du3/dp = [-1/3, 2/3] f = (u, p) -> sum(u .^ 2) / 2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]; res[2] = u[2] + u[3] - p[2]) @@ -410,15 +371,13 @@ end @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 @test opt_sol.u[2] + opt_sol.u[3] ≈ p[2] rtol = 1e-6 - # du1/dp = [2/3, -1/3], du2/dp = [1/3, 1/3], du3/dp = [-1/3, 2/3] expected = [[2/3, -1/3], [1/3, 1/3], [-1/3, 2/3]] - for i in 1:3 - dp = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[i] - end[1] - @test dp ≈ expected[i] rtol = 1e-3 + for (i, exp_row) in enumerate(expected) + e = zeros(3); e[i] = 1.0 + dgdui!(out, _, _, _, _) = copyto!(out, e) + dp = adjoint_sensitivities(opt_sol, nothing; + sensealg = OptimizationAdjoint(), dgdu = dgdui!) + @test dp ≈ exp_row rtol = 1e-3 end end end @@ -440,58 +399,14 @@ end opt_sol = solve(prob, NLopt.LD_SLSQP()) @test opt_sol.u[1] ≈ (p[1] + p[2]) / 2 rtol = 1e-4 @test opt_sol.u[2] ≈ (p[2] - p[1]) / 2 rtol = 1e-4 - @test opt_sol.u[1] + opt_sol.u[2] ≈ p[2] rtol = 1e-6 - - dp1 = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[1] - end[1] - @test dp1 ≈ [0.5, 0.5] rtol = 1e-3 - - dp2 = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[2] - end[1] - @test dp2 ≈ [-0.5, 0.5] rtol = 1e-3 - end - end - - @testset "Lemma 4.2 (Gould et al.): L2 projection onto hyperplane" begin - let - # Minimize (1/2)||u - p||^2 s.t. u1 + u2 + u3 = 1 - # Analytical solution (Lemma 4.2 with A = [1 1 1], H = I): - # g'(p) = I - A^T(AA^T)^{-1}A = I - (1/3)J - # dg_i/dp_j = δ_ij - 1/3 - f = (u, p) -> sum((u .- p) .^ 2) / 2 - cons = (res, u, p) -> (res[1] = u[1] + u[2] + u[3] - 1) + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[2] rtol = 1e-6 # constraint satisfied - p = [2.0, 0.0, 0.0] - u0 = [1.0 / 3, 1.0 / 3, 1.0 / 3] # feasible starting point - - opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) - - # Verify forward solve: u* = p - (sum(p)-1)/3 * [1,1,1] = [5/3, -1/3, -1/3] - opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ 5.0 / 3 rtol = 1e-4 - @test opt_sol.u[2] ≈ -1.0 / 3 rtol = 1e-4 - @test opt_sol.u[3] ≈ -1.0 / 3 rtol = 1e-4 - @test sum(opt_sol.u) ≈ 1.0 rtol = 1e-6 # constraint satisfied - - # Verify adjoint: dg_i/dp_j = δ_ij - 1/3 - for i in 1:3 - dp = Zygote.gradient(p) do p - _prob = remake(prob; p = p) - sol = solve(_prob, NLopt.LD_SLSQP(); sensealg = OptimizationAdjoint()) - sol.u[i] - end[1] - - expected = [-1.0 / 3, -1.0 / 3, -1.0 / 3] - expected[i] += 1.0 # δ_ij term - @test dp ≈ expected rtol = 1e-3 - end + dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) + dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) + dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) + dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) + @test dp1 ≈ [0.5, 0.5] rtol = 1e-3 + @test dp2 ≈ [-0.5, 0.5] rtol = 1e-3 end end end From 325406f0383f195faef2631ebeaaf3a9a5c9903d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 31 Mar 2026 11:11:09 -0400 Subject: [PATCH 07/19] use appropriate derivatives from OptimizationFunction if available --- src/optimization_adjoint.jl | 89 ++++++++++++++++++++++++++++++++----- 1 file changed, 77 insertions(+), 12 deletions(-) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 1b94093cc..e791d9e83 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -16,6 +16,37 @@ function _optimization_hess(f, x, ::Val{false}, ::FDT) where {FDT} y -> FiniteDiff.finite_difference_gradient(f, y, FDT()), x, FDT()) end +# Evaluate OptimizationFunction auxiliary fields (grad, hess, cons_j, lag_h). +# Dispatched on: +# Val{iip} — from OptimizationFunction{iip}: true = in-place (leading buffer), false = oop +# Val{has_p} — true = AbstractOptimizationProblem (p explicit), false = OptimizationCache (p baked in) +function _opt_eval_vec(fn, n, x, p, ::Val{true}, ::Val{true}) + out = zeros(eltype(x), n); fn(out, x, p); out +end +function _opt_eval_vec(fn, n, x, _, ::Val{true}, ::Val{false}) + out = zeros(eltype(x), n); fn(out, x); out +end +_opt_eval_vec(fn, _, x, p, ::Val{false}, ::Val{true}) = fn(x, p) +_opt_eval_vec(fn, _, x, _, ::Val{false}, ::Val{false}) = fn(x) + +function _opt_eval_mat(fn, m, n, x, p, ::Val{true}, ::Val{true}) + out = zeros(eltype(x), m, n); fn(out, x, p); out +end +function _opt_eval_mat(fn, m, n, x, _, ::Val{true}, ::Val{false}) + out = zeros(eltype(x), m, n); fn(out, x); out +end +_opt_eval_mat(fn, _, _, x, p, ::Val{false}, ::Val{true}) = fn(x, p) +_opt_eval_mat(fn, _, _, x, _, ::Val{false}, ::Val{false}) = fn(x) + +function _opt_eval_lag_h(fn, n, x, σ, μ, p, ::Val{true}, ::Val{true}) + H = zeros(eltype(x), n, n); fn(H, x, σ, μ, p); H +end +function _opt_eval_lag_h(fn, n, x, σ, μ, _, ::Val{true}, ::Val{false}) + H = zeros(eltype(x), n, n); fn(H, x, σ, μ); H +end +_opt_eval_lag_h(fn, _, x, σ, μ, p, ::Val{false}, ::Val{true}) = fn(x, σ, μ, p) +_opt_eval_lag_h(fn, _, x, σ, μ, _, ::Val{false}, ::Val{false}) = fn(x, σ, μ) + """ OptimizationAdjointProblem(prob, opt_sol, sensealg, p) -> Jpx @@ -97,15 +128,35 @@ function OptimizationAdjointProblem( n_eq = length(eq_idx) n_act = length(active_lb) + length(active_ub) + n_x = length(x_star) + opt_f = prob.f + iip_val = Val{SciMLBase.isinplace(opt_f)}() + has_p_val = Val{prob isa SciMLBase.AbstractOptimizationProblem}() + + # ---- ∇f at x_star: use stored gradient if available ---- + ∇f = if opt_f.grad !== nothing + _opt_eval_vec(opt_f.grad, n_x, x_star, p, iip_val, has_p_val) + else + _optimization_grad(x -> prob.f(x, p), x_star, ad_val, fdt_val) + end - # Jacobians of constraints w.r.t. x (needed for dual variables and KKT matrix) - Jxg = isempty(eq_idx) ? zeros(eltype(x_star), 0, length(x_star)) : - _optimization_jac(x -> g(x, p), x_star, ad_val, fdt_val) - Jxhι = n_act == 0 ? zeros(eltype(x_star), 0, length(x_star)) : - _optimization_jac(x -> h_I(x, p), x_star, ad_val, fdt_val) + # ---- Constraint Jacobians w.r.t. x: use cons_j if available ---- + # cons_j gives the full (n_cons × n_x) Jacobian in one call; slice for eq/active ineq. + # Sign convention: active_lb rows are negated because h_lb = lcons - cons(x,p). + if has_cons && opt_f.cons_j !== nothing + J_full = _opt_eval_mat(opt_f.cons_j, n_cons, n_x, x_star, p, iip_val, has_p_val) + Jxg = isempty(eq_idx) ? zeros(eltype(x_star), 0, n_x) : J_full[eq_idx, :] + Jxhι = n_act == 0 ? zeros(eltype(x_star), 0, n_x) : + vcat(isempty(active_lb) ? zeros(eltype(x_star), 0, n_x) : -J_full[active_lb, :], + isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :]) + else + Jxg = isempty(eq_idx) ? zeros(eltype(x_star), 0, n_x) : + _optimization_jac(x -> g(x, p), x_star, ad_val, fdt_val) + Jxhι = n_act == 0 ? zeros(eltype(x_star), 0, n_x) : + _optimization_jac(x -> h_I(x, p), x_star, ad_val, fdt_val) + end # Dual variables from stationarity condition: constraint_jac^T * [y*; z_I*] = -∇f(x*) - ∇f = _optimization_grad(x -> prob.f(x, p), x_star, ad_val, fdt_val) constraint_jac = vcat(Jxg, Jxhι) # (n_eq + n_act) × n_x # Solve overdetermined stationarity system via QR (n_x equations, n_eq+n_act unknowns) dual_vars = if n_eq + n_act == 0 @@ -117,19 +168,33 @@ function OptimizationAdjointProblem( y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] zI_star = n_act > 0 ? dual_vars[(n_eq + 1):end] : eltype(x_star)[] - # Lagrangian with fixed multipliers - function L(x, q) + # Lagrangian with fixed multipliers (used for p-derivative computations below) + L = function(x, q) val = prob.f(x, q) n_eq > 0 && (val += dot(y_star, g(x, q))) n_act > 0 && (val += dot(zI_star, h_I(x, q))) return val end - # Assemble KKT matrix - Lxx = _optimization_hess(x -> L(x, p), x_star, ad_val, fdt_val) + # ---- Lagrangian Hessian w.r.t. x: use lag_h if available, else hess (unconstrained), else AD ---- + # lag_h(H, u, σ, μ, p) computes Hessian of σ*f + Σ μᵢ*consᵢ. + # Mapping from our dual vars to the full μ vector: + # μ[eq_idx[j]] = y_star[j] (g = cons[eq] - lcons, same sign as cons) + # μ[active_lb[j]] = -zI_star[j] (h_lb = lcons - cons → -cons contribution) + # μ[active_ub[j]] = zI_star[n_lb + j] (h_ub = cons - ucons → +cons contribution) + Lxx = if opt_f.lag_h !== nothing + mu_full = zeros(eltype(x_star), n_cons) + for (j, i) in enumerate(eq_idx); mu_full[i] = y_star[j] end + for (j, i) in enumerate(active_lb); mu_full[i] -= zI_star[j] end + for (j, i) in enumerate(active_ub); mu_full[i] += zI_star[length(active_lb) + j] end + _opt_eval_lag_h(opt_f.lag_h, n_x, x_star, one(eltype(x_star)), mu_full, p, iip_val, has_p_val) + elseif !has_cons && opt_f.hess !== nothing + _opt_eval_mat(opt_f.hess, n_x, n_x, x_star, p, iip_val, has_p_val) + else + _optimization_hess(x -> L(x, p), x_star, ad_val, fdt_val) + end - n_x = length(x_star) - N = n_x + n_eq + n_act + N = n_x + n_eq + n_act KKT = zeros(eltype(x_star), N, N) KKT[1:n_x, 1:n_x] = Lxx if n_eq > 0 From 4c4dbbba37ab245adeac838337073752561feccf Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 6 Apr 2026 14:38:02 -0400 Subject: [PATCH 08/19] account for lb / ub in optimization problem --- src/optimization_adjoint.jl | 41 ++++++++++++++++++++++++++++-------- test/optimization_adjoint.jl | 23 ++++++++++++++++++++ 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index e791d9e83..6c323a2f7 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -129,6 +129,15 @@ function OptimizationAdjointProblem( n_eq = length(eq_idx) n_act = length(active_lb) + length(active_ub) n_x = length(x_star) + + # Variable bounds (lb/ub) as additional active inequality constraints. + # h_lb_var: lb[i] - x[i] = 0 when active → ∂/∂x = -e_i, ∂/∂p = 0 + # h_ub_var: x[i] - ub[i] = 0 when active → ∂/∂x = +e_i, ∂/∂p = 0 + lb = prob.lb + ub = prob.ub + active_lb_var = lb !== nothing ? findall(i -> abs(x_star[i] - lb[i]) <= atol, 1:n_x) : Int[] + active_ub_var = ub !== nothing ? findall(i -> abs(x_star[i] - ub[i]) <= atol, 1:n_x) : Int[] + n_bound = length(active_lb_var) + length(active_ub_var) opt_f = prob.f iip_val = Val{SciMLBase.isinplace(opt_f)}() has_p_val = Val{prob isa SciMLBase.AbstractOptimizationProblem}() @@ -156,17 +165,30 @@ function OptimizationAdjointProblem( _optimization_jac(x -> h_I(x, p), x_star, ad_val, fdt_val) end - # Dual variables from stationarity condition: constraint_jac^T * [y*; z_I*] = -∇f(x*) - constraint_jac = vcat(Jxg, Jxhι) # (n_eq + n_act) × n_x - # Solve overdetermined stationarity system via QR (n_x equations, n_eq+n_act unknowns) - dual_vars = if n_eq + n_act == 0 + # Append trivial Jacobian rows for active variable bounds + if n_bound > 0 + Jx_bound = zeros(eltype(x_star), n_bound, n_x) + for (j, i) in enumerate(active_lb_var) + Jx_bound[j, i] = -one(eltype(x_star)) + end + for (j, i) in enumerate(active_ub_var) + Jx_bound[length(active_lb_var) + j, i] = one(eltype(x_star)) + end + Jxhι = vcat(Jxhι, Jx_bound) + end + n_act_total = n_act + n_bound + + # Dual variables from stationarity condition: constraint_jac^T * [y*; z_I*; z_bound] = -∇f(x*) + constraint_jac = vcat(Jxg, Jxhι) # (n_eq + n_act_total) × n_x + # Solve overdetermined stationarity system via QR (n_x equations, n_eq+n_act_total unknowns) + dual_vars = if n_eq + n_act_total == 0 eltype(x_star)[] else dual_prob = LinearProblem(Matrix(constraint_jac'), -∇f) solve(dual_prob, LinearSolve.QRFactorization(); sensealg.linsolve_kwargs...).u end - y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] - zI_star = n_act > 0 ? dual_vars[(n_eq + 1):end] : eltype(x_star)[] + y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] + zI_star = n_act > 0 ? dual_vars[(n_eq + 1):(n_eq + n_act)] : eltype(x_star)[] # Lagrangian with fixed multipliers (used for p-derivative computations below) L = function(x, q) @@ -194,26 +216,27 @@ function OptimizationAdjointProblem( _optimization_hess(x -> L(x, p), x_star, ad_val, fdt_val) end - N = n_x + n_eq + n_act + N = n_x + n_eq + n_act_total KKT = zeros(eltype(x_star), N, N) KKT[1:n_x, 1:n_x] = Lxx if n_eq > 0 KKT[1:n_x, (n_x + 1):(n_x + n_eq)] = Jxg' KKT[(n_x + 1):(n_x + n_eq), 1:n_x] = Jxg end - if n_act > 0 + if n_act_total > 0 KKT[1:n_x, (n_x + n_eq + 1):N] = Jxhι' KKT[(n_x + n_eq + 1):N, 1:n_x] = Jxhι end # RHS: parameter Jacobians + # Variable bounds don't depend on p, so their p-Jacobian rows are zero. Lxp = _optimization_jac( q -> _optimization_grad(x -> L(x, q), x_star, ad_val, fdt_val), p, ad_val, fdt_val) Jpg = n_eq > 0 ? _optimization_jac(q -> g(x_star, q), p, ad_val, fdt_val) : zeros(eltype(x_star), 0, length(p)) Jphι = n_act > 0 ? _optimization_jac(q -> h_I(x_star, q), p, ad_val, fdt_val) : zeros(eltype(x_star), 0, length(p)) - RHS_p = vcat(Lxp, Jpg, Jphι) # (N × n_p) + RHS_p = vcat(Lxp, Jpg, Jphι, zeros(eltype(x_star), n_bound, length(p))) # (N × n_p) # Solve KKT system column-by-column, reusing the factorization via the cache interface n_p = size(RHS_p, 2) diff --git a/test/optimization_adjoint.jl b/test/optimization_adjoint.jl index 4eb533193..b2e65328a 100644 --- a/test/optimization_adjoint.jl +++ b/test/optimization_adjoint.jl @@ -382,6 +382,29 @@ end end end + @testset "Active variable bound (lb/ub)" begin + let + # Minimize (u - p[1])^2 s.t. u >= p[2] where p[2] > p[1] (lb active) + # Optimal solution: u* = p[2] + # du*/dp[1] = 0, du*/dp[2] = 1 + f = (u, p) -> (u[1] - p[1])^2 + + p = [1.0, 3.0] # unconstrained min at u=1, lb forces u>=3 + u0 = [3.0] + + opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff()) + prob = OptimizationProblem(opt_f, u0, p; lb = [p[2]], ub = [Inf]) + + opt_sol = solve(prob, NLopt.LD_SLSQP()) + @test opt_sol.u[1] ≈ p[2] rtol = 1e-4 + + dgdu!(out, _, _, _, _) = (out[1] = 1.0) + dp = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu!) + @test dp[1] ≈ 0.0 atol = 1e-4 # du*/dp[1] = 0 + @test dp[2] ≈ 1.0 rtol = 1e-4 # du*/dp[2] = 1 + end + end + @testset "p in both objective and constraint (both ∇²_xp L and J_p g nonzero)" begin let # Minimize (u1 - p[1])^2 + u2^2 s.t. u1 + u2 = p[2] From 80e2a8004fabaad4a9838b0110645ef0ab6658cf Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 8 Apr 2026 11:39:52 -0400 Subject: [PATCH 09/19] better test --- test/optimization_adjoint.jl | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/test/optimization_adjoint.jl b/test/optimization_adjoint.jl index b2e65328a..5eaedb011 100644 --- a/test/optimization_adjoint.jl +++ b/test/optimization_adjoint.jl @@ -384,24 +384,27 @@ end @testset "Active variable bound (lb/ub)" begin let - # Minimize (u - p[1])^2 s.t. u >= p[2] where p[2] > p[1] (lb active) - # Optimal solution: u* = p[2] - # du*/dp[1] = 0, du*/dp[2] = 1 - f = (u, p) -> (u[1] - p[1])^2 + # Minimize (u1-p)^2 + (u2-p)^2 s.t. u1 >= 2 (active lb, since p=0 < 2), u2 free + # u1* = 2 (pinned at bound) → du1*/dp = 0 (without lb in KKT this incorrectly gives 1) + # u2* = p = 0 (unconstrained) → du2*/dp = 1 + f = (u, p) -> (u[1] - p[1])^2 + (u[2] - p[1])^2 - p = [1.0, 3.0] # unconstrained min at u=1, lb forces u>=3 - u0 = [3.0] + p = [0.0] + u0 = [2.0, 0.0] opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff()) - prob = OptimizationProblem(opt_f, u0, p; lb = [p[2]], ub = [Inf]) + prob = OptimizationProblem(opt_f, u0, p; lb = [2.0, -Inf], ub = [Inf, Inf]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ p[2] rtol = 1e-4 + @test opt_sol.u[1] ≈ 2.0 rtol = 1e-4 # pinned at lb + @test opt_sol.u[2] ≈ p[1] rtol = 1e-4 # free, at unconstrained min - dgdu!(out, _, _, _, _) = (out[1] = 1.0) - dp = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu!) - @test dp[1] ≈ 0.0 atol = 1e-4 # du*/dp[1] = 0 - @test dp[2] ≈ 1.0 rtol = 1e-4 # du*/dp[2] = 1 + dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) + dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) + dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) + dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) + @test dp1[1] ≈ 0.0 atol = 1e-4 # du1*/dp = 0 (pinned at bound) + @test dp2[1] ≈ 1.0 rtol = 1e-4 # du2*/dp = 1 (free variable) end end From ee152c29b8c6ac384f4f41cf49faec6e1fd3d55b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 8 Apr 2026 11:42:09 -0400 Subject: [PATCH 10/19] finding dual variables doesn't need kwargs --- src/optimization_adjoint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 6c323a2f7..8bf554374 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -185,7 +185,7 @@ function OptimizationAdjointProblem( eltype(x_star)[] else dual_prob = LinearProblem(Matrix(constraint_jac'), -∇f) - solve(dual_prob, LinearSolve.QRFactorization(); sensealg.linsolve_kwargs...).u + solve(dual_prob, LinearSolve.QRFactorization()).u end y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] zI_star = n_act > 0 ? dual_vars[(n_eq + 1):(n_eq + n_act)] : eltype(x_star)[] From e4cf67045ce6ab0a615bbf4301219bfff770037e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 22 Apr 2026 12:33:43 -0400 Subject: [PATCH 11/19] DifferentiationInterface gradients, with vecjacobian --- Project.toml | 7 +- ext/SciMLSensitivityMooncakeExt.jl | 1 + src/SciMLSensitivity.jl | 1 + src/adjoint_common.jl | 1 + src/concrete_solve.jl | 4 - src/derivative_wrappers.jl | 25 +- src/optimization_adjoint.jl | 384 +++++++++++++++++++++-------- src/sensitivity_algorithms.jl | 10 +- src/sensitivity_interface.jl | 18 +- 9 files changed, 310 insertions(+), 141 deletions(-) diff --git a/Project.toml b/Project.toml index 96a5b4f00..2316e988c 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" @@ -88,7 +89,6 @@ Lux = "1" Markdown = "1.10" ModelingToolkit = "10, 11" Mooncake = "0.5.24" -Reactant = "0.2.22" NLsolve = "4.5.1" NonlinearSolve = "3.0.1, 4" SCCNonlinearSolve = "1" @@ -107,6 +107,7 @@ PreallocationTools = "1.1.1" QuadGK = "2.9.1" Random = "1.10" RandomNumbers = "1.5.3" +Reactant = "0.2.22" RecursiveArrayTools = "3.27.2, 4" Reexport = "1.0" ReverseDiff = "1.15.1" @@ -141,12 +142,11 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" -OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125" OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6" @@ -155,6 +155,7 @@ OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce" OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf" OrdinaryDiffEqStabilizedRK = "358294b1-0aab-51c3-aafe-ad5ab194a2ad" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/ext/SciMLSensitivityMooncakeExt.jl b/ext/SciMLSensitivityMooncakeExt.jl index 83ff34c48..0d8c47555 100644 --- a/ext/SciMLSensitivityMooncakeExt.jl +++ b/ext/SciMLSensitivityMooncakeExt.jl @@ -3,6 +3,7 @@ module SciMLSensitivityMooncakeExt using SciMLSensitivity: SciMLSensitivity, FakeIntegrator using Mooncake: Mooncake import SciMLSensitivity: get_paramjac_config, get_cb_paramjac_config, mooncake_run_ad, + MooncakeVJP, MooncakeLoaded, DiffEqBase, MooncakeAdjoint, _init_originator_gradient using SciMLSensitivity: SciMLBase, SciMLStructures, canonicalize, Tunable, isscimlstructure, diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index d6173e0ea..68ed640da 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -47,6 +47,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit, # AD Backends using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk, AbstractTangent +import DifferentiationInterface as DI using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index b9df97fcb..2ab91bc74 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -614,6 +614,7 @@ function mooncake_run_ad(paramjac_config, y, p, t, λ) error(msg) end + function get_pf(::ReactantVJP, prob, _f) isinplace = DiffEqBase.isinplace(prob) isRODE = isa(prob, RODEProblem) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 938ebdb5e..389755011 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -2814,10 +2814,6 @@ function SciMLBase._concrete_solve_adjoint( u0, p, originator::SciMLBase.ADOriginator, args...; save_idxs = nothing, kwargs... ) where {CS, AD, FDT} - if prob.lcons === nothing - error("OptimizationAdjoint requires a constrained OptimizationProblem (lcons/ucons). " * - "For unconstrained problems, use UnconstrainedOptimizationAdjoint instead.") - end _prob = remake(prob, u0 = u0, p = p) opt_sol = solve(_prob, alg, args...; kwargs...) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index c76a65ece..d215dd7dc 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -146,6 +146,17 @@ function jacobian( return J end +function gradient( + f, x::AbstractArray{<:Number}, + alg::AbstractOverloadingSensitivityAlgorithm + ) + if alg_autodiff(alg) + ForwardDiff.gradient(unwrapped_f(f), x) + else + FiniteDiff.finite_difference_gradient(f, x, diff_type(alg)) + end +end + function jacobian!( J::Nothing, f, x::AbstractArray{<:Number}, fx::Union{Nothing, AbstractArray{<:Number}}, @@ -677,10 +688,11 @@ function _vecjacobian!( (; tunables, repack) = S.diffcache end - u0 = state_values(prob) - if prob isa AbstractNonlinearProblem || + if prob isa AbstractNonlinearProblem || prob isa SciMLBase.AbstractOptimizationCache || ( - eltype(λ) <: eltype(u0) && t isa eltype(u0) && + let u0 = state_values(prob) + eltype(λ) <: eltype(u0) && t isa eltype(u0) + end && compile_tape(sensealg.autojacvec) ) tape = S.diffcache.paramjac_config @@ -731,7 +743,8 @@ function _vecjacobian!( end end - if prob isa AbstractNonlinearProblem + _no_time = prob isa AbstractNonlinearProblem || prob isa SciMLBase.AbstractOptimizationCache + if _no_time tu, tp = ReverseDiff.input_hook(tape) else if W === nothing @@ -743,13 +756,13 @@ function _vecjacobian!( output = ReverseDiff.output_hook(tape) ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls ReverseDiff.unseed!(tp) - if !(prob isa AbstractNonlinearProblem) + if !_no_time ReverseDiff.unseed!(tt) end W !== nothing && ReverseDiff.unseed!(tW) ReverseDiff.value!(tu, y) p isa SciMLBase.NullParameters || ReverseDiff.value!(tp, tunables) - if !(prob isa AbstractNonlinearProblem) + if !_no_time ReverseDiff.value!(tt, [t]) end W !== nothing && ReverseDiff.value!(tW, W) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 8bf554374..435b901bb 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -1,20 +1,28 @@ -# Differentiation helpers: dispatch on autodiff type parameter (Val{true} = ForwardDiff, -# Val{false} = FiniteDiff with the given FDT scheme) -_optimization_grad(f, x, ::Val{true}, ::FDT) where {FDT} = ForwardDiff.gradient(f, x) -function _optimization_grad(f, x, ::Val{false}, ::FDT) where {FDT} - FiniteDiff.finite_difference_gradient(f, x, FDT()) +# SensitivityFunction subtype for the OptimizationAdjoint VJP path. +# f = (_, q_full, _) -> F(x*, q_full) = [∇_x L; g; h_I], OOP, output size M = n_x + n_eq + n_act. +# y is a zeros(M) dummy state used to size AD buffers in vecjacobian! backends. +# λ holds the adjoint cotangent (λ_full[1:M]) from the KKT solve. +# dp is the pre-allocated output gradient buffer (size n_p), written by vecjacobian!. +struct OptimizationAdjointSensitivityFunction{ + C <: AdjointDiffCache, + Alg <: OptimizationAdjoint, + F, + SolType, + yType, + λType, + dpType, + } <: SensitivityFunction + diffcache::C + sensealg::Alg + f::F + sol::SolType + y::yType + λ::λType + dp::dpType end -_optimization_jac(f, x, ::Val{true}, ::FDT) where {FDT} = ForwardDiff.jacobian(f, x) -function _optimization_jac(f, x, ::Val{false}, ::FDT) where {FDT} - FiniteDiff.finite_difference_jacobian(f, x, FDT()) -end - -_optimization_hess(f, x, ::Val{true}, ::FDT) where {FDT} = ForwardDiff.hessian(f, x) -function _optimization_hess(f, x, ::Val{false}, ::FDT) where {FDT} - FiniteDiff.finite_difference_jacobian( - y -> FiniteDiff.finite_difference_gradient(f, y, FDT()), x, FDT()) -end +# Override inplace_sensitivity: f is always OOP for optimization +inplace_sensitivity(::OptimizationAdjointSensitivityFunction) = false # Evaluate OptimizationFunction auxiliary fields (grad, hess, cons_j, lag_h). # Dispatched on: @@ -47,30 +55,14 @@ end _opt_eval_lag_h(fn, _, x, σ, μ, p, ::Val{false}, ::Val{true}) = fn(x, σ, μ, p) _opt_eval_lag_h(fn, _, x, σ, μ, _, ::Val{false}, ::Val{false}) = fn(x, σ, μ) -""" - OptimizationAdjointProblem(prob, opt_sol, sensealg, p) -> Jpx - -Compute the KKT-based parameter Jacobian `Jpx` (n_x × n_p) for a constrained -`OptimizationProblem`, where `Jpx[i,j] = ∂x*[i]/∂p[j]`. - -Uses the implicit function theorem applied to the KKT conditions: - - [∇²_xx L, J_x g^T, J_x h_I^T] [J_p x ] [∇²_xp L] - [J_x g, 0, 0 ] [J_p y ] = -[J_p g ] - [J_x h_I, 0, 0 ] [J_p z_I] [J_p h_I ] - -where g are equality constraints, h_I are active inequality constraints, and -y*, z_I* are the corresponding dual variables. -""" -function OptimizationAdjointProblem( +function OptimizationAdjointSensitivityFunction( prob, opt_sol, - sensealg::OptimizationAdjoint{CS, AD, FDT}, - p - ) where {CS, AD, FDT} + sensealg::OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD, AT}, + p, + Δu + ) where {CS, AD, FDT, VJP, LS, LK, OAD, AT} x_star = opt_sol.u - ad_val = Val{AD}() - fdt_val = FDT() lcons = prob.lcons ucons = prob.ucons @@ -82,6 +74,9 @@ function OptimizationAdjointProblem( # `(res, x) -> f.cons(res, x, captured_p)` from OptimizationBase.instantiate_function. # The captured field names are mangled (e.g. `#95#f`), so we search by type to find # the captured OptimizationFunction, regardless of field ordering. + # Very unfortunate that this is needed, but sensitivity requires the three arg version of + # the constraint function. To make this stable and not hacky there needs to be a way to access + # it without going inside of the closure. if has_cons n_cons = length(lcons) _cons3 = if applicable(prob.f.cons, zeros(n_cons), x_star, p) @@ -93,9 +88,18 @@ function OptimizationAdjointProblem( end captured_f.cons end + cons_cache = LazyBufferCache(_ -> (n_cons,)) eval_cons = function (x, q) - T = promote_type(eltype(x), eltype(q)) - res = zeros(T, n_cons) + # When eltype(x) and eltype(q) match, reuse the cache. Otherwise infer the + # arithmetic result type via promote_op (compile-time inference) — ForwardDiff's + # @define_binary_dual_op bypasses promote_type, so Dual + TrackedReal returns + # Dual{Tag, TrackedReal, N} which promote_type would not predict. + res = if eltype(x) === eltype(q) + cons_cache[q] + else + T = Base.promote_op(+, eltype(x), eltype(q)) + Vector{T}(undef, n_cons) + end _cons3(res, x, q) return res end @@ -111,84 +115,144 @@ function OptimizationAdjointProblem( # Evaluate constraints at solution c_val = eval_cons(x_star, p) - # Find active inequality constraints + # Find active inequality constraints (proximity-based initial estimate). + # Refined below via multiplier sign check to avoid spurious active constraints. atol = sensealg.active_tol === nothing ? sqrt(eps(eltype(x_star))) : sensealg.active_tol active_lb = filter(i -> abs(c_val[i] - lcons[i]) <= atol, ineq_idx) active_ub = filter(i -> abs(c_val[i] - ucons[i]) <= atol, ineq_idx) - # Constraint residual functions shifted to = 0 at optimum - # Equality: g(x,p) = cons(x,p)[eq_idx] - lcons[eq_idx] - # Active ineq lower bound: h_lb(x,p) = lcons[i] - cons(x,p)[i] (= 0 when active) - # Active ineq upper bound: h_ub(x,p) = cons(x,p)[i] - ucons[i] (= 0 when active) - g(x, q) = eval_cons(x, q)[eq_idx] .- lcons[eq_idx] - h_I(x, q) = vcat( - isempty(active_lb) ? eltype(x_star)[] : lcons[active_lb] .- eval_cons(x, q)[active_lb], - isempty(active_ub) ? eltype(x_star)[] : eval_cons(x, q)[active_ub] .- ucons[active_ub] - ) + # Equality constraint residual: g(x,p) = cons(x,p)[eq_idx] - lcons[eq_idx] + g(x, q) = eval_cons(x, q)[eq_idx] .- lcons[eq_idx] - n_eq = length(eq_idx) - n_act = length(active_lb) + length(active_ub) - n_x = length(x_star) + n_eq = length(eq_idx) + n_x = length(x_star) - # Variable bounds (lb/ub) as additional active inequality constraints. - # h_lb_var: lb[i] - x[i] = 0 when active → ∂/∂x = -e_i, ∂/∂p = 0 - # h_ub_var: x[i] - ub[i] = 0 when active → ∂/∂x = +e_i, ∂/∂p = 0 lb = prob.lb ub = prob.ub active_lb_var = lb !== nothing ? findall(i -> abs(x_star[i] - lb[i]) <= atol, 1:n_x) : Int[] active_ub_var = ub !== nothing ? findall(i -> abs(x_star[i] - ub[i]) <= atol, 1:n_x) : Int[] - n_bound = length(active_lb_var) + length(active_ub_var) - opt_f = prob.f + + opt_f = prob.f iip_val = Val{SciMLBase.isinplace(opt_f)}() has_p_val = Val{prob isa SciMLBase.AbstractOptimizationProblem}() + objective_ad = sensealg.objective_ad + # ---- ∇f at x_star: use stored gradient if available ---- ∇f = if opt_f.grad !== nothing _opt_eval_vec(opt_f.grad, n_x, x_star, p, iip_val, has_p_val) else - _optimization_grad(x -> prob.f(x, p), x_star, ad_val, fdt_val) + DI.gradient(x -> prob.f(x, p), objective_ad, x_star) end - # ---- Constraint Jacobians w.r.t. x: use cons_j if available ---- - # cons_j gives the full (n_cons × n_x) Jacobian in one call; slice for eq/active ineq. - # Sign convention: active_lb rows are negated because h_lb = lcons - cons(x,p). - if has_cons && opt_f.cons_j !== nothing - J_full = _opt_eval_mat(opt_f.cons_j, n_cons, n_x, x_star, p, iip_val, has_p_val) - Jxg = isempty(eq_idx) ? zeros(eltype(x_star), 0, n_x) : J_full[eq_idx, :] - Jxhι = n_act == 0 ? zeros(eltype(x_star), 0, n_x) : - vcat(isempty(active_lb) ? zeros(eltype(x_star), 0, n_x) : -J_full[active_lb, :], - isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :]) + # Precompute full constraint Jacobian once if cons_j is available (reused across passes). + J_full = has_cons && opt_f.cons_j !== nothing ? + _opt_eval_mat(opt_f.cons_j, n_cons, n_x, x_star, p, iip_val, has_p_val) : nothing + + # Equality constraint Jacobian (fixed; independent of active set). + Jxg = if isempty(eq_idx) + zeros(eltype(x_star), 0, n_x) + elseif J_full !== nothing + J_full[eq_idx, :] else - Jxg = isempty(eq_idx) ? zeros(eltype(x_star), 0, n_x) : - _optimization_jac(x -> g(x, p), x_star, ad_val, fdt_val) - Jxhι = n_act == 0 ? zeros(eltype(x_star), 0, n_x) : - _optimization_jac(x -> h_I(x, p), x_star, ad_val, fdt_val) + DI.jacobian(x -> g(x, p), objective_ad, x_star) end - # Append trivial Jacobian rows for active variable bounds - if n_bound > 0 - Jx_bound = zeros(eltype(x_star), n_bound, n_x) - for (j, i) in enumerate(active_lb_var) - Jx_bound[j, i] = -one(eltype(x_star)) - end - for (j, i) in enumerate(active_ub_var) - Jx_bound[length(active_lb_var) + j, i] = one(eltype(x_star)) - end - Jxhι = vcat(Jxhι, Jx_bound) + # Active ineq lower bound: h_lb(x,p) = lcons[i] - cons(x,p)[i] (= 0 when active) + # Active ineq upper bound: h_ub(x,p) = cons(x,p)[i] - ucons[i] (= 0 when active) + # Variable bound active lower: h_lb_var = lb[i] - x[i] (∂/∂x = -eᵢ, ∂/∂p = 0) + # Variable bound active upper: h_ub_var = x[i] - ub[i] (∂/∂x = +eᵢ, ∂/∂p = 0) + n_act = length(active_lb) + length(active_ub) + n_bound = length(active_lb_var) + length(active_ub_var) + + h_I = (x, q) -> begin + c = eval_cons(x, q) + vcat( + isempty(active_lb) ? eltype(x_star)[] : lcons[active_lb] .- c[active_lb], + isempty(active_ub) ? eltype(x_star)[] : c[active_ub] .- ucons[active_ub] + ) end - n_act_total = n_act + n_bound - # Dual variables from stationarity condition: constraint_jac^T * [y*; z_I*; z_bound] = -∇f(x*) - constraint_jac = vcat(Jxg, Jxhι) # (n_eq + n_act_total) × n_x - # Solve overdetermined stationarity system via QR (n_x equations, n_eq+n_act_total unknowns) + Jxhι_cons = if n_act == 0 + zeros(eltype(x_star), 0, n_x) + elseif J_full !== nothing + vcat( + isempty(active_lb) ? zeros(eltype(x_star), 0, n_x) : -J_full[active_lb, :], + isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] + ) + else + DI.jacobian(x -> h_I(x, p), objective_ad, x_star) + end + + Jx_bound = zeros(eltype(x_star), n_bound, n_x) + for (j, i) in enumerate(active_lb_var); Jx_bound[j, i] = -one(eltype(x_star)); end + for (j, i) in enumerate(active_ub_var) + Jx_bound[length(active_lb_var) + j, i] = one(eltype(x_star)) + end + Jxhι = vcat(Jxhι_cons, Jx_bound) + + # Dual variables from stationarity: [Jxg; Jxhι]' * [y*; z_I*; z_bound] = -∇f + n_act_total = n_act + n_bound dual_vars = if n_eq + n_act_total == 0 eltype(x_star)[] else - dual_prob = LinearProblem(Matrix(constraint_jac'), -∇f) - solve(dual_prob, LinearSolve.QRFactorization()).u + solve(LinearProblem(Matrix(vcat(Jxg, Jxhι)'), -∇f), LinearSolve.QRFactorization()).u + end + y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] + zI_star = n_act > 0 ? dual_vars[(n_eq+1):(n_eq+n_act)] : eltype(x_star)[] + z_bound_star = n_bound > 0 ? dual_vars[(n_eq+n_act+1):end] : eltype(x_star)[] + + # Multiplier sign check: KKT requires all inequality multipliers ≥ 0 at a minimum. + # Negative multipliers indicate spuriously-included constraints (close to bound but inactive). + # Drop those and redo only the Jxhι build and dual solve — no extra cost if all signs are good. + mtol = sqrt(eps(eltype(x_star))) + if (n_act > 0 && any(<(-mtol), zI_star)) || + (n_bound > 0 && any(<(-mtol), z_bound_star)) + n_lb = length(active_lb) + n_lb_var = length(active_lb_var) + active_lb = active_lb[findall(j -> zI_star[j] >= -mtol, 1:n_lb)] + active_ub = active_ub[findall(j -> zI_star[n_lb+j] >= -mtol, 1:length(active_ub))] + active_lb_var = active_lb_var[findall(j -> z_bound_star[j] >= -mtol, 1:n_lb_var)] + active_ub_var = active_ub_var[findall(j -> z_bound_star[n_lb_var+j] >= -mtol, + 1:length(active_ub_var))] + n_act = length(active_lb) + length(active_ub) + n_bound = length(active_lb_var) + length(active_ub_var) + + h_I = (x, q) -> begin + c = eval_cons(x, q) + vcat( + isempty(active_lb) ? eltype(x_star)[] : lcons[active_lb] .- c[active_lb], + isempty(active_ub) ? eltype(x_star)[] : c[active_ub] .- ucons[active_ub] + ) + end + + Jxhι_cons = if n_act == 0 + zeros(eltype(x_star), 0, n_x) + elseif J_full !== nothing + vcat( + isempty(active_lb) ? zeros(eltype(x_star), 0, n_x) : -J_full[active_lb, :], + isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] + ) + else + DI.jacobian(x -> h_I(x, p), objective_ad, x_star) + end + + Jx_bound = zeros(eltype(x_star), n_bound, n_x) + for (j, i) in enumerate(active_lb_var); Jx_bound[j, i] = -one(eltype(x_star)); end + for (j, i) in enumerate(active_ub_var) + Jx_bound[length(active_lb_var) + j, i] = one(eltype(x_star)) + end + Jxhι = vcat(Jxhι_cons, Jx_bound) + + n_act_total = n_act + n_bound + dual_vars = if n_eq + n_act_total == 0 + eltype(x_star)[] + else + solve(LinearProblem(Matrix(vcat(Jxg, Jxhι)'), -∇f), LinearSolve.QRFactorization()).u + end + y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] + zI_star = n_act > 0 ? dual_vars[(n_eq+1):(n_eq+n_act)] : eltype(x_star)[] end - y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] - zI_star = n_act > 0 ? dual_vars[(n_eq + 1):(n_eq + n_act)] : eltype(x_star)[] # Lagrangian with fixed multipliers (used for p-derivative computations below) L = function(x, q) @@ -213,7 +277,7 @@ function OptimizationAdjointProblem( elseif !has_cons && opt_f.hess !== nothing _opt_eval_mat(opt_f.hess, n_x, n_x, x_star, p, iip_val, has_p_val) else - _optimization_hess(x -> L(x, p), x_star, ad_val, fdt_val) + DI.hessian(x -> L(x, p), objective_ad, x_star) end N = n_x + n_eq + n_act_total @@ -228,24 +292,128 @@ function OptimizationAdjointProblem( KKT[(n_x + n_eq + 1):N, 1:n_x] = Jxhι end - # RHS: parameter Jacobians - # Variable bounds don't depend on p, so their p-Jacobian rows are zero. - Lxp = _optimization_jac( - q -> _optimization_grad(x -> L(x, q), x_star, ad_val, fdt_val), p, ad_val, fdt_val) - Jpg = n_eq > 0 ? _optimization_jac(q -> g(x_star, q), p, ad_val, fdt_val) : - zeros(eltype(x_star), 0, length(p)) - Jphι = n_act > 0 ? _optimization_jac(q -> h_I(x_star, q), p, ad_val, fdt_val) : - zeros(eltype(x_star), 0, length(p)) - RHS_p = vcat(Lxp, Jpg, Jphι, zeros(eltype(x_star), n_bound, length(p))) # (N × n_p) - - # Solve KKT system column-by-column, reusing the factorization via the cache interface - n_p = size(RHS_p, 2) - Jpx = zeros(eltype(x_star), n_x, n_p) - kkt_cache = LinearSolve.init(LinearProblem(KKT, -RHS_p[:, 1]), sensealg.linsolve; - sensealg.linsolve_kwargs...) - for j in 1:n_p - kkt_cache.b = -RHS_p[:, j] - Jpx[:, j] = LinearSolve.solve!(kkt_cache).u[1:n_x] + # KKT is symmetric, so KKT' = KKT. Solve KKT * λ_full = [Δu; 0; ...; 0] once for all parameters. + rhs_adj = vcat(Δu, zeros(eltype(x_star), n_eq + n_act_total)) + λ_full = solve(LinearProblem(KKT, rhs_adj), sensealg.linsolve; + sensealg.linsolve_kwargs...).u + + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + tunables, repack = p, identity + end + + autojacvec = sensealg.autojacvec + + # f_F: OOP function (_, q_full, _) -> F(x*, q_full), the KKT residual as a function of p. + # F = [∇_x L(x*, q); g(x*, q); h_I(x*, q)] — output size M = n_x + n_eq + n_act. + # Variable-bound rows are omitted since ∂(lb - x)/∂p = 0, so they don't contribute to dp. + # y = zeros(M) is passed as the dummy state; f_F ignores it, but backends use it to + # size their buffers, so buffers will be M-sized and match the output of f_F. + M = n_x + n_eq + n_act + y = zeros(eltype(x_star), M) + f_F = let L = L, g = g, h_I = h_I, x_star = x_star, + objective_ad = objective_ad, n_eq = n_eq, n_act = n_act + function(_, q_full, _) + grad_L = DI.gradient(x -> L(x, q_full), objective_ad, x_star) + n_eq == 0 && n_act == 0 && return grad_L + n_eq > 0 && n_act == 0 && return vcat(grad_L, g(x_star, q_full)) + n_eq == 0 && n_act > 0 && return vcat(grad_L, h_I(x_star, q_full)) + vcat(grad_L, g(x_star, q_full), h_I(x_star, q_full)) + end end - return Jpx # (n_x × n_p) + + # λ: adjoint cotangent for f_F — drop the variable-bound rows of λ_full (∂/∂p = 0). + λ = λ_full[1:M] + + # Build pf and paramjac_config via the same adjointdiffcache machinery used by + # SteadyStateAdjoint. f_F is OOP, output size M, no time argument. + # For Bool dispatch: pf = ParamGradientWrapper(f_F, nothing, y), pJ = M × n_p matrix. + # For VJP backends: pf/paramjac_config built by get_pf/get_paramjac_config as usual. + _needs_repack = isscimlstructure(p) && !(p isa AbstractArray) + pf, paramjac_config, pJ = if autojacvec isa ReverseDiffVJP + # 2-input tape (no time): mirrors the AbstractNonlinearProblem branch in adjointdiffcache. + # get_paramjac_config always builds a 3-input (y, p, [t]) tape, which fails for t=nothing. + _tape = ReverseDiff.GradientTape((y, tunables)) do u, q + vec(f_F(u, _needs_repack ? repack(q) : q, nothing)) + end + _config = compile_tape(autojacvec) ? ReverseDiff.compile(_tape) : _tape + nothing, _config, nothing + elseif autojacvec isa EnzymeVJP + _pf = f_F # OOP: Enzyme.make_zero(f) called inline in _vecjacobian! + _needs_shadow = _needs_repack + _shadow_p = _needs_shadow ? repack(zero(tunables)) : nothing + _config = get_paramjac_config(autojacvec, p, f_F, y, tunables, nothing; + numindvar = M, alg = nothing) + _config = (_config..., Enzyme.make_zero(_pf), _shadow_p) + _pf, _config, nothing + elseif autojacvec isa MooncakeVJP + _pf = let f_F = f_F + (out, _, q_full, _) -> (out .= f_F(nothing, q_full, nothing); out) + end + _pf = if _needs_repack + let _pf = _pf, repack = repack + (out, u, q_t, t) -> _pf(out, u, repack(q_t), t) + end + else + _pf + end + _config = get_paramjac_config(MooncakeLoaded(), autojacvec, _pf, tunables, f_F, y, nothing) + _pf, _config, nothing + elseif autojacvec isa ZygoteVJP + nothing, nothing, nothing + elseif autojacvec isa Bool + # Bool dispatch: ParamGradientWrapper (OOP) + pJ matrix + _pgrad_f = _needs_repack ? + (u, q_t, t) -> f_F(u, repack(q_t), t) : + f_F + _pf = ParamGradientWrapper(_pgrad_f, nothing, y) + _pJ = zeros(eltype(x_star), M, length(tunables)) + _pf, nothing, _pJ + else + nothing, nothing, nothing + end + + diffcache = AdjointDiffCache( + nothing, pf, nothing, nothing, pJ, # uf, pf, g, J, pJ + nothing, # dg_val + nothing, nothing, # jac_config, g_grad_config + paramjac_config, + nothing, nothing, nothing, # jac_noise, paramjac_noise, f_cache + nothing, nothing, # dgdu, dgdp + nothing, nothing, nothing, # diffvar_idxs, algevar_idxs, factorized_mass_matrix + false, # issemiexplicitdae + tunables, repack + ) + + dp = zeros(eltype(x_star), length(tunables)) + + return OptimizationAdjointSensitivityFunction(diffcache, sensealg, f_F, opt_sol, y, λ, dp) +end + +""" + OptimizationAdjointProblem(prob, opt_sol, sensealg, p, Δu) -> dp + +Compute the parameter sensitivity `dp = dG/dp` for a scalar loss `G` via one adjoint KKT solve. + +Given `Δu = dG/dx* ∈ Rⁿˣ` (the cotangent of the optimal solution supplied by the caller), +solves the adjoint system `KKT * λ_full = [Δu; 0; ...; 0]` (one linear solve, exploiting +KKT symmetry), then returns `dp = -(∂F/∂p)' · λ` via a single `vecjacobian!` call, where +`F(x*, p) = [∇_x L(x*, p); g(x*, p); h_I(x*, p)]` is the KKT residual and `λ = λ_full[1:M]` +(variable-bound rows are dropped since `∂(lb-x)/∂p = 0`). + +The VJP is computed via `sensealg.autojacvec` (falls back to ForwardDiff otherwise). +""" +function OptimizationAdjointProblem( + prob, + opt_sol, + sensealg::OptimizationAdjoint, + p, + Δu + ) + S = OptimizationAdjointSensitivityFunction(prob, opt_sol, sensealg, p, Δu) + vecjacobian!(nothing, S.y, S.λ, S.diffcache.tunables, nothing, S; dgrad = S.dp) + return -S.dp end diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index a47f79008..31667d0b0 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1404,11 +1404,12 @@ end function UnconstrainedOptimizationAdjoint(; chunk_size = 0, autodiff = true, - diff_type = Val{:central}, objective_ad = true, autojacvec = nothing, linsolve = nothing, + objective_ad = AutoForwardDiff(), + autojacvec = nothing, linsolve = nothing, linsolve_kwargs = (;) ) return UnconstrainedOptimizationAdjoint{ - chunk_size, autodiff, diff_type, typeof(autojacvec), + chunk_size, autodiff, Val{:central}, typeof(autojacvec), typeof(linsolve), typeof(linsolve_kwargs), typeof(objective_ad), }(autojacvec, linsolve, linsolve_kwargs, objective_ad) end @@ -1434,11 +1435,12 @@ end function OptimizationAdjoint(; chunk_size = 0, autodiff = true, - diff_type = Val{:central}, objective_ad = true, autojacvec = nothing, + objective_ad = AutoForwardDiff(), + autojacvec = nothing, linsolve = nothing, linsolve_kwargs = (;), active_tol = nothing ) return OptimizationAdjoint{ - chunk_size, autodiff, diff_type, typeof(autojacvec), + chunk_size, autodiff, Val{:central}, typeof(autojacvec), typeof(linsolve), typeof(linsolve_kwargs), typeof(objective_ad), typeof(active_tol), }(autojacvec, linsolve, linsolve_kwargs, objective_ad, active_tol) end diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index b03a5a129..708b5e527 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -438,11 +438,10 @@ function _adjoint_sensitivities( ) dgdu === nothing && error("dgdu must be specified for OptimizationAdjoint") - p = SymbolicIndexingInterface.parameter_values(sol) - Jpx = OptimizationAdjointProblem(sol.prob, sol, sensealg, p) + p = SymbolicIndexingInterface.parameter_values(sol) Δu = zero(sol.u) dgdu(Δu, sol.u, p, nothing, nothing) - return Jpx' * Δu + return OptimizationAdjointProblem(sol.prob, sol, sensealg, p, Δu) end function _adjoint_sensitivities( @@ -548,19 +547,6 @@ function _adjoint_sensitivities( return SteadyStateAdjointProblem(sol, sensealg, alg, dgdu, dgdp, g; kwargs...) end -function _adjoint_sensitivities( - sol, sensealg::OptimizationAdjoint, alg; - dgdu = nothing, kwargs... - ) - dgdu === nothing && - error("dgdu must be specified for OptimizationAdjoint") - prob = sol.prob - p = prob.p - Jpx = OptimizationAdjointProblem(prob, sol, sensealg, p) - Δu = zero(sol.u) - dgdu(Δu, sol.u, p, nothing, nothing) - return Jpx' * Δu -end @doc doc""" ```julia From 9aede5cb575e58d3bafca2677bb642c897a0f05b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 13 May 2026 11:20:54 -0400 Subject: [PATCH 12/19] introduce wrapper types, use type parameters for AD choosing --- Project.toml | 1 - src/SciMLSensitivity.jl | 1 - src/concrete_solve.jl | 7 +++--- src/derivative_wrappers.jl | 11 +++++++++ src/optimization_adjoint.jl | 39 ++++++++++++++++++++---------- src/sensitivity_algorithms.jl | 45 ++++++++++++++--------------------- src/sensitivity_interface.jl | 5 ++-- 7 files changed, 62 insertions(+), 47 deletions(-) diff --git a/Project.toml b/Project.toml index 2316e988c..112274cff 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 68ed640da..d6173e0ea 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -47,7 +47,6 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit, # AD Backends using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk, AbstractTangent -import DifferentiationInterface as DI using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 389755011..f68331871 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -2687,10 +2687,11 @@ function SciMLBase._concrete_solve_adjoint( opt_f = _prob.f if opt_f.grad === nothing - grad_fn = if sensealg.objective_ad isa Bool && !sensealg.objective_ad - (G, u, p) -> FiniteDiff.finite_difference_gradient!(G, Base.Fix2(opt_f, p), u) - else + grad_fn = if alg_autodiff(sensealg) (G, u, p) -> ForwardDiff.gradient!(G, Base.Fix2(opt_f, p), u) + else + (G, u, p) -> FiniteDiff.finite_difference_gradient!( + G, Base.Fix2(opt_f, p), u, diff_type(sensealg)) end nlprob = NonlinearProblem(grad_fn, opt_sol.u, p) else diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index d215dd7dc..f9217e319 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -157,6 +157,17 @@ function gradient( end end +function hessian( + f, x::AbstractArray{<:Number}, + alg::AbstractOverloadingSensitivityAlgorithm + ) + if alg_autodiff(alg) + ForwardDiff.hessian(unwrapped_f(f), x) + else + FiniteDiff.finite_difference_hessian(f, x) + end +end + function jacobian!( J::Nothing, f, x::AbstractArray{<:Number}, fx::Union{Nothing, AbstractArray{<:Number}}, diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 435b901bb..2a1a93de2 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -24,6 +24,21 @@ end # Override inplace_sensitivity: f is always OOP for optimization inplace_sensitivity(::OptimizationAdjointSensitivityFunction) = false +# Override getprob: OptimizationSolution has no `prob` field; its cache plays the role. +getprob(S::OptimizationAdjointSensitivityFunction) = S.sol.cache + +# Wrapper for the KKT residual closure. Named struct so we can declare the SciMLFunction +# traits that `_vecjacobian!` queries (otherwise a raw closure errors with MethodError). +struct OptimizationKKTResidual{F} + f::F +end +(o::OptimizationKKTResidual)(args...) = o.f(args...) +SciMLBase.has_paramjac(::OptimizationKKTResidual) = false +SciMLBase.has_jac(::OptimizationKKTResidual) = false +SciMLBase.has_vjp(::OptimizationKKTResidual) = false +SciMLBase.has_vjp_p(::OptimizationKKTResidual) = false +SciMLBase.unwrapped_f(o::OptimizationKKTResidual) = o.f + # Evaluate OptimizationFunction auxiliary fields (grad, hess, cons_j, lag_h). # Dispatched on: # Val{iip} — from OptimizationFunction{iip}: true = in-place (leading buffer), false = oop @@ -58,10 +73,10 @@ _opt_eval_lag_h(fn, _, x, σ, μ, _, ::Val{false}, ::Val{false}) = fn(x, σ, μ) function OptimizationAdjointSensitivityFunction( prob, opt_sol, - sensealg::OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD, AT}, + sensealg::OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, AT}, p, Δu - ) where {CS, AD, FDT, VJP, LS, LK, OAD, AT} + ) where {CS, AD, FDT, VJP, LS, LK, AT} x_star = opt_sol.u lcons = prob.lcons @@ -136,13 +151,11 @@ function OptimizationAdjointSensitivityFunction( iip_val = Val{SciMLBase.isinplace(opt_f)}() has_p_val = Val{prob isa SciMLBase.AbstractOptimizationProblem}() - objective_ad = sensealg.objective_ad - # ---- ∇f at x_star: use stored gradient if available ---- ∇f = if opt_f.grad !== nothing _opt_eval_vec(opt_f.grad, n_x, x_star, p, iip_val, has_p_val) else - DI.gradient(x -> prob.f(x, p), objective_ad, x_star) + gradient(x -> prob.f(x, p), x_star, sensealg) end # Precompute full constraint Jacobian once if cons_j is available (reused across passes). @@ -155,7 +168,7 @@ function OptimizationAdjointSensitivityFunction( elseif J_full !== nothing J_full[eq_idx, :] else - DI.jacobian(x -> g(x, p), objective_ad, x_star) + jacobian(x -> g(x, p), x_star, sensealg) end # Active ineq lower bound: h_lb(x,p) = lcons[i] - cons(x,p)[i] (= 0 when active) @@ -181,7 +194,7 @@ function OptimizationAdjointSensitivityFunction( isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] ) else - DI.jacobian(x -> h_I(x, p), objective_ad, x_star) + jacobian(x -> h_I(x, p), x_star, sensealg) end Jx_bound = zeros(eltype(x_star), n_bound, n_x) @@ -234,7 +247,7 @@ function OptimizationAdjointSensitivityFunction( isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] ) else - DI.jacobian(x -> h_I(x, p), objective_ad, x_star) + jacobian(x -> h_I(x, p), x_star, sensealg) end Jx_bound = zeros(eltype(x_star), n_bound, n_x) @@ -277,7 +290,7 @@ function OptimizationAdjointSensitivityFunction( elseif !has_cons && opt_f.hess !== nothing _opt_eval_mat(opt_f.hess, n_x, n_x, x_star, p, iip_val, has_p_val) else - DI.hessian(x -> L(x, p), objective_ad, x_star) + hessian(x -> L(x, p), x_star, sensealg) end N = n_x + n_eq + n_act_total @@ -314,16 +327,16 @@ function OptimizationAdjointSensitivityFunction( # size their buffers, so buffers will be M-sized and match the output of f_F. M = n_x + n_eq + n_act y = zeros(eltype(x_star), M) - f_F = let L = L, g = g, h_I = h_I, x_star = x_star, - objective_ad = objective_ad, n_eq = n_eq, n_act = n_act + f_F = OptimizationKKTResidual(let L = L, g = g, h_I = h_I, x_star = x_star, + sensealg = sensealg, n_eq = n_eq, n_act = n_act function(_, q_full, _) - grad_L = DI.gradient(x -> L(x, q_full), objective_ad, x_star) + grad_L = gradient(x -> L(x, q_full), x_star, sensealg) n_eq == 0 && n_act == 0 && return grad_L n_eq > 0 && n_act == 0 && return vcat(grad_L, g(x_star, q_full)) n_eq == 0 && n_act > 0 && return vcat(grad_L, h_I(x_star, q_full)) vcat(grad_L, g(x_star, q_full), h_I(x_star, q_full)) end - end + end) # λ: adjoint cotangent for f_F — drop the variable-bound rows of λ_full (∂/∂p = 0). λ = λ_full[1:M] diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 31667d0b0..60a0697ec 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1337,7 +1337,7 @@ end """ ```julia -UnconstrainedOptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} +UnconstrainedOptimizationAdjoint{CS, AD, FDT, VJP, LS, LK} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} ``` An implementation of adjoint differentiation for unconstrained optimization problems. @@ -1352,22 +1352,18 @@ steady-state adjoint method. ```julia UnconstrainedOptimizationAdjoint(; chunk_size = 0, autodiff = true, - diff_type = Val{:central}, objective_ad = true, autojacvec = nothing, linsolve = nothing, linsolve_kwargs = (;)) ``` ## Keyword Arguments - - `autodiff`: Use automatic differentiation for constructing the Jacobian - if the Jacobian needs to be constructed. Defaults to `true`. + - `autodiff`: Use automatic differentiation (ForwardDiff) for the objective gradient + and Jacobians when needed. If `false`, FiniteDiff is used with `diff_type=Val{:central}`. + Defaults to `true`. - `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic choice of chunk size. - - `diff_type`: The method used by FiniteDiff.jl for constructing the Jacobian - if the full Jacobian is required with `autodiff=false`. - - `objective_ad`: Use automatic differentiation for computing the gradient of the - objective function when not provided. Defaults to `true`. - `autojacvec`: Calculate the vector-Jacobian product (`J'*v`) via automatic differentiation with special seeding. The total set of choices are: @@ -1394,63 +1390,58 @@ documentation page or the docstrings of the vjp types. Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) """ -struct UnconstrainedOptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD} <: +struct UnconstrainedOptimizationAdjoint{CS, AD, FDT, VJP, LS, LK} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS linsolve_kwargs::LK - objective_ad::OAD end function UnconstrainedOptimizationAdjoint(; chunk_size = 0, autodiff = true, - objective_ad = AutoForwardDiff(), autojacvec = nothing, linsolve = nothing, linsolve_kwargs = (;) ) return UnconstrainedOptimizationAdjoint{ chunk_size, autodiff, Val{:central}, typeof(autojacvec), - typeof(linsolve), typeof(linsolve_kwargs), typeof(objective_ad), - }(autojacvec, linsolve, linsolve_kwargs, objective_ad) + typeof(linsolve), typeof(linsolve_kwargs), + }(autojacvec, linsolve, linsolve_kwargs) end function setvjp( - sensealg::UnconstrainedOptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD}, + sensealg::UnconstrainedOptimizationAdjoint{CS, AD, FDT, VJP, LS, LK}, vjp - ) where {CS, AD, FDT, VJP, LS, LK, OAD} - return UnconstrainedOptimizationAdjoint{CS, AD, FDT, typeof(vjp), LS, LK, OAD}( - vjp, sensealg.linsolve, - sensealg.linsolve_kwargs, sensealg.objective_ad + ) where {CS, AD, FDT, VJP, LS, LK} + return UnconstrainedOptimizationAdjoint{CS, AD, FDT, typeof(vjp), LS, LK}( + vjp, sensealg.linsolve, sensealg.linsolve_kwargs ) end -struct OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD, AT} <: +struct OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, AT} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS linsolve_kwargs::LK - objective_ad::OAD active_tol::AT # tolerance for active inequality constraint detection; nothing = sqrt(eps(eltype(x*))) end function OptimizationAdjoint(; chunk_size = 0, autodiff = true, - objective_ad = AutoForwardDiff(), autojacvec = nothing, linsolve = nothing, linsolve_kwargs = (;), active_tol = nothing ) return OptimizationAdjoint{ chunk_size, autodiff, Val{:central}, typeof(autojacvec), - typeof(linsolve), typeof(linsolve_kwargs), typeof(objective_ad), typeof(active_tol), - }(autojacvec, linsolve, linsolve_kwargs, objective_ad, active_tol) + typeof(linsolve), typeof(linsolve_kwargs), typeof(active_tol), + }(autojacvec, linsolve, linsolve_kwargs, active_tol) end function setvjp( - sensealg::OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, OAD, AT}, + sensealg::OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, AT}, vjp - ) where {CS, AD, FDT, VJP, LS, LK, OAD, AT} - return OptimizationAdjoint{CS, AD, FDT, typeof(vjp), LS, LK, OAD, AT}( - vjp, sensealg.linsolve, sensealg.linsolve_kwargs, sensealg.objective_ad, + ) where {CS, AD, FDT, VJP, LS, LK, AT} + return OptimizationAdjoint{CS, AD, FDT, typeof(vjp), LS, LK, AT}( + vjp, sensealg.linsolve, sensealg.linsolve_kwargs, sensealg.active_tol ) end diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index 708b5e527..75712f504 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -429,7 +429,8 @@ function adjoint_sensitivities( sensealg::OptimizationAdjoint, verbose = true, kwargs... ) - return _adjoint_sensitivities(sol, sensealg, alg, args...; verbose, kwargs...) + _sensealg = sensealg.autojacvec === nothing ? setvjp(sensealg, true) : sensealg + return _adjoint_sensitivities(sol, _sensealg, alg, args...; verbose, kwargs...) end function _adjoint_sensitivities( @@ -441,7 +442,7 @@ function _adjoint_sensitivities( p = SymbolicIndexingInterface.parameter_values(sol) Δu = zero(sol.u) dgdu(Δu, sol.u, p, nothing, nothing) - return OptimizationAdjointProblem(sol.prob, sol, sensealg, p, Δu) + return OptimizationAdjointProblem(sol.cache, sol, sensealg, p, Δu) end function _adjoint_sensitivities( From db67acd50e03858187b03217f27abe2d8902336e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 18 May 2026 12:22:09 -0400 Subject: [PATCH 13/19] format --- ext/SciMLSensitivityMooncakeExt.jl | 2 +- src/concrete_solve.jl | 10 +- src/derivative_wrappers.jl | 4 +- src/optimization_adjoint.jl | 140 +++++++++++++++---------- src/sensitivity_algorithms.jl | 2 +- src/sensitivity_interface.jl | 2 +- test/optimization_adjoint.jl | 158 +++++++++++++++-------------- 7 files changed, 178 insertions(+), 140 deletions(-) diff --git a/ext/SciMLSensitivityMooncakeExt.jl b/ext/SciMLSensitivityMooncakeExt.jl index 0d8c47555..a5de1b689 100644 --- a/ext/SciMLSensitivityMooncakeExt.jl +++ b/ext/SciMLSensitivityMooncakeExt.jl @@ -3,7 +3,7 @@ module SciMLSensitivityMooncakeExt using SciMLSensitivity: SciMLSensitivity, FakeIntegrator using Mooncake: Mooncake import SciMLSensitivity: get_paramjac_config, get_cb_paramjac_config, mooncake_run_ad, - + MooncakeVJP, MooncakeLoaded, DiffEqBase, MooncakeAdjoint, _init_originator_gradient using SciMLSensitivity: SciMLBase, SciMLStructures, canonicalize, Tunable, isscimlstructure, diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index f68331871..87bafc2d5 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -2691,7 +2691,8 @@ function SciMLBase._concrete_solve_adjoint( (G, u, p) -> ForwardDiff.gradient!(G, Base.Fix2(opt_f, p), u) else (G, u, p) -> FiniteDiff.finite_difference_gradient!( - G, Base.Fix2(opt_f, p), u, diff_type(sensealg)) + G, Base.Fix2(opt_f, p), u, diff_type(sensealg) + ) end nlprob = NonlinearProblem(grad_fn, opt_sol.u, p) else @@ -2841,7 +2842,7 @@ function SciMLBase._concrete_solve_adjoint( function optimizationbackpass(Δ) Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ function df(_out, _u, _p, _t, _i) - if _save_idxs isa Number + return if _save_idxs isa Number _out[_save_idxs] = Δ isa AbstractArray ? Δ[_save_idxs] : Δ.u[_save_idxs] elseif Δ isa Number @. _out[_save_idxs] = Δ @@ -2886,10 +2887,11 @@ function SciMLBase._concrete_solve_adjoint( end dp = Zygote.accum( - dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : Δtunables) + dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : Δtunables + ) return if originator isa SciMLBase.TrackerOriginator || - originator isa SciMLBase.ReverseDiffOriginator + originator isa SciMLBase.ReverseDiffOriginator ( NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))..., diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index f9217e319..ddba83eda 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -150,7 +150,7 @@ function gradient( f, x::AbstractArray{<:Number}, alg::AbstractOverloadingSensitivityAlgorithm ) - if alg_autodiff(alg) + return if alg_autodiff(alg) ForwardDiff.gradient(unwrapped_f(f), x) else FiniteDiff.finite_difference_gradient(f, x, diff_type(alg)) @@ -161,7 +161,7 @@ function hessian( f, x::AbstractArray{<:Number}, alg::AbstractOverloadingSensitivityAlgorithm ) - if alg_autodiff(alg) + return if alg_autodiff(alg) ForwardDiff.hessian(unwrapped_f(f), x) else FiniteDiff.finite_difference_hessian(f, x) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 2a1a93de2..73d55e2ca 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -44,30 +44,36 @@ SciMLBase.unwrapped_f(o::OptimizationKKTResidual) = o.f # Val{iip} — from OptimizationFunction{iip}: true = in-place (leading buffer), false = oop # Val{has_p} — true = AbstractOptimizationProblem (p explicit), false = OptimizationCache (p baked in) function _opt_eval_vec(fn, n, x, p, ::Val{true}, ::Val{true}) - out = zeros(eltype(x), n); fn(out, x, p); out + out = zeros(eltype(x), n); fn(out, x, p) + return out end function _opt_eval_vec(fn, n, x, _, ::Val{true}, ::Val{false}) - out = zeros(eltype(x), n); fn(out, x); out + out = zeros(eltype(x), n); fn(out, x) + return out end -_opt_eval_vec(fn, _, x, p, ::Val{false}, ::Val{true}) = fn(x, p) +_opt_eval_vec(fn, _, x, p, ::Val{false}, ::Val{true}) = fn(x, p) _opt_eval_vec(fn, _, x, _, ::Val{false}, ::Val{false}) = fn(x) function _opt_eval_mat(fn, m, n, x, p, ::Val{true}, ::Val{true}) - out = zeros(eltype(x), m, n); fn(out, x, p); out + out = zeros(eltype(x), m, n); fn(out, x, p) + return out end function _opt_eval_mat(fn, m, n, x, _, ::Val{true}, ::Val{false}) - out = zeros(eltype(x), m, n); fn(out, x); out + out = zeros(eltype(x), m, n); fn(out, x) + return out end -_opt_eval_mat(fn, _, _, x, p, ::Val{false}, ::Val{true}) = fn(x, p) +_opt_eval_mat(fn, _, _, x, p, ::Val{false}, ::Val{true}) = fn(x, p) _opt_eval_mat(fn, _, _, x, _, ::Val{false}, ::Val{false}) = fn(x) function _opt_eval_lag_h(fn, n, x, σ, μ, p, ::Val{true}, ::Val{true}) - H = zeros(eltype(x), n, n); fn(H, x, σ, μ, p); H + H = zeros(eltype(x), n, n); fn(H, x, σ, μ, p) + return H end function _opt_eval_lag_h(fn, n, x, σ, μ, _, ::Val{true}, ::Val{false}) - H = zeros(eltype(x), n, n); fn(H, x, σ, μ); H + H = zeros(eltype(x), n, n); fn(H, x, σ, μ) + return H end -_opt_eval_lag_h(fn, _, x, σ, μ, p, ::Val{false}, ::Val{true}) = fn(x, σ, μ, p) +_opt_eval_lag_h(fn, _, x, σ, μ, p, ::Val{false}, ::Val{true}) = fn(x, σ, μ, p) _opt_eval_lag_h(fn, _, x, σ, μ, _, ::Val{false}, ::Val{false}) = fn(x, σ, μ) function OptimizationAdjointSensitivityFunction( @@ -98,8 +104,12 @@ function OptimizationAdjointSensitivityFunction( prob.f.cons # AbstractOptimizationProblem: already (res, x, p) else captured_f = let cl = prob.f.cons - getfield(cl, only(fname for fname in fieldnames(typeof(cl)) - if getfield(cl, fname) isa SciMLBase.AbstractOptimizationFunction)) + getfield( + cl, only( + fname for fname in fieldnames(typeof(cl)) + if getfield(cl, fname) isa SciMLBase.AbstractOptimizationFunction + ) + ) end captured_f.cons end @@ -124,7 +134,7 @@ function OptimizationAdjointSensitivityFunction( end # Classify constraints: equality where lcons[i] == ucons[i] - eq_idx = has_cons ? findall(i -> lcons[i] == ucons[i], eachindex(lcons)) : Int[] + eq_idx = has_cons ? findall(i -> lcons[i] == ucons[i], eachindex(lcons)) : Int[] ineq_idx = has_cons ? findall(i -> lcons[i] != ucons[i], eachindex(lcons)) : Int[] # Evaluate constraints at solution @@ -140,15 +150,15 @@ function OptimizationAdjointSensitivityFunction( g(x, q) = eval_cons(x, q)[eq_idx] .- lcons[eq_idx] n_eq = length(eq_idx) - n_x = length(x_star) + n_x = length(x_star) lb = prob.lb ub = prob.ub active_lb_var = lb !== nothing ? findall(i -> abs(x_star[i] - lb[i]) <= atol, 1:n_x) : Int[] active_ub_var = ub !== nothing ? findall(i -> abs(x_star[i] - ub[i]) <= atol, 1:n_x) : Int[] - opt_f = prob.f - iip_val = Val{SciMLBase.isinplace(opt_f)}() + opt_f = prob.f + iip_val = Val{SciMLBase.isinplace(opt_f)}() has_p_val = Val{prob isa SciMLBase.AbstractOptimizationProblem}() # ---- ∇f at x_star: use stored gradient if available ---- @@ -175,7 +185,7 @@ function OptimizationAdjointSensitivityFunction( # Active ineq upper bound: h_ub(x,p) = cons(x,p)[i] - ucons[i] (= 0 when active) # Variable bound active lower: h_lb_var = lb[i] - x[i] (∂/∂x = -eᵢ, ∂/∂p = 0) # Variable bound active upper: h_ub_var = x[i] - ub[i] (∂/∂x = +eᵢ, ∂/∂p = 0) - n_act = length(active_lb) + length(active_ub) + n_act = length(active_lb) + length(active_ub) n_bound = length(active_lb_var) + length(active_ub_var) h_I = (x, q) -> begin @@ -191,14 +201,16 @@ function OptimizationAdjointSensitivityFunction( elseif J_full !== nothing vcat( isempty(active_lb) ? zeros(eltype(x_star), 0, n_x) : -J_full[active_lb, :], - isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] + isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] ) else jacobian(x -> h_I(x, p), x_star, sensealg) end Jx_bound = zeros(eltype(x_star), n_bound, n_x) - for (j, i) in enumerate(active_lb_var); Jx_bound[j, i] = -one(eltype(x_star)); end + for (j, i) in enumerate(active_lb_var) + Jx_bound[j, i] = -one(eltype(x_star)) + end for (j, i) in enumerate(active_ub_var) Jx_bound[length(active_lb_var) + j, i] = one(eltype(x_star)) end @@ -211,24 +223,28 @@ function OptimizationAdjointSensitivityFunction( else solve(LinearProblem(Matrix(vcat(Jxg, Jxhι)'), -∇f), LinearSolve.QRFactorization()).u end - y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] - zI_star = n_act > 0 ? dual_vars[(n_eq+1):(n_eq+n_act)] : eltype(x_star)[] - z_bound_star = n_bound > 0 ? dual_vars[(n_eq+n_act+1):end] : eltype(x_star)[] + y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] + zI_star = n_act > 0 ? dual_vars[(n_eq + 1):(n_eq + n_act)] : eltype(x_star)[] + z_bound_star = n_bound > 0 ? dual_vars[(n_eq + n_act + 1):end] : eltype(x_star)[] # Multiplier sign check: KKT requires all inequality multipliers ≥ 0 at a minimum. # Negative multipliers indicate spuriously-included constraints (close to bound but inactive). # Drop those and redo only the Jxhι build and dual solve — no extra cost if all signs are good. mtol = sqrt(eps(eltype(x_star))) if (n_act > 0 && any(<(-mtol), zI_star)) || - (n_bound > 0 && any(<(-mtol), z_bound_star)) - n_lb = length(active_lb) + (n_bound > 0 && any(<(-mtol), z_bound_star)) + n_lb = length(active_lb) n_lb_var = length(active_lb_var) - active_lb = active_lb[findall(j -> zI_star[j] >= -mtol, 1:n_lb)] - active_ub = active_ub[findall(j -> zI_star[n_lb+j] >= -mtol, 1:length(active_ub))] - active_lb_var = active_lb_var[findall(j -> z_bound_star[j] >= -mtol, 1:n_lb_var)] - active_ub_var = active_ub_var[findall(j -> z_bound_star[n_lb_var+j] >= -mtol, - 1:length(active_ub_var))] - n_act = length(active_lb) + length(active_ub) + active_lb = active_lb[findall(j -> zI_star[j] >= -mtol, 1:n_lb)] + active_ub = active_ub[findall(j -> zI_star[n_lb + j] >= -mtol, 1:length(active_ub))] + active_lb_var = active_lb_var[findall(j -> z_bound_star[j] >= -mtol, 1:n_lb_var)] + active_ub_var = active_ub_var[ + findall( + j -> z_bound_star[n_lb_var + j] >= -mtol, + 1:length(active_ub_var) + ), + ] + n_act = length(active_lb) + length(active_ub) n_bound = length(active_lb_var) + length(active_ub_var) h_I = (x, q) -> begin @@ -244,14 +260,16 @@ function OptimizationAdjointSensitivityFunction( elseif J_full !== nothing vcat( isempty(active_lb) ? zeros(eltype(x_star), 0, n_x) : -J_full[active_lb, :], - isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] + isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] ) else jacobian(x -> h_I(x, p), x_star, sensealg) end Jx_bound = zeros(eltype(x_star), n_bound, n_x) - for (j, i) in enumerate(active_lb_var); Jx_bound[j, i] = -one(eltype(x_star)); end + for (j, i) in enumerate(active_lb_var) + Jx_bound[j, i] = -one(eltype(x_star)) + end for (j, i) in enumerate(active_ub_var) Jx_bound[length(active_lb_var) + j, i] = one(eltype(x_star)) end @@ -263,14 +281,14 @@ function OptimizationAdjointSensitivityFunction( else solve(LinearProblem(Matrix(vcat(Jxg, Jxhι)'), -∇f), LinearSolve.QRFactorization()).u end - y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] - zI_star = n_act > 0 ? dual_vars[(n_eq+1):(n_eq+n_act)] : eltype(x_star)[] + y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] + zI_star = n_act > 0 ? dual_vars[(n_eq + 1):(n_eq + n_act)] : eltype(x_star)[] end # Lagrangian with fixed multipliers (used for p-derivative computations below) - L = function(x, q) + L = function (x, q) val = prob.f(x, q) - n_eq > 0 && (val += dot(y_star, g(x, q))) + n_eq > 0 && (val += dot(y_star, g(x, q))) n_act > 0 && (val += dot(zI_star, h_I(x, q))) return val end @@ -283,9 +301,15 @@ function OptimizationAdjointSensitivityFunction( # μ[active_ub[j]] = zI_star[n_lb + j] (h_ub = cons - ucons → +cons contribution) Lxx = if opt_f.lag_h !== nothing mu_full = zeros(eltype(x_star), n_cons) - for (j, i) in enumerate(eq_idx); mu_full[i] = y_star[j] end - for (j, i) in enumerate(active_lb); mu_full[i] -= zI_star[j] end - for (j, i) in enumerate(active_ub); mu_full[i] += zI_star[length(active_lb) + j] end + for (j, i) in enumerate(eq_idx) + mu_full[i] = y_star[j] + end + for (j, i) in enumerate(active_lb) + mu_full[i] -= zI_star[j] + end + for (j, i) in enumerate(active_ub) + mu_full[i] += zI_star[length(active_lb) + j] + end _opt_eval_lag_h(opt_f.lag_h, n_x, x_star, one(eltype(x_star)), mu_full, p, iip_val, has_p_val) elseif !has_cons && opt_f.hess !== nothing _opt_eval_mat(opt_f.hess, n_x, n_x, x_star, p, iip_val, has_p_val) @@ -297,18 +321,20 @@ function OptimizationAdjointSensitivityFunction( KKT = zeros(eltype(x_star), N, N) KKT[1:n_x, 1:n_x] = Lxx if n_eq > 0 - KKT[1:n_x, (n_x + 1):(n_x + n_eq)] = Jxg' - KKT[(n_x + 1):(n_x + n_eq), 1:n_x] = Jxg + KKT[1:n_x, (n_x + 1):(n_x + n_eq)] = Jxg' + KKT[(n_x + 1):(n_x + n_eq), 1:n_x] = Jxg end if n_act_total > 0 - KKT[1:n_x, (n_x + n_eq + 1):N] = Jxhι' - KKT[(n_x + n_eq + 1):N, 1:n_x] = Jxhι + KKT[1:n_x, (n_x + n_eq + 1):N] = Jxhι' + KKT[(n_x + n_eq + 1):N, 1:n_x] = Jxhι end # KKT is symmetric, so KKT' = KKT. Solve KKT * λ_full = [Δu; 0; ...; 0] once for all parameters. rhs_adj = vcat(Δu, zeros(eltype(x_star), n_eq + n_act_total)) - λ_full = solve(LinearProblem(KKT, rhs_adj), sensealg.linsolve; - sensealg.linsolve_kwargs...).u + λ_full = solve( + LinearProblem(KKT, rhs_adj), sensealg.linsolve; + sensealg.linsolve_kwargs... + ).u if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity @@ -327,16 +353,18 @@ function OptimizationAdjointSensitivityFunction( # size their buffers, so buffers will be M-sized and match the output of f_F. M = n_x + n_eq + n_act y = zeros(eltype(x_star), M) - f_F = OptimizationKKTResidual(let L = L, g = g, h_I = h_I, x_star = x_star, - sensealg = sensealg, n_eq = n_eq, n_act = n_act - function(_, q_full, _) - grad_L = gradient(x -> L(x, q_full), x_star, sensealg) - n_eq == 0 && n_act == 0 && return grad_L - n_eq > 0 && n_act == 0 && return vcat(grad_L, g(x_star, q_full)) - n_eq == 0 && n_act > 0 && return vcat(grad_L, h_I(x_star, q_full)) - vcat(grad_L, g(x_star, q_full), h_I(x_star, q_full)) + f_F = OptimizationKKTResidual( + let L = L, g = g, h_I = h_I, x_star = x_star, + sensealg = sensealg, n_eq = n_eq, n_act = n_act + function (_, q_full, _) + grad_L = gradient(x -> L(x, q_full), x_star, sensealg) + n_eq == 0 && n_act == 0 && return grad_L + n_eq > 0 && n_act == 0 && return vcat(grad_L, g(x_star, q_full)) + n_eq == 0 && n_act > 0 && return vcat(grad_L, h_I(x_star, q_full)) + return vcat(grad_L, g(x_star, q_full), h_I(x_star, q_full)) + end end - end) + ) # λ: adjoint cotangent for f_F — drop the variable-bound rows of λ_full (∂/∂p = 0). λ = λ_full[1:M] @@ -358,8 +386,10 @@ function OptimizationAdjointSensitivityFunction( _pf = f_F # OOP: Enzyme.make_zero(f) called inline in _vecjacobian! _needs_shadow = _needs_repack _shadow_p = _needs_shadow ? repack(zero(tunables)) : nothing - _config = get_paramjac_config(autojacvec, p, f_F, y, tunables, nothing; - numindvar = M, alg = nothing) + _config = get_paramjac_config( + autojacvec, p, f_F, y, tunables, nothing; + numindvar = M, alg = nothing + ) _config = (_config..., Enzyme.make_zero(_pf), _shadow_p) _pf, _config, nothing elseif autojacvec isa MooncakeVJP diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 60a0697ec..b34fae923 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1418,7 +1418,7 @@ function setvjp( end struct OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, AT} <: - AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} + AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS linsolve_kwargs::LK diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index 75712f504..7f4b20f33 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -439,7 +439,7 @@ function _adjoint_sensitivities( ) dgdu === nothing && error("dgdu must be specified for OptimizationAdjoint") - p = SymbolicIndexingInterface.parameter_values(sol) + p = SymbolicIndexingInterface.parameter_values(sol) Δu = zero(sol.u) dgdu(Δu, sol.u, p, nothing, nothing) return OptimizationAdjointProblem(sol.cache, sol, sensealg, p, Δu) diff --git a/test/optimization_adjoint.jl b/test/optimization_adjoint.jl index 5eaedb011..2dbe08841 100644 --- a/test/optimization_adjoint.jl +++ b/test/optimization_adjoint.jl @@ -193,26 +193,26 @@ end # Minimize (u1-1)^2 + (u2-1)^2 s.t. u1 + u2 = p[1] # Optimal solution: u1* = u2* = p[1]/2 # du1*/dp[1] = 0.5, du2*/dp[1] = 0.5 - f = (u, p) -> (u[1] - 1)^2 + (u[2] - 1)^2 + f = (u, p) -> (u[1] - 1)^2 + (u[2] - 1)^2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]) u0 = [1.5, 1.5] # feasible: u1+u2 = p[1] = 3 - p = [3.0] + p = [3.0] opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ p[1] / 2 rtol = 1e-4 - @test opt_sol.u[2] ≈ p[1] / 2 rtol = 1e-4 - @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 # constraint satisfied + @test opt_sol.u[1] ≈ p[1] / 2 rtol = 1.0e-4 + @test opt_sol.u[2] ≈ p[1] / 2 rtol = 1.0e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1.0e-6 # constraint satisfied dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) - @test dp1[1] ≈ 0.5 rtol = 1e-4 # du1*/dp[1] - @test dp2[1] ≈ 0.5 rtol = 1e-4 # du2*/dp[1] + @test dp1[1] ≈ 0.5 rtol = 1.0e-4 # du1*/dp[1] + @test dp2[1] ≈ 0.5 rtol = 1.0e-4 # du2*/dp[1] end end @@ -221,47 +221,51 @@ end # Minimize (u - p[1])^2 s.t. u <= p[2] where p[2] < p[1] (constraint active) # Optimal solution: u* = p[2] # du*/dp[1] = 0, du*/dp[2] = 1 - f = (u, p) -> (u[1] - p[1])^2 + f = (u, p) -> (u[1] - p[1])^2 cons = (res, u, p) -> (res[1] = u[1] - p[2]) u0 = [0.0] - p = [3.0, 1.0] # unconstrained min at u=3, constraint forces u<=1 + p = [3.0, 1.0] # unconstrained min at u=3, constraint forces u<=1 opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [-Inf], ucons = [0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [-Inf], ucons = [0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ p[2] rtol = 1e-4 - @test opt_sol.u[1] <= p[2] + 1e-6 # constraint satisfied: u <= p[2] + @test opt_sol.u[1] ≈ p[2] rtol = 1.0e-4 + @test opt_sol.u[1] <= p[2] + 1.0e-6 # constraint satisfied: u <= p[2] dgdu!(out, _, _, _, _) = (out[1] = 1.0) dp = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu!) - @test dp[1] ≈ 0.0 atol = 1e-4 # du*/dp[1] = 0 - @test dp[2] ≈ 1.0 rtol = 1e-4 # du*/dp[2] = 1 + @test dp[1] ≈ 0.0 atol = 1.0e-4 # du*/dp[1] = 0 + @test dp[2] ≈ 1.0 rtol = 1.0e-4 # du*/dp[2] = 1 end end @testset "FiniteDiff vs ForwardDiff consistency" begin let # Equality-constrained problem, compare autodiff=true vs autodiff=false - f = (u, p) -> (u[1] - p[1])^2 + (u[2] - p[2])^2 + f = (u, p) -> (u[1] - p[1])^2 + (u[2] - p[2])^2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[3]) u0 = [0.5, 0.5] - p = [1.0, 2.0, 3.0] + p = [1.0, 2.0, 3.0] opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] + opt_sol.u[2] ≈ p[3] rtol = 1e-6 # constraint satisfied + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[3] rtol = 1.0e-6 # constraint satisfied dgdu!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) - dp_fd = adjoint_sensitivities(opt_sol, nothing; - sensealg = OptimizationAdjoint(autodiff = false), dgdu = dgdu!) - dp_fwd = adjoint_sensitivities(opt_sol, nothing; - sensealg = OptimizationAdjoint(autodiff = true), dgdu = dgdu!) - @test dp_fd ≈ dp_fwd rtol = 1e-3 + dp_fd = adjoint_sensitivities( + opt_sol, nothing; + sensealg = OptimizationAdjoint(autodiff = false), dgdu = dgdu! + ) + dp_fwd = adjoint_sensitivities( + opt_sol, nothing; + sensealg = OptimizationAdjoint(autodiff = true), dgdu = dgdu! + ) + @test dp_fd ≈ dp_fwd rtol = 1.0e-3 end end @@ -271,26 +275,26 @@ end # J_p g = 0; sensitivity flows entirely through ∇²_xp L = [1, 0]. # KKT → u1* = (2 - p[1])/4, u2* = (2 + p[1])/4 # du1*/dp[1] = -1/4, du2*/dp[1] = 1/4 - f = (u, p) -> p[1] * u[1] + u[1]^2 + u[2]^2 + f = (u, p) -> p[1] * u[1] + u[1]^2 + u[2]^2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - 1) - p = [2.0] + p = [2.0] u0 = [0.0, 1.0] # feasible: u1+u2 = 1 opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ (2 - p[1]) / 4 rtol = 1e-4 - @test opt_sol.u[2] ≈ (2 + p[1]) / 4 rtol = 1e-4 - @test opt_sol.u[1] + opt_sol.u[2] ≈ 1.0 rtol = 1e-6 # constraint satisfied + @test opt_sol.u[1] ≈ (2 - p[1]) / 4 rtol = 1.0e-4 + @test opt_sol.u[2] ≈ (2 + p[1]) / 4 rtol = 1.0e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ 1.0 rtol = 1.0e-6 # constraint satisfied dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) - @test dp1[1] ≈ -0.25 rtol = 1e-3 # du1*/dp[1] - @test dp2[1] ≈ 0.25 rtol = 1e-3 # du2*/dp[1] + @test dp1[1] ≈ -0.25 rtol = 1.0e-3 # du1*/dp[1] + @test dp2[1] ≈ 0.25 rtol = 1.0e-3 # du2*/dp[1] end end @@ -299,23 +303,23 @@ end # Minimize (u - p[1])^2 s.t. u <= p[2] where p[2] > p[1] (constraint NOT active) # Optimal solution: u* = p[1] (unconstrained min, inequality slack) # du*/dp[1] = 1, du*/dp[2] = 0 - f = (u, p) -> (u[1] - p[1])^2 + f = (u, p) -> (u[1] - p[1])^2 cons = (res, u, p) -> (res[1] = u[1] - p[2]) - p = [1.0, 5.0] # unconstrained min at u=1, well inside bound u<=5 + p = [1.0, 5.0] # unconstrained min at u=1, well inside bound u<=5 u0 = [0.0] opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [-Inf], ucons = [0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [-Inf], ucons = [0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ p[1] rtol = 1e-4 - @test opt_sol.u[1] <= p[2] + 1e-6 # constraint satisfied (slack) + @test opt_sol.u[1] ≈ p[1] rtol = 1.0e-4 + @test opt_sol.u[1] <= p[2] + 1.0e-6 # constraint satisfied (slack) dgdu!(out, _, _, _, _) = (out[1] = 1.0) dp = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu!) - @test dp[1] ≈ 1.0 rtol = 1e-3 # du*/dp[1] = 1 - @test dp[2] ≈ 0.0 atol = 1e-3 # du*/dp[2] = 0 (inactive) + @test dp[1] ≈ 1.0 rtol = 1.0e-3 # du*/dp[1] = 1 + @test dp[2] ≈ 0.0 atol = 1.0e-3 # du*/dp[2] = 0 (inactive) end end @@ -324,29 +328,29 @@ end # Minimize (u1-3)^2 + (u2-3)^2 s.t. u1+u2 = p[1] and u1 <= p[2] # At p=[4,1]: u1* = p[2] = 1, u2* = p[1] - p[2] = 3 # du1*/dp = [0, 1], du2*/dp = [1, -1] - f = (u, p) -> (u[1] - 3)^2 + (u[2] - 3)^2 + f = (u, p) -> (u[1] - 3)^2 + (u[2] - 3)^2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]; res[2] = u[1] - p[2]) - p = [4.0, 1.0] + p = [4.0, 1.0] u0 = [1.0, 3.0] # feasible: u1+u2=4, u1=1<=1 opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0, -Inf], ucons = [0.0, 0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0, -Inf], ucons = [0.0, 0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ p[2] rtol = 1e-4 - @test opt_sol.u[2] ≈ p[1] - p[2] rtol = 1e-4 - @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 # equality satisfied - @test opt_sol.u[1] <= p[2] + 1e-6 # inequality satisfied + @test opt_sol.u[1] ≈ p[2] rtol = 1.0e-4 + @test opt_sol.u[2] ≈ p[1] - p[2] rtol = 1.0e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1.0e-6 # equality satisfied + @test opt_sol.u[1] <= p[2] + 1.0e-6 # inequality satisfied dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) - @test dp1[1] ≈ 0.0 atol = 1e-3 # du1*/dp[1] - @test dp1[2] ≈ 1.0 rtol = 1e-3 # du1*/dp[2] - @test dp2[1] ≈ 1.0 rtol = 1e-3 # du2*/dp[1] - @test dp2[2] ≈ -1.0 rtol = 1e-3 # du2*/dp[2] + @test dp1[1] ≈ 0.0 atol = 1.0e-3 # du1*/dp[1] + @test dp1[2] ≈ 1.0 rtol = 1.0e-3 # du1*/dp[2] + @test dp2[1] ≈ 1.0 rtol = 1.0e-3 # du2*/dp[1] + @test dp2[2] ≈ -1.0 rtol = 1.0e-3 # du2*/dp[2] end end @@ -355,29 +359,31 @@ end # Minimize (1/2)||u||^2 s.t. u1+u2 = p[1], u2+u3 = p[2] # Analytical solution: u* = [(2p[1]-p[2])/3, (p[1]+p[2])/3, (-p[1]+2p[2])/3] # du1/dp = [2/3, -1/3], du2/dp = [1/3, 1/3], du3/dp = [-1/3, 2/3] - f = (u, p) -> sum(u .^ 2) / 2 + f = (u, p) -> sum(u .^ 2) / 2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[1]; res[2] = u[2] + u[3] - p[2]) - p = [1.0, 1.0] + p = [1.0, 1.0] u0 = [1.0 / 3, 2.0 / 3, 1.0 / 3] # feasible opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0, 0.0], ucons = [0.0, 0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0, 0.0], ucons = [0.0, 0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ (2p[1] - p[2]) / 3 rtol = 1e-4 - @test opt_sol.u[2] ≈ (p[1] + p[2]) / 3 rtol = 1e-4 - @test opt_sol.u[3] ≈ (-p[1] + 2p[2]) / 3 rtol = 1e-4 - @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1e-6 - @test opt_sol.u[2] + opt_sol.u[3] ≈ p[2] rtol = 1e-6 + @test opt_sol.u[1] ≈ (2p[1] - p[2]) / 3 rtol = 1.0e-4 + @test opt_sol.u[2] ≈ (p[1] + p[2]) / 3 rtol = 1.0e-4 + @test opt_sol.u[3] ≈ (-p[1] + 2p[2]) / 3 rtol = 1.0e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[1] rtol = 1.0e-6 + @test opt_sol.u[2] + opt_sol.u[3] ≈ p[2] rtol = 1.0e-6 - expected = [[2/3, -1/3], [1/3, 1/3], [-1/3, 2/3]] + expected = [[2 / 3, -1 / 3], [1 / 3, 1 / 3], [-1 / 3, 2 / 3]] for (i, exp_row) in enumerate(expected) e = zeros(3); e[i] = 1.0 dgdui!(out, _, _, _, _) = copyto!(out, e) - dp = adjoint_sensitivities(opt_sol, nothing; - sensealg = OptimizationAdjoint(), dgdu = dgdui!) - @test dp ≈ exp_row rtol = 1e-3 + dp = adjoint_sensitivities( + opt_sol, nothing; + sensealg = OptimizationAdjoint(), dgdu = dgdui! + ) + @test dp ≈ exp_row rtol = 1.0e-3 end end end @@ -389,22 +395,22 @@ end # u2* = p = 0 (unconstrained) → du2*/dp = 1 f = (u, p) -> (u[1] - p[1])^2 + (u[2] - p[1])^2 - p = [0.0] + p = [0.0] u0 = [2.0, 0.0] opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff()) - prob = OptimizationProblem(opt_f, u0, p; lb = [2.0, -Inf], ub = [Inf, Inf]) + prob = OptimizationProblem(opt_f, u0, p; lb = [2.0, -Inf], ub = [Inf, Inf]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ 2.0 rtol = 1e-4 # pinned at lb - @test opt_sol.u[2] ≈ p[1] rtol = 1e-4 # free, at unconstrained min + @test opt_sol.u[1] ≈ 2.0 rtol = 1.0e-4 # pinned at lb + @test opt_sol.u[2] ≈ p[1] rtol = 1.0e-4 # free, at unconstrained min dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) - @test dp1[1] ≈ 0.0 atol = 1e-4 # du1*/dp = 0 (pinned at bound) - @test dp2[1] ≈ 1.0 rtol = 1e-4 # du2*/dp = 1 (free variable) + @test dp1[1] ≈ 0.0 atol = 1.0e-4 # du1*/dp = 0 (pinned at bound) + @test dp2[1] ≈ 1.0 rtol = 1.0e-4 # du2*/dp = 1 (free variable) end end @@ -413,26 +419,26 @@ end # Minimize (u1 - p[1])^2 + u2^2 s.t. u1 + u2 = p[2] # KKT → u1* = (p[1]+p[2])/2, u2* = (p[2]-p[1])/2 # du1*/dp = [1/2, 1/2], du2*/dp = [-1/2, 1/2] - f = (u, p) -> (u[1] - p[1])^2 + u[2]^2 + f = (u, p) -> (u[1] - p[1])^2 + u[2]^2 cons = (res, u, p) -> (res[1] = u[1] + u[2] - p[2]) - p = [1.0, 3.0] + p = [1.0, 3.0] u0 = [1.5, 1.5] # feasible: u1+u2 = 3 = p[2] opt_f = OptimizationFunction(f, Optimization.AutoForwardDiff(); cons = cons) - prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) + prob = OptimizationProblem(opt_f, u0, p; lcons = [0.0], ucons = [0.0]) opt_sol = solve(prob, NLopt.LD_SLSQP()) - @test opt_sol.u[1] ≈ (p[1] + p[2]) / 2 rtol = 1e-4 - @test opt_sol.u[2] ≈ (p[2] - p[1]) / 2 rtol = 1e-4 - @test opt_sol.u[1] + opt_sol.u[2] ≈ p[2] rtol = 1e-6 # constraint satisfied + @test opt_sol.u[1] ≈ (p[1] + p[2]) / 2 rtol = 1.0e-4 + @test opt_sol.u[2] ≈ (p[2] - p[1]) / 2 rtol = 1.0e-4 + @test opt_sol.u[1] + opt_sol.u[2] ≈ p[2] rtol = 1.0e-6 # constraint satisfied dgdu1!(out, _, _, _, _) = (out[1] = 1.0; out[2] = 0.0) dgdu2!(out, _, _, _, _) = (out[1] = 0.0; out[2] = 1.0) dp1 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu1!) dp2 = adjoint_sensitivities(opt_sol, nothing; sensealg = OptimizationAdjoint(), dgdu = dgdu2!) - @test dp1 ≈ [0.5, 0.5] rtol = 1e-3 - @test dp2 ≈ [-0.5, 0.5] rtol = 1e-3 + @test dp1 ≈ [0.5, 0.5] rtol = 1.0e-3 + @test dp2 ≈ [-0.5, 0.5] rtol = 1.0e-3 end end end From f425eb00fde3c3e3aeb63bb7c487fecd0a05b649 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 18 May 2026 13:37:46 -0400 Subject: [PATCH 14/19] drop the lazybuffer cache --- src/optimization_adjoint.jl | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 73d55e2ca..4b5ff53ab 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -28,7 +28,7 @@ inplace_sensitivity(::OptimizationAdjointSensitivityFunction) = false getprob(S::OptimizationAdjointSensitivityFunction) = S.sol.cache # Wrapper for the KKT residual closure. Named struct so we can declare the SciMLFunction -# traits that `_vecjacobian!` queries (otherwise a raw closure errors with MethodError). +# traits that `_vecjacobian!` queries struct OptimizationKKTResidual{F} f::F end @@ -113,18 +113,13 @@ function OptimizationAdjointSensitivityFunction( end captured_f.cons end - cons_cache = LazyBufferCache(_ -> (n_cons,)) eval_cons = function (x, q) - # When eltype(x) and eltype(q) match, reuse the cache. Otherwise infer the - # arithmetic result type via promote_op (compile-time inference) — ForwardDiff's - # @define_binary_dual_op bypasses promote_type, so Dual + TrackedReal returns - # Dual{Tag, TrackedReal, N} which promote_type would not predict. - res = if eltype(x) === eltype(q) - cons_cache[q] - else - T = Base.promote_op(+, eltype(x), eltype(q)) - Vector{T}(undef, n_cons) - end + # promote_op gives the inferred result eltype of `+(eltype(x), eltype(q))`. + # Preferred over promote_type because ForwardDiff's @define_binary_dual_op + # bypasses promote_type — e.g. Dual + TrackedReal returns Dual{Tag, TrackedReal, N} + # which promote_type would not predict. + T = Base.promote_op(+, eltype(x), eltype(q)) + res = Vector{T}(undef, n_cons) _cons3(res, x, q) return res end From 493a1f52039af2e36a4393f462ad52359f7c7a77 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 18 May 2026 16:40:37 -0400 Subject: [PATCH 15/19] fix tests --- test/optimization_adjoint.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/optimization_adjoint.jl b/test/optimization_adjoint.jl index 2dbe08841..93ac40a52 100644 --- a/test/optimization_adjoint.jl +++ b/test/optimization_adjoint.jl @@ -1,7 +1,7 @@ using Test, LinearAlgebra using SciMLSensitivity, Optimization, OptimizationOptimisers, OptimizationNLopt, SciMLBase -using Mooncake, ForwardDiff -using SciMLSensitivity: MooncakeVJP +using Mooncake, ForwardDiff, FiniteDiff +using SciMLSensitivity: MooncakeVJP, alg_autodiff, diff_type # Helper: build a NonlinearSolution from an optimization solve using the gradient as the residual, # and the corresponding SteadyStateAdjoint, matching what _concrete_solve_adjoint does internally. @@ -10,10 +10,11 @@ function build_opt_adjoint_sol(prob, alg, sensealg; kwargs...) opt_f = prob.f grad_fn = if opt_f.grad !== nothing opt_f.grad - elseif sensealg.objective_ad isa Bool && !sensealg.objective_ad - (G, u, p) -> FiniteDiff.finite_difference_gradient!(G, Base.Fix2(opt_f, p), u) - else + elseif alg_autodiff(sensealg) (G, u, p) -> ForwardDiff.gradient!(G, Base.Fix2(opt_f, p), u) + else + (G, u, p) -> FiniteDiff.finite_difference_gradient!( + G, Base.Fix2(opt_f, p), u, diff_type(sensealg)) end nlprob = NonlinearProblem(grad_fn, opt_sol.u, prob.p) sol = SciMLBase.build_solution( From 40c5b9ec3f691bfce004576fa95f788a0eadecfb Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 19 May 2026 11:41:31 -0400 Subject: [PATCH 16/19] fix compat for QA --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 112274cff..6b349511f 100644 --- a/Project.toml +++ b/Project.toml @@ -92,6 +92,7 @@ NLsolve = "4.5.1" NonlinearSolve = "3.0.1, 4" SCCNonlinearSolve = "1" Optimization = "4, 5" +OptimizationNLopt = "0.3" OptimizationOptimisers = "0.3" OrdinaryDiffEq = "6.108, 7" OrdinaryDiffEqCore = "3.26, 4, 5" From 6083542a7daf1802d0b40e1607159d9bc73df831 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 19 May 2026 13:57:30 -0400 Subject: [PATCH 17/19] add docs page --- docs/pages.jl | 1 + docs/src/manual/optimization_sensitivities.md | 20 +++++ src/sensitivity_algorithms.jl | 75 +++++++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 docs/src/manual/optimization_sensitivities.md diff --git a/docs/pages.jl b/docs/pages.jl index 3e1b1be84..c2f4adbc7 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -46,6 +46,7 @@ pages = [ "Manual and APIs" => Any[ "manual/differential_equation_sensitivities.md", "manual/nonlinear_solve_sensitivities.md", + "manual/optimization_sensitivities.md", "manual/direct_forward_sensitivity.md", "manual/direct_adjoint_sensitivities.md", ], diff --git a/docs/src/manual/optimization_sensitivities.md b/docs/src/manual/optimization_sensitivities.md new file mode 100644 index 000000000..86cf06552 --- /dev/null +++ b/docs/src/manual/optimization_sensitivities.md @@ -0,0 +1,20 @@ +# [Sensitivity Algorithms for Optimization Problems](@id sensitivity_optimization) + +SciMLSensitivity provides adjoint algorithms for differentiating through the optimum +`u*(p)` of a parameterized [`OptimizationProblem`](https://docs.sciml.ai/Optimization/stable/), +giving `dG/dp` for any downstream loss `G(u*(p))` via implicit differentiation rather +than by differentiating through the iterations of the optimizer. + + - `UnconstrainedOptimizationAdjoint` handles unconstrained problems by treating the + stationarity condition `∇f(u*, p) = 0` as a steady-state nonlinear system and + reusing the `SteadyStateAdjoint` machinery. + - `OptimizationAdjoint` handles problems with equality, two-sided inequality, and + variable-bound constraints by implicit differentiation of the KKT first-order + optimality conditions. It detects the active inequality set at `u*`, recovers + multipliers from the stationarity equation, and solves a single symmetric KKT + linear system to produce the adjoint. + +```@docs +UnconstrainedOptimizationAdjoint +OptimizationAdjoint +``` diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index b34fae923..17d5ccfb9 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1417,6 +1417,81 @@ function setvjp( ) end +""" +```julia +OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, AT} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} +``` + +An implementation of adjoint differentiation for constrained optimization problems. +Uses implicit differentiation of the KKT first-order optimality conditions to compute +derivatives of the optimal solution u* with respect to parameters p, given a cotangent +`dgdu = dG/du*` for some downstream loss `G`. + +Handles equality constraints (`lcons == ucons`), two-sided inequality constraints +(`lcons ≤ cons(u, p) ≤ ucons`), and variable box bounds (`lb ≤ u ≤ ub`). The active +inequality set is detected by proximity at the optimum and refined by multiplier-sign +checks (KKT requires inequality multipliers to be non-negative). + +Given the cotangent `Δu`, the algorithm solves the symmetric KKT system + +``` +[ L_xx J_g' J_h' ] [ λ_x ] [ Δu ] +[ J_g 0 0 ] · [ λ_y ] = [ 0 ] +[ J_h 0 0 ] [ λ_z ] [ 0 ] +``` + +once, then computes `dG/dp = -λ' · ∂F/∂p` as a single VJP through the KKT residual +`F(p) = [∇_x L(u*, p); g(u*, p); h_I(u*, p)]`. No re-optimization is required. + +## Constructor + +```julia +OptimizationAdjoint(; chunk_size = 0, autodiff = true, + autojacvec = nothing, + linsolve = nothing, linsolve_kwargs = (;), + active_tol = nothing) +``` + +## Keyword Arguments + + - `autodiff`: Use automatic differentiation (ForwardDiff) for the inner derivatives + at `u*` — gradient of the objective, Jacobian of the constraints, Hessian of the + Lagrangian — when not supplied by `OptimizationFunction`. If `false`, FiniteDiff + is used with `diff_type = Val{:central}`. Defaults to `true`. This is independent + of `autojacvec`, which controls the *outer* VJP. + - `chunk_size`: Chunk size for forward-mode differentiation if full Jacobians are + built (`autojacvec=false` and `autodiff=true`). Default is `0` for automatic + choice of chunk size. + - `autojacvec`: Calculate the vector-Jacobian product (`λ' · ∂F/∂p`) through the + KKT residual via automatic differentiation with special seeding. Choices: + + + `nothing`: chooses an automatic algorithm. Defaults to `true` (ForwardDiff + via materialized Jacobian) and is recommended for most users. + + `false`: the Jacobian is constructed via FiniteDiff.jl. + + `true`: the Jacobian is constructed via ForwardDiff.jl. + + `ZygoteVJP`: Uses Zygote.jl for the vjp. + + `EnzymeVJP`: Uses Enzyme.jl for the vjp. + + `ReverseDiffVJP(compile=false)`: Uses ReverseDiff.jl for the vjp. `compile` + is a boolean for whether to precompile the tape, which should only be done + if there are no branches (`if` or `while` statements) in the `f` function. + + `MooncakeVJP`: Uses Mooncake.jl for the vjp. + - `linsolve`: the linear solver used in the KKT solve. Defaults to `nothing`, + which uses a polyalgorithm to choose an efficient algorithm automatically. + - `linsolve_kwargs`: keyword arguments to be passed to the linear solver. + - `active_tol`: proximity tolerance for active inequality / variable-bound + detection. A constraint or bound is considered active at `u*` when + `|c(u*) - bound| ≤ active_tol`. Defaults to `sqrt(eps(eltype(u*)))` when + `nothing`. + +For more details on the vjp choices, please consult the sensitivity algorithms +documentation page or the docstrings of the vjp types. + +## References + +Gould, S., Fernando, B., Cherian, A., Anderson, P., Cruz, R. S., & Guo, E., +On Differentiating Parameterized Argmin and Argmax Problems with Application to +Bi-level Optimization (2016), https://arxiv.org/abs/1607.05447 +""" struct OptimizationAdjoint{CS, AD, FDT, VJP, LS, LK, AT} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP From 6661a11f0d06e7843933b93d15ca6ff858249ad7 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 20 May 2026 13:20:22 -0400 Subject: [PATCH 18/19] Iterate active-set refinement until no negative multipliers --- src/optimization_adjoint.jl | 94 ++++++++++++++----------------------- 1 file changed, 36 insertions(+), 58 deletions(-) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 4b5ff53ab..38fa683f0 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -180,65 +180,20 @@ function OptimizationAdjointSensitivityFunction( # Active ineq upper bound: h_ub(x,p) = cons(x,p)[i] - ucons[i] (= 0 when active) # Variable bound active lower: h_lb_var = lb[i] - x[i] (∂/∂x = -eᵢ, ∂/∂p = 0) # Variable bound active upper: h_ub_var = x[i] - ub[i] (∂/∂x = +eᵢ, ∂/∂p = 0) - n_act = length(active_lb) + length(active_ub) - n_bound = length(active_lb_var) + length(active_ub_var) - - h_I = (x, q) -> begin - c = eval_cons(x, q) - vcat( - isempty(active_lb) ? eltype(x_star)[] : lcons[active_lb] .- c[active_lb], - isempty(active_ub) ? eltype(x_star)[] : c[active_ub] .- ucons[active_ub] - ) - end - - Jxhι_cons = if n_act == 0 - zeros(eltype(x_star), 0, n_x) - elseif J_full !== nothing - vcat( - isempty(active_lb) ? zeros(eltype(x_star), 0, n_x) : -J_full[active_lb, :], - isempty(active_ub) ? zeros(eltype(x_star), 0, n_x) : J_full[active_ub, :] - ) - else - jacobian(x -> h_I(x, p), x_star, sensealg) - end - - Jx_bound = zeros(eltype(x_star), n_bound, n_x) - for (j, i) in enumerate(active_lb_var) - Jx_bound[j, i] = -one(eltype(x_star)) - end - for (j, i) in enumerate(active_ub_var) - Jx_bound[length(active_lb_var) + j, i] = one(eltype(x_star)) - end - Jxhι = vcat(Jxhι_cons, Jx_bound) - - # Dual variables from stationarity: [Jxg; Jxhι]' * [y*; z_I*; z_bound] = -∇f - n_act_total = n_act + n_bound - dual_vars = if n_eq + n_act_total == 0 - eltype(x_star)[] - else - solve(LinearProblem(Matrix(vcat(Jxg, Jxhι)'), -∇f), LinearSolve.QRFactorization()).u - end - y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] - zI_star = n_act > 0 ? dual_vars[(n_eq + 1):(n_eq + n_act)] : eltype(x_star)[] - z_bound_star = n_bound > 0 ? dual_vars[(n_eq + n_act + 1):end] : eltype(x_star)[] - - # Multiplier sign check: KKT requires all inequality multipliers ≥ 0 at a minimum. - # Negative multipliers indicate spuriously-included constraints (close to bound but inactive). - # Drop those and redo only the Jxhι build and dual solve — no extra cost if all signs are good. + # + # Active-set iteration: KKT requires inequality multipliers ≥ 0 at any optimum. + # Build Jxhι and recover multipliers; if any are negative, the offending constraints + # were spuriously flagged by proximity detection. Drop them and re-solve. Dropped + # indices never come back, so this terminates in ≤ |initial active set| iterations. + # The iteration cap is defense-in-depth; the mathematical bound is the same. mtol = sqrt(eps(eltype(x_star))) - if (n_act > 0 && any(<(-mtol), zI_star)) || - (n_bound > 0 && any(<(-mtol), z_bound_star)) - n_lb = length(active_lb) - n_lb_var = length(active_lb_var) - active_lb = active_lb[findall(j -> zI_star[j] >= -mtol, 1:n_lb)] - active_ub = active_ub[findall(j -> zI_star[n_lb + j] >= -mtol, 1:length(active_ub))] - active_lb_var = active_lb_var[findall(j -> z_bound_star[j] >= -mtol, 1:n_lb_var)] - active_ub_var = active_ub_var[ - findall( - j -> z_bound_star[n_lb_var + j] >= -mtol, - 1:length(active_ub_var) - ), - ] + max_iters = length(active_lb) + length(active_ub) + + length(active_lb_var) + length(active_ub_var) + 1 + local h_I, n_act, n_bound, y_star, zI_star + has_negatives = true + iter = 0 + while has_negatives && iter < max_iters + iter += 1 n_act = length(active_lb) + length(active_ub) n_bound = length(active_lb_var) + length(active_ub_var) @@ -270,6 +225,7 @@ function OptimizationAdjointSensitivityFunction( end Jxhι = vcat(Jxhι_cons, Jx_bound) + # Dual variables from stationarity: [Jxg; Jxhι]' * [y*; z_I*; z_bound] = -∇f n_act_total = n_act + n_bound dual_vars = if n_eq + n_act_total == 0 eltype(x_star)[] @@ -278,6 +234,28 @@ function OptimizationAdjointSensitivityFunction( end y_star = n_eq > 0 ? dual_vars[1:n_eq] : eltype(x_star)[] zI_star = n_act > 0 ? dual_vars[(n_eq + 1):(n_eq + n_act)] : eltype(x_star)[] + z_bound_star = n_bound > 0 ? dual_vars[(n_eq + n_act + 1):end] : eltype(x_star)[] + + neg_in_zI = n_act > 0 && any(<(-mtol), zI_star) + neg_in_bound = n_bound > 0 && any(<(-mtol), z_bound_star) + has_negatives = neg_in_zI || neg_in_bound + + if has_negatives + # Filter offenders out of each active set; they never re-enter. + n_lb = length(active_lb) + n_lb_var = length(active_lb_var) + active_lb = active_lb[findall(j -> zI_star[j] >= -mtol, 1:n_lb)] + active_ub = active_ub[findall(j -> zI_star[n_lb + j] >= -mtol, + 1:length(active_ub))] + active_lb_var = active_lb_var[findall(j -> z_bound_star[j] >= -mtol, + 1:n_lb_var)] + active_ub_var = active_ub_var[ + findall( + j -> z_bound_star[n_lb_var + j] >= -mtol, + 1:length(active_ub_var) + ), + ] + end end # Lagrangian with fixed multipliers (used for p-derivative computations below) From b6aef6b01fcea09179849c9c807632a40853e155 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 20 May 2026 15:18:36 -0400 Subject: [PATCH 19/19] fix local variable declarations --- src/optimization_adjoint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimization_adjoint.jl b/src/optimization_adjoint.jl index 38fa683f0..b14e7f858 100644 --- a/src/optimization_adjoint.jl +++ b/src/optimization_adjoint.jl @@ -189,7 +189,7 @@ function OptimizationAdjointSensitivityFunction( mtol = sqrt(eps(eltype(x_star))) max_iters = length(active_lb) + length(active_ub) + length(active_lb_var) + length(active_ub_var) + 1 - local h_I, n_act, n_bound, y_star, zI_star + local h_I, n_act, n_bound, n_act_total, y_star, zI_star, Jxhι has_negatives = true iter = 0 while has_negatives && iter < max_iters