Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 127 additions & 65 deletions lmcache/v1/multiprocess/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,45 @@ class _PrefetchJob:
cache_salt: str = ""


@dataclass
class RegisteredContext:
"""Registered context metadata for a single worker instance.

At least one of ``gpu_context`` or ``non_cuda_metadata`` is expected to be
populated for valid registrations.
"""

model_name: str
world_size: int
gpu_context: GPUCacheContext | None = None
non_cuda_metadata: NonGpuContextMetadata | None = None

@property
def is_gpu(self) -> bool:
"""Return whether this registration uses a GPU transfer context."""
return self.gpu_context is not None

def get_layout_desc(self, chunk_size: int) -> MemoryLayoutDesc:
"""Return the layout descriptor for this registration.

Args:
chunk_size: Chunk size in tokens used for GPU layout derivation.

Returns:
The resolved memory layout descriptor.

Raises:
ValueError: If no GPU context or non-CUDA metadata is configured.
"""
if self.gpu_context is not None:
return get_layout_desc(self.gpu_context, chunk_size)
if self.non_cuda_metadata is None:
raise ValueError(
"Invalid RegisteredContext: no GPU or non-CUDA metadata configured"
)
return self.non_cuda_metadata.layout_desc


# Main class for the mp cache engine
class MPCacheEngine:
def __init__(
Expand All @@ -184,16 +223,8 @@ def __init__(
chunk_size: int = 256,
hash_algorithm: str = "blake3",
):
# GPU ID -> KV cache tensors
self.gpu_contexts: dict[int, GPUCacheContext] = {}

# GPU ID -> (model name, world size) as metadata
# NOTE: This is mainly for determining the layout desc during prefetch
# We assume that if the (model name, world size) is the same, then
# the layout desc returned by the gpu context is the same.
self.gpu_context_meta: dict[int, tuple[str, int]] = {}
self.non_cuda_contexts: dict[int, NonGpuContextMetadata] = {}
self.non_cuda_context_meta: dict[int, tuple[str, int]] = {}
# Worker instance ID -> registered context metadata
self.contexts: dict[int, RegisteredContext] = {}

# chunk size
self.chunk_size = chunk_size
Expand Down Expand Up @@ -221,6 +252,15 @@ def __init__(

self._setup_metrics()

@property
def gpu_contexts(self) -> dict[int, GPUCacheContext]:
"""Return GPU-only context mapping for backward compatibility."""
return {
instance_id: ctx.gpu_context
for instance_id, ctx in self.contexts.items()
if ctx.gpu_context is not None
}

def register_kv_cache(
self,
instance_id: int,
Expand All @@ -244,7 +284,7 @@ def register_kv_cache(
layout_hints: See :class:`LayoutHints`. Forwarded to
:class:`GPUCacheContext` for GPU KV format detection.
"""
if instance_id in self.gpu_contexts:
if instance_id in self.contexts:
logger.warning(
"Instance %s's KV cache is already registered, "
"skipping the new registration",
Expand All @@ -258,8 +298,11 @@ def register_kv_cache(
layout_hints=layout_hints or None,
engine_type=engine_type,
)
self.gpu_contexts[instance_id] = gpu_context
self.gpu_context_meta[instance_id] = (model_name, world_size)
self.contexts[instance_id] = RegisteredContext(
model_name=model_name,
world_size=world_size,
gpu_context=gpu_context,
)
logger.info(
"Registered KV cache for GPU ID %d with %d layers",
instance_id,
Expand All @@ -273,17 +316,18 @@ def unregister_kv_cache(self, instance_id: int) -> None:
Args:
instance_id (int): The GPU instance ID (such as PID).
"""
if instance_id in self.gpu_contexts:
del self.gpu_contexts[instance_id]
del self.gpu_context_meta[instance_id]
context = self.contexts.pop(instance_id, None)
if context is None:
logger.warning(
"No registered context found for instance ID %d", instance_id
)
return

if context.is_gpu:
logger.info("Unregistered KV cache for GPU ID %d", instance_id)
torch_dev.empty_cache()
elif instance_id in self.non_cuda_contexts:
del self.non_cuda_contexts[instance_id]
del self.non_cuda_context_meta[instance_id]
logger.info("Unregistered non-CUDA context for instance ID %d", instance_id)
else:
logger.warning("No KV cache found for GPU ID %d to unregister", instance_id)
logger.info("Unregistered non-CUDA context for instance ID %d", instance_id)

def register_kv_cache_non_gpu_context(
self,
Expand Down Expand Up @@ -311,6 +355,14 @@ def register_kv_cache_non_gpu_context(
Raises:
ValueError: If ``dtype_str`` is not a valid torch dtype name.
"""
if instance_id in self.contexts:
logger.warning(
"Instance %s's KV cache is already registered, "
"skipping the new registration",
instance_id,
)
return

dtype = getattr(torch, dtype_str, None)
if dtype is None or not isinstance(dtype, torch.dtype):
raise ValueError(
Expand All @@ -325,12 +377,15 @@ def register_kv_cache_non_gpu_context(
else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size])
)
layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype])
self.non_cuda_contexts[instance_id] = NonGpuContextMetadata(
layout_desc=layout_desc,
block_size=block_size,
use_mla=use_mla,
self.contexts[instance_id] = RegisteredContext(
model_name=model_name,
world_size=world_size,
non_cuda_metadata=NonGpuContextMetadata(
layout_desc=layout_desc,
block_size=block_size,
use_mla=use_mla,
),
)
self.non_cuda_context_meta[instance_id] = (model_name, world_size)

def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]:
"""Resolve object keys from an IPC cache key.
Expand Down Expand Up @@ -375,11 +430,12 @@ def store_cpu_chunks(
"""
obj_keys = self._resolve_obj_keys(key)

if instance_id not in self.non_cuda_contexts:
context = self.contexts.get(instance_id)
if context is None or context.non_cuda_metadata is None:
raise ValueError(
f"non-CUDA context not registered for instance ID {instance_id}"
)
ctx = self.non_cuda_contexts[instance_id]
ctx = context.non_cuda_metadata
chunks: list[torch.Tensor] = pickle.loads(cpu_data)
reserved_dict = self.storage_manager.reserve_write(
obj_keys, ctx.layout_desc, "new"
Expand Down Expand Up @@ -426,7 +482,8 @@ def retrieve_cpu_chunks(
"""
obj_keys = self._resolve_obj_keys(key)

if instance_id not in self.non_cuda_contexts:
context = self.contexts.get(instance_id)
if context is None or context.non_cuda_metadata is None:
raise ValueError(
f"non-CUDA context not registered for instance ID {instance_id}"
)
Expand Down Expand Up @@ -473,11 +530,15 @@ def store(
st = time.perf_counter()
obj_keys = self._resolve_obj_keys(key)

assert instance_id in self.gpu_contexts, (
f"KV cache not registered for GPU ID {instance_id}"
context = self.contexts.get(instance_id)
assert context is not None, (
f"No context registered for instance ID {instance_id}"
)
assert context.gpu_context is not None, (
f"GPU context not registered for instance ID {instance_id}"
)
gpu_context = self.gpu_contexts[instance_id]
model_name = self.gpu_context_meta[instance_id][0]
gpu_context = context.gpu_context
model_name = context.model_name

# ``blocks_per_chunk`` is counted in inference-engine-side
# blocks (each block addresses
Expand Down Expand Up @@ -656,11 +717,15 @@ def retrieve(
st = time.perf_counter()
obj_keys = self._resolve_obj_keys(key)

assert instance_id in self.gpu_contexts, (
f"KV cache not registered for GPU ID {instance_id}"
context = self.contexts.get(instance_id)
assert context is not None, (
f"No context registered for instance ID {instance_id}"
)
assert context.gpu_context is not None, (
f"GPU context not registered for instance ID {instance_id}"
)
gpu_context = self.gpu_contexts[instance_id]
model_name = self.gpu_context_meta[instance_id][0]
gpu_context = context.gpu_context
model_name = context.model_name

# CPU-synchronous sentinel: a GPU retrieve is about to be enqueued.
# Must be published via publish() (not publish_on_stream) so the
Expand Down Expand Up @@ -837,15 +902,9 @@ def _find_layout_desc(
``(model_name, world_size)``. GPU contexts are checked first,
then CPU contexts.
"""
for gpu_id, (m, w) in self.gpu_context_meta.items():
if m == model_name and w == world_size:
return get_layout_desc(
self.gpu_contexts[gpu_id],
self.chunk_size,
)
for instance_id, (m, w) in self.non_cuda_context_meta.items():
if m == model_name and w == world_size:
return self.non_cuda_contexts[instance_id].layout_desc
for context in self.contexts.values():
if context.model_name == model_name and context.world_size == world_size:
return context.get_layout_desc(self.chunk_size)
return None

def lookup(
Expand Down Expand Up @@ -1161,13 +1220,14 @@ def report_status(self) -> dict:
sm = self.storage_manager.report_status()

gpu_context_meta: dict[str, dict] = {}
for gpu_id, meta in self.gpu_context_meta.items():
non_cuda_context_meta: dict[str, dict] = {}
for instance_id, context in self.contexts.items():
entry: dict = {
"model_name": meta[0],
"world_size": meta[1],
"model_name": context.model_name,
"world_size": context.world_size,
}
ctx = self.gpu_contexts.get(gpu_id)
if ctx is not None:
if context.gpu_context is not None:
ctx = context.gpu_context
entry["kv_cache_layout"] = {
"num_layers": ctx.num_layers,
"inference_engine_logical_block_size": (
Expand All @@ -1185,7 +1245,15 @@ def report_status(self) -> dict:
"attention_backend": ctx.attention_backend,
"cache_size_per_token": ctx.cache_size_per_token(),
}
gpu_context_meta[str(gpu_id)] = entry
gpu_context_meta[str(instance_id)] = entry
continue

if context.non_cuda_metadata is not None:
non_cuda_context_meta[str(instance_id)] = {
**entry,
"block_size": context.non_cuda_metadata.block_size,
"use_mla": context.non_cuda_metadata.use_mla,
}

return {
"is_healthy": sm["is_healthy"],
Expand All @@ -1194,18 +1262,12 @@ def report_status(self) -> dict:
"hash_algorithm": self.token_hasher.hash_algorithm_name,
"registered_gpu_ids": list(self.gpu_contexts.keys()),
"gpu_context_meta": gpu_context_meta,
"registered_non_cuda_instance_ids": list(self.non_cuda_contexts.keys()),
"non_cuda_context_meta": {
str(instance_id): {
"model_name": model_name,
"world_size": world_size,
"block_size": self.non_cuda_contexts[instance_id].block_size,
"use_mla": self.non_cuda_contexts[instance_id].use_mla,
}
for instance_id, (model_name, world_size) in (
self.non_cuda_context_meta.items()
)
},
"registered_non_cuda_instance_ids": [
instance_id
for instance_id, context in self.contexts.items()
if context.non_cuda_metadata is not None
],
"non_cuda_context_meta": non_cuda_context_meta,
"active_sessions": self.session_manager.active_count(),
"active_prefetch_jobs": self._active_prefetch_count(),
"storage_manager": sm,
Expand Down Expand Up @@ -1257,7 +1319,7 @@ def close(self) -> None:
logger.info("MPCacheEngine closed")

# Release GPU contexts
self.gpu_contexts.clear()
self.contexts.clear()

def _active_prefetch_count(self) -> int:
"""Return the number of active prefetch jobs (thread-safe)."""
Expand Down