diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 14a4d296..7bba4b95 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -65,7 +65,7 @@ Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]) Base.sizeof(x::JLDeviceArray) = Base.elsize(x) * length(x) Base.unsafe_convert(::Type{Ptr{T}}, x::JLDeviceArray{T}) where {T} = - convert(Ptr{T}, pointer(x.data)) + x.offset*Base.elsize(x) + convert(Ptr{T}, pointer(x.data)) + x.offset # conversion of untyped data to a typed Array function typed_data(x::JLDeviceArray{T}) where {T} @@ -92,7 +92,7 @@ end mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} data::DataRef{Vector{UInt8}} - offset::Int # offset of the data in the buffer, in number of elements + offset::Int # offset of the data in the buffer, in bytes dims::Dims{N} @@ -266,7 +266,7 @@ end function GPUArrays.derive(::Type{T}, a::JLArray, dims::Dims{N}, offset::Int) where {T,N} ref = copy(a.data) - offset = (a.offset * Base.elsize(a)) รท sizeof(T) + offset + offset = a.offset + offset * sizeof(T) JLArray{T,N}(ref, dims; offset) end @@ -343,7 +343,7 @@ Base.size(x::JLArray) = x.dims Base.sizeof(x::JLArray) = Base.elsize(x) * length(x) Base.unsafe_convert(::Type{Ptr{T}}, x::JLArray{T}) where {T} = - convert(Ptr{T}, pointer(x.data[])) + x.offset*Base.elsize(x) + convert(Ptr{T}, pointer(x.data[])) + x.offset ## interop with Julia arrays diff --git a/test/testsuite/base.jl b/test/testsuite/base.jl index a5571244..9c35898e 100644 --- a/test/testsuite/base.jl +++ b/test/testsuite/base.jl @@ -442,6 +442,13 @@ end @test collect(reinterpret(Int32, AT(fill(1f0))))[] == reinterpret(Int32, 1f0) + @testset "reinterpret of view with non-aligned offset" begin + a = AT(Int32[1,2,3,4,5,6,7,8,9]) + v = view(a, 2:7) # offset of 1 Int32 = 4 bytes + r = reinterpret(Int64, v) # Int64 = 8 bytes; 4 is not a multiple of 8 + @test Array(r) == reinterpret(Int64, @view Array(a)[2:7]) + end + @testset "reinterpret(reshape)" begin a = AT(ComplexF32[1.0f0+2.0f0*im, 2.0f0im, 3.0f0im]) b = reinterpret(reshape, Float32, a)