diff --git a/Project.toml b/Project.toml index ddaab1f9..9147b6f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUCompiler" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "1.13.0" +version = "1.13.1" authors = ["Tim Besard "] [workspace] diff --git a/src/ptx.jl b/src/ptx.jl index ebfce4f7..d2c97b21 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -263,9 +263,11 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), @dispose pb=NewPMPassBuilder() begin register!(pb, NVVMReflectPass()) register!(pb, PTXFDivFastPass()) + register!(pb, PTXFSqrtFastPass()) add!(pb, NVVMReflectPass()) add!(pb, PTXFDivFastPass()) + add!(pb, PTXFSqrtFastPass()) add!(pb, NewPMFunctionPassManager()) do fpm # needed by GemmKernels.jl-like code @@ -569,16 +571,27 @@ function f32_ftz(f::LLVM.Function) return false end -# Rewrite `afn`-flagged `fdiv` to NVPTX' fast lowerings. `apply_fastmath!` -# propagates job-wide `target.fastmath=true` as per-instruction `afn`, so the -# single flag check covers both per-call `@fastmath` and the job toggle. We -# emit NVPTX intrinsics directly (rather than libdevice `__nv_*`) so this -# doesn't depend on which libdevice symbols got linked in. +# Both passes below rewrite `afn`-flagged ops to NVPTX' fast lowerings. +# `apply_fastmath!` propagates job-wide `target.fastmath=true` as per- +# instruction `afn`, so a single flag check covers both per-call `@fastmath` +# and the job toggle. We emit NVPTX intrinsics by name (rather than libdevice +# `__nv_*`) so this doesn't depend on which libdevice symbols got linked in. # -# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`. Redundant on LLVM 21+, where -# `getDivF32Level` honors `afn`; LLVM 18 only consults -# `TargetMachine.Options.UnsafeFPMath`, which is unreachable through LLVM.jl. -# - f64 → `rcp.approx.ftz.d` + one Newton step (NVPTX has no fast f64 fdiv). +# Both passes are temporary backports for LLVM 18: +# - `PTXFSqrtFastPass` is fully redundant on LLVM 21+: `usePrecSqrtF32` then +# honors the per-instruction `afn` and the function `unsafe-fp-math` +# attribute, so `DAGCombiner::visitFSQRT` → `NVPTXTargetLowering::getSqrtEstimate` +# emits the f32 `sqrt.approx{,.ftz}` and f64 `rcp(rsqrt(x))` sequences +# itself. LLVM 18's `usePrecSqrtF32` only consults `TargetMachine.Options.UnsafeFPMath`, +# which is unreachable through LLVM.jl. +# - `PTXFDivFastPass`'s f32 path is similarly redundant on LLVM 21+; +# `getDivF32Level` there honors `afn` + the function attribute. The f64 +# path stays needed until NVPTX gains a `getRecipEstimate` hook (filed +# upstream). + +# Rewrite `afn`-flagged `fdiv`: +# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`. +# - f64 → `rcp.approx.ftz.d` + one Newton step (no native fast f64 fdiv). function ptx_fdiv_fast!(mod::LLVM.Module) changed = false @tracepoint "ptx-fdiv-fast" begin @@ -605,9 +618,9 @@ function ptx_fdiv_fast!(mod::LLVM.Module) f32_ft = LLVM.FunctionType(f32, [f32, f32]) div_f32 = declare("llvm.nvvm.div.approx.f", f32_ft) div_f32_ftz = declare("llvm.nvvm.div.approx.ftz.f", f32_ft) - rcp_ft = LLVM.FunctionType(f64, [f64]) - rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", rcp_ft) - fma_ft = LLVM.FunctionType(f64, [f64, f64, f64]) + f64_ft1 = LLVM.FunctionType(f64, [f64]) + rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", f64_ft1) + fma_ft = LLVM.FunctionType(f64, [f64, f64, f64]) fma_f64 = declare("llvm.fma.f64", fma_ft) one_f64 = ConstantFP(f64, 1.0) @@ -617,11 +630,10 @@ function ptx_fdiv_fast!(mod::LLVM.Module) position!(builder, inst) replacement = if is_f32 - # TODO: drop f32 path once we require LLVM 21+. f = LLVM.parent(LLVM.parent(inst)) call!(builder, f32_ft, f32_ftz(f) ? div_f32_ftz : div_f32, [lhs, rhs]) else - inv_y = call!(builder, rcp_ft, rcp_f64, [rhs]) + inv_y = call!(builder, f64_ft1, rcp_f64, [rhs]) neg_rhs = fneg!(builder, rhs) # Newton refinement, matching CUDA.jl's `FastMath.inv_fast(::Float64)` e = call!(builder, fma_ft, fma_f64, [inv_y, neg_rhs, one_f64]) @@ -640,3 +652,62 @@ function ptx_fdiv_fast!(mod::LLVM.Module) return changed end PTXFDivFastPass() = NewPMModulePass("ptx-fdiv-fast", ptx_fdiv_fast!) + +# Rewrite `afn`-flagged `llvm.sqrt.f{32,64}`: +# - f32 → `llvm.nvvm.sqrt.approx{,.ftz}.f`. +# - f64 → `rcp.approx.ftz.d(rsqrt.approx.d(x))` (no native fast f64 sqrt). +function ptx_fsqrt_fast!(mod::LLVM.Module) + changed = false + @tracepoint "ptx-fsqrt-fast" begin + + f32 = LLVM.FloatType() + f64 = LLVM.DoubleType() + + to_replace = Tuple{LLVM.CallInst, Bool}[] + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + inst isa LLVM.CallInst || continue + callee = LLVM.called_operand(inst) + callee isa LLVM.Function || continue + name = LLVM.name(callee) + is_f32 = name == "llvm.sqrt.f32" + is_f64 = name == "llvm.sqrt.f64" + (is_f32 || is_f64) || continue + LLVM.fast_math(inst).afn || continue + push!(to_replace, (inst, is_f32)) + end + isempty(to_replace) && return false + + fns = functions(mod) + declare(name, ft) = haskey(fns, name) ? fns[name] : LLVM.Function(mod, name, ft) + f32_ft = LLVM.FunctionType(f32, [f32]) + sqrt_f32 = declare("llvm.nvvm.sqrt.approx.f", f32_ft) + sqrt_f32_ftz = declare("llvm.nvvm.sqrt.approx.ftz.f", f32_ft) + f64_ft = LLVM.FunctionType(f64, [f64]) + rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", f64_ft) + rsqrt_f64 = declare("llvm.nvvm.rsqrt.approx.d", f64_ft) + + @dispose builder=IRBuilder() begin + for (inst, is_f32) in to_replace + x = operands(inst)[1] + position!(builder, inst) + + replacement = if is_f32 + f = LLVM.parent(LLVM.parent(inst)) + call!(builder, f32_ft, f32_ftz(f) ? sqrt_f32_ftz : sqrt_f32, [x]) + else + # No native fast f64 sqrt; emit the same `rcp(rsqrt(x))` + # sequence NVPTX' `getSqrtEstimate` would have used. + rsqrt = call!(builder, f64_ft, rsqrt_f64, [x]) + call!(builder, f64_ft, rcp_f64, [rsqrt]) + end + + replace_uses!(inst, replacement) + erase!(inst) + changed = true + end + end + + end # @tracepoint + return changed +end +PTXFSqrtFastPass() = NewPMModulePass("ptx-fsqrt-fast", ptx_fsqrt_fast!) diff --git a/test/ptx.jl b/test/ptx.jl index 34466ef3..b411e533 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -446,11 +446,12 @@ end PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}) end - # with fastmath, the entry function carries the attributes, the sqrt call - # picks up fast-math flags, and PTX selects the approx+ftz variant. + # with fastmath, the entry function carries the attributes, and + # `PTXFSqrtFastPass` rewrites the `afn`-flagged `llvm.sqrt.f32` to the + # NVPTX approx-ftz intrinsic; PTX then selects the approx+ftz variant. @test @filecheck begin @check_label "define void @{{(julia|j)_kernel_[0-9]+}}" - @check "call fast float @llvm.sqrt.f32" + @check "call float @llvm.nvvm.sqrt.approx.ftz.f" @check "\"denormal-fp-math-f32\"=\"preserve-sign,preserve-sign\"" @check "\"unsafe-fp-math\"=\"true\"" PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; @@ -462,36 +463,61 @@ end end end -@testset "fastmath division" begin - # `PTXFDivFastPass` rewrites `afn`-flagged fdiv. f32 → `div.approx{,.ftz}.f32` - # (filling in for LLVM 18, whose `getDivF32Level` doesn't honor per-call - # `afn`); f64 → `rcp.approx.ftz.d` + Newton refinement (NVPTX has no fast - # f64 fdiv lowering). Job-wide `fastmath=true` reaches this through the - # per-instruction flags `apply_fastmath!` stamps in `finish_linked_module!`. - mod_fast = @eval module $(gensym()) - kernel_f32(x::Float32, y::Float32) = @fastmath x / y - kernel_f64(x::Float64, y::Float64) = @fastmath x / y - end - mod_precise = @eval module $(gensym()) - kernel_f32(x::Float32, y::Float32) = x / y - kernel_f64(x::Float64, y::Float64) = x / y +@testset "fastmath fdiv/fsqrt" begin + # `PTXFDivFastPass` rewrites `afn`-flagged fdiv/`llvm.sqrt.*`. f32 ops → + # `*.approx{,.ftz}.f32` (filling in for LLVM 18, whose + # `getDivF32Level`/`usePrecSqrtF32` don't honor per-call `afn`). f64 + # fdiv → `rcp.approx.ftz.d` + Newton refinement; f64 sqrt → + # `rcp.approx.ftz.d(rsqrt.approx.d(x))` (NVPTX has no native fast f64 + # fdiv/sqrt). Job-wide `fastmath=true` reaches this through the per- + # instruction `afn` flags `apply_fastmath!` stamps in + # `finish_linked_module!`. + mod = @eval module $(gensym()) + fdiv32_fast(x::Float32, y::Float32) = @fastmath x / y + fdiv32(x::Float32, y::Float32) = x / y + fdiv64_fast(x::Float64, y::Float64) = @fastmath x / y + fdiv64(x::Float64, y::Float64) = x / y + fsqrt32_fast(x::Float32) = @fastmath sqrt(x) + fsqrt32(x::Float32) = sqrt(x) + fsqrt64_fast(x::Float64) = @fastmath sqrt(x) + fsqrt64(x::Float64) = sqrt(x) end @test @filecheck begin @check "div.approx.f32" - PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32}) + PTX.code_native(mod.fdiv32_fast, Tuple{Float32, Float32}) end @test @filecheck begin @check_not "div.approx" - PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32}) + PTX.code_native(mod.fdiv32, Tuple{Float32, Float32}) end @test @filecheck begin @check "rcp.approx.ftz.f64" - PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64}) + PTX.code_native(mod.fdiv64_fast, Tuple{Float64, Float64}) end @test @filecheck begin @check_not "rcp.approx" - PTX.code_native(mod_precise.kernel_f64, Tuple{Float64, Float64}) + PTX.code_native(mod.fdiv64, Tuple{Float64, Float64}) + end + + @test @filecheck begin + @check "sqrt.approx.f32" + PTX.code_native(mod.fsqrt32_fast, Tuple{Float32}) + end + @test @filecheck begin + @check "sqrt.rn.f32" + @check_not "sqrt.approx" + PTX.code_native(mod.fsqrt32, Tuple{Float32}) + end + @test @filecheck begin + @check "rsqrt.approx.f64" + @check "rcp.approx.ftz.f64" + PTX.code_native(mod.fsqrt64_fast, Tuple{Float64}) + end + @test @filecheck begin + @check "sqrt.rn.f64" + @check_not "rsqrt" + PTX.code_native(mod.fsqrt64, Tuple{Float64}) end end