From b4a642b9ed9832dd7d4dec1cc5094d2007cb4765 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sun, 11 Jan 2026 17:51:32 -0800 Subject: [PATCH 1/3] fix: Zero case in gather --- src/gather.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gather.jl b/src/gather.jl index d75f89a2c..7997f8784 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -110,6 +110,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) end function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) + isempty(dst) && return dst n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims] max_dims_idx = prod(dims) From 44577d6c72771703d2dc7419c742599c1c887bf8 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sun, 11 Jan 2026 20:03:34 -0800 Subject: [PATCH 2/3] test: Test zero-sized array --- test/ext_cuda/gather.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/ext_cuda/gather.jl b/test/ext_cuda/gather.jl index 9fa30efa8..1c2b1384d 100644 --- a/test/ext_cuda/gather.jl +++ b/test/ext_cuda/gather.jl @@ -103,4 +103,10 @@ outv2 = NNlib.gather(v2, i) @test collect(outv2) == NNlib.gather(collect(v2), collect(i)) end + + # Zero-sized + x = CT([1,2,3]) + i = CT(Int[]) + y = NNlib.gather(x, i) + @test isempty(y) end From 6b54cc89ebd9db36803bb8f03bad58471011b4e9 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sun, 11 Jan 2026 21:35:18 -0800 Subject: [PATCH 3/3] fix: scatter! skip empty --- ext/NNlibCUDAExt/scatter.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 9b323d504..a71fdc115 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -48,6 +48,8 @@ end function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, src::Union{AnyCuArray,AbstractCuSparseArray}, idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP + isempty(idx) && return dst + dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx)