Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions lmcache/v1/multiprocess/protocols/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,8 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]:
int,
str,
int,
EngineType,
LayoutHints,
bytes,
int,
int,
int,
str,
bool,
],
response_class=None,
Expand Down
102 changes: 36 additions & 66 deletions lmcache/v1/multiprocess/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,62 +289,60 @@ def register_kv_cache_cpu_context(
instance_id: int,
model_name: str,
world_size: int,
engine_type: EngineType,
layout_hints: LayoutHints,
layout_desc_bytes: bytes,
block_size: int,
num_layers: int,
hidden_dim_size: int,
dtype_str: str,
use_mla: bool,
) -> None:
"""Register non-CUDA KV layout metadata for CPU context mode.
"""Register non-CUDA KV context metadata for CPU context mode.

Args:
instance_id: Worker instance identifier (typically PID).
model_name: Model name associated with this worker.
world_size: Worker world size used in cache keys.
engine_type: Serving engine type (kept for protocol compatibility).
layout_hints: Optional engine layout hints (protocol compatibility).
layout_desc_bytes: Pickled :class:`MemoryLayoutDesc` carrying
KV cache tensor shape/dtype metadata (for both MLA and
non-MLA layouts).
block_size: Tokens per paged block.
num_layers: Number of model layers.
hidden_dim_size: Flattened hidden dimension per token.
dtype_str: Torch dtype name (for example ``"float16"``).
use_mla: Whether the worker KV format is MLA.
MLA stores one latent vector per token with shape
``[num_layers, chunk_size, hidden_dim_size]``; non-MLA stores
separate K/V planes with shape
``[2, num_layers, chunk_size, hidden_dim_size]``.

Raises:
ValueError: If ``dtype_str`` is not a valid torch dtype name.
ValueError: If ``layout_desc_bytes`` does not deserialize to
:class:`MemoryLayoutDesc`.
"""
# Third Party
import torch

# Keep these for protocol compatibility with register_kv_cache().
del engine_type, layout_hints
dtype = getattr(torch, dtype_str, None)
if dtype is None or not isinstance(dtype, torch.dtype):
layout_desc = pickle.loads(layout_desc_bytes)
if not isinstance(layout_desc, MemoryLayoutDesc):
raise ValueError(
f"Invalid dtype_str '{dtype_str}': expected a torch.dtype name "
"(e.g. 'float16', 'bfloat16', 'float32')."
"Invalid layout_desc_bytes: expected pickled MemoryLayoutDesc"
)

shape = (
# MLA has one latent state per token (no separate K/V axis),
# while non-MLA stores separate K and V tensors at dim 0.
torch.Size([num_layers, self.chunk_size, hidden_dim_size])
if use_mla
else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size])
)
layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype])
self.cpu_contexts[instance_id] = CPUContextMetadata(
layout_desc=layout_desc,
block_size=block_size,
use_mla=use_mla,
)
self.cpu_context_meta[instance_id] = (model_name, world_size)

def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]:
"""Resolve object keys from request/session token hashes.

Args:
key: IPC cache key carrying request/session metadata.

Returns:
Object keys derived from hashed chunk ranges.

Raises:
ValueError: If ``key.worker_id`` is ``None``. This method is
only valid for worker-originated IPC operations.
"""
session = self.session_manager.get_or_create(key.request_id)
session.set_tokens(list(key.token_ids))
chunk_hashes = [
TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end)
]
if key.worker_id is None:
raise ValueError("worker_id must not be None for this operation")
return ipc_key_to_object_keys(key, chunk_hashes)

@_lmcache_nvtx_annotate
def store_cpu_chunks(
self,
Expand All @@ -368,14 +366,7 @@ def store_cpu_chunks(
# Third Party
import torch

session = self.session_manager.get_or_create(key.request_id)
session.set_tokens(list(key.token_ids))
chunk_hashes = [
TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end)
]
if key.worker_id is None:
raise ValueError("Must store with worker_id != None")
obj_keys = ipc_key_to_object_keys(key, chunk_hashes)
obj_keys = self._resolve_obj_keys(key)

if instance_id not in self.cpu_contexts:
raise ValueError(
Expand Down Expand Up @@ -426,14 +417,7 @@ def retrieve_cpu_chunks(
Raises:
ValueError: If the instance has no registered cpu context.
"""
session = self.session_manager.get_or_create(key.request_id)
session.set_tokens(list(key.token_ids))
chunk_hashes = [
TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end)
]
if key.worker_id is None:
raise ValueError("Must retrieve with worker_id != None")
obj_keys = ipc_key_to_object_keys(key, chunk_hashes)
obj_keys = self._resolve_obj_keys(key)

if instance_id not in self.cpu_contexts:
raise ValueError(
Expand Down Expand Up @@ -479,17 +463,10 @@ def store(
that signals the completion of the store operation. The second
element indicates whether the store operation was successful.
"""
session = self.session_manager.get_or_create(key.request_id)
session.set_tokens(list(key.token_ids))
chunk_hashes = [
TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end)
]
obj_keys = self._resolve_obj_keys(key)

st = time.perf_counter()

assert key.worker_id is not None, "Must store with worker_id != None"
obj_keys = ipc_key_to_object_keys(key, chunk_hashes)

assert instance_id in self.gpu_contexts, (
f"KV cache not registered for GPU ID {instance_id}"
)
Expand Down Expand Up @@ -655,17 +632,10 @@ def retrieve(
that signals the completion of the retrieve operation. The second
element indicates whether the key was successfully retrieved.
"""
session = self.session_manager.get_or_create(key.request_id)
session.set_tokens(list(key.token_ids))
chunk_hashes = [
TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end)
]
obj_keys = self._resolve_obj_keys(key)

st = time.perf_counter()

assert key.worker_id is not None, "Must retrieve with worker_id != None"
obj_keys = ipc_key_to_object_keys(key, chunk_hashes)

assert instance_id in self.gpu_contexts, (
f"KV cache not registered for GPU ID {instance_id}"
)
Expand Down
30 changes: 14 additions & 16 deletions lmcache/v1/multiprocess/transfer_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Transfer context abstractions for LMCache multiprocess worker adapters."""

# Standard
import pickle
from abc import ABC, abstractmethod
from typing import Any, Callable, Protocol

Expand Down Expand Up @@ -330,6 +331,16 @@ def register(
) = compute_kv_layout(kv_caches, layout_hints=layout_hints)
self._layout_hints = layout_hints
self._gpu_kv_format = gpu_kv_format
use_mla_flag = is_mla(gpu_kv_format)
shape = (
torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size])
if use_mla_flag
else torch.Size(
[2, num_layers, blocks_in_chunk * block_size, hidden_dim_size]
)
)
dtype = getattr(torch, dtype_str)
layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype])

future = send_request(
mq_client,
Expand All @@ -338,27 +349,14 @@ def register(
instance_id,
model_name,
world_size,
EngineType.VLLM,
layout_hints,
pickle.dumps(layout_desc),
block_size,
num_layers,
hidden_dim_size,
dtype_str,
is_mla(gpu_kv_format),
use_mla_flag,
],
)

use_mla_flag = is_mla(gpu_kv_format)
shape = (
torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size])
if use_mla_flag
else torch.Size(
[2, num_layers, blocks_in_chunk * block_size, hidden_dim_size]
)
)
dtype = getattr(torch, dtype_str)
metadata = CPUContextMetadata(
layout_desc=MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]),
layout_desc=layout_desc,
block_size=block_size,
use_mla=use_mla_flag,
)
Expand Down
26 changes: 16 additions & 10 deletions tests/v1/multiprocess/test_cpu_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import pytest
import torch

# First Party
from lmcache.v1.distributed.api import MemoryLayoutDesc


def _make_kv_caches(
num_layers: int = 2,
Expand Down Expand Up @@ -345,16 +348,18 @@ def test_server_register_and_find_cpu_context_layout(
patch("lmcache.v1.multiprocess.server.get_event_bus"),
):
engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16)

engine.register_kv_cache_cpu_context(
instance_id=1,
model_name="m",
world_size=1,
engine_type=MagicMock(),
layout_hints={},
layout_desc_bytes=pickle.dumps(
MemoryLayoutDesc(
shapes=[torch.Size([2, 2, 16, 16])],
dtypes=[torch.float32],
)
),
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
)

Expand Down Expand Up @@ -402,12 +407,13 @@ def _read_prefetched_results(_keys: Any) -> Any:
instance_id=2,
model_name="m",
world_size=1,
engine_type=MagicMock(),
layout_hints={},
layout_desc_bytes=pickle.dumps(
MemoryLayoutDesc(
shapes=[torch.Size([2, 2, 8, 16])],
dtypes=[torch.float32],
)
),
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
)
payload = torch.ones(2, 2, 8, 16)
Expand Down
12 changes: 12 additions & 0 deletions tests/v1/test_vllm_mp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

# Standard
import pickle
from unittest.mock import MagicMock

# Third Party
Expand All @@ -23,6 +24,7 @@
LoadStoreOp,
ParallelStrategy,
)
from lmcache.v1.distributed.api import MemoryLayoutDesc
from lmcache.v1.multiprocess.protocol import RequestType


Expand Down Expand Up @@ -120,6 +122,16 @@ def test_register_kv_caches_cpu_submits_cpu_context_registration(
assert send_mock.call_count == 1
args, _kwargs = send_mock.call_args
assert args[1] == RequestType.REGISTER_KV_CACHE_CPU_CONTEXT
payload = args[2]
assert len(payload) == 6
assert payload[0] == adapter.instance_id
assert payload[1] == "test-model"
assert payload[2] == 1
assert isinstance(payload[3], bytes)

assert isinstance(pickle.loads(payload[3]), MemoryLayoutDesc)
assert payload[4] == 4
assert payload[5] is False


def test_submit_store_request_passes_no_transport_kwargs(fake_adapter, monkeypatch):
Expand Down