Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,30 @@ version = "1.0.0-DEV"
[deps]
ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[extensions]
BNKChainRulesCore = "ChainRulesCore"
BNKJuMP = "JuMP"
BNKReactant = "Reactant"

[compat]
ExaModels = "0.8.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"

[targets]
test = [
"Test", "LinearAlgebra",
"OpenCL", "pocl_jll", "AcceleratedKernels",
"DifferentiationInterface", "FiniteDifferences", "Zygote"
]
test = ["Test", "LinearAlgebra", "OpenCL", "pocl_jll", "AcceleratedKernels", "DifferentiationInterface", "FiniteDifferences", "Zygote"]
48 changes: 48 additions & 0 deletions ext/BNKReactant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module BNKReactant

using BatchNLPKernels
using Reactant, KernelAbstractions
using ExaModels


function to_reactant_KA(bm::BNK.BatchModel)
RKA = Base.get_extension(Reactant, :ReactantKernelAbstractionsExt)
if !occursin("CUDA", string(bm.model.ext.backend))
error("ExaModel must be built with CUDABackend")
end
return BNK.BatchModel(
bm.model,
bm.batch_size,
Reactant.to_rarray(bm.obj_work),
Reactant.to_rarray(bm.cons_work),
Reactant.to_rarray(bm.cons_out),
Reactant.to_rarray(bm.grad_work),
Reactant.to_rarray(bm.grad_out),
Reactant.to_rarray(bm.jprod_work),
Reactant.to_rarray(bm.hprod_work),
Reactant.to_rarray(bm.jprod_out),
Reactant.to_rarray(bm.jtprod_out),
Reactant.to_rarray(bm.hprod_out),
RKA.ReactantBackend(),
)
end

function to_reactant(bm::BNK.BatchModel)
return BNK.BatchModel(
bm.model,
bm.batch_size,
Reactant.to_rarray(bm.obj_work),
Reactant.to_rarray(bm.cons_work),
Reactant.to_rarray(bm.cons_out),
Reactant.to_rarray(bm.grad_work),
Reactant.to_rarray(bm.grad_out),
Reactant.to_rarray(bm.jprod_work),
Reactant.to_rarray(bm.hprod_work),
Reactant.to_rarray(bm.jprod_out),
Reactant.to_rarray(bm.jtprod_out),
Reactant.to_rarray(bm.hprod_out),
nothing,
)
end

end # module BNKReactant
5 changes: 3 additions & 2 deletions src/BatchNLPKernels.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
module BatchNLPKernels

using ExaModels
using LinearAlgebra
using KernelAbstractions

const ExaKA = Base.get_extension(ExaModels, :ExaModelsKernelAbstractions)
const KAExtension = ExaKA.KAExtension

include("batch_model.jl")

const BOI = BatchNLPKernels
export BOI, BatchModel, BatchModelConfig
const BNK = BatchNLPKernels
export BNK, BatchModel, BatchModelConfig
export obj_batch!, grad_batch!, cons_nln_batch!, jac_coord_batch!, hess_coord_batch!
export jprod_nln_batch!, jtprod_nln_batch!, hprod_batch!

