From b4a642b9ed9832dd7d4dec1cc5094d2007cb4765 Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sun, 11 Jan 2026 17:51:32 -0800
Subject: [PATCH 1/3] fix: Zero case in gather
---
src/gather.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/gather.jl b/src/gather.jl
index d75f89a2c..7997f8784 100644
--- a/src/gather.jl
+++ b/src/gather.jl
@@ -110,6 +110,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
end
function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray)
+ isempty(dst) && return dst
n_dims = scatter_dims(src, dst, idx)
dims = size(src)[1:n_dims]
max_dims_idx = prod(dims)
From 44577d6c72771703d2dc7419c742599c1c887bf8 Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sun, 11 Jan 2026 20:03:34 -0800
Subject: [PATCH 2/3] test: Test zero-sized array
---
test/ext_cuda/gather.jl | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/test/ext_cuda/gather.jl b/test/ext_cuda/gather.jl
index 9fa30efa8..1c2b1384d 100644
--- a/test/ext_cuda/gather.jl
+++ b/test/ext_cuda/gather.jl
@@ -103,4 +103,10 @@
outv2 = NNlib.gather(v2, i)
@test collect(outv2) == NNlib.gather(collect(v2), collect(i))
end
+
+ # Zero-sized
+ x = CT([1,2,3])
+ i = CT(Int[])
+ y = NNlib.gather(x, i)
+ @test isempty(y)
end
From 6b54cc89ebd9db36803bb8f03bad58471011b4e9 Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sun, 11 Jan 2026 21:35:18 -0800
Subject: [PATCH 3/3] fix: scatter! skip empty
---
ext/NNlibCUDAExt/scatter.jl | 2 ++
1 file changed, 2 insertions(+)
diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl
index 9b323d504..a71fdc115 100644
--- a/ext/NNlibCUDAExt/scatter.jl
+++ b/ext/NNlibCUDAExt/scatter.jl
@@ -48,6 +48,8 @@ end
function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray},
src::Union{AnyCuArray,AbstractCuSparseArray},
idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP
+ isempty(idx) && return dst
+
dims = NNlib.scatter_dims(dst, src, idx)
args = if dims == 0
max_idx = length(idx)