Skip to content

Expose directed rounding for Float64 WMMA tensor cores#3143

Draft
orkolorko wants to merge 1 commit into
JuliaGPU:mainfrom
orkolorko:wmma-f64-rounding
Draft

Expose directed rounding for Float64 WMMA tensor cores#3143
orkolorko wants to merge 1 commit into
JuliaGPU:mainfrom
orkolorko:wmma-f64-rounding

Conversation

@orkolorko
Copy link
Copy Markdown
Contributor

Summary

  • Adds the m8n8k4 Float64 WMMA path (load/store + mma) and routes the NVPTX mma.sync.aligned.m8n8k4.{rn,rz,rm,rp}.f64 rounding modifiers through both the low-level llvm_wmma_mma_* wrappers and the high-level WMMA.mma API.
  • On either layer, the 4-arg form accepts a Base.RoundingMode (RoundNearest, RoundToZero, RoundDown, RoundUp); the bare 3-arg form forwards to round-to-nearest, matching PTX's default-rnd convention.
  • Non-Float64 configurations reject any non-RoundNearest mode at codegen time, since f16/f32/bf16/int WMMA has no hardware-selectable rounding modifier.

Continues the directed-rounding theme started in #2576 — there the scalar add_rn/mul_rz/etc. intrinsics were exposed; here the same idea extends to the f64 tensor-core path, which is the only WMMA shape/type where PTX exposes per-call rounding control.

Example

using CUDA
using CUDA.WMMA

function kernel_low_level(a_dev, b_dev, c_dev, d_dev)
    a_frag = WMMA.llvm_wmma_load_a_col_m8n8k4_global_stride_f64(pointer(a_dev), 8)
    b_frag = WMMA.llvm_wmma_load_b_col_m8n8k4_global_stride_f64(pointer(b_dev), 4)
    c_frag = WMMA.llvm_wmma_load_c_col_m8n8k4_global_stride_f64(pointer(c_dev), 8)

    d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundDown)

    WMMA.llvm_wmma_store_d_col_m8n8k4_global_stride_f64(pointer(d_dev), d_frag, 8)
    return nothing
end

function kernel_high_level(a_dev, b_dev, c_dev, d_dev)
    conf = WMMA.Config{8, 8, 4, Float64}

    a_frag = WMMA.load_a(pointer(a_dev), 8, WMMA.ColMajor, conf)
    b_frag = WMMA.load_b(pointer(b_dev), 4, WMMA.ColMajor, conf)
    c_frag = WMMA.load_c(pointer(c_dev), 8, WMMA.ColMajor, conf)

    d_frag = WMMA.mma(a_frag, b_frag, c_frag, conf, RoundDown)

    WMMA.store_d(pointer(d_dev), d_frag, 8, WMMA.ColMajor, conf)
    return nothing
end

Notes for reviewers

  • The PTX manual disagrees with itself on the m8n8k4 f64 C/D fragment size: the wmma matrix-fragments table says 1, but the mma-m8n8k4 page correctly says 2. The latter is what hardware does; map_frag_sizes reflects 2 with an inline citation.
  • LLVM exposes the f64 WMMA mma intrinsics only with explicit .rn/.rz/.rm/.rp modifiers — there is no implicit-default form. The bare-3-arg Julia wrapper therefore forwards to _rn, matching PTX's documented default.
  • f64 WMMA requires sm_80+; tests are gated accordingly (same pattern as the existing BFloat16 WMMA tests).

Test plan

  • Existing WMMA tests still pass.
  • New `llvm_wmma_mma (Float64 rounding)` testset passes on Ampere+, exercising all 4 layouts × 4 rounding modes; asserts `RoundDown ≤ RoundNearest ≤ RoundUp` elementwise and that suffixed wrappers (`..._rn`/`_rz`/`_rm`/`_rp`) match the dispatched form.
  • New `CUDA C-style API (Float64 rounding)` testset passes on Ampere+, exercising the full `Config`/`load_*`/`mma`/`store_d` path with all rounding modes and both MAC and MUL variants.
  • Negative test confirms `WMMA.mma(..., Config{16,16,16,Float32}, RoundDown)` raises an `ErrorException` from the @generated body.

