Skip to content

Unified attention CK Tile kernel#3128

Open
juuso-oskari wants to merge 105 commits intodevelopfrom
tianxing/unified-attention
Open

Unified attention CK Tile kernel#3128
juuso-oskari wants to merge 105 commits intodevelopfrom
tianxing/unified-attention

Conversation

@juuso-oskari
Copy link

@juuso-oskari juuso-oskari commented Oct 30, 2025

Authors: @Chi-Chu319 @juuso-oskari

This PR implements a unified attention kernel written in CK Tile. It builds on top of the fmha_v3 (composable_kernel/example/ck_tile/01_fmha) with the pipeline largely remaining the same. This PR implements the following features introduced in Triton unified attention kernel:

reduced launch grid size at composable_kernel/example/ck_tile/01_unified_attention/unified_attention_impl.hpp

// args.num_tokens is the cumulative amount of tokens from all sequences
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
dim3 grids            = Kernel::GridSize2D(args.num_kv_heads, total_num_q_blocks);
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

This is significantly less amount of programs launched compared to before grid=(num_seqs, max_seqlen // BLOCK_M, num_q_heads), which contained lots of empty programs (not all sequences are of length max_seqlen).

But since now the current sequence index cannot be taken from the program id, we need to do a binary search at the beginning of the kernel to find our sequence index (used to index sequence length; needed for determining innerloop length).

This is implemented at composable_kernel/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp:

// Binary search to find the sequence index for a given global index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
                ck_tile::index_t target_idx,
                ck_tile::index_t num_seqs,
                ck_tile::index_t block_q,
                bool use_q_block_mode)
{
    ck_tile::index_t left = 0;
    ck_tile::index_t right = num_seqs;
    while (left < right)
    {
        ck_tile::index_t mid = (left + right) / 2;
        ck_tile::index_t val = query_start_len_ptr[mid];
        ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
        
        if (mid_val <= target_idx)
        {
            left = mid + 1;
        }
        else
        {
            right = mid;
        }
    }
    return left - 1;
}
// usage inside the kernel
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
// grid size is (num_kv_heads, total_num_q_blocks)
// total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
// q.shape[0] is total number of query tokens across all batches
const index_t seq_idx = find_seq_idx(
    kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true
); // which seq am I

In order to process more query tokens per load in decode settings (where sequence length is small, often only 1), we group query tokens in the head dim. Up to num_queries_per_kv query tokens share the same key/value token (CQA-setting). The total number of grouped tokens for a tile load is BLOCK_M = BLOCK_Q * num_queries_per_kv.

We do this in the kernel implementation by transforming the tensor view for Q in dram:

const auto q_dram = [&]() {
    const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
        q_ptr,
        make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
        make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
        number<UnifiedAttentionPipeline::kAlignmentQ>{},
        number<1>{});

    const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
        q_dram_base,
        // block sizes
        make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
        sequence<true, false, kPadHeadDimQ>{}
    ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)

    const auto q_dram_merged = transform_tensor_view(
                q_dram_pad,
                make_tuple(
                    make_merge_transform(
                        make_tuple(query_len_padded, num_queries_per_kv)
                    ),
                    make_pass_through_transform(HEAD_SIZE_PADDED)
                ),
                make_tuple(sequence<0, 1>{}, sequence<2>{}),
                make_tuple(sequence<0>{}, sequence<1>{})
    ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim
    return q_dram_merged;
}();

This way, pipeline can remain untouched and use the BLOCK_M as its tile size.

build

cd ~/composable_kernel
# in the root
mkdir build && cd build
# you can replace "gfx950" with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh .. "gfx950" -G Ninja
ninja tile_example_unified_attention -j64
./bin/tile_example_unified_attention

Benchmark comparison against FAv3 and Triton Unified Attention

Prefill batch: b=1, h=64, d=128, seqlen_q/k=16384, causal=1

Ours: 548 TFLOPS

# need to manually change the num_queries_per_kv from the example files to 1
jukorhon@asrock-1w300-e0-3:~/composable_kernel/build$ ./bin/tile_example_unified_attention -h_k=64 -query_lens=16384 -kv_lens=16384 -varlen=0
[bf16|] b:1, h:64/64, d:128, scale_s:0.0883883, query_lens:[16384], kv_lens:[16384], mask:causal mask, 8.01874065 ms, 548.50 TFlops, 0.07 TB/s

fmha_pagedkv: 531 TFLOPS

