diff --git a/requirements/common.txt b/requirements/common.txt index 4fe34456b8..cdbd26fb4f 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -52,4 +52,5 @@ anthropic >= 0.71.0 model-hosting-container-standards >= 0.1.13, < 1.0.0 mcp grpcio -grpcio-reflection \ No newline at end of file +grpcio-reflection +spnl >= 0.21.0 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6de41a9e73..9cb062688b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -263,6 +263,127 @@ def load_log_config(log_config_file: str | None) -> dict | None: ) return None +if envs.VLLM_V1_SPANS_ENABLED: + import spnl + import time + from fastapi import Body + from vllm import SamplingParams + from vllm.inputs import TokensPrompt + from vllm.outputs import RequestOutput + from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionStreamResponse,ChatCompletionResponseStreamChoice,ChatCompletionResponseChoice,DeltaMessage,UsageInfo) + spnl_state = spnl.init(10) + PAD_TOKEN = 27 + PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None + CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None + def wrap(prompt: str | list[str]) -> TokensPrompt: + if isinstance(prompt[0], list): + return [TokensPrompt(prompt_token_ids=p) for p in prompt] + return TokensPrompt(prompt_token_ids=prompt) + @router.post("/v1/query/execute") + @with_cancellation + @load_aware_call + async def execute_query(raw_request: Request, + query: str = Body(..., media_type="text/plain"), + stream: bool = False): + req = spnl.tokenize_query( + spnl_state, + query, + PAD_TOKEN, + CROSS_TOKEN, + PLUS_TOKEN, + raw_request.app.state.vllm_config.cache_config.block_size + ) + + match req: + case spnl.TokenizedQuery.TokenizedChatCompletionQuery(q): + req = q # intentional fall-through + case spnl.TokenizedQuery.CompletionRequest(q): + request = CompletionRequest(model=q.model, max_tokens=q.max_tokens, temperature=q.temperature, prompt=q.inputs, stream=stream) + # what we want to do, but this is a fastapi endpoint... return create_completion(request, raw_request) + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + try: + generator = await handler.create_completion(request, raw_request) + except OverflowError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.error.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + request_id = raw_request.headers.get( + "X-Request-Id") or uuid.uuid4().hex + client = engine_client(raw_request) + generator = client.generate(wrap(req.messages), SamplingParams(n=1 if req.n <= 0 else req.n,temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id) + + if stream: + async def sgen(): + output_idx: List[int] = [0 for _ in range(req.n)] + async for res in generator: + yield ChatCompletionStreamResponse( + id=request_id, + object="chat.completion.chunk", + created=int(time.time()), + model=req.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(role="assistant", content=output.text[output_idx[index]:]), + logprobs=output.logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + ) + for index, output in enumerate(res.outputs) + ] + ).model_dump_json(exclude_unset=True) + for index, output in enumerate(res.outputs): + output_idx[index] = len(output.text) + return StreamingResponse(content=sgen(), media_type="text/event-stream") + + outputs: List[Optional[CompletionOutput]] = [None for _ in range(req.n)] + async for res in generator: + for output in res.outputs: + outputs[output.index] = output + choices = [ + ChatCompletionResponseChoice( + index=index, + message=ChatMessage(role="assistant", content=output.text), + logprobs=output.logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + ) + for index, output in enumerate(outputs) + ] + num_prompt_tokens=0 # TODO + num_generated_tokens=0 # TODO + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + response = ChatCompletionResponse( + id=request_id, + created=int(time.time()), + model=req.model, + choices=choices, + usage=usage + ) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.error.code) + return JSONResponse(content=response.model_dump()) + class AuthenticationMiddleware: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2caed04937..1ec8f526e6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -116,6 +116,7 @@ def __init__( # potential configs we could expose in the future. self.prefix_cache_stats = PrefixCacheStats() if log_stats else None + self.hash_block_size = hash_block_size self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, @@ -192,6 +193,9 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: request.block_hashes, max_cache_hit_length ) ) + if len(request.block_hashes) > 0: + bs = self.hash_block_size # spnl: cache hit rate logging + print(f"vLLMCacheHitRate {100*(num_new_computed_tokens/(len(request.block_hashes)*bs)):.2f}% computed={num_new_computed_tokens} requested={len(request.block_hashes)*bs}", flush=True) if self.log_stats: assert self.prefix_cache_stats is not None