-
Notifications
You must be signed in to change notification settings - Fork 5
[runtime] feat: support variant mode #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| # Batch-Invariant vs Variant Throughput | ||
|
|
||
| Throughput cost of batch-invariance, measured on a single H100. Covers two | ||
| engines: | ||
|
|
||
| - **VeXact** — its `variant` mode (commit `support variant mode`) disables the | ||
| deterministic ATen patches and unlocks FA Split-KV. | ||
| - **vllm 0.14.1** — its native batch-invariant mode, toggled by the | ||
| `VLLM_BATCH_INVARIANT` env var. | ||
|
|
||
| Both trade bit-level reproducibility for throughput. The VeXact and vllm numbers | ||
| are **not** directly comparable (different engines / graph settings); each | ||
| engine's own on/off delta is the meaningful result. | ||
|
|
||
| ## Summary (256 requests, 1× H100, Qwen3-1.7B) | ||
|
|
||
| | Engine | batch-invariant | Throughput (req/s) | Total tokens/s | Wall time | Wall-time slowdown (ON vs OFF) | | ||
| | ------ | --------------- | ------------------ | -------------- | --------- | ------------------------------ | | ||
| | VeXact | ON (invariant) | 9.25 | 4015.19 | 27.66s | **+82%** (1.82×) | | ||
| | VeXact | OFF (variant) | 16.86 | 7314.25 | 15.19s | — (baseline) | | ||
| | vllm | ON (invariant) | 10.83 | 4698.53 | 23.64s | **+127%** (2.27×) | | ||
| | vllm | OFF (variant) | 24.59 | 10666.40 | 10.41s | — (baseline) | | ||
|
|
||
| Turning batch-invariance ON costs **+82% wall time on VeXact** (1.82×) and | ||
| **+127% on vllm** (2.27×). Cross-engine absolute numbers are not comparable | ||
| (VeXact runs with CUDA graph; vllm runs with `enforce_eager`). Per-engine | ||
| details and metric definitions follow below. | ||
|
|
||
| ### Summary — 512 requests | ||
|
|
||
| | Engine | batch-invariant | Throughput (req/s) | Total tokens/s | Wall time | Wall-time slowdown (ON vs OFF) | | ||
| | ------ | --------------- | ------------------ | -------------- | --------- | ------------------------------ | | ||
| | VeXact | ON (invariant) | 11.82 | 5030.05 | 43.30s | **+77%** (1.77×) | | ||
| | VeXact | OFF (variant) | 20.88 | 8881.09 | 24.52s | — (baseline) | | ||
| | vllm | ON (invariant) | 18.95 | 8060.05 | 27.02s | **+111%** (2.11×) | | ||
| | vllm | OFF (variant) | 40.03 | 17028.39 | 12.79s | — (baseline) | | ||
|
|
||
| ## Setup | ||
|
|
||
| | Item | Value | | ||
| | --------- | --------------------------------------------------------------------- | | ||
| | Model | `Qwen3-1.7B` | | ||
| | Dataset | `ShareGPT_V3_unfiltered_cleaned_split.json` | | ||
| | Requests | 256 | | ||
| | GPU | 1× H100-SXM-80GB (`CUDA_VISIBLE_DEVICES=0`) | | ||
| | Scheduler | `max_num_batched_tokens=2048`, `max_num_seqs=512`, chunked prefill on | | ||
| | Branch | `feat/switch_mode` (rebased on `main`) | | ||
|
|
||
| Reproduce with `benchmarks/run_invariance_compare.sh` (knobs: `NUM_REQUESTS`, `GPU`). | ||
|
|
||
| # VeXact | ||
|
|
||
| ## Results | ||
|
|
||
| ### 256 requests | ||
|
|
||
| | Metric | invariant (default) | variant | Δ | | ||
| | ------------------ | ------------------- | ------- | -------- | | ||
| | Throughput (req/s) | 9.25 | 16.86 | **+82%** | | ||
| | Total tokens/s | 4015.19 | 7314.25 | **+82%** | | ||
| | Output tokens/s | 2017.25 | 3674.71 | **+82%** | | ||
| | Wall time | 27.66s | 15.19s | **−45%** | | ||
| | Avg latency | 15.31s | 8.44s | −45% | | ||
| | P50 latency | 16.26s | 9.01s | −45% | | ||
| | P95 latency | 24.28s | 13.39s | −45% | | ||
|
|
||
| ### 512 requests | ||
|
|
||
| | Metric | invariant (default) | variant | Δ | | ||
| | ------------------ | ------------------- | ------- | -------- | | ||
| | Throughput (req/s) | 11.82 | 20.88 | **+77%** | | ||
| | Total tokens/s | 5030.05 | 8881.09 | **+77%** | | ||
| | Output tokens/s | 2597.63 | 4586.38 | **+77%** | | ||
| | Wall time | 43.30s | 24.52s | **−43%** | | ||
| | Avg latency | 24.72s | 13.83s | −44% | | ||
| | P50 latency | 26.31s | 14.72s | −44% | | ||
| | P95 latency | 39.21s | 22.21s | −43% | | ||
|
|
||
| The speedup holds across scales (+82% at 256 req, +77% at 512 req) — it is not | ||
| a small-sample artifact. Larger batches lift both modes (fuller batches), so the | ||
| relative gap stays roughly constant. | ||
|
|
||
| > Token counts differ slightly between runs because sampling draws different | ||
| > ShareGPT samples; throughput/latency are the comparable metrics. | ||
|
|
||
| ## Metric definitions | ||
|
|
||
| All requests are submitted concurrently via `asyncio.gather`; **wall time** is | ||
| the clock time wrapping that gather (`benchmarks/throughput.py:147-149`): | ||
|
|
||
| ``` | ||
| wall_time = t_end − t_start # around gather() of all concurrent requests | ||
| throughput = completed / wall_time # requests/s | ||
| total_tps = (prompt_tokens + output_tokens) / wall_time | ||
| output_tps = output_tokens / wall_time | ||
| latency_i = t_done_i − t_submit_i # per-request, end-to-end | ||
| avg/P50/P95 = mean / 50th / 95th percentile over latency_i | ||
| ``` | ||
|
|
||
| Because requests run concurrently, `wall_time` is the makespan of the whole | ||
| batch — close to `max(latency_i)`, not their sum. | ||
|
|
||
| ## What changes between modes | ||
|
|
||
| | Layer | invariant | variant | | ||
| | ----------------------------------- | ----------------------------------------------------------------------- | -------------------------------------------------------------------------- | | ||
| | ATen ops (`enable_batch_invariant`) | deterministic Triton kernels patched in (matmul, rms_norm, …) | patches off — native CUDA kernels | | ||
| | Attention (`attn_impl`) | `fa-invariant` — FA forced to `num_splits=1`, fixed LSE reduction order | `fa-variant` — `num_splits` unlocked, kernel picks Split-KV by batch shape | | ||
|
|
||
| Both effects are confirmed in the run logs (`batch invariant mode DISABLED (variant)`, | ||
| `attn_impl: fa-variant`). | ||
|
|
||
| ## Takeaway | ||
|
|
||
| Disabling batch-invariance is **~1.8× faster** end-to-end and nearly halves | ||
| wall time. The invariant mode spends roughly **45% extra latency** to buy | ||
| bit-level reproducibility / batch invariance — useful when training and | ||
| inference must produce identical logits regardless of how requests are batched. | ||
|
|
||
| ## How to switch | ||
|
|
||
| ```bash | ||
| # invariant (default) | ||
| python benchmarks/throughput.py --model-path <model> --dataset-path <data> --mode invariant | ||
|
|
||
| # variant | ||
| python benchmarks/throughput.py --model-path <model> --dataset-path <data> --mode variant | ||
| ``` | ||
|
|
||
| In code (`ModelConfig`): `enable_batch_invariant=True/False` paired with | ||
| `attn_impl="fa-invariant"` / `"fa-variant"`. | ||
|
|
||
| # vllm (native batch-invariant) | ||
|
|
||
| vllm ships its own batch-invariant mode (`vllm.model_executor.layers.batch_invariant`), | ||
| independent of VeXact's. It is toggled by the `VLLM_BATCH_INVARIANT` env var and | ||
| **requires** an explicit attention backend (`FLASH_ATTN` / `FLASHINFER` / | ||
| `*_MLA`) — enabling it with the default `None` backend raises at engine init. | ||
|
|
||
| ## Setup | ||
|
|
||
| | Item | Value | | ||
| | --------------- | ----------------------------------------------------------------------------------------- | | ||
| | vllm | 0.14.1 | | ||
| | Model / Dataset | `Qwen3-1.7B` / same ShareGPT file as above | | ||
| | Requests | 256 (identical samples — same loader + fixed seed) | | ||
| | Output length | fixed via `ignore_eos`, so on/off do identical token work | | ||
| | Engine | `enforce_eager=True` for **both** runs (cudagraph off, isolates the batch-invariant cost) | | ||
| | Backend | `VLLM_ATTENTION_BACKEND=FLASH_ATTN` for both | | ||
| | GPU | 1× H100-SXM-80GB | | ||
|
|
||
| Reproduce with `benchmarks/run_vllm_invariance_compare.sh` | ||
| (benchmark: `benchmarks/vllm_throughput.py`). | ||
|
|
||
| ## Results (256 requests) | ||
|
|
||
| | Metric | invariant (`=1`) | variant (`=0`) | Δ | | ||
| | -------------------- | ---------------- | -------------- | --------- | | ||
| | Throughput (req/s) | 10.83 | 24.59 | **+127%** | | ||
| | Total tokens/s | 4698.53 | 10666.40 | **+127%** | | ||
| | Output tokens/s | 2360.56 | 5358.84 | **+127%** | | ||
| | Wall time (makespan) | 23.64s | 10.41s | **−56%** | | ||
|
|
||
| Prompt/output token counts were identical across runs (55265 / 55799). | ||
| Per-request latency percentiles are omitted: offline `LLM.generate` in vllm | ||
| 0.14.1 does not populate `RequestOutput.metrics`; measuring them needs the | ||
| `AsyncLLMEngine` submit-and-await path. | ||
|
|
||
| ## Takeaway | ||
|
|
||
| vllm's batch-invariant mode is **~2.3× slower** (variant +127%) under eager | ||
| mode — a steeper penalty than VeXact's ~1.8×, since vllm replaces more ATen ops | ||
| (addmm/bmm/mm/rms_norm/softmax/log_softmax/mean) with deterministic kernels. | ||
|
|
||
| ## How to switch | ||
|
|
||
| ```bash | ||
| # variant (default) — batch-invariant off | ||
| VLLM_BATCH_INVARIANT=0 VLLM_ATTENTION_BACKEND=FLASH_ATTN python ... | ||
|
|
||
| # invariant — batch-invariant on (attention backend is mandatory) | ||
| VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python ... | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,11 @@ | |
| set_batch_invariant_mode, | ||
| triton_bmm, | ||
| ) | ||
| from .flash_attention import flash_attention_forward, flash_attention_forward_cute | ||
| from .flash_attention import ( | ||
| flash_attention_forward, | ||
| flash_attention_forward_cute, | ||
| flash_attention_forward_variant, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
| from .flex_attention import flex_attention_forward | ||
|
|
||
|
|
||
|
|
@@ -43,6 +47,7 @@ | |
| "AttentionBlockSize", | ||
| "flash_attention_forward", | ||
| "flash_attention_forward_cute", | ||
| "flash_attention_forward_variant", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| "flex_attention_forward", | ||
| "triton_bmm", | ||
| "batch_invariant_rms_norm", | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -33,6 +33,7 @@ def flash_attention_forward( | |||||||
| dropout: float = 0.0, | ||||||||
| sliding_window: Optional[int] = None, | ||||||||
| use_cute: bool = False, | ||||||||
| lock_num_splits: bool = True, | ||||||||
| **kwargs, | ||||||||
| ): | ||||||||
| """ | ||||||||
|
|
@@ -139,6 +140,11 @@ def flash_attention_forward( | |||||||
| # max_seqlen_k = context_lens.max().item() | ||||||||
| from vexact.utils.device import DEVICE_MAJOR | ||||||||
|
|
||||||||
| # num_splits=1 keeps the FA Split-KV reduction in a single group, which is | ||||||||
| # necessary for batch-invariance (otherwise the kernel picks split counts | ||||||||
| # based on batch shape, changing LSE reduction order). | ||||||||
| # When lock_num_splits=False, let the kernel pick splits dynamically. | ||||||||
| split_kwargs = {"num_splits": 1} if lock_num_splits else {} | ||||||||
| if use_cute: | ||||||||
| assert DEVICE_MAJOR >= 9, f"FA4 (flash_attn.cute) requires SM90+, got SM{DEVICE_MAJOR}0" | ||||||||
| from flash_attn.cute import flash_attn_varlen_func | ||||||||
|
|
@@ -155,7 +161,7 @@ def flash_attention_forward( | |||||||
| page_table=block_tables, # (B, max_num_blocks_per_seq) int32 | ||||||||
| softmax_scale=scaling, | ||||||||
| causal=True, | ||||||||
| num_splits=1, | ||||||||
| **split_kwargs, | ||||||||
| ) | ||||||||
| else: | ||||||||
| assert DEVICE_MAJOR == 9, f"FA3 requires SM90, got SM{DEVICE_MAJOR}0" | ||||||||
|
|
@@ -174,11 +180,15 @@ def flash_attention_forward( | |||||||
| causal=True, | ||||||||
| page_table=block_tables, | ||||||||
| # window_size=window_size | ||||||||
| num_splits=1, | ||||||||
| **split_kwargs, | ||||||||
| ) | ||||||||
| # Output shape: (H, total_query_tokens, D) | ||||||||
| return attn_output, None | ||||||||
|
|
||||||||
|
|
||||||||
| # FA4 (flash_attn.cute) variant — forces use_cute=True on both Hopper and Blackwell | ||||||||
| flash_attention_forward_cute = partial(flash_attention_forward, use_cute=True) | ||||||||
|
|
||||||||
| # Non-invariant variant: same kernel but without the num_splits=1 lock, | ||||||||
| # so the kernel is free to pick Split-KV counts for best throughput. | ||||||||
| flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False) | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,7 +135,7 @@ def __post_init__(self): | |
| raise ValueError("model_path is required!\nUsage: python script.py model.model_path=/path/to/model") | ||
|
|
||
| # Validate attention implementation | ||
| valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex"] | ||
| valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if self.attn_impl not in valid_attn_impls: | ||
| raise ValueError(f"attn_impl must be one of {valid_attn_impls}, got '{self.attn_impl}'") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |||||||||||||
| ) | ||||||||||||||
| from vexact.batch_invariant_ops import flash_attention_forward as flash_attention_forward_impl | ||||||||||||||
| from vexact.batch_invariant_ops import flash_attention_forward_cute as flash_attention_forward_cute_impl | ||||||||||||||
| from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import
Suggested change
|
||||||||||||||
| from vexact.batch_invariant_ops.kv_cache_context import ( | ||||||||||||||
| KVCacheStore, | ||||||||||||||
| set_kv_cache_context, | ||||||||||||||
|
|
@@ -47,10 +48,12 @@ | |||||||||||||
| # Module-level logger | ||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||
|
|
||||||||||||||
| # Register two invariant attention implementations | ||||||||||||||
| # Register invariant attention implementations | ||||||||||||||
| ALL_ATTENTION_FUNCTIONS["flex"] = flex_attention_forward | ||||||||||||||
| ALL_ATTENTION_FUNCTIONS["fa-invariant"] = flash_attention_forward_impl | ||||||||||||||
| ALL_ATTENTION_FUNCTIONS["fa-invariant-cute"] = flash_attention_forward_cute_impl | ||||||||||||||
| # Non-invariant variant: same kernel path as fa-invariant but without num_splits=1. | ||||||||||||||
| ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Register the
Suggested change
|
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class Inferencer: | ||||||||||||||
|
|
@@ -88,8 +91,11 @@ def __init__( | |||||||||||||
| # TODO: remove them here, create block table helper to get slot mappings | ||||||||||||||
| self.page_size = self.cache_config.page_size | ||||||||||||||
|
|
||||||||||||||
| if not is_batch_invariant_mode_enabled(): | ||||||||||||||
| self.enable_batch_invariant = enable_batch_invariant | ||||||||||||||
| if enable_batch_invariant and not is_batch_invariant_mode_enabled(): | ||||||||||||||
| enable_batch_invariant_mode() | ||||||||||||||
| if not enable_batch_invariant: | ||||||||||||||
| logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)") | ||||||||||||||
|
Comment on lines
+95
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation only enables batch-invariant mode if requested, but it doesn't disable it if it was previously enabled (e.g., in a shared process or test environment). To ensure "variant" mode truly disables the patches as described in the CLI help, if enable_batch_invariant:
if not is_batch_invariant_mode_enabled():
enable_batch_invariant_mode()
else:
if is_batch_invariant_mode_enabled():
disable_batch_invariant_mode()
logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)") |
||||||||||||||
|
|
||||||||||||||
| self.input_buffers = InputBuffers( | ||||||||||||||
| device=self.device, | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
fa-variant-cuteto the choices to support variant mode on Blackwell (SM100) hardware, consistent with thefa-invariant-cuteoption.