diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 98a898f111..9f74683bae 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -53,6 +53,9 @@ mutable struct CuSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal)) end end +function GPUArrays.GPUSparseMatrixCSC(colPtr::CuVector{Ti}, rowVal::CuVector{Ti}, nzVal::CuVector{Tv}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} + return CuSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims) +end CuSparseMatrixCSC{Tv, Ti}(csc::CuSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} = csc SparseArrays.rowvals(g::T) where {T<:CuSparseVector} = nonzeroinds(g) @@ -94,7 +97,9 @@ mutable struct CuSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal)) end end - +function GPUArrays.GPUSparseMatrixCSR(rowPtr::CuVector{Ti}, colVal::CuVector{Ti}, nzVal::CuVector{Tv}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} + return CuSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims) +end CuSparseMatrixCSR{Tv, Ti}(csr::CuSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} = csr CuSparseMatrixCSR(A::CuSparseMatrixCSR) = A @@ -147,6 +152,9 @@ mutable struct CuSparseMatrixBSR{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti} new{Tv, Ti}(rowPtr, colVal, nzVal, dims, blockDim, dir, nnz) end end +function GPUArrays.GPUSparseMatrixBSR(rowPtr::CuVector{Ti}, colVal::CuVector{Ti}, nzVal::CuVector{Tv}, dims::NTuple{2,<:Integer}, blockDim::Integer, args...) where {Tv, Ti <: Integer} + return CuSparseMatrixBSR{Tv, Ti}(rowPtr, colVal, nzVal, dims, blockDim, args...) +end CuSparseMatrixBSR(A::CuSparseMatrixBSR) = A @@ -177,7 +185,9 @@ mutable struct CuSparseMatrixCOO{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti} new{Tv, Ti}(rowInd,colInd,nzVal,dims,nnz) end end - +function GPUArrays.GPUSparseMatrixCOO(rowInd::CuVector{Ti}, colInd::CuVector{Ti}, nzVal::CuVector{Tv}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} + return CuSparseMatrixCOO{Tv, Ti}(rowInd, colInd, nzVal, dims) +end CuSparseMatrixCOO(A::CuSparseMatrixCOO) = A mutable struct CuSparseArrayCSR{Tv, Ti, N} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, N}