Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ BNKJuMP = "JuMP"
ExaModels = "0.8.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
PGLib = "07a8691f-3d11-4330-951b-3c50f98338be"
PowerModels = "c36e90e8-916a-50a6-bd94-075b64ef4655"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"

[targets]
test = [
"Test", "LinearAlgebra",
"OpenCL", "pocl_jll", "AcceleratedKernels",
"DifferentiationInterface", "FiniteDifferences", "Zygote"
]
test = ["Test", "CUDA", "GPUArraysCore", "LinearAlgebra", "OpenCL", "pocl_jll", "AcceleratedKernels", "DifferentiationInterface", "FiniteDifferences", "Zygote", "PGLib", "PowerModels"]
46 changes: 46 additions & 0 deletions ext/BNKChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,50 @@
return y, cons_nln_batch_pullback
end


function ChainRulesCore.rrule(::typeof(BatchNLPKernels.constraint_violations!), bm::BatchModel, V)
Vc = BatchNLPKernels.constraint_violations!(bm, V)

function constraint_violations_pullback(Ȳ)
Ȳ = ChainRulesCore.unthunk(Ȳ)

# violation(v, s) = max(s.l - v, v - s.u, 0)
# ∂violation/∂v = -1 if v < s.l, +1 if v > s.u, 0 otherwise

V̄ = if isempty(bm.viols_cons)
zeros(eltype(V), size(V))

Check warning on line 70 in ext/BNKChainRulesCore.jl

View check run for this annotation

Codecov / codecov/patch

ext/BNKChainRulesCore.jl#L70

Added line #L70 was not covered by tests
else
lower_viols = V .< bm.viols_cons.l
upper_viols = V .> bm.viols_cons.u
lower_viols .* (-Ȳ) .+ upper_viols .* Ȳ
end

return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), V̄
end

return Vc, constraint_violations_pullback
end
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.bound_violations!), bm::BatchModel, X)
Vb = BatchNLPKernels.bound_violations!(bm, X)

function bound_violations_pullback(Ȳ)
Ȳ = ChainRulesCore.unthunk(Ȳ)

# violation(x, s) = max(s.l - x, x - s.u, 0)
# ∂violation/∂x = -1 if x < s.l, +1 if x > s.u, 0 otherwise

X̄ = if isempty(bm.viols_vars)
zeros(eltype(X), size(X))

Check warning on line 92 in ext/BNKChainRulesCore.jl

View check run for this annotation

Codecov / codecov/patch

ext/BNKChainRulesCore.jl#L92

Added line #L92 was not covered by tests
else
lower_viols = X .< bm.viols_vars.l
upper_viols = X .> bm.viols_vars.u
lower_viols .* (-Ȳ) .+ upper_viols .* Ȳ
end

return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), X̄
end

return Vb, bound_violations_pullback
end

end # module BNKChainRulesCore
2 changes: 2 additions & 0 deletions src/BatchNLPKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using KernelAbstractions
const ExaKA = Base.get_extension(ExaModels, :ExaModelsKernelAbstractions)
const KAExtension = ExaKA.KAExtension

include("interval.jl")
include("batch_model.jl")

const BOI = BatchNLPKernels
Expand All @@ -22,5 +23,6 @@ include("api/jac.jl")
include("api/obj.jl")
include("api/jprod.jl")
include("api/hprod.jl")
include("api/viols.jl")

end # module BatchNLPKernels
57 changes: 57 additions & 0 deletions src/api/viols.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
all_violations!(bm::BatchModel, X::AbstractMatrix)

Compute all constraint and variable violations for a batch of solutions.
"""
function all_violations!(bm::BatchModel, X::AbstractMatrix)
Comment thread
andrewrosemberg marked this conversation as resolved.
V = cons_nln_batch!(bm, X)

Vc = constraint_violations!(bm, V)
Vb = bound_violations!(bm, X)

return Vc, Vb
end

"""
all_violations!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)

Compute all constraint and variable violations for a batch of solutions and parameters.
"""
function all_violations!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
V = cons_nln_batch!(bm, X, Θ)

Vc = constraint_violations!(bm, V)
Vb = bound_violations!(bm, X)

return Vc, Vb
end

"""
constraint_violations!(bm::BatchModel, V::AbstractMatrix)

