Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/v1/streaming_input/test_async_llm_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import pytest

from vllm.engine.protocol import StreamingInput
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.output_processor import RequestOutputCollector

pytestmark = pytest.mark.skip_global_cleanup


@pytest.fixture
def mock_async_llm():
Expand Down Expand Up @@ -106,6 +110,106 @@ def make_output(request_id: str, finished: bool) -> RequestOutput:
)


def make_engine_core_request(
request_id: str,
sampling_params: SamplingParams,
) -> EngineCoreRequest:
return EngineCoreRequest(
request_id=request_id,
prompt_token_ids=[1],
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)


def make_streaming_llm() -> AsyncLLM:
llm = MagicMock(spec=AsyncLLM)
llm.get_supported_tasks = AsyncMock(return_value=("generate",))
llm.input_processor = MagicMock()

def process_inputs(*_args, **kwargs):
return make_engine_core_request(
kwargs["request_id"],
kwargs["params"],
)

llm.input_processor.process_inputs.side_effect = process_inputs
llm.input_processor.assign_request_id.side_effect = lambda request: setattr(
request, "request_id", f"{request.request_id}-internal"
)
llm.model_config = MagicMock()
llm.model_config.is_encoder_decoder = False
llm._add_request = AsyncMock()
llm._run_output_handler = MagicMock()
llm.log_requests = False
llm._validate_streaming_input_sampling_params = (
AsyncLLM._validate_streaming_input_sampling_params
)
llm._add_streaming_input_request = AsyncLLM._add_streaming_input_request.__get__(
llm,
AsyncLLM,
)
return llm


@pytest.mark.asyncio
async def test_streaming_input_reused_sampling_params_skip_validation():
sampling_params = SamplingParams(max_tokens=10)
llm = make_streaming_llm()

async def input_generator() -> AsyncGenerator[StreamingInput, None]:
yield StreamingInput(prompt=TokensPrompt(prompt_token_ids=[1]))
yield StreamingInput(prompt=TokensPrompt(prompt_token_ids=[2]))

queue = await llm._add_streaming_input_request(
"test",
input_generator(),
sampling_params,
)
task = queue._input_stream_task
assert task is not None
await task

validate_params = [
call.kwargs.get("validate_params", True)
for call in llm.input_processor.process_inputs.call_args_list
]
assert validate_params == [True, False, False]


@pytest.mark.asyncio
async def test_streaming_input_per_chunk_sampling_params_validate():
sampling_params = SamplingParams(max_tokens=10)
chunk_params = SamplingParams(max_tokens=5)
llm = make_streaming_llm()

async def input_generator() -> AsyncGenerator[StreamingInput, None]:
yield StreamingInput(
prompt=TokensPrompt(prompt_token_ids=[1]),
sampling_params=chunk_params,
)

queue = await llm._add_streaming_input_request(
"test",
input_generator(),
sampling_params,
)
task = queue._input_stream_task
assert task is not None
await task

validate_params = [
call.kwargs.get("validate_params", True)
for call in llm.input_processor.process_inputs.call_args_list
]
assert validate_params == [True, True]


@pytest.mark.asyncio
async def test_generate_with_async_generator():
"""Test generate with an async input generator.
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,16 +460,18 @@ async def handle_inputs():
try:
async for input_chunk in input_stream:
sp = input_chunk.sampling_params
if sp:
if sp is not None:
self._validate_streaming_input_sampling_params(sp)
validate_params = True
else:
sp = sampling_params
# TODO(nick): Avoid re-validating reused sampling parameters
validate_params = False
req = self.input_processor.process_inputs(
request_id=internal_req_id,
prompt=input_chunk.prompt,
params=sp,
resumable=True,
validate_params=validate_params,
**inputs, # type: ignore[arg-type]
)
req.external_req_id = request_id
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/engine/input_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,10 @@ def process_inputs(
priority: int = 0,
data_parallel_rank: int | None = None,
resumable: bool = False,
validate_params: bool = True,
) -> EngineCoreRequest:
self._validate_params(params, supported_tasks)
if validate_params:
self._validate_params(params, supported_tasks)
self._validate_lora(lora_request)

parallel_config = self.vllm_config.parallel_config
Expand Down
Loading