Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,10 @@ logs/
inference_outputs*
outputs/
verl_rollout_profile_gsm8k

# local batch-invariance benchmark scripts + results (keep batch_invariance_results.md)
benchmarks/run_invariance_compare.sh
benchmarks/run_vllm_invariance_compare.sh
benchmarks/vllm_throughput.py
benchmarks/invariance_results*/
benchmarks/vllm_invariance_results*/
183 changes: 183 additions & 0 deletions benchmarks/batch_invariance_results.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Batch-Invariant vs Variant Throughput

Throughput cost of batch-invariance, measured on a single H100. Covers two
engines:

- **VeXact** — its `variant` mode (commit `support variant mode`) disables the
deterministic ATen patches and unlocks FA Split-KV.
- **vllm 0.14.1** — its native batch-invariant mode, toggled by the
`VLLM_BATCH_INVARIANT` env var.

Both trade bit-level reproducibility for throughput. The VeXact and vllm numbers
are **not** directly comparable (different engines / graph settings); each
engine's own on/off delta is the meaningful result.

## Summary (256 requests, 1× H100, Qwen3-1.7B)

| Engine | batch-invariant | Throughput (req/s) | Total tokens/s | Wall time | Wall-time slowdown (ON vs OFF) |
| ------ | --------------- | ------------------ | -------------- | --------- | ------------------------------ |
| VeXact | ON (invariant) | 9.25 | 4015.19 | 27.66s | **+82%** (1.82×) |
| VeXact | OFF (variant) | 16.86 | 7314.25 | 15.19s | — (baseline) |
| vllm | ON (invariant) | 10.83 | 4698.53 | 23.64s | **+127%** (2.27×) |
| vllm | OFF (variant) | 24.59 | 10666.40 | 10.41s | — (baseline) |

Turning batch-invariance ON costs **+82% wall time on VeXact** (1.82×) and
**+127% on vllm** (2.27×). Cross-engine absolute numbers are not comparable
(VeXact runs with CUDA graph; vllm runs with `enforce_eager`). Per-engine
details and metric definitions follow below.

### Summary — 512 requests

| Engine | batch-invariant | Throughput (req/s) | Total tokens/s | Wall time | Wall-time slowdown (ON vs OFF) |
| ------ | --------------- | ------------------ | -------------- | --------- | ------------------------------ |
| VeXact | ON (invariant) | 11.82 | 5030.05 | 43.30s | **+77%** (1.77×) |
| VeXact | OFF (variant) | 20.88 | 8881.09 | 24.52s | — (baseline) |
| vllm | ON (invariant) | 18.95 | 8060.05 | 27.02s | **+111%** (2.11×) |
| vllm | OFF (variant) | 40.03 | 17028.39 | 12.79s | — (baseline) |

## Setup

| Item | Value |
| --------- | --------------------------------------------------------------------- |
| Model | `Qwen3-1.7B` |
| Dataset | `ShareGPT_V3_unfiltered_cleaned_split.json` |
| Requests | 256 |
| GPU | 1× H100-SXM-80GB (`CUDA_VISIBLE_DEVICES=0`) |
| Scheduler | `max_num_batched_tokens=2048`, `max_num_seqs=512`, chunked prefill on |
| Branch | `feat/switch_mode` (rebased on `main`) |

Reproduce with `benchmarks/run_invariance_compare.sh` (knobs: `NUM_REQUESTS`, `GPU`).

# VeXact

## Results

### 256 requests

| Metric | invariant (default) | variant | Δ |
| ------------------ | ------------------- | ------- | -------- |
| Throughput (req/s) | 9.25 | 16.86 | **+82%** |
| Total tokens/s | 4015.19 | 7314.25 | **+82%** |
| Output tokens/s | 2017.25 | 3674.71 | **+82%** |
| Wall time | 27.66s | 15.19s | **−45%** |
| Avg latency | 15.31s | 8.44s | −45% |
| P50 latency | 16.26s | 9.01s | −45% |
| P95 latency | 24.28s | 13.39s | −45% |

### 512 requests

| Metric | invariant (default) | variant | Δ |
| ------------------ | ------------------- | ------- | -------- |
| Throughput (req/s) | 11.82 | 20.88 | **+77%** |
| Total tokens/s | 5030.05 | 8881.09 | **+77%** |
| Output tokens/s | 2597.63 | 4586.38 | **+77%** |
| Wall time | 43.30s | 24.52s | **−43%** |
| Avg latency | 24.72s | 13.83s | −44% |
| P50 latency | 26.31s | 14.72s | −44% |
| P95 latency | 39.21s | 22.21s | −43% |

The speedup holds across scales (+82% at 256 req, +77% at 512 req) — it is not
a small-sample artifact. Larger batches lift both modes (fuller batches), so the
relative gap stays roughly constant.

