Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 51 additions & 40 deletions lmcache/v1/multiprocess/blend_server_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"""

# Standard
from dataclasses import dataclass
from typing import Any
import threading
import time
Expand Down Expand Up @@ -101,6 +102,18 @@
logger = init_logger(__name__)


@dataclass
class _CBRegisteredContext:
"""Registered CB GPU context metadata."""

# Model identity for resolving CB layout desc during lookup.
model_name: str
# World size used with model_name to match CB layout.
world_size: int
# GPU context for CB store/retrieve operations.
gpu_context: PlainGPUCacheContext


class BlendTokenRangeMatcher:
"""Fast token-range matcher using polynomial rolling/chunk hashes and a
direct-address lookup table.
Expand Down Expand Up @@ -390,11 +403,9 @@ def __init__(
storage_manager_config, chunk_size, hash_algorithm=hash_algorithm
)

self._cb_gpu_contexts: dict[int, PlainGPUCacheContext] = {}

# CB GPU ID -> (model name, world size) as metadata
# CB instance ID -> registered context
# NOTE: This is mainly for determining the layout desc during prefetch
self._cb_gpu_context_meta: dict[int, tuple[str, int]] = {}
self._cb_contexts: dict[int, _CBRegisteredContext] = {}

# Fast local matcher: indexes pre-computed chunk hashes for sub-sequence lookup
self._token_range_matcher = BlendTokenRangeMatcher(chunk_size)
Expand All @@ -418,8 +429,11 @@ def cb_register_kv_cache(
world_size: The world size associated with this KV cache.
"""
gpu_context = PlainGPUCacheContext(kv_caches, self.chunk_size)
self._cb_gpu_contexts[instance_id] = gpu_context
self._cb_gpu_context_meta[instance_id] = (model_name, world_size)
self._cb_contexts[instance_id] = _CBRegisteredContext(
model_name=model_name,
world_size=world_size,
gpu_context=gpu_context,
)
logger.info(
"Registered CB KV cache for instance_id %d with %d layers",
instance_id,
Expand Down Expand Up @@ -451,9 +465,8 @@ def cb_unregister_kv_cache(self, instance_id: int) -> None:
Args:
instance_id: Unique identifier for the blend engine instance to unregister
"""
if instance_id in self._cb_gpu_contexts:
del self._cb_gpu_contexts[instance_id]
del self._cb_gpu_context_meta[instance_id]
context = self._cb_contexts.pop(instance_id, None)
if context is not None:
logger.info("Unregistered CB KV cache for instance_id %d", instance_id)
else:
logger.warning(
Expand All @@ -471,30 +484,28 @@ def report_status(self) -> dict:
status = super().report_status()

cb_gpu_context_meta: dict[str, dict] = {}
for gpu_id, meta in self._cb_gpu_context_meta.items():
model_name, world_size = meta
for instance_id, context in self._cb_contexts.items():
entry: dict = {
"model_name": model_name,
"world_size": world_size,
"model_name": context.model_name,
"world_size": context.world_size,
}
ctx = self._cb_gpu_contexts.get(gpu_id)
if ctx is not None:
# bytes per token = 2 (K+V) * num_layers * hidden_dim_size *
# itemsize; num_tokens is the cache capacity, not a per-token
# cost.
cache_size_per_token = (
2 * ctx.num_layers * ctx.hidden_dim_size * ctx.dtype.itemsize
)
entry["kv_cache_layout"] = {
"num_layers": ctx.num_layers,
"num_tokens": ctx.num_tokens,
"hidden_dim_size": ctx.hidden_dim_size,
"dtype": str(ctx.dtype),
"cache_size_per_token": cache_size_per_token,
}
cb_gpu_context_meta[str(gpu_id)] = entry

status["registered_cb_gpu_ids"] = list(self._cb_gpu_contexts.keys())
ctx = context.gpu_context
# bytes per token = 2 (K+V) * num_layers * hidden_dim_size *
# itemsize; num_tokens is the cache capacity, not a per-token
# cost.
cache_size_per_token = (
2 * ctx.num_layers * ctx.hidden_dim_size * ctx.dtype.itemsize
)
entry["kv_cache_layout"] = {
"num_layers": ctx.num_layers,
"num_tokens": ctx.num_tokens,
"hidden_dim_size": ctx.hidden_dim_size,
"dtype": str(ctx.dtype),
"cache_size_per_token": cache_size_per_token,
}
cb_gpu_context_meta[str(instance_id)] = entry

status["registered_cb_gpu_ids"] = list(self._cb_contexts.keys())
status["cb_gpu_context_meta"] = cb_gpu_context_meta
return status

Expand Down Expand Up @@ -593,9 +604,9 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]:

# Find the cb gpu context and calculate the layout desc
layout_desc: MemoryLayoutDesc | None = None
for gpu_id, (m_name, w_size) in self._cb_gpu_context_meta.items():
if m_name == model_name and w_size == world_size:
cb_ctx = self._cb_gpu_contexts[gpu_id]
for context in self._cb_contexts.values():
if context.model_name == model_name and context.world_size == world_size:
cb_ctx = context.gpu_context
layout_desc = MemoryLayoutDesc(
shapes=[cb_ctx.get_kv_buffer_shape(self.chunk_size)],
dtypes=[cb_ctx.dtype],
Expand Down Expand Up @@ -845,10 +856,10 @@ def cb_store_pre_computed(
"""
num_tokens = key.end - key.start

assert instance_id in self._cb_gpu_contexts, (
assert instance_id in self._cb_contexts, (
f"Instance ID {instance_id} not registered for CB KV cache"
)
gpu_context = self._cb_gpu_contexts[instance_id]
gpu_context = self._cb_contexts[instance_id].gpu_context

# CPU-synchronous sentinel: GPU store is about to be enqueued.
self._event_bus.publish(
Expand Down Expand Up @@ -968,10 +979,10 @@ def cb_retrieve_pre_computed(
Note:
We must call `cb_lookup_pre_computed` first before calling this function
"""
assert instance_id in self._cb_gpu_contexts, (
assert instance_id in self._cb_contexts, (
f"Instance ID {instance_id} not registered for CB KV cache"
)
gpu_context = self._cb_gpu_contexts[instance_id]
gpu_context = self._cb_contexts[instance_id].gpu_context

# One obj_key per match_result, in cur_st order
cb_match_result = sorted(cb_match_result, key=lambda r: r.cur_st)
Expand Down Expand Up @@ -1113,10 +1124,10 @@ def cb_store_final(
num_tokens = key.end - key.start

# Get GPU context
assert instance_id in self._cb_gpu_contexts, (
assert instance_id in self._cb_contexts, (
f"Instance ID {instance_id} not registered for CB KV cache"
)
gpu_context = self._cb_gpu_contexts[instance_id]
gpu_context = self._cb_contexts[instance_id].gpu_context

# CPU-synchronous sentinels: SUBMITTED before SESSION_END so the
# tracing subscriber's in-flight counter is non-zero when SESSION_END
Expand Down
11 changes: 11 additions & 0 deletions lmcache/v1/multiprocess/http_apis/cache_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ async def kvcache_check(
)

gpu_ctxs = getattr(engine, "gpu_contexts", None)
if gpu_ctxs is None:
contexts = getattr(engine, "contexts", None)
# Unified registry fallback: contexts is expected to be
# dict[int, RegisteredContext]-like, where each value may expose a
# nullable ``gpu_context`` attribute.
if isinstance(contexts, dict):
gpu_ctxs = {
instance_id: context.gpu_context
for instance_id, context in contexts.items()
if hasattr(context, "gpu_context") and context.gpu_context is not None
}
if gpu_ctxs is None:
return JSONResponse(
status_code=501,
Expand Down
Loading