diff --git a/Project.toml b/Project.toml index f73a6d5..965e0be 100644 --- a/Project.toml +++ b/Project.toml @@ -6,31 +6,30 @@ version = "1.0.0-DEV" [deps] ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [extensions] BNKChainRulesCore = "ChainRulesCore" BNKJuMP = "JuMP" +BNKReactant = "Reactant" [compat] ExaModels = "0.8.3" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2" -pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd" AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd" [targets] -test = [ - "Test", "LinearAlgebra", - "OpenCL", "pocl_jll", "AcceleratedKernels", - "DifferentiationInterface", "FiniteDifferences", "Zygote" -] +test = ["Test", "LinearAlgebra", "OpenCL", "pocl_jll", "AcceleratedKernels", "DifferentiationInterface", "FiniteDifferences", "Zygote"] diff --git a/ext/BNKReactant.jl b/ext/BNKReactant.jl new file mode 100644 index 0000000..c2cdf00 --- /dev/null +++ b/ext/BNKReactant.jl @@ -0,0 +1,48 @@ +module BNKReactant + +using BatchNLPKernels +using Reactant, KernelAbstractions +using ExaModels + + +function to_reactant_KA(bm::BNK.BatchModel) + RKA = Base.get_extension(Reactant, :ReactantKernelAbstractionsExt) + if !occursin("CUDA", string(bm.model.ext.backend)) + error("ExaModel must be built with CUDABackend") + end + return BNK.BatchModel( + bm.model, + bm.batch_size, + Reactant.to_rarray(bm.obj_work), + Reactant.to_rarray(bm.cons_work), + Reactant.to_rarray(bm.cons_out), + Reactant.to_rarray(bm.grad_work), + Reactant.to_rarray(bm.grad_out), + Reactant.to_rarray(bm.jprod_work), + Reactant.to_rarray(bm.hprod_work), + Reactant.to_rarray(bm.jprod_out), + Reactant.to_rarray(bm.jtprod_out), + Reactant.to_rarray(bm.hprod_out), + RKA.ReactantBackend(), + ) +end + +function to_reactant(bm::BNK.BatchModel) + return BNK.BatchModel( + bm.model, + bm.batch_size, + Reactant.to_rarray(bm.obj_work), + Reactant.to_rarray(bm.cons_work), + Reactant.to_rarray(bm.cons_out), + Reactant.to_rarray(bm.grad_work), + Reactant.to_rarray(bm.grad_out), + Reactant.to_rarray(bm.jprod_work), + Reactant.to_rarray(bm.hprod_work), + Reactant.to_rarray(bm.jprod_out), + Reactant.to_rarray(bm.jtprod_out), + Reactant.to_rarray(bm.hprod_out), + nothing, + ) +end + +end # module BNKReactant \ No newline at end of file diff --git a/src/BatchNLPKernels.jl b/src/BatchNLPKernels.jl index a18b73f..eb48a1b 100644 --- a/src/BatchNLPKernels.jl +++ b/src/BatchNLPKernels.jl @@ -1,6 +1,7 @@ module BatchNLPKernels using ExaModels +using LinearAlgebra using KernelAbstractions const ExaKA = Base.get_extension(ExaModels, :ExaModelsKernelAbstractions) @@ -8,8 +9,8 @@ const KAExtension = ExaKA.KAExtension include("batch_model.jl") -const BOI = BatchNLPKernels -export BOI, BatchModel, BatchModelConfig +const BNK = BatchNLPKernels +export BNK, BatchModel, BatchModelConfig export obj_batch!, grad_batch!, cons_nln_batch!, jac_coord_batch!, hess_coord_batch! export jprod_nln_batch!, jtprod_nln_batch!, hprod_batch! diff --git a/src/api/cons.jl b/src/api/cons.jl index 46e2002..58c5268 100644 --- a/src/api/cons.jl +++ b/src/api/cons.jl @@ -32,7 +32,7 @@ function cons_nln_batch!( @lencheck length(bm.model.θ) eachrow(Θ) @lencheck bm.model.meta.ncon eachrow(C) _assert_batch_size(batch_size, bm.batch_size) - backend = _get_backend(bm.model) + backend = _get_backend(bm) _cons_nln_batch!(backend, C, bm.model.cons, X, Θ) @@ -41,14 +41,14 @@ function cons_nln_batch!( _conaugs_batch!(backend, conbuffers_batch, bm.model.cons, X, Θ) if length(bm.model.ext.conaugptr) > 1 - compress_to_dense_batch(backend)( + _run_compress_to_dense_batch!( + backend, C, conbuffers_batch, bm.model.ext.conaugptr, - bm.model.ext.conaugsparsity; - ndrange = (length(bm.model.ext.conaugptr) - 1, batch_size), + bm.model.ext.conaugsparsity, + batch_size, ) - synchronize(backend) end return C end @@ -61,6 +61,12 @@ function _cons_nln_batch!(backend, C, con::ExaModels.Constraint, X, Θ) _cons_nln_batch!(backend, C, con.inner, X, Θ) synchronize(backend) end +function _cons_nln_batch!(::Nothing, C, con::ExaModels.Constraint, X, Θ) + if !isempty(con.itr) + kerf_batch_cpu!(C, con.f, con.itr, X, Θ) + end + _cons_nln_batch!(backend, C, con.inner, X, Θ) +end function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintNull, X, Θ) end function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintAug, X, Θ) _cons_nln_batch!(backend, C, con.inner, X, Θ) @@ -74,7 +80,13 @@ function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintAug, X, _conaugs_batch!(backend, conbuffers, con.inner, X, Θ) synchronize(backend) end +function _conaugs_batch!(::Nothing, conbuffers, con::ExaModels.ConstraintAug, X, Θ) + if !isempty(con.itr) + kerf2_batch_cpu!(conbuffers, con.f, con.itr, X, Θ, con.oa) + end + _conaugs_batch!(backend, conbuffers, con.inner, X, Θ) +end function _conaugs_batch!(backend, conbuffers, con::ExaModels.Constraint, X, Θ) _conaugs_batch!(backend, conbuffers, con.inner, X, Θ) end -function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintNull, X, Θ) end +function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintNull, X, Θ) end \ No newline at end of file diff --git a/src/api/grad.jl b/src/api/grad.jl index 0021eda..58f3beb 100644 --- a/src/api/grad.jl +++ b/src/api/grad.jl @@ -39,6 +39,18 @@ function sgradient_batch!( kerg_batch(backend)(Y, f.f, f.itr, X, Θ, adj; ndrange = (length(f.itr), batch_size)) end end +function sgradient_batch!( + ::Nothing, + Y, + f, + X, + Θ, + adj, +) + if !isempty(f.itr) + kerg_batch_cpu!(Y, f.f, f.itr, X, Θ, adj) + end +end """ grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, G::AbstractMatrix) @@ -56,7 +68,7 @@ function grad_batch!( @lencheck bm.model.meta.nvar eachrow(X) eachrow(G) @lencheck length(bm.model.θ) eachrow(Θ) # FIXME _assert_batch_size(batch_size, bm.batch_size) - backend = _get_backend(bm.model) + backend = _get_backend(bm) grad_work = _maybe_view(bm, :grad_work, X) @@ -66,15 +78,15 @@ function grad_batch!( _grad_batch!(backend, grad_work, bm.model.objs, X, Θ) fill!(G, zero(eltype(G))) - compress_to_dense_batch(backend)( + _run_compress_to_dense_batch!( + backend, G, grad_work, bm.model.ext.gptr, - bm.model.ext.gsparsity; - ndrange = (length(bm.model.ext.gptr) - 1, batch_size), + bm.model.ext.gsparsity, + batch_size, ) - synchronize(backend) end return G -end +end \ No newline at end of file diff --git a/src/api/hess.jl b/src/api/hess.jl index 8b52565..04a9b1a 100644 --- a/src/api/hess.jl +++ b/src/api/hess.jl @@ -34,7 +34,7 @@ function hess_coord_batch!( @lencheck bm.model.meta.ncon eachrow(Y) @lencheck bm.model.meta.nnzh eachrow(H) _assert_batch_size(batch_size, bm.batch_size) - backend = _get_backend(bm.model) + backend = _get_backend(bm) fill!(H, zero(eltype(H))) _obj_hess_coord_batch!(backend, H, bm.model.objs, X, Θ, obj_weight) @@ -87,3 +87,33 @@ function shessian_batch!( kerh2_batch(backend)(y1, y2, f.f, f.itr, X, Θ, adj, adj2; ndrange = (length(f.itr), batch_size)) end end + +function shessian_batch!( + backend::Nothing, + y1, + y2, + f, + X, + Θ, + adj, + adj2, +) + if !isempty(f.itr) + kerh_batch_cpu!(y1, y2, f.f, f.itr, X, Θ, adj, adj2) + end +end + +function shessian_batch!( + backend::Nothing, + y1, + y2, + f, + X, + Θ, + adj::AbstractMatrix, + adj2, +) + if !isempty(f.itr) + kerh2_batch_cpu!(y1, y2, f.f, f.itr, X, Θ, adj, adj2) + end +end diff --git a/src/api/hprod.jl b/src/api/hprod.jl index 36a9d67..bae9356 100644 --- a/src/api/hprod.jl +++ b/src/api/hprod.jl @@ -35,7 +35,7 @@ function hprod_batch!( @lencheck length(bm.model.θ) eachrow(Θ) @lencheck bm.model.meta.ncon eachrow(Y) _assert_batch_size(batch_size, bm.batch_size) - backend = _get_backend(bm.model) + backend = _get_backend(bm) ph = _get_prodhelper(bm.model) H_batch = _maybe_view(bm, :hprod_work, X) @@ -43,25 +43,25 @@ function hprod_batch!( hess_coord_batch!(bm, X, Θ, Y, H_batch; obj_weight=obj_weight) fill!(Hv, zero(eltype(Hv))) - kersyspmv_batch(backend)( + _run_kersyspmv_batch!( + backend, Hv, V, ph.hesssparsityi, H_batch, - ph.hessptri; - ndrange = (length(ph.hessptri) - 1, batch_size), + ph.hessptri, + batch_size, ) - synchronize(backend) - - kersyspmv2_batch(backend)( + + _run_kersyspmv2_batch!( + backend, Hv, V, ph.hesssparsityj, H_batch, - ph.hessptrj; - ndrange = (length(ph.hessptrj) - 1, batch_size), + ph.hessptrj, + batch_size, ) - synchronize(backend) return Hv end diff --git a/src/api/jac.jl b/src/api/jac.jl index 72bffc9..accc824 100644 --- a/src/api/jac.jl +++ b/src/api/jac.jl @@ -31,7 +31,7 @@ function jac_coord_batch!( @lencheck length(bm.model.θ) eachrow(Θ) @lencheck bm.model.meta.nnzj eachrow(J) _assert_batch_size(batch_size, bm.batch_size) - backend = _get_backend(bm.model) + backend = _get_backend(bm) fill!(J, zero(eltype(J))) _jac_coord_batch!(backend, J, bm.model.cons, X, Θ) @@ -59,3 +59,17 @@ function sjacobian_batch!( kerj_batch(backend)(y1, y2, f.f, f.itr, X, Θ, adj; ndrange = (length(f.itr), batch_size)) end end + +function sjacobian_batch!( + ::Nothing, + y1, + y2, + f, + X, + Θ, + adj, +) + if !isempty(f.itr) + kerj_batch_cpu!(y1, y2, f.f, f.itr, X, Θ, adj) + end +end diff --git a/src/api/jprod.jl b/src/api/jprod.jl index 60dc0de..01c8381 100644 --- a/src/api/jprod.jl +++ b/src/api/jprod.jl @@ -45,15 +45,7 @@ function jprod_nln_batch!( jac_coord_batch!(bm, X, Θ, J_batch) fill!(Jv, zero(eltype(Jv))) - kerspmv_batch(bm.model.ext.backend)( - Jv, - V, - ph.jacsparsityi, - J_batch, - ph.jacptri; - ndrange = (length(ph.jacptri) - 1, batch_size), - ) - synchronize(bm.model.ext.backend) + _run_kerspmv_batch!(bm.model.ext.backend, Jv, V, ph.jacsparsityi, J_batch, ph.jacptri, batch_size) return Jv end @@ -99,7 +91,7 @@ function jtprod_nln_batch!( @lencheck length(bm.model.θ) eachrow(Θ) @lencheck bm.model.meta.ncon eachrow(V) _assert_batch_size(batch_size, bm.batch_size) - backend = _get_backend(bm.model) + backend = _get_backend(bm) ph = _get_prodhelper(bm.model) J_batch = _maybe_view(bm, :jprod_work, X) @@ -107,15 +99,15 @@ function jtprod_nln_batch!( jac_coord_batch!(bm, X, Θ, J_batch) fill!(Jtv, zero(eltype(Jtv))) - kerspmv2_batch(backend)( + _run_kerspmv2_batch!( + backend, Jtv, V, ph.jacsparsityj, J_batch, - ph.jacptrj; - ndrange = (length(ph.jacptrj) - 1, batch_size), + ph.jacptrj, + batch_size, ) - synchronize(backend) return Jtv end diff --git a/src/api/obj.jl b/src/api/obj.jl index 1b2ce21..61f8dc4 100644 --- a/src/api/obj.jl +++ b/src/api/obj.jl @@ -30,7 +30,7 @@ function obj_batch!( @lencheck bm.model.meta.nvar eachrow(X) @lencheck length(bm.model.θ) eachrow(Θ) _assert_batch_size(batch_size, bm.batch_size) - backend = _get_backend(bm.model) + backend = _get_backend(bm) _obj_batch(backend, obj_work, bm.model.objs, X, Θ) return vec(sum(obj_work, dims=1)) # FIXME @@ -44,4 +44,10 @@ function _obj_batch(backend, obj_work, obj, X, Θ) _obj_batch(backend, obj_work, obj.inner, X, Θ) synchronize(backend) end +function _obj_batch(::Nothing, obj_work, obj, X, Θ) + if !isempty(obj.itr) + kerf_batch_cpu!(obj_work, obj.f, obj.itr, X, Θ) + end + _obj_batch(backend, obj_work, obj.inner, X, Θ) +end function _obj_batch(backend, obj_work, f::ExaModels.ObjectiveNull, X, Θ) end diff --git a/src/batch_model.jl b/src/batch_model.jl index dccb6b4..cda4b5c 100644 --- a/src/batch_model.jl +++ b/src/batch_model.jl @@ -75,8 +75,9 @@ Allows efficient evaluation of multiple points simultaneously. - `jprod_out::MT`: Batch jacobian-vector product buffer (ncon × batch_size), (0 × batch_size) if not allocated - `jtprod_out::MT`: Batch jacobian transpose-vector product buffer (nvar × batch_size), (0 × batch_size) if not allocated - `hprod_out::MT`: Batch hessian-vector product buffer (nvar × batch_size), (0 × batch_size) if not allocated +- `backend::B`: The backend used to evaluate the model. """ -struct BatchModel{MT,E} +struct BatchModel{MT,E,B} model::E batch_size::Int @@ -90,6 +91,8 @@ struct BatchModel{MT,E} jprod_out::MT jtprod_out::MT hprod_out::MT + + backend::B end """ @@ -147,6 +150,7 @@ function BatchModel(model::ExaModels.ExaModel{T,VT,E,O,C}, batch_size::Int; conf jprod_out, jtprod_out, hprod_out, + _get_backend(model), ) end diff --git a/src/kernels.jl b/src/kernels.jl index 743f75e..3b6698c 100644 --- a/src/kernels.jl +++ b/src/kernels.jl @@ -2,11 +2,36 @@ I, batch_idx = @index(Global, NTuple) @inbounds Y[ExaModels.offset0(f, itr, I), batch_idx] = f.f(itr[I], view(X, :, batch_idx), view(Θ, :, batch_idx)) end +function kerf_batch_cpu!(Y::AbstractMatrix, f, itr, X::AbstractMatrix, Θ::AbstractMatrix) + @assert size(X, 2) == size(Θ, 2) == size(Y, 2) "Batch dimension mismatch" + @inbounds for batch_idx in axes(X, 2) + x_batch = view(X, :, batch_idx) + θ_batch = view(Θ, :, batch_idx) + @simd for I in eachindex(itr) + row = ExaModels.offset0(f, itr, I) + Y[row, batch_idx] = f.f(itr[I], x_batch, θ_batch) + end + end + return Y +end + @kernel function kerf2_batch(Y, @Const(f), @Const(itr), @Const(X), @Const(Θ), @Const(oa)) I, batch_idx = @index(Global, NTuple) @inbounds Y[oa+I, batch_idx] = f.f(itr[I], view(X, :, batch_idx), view(Θ, :, batch_idx)) end +function kerf2_batch_cpu!(Y::AbstractMatrix, f, itr, X::AbstractMatrix, Θ::AbstractMatrix, oa::Integer) + @assert size(X, 2) == size(Θ, 2) == size(Y, 2) "Batch dimension mismatch" + @inbounds for batch_idx in axes(X, 2) + x_batch = view(X, :, batch_idx) + θ_batch = view(Θ, :, batch_idx) + @simd for I in eachindex(itr) + Y[oa + I, batch_idx] = f.f(itr[I], x_batch, θ_batch) + end + end + return Y +end + @kernel function kerg_batch(Y, @Const(f), @Const(itr), @Const(X), @Const(Θ), @Const(adj)) I, batch_idx = @index(Global, NTuple) @@ -19,6 +44,26 @@ end adj, ) end +function kerg_batch_cpu!(Y::AbstractMatrix, f, itr, X::AbstractMatrix, Θ::AbstractMatrix, adj) + @assert size(X, 2) == size(Θ, 2) == size(Y, 2) "Batch dimension mismatch" + @inbounds for batch_idx in axes(X, 2) + x_batch = view(X, :, batch_idx) + θ_batch = view(Θ, :, batch_idx) + y_batch = view(Y, :, batch_idx) + @simd for I in eachindex(itr) + ExaModels.grpass( + f.f(itr[I], ExaModels.AdjointNodeSource(x_batch), θ_batch), + f.comp1, + y_batch, + ExaModels.offset1(f, I), + 0, + adj, + ) + end + end + return Y +end + @kernel function kerj_batch(Y1, Y2, @Const(f), @Const(itr), @Const(X), @Const(Θ), @Const(adj)) I, batch_idx = @index(Global, NTuple) @@ -33,6 +78,31 @@ end adj, ) end +function kerj_batch_cpu!(Y1, Y2, f, itr, X::AbstractMatrix, Θ::AbstractMatrix, adj) + @assert size(X, 2) == size(Θ, 2) "Batch dimension mismatch" + nbatch = size(X, 2) + @inbounds for batch_idx in 1:nbatch + x_batch = view(X, :, batch_idx) + θ_batch = view(Θ, :, batch_idx) + y1_view = isnothing(Y1) ? nothing : view(Y1, :, batch_idx) + y2_view = isnothing(Y2) ? nothing : view(Y2, :, batch_idx) + @simd for I in eachindex(itr) + ExaModels.jrpass( + f.f(itr[I], ExaModels.AdjointNodeSource(x_batch), θ_batch), + f.comp1, + ExaModels.offset0(f, itr, I), + y1_view, + y2_view, + ExaModels.offset1(f, I), + 0, + adj, + ) + end + end + return nothing +end + + @kernel function kerh_batch(Y1, Y2, @Const(f), @Const(itr), @Const(X), @Const(Θ), @Const(adj1), @Const(adj2)) I, batch_idx = @index(Global, NTuple) @@ -47,6 +117,30 @@ end adj2, ) end +function kerh_batch_cpu!(Y1, Y2, f, itr, X::AbstractMatrix, Θ::AbstractMatrix, adj1, adj2) + @assert size(X, 2) == size(Θ, 2) "Batch dimension mismatch" + nbatch = size(X, 2) + @inbounds for batch_idx in 1:nbatch + x_batch = view(X, :, batch_idx) + θ_batch = view(Θ, :, batch_idx) + y1_view = isnothing(Y1) ? nothing : view(Y1, :, batch_idx) + y2_view = isnothing(Y2) ? nothing : view(Y2, :, batch_idx) + @simd for I in eachindex(itr) + ExaModels.hrpass0( + f.f(itr[I], ExaModels.SecondAdjointNodeSource(x_batch), θ_batch), + f.comp2, + y1_view, + y2_view, + ExaModels.offset2(f, I), + 0, + adj1, + adj2, + ) + end + end + return nothing +end + @kernel function kerh2_batch(Y1, Y2, @Const(f), @Const(itr), @Const(X), @Const(Θ), @Const(adjs1), @Const(adj2)) I, batch_idx = @index(Global, NTuple) @@ -61,6 +155,30 @@ end adj2, ) end +function kerh2_batch_cpu!(Y1, Y2, f, itr, X::AbstractMatrix, Θ::AbstractMatrix, adjs1, adj2) + @assert size(X, 2) == size(Θ, 2) == size(adjs1, 2) "Batch dimension mismatch" + nbatch = size(X, 2) + @inbounds for batch_idx in 1:nbatch + x_batch = view(X, :, batch_idx) + θ_batch = view(Θ, :, batch_idx) + y1_view = isnothing(Y1) ? nothing : view(Y1, :, batch_idx) + y2_view = isnothing(Y2) ? nothing : view(Y2, :, batch_idx) + @simd for I in eachindex(itr) + ExaModels.hrpass0( + f.f(itr[I], ExaModels.SecondAdjointNodeSource(x_batch), θ_batch), + f.comp2, + y1_view, + y2_view, + ExaModels.offset2(f, I), + 0, + adjs1[ExaModels.offset0(f, itr, I), batch_idx], + adj2, + ) + end + end + return nothing +end + @kernel function compress_to_dense_batch(Y, @Const(Y0), @Const(ptr), @Const(sparsity)) I, batch_idx = @index(Global, NTuple) @@ -69,6 +187,29 @@ end Y[k, batch_idx] += Y0[l, batch_idx] end end +function compress_to_dense_batch_cpu!(Y::AbstractMatrix, Y0::AbstractMatrix, ptr, sparsity) + @assert size(Y, 2) == size(Y0, 2) "Batch dimension mismatch" + nbatch = size(Y, 2) + @inbounds for batch_idx in 1:nbatch + @simd for I in 1:(length(ptr) - 1) + @simd for j in ptr[I]:(ptr[I + 1] - 1) + k, l = sparsity[j] + Y[k, batch_idx] += Y0[l, batch_idx] + end + end + end + return Y +end +@inline function _run_compress_to_dense_batch!(backend, Y, Y0, ptr, sparsity, batch_size) + compress_to_dense_batch(backend)(Y, Y0, ptr, sparsity; ndrange = (length(ptr) - 1, batch_size)) + synchronize(backend) + return Y +end +@inline function _run_compress_to_dense_batch!(::Nothing, Y, Y0, ptr, sparsity, batch_size) + compress_to_dense_batch_cpu!(Y, Y0, ptr, sparsity) + return Y +end + @kernel function kerspmv_batch(Y, @Const(X), @Const(coord), @Const(V), @Const(ptr)) idx, batch_idx = @index(Global, NTuple) @@ -77,6 +218,30 @@ end Y[i, batch_idx] += V[ind, batch_idx] * X[j, batch_idx] end end +function kerspmv_batch_cpu!(Y::AbstractMatrix, X::AbstractMatrix, coord, V::AbstractMatrix, ptr) + @assert size(Y, 2) == size(X, 2) == size(V, 2) "Batch dimension mismatch" + nbatch = size(Y, 2) + nidx = length(ptr) - 1 + @inbounds for batch_idx in 1:nbatch + for idx in 1:nidx + @simd for l in ptr[idx]:(ptr[idx + 1] - 1) + ((i, j), ind) = coord[l] + Y[i, batch_idx] += V[ind, batch_idx] * X[j, batch_idx] + end + end + end + return Y +end +@inline function _run_kerspmv_batch!(backend, Y, X, coord, V, ptr, batch_size) + kerspmv_batch(backend)(Y, X, coord, V, ptr; ndrange = (length(ptr) - 1, batch_size)) + synchronize(backend) + return Y +end +@inline function _run_kerspmv_batch!(::Nothing, Y, X, coord, V, ptr, batch_size) + kerspmv_batch_cpu!(Y, X, coord, V, ptr) + return Y +end + @kernel function kerspmv2_batch(Y, @Const(X), @Const(coord), @Const(V), @Const(ptr)) idx, batch_idx = @index(Global, NTuple) @@ -85,6 +250,30 @@ end Y[j, batch_idx] += V[ind, batch_idx] * X[i, batch_idx] end end +function kerspmv2_batch_cpu!(Y::AbstractMatrix, X::AbstractMatrix, coord, V::AbstractMatrix, ptr) + @assert size(Y, 2) == size(X, 2) == size(V, 2) "Batch dimension mismatch" + nbatch = size(Y, 2) + nidx = length(ptr) - 1 + @inbounds @simd for batch_idx in 1:nbatch + for idx in 1:nidx + @simd for l in ptr[idx]:(ptr[idx + 1] - 1) + ((i, j), ind) = coord[l] + Y[j, batch_idx] += V[ind, batch_idx] * X[i, batch_idx] + end + end + end + return Y +end +@inline function _run_kerspmv2_batch!(backend, Y, X, coord, V, ptr, batch_size) + kerspmv2_batch(backend)(Y, X, coord, V, ptr; ndrange = (length(ptr) - 1, batch_size)) + synchronize(backend) + return Y +end +@inline function _run_kerspmv2_batch!(::Nothing, Y, X, coord, V, ptr, batch_size) + kerspmv2_batch_cpu!(Y, X, coord, V, ptr) + return Y +end + @kernel function kersyspmv_batch(Y, @Const(X), @Const(coord), @Const(V), @Const(ptr)) idx, batch_idx = @index(Global, NTuple) @@ -93,6 +282,30 @@ end Y[i, batch_idx] += V[ind, batch_idx] * X[j, batch_idx] end end +function kersyspmv_batch_cpu!(Y::AbstractMatrix, X::AbstractMatrix, coord, V::AbstractMatrix, ptr) + @assert size(Y, 2) == size(X, 2) == size(V, 2) "Batch dimension mismatch" + nbatch = size(Y, 2) + nidx = length(ptr) - 1 + @inbounds @simd for batch_idx in 1:nbatch + for idx in 1:nidx + @simd for l in ptr[idx]:(ptr[idx + 1] - 1) + ((i, j), ind) = coord[l] + Y[i, batch_idx] += V[ind, batch_idx] * X[j, batch_idx] + end + end + end + return Y +end +@inline function _run_kersyspmv_batch!(backend, Y, X, coord, V, ptr, batch_size) + kersyspmv_batch(backend)(Y, X, coord, V, ptr; ndrange = (length(ptr) - 1, batch_size)) + synchronize(backend) + return Y +end +@inline function _run_kersyspmv_batch!(::Nothing, Y, X, coord, V, ptr, batch_size) + kersyspmv_batch_cpu!(Y, X, coord, V, ptr) + return Y +end + @kernel function kersyspmv2_batch(Y, @Const(X), @Const(coord), @Const(V), @Const(ptr)) idx, batch_idx = @index(Global, NTuple) @@ -102,4 +315,29 @@ end Y[j, batch_idx] += V[ind, batch_idx] * X[i, batch_idx] end end +end +function kersyspmv2_batch_cpu!(Y::AbstractMatrix, X::AbstractMatrix, coord, V::AbstractMatrix, ptr) + @assert size(Y, 2) == size(X, 2) == size(V, 2) "Batch dimension mismatch" + nbatch = size(Y, 2) + nidx = length(ptr) - 1 + @inbounds for batch_idx in 1:nbatch + for idx in 1:nidx + @simd for l in ptr[idx]:(ptr[idx + 1] - 1) + ((i, j), ind) = coord[l] + if i != j + Y[j, batch_idx] += V[ind, batch_idx] * X[i, batch_idx] + end + end + end + end + return Y +end +@inline function _run_kersyspmv2_batch!(backend, Y, X, coord, V, ptr, batch_size) + kersyspmv2_batch(backend)(Y, X, coord, V, ptr; ndrange = (length(ptr) - 1, batch_size)) + synchronize(backend) + return Y +end +@inline function _run_kersyspmv2_batch!(::Nothing, Y, X, coord, V, ptr, batch_size) + kersyspmv2_batch_cpu!(Y, X, coord, V, ptr) + return Y end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index bfca20f..04b85f8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -31,7 +31,7 @@ end _get_prodhelper(bm::BatchModel) = _get_prodhelper(bm.model) _get_prodhelper(model::ExaModels.ExaModel) = model.ext.prodhelper -_get_backend(bm::BatchModel) = _get_backend(bm.model) +_get_backend(bm::BatchModel) = bm.backend _get_backend(model::ExaModels.ExaModel) = model.ext.backend function _check_buffer_available(buffer, buffer_name::Symbol) diff --git a/test/api.jl b/test/api.jl index 230a019..0d33d7b 100644 --- a/test/api.jl +++ b/test/api.jl @@ -1,7 +1,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; atol::Float64=1e-10, rtol::Float64=1e-10) - bm = BOI.BatchModel(model, batch_size, config=BOI.BatchModelConfig(:full)) + bm = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:full)) nvar = model.meta.nvar ncon = model.meta.ncon @@ -12,7 +12,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; @testset "Model Info: $(nvar) vars, $(ncon) cons, $(nθ) params" begin @testset "Objective" begin - obj_vals = BOI.obj_batch!(bm, X, Θ) + obj_vals = BNK.obj_batch!(bm, X, Θ) @test length(obj_vals) == batch_size @test all(isfinite, obj_vals) for i in 1:batch_size @@ -23,7 +23,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; @testset "Constraint" begin if ncon > 0 - cons_vals = BOI.cons_nln_batch!(bm, X, Θ) + cons_vals = BNK.cons_nln_batch!(bm, X, Θ) @test size(cons_vals) == (ncon, batch_size) @test all(isfinite, cons_vals) for i in 1:batch_size @@ -36,7 +36,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; end @testset "Gradient" begin - grad_vals = BOI.grad_batch!(bm, X, Θ) + grad_vals = BNK.grad_batch!(bm, X, Θ) @test size(grad_vals) == (nvar, batch_size) @test all(isfinite, grad_vals) for i in 1:batch_size @@ -50,7 +50,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; @testset "Jacobian-Vector Product" begin if ncon > 0 V = OpenCL.randn(nvar, batch_size) - jprod_vals = BOI.jprod_nln_batch!(bm, X, Θ, V) + jprod_vals = BNK.jprod_nln_batch!(bm, X, Θ, V) @test size(jprod_vals) == (ncon, batch_size) @test all(isfinite, jprod_vals) for i in 1:batch_size @@ -65,7 +65,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; @testset "Jacobian-Transpose-Vector Product" begin if ncon > 0 V = OpenCL.randn(ncon, batch_size) - jtprod_vals = BOI.jtprod_nln_batch!(bm, X, Θ, V) + jtprod_vals = BNK.jtprod_nln_batch!(bm, X, Θ, V) @test size(jtprod_vals) == (nvar, batch_size) @test all(isfinite, jtprod_vals) for i in 1:batch_size @@ -81,7 +81,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; V = OpenCL.randn(nvar, batch_size) if ncon > 0 Y = OpenCL.randn(ncon, batch_size) - hprod_vals = BOI.hprod_batch!(bm, X, Θ, Y, V) + hprod_vals = BNK.hprod_batch!(bm, X, Θ, Y, V) @test size(hprod_vals) == (nvar, batch_size) @test all(isfinite, hprod_vals) for i in 1:batch_size @@ -92,7 +92,7 @@ function test_batch_model(model::ExaModel, batch_size::Int; end else Y = OpenCL.zeros(ncon, batch_size) - hprod_vals = BOI.hprod_batch!(bm, X, Θ, Y, V) + hprod_vals = BNK.hprod_batch!(bm, X, Θ, Y, V) @test size(hprod_vals) == (nvar, batch_size) @test all(isfinite, hprod_vals) for i in 1:batch_size @@ -106,67 +106,67 @@ function test_batch_model(model::ExaModel, batch_size::Int; @testset "Batch Size Validation" begin X_large = OpenCL.randn(nvar, batch_size + 1) - @test_throws AssertionError BOI.obj_batch!(bm, X_large) + @test_throws AssertionError BNK.obj_batch!(bm, X_large) if ncon > 0 - @test_throws AssertionError BOI.cons_nln_batch!(bm, X_large) + @test_throws AssertionError BNK.cons_nln_batch!(bm, X_large) end - @test_throws AssertionError BOI.grad_batch!(bm, X_large) + @test_throws AssertionError BNK.grad_batch!(bm, X_large) if ncon > 0 V_jprod = OpenCL.randn(nvar, batch_size + 1) - @test_throws AssertionError BOI.jprod_nln_batch!(bm, X_large, V_jprod) + @test_throws AssertionError BNK.jprod_nln_batch!(bm, X_large, V_jprod) V_jtprod = OpenCL.randn(ncon, batch_size + 1) - @test_throws AssertionError BOI.jtprod_nln_batch!(bm, X_large, V_jtprod) + @test_throws AssertionError BNK.jtprod_nln_batch!(bm, X_large, V_jtprod) end V_hprod = OpenCL.randn(nvar, batch_size + 1) if ncon > 0 Y_large = OpenCL.randn(ncon, batch_size + 1) - @test_throws AssertionError BOI.hprod_batch!(bm, X_large, Y_large, V_hprod) + @test_throws AssertionError BNK.hprod_batch!(bm, X_large, Y_large, V_hprod) else Y_large = OpenCL.zeros(ncon, batch_size + 1) - @test_throws AssertionError BOI.hprod_batch!(bm, X_large, Y_large, V_hprod) + @test_throws AssertionError BNK.hprod_batch!(bm, X_large, Y_large, V_hprod) end end @testset "Dimension Validation" begin X_wrong = OpenCL.randn(nvar + 1, batch_size) - @test_throws DimensionMismatch BOI.obj_batch!(bm, X_wrong) + @test_throws DimensionMismatch BNK.obj_batch!(bm, X_wrong) if nθ > 0 Θ_wrong = OpenCL.randn(nθ + 1, batch_size) - @test_throws DimensionMismatch BOI.obj_batch!(bm, X, Θ_wrong) + @test_throws DimensionMismatch BNK.obj_batch!(bm, X, Θ_wrong) end if ncon > 0 V_jprod_wrong = OpenCL.randn(nvar + 1, batch_size) - @test_throws DimensionMismatch BOI.jprod_nln_batch!(bm, X, V_jprod_wrong) + @test_throws DimensionMismatch BNK.jprod_nln_batch!(bm, X, V_jprod_wrong) V_jtprod_wrong = OpenCL.randn(ncon + 1, batch_size) - @test_throws DimensionMismatch BOI.jtprod_nln_batch!(bm, X, V_jtprod_wrong) + @test_throws DimensionMismatch BNK.jtprod_nln_batch!(bm, X, V_jtprod_wrong) Y_wrong = OpenCL.randn(ncon + 1, batch_size) V_hprod = OpenCL.randn(nvar, batch_size) - @test_throws DimensionMismatch BOI.hprod_batch!(bm, X, Y_wrong, V_hprod) + @test_throws DimensionMismatch BNK.hprod_batch!(bm, X, Y_wrong, V_hprod) end V_hprod_wrong = OpenCL.randn(nvar + 1, batch_size) if ncon > 0 Y = OpenCL.randn(ncon, batch_size) - @test_throws DimensionMismatch BOI.hprod_batch!(bm, X, Y, V_hprod_wrong) + @test_throws DimensionMismatch BNK.hprod_batch!(bm, X, Y, V_hprod_wrong) else Y = OpenCL.zeros(ncon, batch_size) - @test_throws DimensionMismatch BOI.hprod_batch!(bm, X, Y, V_hprod_wrong) + @test_throws DimensionMismatch BNK.hprod_batch!(bm, X, Y, V_hprod_wrong) end end end end @testset "API" begin - models, names = create_luksan_models() + models, names = create_luksan_models(OpenCLBackend()) for (name, model) in zip(names, models) @testset "$name Model" begin diff --git a/test/config.jl b/test/config.jl index 163c0bc..16fbaa5 100644 --- a/test/config.jl +++ b/test/config.jl @@ -9,44 +9,44 @@ Θ = OpenCL.randn(nθ, batch_size) @testset "Minimal" begin - bm_minimal = BOI.BatchModel(model, batch_size, config=BOI.BatchModelConfig(:minimal)) - @test_throws ArgumentError BOI.grad_batch!(bm_minimal, X, Θ) - @test_throws ArgumentError BOI.jac_coord_batch!(bm_minimal, X, Θ) + bm_minimal = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:minimal)) + @test_throws ArgumentError BNK.grad_batch!(bm_minimal, X, Θ) + @test_throws ArgumentError BNK.jac_coord_batch!(bm_minimal, X, Θ) if ncon > 0 Y = OpenCL.randn(ncon, batch_size) - @test_throws ArgumentError BOI.hess_coord_batch!(bm_minimal, X, Θ, Y) + @test_throws ArgumentError BNK.hess_coord_batch!(bm_minimal, X, Θ, Y) end end @testset "Mat-vec" begin model_with_prod = create_luksan_vlcek_model(5; M = 1, prod = true) - bm_partial = BOI.BatchModel(model_with_prod, batch_size, config=BOI.BatchModelConfig(obj=true, cons=true, grad=true, jac=true, hess=true, jprod=false, jtprod=false, hprod=false)) + bm_partial = BNK.BatchModel(model_with_prod, batch_size, config=BNK.BatchModelConfig(obj=true, cons=true, grad=true, jac=true, hess=true, jprod=false, jtprod=false, hprod=false)) if ncon > 0 V = OpenCL.randn(nvar, batch_size) - @test_throws ArgumentError BOI.jprod_nln_batch!(bm_partial, X, Θ, V) + @test_throws ArgumentError BNK.jprod_nln_batch!(bm_partial, X, Θ, V) V_t = OpenCL.randn(ncon, batch_size) - @test_throws ArgumentError BOI.jtprod_nln_batch!(bm_partial, X, Θ, V_t) + @test_throws ArgumentError BNK.jtprod_nln_batch!(bm_partial, X, Θ, V_t) end V_h = OpenCL.randn(nvar, batch_size) if ncon > 0 Y = OpenCL.randn(ncon, batch_size) - @test_throws ArgumentError BOI.hprod_batch!(bm_partial, X, Θ, Y, V_h) + @test_throws ArgumentError BNK.hprod_batch!(bm_partial, X, Θ, Y, V_h) else Y = OpenCL.zeros(ncon, batch_size) - @test_throws ArgumentError BOI.hprod_batch!(bm_partial, X, Θ, Y, V_h) + @test_throws ArgumentError BNK.hprod_batch!(bm_partial, X, Θ, Y, V_h) end end @testset "Output" begin - bm_no_gradout = BOI.BatchModel(model, batch_size, config=BOI.BatchModelConfig(obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false)) + bm_no_gradout = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(obj=true, cons=true, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false)) G = OpenCL.randn(nvar, batch_size) - @test_throws ArgumentError BOI.grad_batch!(bm_no_gradout, X, Θ, G) - @test_throws ArgumentError BOI.grad_batch!(bm_no_gradout, X, Θ) + @test_throws ArgumentError BNK.grad_batch!(bm_no_gradout, X, Θ, G) + @test_throws ArgumentError BNK.grad_batch!(bm_no_gradout, X, Θ) - bm_no_cons = BOI.BatchModel(model, batch_size, config=BOI.BatchModelConfig(obj=true, cons=false, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false)) - @test_throws ArgumentError BOI.cons_nln_batch!(bm_no_cons, X, Θ) + bm_no_cons = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(obj=true, cons=false, grad=false, jac=false, hess=false, jprod=false, jtprod=false, hprod=false)) + @test_throws ArgumentError BNK.cons_nln_batch!(bm_no_cons, X, Θ) end end diff --git a/test/test_diff.jl b/test/test_diff.jl index 0539b96..0d7b9a4 100644 --- a/test/test_diff.jl +++ b/test/test_diff.jl @@ -6,7 +6,7 @@ import FiniteDifferences function test_diff_gpu(model::ExaModel, batch_size::Int) - bm = BOI.BatchModel(model, batch_size, config=BOI.BatchModelConfig(:full)) + bm = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:full)) nvar = model.meta.nvar ncon = model.meta.ncon @@ -19,14 +19,14 @@ function test_diff_gpu(model::ExaModel, batch_size::Int) Θ_gpu = CLArray(Θ_cpu) @testset "obj_batch! CLArray" begin - y = BOI.obj_batch!(bm, X_gpu, Θ_gpu) + y = BNK.obj_batch!(bm, X_gpu, Θ_gpu) @test y isa CLArray @test size(y) == (batch_size,) function f_gpu(params) X = params[1:nvar, :] Θ = params[nvar+1:end, :] - return sum(BOI.obj_batch!(bm, X, Θ)) + return sum(BNK.obj_batch!(bm, X, Θ)) end params = vcat(X_gpu, Θ_gpu) @@ -38,14 +38,14 @@ function test_diff_gpu(model::ExaModel, batch_size::Int) ncon == 0 && return @testset "cons_nln_batch! CLArray" begin - y = BOI.cons_nln_batch!(bm, X_gpu, Θ_gpu) + y = BNK.cons_nln_batch!(bm, X_gpu, Θ_gpu) @test y isa CLArray @test size(y) == (ncon, batch_size) function f_gpu(params) X = params[1:nvar, :] Θ = params[nvar+1:end, :] - return sum(BOI.cons_nln_batch!(bm, X, Θ)) + return sum(BNK.cons_nln_batch!(bm, X, Θ)) end params = vcat(X_gpu, Θ_gpu) @@ -56,7 +56,7 @@ function test_diff_gpu(model::ExaModel, batch_size::Int) end function test_diff_cpu(model::ExaModel, batch_size::Int) - bm = BOI.BatchModel(model, batch_size, config=BOI.BatchModelConfig(:full)) + bm = BNK.BatchModel(model, batch_size, config=BNK.BatchModelConfig(:full)) nvar = model.meta.nvar ncon = model.meta.ncon @@ -66,13 +66,13 @@ function test_diff_cpu(model::ExaModel, batch_size::Int) Θ_cpu = randn(nθ, batch_size) @testset "obj_batch! CPU" begin - y = BOI.obj_batch!(bm, X_cpu, Θ_cpu) + y = BNK.obj_batch!(bm, X_cpu, Θ_cpu) @test size(y) == (batch_size,) function f_cpu(params) X = params[1:nvar, :] Θ = params[nvar+1:end, :] - return sum(BOI.obj_batch!(bm, X, Θ)) + return sum(BNK.obj_batch!(bm, X, Θ)) end params = vcat(X_cpu, Θ_cpu) @@ -89,13 +89,13 @@ function test_diff_cpu(model::ExaModel, batch_size::Int) ncon == 0 && return @testset "cons_nln_batch! CPU" begin - y = BOI.cons_nln_batch!(bm, X_cpu, Θ_cpu) + y = BNK.cons_nln_batch!(bm, X_cpu, Θ_cpu) @test size(y) == (ncon, batch_size) function f_cpu(params) X = params[1:nvar, :] Θ = params[nvar+1:end, :] - return sum(BOI.cons_nln_batch!(bm, X, Θ)) + return sum(BNK.cons_nln_batch!(bm, X, Θ)) end params = vcat(X_cpu, Θ_cpu)