From cfecfc3f51ca44e2173167da820d70432752b608 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 19:26:00 +0000 Subject: [PATCH 1/6] feat(hip,aiter): add AITER backend to rope cos/sin-cache ops Route apply_rope_with_cos_sin_cache and its inplace variant through AITER's rope_cached_positions_2c kernel on ROCm via a new backend="aiter" opt-in, mirroring the rmsnorm/silu_and_mul AITER backend pattern. The helper adapts FlashInfer's formats to AITER's: splits the (max_seq_len, rotary_dim) float32 cos||sin cache into two (max_seq_len, 1, 1, rotary_dim//2) tables in the query dtype (memoized per cache tensor to avoid re-converting the full table every forward pass), reshapes Q/K to AITER's (1, nnz, heads, dim) layout, and rotates only the leading rotary_dim slice. backend="auto" stays on the native kernel: AITER consumes the cos/sin tables in the query dtype rather than float32, raising bf16 max abs error to ~5e-2 (vs native ~3e-2), at the edge of the rope test tolerance. "aiter" is explicit opt-in. Guards added to the AITER path: device must be gfx942/gfx950, query/key must share a dtype (AITER rotates both with one cos/sin table), and positions are coerced to contiguous int64 (a strided positions tensor otherwise trips a C assert that aborts the process). Co-Authored-By: Claude Opus 4.7 --- flashinfer/rope.py | 164 +++++++++++++++++++ tests/rocm_tests/test_rope_aiter_hip.py | 199 ++++++++++++++++++++++++ 2 files changed, 363 insertions(+) create mode 100644 tests/rocm_tests/test_rope_aiter_hip.py diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 19ee2011fe..04df6c56cf 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -15,10 +15,12 @@ """ import functools +import weakref from typing import Optional, Tuple import torch +from .device_utils import IS_HIP from .jit.rope import gen_rope_module from .utils import register_custom_op, register_fake_op @@ -28,6 +30,111 @@ def get_rope_module(): return gen_rope_module().build_and_load() +if IS_HIP: + + @functools.cache + def _aiter_rope_ops(): + import aiter as _aiter + + return _aiter + + def _auto_select_rope_backend(query: torch.Tensor) -> str: + # AITER's cos/sin-cache rope consumes the cos/sin tables in the query + # dtype (bf16/fp16), whereas the native JIT kernel rotates in float32. + # For bf16 this pushes max abs error to ~5e-2 (vs native ~3e-2), at the + # edge of the flashinfer rope test tolerance. Keep auto on the native + # kernel; pass backend="aiter" to opt in explicitly. + return "native" + + # Memoize the AITER-format cos/sin tables. ``cos_sin_cache`` is a fixed + # precomputed float32 table (typically a persistent module buffer) reused + # across every forward pass, so converting the whole table to the query + # dtype on each call is pure overhead — significant during decode, where + # nnz is tiny but max_seq_len is large. Keyed by id() (tensors aren't + # value-hashable) with a finalizer that evicts the entry when the cache is + # GC'd, so the cached tables never outlive their source. + _aiter_cos_sin_tables: dict = {} + + def _aiter_rope_cos_sin( + cos_sin_cache: torch.Tensor, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + key = id(cos_sin_cache) + cached = _aiter_cos_sin_tables.get(key) + if cached is not None and cached[0] == dtype: + return cached[1], cached[2] + half = cos_sin_cache.shape[-1] // 2 + cos = cos_sin_cache[:, :half].unsqueeze(1).unsqueeze(1).to(dtype) + sin = cos_sin_cache[:, half:].unsqueeze(1).unsqueeze(1).to(dtype) + if cached is None: + weakref.finalize(cos_sin_cache, _aiter_cos_sin_tables.pop, key, None) + _aiter_cos_sin_tables[key] = (dtype, cos, sin) + return cos, sin + + def _apply_rope_cos_sin_cache_aiter( + query: torch.Tensor, + key: torch.Tensor, + query_out: torch.Tensor, + key_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + head_size: int, + is_neox: bool, + ) -> None: + r"""Dispatch the cos/sin-cache rope to AITER's rope_cached_positions_2c kernel. + + FlashInfer stores ``cos_sin_cache`` as ``(max_seq_len, rotary_dim)`` float32 + with cosine in the first half and sine in the second half. AITER wants two + separate ``(max_seq_len, 1, 1, rotary_dim // 2)`` tables in the query dtype + with ``reuse_freqs_front_part=True``. Q/K are reshaped to AITER's + ``(1, nnz, num_heads, head_dim)`` layout and only the leading ``rotary_dim`` + slice is rotated (matching ``nope_first=False``). Writes through views, so + ``query_out``/``key_out`` are updated in place (alias the inputs for the + inplace variant). + """ + from .aiter_utils import is_aiter_supported + + if not is_aiter_supported(query.device): + raise ValueError( + "AITER rope backend requires an AMD gfx942/gfx950 device; " + "use backend='native' instead." + ) + if key.dtype != query.dtype: + # AITER rotates Q and K with a single cos/sin table built in the + # query dtype; the native path tolerates mixed dtypes by rotating + # in float32, but AITER cannot. + raise ValueError( + "AITER rope backend requires query and key to share a dtype; " + f"got query={query.dtype}, key={key.dtype}. Use backend='native'." + ) + + nnz = query.shape[0] + rotary_dim = cos_sin_cache.shape[-1] + cos, sin = _aiter_rope_cos_sin(cos_sin_cache, query.dtype) + + q_view = query_out.view(1, nnz, -1, head_size) + k_view = key_out.view(1, nnz, -1, head_size) + if query_out.data_ptr() != query.data_ptr(): + q_view.copy_(query.view(1, nnz, -1, head_size)) + if key_out.data_ptr() != key.data_ptr(): + k_view.copy_(key.view(1, nnz, -1, head_size)) + + # AITER's HIP kernel asserts int64, contiguous positions of shape + # (1, nnz) (stride(1) == 1) — a strided/non-int64 positions tensor + # otherwise trips a C assert that aborts the process. + pos = positions.to(torch.int64).contiguous().view(1, nnz) + + _aiter_rope_ops().rope_cached_positions_2c_fwd_inplace( + q_view[..., :rotary_dim], + k_view[..., :rotary_dim], + cos, + sin, + pos, + 0 if is_neox else 1, # rotate_style: 0=NEOX, 1=GPT-J + True, # reuse_freqs_front_part (cos/sin are rotary_dim//2 sized) + False, # nope_first + ) + + @register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope")) def _apply_rope( q: torch.Tensor, @@ -1138,6 +1245,7 @@ def apply_rope_with_cos_sin_cache( head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool = True, + backend: str = "auto", ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Apply rotary embedding to keys and queries with precomputed cos/sin values. @@ -1164,6 +1272,14 @@ def apply_rope_with_cos_sin_cache( * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + backend : str + Kernel backend to use. ``"auto"`` (default) selects the best available backend. + ``"native"`` uses the FlashInfer JIT kernel on all platforms. + ``"aiter"`` uses AMD AITER's rope_cached kernel — ROCm (gfx942/gfx950) only; + requires the ``aiter`` package. Precision is slightly lower than ``"native"`` + for bfloat16 (max abs error ~5e-2 vs ~3e-2) because AITER consumes the cos/sin + tables in the query dtype rather than float32. + Returns ------- query_out : torch.Tensor @@ -1181,6 +1297,25 @@ def apply_rope_with_cos_sin_cache( query_out = torch.empty_like(query) key_out = torch.empty_like(key) + if IS_HIP: + _backend = backend if backend != "auto" else _auto_select_rope_backend(query) + if _backend == "aiter": + _apply_rope_cos_sin_cache_aiter( + query=query, + key=key, + query_out=query_out, + key_out=key_out, + cos_sin_cache=cos_sin_cache, + positions=positions, + head_size=head_size, + is_neox=is_neox, + ) + return query_out, key_out + if _backend != "native": + raise ValueError( + f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'." + ) + _apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), @@ -1201,6 +1336,7 @@ def apply_rope_with_cos_sin_cache_inplace( head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool = True, + backend: str = "auto", ) -> None: r""" Apply rotary embedding to keys and queries with precomputed cos/sin values. @@ -1227,6 +1363,15 @@ def apply_rope_with_cos_sin_cache_inplace( * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + backend : str + Kernel backend to use. ``"auto"`` (default) selects the best available backend. + ``"native"`` uses the FlashInfer JIT kernel on all platforms. + ``"aiter"`` uses AMD AITER's rope_cached kernel — ROCm (gfx942/gfx950) only; + requires the ``aiter`` package. Precision is slightly lower than ``"native"`` + for bfloat16 (max abs error ~5e-2 vs ~3e-2) because AITER consumes the cos/sin + tables in the query dtype rather than float32. + Note ---- The rotary dimension is determined by the cosine cache and sine cache. @@ -1234,6 +1379,25 @@ def apply_rope_with_cos_sin_cache_inplace( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") + if IS_HIP: + _backend = backend if backend != "auto" else _auto_select_rope_backend(query) + if _backend == "aiter": + _apply_rope_cos_sin_cache_aiter( + query=query, + key=key, + query_out=query, + key_out=key, + cos_sin_cache=cos_sin_cache, + positions=positions, + head_size=head_size, + is_neox=is_neox, + ) + return + if _backend != "native": + raise ValueError( + f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'." + ) + # pass q_rope and k_rope as q and k to perform inplace operation _apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), diff --git a/tests/rocm_tests/test_rope_aiter_hip.py b/tests/rocm_tests/test_rope_aiter_hip.py new file mode 100644 index 0000000000..57445e8179 --- /dev/null +++ b/tests/rocm_tests/test_rope_aiter_hip.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Tests for the AITER rope backend exposed via +# flashinfer.apply_rope_with_cos_sin_cache(backend="aiter") and its inplace variant. +# +# Note on tolerances: AITER's rope_cached kernel consumes the cos/sin tables in +# the query dtype, whereas the native JIT kernel rotates in float32. For bfloat16 +# this raises max abs error to ~5e-2 (native ~3e-2); fp16 stays at ~7e-3. The +# tolerances below reflect AITER's actual precision. + +import pytest +import torch + +import flashinfer +from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.rope_reference import RotaryEmbedding + +pytestmark = pytest.mark.skipif( + not is_aiter_supported(torch.device("cuda:0")), + reason="AITER backend requires gfx942/gfx950", +) + + +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "head_size, rotary_dim, num_q_heads, num_kv_heads", + [ + (64, 64, 8, 8), + (128, 128, 8, 2), + (128, 64, 8, 2), # partial rotary (rotary_dim < head_size) + (256, 128, 4, 2), + ], +) +def test_rope_cos_sin_cache_aiter_vs_ref( + is_neox_style, dtype, head_size, rotary_dim, num_q_heads, num_kv_heads +): + torch.manual_seed(0x4011) + device = torch.device("cuda:0") + batch_size, seq_len = 4, 33 + + rope = RotaryEmbedding( + head_size, rotary_dim, 4096, 10000, is_neox_style, dtype, device + ) + cos_sin_cache = rope.cos_sin_cache # float32 + + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + ) + + query_ref, key_ref = rope.forward_native(pos_ids, query.clone(), key.clone()) + query_aiter, key_aiter = flashinfer.apply_rope_with_cos_sin_cache( + pos_ids, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox=is_neox_style, + backend="aiter", + ) + + rtol, atol = (7e-2, 7e-2) if dtype == torch.bfloat16 else (1e-2, 1e-2) + torch.testing.assert_close( + query_aiter.float(), query_ref.float(), rtol=rtol, atol=atol + ) + torch.testing.assert_close(key_aiter.float(), key_ref.float(), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_rope_cos_sin_cache_aiter_inplace(is_neox_style, dtype): + """Inplace AITER backend matches its non-inplace counterpart and mutates inputs.""" + torch.manual_seed(0x4012) + device = torch.device("cuda:0") + head_size, rotary_dim = 128, 64 + batch_size, seq_len, num_q_heads, num_kv_heads = 4, 32, 8, 4 + + rope = RotaryEmbedding( + head_size, rotary_dim, 4096, 10000, is_neox_style, dtype, device + ) + cos_sin_cache = rope.cos_sin_cache + + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + ) + + query_out, key_out = flashinfer.apply_rope_with_cos_sin_cache( + pos_ids, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox=is_neox_style, + backend="aiter", + ) + + query_inplace = query.clone() + key_inplace = key.clone() + flashinfer.apply_rope_with_cos_sin_cache_inplace( + pos_ids, + query_inplace, + key_inplace, + head_size, + cos_sin_cache, + is_neox=is_neox_style, + backend="aiter", + ) + + # inplace must mutate the inputs + assert not torch.equal(query_inplace, query) + # and must agree with the non-inplace result + torch.testing.assert_close(query_inplace, query_out, rtol=0, atol=0) + torch.testing.assert_close(key_inplace, key_out, rtol=0, atol=0) + + +def test_rope_auto_backend_stays_native(): + """auto backend on gfx942/950 stays on the native kernel (precision parity).""" + from flashinfer.rope import _auto_select_rope_backend + + device = torch.device("cuda:0") + q = torch.randn(8, 128, dtype=torch.bfloat16, device=device) + assert _auto_select_rope_backend(q) == "native" + + +def test_rope_unknown_backend_raises(): + device = torch.device("cuda:0") + cos_sin_cache = torch.randn(64, 64, dtype=torch.float32, device=device) + pos_ids = torch.arange(8, device=device) + query = torch.randn(8, 8 * 128, dtype=torch.float16, device=device) + key = torch.randn(8, 8 * 128, dtype=torch.float16, device=device) + with pytest.raises(ValueError, match="Unknown backend"): + flashinfer.apply_rope_with_cos_sin_cache( + pos_ids, query, key, 128, cos_sin_cache, backend="nonsense" + ) + + +def test_rope_aiter_mixed_dtype_raises(): + """AITER rotates Q/K with one cos/sin table, so mismatched dtypes must error + clearly rather than crash inside the kernel.""" + device = torch.device("cuda:0") + cos_sin_cache = torch.randn(64, 64, dtype=torch.float32, device=device) + pos_ids = torch.arange(8, device=device) + query = torch.randn(8, 8 * 128, dtype=torch.bfloat16, device=device) + key = torch.randn(8, 8 * 128, dtype=torch.float16, device=device) + with pytest.raises(ValueError, match="share a dtype"): + flashinfer.apply_rope_with_cos_sin_cache( + pos_ids, query, key, 128, cos_sin_cache, backend="aiter" + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_rope_aiter_noncontiguous_positions(dtype): + """A strided positions tensor must be normalized before reaching the AITER + kernel, whose C assert (stride(1) == 1) would otherwise abort the process. + The result must match the contiguous-positions result.""" + device = torch.device("cuda:0") + head_size, rotary_dim = 128, 64 + batch_size, seq_len, num_q_heads, num_kv_heads = 2, 16, 8, 2 + + rope = RotaryEmbedding(head_size, rotary_dim, 4096, 10000, True, dtype, device) + cos_sin_cache = rope.cos_sin_cache + nnz = batch_size * seq_len + + # Build a non-contiguous (stride-2) positions tensor. + pos_strided = torch.arange(2 * seq_len, device=device, dtype=torch.int64).repeat( + batch_size + )[::2] + assert not pos_strided.is_contiguous() + + query = torch.randn(nnz, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(nnz, num_kv_heads * head_size, dtype=dtype, device=device) + + q_strided, k_strided = flashinfer.apply_rope_with_cos_sin_cache( + pos_strided, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + backend="aiter", + ) + q_contig, k_contig = flashinfer.apply_rope_with_cos_sin_cache( + pos_strided.contiguous(), + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + backend="aiter", + ) + torch.testing.assert_close(q_strided, q_contig, rtol=0, atol=0) + torch.testing.assert_close(k_strided, k_contig, rtol=0, atol=0) From 2897bb74378f4ab860b600dcce276c32eab33526 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 19:46:47 +0000 Subject: [PATCH 2/6] feat(hip,aiter): shape-aware auto for rope, document perf matrix Make backend="auto" pick AITER for the rope cos/sin-cache path where it is both faster and precise enough: inplace + fp16 + >= 2048 tokens on gfx942/gfx950. Everything else stays native. Rationale from gfx942 benchmarks (bf16/fp16, q32/k8, hd128): - Out-of-place AITER never wins (0.6-0.9x): its kernel is in-place only, so the wrapper must copy Q/K first, erasing the throughput gain. - Inplace AITER crosses over around 1024-1536 tokens and reaches ~1.65x at 32K; native's lower launch overhead wins below ~2048. - Decode (small nnz) is firmly native territory (~0.6x). - bf16 stays native on precision grounds (~5e-2 vs native ~3e-2); fp16 AITER error (~7e-3) is comfortably inside tolerance. _auto_select_rope_backend now takes an `inplace` flag (out-of-place can never benefit) and gates on dtype/nnz/device. README feature matrix and AITER Support section document the policy. Co-Authored-By: Claude Opus 4.7 --- README.md | 18 +++++--- flashinfer/rope.py | 57 ++++++++++++++++++++----- tests/rocm_tests/test_rope_aiter_hip.py | 27 +++++++++--- 3 files changed, 82 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 5d2ec6b22f..11da96ba0e 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend. | **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` | | **MLA (Multi-Latent Attention)** | — | ✅ | **AITER** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; `backend="auto"` (default) resolves to `"aiter"` | | **POD attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA; single + batch variants (`PODWithPagedKVCacheWrapper`, `BatchPODWithPagedKVCacheWrapper`); JIT-only (excluded from AOT, same as upstream CUDA) | -| **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) | +| **RoPE (positional encoding)** | ✅ `native` | ✅ | **AITER** for the cos/sin-cache path when inplace + `fp16` + `>= 2048` tokens + gfx942/gfx950; else **HIP `native`** | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ). AITER backend covers `apply_rope_with_cos_sin_cache` (CK `rope_cached_positions_2c`); ~1.2–1.65x over native on large-batch inplace prefill; bf16 stays native (slightly lower precision) | | **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + gfx942/gfx950 + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | | **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` | | **Fused add RMSNorm** | ✅ `native` | ✅ | **AITER** when 2-D + `>= 4M` elements + gfx942/gfx950 + AITER importable; else **HIP `native`** | `fused_add_rmsnorm`; AITER (CK `rmsnorm2d_fwd_with_add`) wins on large bandwidth-bound shapes; 2-D only, slightly lower precision at `hidden_size >= 1024` | @@ -309,8 +309,8 @@ pytest -n auto --reruns 2 -m "slow" FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill` (paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`, -`fused_add_rmsnorm`, `silu_and_mul`, and `MLA` paths to -[AITER](https://github.com/ROCm/aiter). MLA on ROCm +`fused_add_rmsnorm`, `silu_and_mul`, `rope` (cos/sin-cache path), and +`MLA` paths to [AITER](https://github.com/ROCm/aiter). MLA on ROCm is **AITER-only** — there is no in-tree HIP MLA kernel yet, so `backend="auto"` (the default for the MLA wrapper) resolves directly to `"aiter"`. @@ -323,8 +323,8 @@ explicitly, or pass the in-tree backend string to skip it: `backend="fa2"` for the attention wrappers (single/batch prefill/decode), `backend="native"` for non-attention ops (`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`, -`silu_and_mul`). Four backend-specific exceptions to "auto picks AITER -when supported": +`silu_and_mul`, `rope`). Five backend-specific exceptions to "auto picks +AITER when supported": * `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the AITER path is opt-in via `backend="aiter"`. @@ -335,6 +335,14 @@ when supported": `>= 33M`-element inputs (where it is faster and matches native precision) and otherwise stays on HIP `native`; the AITER path is also available explicitly via `backend="aiter"`. +* `rope` (`apply_rope_with_cos_sin_cache` / `_inplace`): `backend="auto"` + picks AITER only on the **inplace** path for `fp16` inputs with + `>= 2048` tokens (where the AITER kernel is ~1.2–1.65x faster and fp16 + precision stays inside tolerance) and otherwise stays on HIP `native` — + the out-of-place path always stays native (AITER's kernel is in-place + only, so the wrapper's Q/K copy erases the speedup), and bf16 always + stays native (slightly lower precision). The AITER path is also + available explicitly via `backend="aiter"`. * `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True` force `auto` back to `fa2` (AITER decode does not support either), and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`. diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 04df6c56cf..37f12b1d8e 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -38,13 +38,35 @@ def _aiter_rope_ops(): return _aiter - def _auto_select_rope_backend(query: torch.Tensor) -> str: + # Token count above which AITER's cos/sin-cache rope beats the native JIT + # kernel. Measured on gfx942 (bf16/fp16, q32/k8, hd128): the inplace AITER + # kernel crosses over around nnz~1024-1536 and reaches ~1.65x at 32K, while + # native's lower launch overhead wins below it. 2048 leaves headroom over + # launch-time jitter. + _AITER_ROPE_MIN_TOKENS = 2048 + + def _auto_select_rope_backend(query: torch.Tensor, inplace: bool) -> str: # AITER's cos/sin-cache rope consumes the cos/sin tables in the query - # dtype (bf16/fp16), whereas the native JIT kernel rotates in float32. - # For bf16 this pushes max abs error to ~5e-2 (vs native ~3e-2), at the - # edge of the flashinfer rope test tolerance. Keep auto on the native - # kernel; pass backend="aiter" to opt in explicitly. - return "native" + # dtype, whereas the native JIT kernel rotates in float32. For bf16 this + # pushes max abs error to ~5e-2 (vs native ~3e-2), at the edge of the + # rope test tolerance, so auto never picks AITER for bf16 — only fp16, + # whose AITER error (~7e-3) stays comfortably inside tolerance. + # + # AITER also only wins on the inplace path: its kernel is in-place-only, + # so the out-of-place wrapper must copy Q/K first, which erases the + # throughput gain (measured <1x even at 32K). And it only wins at large + # token counts. Outside that envelope, stay native. + if not inplace: + return "native" + if query.dtype != torch.float16: + return "native" + if query.shape[0] < _AITER_ROPE_MIN_TOKENS: + return "native" + from .aiter_utils import is_aiter_supported + + if not is_aiter_supported(query.device): + return "native" + return "aiter" # Memoize the AITER-format cos/sin tables. ``cos_sin_cache`` is a fixed # precomputed float32 table (typically a persistent module buffer) reused @@ -1273,7 +1295,10 @@ def apply_rope_with_cos_sin_cache( we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. backend : str - Kernel backend to use. ``"auto"`` (default) selects the best available backend. + Kernel backend to use. ``"auto"`` (default) selects the best backend for + the call; for this out-of-place variant that is always ``"native"`` — + AITER's kernel is in-place only, so the out-of-place wrapper must copy + Q/K first, which erases AITER's throughput advantage. ``"native"`` uses the FlashInfer JIT kernel on all platforms. ``"aiter"`` uses AMD AITER's rope_cached kernel — ROCm (gfx942/gfx950) only; requires the ``aiter`` package. Precision is slightly lower than ``"native"`` @@ -1298,7 +1323,11 @@ def apply_rope_with_cos_sin_cache( key_out = torch.empty_like(key) if IS_HIP: - _backend = backend if backend != "auto" else _auto_select_rope_backend(query) + _backend = ( + backend + if backend != "auto" + else _auto_select_rope_backend(query, inplace=False) + ) if _backend == "aiter": _apply_rope_cos_sin_cache_aiter( query=query, @@ -1365,7 +1394,11 @@ def apply_rope_with_cos_sin_cache_inplace( we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. backend : str - Kernel backend to use. ``"auto"`` (default) selects the best available backend. + Kernel backend to use. ``"auto"`` (default) selects the best backend for + the call: on ROCm (gfx942/gfx950) it picks AITER for fp16 inputs with at + least ~2048 tokens (where AITER's kernel is measurably faster), and stays + on ``"native"`` otherwise — for bf16 (precision), small token counts + (launch overhead), and non-ROCm platforms. ``"native"`` uses the FlashInfer JIT kernel on all platforms. ``"aiter"`` uses AMD AITER's rope_cached kernel — ROCm (gfx942/gfx950) only; requires the ``aiter`` package. Precision is slightly lower than ``"native"`` @@ -1380,7 +1413,11 @@ def apply_rope_with_cos_sin_cache_inplace( raise ValueError("cos_sin_cache should be float32") if IS_HIP: - _backend = backend if backend != "auto" else _auto_select_rope_backend(query) + _backend = ( + backend + if backend != "auto" + else _auto_select_rope_backend(query, inplace=True) + ) if _backend == "aiter": _apply_rope_cos_sin_cache_aiter( query=query, diff --git a/tests/rocm_tests/test_rope_aiter_hip.py b/tests/rocm_tests/test_rope_aiter_hip.py index 57445e8179..e4c1c84f78 100644 --- a/tests/rocm_tests/test_rope_aiter_hip.py +++ b/tests/rocm_tests/test_rope_aiter_hip.py @@ -122,13 +122,30 @@ def test_rope_cos_sin_cache_aiter_inplace(is_neox_style, dtype): torch.testing.assert_close(key_inplace, key_out, rtol=0, atol=0) -def test_rope_auto_backend_stays_native(): - """auto backend on gfx942/950 stays on the native kernel (precision parity).""" - from flashinfer.rope import _auto_select_rope_backend +def test_rope_auto_backend_selection(): + """auto picks AITER only for the inplace + fp16 + large-nnz envelope where it + is both faster (measured ~1.2-1.65x) and precise enough (fp16 err ~7e-3); + bf16, small nnz, and the out-of-place path all stay native.""" + from flashinfer.rope import _AITER_ROPE_MIN_TOKENS, _auto_select_rope_backend device = torch.device("cuda:0") - q = torch.randn(8, 128, dtype=torch.bfloat16, device=device) - assert _auto_select_rope_backend(q) == "native" + big = _AITER_ROPE_MIN_TOKENS + small = _AITER_ROPE_MIN_TOKENS - 1 + + # The one case auto routes to AITER: inplace, fp16, nnz >= threshold. + q_fp16_big = torch.randn(big, 128, dtype=torch.float16, device=device) + assert _auto_select_rope_backend(q_fp16_big, inplace=True) == "aiter" + + # Out-of-place always native (the Q/K copy erases AITER's throughput win). + assert _auto_select_rope_backend(q_fp16_big, inplace=False) == "native" + + # bf16 always native (precision: ~5e-2 vs native ~3e-2). + q_bf16_big = torch.randn(big, 128, dtype=torch.bfloat16, device=device) + assert _auto_select_rope_backend(q_bf16_big, inplace=True) == "native" + + # Below the token threshold, native's lower launch overhead wins. + q_fp16_small = torch.randn(small, 128, dtype=torch.float16, device=device) + assert _auto_select_rope_backend(q_fp16_small, inplace=True) == "native" def test_rope_unknown_backend_raises(): From cf8e5d5c11e2641f3cdb4295819e300d6b32dfe1 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 21:03:08 +0000 Subject: [PATCH 3/6] fix(hip,aiter): validate rotary_dim and wrap aiter import for rope backend Address Copilot review on PR #252: - Validate that cos_sin_cache's last dim (rotary_dim) is even (cos||sin) and does not exceed head_size before splitting/slicing, raising a clear ValueError instead of letting mismatched shapes reach the AITER kernel. - Wrap the _aiter_rope_ops() import in the explicit backend="aiter" path so a missing/broken aiter package surfaces as a helpful ValueError rather than a bare ModuleNotFoundError, matching activation.py. Add tests for both new guards. Co-Authored-By: Claude Opus 4.7 --- flashinfer/rope.py | 23 +++++++++++++++++++-- tests/rocm_tests/test_rope_aiter_hip.py | 27 +++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 37f12b1d8e..df8b1d5428 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -120,6 +120,13 @@ def _apply_rope_cos_sin_cache_aiter( "AITER rope backend requires an AMD gfx942/gfx950 device; " "use backend='native' instead." ) + try: + aiter_ops = _aiter_rope_ops() + except Exception as e: + raise ValueError( + "backend='aiter' requires the aiter package, which failed to " + f"import: {e}" + ) from e if key.dtype != query.dtype: # AITER rotates Q and K with a single cos/sin table built in the # query dtype; the native path tolerates mixed dtypes by rotating @@ -129,8 +136,20 @@ def _apply_rope_cos_sin_cache_aiter( f"got query={query.dtype}, key={key.dtype}. Use backend='native'." ) - nnz = query.shape[0] rotary_dim = cos_sin_cache.shape[-1] + # cos_sin_cache stacks cos||sin on its last dim, so rotary_dim must be + # even, and the rotated slice q[..., :rotary_dim] must fit in head_size. + if rotary_dim % 2 != 0: + raise ValueError( + f"cos_sin_cache last dim must be even (cos||sin); got {rotary_dim}." + ) + if rotary_dim > head_size: + raise ValueError( + f"rotary_dim ({rotary_dim}) from cos_sin_cache exceeds head_size " + f"({head_size})." + ) + + nnz = query.shape[0] cos, sin = _aiter_rope_cos_sin(cos_sin_cache, query.dtype) q_view = query_out.view(1, nnz, -1, head_size) @@ -145,7 +164,7 @@ def _apply_rope_cos_sin_cache_aiter( # otherwise trips a C assert that aborts the process. pos = positions.to(torch.int64).contiguous().view(1, nnz) - _aiter_rope_ops().rope_cached_positions_2c_fwd_inplace( + aiter_ops.rope_cached_positions_2c_fwd_inplace( q_view[..., :rotary_dim], k_view[..., :rotary_dim], cos, diff --git a/tests/rocm_tests/test_rope_aiter_hip.py b/tests/rocm_tests/test_rope_aiter_hip.py index e4c1c84f78..bcd5d57945 100644 --- a/tests/rocm_tests/test_rope_aiter_hip.py +++ b/tests/rocm_tests/test_rope_aiter_hip.py @@ -174,6 +174,33 @@ def test_rope_aiter_mixed_dtype_raises(): ) +def test_rope_aiter_odd_rotary_dim_raises(): + """An odd cos_sin_cache last dim cannot split into equal cos||sin halves.""" + device = torch.device("cuda:0") + cos_sin_cache = torch.randn(64, 63, dtype=torch.float32, device=device) + pos_ids = torch.arange(8, device=device) + query = torch.randn(8, 8 * 128, dtype=torch.float16, device=device) + key = torch.randn(8, 8 * 128, dtype=torch.float16, device=device) + with pytest.raises(ValueError, match="even"): + flashinfer.apply_rope_with_cos_sin_cache( + pos_ids, query, key, 128, cos_sin_cache, backend="aiter" + ) + + +def test_rope_aiter_rotary_dim_exceeds_head_size_raises(): + """rotary_dim derived from cos_sin_cache must fit within head_size.""" + device = torch.device("cuda:0") + head_size = 64 + cos_sin_cache = torch.randn(64, 128, dtype=torch.float32, device=device) + pos_ids = torch.arange(8, device=device) + query = torch.randn(8, 8 * head_size, dtype=torch.float16, device=device) + key = torch.randn(8, 8 * head_size, dtype=torch.float16, device=device) + with pytest.raises(ValueError, match="exceeds head_size"): + flashinfer.apply_rope_with_cos_sin_cache( + pos_ids, query, key, head_size, cos_sin_cache, backend="aiter" + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_rope_aiter_noncontiguous_positions(dtype): """A strided positions tensor must be normalized before reaching the AITER From bb17a0215f34c2f2e9344b89f0a49d215e47ce01 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 22:58:12 +0000 Subject: [PATCH 4/6] perf(hip,aiter): zero-copy out-of-place rope via _impl, widen auto MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch the AITER rope dispatch from rope_cached_positions_2c_fwd_inplace to the lower-level rope_cached_positions_2c_fwd_impl with distinct in/out tensors. AITER's _inplace wrapper is just _impl(x,y,x,y,...), so using _impl directly lets the out-of-place path write straight into the fresh output with no Q/K copy (measured ~1.3-2.8x faster than copy-then-rotate; inplace is unchanged). For partial rotary the untouched nope tail is copied across only when the output is a fresh tensor. Because out-of-place no longer pays a copy, backend="auto" now routes both the inplace and out-of-place wrappers to AITER for fp16 + >=2048 tokens (previously inplace-only). _auto_select_rope_backend drops its inplace arg accordingly. Robustness (from code review): - auto falls back to native on mismatched q/k dtype instead of routing into the helper's dtype guard and raising — auto must never raise. - memoized cos/sin entry carries a weakref to its source tensor; a cache hit is only honored when it still resolves to the same tensor, guarding against id() reuse for transient caches. - memoized cos/sin tables are made contiguous before reaching the kernel. Update README perf/auto description and tests accordingly. Co-Authored-By: Claude Opus 4.7 --- README.md | 16 ++-- flashinfer/rope.py | 101 +++++++++++++++--------- tests/rocm_tests/test_rope_aiter_hip.py | 25 +++--- 3 files changed, 86 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 11da96ba0e..1818091359 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend. | **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` | | **MLA (Multi-Latent Attention)** | — | ✅ | **AITER** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; `backend="auto"` (default) resolves to `"aiter"` | | **POD attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA; single + batch variants (`PODWithPagedKVCacheWrapper`, `BatchPODWithPagedKVCacheWrapper`); JIT-only (excluded from AOT, same as upstream CUDA) | -| **RoPE (positional encoding)** | ✅ `native` | ✅ | **AITER** for the cos/sin-cache path when inplace + `fp16` + `>= 2048` tokens + gfx942/gfx950; else **HIP `native`** | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ). AITER backend covers `apply_rope_with_cos_sin_cache` (CK `rope_cached_positions_2c`); ~1.2–1.65x over native on large-batch inplace prefill; bf16 stays native (slightly lower precision) | +| **RoPE (positional encoding)** | ✅ `native` | ✅ | **AITER** for the cos/sin-cache path when `fp16` + `>= 2048` tokens + gfx942/gfx950; else **HIP `native`** | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ). AITER backend covers `apply_rope_with_cos_sin_cache` and its inplace variant (CK `rope_cached_positions_2c`); ~1.3–2.8x over native at large nnz (zero-copy via the `_impl` entry point); bf16 stays native (slightly lower precision) | | **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + gfx942/gfx950 + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | | **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` | | **Fused add RMSNorm** | ✅ `native` | ✅ | **AITER** when 2-D + `>= 4M` elements + gfx942/gfx950 + AITER importable; else **HIP `native`** | `fused_add_rmsnorm`; AITER (CK `rmsnorm2d_fwd_with_add`) wins on large bandwidth-bound shapes; 2-D only, slightly lower precision at `hidden_size >= 1024` | @@ -336,13 +336,13 @@ AITER when supported": precision) and otherwise stays on HIP `native`; the AITER path is also available explicitly via `backend="aiter"`. * `rope` (`apply_rope_with_cos_sin_cache` / `_inplace`): `backend="auto"` - picks AITER only on the **inplace** path for `fp16` inputs with - `>= 2048` tokens (where the AITER kernel is ~1.2–1.65x faster and fp16 - precision stays inside tolerance) and otherwise stays on HIP `native` — - the out-of-place path always stays native (AITER's kernel is in-place - only, so the wrapper's Q/K copy erases the speedup), and bf16 always - stays native (slightly lower precision). The AITER path is also - available explicitly via `backend="aiter"`. + picks AITER for `fp16` inputs with `>= 2048` tokens (where the AITER + kernel is ~1.3–2.8x faster and fp16 precision stays inside tolerance) + and otherwise stays on HIP `native`. Both the inplace and out-of-place + wrappers benefit — the helper uses AITER's `..._impl` entry point with + distinct in/out tensors, so the out-of-place path needs no Q/K copy. + bf16 always stays native (slightly lower precision). The AITER path is + also available explicitly via `backend="aiter"`. * `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True` force `auto` back to `fa2` (AITER decode does not support either), and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`. diff --git a/flashinfer/rope.py b/flashinfer/rope.py index df8b1d5428..6a57b89f72 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -39,33 +39,42 @@ def _aiter_rope_ops(): return _aiter # Token count above which AITER's cos/sin-cache rope beats the native JIT - # kernel. Measured on gfx942 (bf16/fp16, q32/k8, hd128): the inplace AITER - # kernel crosses over around nnz~1024-1536 and reaches ~1.65x at 32K, while - # native's lower launch overhead wins below it. 2048 leaves headroom over - # launch-time jitter. + # kernel. Measured on gfx942 (fp16, q32/k8, hd128): AITER crosses over around + # nnz~1024-1536 and reaches ~2.6-2.8x at large nnz, while native's lower + # launch overhead wins below it. 2048 leaves headroom over launch jitter. _AITER_ROPE_MIN_TOKENS = 2048 - def _auto_select_rope_backend(query: torch.Tensor, inplace: bool) -> str: + def _auto_select_rope_backend(query: torch.Tensor, key: torch.Tensor) -> str: # AITER's cos/sin-cache rope consumes the cos/sin tables in the query # dtype, whereas the native JIT kernel rotates in float32. For bf16 this # pushes max abs error to ~5e-2 (vs native ~3e-2), at the edge of the # rope test tolerance, so auto never picks AITER for bf16 — only fp16, # whose AITER error (~7e-3) stays comfortably inside tolerance. # - # AITER also only wins on the inplace path: its kernel is in-place-only, - # so the out-of-place wrapper must copy Q/K first, which erases the - # throughput gain (measured <1x even at 32K). And it only wins at large - # token counts. Outside that envelope, stay native. - if not inplace: - return "native" + # AITER wins on both the inplace and out-of-place paths once nnz is large + # enough: the helper uses the ..._impl entry point with distinct in/out + # tensors, so out-of-place needs no Q/K copy. Below the token threshold, + # native's lower launch overhead wins. if query.dtype != torch.float16: return "native" + if key.dtype != query.dtype: + # AITER rotates Q and K with one cos/sin table in the query dtype and + # rejects mismatched dtypes; native handles them. auto must not raise, + # so fall back rather than route into the helper's dtype guard. + return "native" if query.shape[0] < _AITER_ROPE_MIN_TOKENS: return "native" from .aiter_utils import is_aiter_supported if not is_aiter_supported(query.device): return "native" + try: + # Best-effort probe: a supported arch can still lack a usable aiter + # (not installed, or its compiled extension fails to load). auto must + # always fall back to native, never raise — matching norm/activation. + _aiter_rope_ops() + except Exception: + return "native" return "aiter" # Memoize the AITER-format cos/sin tables. ``cos_sin_cache`` is a fixed @@ -74,7 +83,10 @@ def _auto_select_rope_backend(query: torch.Tensor, inplace: bool) -> str: # dtype on each call is pure overhead — significant during decode, where # nnz is tiny but max_seq_len is large. Keyed by id() (tensors aren't # value-hashable) with a finalizer that evicts the entry when the cache is - # GC'd, so the cached tables never outlive their source. + # GC'd, so the cached tables never outlive their source. A weakref to the + # source is stored alongside the entry so a hit is only honored when it + # still resolves to the *same* tensor — guarding against id() reuse if a + # transient cache is freed and a new tensor lands on the same address. _aiter_cos_sin_tables: dict = {} def _aiter_rope_cos_sin( @@ -82,14 +94,14 @@ def _aiter_rope_cos_sin( ) -> Tuple[torch.Tensor, torch.Tensor]: key = id(cos_sin_cache) cached = _aiter_cos_sin_tables.get(key) - if cached is not None and cached[0] == dtype: - return cached[1], cached[2] + if cached is not None and cached[0] == dtype and cached[1]() is cos_sin_cache: + return cached[2], cached[3] half = cos_sin_cache.shape[-1] // 2 - cos = cos_sin_cache[:, :half].unsqueeze(1).unsqueeze(1).to(dtype) - sin = cos_sin_cache[:, half:].unsqueeze(1).unsqueeze(1).to(dtype) + cos = cos_sin_cache[:, :half].unsqueeze(1).unsqueeze(1).to(dtype).contiguous() + sin = cos_sin_cache[:, half:].unsqueeze(1).unsqueeze(1).to(dtype).contiguous() if cached is None: weakref.finalize(cos_sin_cache, _aiter_cos_sin_tables.pop, key, None) - _aiter_cos_sin_tables[key] = (dtype, cos, sin) + _aiter_cos_sin_tables[key] = (dtype, weakref.ref(cos_sin_cache), cos, sin) return cos, sin def _apply_rope_cos_sin_cache_aiter( @@ -109,9 +121,15 @@ def _apply_rope_cos_sin_cache_aiter( separate ``(max_seq_len, 1, 1, rotary_dim // 2)`` tables in the query dtype with ``reuse_freqs_front_part=True``. Q/K are reshaped to AITER's ``(1, nnz, num_heads, head_dim)`` layout and only the leading ``rotary_dim`` - slice is rotated (matching ``nope_first=False``). Writes through views, so - ``query_out``/``key_out`` are updated in place (alias the inputs for the - inplace variant). + slice is rotated (matching ``nope_first=False``). + + Uses the ``..._impl`` entry point with distinct input/output tensors so the + out-of-place case writes straight into ``query_out``/``key_out`` with no Q/K + copy (measured ~1.3-2.8x faster than copy-then-rotate). The kernel only + writes the rotated ``[:rotary_dim]`` slice, so for partial rotary + (``rotary_dim < head_size``) the untouched nope tail is copied across when + the output is a fresh tensor. When ``query_out``/``key_out`` alias the + inputs (the inplace variant) this is ``_impl(x, y, x, y, ...)``. """ from .aiter_utils import is_aiter_supported @@ -152,21 +170,31 @@ def _apply_rope_cos_sin_cache_aiter( nnz = query.shape[0] cos, sin = _aiter_rope_cos_sin(cos_sin_cache, query.dtype) - q_view = query_out.view(1, nnz, -1, head_size) - k_view = key_out.view(1, nnz, -1, head_size) - if query_out.data_ptr() != query.data_ptr(): - q_view.copy_(query.view(1, nnz, -1, head_size)) - if key_out.data_ptr() != key.data_ptr(): - k_view.copy_(key.view(1, nnz, -1, head_size)) + q_in = query.view(1, nnz, -1, head_size) + k_in = key.view(1, nnz, -1, head_size) + q_out = query_out.view(1, nnz, -1, head_size) + k_out = key_out.view(1, nnz, -1, head_size) + + # The kernel only writes the rotated [:rotary_dim] slice. For partial + # rotary into a fresh output, copy the untouched nope tail [rotary_dim:] + # across first (skipped when output aliases input, and when rotary_dim + # covers the full head_size). + if rotary_dim < head_size: + if query_out.data_ptr() != query.data_ptr(): + q_out[..., rotary_dim:].copy_(q_in[..., rotary_dim:]) + if key_out.data_ptr() != key.data_ptr(): + k_out[..., rotary_dim:].copy_(k_in[..., rotary_dim:]) # AITER's HIP kernel asserts int64, contiguous positions of shape # (1, nnz) (stride(1) == 1) — a strided/non-int64 positions tensor # otherwise trips a C assert that aborts the process. pos = positions.to(torch.int64).contiguous().view(1, nnz) - aiter_ops.rope_cached_positions_2c_fwd_inplace( - q_view[..., :rotary_dim], - k_view[..., :rotary_dim], + aiter_ops.rope_cached_positions_2c_fwd_impl( + q_out[..., :rotary_dim], + k_out[..., :rotary_dim], + q_in[..., :rotary_dim], + k_in[..., :rotary_dim], cos, sin, pos, @@ -1315,9 +1343,10 @@ def apply_rope_with_cos_sin_cache( backend : str Kernel backend to use. ``"auto"`` (default) selects the best backend for - the call; for this out-of-place variant that is always ``"native"`` — - AITER's kernel is in-place only, so the out-of-place wrapper must copy - Q/K first, which erases AITER's throughput advantage. + the call: on ROCm (gfx942/gfx950) it picks AITER for fp16 inputs with at + least ~2048 tokens (where the AITER kernel is measurably faster), and stays + on ``"native"`` otherwise — for bf16 (precision), small token counts + (launch overhead), and non-ROCm platforms. ``"native"`` uses the FlashInfer JIT kernel on all platforms. ``"aiter"`` uses AMD AITER's rope_cached kernel — ROCm (gfx942/gfx950) only; requires the ``aiter`` package. Precision is slightly lower than ``"native"`` @@ -1343,9 +1372,7 @@ def apply_rope_with_cos_sin_cache( if IS_HIP: _backend = ( - backend - if backend != "auto" - else _auto_select_rope_backend(query, inplace=False) + backend if backend != "auto" else _auto_select_rope_backend(query, key) ) if _backend == "aiter": _apply_rope_cos_sin_cache_aiter( @@ -1433,9 +1460,7 @@ def apply_rope_with_cos_sin_cache_inplace( if IS_HIP: _backend = ( - backend - if backend != "auto" - else _auto_select_rope_backend(query, inplace=True) + backend if backend != "auto" else _auto_select_rope_backend(query, key) ) if _backend == "aiter": _apply_rope_cos_sin_cache_aiter( diff --git a/tests/rocm_tests/test_rope_aiter_hip.py b/tests/rocm_tests/test_rope_aiter_hip.py index bcd5d57945..503d18ee49 100644 --- a/tests/rocm_tests/test_rope_aiter_hip.py +++ b/tests/rocm_tests/test_rope_aiter_hip.py @@ -123,29 +123,34 @@ def test_rope_cos_sin_cache_aiter_inplace(is_neox_style, dtype): def test_rope_auto_backend_selection(): - """auto picks AITER only for the inplace + fp16 + large-nnz envelope where it - is both faster (measured ~1.2-1.65x) and precise enough (fp16 err ~7e-3); - bf16, small nnz, and the out-of-place path all stay native.""" + """auto picks AITER for fp16 + large-nnz (both inplace and out-of-place, since + the _impl path needs no Q/K copy), where it is both faster (~1.3-2.8x) and + precise enough (fp16 err ~7e-3); bf16 and small nnz stay native.""" from flashinfer.rope import _AITER_ROPE_MIN_TOKENS, _auto_select_rope_backend device = torch.device("cuda:0") big = _AITER_ROPE_MIN_TOKENS small = _AITER_ROPE_MIN_TOKENS - 1 - # The one case auto routes to AITER: inplace, fp16, nnz >= threshold. + # fp16 + nnz >= threshold routes to AITER (selection is shape/dtype-based and + # backend dispatch is shared by both the inplace and out-of-place wrappers). q_fp16_big = torch.randn(big, 128, dtype=torch.float16, device=device) - assert _auto_select_rope_backend(q_fp16_big, inplace=True) == "aiter" - - # Out-of-place always native (the Q/K copy erases AITER's throughput win). - assert _auto_select_rope_backend(q_fp16_big, inplace=False) == "native" + k_fp16_big = torch.randn(big, 128, dtype=torch.float16, device=device) + assert _auto_select_rope_backend(q_fp16_big, k_fp16_big) == "aiter" # bf16 always native (precision: ~5e-2 vs native ~3e-2). q_bf16_big = torch.randn(big, 128, dtype=torch.bfloat16, device=device) - assert _auto_select_rope_backend(q_bf16_big, inplace=True) == "native" + k_bf16_big = torch.randn(big, 128, dtype=torch.bfloat16, device=device) + assert _auto_select_rope_backend(q_bf16_big, k_bf16_big) == "native" # Below the token threshold, native's lower launch overhead wins. q_fp16_small = torch.randn(small, 128, dtype=torch.float16, device=device) - assert _auto_select_rope_backend(q_fp16_small, inplace=True) == "native" + k_fp16_small = torch.randn(small, 128, dtype=torch.float16, device=device) + assert _auto_select_rope_backend(q_fp16_small, k_fp16_small) == "native" + + # Mixed q/k dtype falls back to native rather than raising: AITER can't rotate + # both with one cos/sin table, but auto must never raise. + assert _auto_select_rope_backend(q_fp16_big, k_bf16_big) == "native" def test_rope_unknown_backend_raises(): From 97f9b2d904fcd4ccfa8228a327a62d4df48e02c0 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 23:10:07 +0000 Subject: [PATCH 5/6] test(hip,aiter): cover rope auto fallback when aiter unimportable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Copilot review on PR #252: add a unit test that monkeypatches flashinfer.rope._aiter_rope_ops to raise and asserts _auto_select_rope_backend(query, key) returns "native" — verifying the best-effort import probe falls back gracefully on a supported arch with a missing/broken aiter install. Co-Authored-By: Claude Opus 4.7 --- tests/rocm_tests/test_rope_aiter_hip.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/rocm_tests/test_rope_aiter_hip.py b/tests/rocm_tests/test_rope_aiter_hip.py index 503d18ee49..3ab875c816 100644 --- a/tests/rocm_tests/test_rope_aiter_hip.py +++ b/tests/rocm_tests/test_rope_aiter_hip.py @@ -246,3 +246,25 @@ def test_rope_aiter_noncontiguous_positions(dtype): ) torch.testing.assert_close(q_strided, q_contig, rtol=0, atol=0) torch.testing.assert_close(k_strided, k_contig, rtol=0, atol=0) + + +def test_rope_auto_falls_back_when_aiter_unimportable(monkeypatch): + """On a supported arch with a missing/broken aiter install, auto must fall + back to native rather than raise — _auto_select_rope_backend probes the + import and returns 'native' on failure.""" + from flashinfer import rope + from flashinfer.rope import _AITER_ROPE_MIN_TOKENS, _auto_select_rope_backend + + device = torch.device("cuda:0") + n = _AITER_ROPE_MIN_TOKENS + q = torch.randn(n, 128, dtype=torch.float16, device=device) + k = torch.randn(n, 128, dtype=torch.float16, device=device) + + # Sanity: with aiter importable this shape selects aiter. + assert _auto_select_rope_backend(q, k) == "aiter" + + def _boom(): + raise ImportError("simulated missing aiter") + + monkeypatch.setattr(rope, "_aiter_rope_ops", _boom) + assert _auto_select_rope_backend(q, k) == "native" From 390a1ae37d9c10e27d5e18dbea26845c550832fd Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 23:20:49 +0000 Subject: [PATCH 6/6] test(hip,aiter): skip rope aiter tests when aiter package missing Address Copilot review on PR #252: switch the module pytestmark from a bare is_aiter_supported (arch-only) check to tests.test_helpers.requires_aiter, which skips when the arch is unsupported OR the aiter package is not importable. On a supported GPU without aiter installed the old guard let the tests run and then fail; this matches the other test_*_aiter_hip.py modules. Co-Authored-By: Claude Opus 4.7 --- tests/rocm_tests/test_rope_aiter_hip.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/rocm_tests/test_rope_aiter_hip.py b/tests/rocm_tests/test_rope_aiter_hip.py index 3ab875c816..4d4c35ca1d 100644 --- a/tests/rocm_tests/test_rope_aiter_hip.py +++ b/tests/rocm_tests/test_rope_aiter_hip.py @@ -13,13 +13,13 @@ import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported from tests.test_helpers.rope_reference import RotaryEmbedding +from tests.test_helpers.test_helpers import requires_aiter -pytestmark = pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +# Skip when the arch is unsupported OR the aiter package is missing — checking +# arch alone would let these tests run and then fail on a supported GPU without +# aiter installed. Matches the other test_*_aiter_hip.py modules. +pytestmark = requires_aiter @pytest.mark.parametrize("is_neox_style", [True, False])