@@ -21,6 +21,7 @@ using CUDA
2121using CUDA: AS
2222using BenchmarkTools
2323using PythonCall
24+ using Lux, Zygote
2425
2526# -------------------------------------------------------------------------
2627# Hardcoded CUDA.jl path
@@ -146,18 +147,25 @@ function _gemm_simt!(C::CuArray{Float32,2}, transA::Char, A::CuArray{Float32,2},
146147 ldb = max (1 , stride (B, 2 ))
147148 ldc = max (1 , stride (C, 2 ))
148149 # CUDA.jl puts the cuBLAS handle in CUBLAS_POINTER_MODE_DEVICE, so alpha/beta
149- # MUST be device pointers. Passing host Ref{Float32} causes UVA fault handling
150- # per kernel launch (~100× slowdown but eventually-correct values).
150+ # MUST be device pointers (host Ref triggers UVA fault handling — 100× slowdown).
151151 α = CUDA. CuRef {Float32} (alpha); β = CUDA. CuRef {Float32} (beta)
152- CUDA. CUBLAS. cublasGemmEx (
153- CUDA. CUBLAS. handle (),
154- transA, transB, m, n, k,
155- α, A, Float32, lda,
156- B, Float32, ldb,
157- β, C, Float32, ldc,
158- CUDA. CUBLAS. CUBLAS_COMPUTE_32F,
159- CUDA. CUBLAS. CUBLAS_GEMM_DEFAULT,
160- )
152+ h = CUDA. CUBLAS. handle ()
153+ # Under FAST_MATH the handle's math mode is CUBLAS_TF32_TENSOR_OP_MATH, which
154+ # forces TF32 tensor cores even when we ask for CUBLAS_COMPUTE_32F. Flip it to
155+ # DEFAULT_MATH for this call so cuBLAS picks a SIMT FP32 kernel.
156+ CUDA. CUBLAS. math_mode! (h, CUDA. DEFAULT_MATH)
157+ try
158+ CUDA. CUBLAS. cublasGemmEx (
159+ h, transA, transB, m, n, k,
160+ α, A, Float32, lda,
161+ B, Float32, ldb,
162+ β, C, Float32, ldc,
163+ CUDA. CUBLAS. CUBLAS_COMPUTE_32F,
164+ CUDA. CUBLAS. CUBLAS_GEMM_DEFAULT,
165+ )
166+ finally
167+ CUDA. CUBLAS. math_mode! (h, CUDA. math_mode ()) # restore (FAST_MATH → TF32 tensor op)
168+ end
161169 return C
162170end
163171
@@ -185,6 +193,38 @@ function reverse_diff_v5(W1, W2, X, y)
185193 return result
186194end
187195
196+ # -------------------------------------------------------------------------
197+ # Lux + Zygote path
198+ #
199+ # Builds an equivalent 2-layer MLP `Y = W2 * tanh(W1 * X)` (no bias) using
200+ # Lux, plugs in the *same* CuArray weights so the gradient is comparable,
201+ # and lets Zygote source-to-source the backward. This goes through the same
202+ # CUDA.jl + cuBLAS stack as `reverse_diff`, so we expect similar kernels —
203+ # the interesting thing is the AD/dispatch overhead Lux+Zygote add on top.
204+ # -------------------------------------------------------------------------
205+ function build_lux (W1g:: CuArray{Float32,2} , W2g:: CuArray{Float32,2} )
206+ h, d = size (W1g)
207+ model = Lux. Chain (
208+ Lux. Dense (d => h, tanh; use_bias = false ),
209+ Lux. Dense (h => 1 , identity; use_bias = false ),
210+ )
211+ ps = (
212+ layer_1 = (weight = W1g,),
213+ layer_2 = (weight = W2g,),
214+ )
215+ st = Lux. initialstates (Random. default_rng (), model)
216+ return model, ps, st
217+ end
218+
219+ function lux_grad (model, ps, st, Xg:: CuArray , yg:: CuArray )
220+ function loss_fn (p)
221+ y_hat, _ = model (Xg, p, st)
222+ return sum ((y_hat .- yg) .^ 2 ) / size (yg, 2 )
223+ end
224+ ∂ps = first (Zygote. gradient (loss_fn, ps))
225+ return ∂ps. layer_1. weight
226+ end
227+
188228# -------------------------------------------------------------------------
189229# PyTorch path
190230# -------------------------------------------------------------------------
@@ -314,6 +354,17 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
314354 grad_julia_v5 = Array (reverse_diff_v5 (W1g, W2g, Xg, yg))
315355 CUDA. synchronize ()
316356
357+ # Lux + Zygote warmup (first call compiles Zygote's pullback for this shape)
358+ print (" Lux+Zygote compile warmup for h=$h ... " ); flush (stdout )
359+ lux_model, lux_ps, lux_st = build_lux (W1g, W2g)
360+ t_lux_compile = @elapsed begin
361+ lux_grad (lux_model, lux_ps, lux_st, Xg, yg)
362+ CUDA. synchronize ()
363+ end
364+ @printf " %.2f s\n " t_lux_compile
365+ grad_lux = Array (lux_grad (lux_model, lux_ps, lux_st, Xg, yg))
366+ CUDA. synchronize ()
367+
317368 # ----- PyTorch -----
318369 W1t, W2t, Xt, yt = build_torch_tensors (W1, W2, X, y)
319370 grad_pytorch_eager = torch_to_julia (pytorch_grad_eager (W1t, W2t, Xt, yt))
@@ -333,6 +384,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
333384 # ----- Numerical equivalence -----
334385 for (name, g) in [(" Julia v4 (vec=4) " , grad_julia_v4),
335386 (" Julia v5 (vec=4+SIMT)" , grad_julia_v5),
387+ (" Lux + Zygote " , grad_lux),
336388 (" PyTorch eager " , grad_pytorch_eager),
337389 (" PyTorch compiled " , grad_pytorch_compiled)]
338390 maxdiff = maximum (abs .(grad_julia .- g))
@@ -357,6 +409,10 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
357409 reverse_diff_v5 ($ W1g, $ W2g, $ Xg, $ yg)
358410 CUDA. synchronize ()
359411 end samples= 30 evals= 1 seconds= 10
412+ bjlux = @benchmark begin
413+ lux_grad ($ lux_model, $ lux_ps, $ lux_st, $ Xg, $ yg)
414+ CUDA. synchronize ()
415+ end samples= 30 evals= 1 seconds= 10
360416 be = @benchmark begin
361417 pytorch_grad_eager ($ W1t, $ W2t, $ Xt, $ yt)
362418 $ torch. cuda. synchronize ()
@@ -368,6 +424,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
368424 @printf " Julia broadcast : median %8.3f µs\n " 1e-3 * median (bj). time
369425 @printf " Julia vec=4 : median %8.3f µs\n " 1e-3 * median (bj4). time
370426 @printf " Julia vec=4 + SIMT : median %8.3f µs\n " 1e-3 * median (bj5). time
427+ @printf " Lux + Zygote : median %8.3f µs\n " 1e-3 * median (bjlux). time
371428 @printf " PyTorch eager : median %8.3f µs\n " 1e-3 * median (be). time
372429 @printf " PyTorch compiled : median %8.3f µs\n " 1e-3 * median (bc). time
373430
@@ -381,6 +438,9 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
381438 println (" \n --- CUDA trace: Julia vec=4 + SIMT ---" )
382439 summarize_julia_trace (stdout , julia_trace (() -> reverse_diff_v5 (W1g, W2g, Xg, yg)))
383440
441+ println (" \n --- CUDA trace: Lux + Zygote ---" )
442+ summarize_julia_trace (stdout , julia_trace (() -> lux_grad (lux_model, lux_ps, lux_st, Xg, yg)))
443+
384444 println (" \n --- CUDA trace: PyTorch eager ---" )
385445 println (pytorch_trace (() -> pytorch_grad_eager (W1t, W2t, Xt, yt)))
386446
0 commit comments