diff --git a/Project.toml b/Project.toml index 28518343..091a37f9 100644 --- a/Project.toml +++ b/Project.toml @@ -10,10 +10,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" [weakdeps] +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [extensions] +SciMLOperatorsLoopVectorizationExt = "LoopVectorization" SciMLOperatorsSparseArraysExt = "SparseArrays" SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore" @@ -22,6 +24,7 @@ Accessors = "0.1.42" ArrayInterface = "7.19" DocStringExtensions = "0.9.4" LinearAlgebra = "1.10" +LoopVectorization = "0.12" SparseArrays = "1.10" StaticArraysCore = "1" julia = "1.10" diff --git a/ext/SciMLOperatorsLoopVectorizationExt.jl b/ext/SciMLOperatorsLoopVectorizationExt.jl new file mode 100644 index 00000000..343c7240 --- /dev/null +++ b/ext/SciMLOperatorsLoopVectorizationExt.jl @@ -0,0 +1,46 @@ +module SciMLOperatorsLoopVectorizationExt + +import LoopVectorization: @turbo +import SciMLOperators + +const StridedMatrixOperator = SciMLOperators.MatrixOperator{<:Any, <:StridedMatrix} + +SciMLOperators._has_tensor_outer_mul_fast(::StridedMatrixOperator) = true + +function SciMLOperators._tensor_outer_mul_fast!( + w, outer::StridedMatrixOperator, C, mi::Int, mo::Int, no::Int, k::Int + ) + A = outer.A + C = reshape(C, (mi, no, k)) + W = reshape(w, (mi, mo, k)) + + @turbo for j in 1:k, m in 1:mo, i in 1:mi + acc = zero(eltype(w)) + for o in 1:no + acc += A[m, o] * C[i, o, j] + end + W[i, m, j] = acc + end + + return w +end + +function SciMLOperators._tensor_outer_mul_fast!( + w, outer::StridedMatrixOperator, C, mi::Int, mo::Int, no::Int, k::Int, α, β + ) + A = outer.A + C = reshape(C, (mi, no, k)) + W = reshape(w, (mi, mo, k)) + + @turbo for j in 1:k, m in 1:mo, i in 1:mi + acc = zero(eltype(w)) + for o in 1:no + acc += A[m, o] * C[i, o, j] + end + W[i, m, j] = α * acc + β * W[i, m, j] + end + + return w +end + +end diff --git a/src/tensor.jl b/src/tensor.jl index 50089bf0..11f144b6 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -413,6 +413,9 @@ end # helper functions const PERM = (2, 1, 3) +_has_tensor_outer_mul_fast(outer) = false +function _tensor_outer_mul_fast! end + function outer_mul(L::TensorProductOperator, v::AbstractVecOrMat, C::AbstractVecOrMat) outer, inner = L.ops @@ -465,6 +468,11 @@ function outer_mul!(w::AbstractVecOrMat, L::TensorProductOperator, v::AbstractVe return w end + if _has_tensor_outer_mul_fast(outer) + _tensor_outer_mul_fast!(w, outer, C1, mi, mo, no, k) + return w + end + C2, C3 = L.cache[2:3] C1 = reshape(C1, (mi, no, k)) @@ -503,6 +511,11 @@ function outer_mul!( return w end + if _has_tensor_outer_mul_fast(outer) + _tensor_outer_mul_fast!(w, outer, v, mi, mo, no, k, α, β) + return w + end + C2, C3, c4 = L.cache[2:4] C = reshape(v, (mi, no, k)) diff --git a/test/Project.toml b/test/Project.toml index 342ab975..8d16f7b1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -10,5 +11,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] FFTW = "1.10.0" +LoopVectorization = "0.12" SafeTestsets = "0.1.0" Zygote = "0.7.10" diff --git a/test/matrix.jl b/test/matrix.jl index 3d363e56..355b1f36 100644 --- a/test/matrix.jl +++ b/test/matrix.jl @@ -2,6 +2,7 @@ using SciMLOperators, LinearAlgebra using SparseArrays using Random using Test +using LoopVectorization using SciMLOperators: InvertibleOperator, InvertedOperator, ⊗, AbstractSciMLOperator using FFTW