> Token counts differ slightly between runs because sampling draws different
> ShareGPT samples; throughput/latency are the comparable metrics.

## Metric definitions

All requests are submitted concurrently via `asyncio.gather`; **wall time** is
the clock time wrapping that gather (`benchmarks/throughput.py:147-149`):

```
wall_time = t_end − t_start # around gather() of all concurrent requests
throughput = completed / wall_time # requests/s
total_tps = (prompt_tokens + output_tokens) / wall_time
output_tps = output_tokens / wall_time
latency_i = t_done_i − t_submit_i # per-request, end-to-end
avg/P50/P95 = mean / 50th / 95th percentile over latency_i
```

Because requests run concurrently, `wall_time` is the makespan of the whole
batch — close to `max(latency_i)`, not their sum.

## What changes between modes

| Layer | invariant | variant |
| ----------------------------------- | ----------------------------------------------------------------------- | -------------------------------------------------------------------------- |
| ATen ops (`enable_batch_invariant`) | deterministic Triton kernels patched in (matmul, rms_norm, …) | patches off — native CUDA kernels |
| Attention (`attn_impl`) | `fa-invariant` — FA forced to `num_splits=1`, fixed LSE reduction order | `fa-variant` — `num_splits` unlocked, kernel picks Split-KV by batch shape |

Both effects are confirmed in the run logs (`batch invariant mode DISABLED (variant)`,
`attn_impl: fa-variant`).

## Takeaway

Disabling batch-invariance is **~1.8× faster** end-to-end and nearly halves
wall time. The invariant mode spends roughly **45% extra latency** to buy
bit-level reproducibility / batch invariance — useful when training and
inference must produce identical logits regardless of how requests are batched.

## How to switch

```bash
# invariant (default)
python benchmarks/throughput.py --model-path <model> --dataset-path <data> --mode invariant

# variant
python benchmarks/throughput.py --model-path <model> --dataset-path <data> --mode variant
```

In code (`ModelConfig`): `enable_batch_invariant=True/False` paired with
`attn_impl="fa-invariant"` / `"fa-variant"`.

# vllm (native batch-invariant)

vllm ships its own batch-invariant mode (`vllm.model_executor.layers.batch_invariant`),
independent of VeXact's. It is toggled by the `VLLM_BATCH_INVARIANT` env var and
**requires** an explicit attention backend (`FLASH_ATTN` / `FLASHINFER` /
`*_MLA`) — enabling it with the default `None` backend raises at engine init.

## Setup

| Item | Value |
| --------------- | ----------------------------------------------------------------------------------------- |
| vllm | 0.14.1 |
| Model / Dataset | `Qwen3-1.7B` / same ShareGPT file as above |
| Requests | 256 (identical samples — same loader + fixed seed) |
| Output length | fixed via `ignore_eos`, so on/off do identical token work |
| Engine | `enforce_eager=True` for **both** runs (cudagraph off, isolates the batch-invariant cost) |
| Backend | `VLLM_ATTENTION_BACKEND=FLASH_ATTN` for both |
| GPU | 1× H100-SXM-80GB |

Reproduce with `benchmarks/run_vllm_invariance_compare.sh`
(benchmark: `benchmarks/vllm_throughput.py`).

## Results (256 requests)

| Metric | invariant (`=1`) | variant (`=0`) | Δ |
| -------------------- | ---------------- | -------------- | --------- |
| Throughput (req/s) | 10.83 | 24.59 | **+127%** |
| Total tokens/s | 4698.53 | 10666.40 | **+127%** |
| Output tokens/s | 2360.56 | 5358.84 | **+127%** |
| Wall time (makespan) | 23.64s | 10.41s | **−56%** |

Prompt/output token counts were identical across runs (55265 / 55799).
Per-request latency percentiles are omitted: offline `LLM.generate` in vllm
0.14.1 does not populate `RequestOutput.metrics`; measuring them needs the
`AsyncLLMEngine` submit-and-await path.

## Takeaway

vllm's batch-invariant mode is **~2.3× slower** (variant +127%) under eager
mode — a steeper penalty than VeXact's ~1.8×, since vllm replaces more ATen ops
(addmm/bmm/mm/rms_norm/softmax/log_softmax/mean) with deterministic kernels.

## How to switch

