From a5dac76a00939a95f15b6b1b7d2660682ef8cefb Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 15 Apr 2026 09:36:34 +0200 Subject: [PATCH] Override ^(Float, Int64) to avoid Float64 widening Base.`^(::Union{Float16,Float32,Float64}, ::Int64)` widens Float32/16 through Float64 (via `power_by_squaring`), which breaks on backends without FP64 support like oneAPI on some devices, and leaves a runtime `pown` call in the generated code in all cases. --- lib/intrinsics/Project.toml | 2 +- lib/intrinsics/src/math.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/lib/intrinsics/Project.toml b/lib/intrinsics/Project.toml index 5f2dafaa..48f65a77 100644 --- a/lib/intrinsics/Project.toml +++ b/lib/intrinsics/Project.toml @@ -1,7 +1,7 @@ name = "SPIRVIntrinsics" uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8" authors = ["Tim Besard "] -version = "0.5.7" +version = "0.5.8" [deps] ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" diff --git a/lib/intrinsics/src/math.jl b/lib/intrinsics/src/math.jl index 39117bb1..18993055 100644 --- a/lib/intrinsics/src/math.jl +++ b/lib/intrinsics/src/math.jl @@ -1,5 +1,7 @@ # Math Functions +using Base: @assume_effects + # TODO: vector types const generic_types = [Float16, Float32, Float64] const generic_types_float = [Float32] @@ -183,6 +185,36 @@ end # pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y) @device_override Base.:(^)(x::Float64, y::Int32) = @builtin_ccall("pown", Float64, (Float64, Int32), x, y) +# Base's `^(::Union{Float16,Float32,Float64}, ::Int64)` widens Float32/16 +# through Float64 (broken on backends without FP64) and in all cases leaves a +# runtime `pown` in the generated code. Mark the override `:foldable` so +# literal expressions like `Float32(2)^(-32)` const-fold to a compile-time +# constant, and recurse into the existing `::Int32` overrides for the tail. +@device_override @assume_effects :foldable @inline function Base.:(^)(x::Float16, y::Int64) + y == -1 && return inv(x) + y == 0 && return one(x) + y == 1 && return x + y == 2 && return x * x + y == 3 && return x * x * x + x ^ (y % Int32) +end +@device_override @assume_effects :foldable @inline function Base.:(^)(x::Float32, y::Int64) + y == -1 && return inv(x) + y == 0 && return one(x) + y == 1 && return x + y == 2 && return x * x + y == 3 && return x * x * x + x ^ (y % Int32) +end +@device_override @assume_effects :foldable @inline function Base.:(^)(x::Float64, y::Int64) + y == -1 && return inv(x) + y == 0 && return one(x) + y == 1 && return x + y == 2 && return x * x + y == 3 && return x * x * x + x ^ (y % Int32) +end + # remquo(x::Float32{n}, y::Float32{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float32{n}, (Float32{n}, Float32{n}, Int32{n} *), x, y, quo) # remquo(x::Float32, y::Float32, Int32 *quo) = @builtin_ccall("remquo", Float32, (Float32, Float32, Int32 *), x::Float32, y, quo) # remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo)