feat(hip,aiter): route rmsnorm auto backend to AITER for 2D inputs#256
Merged
Merged
Conversation
backend="auto" now selects AITER's C++ rms_norm kernel for 2D inputs on gfx942/gfx950, falling back to native for 3D inputs or when AITER is unavailable, so auto never raises. Aligns rmsnorm auto behavior with the other C++-level AITER ops (fused_add_rmsnorm, silu_and_mul, rope). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Routes flashinfer.norm.rmsnorm(backend="auto") to prefer the ROCm AITER C++ rms_norm kernel for 2D inputs on supported AMD GPUs (gfx942/gfx950), aligning RMSNorm’s auto-backend behavior with other AITER-integrated ops.
Changes:
- Update RMSNorm auto-backend selection to choose AITER for 2D inputs when AITER is available; otherwise fall back to native.
- Update RMSNorm backend documentation in the Python docstring and README feature/support descriptions.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
flashinfer/norm.py |
Changes the backend="auto" selection logic for RMSNorm to route 2D inputs to AITER when available. |
README.md |
Updates the feature matrix and AITER integration notes to reflect new RMSNorm auto-routing behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Address Copilot review on ROCm#256: - Rewrite the stale test_rmsnorm_auto_backend_stays_native (it passed a torch.device, which now breaks on .ndim, and asserted the old native-only behavior) to verify 2D->aiter / 3D->native routing. - Reword the README C++-integration note so it no longer implies rmsnorm auto is unconditionally AITER; it is now shape-gated to 2D inputs. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
test_norm validates the native kernel against a tight float32 reference. With rmsnorm backend="auto" now routing 2D inputs to AITER's lower-precision rms_norm, the default-backend call broke those tolerances (124 failures). Pin to backend="native"; AITER auto-path precision is covered separately in test_rmsnorm_aiter_hip.py. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
| | **RoPE (positional encoding)** | ✅ `native` | ✅ | **AITER** for the cos/sin-cache path on gfx942/gfx950; else **HIP `native`** | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ). AITER backend covers `apply_rope_with_cos_sin_cache` and its inplace variant via AITER's C++ `rope_cached_positions_2c_fwd_impl` (linked at the C++ level, no runtime `import aiter`); cos/sin passed as float32 | | ||
| | **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + gfx942/gfx950 + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | | ||
| | **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` | | ||
| | **RMSNorm** | ✅ `native` | ✅ | **AITER** for 2-D inputs on gfx942/gfx950; else **HIP `native`** (3-D inputs or AITER unavailable) | `rmsnorm`; AITER's C++ `rms_norm` (linked at the C++ level, no runtime `import aiter`); fp16/bf16, 2-D only, slightly lower precision at `hidden_size >= 1024` | |
Comment on lines
+330
to
334
| FlashInfer's JIT compiles a small HIP shim that calls AITER's C++ kernels | ||
| (`rms_norm`, `rmsnorm2d_with_add`, `aiter::silu_and_mul`, | ||
| `rope_cached_positions_2c_fwd_impl`) directly and links a symbol-visible | ||
| AITER `.so` — there is no runtime `import aiter` on these paths. The first | ||
| JIT build of each op builds the corresponding AITER module once with |
Comment on lines
49
to
+53
| device = torch.device("cuda:0") | ||
| assert _auto_select_norm_backend(device) == "native" | ||
| x2d = torch.randn(8, 128, dtype=torch.float16, device=device) | ||
| x3d = torch.randn(8, 4, 128, dtype=torch.float16, device=device) | ||
| assert _auto_select_norm_backend(x2d) == "aiter" | ||
| assert _auto_select_norm_backend(x3d) == "native" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Route the
rmsnormbackend="auto"path to AITER's C++rms_normkernel for2D inputs on gfx942/gfx950, instead of always staying on the in-tree HIP
nativekernel. This alignsrmsnormauto behavior with the other C++-levelAITER ops (
fused_add_rmsnorm,silu_and_mul,rope).What changed
flashinfer/norm.py—_auto_select_norm_backendnow takes the inputtensor and returns
"aiter"for 2D inputs when AITER is available, fallingback to
"native"for 3D inputs or when AITER is not installed, soautonever raises. Docstring updated to match.
README.md— feature matrix RMSNorm row, the C++-integration paragraph,and the backend-specific exceptions list updated to reflect the new routing.
tests/rocm_tests/test_rmsnorm_aiter_hip.py— rewrote the staletest_rmsnorm_auto_backend_stays_native(it passed atorch.deviceandasserted the old native-only behavior) to verify 2D→aiter / 3D→native.
tests/rocm_tests/test_norm_hip.py— pinnedtest_normtobackend="native"; it validates the native kernel against a tight float32reference and must not pick up AITER's lower-precision path via
auto.Architecture / design notes
backend="auto"resolves toaiter(C++rms_norm)nativenative(AITER rmsnorm is 2D-only)Note: AITER's
rms_normuses lower-precision reductions that exceed thenative kernel's test tolerance at
hidden_size >= 1024(fp16 atol ~4e-3, bf16~7e-2). The AITER-path accuracy is validated separately at AITER tolerances in
test_rmsnorm_aiter_hip.py.Test plan
pytest tests/rocm_tests/test_norm_hip.py tests/rocm_tests/test_rmsnorm_aiter_hip.py -n auto --reruns 2— 766 passedpre-commit run --fileson all changed files — clean