diff --git a/src/ptx.jl b/src/ptx.jl index 66880850..2082335f 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -154,8 +154,10 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) @dispose pb=NewPMPassBuilder() begin register!(pb, NVVMReflectPass()) + register!(pb, PTXFDivFastPass()) add!(pb, NVVMReflectPass()) + add!(pb, PTXFDivFastPass()) add!(pb, NewPMFunctionPassManager()) do fpm # needed by GemmKernels.jl-like code @@ -486,3 +488,72 @@ function nvvm_reflect!(mod::LLVM.Module) return changed end NVVMReflectPass() = NewPMModulePass("custom-nvvm-reflect", nvvm_reflect!) + +# Triggered by the per-instruction `afn` fast-math flag or by target.fastmath=true. +# Float32 → __nv_fast_fdividef; Float64 → rcp.approx.ftz.d + Newton refinement. +function ptx_fdiv_fast!(mod::LLVM.Module) + job = current_job::CompilerJob + global_fastmath = job.config.target.fastmath + changed = false + @tracepoint "ptx-fdiv-fast" begin + + f32 = LLVM.FloatType() + f64 = LLVM.DoubleType() + + # Collect first to avoid mutation-during-iteration + to_replace = Tuple{LLVM.FDivInst, Bool}[] + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + inst isa LLVM.FDivInst || continue + typ = LLVM.value_type(inst) + is_f32 = typ == f32 + is_f64 = typ == f64 + (is_f32 || is_f64) || continue + fmf = LLVM.fast_math(inst) + (fmf.afn || global_fastmath) || continue + push!(to_replace, (inst, is_f32)) + end + + isempty(to_replace) && return false + + # Hoist all declarations and constants — looked up once, reused across all replacements. + fns = functions(mod) + f32_ft = LLVM.FunctionType(f32, [f32, f32]) + f32_fn = haskey(fns, "__nv_fast_fdividef") ? + fns["__nv_fast_fdividef"] : LLVM.Function(mod, "__nv_fast_fdividef", f32_ft) + # Declare by name so LLVM keeps the exact (non-overloaded) intrinsic name; + # LLVM.Intrinsic + type params would mangle to *.f64, unrecognized by the NVPTX backend. + rcp_ft = LLVM.FunctionType(f64, [f64]) + rcp_fn = haskey(fns, "llvm.nvvm.rcp.approx.ftz.d") ? + fns["llvm.nvvm.rcp.approx.ftz.d"] : LLVM.Function(mod, "llvm.nvvm.rcp.approx.ftz.d", rcp_ft) + fma_ft = LLVM.FunctionType(f64, [f64, f64, f64]) + fma_fn = haskey(fns, "llvm.fma.f64") ? + fns["llvm.fma.f64"] : LLVM.Function(mod, "llvm.fma.f64", fma_ft) + one_f64 = ConstantFP(f64, 1.0) + + @dispose builder=IRBuilder() begin + for (inst, is_f32) in to_replace + lhs, rhs = operands(inst)[1], operands(inst)[2] + position!(builder, inst) + + replacement = if is_f32 + call!(builder, f32_ft, f32_fn, [lhs, rhs]) + else + inv_y = call!(builder, rcp_ft, rcp_fn, [rhs]) + neg_rhs = fneg!(builder, rhs) + # Newton refinement matching CUDA.jl's inv_fast(::Float64) + e = call!(builder, fma_ft, fma_fn, [inv_y, neg_rhs, one_f64]) + e = call!(builder, fma_ft, fma_fn, [e, e, e]) + inv_ref = call!(builder, fma_ft, fma_fn, [e, inv_y, inv_y]) + fmul!(builder, lhs, inv_ref) + end + + replace_uses!(inst, replacement) + erase!(inst) + changed = true + end + end + + end # @tracepoint + return changed +end +PTXFDivFastPass() = NewPMModulePass("ptx-fdiv-fast", ptx_fdiv_fast!) diff --git a/test/helpers/ptx.jl b/test/helpers/ptx.jl index e82416bc..a12dc01c 100644 --- a/test/helpers/ptx.jl +++ b/test/helpers/ptx.jl @@ -38,10 +38,11 @@ GPUCompiler.runtime_module(::PTXCompilerJob) = PTXTestRuntime function create_job(@nospecialize(func), @nospecialize(types); minthreads=nothing, maxthreads=nothing, blocks_per_sm=nothing, maxregs=nothing, + fastmath=false, kwargs...) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) - target = PTXCompilerTarget(; cap=v"7.0", minthreads, maxthreads, blocks_per_sm, maxregs) + target = PTXCompilerTarget(; cap=v"7.0", minthreads, maxthreads, blocks_per_sm, maxregs, fastmath) params = CompilerParams() config = CompilerConfig(target, params; kernel=false, config_kwargs...) CompilerJob(source, config), kwargs diff --git a/test/ptx.jl b/test/ptx.jl index 7010917a..eb5821ef 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -419,5 +419,46 @@ end PTX.code_native(devnull, mod.kernel, Tuple{Float32,Ptr{Float32}}) end +@testset "fastmath division" begin + mod_fast = @eval module $(gensym()) + function kernel_f32(x::Float32, y::Float32) + @fastmath x / y + end + function kernel_f64(x::Float64, y::Float64) + @fastmath x / y + end + end + + mod_precise = @eval module $(gensym()) + function kernel_f32(x::Float32, y::Float32) + x / y + end + function kernel_f64(x::Float64, y::Float64) + x / y + end + end + + @test @filecheck begin + @check "__nv_fast_fdividef" + PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32}) + end + + @test @filecheck begin + @check_not "__nv_fast_fdividef" + @check_not "div.approx" + PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32}) + end + + @test @filecheck begin + @check "rcp.approx.ftz.f64" + PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64}) + end + + @test @filecheck begin + @check_not "rcp.approx" + PTX.code_native(mod_precise.kernel_f64, Tuple{Float64, Float64}) + end +end + end end # NVPTX in LLVM.backends()