diff --git a/src/gradient.jl b/src/gradient.jl index 66c37b318c..29d6b6535e 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -54,6 +54,8 @@ function gradient(f, adtype::AbstractADType, args...) Supported backends are $SUPPORTED_AD_BACKENDS.") end +gradient(f, adtype::AutoZygote, args...) = Zygote.gradient(f, args...) + # Default gradient using Zygote function gradient(f, args...; zero::Bool=true) for a in args diff --git a/test/gradient.jl b/test/gradient.jl new file mode 100644 index 0000000000..33318d58e3 --- /dev/null +++ b/test/gradient.jl @@ -0,0 +1,7 @@ +@testset "AutoADTypes gradient" begin + m = Dense(2 => 2) + x = rand(Float32, 2) + g_zygote = Flux.gradient(m -> sum(m(x)), AutoZygote(), m)[1] + g_mooncake = Flux.gradient(m -> sum(m(x)), AutoMooncake(), m)[1] + @test g_zygote.weight ≈ g_mooncake.weight +end