Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ anthropic >= 0.71.0
model-hosting-container-standards >= 0.1.13, < 1.0.0
mcp
grpcio
grpcio-reflection
grpcio-reflection
spnl >= 0.21.0
121 changes: 121 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,127 @@ def load_log_config(log_config_file: str | None) -> dict | None:
)
return None

if envs.VLLM_V1_SPANS_ENABLED:
import spnl
import time
from fastapi import Body
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionStreamResponse,ChatCompletionResponseStreamChoice,ChatCompletionResponseChoice,DeltaMessage,UsageInfo)
spnl_state = spnl.init(10)
PAD_TOKEN = 27
PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
def wrap(prompt: str | list[str]) -> TokensPrompt:
if isinstance(prompt[0], list):
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
return TokensPrompt(prompt_token_ids=prompt)
@router.post("/v1/query/execute")
@with_cancellation
@load_aware_call
async def execute_query(raw_request: Request,
query: str = Body(..., media_type="text/plain"),
stream: bool = False):
req = spnl.tokenize_query(
spnl_state,
query,
PAD_TOKEN,
CROSS_TOKEN,
PLUS_TOKEN,
raw_request.app.state.vllm_config.cache_config.block_size
)

match req:
case spnl.TokenizedQuery.TokenizedChatCompletionQuery(q):
req = q # intentional fall-through
case spnl.TokenizedQuery.CompletionRequest(q):
request = CompletionRequest(model=q.model, max_tokens=q.max_tokens, temperature=q.temperature, prompt=q.inputs, stream=stream)
# what we want to do, but this is a fastapi endpoint... return create_completion(request, raw_request)
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")

try:
generator = await handler.create_completion(request, raw_request)
except OverflowError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e

if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.error.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())

return StreamingResponse(content=generator, media_type="text/event-stream")

request_id = raw_request.headers.get(
"X-Request-Id") or uuid.uuid4().hex
client = engine_client(raw_request)
generator = client.generate(wrap(req.messages), SamplingParams(n=1 if req.n <= 0 else req.n,temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id)

if stream:
async def sgen():
output_idx: List[int] = [0 for _ in range(req.n)]
async for res in generator:
yield ChatCompletionStreamResponse(
id=request_id,
object="chat.completion.chunk",
created=int(time.time()),
model=req.model,
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=output.text[output_idx[index]:]),
logprobs=output.logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
for index, output in enumerate(res.outputs)
]
).model_dump_json(exclude_unset=True)
for index, output in enumerate(res.outputs):
output_idx[index] = len(output.text)
return StreamingResponse(content=sgen(), media_type="text/event-stream")

outputs: List[Optional[CompletionOutput]] = [None for _ in range(req.n)]
async for res in generator:
for output in res.outputs:
outputs[output.index] = output
choices = [
ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content=output.text),
logprobs=output.logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
for index, output in enumerate(outputs)
]
num_prompt_tokens=0 # TODO
num_generated_tokens=0 # TODO
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
response = ChatCompletionResponse(
id=request_id,
created=int(time.time()),
model=req.model,
choices=choices,
usage=usage
)

if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.error.code)
return JSONResponse(content=response.model_dump())


class AuthenticationMiddleware:
"""
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
# potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None

self.hash_block_size = hash_block_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
Expand Down Expand Up @@ -192,6 +193,9 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:
request.block_hashes, max_cache_hit_length
)
)
if len(request.block_hashes) > 0:
bs = self.hash_block_size # spnl: cache hit rate logging
print(f"vLLMCacheHitRate {100*(num_new_computed_tokens/(len(request.block_hashes)*bs)):.2f}% computed={num_new_computed_tokens} requested={len(request.block_hashes)*bs}", flush=True)

if self.log_stats:
assert self.prefix_cache_stats is not None
Expand Down