Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, T
new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
end
end
SparseArrays.nnz(x::JLSparseVector) = x.nnz
SparseArrays.nnz(x::JLSparseVector) = x.nnz
SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal

Expand Down Expand Up @@ -181,7 +181,7 @@ end
function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
end
function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
return SparseMatrixCSC(transpose(x_transpose))
end
Expand Down Expand Up @@ -230,12 +230,12 @@ GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
GPUArrays.sparse_array_type(sa::JLSparseVector) = JLSparseVector
GPUArrays.sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector

GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray
GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray
GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray
GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray
GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray

GPUArrays.csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC
GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
Expand All @@ -246,14 +246,14 @@ Base.similar(Mat::JLSparseMatrixCSR, T::Type) = JLSparseMatrixCSR(copy(Mat.rowPt
Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) = JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) = JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))

Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)

Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)

Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...)
Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...)
Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...)
Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...)

JLArray(x::JLSparseVector) = JLArray(collect(SparseVector(x)))
JLArray(x::JLSparseMatrixCSC) = JLArray(collect(SparseMatrixCSC(x)))
Expand Down Expand Up @@ -323,9 +323,11 @@ StridedJLVector{T} = StridedJLArray{T,1}
StridedJLMatrix{T} = StridedJLArray{T,2}
StridedJLVecOrMat{T} = Union{StridedJLVector{T}, StridedJLMatrix{T}}

Base.pointer(x::StridedJLArray{T}) where {T} = Base.unsafe_convert(Ptr{T}, x)
@inline function Base.pointer(x::StridedJLArray{T}, i::Integer) where T
Base.unsafe_convert(Ptr{T}, x) + Base._memory_offset(x, i)
# Pointer access is only available for callers that explicitly want a pointer.
Base.pointer(x::JLArray{T}) where {T} =
convert(Ptr{T}, pointer(x.data[])) + x.offset
@inline function Base.pointer(x::JLArray{T}, i::Integer) where T
pointer(x) + (i - 1) * sizeof(T)
end

# anything that's (secretly) backed by a JLArray
Expand All @@ -342,8 +344,10 @@ Base.elsize(::Type{<:JLArray{T}}) where {T} = sizeof(T)
Base.size(x::JLArray) = x.dims
Base.sizeof(x::JLArray) = Base.elsize(x) * length(x)

Base.unsafe_convert(::Type{Ptr{T}}, x::JLArray{T}) where {T} =
convert(Ptr{T}, pointer(x.data[])) + x.offset
# Reject implicit conversions to a pointer, i.e., by passing to a CPU ccall.
function Base.unsafe_convert(::Type{Ptr{T}}, x::JLArray{T}) where {T}
error("Illegal conversion of a JLArray to a Ptr")
end


## interop with Julia arrays
Expand Down Expand Up @@ -378,7 +382,7 @@ function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
end
JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs))
JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs))
Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
Base.size(x::JLSparseMatrixCSC) = x.dims

Expand Down
Loading