Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 69 additions & 110 deletions test/ptx.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
@testset "IR" begin

@testset "exceptions" begin
mod = @eval module $(gensym())
foobar() = throw(DivideError())
end
@test @filecheck begin
@check_label "define void @{{(julia|j)_foobar_[0-9]+}}"
# plain exceptions should get lowered to a call to the GPU run-time
# not a jl_throw referencing a jl_value_t representing the exception
# plain exceptions should get lowered to a call to the GPU run-time, not a
# jl_throw referencing a jl_value_t representing the exception
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true) do
@check_not "jl_throw"
@check "gpu_report_exception"

PTX.code_llvm(mod.foobar, Tuple{}; dump_module=true)
throw(DivideError())
end
end

Expand Down Expand Up @@ -41,62 +36,48 @@ end
end

@testset "property_annotations" begin
mod = @eval module $(gensym())
kernel() = return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true) do
@check_not "nvvm.annotations"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true)
return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true, kernel=true) do
@check_not "maxntid"
@check_not "reqntid"
@check_not "minctasm"
@check_not "maxnreg"
@check "nvvm.annotations"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true)
return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true, kernel=true, maxthreads=42) do
@check "maxntidx\", i32 42"
@check "maxntidy\", i32 1"
@check "maxntidz\", i32 1"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, maxthreads=42)
return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true, kernel=true, minthreads=42) do
@check "reqntidx\", i32 42"
@check "reqntidy\", i32 1"
@check "reqntidz\", i32 1"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, minthreads=42)
return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true, kernel=true, blocks_per_sm=42) do
@check "minctasm\", i32 42"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, blocks_per_sm=42)
return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true, kernel=true, maxregs=42) do
@check "maxnreg\", i32 42"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true, maxregs=42)
return
end
end

LLVM.version() >= v"8" && @testset "calling convention" begin
mod = @eval module $(gensym())
kernel() = return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true) do
@check_not "ptx_kernel"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true)
return
end

@test @filecheck begin
@test @filecheck PTX.code_llvm(Tuple{}; dump_module=true, kernel=true) do
@check "ptx_kernel"
PTX.code_llvm(mod.kernel, Tuple{}; dump_module=true, kernel=true)
return
end
end

Expand Down Expand Up @@ -423,101 +404,79 @@ end
end

@testset "fastmath" begin
# `fastmath=true` on the target should call `apply_fastmath!` from
# `fastmath=true` on the target calls `apply_fastmath!` from
# `finish_linked_module!`, stamping `unsafe-fp-math` + fast-math flags on
# every FP op, and additionally setting `denormal-fp-math-f32` so NVPTX
# picks the FTZ variants. Verify both pieces — IR-level attributes and
# PTX-level instruction selection — with and without the flag.
mod = @eval module $(gensym())
kernel(x, out) = (unsafe_store!(out, sqrt(unsafe_load(x))); return)
end

# without fastmath, no unsafe-fp-math / f32-FTZ, and sqrt stays precise
@test @filecheck begin
@check_label "define void @{{(julia|j)_kernel_[0-9]+}}"
@check_not "unsafe-fp-math"
@check_not "denormal-fp-math-f32"
@check "call float @llvm.sqrt.f32"
PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; dump_module=true)
end
@test @filecheck begin
@check "sqrt.rn.f32"
@check_not "sqrt.approx"
PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}})
end

# with fastmath, the entry function carries the attributes, and
# `PTXFSqrtFastPass` rewrites the `afn`-flagged `llvm.sqrt.f32` to the
# NVPTX approx-ftz intrinsic; PTX then selects the approx+ftz variant.
@test @filecheck begin
@check_label "define void @{{(julia|j)_kernel_[0-9]+}}"
@check "call float @llvm.nvvm.sqrt.approx.ftz.f"
@check "\"denormal-fp-math-f32\"=\"preserve-sign,preserve-sign\""
@check "\"unsafe-fp-math\"=\"true\""
PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}};
dump_module=true, fastmath=true)
end
@test @filecheck begin
@check "sqrt.approx.ftz.f32"
PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; fastmath=true)
# every FP op and setting `denormal-fp-math-f32` so NVPTX picks the FTZ
# variants. Sweep both axes (IR attributes + PTX selection) across the
# flag.
for fastmath in (false, true)
# IR attributes + sqrt rewrite by PTXFSqrtFastPass when fastmath.
@test @filecheck PTX.code_llvm(Tuple{Ptr{Float32},Ptr{Float32}};
dump_module=true, fastmath) do x, out
@check_not cond=!fastmath "unsafe-fp-math"
@check_not cond=!fastmath "denormal-fp-math-f32"
@check cond=!fastmath "call float @llvm.sqrt.f32"
@check cond=fastmath "call float @llvm.nvvm.sqrt.approx.ftz.f"
@check cond=fastmath "\"denormal-fp-math-f32\"=\"preserve-sign,preserve-sign\""
@check cond=fastmath "\"unsafe-fp-math\"=\"true\""
unsafe_store!(out, sqrt(unsafe_load(x)))
return
end
# PTX-level selection.
@test @filecheck PTX.code_native(Tuple{Ptr{Float32},Ptr{Float32}}; fastmath) do x, out
@check cond=fastmath "sqrt.approx.ftz.f32"
@check cond=!fastmath "sqrt.rn.f32"
@check_not cond=!fastmath "sqrt.approx"
unsafe_store!(out, sqrt(unsafe_load(x)))
return
end
end
end

