diff --git a/lmcache/v1/multiprocess/custom_types.py b/lmcache/v1/multiprocess/custom_types.py index ec82b8bbd47..28cc2e3a85a 100644 --- a/lmcache/v1/multiprocess/custom_types.py +++ b/lmcache/v1/multiprocess/custom_types.py @@ -315,6 +315,30 @@ def no_worker_id_version(self) -> "IPCCacheEngineKey": KVCache = list[CudaIPCWrapper] +class RegisterNonGpuContextPayload(msgspec.Struct): + """Payload for the REGISTER_KV_CACHE_NON_GPU_CONTEXT protocol message. + + Attributes: + instance_id: Worker instance identifier (typically PID). + model_name: Model name associated with this worker. + world_size: Worker world size used in cache keys. + block_size: Tokens per paged block. + num_layers: Number of model layers. + hidden_dim_size: Flattened hidden dimension per token. + dtype_str: Torch dtype name (e.g. ``"float16"``). + use_mla: Whether the worker KV format is MLA. + """ + + instance_id: int + model_name: str + world_size: int + block_size: int + num_layers: int + hidden_dim_size: int + dtype_str: str + use_mla: bool + + @dataclass class CustomizedSerdeConfig: serializer: Callable[[Any], bytes] diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index e3f80b6a6c7..cb3126c2ad9 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -18,6 +18,7 @@ from lmcache.v1.multiprocess.custom_types import ( IPCCacheEngineKey, KVCache, + RegisterNonGpuContextPayload, ) from lmcache.v1.multiprocess.protocols.base import HandlerType, ProtocolDefinition @@ -149,17 +150,12 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: response_class=None, handler_type=HandlerType.BLOCKING, ), + # Register non-GPU KV cache context + # Payload: + # - RegisterNonGpuContextPayload - all metadata fields in one struct + # Returns: None "REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition( - payload_classes=[ - int, - str, - int, - int, - int, - int, - str, - bool, - ], + payload_classes=[RegisterNonGpuContextPayload], response_class=None, handler_type=HandlerType.SYNC, ), diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 8b467e6c4c1..5dad97e4f2a 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -57,6 +57,7 @@ BlockAllocationRecord, IPCCacheEngineKey, KVCache, + RegisterNonGpuContextPayload, ) from lmcache.v1.multiprocess.gpu_context import ( GPUCacheContext, @@ -331,59 +332,49 @@ def unregister_kv_cache(self, instance_id: int) -> None: def register_kv_cache_non_gpu_context( self, - instance_id: int, - model_name: str, - world_size: int, - block_size: int, - num_layers: int, - hidden_dim_size: int, - dtype_str: str, - use_mla: bool, + payload: RegisterNonGpuContextPayload, ) -> None: """Register non-CUDA KV layout metadata for non-GPU context mode. Args: - instance_id: Worker instance identifier (typically PID). - model_name: Model name associated with this worker. - world_size: Worker world size used in cache keys. - block_size: Tokens per paged block. - num_layers: Number of model layers. - hidden_dim_size: Flattened hidden dimension per token. - dtype_str: Torch dtype name (for example ``"float16"``). - use_mla: Whether the worker KV format is MLA. + payload: Struct containing all registration fields + (instance_id, model_name, world_size, block_size, + num_layers, hidden_dim_size, dtype_str, use_mla). Raises: - ValueError: If ``dtype_str`` is not a valid torch dtype name. + ValueError: If ``payload.dtype_str`` is not a valid torch dtype name. """ - if instance_id in self.contexts: + if payload.instance_id in self.contexts: logger.warning( "Instance %s's KV cache is already registered, " "skipping the new registration", - instance_id, + payload.instance_id, ) return - dtype = getattr(torch, dtype_str, None) + dtype = getattr(torch, payload.dtype_str, None) if dtype is None or not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid dtype_str '{dtype_str}': must be a valid torch dtype " + f"Invalid dtype_str '{payload.dtype_str}': must be a valid torch dtype " "attribute name (e.g. 'float16' for torch.float16, " "'bfloat16' for torch.bfloat16, 'float32' for torch.float32)." ) shape = ( - torch.Size([num_layers, self.chunk_size, hidden_dim_size]) - if use_mla - else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) + torch.Size([payload.num_layers, self.chunk_size, payload.hidden_dim_size]) + if payload.use_mla + else torch.Size( + [2, payload.num_layers, self.chunk_size, payload.hidden_dim_size] + ) ) layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) - self.contexts[instance_id] = RegisteredContext( - model_name=model_name, - world_size=world_size, + self.contexts[payload.instance_id] = RegisteredContext( + model_name=payload.model_name, + world_size=payload.world_size, non_cuda_metadata=NonGpuContextMetadata( layout_desc=layout_desc, - block_size=block_size, - use_mla=use_mla, + block_size=payload.block_size, + use_mla=payload.use_mla, ), ) diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 7d59ab323ed..a2b1f8ba7a2 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -13,6 +13,7 @@ from lmcache.utils import EngineType, init_logger from lmcache.v1.distributed.api import MemoryLayoutDesc from lmcache.v1.gpu_connector.utils import is_mla +from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload from lmcache.v1.multiprocess.futures import MessagingFuture from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.non_gpu_context import ( @@ -287,14 +288,16 @@ def register( mq_client, RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, [ - instance_id, - model_name, - world_size, - block_size, - num_layers, - hidden_dim_size, - dtype_str, - use_mla_flag, + RegisterNonGpuContextPayload( + instance_id=instance_id, + model_name=model_name, + world_size=world_size, + block_size=block_size, + num_layers=num_layers, + hidden_dim_size=hidden_dim_size, + dtype_str=dtype_str, + use_mla=use_mla_flag, + ) ], ) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index b7cb30dcd32..31783e46903 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -356,6 +356,7 @@ def test_server_register_and_find_non_cuda_context_layout( ) -> None: """Ensure non-CUDA registration stores metadata and lookup finds layout.""" # First Party + from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload from lmcache.v1.multiprocess.server import MPCacheEngine with ( @@ -366,14 +367,16 @@ def test_server_register_and_find_non_cuda_context_layout( ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) engine.register_kv_cache_non_gpu_context( - instance_id=1, - model_name="m", - world_size=1, - block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", - use_mla=False, + RegisterNonGpuContextPayload( + instance_id=1, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) ) layout = engine._find_layout_desc("m", 1) @@ -384,7 +387,10 @@ def test_server_register_and_find_non_cuda_context_layout( def test_server_store_and_retrieve_cpu_chunks(stub_native_storage_ops: Any) -> None: """Validate mocked server-side CPU chunk store and retrieve behavior.""" # First Party - from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) from lmcache.v1.multiprocess.server import MPCacheEngine mock_storage = MagicMock() @@ -417,14 +423,16 @@ def _read_prefetched_results(_keys: Any) -> Any: engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) engine.register_kv_cache_non_gpu_context( - instance_id=2, - model_name="m", - world_size=1, - block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", - use_mla=False, + RegisterNonGpuContextPayload( + instance_id=2, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) ) payload = torch.ones(2, 2, 8, 16) key = IPCCacheEngineKey.from_token_ids( diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index f18404053e6..e4e9c64b33f 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -124,7 +124,7 @@ def test_register_kv_caches_cpu_submits_non_gpu_context_registration( assert send_mock.call_count == 1 args, _kwargs = send_mock.call_args assert args[1] == RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT - assert len(args[2]) == 8 + assert len(args[2]) == 1 def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch):