Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GPUCompiler"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "1.13.0"
version = "1.13.1"
authors = ["Tim Besard <tim.besard@gmail.com>"]

[workspace]
Expand Down
99 changes: 85 additions & 14 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,11 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
@dispose pb=NewPMPassBuilder() begin
register!(pb, NVVMReflectPass())
register!(pb, PTXFDivFastPass())
register!(pb, PTXFSqrtFastPass())

add!(pb, NVVMReflectPass())
add!(pb, PTXFDivFastPass())
add!(pb, PTXFSqrtFastPass())

add!(pb, NewPMFunctionPassManager()) do fpm
# needed by GemmKernels.jl-like code
Expand Down Expand Up @@ -569,16 +571,27 @@ function f32_ftz(f::LLVM.Function)
return false
end

# Rewrite `afn`-flagged `fdiv` to NVPTX' fast lowerings. `apply_fastmath!`
# propagates job-wide `target.fastmath=true` as per-instruction `afn`, so the
# single flag check covers both per-call `@fastmath` and the job toggle. We
# emit NVPTX intrinsics directly (rather than libdevice `__nv_*`) so this
# doesn't depend on which libdevice symbols got linked in.
# Both passes below rewrite `afn`-flagged ops to NVPTX' fast lowerings.
# `apply_fastmath!` propagates job-wide `target.fastmath=true` as per-
# instruction `afn`, so a single flag check covers both per-call `@fastmath`
# and the job toggle. We emit NVPTX intrinsics by name (rather than libdevice
# `__nv_*`) so this doesn't depend on which libdevice symbols got linked in.
#
# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`. Redundant on LLVM 21+, where
# `getDivF32Level` honors `afn`; LLVM 18 only consults
# `TargetMachine.Options.UnsafeFPMath`, which is unreachable through LLVM.jl.
# - f64 → `rcp.approx.ftz.d` + one Newton step (NVPTX has no fast f64 fdiv).
# Both passes are temporary backports for LLVM 18:
# - `PTXFSqrtFastPass` is fully redundant on LLVM 21+: `usePrecSqrtF32` then
# honors the per-instruction `afn` and the function `unsafe-fp-math`
# attribute, so `DAGCombiner::visitFSQRT` → `NVPTXTargetLowering::getSqrtEstimate`
# emits the f32 `sqrt.approx{,.ftz}` and f64 `rcp(rsqrt(x))` sequences
# itself. LLVM 18's `usePrecSqrtF32` only consults `TargetMachine.Options.UnsafeFPMath`,
# which is unreachable through LLVM.jl.
# - `PTXFDivFastPass`'s f32 path is similarly redundant on LLVM 21+;
# `getDivF32Level` there honors `afn` + the function attribute. The f64
# path stays needed until NVPTX gains a `getRecipEstimate` hook (filed
# upstream).

# Rewrite `afn`-flagged `fdiv`:
# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`.
# - f64 → `rcp.approx.ftz.d` + one Newton step (no native fast f64 fdiv).
function ptx_fdiv_fast!(mod::LLVM.Module)
changed = false
@tracepoint "ptx-fdiv-fast" begin
Expand All @@ -605,9 +618,9 @@ function ptx_fdiv_fast!(mod::LLVM.Module)
f32_ft = LLVM.FunctionType(f32, [f32, f32])
div_f32 = declare("llvm.nvvm.div.approx.f", f32_ft)
div_f32_ftz = declare("llvm.nvvm.div.approx.ftz.f", f32_ft)
rcp_ft = LLVM.FunctionType(f64, [f64])
rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", rcp_ft)
fma_ft = LLVM.FunctionType(f64, [f64, f64, f64])
f64_ft1 = LLVM.FunctionType(f64, [f64])
rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", f64_ft1)
fma_ft = LLVM.FunctionType(f64, [f64, f64, f64])
fma_f64 = declare("llvm.fma.f64", fma_ft)
one_f64 = ConstantFP(f64, 1.0)

