From 2ac41d2a525cebaacd3ceebdcd9308a2da02fc20 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Wed, 6 May 2026 05:28:05 +0000 Subject: [PATCH 1/2] broadcast: split ComposedFunction on AbstractGPUArrayStyle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `(f ∘ g).(args...)` is rewritten to `f.(g.(args...))` for any GPU broadcast style. Semantics are unchanged — broadcast fusion produces the same per-element call — but the kernel closure no longer carries a `ComposedFunction` value, so its `(c)(args...; kw...)` entry never has to be resolved by GPUCompiler. That kwsorter dispatch is what fails: with a non-trivial inner body (e.g. `NNlib.tanh_fast`'s polynomial branch with `sign`/`ifelse`) GPUCompiler cannot statically resolve `var"#_#NN"(kw, c, args...)` and emits `InvalidIRError: unsupported dynamic function invocation`. Reproducer (pre-fix, CUDA + NNlib): a = CUDA.rand(Float32, 5); b = CUDA.rand(Float32, 5) broadcast(NNlib.tanh_fast ∘ (+), a, b) # InvalidIRError broadcast(tanh ∘ (+), a, b) # OK NNlib.tanh_fast.(a .+ b) # OK (split form) The same shape on Metal was patched in NNlib v0.9.32 with a per-function `@device_override` for `tanh_fast`. Doing it at the broadcast layer here covers every `f ∘ g` (and arbitrary depth via right-associative recursion) for every GPU backend, instead of chasing fast activations one by one. Verified: - 920/920 tests pass on JLArray and Array backends. - Original `tanh_fast ∘ (+)` MWE compiles and matches CPU on H100 (CUDA 13.2, NNlib 0.9.34, Julia 1.11.9). --- src/host/broadcast.jl | 10 ++++++++++ test/testsuite/broadcasting.jl | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index e5cbaffe..92e58f5a 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -4,6 +4,16 @@ using Base.Broadcast using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate +# Rewrite `(f ∘ g).(args...)` as `f.(g.(args...))` whenever the broadcast style is +# a GPU style. The fused `Broadcasted` tree is identical in semantics, but the +# kernel closure no longer has to call `ComposedFunction`'s `(c)(args...; kw...)` +# entry — which forces a kwsorter dispatch GPUCompiler cannot resolve statically +# when the inner function has a non-trivial body (e.g. `NNlib.tanh_fast`'s +# polynomial+sign branch produces an `InvalidIRError` on CUDA). The CPU +# broadcast path is unaffected because dispatch is gated on `AbstractGPUArrayStyle`. +@inline Broadcast.broadcasted(S::AbstractGPUArrayStyle, c::ComposedFunction, args...) = + Broadcast.broadcasted(S, c.outer, Broadcast.broadcasted(S, c.inner, args...)) + # but make sure we don't dispatch to the optimized copy method that directly indexes function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}}) ElType = Broadcast.combine_eltypes(bc.f, bc.args) diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl index df7cb013..81e0a705 100644 --- a/test/testsuite/broadcasting.jl +++ b/test/testsuite/broadcasting.jl @@ -2,6 +2,7 @@ broadcasting(AT, eltypes) vec3(AT, eltypes) unknown_wrapper(AT, eltypes) + composed_function(AT, eltypes) end test_idx(idx, A::AbstractArray{T}) where T = A[idx] * T(2) @@ -228,3 +229,25 @@ function unknown_wrapper(AT, eltypes) end end end + +# `(f ∘ g).(args...)` is rewritten to `f.(g.(args...))` for AbstractGPUArrayStyle so +# that the kernel closure never carries a `ComposedFunction` value — the latter +# forces a kwsorter dispatch that GPUCompiler cannot resolve when the inner body +# has non-trivial control flow (e.g. NNlib.tanh_fast). On CPU this is a no-op. +function composed_function(AT, eltypes) + sq(x) = x*x + for ET in eltypes + @testset "ComposedFunction $ET" begin + a = AT(rand(ET, 8)) + b = AT(rand(ET, 8)) + ca, cb = Array(a), Array(b) + + @test Array(broadcast(sq ∘ (+), a, b)) ≈ (ca .+ cb).^2 + @test Array((sq ∘ (+)).(a, b)) ≈ (ca .+ cb).^2 + @test Array((sq ∘ sq ∘ (+)).(a, b)) ≈ ((ca .+ cb).^2).^2 + @test Array((sq ∘ identity).(a)) ≈ ca.^2 + @test Array((sq ∘ (+)).(a, Ref(ET(2)))) ≈ (ca .+ ET(2)).^2 + @test Array((identity ∘ (-)).(a, b)) ≈ ca .- cb + end + end +end From e7be961e896c6aa0174e6bddd403faf7a6e13a94 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 14 May 2026 13:22:40 +0200 Subject: [PATCH 2/2] minimize comment. [ci skip] --- src/host/broadcast.jl | 9 ++------- test/testsuite/broadcasting.jl | 4 ---- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 92e58f5a..85d9ecdc 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -4,13 +4,8 @@ using Base.Broadcast using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate -# Rewrite `(f ∘ g).(args...)` as `f.(g.(args...))` whenever the broadcast style is -# a GPU style. The fused `Broadcasted` tree is identical in semantics, but the -# kernel closure no longer has to call `ComposedFunction`'s `(c)(args...; kw...)` -# entry — which forces a kwsorter dispatch GPUCompiler cannot resolve statically -# when the inner function has a non-trivial body (e.g. `NNlib.tanh_fast`'s -# polynomial+sign branch produces an `InvalidIRError` on CUDA). The CPU -# broadcast path is unaffected because dispatch is gated on `AbstractGPUArrayStyle`. +# split `ComposedFunction` so the kernel closure doesn't hit its kwarg-accepting +# call, whose kwsorter dispatch GPUCompiler can't resolve statically. @inline Broadcast.broadcasted(S::AbstractGPUArrayStyle, c::ComposedFunction, args...) = Broadcast.broadcasted(S, c.outer, Broadcast.broadcasted(S, c.inner, args...)) diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl index 81e0a705..e2a2d994 100644 --- a/test/testsuite/broadcasting.jl +++ b/test/testsuite/broadcasting.jl @@ -230,10 +230,6 @@ function unknown_wrapper(AT, eltypes) end end -# `(f ∘ g).(args...)` is rewritten to `f.(g.(args...))` for AbstractGPUArrayStyle so -# that the kernel closure never carries a `ComposedFunction` value — the latter -# forces a kwsorter dispatch that GPUCompiler cannot resolve when the inner body -# has non-trivial control flow (e.g. NNlib.tanh_fast). On CPU this is a no-op. function composed_function(AT, eltypes) sq(x) = x*x for ET in eltypes