Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 126 additions & 10 deletions CUDACore/src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ const map_ptx_to_jl_array = Dict(
"s32" => Int32,
"f16" => Float16,
"bf16" => BFloat16,
"f32" => Float32
"f32" => Float32,
"f64" => Float64
)

# Maps PTX types to Julia fragment types
Expand All @@ -27,7 +28,8 @@ const map_ptx_to_jl_frag = Dict(
"s32" => Int32,
"f16" => NTuple{2, VecElement{Float16}},
"bf16" => UInt32,
"f32" => Float32
"f32" => Float32,
"f64" => Float64
)

# Maps matrix & PTX types to fragment sizes
Expand All @@ -48,6 +50,8 @@ const map_frag_sizes = Dict(
"a.bf16.m16n16k16" => 4,
"a.bf16.m8n32k16" => 2,
"a.bf16.m32n8k16" => 8,

"a.f64.m8n8k4" => 1,
# B
"b.u8.m16n16k16" => 2,
"b.u8.m8n32k16" => 4,
Expand All @@ -64,6 +68,8 @@ const map_frag_sizes = Dict(
"b.bf16.m16n16k16" => 4,
"b.bf16.m8n32k16" => 8,
"b.bf16.m32n8k16" => 2,

"b.f64.m8n8k4" => 1,
# C
"c.s32.m16n16k16" => 8,
"c.s32.m8n32k16" => 8,
Expand All @@ -76,6 +82,12 @@ const map_frag_sizes = Dict(
"c.f32.m16n16k16" => 8,
"c.f32.m8n32k16" => 8,
"c.f32.m32n8k16" => 8,

# NB: the PTX docs disagree with themselves on the m8n8k4 f64 accumulator
# fragment size. The mma-m8n8k4 page is correct ("two .f64 elements from
# the matrix C"); the wmma matrix-fragments table erroneously lists 1.
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type
"c.f64.m8n8k4" => 2,
# D
"d.s32.m16n16k16" => 8,
"d.s32.m8n32k16" => 8,
Expand All @@ -88,6 +100,8 @@ const map_frag_sizes = Dict(
"d.f32.m16n16k16" => 8,
"d.f32.m8n32k16" => 8,
"d.f32.m32n8k16" => 8,

"d.f64.m8n8k4" => 2,
)

# Maps PTX AS to CUDA.AS
Expand All @@ -110,13 +124,20 @@ const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"
# BFloat16 (requires Ampere+, only f32 accumulator supported)
const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"]
const wmma_bf16_ops = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"]
# Double-precision (requires sm_80+; only m8n8k4, only f64 accumulator)
const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"]
const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"]
const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops)
ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops,
ldst_double_ab_ops, ldst_double_cd_ops)
# f64 MMA ops are generated by a separate loop because PTX requires an explicit
# rounding modifier in their intrinsic name.
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops)

# Valid WMMA operation shapes
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)]

################################################################################
# HELPER FUNCTIONS
Expand Down Expand Up @@ -370,6 +391,63 @@ for ops in all_wmma_ops,
@eval @doc (@doc llvm_wmma_mma) $func_name
end

# Float64 MMA. The PTX/LLVM intrinsic name always carries a rounding modifier
# (rn/rz/rm/rp); there is no implicit-default form. We generate one wrapper
# per (a_layout, b_layout, rnd), then add a 3-arg alias and four RoundingMode-
# dispatched 4-arg aliases that forward to the corresponding suffixed wrapper.
for ops in [wmma_double_ops],
a_layout in ["col", "row"],
b_layout in ["col", "row"],
mnk in ops[1],
rnd in ("rn", "rz", "rm", "rp")

shape = get_hl_shape(mnk[1], mnk[2], mnk[3])

llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64"
func_name = Symbol(join(["llvm", "wmma", "mma", a_layout, b_layout, shape, "f64", rnd], "_"))

a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", "f64", shape)
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", "f64", shape)
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", "f64", shape)
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", "f64", shape)

a_types = ntuple(i -> a_frag_ty, a_sz)
b_types = ntuple(i -> b_frag_ty, b_sz)
c_types = ntuple(i -> c_frag_ty, c_sz)

a_vars = ntuple(i -> :(a[$i]), a_sz)
b_vars = ntuple(i -> :(b[$i]), b_sz)
c_vars = ntuple(i -> :(c[$i]), c_sz)

