Skip to content

Commit d131cbb

Browse files
committed
Fix format
1 parent 66a5f5b commit d131cbb

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

perf/arraydiff_gpu.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float64 = 1e-6)
141141
CUDA.synchronize()
142142

143143
# ArrayDiff CPU.
144-
print("ArrayDiff CPU build (h=$h) ... "); flush(stdout)
145-
t_cpu_build = @elapsed ev_cpu =
146-
build_arraydiff(W2, X, y, h, d, n, ArrayDiff.Mode())
144+
print("ArrayDiff CPU build (h=$h) ... ");
145+
flush(stdout)
146+
t_cpu_build =
147+
@elapsed ev_cpu = build_arraydiff(W2, X, y, h, d, n, ArrayDiff.Mode())
147148
@printf "%.2f s\n" t_cpu_build
148149
x_cpu = vec(W1)
149150
g_cpu = zeros(Float64, length(x_cpu))
@@ -154,7 +155,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float64 = 1e-6)
154155
# GPU-resident solver (e.g. one whose ADAM step is on `CuVector`) would
155156
# use: the AD tape, the input vector, and the gradient buffer all stay
156157
# on the device, so there's no D2H round-trip on the gradient hot path.
157-
print("ArrayDiff GPU build (h=$h) ... "); flush(stdout)
158+
print("ArrayDiff GPU build (h=$h) ... ");
159+
flush(stdout)
158160
t_gpu_build = @elapsed ev_gpu = build_arraydiff(
159161
W2,
160162
X,
@@ -172,10 +174,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float64 = 1e-6)
172174
grad_gpu = reshape(Array(g_gpu), h, d)
173175

174176
# Numerical equivalence.
175-
for (name, g) in [
176-
("ArrayDiff CPU", grad_cpu),
177-
("ArrayDiff GPU", grad_gpu),
178-
]
177+
for (name, g) in [("ArrayDiff CPU", grad_cpu), ("ArrayDiff GPU", grad_gpu)]
179178
maxdiff = maximum(abs.(grad_ref .- g))
180179
relmag = maxdiff / max(maximum(abs.(grad_ref)), eps(Float64))
181180
ok = isapprox(grad_ref, g; rtol = rtol)
@@ -212,7 +211,11 @@ function main()
212211
error("CUDA is not functional in this environment.")
213212
end
214213
CUDA.math_mode!(CUDA.FAST_MATH)
215-
println("CUDA.jl device : ", CUDA.name(CUDA.device()), " (math_mode=FAST_MATH)")
214+
println(
215+
"CUDA.jl device : ",
216+
CUDA.name(CUDA.device()),
217+
" (math_mode=FAST_MATH)",
218+
)
216219
for h in (16, 256, 4096)
217220
run_one(; h = h)
218221
GC.gc(true)

0 commit comments

Comments
 (0)