jukorhon@asrock-1w300-e0-3:~/composable_kernel/build$ ./bin/tile_example_fmha_fwd -h=64 -b=1 -mask='b' -s=16384 -v=0 -prec=bf16 -page_block_size=128
[bf16|batch|bhsd] b:1, h:64/64, s:16384/16384, d:128/128, scale_s:0.0883883, bias:n, p_drop:0, lse:0, qscale:n, mask:b(-1:0), v:r, page_block_size:128, 8.273 ms, 531.63 TFlops, 129.78 GB/s

fmha_v3: 654 TFLOPS

jukorhon@asrock-1w300-e0-3:~/composable_kernel/build$ ./bin/tile_example_fmha_fwd_v3 -h=64 -q_eff_lens=16384 -kv_eff_lens=16384 -b=1 -causal=1 -s=16384 -v=0
[fp16|bshd] b:1, h:64/64, s:16384/16384, d:128, scale_s:0.0883883, mask:b(-1:0), 6.719 ms, 654.66 TFlops

FAv3 Triton blogpost: 654 TFLOPS

Unified Attention Triton (at branch bench-unified-attention:

(py_3.10) tianxiwu@asrock-1w300-e0-3:~/rocm/aiter$ python op_tests/op_benchmarks/triton/bench_unified_attention.py  -hq 64 -hk 64 -b 1 -sq 16384 -sk 16384 -d 128 -unified_attention -causal true --layout "thd"
[aiter] import [module_aiter_enum] under /home/tianxiwu/rocm/aiter/aiter/jit/module_aiter_enum.so
bench_unified_attention:
   BATCH    HQ    HK  N_CTX_Q  N_CTX_K  D_HEAD  D_HEAD_V  fwd(TFLOPS)
0    1.0  64.0  64.0  16384.0  16384.0   128.0     128.0   472.475302

Mixed batch: b=8, h_q / h_k=128/8, d=128, seqlen_q=1024,1,32,256,1,1,4096,4096, seqlen_k=16384, causal=1

Ours: 674 TFLOPS

# need to manually change the num_queries_per_kv from the example files to 16
jukorhon@asrock-1w300-e0-3:~/composable_kernel/build$ ./bin/tile_example_unified_attention -h_k=8 -query_lens=1024,1,32,256,1,1,4096,4096 -kv_lens=16384,16384,16384,16384,16384,16384,16384,16384 -varlen=0
[0/2] Re-checking globbed directories...
[7/8] Linking CXX executable bin/tile_example_unified_attention
[bf16|] b:8, h:128/8, d:128, scale_s:0.0883883, query_lens:[1024,1,32,256,1,1,4096,4096], kv_lens:[16384,16384,16384,16384,16384,16384,16384,16384], mask:causal mask, 13.45901966 ms, 674.07 TFlops, 0.05 TB/s

Unified Attention Triton (at branch bench-unified-attention): 562 TFLOPS

(py_3.10) jukorhon@asrock-1w300-e0-3:~/aiter$ python op_tests/op_benchmarks/triton/bench_unified_attention.py -causal true -d 128 -hq 32 -hk 8 -sq 128,1,1,1024,1,1,4096,4096 -sk 16384 -unified_attention --layout "thd"
[aiter] import [module_aiter_enum] under /home/jukorhon/aiter/aiter/jit/module_aiter_enum.so
bench_unified_attention:
   BATCH  HQ  HK                     N_CTX_Q  N_CTX_K  D_HEAD  D_HEAD_V  fwd(TFLOPS) (TFLOPS)
0      8  32   8  128,1,1,1024,1,1,4096,4096    16384     128       128            562.924279

Copy link
Contributor

@spolifroni-amd spolifroni-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The readme needs some work.

# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_unified_attention -j
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an empty line below this

## kernel
The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.

There are 2 template parameters for this kernel template.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an empty line below

```
This will result in an executable `build/bin/tile_example_unified_attention`

## kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put an empty line beneath all headers

Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.
![](misc/gamc.png)

Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate. The following are some examples of how to achieve different kind of mask through cmdline:


Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.

### dropout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

### dropout
TBD

### sequence padding and variable length support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

### sequence padding and variable length support
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.

**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical strides between batches and enable padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`). Use larger physical strides for memory alignment.


Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.

## FP8 experimental support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

Copy link
Contributor

@spolifroni-amd spolifroni-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The readme needs some work.

Chi-Chu319 and others added 3 commits January 2, 2026 14:12
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
Comment on lines +12 to +13
```
This will result in an executable `build/bin/tile_example_unified_attention`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```
This will result in an executable `build/bin/tile_example_unified_attention`

This will result in an executable build/bin/tile_example_unified_attention

@asleepzzz
Copy link
Contributor

please remove fmha_bwd_known_fails_gfxxxx.txt since you don't need them