Expand All @@ -617,11 +630,10 @@ function ptx_fdiv_fast!(mod::LLVM.Module)
position!(builder, inst)

replacement = if is_f32
# TODO: drop f32 path once we require LLVM 21+.
f = LLVM.parent(LLVM.parent(inst))
call!(builder, f32_ft, f32_ftz(f) ? div_f32_ftz : div_f32, [lhs, rhs])
else
inv_y = call!(builder, rcp_ft, rcp_f64, [rhs])
inv_y = call!(builder, f64_ft1, rcp_f64, [rhs])
neg_rhs = fneg!(builder, rhs)
# Newton refinement, matching CUDA.jl's `FastMath.inv_fast(::Float64)`
e = call!(builder, fma_ft, fma_f64, [inv_y, neg_rhs, one_f64])
Expand All @@ -640,3 +652,62 @@ function ptx_fdiv_fast!(mod::LLVM.Module)
return changed
end
PTXFDivFastPass() = NewPMModulePass("ptx-fdiv-fast", ptx_fdiv_fast!)

# Rewrite `afn`-flagged `llvm.sqrt.f{32,64}`:
# - f32 → `llvm.nvvm.sqrt.approx{,.ftz}.f`.
# - f64 → `rcp.approx.ftz.d(rsqrt.approx.d(x))` (no native fast f64 sqrt).
function ptx_fsqrt_fast!(mod::LLVM.Module)
changed = false
@tracepoint "ptx-fsqrt-fast" begin

f32 = LLVM.FloatType()
f64 = LLVM.DoubleType()

to_replace = Tuple{LLVM.CallInst, Bool}[]
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
inst isa LLVM.CallInst || continue
callee = LLVM.called_operand(inst)
callee isa LLVM.Function || continue
name = LLVM.name(callee)
is_f32 = name == "llvm.sqrt.f32"
is_f64 = name == "llvm.sqrt.f64"
(is_f32 || is_f64) || continue
LLVM.fast_math(inst).afn || continue
push!(to_replace, (inst, is_f32))
end
isempty(to_replace) && return false

fns = functions(mod)
declare(name, ft) = haskey(fns, name) ? fns[name] : LLVM.Function(mod, name, ft)
f32_ft = LLVM.FunctionType(f32, [f32])
sqrt_f32 = declare("llvm.nvvm.sqrt.approx.f", f32_ft)
sqrt_f32_ftz = declare("llvm.nvvm.sqrt.approx.ftz.f", f32_ft)
f64_ft = LLVM.FunctionType(f64, [f64])
rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", f64_ft)
rsqrt_f64 = declare("llvm.nvvm.rsqrt.approx.d", f64_ft)

@dispose builder=IRBuilder() begin
for (inst, is_f32) in to_replace
x = operands(inst)[1]
position!(builder, inst)

replacement = if is_f32
f = LLVM.parent(LLVM.parent(inst))
call!(builder, f32_ft, f32_ftz(f) ? sqrt_f32_ftz : sqrt_f32, [x])
else
# No native fast f64 sqrt; emit the same `rcp(rsqrt(x))`
# sequence NVPTX' `getSqrtEstimate` would have used.
rsqrt = call!(builder, f64_ft, rsqrt_f64, [x])
call!(builder, f64_ft, rcp_f64, [rsqrt])
end

replace_uses!(inst, replacement)
erase!(inst)
changed = true
end
end

end # @tracepoint
return changed
end
PTXFSqrtFastPass() = NewPMModulePass("ptx-fsqrt-fast", ptx_fsqrt_fast!)
66 changes: 46 additions & 20 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,12 @@ end
PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}})
end

# with fastmath, the entry function carries the attributes, the sqrt call
# picks up fast-math flags, and PTX selects the approx+ftz variant.
# with fastmath, the entry function carries the attributes, and
# `PTXFSqrtFastPass` rewrites the `afn`-flagged `llvm.sqrt.f32` to the
# NVPTX approx-ftz intrinsic; PTX then selects the approx+ftz variant.
@test @filecheck begin
@check_label "define void @{{(julia|j)_kernel_[0-9]+}}"
@check "call fast float @llvm.sqrt.f32"
@check "call float @llvm.nvvm.sqrt.approx.ftz.f"
@check "\"denormal-fp-math-f32\"=\"preserve-sign,preserve-sign\""
@check "\"unsafe-fp-math\"=\"true\""
PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}};
Expand All @@ -462,36 +463,61 @@ end
end
end

