Skip to content

Commit 374fcda

Browse files
committed
Add more
1 parent ac1c6bd commit 374fcda

2 files changed

Lines changed: 220 additions & 27 deletions

File tree

perf/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
12+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1213
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1314
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

perf/cuda_vs_pytorch.jl

Lines changed: 219 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ using CUDA
2121
using CUDA: AS
2222
using BenchmarkTools
2323
using PythonCall
24-
using Lux, Zygote
24+
using Lux
25+
import Mooncake
2526

2627
# -------------------------------------------------------------------------
2728
# Hardcoded CUDA.jl path
@@ -194,15 +195,182 @@ function reverse_diff_v5(W1, W2, X, y)
194195
end
195196

196197
# -------------------------------------------------------------------------
197-
# Lux + Zygote path
198+
# v6: vec=4 elementwise + cuBLASLt with per-shape heuristic-picked algo
199+
#
200+
# cuBLAS's standard heuristic, even with CUBLAS_COMPUTE_32F + DEFAULT_MATH,
201+
# picks `cutlass_80_simt_sgemm_*` for our awkward shapes. PyTorch's process
202+
# happens to land on `magma_sgemmEx_kernel` for the same compute type — same
203+
# library, different choice. cuBLASLt exposes a fuller heuristic API with a
204+
# workspace budget that often unlocks better algos. We build a matmul
205+
# descriptor + layouts per (transA, transB, m, n, k) shape, ask cuBLASLt for
206+
# its best algo, and reuse the cached plan on every call.
207+
# -------------------------------------------------------------------------
208+
const _LT_WS_BYTES = Csize_t(32 * 1024 * 1024) # 32 MiB workspace
209+
210+
# Lazy: created on first use, kept alive for the process.
211+
const _LT_STATE = Ref{Any}(nothing)
212+
213+
function _lt_state()
214+
s = _LT_STATE[]
215+
if s === nothing
216+
h_ref = Ref{CUDA.CUBLAS.cublasLtHandle_t}(C_NULL)
217+
CUDA.CUBLAS.cublasLtCreate(h_ref)
218+
ws = CUDA.CuArray{UInt8}(undef, Int(_LT_WS_BYTES))
219+
s = (handle = h_ref[], ws = ws)
220+
_LT_STATE[] = s
221+
end
222+
return s::NamedTuple{(:handle, :ws)}
223+
end
224+
225+
mutable struct LtPlan
226+
desc::CUDA.CUBLAS.cublasLtMatmulDesc_t
227+
Adesc::CUDA.CUBLAS.cublasLtMatrixLayout_t
228+
Bdesc::CUDA.CUBLAS.cublasLtMatrixLayout_t
229+
Cdesc::CUDA.CUBLAS.cublasLtMatrixLayout_t
230+
algo::CUDA.CUBLAS.cublasLtMatmulAlgo_t
231+
end
232+
233+
function _build_lt_plan(transA::Char, transB::Char,
234+
m::Int, n::Int, k::Int,
235+
lda::Int, ldb::Int, ldc::Int)
236+
state = _lt_state()
237+
handle = state.handle
238+
R32 = CUDA.CUDACore.R_32F # cudaDataType for Float32
239+
240+
desc_ref = Ref{CUDA.CUBLAS.cublasLtMatmulDesc_t}(C_NULL)
241+
CUDA.CUBLAS.cublasLtMatmulDescCreate(desc_ref, CUDA.CUBLAS.CUBLAS_COMPUTE_32F, R32)
242+
desc = desc_ref[]
243+
244+
# Set transpose attributes.
245+
tA = (transA == 'N') ? CUDA.CUBLAS.CUBLAS_OP_N : CUDA.CUBLAS.CUBLAS_OP_T
246+
tB = (transB == 'N') ? CUDA.CUBLAS.CUBLAS_OP_N : CUDA.CUBLAS.CUBLAS_OP_T
247+
let r = Ref(tA)
248+
CUDA.CUBLAS.cublasLtMatmulDescSetAttribute(
249+
desc, CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSA, r, sizeof(tA))
250+
end
251+
let r = Ref(tB)
252+
CUDA.CUBLAS.cublasLtMatmulDescSetAttribute(
253+
desc, CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSB, r, sizeof(tB))
254+
end
255+
256+
# Layout shape is the *storage* shape (pre-transpose).
257+
Arows = transA == 'N' ? m : k
258+
Acols = transA == 'N' ? k : m
259+
Brows = transB == 'N' ? k : n
260+
Bcols = transB == 'N' ? n : k
261+
262+
Aref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL)
263+
Bref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL)
264+
Cref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL)
265+
CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Aref, R32, UInt64(Arows), UInt64(Acols), Int64(lda))
266+
CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Bref, R32, UInt64(Brows), UInt64(Bcols), Int64(ldb))
267+
CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Cref, R32, UInt64(m), UInt64(n), Int64(ldc))
268+
269+
# Preference: tell the heuristic how much workspace it can use.
270+
pref_ref = Ref{CUDA.CUBLAS.cublasLtMatmulPreference_t}(C_NULL)
271+
CUDA.CUBLAS.cublasLtMatmulPreferenceCreate(pref_ref)
272+
pref = pref_ref[]
273+
let r = Ref(_LT_WS_BYTES)
274+
CUDA.CUBLAS.cublasLtMatmulPreferenceSetAttribute(
275+
pref, CUDA.CUBLAS.CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
276+
r, sizeof(_LT_WS_BYTES))
277+
end
278+
279+
# Heuristic: top-1 algorithm.
280+
heur = Vector{CUDA.CUBLAS.cublasLtMatmulHeuristicResult_t}(undef, 1)
281+
returned = Ref{Cint}(0)
282+
CUDA.CUBLAS.cublasLtMatmulAlgoGetHeuristic(
283+
handle, desc, Aref[], Bref[], Cref[], Cref[],
284+
pref, Cint(1), heur, returned)
285+
returned[] < 1 && error("cuBLASLt has no algo for shape (m=$m,n=$n,k=$k,trans=$transA$transB)")
286+
287+
return LtPlan(desc, Aref[], Bref[], Cref[], heur[1].algo)
288+
end
289+
290+
function _gemm_lt!(plan::LtPlan,
291+
C::CuArray{Float32,2}, A::CuArray{Float32,2}, B::CuArray{Float32,2};
292+
alpha::Float32 = 1f0, beta::Float32 = 0f0)
293+
state = _lt_state()
294+
# cuBLASLt's matmul descriptor defaults to CUBLASLT_POINTER_MODE_HOST
295+
# (independent of the cuBLAS handle's pointer mode), so alpha/beta are
296+
# plain host Refs here — using CuRef would trigger UVA faults.
297+
α = Ref{Float32}(alpha)
298+
β = Ref{Float32}(beta)
299+
algo_ref = Ref(plan.algo)
300+
CUDA.CUBLAS.cublasLtMatmul(
301+
state.handle, plan.desc,
302+
α, A, plan.Adesc,
303+
B, plan.Bdesc,
304+
β, C, plan.Cdesc,
305+
C, plan.Cdesc, # D = C in place
306+
algo_ref,
307+
state.ws, sizeof(state.ws),
308+
CUDA.stream(),
309+
)
310+
return C
311+
end
312+
313+
# Three plans for our specific 2-layer MLP shape.
314+
struct LtPlans
315+
p1::LtPlan # W1 * X : (h,d) * (d,n) → (h,n)
316+
p2::LtPlan # W2' * J_2 : store (1,h),'T' * (1,n) → (h,n)
317+
p3::LtPlan # out * X' : (h,n) * store (d,n),'T' → (h,d)
318+
end
319+
320+
function build_lt_plans(W1::CuArray{Float32,2}, W2::CuArray{Float32,2},
321+
X::CuArray{Float32,2})
322+
h, d = size(W1)
323+
nn = size(X, 2)
324+
p1 = _build_lt_plan('N', 'N', h, nn, d, h, d, h)
325+
p2 = _build_lt_plan('T', 'N', h, nn, 1, 1, 1, h)
326+
p3 = _build_lt_plan('N', 'T', h, d, nn, h, d, h)
327+
return LtPlans(p1, p2, p3)
328+
end
329+
330+
function reverse_diff_v6(plans::LtPlans, W1, W2, X, y)
331+
h, d = size(W1)
332+
nn = size(X, 2)
333+
334+
Z1 = CuArray{Float32}(undef, h, nn)
335+
_gemm_lt!(plans.p1, Z1, W1, X)
336+
337+
y_1 = similar(Z1)
338+
J_1 = similar(Z1)
339+
tanh_and_jac!(y_1, J_1, Z1)
340+
341+
J_2 = 2 .* (W2 * y_1 .- y) ./ size(y, 2)
342+
343+
tmp = CuArray{Float32}(undef, h, nn)
344+
_gemm_lt!(plans.p2, tmp, W2, J_2)
345+
346+
out = similar(tmp)
347+
vmul!(out, J_1, tmp)
348+
349+
result = CuArray{Float32}(undef, h, d)
350+
_gemm_lt!(plans.p3, result, out, X)
351+
return result
352+
end
353+
354+
# -------------------------------------------------------------------------
355+
# Lux + Mooncake path
198356
#
199357
# Builds an equivalent 2-layer MLP `Y = W2 * tanh(W1 * X)` (no bias) using
200358
# Lux, plugs in the *same* CuArray weights so the gradient is comparable,
201-
# and lets Zygote source-to-source the backward. This goes through the same
202-
# CUDA.jl + cuBLAS stack as `reverse_diff`, so we expect similar kernels —
203-
# the interesting thing is the AD/dispatch overhead Lux+Zygote add on top.
359+
# and uses Mooncake (the modern Julia 1.12-friendly reverse-mode AD) for the
360+
# backward. Goes through the same CUDA.jl + cuBLAS stack as `reverse_diff`,
361+
# so kernels should look similar — what we're measuring is the AD/dispatch
362+
# overhead Lux+Mooncake add on top.
204363
# -------------------------------------------------------------------------
205-
function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2})
364+
struct LuxMooncake{M,P,S,L,R}
365+
model::M
366+
ps::P
367+
st::S
368+
loss_fn::L
369+
rule::R
370+
end
371+
372+
function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2},
373+
Xg::CuArray, yg::CuArray)
206374
h, d = size(W1g)
207375
model = Lux.Chain(
208376
Lux.Dense(d => h, tanh; use_bias = false),
@@ -213,15 +381,24 @@ function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2})
213381
layer_2 = (weight = W2g,),
214382
)
215383
st = Lux.initialstates(Random.default_rng(), model)
216-
return model, ps, st
217-
end
218384