if d_sz == 1
@eval @device_function $func_name(a, b, c) = tuple(ccall($llvm_intr, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
else
struct_ty = Symbol("LLVMStruct$d_sz")
@eval @device_function $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($llvm_intr, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_mma) $func_name
end

# RoundingMode-dispatched aliases for f64 MMA. The bare-3-arg form forwards
# to round-to-nearest, matching PTX's default-rnd convention.
for a_layout in ("col", "row"), b_layout in ("col", "row"), mnk in ((8, 8, 4),)
shape = get_hl_shape(mnk...)
base = Symbol(join(["llvm", "wmma", "mma", a_layout, b_layout, shape, "f64"], "_"))
rn_name = Symbol(base, :_rn)
rz_name = Symbol(base, :_rz)
rm_name = Symbol(base, :_rm)
rp_name = Symbol(base, :_rp)

@eval @device_function $base(a, b, c) = $rn_name(a, b, c)
@eval @device_function $base(a, b, c, ::RoundingMode{:Nearest}) = $rn_name(a, b, c)
@eval @device_function $base(a, b, c, ::RoundingMode{:ToZero}) = $rz_name(a, b, c)
@eval @device_function $base(a, b, c, ::RoundingMode{:Down}) = $rm_name(a, b, c)
@eval @device_function $base(a, b, c, ::RoundingMode{:Up}) = $rp_name(a, b, c)
@eval export $base
@eval @doc (@doc llvm_wmma_mma) $base
end

################################################################################
# FLATTENING/UNFLATTENING LOGIC
################################################################################
Expand Down Expand Up @@ -517,9 +595,13 @@ Type that contains all information for WMMA operations that cannot be inferred f
WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix,
``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices.

`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`.
`d_type` refers to the type of the elements of matrix ``D``, and can be `Float16`,
`Float32`, or `Float64` (the latter only with shape `(8, 8, 4)`).

All WMMA operations take a `Config` as their final argument.
All WMMA operations take a `Config` as their final argument. For `Float64`
configurations, [`WMMA.mma`](@ref) additionally accepts a
[`Base.RoundingMode`](https://docs.julialang.org/en/v1/base/math/#Base.Rounding.RoundingMode)
to select the hardware rounding mode.

# Examples
```jldoctest
Expand Down Expand Up @@ -554,7 +636,12 @@ const map_num_elems = Dict(
("c", Float16) => 8,
("c", Float32) => 8,
("d", Float16) => 8,
("d", Float32) => 8
("d", Float32) => 8,
# f64 m8n8k4: per-thread frag is already flat (no VecElement)
("a", Float64) => 1,
("b", Float64) => 1,
("c", Float64) => 2,
("d", Float64) => 2,
)

# Maps matrix to its use
Expand Down Expand Up @@ -669,6 +756,7 @@ export mma

"""
WMMA.mma(a, b, c, conf)
WMMA.mma(a, b, c, conf, rounding)

Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``.

Expand All @@ -678,6 +766,10 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``.
- `b`: The [`WMMA.Fragment`](@ref) corresponding to the matrix ``B``.
- `c`: The [`WMMA.Fragment`](@ref) corresponding to the matrix ``C``.
- `conf`: The [`WMMA.Config`](@ref) that should be used in this WMMA operation.
- `rounding`: A `Base.RoundingMode` selecting the hardware rounding mode. Only
`Float64` configurations support modes other than `RoundNearest`; for other
element types, passing a non-`RoundNearest` mode raises an error. Defaults to
`RoundNearest`.

!!! warning

Expand All @@ -686,10 +778,14 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``.
"""
mma

mma(a::Fragment, b::Fragment, c::Fragment,
config::Type{<:Config}) = mma(a, b, c, config, RoundNearest)

@generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA},
b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB},
c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator},
config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T}
config::Type{Config{M, N, K, D_T}},
::RoundingMode{R}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T, R}

a_layout = get_hl_layout(A_L)
b_layout = get_hl_layout(B_L)
Expand All @@ -700,9 +796,29 @@ mma
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape)

# Only Float64 MMA has hardware-selectable rounding modes; for other types,
# the only valid mode is the hardware's implicit round-to-nearest.
rnd_suffix = ""
if D_T === Float64
rnd_suffix = R === :Nearest ? "rn" :
R === :ToZero ? "rz" :
R === :Up ? "rp" :
R === :Down ? "rm" :
error("WMMA.mma: unsupported RoundingMode :$R for Float64")
elseif R !== :Nearest
error("WMMA.mma: RoundingMode :$R is only supported for Float64 configurations; got D_T = $D_T")
end

names = ["llvm", "wmma", "mma", a_layout, b_layout, shape]
# bf16 uses input type in intrinsic name, f16 uses d/c types
A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str)
# bf16 uses A/B element type in the intrinsic name; f16/f32 use D/C types;
# f64 uses its single element type plus the rounding suffix.
if A_T === BFloat16
push!(names, a_arr_str)
elseif D_T === Float64
push!(names, d_arr_str, rnd_suffix)
else
push!(names, d_arr_str, c_arr_str)
end
wrapper = Symbol(join(filter(!isempty, names), "_"))

a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x))
Expand Down
Loading