Skip to content

Commit ac1c6bd

Browse files
committed
Lux
1 parent 8b7e2e8 commit ac1c6bd

2 files changed

Lines changed: 73 additions & 11 deletions

File tree

perf/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1112
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1213
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

perf/cuda_vs_pytorch.jl

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using CUDA
2121
using CUDA: AS
2222
using BenchmarkTools
2323
using PythonCall
24+
using Lux, Zygote
2425

2526
# -------------------------------------------------------------------------
2627
# Hardcoded CUDA.jl path
@@ -146,18 +147,25 @@ function _gemm_simt!(C::CuArray{Float32,2}, transA::Char, A::CuArray{Float32,2},
146147
ldb = max(1, stride(B, 2))
147148
ldc = max(1, stride(C, 2))
148149
# CUDA.jl puts the cuBLAS handle in CUBLAS_POINTER_MODE_DEVICE, so alpha/beta
149-
# MUST be device pointers. Passing host Ref{Float32} causes UVA fault handling
150-
# per kernel launch (~100× slowdown but eventually-correct values).
150+
# MUST be device pointers (host Ref triggers UVA fault handling — 100× slowdown).
151151
α = CUDA.CuRef{Float32}(alpha); β = CUDA.CuRef{Float32}(beta)
152-
CUDA.CUBLAS.cublasGemmEx(
153-
CUDA.CUBLAS.handle(),
154-
transA, transB, m, n, k,
155-
α, A, Float32, lda,
156-
B, Float32, ldb,
157-
β, C, Float32, ldc,
158-
CUDA.CUBLAS.CUBLAS_COMPUTE_32F,
159-
CUDA.CUBLAS.CUBLAS_GEMM_DEFAULT,
160-
)
152+
h = CUDA.CUBLAS.handle()
153+
# Under FAST_MATH the handle's math mode is CUBLAS_TF32_TENSOR_OP_MATH, which
154+
# forces TF32 tensor cores even when we ask for CUBLAS_COMPUTE_32F. Flip it to
155+
# DEFAULT_MATH for this call so cuBLAS picks a SIMT FP32 kernel.
156+
CUDA.CUBLAS.math_mode!(h, CUDA.DEFAULT_MATH)
157+
try
158+
CUDA.CUBLAS.cublasGemmEx(
159+
h, transA, transB, m, n, k,
160+
α, A, Float32, lda,
161+
B, Float32, ldb,
162+
β, C, Float32, ldc,
163+
CUDA.CUBLAS.CUBLAS_COMPUTE_32F,
164+
CUDA.CUBLAS.CUBLAS_GEMM_DEFAULT,
165+
)
166+
finally
167+
CUDA.CUBLAS.math_mode!(h, CUDA.math_mode()) # restore (FAST_MATH → TF32 tensor op)
168+
end
161169
return C
162170
end
163171

@@ -185,6 +193,38 @@ function reverse_diff_v5(W1, W2, X, y)
185193
return result
186194
end
187195

196+
# -------------------------------------------------------------------------
197+
# Lux + Zygote path
198+
#
199+
# Builds an equivalent 2-layer MLP `Y = W2 * tanh(W1 * X)` (no bias) using
200+
# Lux, plugs in the *same* CuArray weights so the gradient is comparable,
201+
# and lets Zygote source-to-source the backward. This goes through the same
202+
# CUDA.jl + cuBLAS stack as `reverse_diff`, so we expect similar kernels —
203+
# the interesting thing is the AD/dispatch overhead Lux+Zygote add on top.
204+
# -------------------------------------------------------------------------
205+
function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2})
206+
h, d = size(W1g)
207+
model = Lux.Chain(
208+
Lux.Dense(d => h, tanh; use_bias = false),
209+
Lux.Dense(h => 1, identity; use_bias = false),
210+
)
211+
ps = (
212+
layer_1 = (weight = W1g,),
213+
layer_2 = (weight = W2g,),
214+
)
215+
st = Lux.initialstates(Random.default_rng(), model)
216+
return model, ps, st
217+
end
218+
219+
function lux_grad(model, ps, st, Xg::CuArray, yg::CuArray)
220+
function loss_fn(p)
221+
y_hat, _ = model(Xg, p, st)
222+
return sum((y_hat .- yg) .^ 2) / size(yg, 2)
223+
end
224+
∂ps = first(Zygote.gradient(loss_fn, ps))
225+
return ∂ps.layer_1.weight
226+
end
227+
188228
# -------------------------------------------------------------------------
189229
# PyTorch path
190230
# -------------------------------------------------------------------------
@@ -314,6 +354,17 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
314354
grad_julia_v5 = Array(reverse_diff_v5(W1g, W2g, Xg, yg))
315355
CUDA.synchronize()
316356

357+
# Lux + Zygote warmup (first call compiles Zygote's pullback for this shape)
358+
print("Lux+Zygote compile warmup for h=$h ... "); flush(stdout)
359+
lux_model, lux_ps, lux_st = build_lux(W1g, W2g)
360+
t_lux_compile = @elapsed begin
361+
lux_grad(lux_model, lux_ps, lux_st, Xg, yg)
362+
CUDA.synchronize()
363+
end
364+
@printf "%.2f s\n" t_lux_compile
365+
grad_lux = Array(lux_grad(lux_model, lux_ps, lux_st, Xg, yg))
366+
CUDA.synchronize()
367+
317368
# ----- PyTorch -----
318369
W1t, W2t, Xt, yt = build_torch_tensors(W1, W2, X, y)
319370
grad_pytorch_eager = torch_to_julia(pytorch_grad_eager(W1t, W2t, Xt, yt))
@@ -333,6 +384,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
333384
# ----- Numerical equivalence -----
334385
for (name, g) in [("Julia v4 (vec=4) ", grad_julia_v4),
335386
("Julia v5 (vec=4+SIMT)", grad_julia_v5),
387+
("Lux + Zygote ", grad_lux),
336388
("PyTorch eager ", grad_pytorch_eager),
337389
("PyTorch compiled ", grad_pytorch_compiled)]
338390
maxdiff = maximum(abs.(grad_julia .- g))
@@ -357,6 +409,10 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
357409
reverse_diff_v5($W1g, $W2g, $Xg, $yg)
358410
CUDA.synchronize()
359411
end samples=30 evals=1 seconds=10
412+
bjlux = @benchmark begin
413+
lux_grad($lux_model, $lux_ps, $lux_st, $Xg, $yg)
414+
CUDA.synchronize()
415+
end samples=30 evals=1 seconds=10
360416
be = @benchmark begin
361417
pytorch_grad_eager($W1t, $W2t, $Xt, $yt)
362418
$torch.cuda.synchronize()
@@ -368,6 +424,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
368424
@printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time
369425
@printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time
370426
@printf "Julia vec=4 + SIMT : median %8.3f µs\n" 1e-3 * median(bj5).time
427+
@printf "Lux + Zygote : median %8.3f µs\n" 1e-3 * median(bjlux).time
371428
@printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time
372429
@printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time
373430

@@ -381,6 +438,9 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3)
381438
println("\n--- CUDA trace: Julia vec=4 + SIMT ---")
382439
summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v5(W1g, W2g, Xg, yg)))
383440

441+
println("\n--- CUDA trace: Lux + Zygote ---")
442+
summarize_julia_trace(stdout, julia_trace(() -> lux_grad(lux_model, lux_ps, lux_st, Xg, yg)))
443+
384444
println("\n--- CUDA trace: PyTorch eager ---")
385445
println(pytorch_trace(() -> pytorch_grad_eager(W1t, W2t, Xt, yt)))
386446

0 commit comments

Comments
 (0)