```bash
# variant (default) — batch-invariant off
VLLM_BATCH_INVARIANT=0 VLLM_ATTENTION_BACKEND=FLASH_ATTN python ...

# invariant — batch-invariant on (attention backend is mandatory)
VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python ...
```
23 changes: 19 additions & 4 deletions benchmarks/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def build_vexact_engine(
profiler_delay_iterations: int = 0,
profiler_max_iterations: int = 0,
attn_impl: str = "fa-invariant",
enable_batch_invariant: bool = True,
):
config = VeXactConfig(
model=ModelConfig(
model_path=model_path,
attn_impl=attn_impl,
enable_batch_invariant=True,
enable_batch_invariant=enable_batch_invariant,
enable_memory_saver=False,
enforce_eager=False,
use_fp32_logits=True,
Expand Down Expand Up @@ -223,11 +224,21 @@ def _parse_args():
default="INFO",
help="Python logging level (e.g., DEBUG, INFO, WARNING).",
)
parser.add_argument(
"--mode",
choices=["invariant", "variant"],
default="invariant",
help=(
"Batch-invariance mode. 'invariant' (default) keeps the batch-invariant "
"ATen patches on and defaults attn to fa-invariant. 'variant' disables "
"the patches and defaults attn to fa-variant (num_splits unlocked)."
),
)
parser.add_argument(
"--attn-impl",
choices=["fa-invariant", "fa-invariant-cute", "flex"],
default="fa-invariant",
help="Attention implementation (default: fa-invariant, i.e. flash attn).",
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add fa-variant-cute to the choices to support variant mode on Blackwell (SM100) hardware, consistent with the fa-invariant-cute option.

Suggested change
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"],
choices=["fa-invariant", "fa-invariant-cute", "flex", "fa-variant", "fa-variant-cute"],

default=None,
help=("Attention implementation. If unset, follows --mode: invariant -> fa-invariant, variant -> fa-variant."),
)
parser.add_argument(
"--profile-backend",
Expand Down Expand Up @@ -261,6 +272,8 @@ def main():
level=getattr(logging, args.log_level.upper(), logging.INFO),
force=True,
)
if args.attn_impl is None:
args.attn_impl = "fa-invariant" if args.mode == "invariant" else "fa-variant"
engine = build_vexact_engine(
model_path=args.model_path,
pp_size=args.pp_size,
Expand All @@ -272,6 +285,7 @@ def main():
profiler_delay_iterations=args.profile_delay_iterations,
profiler_max_iterations=args.profile_max_iterations,
attn_impl=args.attn_impl,
enable_batch_invariant=(args.mode == "invariant"),
)
try:
total_prompt_tokens, total_output_tokens, latencies, errors, completed, total_time = run_throughput(
Expand Down Expand Up @@ -303,6 +317,7 @@ def main():
print(f"\n{'=' * 60}")
print("THROUGHPUT SUMMARY")
print(f"{'=' * 60}")
print(f"Mode: {args.mode} | attn_impl: {args.attn_impl}")
print(f"Throughput: {rps:.2f} requests/s, {total_tps:.2f} total tokens/s, {output_tps:.2f} output tokens/s")
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
Expand Down
7 changes: 6 additions & 1 deletion vexact/batch_invariant_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
set_batch_invariant_mode,
triton_bmm,
)
from .flash_attention import flash_attention_forward, flash_attention_forward_cute
from .flash_attention import (
flash_attention_forward,
flash_attention_forward_cute,
flash_attention_forward_variant,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Export flash_attention_forward_variant_cute to support variant mode on Blackwell hardware.

    flash_attention_forward_variant,
    flash_attention_forward_variant_cute,

)
from .flex_attention import flex_attention_forward


Expand All @@ -43,6 +47,7 @@
"AttentionBlockSize",
"flash_attention_forward",
"flash_attention_forward_cute",
"flash_attention_forward_variant",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add flash_attention_forward_variant_cute to __all__.

Suggested change
"flash_attention_forward_variant",
"flash_attention_forward_variant",
"flash_attention_forward_variant_cute",

"flex_attention_forward",
"triton_bmm",
"batch_invariant_rms_norm",
Expand Down
14 changes: 12 additions & 2 deletions vexact/batch_invariant_ops/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def flash_attention_forward(
dropout: float = 0.0,
sliding_window: Optional[int] = None,
use_cute: bool = False,
lock_num_splits: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -139,6 +140,11 @@ def flash_attention_forward(
# max_seqlen_k = context_lens.max().item()
from vexact.utils.device import DEVICE_MAJOR

# num_splits=1 keeps the FA Split-KV reduction in a single group, which is
# necessary for batch-invariance (otherwise the kernel picks split counts
# based on batch shape, changing LSE reduction order).
# When lock_num_splits=False, let the kernel pick splits dynamically.
split_kwargs = {"num_splits": 1} if lock_num_splits else {}
if use_cute:
assert DEVICE_MAJOR >= 9, f"FA4 (flash_attn.cute) requires SM90+, got SM{DEVICE_MAJOR}0"
from flash_attn.cute import flash_attn_varlen_func
Expand All @@ -155,7 +161,7 @@ def flash_attention_forward(
page_table=block_tables, # (B, max_num_blocks_per_seq) int32
softmax_scale=scaling,
causal=True,
num_splits=1,
**split_kwargs,
)
else:
assert DEVICE_MAJOR == 9, f"FA3 requires SM90, got SM{DEVICE_MAJOR}0"
Expand All @@ -174,11 +180,15 @@ def flash_attention_forward(
causal=True,
page_table=block_tables,
# window_size=window_size
num_splits=1,
**split_kwargs,
)
# Output shape: (H, total_query_tokens, D)
return attn_output, None


# FA4 (flash_attn.cute) variant — forces use_cute=True on both Hopper and Blackwell
flash_attention_forward_cute = partial(flash_attention_forward, use_cute=True)

# Non-invariant variant: same kernel but without the num_splits=1 lock,
# so the kernel is free to pick Split-KV counts for best throughput.
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fa-variant implementation currently defaults to use_cute=False, which will cause a crash on Blackwell (SM100) hardware due to the assertion at line 167 (which requires SM90 for the non-cute path). Adding a fa-variant-cute implementation allows variant mode to be used on newer hardware, following the existing pattern for invariant mode.

Suggested change
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
flash_attention_forward_variant = partial(flash_attention_forward, lock_num_splits=False)
flash_attention_forward_variant_cute = partial(flash_attention_forward, lock_num_splits=False, use_cute=True)

2 changes: 1 addition & 1 deletion vexact/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __post_init__(self):
raise ValueError("model_path is required!\nUsage: python script.py model.model_path=/path/to/model")

# Validate attention implementation
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex"]
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Include fa-variant-cute in the list of valid attention implementations.

Suggested change
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant"]
valid_attn_impls = ["fa-invariant", "fa-invariant-cute", "flex", "fa-variant", "fa-variant-cute"]

if self.attn_impl not in valid_attn_impls:
raise ValueError(f"attn_impl must be one of {valid_attn_impls}, got '{self.attn_impl}'")

Expand Down
10 changes: 8 additions & 2 deletions vexact/inferencer/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from vexact.batch_invariant_ops import flash_attention_forward as flash_attention_forward_impl
from vexact.batch_invariant_ops import flash_attention_forward_cute as flash_attention_forward_cute_impl
from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Import disable_batch_invariant_mode to allow explicitly disabling the global ATen patches when variant mode is requested. Also import the new cute variant for Blackwell support.

Suggested change
from vexact.batch_invariant_ops import flash_attention_forward_variant as flash_attention_forward_variant_impl
from vexact.batch_invariant_ops import (
disable_batch_invariant_mode,
flash_attention_forward_variant as flash_attention_forward_variant_impl,
flash_attention_forward_variant_cute as flash_attention_forward_variant_cute_impl,
)

from vexact.batch_invariant_ops.kv_cache_context import (
KVCacheStore,
set_kv_cache_context,
Expand All @@ -47,10 +48,12 @@
# Module-level logger
logger = logging.getLogger(__name__)

# Register two invariant attention implementations
# Register invariant attention implementations
ALL_ATTENTION_FUNCTIONS["flex"] = flex_attention_forward
ALL_ATTENTION_FUNCTIONS["fa-invariant"] = flash_attention_forward_impl
ALL_ATTENTION_FUNCTIONS["fa-invariant-cute"] = flash_attention_forward_cute_impl
# Non-invariant variant: same kernel path as fa-invariant but without num_splits=1.
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Register the fa-variant-cute implementation.

Suggested change
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
ALL_ATTENTION_FUNCTIONS["fa-variant"] = flash_attention_forward_variant_impl
ALL_ATTENTION_FUNCTIONS["fa-variant-cute"] = flash_attention_forward_variant_cute_impl



class Inferencer:
Expand Down Expand Up @@ -88,8 +91,11 @@ def __init__(
# TODO: remove them here, create block table helper to get slot mappings
self.page_size = self.cache_config.page_size

if not is_batch_invariant_mode_enabled():
self.enable_batch_invariant = enable_batch_invariant
if enable_batch_invariant and not is_batch_invariant_mode_enabled():
enable_batch_invariant_mode()
if not enable_batch_invariant:
logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)")
Comment on lines +95 to +98
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation only enables batch-invariant mode if requested, but it doesn't disable it if it was previously enabled (e.g., in a shared process or test environment). To ensure "variant" mode truly disables the patches as described in the CLI help, disable_batch_invariant_mode() should be called when enable_batch_invariant is False.

        if enable_batch_invariant:
            if not is_batch_invariant_mode_enabled():
                enable_batch_invariant_mode()
        else:
            if is_batch_invariant_mode_enabled():
                disable_batch_invariant_mode()
            logger.info("[VEXACT] Inferencer: batch invariant mode DISABLED (variant)")


self.input_buffers = InputBuffers(
device=self.device,
Expand Down
Loading