From e4c6bed4f969eba971020cef58c0846e59092e0f Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 25 Mar 2026 22:59:46 +0100 Subject: [PATCH] fix AutoZygote gradient Use GPUArrays caching allocator in train! Wraps each training iteration in `GPUArrays.@cached` so temporary GPU allocations (gradients, activations) are pooled and reused across steps instead of being freed and reallocated each iteration, reducing GC pressure. Closes #2636 Co-Authored-By: Claude Sonnet 4.6 cleanup cleanup --- Project.toml | 2 ++ src/train.jl | 15 ++++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index fe38383b18..ab2e158dc1 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLCore = "c2834f40-e789-41da-a90e-33b280584a8c" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -58,6 +59,7 @@ Enzyme = "0.13" EnzymeCore = "0.7.7, 0.8.4" FiniteDifferences = "0.12" Functors = "0.5" +GPUArrays = "11.2" MLCore = "1.0.0" MLDataDevices = "1.4.2" MLUtils = "0.4" diff --git a/src/train.jl b/src/train.jl index 5ab9e6e11a..f66f06422d 100644 --- a/src/train.jl +++ b/src/train.jl @@ -4,6 +4,7 @@ using LinearAlgebra using Optimisers: Optimisers using Functors: fmap, fmapstructure using ..Flux: Flux +using GPUArrays: GPUArrays using ProgressLogging: @progress, @withprogress, @logprogress using EnzymeCore: Duplicated @@ -110,19 +111,23 @@ It adds only a few features to the loop above: function train!(loss, adtype::AbstractADType, model, data, opt; cb = nothing) isnothing(cb) || error("""train! does not support callback functions. For more control use a loop with `gradient` and `update!`.""") + cache = GPUArrays.AllocCache() @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) - l, gs = Flux.withgradient(m -> loss(m, d_splat...), adtype, model) + GPUArrays.@cached cache begin + l, gs = Flux.withgradient(m -> loss(m, d_splat...), adtype, model) - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end - opt, model = _update!(opt, model, gs[1]) + opt, model = _update!(opt, model, gs[1]) + end @logprogress Base.haslength(data) ? i/length(data) : nothing end + GPUArrays.unsafe_free!(cache) end _update!(opt_state, model, grads) = Optimisers.update!(opt_state, model, grads)