@@ -21,7 +21,8 @@ using CUDA
2121using CUDA: AS
2222using BenchmarkTools
2323using PythonCall
24- using Lux, Zygote
24+ using Lux
25+ import Mooncake
2526
2627# -------------------------------------------------------------------------
2728# Hardcoded CUDA.jl path
@@ -194,15 +195,182 @@ function reverse_diff_v5(W1, W2, X, y)
194195end
195196
196197# -------------------------------------------------------------------------
197- # Lux + Zygote path
198+ # v6: vec=4 elementwise + cuBLASLt with per-shape heuristic-picked algo
199+ #
200+ # cuBLAS's standard heuristic, even with CUBLAS_COMPUTE_32F + DEFAULT_MATH,
201+ # picks `cutlass_80_simt_sgemm_*` for our awkward shapes. PyTorch's process
202+ # happens to land on `magma_sgemmEx_kernel` for the same compute type — same
203+ # library, different choice. cuBLASLt exposes a fuller heuristic API with a
204+ # workspace budget that often unlocks better algos. We build a matmul
205+ # descriptor + layouts per (transA, transB, m, n, k) shape, ask cuBLASLt for
206+ # its best algo, and reuse the cached plan on every call.
207+ # -------------------------------------------------------------------------
208+ const _LT_WS_BYTES = Csize_t (32 * 1024 * 1024 ) # 32 MiB workspace
209+
210+ # Lazy: created on first use, kept alive for the process.
211+ const _LT_STATE = Ref {Any} (nothing )
212+
213+ function _lt_state ()
214+ s = _LT_STATE[]
215+ if s === nothing
216+ h_ref = Ref {CUDA.CUBLAS.cublasLtHandle_t} (C_NULL )
217+ CUDA. CUBLAS. cublasLtCreate (h_ref)
218+ ws = CUDA. CuArray {UInt8} (undef, Int (_LT_WS_BYTES))
219+ s = (handle = h_ref[], ws = ws)
220+ _LT_STATE[] = s
221+ end
222+ return s:: NamedTuple{(:handle, :ws)}
223+ end
224+
225+ mutable struct LtPlan
226+ desc:: CUDA.CUBLAS.cublasLtMatmulDesc_t
227+ Adesc:: CUDA.CUBLAS.cublasLtMatrixLayout_t
228+ Bdesc:: CUDA.CUBLAS.cublasLtMatrixLayout_t
229+ Cdesc:: CUDA.CUBLAS.cublasLtMatrixLayout_t
230+ algo:: CUDA.CUBLAS.cublasLtMatmulAlgo_t
231+ end
232+
233+ function _build_lt_plan (transA:: Char , transB:: Char ,
234+ m:: Int , n:: Int , k:: Int ,
235+ lda:: Int , ldb:: Int , ldc:: Int )
236+ state = _lt_state ()
237+ handle = state. handle
238+ R32 = CUDA. CUDACore. R_32F # cudaDataType for Float32
239+
240+ desc_ref = Ref {CUDA.CUBLAS.cublasLtMatmulDesc_t} (C_NULL )
241+ CUDA. CUBLAS. cublasLtMatmulDescCreate (desc_ref, CUDA. CUBLAS. CUBLAS_COMPUTE_32F, R32)
242+ desc = desc_ref[]
243+
244+ # Set transpose attributes.
245+ tA = (transA == ' N' ) ? CUDA. CUBLAS. CUBLAS_OP_N : CUDA. CUBLAS. CUBLAS_OP_T
246+ tB = (transB == ' N' ) ? CUDA. CUBLAS. CUBLAS_OP_N : CUDA. CUBLAS. CUBLAS_OP_T
247+ let r = Ref (tA)
248+ CUDA. CUBLAS. cublasLtMatmulDescSetAttribute (
249+ desc, CUDA. CUBLAS. CUBLASLT_MATMUL_DESC_TRANSA, r, sizeof (tA))
250+ end
251+ let r = Ref (tB)
252+ CUDA. CUBLAS. cublasLtMatmulDescSetAttribute (
253+ desc, CUDA. CUBLAS. CUBLASLT_MATMUL_DESC_TRANSB, r, sizeof (tB))
254+ end
255+
256+ # Layout shape is the *storage* shape (pre-transpose).
257+ Arows = transA == ' N' ? m : k
258+ Acols = transA == ' N' ? k : m
259+ Brows = transB == ' N' ? k : n
260+ Bcols = transB == ' N' ? n : k
261+
262+ Aref = Ref {CUDA.CUBLAS.cublasLtMatrixLayout_t} (C_NULL )
263+ Bref = Ref {CUDA.CUBLAS.cublasLtMatrixLayout_t} (C_NULL )
264+ Cref = Ref {CUDA.CUBLAS.cublasLtMatrixLayout_t} (C_NULL )
265+ CUDA. CUBLAS. cublasLtMatrixLayoutCreate (Aref, R32, UInt64 (Arows), UInt64 (Acols), Int64 (lda))
266+ CUDA. CUBLAS. cublasLtMatrixLayoutCreate (Bref, R32, UInt64 (Brows), UInt64 (Bcols), Int64 (ldb))
267+ CUDA. CUBLAS. cublasLtMatrixLayoutCreate (Cref, R32, UInt64 (m), UInt64 (n), Int64 (ldc))
268+
269+ # Preference: tell the heuristic how much workspace it can use.
270+ pref_ref = Ref {CUDA.CUBLAS.cublasLtMatmulPreference_t} (C_NULL )
271+ CUDA. CUBLAS. cublasLtMatmulPreferenceCreate (pref_ref)
272+ pref = pref_ref[]
273+ let r = Ref (_LT_WS_BYTES)
274+ CUDA. CUBLAS. cublasLtMatmulPreferenceSetAttribute (
275+ pref, CUDA. CUBLAS. CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
276+ r, sizeof (_LT_WS_BYTES))
277+ end
278+
279+ # Heuristic: top-1 algorithm.
280+ heur = Vector {CUDA.CUBLAS.cublasLtMatmulHeuristicResult_t} (undef, 1 )
281+ returned = Ref {Cint} (0 )
282+ CUDA. CUBLAS. cublasLtMatmulAlgoGetHeuristic (
283+ handle, desc, Aref[], Bref[], Cref[], Cref[],
284+ pref, Cint (1 ), heur, returned)
285+ returned[] < 1 && error (" cuBLASLt has no algo for shape (m=$m ,n=$n ,k=$k ,trans=$transA$transB )" )
286+
287+ return LtPlan (desc, Aref[], Bref[], Cref[], heur[1 ]. algo)
288+ end
289+
290+ function _gemm_lt! (plan:: LtPlan ,
291+ C:: CuArray{Float32,2} , A:: CuArray{Float32,2} , B:: CuArray{Float32,2} ;
292+ alpha:: Float32 = 1f0 , beta:: Float32 = 0f0 )
293+ state = _lt_state ()
294+ # cuBLASLt's matmul descriptor defaults to CUBLASLT_POINTER_MODE_HOST
295+ # (independent of the cuBLAS handle's pointer mode), so alpha/beta are
296+ # plain host Refs here — using CuRef would trigger UVA faults.
297+ α = Ref {Float32} (alpha)
298+ β = Ref {Float32} (beta)
299+ algo_ref = Ref (plan. algo)
300+ CUDA. CUBLAS. cublasLtMatmul (
301+ state. handle, plan. desc,
302+ α, A, plan. Adesc,
303+ B, plan. Bdesc,
304+ β, C, plan. Cdesc,
305+ C, plan. Cdesc, # D = C in place
306+ algo_ref,
307+ state. ws, sizeof (state. ws),
308+ CUDA. stream (),
309+ )
310+ return C
311+ end
312+
313+ # Three plans for our specific 2-layer MLP shape.
314+ struct LtPlans
315+ p1:: LtPlan # W1 * X : (h,d) * (d,n) → (h,n)
316+ p2:: LtPlan # W2' * J_2 : store (1,h),'T' * (1,n) → (h,n)
317+ p3:: LtPlan # out * X' : (h,n) * store (d,n),'T' → (h,d)
318+ end
319+
320+ function build_lt_plans (W1:: CuArray{Float32,2} , W2:: CuArray{Float32,2} ,
321+ X:: CuArray{Float32,2} )
322+ h, d = size (W1)
323+ nn = size (X, 2 )
324+ p1 = _build_lt_plan (' N' , ' N' , h, nn, d, h, d, h)
325+ p2 = _build_lt_plan (' T' , ' N' , h, nn, 1 , 1 , 1 , h)
326+ p3 = _build_lt_plan (' N' , ' T' , h, d, nn, h, d, h)
327+ return LtPlans (p1, p2, p3)
328+ end
329+
330+ function reverse_diff_v6 (plans:: LtPlans , W1, W2, X, y)
331+ h, d = size (W1)
332+ nn = size (X, 2 )
333+
334+ Z1 = CuArray {Float32} (undef, h, nn)
335+ _gemm_lt! (plans. p1, Z1, W1, X)
336+
337+ y_1 = similar (Z1)
338+ J_1 = similar (Z1)
339+ tanh_and_jac! (y_1, J_1, Z1)
340+
341+ J_2 = 2 .* (W2 * y_1 .- y) ./ size (y, 2 )
342+
343+ tmp = CuArray {Float32} (undef, h, nn)
344+ _gemm_lt! (plans. p2, tmp, W2, J_2)
345+
346+ out = similar (tmp)
347+ vmul! (out, J_1, tmp)
348+
349+ result = CuArray {Float32} (undef, h, d)
350+ _gemm_lt! (plans. p3, result, out, X)
351+ return result
352+ end
353+
354+ # -------------------------------------------------------------------------
355+ # Lux + Mooncake path
198356#
199357# Builds an equivalent 2-layer MLP `Y = W2 * tanh(W1 * X)` (no bias) using
200358# Lux, plugs in the *same* CuArray weights so the gradient is comparable,
201- # and lets Zygote source-to-source the backward. This goes through the same
202- # CUDA.jl + cuBLAS stack as `reverse_diff`, so we expect similar kernels —
203- # the interesting thing is the AD/dispatch overhead Lux+Zygote add on top.
359+ # and uses Mooncake (the modern Julia 1.12-friendly reverse-mode AD) for the
360+ # backward. Goes through the same CUDA.jl + cuBLAS stack as `reverse_diff`,
361+ # so kernels should look similar — what we're measuring is the AD/dispatch
362+ # overhead Lux+Mooncake add on top.
204363# -------------------------------------------------------------------------
205- function build_lux (W1g:: CuArray{Float32,2} , W2g:: CuArray{Float32,2} )
364+ struct LuxMooncake{M,P,S,L,R}
365+ model:: M
366+ ps:: P
367+ st:: S
368+ loss_fn:: L
369+ rule:: R
370+ end
371+
372+ function build_lux (W1g:: CuArray{Float32,2} , W2g:: CuArray{Float32,2} ,
373+ Xg:: CuArray , yg:: CuArray )
206374 h, d = size (W1g)
207375 model = Lux. Chain (
208376 Lux. Dense (d => h, tanh; use_bias = false ),
@@ -213,15 +381,24 @@ function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2})
213381 layer_2 = (weight = W2g,),
214382 )
215383 st = Lux. initialstates (Random. default_rng (), model)
216- return model, ps, st
217- end
218384
219- function lux_grad (model, ps, st, Xg:: CuArray , yg:: CuArray )
220- function loss_fn (p)
221- y_hat, _ = model (Xg, p, st)
222- return sum ((y_hat .- yg) .^ 2 ) / size (yg, 2 )
385+ # Closure captures Xg, yg, model, st — only `p` is the differentiated arg.
386+ loss_fn = let model = model, st = st, Xg = Xg, yg = yg
387+ p -> begin
388+ y_hat, _ = model (Xg, p, st)
389+ return sum ((y_hat .- yg) .^ 2 ) / size (yg, 2 )
390+ end
223391 end
224- ∂ps = first (Zygote. gradient (loss_fn, ps))
392+
393+ # build_rrule is the expensive step (compiles the reverse pass for these
394+ # types) — do it once at setup so the per-call cost in the benchmark is
395+ # just the actual fwd+bwd execution.
396+ rule = Mooncake. build_rrule (loss_fn, ps)
397+ return LuxMooncake (model, ps, st, loss_fn, rule)
398+ end
399+
400+ function lux_grad (lm:: LuxMooncake )
401+ _, (_, ∂ps) = Mooncake. value_and_gradient!! (lm. rule, lm. loss_fn, lm. ps)
225402 return ∂ps. layer_1. weight
226403end
227404
@@ -352,17 +529,24 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
352529 grad_julia = Array (reverse_diff (W1g, W2g, Xg, yg))
353530 grad_julia_v4 = Array (reverse_diff_v4 (W1g, W2g, Xg, yg))
354531 grad_julia_v5 = Array (reverse_diff_v5 (W1g, W2g, Xg, yg))
532+ print (" cuBLASLt build_lt_plans for h=$h ... " ); flush (stdout )
533+ t_lt_build = @elapsed lt_plans = build_lt_plans (W1g, W2g, Xg)
534+ @printf " %.3f s\n " t_lt_build
535+ grad_julia_v6 = Array (reverse_diff_v6 (lt_plans, W1g, W2g, Xg, yg))
355536 CUDA. synchronize ()
356537
357- # Lux + Zygote warmup (first call compiles Zygote's pullback for this shape)
358- print (" Lux+Zygote compile warmup for h=$h ... " ); flush (stdout )
359- lux_model, lux_ps, lux_st = build_lux (W1g, W2g)
360- t_lux_compile = @elapsed begin
361- lux_grad (lux_model, lux_ps, lux_st, Xg, yg)
362- CUDA. synchronize ()
538+ # Lux + Mooncake setup. build_rrule compiles the reverse pass for these
539+ # types (one-time cost per shape); first call afterwards still does some
540+ # JIT, so we time both separately.
541+ print (" Lux+Mooncake build_rrule for h=$h ... " ); flush (stdout )
542+ t_lux_build = @elapsed lm = build_lux (W1g, W2g, Xg, yg)
543+ @printf " %.2f s, " t_lux_build
544+ print (" first call ... " ); flush (stdout )
545+ t_lux_first = @elapsed begin
546+ lux_grad (lm); CUDA. synchronize ()
363547 end
364- @printf " %.2f s\n " t_lux_compile
365- grad_lux = Array (lux_grad (lux_model, lux_ps, lux_st, Xg, yg ))
548+ @printf " %.2f s\n " t_lux_first
549+ grad_lux = Array (lux_grad (lm ))
366550 CUDA. synchronize ()
367551
368552 # ----- PyTorch -----
@@ -384,7 +568,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
384568 # ----- Numerical equivalence -----
385569 for (name, g) in [(" Julia v4 (vec=4) " , grad_julia_v4),
386570 (" Julia v5 (vec=4+SIMT)" , grad_julia_v5),
387- (" Lux + Zygote " , grad_lux),
571+ (" Julia v6 (vec=4+Lt)" , grad_julia_v6),
572+ (" Lux + Mooncake " , grad_lux),
388573 (" PyTorch eager " , grad_pytorch_eager),
389574 (" PyTorch compiled " , grad_pytorch_compiled)]
390575 maxdiff = maximum (abs .(grad_julia .- g))
@@ -409,8 +594,12 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
409594 reverse_diff_v5 ($ W1g, $ W2g, $ Xg, $ yg)
410595 CUDA. synchronize ()
411596 end samples= 30 evals= 1 seconds= 10
597+ bj6 = @benchmark begin
598+ reverse_diff_v6 ($ lt_plans, $ W1g, $ W2g, $ Xg, $ yg)
599+ CUDA. synchronize ()
600+ end samples= 30 evals= 1 seconds= 10
412601 bjlux = @benchmark begin
413- lux_grad ($ lux_model, $ lux_ps, $ lux_st, $ Xg, $ yg )
602+ lux_grad ($ lm )
414603 CUDA. synchronize ()
415604 end samples= 30 evals= 1 seconds= 10
416605 be = @benchmark begin
@@ -424,7 +613,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
424613 @printf " Julia broadcast : median %8.3f µs\n " 1e-3 * median (bj). time
425614 @printf " Julia vec=4 : median %8.3f µs\n " 1e-3 * median (bj4). time
426615 @printf " Julia vec=4 + SIMT : median %8.3f µs\n " 1e-3 * median (bj5). time
427- @printf " Lux + Zygote : median %8.3f µs\n " 1e-3 * median (bjlux). time
616+ @printf " Julia vec=4 + cuBLASLt: median %8.3f µs\n " 1e-3 * median (bj6). time
617+ @printf " Lux + Mooncake : median %8.3f µs\n " 1e-3 * median (bjlux). time
428618 @printf " PyTorch eager : median %8.3f µs\n " 1e-3 * median (be). time
429619 @printf " PyTorch compiled : median %8.3f µs\n " 1e-3 * median (bc). time
430620
@@ -438,8 +628,11 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
438628 println (" \n --- CUDA trace: Julia vec=4 + SIMT ---" )
439629 summarize_julia_trace (stdout , julia_trace (() -> reverse_diff_v5 (W1g, W2g, Xg, yg)))
440630
441- println (" \n --- CUDA trace: Lux + Zygote ---" )
442- summarize_julia_trace (stdout , julia_trace (() -> lux_grad (lux_model, lux_ps, lux_st, Xg, yg)))
631+ println (" \n --- CUDA trace: Julia vec=4 + cuBLASLt ---" )
632+ summarize_julia_trace (stdout , julia_trace (() -> reverse_diff_v6 (lt_plans, W1g, W2g, Xg, yg)))
633+
634+ println (" \n --- CUDA trace: Lux + Mooncake ---" )
635+ summarize_julia_trace (stdout , julia_trace (() -> lux_grad (lm)))
443636
444637 println (" \n --- CUDA trace: PyTorch eager ---" )
445638 println (pytorch_trace (() -> pytorch_grad_eager (W1t, W2t, Xt, yt)))
0 commit comments