From 0dff2ec6e71f746910e83b16ed5c7a105b228c51 Mon Sep 17 00:00:00 2001 From: shreyas-omkar Date: Tue, 20 Jan 2026 17:05:48 +0530 Subject: [PATCH 1/3] fix: Specialized ReshapedArray dispatch to resolve setindex! ambiguities --- src/host/indexing.jl | 10 ++++-- test/testsuite/indexing.jl | 70 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 401780c6..898170f0 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -167,8 +167,14 @@ end function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) end -# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization. -function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N + +#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties. +function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M} + return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) +end + +#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties. +function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M} return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) end diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index 2c44d21a..aaae72b2 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -284,3 +284,73 @@ end @test compare(argmin, AT, -rand(Int, 10)) end end + +@testsuite "indexing combinatorial" (AT, eltypes) -> begin + @testset "Reshaped SubArray dispatch" for T in eltypes + @testset "3D slice assignment" begin + A = AT(ones(T, 4, 4, 4)) + @views V = A[:, :, 1:2] + @allowscalar begin + @test_nowarn V .= zero(T) + @test all(Array(V) .== zero(T)) + end + end + + @testset "Logical mask view (dim = 3) — GPU safe" begin + A = AT(ones(T, 4, 4, 4)) + idx = findall(Bool[true, false, true, false]) + @views V = A[:, :, idx] + @allowscalar begin + @test_nowarn V .+= T(2) + @test all(Array(V) .== T(3)) + end + end + + @testset "Nested Reshape" begin + A = AT(ones(T, 4, 4, 4)) + V = view(A, 1:2, 1:2, 1:2) + R1 = reshape(V, 4, 2) + R2 = reshape(R1, :) + @allowscalar begin + @test_nowarn R2 .+= one(T) + @test all(Array(R2) .== T(2)) + end + end + end + + @testset "Permuted and Reinterpreted Views" for T in eltypes + @testset "Reshaped PermutedDims" begin + A = AT(ones(T, 4, 4)) + P = PermutedDimsArray(A, (2, 1)) + R = reshape(P, :) + @allowscalar begin + @test_nowarn R[1:2] .= zero(T) + # Check the full assigned range. + @test all(Array(R)[1:2] .== zero(T)) + end + end + + @testset "Reshaped Reinterpreted" begin + T_base = real(T) + if T <: Complex + A = AT(ones(T, 4, 4)) + IT = Complex{Int16} + R = reshape(reinterpret(IT, A), :) + @allowscalar begin + @test_nowarn R[1:2] .= zero(IT) + @test all(Array(R)[1:2] .== zero(IT)) + end + end + end + end + + @testset "Data parity with compare() — GPU safe" for T in eltypes + idx = 2:4 + @test compare(AT, rand(T, 8, 8, 8)) do A + # compare() handles CPU/GPU execution no @allowscalar needed here + V = view(A, :, idx, :) + V .+= one(T) + A + end + end +end \ No newline at end of file From bced345886a408c79eb85acfad924ce62efb9799 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 15 May 2026 11:35:54 +0200 Subject: [PATCH 2/3] Simplify. --- src/host/indexing.jl | 20 ++++----- test/testsuite/indexing.jl | 86 +++++++++++++------------------------- 2 files changed, 37 insertions(+), 69 deletions(-) diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 898170f0..7c652f50 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -164,17 +164,15 @@ function Base._unsafe_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is:: end # Similar for `setindex!`, its default fallback is equivalent to `copyto!`. # We only dispatch the `copyto!` part (`Base._unsafe_setindex!`) to our implement. -function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N - return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) -end - -#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties. -function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M} - return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) -end - -#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties. -function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M} +# Also cover the outer `ReshapedArray` that `_maybe_reshape` produces when the parent +# is a `WrappedGPUArray`. Keeping this in the same `Union` as `WrappedGPUArray` (rather +# than as a second method) avoids the dispatch ambiguity from #587: the two signatures +# would otherwise overlap (`WrappedGPUArray` already includes some `ReshapedArray`s) +# without either being a strict subtype of the other. +function Base._unsafe_setindex!(::IndexStyle, A::Union{ + WrappedGPUArray, + Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, + }, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) end diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index aaae72b2..c6bb2e76 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -285,72 +285,42 @@ end end end -@testsuite "indexing combinatorial" (AT, eltypes) -> begin - @testset "Reshaped SubArray dispatch" for T in eltypes - @testset "3D slice assignment" begin - A = AT(ones(T, 4, 4, 4)) - @views V = A[:, :, 1:2] - @allowscalar begin - @test_nowarn V .= zero(T) - @test all(Array(V) .== zero(T)) - end - end - - @testset "Logical mask view (dim = 3) — GPU safe" begin - A = AT(ones(T, 4, 4, 4)) - idx = findall(Bool[true, false, true, false]) - @views V = A[:, :, idx] - @allowscalar begin - @test_nowarn V .+= T(2) - @test all(Array(V) .== T(3)) - end - end - - @testset "Nested Reshape" begin - A = AT(ones(T, 4, 4, 4)) - V = view(A, 1:2, 1:2, 1:2) - R1 = reshape(V, 4, 2) - R2 = reshape(R1, :) - @allowscalar begin - @test_nowarn R2 .+= one(T) - @test all(Array(R2) .== T(2)) - end +@testsuite "indexing reshaped wrappers" (AT, eltypes) -> begin + # Regression for #587: `Q[:] = …` where `Q = @view P[…]` of a GPU array used to + # throw a `MethodError` due to ambiguous `_unsafe_setindex!` on the + # `ReshapedArray{…,<:WrappedGPUArray}` that `_maybe_reshape` produces. + @testset "issue #587 with $T" for T in eltypes + @test compare(AT, ones(T, 16, 16, 16)) do P + active = (1:16) .< 12 + Q = @view P[:, :, active] + Q[:] = Q .+ one(T) + P end end - @testset "Permuted and Reinterpreted Views" for T in eltypes - @testset "Reshaped PermutedDims" begin - A = AT(ones(T, 4, 4)) - P = PermutedDimsArray(A, (2, 1)) - R = reshape(P, :) - @allowscalar begin - @test_nowarn R[1:2] .= zero(T) - # Check the full assigned range. - @test all(Array(R)[1:2] .== zero(T)) - end + @testset "reshape(view) with $T" for T in eltypes + @test compare(AT, ones(T, 4, 4, 4)) do A + R = reshape(view(A, 1:2, 1:2, 1:2), :) + R[:] = fill(T(2), length(R)) + A end + end - @testset "Reshaped Reinterpreted" begin - T_base = real(T) - if T <: Complex - A = AT(ones(T, 4, 4)) - IT = Complex{Int16} - R = reshape(reinterpret(IT, A), :) - @allowscalar begin - @test_nowarn R[1:2] .= zero(IT) - @test all(Array(R)[1:2] .== zero(IT)) - end - end + @testset "reshape(PermutedDimsArray) with $T" for T in eltypes + @test compare(AT, ones(T, 4, 4)) do A + R = reshape(PermutedDimsArray(A, (2, 1)), :) + R[1:2] = fill(zero(T), 2) + A end end - @testset "Data parity with compare() — GPU safe" for T in eltypes - idx = 2:4 - @test compare(AT, rand(T, 8, 8, 8)) do A - # compare() handles CPU/GPU execution no @allowscalar needed here - V = view(A, :, idx, :) - V .+= one(T) + @testset "reshape(reinterpret) with $T" for T in eltypes + T <: Complex || continue + IT = Complex{Int16} + @test compare(AT, ones(T, 4, 4)) do A + R = reshape(reinterpret(IT, A), :) + R[1:2] = fill(zero(IT), 2) A end end -end \ No newline at end of file +end From 022292765f07339e132d2ab8c0a2983b6b9dd45a Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 15 May 2026 14:39:01 +0200 Subject: [PATCH 3/3] Use isbits-friendly indices in #587 regression test. The original `BitVector` view caused `findall` to embed a `Vector{Int}` inside the SubArray's indices, which OpenCL's broadcast kernel rejects as non-isbits. A strided view exercises the same `IndexCartesian` + `_maybe_reshape` dispatch path with `StepRange` indices instead. Co-Authored-By: Claude Opus 4.7 --- test/testsuite/indexing.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index c6bb2e76..8a336291 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -286,14 +286,15 @@ end end @testsuite "indexing reshaped wrappers" (AT, eltypes) -> begin - # Regression for #587: `Q[:] = …` where `Q = @view P[…]` of a GPU array used to - # throw a `MethodError` due to ambiguous `_unsafe_setindex!` on the - # `ReshapedArray{…,<:WrappedGPUArray}` that `_maybe_reshape` produces. + # Regression for #587: `Q[:] = …` on an `IndexCartesian` SubArray of a GPU array + # threw a `MethodError` due to ambiguous `_unsafe_setindex!` on the + # `ReshapedArray{…,<:WrappedGPUArray}` that `_maybe_reshape` produces. A + # strided view triggers the same dispatch path as the original `BitVector` + # MWE without depending on non-isbits indices that some backends reject. @testset "issue #587 with $T" for T in eltypes - @test compare(AT, ones(T, 16, 16, 16)) do P - active = (1:16) .< 12 - Q = @view P[:, :, active] - Q[:] = Q .+ one(T) + @test compare(AT, ones(T, 8, 8, 8)) do P + Q = view(P, 1:2:7, :, :) + Q[:] = fill(T(2), length(Q)) P end end