feat(hip,aiter): add AITER backend for silu_and_mul#251
Merged
demandal25 merged 6 commits intoJun 15, 2026
Conversation
There was a problem hiding this comment.
Pull request overview
Adds an AMD AITER-backed implementation path for flashinfer.activation.silu_and_mul on ROCm/HIP via a new backend="auto"|"native"|"aiter" parameter, with auto-selection for large fp16 2D inputs and ROCm-only tests covering correctness and selection behavior.
Changes:
- Added HIP-only AITER integration and
backendparameter toflashinfer.activation.silu_and_mul, including an auto-selector with a size cutoff. - Added ROCm tests validating AITER correctness vs a reference,
out=behavior, auto-selection branches, and unknown-backend errors.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| flashinfer/activation.py | Adds backend parameter and HIP AITER routing/auto-selection for silu_and_mul. |
| tests/rocm_tests/test_activation_aiter_hip.py | Adds ROCm-only tests for AITER backend correctness and selection logic. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
demandal25
added a commit
to demandal25/flashinfer
that referenced
this pull request
Jun 15, 2026
…-div Address Copilot review on PR ROCm#251: - Validate the backend argument unconditionally so an unknown value or an explicit backend="aiter" off ROCm/unsupported arch raises ValueError instead of silently falling through to the native kernel. - Use the clearer ceil-to-multiple-of-8 form in the auto-selection test. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
demandal25
added a commit
to demandal25/flashinfer
that referenced
this pull request
Jun 15, 2026
Address second Copilot review on PR ROCm#251: - backend="aiter" now probes _aiter_act_ops() and re-raises a clear ValueError (chaining the original) when the aiter package is missing or fails to import, instead of surfacing a cryptic ImportError at the call. - The out= test seeds the tensor with NaN and asserts numerical correctness against the reference, so a no-op write can no longer pass. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Route flashinfer.activation.silu_and_mul through AMD AITER's silu_and_mul on ROCm via a backend="auto"|"native"|"aiter" parameter, mirroring the existing norm.py AITER-backend idiom. "auto" stays on the native JIT kernel except for large (>=64M element) 2D fp16 inputs, where AITER is ~5-10% faster and matches native precision. bf16 is excluded from the auto path (AITER max err ~6e-2 vs native ~4e-3); "aiter" remains available as an explicit opt-in. Adds tests/rocm_tests/test_activation_aiter_hip.py covering correctness across shapes/dtypes, out= handling, backend auto-selection, and the unknown-backend error. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…-div Address Copilot review on PR ROCm#251: - Validate the backend argument unconditionally so an unknown value or an explicit backend="aiter" off ROCm/unsupported arch raises ValueError instead of silently falling through to the native kernel. - Use the clearer ceil-to-multiple-of-8 form in the auto-selection test. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…t in README Set the auto-selection threshold to the measured ~33M-element break-even (was a conservative 64M). Update the README feature matrix and AITER Support section to list silu_and_mul's AITER backend and its auto-routing criteria. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Address second Copilot review on PR ROCm#251: - backend="aiter" now probes _aiter_act_ops() and re-raises a clear ValueError (chaining the original) when the aiter package is missing or fails to import, instead of surfacing a cryptic ImportError at the call. - The out= test seeds the tensor with NaN and asserts numerical correctness against the reference, so a no-op write can no longer pass. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
4259b0b to
f5ef7a4
Compare
Add requires_aiter to tests/test_helpers/test_helpers.py (gating on arch + aiter importability) and import it from every AITER rocm test, replacing the per-file copies of the @pytest.mark.skipif(not is_aiter_supported...) decorator. One definition, no duplicates. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Address third Copilot review on PR ROCm#251: - Type silu_and_mul's out= as Optional[torch.Tensor] to match the enable_pdl: Optional[bool] style. - Make the backend="aiter" arch-check error strictly about the ROCm/arch requirement and include the actual device; the missing-package case is already reported separately by the import probe below. - Rephrase the README Activation matrix cell to the "AITER when ...; else HIP native" pattern used by the other rows. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Routes
flashinfer.activation.silu_and_multhrough AMD AITER'ssilu_and_mulon ROCm via a new
backend="auto"|"native"|"aiter"parameter, mirroring theexisting AITER-backend idiom in
norm.py.What changed
flashinfer/activation.py—silu_and_mulgains abackendparam.On ROCm, a
_auto_select_silu_and_mul_backendselector routes large 2D fp16inputs to AITER; everything else stays on the native JIT kernel.
"aiter"is available as an explicit opt-in;
"native"forces the JIT kernel. Thebackendargument is validated on all platforms (unknown values and anoff-ROCm/unsupported
"aiter"raiseValueErrorrather than silentlyfalling back).
tests/rocm_tests/test_activation_aiter_hip.py— new tests: correctnessvs reference across shapes/dtypes,
out=handling, auto-selection branches,unknown-backend error, and unsupported-
aiterrejection.README.md— feature matrix + AITER Support section updated for the newsilu_and_mulbackend.Architecture / design notes
autobackend selection (ROCm only):The cutoff (
33 * 1024 * 1024input elements, i.e. rows x 2*hidden) is themeasured break-even, e.g. 2048 x 16384.
Benchmark results
silu_and_mul, gfx942, CUDA-event timed (20 warmup / 200 iters):
Test plan
pytest tests/rocm_tests/test_activation_hip.py tests/rocm_tests/test_activation_aiter_hip.py -m "not slow"-> 94 passedFLASHINFER_TEST_TORCH_COMPILE=1 pytest tests/rocm_tests/test_torch_compile_hip.py-> 3 passed, 1 skippedpre-commit run -a(changed files: ruff + markdownlint)