diff --git a/CUDACore/src/device/intrinsics/wmma.jl b/CUDACore/src/device/intrinsics/wmma.jl index 1963d9ee4a..fc2ee40bd3 100644 --- a/CUDACore/src/device/intrinsics/wmma.jl +++ b/CUDACore/src/device/intrinsics/wmma.jl @@ -17,7 +17,8 @@ const map_ptx_to_jl_array = Dict( "s32" => Int32, "f16" => Float16, "bf16" => BFloat16, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps PTX types to Julia fragment types @@ -27,7 +28,8 @@ const map_ptx_to_jl_frag = Dict( "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, "bf16" => UInt32, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps matrix & PTX types to fragment sizes @@ -48,6 +50,8 @@ const map_frag_sizes = Dict( "a.bf16.m16n16k16" => 4, "a.bf16.m8n32k16" => 2, "a.bf16.m32n8k16" => 8, + + "a.f64.m8n8k4" => 1, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -64,6 +68,8 @@ const map_frag_sizes = Dict( "b.bf16.m16n16k16" => 4, "b.bf16.m8n32k16" => 8, "b.bf16.m32n8k16" => 2, + + "b.f64.m8n8k4" => 1, # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, @@ -76,6 +82,12 @@ const map_frag_sizes = Dict( "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, + + # NB: the PTX docs disagree with themselves on the m8n8k4 f64 accumulator + # fragment size. The mma-m8n8k4 page is correct ("two .f64 elements from + # the matrix C"); the wmma matrix-fragments table erroneously lists 1. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type + "c.f64.m8n8k4" => 2, # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, @@ -88,6 +100,8 @@ const map_frag_sizes = Dict( "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, + + "d.f64.m8n8k4" => 2, ) # Maps PTX AS to CUDA.AS @@ -110,13 +124,20 @@ const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32" # BFloat16 (requires Ampere+, only f32 accumulator supported) const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"] const wmma_bf16_ops = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"] +# Double-precision (requires sm_80+; only m8n8k4, only f64 accumulator) +const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"] +const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"] +const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops) + ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops, + ldst_double_ab_ops, ldst_double_cd_ops) +# f64 MMA ops are generated by a separate loop because PTX requires an explicit +# rounding modifier in their intrinsic name. const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)] ################################################################################ # HELPER FUNCTIONS @@ -370,6 +391,63 @@ for ops in all_wmma_ops, @eval @doc (@doc llvm_wmma_mma) $func_name end +# Float64 MMA. The PTX/LLVM intrinsic name always carries a rounding modifier +# (rn/rz/rm/rp); there is no implicit-default form. We generate one wrapper +# per (a_layout, b_layout, rnd), then add a 3-arg alias and four RoundingMode- +# dispatched 4-arg aliases that forward to the corresponding suffixed wrapper. +for ops in [wmma_double_ops], + a_layout in ["col", "row"], + b_layout in ["col", "row"], + mnk in ops[1], + rnd in ("rn", "rz", "rm", "rp") + + shape = get_hl_shape(mnk[1], mnk[2], mnk[3]) + + llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64" + func_name = Symbol(join(["llvm", "wmma", "mma", a_layout, b_layout, shape, "f64", rnd], "_")) + + a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", "f64", shape) + b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", "f64", shape) + c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", "f64", shape) + d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", "f64", shape) + + a_types = ntuple(i -> a_frag_ty, a_sz) + b_types = ntuple(i -> b_frag_ty, b_sz) + c_types = ntuple(i -> c_frag_ty, c_sz) + + a_vars = ntuple(i -> :(a[$i]), a_sz) + b_vars = ntuple(i -> :(b[$i]), b_sz) + c_vars = ntuple(i -> :(c[$i]), c_sz) + + if d_sz == 1 + @eval @device_function $func_name(a, b, c) = tuple(ccall($llvm_intr, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + else + struct_ty = Symbol("LLVMStruct$d_sz") + @eval @device_function $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($llvm_intr, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + end + @eval export $func_name + @eval @doc (@doc llvm_wmma_mma) $func_name +end + +# RoundingMode-dispatched aliases for f64 MMA. The bare-3-arg form forwards +# to round-to-nearest, matching PTX's default-rnd convention. +for a_layout in ("col", "row"), b_layout in ("col", "row"), mnk in ((8, 8, 4),) + shape = get_hl_shape(mnk...) + base = Symbol(join(["llvm", "wmma", "mma", a_layout, b_layout, shape, "f64"], "_")) + rn_name = Symbol(base, :_rn) + rz_name = Symbol(base, :_rz) + rm_name = Symbol(base, :_rm) + rp_name = Symbol(base, :_rp) + + @eval @device_function $base(a, b, c) = $rn_name(a, b, c) + @eval @device_function $base(a, b, c, ::RoundingMode{:Nearest}) = $rn_name(a, b, c) + @eval @device_function $base(a, b, c, ::RoundingMode{:ToZero}) = $rz_name(a, b, c) + @eval @device_function $base(a, b, c, ::RoundingMode{:Down}) = $rm_name(a, b, c) + @eval @device_function $base(a, b, c, ::RoundingMode{:Up}) = $rp_name(a, b, c) + @eval export $base + @eval @doc (@doc llvm_wmma_mma) $base +end + ################################################################################ # FLATTENING/UNFLATTENING LOGIC ################################################################################ @@ -517,9 +595,13 @@ Type that contains all information for WMMA operations that cannot be inferred f WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix, ``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices. -`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. +`d_type` refers to the type of the elements of matrix ``D``, and can be `Float16`, +`Float32`, or `Float64` (the latter only with shape `(8, 8, 4)`). -All WMMA operations take a `Config` as their final argument. +All WMMA operations take a `Config` as their final argument. For `Float64` +configurations, [`WMMA.mma`](@ref) additionally accepts a +[`Base.RoundingMode`](https://docs.julialang.org/en/v1/base/math/#Base.Rounding.RoundingMode) +to select the hardware rounding mode. # Examples ```jldoctest @@ -554,7 +636,12 @@ const map_num_elems = Dict( ("c", Float16) => 8, ("c", Float32) => 8, ("d", Float16) => 8, - ("d", Float32) => 8 + ("d", Float32) => 8, + # f64 m8n8k4: per-thread frag is already flat (no VecElement) + ("a", Float64) => 1, + ("b", Float64) => 1, + ("c", Float64) => 2, + ("d", Float64) => 2, ) # Maps matrix to its use @@ -669,6 +756,7 @@ export mma """ WMMA.mma(a, b, c, conf) + WMMA.mma(a, b, c, conf, rounding) Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. @@ -678,6 +766,10 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. - `b`: The [`WMMA.Fragment`](@ref) corresponding to the matrix ``B``. - `c`: The [`WMMA.Fragment`](@ref) corresponding to the matrix ``C``. - `conf`: The [`WMMA.Config`](@ref) that should be used in this WMMA operation. +- `rounding`: A `Base.RoundingMode` selecting the hardware rounding mode. Only + `Float64` configurations support modes other than `RoundNearest`; for other + element types, passing a non-`RoundNearest` mode raises an error. Defaults to + `RoundNearest`. !!! warning @@ -686,10 +778,14 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. """ mma +mma(a::Fragment, b::Fragment, c::Fragment, + config::Type{<:Config}) = mma(a, b, c, config, RoundNearest) + @generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, - config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} + config::Type{Config{M, N, K, D_T}}, + ::RoundingMode{R}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T, R} a_layout = get_hl_layout(A_L) b_layout = get_hl_layout(B_L) @@ -700,9 +796,29 @@ mma _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape) d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape) + # Only Float64 MMA has hardware-selectable rounding modes; for other types, + # the only valid mode is the hardware's implicit round-to-nearest. + rnd_suffix = "" + if D_T === Float64 + rnd_suffix = R === :Nearest ? "rn" : + R === :ToZero ? "rz" : + R === :Up ? "rp" : + R === :Down ? "rm" : + error("WMMA.mma: unsupported RoundingMode :$R for Float64") + elseif R !== :Nearest + error("WMMA.mma: RoundingMode :$R is only supported for Float64 configurations; got D_T = $D_T") + end + names = ["llvm", "wmma", "mma", a_layout, b_layout, shape] - # bf16 uses input type in intrinsic name, f16 uses d/c types - A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str) + # bf16 uses A/B element type in the intrinsic name; f16/f32 use D/C types; + # f64 uses its single element type plus the rounding suffix. + if A_T === BFloat16 + push!(names, a_arr_str) + elseif D_T === Float64 + push!(names, d_arr_str, rnd_suffix) + else + push!(names, d_arr_str, c_arr_str) + end wrapper = Symbol(join(filter(!isempty, names), "_")) a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x)) diff --git a/test/core/device/intrinsics/wmma.jl b/test/core/device/intrinsics/wmma.jl index f4ef5fdf4c..deb997b21b 100644 --- a/test/core/device/intrinsics/wmma.jl +++ b/test/core/device/intrinsics/wmma.jl @@ -249,6 +249,108 @@ end end end end + + # m8n8k4 f64 WMMA requires sm_80+ (Ampere) and is the only WMMA shape/type + # that exposes hardware rounding-mode selection. + if capability(device()) >= v"8.0" + @testset "llvm_wmma_mma (Float64 rounding)" begin + m, n, k = 8, 8, 4 + round_modes = ((RoundNearest, "rn"), (RoundToZero, "rz"), + (RoundDown, "rm"), (RoundUp, "rp")) + + @testset "$(a_layout)_$(b_layout)" for a_layout in ("col", "row"), + b_layout in ("col", "row") + lda = getfield(@__MODULE__, Symbol("llvm_wmma_load_a_$(a_layout)_m8n8k4_global_stride_f64")) + ldb = getfield(@__MODULE__, Symbol("llvm_wmma_load_b_$(b_layout)_m8n8k4_global_stride_f64")) + ldc = getfield(@__MODULE__, Symbol("llvm_wmma_load_c_col_m8n8k4_global_stride_f64")) + std = getfield(@__MODULE__, Symbol("llvm_wmma_store_d_col_m8n8k4_global_stride_f64")) + mma = getfield(@__MODULE__, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m8n8k4_f64")) + + a_shape = a_layout == "col" ? (m, k) : (k, m) + b_shape = b_layout == "col" ? (k, n) : (n, k) + cd_shape = (m, n) + + # Inputs chosen so that A*B + C has bits below Float64 precision, + # making the four rounding modes produce different results. + a = rand(Float64, a_shape) .+ Float64(π) + b = rand(Float64, b_shape) .+ Float64(ℯ) + c = rand(Float64, cd_shape) .* Float64(π) + + a_dev = CuArray(a); b_dev = CuArray(b); c_dev = CuArray(c) + + # One kernel that takes the rounding mode as a (compile-time) singleton. + @eval function kernel_rmode(a_dev, b_dev, c_dev, d_dev, rmode) + a_frag = $lda(pointer(a_dev), $(a_shape[1])) + b_frag = $ldb(pointer(b_dev), $(b_shape[1])) + c_frag = $ldc(pointer(c_dev), $(cd_shape[1])) + d_frag = $mma(a_frag, b_frag, c_frag, rmode) + $std(pointer(d_dev), d_frag, $(cd_shape[1])) + return + end + + run_mode(rmode) = begin + d_dev = CuArray{Float64}(undef, cd_shape) + @cuda threads=32 kernel_rmode(a_dev, b_dev, c_dev, d_dev, rmode) + Array(d_dev) + end + + d_rn = run_mode(RoundNearest) + d_rz = run_mode(RoundToZero) + d_rp = run_mode(RoundUp) + d_rm = run_mode(RoundDown) + + new_a = a_layout == "col" ? a : transpose(a) + new_b = b_layout == "col" ? b : transpose(b) + expected = new_a * new_b + c + + # All modes are close to the BLAS reference within Float64 precision. + @test d_rn ≈ expected rtol=Base.rtoldefault(Float64) + @test d_rz ≈ expected rtol=Base.rtoldefault(Float64) + @test d_rp ≈ expected rtol=Base.rtoldefault(Float64) + @test d_rm ≈ expected rtol=Base.rtoldefault(Float64) + + # Rounding ordering: RoundDown ≤ RoundNearest ≤ RoundUp, elementwise. + @test all(d_rm .<= d_rn) + @test all(d_rn .<= d_rp) + + # Inputs are positive, so RoundToZero matches RoundDown. + @test d_rz == d_rm + + # Sanity: at least one element actually differs between Down and Up, + # i.e. we are genuinely exercising directed rounding. + @test any(d_rm .!= d_rp) + + # The bare 3-arg form must equal the explicit RoundNearest form. + @eval function kernel_default(a_dev, b_dev, c_dev, d_dev) + a_frag = $lda(pointer(a_dev), $(a_shape[1])) + b_frag = $ldb(pointer(b_dev), $(b_shape[1])) + c_frag = $ldc(pointer(c_dev), $(cd_shape[1])) + d_frag = $mma(a_frag, b_frag, c_frag) + $std(pointer(d_dev), d_frag, $(cd_shape[1])) + return + end + d_default_dev = CuArray{Float64}(undef, cd_shape) + @cuda threads=32 kernel_default(a_dev, b_dev, c_dev, d_default_dev) + @test Array(d_default_dev) == d_rn + + # Suffixed wrappers (..._rn/_rz/_rm/_rp) must equal the dispatched form. + @testset "suffixed wrapper $(suffix)" for (rmode, suffix) in round_modes + mma_sfx = getfield(@__MODULE__, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m8n8k4_f64_$(suffix)")) + @eval function kernel_sfx(a_dev, b_dev, c_dev, d_dev) + a_frag = $lda(pointer(a_dev), $(a_shape[1])) + b_frag = $ldb(pointer(b_dev), $(b_shape[1])) + c_frag = $ldc(pointer(c_dev), $(cd_shape[1])) + d_frag = $mma_sfx(a_frag, b_frag, c_frag) + $std(pointer(d_dev), d_frag, $(cd_shape[1])) + return + end + d_sfx_dev = CuArray{Float64}(undef, cd_shape) + @cuda threads=32 kernel_sfx(a_dev, b_dev, c_dev, d_sfx_dev) + @test Array(d_sfx_dev) == run_mode(rmode) + end + end + end + end end ################################################################################ @@ -476,6 +578,72 @@ end ################################################################################ +# m8n8k4 f64 WMMA requires sm_80+ (Ampere) +if capability(device()) >= v"8.0" +@testset "CUDA C-style API (Float64 rounding)" begin + @testset "$(do_mac ? "MAC" : "MUL"), A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, rmode: $rmode" for + a_layout in [ColMajor, RowMajor], + b_layout in [ColMajor, RowMajor], + c_layout in [ColMajor, RowMajor], + d_layout in [ColMajor, RowMajor], + rmode in [RoundNearest, RoundToZero, RoundDown, RoundUp], + do_mac in [true, false] + + a = rand(Float64, (8, 4)) .+ Float64(π) + b = rand(Float64, (4, 8)) .+ Float64(ℯ) + c = rand(Float64, (8, 8)) .* Float64(π) + d = Array{Float64}(undef, (8, 8)) + + a_dev = CuArray(a_layout == ColMajor ? a : collect(transpose(a))) + b_dev = CuArray(b_layout == ColMajor ? b : collect(transpose(b))) + c_dev = CuArray(c_layout == ColMajor ? c : collect(transpose(c))) + d_dev = CuArray(d) + + @eval function kernel_f64(a_dev, b_dev, c_dev, d_dev) + conf = Config{8, 8, 4, Float64} + + a_frag = load_a(pointer(a_dev), 8, $a_layout, conf) + b_frag = load_b(pointer(b_dev), 4, $b_layout, conf) + + if $do_mac + c_frag = load_c(pointer(c_dev), 8, $c_layout, conf) + else + c_frag = fill_c(Float64(0), conf) + end + + d_frag = mma(a_frag, b_frag, c_frag, conf, $rmode) + + store_d(pointer(d_dev), d_frag, 8, $d_layout, conf) + + return + end + + @cuda threads=32 kernel_f64(a_dev, b_dev, c_dev, d_dev) + d = Array(d_dev) + + new_d = (d_layout == ColMajor) ? d : transpose(d) + + if do_mac + @test a * b + c ≈ new_d rtol=Base.rtoldefault(Float64) + else + @test a * b ≈ new_d rtol=Base.rtoldefault(Float64) + end + end + + # Non-Float64 configurations must reject rounding modes other than RoundNearest. + # The check lives in the @generated body, so the error surfaces at method + # expansion time — we can exercise it directly from the host without @cuda. + @testset "rounding mode rejected for non-Float64" begin + a = Fragment{16, 16, 16, 16, Float16, ColMajor, MatrixA}(ntuple(_ -> Float16(0), 16)) + b = Fragment{16, 16, 16, 16, Float16, ColMajor, MatrixB}(ntuple(_ -> Float16(0), 16)) + c = Fragment{16, 16, 16, 8, Float32, Unspecified, Accumulator}(ntuple(_ -> Float32(0), 8)) + @test_throws ErrorException mma(a, b, c, Config{16, 16, 16, Float32}, RoundDown) + end +end +end + +################################################################################ + @testset "Codegen addressing" begin @testset "Global" begin function kernel(d)