Compute constraint violations for a batch of constraint primal values.
"""
function constraint_violations!(bm::BatchModel, V::AbstractMatrix)
viols_cons_out = _maybe_view(bm, :viols_cons_out, V)

_violation!.(eachcol(viols_cons_out), eachcol(V), bm.viols_cons)

return viols_cons_out
end

"""
bound_violations!(bm::BatchModel, X::AbstractMatrix)

Compute variable violations for a batch of variable primal values.
"""
function bound_violations!(bm::BatchModel, X::AbstractMatrix)
viols_vars_out = _maybe_view(bm, :viols_vars_out, X)

_violation!.(eachcol(viols_vars_out), eachcol(X), bm.viols_vars)

return viols_vars_out
end

@inline _violation!(d, v, s::S) where {S} = begin
d .= _violation(v, s)
end
67 changes: 51 additions & 16 deletions src/batch_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- `jprod::Bool`: Allocate jacobian-vector product buffer (default: false)
- `jtprod::Bool`: Allocate jacobian transpose-vector product buffer (default: false)
- `hprod::Bool`: Allocate hessian-vector product buffer (default: false)
- `viols::Bool`: Allocate constraint and variable violation buffers (default: false)
"""
struct BatchModelConfig
obj::Bool
Expand All @@ -22,37 +23,52 @@
jprod::Bool
jtprod::Bool
hprod::Bool
viols::Bool
end

"""
BatchModelConfig(; obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false)
BatchModelConfig(; obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false, viols=false)

Create a BatchModelConfig with specified buffer allocations.
"""
function BatchModelConfig(; obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false)
return BatchModelConfig(obj, cons, grad, jac, hess, jprod, jtprod, hprod)
function BatchModelConfig(; obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false, viols=false)
return BatchModelConfig(obj, cons, grad, jac, hess, jprod, jtprod, hprod, viols)
end

"""
BatchModelConfig(:minimal)

Minimal configuration with only objective and constraint buffers.
"""
BatchModelConfig(::Val{:minimal}) = BatchModelConfig(obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false)
BatchModelConfig(::Val{:minimal}) = BatchModelConfig(obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false, viols=false)

"""
BatchModelConfig(:gradients)

Configuration to support obj, cons, and their gradients (grad, jtprod).
"""
BatchModelConfig(::Val{:gradients}) = BatchModelConfig(obj=true, cons=true, grad=true, jac=false, hess=false, jprod=false, jtprod=true, hprod=false)
BatchModelConfig(::Val{:gradients}) = BatchModelConfig(obj=true, cons=true, grad=true, jac=true, hess=false, jprod=false, jtprod=true, hprod=false, viols=false)

Check warning on line 50 in src/batch_model.jl

View check run for this annotation

Codecov / codecov/patch

src/batch_model.jl#L50

Added line #L50 was not covered by tests

"""
BatchModelConfig(:violations)

Configuration to support obj, cons, and constraint/variable violations.
"""
BatchModelConfig(::Val{:violations}) = BatchModelConfig(obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false, viols=true)

"""
BatchModelConfig(:viol_grad)

Configuration to support obj, cons, constraint/variable violations, and their gradients.
"""
BatchModelConfig(::Val{:viol_grad}) = BatchModelConfig(obj=true, cons=true, grad=true, jac=true, hess=false, jprod=false, jtprod=true, hprod=false, viols=true)

"""
BatchModelConfig(:full)

Full configuration with all buffers allocated.
"""
BatchModelConfig(::Val{:full}) = BatchModelConfig(obj=true, cons=true, grad=true, jac=true, hess=true, jprod=true, jtprod=true, hprod=true)
BatchModelConfig(::Val{:full}) = BatchModelConfig(obj=true, cons=true, grad=true, jac=true, hess=true, jprod=true, jtprod=true, hprod=true, viols=true)

BatchModelConfig(s::Symbol) = BatchModelConfig(Val(s))

