Skip to content

[runtime] feat: support variant mode#17

Draft
Neawhen wants to merge 2 commits into
mainfrom
feat/switch_mode
Draft

[runtime] feat: support variant mode#17
Neawhen wants to merge 2 commits into
mainfrom
feat/switch_mode

Conversation

@Neawhen
Copy link
Copy Markdown
Collaborator

@Neawhen Neawhen commented Apr 22, 2026

No description provided.

@Neawhen Neawhen requested review from Luosuu and pengwu22 and removed request for Luosuu and pengwu22 April 22, 2026 21:20
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a "variant" mode for Flash Attention, allowing for dynamic num_splits selection to improve throughput when batch-invariance is not required. The changes include new CLI arguments, configuration validation, and logic within the inferencer to toggle batch-invariant patches. Review feedback suggests extending this support to Blackwell hardware by adding a fa-variant-cute implementation and recommends explicitly disabling batch-invariant mode when the variant path is selected to ensure consistent global state.


# Non-invariant variant: same kernel but without the num_splits=1 lock,
# so the kernel is free to pick Split-KV counts for best throughput.
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fa-variant implementation currently defaults to use_cute=False, which will cause a crash on Blackwell (SM100) hardware due to the assertion at line 167 (which requires SM90 for the non-cute path). Adding a fa-variant-cute implementation allows variant mode to be used on newer hardware, following the existing pattern for invariant mode.

Suggested change
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
flash_attention_forward_variant_cute = partial(flash_attention_forward, lock_num_splits=False, use_cute=True)

Comment thread benchmarks/throughput.py
choices=["fa-invariant", "fa-invariant-cute", "flex"],
default="fa-invariant",
help="Attention implementation (default: fa-invariant, i.e. flash attn).",
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add fa-variant-cute to the choices to support variant mode on Blackwell (SM100) hardware, consistent with the fa-invariant-cute option.

Suggested change
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"],
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant", "fa-variant-cute"],

from .flash_attention import (
flash_attention_forward,
flash_attention_forward_cute,
flash_attention_forward_variant,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Export flash_attention_forward_variant_cute to support variant mode on Blackwell hardware.

    flash_attention_forward_variant,
    flash_attention_forward_variant_cute,

"AttentionBlockSize",
"flash_attention_forward",
"flash_attention_forward_cute",
"flash_attention_forward_variant",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add flash_attention_forward_variant_cute to __all__.

Suggested change
"flash_attention_forward_variant",
"flash_attention_forward_variant",
"flash_attention_forward_variant_cute",

Comment thread vexact/config.py

# Validate attention implementation
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex"]
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Include fa-variant-cute in the list of valid attention implementations.

Suggested change
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"]
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant", "fa-variant-cute"]

)
from vexact.batch_invariant_ops import flash_attention_forward as flash_attention_forward_impl
from vexact.batch_invariant_ops import flash_attention_forward_cute as flash_attention_forward_cute_impl
from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Import disable_batch_invariant_mode to allow explicitly disabling the global ATen patches when variant mode is requested. Also import the new cute variant for Blackwell support.

Suggested change
from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl
from vexact.batch_invariant_ops import (
disable_batch_invariant_mode,
flash_attention_forward_variant as flash_attention_forward_variant_impl,
flash_attention_forward_variant_cute as flash_attention_forward_variant_cute_impl,
)

ALL_ATTENTION_FUNCTIONS["fa-invariant"] = flash_attention_forward_impl
ALL_ATTENTION_FUNCTIONS["fa-invariant-cute"] = flash_attention_forward_cute_impl
# Non-invariant variant: same kernel path as fa-invariant but without num_splits=1.
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Register the fa-variant-cute implementation.

Suggested change
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
ALL_ATTENTION_FUNCTIONS["fa-variant-cute"] = flash_attention_forward_variant_cute_impl

Comment on lines +95 to +98
if enable_batch_invariant and not is_batch_invariant_mode_enabled():
enable_batch_invariant_mode()
if not enable_batch_invariant:
logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation only enables batch-invariant mode if requested, but it doesn't disable it if it was previously enabled (e.g., in a shared process or test environment). To ensure "variant" mode truly disables the patches as described in the CLI help, disable_batch_invariant_mode() should be called when enable_batch_invariant is False.

        if enable_batch_invariant:
            if not is_batch_invariant_mode_enabled():
                enable_batch_invariant_mode()
        else:
            if is_batch_invariant_mode_enabled():
                disable_batch_invariant_mode()
            logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)")

neiwen.ling and others added 2 commits May 29, 2026 02:56
Document VeXact and vllm throughput cost of batch-invariance (256/512
requests, Qwen3-1.7B and Qwen3-30B-A3B, on H100). Ignore the local
benchmark runner scripts and result logs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@Neawhen Neawhen force-pushed the feat/switch_mode branch from 4835504 to ea02ac7 Compare May 28, 2026 23:44
@Neawhen Neawhen marked this pull request as draft May 28, 2026 23:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant