Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend.
| **POD attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA; single + batch variants (`PODWithPagedKVCacheWrapper`, `BatchPODWithPagedKVCacheWrapper`); JIT-only (excluded from AOT, same as upstream CUDA) |
| **RoPE (positional encoding)** | ✅ `native` | ✅ | **AITER** for the cos/sin-cache path on gfx942/gfx950; else **HIP `native`** | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ). AITER backend covers `apply_rope_with_cos_sin_cache` and its inplace variant via AITER's C++ `rope_cached_positions_2c_fwd_impl` (linked at the C++ level, no runtime `import aiter`); cos/sin passed as float32 |
| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + gfx942/gfx950 + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path |
| **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` |
| **RMSNorm** | ✅ `native` | ✅ | **AITER** for 2-D inputs on gfx942/gfx950; else **HIP `native`** (3-D inputs or AITER unavailable) | `rmsnorm`; AITER's C++ `rms_norm` (linked at the C++ level, no runtime `import aiter`); fp16/bf16, 2-D only, slightly lower precision at `hidden_size >= 1024` |
| **Fused add RMSNorm** | ✅ `native` | ✅ | **AITER** on gfx942/gfx950; else **HIP `native`** | `fused_add_rmsnorm`; AITER's C++ CK `rmsnorm2d_with_add` (linked at the C++ level, no runtime `import aiter`); 2-D only, slightly lower precision at `hidden_size >= 1024` |
| **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | |
| **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits |
Expand Down Expand Up @@ -325,23 +325,25 @@ prefill/decode), `backend="native"` for non-attention ops
(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`,
`silu_and_mul`, `rope`).

The `fused_add_rmsnorm`, `silu_and_mul`, and `rope` (cos/sin-cache) AITER
backends are integrated at the **C++ level**: FlashInfer's JIT compiles a
small HIP shim that calls AITER's C++ kernels
(`rmsnorm2d_with_add`, `aiter::silu_and_mul`,
The `rmsnorm`, `fused_add_rmsnorm`, `silu_and_mul`, and `rope`
(cos/sin-cache) AITER backends are integrated at the **C++ level**:
FlashInfer's JIT compiles a small HIP shim that calls AITER's C++ kernels
(`rms_norm`, `rmsnorm2d_with_add`, `aiter::silu_and_mul`,
`rope_cached_positions_2c_fwd_impl`) directly and links a symbol-visible
AITER `.so` — there is no runtime `import aiter` on these paths. The first
JIT build of each op builds the corresponding AITER module once with
Comment on lines +330 to 334
`AITER_SYMBOL_VISIBLE=1` and caches it under
`~/.cache/flashinfer/aiter_libs/` (the CK `module_rmsnorm` build is large
and can take many minutes the first time). For these three ops,
and can take many minutes the first time). For these ops,
`backend="auto"` resolves to AITER on gfx942/gfx950 and to HIP `native`
elsewhere; a later performance pass may re-introduce shape-based gating.
elsewhere, subject to op-specific constraints (e.g. `rmsnorm` auto only
routes 2-D inputs to AITER); a later performance pass may extend this
shape-based gating to the other ops.

Backend-specific exceptions to "auto picks AITER when supported":

* `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the
AITER C++ path (`rms_norm`) is opt-in via `backend="aiter"`.
* `rmsnorm`: `backend="auto"` picks the AITER C++ path (`rms_norm`) only
for 2-D inputs; 3-D inputs fall back to the HIP `native` kernel.
* `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True`
force `auto` back to `fa2` (AITER decode does not support either),
and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`.
Expand Down
21 changes: 13 additions & 8 deletions flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,16 @@ def get_norm_aiter_module():

return gen_norm_aiter_module().build_and_load()

