From ee8a29511fc69e3f0f6291fa6ff1cf6e47f7750d Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 7 Mar 2026 17:26:59 +0800 Subject: [PATCH 01/19] [Bugfix] Fix compressed-tensors quantization failure for DeepSeek-R1 on MI300x (#36247) Signed-off-by: vllmellm --- vllm/model_executor/models/deepseek_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5dd883f222e5..8277e99fdc37 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -756,7 +756,7 @@ def _min_latency_fused_qkv_a_proj_fake( ) -class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear): +class DeepSeekV2FusedQkvAProjLinear(MergedColumnParallelLinear): def __init__( self, input_size: int, @@ -848,7 +848,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj( + self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProjLinear( self.hidden_size, [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], quant_config=quant_config, From 00b814ba5a4139910c0824619a8dc6af547e178a Mon Sep 17 00:00:00 2001 From: lif <1835304752@qq.com> Date: Sat, 7 Mar 2026 22:09:55 +0800 Subject: [PATCH 02/19] [V0 Deprecation] Remove unused swap_space parameter (#36216) Signed-off-by: majiayu000 <1835304752@qq.com> Co-authored-by: mcelrath --- .buildkite/performance-benchmarks/README.md | 1 - .../tests/serving-tests-hpu.json | 4 --- .../tests/serving-tests.json | 4 --- benchmarks/attention_benchmarks/mla_runner.py | 1 - benchmarks/attention_benchmarks/runner.py | 1 - docs/design/metrics.md | 8 ++--- docs/serving/integrations/llamaindex.md | 2 +- tests/conftest.py | 2 -- tests/distributed/test_torchrun_example.py | 3 +- .../distributed/test_torchrun_example_moe.py | 3 +- tests/lora/test_worker.py | 1 - tests/v1/attention/utils.py | 1 - tests/v1/core/test_scheduler.py | 2 -- tests/v1/core/utils.py | 1 - tests/v1/engine/test_engine_core.py | 1 - .../unit/test_moriio_connector.py | 1 - tests/v1/kv_connector/unit/utils.py | 1 - tests/v1/worker/test_gpu_model_runner.py | 3 -- vllm/config/cache.py | 34 +------------------ vllm/config/vllm.py | 2 -- vllm/engine/arg_utils.py | 3 -- vllm/entrypoints/llm.py | 19 ++++++----- 22 files changed, 19 insertions(+), 79 deletions(-) diff --git a/.buildkite/performance-benchmarks/README.md b/.buildkite/performance-benchmarks/README.md index 289877e504bb..3a321c0fefdf 100644 --- a/.buildkite/performance-benchmarks/README.md +++ b/.buildkite/performance-benchmarks/README.md @@ -83,7 +83,6 @@ We test the throughput by using `vllm bench serve` with request rate = inf to co "server_parameters": { "model": "meta-llama/Meta-Llama-3-8B", "tensor_parallel_size": 1, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy" }, diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json b/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json index a2e42aa16fd3..3929aa5fbbe0 100644 --- a/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json +++ b/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json @@ -10,7 +10,6 @@ "server_parameters": { "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy", "max-model-len": 2048, @@ -37,7 +36,6 @@ "server_parameters": { "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy", "max-model-len": 2048, @@ -64,7 +62,6 @@ "server_parameters": { "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", "tensor_parallel_size": 2, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy", "max-model-len": 2048, @@ -91,7 +88,6 @@ "server_parameters": { "model": "deepseek-ai/DeepSeek-R1", "tensor_parallel_size": 8, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy", "max-model-len": 2048, diff --git a/.buildkite/performance-benchmarks/tests/serving-tests.json b/.buildkite/performance-benchmarks/tests/serving-tests.json index a6d4141d5c2d..66d52abc1206 100644 --- a/.buildkite/performance-benchmarks/tests/serving-tests.json +++ b/.buildkite/performance-benchmarks/tests/serving-tests.json @@ -5,7 +5,6 @@ "server_parameters": { "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy" }, @@ -23,7 +22,6 @@ "server_parameters": { "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy" }, @@ -41,7 +39,6 @@ "server_parameters": { "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", "tensor_parallel_size": 2, - "swap_space": 16, "disable_log_stats": "", "load_format": "dummy" }, @@ -59,7 +56,6 @@ "server_parameters": { "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, - "swap_space": 16, "speculative_config": { "model": "turboderp/Qwama-0.5B-Instruct", "num_speculative_tokens": 4, diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index 867f55fa9ef7..110f580fb7bd 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -145,7 +145,6 @@ def create_minimal_vllm_config( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", enable_prefix_caching=False, ) diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py index 9744b857d96b..7f968cfec148 100644 --- a/benchmarks/attention_benchmarks/runner.py +++ b/benchmarks/attention_benchmarks/runner.py @@ -141,7 +141,6 @@ def _create_vllm_config( cache_config = CacheConfig( block_size=config.block_size, cache_dtype="auto", - swap_space=0, ) cache_config.num_gpu_blocks = max_num_blocks cache_config.num_cpu_blocks = 0 diff --git a/docs/design/metrics.md b/docs/design/metrics.md index a977ce9b9bb2..b24ff64b6783 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -507,10 +507,10 @@ longer relevant in v1: - `vllm:num_requests_swapped` - `vllm:cpu_cache_usage_perc` -In this mode, when a request is preempted (e.g. to make room in KV -cache to complete other requests), we swap kv cache blocks out to CPU -memory. This is also known as "KV cache offloading" and is configured -with `--swap-space` and `--preemption-mode`. +In this mode, when a request was preempted (e.g. to make room in KV +cache to complete other requests), kv cache blocks were swapped out to +CPU memory. The `--swap-space` flag has been removed as this feature +is no longer used in V1. Historically, [vLLM has long supported beam search](https://github.com/vllm-project/vllm/issues/6226). The SequenceGroup encapsulated the idea of N Sequences which diff --git a/docs/serving/integrations/llamaindex.md b/docs/serving/integrations/llamaindex.md index 4b838cbcaa9d..3d669f169e01 100644 --- a/docs/serving/integrations/llamaindex.md +++ b/docs/serving/integrations/llamaindex.md @@ -17,7 +17,7 @@ llm = Vllm( model="microsoft/Orca-2-7b", tensor_parallel_size=4, max_new_tokens=100, - vllm_kwargs={"swap_space": 1, "gpu_memory_utilization": 0.5}, + vllm_kwargs={"gpu_memory_utilization": 0.5}, ) ``` diff --git a/tests/conftest.py b/tests/conftest.py index 1e9d46d3c169..4b907b7dd760 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -794,7 +794,6 @@ def __init__( tensor_parallel_size: int = 1, block_size: int = 16 if not torch.xpu.is_available() else 64, enable_chunked_prefill: bool | None = False, - swap_space: int = 4, enforce_eager: bool | None = False, # Set this to avoid hanging issue default_torch_num_threads: int | None = None, @@ -831,7 +830,6 @@ def __init__( trust_remote_code=trust_remote_code, dtype=dtype, seed=seed, - swap_space=swap_space, enforce_eager=enforce_eager, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index f415409d7b37..8c9898ca20f3 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -22,7 +22,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -# set different `gpu_memory_utilization` and `swap_space` for different ranks, +# set different `gpu_memory_utilization` for different ranks, # to test if all ranks agree on the same kv cache configuration. llm = LLM( model="facebook/opt-125m", @@ -30,7 +30,6 @@ pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), distributed_executor_backend="external_launcher", gpu_memory_utilization=random.uniform(0.7, 0.9), - swap_space=random.randint(1, 4), seed=0, ) diff --git a/tests/distributed/test_torchrun_example_moe.py b/tests/distributed/test_torchrun_example_moe.py index 1aa7f1793570..a6298d1b6739 100644 --- a/tests/distributed/test_torchrun_example_moe.py +++ b/tests/distributed/test_torchrun_example_moe.py @@ -28,7 +28,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -# set different `gpu_memory_utilization` and `swap_space` for different ranks, +# set different `gpu_memory_utilization` for different ranks, # to test if all ranks agree on the same kv cache configuration. llm = LLM( model="microsoft/Phi-mini-MoE-instruct", @@ -37,7 +37,6 @@ enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1, distributed_executor_backend="external_launcher", gpu_memory_utilization=random.uniform(0.7, 0.9), - swap_space=random.randint(1, 4), seed=0, ) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 274142e8d66e..4af3ccf893ff 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -64,7 +64,6 @@ def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): device_config=DeviceConfig("cuda"), cache_config=CacheConfig( block_size=16, - swap_space=0, cache_dtype="auto", ), lora_config=LoRAConfig( diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 3cff52929146..91decf6658a5 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -182,7 +182,6 @@ def create_vllm_config( cache_config = CacheConfig( block_size=block_size, cache_dtype="auto", - swap_space=0, ) # Set cache blocks for testing # (these may be set during initialization normally) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 24edfadb9b53..bbeca6ef7dba 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1776,7 +1776,6 @@ def create_scheduler_with_priority( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", enable_prefix_caching=enable_prefix_caching, ) @@ -3726,7 +3725,6 @@ def _create_encoder_decoder_scheduler( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", enable_prefix_caching=False, ) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 90c174adf8c8..92122bcb0ba4 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -94,7 +94,6 @@ def create_scheduler( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", enable_prefix_caching=enable_prefix_caching, ) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 8d7377c286ac..ae674919ae91 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -506,7 +506,6 @@ def test_encoder_instance_zero_kv_cache( cache_config = CacheConfig( block_size=16, gpu_memory_utilization=gpu_memory_utilization, - swap_space=0, cache_dtype="auto", enable_prefix_caching=enable_prefix_caching, ) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index 7aa824609b7e..2ee224013131 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -206,7 +206,6 @@ def create_vllm_config( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", enable_prefix_caching=True, ) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index d267299815a6..f03d7c479eb2 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -118,7 +118,6 @@ def create_vllm_config( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype=cache_dtype, enable_prefix_caching=True, ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index a2c1466ca61a..c8a6c1301444 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -96,7 +96,6 @@ def get_vllm_config(): cache_config = CacheConfig( block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", ) parallel_config = ParallelConfig() @@ -809,7 +808,6 @@ def test_hybrid_attention_mamba_tensor_shapes(): cache_config = CacheConfig( block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", ) parallel_config = ParallelConfig() @@ -1242,7 +1240,6 @@ def test_cudagraph_sizes_capped_for_mamba_cache(): cache_config = CacheConfig( block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, - swap_space=0, cache_dtype="auto", ) parallel_config = ParallelConfig() diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 8a94141c91b6..71603d8c883e 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -1,21 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from dataclasses import field -from typing import TYPE_CHECKING, Any, Literal +from typing import Literal from pydantic import Field, SkipValidation, field_validator from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils.mem_constants import GiB_bytes -from vllm.utils.mem_utils import format_gib, get_cpu_memory - -if TYPE_CHECKING: - from vllm.config.parallel import ParallelConfig -else: - ParallelConfig = Any logger = init_logger(__name__) @@ -53,8 +45,6 @@ class CacheConfig: not matter if you have another vLLM instance running on the same GPU. For example, if you have two vLLM instances running on the same GPU, you can set the GPU memory utilization to 0.5 for each instance.""" - swap_space: float = Field(default=4, ge=0) - """Size of the CPU swap space per GPU (in GiB).""" cache_dtype: CacheDType = "auto" """Data type for kv cache storage. If "auto", will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports @@ -173,7 +163,6 @@ def compute_hash(self) -> str: ignored_factors = { # Runtime/derived knobs that don't affect compiled graph shape "gpu_memory_utilization", - "swap_space", "is_attention_free", "num_gpu_blocks_override", "enable_prefix_caching", @@ -208,24 +197,3 @@ def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: "scaling factor." ) return cache_dtype - - def verify_with_parallel_config( - self, - parallel_config: ParallelConfig, - ) -> None: - swap_space_bytes = math.ceil(self.swap_space * GiB_bytes) - total_cpu_memory = get_cpu_memory() - # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel - # group are in the same node. However, the GPUs may span multiple nodes. - num_gpus_per_node = parallel_config.tensor_parallel_size - cpu_memory_usage = swap_space_bytes * num_gpus_per_node - - msg = ( - f"{format_gib(cpu_memory_usage)} GiB out of the " - f"{format_gib(total_cpu_memory)} GiB total CPU memory " - "is allocated for the swap space." - ) - if cpu_memory_usage > 0.7 * total_cpu_memory: - raise ValueError("Too large swap space. " + msg) - elif cpu_memory_usage > 0.4 * total_cpu_memory: - logger.warning("Possibly too large swap space. %s", msg) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 34c668362d40..d5b60a566fd3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -674,8 +674,6 @@ def __post_init__(self): self.parallel_config.is_moe_model = self.model_config.is_moe - self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config is not None: self.lora_config.verify_with_model_config(self.model_config) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 09ffd5e121cc..dc1735a01788 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -447,7 +447,6 @@ class EngineArgs: ) disable_sliding_window: bool = ModelConfig.disable_sliding_window disable_cascade_attn: bool = ModelConfig.disable_cascade_attn - swap_space: float = CacheConfig.swap_space offload_backend: str = OffloadConfig.offload_backend cpu_offload_gb: float = UVAOffloadConfig.cpu_offload_gb cpu_offload_params: set[str] = get_field(UVAOffloadConfig, "cpu_offload_params") @@ -961,7 +960,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"] ) - cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) cache_group.add_argument( "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"] @@ -1526,7 +1524,6 @@ def create_engine_config( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, kv_cache_memory_bytes=self.kv_cache_memory_bytes, - swap_space=self.swap_space, cache_dtype=resolved_cache_dtype, # type: ignore[arg-type] is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index eb1d4dbeb365..9c6d6ddcdf75 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -164,12 +164,6 @@ class LLM: compared with using gpu_memory_utilization. Note that kv_cache_memory_bytes (when not-None) ignores gpu_memory_utilization - swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - This can be used for temporarily storing the states of the requests - when their `best_of` sampling parameters are larger than 1. If all - requests will have `best_of=1`, you can safely set this to 0. - Noting that `best_of` is only supported in V0. Otherwise, too small - values may cause out-of-memory (OOM) errors. cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. This virtually increases the GPU memory space you can use to hold the model weights, at the cost of CPU-GPU data @@ -240,7 +234,6 @@ def __init__( chat_template: Path | str | None = None, seed: int = 0, gpu_memory_utilization: float = 0.9, - swap_space: float = 4, cpu_offload_gb: float = 0, offload_group_size: int = 0, offload_num_in_group: int = 1, @@ -265,6 +258,17 @@ def __init__( ) -> None: """LLM constructor.""" + if "swap_space" in kwargs: + kwargs.pop("swap_space") + import warnings + + warnings.warn( + "The 'swap_space' parameter is deprecated and ignored. " + "It will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True @@ -353,7 +357,6 @@ def _make_config(value: Any, cls: type[_R]) -> _R: seed=seed, gpu_memory_utilization=gpu_memory_utilization, kv_cache_memory_bytes=kv_cache_memory_bytes, - swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, offload_group_size=offload_group_size, offload_num_in_group=offload_num_in_group, From 5261223c2d1082fa3facc99c52fc96c0ebcc041b Mon Sep 17 00:00:00 2001 From: Taneem Ibrahim Date: Sat, 7 Mar 2026 08:37:01 -0600 Subject: [PATCH 03/19] [Misc] Remove duplicate parser registration (#36303) Signed-off-by: Taneem Ibrahim --- vllm/parser/__init__.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/parser/__init__.py b/vllm/parser/__init__.py index 8bce3e912cc5..dc256daaa7e2 100644 --- a/vllm/parser/__init__.py +++ b/vllm/parser/__init__.py @@ -22,13 +22,6 @@ ), } -# Register lazy parsers -ParserManager.register_lazy_module( - name="minimax_m2", - module_path="vllm.parser.minimax_m2_parser", - class_name="MiniMaxM2Parser", -) - def register_lazy_parsers(): for name, (file_name, class_name) in _PARSERS_TO_REGISTER.items(): From 85f50eb41fa43783b64e07d768ba3ac6d4ed7a5a Mon Sep 17 00:00:00 2001 From: rahul-sarvam <140298821+rahul-sarvam@users.noreply.github.com> Date: Sun, 8 Mar 2026 01:16:24 +0800 Subject: [PATCH 04/19] Adding support to Sarvam's MoE models (#33942) Signed-off-by: rahul-sarvam <140298821+rahul-sarvam@users.noreply.github.com> --- docs/models/supported_models.md | 2 + tests/models/registry.py | 12 + vllm/model_executor/models/registry.py | 2 + vllm/model_executor/models/sarvam.py | 786 +++++++++++++++++++++++++ 4 files changed, 802 insertions(+) create mode 100644 vllm/model_executor/models/sarvam.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 967f3cfb6ddb..5ceea6228d9e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -469,6 +469,8 @@ th { | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | | `Qwen3NextForCausalLM` | Qwen3NextMoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | | `RWForCausalLM` | Falcon RW | `tiiuae/falcon-40b`, etc. | | ✅︎ | +| `SarvamMoEForCausalLM` | Sarvam 2 | `sarvamai/sarvam2-30b-a3b`, etc. | ✅︎ | ✅︎ | +| `SarvamMLAForCausalLM` | Sarvam 2 | `sarvamai/sarvam2-105b-a9b`, etc. | | ✅︎ | | `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | | `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index 40c4d0d311bc..48e5c251d7a6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -480,6 +480,18 @@ def check_available_online( min_transformers_version="4.56.3", ), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), + "SarvamMoEForCausalLM": _HfExamplesInfo( + "sarvamai/sarvam-30b", + trust_remote_code=True, + max_model_len=4096, + is_available_online=True, + ), + "SarvamMLAForCausalLM": _HfExamplesInfo( + "sarvamai/sarvam-105b", + trust_remote_code=True, + max_model_len=4096, + is_available_online=True, + ), "SeedOssForCausalLM": _HfExamplesInfo( "ByteDance-Seed/Seed-OSS-36B-Instruct", trust_remote_code=True, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 274b18f35a42..29ca31875324 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -191,6 +191,8 @@ "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "SarvamMoEForCausalLM": ("sarvam", "SarvamMoEForCausalLM"), + "SarvamMLAForCausalLM": ("sarvam", "SarvamMLAForCausalLM"), "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step1ForCausalLM": ("step1", "Step1ForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), diff --git a/vllm/model_executor/models/sarvam.py b/vllm/model_executor/models/sarvam.py new file mode 100644 index 000000000000..fa5ec44d7e72 --- /dev/null +++ b/vllm/model_executor/models/sarvam.py @@ -0,0 +1,786 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright 2026 Sarvam AI team. All rights reserved. +# +# This code is based on Llama, Deepseek, and Bailing MoE implementations +# in this library. It has been modified from its original forms to +# accommodate Sarvam's MoE architectures. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Iterable, Iterator +from itertools import islice + +import torch +from torch import nn + +from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors + +from .bailing_moe import BailingMoeForCausalLM +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def _is_gate_expert_bias_name(name: str) -> bool: + return name.endswith(".mlp.gate.e_score_correction_bias") or name.endswith( + ".gate.e_score_correction_bias" + ) + + +def _zero_mean_tensor(t: torch.Tensor) -> torch.Tensor: + if t.numel() == 0: + return t + return t - t.mean() + + +def _normalized_weights( + weights: Iterable[tuple[str, torch.Tensor]], +) -> Iterator[tuple[str, torch.Tensor]]: + for name, w in weights: + if _is_gate_expert_bias_name(name): + yield name, _zero_mean_tensor(w) + else: + yield name, w + + +class SarvamMLAAttention(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + + self.q_lora_rank = getattr(config, "q_lora_rank", None) + self.kv_lora_rank = config.kv_lora_rank + + self.total_num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_local_heads = self.total_num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.max_position_embeddings = config.max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.total_num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) + self.q_proj = None # type: ignore + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.total_num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.q_a_proj = None # type: ignore + self.q_a_layernorm = None # type: ignore + self.q_b_proj = None # type: ignore + + # KV latent (MQA-style) A-proj + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + + # KV B-proj produces per-head K_nope and V + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.total_num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.qk_rope_head_dim, + # rotary_dim=self.qk_rope_head_dim, + max_position=config.max_position_embeddings, + rope_parameters=config.rope_parameters, + is_neox_style=False, + ) + + if config.rope_parameters.get("rope_type", None) == "deepseek_yarn": + mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) + scaling_factor = config.rope_parameters["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=None, + indexer_rotary_emb=None, + is_sparse=False, + topk_indices_buffer=None, + ) + + self.mla_attn = MultiHeadLatentAttentionWrapper( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.mla_attn(positions, hidden_states, llama_4_scaling=None) + + +class SarvamMLAMLP(nn.Module): + def __init__( + self, + intermediate_size: int, + config, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SarvamMLAMoE(nn.Module): + def __init__( + self, + config, + parallel_config: ParallelConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 2.5) + + self.n_group = getattr(config, "n_group", None) + self.topk_group = getattr(config, "topk_group", None) + self.use_grouped_topk = self.n_group is not None and self.topk_group is not None + + self.norm_expert_prob = getattr(config, "norm_topk_prob", True) + + router_dtype_cfg = getattr(config, "router_dtype", "fp32") + if router_dtype_cfg is None: + self.router_dtype = None + elif router_dtype_cfg == "fp32": + self.router_dtype = torch.float32 + else: + self.router_dtype = torch.bfloat16 + + self.gate = nn.Linear( + self.hidden_size, + self.num_experts, + bias=False, + dtype=self.router_dtype, + ) + + if getattr(config, "moe_router_enable_expert_bias", True): + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty( + (self.num_experts,), + dtype=torch.float32, + ) + ) + else: + self.gate.e_score_correction_bias = None + + self.score_function = getattr(config, "score_function", "sigmoid") + self.num_shared_experts = getattr(config, "num_shared_experts", 1) + if self.num_shared_experts > 0: + if hasattr(config, "moe_shared_expert_intermediate_size"): + shared_int = config.moe_shared_expert_intermediate_size + else: + shared_int = config.moe_intermediate_size + shared_int *= self.num_shared_experts + self.shared_experts = SarvamMLAMLP( + intermediate_size=shared_int, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func=self.score_function, + e_score_correction_bias=self.gate.e_score_correction_bias, + num_expert_group=self.n_group, + topk_group=self.topk_group, + use_grouped_topk=self.use_grouped_topk, + routed_scaling_factor=self.routed_scaling_factor, + ) + + def maybe_get_fused_moe(self) -> SharedFusedMoE: + return self.experts + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate( + hidden_states.to(self.router_dtype) + if self.router_dtype is not None + else hidden_states + ) + router_logits = router_logits.to(hidden_states.dtype) + final_hidden = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + + if self.shared_experts is not None: + shared_output, expert_output = final_hidden + else: + shared_output, expert_output = None, final_hidden + + if shared_output is not None: + expert_output = expert_output + shared_output + + if self.tp_size > 1: + expert_output = self.experts.maybe_all_reduce_tensor_model_parallel( + expert_output + ) + + return expert_output.view(num_tokens, hidden_dim) + + +class SarvamMLABlock(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + layer_idx = int(prefix.split(".")[-1]) + hidden_size = config.hidden_size + dense_intermediate = getattr(config, "intermediate_size", 16384) + + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.self_attn = SarvamMLAAttention( + vllm_config=vllm_config, + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + use_moe = hasattr(config, "num_experts") and config.num_experts is not None + first_k_dense = getattr(config, "first_k_dense_replace", 1) + moe_layer_freq = getattr(config, "moe_layer_freq", 1) + if use_moe: + is_moe_layer = layer_idx >= first_k_dense and ( + (layer_idx - first_k_dense) % moe_layer_freq == 0 + ) + else: + is_moe_layer = False + + if is_moe_layer: + self.mlp = SarvamMLAMoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = SarvamMLAMLP( + intermediate_size=dense_intermediate, + config=config, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class SarvamMLAModel(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + if get_pp_group().is_first_rank or ( + self.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.embedding_dropout = torch.nn.Dropout( + getattr(config, "embedding_dropout", 0.0) + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SarvamMLABlock( + vllm_config=vllm_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + hidden_states = self.embedding_dropout(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + hidden_states, + positions, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return SharedFusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + """Load weights with stacked gate+up and MoE expert remapping.""" + weights = _normalized_weights(weights) + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + new_name = name.replace(weight_name, param_name) + if new_name.endswith(".bias") and new_name not in params_dict: + continue + if new_name not in params_dict: + continue + if is_pp_missing_parameter(new_name, self): + continue + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(new_name) + break + else: + mapped = False + for ( + param_name, + weight_name, + expert_id, + shard_id, + ) in expert_params_mapping: + if weight_name not in name: + continue + + new_name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(new_name, self): + continue + if new_name not in params_dict: + continue + + param = params_dict[new_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(new_name) + mapped = True + break + + if mapped: + continue + + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class SarvamMixtureOfExperts(MixtureOfExperts): + def extract_moe_parameters(self, example_moe: SarvamMLAMoE | None) -> None: + if example_moe is None: + raise RuntimeError("No SarvamMLAMoE layer found in model.layers.") + + self.num_logical_experts = example_moe.num_experts + self.num_routed_experts = example_moe.num_experts # routed pool size + self.num_shared_experts = getattr(example_moe.config, "num_shared_experts", 1) + + self.num_physical_experts = self.num_logical_experts + self.num_local_physical_experts = self.num_logical_experts + self.num_redundant_experts = 0 + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + + for moe in self.moe_mlp_layers: + moe.n_physical_experts = num_physical_experts + moe.n_local_physical_experts = num_local_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + + fused = moe.experts + if hasattr(fused, "n_local_physical_experts"): + fused.n_local_physical_experts = num_local_physical_experts + if hasattr(fused, "n_physical_experts"): + fused.n_physical_experts = num_physical_experts + if hasattr(fused, "n_redundant_experts"): + fused.n_redundant_experts = self.num_redundant_experts + if hasattr(fused, "update_expert_map"): + fused.update_expert_map() + + def set_eplb_state(self, eplb_state) -> None: + self.eplb_state = eplb_state + for moe in self.moe_layers: + if hasattr(moe, "set_eplb_state"): + moe.set_eplb_state(eplb_state) + + +class SarvamMLAForCausalLM(nn.Module, SupportsPP, SupportsLoRA, SarvamMixtureOfExperts): + packed_modules_mapping = { + "q_proj": ["q_proj"], + "q_a_proj": ["q_a_proj"], + "q_b_proj": ["q_b_proj"], + "kv_a_proj_with_mqa": ["kv_a_proj_with_mqa"], + "kv_b_proj": ["kv_b_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.model = SarvamMLAModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + if get_pp_group().is_last_rank: + if self.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = None # type: ignore + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + self.expert_weights = [] + self.num_moe_layers = 0 + + self.moe_layers = [] + self.moe_mlp_layers = [] + + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + if isinstance(layer.mlp, SarvamMLAMoE): + example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) + self.moe_layers.append(layer.mlp.experts) + self.num_moe_layers += 1 + + self.extract_moe_parameters(example_moe) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + return self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + if not get_pp_group().is_last_rank: + return None + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +class SarvamMoEForCausalLM(BailingMoeForCausalLM): + """Same as BailingMoeForCausalLM, but normalizes gate expert_bias pre-load.""" + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return super().load_weights(_normalized_weights(weights)) From ebb9cc5f2b26d73222c08e42b32fcf59e831386c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 7 Mar 2026 16:49:23 -0500 Subject: [PATCH 05/19] [UX][Startup] Account for CUDA graphs during memory profiling (#30515) --- vllm/compilation/cuda_graph.py | 20 +- vllm/envs.py | 7 + vllm/v1/cudagraph_dispatcher.py | 7 +- vllm/v1/worker/gpu_model_runner.py | 279 ++++++++++++++++++++++----- vllm/v1/worker/gpu_ubatch_wrapper.py | 13 +- vllm/v1/worker/gpu_worker.py | 95 ++++++++- 6 files changed, 360 insertions(+), 61 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 41db70155e38..13e88448c0f1 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +import weakref from collections import Counter from collections.abc import Callable from contextlib import ExitStack -from typing import Any +from typing import Any, ClassVar from unittest.mock import patch import torch @@ -162,6 +163,14 @@ class CUDAGraphWrapper: guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ + _all_instances: ClassVar[weakref.WeakSet["CUDAGraphWrapper"]] = weakref.WeakSet() + + @classmethod + def clear_all_graphs(cls) -> None: + """Clear captured graphs from all CUDAGraphWrapper instances.""" + for instance in list(cls._all_instances): + instance.clear_graphs() + def __init__( self, runnable: Callable[..., Any], @@ -192,6 +201,8 @@ def __init__( # cudagraphs for. self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {} + CUDAGraphWrapper._all_instances.add(self) + def __getattr__(self, key: str) -> Any: # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): @@ -205,6 +216,13 @@ def unwrap(self) -> Callable[..., Any]: # in case we need to access the original runnable. return self.runnable + @property + def cudagraph_wrapper(self) -> "CUDAGraphWrapper": + return self + + def clear_graphs(self) -> None: + self.concrete_cudagraph_entries.clear() + def __call__(self, *args: Any, **kwargs: Any) -> Any | None: forward_context = get_forward_context() batch_descriptor = forward_context.batch_descriptor diff --git a/vllm/envs.py b/vllm/envs.py index 66ddd7918768..716810da1c27 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -244,6 +244,7 @@ VLLM_CUDA_COMPATIBILITY_PATH: str | None = None VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False + VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False def get_default_cache_root(): @@ -1628,6 +1629,12 @@ def _get_or_set_default() -> str: "VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool( int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0")) ), + # If set to 1, enable CUDA graph memory estimation during memory profiling. + # This profiles CUDA graph memory usage to provide more accurate KV cache + # memory allocation. Disabled by default to preserve existing behavior. + "VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool( + int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0")) + ), } diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index b852808ecd43..701c97d6de42 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -334,8 +334,11 @@ def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]] for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]: descs = list(self.cudagraph_keys[mode]) if descs: - # Sort by num_tokens descending (largest first) - descs.sort(key=lambda d: d.num_tokens, reverse=True) + # Sort by (num_tokens, num_active_loras) descending + descs.sort( + key=lambda d: (d.num_tokens, d.num_active_loras), + reverse=True, + ) result.append((mode, descs)) return result diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index abeb10735129..cf08c13db062 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -29,6 +29,7 @@ CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config, update_config, ) from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer @@ -94,6 +95,7 @@ PlaceholderRange, ) from vllm.multimodal.utils import group_and_batch_mm_kwargs +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors @@ -596,6 +598,17 @@ def __init__( self.async_output_copy_stream = torch.cuda.Stream() self.prepare_inputs_event = torch.Event() + # self.cudagraph_batch_sizes sorts in ascending order. + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + self.cudagraph_batch_sizes = sorted( + self.compilation_config.cudagraph_capture_sizes + ) + else: + self.cudagraph_batch_sizes = [] + # Cache the device properties. self._init_device_properties() @@ -4727,6 +4740,7 @@ def _dummy_run( remove_lora: bool = True, is_graph_capturing: bool = False, num_active_loras: int = 0, + profile_seq_lens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -4751,6 +4765,9 @@ def _dummy_run( remove_lora: If False, dummy LoRAs are not destroyed after the run num_active_loras: Number of distinct active LoRAs to capture for. LoRA is activated when num_active_loras > 0. + profile_seq_lens: If provided, use this value for seq_lens instead + of max_query_len. Used to profile attention workspace that + scales with context length. """ mm_config = self.vllm_config.model_config.multimodal_config if mm_config and mm_config.mm_encoder_only: @@ -4881,11 +4898,13 @@ def _dummy_run( # If force_attention is True, we always capture attention. # Otherwise, it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - if create_mixed_batch: + if profile_seq_lens is not None: + seq_lens = profile_seq_lens # type: ignore[assignment] + elif create_mixed_batch: # In the mixed batch mode (used for FI warmup), we use # shorter sequence lengths to run faster. # TODO(luka) better system for describing dummy batches - seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment] else: seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens @@ -5298,6 +5317,167 @@ def profile_run(self) -> None: self.encoder_cache.clear() gc.collect() + def _init_minimal_kv_cache_for_profiling(self) -> None: + from vllm.v1.core.kv_cache_utils import ( + get_kv_cache_config_from_groups, + get_kv_cache_groups, + ) + + kv_cache_spec = self.get_kv_cache_spec() + kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) + min_blocks = self.compilation_config.max_cudagraph_capture_size or 1 + if kv_cache_groups: + page_size = kv_cache_groups[0].kv_cache_spec.page_size_bytes + group_size = max(len(g.layer_names) for g in kv_cache_groups) + available_memory = min_blocks * page_size * group_size + else: + available_memory = 1 # Attention-free model + + minimal_config = get_kv_cache_config_from_groups( + self.vllm_config, kv_cache_groups, available_memory=available_memory + ) + + self.initialize_kv_cache(minimal_config) + self.cache_config.num_gpu_blocks = minimal_config.num_blocks + + logger.debug("Initialized minimal KV cache for CUDA graph profiling") + + @staticmethod + @contextmanager + def _freeze_gc(): + gc.collect() + should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + gc.collect() + + def _cleanup_profiling_kv_cache(self) -> None: + torch.accelerator.synchronize() + if hasattr(self, "kv_caches") and self.kv_caches: + for i in range(len(self.kv_caches)): + self.kv_caches[i] = None # type: ignore + self.kv_caches.clear() + if hasattr(self, "cross_layers_kv_cache"): + self.cross_layers_kv_cache = None + self.cross_layers_attn_backend = None + if hasattr(self, "attn_groups"): + self.attn_groups.clear() + if hasattr(self, "kv_cache_config"): + delattr(self, "kv_cache_config") + self.cache_config.num_gpu_blocks = None + + for layer in self.compilation_config.static_forward_context.values(): + if hasattr(layer, "kv_cache"): + layer.kv_cache = [] + + gc.collect() + torch.accelerator.empty_cache() + + logger.debug("Cleaned up profiling KV cache and CUDA graphs") + + @torch.inference_mode() + def profile_cudagraph_memory(self) -> int: + with set_current_vllm_config(self.vllm_config): + self._init_minimal_kv_cache_for_profiling() + + saved_num_cudagraph_captured = compilation_counter.num_cudagraph_captured + + capture_descs = self.cudagraph_dispatcher.get_capture_descs() + + total_graphs = sum(len(descs) for _, descs in capture_descs) + if total_graphs == 0: + logger.debug("No CUDA graphs will be captured, skipping profiling") + self._cleanup_profiling_kv_cache() + return 0 + + logger.info( + "Profiling CUDA graph memory: %s", + ", ".join( + f"{mode.name}={len(descs)} (largest={descs[0].num_tokens})" + for mode, descs in capture_descs + if descs + ), + ) + + # Use a temporary pool for profiling to avoid fragmentation in the main pool. + profiling_pool = current_platform.graph_pool_handle() + original_pools: dict[int, Any] = {} + for instance in list(CUDAGraphWrapper._all_instances): + original_pools[id(instance)] = instance.graph_pool + instance.graph_pool = profiling_pool + + set_cudagraph_capturing_enabled(True) + with self._freeze_gc(), graph_capture(device=self.device): + shared_memory_estimate = {} + per_graph_estimate = {} + torch.accelerator.synchronize() + torch.accelerator.empty_cache() + + for mode, descs in capture_descs: + profile_descs = descs[:2] + mem_samples: list[int] = [] + + for i, desc in enumerate(profile_descs): + mem_before = torch.cuda.mem_get_info()[0] + self._warmup_and_capture( + desc, + cudagraph_runtime_mode=mode, + profile_seq_lens=( + min( + self.max_model_len, + self.max_num_tokens // desc.num_tokens, + ) + if mode == CUDAGraphMode.FULL and i == 0 + else None + ), + ) + torch.accelerator.synchronize() + free_after = torch.cuda.mem_get_info()[0] + mem_samples.append(mem_before - free_after) + + first_capture = mem_samples[0] + # Use at least 1 MiB per graph for driver overhead + per_graph = max(mem_samples[1] if len(mem_samples) > 1 else 0, 1 << 20) + + shared_memory_estimate[mode] = first_capture + per_graph_estimate[mode] = per_graph * (len(descs) - 1) + + logger.debug( + "Estimated %s CUDA graph memory: " + "%.2f MiB first-capture + (%d-1) × %.2f MiB per-graph", + mode.name, + first_capture / (1 << 20), + len(descs), + per_graph / (1 << 20), + ) + + set_cudagraph_capturing_enabled(False) + CUDAGraphWrapper.clear_all_graphs() + for instance in list(CUDAGraphWrapper._all_instances): + if id(instance) in original_pools: + instance.graph_pool = original_pools[id(instance)] + self.maybe_remove_all_loras(self.lora_config) + self._cleanup_profiling_kv_cache() + compilation_counter.num_cudagraph_captured = saved_num_cudagraph_captured + + # FULL and PIECEWISE graphs share the global pool at runtime and are + # never replayed concurrently, so the pool overlays their memory. + # Take the max to avoid double-counting the overlap. + total_estimate = max(shared_memory_estimate.values()) + sum( + per_graph_estimate.values() + ) + logger.info( + "Estimated CUDA graph memory: %.2f GiB total", + total_estimate / (1 << 30), + ) + + return int(total_estimate) + @instrument(span_name="Capture model") def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: @@ -5311,27 +5491,13 @@ def capture_model(self) -> int: start_time = time.perf_counter() - @contextmanager - def freeze_gc(): - # Optimize garbage collection during CUDA graph capture. - # Clean up, then freeze all remaining objects from being included - # in future collections. - gc.collect() - should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC - if should_freeze: - gc.freeze() - try: - yield - finally: - if should_freeze: - gc.unfreeze() - gc.collect() - # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) - with freeze_gc(), graph_capture(device=self.device): + with self._freeze_gc(), graph_capture(device=self.device): + torch.accelerator.synchronize() + torch.accelerator.empty_cache() start_free_gpu_memory = torch.cuda.mem_get_info()[0] for ( @@ -5342,6 +5508,7 @@ def freeze_gc(): batch_descriptors=batch_descs, cudagraph_runtime_mode=runtime_mode, ) + torch.accelerator.synchronize() torch.accelerator.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -5353,6 +5520,9 @@ def freeze_gc(): # after here. set_cudagraph_capturing_enabled(False) + torch.accelerator.synchronize() + torch.accelerator.empty_cache() + # Lock workspace to prevent resizing during execution. # Max workspace sizes should have been captured during warmup/profiling. lock_workspace() @@ -5369,6 +5539,40 @@ def freeze_gc(): ) return cuda_graph_size + def _warmup_and_capture( + self, + desc: BatchDescriptor, + cudagraph_runtime_mode: CUDAGraphMode, + profile_seq_lens: int | None = None, + allow_microbatching: bool = False, + num_warmups: int | None = None, + ): + if num_warmups is None: + num_warmups = self.compilation_config.cudagraph_num_of_warmups + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + for _ in range(num_warmups): + self._dummy_run( + desc.num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=desc.uniform, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + num_active_loras=desc.num_active_loras, + ) + self._dummy_run( + desc.num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=desc.uniform, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + num_active_loras=desc.num_active_loras, + is_graph_capturing=True, + profile_seq_lens=profile_seq_lens, + ) + def _capture_cudagraphs( self, batch_descriptors: list[BatchDescriptor], @@ -5383,15 +5587,6 @@ def _capture_cudagraphs( return uniform_decode = batch_descriptors[0].uniform - force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL - - dummy_run = functools.partial( - self._dummy_run, - uniform_decode=uniform_decode, - skip_eplb=True, - remove_lora=False, - force_attention=force_attention, - ) # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -5406,9 +5601,6 @@ def _capture_cudagraphs( # We skip EPLB here since we don't want to record dummy metrics for batch_desc in batch_descriptors: - num_tokens = batch_desc.num_tokens - num_active_loras = batch_desc.num_active_loras - # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched @@ -5419,33 +5611,16 @@ def _capture_cudagraphs( and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, - num_tokens=num_tokens, + num_tokens=batch_desc.num_tokens, uniform_decode=uniform_decode, ) ) - - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE` is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - - dummy_run( - num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - allow_microbatching=allow_microbatching, - num_active_loras=num_active_loras, - ) - - # Capture run - dummy_run( - num_tokens, + self._warmup_and_capture( + batch_desc, cudagraph_runtime_mode=cudagraph_runtime_mode, allow_microbatching=allow_microbatching, - num_active_loras=num_active_loras, - is_graph_capturing=True, ) + torch.accelerator.synchronize() self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 754f2981c9f2..c4cbfff5a42c 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -112,16 +112,25 @@ def __init__( self.cudagraphs: dict[int, CUDAGraphMetaData] = {} self.cudagraph_wrapper = None - self.graph_pool = None if runtime_mode is not CUDAGraphMode.NONE: self.cudagraph_wrapper = CUDAGraphWrapper( runnable, vllm_config, runtime_mode=runtime_mode ) - self.graph_pool = current_platform.get_global_graph_pool() self.sm_control = self._create_sm_control_context(vllm_config) self.device = device + @property + def graph_pool(self): + if self.cudagraph_wrapper is not None: + return self.cudagraph_wrapper.graph_pool + return None + + def clear_graphs(self) -> None: + self.cudagraphs.clear() + if self.cudagraph_wrapper is not None: + self.cudagraph_wrapper.clear_graphs() + @staticmethod def _create_sm_control_context(vllm_config: VllmConfig): comm_sms: int = envs.VLLM_DBO_COMM_SMS diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index e56905fe763d..929474e4f1f1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -44,6 +44,7 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.tracing import instrument +from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling from vllm.utils.torch_utils import set_random_seed from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -390,8 +391,36 @@ def determine_available_memory(self) -> int: ) as profile_result: self.model_runner.profile_run() + profile_torch_peak = current_platform.memory_stats(self.device).get( + "allocated_bytes.all.peak", 0 + ) + + # Profile CUDA graph memory if graphs will be captured. + cudagraph_memory_estimate = 0 + if not self.model_config.enforce_eager: + cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory() + + # Use the pre-cudagraph torch peak to avoid double-counting. + profile_result.torch_peak_increase = ( + profile_torch_peak - profile_result.before_profile.torch_peak + ) + profile_result.non_kv_cache_memory = ( + profile_result.non_torch_increase + + profile_result.torch_peak_increase + + profile_result.weights_memory + ) + + cudagraph_memory_estimate_applied = ( + cudagraph_memory_estimate + if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS + else 0 + ) + self.non_torch_memory = profile_result.non_torch_increase - self.peak_activation_memory = profile_result.torch_peak_increase + self.peak_activation_memory = ( + profile_result.torch_peak_increase + cudagraph_memory_estimate_applied + ) + self.cudagraph_memory_estimate = cudagraph_memory_estimate free_gpu_memory = profile_result.after_profile.free_memory # NOTE(woosuk): Here we assume that the other processes using the same @@ -406,7 +435,9 @@ def determine_available_memory(self) -> int: "isolate vLLM in its own container." ) self.available_kv_cache_memory_bytes = ( - self.requested_memory - profile_result.non_kv_cache_memory + self.requested_memory + - profile_result.non_kv_cache_memory + - cudagraph_memory_estimate_applied ) unrequested_memory = self.init_snapshot.free_memory - self.requested_memory @@ -428,6 +459,46 @@ def determine_available_memory(self) -> int: scope="local", ) + if cudagraph_memory_estimate > 0: + total_mem = self.init_snapshot.total_memory + current_util = self.cache_config.gpu_memory_utilization + cg_util_delta = cudagraph_memory_estimate / total_mem + if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: + equiv_util = round(current_util - cg_util_delta, 4) + suggested_util = min( + round(current_util + cg_util_delta, 4), + 1.0, + ) + logger.info( + "CUDA graph memory profiling is enabled " + "(VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). " + "This will become the default in v0.19. " + "The current --gpu-memory-utilization=%.4f is equivalent " + "to --gpu-memory-utilization=%.4f without CUDA graph " + "memory profiling. To maintain the same effective KV " + "cache size as before, increase " + "--gpu-memory-utilization to %.4f.", + current_util, + equiv_util, + suggested_util, + ) + else: + suggested_util = min( + round(current_util + cg_util_delta, 4), + 1.0, + ) + logger.info( + "In v0.19, CUDA graph memory profiling will be enabled " + "by default (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1), " + "which more accurately accounts for CUDA graph memory " + "during KV cache allocation. To try it now, set " + "VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 and increase " + "--gpu-memory-utilization from %.4f to %.4f to maintain " + "the same effective KV cache size.", + current_util, + suggested_util, + ) + return int(self.available_kv_cache_memory_bytes) def get_kv_connector_handshake_metadata(self) -> dict | None: @@ -487,14 +558,14 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: @instrument(span_name="Warmup (GPU)") def compile_or_warm_up_model(self) -> float: - warmup_sizes = [] + warmup_sizes: list[int] = [] if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: # warm up sizes that are not in cudagraph capture sizes, # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. compile_sizes = self.vllm_config.compilation_config.compile_sizes - warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] + warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] # type: ignore[assignment] cg_capture_sizes: list[int] = [] if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: @@ -526,6 +597,22 @@ def compile_or_warm_up_model(self) -> float: if not self.model_config.enforce_eager: cuda_graph_memory_bytes = self.model_runner.capture_model() + # Compare actual vs estimated CUDA graph memory (if we did profiling) + if ( + hasattr(self, "cudagraph_memory_estimate") + and self.cudagraph_memory_estimate > 0 + ): + GiB = lambda b: round(b / GiB_bytes, 2) + diff = abs(cuda_graph_memory_bytes - self.cudagraph_memory_estimate) + logger.info( + "CUDA graph pool memory: %s GiB (actual), %s GiB (estimated), " + "difference: %s GiB (%.1f%%).", + GiB(cuda_graph_memory_bytes), + GiB(self.cudagraph_memory_estimate), + GiB(diff), + 100 * diff / max(cuda_graph_memory_bytes, 1), + ) + if self.cache_config.kv_cache_memory_bytes is None and hasattr( self, "peak_activation_memory" ): From eebd14651f7618eddda5e79eab2d4ea0cdcc1770 Mon Sep 17 00:00:00 2001 From: qli88 Date: Sat, 7 Mar 2026 15:49:56 -0600 Subject: [PATCH 06/19] [CI] Enable Crosslayer KV layout tests for ROCm platforms (#35416) --- .buildkite/test-amd.yaml | 28 ++++++++++++++++++ .../config_sweep_accuracy_test.sh | 29 ++++++++++--------- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index f69713a335df..9323310b411b 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1486,6 +1486,20 @@ steps: - uv pip install --system -r /vllm-workspace/requirements/kv_connectors_rocm.txt - DP_EP=1 ROCM_ATTN=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +- label: CrossLayer KV layout Distributed NixlConnector PD accuracy tests (4 GPUs) + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_4 + # grade: Blocking + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_devices: 4 + source_file_dependencies: + - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - tests/v1/kv_connector/nixl_integration/ + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors_rocm.txt + - CROSS_LAYERS_BLOCKS=1 ROCM_ATTN=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh + ##### multi gpus test ##### ##### A100 test ##### @@ -3136,6 +3150,20 @@ steps: - uv pip install --system -r /vllm-workspace/requirements/kv_connectors_rocm.txt - DP_EP=1 ROCM_ATTN=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +- label: CrossLayer KV layout Distributed NixlConnector PD accuracy tests (4 GPUs) + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi355_4 + # grade: Blocking + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_devices: 4 + source_file_dependencies: + - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - tests/v1/kv_connector/nixl_integration/ + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors_rocm.txt + - CROSS_LAYERS_BLOCKS=1 ROCM_ATTN=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh + ##### multi gpus test ##### ##### A100 test ##### diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index c35f4bfe8890..684e2ec4d7b9 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -56,24 +56,27 @@ run_tests() { echo "✅ All ${label} tests passed!" } -# Run tests +# Set backend +label="default backend" +cmdline_args="" if [[ -n "${ROCM_ATTN:-}" ]]; then echo "ROCM_ATTN is set, running with --attention-backend ROCM_ATTN" - run_tests "ROCM_ATTN backend" "--attention-backend ROCM_ATTN" -else - run_tests "default backend" "" -fi - -# Check if FLASHINFER is set (non-empty) -if [[ -n "${FLASHINFER:-}" ]]; then - echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER" - run_tests "FLASHINFER backend" "--attention-backend FLASHINFER" + label="ROCM_ATTN backend" + cmdline_args=" --attention-backend ROCM_ATTN " +elif [[ -n "${FLASHINFER:-}" ]]; then + echo "FLASHINFER is set, running with --attention-backend FLASHINFER" + label="FLASHINFER backend" + cmdline_args=" --attention-backend FLASHINFER " else - echo "FLASHINFER not set, skipping FLASHINFER runs." + echo "running with default attention backend" fi # Check if cross-layers is enabled (non-empty) if [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then - echo "CROSS_LAYERS_BLOCKS is set, rerunning with --enable-cross-layers" - run_tests "default backend" "--enable-cross-layers" + echo "CROSS_LAYERS_BLOCKS is set, running with --enable-cross-layers" + label+=" - CROSS_LAYERS_BLOCKS enabled" + cmdline_args+=" --enable-cross-layers " fi + +# Run tests +run_tests "${label}" "${cmdline_args}" From fc4657756ff01fec770433530a5dd2a238e7e034 Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Sat, 7 Mar 2026 15:50:17 -0600 Subject: [PATCH 07/19] [ROCm][CI] Enable AITER for failing `test_gpt_oss` test case on MI355 (#36174) --- tests/models/quantization/test_gpt_oss.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/models/quantization/test_gpt_oss.py b/tests/models/quantization/test_gpt_oss.py index 7599a5a5ee4c..21cc9555bfde 100644 --- a/tests/models/quantization/test_gpt_oss.py +++ b/tests/models/quantization/test_gpt_oss.py @@ -21,6 +21,7 @@ import pytest from packaging import version +from vllm.platforms.rocm import on_gfx950 from vllm.utils.torch_utils import cuda_device_count_stateless MODEL_ACCURACIES = { @@ -83,11 +84,17 @@ def get_model_args(self, tp_size: int): @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items()) def test_gpt_oss_attention_quantization( - model_name: str, tp_size: int, expected_accuracy: float + model_name: str, + tp_size: int, + expected_accuracy: float, + monkeypatch: pytest.MonkeyPatch, ): if tp_size > cuda_device_count_stateless(): pytest.skip("Not enough GPUs to run this test case") + if "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8" in model_name and on_gfx950(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + model_args = EvaluationConfig(model_name).get_model_args(tp_size) extra_run_kwargs = { From ee54f9cdb91f04350bba0cf11890b02b12c62baa Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Sat, 7 Mar 2026 15:50:52 -0600 Subject: [PATCH 08/19] [ROCm][CI] Accept Different But Valid Output for `test_olmoe_tp` (#35224) --- tests/lora/test_olmoe_tp.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py index 5e38638b9b6f..492716b46451 100644 --- a/tests/lora/test_olmoe_tp.py +++ b/tests/lora/test_olmoe_tp.py @@ -3,6 +3,7 @@ import shutil +from collections.abc import Sequence import pytest import torch @@ -15,7 +16,7 @@ MODEL_PATH = "allenai/OLMoE-1B-7B-0125-Instruct" -PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me. Do not return any additional explanation. Below is an instruction that describes a task, Write a response that appropriately completes the request. " ##Instruction: candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key. @@ -39,10 +40,20 @@ "SELECT COUNT(Candidate_ID) FROM candidate", "SELECT COUNT(Candidate_ID) FROM candidate", "SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501 - "SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501 + # There are multiple acceptable responses + ( + "SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501 + "SELECT Candidate_ID, Poll_Source FROM candidate WHERE COUNT(People_ID) = (SELECT COUNT(People_ID) FROM people) ORDER BY Candidate_ID DESC LIMIT 1", # noqa: E501 + ), ] +def _output_matches(generated: str, accepted: str | Sequence[str]) -> bool: + if isinstance(accepted, str): + accepted = (accepted,) + return any(generated.startswith(s) for s in accepted) + + def generate_and_test( llm: vllm.LLM, lora_path: str, @@ -90,9 +101,13 @@ def generate_and_test( if compare_lower: generated_text = generated_text.lower() - expected_output = expected_output.lower() - - assert generated_text.startswith(expected_output) + if isinstance(expected_output, str): + expected_output = (expected_output.lower(),) + else: + expected_output = tuple(s.lower() for s in expected_output) + assert _output_matches(generated_text, expected_output), ( + f"Output {i}: {generated_text!r} does not match any of {expected_output!r}" + ) def test_olmoe_lora(olmoe_lora_files): From a6be75dbd2a8dd1886da725727ee178f42e3f84f Mon Sep 17 00:00:00 2001 From: PatchyTIS <58251192+PatchouliTIS@users.noreply.github.com> Date: Sun, 8 Mar 2026 05:51:37 +0800 Subject: [PATCH 09/19] [Core] NGram GPU Implementation compatible with Async Scheduler (#29184) --- tests/v1/e2e/test_async_scheduling.py | 43 +- tests/v1/e2e/test_spec_decode.py | 28 + vllm/compilation/backends.py | 7 + vllm/config/speculative.py | 10 +- vllm/config/vllm.py | 7 +- vllm/tool_parsers/hermes_tool_parser.py | 2 + vllm/v1/spec_decode/ngram_proposer_gpu.py | 660 ++++++++++++++++++++++ vllm/v1/worker/gpu_input_batch.py | 8 +- vllm/v1/worker/gpu_model_runner.py | 187 +++++- 9 files changed, 940 insertions(+), 12 deletions(-) create mode 100644 vllm/v1/spec_decode/ngram_proposer_gpu.py diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 042e953866cf..c703d6aae9f9 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -98,7 +98,7 @@ def test_without_spec_decoding( @single_gpu_only @large_gpu_mark(min_gb=16) -def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch): +def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch): """Test consistency and acceptance rates with some different combos of preemption, executor, async scheduling, prefill chunking, spec decoding model length. @@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch) ) +def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch): + """Test ngram_gpu speculative decoding with different configurations. + + This test specifically validates ngram_gpu behavior with various: + - Number of speculative tokens (2-6) + - Prompt lookup window sizes (min/max) + - Async scheduling enabled (as in production) + - Different executors and chunking settings + """ + + # Variant with larger speculation window + ngram_gpu_config = { + "method": "ngram_gpu", + "num_speculative_tokens": 3, + "prompt_lookup_max": 3, + "prompt_lookup_min": 2, + } + + # Test configurations covering various scenarios + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (False, "mp", False, ngram_gpu_config, False), + (True, "mp", False, ngram_gpu_config, True), + (False, "mp", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, False), + (True, "uni", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, True), + ] + + # Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight + # and ngram_gpu doesn't require a specific draft model + run_tests(monkeypatch, MODEL, test_configs, [{}]) + + @dynamo_config.patch(cache_size_limit=16) def run_tests( monkeypatch: pytest.MonkeyPatch, @@ -282,11 +318,12 @@ def run_test( else dict(gpu_memory_utilization=0.9) ) spec_mml = (spec_config or {}).get("max_model_len") + spec_method = (spec_config or {}).get("method", "none") test_config = ( f"executor={executor}, preemption={test_preemption}, " f"async_sched={async_scheduling}, " f"chunk_prefill={test_prefill_chunking}, " - f"spec_decoding={spec_decoding}, spec_mml={spec_mml}" + f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}" ) print("-" * 80) print(f"---- TESTING {test_str}: {test_config}") @@ -294,7 +331,7 @@ def run_test( with VllmRunner( model, - max_model_len=512, + max_model_len=4096, enable_chunked_prefill=test_prefill_chunking, # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4066dfe9e34d..3988070ca759 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness( cleanup_dist_env_and_memory() +@pytest.mark.parametrize("async_scheduling", [True], ids=["async"]) +@single_gpu_only +@large_gpu_mark(min_gb=20) +def test_ngram_gpu_default_with_async_scheduling( + async_scheduling: bool, +): + """ + Test ngram_gpu speculative decoding (k=3) correctness with and without + async scheduling, validated via GSM8K accuracy. + Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%). + """ + qwen3_model = "Qwen/Qwen3-8B" + spec_llm = LLM( + model=qwen3_model, + speculative_config={ + "method": "ngram_gpu", + "prompt_lookup_max": 3, + "prompt_lookup_min": 2, + "num_speculative_tokens": 2, + }, + max_model_len=4096, + async_scheduling=async_scheduling, + ) + evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8) + del spec_llm + cleanup_dist_env_and_memory() + + @single_gpu_only @large_gpu_mark(min_gb=20) def test_suffix_decoding_acceptance( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9d37a5331c96..2bf53a7fad74 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -907,6 +907,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. disable_cache = not is_compile_cache_enabled(self.inductor_config) + # TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors. + is_ngram_gpu_enabled = ( + vllm_config.speculative_config is not None + and vllm_config.speculative_config.use_ngram_gpu() + ) + disable_cache = disable_cache or is_ngram_gpu_enabled + if disable_cache: logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") else: diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a950ba531ad2..27b5188eb52d 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -47,6 +47,7 @@ "step3p5_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes] +NgramGPUTypes = Literal["ngram_gpu"] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -54,6 +55,7 @@ "draft_model", "suffix", EagleModelTypes, + NgramGPUTypes, ] @@ -364,6 +366,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "ngram_gpu": + self.model = "ngram_gpu" elif self.method == "suffix": self.model = "suffix" elif self.method == "extract_hidden_states": @@ -374,8 +378,9 @@ def __post_init__(self): ) if self.method in ("ngram", "[ngram]"): - # Unified to "ngram" internally self.method = "ngram" + + if self.method in ("ngram", "ngram_gpu"): # Set default values if not provided if self.prompt_lookup_min is None and self.prompt_lookup_max is None: # TODO(woosuk): Tune these values. They are arbitrarily chosen. @@ -832,6 +837,9 @@ def uses_draft_model(self) -> bool: def uses_extract_hidden_states(self) -> bool: return self.method == "extract_hidden_states" + def use_ngram_gpu(self) -> bool: + return self.method == "ngram_gpu" + def __repr__(self) -> str: method = self.method model = ( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d5b60a566fd3..16f2c375d5fd 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -41,7 +41,7 @@ from .parallel import ParallelConfig from .profiler import ProfilerConfig from .scheduler import SchedulerConfig -from .speculative import EagleModelTypes, SpeculativeConfig +from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig from .structured_outputs import StructuredOutputsConfig from .utils import SupportsHash, config, replace from .weight_transfer import WeightTransferConfig @@ -696,11 +696,13 @@ def __post_init__(self): if self.speculative_config is not None: if ( self.speculative_config.method not in get_args(EagleModelTypes) + and self.speculative_config.method not in get_args(NgramGPUTypes) and self.speculative_config.method != "draft_model" ): raise ValueError( "Currently, async scheduling is only supported " - "with EAGLE/MTP/Draft Model kind of speculative decoding." + "with EAGLE/MTP/Draft Model/NGram GPU kind of " + "speculative decoding" ) if self.speculative_config.disable_padded_drafter_batch: raise ValueError( @@ -718,6 +720,7 @@ def __post_init__(self): if ( self.speculative_config is not None and self.speculative_config.method not in get_args(EagleModelTypes) + and self.speculative_config.method not in get_args(NgramGPUTypes) ): logger.warning_once( "Async scheduling not supported with %s-based " diff --git a/vllm/tool_parsers/hermes_tool_parser.py b/vllm/tool_parsers/hermes_tool_parser.py index b9b1dcda6f68..5bde5b2c07ab 100644 --- a/vllm/tool_parsers/hermes_tool_parser.py +++ b/vllm/tool_parsers/hermes_tool_parser.py @@ -385,6 +385,7 @@ def extract_tool_calls_streaming( prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( "arguments" ) + assert current_tool_call is not None cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -489,6 +490,7 @@ def extract_tool_calls_streaming( # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration + assert isinstance(current_tool_call, dict) if self.current_tool_id == len(self.prev_tool_call_arr) - 1: self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py new file mode 100644 index 000000000000..3ff84180463d --- /dev/null +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -0,0 +1,660 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GPU-accelerated N-gram proposer using fully async PyTorch tensor operations. + +This version uses a fully vectorized approach with unfold and argmax for +finding the first match across all sequences in parallel. +""" + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CompilationConfig, + CompilationMode, + CUDAGraphMode, + VllmConfig, +) +from vllm.forward_context import set_forward_context +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + + +@support_torch_compile() +class NgramGPUKernel(nn.Module): + """GPU-accelerated N-gram proposer using fully async tensor operations.""" + + def __init__( + self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda" + ): + super().__init__() + + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.device = device + + def _find_first_and_extract_all_n_parallel( + self, + token_ids: torch.Tensor, + seq_lengths: torch.Tensor, + min_ngram_len: int, + max_ngram_len: int, + num_draft_tokens: int, + ) -> torch.Tensor: + """ + Find suffix n-gram matches and extract following tokens. + Searches for the earliest prior occurrence of the trailing n-gram, + tries multiple lengths, and picks the longest valid match. + + Args: + token_ids: Token IDs for each sequence + seq_lengths: Actual length of each sequence (excluding padding) + min_ngram_len: Minimum n-gram size to search for (e.g., 2) + max_ngram_len: Maximum n-gram size to search for (e.g., 5) + num_draft_tokens: Number of tokens to extract after match (k) + + Returns: + Draft token predictions; -1 means invalid/no match. + """ + batch_size = token_ids.shape[0] + max_seq_len = token_ids.shape[1] + device = token_ids.device + num_ngram_sizes = max_ngram_len - min_ngram_len + 1 + + # All n-gram sizes to try. + ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device) + batch_indices = torch.arange(batch_size, device=device) + + # Earliest match per (sequence, ngram_len); -1 means no match. + first_match_positions = torch.full( + (batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device + ) + + for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)): + # Sliding windows of size ngram_len; unfold is O(1) view. + search_windows = token_ids.unfold(1, ngram_len, 1) + num_windows = search_windows.shape[1] + + # Trailing suffix (last ngram_len tokens) for each sequence. + suffix_starts = seq_lengths - ngram_len + suffix_indices = suffix_starts.unsqueeze(1) + torch.arange( + ngram_len, device=device + ) + suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0)) + + # Window matches for each sequence. + matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1) + + # Match must leave room for at least one draft token. + max_valid_suffix_start = seq_lengths - ngram_len - 1 + window_positions = torch.arange(num_windows, device=device) + valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1) + final_matches = matches & valid_mask + + # Find earliest match (argmax=0 when empty; verify with has_match). + first_match_idx = torch.argmax(final_matches.int(), dim=1) + has_match = final_matches[batch_indices, first_match_idx] + + # Store valid match positions (window index = position). + first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1) + + # Select the longest n-gram with a match. + best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1) + best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx # Flip back + + # Match position for the best n-gram. + best_match_pos = first_match_positions[batch_indices, best_ngram_idx] + + # Avoid data-dependent branching. + has_any_match = best_match_pos >= 0 + + # Length of the best matching n-gram. + best_ngram_lengths = ngram_lengths[best_ngram_idx] + + # Start position right after the matched suffix. + draft_start = torch.where( + has_any_match, + best_match_pos + best_ngram_lengths, + torch.zeros_like(best_match_pos), + ) + tokens_available = seq_lengths - draft_start + + # Gather indices for draft tokens. + draft_indices = draft_start.unsqueeze(1) + torch.arange( + num_draft_tokens, device=device + ) + draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1) + + # Extract draft tokens; gather always runs. + draft_tokens = torch.gather(token_ids, 1, draft_indices) + + # Mask positions beyond available tokens. + position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0) + valid_positions = position_indices < tokens_available.unsqueeze(1) + + draft_tokens = torch.where( + valid_positions, + draft_tokens, + torch.full_like(draft_tokens, -1), + ) + + # If no match, mask all positions. + draft_tokens = torch.where( + has_any_match.unsqueeze(1), + draft_tokens, + torch.full_like(draft_tokens, -1), + ) + + return draft_tokens + + def forward( + self, + num_tokens_no_spec: torch.Tensor, + token_ids_gpu: torch.Tensor, + combined_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for N-gram proposal using GPU tensor operations. + + Args: + num_tokens_no_spec: Number of tokens for each sequence [batch_size] + token_ids_gpu: Token IDs [batch_size, max_len] + combined_mask: Whether each sequence is valid for spec decode [batch_size] + + Returns: + draft_tokens: [batch_size, k] on GPU + num_valid_draft_tokens: [batch_size] int32 on GPU, count of + leading valid (non -1) tokens per request. + """ + + device = token_ids_gpu.device + + # Infer batch size to preserve dynamic shape. + actual_batch_size = token_ids_gpu.shape[0] + + # Allocate in forward so torch.compile can optimize. + # NOTE(patchy): Do NOT pre-allocate this as a buffer + # it breaks torch.compile + draft_tokens = torch.full( + (actual_batch_size, self.k), -1, dtype=torch.int32, device=device + ) + + results = self._find_first_and_extract_all_n_parallel( + token_ids_gpu, + num_tokens_no_spec, + min_ngram_len=self.min_n, + max_ngram_len=self.max_n, + num_draft_tokens=self.k, + ) + + draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1) + + # Count leading contiguous valid (non -1) tokens per request. + is_valid = draft_tokens != -1 # [batch, k] + cum_valid = is_valid.int().cumsum(dim=1) # [batch, k] + positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0) + num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1) + + return draft_tokens, num_valid_draft_tokens + + def load_model(self, *args, **kwargs): + """No model to load for N-gram proposer.""" + pass + + +class NgramProposerGPU: + def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + compilation_config = CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["none"], + splitting_ops=[], + compile_sizes=[], + inductor_compile_config={ + "enable_auto_functionalized_v2": False, + "max_autotune": True, + "aggressive_fusion": True, + "triton.autotune_pointwise": True, + "coordinate_descent_tuning": True, + "use_mixed_mm": False, + }, + cudagraph_mode=CUDAGraphMode.NONE, + ) + model_config = vllm_config.model_config + speculative_config = vllm_config.speculative_config + scheduler_config = vllm_config.scheduler_config + + self.vllm_config = VllmConfig( + compilation_config=compilation_config, + model_config=model_config, + speculative_config=speculative_config, + scheduler_config=scheduler_config, + ) + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.device = device + + self.kernel = NgramGPUKernel( + vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device + ) + self.kernel.to(device) + self.kernel.eval() + + self._dummy_run() + + def _dummy_run(self): + token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data( + batch_size=self.max_num_seqs, + max_seq_len=self.max_model_len, + pattern_len=self.k, + device=self.device, + ) + + combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n) + + for _ in range(3): + with set_forward_context(None, self.vllm_config): + _, _ = self.kernel(num_tokens, token_ids, combined_mask) + + def _generate_dummy_data( + self, + batch_size: int, + max_seq_len: int, + pattern_len: int, + device: str = "cuda", + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate random test data with n-gram repetitions. + + Args: + batch_size: Number of sequences in the batch + max_seq_len: Maximum sequence length + pattern_len: Length of patterns to inject for matching + device: Device to place tensors on + + Returns: + token_ids: [batch_size, max_seq_len] tensor + num_tokens: [batch_size] tensor + sampled_flags: [batch_size] bool tensor + valid_mask: [batch_size] bool tensor + """ + token_ids = torch.zeros( + batch_size, + max_seq_len, + dtype=torch.int32, + device=device, + ) + + num_tokens = torch.randint( + pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device + ) + + sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device) + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + return token_ids, num_tokens, sampled_flags, valid_mask + + def propose( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] + token_ids_gpu: torch.Tensor, # [batch_size, max_len] + valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1] + valid_sampled_tokens_count: torch.Tensor, # [batch_size] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Propose draft tokens using GPU-accelerated n-gram matching. + + Scatter sampled tokens into `token_ids_gpu`, compute temporary + updated lengths, then run the kernel. + + Args: + num_tokens_no_spec: Number of tokens per sequence (read-only) + token_ids_gpu: Token IDs tensor (modified in-place with new tokens) + valid_sampled_token_ids_gpu: Newly sampled tokens to scatter + valid_sampled_tokens_count: Count of valid tokens per sequence + + Returns: + draft_tokens: Proposed draft token IDs [batch_size, k] + num_valid_draft_tokens: Count of leading valid draft tokens + per request [batch_size] + """ + assert token_ids_gpu.device == self.device + assert num_tokens_no_spec.device == self.device + + batch_size = num_tokens_no_spec.shape[0] + max_seq_len = token_ids_gpu.shape[1] + max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1 + + # Scatter newly sampled tokens into token_ids_gpu. + offsets = torch.arange(max_new_tokens, device=self.device) + write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0) + valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze( + 1 + ) + in_bounds = write_positions < max_seq_len + scatter_mask = ( + valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds + ) + + write_positions_long = write_positions.clamp(max=max_seq_len - 1).long() + existing_values = token_ids_gpu.gather(1, write_positions_long) + + tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype) + tokens_to_scatter = torch.where( + scatter_mask, + tokens_cast, + existing_values, + ) + token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter) + + num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count + + # Compute validity masks. + sampled_flags = valid_sampled_tokens_count > 0 + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device) + + with set_forward_context(None, self.vllm_config): + combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n) + + with record_function_or_nullcontext("ngram_proposer_gpu: kernel"): + draft_tokens, num_valid_draft_tokens = self.kernel( + num_tokens_tmp, + token_ids_gpu, + combined_mask, + ) + + return draft_tokens, num_valid_draft_tokens + + def update_token_ids_ngram( + self, + sampled_token_ids: torch.Tensor | list[list[int]], + gpu_input_batch: InputBatch, + token_ids_gpu: torch.Tensor, + num_tokens_no_spec: torch.Tensor, + discard_request_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare speculative decoding inputs on device: + compute next token ids and valid counts, honoring discarded requests + and rejected tokens, without CPU-GPU sync. + """ + num_reqs = gpu_input_batch.num_reqs + + if isinstance(sampled_token_ids, list): + # When disable_padded_drafter_batch=True, sampled_token_ids is + # an irregular list[list[int]] where sublists may have different + # lengths (including empty lists for discarded requests). + # Pad all sublists to the same length with -1 before converting + # to tensor. + max_len = max( + (len(sublist) for sublist in sampled_token_ids), + default=0, + ) + # Ensure at least length 1 for tensor creation + max_len = max(max_len, 1) + padded_list = [ + sublist + [-1] * (max_len - len(sublist)) + for sublist in sampled_token_ids + ] + sampled_token_ids = torch.tensor( + padded_list, dtype=torch.int32, device=self.device + ) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for ngram_gpu" + ) + + # Backup last valid token before speculative tokens. + backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long() + backup_next_token_ids = torch.gather( + token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1) + ).squeeze(1) + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + # Invalidate sampled tokens for discarded requests. + discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1) + valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1) + + # Mask valid tokens within each request. + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) + + # Count valid tokens per request. + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Rightmost valid index per row. + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Last valid token from each row; undefined if none. + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) + + # Use last token if valid; otherwise fallback to backup. + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + backup_next_token_ids, + ) + + return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu + + def load_model(self, *args, **kwargs): + self.kernel.load_model(*args, **kwargs) + + +def update_scheduler_for_invalid_drafts( + num_valid_draft_tokens_event: torch.cuda.Event, + num_valid_draft_tokens_cpu: torch.Tensor, + scheduler_output: "SchedulerOutput", + req_id_to_index: dict[str, int], +) -> None: + """Trim invalid speculative slots using per-request valid draft counts. + + Args: + num_valid_draft_tokens_event: Event for async D2H completion. + num_valid_draft_tokens_cpu: CPU buffer of valid draft counts. + scheduler_output: Scheduler metadata to update in-place. + req_id_to_index: Request-id to batch-index mapping. + """ + req_data = scheduler_output.scheduled_cached_reqs + num_valid_draft_tokens_event.synchronize() + + for req_id in req_data.req_ids: + req_index = req_id_to_index.get(req_id) + if req_index is None: + continue + + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id) + if spec_token_ids is None: + continue + + scheduled_k = len(spec_token_ids) + + valid_k = int(num_valid_draft_tokens_cpu[req_index].item()) + valid_k = max(0, min(valid_k, scheduled_k)) + + tokens_to_trim = scheduled_k - valid_k + scheduler_output.total_num_scheduled_tokens -= tokens_to_trim + scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim + + if valid_k == 0: + scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) + else: + scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[ + :valid_k + ] + + +def update_ngram_gpu_tensors_incremental( + input_batch: InputBatch, + token_ids_gpu_tensor: torch.Tensor, + num_tokens_no_spec_gpu: torch.Tensor, + new_reqs: list[CachedRequestState], + device: torch.device, + _pinned_idx_buf: torch.Tensor, + _pinned_val_buf: torch.Tensor, +) -> None: + """Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu + for ngram GPU proposer. + """ + prev_req_id_to_index = input_batch.prev_req_id_to_index + curr_req_id_to_index = input_batch.req_id_to_index + + if not curr_req_id_to_index: + return + + active_indices = list(curr_req_id_to_index.values()) + n_active = len(active_indices) + + # Use resident pinned buffers to avoid per-call allocation. + active_idx_cpu = _pinned_idx_buf[:n_active] + active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long)) + + active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True) + + new_req_ids = {req.req_id for req in new_reqs} + + # First run, no previous state. + if prev_req_id_to_index is None: + for idx in active_indices: + num_tokens = input_batch.num_tokens_no_spec[idx] + if num_tokens > 0: + token_ids_gpu_tensor[idx, :num_tokens].copy_( + input_batch.token_ids_cpu_tensor[idx, :num_tokens], + non_blocking=True, + ) + + _sync_num_tokens( + input_batch, + num_tokens_no_spec_gpu, + active_idx_cpu, + active_idx_gpu, + n_active, + device, + _pinned_val_buf, + ) + return + + # Detect index changes for reorder. + reorder_src: list[int] = [] + reorder_dst: list[int] = [] + + for req_id, curr_idx in curr_req_id_to_index.items(): + if req_id in new_req_ids: + continue + prev_idx = prev_req_id_to_index.get(req_id) + if prev_idx is not None and prev_idx != curr_idx: + reorder_src.append(prev_idx) + reorder_dst.append(curr_idx) + + if reorder_src: + src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device) + dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device) + + temp_token_ids = token_ids_gpu_tensor[src_tensor].clone() + temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone() + + token_ids_gpu_tensor[dst_tensor] = temp_token_ids + num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens + + # Full copy for new/resumed requests. + for req_state in new_reqs: + new_req_idx = curr_req_id_to_index.get(req_state.req_id) + if new_req_idx is None: + continue + + num_tokens = input_batch.num_tokens_no_spec[new_req_idx] + if num_tokens > 0: + token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_( + input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens], + non_blocking=True, + ) + + # Always batch-sync sequence lengths from CPU for ALL active requests. + _sync_num_tokens( + input_batch, + num_tokens_no_spec_gpu, + active_idx_cpu, + active_idx_gpu, + n_active, + device, + _pinned_val_buf, + ) + + +def _sync_num_tokens( + input_batch: InputBatch, + num_tokens_no_spec_gpu: torch.Tensor, + active_idx_cpu: torch.Tensor, + active_idx_gpu: torch.Tensor, + n_active: int, + device: torch.device, + _pinned_val_buf: torch.Tensor, +) -> None: + """Batch-sync GPU sequence lengths from CPU source of truth. + + Inputs: + input_batch: Batch container with CPU length tensor. + num_tokens_no_spec_gpu: Destination GPU length tensor. + active_idx_cpu: Active request indices on CPU. + active_idx_gpu: Active request indices on GPU. + n_active: Number of active requests. + device: Target CUDA device. + _pinned_val_buf: Resident pinned int32 staging buffer. + Outputs: + None (updates num_tokens_no_spec_gpu in-place). + """ + src_cpu = input_batch.num_tokens_no_spec_cpu_tensor + vals = _pinned_val_buf[:n_active] + vals.copy_(src_cpu.index_select(0, active_idx_cpu)) + + num_tokens_no_spec_gpu.index_copy_( + 0, + active_idx_gpu, + vals.to(device=device, non_blocking=True), + ) + + +def copy_num_valid_draft_tokens( + num_valid_draft_tokens_cpu: torch.Tensor, + num_valid_draft_tokens_copy_stream: torch.cuda.Stream, + num_valid_draft_tokens_event: torch.cuda.Event, + num_valid_draft_tokens: torch.Tensor | None, + batch_size: int, +) -> None: + """ + Async D2H copy of per-request valid draft counts. + """ + if num_valid_draft_tokens is None: + return + + num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0]) + if num_reqs_to_copy <= 0: + return + + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(num_valid_draft_tokens_copy_stream): + num_valid_draft_tokens_copy_stream.wait_stream(default_stream) + num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_( + num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True + ) + num_valid_draft_tokens_event.record() diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c70970fdc06e..579c9b7a5acc 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -127,7 +127,13 @@ def __init__( # allocation if max_model_len is big. # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) self.req_prompt_embeds: dict[int, torch.Tensor] = {} - self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec_cpu_tensor = torch.zeros( + (max_num_reqs,), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy() self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs,), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cf08c13db062..08dbd614fdcf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,7 +10,7 @@ from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import reduce from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast @@ -164,6 +164,12 @@ from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.ngram_proposer_gpu import ( + NgramProposerGPU, + copy_num_valid_draft_tokens, + update_ngram_gpu_tensors_incremental, + update_scheduler_for_invalid_drafts, +) from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -424,7 +430,7 @@ def __init__( # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches + # TODO: Support overlapping micro-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( self.parallel_config.distributed_executor_backend == "external_launcher" @@ -493,6 +499,7 @@ def __init__( if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( NgramProposer # noqa: F823 + | NgramProposerGPU | SuffixDecodingProposer | EagleProposer | DraftModelProposer @@ -509,6 +516,23 @@ def __init__( device=self.device, runner=self, ) + elif self.speculative_config.use_ngram_gpu(): + self.drafter = NgramProposerGPU(self.vllm_config, self.device, self) + self.num_tokens_no_spec_gpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=device + ) + self.token_ids_gpu_tensor = torch.zeros( + self.max_num_reqs, + self.max_model_len, + dtype=torch.int32, + device=device, + ) + self._ngram_pinned_idx_buf = torch.zeros( + self.max_num_reqs, dtype=torch.long, pin_memory=True + ) + self._ngram_pinned_val_buf = torch.zeros( + self.max_num_reqs, dtype=torch.int32, pin_memory=True + ) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -564,7 +588,7 @@ def __init__( ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - # We need to use the encoder length for encoder-decoer + # We need to use the encoder length for encoder-decoder # because of KV cache for cross-attention. max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, @@ -721,6 +745,21 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. + self._num_valid_draft_tokens: torch.Tensor | None = None + self._num_valid_draft_tokens_cpu: torch.Tensor | None = None + self._num_valid_draft_tokens_event: torch.cuda.Event | None = None + self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + self._num_valid_draft_tokens_cpu = torch.empty( + self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory + ) + self._num_valid_draft_tokens_event = torch.cuda.Event() + self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream() + self._draft_token_req_ids: list[str] | None = None self.transfer_event = torch.Event() self.sampled_token_ids_pinned_cpu = torch.empty( @@ -992,6 +1031,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) + is_ngram_gpu = ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ) + if is_ngram_gpu: + ngram_gpu_new_reqs: list[CachedRequestState] = [] + reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: @@ -1054,12 +1100,31 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self._init_xdrope_positions(req_state) reqs_to_add.append(req_state) + # Track new requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + # Save scheduler-allocated spec lengths before trimming so + # prev_num_draft_len keeps the optimistic count for rejection correction. + original_num_spec_per_req: dict[str, int] = {} + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + for req_id, toks in scheduled_spec_tokens.items(): + original_num_spec_per_req[req_id] = len(toks) + update_scheduler_for_invalid_drafts( + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens_cpu, + scheduler_output, + self.input_batch.req_id_to_index, + ) + # Wait until valid_sampled_tokens_count is copied to cpu, # then use it to update actual num_computed_tokens of each request. valid_sampled_token_count = self._get_valid_sampled_token_count() @@ -1076,13 +1141,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # prev_num_draft_len is used in async scheduling mode with # spec decode. it indicates if need to update num_computed_tokens # of the request. for example: - # fist step: num_computed_tokens = 0, spec_tokens = [], + # first step: num_computed_tokens = 0, spec_tokens = [], # prev_num_draft_len = 0. # second step: num_computed_tokens = 100(prompt length), # spec_tokens = [a,b], prev_num_draft_len = 0. # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], # prev_num_draft_len = 2. - # num_computed_tokens in first step and second step does't contain + # num_computed_tokens in first step and second step doesn't contain # the spec tokens length, but in third step it contains the # spec tokens length. we only need to update num_computed_tokens # when prev_num_draft_len > 0. @@ -1096,6 +1161,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens -= num_rejected req_state.output_token_ids.extend([-1] * num_accepted) + if is_ngram_gpu and num_accepted > 0 and req_index is not None: + self.input_batch.num_tokens_no_spec[req_index] += num_accepted + # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -1156,6 +1224,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] reqs_to_add.append(req_state) + # Track resumed requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) continue # Update the persistent batch. @@ -1176,6 +1247,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add spec_token_ids to token_ids_cpu. self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) + # Restore scheduler-side draft count after ngram trimming. + if original_num_spec_per_req: + orig = original_num_spec_per_req.get(req_id, 0) + if orig != req_state.prev_num_draft_len: + req_state.prev_num_draft_len = orig # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -1190,6 +1266,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + # Incrementally update ngram_gpu tensors after batch is stable + if is_ngram_gpu: + update_ngram_gpu_tensors_incremental( + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + ngram_gpu_new_reqs, + self.device, + _pinned_idx_buf=self._ngram_pinned_idx_buf, + _pinned_val_buf=self._ngram_pinned_val_buf, + ) + def _update_states_after_model_execute( self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: @@ -3412,6 +3500,23 @@ def execute_model( else: logger.error("RoutedExpertsCapturer not initialized.") + # If ngram_gpu is used, we need to copy the scheduler_output to avoid + # the modification has influence on the scheduler_output in engine core process. + # The replace is much faster than deepcopy. + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy() + spec_decode_tokens_copy = ( + scheduler_output.scheduled_spec_decode_tokens.copy() + ) + scheduler_output = replace( + scheduler_output, + num_scheduled_tokens=num_scheduled_tokens_copy, + scheduled_spec_decode_tokens=spec_decode_tokens_copy, + ) + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): get_kv_transfer_group().handle_preemptions( scheduler_output.preempted_req_ids @@ -3825,6 +3930,32 @@ def propose_draft_token_ids(sampled_token_ids): self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + elif ( + spec_config.use_ngram_gpu() + and not spec_config.disable_padded_drafter_batch + ): + assert isinstance(self.drafter, NgramProposerGPU) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count, _ = ( + self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) # Since we couldn't run the drafter, # just use zeros for the draft tokens. self._draft_token_ids = torch.zeros( @@ -4064,6 +4195,52 @@ def propose_draft_token_ids( self.input_batch.token_ids_cpu, slot_mappings=slot_mappings, ) + if isinstance(self.drafter, NgramProposer): + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when ngram is used." + ) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + ) + elif spec_config.use_ngram_gpu(): + assert isinstance(self.drafter, NgramProposerGPU) + ( + next_token_ids, + valid_sampled_tokens_count, + valid_sampled_token_ids_gpu, + ) = self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + batch_size = next_token_ids.shape[0] + + draft_token_ids, num_valid_draft_tokens = self.drafter.propose( + self.num_tokens_no_spec_gpu[:batch_size], + self.token_ids_gpu_tensor[:batch_size], + valid_sampled_token_ids_gpu, + valid_sampled_tokens_count, + ) + + # Cache valid draft counts for scheduler-side trimming. + self._num_valid_draft_tokens = num_valid_draft_tokens + + # Async D2H copy on a dedicated stream. + copy_num_valid_draft_tokens( + self._num_valid_draft_tokens_cpu, + self._num_valid_draft_tokens_copy_stream, + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens, + self.input_batch.num_reqs, + ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) From 379689d533642cfc1d3ab2cf4dc02f09a8318a5f Mon Sep 17 00:00:00 2001 From: Wei Zhao <51183510+wzhao18@users.noreply.github.com> Date: Sat, 7 Mar 2026 16:51:54 -0500 Subject: [PATCH 10/19] [Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891) --- docs/design/attention_backends.md | 2 +- tests/v1/attention/test_mla_backends.py | 20 +++++++++-- .../v1/attention/test_sparse_mla_backends.py | 12 ++++++- .../generate_attention_backend_docs.py | 16 ++++++++- .../layers/attention/mla_attention.py | 35 ++++++++++++++++--- vllm/model_executor/models/config.py | 7 ---- .../backends/mla/flashinfer_mla_sparse.py | 7 ++++ .../attention/backends/mla/flashmla_sparse.py | 7 ++++ 8 files changed, 89 insertions(+), 17 deletions(-) diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index e7170babb6c9..a2079e70d7e8 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -206,7 +206,7 @@ configuration. |---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------| | `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | -| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | +| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 32c0b9064275..86efefc3740f 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -327,6 +327,12 @@ def __init__( self._k_scale_float = 1.0 self._v_scale_float = 1.0 + self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + compile_native=True, + ) + def forward_impl( self, q: torch.Tensor, @@ -338,6 +344,7 @@ def forward_impl( ) -> torch.Tensor: """Forward for sparse MLA - uses forward_mqa for all tokens.""" kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto") + fp8_attention = kv_cache_dtype.startswith("fp8") # Write to KV cache if kv_cache.numel() > 0: @@ -350,6 +357,9 @@ def forward_impl( scale=self._k_scale, ) + if fp8_attention and kv_cache_dtype != "fp8_ds_mla": + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + num_tokens = q.shape[0] # Sparse MLA uses forward_mqa for all tokens @@ -367,8 +377,14 @@ def forward_impl( # Convert from (N, B, L) to (B, N, L) mqa_ql_nope = mqa_ql_nope.transpose(0, 1) - # Pass as tuple to forward_mqa - mqa_q = (mqa_ql_nope, mqa_q_pe) + if fp8_attention and self.impl.supports_quant_query_input: + assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] + assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] + mqa_q = self._decode_concat_quant_fp8_op( + mqa_ql_nope, mqa_q_pe, self._q_scale + ) + else: + mqa_q = (mqa_ql_nope, mqa_q_pe) attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 86cefa036b40..0fd0ba6fab0d 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -191,6 +191,16 @@ def test_sparse_backend_decode_correctness( if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes: pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}") + if ( + backend_cls == FlashMLASparseBackend + and kv_cache_dtype.startswith("fp8") + and kv_cache_dtype != "fp8_ds_mla" + ): + pytest.skip( + "FlashMLA Sparse Attention backend fp8 only supports " + "fp8_ds_mla kv-cache dtype" + ) + supported_block_sizes = backend_cls.get_supported_kernel_block_sizes() if block_size not in supported_block_sizes: pytest.skip( @@ -419,7 +429,7 @@ def test_sparse_backend_decode_correctness( num_blocks=vllm_config.cache_config.num_gpu_blocks, common_attn_metadata=common_attn_metadata, randomize_blocks=False, - kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto", + kv_cache_dtype=kv_cache_dtype, scale=kv_cache_scale, ) diff --git a/tools/pre_commit/generate_attention_backend_docs.py b/tools/pre_commit/generate_attention_backend_docs.py index 628656f0df1a..3ec2248a82a4 100644 --- a/tools/pre_commit/generate_attention_backend_docs.py +++ b/tools/pre_commit/generate_attention_backend_docs.py @@ -49,6 +49,11 @@ # Backends to skip during doc generation SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"} +BACKEND_KV_DTYPE_EXCLUDES: dict[str, set[str]] = { + # fp8 is an alias for fp8_ds_mla for FlashMLA Sparse + "FLASHMLA_SPARSE": {"fp8"}, +} + def is_relevant_file(filepath: str) -> bool: """Check if a file matches any of the relevant patterns.""" @@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None tree, impl_class_name, "can_return_lse_for_decode", False, file_path ) + kv_cache_dtypes = parse_kv_cache_dtypes(class_node) + if backend_name in BACKEND_KV_DTYPE_EXCLUDES: + excluded = BACKEND_KV_DTYPE_EXCLUDES[backend_name] + kv_cache_dtypes = ", ".join( + d + for d in (d.strip() for d in kv_cache_dtypes.split(",")) + if d not in excluded + ) + return { "name": backend_name, "dtypes": parse_supported_dtypes(class_node), - "kv_cache_dtypes": parse_kv_cache_dtypes(class_node), + "kv_cache_dtypes": kv_cache_dtypes, "block_sizes": parse_block_sizes(class_node), "head_sizes": parse_head_sizes(class_node), "attn_types": parse_attention_types(class_node), diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index b0e16fa5240d..97ae3ef1b9d7 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -331,11 +331,6 @@ def __init__( calculate_kv_scales = False self.quant_config = quant_config - # Initialize KV cache quantization attributes - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - _init_kv_cache_quant(self, quant_config, prefix) - dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( self.head_size, @@ -347,6 +342,36 @@ def __init__( num_heads=self.num_heads, ) + # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format + # Automatically convert fp8 kv-cache format to "fp8_ds_mla" + if ( + self.attn_backend.get_name() == "FLASHMLA_SPARSE" + and kv_cache_dtype.startswith("fp8") + and kv_cache_dtype != "fp8_ds_mla" + ): + assert cache_config is not None + cache_config.cache_dtype = "fp8_ds_mla" + kv_cache_dtype = "fp8_ds_mla" + logger.info_once( + "Using DeepSeek's fp8_ds_mla KV cache format. To use standard " + "fp8 kv-cache format, please set `--attention-backend " + "FLASHINFER_MLA_SPARSE`" + ) + + if ( + self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE" + and kv_cache_dtype.startswith("fp8") + ): + logger.info_once( + "Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla " + "KV cache format, please set `--attention-backend FLASHMLA_SPARSE`" + ) + + # Initialize KV cache quantization attributes + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + _init_kv_cache_quant(self, quant_config, prefix) + if ( cache_config is not None and cache_config.enable_prefix_caching diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 734e3ad2339f..0e35bedbc99f 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -31,20 +31,13 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None: class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: - """ - Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32 - """ hf_config = vllm_config.model_config.hf_config # Mirror the check in vllm/model_executor/models/deepseek_v2.py is_v32 = hasattr(hf_config, "index_topk") assert is_v32 - # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled. cache_config = vllm_config.cache_config - if cache_config.cache_dtype.startswith("fp8"): - cache_config.cache_dtype = "fp8_ds_mla" - logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") if cache_config.cache_dtype == "bfloat16": cache_config.cache_dtype = "auto" logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 21a0d99c20c5..34683d3f699a 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8", + "fp8_e4m3", ] @staticmethod @@ -304,6 +306,11 @@ def __init__( self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None + # fp8 query quantization is required when using fp8 kv_cache, + # as the TRTLLM-GEN sparse MLA kernel requires matching dtypes + # for query and kv_cache (mixed bf16+fp8 is not supported). + self.supports_quant_query_input = True + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index c8a78af4a97d..c0cdc204d2df 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend): "auto", "bfloat16", "fp8_ds_mla", + "fp8", # alias for fp8_ds_mla ] @staticmethod @@ -567,6 +568,12 @@ def __init__( ) self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads) + if kv_cache_dtype.startswith("fp8"): + assert kv_cache_dtype == "fp8_ds_mla", ( + "FlashMLA Sparse Attention backend fp8 only supports " + "fp8_ds_mla kv-cache dtype" + ) + if kv_cache_dtype == "fp8_ds_mla": # Reserve workspace during initialization vllm_config = get_current_vllm_config() From 2dde535df1b736315e56eace0fa1923fe0beffc5 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Sat, 7 Mar 2026 16:52:11 -0500 Subject: [PATCH 11/19] [compile] Split compile/warmup monitoring (#36098) --- vllm/compilation/caching.py | 26 ++++++++++- vllm/compilation/decorators.py | 68 ++++++++++++++++------------ vllm/compilation/monitor.py | 81 +++++++++++++++++++++++++--------- 3 files changed, 125 insertions(+), 50 deletions(-) diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 3eda948b693f..70fbaabb4aac 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -189,13 +189,13 @@ def __init__( self.shape_env = None self.vllm_backend = vllm_backend self.sym_tensor_indices = sym_tensor_indices + self._fake_mode: Any | None = None import torch._functorch.config as functorch_config self.aot_autograd_config = ( aot_autograd_config or functorch_config.save_config_portable() ) - sym_input = next( (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None ) @@ -217,6 +217,7 @@ def serialize_compile_artifacts( state.pop("optimized_call") state.pop("shape_env") state.pop("vllm_backend", None) + state.pop("_fake_mode", None) for node in state["graph_module"].graph.nodes: node.meta.pop("source_fn_stack", None) node.meta.pop("nn_module_stack", None) @@ -351,8 +352,31 @@ def optimized_call(*example_inputs: Any) -> Any: return fn.optimized_call(*example_inputs) fn = cls(**state, optimized_call=optimized_call) + fn._fake_mode = fake_mode return fn + def finalize_loading(self, vllm_config: VllmConfig) -> None: + """Eagerly initialize the compiled backend and perform all loading. + + Must be called after _verify_source_unchanged has populated + compilation_config.traced_files, which is needed for cache dir + computation. + """ + if self._fake_mode is None: + return # Already finalized, or mega path (no _fake_mode set) + + from torch._guards import TracingContext, tracing + + from vllm.compilation.backends import VllmBackend + + vllm_backend = VllmBackend(vllm_config, self.prefix, self.is_encoder) + with tracing(TracingContext(self._fake_mode)): + result = vllm_backend(self.graph_module, list(self.example_inputs)) + self.optimized_call = result.optimized_call + self.vllm_backend = vllm_backend + + self._fake_mode = None + @property def co_name(self) -> Literal["VllmSerializableFunction"]: """ diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index fe0984baf97c..f8629be34b53 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -30,7 +30,7 @@ from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import is_torch_equal_or_newer -from .monitor import start_monitoring_torch_compile +from .monitor import monitor_profiling_run, monitor_torch_compile if TYPE_CHECKING: # Only added on nightly/2.10 so wrap @@ -434,17 +434,24 @@ def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any: cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") aot_compilation_path = os.path.join(cache_dir, "model") try: - with ( - set_current_vllm_config(self.vllm_config), - open(aot_compilation_path, "rb") as f, - ): - start_monitoring_torch_compile(self.vllm_config) - loaded_fn = torch.compiler.load_compiled_function( - f, f_globals=self.forward.__globals__ - ) - _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) - if not self.compilation_config.dynamic_shapes_config.evaluate_guards: - loaded_fn.disable_guard_check() + with monitor_torch_compile(self.vllm_config): + with ( + set_current_vllm_config(self.vllm_config), + open(aot_compilation_path, "rb") as f, + ): + loaded_fn = torch.compiler.load_compiled_function( + f, f_globals=self.forward.__globals__ + ) + _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) + ds_config = self.compilation_config.dynamic_shapes_config + if not ds_config.evaluate_guards: + loaded_fn.disable_guard_check() + # Eagerly load compiled artifacts now that traced_files + # is populated by _verify_source_unchanged. + with maybe_use_cudagraph_partition_wrapper(self.vllm_config): + loaded_fn._artifacts.compiled_fn.finalize_loading( + self.vllm_config + ) self.aot_compiled_fn = loaded_fn self.was_aot_compile_fn_loaded_from_disk = True except Exception as e: @@ -465,12 +472,11 @@ def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any: logger.info( "Directly load AOT compilation from path %s", aot_compilation_path ) - # Apply partition wrapper context for proper CUDA graph capture - from .monitor import end_monitoring_torch_compile - - with maybe_use_cudagraph_partition_wrapper(self.vllm_config): + with ( + monitor_profiling_run(), + maybe_use_cudagraph_partition_wrapper(self.vllm_config), + ): output = self.aot_compiled_fn(self, *args, **kwargs) - end_monitoring_torch_compile(self.vllm_config) return output if self.compiled: @@ -489,8 +495,6 @@ def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any: **kwargs, ) - # here, it is the starting point of the `torch.compile` process - start_monitoring_torch_compile(self.vllm_config) original_code_object = self.original_code_object() logger.debug("Start compiling function %s", original_code_object) @@ -559,16 +563,26 @@ def patched_inline_call(self_: Any) -> Any: # store the path for saving after warmup self._aot_compilation_path = aot_compilation_path self._aot_cache_dir = cache_dir - self.aot_compiled_fn = self.aot_compile(*args, **kwargs) - # All compilation is done at this point, save the AOT artifact. - self.save_aot_compiled_function() - output = self.aot_compiled_fn(self, *args, **kwargs) - else: - output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] + with monitor_torch_compile(self.vllm_config): + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + # All compilation is done at this point, save the + # AOT artifact. + self.save_aot_compiled_function() - from .monitor import end_monitoring_torch_compile + with monitor_profiling_run(): + output = self.aot_compiled_fn(self, *args, **kwargs) + else: + with monitor_torch_compile( + self.vllm_config, + "torch.compile and initial profiling/warmup " + "run together took %.2f s in total", + ): + output = TorchCompileWithNoGuardsWrapper.__call__( + self, # type: ignore[arg-type] + *args, + **kwargs, + ) - end_monitoring_torch_compile(self.vllm_config) self.compiled = True return output diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index fb9dfa3ac127..f584f526f08f 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,46 +1,83 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import time +from collections.abc import Generator -from vllm.config import CompilationConfig, CompilationMode, VllmConfig +from vllm.config import CompilationMode, VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) -context_manager = None +# Shared global so backends.py can read the start time for Dynamo timing. torch_compile_start_time: float = 0.0 -def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None: +@contextlib.contextmanager +def monitor_torch_compile( + vllm_config: VllmConfig, + message: str = "torch.compile took %.2f s in total", +) -> Generator[None, None, None]: + """Context manager that times torch.compile and manages depyf debugging. + + On normal exit: logs the compile time and exits depyf. + On exception: cleans up depyf without logging (compilation failed). + """ global torch_compile_start_time torch_compile_start_time = time.perf_counter() - compilation_config: CompilationConfig = vllm_config.compilation_config + compilation_config = vllm_config.compilation_config + depyf_cm = None path = vllm_config.compile_debug_dump_path() if compilation_config.mode == CompilationMode.VLLM_COMPILE and path: import depyf path.mkdir(parents=True, exist_ok=True) logger.debug("Dumping depyf output to %s", path) - global context_manager - context_manager = depyf.prepare_debug(path.as_posix()) - context_manager.__enter__() - - -def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None: - compilation_config: CompilationConfig = vllm_config.compilation_config - total_compile_time: float = time.perf_counter() - torch_compile_start_time - if compilation_config.mode == CompilationMode.VLLM_COMPILE: - logger.info_once( - "torch.compile and initial profiling run took %.2f s in total", - total_compile_time, - scope="local", - ) - global context_manager - if context_manager is not None: - context_manager.__exit__(None, None, None) - context_manager = None + depyf_cm = depyf.prepare_debug(path.as_posix()) + depyf_cm.__enter__() + + try: + yield + except Exception: + raise + else: + total_compile_time = time.perf_counter() - torch_compile_start_time + if compilation_config.mode == CompilationMode.VLLM_COMPILE: + logger.info_once(message, total_compile_time, scope="local") + finally: + if depyf_cm is not None: + try: + depyf_cm.__exit__(None, None, None) + except Exception: + logger.warning("Exception during depyf cleanup.", exc_info=True) + + +@contextlib.contextmanager +def monitor_profiling_run() -> Generator[None, None, None]: + """Context manager that times the initial profiling run. + + Asserts that no backend compilation occurs during the profiling run + (all compilation should have completed before this point). + """ + from vllm.compilation.counter import compilation_counter + + backend_compilations_before = compilation_counter.num_backend_compilations + start = time.perf_counter() + yield + elapsed = time.perf_counter() - start + assert ( + compilation_counter.num_backend_compilations == backend_compilations_before + ), ( + "backend compilation occurred during the initial profiling run; " + "all compilation should be complete before the profiling run starts." + ) + logger.info_once( + "Initial profiling/warmup run took %.2f s", + elapsed, + scope="local", + ) cudagraph_capturing_enabled: bool = True From 63298ee17350e4eda3f574eab16286bc405b23a6 Mon Sep 17 00:00:00 2001 From: Roy Huang Date: Sat, 7 Mar 2026 13:52:35 -0800 Subject: [PATCH 12/19] [Bugfix][LMCache][KVConnector] fix potential memory leak in LMCache multiprocess mode (#35931) --- .../kv_connector/v1/lmcache_mp_connector.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index fc31836aa7e1..db1d34ca15c3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -721,6 +721,34 @@ def update_state_after_alloc( # Clean up lookup future in scheduler adapter self.scheduler_adapter.cleanup_lookup_result(request.request_id) + # Free locks on chunks that vLLM already computed and won't + # retrieve from LMCache. + if tracker.num_lmcache_hit_blocks > 0: + if not condition: + # No retrieve needed — free ALL locked chunks + free_end = tracker.num_lmcache_hit_blocks * self.vllm_block_size + else: + # Note(Roy): Boundary misalignment between vLLM blocks and LMCache + # blocks is handled in free_lookup_locks. It makes sure that if + # the last vLLM computed block ends in the middle of a LMCache + # block, the end LMCache block is not freed (i.e., floor division) + # since it will still be needed by vLLM and such block's lock will + # be freed by vLLM's retrieve. + free_end = tracker.num_vllm_hit_blocks * self.vllm_block_size + + if free_end > 0: + self.scheduler_adapter.free_lookup_locks( + token_ids=list(tracker.all_token_ids), + start=0, + end=free_end, + request_id=request.request_id, + ) + logger.debug( + "Free locks of tokens %d-%d since it is cached by vLLM.", + 0, + free_end, + ) + def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: From 5d6aae4577590cd6b6a604f9e74c17c5f234271d Mon Sep 17 00:00:00 2001 From: Samuel Shen Date: Sat, 7 Mar 2026 13:52:48 -0800 Subject: [PATCH 13/19] [LMCache MP Patch]: Race Condition + Duplicated Block Ids (#35831) --- .../kv_connector/v1/lmcache_mp_connector.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index db1d34ca15c3..38dd980c62d6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -336,11 +336,21 @@ def GetRetrieveMetadata( start_token_idx = start * vllm_block_size end_token_idx = end * vllm_block_size token_ids = list(tracker.all_token_ids) + + # Compute how many tokens at the start of the retrieve range + # overlap with APC-shared blocks. The server must skip writing + # to these positions to avoid a cross-stream data race: the + # retrieve writes on the LMCache CUDA stream while concurrent + # requests may read these APC-shared blocks on the vLLM stream. + apc_overlap_blocks = tracker.num_vllm_hit_blocks - start + skip_first_n_tokens = apc_overlap_blocks * vllm_block_size + op = LoadStoreOp( token_ids=token_ids, block_ids=block_ids, start=start_token_idx, end=end_token_idx, + skip_first_n_tokens=skip_first_n_tokens, ) ret = LMCacheMPRequestMetadata( @@ -700,13 +710,22 @@ def update_state_after_alloc( num_external_tokens (int): the number of tokens that will be loaded from the external KV cache. """ - # NOTE: the `blocks` are NEW BLOCKS allocated for this request. + # NOTE: `blocks` comes from kv_cache_manager.get_blocks(request_id), + # which returns ALL blocks for the request (not just newly allocated). + # This function may be called twice for async-load requests: + # 1st call: blocks = initial allocation (APC + fresh) + # 2nd call: blocks = all blocks + # (initial + newly allocated for remaining tokens) + # We must only append the NEW blocks beyond what's already tracked + # to avoid duplication, which would corrupt the store path's block indexing. tracker = self._get_request_tracker(request.request_id) block_ids = reformat_block_ids(blocks.get_block_ids()) - # No matter we need to retrieve or not, we need to update - # the block ids into the tracker - tracker.append_block_ids(block_ids) + # Only append blocks beyond what's already tracked + existing_count = len(tracker.allocated_block_ids) + new_block_ids = block_ids[existing_count:] + if new_block_ids: + tracker.append_block_ids(new_block_ids) # Update the state of the tracker condition = tracker.needs_retrieve() From 40077ea3defdf2b0997245ca8999097eede2308f Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Sun, 8 Mar 2026 00:42:24 -0600 Subject: [PATCH 14/19] [CI] fix flaky empty responses and add diagnostic assertions in vision chat tests (#36341) Signed-off-by: Andreas Karatzas --- .../openai/test_transcription_validation.py | 106 +++-- tests/entrypoints/openai/test_vision.py | 403 +++++++++++------- 2 files changed, 317 insertions(+), 192 deletions(-) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index cbab74145433..58742f186851 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -6,7 +6,7 @@ import pytest -from ...utils import RemoteOpenAIServer +from ...utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer from .conftest import add_attention_backend MISTRAL_FORMAT_ARGS = [ @@ -19,12 +19,55 @@ ] +async def transcribe_and_check( + client, + model_name: str, + file, + *, + language: str, + expected_text: str, + expected_seconds: int | None = None, + case_sensitive: bool = False, +): + """Run a transcription request and assert the output contains + *expected_text* and optionally that usage reports *expected_seconds*. + + Provides detailed failure messages with the actual transcription output. + """ + transcription = await client.audio.transcriptions.create( + model=model_name, + file=file, + language=language, + response_format="text", + temperature=0.0, + ) + out = json.loads(transcription) + out_text = out["text"] + out_usage = out["usage"] + + if case_sensitive: + assert expected_text in out_text, ( + f"Expected {expected_text!r} in transcription output, got: {out_text!r}" + ) + else: + assert expected_text.lower() in out_text.lower(), ( + f"Expected {expected_text!r} (case-insensitive) in transcription " + f"output, got: {out_text!r}" + ) + + if expected_seconds is not None: + assert out_usage["seconds"] == expected_seconds, ( + f"Expected {expected_seconds}s of audio, " + f"got {out_usage['seconds']}s. Full usage: {out_usage!r}" + ) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", ["mistralai/Voxtral-Mini-3B-2507", "Qwen/Qwen3-ASR-0.6B"] ) async def test_basic_audio(mary_had_lamb, model_name, rocm_aiter_fa_attention): - server_args = ["--enforce-eager"] + server_args = ["--enforce-eager", *ROCM_EXTRA_ARGS] if model_name.startswith("mistralai"): server_args += MISTRAL_FORMAT_ARGS @@ -32,20 +75,18 @@ async def test_basic_audio(mary_had_lamb, model_name, rocm_aiter_fa_attention): add_attention_backend(server_args, rocm_aiter_fa_attention) # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. - with RemoteOpenAIServer(model_name, server_args) as remote_server: + with RemoteOpenAIServer( + model_name, server_args, env_dict=ROCM_ENV_OVERRIDES + ) as remote_server: client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, + await transcribe_and_check( + client, + model_name, + mary_had_lamb, language="en", - response_format="text", - temperature=0.0, + expected_text="Mary had a little lamb", + expected_seconds=16, ) - out = json.loads(transcription) - out_text = out["text"] - out_usage = out["usage"] - assert "Mary had a little lamb" in out_text - assert out_usage["seconds"] == 16, out_usage["seconds"] @pytest.mark.asyncio @@ -74,20 +115,18 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention): add_attention_backend(server_args, rocm_aiter_fa_attention) # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. - with RemoteOpenAIServer(model_name, server_args) as remote_server: + with RemoteOpenAIServer( + model_name, server_args, env_dict=ROCM_ENV_OVERRIDES + ) as remote_server: client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=lora_model_name, - file=mary_had_lamb, + await transcribe_and_check( + client, + lora_model_name, + mary_had_lamb, language="en", - response_format="text", - temperature=0.0, + expected_text="mary had a little lamb", + expected_seconds=16, ) - out = json.loads(transcription) - out_text = out["text"] - out_usage = out["usage"] - assert "mary had a little lamb" in out_text - assert out_usage["seconds"] == 16, out_usage["seconds"] @pytest.mark.asyncio @@ -97,20 +136,21 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention): async def test_basic_audio_foscolo(foscolo, rocm_aiter_fa_attention, model_name): # Gemma accuracy on some of the audio samples we use is particularly bad, # hence we use a different one here. WER is evaluated separately. - server_args = ["--enforce-eager"] + server_args = ["--enforce-eager", *ROCM_EXTRA_ARGS] add_attention_backend(server_args, rocm_aiter_fa_attention) with RemoteOpenAIServer( - model_name, server_args, max_wait_seconds=480 + model_name, + server_args, + max_wait_seconds=480, + env_dict=ROCM_ENV_OVERRIDES, ) as remote_server: client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=foscolo, + await transcribe_and_check( + client, + model_name, + foscolo, language="it", - response_format="text", - temperature=0.0, + expected_text="ove il mio corpo fanciulletto giacque", ) - out = json.loads(transcription)["text"] - assert "ove il mio corpo fanciulletto giacque" in out diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 6c5a08ae2f91..c0d8b0532830 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -12,7 +12,7 @@ from vllm.multimodal.utils import encode_image_url, fetch_image from vllm.platforms import current_platform -from ...utils import RemoteOpenAIServer +from ...utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" MAXIMUM_IMAGES = 2 @@ -48,10 +48,37 @@ def check_output_matches_terms(content: str, term_groups: list[list[str]]) -> bo All term groups must be satisfied. """ content_lower = content.lower() - for group in term_groups: - if not any(term.lower() in content_lower for term in group): - return False - return True + return all( + any(term.lower() in content_lower for term in group) for group in term_groups + ) + + +def assert_non_empty_content(chat_completion, *, context: str = "") -> str: + """Assert the first choice has non-empty string content; return it. + + Provides a detailed failure message including the full ChatCompletion + response so flaky / model-quality issues are easy to diagnose. + """ + prefix = f"[{context}] " if context else "" + choice = chat_completion.choices[0] + content = choice.message.content + + assert content is not None, ( + f"{prefix}Expected non-None content but got None. " + f"finish_reason={choice.finish_reason!r}, " + f"full message={choice.message!r}, " + f"usage={chat_completion.usage!r}" + ) + assert isinstance(content, str), ( + f"{prefix}Expected str content, got {type(content).__name__}: {content!r}" + ) + assert len(content) > 0, ( + f"{prefix}Expected non-empty content but got empty string. " + f"finish_reason={choice.finish_reason!r}, " + f"full message={choice.message!r}, " + f"usage={chat_completion.usage!r}" + ) + return content @pytest.fixture(scope="module") @@ -67,16 +94,22 @@ def server(): "--trust-remote-code", "--limit-mm-per-prompt", json.dumps({"image": MAXIMUM_IMAGES}), + *ROCM_EXTRA_ARGS, ] # ROCm: Increase timeouts to handle potential network delays and slower # video processing when downloading multiple videos from external sources - env_overrides = {} - if current_platform.is_rocm(): - env_overrides = { - "VLLM_VIDEO_FETCH_TIMEOUT": "120", - "VLLM_ENGINE_ITERATION_TIMEOUT_S": "300", - } + env_overrides = { + **ROCM_ENV_OVERRIDES, + **( + { + "VLLM_VIDEO_FETCH_TIMEOUT": "120", + "VLLM_ENGINE_ITERATION_TIMEOUT_S": "300", + } + if current_platform.is_rocm() + else {} + ), + } with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_overrides) as remote_server: yield remote_server @@ -117,6 +150,51 @@ def dummy_messages_from_image_url( ] +def describe_image_messages( + image_url: str, *, extra_image_fields: dict | None = None +) -> list[dict]: + """Build the system + user messages used by the completions-with-image + family of tests. *extra_image_fields* is merged into the top-level + image content block (for uuid / bad-key tests).""" + image_block: dict = { + "type": "image_url", + "image_url": {"url": image_url}, + } + if extra_image_fields: + image_block.update(extra_image_fields) + + return [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image."}, + image_block, + ], + }, + ] + + +async def complete_and_check( + client: openai.AsyncOpenAI, + model_name: str, + messages: list[dict], + *, + context: str, + max_completion_tokens: int = 50, + temperature: float = 0.0, +) -> str: + """Run a chat completion and assert the output is non-empty. + Returns the content string.""" + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=max_completion_tokens, + temperature=temperature, + ) + return assert_non_empty_content(chat_completion, context=context) + + def get_hf_prompt_tokens(model_name, content, image_url): processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True, num_crops=4 @@ -153,7 +231,6 @@ async def test_single_chat_session_image( messages = dummy_messages_from_image_url(image_url, content_text) max_completion_tokens = 10 - # test single completion chat_completion = await client.chat.completions.create( model=model_name, messages=messages, @@ -162,32 +239,46 @@ async def test_single_chat_session_image( temperature=0.0, top_logprobs=5, ) - assert len(chat_completion.choices) == 1 + assert len(chat_completion.choices) == 1, ( + f"Expected 1 choice, got {len(chat_completion.choices)}" + ) choice = chat_completion.choices[0] - assert choice.finish_reason == "length" + assert choice.finish_reason == "length", ( + f"Expected finish_reason='length' (capped at {max_completion_tokens} " + f"tokens), got {choice.finish_reason!r}. " + f"content={choice.message.content!r}" + ) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) - assert chat_completion.usage == openai.types.CompletionUsage( + expected_usage = openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, total_tokens=hf_prompt_tokens + max_completion_tokens, ) + assert chat_completion.usage == expected_usage, ( + f"Usage mismatch: got {chat_completion.usage!r}, expected {expected_usage!r}" + ) message = choice.message - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" + assert message.content is not None and len(message.content) >= 10, ( + f"Expected content with >=10 chars, got {message.content!r}" + ) + assert message.role == "assistant", ( + f"Expected role='assistant', got {message.role!r}" + ) + messages.append({"role": "assistant", "content": message.content}) # test multi-turn dialogue messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, + await complete_and_check( + client, + model_name, + messages, + context=f"multi-turn follow-up for {image_url}", max_completion_tokens=10, ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 @pytest.mark.asyncio @@ -209,7 +300,7 @@ async def test_error_on_invalid_image_url_type( # image_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create( + await client.chat.completions.create( model=model_name, messages=messages, max_completion_tokens=10, @@ -235,10 +326,15 @@ async def test_single_chat_session_image_beamsearch( top_logprobs=5, extra_body=dict(use_beam_search=True), ) - assert len(chat_completion.choices) == 2 - assert ( - chat_completion.choices[0].message.content - != chat_completion.choices[1].message.content + assert len(chat_completion.choices) == 2, ( + f"Expected 2 beam search choices, got {len(chat_completion.choices)}" + ) + + content_0 = chat_completion.choices[0].message.content + content_1 = chat_completion.choices[1].message.content + assert content_0 != content_1, ( + f"Beam search should produce different outputs for {image_url}, " + f"but both returned: {content_0!r}" ) @@ -269,33 +365,46 @@ async def test_single_chat_session_image_base64encoded( temperature=0.0, top_logprobs=5, ) - assert len(chat_completion.choices) == 1 + assert len(chat_completion.choices) == 1, ( + f"Expected 1 choice, got {len(chat_completion.choices)}" + ) choice = chat_completion.choices[0] - assert choice.finish_reason == "length" + assert choice.finish_reason == "length", ( + f"Expected finish_reason='length', got {choice.finish_reason!r}. " + f"content={choice.message.content!r}" + ) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) - assert chat_completion.usage == openai.types.CompletionUsage( + expected_usage = openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, total_tokens=hf_prompt_tokens + max_completion_tokens, ) + assert chat_completion.usage == expected_usage, ( + f"Usage mismatch: got {chat_completion.usage!r}, expected {expected_usage!r}" + ) message = choice.message - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" + assert message.content is not None and len(message.content) >= 10, ( + f"Expected content with >=10 chars, got {message.content!r}" + ) + assert message.role == "assistant", ( + f"Expected role='assistant', got {message.role!r}" + ) + messages.append({"role": "assistant", "content": message.content}) # test multi-turn dialogue messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, + await complete_and_check( + client, + model_name, + messages, + context=f"multi-turn base64 follow-up for {raw_image_url}", max_completion_tokens=10, temperature=0.0, ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 @pytest.mark.asyncio @@ -321,7 +430,10 @@ async def test_single_chat_session_image_base64encoded_beamsearch( temperature=0.0, extra_body=dict(use_beam_search=True), ) - assert len(chat_completion.choices) == 2 + assert len(chat_completion.choices) == 2, ( + f"Expected 2 beam search choices for image {image_idx} " + f"({raw_image_url}), got {len(chat_completion.choices)}" + ) # Verify beam search produces two different non-empty outputs content_0 = chat_completion.choices[0].message.content @@ -333,18 +445,28 @@ async def test_single_chat_session_image_base64encoded_beamsearch( f"Output 0: {content_0!r}, Output 1: {content_1!r}" ) - assert content_0, "First beam search output should not be empty" - assert content_1, "Second beam search output should not be empty" - assert content_0 != content_1, "Beam search should produce different outputs" + assert content_0, ( + f"First beam output is empty for image {image_idx} ({raw_image_url}). " + f"finish_reason={chat_completion.choices[0].finish_reason!r}" + ) + assert content_1, ( + f"Second beam output is empty for image {image_idx} " + f"({raw_image_url}). " + f"finish_reason={chat_completion.choices[1].finish_reason!r}" + ) + assert content_0 != content_1, ( + f"Beam search produced identical outputs for image {image_idx} " + f"({raw_image_url}): {content_0!r}" + ) # Verify each output contains the required terms for this image for i, content in enumerate([content_0, content_1]): - if not check_output_matches_terms(content, required_terms): - pytest.fail( - f"Output {i} '{content}' doesn't contain required terms. " - f"Expected all of these term groups (at least one from each): " - f"{required_terms}" - ) + assert check_output_matches_terms(content, required_terms), ( + f"Beam output {i} for image {image_idx} ({raw_image_url}) " + f"doesn't match required terms.\n" + f" content: {content!r}\n" + f" required (all groups, >=1 per group): {required_terms}" + ) @pytest.mark.asyncio @@ -378,16 +500,29 @@ async def test_chat_streaming_image( async for chunk in stream: delta = chunk.choices[0].delta if delta.role: - assert delta.role == "assistant" + assert delta.role == "assistant", ( + f"Expected role='assistant' in stream delta, got {delta.role!r}" + ) if delta.content: chunks.append(delta.content) if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 # finish reason should only return in last block - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == stop_reason - assert delta.content - assert "".join(chunks) == output + assert finish_reason_count == 1, ( + f"Expected exactly 1 finish_reason across stream chunks, " + f"got {finish_reason_count}" + ) + assert chunk.choices[0].finish_reason == stop_reason, ( + f"Stream finish_reason={chunk.choices[0].finish_reason!r} " + f"doesn't match non-stream finish_reason={stop_reason!r}" + ) + + streamed_text = "".join(chunks) + assert streamed_text == output, ( + f"Streamed output doesn't match non-streamed for {image_url}.\n" + f" streamed: {streamed_text!r}\n" + f" non-streamed: {output!r}" + ) @pytest.mark.asyncio @@ -418,17 +553,19 @@ async def test_multi_image_input( max_tokens=5, temperature=0.0, ) - completion = completion.choices[0].text - assert completion is not None and len(completion) >= 0 + assert completion.choices[0].text is not None, ( + "Server failed to produce output after rejecting over-limit " + "multi-image request" + ) else: - chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, + await complete_and_check( + client, + model_name, + messages, + context=f"multi-image input ({len(image_urls)} images)", max_completion_tokens=10, temperature=0.0, ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 @pytest.mark.asyncio @@ -444,30 +581,13 @@ async def test_completions_with_image( image_urls: list[str], ): for image_url in image_urls: - chat_completion = await client.chat.completions.create( - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Describe this image.", - }, - { - "type": "image_url", - "image_url": { - "url": image_url, - }, - }, - ], - }, - ], - model=model_name, + messages = describe_image_messages(image_url) + await complete_and_check( + client, + model_name, + messages, + context=f"completions_with_image url={image_url}", ) - assert chat_completion.choices[0].message.content is not None - assert isinstance(chat_completion.choices[0].message.content, str) - assert len(chat_completion.choices[0].message.content) > 0 @pytest.mark.asyncio @@ -483,54 +603,33 @@ async def test_completions_with_image_with_uuid( image_urls: list[str], ): for image_url in image_urls: - chat_completion = await client.chat.completions.create( - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Describe this image.", - }, - { - "type": "image_url", - "image_url": { - "url": image_url, - }, - "uuid": image_url, - }, - ], - }, - ], - model=model_name, + messages = describe_image_messages( + image_url, + extra_image_fields={"uuid": image_url}, ) - assert chat_completion.choices[0].message.content is not None - assert isinstance(chat_completion.choices[0].message.content, str) - assert len(chat_completion.choices[0].message.content) > 0 - - # Second request, with empty image but the same uuid. - chat_completion_with_empty_image = await client.chat.completions.create( - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Describe this image.", - }, - {"type": "image_url", "image_url": {}, "uuid": image_url}, - ], - }, - ], - model=model_name, + await complete_and_check( + client, + model_name, + messages, + context=f"uuid first request url={image_url}", ) - assert chat_completion_with_empty_image.choices[0].message.content is not None - assert isinstance( - chat_completion_with_empty_image.choices[0].message.content, str + + cached_messages: list[dict] = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image."}, + {"type": "image_url", "image_url": {}, "uuid": image_url}, + ], + }, + ] + await complete_and_check( + client, + model_name, + cached_messages, + context=f"uuid cached (empty image) uuid={image_url}", ) - assert len(chat_completion_with_empty_image.choices[0].message.content) > 0 @pytest.mark.asyncio @@ -540,16 +639,13 @@ async def test_completions_with_empty_image_with_uuid_without_cache_hit( model_name: str, ): with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create( + await client.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": [ - { - "type": "text", - "text": "Describe this image.", - }, + {"type": "text", "text": "Describe this image."}, { "type": "image_url", "image_url": {}, @@ -575,29 +671,18 @@ async def test_completions_with_image_with_incorrect_uuid_format( image_urls: list[str], ): for image_url in image_urls: - chat_completion = await client.chat.completions.create( - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Describe this image.", - }, - { - "type": "image_url", - "image_url": { - "url": image_url, - "incorrect_uuid_key": image_url, - }, - "also_incorrect_uuid_key": image_url, - }, - ], - }, - ], - model=model_name, + messages = describe_image_messages( + image_url, + extra_image_fields={ + "also_incorrect_uuid_key": image_url, + }, + ) + # Inject the bad key inside image_url dict too + messages[1]["content"][1]["image_url"]["incorrect_uuid_key"] = image_url + + await complete_and_check( + client, + model_name, + messages, + context=f"incorrect uuid format url={image_url}", ) - assert chat_completion.choices[0].message.content is not None - assert isinstance(chat_completion.choices[0].message.content, str) - assert len(chat_completion.choices[0].message.content) > 0 From b7332b058c3b0d8533395b49dea9273aa0973b4e Mon Sep 17 00:00:00 2001 From: nvnbagrov Date: Sun, 8 Mar 2026 12:04:05 +0200 Subject: [PATCH 15/19] [Model] Nano Nemotron VL - fast media preprocessing (#35657) Signed-off-by: Natan Bagrov --- .../model_executor/models/nano_nemotron_vl.py | 141 ++++++++++-------- 1 file changed, 80 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 9b9beadc099e..b32067557622 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -17,11 +17,11 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar import einops +import numpy as np import numpy.typing as npt import regex as re import torch import torch.nn as nn -import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType @@ -214,7 +214,12 @@ class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): def dynamic_preprocess( - image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0 + image, + *, + image_size=512, + max_num_tiles=12, + use_thumbnail=True, + idx=0, ): orig_width, orig_height = image.size @@ -227,35 +232,44 @@ def dynamic_preprocess( image_size=image_size, use_thumbnail=False, ) - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size, - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - assert len(processed_images) == blocks - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((image_size, image_size)) - processed_images.append(thumbnail_img) - - processed_images = [ - img.convert("RGB") if img.mode != "RGB" else img for img in processed_images - ] - processed_images = [ - T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)( - img + + image = np.asarray( + image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8 + ) + + image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3) + image = image.permute(0, 3, 1, 2) # (1, 3, H, W) + + resized_img = torch.nn.functional.interpolate( + image, + size=(target_height, target_width), + mode="bicubic", + align_corners=False, + antialias=True, + ) + B, C, H, W = resized_img.shape + hp, wp = H // image_size, W // image_size + patches = ( + resized_img.reshape(B, C, hp, image_size, wp, image_size) + .permute(0, 2, 4, 1, 3, 5) + .reshape(B * hp * wp, C, image_size, image_size) + / 255.0 + ) + + if use_thumbnail and patches.shape[0] > 1: + thumb = ( + torch.nn.functional.interpolate( + image, + size=(image_size, image_size), + mode="bicubic", + align_corners=False, + antialias=True, + ) + / 255.0 ) - for img in processed_images - ] - processed_images = [T.ToTensor()(img) for img in processed_images] - return processed_images + patches = torch.cat([patches, thumb], dim=0) + + return list(patches) def image_to_pixel_values( @@ -287,22 +301,21 @@ def video_to_pixel_values( ) -> torch.Tensor: assert max_num_tiles == 1, "Video modality always uses one tile" - # Convert each frame to a single resized tile tensor consistent - # with image path - frames_tensors: list[torch.Tensor] = [] - for frame in video: - pil_frame = dynamic_preprocess( - Image.fromarray(frame, mode="RGB"), - image_size=input_size, - max_num_tiles=max_num_tiles, - use_thumbnail=use_thumbnail, - idx=0, + # (num_frames, H, W, C) -> (num_frames, C, H, W) + video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2) + + if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size: + video_tensor = torch.nn.functional.interpolate( + video_tensor, + size=(input_size, input_size), + mode="bicubic", + align_corners=False, + antialias=True, ) - # dynamic_preprocess returns tensors already; take the single tile - assert len(pil_frame) >= 1 - frames_tensors.append(pil_frame[-1]) - return torch.stack(frames_tensors) + video_tensor = video_tensor / 255.0 + + return video_tensor def input_conditioner(x, norm_mean, norm_std): @@ -346,12 +359,6 @@ def __init__( self._factor_max = factor_max self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1) self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1) - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - ] - ) assert downsample_ratio < 1 reduction_factor = 1 / downsample_ratio assert reduction_factor == 2.0 @@ -441,15 +448,25 @@ class DynamicResolutionParams: patch_size: tuple[int, int] def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]: - resized_img = params.media.resize( - ( - params.patch_size[0] * self._patch_size, - params.patch_size[1] * self._patch_size, + target_size = ( + params.patch_size[1] * self._patch_size, + params.patch_size[0] * self._patch_size, + ) + image = np.asarray( + params.media.convert("RGB") if params.media.mode != "RGB" else params.media, + dtype=np.uint8, + ) + resized_img = ( + torch.nn.functional.interpolate( + torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2), + size=target_size, + mode="bicubic", + align_corners=False, + antialias=True, ) + / 255.0 ) - processed_images = [resized_img] - - return [self._transform(img) for img in processed_images] + return list(resized_img) def process_media( self, @@ -803,6 +820,7 @@ def _preprocess_image( image_repl = self.get_image_repl(feature_size, num_patches) parts[i] = parts[i].replace("", image_repl.full) text = ["".join(parts)] + return text, image_inputs def _make_batch_input(self, input_item: Any | list[Any] | None = None): @@ -922,14 +940,14 @@ def _preprocess_video( frames_indices_lst = [ metadata["frames_indices"] for metadata in video_metadata_lst ] - + video_num_patches = torch.tensor( + [len(item) for item in pixel_values_lst_video] + ) video_inputs = { "pixel_values_flat_video": input_conditioner( torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std ), - "video_num_patches": torch.tensor( - [len(item) for item in pixel_values_lst_video] - ), + "video_num_patches": video_num_patches, "frames_indices": frames_indices_lst, "frame_duration_ms": torch.tensor(frame_duration_ms_lst), } @@ -985,6 +1003,7 @@ def _preprocess_video( video_repl.full, skip_special_tokens=False ) text = [t.replace("