Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions lmcache/v1/multiprocess/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,30 @@ def no_worker_id_version(self) -> "IPCCacheEngineKey":
KVCache = list[CudaIPCWrapper]


class RegisterNonGpuContextPayload(msgspec.Struct):
"""Payload for the REGISTER_KV_CACHE_NON_GPU_CONTEXT protocol message.

Attributes:
instance_id: Worker instance identifier (typically PID).
model_name: Model name associated with this worker.
world_size: Worker world size used in cache keys.
block_size: Tokens per paged block.
num_layers: Number of model layers.
hidden_dim_size: Flattened hidden dimension per token.
dtype_str: Torch dtype name (e.g. ``"float16"``).
use_mla: Whether the worker KV format is MLA.
"""

instance_id: int
model_name: str
world_size: int
block_size: int
num_layers: int
hidden_dim_size: int
dtype_str: str
use_mla: bool


@dataclass
class CustomizedSerdeConfig:
serializer: Callable[[Any], bytes]
Expand Down
16 changes: 6 additions & 10 deletions lmcache/v1/multiprocess/protocols/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lmcache.v1.multiprocess.custom_types import (
IPCCacheEngineKey,
KVCache,
RegisterNonGpuContextPayload,
)
from lmcache.v1.multiprocess.protocols.base import HandlerType, ProtocolDefinition

