diff --git a/Project.toml b/Project.toml index b982cf5d..c42cf446 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] ExprTools = "0.1" InteractiveUtils = "1" -LLVM = "9.3" +LLVM = "9.5" Libdl = "1" Logging = "1" PrecompileTools = "1" diff --git a/src/interface.jl b/src/interface.jl index fc65c888..d7eb3406 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -37,6 +37,10 @@ end llvm_datalayout(target::AbstractCompilerTarget) = DataLayout(llvm_machine(target)) +# a custom `TargetTransformInfo` for targets that don't have (or can't rely on) a +# `TargetMachine`-supplied TTI. Return `nothing` to fall back to LLVM's native TTI. +llvm_targetinfo(@nospecialize(target::AbstractCompilerTarget)) = nothing + # the target's datalayout, with Julia's non-integral address spaces added to it function julia_datalayout(@nospecialize(target::AbstractCompilerTarget)) dl = llvm_datalayout(target) diff --git a/src/metal.jl b/src/metal.jl index 569a6cd1..77c9ff89 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -1,5 +1,41 @@ # implementation of the GPUCompiler interfaces for generating Metal code + +## target info + +# Metal has no target machine, so provide our own TTI +struct MetalTTI <: LLVM.AbstractTargetTransformInfo end + +# teache LLVM about Metal's address-space hierarchy: +# 0: Generic 1: Device 2: Constant +# 3: ThreadGroup 4: Thread 5: ThreadGroup_ImgBlock 6: Ray +# AS 0 is the flat/generic space; only casts involving it are legal, and the +# specific spaces are mutually disjoint. +LLVM.flat_address_space(::MetalTTI) = UInt(0) +LLVM.is_noop_addr_space_cast(::MetalTTI, from::Unsigned, to::Unsigned) = + from == 0 || to == 0 +LLVM.is_valid_addr_space_cast(::MetalTTI, from::Unsigned, to::Unsigned) = + from == to || from == 0 || to == 0 + +# distinct specific address spaces are disjoint; only the generic AS overlaps. +LLVM.addrspaces_may_alias(::MetalTTI, a::Unsigned, b::Unsigned) = + a == b || a == 0 || b == 0 + +# used as a coarse "this is a GPU target" switch by several IR passes (e.g. +# JumpThreading and non-trivial SimpleLoopUnswitch become no-ops), not just +# UniformityAnalysis — which we don't have consumers for anyway. +LLVM.has_branch_divergence(::MetalTTI) = true + +# deliberately not overriding `is_single_threaded`: a kernel is multi-lane, and +# returning `true` would let LICM sink stores onto paths that didn't store, +# producing races across lanes. + +# only the spaces backed by static storage admit non-undef initializers; thread, +# threadgroup and ray-payload spaces are populated at dispatch/invocation time. +LLVM.can_have_non_undef_global_initializer_in_address_space(::MetalTTI, as::Unsigned) = + as == 0 || as == 1 || as == 2 + + ## target export MetalCompilerTarget @@ -35,6 +71,8 @@ llvm_datalayout(target::MetalCompilerTarget) = "-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024"* "-n8:16:32" +llvm_targetinfo(::MetalCompilerTarget) = MetalTTI() + pass_by_value(job::CompilerJob{MetalCompilerTarget}) = false @@ -76,6 +114,7 @@ function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarge # we emit properties (of the air and metal version) as private global constants, # so run the optimizer so that they are inlined before the rest of the optimizer runs. @dispose pb=NewPMPassBuilder() begin + LLVM.target_transform_info!(pb, MetalTTI()) add!(pb, RecomputeGlobalsAAPass()) add!(pb, GlobalOptPass()) run!(pb, mod) @@ -141,6 +180,7 @@ function hide_noreturn!(mod::LLVM.Module) any_noreturn || return false @dispose pb=NewPMPassBuilder() begin + LLVM.target_transform_info!(pb, MetalTTI()) add!(pb, AlwaysInlinerPass()) add!(pb, NewPMFunctionPassManager()) do fpm add!(fpm, SimplifyCFGPass()) @@ -166,6 +206,22 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L entry = add_parameter_address_spaces!(job, mod, entry) entry = add_global_address_spaces!(job, mod, entry) + # propagate specific address spaces through addrspacecast chains introduced + # by the rewrites above, so that loads/stores happen in the right address + # space (e.g. constant globals in addrspace 2 rather than via a cast to 0, + # which Metal's backend cannot handle correctly for dynamic indices). + @dispose pb=NewPMPassBuilder() begin + LLVM.target_transform_info!(pb, MetalTTI()) + add!(pb, NewPMFunctionPassManager()) do fpm + add!(fpm, InferAddressSpacesPass()) + add!(fpm, SROAPass()) + add!(fpm, InstCombinePass()) + add!(fpm, EarlyCSEPass()) + add!(fpm, SimplifyCFGPass()) + end + run!(pb, mod) + end + add_argument_metadata!(job, mod, entry) add_module_metadata!(job, mod) diff --git a/src/optim.jl b/src/optim.jl index 282127f2..95834f0b 100644 --- a/src/optim.jl +++ b/src/optim.jl @@ -2,11 +2,14 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=2) tm = llvm_machine(job.config.target) + tti = llvm_targetinfo(job.config.target) global current_job current_job = job @dispose pb=NewPMPassBuilder() begin + tti === nothing || LLVM.target_transform_info!(pb, tti) + register!(pb, GPULowerCPUFeaturesPass()) register!(pb, GPULowerPTLSPass()) register!(pb, GPULowerGCFramePass()) diff --git a/test/metal.jl b/test/metal.jl index d35ada93..87539b26 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -181,6 +181,31 @@ end end end +# Tuples with a dynamic index are lowered to an addrspace(2) constant plus a +# GEP+load. Without InferAddressSpaces propagating AS 2 through the cast to +# the generic AS introduced during `add_global_address_spaces!`, the load +# would end up in AS 0 and Metal's back-end miscompiles it into zeroes. +@testset "dynamic constant global access" begin + mod = @eval module $(gensym()) + function kernel(ptr, i) + t = (1.0f0, 2.0f0, 3.0f0, 4.0f0) + @inbounds unsafe_store!(ptr, t[i]) + return + end + end + + @test @filecheck begin + @check "@{{.+}} ={{.*}} addrspace(2) constant [4 x float]" + @check_label "define void @_Z6kernel7LLVMPtrI7Float32Li1EE5Int64" + # the load must happen in addrspace(2); Metal miscompiles loads of + # the constant if they occur via an addrspacecast to the generic AS + @check cond=opaque_ptrs "load float, ptr addrspace(2)" + @check cond=typed_ptrs "load float, float addrspace(2)*" + Metal.code_llvm(mod.kernel, Tuple{Core.LLVMPtr{Float32,1}, Int}; + dump_module=true, kernel=true) + end +end + end end