diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index e5cbaffe..85d9ecdc 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -4,6 +4,11 @@ using Base.Broadcast using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate +# split `ComposedFunction` so the kernel closure doesn't hit its kwarg-accepting +# call, whose kwsorter dispatch GPUCompiler can't resolve statically. +@inline Broadcast.broadcasted(S::AbstractGPUArrayStyle, c::ComposedFunction, args...) = + Broadcast.broadcasted(S, c.outer, Broadcast.broadcasted(S, c.inner, args...)) + # but make sure we don't dispatch to the optimized copy method that directly indexes function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}}) ElType = Broadcast.combine_eltypes(bc.f, bc.args) diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl index df7cb013..e2a2d994 100644 --- a/test/testsuite/broadcasting.jl +++ b/test/testsuite/broadcasting.jl @@ -2,6 +2,7 @@ broadcasting(AT, eltypes) vec3(AT, eltypes) unknown_wrapper(AT, eltypes) + composed_function(AT, eltypes) end test_idx(idx, A::AbstractArray{T}) where T = A[idx] * T(2) @@ -228,3 +229,21 @@ function unknown_wrapper(AT, eltypes) end end end + +function composed_function(AT, eltypes) + sq(x) = x*x + for ET in eltypes + @testset "ComposedFunction $ET" begin + a = AT(rand(ET, 8)) + b = AT(rand(ET, 8)) + ca, cb = Array(a), Array(b) + + @test Array(broadcast(sq ∘ (+), a, b)) ≈ (ca .+ cb).^2 + @test Array((sq ∘ (+)).(a, b)) ≈ (ca .+ cb).^2 + @test Array((sq ∘ sq ∘ (+)).(a, b)) ≈ ((ca .+ cb).^2).^2 + @test Array((sq ∘ identity).(a)) ≈ ca.^2 + @test Array((sq ∘ (+)).(a, Ref(ET(2)))) ≈ (ca .+ ET(2)).^2 + @test Array((identity ∘ (-)).(a, b)) ≈ ca .- cb + end + end +end