@@ -141,9 +141,10 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float64 = 1e-6)
141141 CUDA. synchronize ()
142142
143143 # ArrayDiff CPU.
144- print (" ArrayDiff CPU build (h=$h ) ... " ); flush (stdout )
145- t_cpu_build = @elapsed ev_cpu =
146- build_arraydiff (W2, X, y, h, d, n, ArrayDiff. Mode ())
144+ print (" ArrayDiff CPU build (h=$h ) ... " );
145+ flush (stdout )
146+ t_cpu_build =
147+ @elapsed ev_cpu = build_arraydiff (W2, X, y, h, d, n, ArrayDiff. Mode ())
147148 @printf " %.2f s\n " t_cpu_build
148149 x_cpu = vec (W1)
149150 g_cpu = zeros (Float64, length (x_cpu))
@@ -154,7 +155,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float64 = 1e-6)
154155 # GPU-resident solver (e.g. one whose ADAM step is on `CuVector`) would
155156 # use: the AD tape, the input vector, and the gradient buffer all stay
156157 # on the device, so there's no D2H round-trip on the gradient hot path.
157- print (" ArrayDiff GPU build (h=$h ) ... " ); flush (stdout )
158+ print (" ArrayDiff GPU build (h=$h ) ... " );
159+ flush (stdout )
158160 t_gpu_build = @elapsed ev_gpu = build_arraydiff (
159161 W2,
160162 X,
@@ -172,10 +174,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float64 = 1e-6)
172174 grad_gpu = reshape (Array (g_gpu), h, d)
173175
174176 # Numerical equivalence.
175- for (name, g) in [
176- (" ArrayDiff CPU" , grad_cpu),
177- (" ArrayDiff GPU" , grad_gpu),
178- ]
177+ for (name, g) in [(" ArrayDiff CPU" , grad_cpu), (" ArrayDiff GPU" , grad_gpu)]
179178 maxdiff = maximum (abs .(grad_ref .- g))
180179 relmag = maxdiff / max (maximum (abs .(grad_ref)), eps (Float64))
181180 ok = isapprox (grad_ref, g; rtol = rtol)
@@ -212,7 +211,11 @@ function main()
212211 error (" CUDA is not functional in this environment." )
213212 end
214213 CUDA. math_mode! (CUDA. FAST_MATH)
215- println (" CUDA.jl device : " , CUDA. name (CUDA. device ()), " (math_mode=FAST_MATH)" )
214+ println (
215+ " CUDA.jl device : " ,
216+ CUDA. name (CUDA. device ()),
217+ " (math_mode=FAST_MATH)" ,
218+ )
216219 for h in (16 , 256 , 4096 )
217220 run_one (; h = h)
218221 GC. gc (true )
0 commit comments