Expose directed rounding for Float64 WMMA tensor cores#3143
Draft
orkolorko wants to merge 1 commit into
Draft
Conversation
Add the m8n8k4 Float64 WMMA path (load/store + mma) and route the
NVPTX `mma.sync.aligned.m8n8k4.{rn,rz,rm,rp}.f64` rounding modifiers
through both the low-level `llvm_wmma_mma_*` wrappers and the
high-level `WMMA.mma` API. On either layer the 4-arg form accepts a
`Base.RoundingMode` (`RoundNearest`, `RoundToZero`, `RoundDown`,
`RoundUp`); the bare 3-arg form forwards to round-to-nearest, matching
PTX's default-rnd convention. Non-Float64 configurations reject any
non-`RoundNearest` mode at codegen time, since f16/f32/bf16/int WMMA
has no hardware-selectable rounding modifier.
The WMMA matrix-fragments table in the PTX manual erroneously lists
the m8n8k4 f64 C/D fragment size as 1; the mma-m8n8k4 page (correct,
2) is the one reflected in `map_frag_sizes`, with an inline note.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
CUDA.jl Benchmarks
Details
| Benchmark suite | Current: abd66ed | Previous: e550cee | Ratio |
|---|---|---|---|
array/accumulate/Float32/1d |
100706 ns |
101415 ns |
0.99 |
array/accumulate/Float32/dims=1 |
75956 ns |
77151 ns |
0.98 |
array/accumulate/Float32/dims=1L |
1585415 ns |
1586551.5 ns |
1.00 |
array/accumulate/Float32/dims=2 |
142923 ns |
144125.5 ns |
0.99 |
array/accumulate/Float32/dims=2L |
656944 ns |
658127 ns |
1.00 |
array/accumulate/Int64/1d |
118315.5 ns |
118535 ns |
1.00 |
array/accumulate/Int64/dims=1 |
79639.5 ns |
80097 ns |
0.99 |
array/accumulate/Int64/dims=1L |
1694511.5 ns |
1694711.5 ns |
1.00 |
array/accumulate/Int64/dims=2 |
155389.5 ns |
156189 ns |
0.99 |
array/accumulate/Int64/dims=2L |
961609 ns |
962516 ns |
1.00 |
array/broadcast |
20678 ns |
20532 ns |
1.01 |
array/construct |
1271.5 ns |
1270 ns |
1.00 |
array/copy |
17918 ns |
18019 ns |
0.99 |
array/copyto!/cpu_to_gpu |
213811 ns |
215242 ns |
0.99 |
array/copyto!/gpu_to_cpu |
280447 ns |
283251 ns |
0.99 |
array/copyto!/gpu_to_gpu |
10608 ns |
10886 ns |
0.97 |
array/iteration/findall/bool |
134407 ns |
134914 ns |
1.00 |
array/iteration/findall/int |
148383 ns |
149150 ns |
0.99 |
array/iteration/findfirst/bool |
80689 ns |
81669 ns |
0.99 |
array/iteration/findfirst/int |
82922 ns |
83894 ns |
0.99 |
array/iteration/findmin/1d |
83881.5 ns |
84687.5 ns |
0.99 |
array/iteration/findmin/2d |
114898 ns |
114576 ns |
1.00 |
array/iteration/logical |
197998.5 ns |
202045 ns |
0.98 |
array/iteration/scalar |
67108 ns |
66531 ns |
1.01 |
array/permutedims/2d |
52030.5 ns |
52383 ns |
0.99 |
array/permutedims/3d |
52185 ns |
52705 ns |
0.99 |
array/permutedims/4d |
51517 ns |
51404 ns |
1.00 |
array/random/rand/Float32 |
13238 ns |
13062 ns |
1.01 |
array/random/rand/Int64 |
24403 ns |
24415.5 ns |
1.00 |
array/random/rand!/Float32 |
8400.125 ns |
8848.333333333334 ns |
0.95 |
array/random/rand!/Int64 |
21068 ns |
21621 ns |
0.97 |
array/random/randn/Float32 |
38979 ns |
42790.5 ns |
0.91 |
array/random/randn!/Float32 |
30348 ns |
30709 ns |
0.99 |
array/reductions/mapreduce/Float32/1d |
33800 ns |
34002.5 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=1 |
43459 ns |
49231 ns |
0.88 |
array/reductions/mapreduce/Float32/dims=1L |
51152 ns |
51127 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2 |
57744 ns |
58010 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2L |
67045.5 ns |
67462 ns |
0.99 |
array/reductions/mapreduce/Int64/1d |
42062 ns |
42403 ns |
0.99 |
array/reductions/mapreduce/Int64/dims=1 |
42243 ns |
51369.5 ns |
0.82 |
array/reductions/mapreduce/Int64/dims=1L |
86956 ns |
87141 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
60600 ns |
60667 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2L |
84056.5 ns |
83968.5 ns |
1.00 |
array/reductions/reduce/Float32/1d |
34021 ns |
34433 ns |
0.99 |
array/reductions/reduce/Float32/dims=1 |
40364 ns |
39665 ns |
1.02 |
array/reductions/reduce/Float32/dims=1L |
51118.5 ns |
51437 ns |
0.99 |
array/reductions/reduce/Float32/dims=2 |
58062 ns |
57991 ns |
1.00 |
array/reductions/reduce/Float32/dims=2L |
68048 ns |
68174 ns |
1.00 |
array/reductions/reduce/Int64/1d |
41684 ns |
42107 ns |
0.99 |
array/reductions/reduce/Int64/dims=1 |
50730 ns |
42289 ns |
1.20 |
array/reductions/reduce/Int64/dims=1L |
86984 ns |
86974 ns |
1.00 |
array/reductions/reduce/Int64/dims=2 |
60523.5 ns |
60553 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
83642 ns |
83815 ns |
1.00 |
array/reverse/1d |
17564 ns |
17742 ns |
0.99 |
array/reverse/1dL |
68165 ns |
68341 ns |
1.00 |
array/reverse/1dL_inplace |
65596 ns |
65689 ns |
1.00 |
array/reverse/1d_inplace |
10159 ns |
8403.666666666666 ns |
1.21 |
array/reverse/2d |
20782 ns |
20834 ns |
1.00 |
array/reverse/2dL |
72797 ns |
72716 ns |
1.00 |
array/reverse/2dL_inplace |
65685 ns |
65725 ns |
1.00 |
array/reverse/2d_inplace |
9806 ns |
9895 ns |
0.99 |
array/sorting/1d |
2734127.5 ns |
2736362 ns |
1.00 |
array/sorting/2d |
1067513 ns |
1069216 ns |
1.00 |
array/sorting/by |
3303607 ns |
3305075 ns |
1.00 |
cuda/synchronization/context/auto |
1181.2 ns |
1164.8 ns |
1.01 |
cuda/synchronization/context/blocking |
930.3333333333334 ns |
888.625 ns |
1.05 |
cuda/synchronization/context/nonblocking |
6790.4 ns |
8179 ns |
0.83 |
cuda/synchronization/stream/auto |
1000.9166666666666 ns |
1008.75 ns |
0.99 |
cuda/synchronization/stream/blocking |
824.3315789473684 ns |
794.4795918367347 ns |
1.04 |
cuda/synchronization/stream/nonblocking |
7051.9 ns |
7027.8 ns |
1.00 |
integration/byval/reference |
143661 ns |
143758 ns |
1.00 |
integration/byval/slices=1 |
145573 ns |
145752 ns |
1.00 |
integration/byval/slices=2 |
284416 ns |
284484 ns |
1.00 |
integration/byval/slices=3 |
422978 ns |
423202 ns |
1.00 |
integration/cudadevrt |
102335 ns |
102333 ns |
1.00 |
integration/volumerhs |
9945119.5 ns |
23449772.5 ns |
0.42 |
kernel/indexing |
13050 ns |
13234 ns |
0.99 |
kernel/indexing_checked |
13756 ns |
13897 ns |
0.99 |
kernel/launch |
2144.5 ns |
2186.777777777778 ns |
0.98 |
kernel/occupancy |
666.6477987421383 ns |
690.9602649006622 ns |
0.96 |
kernel/rand |
17874 ns |
14864 ns |
1.20 |
latency/import |
3846837228.5 ns |
3857912982.5 ns |
1.00 |
latency/precompile |
4658516091 ns |
4640798175 ns |
1.00 |
latency/ttfp |
4458051461.5 ns |
4434643422.5 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
Member
|
FYI, have you looked at cuTile.jl? It should automatically target tensor cores, while featuring a much more user friendly rounding mode API. |
Contributor
Author
|
Thank you @maleadt, really interesting! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
mma.sync.aligned.m8n8k4.{rn,rz,rm,rp}.f64rounding modifiers through both the low-levelllvm_wmma_mma_*wrappers and the high-levelWMMA.mmaAPI.Base.RoundingMode(RoundNearest,RoundToZero,RoundDown,RoundUp); the bare 3-arg form forwards to round-to-nearest, matching PTX's default-rnd convention.RoundNearestmode at codegen time, since f16/f32/bf16/int WMMA has no hardware-selectable rounding modifier.Continues the directed-rounding theme started in #2576 — there the scalar
add_rn/mul_rz/etc. intrinsics were exposed; here the same idea extends to the f64 tensor-core path, which is the only WMMA shape/type where PTX exposes per-call rounding control.Example
Notes for reviewers
map_frag_sizesreflects 2 with an inline citation..rn/.rz/.rm/.rpmodifiers — there is no implicit-default form. The bare-3-arg Julia wrapper therefore forwards to_rn, matching PTX's documented default.Test plan
🤖 Generated with Claude Code