@testset "fastmath fdiv/fsqrt" begin
# `PTXFDivFastPass` rewrites `afn`-flagged fdiv/`llvm.sqrt.*`. f32 ops →
# `*.approx{,.ftz}.f32` (filling in for LLVM 18, whose
# `getDivF32Level`/`usePrecSqrtF32` don't honor per-call `afn`). f64
# fdiv → `rcp.approx.ftz.d` + Newton refinement; f64 sqrt →
# `rcp.approx.ftz.d(rsqrt.approx.d(x))` (NVPTX has no native fast f64
# fdiv/sqrt). Job-wide `fastmath=true` reaches this through the per-
# instruction `afn` flags `apply_fastmath!` stamps in
# `finish_linked_module!`.
mod = @eval module $(gensym())
fdiv32_fast(x::Float32, y::Float32) = @fastmath x / y
fdiv32(x::Float32, y::Float32) = x / y
fdiv64_fast(x::Float64, y::Float64) = @fastmath x / y
fdiv64(x::Float64, y::Float64) = x / y
fsqrt32_fast(x::Float32) = @fastmath sqrt(x)
fsqrt32(x::Float32) = sqrt(x)
fsqrt64_fast(x::Float64) = @fastmath sqrt(x)
fsqrt64(x::Float64) = sqrt(x)
end

@test @filecheck begin
# `PTXFDivFastPass` and `PTXFSqrtFastPass` rewrite `afn`-flagged fdiv /
# `llvm.sqrt.*`. f32 ops → `*.approx{,.ftz}.f32` (filling in for LLVM 18,
# whose `getDivF32Level`/`usePrecSqrtF32` don't honor per-call `afn`).
# f64 fdiv → `rcp.approx.ftz.d` + Newton refinement; f64 sqrt →
# `rcp.approx.ftz.d(rsqrt.approx.d(x))` (NVPTX has no native fast f64).
# Per-call `@fastmath` is what we test here; the job-wide path is
# covered by the testset above.

@test @filecheck PTX.code_native(Tuple{Float32, Float32}) do x, y
@check "div.approx.f32"
PTX.code_native(mod.fdiv32_fast, Tuple{Float32, Float32})
@fastmath x / y
end
@test @filecheck begin
@test @filecheck PTX.code_native(Tuple{Float32, Float32}) do x, y
@check_not "div.approx"
PTX.code_native(mod.fdiv32, Tuple{Float32, Float32})
x / y
end
@test @filecheck begin
@test @filecheck PTX.code_native(Tuple{Float64, Float64}) do x, y
@check "rcp.approx.ftz.f64"
PTX.code_native(mod.fdiv64_fast, Tuple{Float64, Float64})
@fastmath x / y
end
@test @filecheck begin
@test @filecheck PTX.code_native(Tuple{Float64, Float64}) do x, y
@check_not "rcp.approx"
PTX.code_native(mod.fdiv64, Tuple{Float64, Float64})
x / y
end

@test @filecheck begin
@test @filecheck PTX.code_native(Tuple{Float32}) do x
@check "sqrt.approx.f32"
PTX.code_native(mod.fsqrt32_fast, Tuple{Float32})
@fastmath sqrt(x)
end
@test @filecheck begin
@test @filecheck PTX.code_native(Tuple{Float32}) do x
@check "sqrt.rn.f32"
@check_not "sqrt.approx"
PTX.code_native(mod.fsqrt32, Tuple{Float32})
sqrt(x)
end
@test @filecheck begin
@test @filecheck PTX.code_native(Tuple{Float64}) do x
@check "rsqrt.approx.f64"
@check "rcp.approx.ftz.f64"
PTX.code_native(mod.fsqrt64_fast, Tuple{Float64})
@fastmath sqrt(x)
end
@test @filecheck begin
@test @filecheck PTX.code_native(Tuple{Float64}) do x
@check "sqrt.rn.f64"
@check_not "rsqrt"
PTX.code_native(mod.fsqrt64, Tuple{Float64})
sqrt(x)
end
end

Expand Down
Loading