def _auto_select_norm_backend(device: torch.device) -> str:
# auto routes plain rmsnorm to native: AITER's rms_norm uses lower-precision
# reductions that exceed the flashinfer test tolerance at hidden_size >= 1024.
# Pass backend="aiter" to opt in explicitly.
def _auto_select_norm_backend(input: torch.Tensor) -> str:
# auto routes plain rmsnorm to the C++ AITER kernel on supported devices
# (2D inputs only) and falls back to native everywhere else (3D inputs, or
# when AITER is not installed, so auto never raises). Note: AITER's rms_norm
# uses lower-precision reductions that exceed the flashinfer test tolerance
# at hidden_size >= 1024 (fp16 atol ~4e-3, bf16 ~7e-2).
from .aiter_utils import is_aiter_available

if input.ndim == 2 and is_aiter_available(input.device):
return "aiter"
return "native"

def _auto_select_fused_add_rmsnorm_backend(input: torch.Tensor) -> str:
Expand Down Expand Up @@ -79,7 +85,8 @@ def rmsnorm(
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
backend: str
Kernel backend to use. ``"auto"`` (default) selects the best available backend.
Kernel backend to use. ``"auto"`` (default) selects the best available backend:
the AITER C++ kernel for 2D inputs on supported ROCm devices, else native.
``"native"`` uses the FlashInfer JIT kernel on all platforms.
``"aiter"`` uses AMD AITER's ``rms_norm`` C++ kernel — ROCm (gfx942/gfx950)
only, 2D inputs only. Precision is slightly lower than ``"native"`` at
Expand All @@ -91,9 +98,7 @@ def rmsnorm(
Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
"""
if IS_HIP:
_backend = (
backend if backend != "auto" else _auto_select_norm_backend(input.device)
)
_backend = backend if backend != "auto" else _auto_select_norm_backend(input)
if _backend == "aiter":
Comment thread
demandal25 marked this conversation as resolved.
from .aiter_utils import require_aiter

Expand Down
7 changes: 5 additions & 2 deletions tests/rocm_tests/test_norm_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,14 @@ def test_norm(batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguou
w = torch.randn(hidden_size).to(0).to(dtype)

y_ref = llama_rms_norm(x, w)
# Pin to native: this test checks the native kernel against a tight float32
# reference. On ROCm, backend="auto" routes 2D inputs to AITER's lower-precision
# rms_norm, whose accuracy is validated separately in test_rmsnorm_aiter_hip.py.
if specify_out:
y = torch.empty_like(x)
flashinfer.norm.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
flashinfer.norm.rmsnorm(x, w, out=y, enable_pdl=enable_pdl, backend="native")
else:
y = flashinfer.norm.rmsnorm(x, w, enable_pdl=enable_pdl)
y = flashinfer.norm.rmsnorm(x, w, enable_pdl=enable_pdl, backend="native")

rtol, atol = (1.6e-2, 1.6e-2) if dtype == torch.bfloat16 else (1e-3, 1e-3)
torch.testing.assert_close(y_ref, y, rtol=rtol, atol=atol)
Expand Down
9 changes: 6 additions & 3 deletions tests/rocm_tests/test_rmsnorm_aiter_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ def test_rmsnorm_aiter_vs_ref(dtype, hidden_size, batch_size):


@requires_aiter
def test_rmsnorm_auto_backend_stays_native():
"""auto backend on gfx942/950 should stay on native kernel (precision parity with tests)."""
def test_rmsnorm_auto_backend_selects_aiter_for_2d():
"""auto backend on gfx942/950 routes 2D inputs to AITER and 3D inputs to native."""
from flashinfer.norm import _auto_select_norm_backend

device = torch.device("cuda:0")
assert _auto_select_norm_backend(device) == "native"
x2d = torch.randn(8, 128, dtype=torch.float16, device=device)
x3d = torch.randn(8, 4, 128, dtype=torch.float16, device=device)
assert _auto_select_norm_backend(x2d) == "aiter"
assert _auto_select_norm_backend(x3d) == "native"
Comment on lines 49 to +53


@requires_aiter
Expand Down
Loading