diff --git a/README.md b/README.md index de3d3b3b3e..5d2ec6b22f 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend. | **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | | | **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | | **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) | -| **Activation** | ✅ | — | HIP | SiLU / GELU with fused gating | +| **Activation** | ✅ `native` | ✅ | **AITER** for `silu_and_mul` when `fp16` + 2-D + `>= 33M` elements; else **HIP `native`** | SiLU / GELU with fused gating. AITER path (`silu_and_mul` only) is opt-in via `backend="aiter"`; matches native precision in fp16, lower in bf16 | | **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` | | **`torch.compile`** | ✅ (opt-in) | n/a | n/a | Set `FLASHINFER_USE_TORCH_CUSTOM_OPS=1` **before** importing `flashinfer`; requires PyTorch ≥ 2.4. Without it, `torch.compile` raises a clear error if it traces into a flashinfer op | @@ -309,7 +309,7 @@ pytest -n auto --reruns 2 -m "slow" FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill` (paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`, -`fused_add_rmsnorm`, and `MLA` paths to +`fused_add_rmsnorm`, `silu_and_mul`, and `MLA` paths to [AITER](https://github.com/ROCm/aiter). MLA on ROCm is **AITER-only** — there is no in-tree HIP MLA kernel yet, so `backend="auto"` (the default for the MLA wrapper) resolves directly @@ -322,14 +322,19 @@ a one-time `logger.warning`. Pass `backend="aiter"` to require AITER explicitly, or pass the in-tree backend string to skip it: `backend="fa2"` for the attention wrappers (single/batch prefill/decode), `backend="native"` for non-attention ops -(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`). Three -backend-specific exceptions to "auto picks AITER when supported": +(`append_paged_kv_cache`, `rmsnorm`, `fused_add_rmsnorm`, +`silu_and_mul`). Four backend-specific exceptions to "auto picks AITER +when supported": * `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the AITER path is opt-in via `backend="aiter"`. * `fused_add_rmsnorm`: `backend="auto"` is shape-gated — it picks AITER only for 2-D inputs with `>= 4M` elements (where the CK kernel is faster) and stays on the HIP `native` kernel otherwise. +* `silu_and_mul`: `backend="auto"` picks AITER only for `fp16` + 2-D + + `>= 33M`-element inputs (where it is faster and matches native + precision) and otherwise stays on HIP `native`; the AITER path is also + available explicitly via `backend="aiter"`. * `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True` force `auto` back to `fa2` (AITER decode does not support either), and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`. diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 2813ae6ad9..119fd2722a 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -20,7 +20,7 @@ import torch -from .device_utils import IS_CUDA +from .device_utils import IS_CUDA, IS_HIP from .jit import gen_act_and_mul_module from .utils import ( device_support_pdl, @@ -33,6 +33,44 @@ from .fp4_quantization import get_fp4_quantization_module +if IS_HIP: + + @functools.cache + def _aiter_act_ops(): + import aiter as _aiter + + return _aiter + + # AITER's silu_and_mul only overtakes the native kernel on large, + # bandwidth-bound shapes (~5-10% faster); below that a fixed ~0.7us launch + # overhead makes it slower. It also matches the native kernel's precision + # only in fp16 (bf16 is ~6e-2 vs ~4e-3 max err). The cutoff counts elements + # of the full input (rows x 2*hidden); ~33M is the measured break-even + # (e.g. 2048 x 16384). fp16 only. + _AITER_SILU_AND_MUL_MIN_ELEMS = 33 * 1024 * 1024 + + def _auto_select_silu_and_mul_backend(input: torch.Tensor) -> str: + # Cheapest guards first so the common small/medium case exits early. + if input.dtype != torch.float16: + return "native" + if input.ndim != 2: + return "native" + if input.numel() < _AITER_SILU_AND_MUL_MIN_ELEMS: + return "native" + from .aiter_utils import is_aiter_supported + + if not is_aiter_supported(input.device): + return "native" + try: + # Best-effort probe: a supported arch can still lack a usable aiter + # (not installed, or its compiled extension fails to load). auto must + # always be able to fall back to native, so catch any import failure. + _aiter_act_ops() + except Exception: + return "native" + return "aiter" + + @functools.cache def get_act_and_mul_module(act_func_name: str): module = gen_act_and_mul_module(act_func_name).build_and_load() @@ -70,7 +108,10 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: def silu_and_mul( - input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None + input: torch.Tensor, + out: Optional[torch.Tensor] = None, + enable_pdl: Optional[bool] = None, + backend: str = "auto", ) -> torch.Tensor: r"""Fused SiLU and Mul operation. @@ -88,13 +129,44 @@ def silu_and_mul( Whether to enable `programmatic dependent launch `_ + backend: str + Kernel backend to use. ``"auto"`` (default) uses the native kernel for small + and medium inputs, and switches to AITER on ROCm for large (>= 33M element) + 2D fp16 inputs where its kernel is faster and matches native precision; it + falls back to native whenever AITER is unavailable. + ``"native"`` uses the FlashInfer JIT kernel on all platforms. + ``"aiter"`` uses AMD AITER's ``silu_and_mul`` — ROCm (gfx942/gfx950) only; + requires the ``aiter`` package, and raises ``ValueError`` on any other + platform. Precision matches ``"native"`` in fp16 but is lower in bf16 + (max err ~6e-2 vs ~4e-3), which is why ``"auto"`` restricts the AITER path + to fp16. + Returns ------- output: torch.Tensor Output tensor, shape (..., hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + if backend not in ("auto", "native", "aiter"): + raise ValueError( + f"Unknown backend {backend!r}; expected one of 'auto', 'native', 'aiter'." + ) + if backend == "aiter": + # Validate the explicit opt-in on every platform so a misconfiguration + # surfaces here instead of silently running native off ROCm. + from .aiter_utils import is_aiter_supported + + if not (IS_HIP and is_aiter_supported(input.device)): + raise ValueError( + f"backend='aiter' requires a ROCm gfx942/gfx950 device; got " + f"device {input.device}." + ) + try: + _aiter_act_ops() + except Exception as e: + raise ValueError( + "backend='aiter' requires the aiter package, which failed to " + f"import: {e}" + ) from e if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: @@ -105,6 +177,15 @@ def silu_and_mul( device=input.device, dtype=input.dtype, ) + if IS_HIP: + _backend = ( + backend if backend != "auto" else _auto_select_silu_and_mul_backend(input) + ) + if _backend == "aiter": + _aiter_act_ops().silu_and_mul(out, input) + return out + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) get_act_and_mul_module("silu").silu_and_mul( out, input, diff --git a/tests/rocm_tests/test_activation_aiter_hip.py b/tests/rocm_tests/test_activation_aiter_hip.py new file mode 100644 index 0000000000..c86dc3054c --- /dev/null +++ b/tests/rocm_tests/test_activation_aiter_hip.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Tests for the AITER silu_and_mul backend exposed via +# flashinfer.activation.silu_and_mul(backend="aiter"). +# +# Note on tolerances: AITER silu_and_mul matches the native JIT kernel exactly in +# fp16, but uses lower-precision arithmetic in bf16 (max err ~6e-2 vs the native +# kernel's ~4e-3). The tolerances below reflect AITER's actual precision. + +import pytest +import torch + +import flashinfer +from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter + + +def _silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x_f32 = x.float() + gate, up = x_f32[..., :d], x_f32[..., d:] + return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype) + + +@requires_aiter +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("d", [128, 512, 4096, 8192, 14336]) +@pytest.mark.parametrize("num_tokens", [1, 8, 256]) +def test_silu_and_mul_aiter_vs_ref(dtype, d, num_tokens): + torch.manual_seed(0xA17E2) + device = torch.device("cuda:0") + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device) + + ref = _silu_and_mul_ref(x) + got = flashinfer.activation.silu_and_mul(x, backend="aiter") + + # AITER precision: fp16 matches native; bf16 ~6e-2 observed across shapes. + rtol, atol = (7e-2, 7e-2) if dtype == torch.bfloat16 else (1e-3, 1e-3) + torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) + + +@requires_aiter +def test_silu_and_mul_auto_backend_selection(): + """auto stays native for small/bf16 inputs and picks AITER for large fp16 2D inputs.""" + from flashinfer.activation import ( + _AITER_SILU_AND_MUL_MIN_ELEMS, + _auto_select_silu_and_mul_backend, + ) + + device = torch.device("cuda:0") + + # Small inputs: native regardless of dtype. + for dtype in (torch.float16, torch.bfloat16): + small = torch.empty(8, 256, dtype=dtype, device=device) + assert _auto_select_silu_and_mul_backend(small) == "native" + + # Large enough fp16 2D input (>= cutoff). cols is a multiple of 8 so the + # shape also clears silu_and_mul's 16-byte alignment guard, i.e. it is a + # shape that could actually flow through the public function to AITER. + rows = 8192 + cols = -(-_AITER_SILU_AND_MUL_MIN_ELEMS // (rows * 8)) * 8 # ceil to mult of 8 + large_fp16 = torch.empty(rows, cols, dtype=torch.float16, device=device) + assert large_fp16.numel() >= _AITER_SILU_AND_MUL_MIN_ELEMS + assert _auto_select_silu_and_mul_backend(large_fp16) == "aiter" + + # Same large shape in bf16: native (precision guard). + large_bf16 = torch.empty(rows, cols, dtype=torch.bfloat16, device=device) + assert _auto_select_silu_and_mul_backend(large_bf16) == "native" + + # Large fp16 but 3D: native (2D guard). + large_3d = torch.empty(2, rows, cols, dtype=torch.float16, device=device) + assert _auto_select_silu_and_mul_backend(large_3d) == "native" + + +@requires_aiter +def test_silu_and_mul_aiter_with_out_tensor(): + """backend='aiter' writes the correct result into the supplied out= tensor.""" + device = torch.device("cuda:0") + x = torch.randn(8, 256, dtype=torch.float16, device=device) + # Seed out with a sentinel the kernel must overwrite, so a no-op write fails. + out = torch.full((8, 128), float("nan"), dtype=torch.float16, device=device) + ret = flashinfer.activation.silu_and_mul(x, out=out, backend="aiter") + assert ret.data_ptr() == out.data_ptr() + ref = _silu_and_mul_ref(x) + torch.testing.assert_close(out.float(), ref.float(), rtol=1e-3, atol=1e-3) + + +def test_silu_and_mul_unknown_backend_raises(): + # Backend validation is platform-independent, so this needs no aiter device. + x = torch.randn(8, 256, dtype=torch.float16) + with pytest.raises(ValueError, match="Unknown backend"): + flashinfer.activation.silu_and_mul(x, backend="nope") + + +def test_silu_and_mul_aiter_backend_rejected_when_unsupported(): + """Explicit backend='aiter' raises (not silently falls back) on an unsupported device.""" + cpu_x = torch.randn(8, 256, dtype=torch.float16) + if not is_aiter_supported(cpu_x.device): + with pytest.raises(ValueError, match="requires a ROCm"): + flashinfer.activation.silu_and_mul(cpu_x, backend="aiter") diff --git a/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py b/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py index dc8df004f3..b0ff18ed81 100644 --- a/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py +++ b/tests/rocm_tests/test_append_paged_kv_cache_aiter_hip.py @@ -10,7 +10,7 @@ import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter from flashinfer.jit.core import logger logger.setLevel(logging.ERROR) @@ -63,10 +63,7 @@ def _build_append_inputs(append_lens, page_size, num_kv_heads, head_dim, dtype, ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("page_size", [16, 32]) @pytest.mark.parametrize("num_kv_heads,head_dim", [(4, 64), (8, 128), (16, 128)]) @@ -133,10 +130,7 @@ def test_append_paged_kv_cache_aiter_vs_native( torch.testing.assert_close(v_aiter, v_native, rtol=0, atol=0) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_append_paged_kv_cache_aiter_auto_routes_on_nhd_fp16(): """auto backend should pick aiter when device + dtype + layout match constraints.""" from flashinfer.page import _auto_select_kv_append_backend diff --git a/tests/rocm_tests/test_batch_decode_aiter_hip.py b/tests/rocm_tests/test_batch_decode_aiter_hip.py index d59467b947..63def3f2c7 100644 --- a/tests/rocm_tests/test_batch_decode_aiter_hip.py +++ b/tests/rocm_tests/test_batch_decode_aiter_hip.py @@ -9,7 +9,7 @@ import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter from flashinfer.jit.core import logger logger.setLevel(logging.ERROR) @@ -57,10 +57,7 @@ def _build_paged_kv( ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("batch_size", [1, 4, 17]) @pytest.mark.parametrize("page_size", [16, 32]) @@ -127,10 +124,7 @@ def test_batch_decode_aiter_vs_fa2( torch.testing.assert_close(o_cand.float(), o_ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_batch_decode_aiter_rejects_invalid_config(): """plan() should reject unsupported configs with a clear error.""" device = torch.device("cuda:0") @@ -217,10 +211,7 @@ def test_batch_decode_aiter_rejects_invalid_config(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("page_size", [16]) @@ -295,10 +286,7 @@ def test_batch_decode_aiter_sliding_window_vs_fa2( torch.testing.assert_close(o_cand.float(), o_ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("window_left", [-1, 31]) def test_batch_decode_aiter_return_lse_via_fa2(dtype, window_left): @@ -367,10 +355,7 @@ def test_batch_decode_aiter_return_lse_via_fa2(dtype, window_left): torch.testing.assert_close(lse_cand, lse_ref, rtol=1e-3, atol=1e-3) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_batch_decode_auto_routes_cuda_graph_to_fa2(): """backend='auto' with use_cuda_graph=True must route to fa2 (AITER doesn't support graph capture).""" diff --git a/tests/rocm_tests/test_mla_aiter_hip.py b/tests/rocm_tests/test_mla_aiter_hip.py index b0b84c6111..1a8e9be0d9 100644 --- a/tests/rocm_tests/test_mla_aiter_hip.py +++ b/tests/rocm_tests/test_mla_aiter_hip.py @@ -9,7 +9,7 @@ import pytest import torch -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter def _paged_mla_ref( @@ -101,10 +101,7 @@ def _build_paged_kv( return ckv_cache, kpe_cache, kv_indptr, kv_indices, kv_last_page_len -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("page_size", [1]) @pytest.mark.parametrize("num_heads,head_dim_ckv,head_dim_kpe", [(16, 512, 64)]) @@ -185,10 +182,7 @@ def test_mla_decode_vs_ref( torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_decode_out_tensor(): """run() respects a pre-allocated out= tensor.""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -237,10 +231,7 @@ def test_mla_decode_out_tensor(): assert not torch.all(out == 0) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_plan_validation(): """plan() raises on invalid arguments.""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -278,10 +269,7 @@ def test_mla_plan_validation(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_plan_kv_len_inconsistent_with_paging(): """Passing last-page counts as kv_len_arr must fail (was accepted pre-conversion).""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -308,10 +296,7 @@ def test_mla_plan_kv_len_inconsistent_with_paging(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_mla_run_before_plan_raises(): """run() before plan() raises RuntimeError.""" from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper @@ -328,10 +313,7 @@ def test_mla_run_before_plan_raises(): ) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("backend", ["auto", "aiter"]) def test_mla_backend_accepts_auto_and_aiter(backend): """The ROCm MLA wrapper accepts both 'auto' (default) and 'aiter'. diff --git a/tests/rocm_tests/test_norm_hip.py b/tests/rocm_tests/test_norm_hip.py index 6d879b82f5..fd92d985da 100644 --- a/tests/rocm_tests/test_norm_hip.py +++ b/tests/rocm_tests/test_norm_hip.py @@ -15,25 +15,11 @@ limitations under the License. """ -import importlib.util - import pytest import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported - -# is_aiter_supported only checks the GPU arch; AITER is a separate source install, -# so a supported board can still lack the package. Require both so these tests skip -# (rather than error/fail) when the arch matches but aiter isn't importable. -_aiter_available = ( - is_aiter_supported(torch.device("cuda:0")) - and importlib.util.find_spec("aiter") is not None -) -requires_aiter = pytest.mark.skipif( - not _aiter_available, - reason="AITER backend requires gfx942/gfx950 and the aiter package", -) +from tests.test_helpers.test_helpers import requires_aiter def llama_rms_norm(x, w, eps=1e-6): diff --git a/tests/rocm_tests/test_rmsnorm_aiter_hip.py b/tests/rocm_tests/test_rmsnorm_aiter_hip.py index 841f065d60..9105cfa23d 100644 --- a/tests/rocm_tests/test_rmsnorm_aiter_hip.py +++ b/tests/rocm_tests/test_rmsnorm_aiter_hip.py @@ -12,7 +12,7 @@ import torch import flashinfer -from flashinfer.aiter_utils import is_aiter_supported +from tests.test_helpers.test_helpers import requires_aiter def _rms_norm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: @@ -23,10 +23,7 @@ def _rms_norm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch. return (x * torch.rsqrt(variance + eps) * w.float()).to(orig) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [128, 512, 1024, 4096]) @pytest.mark.parametrize("batch_size", [1, 32, 256]) @@ -44,10 +41,7 @@ def test_rmsnorm_aiter_vs_ref(dtype, hidden_size, batch_size): torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_rmsnorm_auto_backend_stays_native(): """auto backend on gfx942/950 should stay on native kernel (precision parity with tests).""" from flashinfer.norm import _auto_select_norm_backend @@ -59,10 +53,7 @@ def test_rmsnorm_auto_backend_stays_native(): assert _auto_select_norm_backend(device, torch.float32) == "native" -@pytest.mark.skipif( - not is_aiter_supported(torch.device("cuda:0")), - reason="AITER backend requires gfx942/gfx950", -) +@requires_aiter def test_rmsnorm_aiter_with_out_tensor(): """backend='aiter' respects the out= argument.""" device = torch.device("cuda:0") diff --git a/tests/test_helpers/test_helpers.py b/tests/test_helpers/test_helpers.py index 1b4ac043bc..a65804998e 100644 --- a/tests/test_helpers/test_helpers.py +++ b/tests/test_helpers/test_helpers.py @@ -1,11 +1,26 @@ import torch import functools +import importlib.util import os from flashinfer.utils import GPUArchitectureError +from flashinfer.aiter_utils import is_aiter_supported import pytest import gc +# is_aiter_supported only checks the GPU arch; AITER is a separate source install, +# so a supported board can still lack the package. Require both so AITER tests skip +# (rather than error/fail) when the arch matches but aiter isn't importable. +_aiter_available = ( + is_aiter_supported(torch.device("cuda:0")) + and importlib.util.find_spec("aiter") is not None +) +requires_aiter = pytest.mark.skipif( + not _aiter_available, + reason="AITER backend requires gfx942/gfx950 and the aiter package", +) + + @functools.cache def get_device_properties(device: torch.device): return torch.cuda.get_device_properties(device)