Expand Down Expand Up @@ -149,17 +150,12 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]:
response_class=None,
handler_type=HandlerType.BLOCKING,
),
# Register non-GPU KV cache context
# Payload:
# - RegisterNonGpuContextPayload - all metadata fields in one struct
# Returns: None
"REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition(
payload_classes=[
int,
str,
int,
int,
int,
int,
str,
bool,
],
payload_classes=[RegisterNonGpuContextPayload],
response_class=None,
handler_type=HandlerType.SYNC,
),
Expand Down
49 changes: 20 additions & 29 deletions lmcache/v1/multiprocess/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
BlockAllocationRecord,
IPCCacheEngineKey,
KVCache,
RegisterNonGpuContextPayload,
)
from lmcache.v1.multiprocess.gpu_context import (
GPUCacheContext,
Expand Down Expand Up @@ -331,59 +332,49 @@ def unregister_kv_cache(self, instance_id: int) -> None:

def register_kv_cache_non_gpu_context(
self,
instance_id: int,
model_name: str,
world_size: int,
block_size: int,
num_layers: int,
hidden_dim_size: int,
dtype_str: str,
use_mla: bool,
payload: RegisterNonGpuContextPayload,
) -> None:
"""Register non-CUDA KV layout metadata for non-GPU context mode.

Args:
instance_id: Worker instance identifier (typically PID).
model_name: Model name associated with this worker.
world_size: Worker world size used in cache keys.
block_size: Tokens per paged block.
num_layers: Number of model layers.
hidden_dim_size: Flattened hidden dimension per token.
dtype_str: Torch dtype name (for example ``"float16"``).
use_mla: Whether the worker KV format is MLA.
payload: Struct containing all registration fields
(instance_id, model_name, world_size, block_size,
num_layers, hidden_dim_size, dtype_str, use_mla).

Raises:
ValueError: If ``dtype_str`` is not a valid torch dtype name.
ValueError: If ``payload.dtype_str`` is not a valid torch dtype name.
"""
if instance_id in self.contexts:
if payload.instance_id in self.contexts:
logger.warning(
"Instance %s's KV cache is already registered, "
"skipping the new registration",
instance_id,
payload.instance_id,
)
return

dtype = getattr(torch, dtype_str, None)
dtype = getattr(torch, payload.dtype_str, None)
if dtype is None or not isinstance(dtype, torch.dtype):
raise ValueError(
f"Invalid dtype_str '{dtype_str}': must be a valid torch dtype "
f"Invalid dtype_str '{payload.dtype_str}': must be a valid torch dtype "
"attribute name (e.g. 'float16' for torch.float16, "
"'bfloat16' for torch.bfloat16, 'float32' for torch.float32)."
)

shape = (
torch.Size([num_layers, self.chunk_size, hidden_dim_size])
if use_mla
else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size])
torch.Size([payload.num_layers, self.chunk_size, payload.hidden_dim_size])
if payload.use_mla
else torch.Size(
[2, payload.num_layers, self.chunk_size, payload.hidden_dim_size]
)
)
layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype])
self.contexts[instance_id] = RegisteredContext(
model_name=model_name,
world_size=world_size,
self.contexts[payload.instance_id] = RegisteredContext(
model_name=payload.model_name,
world_size=payload.world_size,
non_cuda_metadata=NonGpuContextMetadata(
layout_desc=layout_desc,
block_size=block_size,
use_mla=use_mla,
block_size=payload.block_size,
use_mla=payload.use_mla,
),
)

Expand Down
19 changes: 11 additions & 8 deletions lmcache/v1/multiprocess/transfer_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lmcache.utils import EngineType, init_logger
from lmcache.v1.distributed.api import MemoryLayoutDesc
from lmcache.v1.gpu_connector.utils import is_mla
from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload
from lmcache.v1.multiprocess.futures import MessagingFuture
from lmcache.v1.multiprocess.mq import MessageQueueClient
from lmcache.v1.multiprocess.non_gpu_context import (
Expand Down Expand Up @@ -287,14 +288,16 @@ def register(
mq_client,
RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT,
[
instance_id,
model_name,
world_size,
block_size,
num_layers,
hidden_dim_size,
dtype_str,
use_mla_flag,
RegisterNonGpuContextPayload(
instance_id=instance_id,
model_name=model_name,
world_size=world_size,
block_size=block_size,
num_layers=num_layers,
hidden_dim_size=hidden_dim_size,
dtype_str=dtype_str,
use_mla=use_mla_flag,
)
],
)

Expand Down
42 changes: 25 additions & 17 deletions tests/v1/multiprocess/test_non_cuda_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def test_server_register_and_find_non_cuda_context_layout(
) -> None:
"""Ensure non-CUDA registration stores metadata and lookup finds layout."""
# First Party
from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload
from lmcache.v1.multiprocess.server import MPCacheEngine

with (
Expand All @@ -366,14 +367,16 @@ def test_server_register_and_find_non_cuda_context_layout(
):
engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16)
engine.register_kv_cache_non_gpu_context(
instance_id=1,
model_name="m",
world_size=1,
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
RegisterNonGpuContextPayload(
instance_id=1,
model_name="m",
world_size=1,
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
)
)

layout = engine._find_layout_desc("m", 1)
Expand All @@ -384,7 +387,10 @@ def test_server_register_and_find_non_cuda_context_layout(
def test_server_store_and_retrieve_cpu_chunks(stub_native_storage_ops: Any) -> None:
"""Validate mocked server-side CPU chunk store and retrieve behavior."""
# First Party
from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey
from lmcache.v1.multiprocess.custom_types import (
IPCCacheEngineKey,
RegisterNonGpuContextPayload,
)
from lmcache.v1.multiprocess.server import MPCacheEngine

mock_storage = MagicMock()
Expand Down Expand Up @@ -417,14 +423,16 @@ def _read_prefetched_results(_keys: Any) -> Any:
engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8)

engine.register_kv_cache_non_gpu_context(
instance_id=2,
model_name="m",
world_size=1,
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
RegisterNonGpuContextPayload(
instance_id=2,
model_name="m",
world_size=1,
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
)
)
payload = torch.ones(2, 2, 8, 16)
key = IPCCacheEngineKey.from_token_ids(
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/test_vllm_mp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_register_kv_caches_cpu_submits_non_gpu_context_registration(
assert send_mock.call_count == 1
args, _kwargs = send_mock.call_args
assert args[1] == RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT
assert len(args[2]) == 8
assert len(args[2]) == 1


def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch):
Expand Down