diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 401780c6..7c652f50 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -164,11 +164,15 @@ function Base._unsafe_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is:: end # Similar for `setindex!`, its default fallback is equivalent to `copyto!`. # We only dispatch the `copyto!` part (`Base._unsafe_setindex!`) to our implement. -function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N - return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) -end -# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization. -function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N +# Also cover the outer `ReshapedArray` that `_maybe_reshape` produces when the parent +# is a `WrappedGPUArray`. Keeping this in the same `Union` as `WrappedGPUArray` (rather +# than as a second method) avoids the dispatch ambiguity from #587: the two signatures +# would otherwise overlap (`WrappedGPUArray` already includes some `ReshapedArray`s) +# without either being a strict subtype of the other. +function Base._unsafe_setindex!(::IndexStyle, A::Union{ + WrappedGPUArray, + Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, + }, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) end diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index 2c44d21a..8a336291 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -284,3 +284,44 @@ end @test compare(argmin, AT, -rand(Int, 10)) end end + +@testsuite "indexing reshaped wrappers" (AT, eltypes) -> begin + # Regression for #587: `Q[:] = …` on an `IndexCartesian` SubArray of a GPU array + # threw a `MethodError` due to ambiguous `_unsafe_setindex!` on the + # `ReshapedArray{…,<:WrappedGPUArray}` that `_maybe_reshape` produces. A + # strided view triggers the same dispatch path as the original `BitVector` + # MWE without depending on non-isbits indices that some backends reject. + @testset "issue #587 with $T" for T in eltypes + @test compare(AT, ones(T, 8, 8, 8)) do P + Q = view(P, 1:2:7, :, :) + Q[:] = fill(T(2), length(Q)) + P + end + end + + @testset "reshape(view) with $T" for T in eltypes + @test compare(AT, ones(T, 4, 4, 4)) do A + R = reshape(view(A, 1:2, 1:2, 1:2), :) + R[:] = fill(T(2), length(R)) + A + end + end + + @testset "reshape(PermutedDimsArray) with $T" for T in eltypes + @test compare(AT, ones(T, 4, 4)) do A + R = reshape(PermutedDimsArray(A, (2, 1)), :) + R[1:2] = fill(zero(T), 2) + A + end + end + + @testset "reshape(reinterpret) with $T" for T in eltypes + T <: Complex || continue + IT = Complex{Int16} + @test compare(AT, ones(T, 4, 4)) do A + R = reshape(reinterpret(IT, A), :) + R[1:2] = fill(zero(IT), 2) + A + end + end +end