Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/ptx.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# implementation of the GPUCompiler interfaces for generating PTX code

const NVPTX_LLVM_Backend_jll =
LazyModule("NVPTX_LLVM_Backend_jll",
UUID("ef6e0fe3-e6ef-59c0-bde6-4989574699e0"))


## target

export PTXCompilerTarget
Expand Down Expand Up @@ -289,6 +294,44 @@ function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
return entry
end

@unlocked function mcgen(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
mod::LLVM.Module, format=LLVM.API.LLVMAssemblyFile)
if !isavailable(NVPTX_LLVM_Backend_jll) || !NVPTX_LLVM_Backend_jll.is_available()
error("NVPTX LLVM back-end not loaded; cannot compile to PTX.")
end

target = job.config.target
filetype = if format == LLVM.API.LLVMAssemblyFile
"asm"
elseif format == LLVM.API.LLVMObjectFile
"obj"
else
error("Unsupported PTX output format $format")
end

input = tempname(cleanup=false) * ".bc"
output = tempname(cleanup=false) * (filetype == "asm" ? ".ptx" : ".cubin")
write(input, mod)

cmd = `$(NVPTX_LLVM_Backend_jll.llc()) $input
-mtriple=$(llvm_triple(target))
-mcpu=$(cpu_name(target))
-mattr=+ptx$(target.ptx.major)$(target.ptx.minor)
-filetype=$filetype
-o $output`
try
run(cmd)
catch
error("""Failed to compile to PTX with external llc.
If you think this is a bug, please file an issue and attach $(input).""")
end

code = filetype == "asm" ? read(output, String) : String(read(output))
rm(input)
rm(output)
return code
end

function llvm_debug_info(@nospecialize(job::CompilerJob{PTXCompilerTarget}))
# allow overriding the debug info from CUDA.jl
if job.config.target.debuginfo
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
NVPTX_LLVM_Backend_jll = "ef6e0fe3-e6ef-59c0-bde6-4989574699e0"
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
3 changes: 1 addition & 2 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ if :NVPTX in LLVM.backends()
@test @filecheck begin
@check_label ".visible .func {{(julia|j)_parent[0-9_]*}}"
@check "call.uni"
@check_same cond=(LLVM.version() >= v"21") "{{(julia|j)_child_}}"
@check_next cond=(LLVM.version() < v"21") "{{(julia|j)_child_}}"
@check_same "{{(julia|j)_child_}}"
PTX.code_native(mod.parent, Tuple{Int64})
end
end
Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
using ParallelTestRunner
import GPUCompiler, LLVM
using GPUCompiler, LLVM
using SPIRV_LLVM_Backend_jll, SPIRV_LLVM_Translator_jll, SPIRV_Tools_jll
using NVPTX_LLVM_Backend_jll

const init_code = quote
using GPUCompiler, LLVM
using SPIRV_LLVM_Backend_jll, SPIRV_LLVM_Translator_jll, SPIRV_Tools_jll
using NVPTX_LLVM_Backend_jll

# include all helpers
include(joinpath(@__DIR__, "helpers", "runtime.jl"))
Expand All @@ -28,12 +32,26 @@ if filter_tests!(testsuite, args)
end

if LLVM.is_asserts()
@warn "LLVM with assertions; skipping GCN tests"
delete!(testsuite, "gcn")
end
if VERSION < v"1.11"
@warn "Julia 1.11+ required for precompile tests; skipping"
delete!(testsuite, "ptx/precompile")
delete!(testsuite, "native/precompile")
end
if !SPIRV_LLVM_Backend_jll.is_available() || !SPIRV_LLVM_Translator_jll.is_available() || !SPIRV_Tools_jll.is_available()
@warn "SPIRV back-end not available; skipping SPIRV tests"
for key in collect(keys(testsuite))
startswith(key, "spirv") && delete!(testsuite, key)
end
end
if !NVPTX_LLVM_Backend_jll.is_available()
@warn "NVPTX back-end not available; skipping PTX tests"
for key in collect(keys(testsuite))
startswith(key, "ptx") && delete!(testsuite, key)
end
end
end

runtests(GPUCompiler, args; testsuite, init_code)
Loading