[runtime] feat: support variant mode#17
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a "variant" mode for Flash Attention, allowing for dynamic num_splits selection to improve throughput when batch-invariance is not required. The changes include new CLI arguments, configuration validation, and logic within the inferencer to toggle batch-invariant patches. Review feedback suggests extending this support to Blackwell hardware by adding a fa-variant-cute implementation and recommends explicitly disabling batch-invariant mode when the variant path is selected to ensure consistent global state.
|
|
||
| # Non-invariant variant: same kernel but without the num_splits=1 lock, | ||
| # so the kernel is free to pick Split-KV counts for best throughput. | ||
| flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False) |
There was a problem hiding this comment.
The fa-variant implementation currently defaults to use_cute=False, which will cause a crash on Blackwell (SM100) hardware due to the assertion at line 167 (which requires SM90 for the non-cute path). Adding a fa-variant-cute implementation allows variant mode to be used on newer hardware, following the existing pattern for invariant mode.
| flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False) | |
| flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False) | |
| flash_attention_forward_variant_cute = partial(flash_attention_forward, lock_num_splits=False, use_cute=True) |
| choices=["fa-invariant", "fa-invariant-cute", "flex"], | ||
| default="fa-invariant", | ||
| help="Attention implementation (default: fa-invariant, i.e. flash attn).", | ||
| choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"], |
There was a problem hiding this comment.
Add fa-variant-cute to the choices to support variant mode on Blackwell (SM100) hardware, consistent with the fa-invariant-cute option.
| choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"], | |
| choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant", "fa-variant-cute"], |
| from .flash_attention import ( | ||
| flash_attention_forward, | ||
| flash_attention_forward_cute, | ||
| flash_attention_forward_variant, |
| "AttentionBlockSize", | ||
| "flash_attention_forward", | ||
| "flash_attention_forward_cute", | ||
| "flash_attention_forward_variant", |
|
|
||
| # Validate attention implementation | ||
| valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex"] | ||
| valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"] |
There was a problem hiding this comment.
| ) | ||
| from vexact.batch_invariant_ops import flash_attention_forward as flash_attention_forward_impl | ||
| from vexact.batch_invariant_ops import flash_attention_forward_cute as flash_attention_forward_cute_impl | ||
| from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl |
There was a problem hiding this comment.
Import disable_batch_invariant_mode to allow explicitly disabling the global ATen patches when variant mode is requested. Also import the new cute variant for Blackwell support.
| from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl | |
| from vexact.batch_invariant_ops import ( | |
| disable_batch_invariant_mode, | |
| flash_attention_forward_variant as flash_attention_forward_variant_impl, | |
| flash_attention_forward_variant_cute as flash_attention_forward_variant_cute_impl, | |
| ) |
| ALL_ATTENTION_FUNCTIONS["fa-invariant"] = flash_attention_forward_impl | ||
| ALL_ATTENTION_FUNCTIONS["fa-invariant-cute"] = flash_attention_forward_cute_impl | ||
| # Non-invariant variant: same kernel path as fa-invariant but without num_splits=1. | ||
| ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl |
There was a problem hiding this comment.
Register the fa-variant-cute implementation.
| ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl | |
| ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl | |
| ALL_ATTENTION_FUNCTIONS["fa-variant-cute"] = flash_attention_forward_variant_cute_impl |
| if enable_batch_invariant and not is_batch_invariant_mode_enabled(): | ||
| enable_batch_invariant_mode() | ||
| if not enable_batch_invariant: | ||
| logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)") |
There was a problem hiding this comment.
The current implementation only enables batch-invariant mode if requested, but it doesn't disable it if it was previously enabled (e.g., in a shared process or test environment). To ensure "variant" mode truly disables the patches as described in the CLI help, disable_batch_invariant_mode() should be called when enable_batch_invariant is False.
if enable_batch_invariant:
if not is_batch_invariant_mode_enabled():
enable_batch_invariant_mode()
else:
if is_batch_invariant_mode_enabled():
disable_batch_invariant_mode()
logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)")Document VeXact and vllm throughput cost of batch-invariance (256/512 requests, Qwen3-1.7B and Qwen3-30B-A3B, on H100). Ignore the local benchmark runner scripts and result logs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
No description provided.