diff --git a/src/spirv.jl b/src/spirv.jl index ea552e70..ef96f4c4 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -27,6 +27,7 @@ Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget extensions::Vector{String} = [] supports_fp16::Bool = true supports_fp64::Bool = true + supports_bfloat16::Bool = false backend::Symbol = isavailable(SPIRV_LLVM_Backend_jll) ? :llvm : :khronos # XXX: these don't really belong in the _target_ struct @@ -86,6 +87,9 @@ function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module) if !job.config.target.supports_fp64 append!(errors, check_ir_values(mod, LLVM.DoubleType())) end + if !job.config.target.supports_bfloat16 && isdefined(LLVM, :BFloatType) + append!(errors, check_ir_values(mod, LLVM.BFloatType())) + end return errors end diff --git a/test/helpers/spirv.jl b/test/helpers/spirv.jl index 0144cd6a..5e765bb7 100644 --- a/test/helpers/spirv.jl +++ b/test/helpers/spirv.jl @@ -7,12 +7,12 @@ struct CompilerParams <: AbstractCompilerParams end GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime function create_job(@nospecialize(func), @nospecialize(types); - supports_fp16=true, supports_fp64=true, backend::Symbol, - kwargs...) + supports_fp16=true, supports_fp64=true, supports_bfloat16=false, + backend::Symbol, kwargs...) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) target = SPIRVCompilerTarget(; backend, validate=true, optimize=true, - supports_fp16, supports_fp64) + supports_fp16, supports_fp64, supports_bfloat16) params = CompilerParams() config = CompilerConfig(target, params; kernel=false, config_kwargs...) CompilerJob(source, config), kwargs diff --git a/test/spirv.jl b/test/spirv.jl index c737b6f4..f5ef53fa 100644 --- a/test/spirv.jl +++ b/test/spirv.jl @@ -88,6 +88,23 @@ end occursin("[1] unsafe_store!", msg) && occursin(r"\[\d+\] kernel", msg) end + + @static if isdefined(Core, :BFloat16) + @test @filecheck begin + @check_label "define void @{{(julia|j)_kernel_[0-9]+}}" + @check "store bfloat" + SPIRV.code_llvm(mod.kernel, Tuple{Ptr{Core.BFloat16}, Core.BFloat16}; + backend, supports_bfloat16=true) + end + + @test_throws_message(InvalidIRError, + SPIRV.code_execution(mod.kernel, Tuple{Ptr{Core.BFloat16}, Core.BFloat16}; + backend, supports_bfloat16=false)) do msg + occursin("unsupported use of bfloat value", msg) && + occursin("[1] unsafe_store!", msg) && + occursin(r"\[\d+\] kernel", msg) + end + end end end