Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend.
| **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` |
| **MLA (Multi-Latent Attention)** | — | ✅ | **AITER** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; `backend="auto"` (default) resolves to `"aiter"` |
| **POD attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA; single + batch variants (`PODWithPagedKVCacheWrapper`, `BatchPODWithPagedKVCacheWrapper`); JIT-only (excluded from AOT, same as upstream CUDA) |
| **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) |
| **RoPE (positional encoding)** | ✅ `native` | ✅ | **AITER** for the cos/sin-cache path when `fp16` + `>= 2048` tokens + gfx942/gfx950; else **HIP `native`** | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ). AITER backend covers `apply_rope_with_cos_sin_cache` and its inplace variant (CK `rope_cached_positions_2c`); ~1.3–2.8x over native at large nnz (zero-copy via the `_impl` entry point); bf16 stays native (slightly lower precision) |
| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + gfx942/gfx950 + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path |
| **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` |
| **Fused add RMSNorm** | ✅ `native` | ✅ | **AITER** when 2-D + `>= 4M` elements + gfx942/gfx950 + AITER importable; else **HIP `native`** | `fused_add_rmsnorm`; AITER (CK `rmsnorm2d_fwd_with_add`) wins on large bandwidth-bound shapes; 2-D only, slightly lower precision at `hidden_size >= 1024` |
Expand Down Expand Up @@ -309,8 +309,8 @@ pytest -n auto --reruns 2 -m "slow"

FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill`
(paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`,
`fused_add_rmsnorm`, `silu_and_mul`, and `MLA` paths to
[AITER](https://github.com/ROCm/aiter). MLA on ROCm
`fused_add_rmsnorm`, `silu_and_mul`, `rope` (cos/sin-cache path), and
`MLA` paths to [AITER](https://github.com/ROCm/aiter). MLA on ROCm
is **AITER-only** — there is no in-tree HIP MLA kernel yet, so
`backend="auto"` (the default for the MLA wrapper) resolves directly
to `"aiter"`.
Expand All @@ -323,8 +323,8 @@ explicitly, or pass the in-tree backend string to skip it:
`backend="fa2"` for the attention wrappers (single/batch
prefill/decode), `backend="native"` for non-attention ops
(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`,
`silu_and_mul`). Four backend-specific exceptions to "auto picks AITER
when supported":
`silu_and_mul`, `rope`). Five backend-specific exceptions to "auto picks
AITER when supported":

* `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the
AITER path is opt-in via `backend="aiter"`.
Expand All @@ -335,6 +335,14 @@ when supported":
`>= 33M`-element inputs (where it is faster and matches native
precision) and otherwise stays on HIP `native`; the AITER path is also
available explicitly via `backend="aiter"`.
* `rope` (`apply_rope_with_cos_sin_cache` / `_inplace`): `backend="auto"`
picks AITER for `fp16` inputs with `>= 2048` tokens (where the AITER
kernel is ~1.3–2.8x faster and fp16 precision stays inside tolerance)
and otherwise stays on HIP `native`. Both the inplace and out-of-place
wrappers benefit — the helper uses AITER's `..._impl` entry point with
distinct in/out tensors, so the out-of-place path needs no Q/K copy.
bf16 always stays native (slightly lower precision). The AITER path is
also available explicitly via `backend="aiter"`.
* `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True`
force `auto` back to `fa2` (AITER decode does not support either),
and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`.
Expand Down
245 changes: 245 additions & 0 deletions flashinfer/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""

import functools
import weakref
from typing import Optional, Tuple

import torch

from .device_utils import IS_HIP
from .jit.rope import gen_rope_module
from .utils import register_custom_op, register_fake_op

Expand All @@ -28,6 +30,180 @@ def get_rope_module():
return gen_rope_module().build_and_load()


if IS_HIP:

@functools.cache
def _aiter_rope_ops():
import aiter as _aiter

return _aiter

# Token count above which AITER's cos/sin-cache rope beats the native JIT
# kernel. Measured on gfx942 (fp16, q32/k8, hd128): AITER crosses over around
# nnz~1024-1536 and reaches ~2.6-2.8x at large nnz, while native's lower
# launch overhead wins below it. 2048 leaves headroom over launch jitter.
_AITER_ROPE_MIN_TOKENS = 2048

def _auto_select_rope_backend(query: torch.Tensor, key: torch.Tensor) -> str:
# AITER's cos/sin-cache rope consumes the cos/sin tables in the query
# dtype, whereas the native JIT kernel rotates in float32. For bf16 this
# pushes max abs error to ~5e-2 (vs native ~3e-2), at the edge of the
# rope test tolerance, so auto never picks AITER for bf16 — only fp16,
# whose AITER error (~7e-3) stays comfortably inside tolerance.
#
# AITER wins on both the inplace and out-of-place paths once nnz is large
# enough: the helper uses the ..._impl entry point with distinct in/out
# tensors, so out-of-place needs no Q/K copy. Below the token threshold,
# native's lower launch overhead wins.
if query.dtype != torch.float16:
return "native"
if key.dtype != query.dtype:
# AITER rotates Q and K with one cos/sin table in the query dtype and
# rejects mismatched dtypes; native handles them. auto must not raise,
# so fall back rather than route into the helper's dtype guard.
return "native"
if query.shape[0] < _AITER_ROPE_MIN_TOKENS:
return "native"
from .aiter_utils import is_aiter_supported

if not is_aiter_supported(query.device):
return "native"
try:
# Best-effort probe: a supported arch can still lack a usable aiter
# (not installed, or its compiled extension fails to load). auto must
# always fall back to native, never raise — matching norm/activation.
_aiter_rope_ops()
except Exception:
return "native"
return "aiter"

# Memoize the AITER-format cos/sin tables. ``cos_sin_cache`` is a fixed
# precomputed float32 table (typically a persistent module buffer) reused
# across every forward pass, so converting the whole table to the query
# dtype on each call is pure overhead — significant during decode, where
# nnz is tiny but max_seq_len is large. Keyed by id() (tensors aren't
# value-hashable) with a finalizer that evicts the entry when the cache is
# GC'd, so the cached tables never outlive their source. A weakref to the
# source is stored alongside the entry so a hit is only honored when it
# still resolves to the *same* tensor — guarding against id() reuse if a
# transient cache is freed and a new tensor lands on the same address.
_aiter_cos_sin_tables: dict = {}

def _aiter_rope_cos_sin(
cos_sin_cache: torch.Tensor, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
key = id(cos_sin_cache)
cached = _aiter_cos_sin_tables.get(key)
if cached is not None and cached[0] == dtype and cached[1]() is cos_sin_cache:
return cached[2], cached[3]
half = cos_sin_cache.shape[-1] // 2
cos = cos_sin_cache[:, :half].unsqueeze(1).unsqueeze(1).to(dtype).contiguous()
sin = cos_sin_cache[:, half:].unsqueeze(1).unsqueeze(1).to(dtype).contiguous()
if cached is None:
weakref.finalize(cos_sin_cache, _aiter_cos_sin_tables.pop, key, None)
_aiter_cos_sin_tables[key] = (dtype, weakref.ref(cos_sin_cache), cos, sin)
return cos, sin

def _apply_rope_cos_sin_cache_aiter(
query: torch.Tensor,
key: torch.Tensor,
query_out: torch.Tensor,
key_out: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
head_size: int,
is_neox: bool,
) -> None:
r"""Dispatch the cos/sin-cache rope to AITER's rope_cached_positions_2c kernel.

FlashInfer stores ``cos_sin_cache`` as ``(max_seq_len, rotary_dim)`` float32
with cosine in the first half and sine in the second half. AITER wants two
separate ``(max_seq_len, 1, 1, rotary_dim // 2)`` tables in the query dtype
with ``reuse_freqs_front_part=True``. Q/K are reshaped to AITER's
``(1, nnz, num_heads, head_dim)`` layout and only the leading ``rotary_dim``
slice is rotated (matching ``nope_first=False``).

Uses the ``..._impl`` entry point with distinct input/output tensors so the
out-of-place case writes straight into ``query_out``/``key_out`` with no Q/K
copy (measured ~1.3-2.8x faster than copy-then-rotate). The kernel only
writes the rotated ``[:rotary_dim]`` slice, so for partial rotary
(``rotary_dim < head_size``) the untouched nope tail is copied across when
the output is a fresh tensor. When ``query_out``/``key_out`` alias the
inputs (the inplace variant) this is ``_impl(x, y, x, y, ...)``.
"""
from .aiter_utils import is_aiter_supported

if not is_aiter_supported(query.device):
raise ValueError(
"AITER rope backend requires an AMD gfx942/gfx950 device; "
"use backend='native' instead."
)
try:
aiter_ops = _aiter_rope_ops()
except Exception as e:
raise ValueError(
"backend='aiter' requires the aiter package, which failed to "
f"import: {e}"
) from e
if key.dtype != query.dtype:
# AITER rotates Q and K with a single cos/sin table built in the
# query dtype; the native path tolerates mixed dtypes by rotating
# in float32, but AITER cannot.
raise ValueError(
"AITER rope backend requires query and key to share a dtype; "
f"got query={query.dtype}, key={key.dtype}. Use backend='native'."
)

rotary_dim = cos_sin_cache.shape[-1]
# cos_sin_cache stacks cos||sin on its last dim, so rotary_dim must be
# even, and the rotated slice q[..., :rotary_dim] must fit in head_size.
if rotary_dim % 2 != 0:
raise ValueError(
f"cos_sin_cache last dim must be even (cos||sin); got {rotary_dim}."
)
if rotary_dim > head_size:
raise ValueError(
f"rotary_dim ({rotary_dim}) from cos_sin_cache exceeds head_size "
f"({head_size})."
)

nnz = query.shape[0]
cos, sin = _aiter_rope_cos_sin(cos_sin_cache, query.dtype)

q_in = query.view(1, nnz, -1, head_size)
k_in = key.view(1, nnz, -1, head_size)
q_out = query_out.view(1, nnz, -1, head_size)
k_out = key_out.view(1, nnz, -1, head_size)

# The kernel only writes the rotated [:rotary_dim] slice. For partial
# rotary into a fresh output, copy the untouched nope tail [rotary_dim:]
# across first (skipped when output aliases input, and when rotary_dim
# covers the full head_size).
if rotary_dim < head_size:
if query_out.data_ptr() != query.data_ptr():
q_out[..., rotary_dim:].copy_(q_in[..., rotary_dim:])
if key_out.data_ptr() != key.data_ptr():
k_out[..., rotary_dim:].copy_(k_in[..., rotary_dim:])

# AITER's HIP kernel asserts int64, contiguous positions of shape
# (1, nnz) (stride(1) == 1) — a strided/non-int64 positions tensor
# otherwise trips a C assert that aborts the process.
pos = positions.to(torch.int64).contiguous().view(1, nnz)

aiter_ops.rope_cached_positions_2c_fwd_impl(
q_out[..., :rotary_dim],
k_out[..., :rotary_dim],
q_in[..., :rotary_dim],
k_in[..., :rotary_dim],
cos,
sin,
pos,
0 if is_neox else 1, # rotate_style: 0=NEOX, 1=GPT-J
True, # reuse_freqs_front_part (cos/sin are rotary_dim//2 sized)
False, # nope_first
)


@register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope"))
def _apply_rope(
q: torch.Tensor,
Expand Down Expand Up @@ -1138,6 +1314,7 @@ def apply_rope_with_cos_sin_cache(
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
backend: str = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
Expand All @@ -1164,6 +1341,18 @@ def apply_rope_with_cos_sin_cache(
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.

backend : str
Kernel backend to use. ``"auto"`` (default) selects the best backend for
the call: on ROCm (gfx942/gfx950) it picks AITER for fp16 inputs with at
least ~2048 tokens (where the AITER kernel is measurably faster), and stays
on ``"native"`` otherwise — for bf16 (precision), small token counts
(launch overhead), and non-ROCm platforms.
``"native"`` uses the FlashInfer JIT kernel on all platforms.
``"aiter"`` uses AMD AITER's rope_cached kernel — ROCm (gfx942/gfx950) only;
requires the ``aiter`` package. Precision is slightly lower than ``"native"``
for bfloat16 (max abs error ~5e-2 vs ~3e-2) because AITER consumes the cos/sin
tables in the query dtype rather than float32.

Returns
-------
query_out : torch.Tensor
Expand All @@ -1181,6 +1370,27 @@ def apply_rope_with_cos_sin_cache(
query_out = torch.empty_like(query)
key_out = torch.empty_like(key)

if IS_HIP:
_backend = (
backend if backend != "auto" else _auto_select_rope_backend(query, key)
)
if _backend == "aiter":
_apply_rope_cos_sin_cache_aiter(
query=query,
key=key,
query_out=query_out,
key_out=key_out,
cos_sin_cache=cos_sin_cache,
positions=positions,
head_size=head_size,
is_neox=is_neox,
)
return query_out, key_out
if _backend != "native":
raise ValueError(
f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'."
)

_apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
Expand All @@ -1201,6 +1411,7 @@ def apply_rope_with_cos_sin_cache_inplace(
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
backend: str = "auto",
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
Expand All @@ -1227,13 +1438,47 @@ def apply_rope_with_cos_sin_cache_inplace(

* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.

backend : str
Kernel backend to use. ``"auto"`` (default) selects the best backend for
the call: on ROCm (gfx942/gfx950) it picks AITER for fp16 inputs with at
least ~2048 tokens (where AITER's kernel is measurably faster), and stays
on ``"native"`` otherwise — for bf16 (precision), small token counts
(launch overhead), and non-ROCm platforms.
``"native"`` uses the FlashInfer JIT kernel on all platforms.
``"aiter"`` uses AMD AITER's rope_cached kernel — ROCm (gfx942/gfx950) only;
requires the ``aiter`` package. Precision is slightly lower than ``"native"``
for bfloat16 (max abs error ~5e-2 vs ~3e-2) because AITER consumes the cos/sin
tables in the query dtype rather than float32.

Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")

if IS_HIP:
_backend = (
backend if backend != "auto" else _auto_select_rope_backend(query, key)
)
if _backend == "aiter":
_apply_rope_cos_sin_cache_aiter(
query=query,
key=key,
query_out=query,
key_out=key,
cos_sin_cache=cos_sin_cache,
positions=positions,
head_size=head_size,
is_neox=is_neox,
)
return
if _backend != "native":
raise ValueError(
f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'."
)

# pass q_rope and k_rope as q and k to perform inplace operation
_apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
Expand Down
Loading
Loading