Skip to content

feat(hip,aiter): use CK rmsnorm2d for the plain rmsnorm AITER backend#258

Merged
demandal25 merged 2 commits into
ROCm:amd-integrationfrom
demandal25:rmsnorm-change-aiter-api
Jun 16, 2026
Merged

feat(hip,aiter): use CK rmsnorm2d for the plain rmsnorm AITER backend#258
demandal25 merged 2 commits into
ROCm:amd-integrationfrom
demandal25:rmsnorm-change-aiter-api

Conversation

@demandal25

@demandal25 demandal25 commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Summary

Switch the standalone rmsnorm AITER backend from AITER's vLLM-style rms_norm C++ symbol to its CK rmsnorm2d symbol (the rmsnorm2d_fwd entry point), aligning it with the fused_add_rmsnorm path which already uses the CK rmsnorm2d_with_add kernel. The Python API and JIT wiring are unchanged.

What changed

Kernel

  • flashinfer/csrc_rocm/norm_aiter.curmsnorm_aiter now forward-declares and calls CK rmsnorm2d(out, input, weight2d, eps, 0) instead of rms_norm. Reshapes weight to [1, n] as CK requires (same as the fused-add path).

Dtype gating (regression fix)

  • flashinfer/norm.py — CK rmsnorm2d only accepts fp16/bf16, whereas the old rms_norm dispatch also accepted fp32. backend="auto" is now gated to 2D fp16/bf16 inputs so fp32 routes to native instead of hitting a CK-specific error, and backend="aiter" rejects fp32 with a clear Python-level ValueError. (fp32 rmsnorm is unsupported by the native ROCm kernel too, so this only changes which error surfaces — it does not remove working functionality.)

Tests / docs

  • tests/rocm_tests/test_rmsnorm_aiter_hip.py — added fp32 coverage (auto→native selection, explicit-aiter rejection); updated kernel name in the header.
  • README.md — feature matrix, C++-integration note, and backend-exceptions bullet now reference CK rmsnorm2d; auto criteria note fp16/bf16.

Notes on correctness verification

  • Non-contiguous tensors are handled correctly. CK rmsnorm2d passes input.stride(0) / out.stride(0) to the kernel, so sliced/non-contiguous inputs and outputs produce correct results (verified: max abs err 0.0, untouched regions preserved). This is actually more stride-aware than the old packed-layout rms_norm kernel.
  • Precision is unchanged from the old path (fp16 ~3.9e-3, bf16 ~6.25e-2 vs an fp32 reference), so the existing AITER tolerances stay.

Test plan

  • pytest tests/rocm_tests/test_norm_hip.py tests/rocm_tests/test_rmsnorm_aiter_hip.py -n auto --reruns 2 — 767 passed
  • Verified the rebuilt JIT shim links the CK rmsnorm2d symbol (nm)
  • pre-commit run --files on all changed files — clean

Switch the standalone rmsnorm AITER shim from AITER's vLLM-style rms_norm
symbol to its CK rmsnorm2d symbol (the rmsnorm2d_fwd entry point), aligning
it with the fused_add_rmsnorm path. The shim reshapes weight to [1, n] as CK
requires; the Python API and JIT wiring are unchanged.

CK rmsnorm2d only accepts fp16/bf16 (the old rms_norm dispatch also accepted
fp32), so gate backend="auto" to 2D fp16/bf16 — fp32 now routes to native
instead of hitting a CK-specific error — and reject fp32 with a clear Python
error under backend="aiter". Note fp32 rmsnorm is unsupported by the native
ROCm kernel too, so this only changes which error surfaces.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 16, 2026 19:25

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the ROCm AITER RMSNorm backend to call AITER’s CK rmsnorm2d (via the rmsnorm2d_fwd entry point) instead of the older rms_norm symbol, aligning the standalone RMSNorm path with the existing fused-add RMSNorm path that already uses CK kernels. It also adjusts Python-side backend selection/validation and updates tests/docs accordingly.

Changes:

  • Switch AITER plain rmsnorm C++ entry point to reshape weight to [1, n] and call CK rmsnorm2d.
  • Tighten backend="auto" selection to 2D fp16/bf16 (and add a clearer error for backend="aiter" with unsupported dtypes).
  • Add ROCm tests for fp32 routing/rejection and update README references to the CK kernel name.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
flashinfer/csrc_rocm/norm_aiter.cu Routes plain AITER RMSNorm through CK rmsnorm2d with required weight reshaping.
flashinfer/norm.py Updates auto-backend gating and adds Python-level dtype validation for backend="aiter".
tests/rocm_tests/test_rmsnorm_aiter_hip.py Adds coverage for fp32 auto-routing and explicit AITER fp32 rejection; updates header text.
README.md Updates documentation references from rms_norm to CK rmsnorm2d.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread flashinfer/norm.py
Comment thread README.md Outdated
Comment thread README.md Outdated
Address Copilot review on ROCm#258:
- CK rmsnorm2d derives a single dtype from input and reads weight with it, so a
  mismatched weight dtype silently produced NaN/garbage. Reject weight.dtype !=
  input.dtype under backend="aiter", and have backend="auto" fall back to native
  on a mismatch (native handles fp16 input + fp32 weight fine).
- Update the README RMSNorm matrix row and backend-exception bullet to reflect
  the fp16/bf16 + matching-weight-dtype auto gating.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@demandal25 demandal25 merged commit 06a7a30 into ROCm:amd-integration Jun 16, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants