Skip to content
8 changes: 8 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
friendly_tangents::Bool=false,
chunk_size::Union{Nothing,Int}=nothing,
enable_nfwd::Bool=true,
empty_cache::Bool=false,
)

Configuration struct for use with `ADTypes.AutoMooncake`.
Expand Down Expand Up @@ -32,11 +33,18 @@ Configuration struct for use with `ADTypes.AutoMooncake`.
`prepare_hvp_cache` and `prepare_hessian_cache`. When left enabled, cache
construction stays passive, but `value_and_derivative!!` / `value_and_gradient!!`
may still error at runtime if `nfwd` turns out not to support the function.
- `empty_cache::Bool=false`: if `true`, all internal Mooncake caches (compiled OpaqueClosures,
CodeInstances, and type-inference results) are cleared before building the new rule. This
allows the garbage collector to reclaim memory held by previously compiled rules, and is
useful in long-running sessions where many distinct functions have been differentiated.
Note that only Julia-level (GC-managed) objects are freed; JIT-compiled native machine
code is held permanently by the Julia runtime and cannot be reclaimed.
"""
@kwdef struct Config
debug_mode::Bool = false
silence_debug_messages::Bool = false
friendly_tangents::Bool = false
chunk_size::Union{Nothing,Int} = nothing
enable_nfwd::Bool = true
empty_cache::Bool = false
end
5 changes: 5 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ The API guarantees that tangents are initialized at zero before the first autodi
"""
@unstable function prepare_pullback_cache(fx...; config=Config())

# Clear global caches if requested.
config.empty_cache && empty_mooncake_caches!()

# Check that the output of `fx` is supported.
__exclude_func_with_unsupported_output(fx)

Expand Down Expand Up @@ -772,6 +775,7 @@ The API guarantees that tangents are initialized at zero before the first autodi
Calls `f(x...)` once during cache preparation.
"""
@unstable function prepare_gradient_cache(fx...; config=Config())
config.empty_cache && empty_mooncake_caches!()
rule = build_rrule(fx...; config.debug_mode, config.silence_debug_messages)
tangents = map(zero_tangent, fx)
y, rvs!! = __call_rule(rule, map((x, dx) -> CoDual(x, fdata(dx)), fx, tangents))
Expand Down Expand Up @@ -1678,6 +1682,7 @@ Returns a cache used with [`value_and_derivative!!`](@ref). See that function fo
@unstable @inline function prepare_derivative_cache(
f, x::Vararg{Any,N}; config=Config()
) where {N}
config.empty_cache && empty_mooncake_caches!()
fx = (f, x...)
requested_chunk_size = getfield(config, :chunk_size)
requested_chunk_size = if isnothing(requested_chunk_size)
Expand Down
26 changes: 26 additions & 0 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct MooncakeCache
end

MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance,Core.CodeInstance}())
Base.empty!(c::MooncakeCache) = (empty!(c.dict); c)

# The method table used by `Mooncake.@mooncake_overlay`.
Base.Experimental.@MethodTable mooncake_method_table
Expand Down Expand Up @@ -348,3 +349,28 @@ function get_interpreter(mode::Type{<:Mode})
end
return GLOBAL_INTERPRETERS[mode]
end

"""
empty_mooncake_caches!()

This is an internal function and not part of the public API. Called by `prepare_pullback_cache`,
`prepare_gradient_cache`, and `prepare_derivative_cache` when `Config(empty_cache=true)`
is passed.

Empties all three per-interpreter caches for both `ForwardMode` and `ReverseMode`:
- `oc_cache` : compiled `DerivedRule` / `OpaqueClosures`
- `code_cache` : `CodeInstance` objects (Julia IR per `MethodInstance`)
- `inf_cache` : `InferenceResult` objects from type inference

After clearing, Mooncake re-derives rules from scratch on the next use. Only Julia-level
(GC-managed) objects are freed; JIT-compiled native machine code allocated by LLVM
is held permanently by the Julia runtime.
"""
function empty_mooncake_caches!()
for interp in values(GLOBAL_INTERPRETERS)
empty!(interp.oc_cache)
empty!(interp.code_cache)
empty!(interp.inf_cache)
end
return nothing
end
6 changes: 2 additions & 4 deletions src/rules/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
return zero_dual(_foreigncall_(name, tuple_map(primal, args)...))
end

@inline function _threading_foreigncall_rrule(f, name, args...)
function _threading_foreigncall_rrule()
throw(
ErrorException(
"Differentiating through threading is not safe and is unsupported " *
Expand Down Expand Up @@ -31,9 +31,7 @@ for name in [
Val($(QuoteNode(name))), args...
)

@eval rrule!!(f::CoDual{typeof(_foreigncall_)}, name::CoDual{Val{$(QuoteNode(name))}}, args...) = _threading_foreigncall_rrule(
f, primal(name), args...
)
@eval rrule!!(::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(QuoteNode(name))}}, args...) = _threading_foreigncall_rrule()
end

@is_primitive MinimalCtx ForwardMode Tuple{
Expand Down
3 changes: 1 addition & 2 deletions src/tangents/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1607,8 +1607,7 @@ end
# NamedTuple destination: recurse field-wise.
# For NamedTuple primals, tangents are plain NamedTuples and are indexed directly.
# For immutable struct primals, tangents are Tangent wrappers whose `.fields` entries are
# plain tangents or `PossiblyUninitTangent` values. If `tangent isa NoTangent` at runtime,
# return zero-tangent friendly values per field instead of erroring.
# plain tangents or `PossiblyUninitTangent` values.
# Mutable structs use the AsMutableFields path above instead.
# When `tangent isa NoTangent` the primal type has no differentiable fields according to
# the runtime world (e.g. because an extension declared tangent_type(P) == NoTangent after
Expand Down
1 change: 1 addition & 0 deletions test/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
@test !Mooncake.Config().silence_debug_messages
@test isnothing(Mooncake.Config().chunk_size)
@test Mooncake.Config().enable_nfwd
@test !Mooncake.Config().empty_cache
end
22 changes: 22 additions & 0 deletions test/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,26 @@ end
@test grad[2] == [2.0, 0.0, 1.0]
end
end

@testset "Config(empty_cache=true)" begin
Comment thread
yebai marked this conversation as resolved.
f = x -> sin(x[1]) + x[2]^2
x = [1.0, 2.0]

# Build up the cache with several functions, then clear it.
for g in [x -> sum(x .^ 2), x -> prod(x), x -> sum(exp.(x))]
Mooncake.prepare_gradient_cache(g, randn(10))
end
n_before = length(Mooncake.GLOBAL_INTERPRETERS[Mooncake.ReverseMode].oc_cache)
@test n_before > 0

cache = Mooncake.prepare_gradient_cache(
f, x; config=Mooncake.Config(empty_cache=true)
)
@test length(Mooncake.GLOBAL_INTERPRETERS[Mooncake.ReverseMode].oc_cache) < n_before

# AD still correct after clearing.
val, grad = Mooncake.value_and_gradient!!(cache, f, x)
@test val ≈ sin(x[1]) + x[2]^2
@test grad[2] ≈ [cos(x[1]), 2x[2]]
end
end
Loading