From 3a86ed28d62b9907d2e0ac69b8f087bb3a946585 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 22 Apr 2026 14:10:31 +0200 Subject: [PATCH 1/7] Trim reduce/mapreduce corner-case loops. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The nested `isize, jsize, ksize ∈ 0:3` loops in the AccelerateKernels corner-case checks produced 64 shape combinations × 15 dim patterns = 960 `@test compare` invocations per testset (1920 between reduce and mapreduce). Testing with `(0, 3)` for each axis keeps coverage of empty (0) and non-singleton (3) dims while cutting the shape count 8×. The preceding size-10 loop already exercises the common non-edge shape. Co-Authored-By: Claude Opus 4.7 (1M context) --- test/testsuite/reductions.jl | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index 7f135df6..b7587c9f 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -56,14 +56,11 @@ end end end # Test more corner cases. Tests from AcceleraterKernels.jl - for dims in [1,2,3,4,[1,2],[1,3],[1,4],[2,3],[2,4],[3,4],[1,2,3],[1,2,4],[1,3,4],[2,3,4],[1,2,3,4]] - for isize in 0:3 - for jsize in 0:3 - for ksize in 0:3 - @test compare(A->mapreduce(x->x+x, +, A; init=zero(Int32), dims), AT, rand(Int32(1):Int32(10), isize, jsize, ksize)) - end - end - end + # Cover empty (size 0) and non-singleton (size 3) axes; the size-10 loop above + # already covers the common non-edge shape. + for dims in [1,2,3,4,[1,2],[1,3],[1,4],[2,3],[2,4],[3,4],[1,2,3],[1,2,4],[1,3,4],[2,3,4],[1,2,3,4]], + isize in (0, 3), jsize in (0, 3), ksize in (0, 3) + @test compare(A->mapreduce(x->x+x, +, A; init=zero(Int32), dims), AT, rand(Int32(1):Int32(10), isize, jsize, ksize)) end end @@ -84,14 +81,11 @@ end end end # Test more corner cases. Tests from AcceleraterKernels.jl - for dims in [1,2,3,4,[1,2],[1,3],[1,4],[2,3],[2,4],[3,4],[1,2,3],[1,2,4],[1,3,4],[2,3,4],[1,2,3,4]] - for isize in 0:3 - for jsize in 0:3 - for ksize in 0:3 - @test compare(A->reduce(+, A; init=zero(Int32), dims), AT, rand(Int32(1):Int32(10), isize, jsize, ksize)) - end - end - end + # Cover empty (size 0) and non-singleton (size 3) axes; the size-10 loop above + # already covers the common non-edge shape. + for dims in [1,2,3,4,[1,2],[1,3],[1,4],[2,3],[2,4],[3,4],[1,2,3],[1,2,4],[1,3,4],[2,3,4],[1,2,3,4]], + isize in (0, 3), jsize in (0, 3), ksize in (0, 3) + @test compare(A->reduce(+, A; init=zero(Int32), dims), AT, rand(Int32(1):Int32(10), isize, jsize, ksize)) end end From d9a024841c92c7caa83340c5cf8e12a0bcbcc649 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 22 Apr 2026 14:13:27 +0200 Subject: [PATCH 2/7] Factor dim-independent reductions out of the dim loop. In sum/prod and minimum/maximum/extrema testsets, the whole-array variants (`sum(A)`, `minimum(A)`, ...) were called inside the `(sz, dims)` loop even though they only depend on `sz`. For the dim-heavy (10,10,10) cases that meant running the same reduction 8 times per eltype. Run the dim-independent variants once per unique shape. Also hoist the `ET <: Complex` check out of the min/max/extrema body. Co-Authored-By: Claude Opus 4.7 (1M context) --- test/testsuite/reductions.jl | 53 +++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index b7587c9f..f663dbee 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -92,14 +92,11 @@ end @testsuite "reductions/sum prod" (AT, eltypes)->begin @testset "$ET" for ET in eltypes range = ET <: Real ? (ET(1):ET(10)) : ET - for (sz,dims) in [(10,)=>[1], (10,10)=>[1,2], (10,10,10)=>[1,2,3], (10,10,10)=>[], - (10,)=>:, (10,10)=>:, (10,10,10)=>:, - (10,10,10)=>[1], (10,10,10)=>[2], (10,10,10)=>[3], - (0,)=>[1]] + + # whole-array reductions: exercise each unique shape only once + for sz in ((10,), (10,10), (10,10,10), (0,)) @test compare(A->sum(A), AT, rand(range, sz)) - @test compare(A->sum(A; dims=dims), AT, rand(range, sz)) @test compare(A->prod(A), AT, rand(range, sz)) - @test compare(A->prod(A; dims=dims), AT, rand(range, sz)) if typeof(abs(rand(range))) in eltypes # abs(::Complex{Int}) promotes to Float64 @test compare(A->sum(abs, A), AT, rand(range, sz)) @@ -107,6 +104,15 @@ end end end + # reductions along specific dims + for (sz,dims) in [(10,)=>[1], (10,10)=>[1,2], (10,10,10)=>[1,2,3], (10,10,10)=>[], + (10,)=>:, (10,10)=>:, (10,10,10)=>:, + (10,10,10)=>[1], (10,10,10)=>[2], (10,10,10)=>[3], + (0,)=>[1]] + @test compare(A->sum(A; dims=dims), AT, rand(range, sz)) + @test compare(A->prod(A; dims=dims), AT, rand(range, sz)) + end + if ET in (Float32, Float64, Int64, ComplexF32, ComplexF64) # smaller-scale test to avoid very large values and roundoff issues for (sz,red) in [(2,)=>(1,), (2,2)=>(1,1), (2,2,2)=>(1,1,1), (2,2,2)=>(2,2,2), @@ -120,30 +126,33 @@ end @testsuite "reductions/minimum maximum extrema" (AT, eltypes)->begin @testset "$ET" for ET in eltypes + ET <: Complex && continue range = ET <: Real ? (ET(1):ET(10)) : ET + + # whole-array reductions: exercise each unique shape only once + for sz in ((10,), (10,10), (10,10,10)) + @test compare(A->minimum(A), AT, rand(range, sz)) + @test compare(A->minimum(x->x*x, A), AT, rand(range, sz)) + @test compare(A->maximum(A), AT, rand(range, sz)) + @test compare(A->maximum(x->x*x, A), AT, rand(range, sz)) + @test compare(A->extrema(A), AT, rand(range, sz)) + @test compare(A->extrema(x->x*x, A), AT, rand(range, sz)) + end + + # reductions along specific dims for (sz,dims) in [(10,)=>[1], (10,10)=>[1,2], (10,10,10)=>[1,2,3], (10,10,10)=>[], (10,)=>:, (10,10)=>:, (10,10,10)=>:, (10,10,10)=>[1], (10,10,10)=>[2], (10,10,10)=>[3]] - if !(ET <: Complex) - @test compare(A->minimum(A), AT, rand(range, sz)) - @test compare(A->minimum(x->x*x, A), AT, rand(range, sz)) - @test compare(A->minimum(A; dims=dims), AT, rand(range, sz)) - @test compare(A->maximum(A), AT, rand(range, sz)) - @test compare(A->maximum(x->x*x, A), AT, rand(range, sz)) - @test compare(A->maximum(A; dims=dims), AT, rand(range, sz)) - @test compare(A->extrema(A), AT, rand(range, sz)) - @test compare(A->extrema(x->x*x, A), AT, rand(range, sz)) - @test compare(A->extrema(A; dims=dims), AT, rand(range, sz)) - end + @test compare(A->minimum(A; dims=dims), AT, rand(range, sz)) + @test compare(A->maximum(A; dims=dims), AT, rand(range, sz)) + @test compare(A->extrema(A; dims=dims), AT, rand(range, sz)) end for (sz,red) in [(10,)=>(1,), (10,10)=>(1,1), (10,10,10)=>(1,1,1), (10,10,10)=>(10,10,10), (10,10,10)=>(1,10,10), (10,10,10)=>(10,1,10), (10,10,10)=>(10,10,1)] - if !(ET <: Complex) - @test compare((A,R)->minimum!(R, A), AT, rand(range, sz), fill(typemax(ET), red)) - @test compare((A,R)->maximum!(R, A), AT, rand(range, sz), fill(typemin(ET), red)) - @test compare((A,R)->extrema!(R, A), AT, rand(range, sz), fill((typemax(ET),typemin(ET)), red)) - end + @test compare((A,R)->minimum!(R, A), AT, rand(range, sz), fill(typemax(ET), red)) + @test compare((A,R)->maximum!(R, A), AT, rand(range, sz), fill(typemin(ET), red)) + @test compare((A,R)->extrema!(R, A), AT, rand(range, sz), fill((typemax(ET),typemin(ET)), red)) end end end From db47e4e29f452b43984697b3419b3e9a31a0575c Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 22 Apr 2026 14:17:36 +0200 Subject: [PATCH 3/7] Drop redundant size variants in random test. The rand/randn tests iterated `d in (2, (2,2), (2,2,2), 3, (3,3))` to cover shape variations. Sizes 2 vs 3 and (2,2) vs (3,3) exercise the same code paths, so keep one from each ndims group. Co-Authored-By: Claude Opus 4.7 (1M context) --- test/testsuite/random.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/testsuite/random.jl b/test/testsuite/random.jl index fcda79b5..2cd65a3c 100644 --- a/test/testsuite/random.jl +++ b/test/testsuite/random.jl @@ -6,7 +6,7 @@ end @testset "rand" begin # uniform - @testset "$T $d" for T in eltypes, d in (2, (2,2), (2,2,2), 3, (3,3)) + @testset "$T $d" for T in eltypes, d in (2, (2,2), (2,2,2)) A = AT{T}(undef, d) B = copy(A) rand!(rng, A) @@ -31,7 +31,7 @@ @testset "randn" begin # normally-distributed @testset "$T $d" for T in filter(isrealfloattype, eltypes), - d in (2, (2,2), (2,2,2), 3, (3,3)) + d in (2, (2,2), (2,2,2)) A = AT{T}(undef, d) B = copy(A) randn!(rng, A) From 7ffd5095896c49567ddc1a0bf90f80887b83cc85 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 22 Apr 2026 14:55:52 +0200 Subject: [PATCH 4/7] @nospecialize host-level forwarders and error renderers. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Skip type-specialization on arguments that the function only forwards or renders — none of these do per-type compute: - show/print_array/show_vector forwarders all immediately adapt to Array before deferring to Base's show machinery. - Serialization.serialize reads the array on the CPU side. - Reinterpret showerror methods just format a message. - contains_eltype walks Julia types recursively. - deepcopy_internal and the Tuple-source append! are one-liners. Kernel-path functions (broadcast, mapreduce, linalg) still fully specialize. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/host/abstractarray.jl | 19 ++++++++++--------- src/host/base.jl | 34 +++++++++++++++++++++------------- src/host/construction.jl | 2 +- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 0080935e..15543e31 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -121,8 +121,8 @@ unsafe_free!(x::AbstractGPUArray) = unsafe_free!(storage(x)) using Serialization: AbstractSerializer, serialize_type -function Serialization.serialize(s::AbstractSerializer, t::T) where T <: AbstractGPUArray - serialize_type(s, T) +function Serialization.serialize(s::AbstractSerializer, @nospecialize(t::AbstractGPUArray)) + serialize_type(s, typeof(t)) serialize(s, Array(t)) end @@ -136,16 +136,17 @@ end struct ToArray end Adapt.adapt_storage(::ToArray, xs::AbstractGPUArray) = convert(Array, xs) -# display -Base.print_array(io::IO, X::AnyGPUArray) = +# display: show is called on the materialised CPU copy, so no need to +# specialize the forwarders per element type / wrapper. +Base.print_array(io::IO, @nospecialize(X::AnyGPUArray)) = Base.print_array(io, adapt(ToArray(), X)) # show -Base._show_nonempty(io::IO, X::AnyGPUArray, prefix::String) = +Base._show_nonempty(io::IO, @nospecialize(X::AnyGPUArray), prefix::String) = Base._show_nonempty(io, adapt(ToArray(), X), prefix) -Base._show_empty(io::IO, X::AnyGPUArray) = +Base._show_empty(io::IO, @nospecialize(X::AnyGPUArray)) = Base._show_empty(io, adapt(ToArray(), X)) -Base.show_vector(io::IO, v::AnyGPUArray, args...) = +Base.show_vector(io::IO, @nospecialize(v::AnyGPUArray), args...) = Base.show_vector(io, adapt(ToArray(), v), args...) ## collect to CPU (discarding wrapper type) @@ -324,7 +325,7 @@ end Base.copy(x::AbstractGPUArray) = error("Not implemented") # COV_EXCL_LINE -Base.deepcopy_internal(x::AbstractGPUArray, ::IdDict) = copy(x) +Base.deepcopy_internal(@nospecialize(x::AbstractGPUArray), ::IdDict) = copy(x) # filtering @@ -345,7 +346,7 @@ end # this is needed because copyto! of most GPU arrays # doesn't currently support Tuple sources -function Base.append!(a::AbstractGPUVector, items::Tuple) +function Base.append!(a::AbstractGPUVector, @nospecialize(items::Tuple)) append!(a, collect(items)) return a end diff --git a/src/host/base.jl b/src/host/base.jl index f929ea5a..0257ac17 100644 --- a/src/host/base.jl +++ b/src/host/base.jl @@ -234,32 +234,40 @@ function _reinterpret_exception(::Type{T}, a::AbstractArray{S,N}) where {T,S,N} end struct ReinterpretBitsTypeError{T,A} <: Exception end -function Base.showerror(io::IO, ::ReinterpretBitsTypeError{T, <:AbstractArray{S}}) where {T, S} - print(io, "cannot reinterpret an `$(S)` array to `$(T)`, because not all types are bitstypes") +function Base.showerror(io::IO, @nospecialize(err::ReinterpretBitsTypeError)) + T, A = typeof(err).parameters + S = eltype(A) + print(io, "cannot reinterpret an `$(S)` array to `$(T)`, because not all types are bitstypes") end struct ReinterpretZeroDimError{T,A} <: Exception end -function Base.showerror(io::IO, ::ReinterpretZeroDimError{T, <:AbstractArray{S,N}}) where {T, S, N} - print(io, "cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size") +function Base.showerror(io::IO, @nospecialize(err::ReinterpretZeroDimError)) + T, A = typeof(err).parameters + S = eltype(A) + print(io, "cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size") end struct ReinterpretDivisibilityError{T,A} <: Exception dim::Int end -function Base.showerror(io::IO, err::ReinterpretDivisibilityError{T, <:AbstractArray{S,N}}) where {T, S, N} - dim = err.dim - print(io, """ - cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`. - The resulting array would have non-integral first dimension. - """) +function Base.showerror(io::IO, @nospecialize(err::ReinterpretDivisibilityError)) + T, A = typeof(err).parameters + S = eltype(A) + dim = err.dim + print(io, """ + cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`. + The resulting array would have non-integral first dimension. + """) end struct ReinterpretFirstIndexError{T,A,Ax1} <: Exception ax1::Ax1 end -function Base.showerror(io::IO, err::ReinterpretFirstIndexError{T, <:AbstractArray{S,N}}) where {T, S, N} - ax1 = err.ax1 - print(io, "cannot reinterpret a `$(S)` array to `$(T)` when the first axis is $ax1. Try reshaping first.") +function Base.showerror(io::IO, @nospecialize(err::ReinterpretFirstIndexError)) + T, A, _ = typeof(err).parameters + S = eltype(A) + ax1 = err.ax1 + print(io, "cannot reinterpret a `$(S)` array to `$(T)` when the first axis is $ax1. Try reshaping first.") end diff --git a/src/host/construction.jl b/src/host/construction.jl index 454ff3d9..2025a97b 100644 --- a/src/host/construction.jl +++ b/src/host/construction.jl @@ -88,7 +88,7 @@ function hasfieldcount(@nospecialize(dt)) end # for finding specific element types, e.g., when Float64 is unsupported -function contains_eltype(T, typ) +function contains_eltype(@nospecialize(T), @nospecialize(typ)) if T === typ return true elseif T isa Union From c26ac7cabbf23c9f6e034e0dbbe93bd0fbd9eed9 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 22 Apr 2026 15:04:00 +0200 Subject: [PATCH 5/7] Use map over the kernel args tuple instead of broadcast. \`jlconvert.(args)\` dispatches through \`Base.Broadcast.broadcasted\` and its helpers, which specialize per unique arg-tuple type: for each kernel-launch signature, the trace shows ~200 extra precompile events (broadcasted, ntuple, and a duplicate jlconvert specialization). \`map(jlconvert, args)\` lands at the same result with a much thinner specialization chain. On the broadcasting testsuite (JLArray), this drops total trace-compile events ~10 % and wall-clock from 1m50s to 1m31s. Co-Authored-By: Claude Opus 4.7 (1M context) --- lib/JLArrays/src/JLArrays.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 3557b63c..bf671ff2 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -615,7 +615,10 @@ end function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothing) ndrange, workgroupsize, _, _ = launch_config(obj, ndrange, workgroupsize) - device_args = jlconvert.(args) + # Use `map` rather than `jlconvert.(args)` to skip the broadcast + # machinery (broadcasted/materialize/ntuple) that would otherwise + # be specialized per unique arg-tuple type. + device_args = map(jlconvert, args) new_obj = convert_to_cpu(obj) new_obj(device_args...; ndrange, workgroupsize) end From d9be81e629b44ec0c4c2dd2a9c5845f4060577f9 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 22 Apr 2026 15:11:14 +0200 Subject: [PATCH 6/7] Revert @nospecialize on reinterpret showerror methods. The trace showed these hit a handful of unique error types at most, so the change saves a negligible amount of specialization. Meanwhile it required replacing the clean type-destructuring signature with a runtime \`typeof(err).parameters\` lookup. Not worth the uglier body. The remaining @nospecialize additions (print_array/_show_*, Serialization, contains_eltype, deepcopy_internal, append!) stay: each is a forwarder or type-level helper where specialization gives nothing useful, matching Julia Base's pattern for print_matrix_row / hasfieldcount / boot.jl constructors. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/host/base.jl | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/host/base.jl b/src/host/base.jl index 0257ac17..f929ea5a 100644 --- a/src/host/base.jl +++ b/src/host/base.jl @@ -234,40 +234,32 @@ function _reinterpret_exception(::Type{T}, a::AbstractArray{S,N}) where {T,S,N} end struct ReinterpretBitsTypeError{T,A} <: Exception end -function Base.showerror(io::IO, @nospecialize(err::ReinterpretBitsTypeError)) - T, A = typeof(err).parameters - S = eltype(A) - print(io, "cannot reinterpret an `$(S)` array to `$(T)`, because not all types are bitstypes") +function Base.showerror(io::IO, ::ReinterpretBitsTypeError{T, <:AbstractArray{S}}) where {T, S} + print(io, "cannot reinterpret an `$(S)` array to `$(T)`, because not all types are bitstypes") end struct ReinterpretZeroDimError{T,A} <: Exception end -function Base.showerror(io::IO, @nospecialize(err::ReinterpretZeroDimError)) - T, A = typeof(err).parameters - S = eltype(A) - print(io, "cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size") +function Base.showerror(io::IO, ::ReinterpretZeroDimError{T, <:AbstractArray{S,N}}) where {T, S, N} + print(io, "cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size") end struct ReinterpretDivisibilityError{T,A} <: Exception dim::Int end -function Base.showerror(io::IO, @nospecialize(err::ReinterpretDivisibilityError)) - T, A = typeof(err).parameters - S = eltype(A) - dim = err.dim - print(io, """ - cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`. - The resulting array would have non-integral first dimension. - """) +function Base.showerror(io::IO, err::ReinterpretDivisibilityError{T, <:AbstractArray{S,N}}) where {T, S, N} + dim = err.dim + print(io, """ + cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`. + The resulting array would have non-integral first dimension. + """) end struct ReinterpretFirstIndexError{T,A,Ax1} <: Exception ax1::Ax1 end -function Base.showerror(io::IO, @nospecialize(err::ReinterpretFirstIndexError)) - T, A, _ = typeof(err).parameters - S = eltype(A) - ax1 = err.ax1 - print(io, "cannot reinterpret a `$(S)` array to `$(T)` when the first axis is $ax1. Try reshaping first.") +function Base.showerror(io::IO, err::ReinterpretFirstIndexError{T, <:AbstractArray{S,N}}) where {T, S, N} + ax1 = err.ax1 + print(io, "cannot reinterpret a `$(S)` array to `$(T)` when the first axis is $ax1. Try reshaping first.") end From 68cc95e9eff6b5d5f45e3bcb70c1e50ae611a263 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 22 Apr 2026 15:32:28 +0200 Subject: [PATCH 7/7] @nospecialize compare and test_result in the testsuite. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit \`compare\` was previously the #1 testsuite-side compile hotspot (322 precompile events in the broadcasting trace) because every unique closure + argument-tuple combination created a fresh specialization — and the function just forwards those to \`deepcopy\`/\`adapt\` and then to a test-level \`≈\`. Nothing inside benefits from per-call specialization. Mark \`f\` and the \`xs\` varargs \`@nospecialize\` on both compare methods; events drop to zero. \`test_result\` had the same issue: 48 precompile events, one per (eltype, ndims) combination, because the original \`where T<:Number\` / \`where T<:NTuple\` dispatch still instantiated per T even under \`@nospecialize\`. Collapse the two \`AbstractArray\` methods into a single unparameterised method that branches on \`eltype(a)\` at runtime; an \`eltype(b) === T\` guard preserves the old "diverging eltypes fall through to \`a == b\`" semantics. One event instead of 48. Broadcasting testset trace events drop ~18 % overall (3750 → 3561 entries in the broadcasting testsuite). Co-Authored-By: Claude Opus 4.7 (1M context) --- test/testsuite.jl | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/test/testsuite.jl b/test/testsuite.jl index 73db8703..e4aea4f7 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -16,26 +16,36 @@ using Test using Adapt -test_result(a, b; kwargs...) = a == b +test_result(@nospecialize(a), @nospecialize(b); kwargs...) = a == b test_result(a::Number, b::Number; kwargs...) = ≈(a, b; kwargs...) test_result(a::Missing, b::Missing; kwargs...) = true test_result(a::Number, b::Missing; kwargs...) = false test_result(a::Missing, b::Number; kwargs...) = false -function test_result(a::AbstractArray{T}, b::AbstractArray{T}; kwargs...) where {T<:Number} - ≈(collect(a), collect(b); kwargs...) -end -function test_result(a::AbstractArray{T}, b::AbstractArray{T}; - kwargs...) where {T<:NTuple{N,<:Number} where {N}} - ET = eltype(T) - ≈(reinterpret(ET, collect(a)), reinterpret(ET, collect(b)); kwargs...) +# Branch on eltype at runtime so one compiled method body handles every +# (T, ndims) combination — the `where T` version would still instantiate +# per element type even under @nospecialize. +function test_result(@nospecialize(a::AbstractArray), @nospecialize(b::AbstractArray); kwargs...) + T = eltype(a) + # The original `where T<:…` methods required matching eltypes; preserve + # that by falling through to `a == b` when they diverge. + if eltype(b) === T + if T <: Number + return ≈(collect(a), collect(b); kwargs...) + elseif T <: NTuple{N,<:Number} where {N} + ET = eltype(T) + return ≈(reinterpret(ET, collect(a)), reinterpret(ET, collect(b)); kwargs...) + end + end + a == b end -function test_result(as::NTuple{N,Any}, bs::NTuple{N,Any}; kwargs...) where {N} +function test_result(@nospecialize(as::Tuple), @nospecialize(bs::Tuple); kwargs...) + length(as) == length(bs) || return false all(zip(as, bs)) do (a, b) test_result(a, b; kwargs...) end end -function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...) +function compare(@nospecialize(f), AT::Type{<:AbstractGPUArray}, @nospecialize(xs...); kwargs...) # copy on the CPU, adapt on the GPU, but keep Ref's cpu_in = map(x -> isa(x, Base.RefValue) ? x[] : deepcopy(x), xs) gpu_in = map(x -> isa(x, Base.RefValue) ? x[] : adapt(AT, x), xs) @@ -46,7 +56,7 @@ function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...) test_result(cpu_out, gpu_out; kwargs...) end -function compare(f, AT::Type{<:Array}, xs...; kwargs...) +function compare(@nospecialize(f), AT::Type{<:Array}, @nospecialize(xs...); kwargs...) # no need to actually run this tests: we have nothing to compare against, # and we'll run it on a CPU array anyhow when comparing to a GPU array. #