From 2d06881bc1b4668e227ed04cf73c6bf71c9a3704 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 15 May 2026 10:57:06 +0200 Subject: [PATCH] Generalize norm and normalize to AnyGPUArray. Lets `norm`/`normalize` short-circuit the LinearAlgebra dispatch for SubArrays/wrappers of GPU arrays, which would otherwise fall back to `norm_recursive_check` (Julia 1.12+) or `generic_norm2`, both of which iterate and trigger scalar indexing. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/host/linalg.jl | 4 ++-- test/testsuite/linalg.jl | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 44cb137d..a46b85d2 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -757,7 +757,7 @@ end ## norm -function LinearAlgebra.norm(v::AbstractGPUArray{T}, p::Real=2) where {T} +function LinearAlgebra.norm(v::AnyGPUArray{T}, p::Real=2) where {T} result_type, sum_type, promote_ = _normtypes(T) isempty(v) && return zero(result_type) p == 0 && return convert(result_type, count(!iszero, v)) @@ -805,7 +805,7 @@ end ## normalize # Avoid `first(a)` scalar indexing in LinearAlgebra.normalize (JuliaGPU/CUDA.jl#3097) -function LinearAlgebra.normalize(a::AbstractGPUArray, p::Real=2) +function LinearAlgebra.normalize(a::AnyGPUArray, p::Real=2) nrm = norm(a, p) if !isempty(a) T = typeof(zero(eltype(a))/nrm) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 266ac021..e8171233 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -542,6 +542,29 @@ end @test compare(normalize, AT, arr) @test compare(normalize, AT, arr, Ref(1)) end + # Wrapped GPU arrays (e.g. SubArray) must also avoid scalar iteration. + @testset "$p-norm(view, $sz x $T)" for sz in [(5,), (5, 5), (4, 4, 4)], + p in Any[0, 1, 2, Inf], + T in eltypes + if T == Int8 + continue + end + if !in(float(real(T)), eltypes) + continue + end + range = real(T) <: Integer ? (T.(1:10)) : T + arr = rand(range, sz) + indices = map(d -> 2:d-1, sz) + @test compare(x -> norm(view(x, indices...), p), AT, arr) + end + @testset "normalize(view, $T)" for T in eltypes + if !in(float(real(T)), eltypes) + continue + end + range = real(T) <: Integer ? (T.(1:10)) : T + arr = rand(range, 10) + @test compare(x -> normalize(view(x, 2:9)), AT, arr) + end end @testsuite "linalg/NaN_false" (AT, eltypes)->begin