@testset "fastmath division" begin
# `PTXFDivFastPass` rewrites `afn`-flagged fdiv. f32 → `div.approx{,.ftz}.f32`
# (filling in for LLVM 18, whose `getDivF32Level` doesn't honor per-call
# `afn`); f64 → `rcp.approx.ftz.d` + Newton refinement (NVPTX has no fast
# f64 fdiv lowering). Job-wide `fastmath=true` reaches this through the
# per-instruction flags `apply_fastmath!` stamps in `finish_linked_module!`.
mod_fast = @eval module $(gensym())
kernel_f32(x::Float32, y::Float32) = @fastmath x / y
kernel_f64(x::Float64, y::Float64) = @fastmath x / y
end
mod_precise = @eval module $(gensym())
kernel_f32(x::Float32, y::Float32) = x / y
kernel_f64(x::Float64, y::Float64) = x / y
@testset "fastmath fdiv/fsqrt" begin
# `PTXFDivFastPass` rewrites `afn`-flagged fdiv/`llvm.sqrt.*`. f32 ops →
# `*.approx{,.ftz}.f32` (filling in for LLVM 18, whose
# `getDivF32Level`/`usePrecSqrtF32` don't honor per-call `afn`). f64
# fdiv → `rcp.approx.ftz.d` + Newton refinement; f64 sqrt →
# `rcp.approx.ftz.d(rsqrt.approx.d(x))` (NVPTX has no native fast f64
# fdiv/sqrt). Job-wide `fastmath=true` reaches this through the per-
# instruction `afn` flags `apply_fastmath!` stamps in
# `finish_linked_module!`.
mod = @eval module $(gensym())
fdiv32_fast(x::Float32, y::Float32) = @fastmath x / y
fdiv32(x::Float32, y::Float32) = x / y
fdiv64_fast(x::Float64, y::Float64) = @fastmath x / y
fdiv64(x::Float64, y::Float64) = x / y
fsqrt32_fast(x::Float32) = @fastmath sqrt(x)
fsqrt32(x::Float32) = sqrt(x)
fsqrt64_fast(x::Float64) = @fastmath sqrt(x)
fsqrt64(x::Float64) = sqrt(x)
end

@test @filecheck begin
@check "div.approx.f32"
PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32})
PTX.code_native(mod.fdiv32_fast, Tuple{Float32, Float32})
end
@test @filecheck begin
@check_not "div.approx"
PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32})
PTX.code_native(mod.fdiv32, Tuple{Float32, Float32})
end
@test @filecheck begin
@check "rcp.approx.ftz.f64"
PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64})
PTX.code_native(mod.fdiv64_fast, Tuple{Float64, Float64})
end
@test @filecheck begin
@check_not "rcp.approx"
PTX.code_native(mod_precise.kernel_f64, Tuple{Float64, Float64})
PTX.code_native(mod.fdiv64, Tuple{Float64, Float64})
end

@test @filecheck begin
@check "sqrt.approx.f32"
PTX.code_native(mod.fsqrt32_fast, Tuple{Float32})
end
@test @filecheck begin
@check "sqrt.rn.f32"
@check_not "sqrt.approx"
PTX.code_native(mod.fsqrt32, Tuple{Float32})
end
@test @filecheck begin
@check "rsqrt.approx.f64"
@check "rcp.approx.ftz.f64"
PTX.code_native(mod.fsqrt64_fast, Tuple{Float64})
end
@test @filecheck begin
@check "sqrt.rn.f64"
@check_not "rsqrt"
PTX.code_native(mod.fsqrt64, Tuple{Float64})
end
end

Expand Down
Loading