Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend.
| **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | |
| **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits |
| **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) |
| **Activation** | ✅ | — | HIP | SiLU / GELU with fused gating |
| **Activation** | ✅ `native` | ✅ | **AITER** for `silu_and_mul` when `fp16` + 2-D + `>= 33M` elements; else **HIP `native`** | SiLU / GELU with fused gating. AITER path (`silu_and_mul` only) is opt-in via `backend="aiter"`; matches native precision in fp16, lower in bf16 |
| **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` |
| **`torch.compile`** | ✅ (opt-in) | n/a | n/a | Set `FLASHINFER_USE_TORCH_CUSTOM_OPS=1` **before** importing `flashinfer`; requires PyTorch ≥ 2.4. Without it, `torch.compile` raises a clear error if it traces into a flashinfer op |

Expand Down Expand Up @@ -309,7 +309,7 @@ pytest -n auto --reruns 2 -m "slow"

FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill`
(paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`,
`fused_add_rmsnorm`, and `MLA` paths to
`fused_add_rmsnorm`, `silu_and_mul`, and `MLA` paths to
[AITER](https://github.com/ROCm/aiter). MLA on ROCm
is **AITER-only** — there is no in-tree HIP MLA kernel yet, so
`backend="auto"` (the default for the MLA wrapper) resolves directly
Expand All @@ -322,14 +322,19 @@ a one-time `logger.warning`. Pass `backend="aiter"` to require AITER
explicitly, or pass the in-tree backend string to skip it:
`backend="fa2"` for the attention wrappers (single/batch
prefill/decode), `backend="native"` for non-attention ops
(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`). Three
backend-specific exceptions to "auto picks AITER when supported":
(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`,
`silu_and_mul`). Four backend-specific exceptions to "auto picks AITER
when supported":

* `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the
AITER path is opt-in via `backend="aiter"`.
* `fused_add_rmsnorm`: `backend="auto"` is shape-gated — it picks AITER
only for 2-D inputs with `>= 4M` elements (where the CK kernel is
faster) and stays on the HIP `native` kernel otherwise.
* `silu_and_mul`: `backend="auto"` picks AITER only for `fp16` + 2-D +
`>= 33M`-element inputs (where it is faster and matches native
precision) and otherwise stays on HIP `native`; the AITER path is also
available explicitly via `backend="aiter"`.
* `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True`
force `auto` back to `fa2` (AITER decode does not support either),
and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`.
Expand Down
89 changes: 85 additions & 4 deletions flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch

from .device_utils import IS_CUDA
from .device_utils import IS_CUDA, IS_HIP
from .jit import gen_act_and_mul_module
from .utils import (
device_support_pdl,
Expand All @@ -33,6 +33,44 @@
from .fp4_quantization import get_fp4_quantization_module


if IS_HIP:

@functools.cache
def _aiter_act_ops():
import aiter as _aiter

return _aiter

# AITER's silu_and_mul only overtakes the native kernel on large,
# bandwidth-bound shapes (~5-10% faster); below that a fixed ~0.7us launch
# overhead makes it slower. It also matches the native kernel's precision
# only in fp16 (bf16 is ~6e-2 vs ~4e-3 max err). The cutoff counts elements
# of the full input (rows x 2*hidden); ~33M is the measured break-even
# (e.g. 2048 x 16384). fp16 only.
_AITER_SILU_AND_MUL_MIN_ELEMS = 33 * 1024 * 1024

def _auto_select_silu_and_mul_backend(input: torch.Tensor) -> str:
# Cheapest guards first so the common small/medium case exits early.
if input.dtype != torch.float16:
return "native"
if input.ndim != 2:
return "native"
if input.numel() < _AITER_SILU_AND_MUL_MIN_ELEMS:
return "native"
from .aiter_utils import is_aiter_supported

if not is_aiter_supported(input.device):
return "native"
try:
# Best-effort probe: a supported arch can still lack a usable aiter
# (not installed, or its compiled extension fails to load). auto must
# always be able to fall back to native, so catch any import failure.
_aiter_act_ops()
except Exception:
return "native"
return "aiter"


@functools.cache
def get_act_and_mul_module(act_func_name: str):
module = gen_act_and_mul_module(act_func_name).build_and_load()
Expand Down Expand Up @@ -70,7 +108,10 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:


def silu_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
input: torch.Tensor,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
backend: str = "auto",
) -> torch.Tensor:
r"""Fused SiLU and Mul operation.

Expand All @@ -88,13 +129,44 @@ def silu_and_mul(
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_

backend: str
Kernel backend to use. ``"auto"`` (default) uses the native kernel for small
and medium inputs, and switches to AITER on ROCm for large (>= 33M element)
2D fp16 inputs where its kernel is faster and matches native precision; it
falls back to native whenever AITER is unavailable.
``"native"`` uses the FlashInfer JIT kernel on all platforms.
``"aiter"`` uses AMD AITER's ``silu_and_mul`` — ROCm (gfx942/gfx950) only;
requires the ``aiter`` package, and raises ``ValueError`` on any other
platform. Precision matches ``"native"`` in fp16 but is lower in bf16
(max err ~6e-2 vs ~4e-3), which is why ``"auto"`` restricts the AITER path
to fp16.

Returns
-------
output: torch.Tensor
Output tensor, shape (..., hidden_size).
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
if backend not in ("auto", "native", "aiter"):
raise ValueError(
f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'."
)
if backend == "aiter":
# Validate the explicit opt-in on every platform so a misconfiguration
# surfaces here instead of silently running native off ROCm.
from .aiter_utils import is_aiter_supported

if not (IS_HIP and is_aiter_supported(input.device)):
raise ValueError(
f"backend='aiter' requires a ROCm gfx942/gfx950 device; got "
f"device {input.device}."
)
try:
_aiter_act_ops()
except Exception as e:
raise ValueError(
"backend='aiter' requires the aiter package, which failed to "
f"import: {e}"
) from e
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
Expand All @@ -105,6 +177,15 @@ def silu_and_mul(
device=input.device,
dtype=input.dtype,
)
if IS_HIP:
_backend = (
backend if backend != "auto" else _auto_select_silu_and_mul_backend(input)
)
if _backend == "aiter":
_aiter_act_ops().silu_and_mul(out, input)
return out
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_act_and_mul_module("silu").silu_and_mul(
out,
input,
Expand Down
101 changes: 101 additions & 0 deletions tests/rocm_tests/test_activation_aiter_hip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Tests for the AITER silu_and_mul backend exposed via
# flashinfer.activation.silu_and_mul(backend="aiter").
#
# Note on tolerances: AITER silu_and_mul matches the native JIT kernel exactly in
# fp16, but uses lower-precision arithmetic in bf16 (max err ~6e-2 vs the native
# kernel's ~4e-3). The tolerances below reflect AITER's actual precision.

import pytest
import torch

import flashinfer
from flashinfer.aiter_utils import is_aiter_supported
from tests.test_helpers.test_helpers import requires_aiter


def _silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x_f32 = x.float()
gate, up = x_f32[..., :d], x_f32[..., d:]
return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype)


@requires_aiter
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("d", [128, 512, 4096, 8192, 14336])
@pytest.mark.parametrize("num_tokens", [1, 8, 256])
def test_silu_and_mul_aiter_vs_ref(dtype, d, num_tokens):
torch.manual_seed(0xA17E2)
device = torch.device("cuda:0")
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device)

ref = _silu_and_mul_ref(x)
got = flashinfer.activation.silu_and_mul(x, backend="aiter")

# AITER precision: fp16 matches native; bf16 ~6e-2 observed across shapes.
rtol, atol = (7e-2, 7e-2) if dtype == torch.bfloat16 else (1e-3, 1e-3)
torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol)


@requires_aiter
def test_silu_and_mul_auto_backend_selection():
"""auto stays native for small/bf16 inputs and picks AITER for large fp16 2D inputs."""
from flashinfer.activation import (
_AITER_SILU_AND_MUL_MIN_ELEMS,
_auto_select_silu_and_mul_backend,
)

device = torch.device("cuda:0")

# Small inputs: native regardless of dtype.
for dtype in (torch.float16, torch.bfloat16):
small = torch.empty(8, 256, dtype=dtype, device=device)
assert _auto_select_silu_and_mul_backend(small) == "native"

# Large enough fp16 2D input (>= cutoff). cols is a multiple of 8 so the
# shape also clears silu_and_mul's 16-byte alignment guard, i.e. it is a
# shape that could actually flow through the public function to AITER.
rows = 8192
cols = -(-_AITER_SILU_AND_MUL_MIN_ELEMS // (rows * 8)) * 8 # ceil to mult of 8
large_fp16 = torch.empty(rows, cols, dtype=torch.float16, device=device)
assert large_fp16.numel() >= _AITER_SILU_AND_MUL_MIN_ELEMS
assert _auto_select_silu_and_mul_backend(large_fp16) == "aiter"

# Same large shape in bf16: native (precision guard).
large_bf16 = torch.empty(rows, cols, dtype=torch.bfloat16, device=device)
assert _auto_select_silu_and_mul_backend(large_bf16) == "native"

# Large fp16 but 3D: native (2D guard).
large_3d = torch.empty(2, rows, cols, dtype=torch.float16, device=device)
assert _auto_select_silu_and_mul_backend(large_3d) == "native"


@requires_aiter
def test_silu_and_mul_aiter_with_out_tensor():
"""backend='aiter' writes the correct result into the supplied out= tensor."""
device = torch.device("cuda:0")
x = torch.randn(8, 256, dtype=torch.float16, device=device)
# Seed out with a sentinel the kernel must overwrite, so a no-op write fails.
out = torch.full((8, 128), float("nan"), dtype=torch.float16, device=device)
ret = flashinfer.activation.silu_and_mul(x, out=out, backend="aiter")
assert ret.data_ptr() == out.data_ptr()
ref = _silu_and_mul_ref(x)
torch.testing.assert_close(out.float(), ref.float(), rtol=1e-3, atol=1e-3)


def test_silu_and_mul_unknown_backend_raises():
# Backend validation is platform-independent, so this needs no aiter device.
x = torch.randn(8, 256, dtype=torch.float16)
with pytest.raises(ValueError, match="Unknown backend"):
flashinfer.activation.silu_and_mul(x, backend="nope")


def test_silu_and_mul_aiter_backend_rejected_when_unsupported():
"""Explicit backend='aiter' raises (not silently falls back) on an unsupported device."""
cpu_x = torch.randn(8, 256, dtype=torch.float16)
if not is_aiter_supported(cpu_x.device):
with pytest.raises(ValueError, match="requires a ROCm"):
flashinfer.activation.silu_and_mul(cpu_x, backend="aiter")
12 changes: 3 additions & 9 deletions tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

import flashinfer
from flashinfer.aiter_utils import is_aiter_supported
from tests.test_helpers.test_helpers import requires_aiter
from flashinfer.jit.core import logger

logger.setLevel(logging.ERROR)
Expand Down Expand Up @@ -63,10 +63,7 @@ def _build_append_inputs(append_lens, page_size, num_kv_heads, head_dim, dtype,
)


@pytest.mark.skipif(
not is_aiter_supported(torch.device("cuda:0")),
reason="AITER backend requires gfx942/gfx950",
)
@requires_aiter
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("page_size", [16, 32])
@pytest.mark.parametrize("num_kv_heads,head_dim", [(4, 64), (8, 128), (16, 128)])
Expand Down Expand Up @@ -133,10 +130,7 @@ def test_append_paged_kv_cache_aiter_vs_native(
torch.testing.assert_close(v_aiter, v_native, rtol=0, atol=0)


@pytest.mark.skipif(
not is_aiter_supported(torch.device("cuda:0")),
reason="AITER backend requires gfx942/gfx950",
)
@requires_aiter
def test_append_paged_kv_cache_aiter_auto_routes_on_nhd_fp16():
"""auto backend should pick aiter when device + dtype + layout match constraints."""
from flashinfer.page import _auto_select_kv_append_backend
Expand Down
27 changes: 6 additions & 21 deletions tests/rocm_tests/test_batch_decode_aiter_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

import flashinfer
from flashinfer.aiter_utils import is_aiter_supported
from tests.test_helpers.test_helpers import requires_aiter
from flashinfer.jit.core import logger

logger.setLevel(logging.ERROR)
Expand Down Expand Up @@ -57,10 +57,7 @@ def _build_paged_kv(
)


@pytest.mark.skipif(
not is_aiter_supported(torch.device("cuda:0")),
reason="AITER backend requires gfx942/gfx950",
)
@requires_aiter
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("batch_size", [1, 4, 17])
@pytest.mark.parametrize("page_size", [16, 32])
Expand Down Expand Up @@ -127,10 +124,7 @@ def test_batch_decode_aiter_vs_fa2(
torch.testing.assert_close(o_cand.float(), o_ref.float(), rtol=rtol, atol=atol)


@pytest.mark.skipif(
not is_aiter_supported(torch.device("cuda:0")),
reason="AITER backend requires gfx942/gfx950",
)
@requires_aiter
def test_batch_decode_aiter_rejects_invalid_config():
"""plan() should reject unsupported configs with a clear error."""
device = torch.device("cuda:0")
Expand Down Expand Up @@ -217,10 +211,7 @@ def test_batch_decode_aiter_rejects_invalid_config():
)


@pytest.mark.skipif(
not is_aiter_supported(torch.device("cuda:0")),
reason="AITER backend requires gfx942/gfx950",
)
@requires_aiter
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("page_size", [16])
Expand Down Expand Up @@ -295,10 +286,7 @@ def test_batch_decode_aiter_sliding_window_vs_fa2(
torch.testing.assert_close(o_cand.float(), o_ref.float(), rtol=rtol, atol=atol)


@pytest.mark.skipif(
not is_aiter_supported(torch.device("cuda:0")),
reason="AITER backend requires gfx942/gfx950",
)
@requires_aiter
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("window_left", [-1, 31])
def test_batch_decode_aiter_return_lse_via_fa2(dtype, window_left):
Expand Down Expand Up @@ -367,10 +355,7 @@ def test_batch_decode_aiter_return_lse_via_fa2(dtype, window_left):
torch.testing.assert_close(lse_cand, lse_ref, rtol=1e-3, atol=1e-3)


@pytest.mark.skipif(
not is_aiter_supported(torch.device("cuda:0")),
reason="AITER backend requires gfx942/gfx950",
)
@requires_aiter
def test_batch_decode_auto_routes_cuda_graph_to_fa2():
"""backend='auto' with use_cuda_graph=True must route to fa2 (AITER doesn't
support graph capture)."""
Expand Down
Loading
Loading