feat(hip,aiter): add AITER backend to rope cos/sin-cache ops#252
Merged
demandal25 merged 6 commits intoJun 15, 2026
Conversation
Route apply_rope_with_cos_sin_cache and its inplace variant through AITER's rope_cached_positions_2c kernel on ROCm via a new backend="aiter" opt-in, mirroring the rmsnorm/silu_and_mul AITER backend pattern. The helper adapts FlashInfer's formats to AITER's: splits the (max_seq_len, rotary_dim) float32 cos||sin cache into two (max_seq_len, 1, 1, rotary_dim//2) tables in the query dtype (memoized per cache tensor to avoid re-converting the full table every forward pass), reshapes Q/K to AITER's (1, nnz, heads, dim) layout, and rotates only the leading rotary_dim slice. backend="auto" stays on the native kernel: AITER consumes the cos/sin tables in the query dtype rather than float32, raising bf16 max abs error to ~5e-2 (vs native ~3e-2), at the edge of the rope test tolerance. "aiter" is explicit opt-in. Guards added to the AITER path: device must be gfx942/gfx950, query/key must share a dtype (AITER rotates both with one cos/sin table), and positions are coerced to contiguous int64 (a strided positions tensor otherwise trips a C assert that aborts the process). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds an opt-in AITER backend for the HIP/ROCm RoPE cos/sin-cache entry points, so users on gfx942/gfx950 can route apply_rope_with_cos_sin_cache* through AITER’s cached-kernel while keeping "auto" on the existing native JIT kernel for precision parity.
Changes:
- Extend
apply_rope_with_cos_sin_cacheand_inplacewithbackend: str = "auto"and routebackend="aiter"to a new HIP-only AITER adaptor. - Introduce helpers to reshape Q/K and convert/split the cos||sin cache into AITER’s expected table format (with memoization).
- Add ROCm tests validating AITER vs native reference (including partial rotary), inplace semantics, backend selection, and positions normalization.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
flashinfer/rope.py |
Adds the backend parameter and implements the HIP AITER dispatch path + cos/sin table memoization. |
tests/rocm_tests/test_rope_aiter_hip.py |
New ROCm tests covering AITER backend correctness, backend selection behavior, and guardrails. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Make backend="auto" pick AITER for the rope cos/sin-cache path where it is both faster and precise enough: inplace + fp16 + >= 2048 tokens on gfx942/gfx950. Everything else stays native. Rationale from gfx942 benchmarks (bf16/fp16, q32/k8, hd128): - Out-of-place AITER never wins (0.6-0.9x): its kernel is in-place only, so the wrapper must copy Q/K first, erasing the throughput gain. - Inplace AITER crosses over around 1024-1536 tokens and reaches ~1.65x at 32K; native's lower launch overhead wins below ~2048. - Decode (small nnz) is firmly native territory (~0.6x). - bf16 stays native on precision grounds (~5e-2 vs native ~3e-2); fp16 AITER error (~7e-3) is comfortably inside tolerance. _auto_select_rope_backend now takes an `inplace` flag (out-of-place can never benefit) and gates on dtype/nnz/device. README feature matrix and AITER Support section document the policy. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ckend Address Copilot review on PR ROCm#252: - Validate that cos_sin_cache's last dim (rotary_dim) is even (cos||sin) and does not exceed head_size before splitting/slicing, raising a clear ValueError instead of letting mismatched shapes reach the AITER kernel. - Wrap the _aiter_rope_ops() import in the explicit backend="aiter" path so a missing/broken aiter package surfaces as a helpful ValueError rather than a bare ModuleNotFoundError, matching activation.py. Add tests for both new guards. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Switch the AITER rope dispatch from rope_cached_positions_2c_fwd_inplace to the lower-level rope_cached_positions_2c_fwd_impl with distinct in/out tensors. AITER's _inplace wrapper is just _impl(x,y,x,y,...), so using _impl directly lets the out-of-place path write straight into the fresh output with no Q/K copy (measured ~1.3-2.8x faster than copy-then-rotate; inplace is unchanged). For partial rotary the untouched nope tail is copied across only when the output is a fresh tensor. Because out-of-place no longer pays a copy, backend="auto" now routes both the inplace and out-of-place wrappers to AITER for fp16 + >=2048 tokens (previously inplace-only). _auto_select_rope_backend drops its inplace arg accordingly. Robustness (from code review): - auto falls back to native on mismatched q/k dtype instead of routing into the helper's dtype guard and raising — auto must never raise. - memoized cos/sin entry carries a weakref to its source tensor; a cache hit is only honored when it still resolves to the same tensor, guarding against id() reuse for transient caches. - memoized cos/sin tables are made contiguous before reaching the kernel. Update README perf/auto description and tests accordingly. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Address Copilot review on PR ROCm#252: add a unit test that monkeypatches flashinfer.rope._aiter_rope_ops to raise and asserts _auto_select_rope_backend(query, key) returns "native" — verifying the best-effort import probe falls back gracefully on a supported arch with a missing/broken aiter install. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Address Copilot review on PR ROCm#252: switch the module pytestmark from a bare is_aiter_supported (arch-only) check to tests.test_helpers.requires_aiter, which skips when the arch is unsupported OR the aiter package is not importable. On a supported GPU without aiter installed the old guard let the tests run and then fail; this matches the other test_*_aiter_hip.py modules. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an AITER backend to the cos/sin-cache RoPE entry points on ROCm (gfx942/gfx950), mirroring the existing
rmsnorm/fused_add_rmsnorm/silu_and_mulAITER backends.backend="auto"(default) is shape-aware: it picks AITER only where it is both faster and precise enough, and stays on the native JIT kernel otherwise.What changed
flashinfer/rope.pyapply_rope_with_cos_sin_cacheandapply_rope_with_cos_sin_cache_inplace— gain abackend: str = "auto"parameter. On HIP,backend="aiter"routes to AITER'srope_cached_positions_2ckernel;"native"keeps the existing JIT path;"auto"selects per the policy below. Unknown backends raiseValueError._auto_select_rope_backend(query, key)(new) — shape/dtype/device-aware selector (see policy table)._apply_rope_cos_sin_cache_aiter(new helper) — adapts FlashInfer's tensor formats to AITER's: splits the(max_seq_len, rotary_dim)float32cos||sincache into two separate(max_seq_len, 1, 1, rotary_dim//2)tables in the query dtype withreuse_freqs_front_part=True; reshapes Q/K to AITER's(1, nnz, num_heads, head_dim)layout and rotates only the leadingrotary_dimslice (nope_first=False), so partial-rotary (rotary_dim < head_size) works; and calls AITER's lower-levelrope_cached_positions_2c_fwd_implwith distinct in/out tensors so the out-of-place path writes straight into the fresh output with no Q/K copy (AITER's_inplacewrapper is just_impl(x,y,x,y,...)). For partial rotary the untouched nope tail is copied across only when the output is a fresh tensor._aiter_rope_cos_sin(new helper) — memoizes the dtype-converted cos/sin tables, keyed by the cache tensor'sid()with aweakref.finalizeevictor (and aweakref.refidentity check on hit to guard againstid()reuse) so they never outlive the source and the full-table conversion runs once.README.mdtests/rocm_tests/test_rope_aiter_hip.py(new)aiter, and non-contiguous positions are normalized (no process abort).Benchmark results
Measured on gfx942 (MI300X),
q32/k8,hd128,rot128, CUDA-event timed, warm JIT + warm cos/sin memo. Time is mean µs per call; ratio is native/aiter (>1 = AITER faster).Out-of-place (
apply_rope_with_cos_sin_cache) — zero-copy via the_implentry point makes AITER win across the board at large nnz (bf16 shown; the copy in the original_inplace-based approach had made these <1x):Inplace (
apply_rope_with_cos_sin_cache_inplace, bf16) — crosses over around 1024–1536 tokens:max_seq(4K→128K) has no measurable effect at fixed nnz — the cos/sin memoization removes the per-call table conversion, so cache size is irrelevant.Architecture / design notes
backend="auto"policy (_auto_select_rope_backend):nativenativenativenativeaiterWhy the
_implentry point. AITER's public_inplace/_fwdwrappers force either out==in or an internal allocation._impl(which both wrappers delegate to) exposes distinct in/out tensors, giving zero-copy out-of-place without reaching below AITER's stable kernel boundary.Why only the cos/sin-cache functions. AITER's
rope_cached_*API takes a precomputed cos/sin cache, matchingapply_rope_with_cos_sin_cache1:1. The other rope entry points compute frequencies on the fly and have no corresponding AITER cached kernel.Guards on the explicit
backend="aiter"path (the native path tolerates these; AITER does not):aiterimport wrappedValueErrorif the package is missingrotary_dimeven and ≤ head_sizeassertthat aborts the processTest plan
pytest tests/rocm_tests/test_rope_aiter_hip.py(passed on gfx942)pytest tests/rocm_tests/test_rope_hip.py(native path unbroken; native ↔ aiter agree within tolerance)pytest tests/rocm_tests/test_rope_hip.py tests/rocm_tests/test_rope_aiter_hip.py -n auto --reruns 2 -m "not slow"(0 failures)pre-commit run -a