🤖 Generated with Claude Code

Add the m8n8k4 Float64 WMMA path (load/store + mma) and route the
NVPTX `mma.sync.aligned.m8n8k4.{rn,rz,rm,rp}.f64` rounding modifiers
through both the low-level `llvm_wmma_mma_*` wrappers and the
high-level `WMMA.mma` API. On either layer the 4-arg form accepts a
`Base.RoundingMode` (`RoundNearest`, `RoundToZero`, `RoundDown`,
`RoundUp`); the bare 3-arg form forwards to round-to-nearest, matching
PTX's default-rnd convention. Non-Float64 configurations reject any
non-`RoundNearest` mode at codegen time, since f16/f32/bf16/int WMMA
has no hardware-selectable rounding modifier.

The WMMA matrix-fragments table in the PTX manual erroneously lists
the m8n8k4 f64 C/D fragment size as 1; the mma-m8n8k4 page (correct,
2) is the one reflected in `map_frag_sizes`, with an inline note.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Details
Benchmark suite Current: abd66ed Previous: e550cee Ratio
array/accumulate/Float32/1d 100706 ns 101415 ns 0.99
array/accumulate/Float32/dims=1 75956 ns 77151 ns 0.98
array/accumulate/Float32/dims=1L 1585415 ns 1586551.5 ns 1.00
array/accumulate/Float32/dims=2 142923 ns 144125.5 ns 0.99
array/accumulate/Float32/dims=2L 656944 ns 658127 ns 1.00
array/accumulate/Int64/1d 118315.5 ns 118535 ns 1.00
array/accumulate/Int64/dims=1 79639.5 ns 80097 ns 0.99
array/accumulate/Int64/dims=1L 1694511.5 ns 1694711.5 ns 1.00
array/accumulate/Int64/dims=2 155389.5 ns 156189 ns 0.99
array/accumulate/Int64/dims=2L 961609 ns 962516 ns 1.00
array/broadcast 20678 ns 20532 ns 1.01
array/construct 1271.5 ns 1270 ns 1.00
array/copy 17918 ns 18019 ns 0.99
array/copyto!/cpu_to_gpu 213811 ns 215242 ns 0.99
array/copyto!/gpu_to_cpu 280447 ns 283251 ns 0.99
array/copyto!/gpu_to_gpu 10608 ns 10886 ns 0.97
array/iteration/findall/bool 134407 ns 134914 ns 1.00
array/iteration/findall/int 148383 ns 149150 ns 0.99
array/iteration/findfirst/bool 80689 ns 81669 ns 0.99
array/iteration/findfirst/int 82922 ns 83894 ns 0.99
array/iteration/findmin/1d 83881.5 ns 84687.5 ns 0.99
array/iteration/findmin/2d 114898 ns 114576 ns 1.00
array/iteration/logical 197998.5 ns 202045 ns 0.98
array/iteration/scalar 67108 ns 66531 ns 1.01
array/permutedims/2d 52030.5 ns 52383 ns 0.99
array/permutedims/3d 52185 ns 52705 ns 0.99
array/permutedims/4d 51517 ns 51404 ns 1.00
array/random/rand/Float32 13238 ns 13062 ns 1.01
array/random/rand/Int64 24403 ns 24415.5 ns 1.00
array/random/rand!/Float32 8400.125 ns 8848.333333333334 ns 0.95
array/random/rand!/Int64 21068 ns 21621 ns 0.97
array/random/randn/Float32 38979 ns 42790.5 ns 0.91
array/random/randn!/Float32 30348 ns 30709 ns 0.99
array/reductions/mapreduce/Float32/1d 33800 ns 34002.5 ns 0.99
array/reductions/mapreduce/Float32/dims=1 43459 ns 49231 ns 0.88
array/reductions/mapreduce/Float32/dims=1L 51152 ns 51127 ns 1.00
array/reductions/mapreduce/Float32/dims=2 57744 ns 58010 ns 1.00
array/reductions/mapreduce/Float32/dims=2L 67045.5 ns 67462 ns 0.99
array/reductions/mapreduce/Int64/1d 42062 ns 42403 ns 0.99
array/reductions/mapreduce/Int64/dims=1 42243 ns 51369.5 ns 0.82
array/reductions/mapreduce/Int64/dims=1L 86956 ns 87141 ns 1.00
array/reductions/mapreduce/Int64/dims=2 60600 ns 60667 ns 1.00
array/reductions/mapreduce/Int64/dims=2L 84056.5 ns 83968.5 ns 1.00
array/reductions/reduce/Float32/1d 34021 ns 34433 ns 0.99
array/reductions/reduce/Float32/dims=1 40364 ns 39665 ns 1.02
array/reductions/reduce/Float32/dims=1L 51118.5 ns 51437 ns 0.99
array/reductions/reduce/Float32/dims=2 58062 ns 57991 ns 1.00
array/reductions/reduce/Float32/dims=2L 68048 ns 68174 ns 1.00
array/reductions/reduce/Int64/1d 41684 ns 42107 ns 0.99
array/reductions/reduce/Int64/dims=1 50730 ns 42289 ns 1.20
array/reductions/reduce/Int64/dims=1L 86984 ns 86974 ns 1.00
array/reductions/reduce/Int64/dims=2 60523.5 ns 60553 ns 1.00
array/reductions/reduce/Int64/dims=2L 83642 ns 83815 ns 1.00
array/reverse/1d 17564 ns 17742 ns 0.99
array/reverse/1dL 68165 ns 68341 ns 1.00
array/reverse/1dL_inplace 65596 ns 65689 ns 1.00
array/reverse/1d_inplace 10159 ns 8403.666666666666 ns 1.21
array/reverse/2d 20782 ns 20834 ns 1.00
array/reverse/2dL 72797 ns 72716 ns 1.00
array/reverse/2dL_inplace 65685 ns 65725 ns 1.00
array/reverse/2d_inplace 9806 ns 9895 ns 0.99
array/sorting/1d 2734127.5 ns 2736362 ns 1.00
array/sorting/2d 1067513 ns 1069216 ns 1.00
array/sorting/by 3303607 ns 3305075 ns 1.00
cuda/synchronization/context/auto 1181.2 ns 1164.8 ns 1.01
cuda/synchronization/context/blocking 930.3333333333334 ns 888.625 ns 1.05
cuda/synchronization/context/nonblocking 6790.4 ns 8179 ns 0.83
cuda/synchronization/stream/auto 1000.9166666666666 ns 1008.75 ns 0.99
cuda/synchronization/stream/blocking 824.3315789473684 ns 794.4795918367347 ns 1.04
cuda/synchronization/stream/nonblocking 7051.9 ns 7027.8 ns 1.00
integration/byval/reference 143661 ns 143758 ns 1.00
integration/byval/slices=1 145573 ns 145752 ns 1.00
integration/byval/slices=2 284416 ns 284484 ns 1.00
integration/byval/slices=3 422978 ns 423202 ns 1.00
integration/cudadevrt 102335 ns 102333 ns 1.00
integration/volumerhs 9945119.5 ns 23449772.5 ns 0.42
kernel/indexing 13050 ns 13234 ns 0.99
kernel/indexing_checked 13756 ns 13897 ns 0.99
kernel/launch 2144.5 ns 2186.777777777778 ns 0.98
kernel/occupancy 666.6477987421383 ns 690.9602649006622 ns 0.96
kernel/rand 17874 ns 14864 ns 1.20
latency/import 3846837228.5 ns 3857912982.5 ns 1.00
latency/precompile 4658516091 ns 4640798175 ns 1.00
latency/ttfp 4458051461.5 ns 4434643422.5 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

@maleadt maleadt marked this pull request as draft May 18, 2026 16:01
@maleadt
Copy link
Copy Markdown
Member

maleadt commented May 18, 2026

FYI, have you looked at cuTile.jl? It should automatically target tensor cores, while featuring a much more user friendly rounding mode API.

@orkolorko
Copy link
Copy Markdown
Contributor Author

Thank you @maleadt, really interesting!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants