feat(hip,aiter): use CK rmsnorm2d for the plain rmsnorm AITER backend#258
Merged
demandal25 merged 2 commits intoJun 16, 2026
Merged
Conversation
Switch the standalone rmsnorm AITER shim from AITER's vLLM-style rms_norm symbol to its CK rmsnorm2d symbol (the rmsnorm2d_fwd entry point), aligning it with the fused_add_rmsnorm path. The shim reshapes weight to [1, n] as CK requires; the Python API and JIT wiring are unchanged. CK rmsnorm2d only accepts fp16/bf16 (the old rms_norm dispatch also accepted fp32), so gate backend="auto" to 2D fp16/bf16 — fp32 now routes to native instead of hitting a CK-specific error — and reject fp32 with a clear Python error under backend="aiter". Note fp32 rmsnorm is unsupported by the native ROCm kernel too, so this only changes which error surfaces. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR updates the ROCm AITER RMSNorm backend to call AITER’s CK rmsnorm2d (via the rmsnorm2d_fwd entry point) instead of the older rms_norm symbol, aligning the standalone RMSNorm path with the existing fused-add RMSNorm path that already uses CK kernels. It also adjusts Python-side backend selection/validation and updates tests/docs accordingly.
Changes:
- Switch AITER plain
rmsnormC++ entry point to reshapeweightto[1, n]and call CKrmsnorm2d. - Tighten
backend="auto"selection to 2D fp16/bf16 (and add a clearer error forbackend="aiter"with unsupported dtypes). - Add ROCm tests for fp32 routing/rejection and update README references to the CK kernel name.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
flashinfer/csrc_rocm/norm_aiter.cu |
Routes plain AITER RMSNorm through CK rmsnorm2d with required weight reshaping. |
flashinfer/norm.py |
Updates auto-backend gating and adds Python-level dtype validation for backend="aiter". |
tests/rocm_tests/test_rmsnorm_aiter_hip.py |
Adds coverage for fp32 auto-routing and explicit AITER fp32 rejection; updates header text. |
README.md |
Updates documentation references from rms_norm to CK rmsnorm2d. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Address Copilot review on ROCm#258: - CK rmsnorm2d derives a single dtype from input and reads weight with it, so a mismatched weight dtype silently produced NaN/garbage. Reject weight.dtype != input.dtype under backend="aiter", and have backend="auto" fall back to native on a mismatch (native handles fp16 input + fp32 weight fine). - Update the README RMSNorm matrix row and backend-exception bullet to reflect the fp16/bf16 + matching-weight-dtype auto gating. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Switch the standalone
rmsnormAITER backend from AITER's vLLM-stylerms_normC++ symbol to its CKrmsnorm2dsymbol (thermsnorm2d_fwdentry point), aligning it with thefused_add_rmsnormpath which already uses the CKrmsnorm2d_with_addkernel. The Python API and JIT wiring are unchanged.What changed
Kernel
flashinfer/csrc_rocm/norm_aiter.cu—rmsnorm_aiternow forward-declares and calls CKrmsnorm2d(out, input, weight2d, eps, 0)instead ofrms_norm. Reshapesweightto[1, n]as CK requires (same as the fused-add path).Dtype gating (regression fix)
flashinfer/norm.py— CKrmsnorm2donly accepts fp16/bf16, whereas the oldrms_normdispatch also accepted fp32.backend="auto"is now gated to 2D fp16/bf16 inputs so fp32 routes to native instead of hitting a CK-specific error, andbackend="aiter"rejects fp32 with a clear Python-levelValueError. (fp32 rmsnorm is unsupported by the native ROCm kernel too, so this only changes which error surfaces — it does not remove working functionality.)Tests / docs
tests/rocm_tests/test_rmsnorm_aiter_hip.py— added fp32 coverage (auto→native selection, explicit-aiter rejection); updated kernel name in the header.README.md— feature matrix, C++-integration note, and backend-exceptions bullet now reference CKrmsnorm2d; auto criteria note fp16/bf16.Notes on correctness verification
rmsnorm2dpassesinput.stride(0)/out.stride(0)to the kernel, so sliced/non-contiguous inputs and outputs produce correct results (verified: max abs err 0.0, untouched regions preserved). This is actually more stride-aware than the old packed-layoutrms_normkernel.Test plan
pytest tests/rocm_tests/test_norm_hip.py tests/rocm_tests/test_rmsnorm_aiter_hip.py -n auto --reruns 2— 767 passedrmsnorm2dsymbol (nm)pre-commit run --fileson all changed files — clean