Expand All @@ -65,16 +81,20 @@
## Fields
- `model::ExaModel`: The underlying ExaModel
- `batch_size::Int`: Number of points to evaluate simultaneously
- `obj_work::MT`: Batch objective values (nobj × batch_size), (0 × batch_size) if not allocated
- `cons_work::MT`: Batch constraint values (nconaug × batch_size), (0 × batch_size) if not allocated
- `cons_out::MT`: Dense constraint output buffer (ncon × batch_size), (0 × batch_size) if not allocated
- `grad_work::MT`: Batch gradient values (nnzg × batch_size), (0 × batch_size) if not allocated
- `grad_out::MT`: Dense gradient output buffer (nvar × batch_size), (0 × batch_size) if not allocated
- `jprod_work::MT`: Batch jacobian values (nnzj × batch_size), (0 × batch_size) if not allocated
- `hprod_work::MT`: Batch hessian values (nnzh × batch_size), (0 × batch_size) if not allocated
- `jprod_out::MT`: Batch jacobian-vector product buffer (ncon × batch_size), (0 × batch_size) if not allocated
- `jtprod_out::MT`: Batch jacobian transpose-vector product buffer (nvar × batch_size), (0 × batch_size) if not allocated
- `hprod_out::MT`: Batch hessian-vector product buffer (nvar × batch_size), (0 × batch_size) if not allocated
- `obj_work::MT`: Batch objective values (nobj × batch_size)
- `cons_work::MT`: Batch constraint values (nconaug × batch_size)
- `cons_out::MT`: Dense constraint output buffer (ncon × batch_size)
- `grad_work::MT`: Batch gradient values (nnzg × batch_size)
- `grad_out::MT`: Dense gradient output buffer (nvar × batch_size)
- `jprod_work::MT`: Batch jacobian values (nnzj × batch_size)
- `hprod_work::MT`: Batch hessian values (nnzh × batch_size)
- `jprod_out::MT`: Batch jacobian-vector product buffer (ncon × batch_size)
- `jtprod_out::MT`: Batch jacobian transpose-vector product buffer (nvar × batch_size)
- `hprod_out::MT`: Batch hessian-vector product buffer (nvar × batch_size)
- `viols_cons_out::MT`: Constraint violation output buffer (ncon × batch_size)
- `viols_vars_out::MT`: Variable violation output buffer (nvar × batch_size)
- `viols_cons::Interval`: Constraint bounds as interval set
- `viols_vars::Interval`: Variable bounds as interval set
"""
struct BatchModel{MT,E}
model::E
Expand All @@ -90,6 +110,11 @@
jprod_out::MT
jtprod_out::MT
hprod_out::MT

viols_cons_out::MT
viols_vars_out::MT
viols_cons::Interval
viols_vars::Interval
end

"""
Expand Down Expand Up @@ -133,6 +158,12 @@
jprod_out = config.jprod ? similar(o, ncon, batch_size) : similar(o, 0, batch_size)
jtprod_out = config.jtprod ? similar(o, nvar, batch_size) : similar(o, 0, batch_size)
hprod_out = config.hprod ? similar(o, nvar, batch_size) : similar(o, 0, batch_size)

# FIXME: consider don't allocate vars if there are no bound constraints
viols_cons_out = config.viols ? similar(o, ncon, batch_size) : similar(o, 0, batch_size)
viols_vars_out = config.viols ? similar(o, nvar, batch_size) : similar(o, 0, batch_size)
viols_cons = config.viols ? Interval(model.meta.lcon, model.meta.ucon) : Interval()
viols_vars = config.viols ? Interval(model.meta.lvar, model.meta.uvar) : Interval()

return BatchModel(
model,
Expand All @@ -147,6 +178,10 @@
jprod_out,
jtprod_out,
hprod_out,
viols_cons_out,
viols_vars_out,
viols_cons,
viols_vars,
)
end

Expand Down
21 changes: 21 additions & 0 deletions src/interval.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Interval{VT}

Represents the RHS of M constraints g(xᵢ) ∈ [lᵢ, uᵢ] ∀i ∈ 1:M.
"""
struct Interval{VT}
l::VT
u::VT
end
@inline _violation(v, s::Interval{VT}) where {VT} = begin
@. max(s.l - v, v - s.u, zero(v))
end

Base.broadcastable(s::Interval) = Ref(s)
Base.isempty(s::Interval{VT}) where {VT} = isempty(s.l) || isempty(s.u)

# empty support (unconstrained)
Interval(::Nothing) = Interval()

Check warning on line 18 in src/interval.jl

View check run for this annotation

Codecov / codecov/patch

src/interval.jl#L18

Added line #L18 was not covered by tests
Interval() = Interval(nothing, nothing)
Base.isempty(::Interval{Nothing}) = true
@inline _violation(v, ::Interval{Nothing}) = zero(v)

Check warning on line 21 in src/interval.jl

View check run for this annotation

Codecov / codecov/patch

src/interval.jl#L20-L21

Added lines #L20 - L21 were not covered by tests
Loading