Conversation
…composable_kernel into tianxing/unified-attention
…composable_kernel into tianxing/unified-attention
…composable_kernel into tianxing/unified-attention
…composable_kernel into tianxing/unified-attention
…composable_kernel into tianxing/unified-attention
spolifroni-amd
left a comment
There was a problem hiding this comment.
The readme needs some work.
| # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank | ||
| ../script/cmake-ck-dev.sh ../ <arch> | ||
| make tile_example_unified_attention -j | ||
| ``` |
There was a problem hiding this comment.
Add an empty line below this
| ## kernel | ||
| The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. | ||
|
|
||
| There are 2 template parameters for this kernel template. |
There was a problem hiding this comment.
Add an empty line below
| ``` | ||
| This will result in an executable `build/bin/tile_example_unified_attention` | ||
|
|
||
| ## kernel |
There was a problem hiding this comment.
Put an empty line beneath all headers
| Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out. | ||
|  | ||
|
|
||
| Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline. |
There was a problem hiding this comment.
| Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline. | |
| Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate. The following are some examples of how to achieve different kind of mask through cmdline: |
|
|
||
| Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right. | ||
|
|
||
| ### dropout |
| ### dropout | ||
| TBD | ||
|
|
||
| ### sequence padding and variable length support |
| ### sequence padding and variable length support | ||
| We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths. | ||
|
|
||
| **Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. |
There was a problem hiding this comment.
| **Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. | |
| **Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical strides between batches and enable padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`). Use larger physical strides for memory alignment. |
|
|
||
| Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. | ||
|
|
||
| ## FP8 experimental support |
spolifroni-amd
left a comment
There was a problem hiding this comment.
The readme needs some work.
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
| ``` | ||
| This will result in an executable `build/bin/tile_example_unified_attention` |
There was a problem hiding this comment.
| ``` | |
| This will result in an executable `build/bin/tile_example_unified_attention` |
This will result in an executable build/bin/tile_example_unified_attention
|
please remove fmha_bwd_known_fails_gfxxxx.txt since you don't need them |
|
please resolve the merge conflicts |
There was a problem hiding this comment.
Pull request overview
This pull request implements a unified attention kernel in CK Tile that optimizes fused multi-head attention for AMD GPUs. The implementation reduces launch grid size by using a variable-length approach with binary search for sequence indexing, and groups query tokens in the head dimension to improve decode performance.
Changes:
- Implements unified attention pipeline with optimized grid sizing and binary search for sequence indexing
- Adds support for variable-length sequences and paged KV cache
- Provides example implementations with fp16/bf16 support and causal masking
Reviewed changes
Copilot reviewed 26 out of 33 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| script/cmake-ck-dev.sh | Modified GPU target (should not be committed) |
| include/ck_tile/ops/unified_attention/pipeline/*.hpp | Pipeline infrastructure for unified attention |
| include/ck_tile/ops/unified_attention/kernel/*.hpp | Kernel implementation with binary search and masking |
| include/ck_tile/ops/unified_attention/block/*.hpp | Generic attention mask implementation |
| example/ck_tile/42_unified_attention/* | Example code, instances, and build files |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| namespace ck_tile { | ||
|
|
||
| template <bool kPadSeqLenQ_ /* padding for seqlen_q */, | ||
| bool kPadHeadDim_ /* paddding for hdim_v */, |
There was a problem hiding this comment.
Typo in comment: "paddding" should be "padding" (one 'd' removed).
| bool kPadHeadDim_ /* paddding for hdim_v */, | |
| bool kPadHeadDim_ /* padding for hdim_v */, |
| template <unified_attention_args::data_type_enum DataType, bool IsMasking> | ||
| struct unified_attention_kernel_traits | ||
| { | ||
| static constexpr auto date_type = DataType; |
There was a problem hiding this comment.
Typo in variable name: "date_type" should be "data_type". This naming inconsistency is used throughout the struct definition and should be corrected for clarity.
| fi | ||
|
|
||
| GPU_TARGETS="gfx908;gfx90a;gfx942" | ||
| GPU_TARGETS="gfx950" |
There was a problem hiding this comment.
The GPU_TARGETS has been hardcoded to "gfx950" in the development script, but this change should not be committed to the repository. Development scripts should maintain the default configuration with multiple GPU targets to ensure compatibility across different architectures.
| typename Traits_> | ||
| struct UnifiedAttentionPipelineProblem | ||
| { | ||
| // TODO kM0 and KN1?? |
There was a problem hiding this comment.
The comment "TODO kM0 and KN1??" suggests incomplete work or unclear naming. If this is a genuine TODO item that needs to be addressed, it should be properly documented with what needs to be done. Otherwise, remove the comment.
| { | ||
| using namespace ck_tile; | ||
|
|
||
| /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension |
There was a problem hiding this comment.
Typo in comment: "congtigous" should be "contiguous".
| function print_log_header(){ | ||
| rm -f $1; | ||
| echo 'On branch ' $3 &> $1; | ||
| echo 'Node name: ' $4 >> $1; | ||
| #get GPU_arch and number of compute units from rocminfo | ||
| echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; | ||
| rocminfo | grep "Compute Unit:" >> $1; | ||
| hipcc --version | grep -e 'HIP version' >> $1; | ||
| echo 'Environment type: ' $2 >> $1; | ||
| /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; | ||
| } |
There was a problem hiding this comment.
The function print_log_header interpolates untrusted shell arguments directly into commands without quoting (e.g., rm -f $1; and echo 'On branch ' $3 &> $1;), which allows command injection if any of the positional parameters contain shell metacharacters (such as ;, &, |, or spaces). Because $1, $2, $3, and $4 are derived from script arguments (env_type, branch, host_name, GPU_arch), an attacker who can influence these values (for example via CI variables or crafted branch names) could execute arbitrary commands with the script's privileges. To fix this, treat all positional parameters as data by consistently quoting them (e.g., rm -f "$1", echo 'On branch ' "$3" > "$1") and avoid letting user-controlled values be expanded in a way that the shell can interpret as additional syntax.
| index_t kBlockQ = Kernel::kBlockQ; | ||
| assert(args.num_queries_per_kv == Kernel::num_queries_per_kv && | ||
| "argument num_queries_per_kv must equal compiled num_queries_per_kv"); | ||
| assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && |
There was a problem hiding this comment.
The unified_attention_args does not have the BLOCK_SIZE attribute. Please review if this assert() is necessary.
| "argument num_queries_per_kv must equal compiled num_queries_per_kv"); | ||
| assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && | ||
| "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); | ||
| assert(kBlockQ == kBlockM / args.num_queries_per_kv && |
There was a problem hiding this comment.
Add missing quailfer Kernel:: for accessing those attributes: Kernel::kBlockM
| args.num_head_q = problem.nhead_q; | ||
| args.num_queries_per_kv = num_queries_per_kv; | ||
| args.page_blk_size = problem.page_blk_size; | ||
| args.mask_type = 2; |
There was a problem hiding this comment.
The mask_type is hardcoded and ignores the command-line option, so we cannot tell whether the unmasked kernels work correctly. Please update this.
| -repeat number of iterations to benchmark the kernel (default:20) | ||
| -json 0: No Json, 1: Dump Results in Json format (default:0) | ||
| -jsonfile json file name to dump results (default:fmha_fwd.json) | ||
| -q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"") |
There was a problem hiding this comment.
Please remove q_eff_lens/kv_eff_lens from the README.md if they are not intended to be supported.
|
Error importing due to merge conflicts – please reopen the PR on ROCm/rocm-libraries |
Authors: @Chi-Chu319 @juuso-oskari
This PR implements a unified attention kernel written in CK Tile. It builds on top of the fmha_v3 (composable_kernel/example/ck_tile/01_fmha) with the pipeline largely remaining the same. This PR implements the following features introduced in Triton unified attention kernel:
reduced launch grid size at composable_kernel/example/ck_tile/01_unified_attention/unified_attention_impl.hpp
This is significantly less amount of programs launched compared to before grid=(num_seqs, max_seqlen // BLOCK_M, num_q_heads), which contained lots of empty programs (not all sequences are of length max_seqlen).
But since now the current sequence index cannot be taken from the program id, we need to do a binary search at the beginning of the kernel to find our sequence index (used to index sequence length; needed for determining innerloop length).
This is implemented at composable_kernel/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp:
In order to process more query tokens per load in decode settings (where sequence length is small, often only 1), we group query tokens in the head dim. Up to num_queries_per_kv query tokens share the same key/value token (CQA-setting). The total number of grouped tokens for a tile load is BLOCK_M = BLOCK_Q * num_queries_per_kv.
We do this in the kernel implementation by transforming the tensor view for Q in dram:
This way, pipeline can remain untouched and use the BLOCK_M as its tile size.
build
Benchmark comparison against FAv3 and Triton Unified Attention
Prefill batch: b=1, h=64, d=128, seqlen_q/k=16384, causal=1
Ours: 548 TFLOPS
fmha_pagedkv: 531 TFLOPS
fmha_v3: 654 TFLOPS
FAv3 Triton blogpost: 654 TFLOPS
Unified Attention Triton (at branch bench-unified-attention:
Mixed batch: b=8, h_q / h_k=128/8, d=128, seqlen_q=1024,1,32,256,1,1,4096,4096, seqlen_k=16384, causal=1
Ours: 674 TFLOPS
Unified Attention Triton (at branch bench-unified-attention): 562 TFLOPS