@poyenc
Copy link
Contributor

poyenc commented Feb 3, 2026

please resolve the merge conflicts

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request implements a unified attention kernel in CK Tile that optimizes fused multi-head attention for AMD GPUs. The implementation reduces launch grid size by using a variable-length approach with binary search for sequence indexing, and groups query tokens in the head dimension to improve decode performance.

Changes:

  • Implements unified attention pipeline with optimized grid sizing and binary search for sequence indexing
  • Adds support for variable-length sequences and paged KV cache
  • Provides example implementations with fp16/bf16 support and causal masking

Reviewed changes

Copilot reviewed 26 out of 33 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
script/cmake-ck-dev.sh Modified GPU target (should not be committed)
include/ck_tile/ops/unified_attention/pipeline/*.hpp Pipeline infrastructure for unified attention
include/ck_tile/ops/unified_attention/kernel/*.hpp Kernel implementation with binary search and masking
include/ck_tile/ops/unified_attention/block/*.hpp Generic attention mask implementation
example/ck_tile/42_unified_attention/* Example code, instances, and build files

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

namespace ck_tile {

template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDim_ /* paddding for hdim_v */,
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "paddding" should be "padding" (one 'd' removed).

Suggested change
bool kPadHeadDim_ /* paddding for hdim_v */,
bool kPadHeadDim_ /* padding for hdim_v */,

Copilot uses AI. Check for mistakes.
template <unified_attention_args::data_type_enum DataType, bool IsMasking>
struct unified_attention_kernel_traits
{
static constexpr auto date_type = DataType;
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in variable name: "date_type" should be "data_type". This naming inconsistency is used throughout the struct definition and should be corrected for clarity.

Copilot uses AI. Check for mistakes.
fi

GPU_TARGETS="gfx908;gfx90a;gfx942"
GPU_TARGETS="gfx950"
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GPU_TARGETS has been hardcoded to "gfx950" in the development script, but this change should not be committed to the repository. Development scripts should maintain the default configuration with multiple GPU targets to ensure compatibility across different architectures.

Copilot uses AI. Check for mistakes.
typename Traits_>
struct UnifiedAttentionPipelineProblem
{
// TODO kM0 and KN1??
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "TODO kM0 and KN1??" suggests incomplete work or unclear naming. If this is a genuine TODO item that needs to be addressed, it should be properly documented with what needs to be done. Otherwise, remove the comment.

Copilot uses AI. Check for mistakes.
{
using namespace ck_tile;

/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "congtigous" should be "contiguous".

Copilot uses AI. Check for mistakes.
Comment on lines +24 to +34
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function print_log_header interpolates untrusted shell arguments directly into commands without quoting (e.g., rm -f $1; and echo 'On branch ' $3 &> $1;), which allows command injection if any of the positional parameters contain shell metacharacters (such as ;, &, |, or spaces). Because $1, $2, $3, and $4 are derived from script arguments (env_type, branch, host_name, GPU_arch), an attacker who can influence these values (for example via CI variables or crafted branch names) could execute arbitrary commands with the script's privileges. To fix this, treat all positional parameters as data by consistently quoting them (e.g., rm -f "$1", echo 'On branch ' "$3" > "$1") and avoid letting user-controlled values be expanded in a way that the shell can interpret as additional syntax.

Copilot uses AI. Check for mistakes.
index_t kBlockQ = Kernel::kBlockQ;
assert(args.num_queries_per_kv == Kernel::num_queries_per_kv &&
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
Copy link
Contributor

@poyenc poyenc Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unified_attention_args does not have the BLOCK_SIZE attribute. Please review if this assert() is necessary.

"argument num_queries_per_kv must equal compiled num_queries_per_kv");
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
assert(kBlockQ == kBlockM / args.num_queries_per_kv &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add missing quailfer Kernel:: for accessing those attributes: Kernel::kBlockM

args.num_head_q = problem.nhead_q;
args.num_queries_per_kv = num_queries_per_kv;
args.page_blk_size = problem.page_blk_size;
args.mask_type = 2;
Copy link
Contributor

@poyenc poyenc Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask_type is hardcoded and ignores the command-line option, so we cannot tell whether the unmasked kernels work correctly. Please update this.

-repeat number of iterations to benchmark the kernel (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:fmha_fwd.json)
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
Copy link
Contributor

@poyenc poyenc Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove q_eff_lens/kv_eff_lens from the README.md if they are not intended to be supported.

@ammallya
Copy link
Contributor

ammallya commented Feb 3, 2026

Error importing due to merge conflicts – please reopen the PR on ROCm/rocm-libraries

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants