Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,10 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
# TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450)
@dispose pb=NewPMPassBuilder() begin
register!(pb, NVVMReflectPass())
register!(pb, PTXFDivFastPass())

add!(pb, NVVMReflectPass())
add!(pb, PTXFDivFastPass())

add!(pb, NewPMFunctionPassManager()) do fpm
# needed by GemmKernels.jl-like code
Expand Down Expand Up @@ -486,3 +488,72 @@ function nvvm_reflect!(mod::LLVM.Module)
return changed
end
NVVMReflectPass() = NewPMModulePass("custom-nvvm-reflect", nvvm_reflect!)

# Triggered by the per-instruction `afn` fast-math flag or by target.fastmath=true.
# Float32 → __nv_fast_fdividef; Float64 → rcp.approx.ftz.d + Newton refinement.
function ptx_fdiv_fast!(mod::LLVM.Module)
job = current_job::CompilerJob
global_fastmath = job.config.target.fastmath
changed = false
@tracepoint "ptx-fdiv-fast" begin

f32 = LLVM.FloatType()
f64 = LLVM.DoubleType()

# Collect first to avoid mutation-during-iteration
to_replace = Tuple{LLVM.FDivInst, Bool}[]
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
inst isa LLVM.FDivInst || continue
typ = LLVM.value_type(inst)
is_f32 = typ == f32
is_f64 = typ == f64
(is_f32 || is_f64) || continue
fmf = LLVM.fast_math(inst)
(fmf.afn || global_fastmath) || continue
push!(to_replace, (inst, is_f32))
end

isempty(to_replace) && return false

# Hoist all declarations and constants — looked up once, reused across all replacements.
fns = functions(mod)
f32_ft = LLVM.FunctionType(f32, [f32, f32])
f32_fn = haskey(fns, "__nv_fast_fdividef") ?
fns["__nv_fast_fdividef"] : LLVM.Function(mod, "__nv_fast_fdividef", f32_ft)
# Declare by name so LLVM keeps the exact (non-overloaded) intrinsic name;
# LLVM.Intrinsic + type params would mangle to *.f64, unrecognized by the NVPTX backend.
rcp_ft = LLVM.FunctionType(f64, [f64])
rcp_fn = haskey(fns, "llvm.nvvm.rcp.approx.ftz.d") ?
fns["llvm.nvvm.rcp.approx.ftz.d"] : LLVM.Function(mod, "llvm.nvvm.rcp.approx.ftz.d", rcp_ft)
fma_ft = LLVM.FunctionType(f64, [f64, f64, f64])
fma_fn = haskey(fns, "llvm.fma.f64") ?
fns["llvm.fma.f64"] : LLVM.Function(mod, "llvm.fma.f64", fma_ft)
one_f64 = ConstantFP(f64, 1.0)

@dispose builder=IRBuilder() begin
for (inst, is_f32) in to_replace
lhs, rhs = operands(inst)[1], operands(inst)[2]
position!(builder, inst)

replacement = if is_f32
call!(builder, f32_ft, f32_fn, [lhs, rhs])
else
inv_y = call!(builder, rcp_ft, rcp_fn, [rhs])
neg_rhs = fneg!(builder, rhs)
# Newton refinement matching CUDA.jl's inv_fast(::Float64)
e = call!(builder, fma_ft, fma_fn, [inv_y, neg_rhs, one_f64])
e = call!(builder, fma_ft, fma_fn, [e, e, e])
inv_ref = call!(builder, fma_ft, fma_fn, [e, inv_y, inv_y])
fmul!(builder, lhs, inv_ref)
end

replace_uses!(inst, replacement)
erase!(inst)
changed = true
end
end

end # @tracepoint
return changed
end
PTXFDivFastPass() = NewPMModulePass("ptx-fdiv-fast", ptx_fdiv_fast!)
3 changes: 2 additions & 1 deletion test/helpers/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ GPUCompiler.runtime_module(::PTXCompilerJob) = PTXTestRuntime
function create_job(@nospecialize(func), @nospecialize(types);
minthreads=nothing, maxthreads=nothing,
blocks_per_sm=nothing, maxregs=nothing,
fastmath=false,
kwargs...)
config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS)
source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter())
target = PTXCompilerTarget(; cap=v"7.0", minthreads, maxthreads, blocks_per_sm, maxregs)
target = PTXCompilerTarget(; cap=v"7.0", minthreads, maxthreads, blocks_per_sm, maxregs, fastmath)
params = CompilerParams()
config = CompilerConfig(target, params; kernel=false, config_kwargs...)
CompilerJob(source, config), kwargs
Expand Down
41 changes: 41 additions & 0 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,5 +419,46 @@ end
PTX.code_native(devnull, mod.kernel, Tuple{Float32,Ptr{Float32}})
end

@testset "fastmath division" begin
mod_fast = @eval module $(gensym())
function kernel_f32(x::Float32, y::Float32)
@fastmath x / y
end
function kernel_f64(x::Float64, y::Float64)
@fastmath x / y
end
end

mod_precise = @eval module $(gensym())
function kernel_f32(x::Float32, y::Float32)
x / y
end
function kernel_f64(x::Float64, y::Float64)
x / y
end
end

@test @filecheck begin
@check "__nv_fast_fdividef"
PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32})
end

@test @filecheck begin
@check_not "__nv_fast_fdividef"
@check_not "div.approx"
PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32})
end

@test @filecheck begin
@check "rcp.approx.ftz.f64"
PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64})
end

@test @filecheck begin
@check_not "rcp.approx"
PTX.code_native(mod_precise.kernel_f64, Tuple{Float64, Float64})
end
end

end
end # NVPTX in LLVM.backends()