219-
function lux_grad(model, ps, st, Xg::CuArray, yg::CuArray)
220-
function loss_fn(p)
221-
y_hat, _ = model(Xg, p, st)
222-
return sum((y_hat .- yg) .^ 2) / size(yg, 2)
385+
# Closure captures Xg, yg, model, st — only `p` is the differentiated arg.
386+
loss_fn = let model = model, st = st, Xg = Xg, yg = yg
387+
p -> begin
388+
y_hat, _ = model(Xg, p, st)
389+
return sum((y_hat .- yg) .^ 2) / size(yg, 2)
390+
end
223391
end
224-
∂ps = first(Zygote.gradient(loss_fn, ps))
392+
393+
# build_rrule is the expensive step (compiles the reverse pass for these
394+
# types) — do it once at setup so the per-call cost in the benchmark is
395+
# just the actual fwd+bwd execution.
396+
rule = Mooncake.build_rrule(loss_fn, ps)
397+
return LuxMooncake(model, ps, st, loss_fn, rule)
398+
end
399+
400+
function lux_grad(lm::LuxMooncake)
401+
_, (_, ∂ps) = Mooncake.value_and_gradient!!(lm.rule, lm.loss_fn, lm.ps)
225402
return ∂ps.layer_1.weight
226403
end
227404

