diff --git a/test/ext_cuda/models.jl b/test/ext_cuda/models.jl new file mode 100644 index 0000000000..36b04fcbb5 --- /dev/null +++ b/test/ext_cuda/models.jl @@ -0,0 +1,13 @@ +@testset "models' gradients on CUDA" begin + for (model, x, name) in TEST_MODELS + println("$name") + @testset "Zygote grad check $name" begin + @test test_gradients(model, x; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=AutoZygote()) + end + @testset "Mooncake grad check $name" begin + @test test_gradients(model, x; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=AutoMooncake()) + end + end +end