From 532a3ad062daf93e7c6d5ed7aad12f6b629b6cc0 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 18:11:29 +0000 Subject: [PATCH 1/6] feat(hip,aiter): add AITER backend for silu_and_mul Route flashinfer.activation.silu_and_mul through AMD AITER's silu_and_mul on ROCm via a backend="auto"|"native"|"aiter" parameter, mirroring the existing norm.py AITER-backend idiom. "auto" stays on the native JIT kernel except for large (>=64M element) 2D fp16 inputs, where AITER is ~5-10% faster and matches native precision. bf16 is excluded from the auto path (AITER max err ~6e-2 vs native ~4e-3); "aiter" remains available as an explicit opt-in. Adds tests/rocm_tests/test_activation_aiter_hip.py covering correctness across shapes/dtypes, out= handling, backend auto-selection, and the unknown-backend error. Co-Authored-By: Claude Opus 4.7 --- flashinfer/activation.py | 71 +++++++++++- tests/rocm_tests/test_activation_aiter_hip.py | 103 ++++++++++++++++++ 2 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 tests/rocm_tests/test_activation_aiter_hip.py diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 2813ae6ad9..d9febf693d 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -20,7 +20,7 @@ import torch -from .device_utils import IS_CUDA +from .device_utils import IS_CUDA, IS_HIP from .jit import gen_act_and_mul_module from .utils import ( device_support_pdl, @@ -33,6 +33,44 @@ from .fp4_quantization import get_fp4_quantization_module +if IS_HIP: + + @functools.cache + def _aiter_act_ops(): + import aiter as _aiter + + return _aiter + + # AITER's silu_and_mul only overtakes the native kernel on large, + # bandwidth-bound shapes (~5-10% faster); below that a fixed ~0.7us launch + # overhead makes it slower. It also matches the native kernel's precision + # only in fp16 (bf16 is ~6e-2 vs ~4e-3 max err). The cutoff counts elements + # of the full input (rows x 2*hidden); the measured break-even is ~33M input + # elements (e.g. 2048 x 16384), so 64M is a safe 2x margin. fp16 only. + _AITER_SILU_AND_MUL_MIN_ELEMS = 64 * 1024 * 1024 + + def _auto_select_silu_and_mul_backend(input: torch.Tensor) -> str: + # Cheapest guards first so the common small/medium case exits early. + if input.dtype != torch.float16: + return "native" + if input.ndim != 2: + return "native" + if input.numel() < _AITER_SILU_AND_MUL_MIN_ELEMS: + return "native" + from .aiter_utils import is_aiter_supported + + if not is_aiter_supported(input.device): + return "native" + try: + # Best-effort probe: a supported arch can still lack a usable aiter + # (not installed, or its compiled extension fails to load). auto must + # always be able to fall back to native, so catch any import failure. + _aiter_act_ops() + except Exception: + return "native" + return "aiter" + + @functools.cache def get_act_and_mul_module(act_func_name: str): module = gen_act_and_mul_module(act_func_name).build_and_load() @@ -70,7 +108,10 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: def silu_and_mul( - input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None + input: torch.Tensor, + out: torch.Tensor = None, + enable_pdl: Optional[bool] = None, + backend: str = "auto", ) -> torch.Tensor: r"""Fused SiLU and Mul operation. @@ -88,13 +129,22 @@ def silu_and_mul( Whether to enable `programmatic dependent launch `_ + backend: str + Kernel backend to use. ``"auto"`` (default) uses the native kernel for small + and medium inputs, and switches to AITER on ROCm for large (>= 64M element) + 2D fp16 inputs where its kernel is faster and matches native precision; it + falls back to native whenever AITER is unavailable. + ``"native"`` uses the FlashInfer JIT kernel on all platforms. + ``"aiter"`` uses AMD AITER's ``silu_and_mul`` — ROCm (gfx942/gfx950) only; + requires the ``aiter`` package. Precision matches ``"native"`` in fp16 but is + lower in bf16 (max err ~6e-2 vs ~4e-3), which is why ``"auto"`` restricts the + AITER path to fp16. + Returns ------- output: torch.Tensor Output tensor, shape (..., hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: @@ -105,6 +155,19 @@ def silu_and_mul( device=input.device, dtype=input.dtype, ) + if IS_HIP: + _backend = ( + backend if backend != "auto" else _auto_select_silu_and_mul_backend(input) + ) + if _backend == "aiter": + _aiter_act_ops().silu_and_mul(out, input) + return out + if _backend != "native": + raise ValueError( + f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'." + ) + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) get_act_and_mul_module("silu").silu_and_mul( out, input, diff --git a/tests/rocm_tests/test_activation_aiter_hip.py b/tests/rocm_tests/test_activation_aiter_hip.py new file mode 100644 index 0000000000..ca3aa99bdb --- /dev/null +++ b/tests/rocm_tests/test_activation_aiter_hip.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Tests for the AITER silu_and_mul backend exposed via +# flashinfer.activation.silu_and_mul(backend="aiter"). +# +# Note on tolerances: AITER silu_and_mul matches the native JIT kernel exactly in +# fp16, but uses lower-precision arithmetic in bf16 (max err ~6e-2 vs the native +# kernel's ~4e-3). The tolerances below reflect AITER's actual precision. + +import pytest +import torch + +import flashinfer +from flashinfer.aiter_utils import is_aiter_supported + + +def _silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x_f32 = x.float() + gate, up = x_f32[..., :d], x_f32[..., d:] + return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype) + + +@pytest.mark.skipif( + not is_aiter_supported(torch.device("cuda:0")), + reason="AITER backend requires gfx942/gfx950", +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("d", [128, 512, 4096, 8192, 14336]) +@pytest.mark.parametrize("num_tokens", [1, 8, 256]) +def test_silu_and_mul_aiter_vs_ref(dtype, d, num_tokens): + torch.manual_seed(0xA17E2) + device = torch.device("cuda:0") + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device) + + ref = _silu_and_mul_ref(x) + got = flashinfer.activation.silu_and_mul(x, backend="aiter") + + # AITER precision: fp16 matches native; bf16 ~6e-2 observed across shapes. + rtol, atol = (7e-2, 7e-2) if dtype == torch.bfloat16 else (1e-3, 1e-3) + torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) + + +@pytest.mark.skipif( + not is_aiter_supported(torch.device("cuda:0")), + reason="AITER backend requires gfx942/gfx950", +) +def test_silu_and_mul_auto_backend_selection(): + """auto stays native for small/bf16 inputs and picks AITER for large fp16 2D inputs.""" + from flashinfer.activation import ( + _AITER_SILU_AND_MUL_MIN_ELEMS, + _auto_select_silu_and_mul_backend, + ) + + device = torch.device("cuda:0") + + # Small inputs: native regardless of dtype. + for dtype in (torch.float16, torch.bfloat16): + small = torch.empty(8, 256, dtype=dtype, device=device) + assert _auto_select_silu_and_mul_backend(small) == "native" + + # Large enough fp16 2D input (>= cutoff). cols is a multiple of 8 so the + # shape also clears silu_and_mul's 16-byte alignment guard, i.e. it is a + # shape that could actually flow through the public function to AITER. + rows = 8192 + cols = -(-_AITER_SILU_AND_MUL_MIN_ELEMS // rows // 8) * 8 # round up to mult of 8 + large_fp16 = torch.empty(rows, cols, dtype=torch.float16, device=device) + assert large_fp16.numel() >= _AITER_SILU_AND_MUL_MIN_ELEMS + assert _auto_select_silu_and_mul_backend(large_fp16) == "aiter" + + # Same large shape in bf16: native (precision guard). + large_bf16 = torch.empty(rows, cols, dtype=torch.bfloat16, device=device) + assert _auto_select_silu_and_mul_backend(large_bf16) == "native" + + # Large fp16 but 3D: native (2D guard). + large_3d = torch.empty(2, rows, cols, dtype=torch.float16, device=device) + assert _auto_select_silu_and_mul_backend(large_3d) == "native" + + +@pytest.mark.skipif( + not is_aiter_supported(torch.device("cuda:0")), + reason="AITER backend requires gfx942/gfx950", +) +def test_silu_and_mul_aiter_with_out_tensor(): + """backend='aiter' respects the out= argument.""" + device = torch.device("cuda:0") + x = torch.randn(8, 256, dtype=torch.float16, device=device) + out = torch.empty(8, 128, dtype=torch.float16, device=device) + ret = flashinfer.activation.silu_and_mul(x, out=out, backend="aiter") + assert ret.data_ptr() == out.data_ptr() + assert not torch.all(out == 0) + + +@pytest.mark.skipif( + not is_aiter_supported(torch.device("cuda:0")), + reason="AITER backend requires gfx942/gfx950", +) +def test_silu_and_mul_unknown_backend_raises(): + device = torch.device("cuda:0") + x = torch.randn(8, 256, dtype=torch.float16, device=device) + with pytest.raises(ValueError, match="Unknown backend"): + flashinfer.activation.silu_and_mul(x, backend="nope") From 18706712b521eb23728eaf814d853b2076dd3061 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 18:32:07 +0000 Subject: [PATCH 2/6] fix(activation): validate backend on all platforms; clarify test ceil-div Address Copilot review on PR #251: - Validate the backend argument unconditionally so an unknown value or an explicit backend="aiter" off ROCm/unsupported arch raises ValueError instead of silently falling through to the native kernel. - Use the clearer ceil-to-multiple-of-8 form in the auto-selection test. Co-Authored-By: Claude Opus 4.7 --- flashinfer/activation.py | 25 +++++++++++++------ tests/rocm_tests/test_activation_aiter_hip.py | 18 +++++++------ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/flashinfer/activation.py b/flashinfer/activation.py index d9febf693d..c45ef5a93d 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -136,15 +136,30 @@ def silu_and_mul( falls back to native whenever AITER is unavailable. ``"native"`` uses the FlashInfer JIT kernel on all platforms. ``"aiter"`` uses AMD AITER's ``silu_and_mul`` — ROCm (gfx942/gfx950) only; - requires the ``aiter`` package. Precision matches ``"native"`` in fp16 but is - lower in bf16 (max err ~6e-2 vs ~4e-3), which is why ``"auto"`` restricts the - AITER path to fp16. + requires the ``aiter`` package, and raises ``ValueError`` on any other + platform. Precision matches ``"native"`` in fp16 but is lower in bf16 + (max err ~6e-2 vs ~4e-3), which is why ``"auto"`` restricts the AITER path + to fp16. Returns ------- output: torch.Tensor Output tensor, shape (..., hidden_size). """ + if backend not in ("auto", "native", "aiter"): + raise ValueError( + f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'." + ) + if backend == "aiter": + # Validate the explicit opt-in on every platform so a misconfiguration + # surfaces here instead of silently running native off ROCm. + from .aiter_utils import is_aiter_supported + + if not (IS_HIP and is_aiter_supported(input.device)): + raise ValueError( + "backend='aiter' requires a ROCm gfx942/gfx950 device with the " + "aiter package installed." + ) if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: @@ -162,10 +177,6 @@ def silu_and_mul( if _backend == "aiter": _aiter_act_ops().silu_and_mul(out, input) return out - if _backend != "native": - raise ValueError( - f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'." - ) if enable_pdl is None: enable_pdl = device_support_pdl(input.device) get_act_and_mul_module("silu").silu_and_mul( diff --git a/tests/rocm_tests/test_activation_aiter_hip.py b/tests/rocm_tests/test_activation_aiter_hip.py index ca3aa99bdb..499d072fb2 100644 --- a/tests/rocm_tests/test_activation_aiter_hip.py +++ b/tests/rocm_tests/test_activation_aiter_hip.py @@ -64,7 +64,7 @@ def test_silu_and_mul_auto_backend_selection(): # shape also clears silu_and_mul's 16-byte alignment guard, i.e. it is a # shape that could actually flow through the public function to AITER. rows = 8192 - cols = -(-_AITER_SILU_AND_MUL_MIN_ELEMS // rows // 8) * 8 # round up to mult of 8 + cols = -(-_AITER_SILU_AND_MUL_MIN_ELEMS // (rows * 8)) * 8 # ceil to mult of 8 large_fp16 = torch.empty(rows, cols, dtype=torch.float16, device=device) assert large_fp16.numel() >= _AITER_SILU_AND_MUL_MIN_ELEMS assert _auto_select_silu_and_mul_backend(large_fp16) == "aiter" @@ -92,12 +92,16 @@ def test_silu_and_mul_aiter_with_out_tensor(): assert not torch.all(out == 0) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) def test_silu_and_mul_unknown_backend_raises(): - device = torch.device("cuda:0") - x = torch.randn(8, 256, dtype=torch.float16, device=device) + # Backend validation is platform-independent, so this needs no aiter device. + x = torch.randn(8, 256, dtype=torch.float16) with pytest.raises(ValueError, match="Unknown backend"): flashinfer.activation.silu_and_mul(x, backend="nope") + + +def test_silu_and_mul_aiter_backend_rejected_when_unsupported(): + """Explicit backend='aiter' raises (not silently falls back) on an unsupported device.""" + cpu_x = torch.randn(8, 256, dtype=torch.float16) + if not is_aiter_supported(cpu_x.device): + with pytest.raises(ValueError, match="requires a ROCm"): + flashinfer.activation.silu_and_mul(cpu_x, backend="aiter") From 094e8734c045834acf0c31eb0e707bba0fbfe9ec Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 18:36:56 +0000 Subject: [PATCH 3/6] fix(activation): lower aiter silu_and_mul auto cutoff to 33M; document in README Set the auto-selection threshold to the measured ~33M-element break-even (was a conservative 64M). Update the README feature matrix and AITER Support section to list silu_and_mul's AITER backend and its auto-routing criteria. Co-Authored-By: Claude Opus 4.7 --- README.md | 13 +++++++++---- flashinfer/activation.py | 8 ++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index de3d3b3b3e..a0ed808fce 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend. | **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | | | **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | | **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) | -| **Activation** | ✅ | — | HIP | SiLU / GELU with fused gating | +| **Activation** | ✅ `native` | ✅ | **HIP `native`**, except `silu_and_mul` resolves to **AITER** for `fp16` + 2-D + `>= 33M`-element inputs; else HIP | SiLU / GELU with fused gating. AITER path (`silu_and_mul` only) is opt-in via `backend="aiter"`; matches native precision in fp16, lower in bf16 | | **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` | | **`torch.compile`** | ✅ (opt-in) | n/a | n/a | Set `FLASHINFER_USE_TORCH_CUSTOM_OPS=1` **before** importing `flashinfer`; requires PyTorch ≥ 2.4. Without it, `torch.compile` raises a clear error if it traces into a flashinfer op | @@ -309,7 +309,7 @@ pytest -n auto --reruns 2 -m "slow" FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill` (paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`, -`fused_add_rmsnorm`, and `MLA` paths to +`fused_add_rmsnorm`, `silu_and_mul`, and `MLA` paths to [AITER](https://github.com/ROCm/aiter). MLA on ROCm is **AITER-only** — there is no in-tree HIP MLA kernel yet, so `backend="auto"` (the default for the MLA wrapper) resolves directly @@ -322,14 +322,19 @@ a one-time `logger.warning`. Pass `backend="aiter"` to require AITER explicitly, or pass the in-tree backend string to skip it: `backend="fa2"` for the attention wrappers (single/batch prefill/decode), `backend="native"` for non-attention ops -(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`). Three -backend-specific exceptions to "auto picks AITER when supported": +(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`, +`silu_and_mul`). Four backend-specific exceptions to "auto picks AITER +when supported": * `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the AITER path is opt-in via `backend="aiter"`. * `fused_add_rmsnorm`: `backend="auto"` is shape-gated — it picks AITER only for 2-D inputs with `>= 4M` elements (where the CK kernel is faster) and stays on the HIP `native` kernel otherwise. +* `silu_and_mul`: `backend="auto"` picks AITER only for `fp16` + 2-D + + `>= 33M`-element inputs (where it is faster and matches native + precision) and otherwise stays on HIP `native`; the AITER path is also + available explicitly via `backend="aiter"`. * `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True` force `auto` back to `fa2` (AITER decode does not support either), and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`. diff --git a/flashinfer/activation.py b/flashinfer/activation.py index c45ef5a93d..ea084e7c8a 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -45,9 +45,9 @@ def _aiter_act_ops(): # bandwidth-bound shapes (~5-10% faster); below that a fixed ~0.7us launch # overhead makes it slower. It also matches the native kernel's precision # only in fp16 (bf16 is ~6e-2 vs ~4e-3 max err). The cutoff counts elements - # of the full input (rows x 2*hidden); the measured break-even is ~33M input - # elements (e.g. 2048 x 16384), so 64M is a safe 2x margin. fp16 only. - _AITER_SILU_AND_MUL_MIN_ELEMS = 64 * 1024 * 1024 + # of the full input (rows x 2*hidden); ~33M is the measured break-even + # (e.g. 2048 x 16384). fp16 only. + _AITER_SILU_AND_MUL_MIN_ELEMS = 33 * 1024 * 1024 def _auto_select_silu_and_mul_backend(input: torch.Tensor) -> str: # Cheapest guards first so the common small/medium case exits early. @@ -131,7 +131,7 @@ def silu_and_mul( backend: str Kernel backend to use. ``"auto"`` (default) uses the native kernel for small - and medium inputs, and switches to AITER on ROCm for large (>= 64M element) + and medium inputs, and switches to AITER on ROCm for large (>= 33M element) 2D fp16 inputs where its kernel is faster and matches native precision; it falls back to native whenever AITER is unavailable. ``"native"`` uses the FlashInfer JIT kernel on all platforms. From f5ef7a4dd86a031744a836262ae5fd39429553b0 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 18:44:29 +0000 Subject: [PATCH 4/6] fix(activation): probe aiter import on opt-in; assert out= correctness Address second Copilot review on PR #251: - backend="aiter" now probes _aiter_act_ops() and re-raises a clear ValueError (chaining the original) when the aiter package is missing or fails to import, instead of surfacing a cryptic ImportError at the call. - The out= test seeds the tensor with NaN and asserts numerical correctness against the reference, so a no-op write can no longer pass. Co-Authored-By: Claude Opus 4.7 --- flashinfer/activation.py | 7 +++++++ tests/rocm_tests/test_activation_aiter_hip.py | 8 +++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/flashinfer/activation.py b/flashinfer/activation.py index ea084e7c8a..89a963647a 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -160,6 +160,13 @@ def silu_and_mul( "backend='aiter' requires a ROCm gfx942/gfx950 device with the " "aiter package installed." ) + try: + _aiter_act_ops() + except Exception as e: + raise ValueError( + "backend='aiter' requires the aiter package, which failed to " + f"import: {e}" + ) from e if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: diff --git a/tests/rocm_tests/test_activation_aiter_hip.py b/tests/rocm_tests/test_activation_aiter_hip.py index 499d072fb2..160955697b 100644 --- a/tests/rocm_tests/test_activation_aiter_hip.py +++ b/tests/rocm_tests/test_activation_aiter_hip.py @@ -83,13 +83,15 @@ def test_silu_and_mul_auto_backend_selection(): reason="AITER backend requires gfx942/gfx950", ) def test_silu_and_mul_aiter_with_out_tensor(): - """backend='aiter' respects the out= argument.""" + """backend='aiter' writes the correct result into the supplied out= tensor.""" device = torch.device("cuda:0") x = torch.randn(8, 256, dtype=torch.float16, device=device) - out = torch.empty(8, 128, dtype=torch.float16, device=device) + # Seed out with a sentinel the kernel must overwrite, so a no-op write fails. + out = torch.full((8, 128), float("nan"), dtype=torch.float16, device=device) ret = flashinfer.activation.silu_and_mul(x, out=out, backend="aiter") assert ret.data_ptr() == out.data_ptr() - assert not torch.all(out == 0) + ref = _silu_and_mul_ref(x) + torch.testing.assert_close(out.float(), ref.float(), rtol=1e-3, atol=1e-3) def test_silu_and_mul_unknown_backend_raises(): From 068f1ea34cdad24f98b29c639540d2ce48c8864c Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 19:01:13 +0000 Subject: [PATCH 5/6] test(rocm): share a single requires_aiter skip decorator Add requires_aiter to tests/test_helpers/test_helpers.py (gating on arch + aiter importability) and import it from every AITER rocm test, replacing the per-file copies of the @pytest.mark.skipif(not is_aiter_supported...) decorator. One definition, no duplicates. Co-Authored-By: Claude Opus 4.7 --- tests/rocm_tests/test_activation_aiter_hip.py | 16 +++------- .../test_append_paged_kv_cache_aiter_hip.py | 12 ++----- .../rocm_tests/test_batch_decode_aiter_hip.py | 27 ++++------------ tests/rocm_tests/test_mla_aiter_hip.py | 32 ++++--------------- tests/rocm_tests/test_norm_hip.py | 16 +--------- tests/rocm_tests/test_rmsnorm_aiter_hip.py | 17 +++------- tests/test_helpers/test_helpers.py | 15 +++++++++ 7 files changed, 40 insertions(+), 95 deletions(-) diff --git a/tests/rocm_tests/test_activation_aiter_hip.py b/tests/rocm_tests/test_activation_aiter_hip.py index 160955697b..c86dc3054c 100644 --- a/tests/rocm_tests/test_activation_aiter_hip.py +++ b/tests/rocm_tests/test_activation_aiter_hip.py @@ -13,6 +13,7 @@ import flashinfer from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter def _silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: @@ -22,10 +23,7 @@ def _silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("d", [128, 512, 4096, 8192, 14336]) @pytest.mark.parametrize("num_tokens", [1, 8, 256]) @@ -42,10 +40,7 @@ def test_silu_and_mul_aiter_vs_ref(dtype, d, num_tokens): torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_silu_and_mul_auto_backend_selection(): """auto stays native for small/bf16 inputs and picks AITER for large fp16 2D inputs.""" from flashinfer.activation import ( @@ -78,10 +73,7 @@ def test_silu_and_mul_auto_backend_selection(): assert _auto_select_silu_and_mul_backend(large_3d) == "native" -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_silu_and_mul_aiter_with_out_tensor(): """backend='aiter' writes the correct result into the supplied out= tensor.""" device = torch.device("cuda:0") diff --git a/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py b/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py index dc8df004f3..b0ff18ed81 100644 --- a/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py +++ b/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py @@ -10,7 +10,7 @@ import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter from flashinfer.jit.core import logger logger.setLevel(logging.ERROR) @@ -63,10 +63,7 @@ def _build_append_inputs(append_lens, page_size, num_kv_heads, head_dim, dtype, ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("page_size", [16, 32]) @pytest.mark.parametrize("num_kv_heads,head_dim", [(4, 64), (8, 128), (16, 128)]) @@ -133,10 +130,7 @@ def test_append_paged_kv_cache_aiter_vs_native( torch.testing.assert_close(v_aiter, v_native, rtol=0, atol=0) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_append_paged_kv_cache_aiter_auto_routes_on_nhd_fp16(): """auto backend should pick aiter when device + dtype + layout match constraints.""" from flashinfer.page import _auto_select_kv_append_backend diff --git a/tests/rocm_tests/test_batch_decode_aiter_hip.py b/tests/rocm_tests/test_batch_decode_aiter_hip.py index d59467b947..63def3f2c7 100644 --- a/tests/rocm_tests/test_batch_decode_aiter_hip.py +++ b/tests/rocm_tests/test_batch_decode_aiter_hip.py @@ -9,7 +9,7 @@ import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter from flashinfer.jit.core import logger logger.setLevel(logging.ERROR) @@ -57,10 +57,7 @@ def _build_paged_kv( ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("batch_size", [1, 4, 17]) @pytest.mark.parametrize("page_size", [16, 32]) @@ -127,10 +124,7 @@ def test_batch_decode_aiter_vs_fa2( torch.testing.assert_close(o_cand.float(), o_ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_batch_decode_aiter_rejects_invalid_config(): """plan() should reject unsupported configs with a clear error.""" device = torch.device("cuda:0") @@ -217,10 +211,7 @@ def test_batch_decode_aiter_rejects_invalid_config(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("page_size", [16]) @@ -295,10 +286,7 @@ def test_batch_decode_aiter_sliding_window_vs_fa2( torch.testing.assert_close(o_cand.float(), o_ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("window_left", [-1, 31]) def test_batch_decode_aiter_return_lse_via_fa2(dtype, window_left): @@ -367,10 +355,7 @@ def test_batch_decode_aiter_return_lse_via_fa2(dtype, window_left): torch.testing.assert_close(lse_cand, lse_ref, rtol=1e-3, atol=1e-3) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_batch_decode_auto_routes_cuda_graph_to_fa2(): """backend='auto' with use_cuda_graph=True must route to fa2 (AITER doesn't support graph capture).""" diff --git a/tests/rocm_tests/test_mla_aiter_hip.py b/tests/rocm_tests/test_mla_aiter_hip.py index b0b84c6111..1a8e9be0d9 100644 --- a/tests/rocm_tests/test_mla_aiter_hip.py +++ b/tests/rocm_tests/test_mla_aiter_hip.py @@ -9,7 +9,7 @@ import pytest import torch -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter def _paged_mla_ref( @@ -101,10 +101,7 @@ def _build_paged_kv( return ckv_cache, kpe_cache, kv_indptr, kv_indices, kv_last_page_len -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("page_size", [1]) @pytest.mark.parametrize("num_heads,head_dim_ckv,head_dim_kpe", [(16, 512, 64)]) @@ -185,10 +182,7 @@ def test_mla_decode_vs_ref( torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_decode_out_tensor(): """run() respects a pre-allocated out= tensor.""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -237,10 +231,7 @@ def test_mla_decode_out_tensor(): assert not torch.all(out == 0) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_plan_validation(): """plan() raises on invalid arguments.""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -278,10 +269,7 @@ def test_mla_plan_validation(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_plan_kv_len_inconsistent_with_paging(): """Passing last-page counts as kv_len_arr must fail (was accepted pre-conversion).""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -308,10 +296,7 @@ def test_mla_plan_kv_len_inconsistent_with_paging(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_run_before_plan_raises(): """run() before plan() raises RuntimeError.""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -328,10 +313,7 @@ def test_mla_run_before_plan_raises(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("backend", ["auto", "aiter"]) def test_mla_backend_accepts_auto_and_aiter(backend): """The ROCm MLA wrapper accepts both 'auto' (default) and 'aiter'. diff --git a/tests/rocm_tests/test_norm_hip.py b/tests/rocm_tests/test_norm_hip.py index 6d879b82f5..fd92d985da 100644 --- a/tests/rocm_tests/test_norm_hip.py +++ b/tests/rocm_tests/test_norm_hip.py @@ -15,25 +15,11 @@ limitations under the License. """ -import importlib.util - import pytest import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported - -# is_aiter_supported only checks the GPU arch; AITER is a separate source install, -# so a supported board can still lack the package. Require both so these tests skip -# (rather than error/fail) when the arch matches but aiter isn't importable. -_aiter_available = ( - is_aiter_supported(torch.device("cuda:0")) - and importlib.util.find_spec("aiter") is not None -) -requires_aiter = pytest.mark.skipif( - not _aiter_available, - reason="AITER backend requires gfx942/gfx950 and the aiter package", -) +from tests.test_helpers.test_helpers import requires_aiter def llama_rms_norm(x, w, eps=1e-6): diff --git a/tests/rocm_tests/test_rmsnorm_aiter_hip.py b/tests/rocm_tests/test_rmsnorm_aiter_hip.py index 841f065d60..9105cfa23d 100644 --- a/tests/rocm_tests/test_rmsnorm_aiter_hip.py +++ b/tests/rocm_tests/test_rmsnorm_aiter_hip.py @@ -12,7 +12,7 @@ import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter def _rms_norm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: @@ -23,10 +23,7 @@ def _rms_norm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch. return (x * torch.rsqrt(variance + eps) * w.float()).to(orig) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [128, 512, 1024, 4096]) @pytest.mark.parametrize("batch_size", [1, 32, 256]) @@ -44,10 +41,7 @@ def test_rmsnorm_aiter_vs_ref(dtype, hidden_size, batch_size): torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_rmsnorm_auto_backend_stays_native(): """auto backend on gfx942/950 should stay on native kernel (precision parity with tests).""" from flashinfer.norm import _auto_select_norm_backend @@ -59,10 +53,7 @@ def test_rmsnorm_auto_backend_stays_native(): assert _auto_select_norm_backend(device, torch.float32) == "native" -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_rmsnorm_aiter_with_out_tensor(): """backend='aiter' respects the out= argument.""" device = torch.device("cuda:0") diff --git a/tests/test_helpers/test_helpers.py b/tests/test_helpers/test_helpers.py index 1b4ac043bc..a65804998e 100644 --- a/tests/test_helpers/test_helpers.py +++ b/tests/test_helpers/test_helpers.py @@ -1,11 +1,26 @@ import torch import functools +import importlib.util import os from flashinfer.utils import GPUArchitectureError +from flashinfer.aiter_utils import is_aiter_supported import pytest import gc +# is_aiter_supported only checks the GPU arch; AITER is a separate source install, +# so a supported board can still lack the package. Require both so AITER tests skip +# (rather than error/fail) when the arch matches but aiter isn't importable. +_aiter_available = ( + is_aiter_supported(torch.device("cuda:0")) + and importlib.util.find_spec("aiter") is not None +) +requires_aiter = pytest.mark.skipif( + not _aiter_available, + reason="AITER backend requires gfx942/gfx950 and the aiter package", +) + + @functools.cache def get_device_properties(device: torch.device): return torch.cuda.get_device_properties(device) From 917a104ba0882bd004c05805ef43c0d740607b12 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 15 Jun 2026 19:08:13 +0000 Subject: [PATCH 6/6] fix(activation): tidy out= annotation, aiter error message, README cell Address third Copilot review on PR #251: - Type silu_and_mul's out= as Optional[torch.Tensor] to match the enable_pdl: Optional[bool] style. - Make the backend="aiter" arch-check error strictly about the ROCm/arch requirement and include the actual device; the missing-package case is already reported separately by the import probe below. - Rephrase the README Activation matrix cell to the "AITER when ...; else HIP native" pattern used by the other rows. Co-Authored-By: Claude Opus 4.7 --- README.md | 2 +- flashinfer/activation.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a0ed808fce..5d2ec6b22f 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend. | **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | | | **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | | **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) | -| **Activation** | ✅ `native` | ✅ | **HIP `native`**, except `silu_and_mul` resolves to **AITER** for `fp16` + 2-D + `>= 33M`-element inputs; else HIP | SiLU / GELU with fused gating. AITER path (`silu_and_mul` only) is opt-in via `backend="aiter"`; matches native precision in fp16, lower in bf16 | +| **Activation** | ✅ `native` | ✅ | **AITER** for `silu_and_mul` when `fp16` + 2-D + `>= 33M` elements; else **HIP `native`** | SiLU / GELU with fused gating. AITER path (`silu_and_mul` only) is opt-in via `backend="aiter"`; matches native precision in fp16, lower in bf16 | | **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` | | **`torch.compile`** | ✅ (opt-in) | n/a | n/a | Set `FLASHINFER_USE_TORCH_CUSTOM_OPS=1` **before** importing `flashinfer`; requires PyTorch ≥ 2.4. Without it, `torch.compile` raises a clear error if it traces into a flashinfer op | diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 89a963647a..119fd2722a 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -109,7 +109,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: def silu_and_mul( input: torch.Tensor, - out: torch.Tensor = None, + out: Optional[torch.Tensor] = None, enable_pdl: Optional[bool] = None, backend: str = "auto", ) -> torch.Tensor: @@ -157,8 +157,8 @@ def silu_and_mul( if not (IS_HIP and is_aiter_supported(input.device)): raise ValueError( - "backend='aiter' requires a ROCm gfx942/gfx950 device with the " - "aiter package installed." + f"backend='aiter' requires a ROCm gfx942/gfx950 device; got " + f"device {input.device}." ) try: _aiter_act_ops()