Skip to content

feat(hip,aiter): add AITER backend to rope cos/sin-cache ops#252

Merged
demandal25 merged 6 commits into
ROCm:amd-integrationfrom
demandal25:feat/rope-aiter-backend
Jun 15, 2026
Merged

feat(hip,aiter): add AITER backend to rope cos/sin-cache ops#252
demandal25 merged 6 commits into
ROCm:amd-integrationfrom
demandal25:feat/rope-aiter-backend

Conversation

@demandal25

@demandal25 demandal25 commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds an AITER backend to the cos/sin-cache RoPE entry points on ROCm (gfx942/gfx950), mirroring the existing rmsnorm / fused_add_rmsnorm / silu_and_mul AITER backends. backend="auto" (default) is shape-aware: it picks AITER only where it is both faster and precise enough, and stays on the native JIT kernel otherwise.

What changed

flashinfer/rope.py

  • apply_rope_with_cos_sin_cache and apply_rope_with_cos_sin_cache_inplace — gain a backend: str = "auto" parameter. On HIP, backend="aiter" routes to AITER's rope_cached_positions_2c kernel; "native" keeps the existing JIT path; "auto" selects per the policy below. Unknown backends raise ValueError.
  • _auto_select_rope_backend(query, key) (new) — shape/dtype/device-aware selector (see policy table).
  • _apply_rope_cos_sin_cache_aiter (new helper) — adapts FlashInfer's tensor formats to AITER's: splits the (max_seq_len, rotary_dim) float32 cos||sin cache into two separate (max_seq_len, 1, 1, rotary_dim//2) tables in the query dtype with reuse_freqs_front_part=True; reshapes Q/K to AITER's (1, nnz, num_heads, head_dim) layout and rotates only the leading rotary_dim slice (nope_first=False), so partial-rotary (rotary_dim < head_size) works; and calls AITER's lower-level rope_cached_positions_2c_fwd_impl with distinct in/out tensors so the out-of-place path writes straight into the fresh output with no Q/K copy (AITER's _inplace wrapper is just _impl(x,y,x,y,...)). For partial rotary the untouched nope tail is copied across only when the output is a fresh tensor.
  • _aiter_rope_cos_sin (new helper) — memoizes the dtype-converted cos/sin tables, keyed by the cache tensor's id() with a weakref.finalize evictor (and a weakref.ref identity check on hit to guard against id() reuse) so they never outlive the source and the full-table conversion runs once.

README.md

  • Feature matrix RoPE row + AITER Support section document the auto policy.

tests/rocm_tests/test_rope_aiter_hip.py (new)

  • AITER-vs-native-reference across dtypes / shapes / NEOX styles (incl. partial rotary), inplace == non-inplace equivalence, the auto selector (incl. mixed-dtype fallback), unknown-backend raises, mixed-dtype / odd-rotary / oversized-rotary raise under explicit aiter, and non-contiguous positions are normalized (no process abort).

Benchmark results

Measured on gfx942 (MI300X), q32/k8, hd128, rot128, CUDA-event timed, warm JIT + warm cos/sin memo. Time is mean µs per call; ratio is native/aiter (>1 = AITER faster).

Out-of-place (apply_rope_with_cos_sin_cache) — zero-copy via the _impl entry point makes AITER win across the board at large nnz (bf16 shown; the copy in the original _inplace-based approach had made these <1x):

nnz native aiter native/aiter
8 (decode) 17.0 13.1 1.29x
2048 34.1 13.0 2.63x
8192 93.8 39.6 2.37x
32768 488.4 176.3 2.77x

Inplace (apply_rope_with_cos_sin_cache_inplace, bf16) — crosses over around 1024–1536 tokens:

nnz native aiter native/aiter winner
8 (decode) 7.2 10.8 0.67x native
1024 12.8 13.0 0.98x ~tie
2048 24.6 19.6 1.26x AITER
32768 364.6 220.9 1.65x AITER

max_seq (4K→128K) has no measurable effect at fixed nnz — the cos/sin memoization removes the per-call table conversion, so cache size is irrelevant.

Architecture / design notes

backend="auto" policy (_auto_select_rope_backend):

Condition Resolves to Why
dtype != fp16 native bf16 AITER error ~5e-2 vs native ~3e-2 (tolerance edge); fp16 ~7e-3 is safe
query/key dtype mismatch native AITER rotates both with one cos/sin table; native handles mixed dtypes — auto must not raise
nnz < 2048 native below crossover, native's lower launch overhead wins
aiter unsupported/unimportable native best-effort probe; auto never raises
else (fp16 + nnz≥2048 + supported) aiter measured ~1.3–2.8x, both inplace and out-of-place

Why the _impl entry point. AITER's public _inplace/_fwd wrappers force either out==in or an internal allocation. _impl (which both wrappers delegate to) exposes distinct in/out tensors, giving zero-copy out-of-place without reaching below AITER's stable kernel boundary.

Why only the cos/sin-cache functions. AITER's rope_cached_* API takes a precomputed cos/sin cache, matching apply_rope_with_cos_sin_cache 1:1. The other rope entry points compute frequencies on the fly and have no corresponding AITER cached kernel.

Guards on the explicit backend="aiter" path (the native path tolerates these; AITER does not):

Guard Reason
device must be gfx942/gfx950 otherwise a cryptic import/JIT failure
aiter import wrapped clear ValueError if the package is missing
query/key must share a dtype AITER rotates both with one cos/sin table built in the query dtype
rotary_dim even and ≤ head_size the cos
positions → contiguous int64 a strided positions tensor trips a C assert that aborts the process

Test plan

  • pytest tests/rocm_tests/test_rope_aiter_hip.py (passed on gfx942)
  • pytest tests/rocm_tests/test_rope_hip.py (native path unbroken; native ↔ aiter agree within tolerance)
  • pytest tests/rocm_tests/test_rope_hip.py tests/rocm_tests/test_rope_aiter_hip.py -n auto --reruns 2 -m "not slow" (0 failures)
  • pre-commit run -a

Route apply_rope_with_cos_sin_cache and its inplace variant through AITER's
rope_cached_positions_2c kernel on ROCm via a new backend="aiter" opt-in,
mirroring the rmsnorm/silu_and_mul AITER backend pattern.

The helper adapts FlashInfer's formats to AITER's: splits the (max_seq_len,
rotary_dim) float32 cos||sin cache into two (max_seq_len, 1, 1, rotary_dim//2)
tables in the query dtype (memoized per cache tensor to avoid re-converting the
full table every forward pass), reshapes Q/K to AITER's (1, nnz, heads, dim)
layout, and rotates only the leading rotary_dim slice.

backend="auto" stays on the native kernel: AITER consumes the cos/sin tables in
the query dtype rather than float32, raising bf16 max abs error to ~5e-2 (vs
native ~3e-2), at the edge of the rope test tolerance. "aiter" is explicit
opt-in.

Guards added to the AITER path: device must be gfx942/gfx950, query/key must
share a dtype (AITER rotates both with one cos/sin table), and positions are
coerced to contiguous int64 (a strided positions tensor otherwise trips a C
assert that aborts the process).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 15, 2026 19:27

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an opt-in AITER backend for the HIP/ROCm RoPE cos/sin-cache entry points, so users on gfx942/gfx950 can route apply_rope_with_cos_sin_cache* through AITER’s cached-kernel while keeping "auto" on the existing native JIT kernel for precision parity.

Changes:

  • Extend apply_rope_with_cos_sin_cache and _inplace with backend: str = "auto" and route backend="aiter" to a new HIP-only AITER adaptor.
  • Introduce helpers to reshape Q/K and convert/split the cos||sin cache into AITER’s expected table format (with memoization).
  • Add ROCm tests validating AITER vs native reference (including partial rotary), inplace semantics, backend selection, and positions normalization.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
flashinfer/rope.py Adds the backend parameter and implements the HIP AITER dispatch path + cos/sin table memoization.
tests/rocm_tests/test_rope_aiter_hip.py New ROCm tests covering AITER backend correctness, backend selection behavior, and guardrails.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread flashinfer/rope.py
Comment thread flashinfer/rope.py
Comment thread flashinfer/rope.py Outdated
Comment thread flashinfer/rope.py Outdated
demandal25 and others added 2 commits June 15, 2026 19:46
Make backend="auto" pick AITER for the rope cos/sin-cache path where it
is both faster and precise enough: inplace + fp16 + >= 2048 tokens on
gfx942/gfx950. Everything else stays native.

Rationale from gfx942 benchmarks (bf16/fp16, q32/k8, hd128):
- Out-of-place AITER never wins (0.6-0.9x): its kernel is in-place only,
  so the wrapper must copy Q/K first, erasing the throughput gain.
- Inplace AITER crosses over around 1024-1536 tokens and reaches ~1.65x
  at 32K; native's lower launch overhead wins below ~2048.
- Decode (small nnz) is firmly native territory (~0.6x).
- bf16 stays native on precision grounds (~5e-2 vs native ~3e-2); fp16
  AITER error (~7e-3) is comfortably inside tolerance.

_auto_select_rope_backend now takes an `inplace` flag (out-of-place can
never benefit) and gates on dtype/nnz/device. README feature matrix and
AITER Support section document the policy.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ckend

Address Copilot review on PR ROCm#252:
- Validate that cos_sin_cache's last dim (rotary_dim) is even (cos||sin)
  and does not exceed head_size before splitting/slicing, raising a clear
  ValueError instead of letting mismatched shapes reach the AITER kernel.
- Wrap the _aiter_rope_ops() import in the explicit backend="aiter" path
  so a missing/broken aiter package surfaces as a helpful ValueError
  rather than a bare ModuleNotFoundError, matching activation.py.

Add tests for both new guards.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 15, 2026 21:03

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

Comment thread flashinfer/rope.py
Comment thread flashinfer/rope.py
Comment thread tests/rocm_tests/test_rope_aiter_hip.py Outdated
demandal25 and others added 2 commits June 15, 2026 22:58
Switch the AITER rope dispatch from rope_cached_positions_2c_fwd_inplace
to the lower-level rope_cached_positions_2c_fwd_impl with distinct in/out
tensors. AITER's _inplace wrapper is just _impl(x,y,x,y,...), so using
_impl directly lets the out-of-place path write straight into the fresh
output with no Q/K copy (measured ~1.3-2.8x faster than copy-then-rotate;
inplace is unchanged). For partial rotary the untouched nope tail is
copied across only when the output is a fresh tensor.

Because out-of-place no longer pays a copy, backend="auto" now routes
both the inplace and out-of-place wrappers to AITER for fp16 + >=2048
tokens (previously inplace-only). _auto_select_rope_backend drops its
inplace arg accordingly.

Robustness (from code review):
- auto falls back to native on mismatched q/k dtype instead of routing
  into the helper's dtype guard and raising — auto must never raise.
- memoized cos/sin entry carries a weakref to its source tensor; a cache
  hit is only honored when it still resolves to the same tensor, guarding
  against id() reuse for transient caches.
- memoized cos/sin tables are made contiguous before reaching the kernel.

Update README perf/auto description and tests accordingly.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Address Copilot review on PR ROCm#252: add a unit test that monkeypatches
flashinfer.rope._aiter_rope_ops to raise and asserts
_auto_select_rope_backend(query, key) returns "native" — verifying the
best-effort import probe falls back gracefully on a supported arch with a
missing/broken aiter install.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 15, 2026 23:10

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

Comment thread tests/rocm_tests/test_rope_aiter_hip.py Outdated
Address Copilot review on PR ROCm#252: switch the module pytestmark from a
bare is_aiter_supported (arch-only) check to tests.test_helpers.requires_aiter,
which skips when the arch is unsupported OR the aiter package is not
importable. On a supported GPU without aiter installed the old guard let
the tests run and then fail; this matches the other test_*_aiter_hip.py
modules.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@demandal25 demandal25 merged commit 40fb6a4 into ROCm:amd-integration Jun 15, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants