From 0e535001fa3bfd698875a450774849c6afc2e45a Mon Sep 17 00:00:00 2001 From: shervin Date: Wed, 13 May 2026 16:00:30 +0800 Subject: [PATCH] perf: skip streaming sampling params revalidation Assisted-by: OpenAI Codex Signed-off-by: shervin --- .../test_async_llm_streaming.py | 104 ++++++++++++++++++ vllm/v1/engine/async_llm.py | 6 +- vllm/v1/engine/input_processor.py | 4 +- 3 files changed, 111 insertions(+), 3 deletions(-) diff --git a/tests/v1/streaming_input/test_async_llm_streaming.py b/tests/v1/streaming_input/test_async_llm_streaming.py index b532eed15f38..06a4a4a7543d 100644 --- a/tests/v1/streaming_input/test_async_llm_streaming.py +++ b/tests/v1/streaming_input/test_async_llm_streaming.py @@ -8,11 +8,15 @@ import pytest from vllm.engine.protocol import StreamingInput +from vllm.inputs import TokensPrompt from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.output_processor import RequestOutputCollector +pytestmark = pytest.mark.skip_global_cleanup + @pytest.fixture def mock_async_llm(): @@ -106,6 +110,106 @@ def make_output(request_id: str, finished: bool) -> RequestOutput: ) +def make_engine_core_request( + request_id: str, + sampling_params: SamplingParams, +) -> EngineCoreRequest: + return EngineCoreRequest( + request_id=request_id, + prompt_token_ids=[1], + mm_features=None, + sampling_params=sampling_params, + pooling_params=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) + + +def make_streaming_llm() -> AsyncLLM: + llm = MagicMock(spec=AsyncLLM) + llm.get_supported_tasks = AsyncMock(return_value=("generate",)) + llm.input_processor = MagicMock() + + def process_inputs(*_args, **kwargs): + return make_engine_core_request( + kwargs["request_id"], + kwargs["params"], + ) + + llm.input_processor.process_inputs.side_effect = process_inputs + llm.input_processor.assign_request_id.side_effect = lambda request: setattr( + request, "request_id", f"{request.request_id}-internal" + ) + llm.model_config = MagicMock() + llm.model_config.is_encoder_decoder = False + llm._add_request = AsyncMock() + llm._run_output_handler = MagicMock() + llm.log_requests = False + llm._validate_streaming_input_sampling_params = ( + AsyncLLM._validate_streaming_input_sampling_params + ) + llm._add_streaming_input_request = AsyncLLM._add_streaming_input_request.__get__( + llm, + AsyncLLM, + ) + return llm + + +@pytest.mark.asyncio +async def test_streaming_input_reused_sampling_params_skip_validation(): + sampling_params = SamplingParams(max_tokens=10) + llm = make_streaming_llm() + + async def input_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt=TokensPrompt(prompt_token_ids=[1])) + yield StreamingInput(prompt=TokensPrompt(prompt_token_ids=[2])) + + queue = await llm._add_streaming_input_request( + "test", + input_generator(), + sampling_params, + ) + task = queue._input_stream_task + assert task is not None + await task + + validate_params = [ + call.kwargs.get("validate_params", True) + for call in llm.input_processor.process_inputs.call_args_list + ] + assert validate_params == [True, False, False] + + +@pytest.mark.asyncio +async def test_streaming_input_per_chunk_sampling_params_validate(): + sampling_params = SamplingParams(max_tokens=10) + chunk_params = SamplingParams(max_tokens=5) + llm = make_streaming_llm() + + async def input_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput( + prompt=TokensPrompt(prompt_token_ids=[1]), + sampling_params=chunk_params, + ) + + queue = await llm._add_streaming_input_request( + "test", + input_generator(), + sampling_params, + ) + task = queue._input_stream_task + assert task is not None + await task + + validate_params = [ + call.kwargs.get("validate_params", True) + for call in llm.input_processor.process_inputs.call_args_list + ] + assert validate_params == [True, True] + + @pytest.mark.asyncio async def test_generate_with_async_generator(): """Test generate with an async input generator. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 32cc3c0d2d05..ce6df76f9e4f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -460,16 +460,18 @@ async def handle_inputs(): try: async for input_chunk in input_stream: sp = input_chunk.sampling_params - if sp: + if sp is not None: self._validate_streaming_input_sampling_params(sp) + validate_params = True else: sp = sampling_params - # TODO(nick): Avoid re-validating reused sampling parameters + validate_params = False req = self.input_processor.process_inputs( request_id=internal_req_id, prompt=input_chunk.prompt, params=sp, resumable=True, + validate_params=validate_params, **inputs, # type: ignore[arg-type] ) req.external_req_id = request_id diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index c579c92baf37..2aa793a63bff 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -244,8 +244,10 @@ def process_inputs( priority: int = 0, data_parallel_rank: int | None = None, resumable: bool = False, + validate_params: bool = True, ) -> EngineCoreRequest: - self._validate_params(params, supported_tasks) + if validate_params: + self._validate_params(params, supported_tasks) self._validate_lora(lora_request) parallel_config = self.vllm_config.parallel_config