@@ -352,17 +529,24 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
352529
grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg))
353530
grad_julia_v4 = Array(reverse_diff_v4(W1g, W2g, Xg, yg))
354531
grad_julia_v5 = Array(reverse_diff_v5(W1g, W2g, Xg, yg))
532+
print("cuBLASLt build_lt_plans for h=$h ... "); flush(stdout)
533+
t_lt_build = @elapsed lt_plans = build_lt_plans(W1g, W2g, Xg)
534+
@printf "%.3f s\n" t_lt_build
535+
grad_julia_v6 = Array(reverse_diff_v6(lt_plans, W1g, W2g, Xg, yg))
355536
CUDA.synchronize()
356537

357-
# Lux + Zygote warmup (first call compiles Zygote's pullback for this shape)
358-
print("Lux+Zygote compile warmup for h=$h ... "); flush(stdout)
359-
lux_model, lux_ps, lux_st = build_lux(W1g, W2g)
360-
t_lux_compile = @elapsed begin
361-
lux_grad(lux_model, lux_ps, lux_st, Xg, yg)
362-
CUDA.synchronize()
538+
# Lux + Mooncake setup. build_rrule compiles the reverse pass for these
539+
# types (one-time cost per shape); first call afterwards still does some
540+
# JIT, so we time both separately.
541+
print("Lux+Mooncake build_rrule for h=$h ... "); flush(stdout)
542+
t_lux_build = @elapsed lm = build_lux(W1g, W2g, Xg, yg)
543+
@printf "%.2f s, " t_lux_build
544+
print("first call ... "); flush(stdout)
545+
t_lux_first = @elapsed begin
546+
lux_grad(lm); CUDA.synchronize()
363547
end
364-
@printf "%.2f s\n" t_lux_compile
365-
grad_lux = Array(lux_grad(lux_model, lux_ps, lux_st, Xg, yg))
548+
@printf "%.2f s\n" t_lux_first
549+
grad_lux = Array(lux_grad(lm))
366550
CUDA.synchronize()
367551

