Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ version = "0.0.1"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
julia = "1"
LinearMaps = "3"
OhMyThreads = "0.8.3"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
1 change: 1 addition & 0 deletions src/BlockSparseMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module BlockSparseMatrices
using LinearAlgebra
using SparseArrays
using LinearMaps
using OhMyThreads

include("abstractblockmatrix.jl")
include("matrixblock/abstractmatrixblock.jl")
Expand Down
61 changes: 53 additions & 8 deletions src/blockmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ struct BlockSparseMatrix{T,M,D} <: AbstractBlockMatrix{T}
buffer::Vector{T}
rowindexdict::D
colindexdict::D
threadsafecolors::Vector{Vector{Int}}
ntasks::Int
end

function BlockSparseMatrix(
blocks::Vector{M}, rowindices::V, colindices::V, size::Tuple{Int,Int}
blocks::Vector{M}, rowindices::V, colindices::V, size::Tuple{Int,Int}; ntasks=1
) where {M,V}
denseblockmatrices = Vector{DenseMatrixBlock{eltype(M),M,eltype(rowindices)}}(
undef, length(blocks)
Expand All @@ -19,14 +21,14 @@ function BlockSparseMatrix(
denseblockmatrices[i] = DenseMatrixBlock(blocks[i], rowindices[i], colindices[i])
end

return BlockSparseMatrix(denseblockmatrices, size)
return BlockSparseMatrix(denseblockmatrices, size; ntasks=ntasks)
end

function BlockSparseMatrix(blocks::Vector{M}, size::Tuple{Int,Int}) where {M}
return BlockSparseMatrix(blocks, size[1], size[2])
function BlockSparseMatrix(blocks::Vector{M}, size::Tuple{Int,Int}; ntasks=1) where {M}
return BlockSparseMatrix(blocks, size[1], size[2]; ntasks=ntasks)
end

function BlockSparseMatrix(blocks::Vector{M}, rows::Int, cols::Int) where {M}
function BlockSparseMatrix(blocks::Vector{M}, rows::Int, cols::Int; ntasks=1) where {M}
forwardbuffer, adjointbuffer, buffer = buffers(eltype(M), rows, cols)

sort!(blocks; lt=islessinordering)
Expand All @@ -39,6 +41,23 @@ function BlockSparseMatrix(blocks::Vector{M}, rows::Int, cols::Int) where {M}
_appendindexdict!(colindexdict, block.colindices, i)
end

#TODO: Pessimistic choice -> check performance
if blocks != M[]
threadsafecolors = [
Int[] for _ in
1:(maximum(length.(values(rowindexdict))) + maximum(
length.(values(colindexdict))
))
]
else
threadsafecolors = Vector{Int}[]
end
colorperm = Vector(1:length(threadsafecolors))
for i in eachindex(blocks)
findcolor!(i, view(threadsafecolors, colorperm), blocks)
sortperm!(colorperm, length.(threadsafecolors))
end

return BlockSparseMatrix{eltype(M),M,typeof(rowindexdict)}(
blocks,
(rows, cols),
Expand All @@ -47,6 +66,8 @@ function BlockSparseMatrix(blocks::Vector{M}, rows::Int, cols::Int) where {M}
buffer,
rowindexdict,
colindexdict,
threadsafecolors,
ntasks,
)
end

Expand Down Expand Up @@ -74,6 +95,26 @@ function block(A::M, i) where {Z<:BlockSparseMatrix,T,M<:LinearMaps.TransposeMap
return transpose(block(A.lmap, i))
end

ntasks(A::BlockSparseMatrix) = A.ntasks

function ntasks(
A::M
) where {
Z<:BlockSparseMatrix,T,M<:Union{LinearMaps.AdjointMap{T,Z},LinearMaps.TransposeMap{T,Z}}
}
return ntasks(A.lmap)
end

threadsafecolors(A::BlockSparseMatrix) = A.threadsafecolors

function threadsafecolors(
A::M
) where {
Z<:BlockSparseMatrix,T,M<:Union{LinearMaps.AdjointMap{T,Z},LinearMaps.TransposeMap{T,Z}}
}
return threadsafecolors(A.lmap)
end

function SparseArrays.nnz(A::BlockSparseMatrix)
nonzeros = 0
for blockid in eachblockindex(A)
Expand All @@ -90,10 +131,14 @@ function LinearMaps._unsafe_mul!(
M<:Union{Z,LinearMaps.AdjointMap{<:Any,Z},LinearMaps.TransposeMap{<:Any,Z}},
}
y .= zero(eltype(y))
for blockid in eachblockindex(A)
b = block(A, blockid)
@inbounds LinearAlgebra.mul!(view(y, rowindices(b)), b, view(x, colindices(b)))
for color in threadsafecolors(A)
@tasks for blockid in color
@set ntasks = ntasks(A)
b = block(A, blockid)
@inbounds LinearAlgebra.mul!(view(y, rowindices(b)), b, view(x, colindices(b)))
end
end

return y
end

Expand Down
59 changes: 59 additions & 0 deletions src/matrixblock/abstractmatrixblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,62 @@ function islessinordering(blocka::AbstractMatrixBlock, blockb::AbstractMatrixBlo
return maximum(colindices(blocka)) < maximum(colindices(blockb))
end
end

struct isthreadsafe end
struct issymthreadsafe
isthreadsafe::isthreadsafe
end
issymthreadsafe() = issymthreadsafe(isthreadsafe())

function (ists::issymthreadsafe)(blocka, blockb)
return ists.isthreadsafe(blocka, blockb) &&
ists.isthreadsafe(transpose(blocka), blockb) &&
ists.isthreadsafe(blocka, transpose(blockb))
end

function (::isthreadsafe)(blocka, blockb)
if size(blocka, 1) >= size(blockb, 1)
issubset(rowindices(blockb), rowindices(blocka)) && (return false)

if size(blocka, 2) >= size(blockb, 2)
return !issubset(colindices(blockb), colindices(blocka))
else
return !issubset(colindices(blocka), colindices(blockb))
end
else
issubset(rowindices(blocka), rowindices(blockb)) && (return false)

if size(blocka, 2) >= size(blockb, 2)
return !issubset(colindices(blockb), colindices(blocka))
else
return !issubset(colindices(blocka), colindices(blockb))
end
end
end

function findcolor!(
blockid::Int,
threadsafecolors::AbstractArray,
blocks::Vector{M};
threadsafecheck=isthreadsafe(),
color=1,
) where {M}
#This case should not appear, it is a rescue measure
(length(threadsafecolors) < color) && return push!(threadsafecolors, [blockid])

for testblockid in threadsafecolors[color]
if threadsafecheck(blocks[testblockid], blocks[blockid])
return push!(threadsafecolors[color], blockid)
else
return findcolor!(
blockid,
threadsafecolors,
blocks;
threadsafecheck=threadsafecheck,
color=color + 1,
)
end
end

return push!(threadsafecolors[color], blockid)
end
85 changes: 71 additions & 14 deletions src/symmetricblockmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ struct SymmetricBlockMatrix{T,DM,M,D} <: AbstractBlockMatrix{T}
diagonalscolindexdict::D
offdiagonalsrowindexdict::D
offdiagonalscolindexdict::D
threadsafecolors::Vector{Vector{Int}}
ntasks::Int
end

function SymmetricBlockMatrix(
Expand All @@ -18,7 +20,8 @@ function SymmetricBlockMatrix(
offdiagonals::Vector{M},
rowidcs::V,
colidcs::V,
size::Tuple{Int,Int},
size::Tuple{Int,Int};
ntasks=1,
) where {DM,M,V}
offdiagonalblocks = Vector{DenseMatrixBlock{eltype(M),M,eltype(rowidcs)}}(
undef, length(offdiagonals)
Expand All @@ -35,17 +38,17 @@ function SymmetricBlockMatrix(
diagonalblocks[i] = DenseMatrixBlock(diagonals[i], drowidcs[i], dcolidcs[i])
end

return SymmetricBlockMatrix(diagonalblocks, offdiagonalblocks, size)
return SymmetricBlockMatrix(diagonalblocks, offdiagonalblocks, size; ntasks=ntasks)
end

function SymmetricBlockMatrix(
diagonals::Vector{DM}, offdiagonals::Vector{M}, size::Tuple{Int,Int}
diagonals::Vector{DM}, offdiagonals::Vector{M}, size::Tuple{Int,Int}; ntasks=1
) where {DM,M}
return SymmetricBlockMatrix(diagonals, offdiagonals, size[1], size[2])
return SymmetricBlockMatrix(diagonals, offdiagonals, size[1], size[2]; ntasks=ntasks)
end

function SymmetricBlockMatrix(
diagonals::Vector{DM}, offdiagonals::Vector{M}, rows::Int, cols::Int
diagonals::Vector{DM}, offdiagonals::Vector{M}, rows::Int, cols::Int; ntasks=1
) where {DM,M}
forwardbuffer, adjointbuffer, buffer = buffers(eltype(M), rows, cols)

Expand All @@ -66,6 +69,27 @@ function SymmetricBlockMatrix(
_appendindexdict!(offdiagonalscolindexdict, block.colindices, i)
end

if offdiagonals != M[]
threadsafecolors = [
Int[] for _ in
1:(2 * (maximum(length.(values(offdiagonalsrowindexdict))) + maximum(
length.(values(offdiagonalscolindexdict))
)))
]
else
threadsafecolors = Vector{Int}[]
end
colorperm = Vector(1:length(threadsafecolors))
for i in eachindex(offdiagonals)
findcolor!(
i,
view(threadsafecolors, colorperm),
offdiagonals;
threadsafecheck=issymthreadsafe(),
)
sortperm!(colorperm, length.(threadsafecolors))
end

return SymmetricBlockMatrix{eltype(M),DM,M,typeof(diagonalsrowindexdict)}(
diagonals,
offdiagonals,
Expand All @@ -77,6 +101,8 @@ function SymmetricBlockMatrix(
diagonalscolindexdict,
offdiagonalsrowindexdict,
offdiagonalscolindexdict,
threadsafecolors,
ntasks,
)
end

Expand Down Expand Up @@ -116,6 +142,8 @@ function diagonal(A::SymmetricBlockMatrix, i)
return A.diagonals[i]
end

ntasks(A::SymmetricBlockMatrix) = A.ntasks

function offdiagonal(
A::M, i
) where {Z<:SymmetricBlockMatrix,T,M<:LinearMaps.AdjointMap{T,Z}}
Expand All @@ -126,6 +154,15 @@ function diagonal(A::M, i) where {Z<:SymmetricBlockMatrix,T,M<:LinearMaps.Adjoin
return adjoint(diagonal(A.lmap, i))
end

function ntasks(
A::M
) where {
Z<:SymmetricBlockMatrix,
M<:Union{Z,LinearMaps.AdjointMap{<:Any,Z},LinearMaps.TransposeMap{<:Any,Z}},
}
return ntasks(A.lmap)
end

function offdiagonal(
A::M, i
) where {Z<:SymmetricBlockMatrix,T,M<:LinearMaps.TransposeMap{T,Z}}
Expand Down Expand Up @@ -167,23 +204,43 @@ function offdiagonalcolindices(A::SymmetricBlockMatrix, j::Integer)
return A.offdiagonalscolindexdict[j]
end

threadsafecolors(A::SymmetricBlockMatrix) = A.threadsafecolors

function threadsafecolors(
A::M
) where {
Z<:SymmetricBlockMatrix,
T,
M<:Union{LinearMaps.AdjointMap{T,Z},LinearMaps.TransposeMap{T,Z}},
}
return threadsafecolors(A.lmap)
end

function LinearMaps._unsafe_mul!(
y::AbstractVector, A::M, x::AbstractVector
) where {
Z<:SymmetricBlockMatrix,
M<:Union{Z,LinearMaps.AdjointMap{<:Any,Z},LinearMaps.TransposeMap{<:Any,Z}},
}
y .= zero(eltype(y))
for blockid in eachoffdiagonalindex(A)
b = offdiagonal(A, blockid)
LinearAlgebra.mul!(
view(y, rowindices(b)), matrix(b), view(x, colindices(b)), true, true
)
LinearAlgebra.mul!(
view(y, colindices(b)), transpose(matrix(b)), view(x, rowindices(b)), true, true
)
for color in threadsafecolors(A)
@tasks for blockid in color
@set ntasks = ntasks(A)
b = offdiagonal(A, blockid)
LinearAlgebra.mul!(
view(y, rowindices(b)), matrix(b), view(x, colindices(b)), true, true
)
LinearAlgebra.mul!(
view(y, colindices(b)),
transpose(matrix(b)),
view(x, rowindices(b)),
true,
true,
)
end
end
for blockid in eachdiagonalindex(A)
@tasks for blockid in eachdiagonalindex(A)
@set ntasks = ntasks(A)
b = diagonal(A, blockid)
@inbounds LinearAlgebra.mul!(view(y, rowindices(b)), b, view(x, colindices(b)))
end
Expand Down
12 changes: 12 additions & 0 deletions test/test_blockmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,24 @@ block1 = BlockSparseMatrices.DenseMatrixBlock(mat1, 1:2, 1:2)
block2 = BlockSparseMatrices.DenseMatrixBlock(mat2, 3:5, 3:5)
block3 = BlockSparseMatrices.DenseMatrixBlock(mat3, 1:2, 3:5)
block4 = BlockSparseMatrices.DenseMatrixBlock(mat4, 3:5, 1:2)
block5 = BlockSparseMatrices.DenseMatrixBlock(mat1, 6:7, 6:7)

@test BlockSparseMatrices.isthreadsafe()(block1, block2)
@test BlockSparseMatrices.isthreadsafe()(block2, block1)
@test BlockSparseMatrices.isthreadsafe()(block3, block4)
@test BlockSparseMatrices.isthreadsafe()(block4, block3)
@test !BlockSparseMatrices.isthreadsafe()(block1, block3)
@test !BlockSparseMatrices.isthreadsafe()(block3, block1)
@test !BlockSparseMatrices.isthreadsafe()(block1, block4)
@test !BlockSparseMatrices.isthreadsafe()(block4, block1)

blockmatrix = BlockSparseMatrix([block1, block2, block3, block4], 5, 5)
blockmatrix2 = BlockSparseMatrix([block1, block2, block3, block4], (5, 5))
blockmatrix3 = BlockSparseMatrix(
[mat1, mat2, mat3, mat4], [1:2, 3:5, 1:2, 3:5], [1:2, 3:5, 3:5, 1:2], (5, 5)
)
blockmatrix4 = BlockSparseMatrix([block1, block2, block5], 7, 7)
@test length(blockmatrix4.threadsafecolors) == 2

@test blockmatrix.blocks == blockmatrix2.blocks == blockmatrix3.blocks

Expand Down
Loading