Expand Down
24 changes: 18 additions & 6 deletions src/api/cons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function cons_nln_batch!(
@lencheck length(bm.model.θ) eachrow(Θ)
@lencheck bm.model.meta.ncon eachrow(C)
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
backend = _get_backend(bm)

_cons_nln_batch!(backend, C, bm.model.cons, X, Θ)

Expand All @@ -41,14 +41,14 @@ function cons_nln_batch!(
_conaugs_batch!(backend, conbuffers_batch, bm.model.cons, X, Θ)

if length(bm.model.ext.conaugptr) > 1
compress_to_dense_batch(backend)(
_run_compress_to_dense_batch!(
backend,
C,
conbuffers_batch,
bm.model.ext.conaugptr,
bm.model.ext.conaugsparsity;
ndrange = (length(bm.model.ext.conaugptr) - 1, batch_size),
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

bm.model.ext.conaugsparsity,
batch_size,
)
synchronize(backend)
end
return C
end
Expand All @@ -61,6 +61,12 @@ function _cons_nln_batch!(backend, C, con::ExaModels.Constraint, X, Θ)
_cons_nln_batch!(backend, C, con.inner, X, Θ)
synchronize(backend)
end
function _cons_nln_batch!(::Nothing, C, con::ExaModels.Constraint, X, Θ)
if !isempty(con.itr)
kerf_batch_cpu!(C, con.f, con.itr, X, Θ)
end
_cons_nln_batch!(backend, C, con.inner, X, Θ)
end
function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintNull, X, Θ) end
function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintAug, X, Θ)
_cons_nln_batch!(backend, C, con.inner, X, Θ)
Expand All @@ -74,7 +80,13 @@ function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintAug, X,
_conaugs_batch!(backend, conbuffers, con.inner, X, Θ)
synchronize(backend)
end
function _conaugs_batch!(::Nothing, conbuffers, con::ExaModels.ConstraintAug, X, Θ)
if !isempty(con.itr)
kerf2_batch_cpu!(conbuffers, con.f, con.itr, X, Θ, con.oa)
end
_conaugs_batch!(backend, conbuffers, con.inner, X, Θ)
end
function _conaugs_batch!(backend, conbuffers, con::ExaModels.Constraint, X, Θ)
_conaugs_batch!(backend, conbuffers, con.inner, X, Θ)
end
function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintNull, X, Θ) end
function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintNull, X, Θ) end
24 changes: 18 additions & 6 deletions src/api/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ function sgradient_batch!(
kerg_batch(backend)(Y, f.f, f.itr, X, Θ, adj; ndrange = (length(f.itr), batch_size))
end
end
function sgradient_batch!(
::Nothing,
Y,
f,
X,
Θ,
adj,
)
if !isempty(f.itr)
kerg_batch_cpu!(Y, f.f, f.itr, X, Θ, adj)
end
end

"""
grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, G::AbstractMatrix)
Expand All @@ -56,7 +68,7 @@ function grad_batch!(
@lencheck bm.model.meta.nvar eachrow(X) eachrow(G)
@lencheck length(bm.model.θ) eachrow(Θ) # FIXME
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
backend = _get_backend(bm)

grad_work = _maybe_view(bm, :grad_work, X)

Expand All @@ -66,15 +78,15 @@ function grad_batch!(
_grad_batch!(backend, grad_work, bm.model.objs, X, Θ)

fill!(G, zero(eltype(G)))
compress_to_dense_batch(backend)(
_run_compress_to_dense_batch!(
backend,
G,
grad_work,
bm.model.ext.gptr,
bm.model.ext.gsparsity;
ndrange = (length(bm.model.ext.gptr) - 1, batch_size),
bm.model.ext.gsparsity,
batch_size,
)
synchronize(backend)
end

return G
end
end
32 changes: 31 additions & 1 deletion src/api/hess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function hess_coord_batch!(
@lencheck bm.model.meta.ncon eachrow(Y)
@lencheck bm.model.meta.nnzh eachrow(H)
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
backend = _get_backend(bm)

fill!(H, zero(eltype(H)))
_obj_hess_coord_batch!(backend, H, bm.model.objs, X, Θ, obj_weight)
Expand Down Expand Up @@ -87,3 +87,33 @@ function shessian_batch!(
kerh2_batch(backend)(y1, y2, f.f, f.itr, X, Θ, adj, adj2; ndrange = (length(f.itr), batch_size))
end
end

function shessian_batch!(
backend::Nothing,
y1,
y2,
f,
X,
Θ,
adj,
adj2,
)
if !isempty(f.itr)
kerh_batch_cpu!(y1, y2, f.f, f.itr, X, Θ, adj, adj2)
end
end

function shessian_batch!(
backend::Nothing,
y1,
y2,
f,
X,
Θ,
adj::AbstractMatrix,
adj2,
)
if !isempty(f.itr)
kerh2_batch_cpu!(y1, y2, f.f, f.itr, X, Θ, adj, adj2)
end
end
20 changes: 10 additions & 10 deletions src/api/hprod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,33 @@ function hprod_batch!(
@lencheck length(bm.model.θ) eachrow(Θ)
@lencheck bm.model.meta.ncon eachrow(Y)
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
backend = _get_backend(bm)
ph = _get_prodhelper(bm.model)

H_batch = _maybe_view(bm, :hprod_work, X)

hess_coord_batch!(bm, X, Θ, Y, H_batch; obj_weight=obj_weight)

fill!(Hv, zero(eltype(Hv)))
kersyspmv_batch(backend)(
_run_kersyspmv_batch!(
backend,
Hv,
V,
ph.hesssparsityi,
H_batch,
ph.hessptri;
ndrange = (length(ph.hessptri) - 1, batch_size),
ph.hessptri,
batch_size,
)
synchronize(backend)

kersyspmv2_batch(backend)(

_run_kersyspmv2_batch!(
backend,
Hv,
V,
ph.hesssparsityj,
H_batch,
ph.hessptrj;
ndrange = (length(ph.hessptrj) - 1, batch_size),
ph.hessptrj,
batch_size,
)
synchronize(backend)

return Hv
end
16 changes: 15 additions & 1 deletion src/api/jac.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function jac_coord_batch!(
@lencheck length(bm.model.θ) eachrow(Θ)
@lencheck bm.model.meta.nnzj eachrow(J)
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
backend = _get_backend(bm)

fill!(J, zero(eltype(J)))
_jac_coord_batch!(backend, J, bm.model.cons, X, Θ)
Expand Down Expand Up @@ -59,3 +59,17 @@ function sjacobian_batch!(
kerj_batch(backend)(y1, y2, f.f, f.itr, X, Θ, adj; ndrange = (length(f.itr), batch_size))
end
end

function sjacobian_batch!(
::Nothing,
y1,
y2,
f,
X,
Θ,
adj,
)
if !isempty(f.itr)
kerj_batch_cpu!(y1, y2, f.f, f.itr, X, Θ, adj)
end
end
20 changes: 6 additions & 14 deletions src/api/jprod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,7 @@ function jprod_nln_batch!(
jac_coord_batch!(bm, X, Θ, J_batch)

fill!(Jv, zero(eltype(Jv)))
kerspmv_batch(bm.model.ext.backend)(
Jv,
V,
ph.jacsparsityi,
J_batch,
ph.jacptri;
ndrange = (length(ph.jacptri) - 1, batch_size),
)
synchronize(bm.model.ext.backend)
_run_kerspmv_batch!(bm.model.ext.backend, Jv, V, ph.jacsparsityi, J_batch, ph.jacptri, batch_size)

return Jv
end
Expand Down Expand Up @@ -99,23 +91,23 @@ function jtprod_nln_batch!(
@lencheck length(bm.model.θ) eachrow(Θ)
@lencheck bm.model.meta.ncon eachrow(V)
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
backend = _get_backend(bm)
ph = _get_prodhelper(bm.model)

J_batch = _maybe_view(bm, :jprod_work, X)

jac_coord_batch!(bm, X, Θ, J_batch)

fill!(Jtv, zero(eltype(Jtv)))
kerspmv2_batch(backend)(
_run_kerspmv2_batch!(
backend,
Jtv,
V,
ph.jacsparsityj,
J_batch,
ph.jacptrj;
ndrange = (length(ph.jacptrj) - 1, batch_size),
ph.jacptrj,
batch_size,
)
synchronize(backend)

return Jtv
end
8 changes: 7 additions & 1 deletion src/api/obj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function obj_batch!(
@lencheck bm.model.meta.nvar eachrow(X)
@lencheck length(bm.model.θ) eachrow(Θ)
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
backend = _get_backend(bm)

_obj_batch(backend, obj_work, bm.model.objs, X, Θ)
return vec(sum(obj_work, dims=1)) # FIXME
Expand All @@ -44,4 +44,10 @@ function _obj_batch(backend, obj_work, obj, X, Θ)
_obj_batch(backend, obj_work, obj.inner, X, Θ)
synchronize(backend)
end
function _obj_batch(::Nothing, obj_work, obj, X, Θ)
if !isempty(obj.itr)
kerf_batch_cpu!(obj_work, obj.f, obj.itr, X, Θ)
end
_obj_batch(backend, obj_work, obj.inner, X, Θ)
end
function _obj_batch(backend, obj_work, f::ExaModels.ObjectiveNull, X, Θ) end
Loading