Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,15 @@ function Base._unsafe_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is::
end
# Similar for `setindex!`, its default fallback is equivalent to `copyto!`.
# We only dispatch the `copyto!` part (`Base._unsafe_setindex!`) to our implement.
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
# Also cover the outer `ReshapedArray` that `_maybe_reshape` produces when the parent
# is a `WrappedGPUArray`. Keeping this in the same `Union` as `WrappedGPUArray` (rather
# than as a second method) avoids the dispatch ambiguity from #587: the two signatures
# would otherwise overlap (`WrappedGPUArray` already includes some `ReshapedArray`s)
# without either being a strict subtype of the other.
function Base._unsafe_setindex!(::IndexStyle, A::Union{
WrappedGPUArray,
Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray},
}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end

Expand Down
41 changes: 41 additions & 0 deletions test/testsuite/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,44 @@ end
@test compare(argmin, AT, -rand(Int, 10))
end
end

@testsuite "indexing reshaped wrappers" (AT, eltypes) -> begin
# Regression for #587: `Q[:] = …` on an `IndexCartesian` SubArray of a GPU array
# threw a `MethodError` due to ambiguous `_unsafe_setindex!` on the
# `ReshapedArray{…,<:WrappedGPUArray}` that `_maybe_reshape` produces. A
# strided view triggers the same dispatch path as the original `BitVector`
# MWE without depending on non-isbits indices that some backends reject.
@testset "issue #587 with $T" for T in eltypes
@test compare(AT, ones(T, 8, 8, 8)) do P
Q = view(P, 1:2:7, :, :)
Q[:] = fill(T(2), length(Q))
P
end
end

@testset "reshape(view) with $T" for T in eltypes
@test compare(AT, ones(T, 4, 4, 4)) do A
R = reshape(view(A, 1:2, 1:2, 1:2), :)
R[:] = fill(T(2), length(R))
A
end
end

@testset "reshape(PermutedDimsArray) with $T" for T in eltypes
@test compare(AT, ones(T, 4, 4)) do A
R = reshape(PermutedDimsArray(A, (2, 1)), :)
R[1:2] = fill(zero(T), 2)
A
end
end

@testset "reshape(reinterpret) with $T" for T in eltypes
T <: Complex || continue
IT = Complex{Int16}
@test compare(AT, ones(T, 4, 4)) do A
R = reshape(reinterpret(IT, A), :)
R[1:2] = fill(zero(IT), 2)
A
end
end
end
Loading