diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index bf671ff2..4d2dad41 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -134,7 +134,7 @@ mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, T new{Tv, Ti}(iPtr, nzVal, len, length(nzVal)) end end -SparseArrays.nnz(x::JLSparseVector) = x.nnz +SparseArrays.nnz(x::JLSparseVector) = x.nnz SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal @@ -181,7 +181,7 @@ end function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims) end -function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR) +function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR) x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal)) return SparseMatrixCSC(transpose(x_transpose)) end @@ -230,12 +230,12 @@ GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR GPUArrays.sparse_array_type(sa::JLSparseVector) = JLSparseVector GPUArrays.sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector -GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray -GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray -GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray -GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray -GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray -GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray +GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray +GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray +GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray +GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray +GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray +GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray GPUArrays.csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR @@ -246,14 +246,14 @@ Base.similar(Mat::JLSparseMatrixCSR, T::Type) = JLSparseMatrixCSR(copy(Mat.rowPt Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) = JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M)) Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) = JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M)) -Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) -Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) +Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) +Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) -Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) -Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) +Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) +Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) -Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...) -Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...) +Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...) +Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...) JLArray(x::JLSparseVector) = JLArray(collect(SparseVector(x))) JLArray(x::JLSparseMatrixCSC) = JLArray(collect(SparseMatrixCSC(x))) @@ -323,9 +323,11 @@ StridedJLVector{T} = StridedJLArray{T,1} StridedJLMatrix{T} = StridedJLArray{T,2} StridedJLVecOrMat{T} = Union{StridedJLVector{T}, StridedJLMatrix{T}} -Base.pointer(x::StridedJLArray{T}) where {T} = Base.unsafe_convert(Ptr{T}, x) -@inline function Base.pointer(x::StridedJLArray{T}, i::Integer) where T - Base.unsafe_convert(Ptr{T}, x) + Base._memory_offset(x, i) +# Pointer access is only available for callers that explicitly want a pointer. +Base.pointer(x::JLArray{T}) where {T} = + convert(Ptr{T}, pointer(x.data[])) + x.offset +@inline function Base.pointer(x::JLArray{T}, i::Integer) where T + pointer(x) + (i - 1) * sizeof(T) end # anything that's (secretly) backed by a JLArray @@ -342,8 +344,10 @@ Base.elsize(::Type{<:JLArray{T}}) where {T} = sizeof(T) Base.size(x::JLArray) = x.dims Base.sizeof(x::JLArray) = Base.elsize(x) * length(x) -Base.unsafe_convert(::Type{Ptr{T}}, x::JLArray{T}) where {T} = - convert(Ptr{T}, pointer(x.data[])) + x.offset +# Reject implicit conversions to a pointer, i.e., by passing to a CPU ccall. +function Base.unsafe_convert(::Type{Ptr{T}}, x::JLArray{T}) where {T} + error("Illegal conversion of a JLArray to a Ptr") +end ## interop with Julia arrays @@ -378,7 +382,7 @@ function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv} copyto!(nzVal, convert(Vector{Tv}, xs.nzval)) return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n)) end -JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs)) +JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs)) Base.length(x::JLSparseMatrixCSC) = prod(x.dims) Base.size(x::JLSparseMatrixCSC) = x.dims