368552
# ----- PyTorch -----
@@ -384,7 +568,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
384568
# ----- Numerical equivalence -----
385569
for (name, g) in [("Julia v4 (vec=4) ", grad_julia_v4),
386570
("Julia v5 (vec=4+SIMT)", grad_julia_v5),
387-
("Lux + Zygote ", grad_lux),
571+
("Julia v6 (vec=4+Lt)", grad_julia_v6),
572+
("Lux + Mooncake ", grad_lux),
388573
("PyTorch eager ", grad_pytorch_eager),
389574
("PyTorch compiled ", grad_pytorch_compiled)]
390575
maxdiff = maximum(abs.(grad_julia .- g))
@@ -409,8 +594,12 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
409594
reverse_diff_v5($W1g, $W2g, $Xg, $yg)
410595
CUDA.synchronize()
411596
end samples=30 evals=1 seconds=10
597+
bj6 = @benchmark begin
598+
reverse_diff_v6($lt_plans, $W1g, $W2g, $Xg, $yg)
599+
CUDA.synchronize()
600+
end samples=30 evals=1 seconds=10
412601
bjlux = @benchmark begin
413-
lux_grad($lux_model, $lux_ps, $lux_st, $Xg, $yg)
602+
lux_grad($lm)
414603
CUDA.synchronize()
415604
end samples=30 evals=1 seconds=10
416605
be = @benchmark begin
@@ -424,7 +613,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
424613
@printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time
425614
@printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time
426615
@printf "Julia vec=4 + SIMT : median %8.3f µs\n" 1e-3 * median(bj5).time
427-
@printf "Lux + Zygote : median %8.3f µs\n" 1e-3 * median(bjlux).time
616+
@printf "Julia vec=4 + cuBLASLt: median %8.3f µs\n" 1e-3 * median(bj6).time
617+
@printf "Lux + Mooncake : median %8.3f µs\n" 1e-3 * median(bjlux).time
428618
@printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time
429619
@printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time
430620

@@ -438,8 +628,11 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
438628
println("\n--- CUDA trace: Julia vec=4 + SIMT ---")
439629
summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v5(W1g, W2g, Xg, yg)))
440630

441-
println("\n--- CUDA trace: Lux + Zygote ---")
442-
summarize_julia_trace(stdout, julia_trace(() -> lux_grad(lux_model, lux_ps, lux_st, Xg, yg)))
631+
println("\n--- CUDA trace: Julia vec=4 + cuBLASLt ---")
632+
summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v6(lt_plans, W1g, W2g, Xg, yg)))
633+
634+
println("\n--- CUDA trace: Lux + Mooncake ---")
635+
summarize_julia_trace(stdout, julia_trace(() -> lux_grad(lm)))
443636

444637
println("\n--- CUDA trace: PyTorch eager ---")
445638
println(pytorch_trace(() -> pytorch_grad_eager(W1t, W2t, Xt, yt)))

0 commit comments

Comments
 (0)