From 7bb415714650600e049d6029b230d73856e5f83e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 13 May 2026 06:16:23 +0000 Subject: [PATCH 1/6] Initial plan From ace643dfea542d2552b80ae21fe417ee6fe264b9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 13 May 2026 06:32:05 +0000 Subject: [PATCH 2/6] Refactor transfer context to return adapter-compatible futures Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/b277cbef-7a95-47da-b1b5-93a69834a331 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- .../vllm/vllm_multi_process_adapter.py | 106 +++++--- lmcache/v1/multiprocess/protocols/engine.py | 6 +- lmcache/v1/multiprocess/server.py | 93 ++----- lmcache/v1/multiprocess/transfer_context.py | 228 ++++-------------- tests/v1/multiprocess/test_cpu_context.py | 39 +-- tests/v1/test_vllm_mp_adapter.py | 16 +- 6 files changed, 174 insertions(+), 314 deletions(-) diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 10fec34cab7..02fcb031485 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -36,11 +36,7 @@ DEFAULT_HEARTBEAT_INTERVAL: float = 10.0 -def wrap_kv_caches( - kv_caches: dict[str, torch.Tensor], use_cpu_context: bool = False -) -> KVCache: - if use_cpu_context: - return [] +def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache: logger.info("KV caches keys are %s", list(kv_caches.keys())) return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()] @@ -704,8 +700,13 @@ def __init__( # Transport context for transfer operations. self.transfer_ctx: TransferContext | None = None - # Store requests submitted but not yet finished by transfer context. - self._pending_store_request_ids: set[str] = set() + + # Request futures + self.store_futures: dict[str, MessagingFuture[StoreResult]] = {} + # request_id -> (future, block_ids) + self.retrieve_futures: dict[ + str, tuple[MessagingFuture[RetrieveResult], list[int]] + ] = {} # Block IDs that failed due to retrieve timeout self.error_block_ids: set[int] = set() @@ -933,7 +934,7 @@ def submit_store_request( "Transfer context is not initialized. " "Call register_kv_caches() before submitting store requests." ) - self.transfer_ctx.submit_store( + future = self.transfer_ctx.submit_store( request_id, key, self.instance_id, @@ -942,7 +943,7 @@ def submit_store_request( event, self.blocks_in_chunk, ) - self._pending_store_request_ids.add(request_id) + self.store_futures[request_id] = future @_lmcache_nvtx_annotate def submit_retrieve_request( @@ -981,7 +982,7 @@ def submit_retrieve_request( "Transfer context is not initialized. " "Call register_kv_caches() before submitting retrieve requests." ) - self.transfer_ctx.submit_retrieve( + future = self.transfer_ctx.submit_retrieve( request_id, key, self.instance_id, @@ -991,6 +992,7 @@ def submit_retrieve_request( self.blocks_in_chunk, skip_first_n_tokens=op.skip_first_n_tokens, ) + self.retrieve_futures[request_id] = (future, list(op.block_ids)) @_lmcache_nvtx_annotate def batched_submit_store_requests( @@ -1055,10 +1057,7 @@ def _process_finished_stores( for req_id in finished_req_ids_from_engine: if req_id in self._returned_finished: continue - if ( - req_id in self.finished_stores - or req_id in self._pending_store_request_ids - ): + if req_id in self.finished_stores or req_id in self.store_futures: self.previously_finished.add(req_id) else: ret_stores.add(req_id) @@ -1090,35 +1089,72 @@ def get_finished( take care of deduplicating the request IDs and only return the request IDs that have not been returned before. """ - if self.transfer_ctx is None: - return set(), set() - - unhealthy = not self.is_healthy - if unhealthy: - finished_stores, finished_retrieves, error_block_ids = ( - self.transfer_ctx.drain_all() - ) - else: - finished_stores, finished_retrieves, error_block_ids = ( - self.transfer_ctx.poll_finished() + # If unhealthy, drain all pending futures immediately + if not self.is_healthy: + finished_stores = set(self.store_futures.keys()) + finished_retrieves = set() + for request_id, ( + _r_future, + r_block_ids, + ) in self.retrieve_futures.items(): + finished_retrieves.add(request_id) + self.error_block_ids.update(r_block_ids) + self.store_futures.clear() + self.retrieve_futures.clear() + + ret_stores = self._process_finished_stores( + finished_stores, finished_req_ids_from_engine ) + # A request may have a pending retrieve AND appear in + # finished_req_ids_from_engine (it ran without loading KV after + # the server died). The scheduler processes finished_recving + # first and deletes the request, so we must not also report it + # in finished_sending. + ret_stores -= finished_retrieves + return ret_stores, finished_retrieves - self.error_block_ids.update(error_block_ids) - self._pending_store_request_ids.difference_update(finished_stores) + finished_stores = set() + finished_retrieves = set() + for request_id, s_future in self.store_futures.items(): + if not s_future.query(): + continue + + s_result = s_future.result() + finished_stores.add(request_id) + + if not s_result: + logger.error( + "Something went wrong when processing the " + "store request for request_id=%s", + request_id, + ) + + for request_id, (r_future, _) in self.retrieve_futures.items(): + if not r_future.query(): + continue + + r_result = r_future.result() + finished_retrieves.add(request_id) + + if not r_result: + logger.error( + "Something went wrong when processing the " + "retrieve request for request_id=%s, result=%s", + request_id, + r_result, + ) + + # Remove the finished requests from the tracking dicts + for request_id in finished_stores: + self.store_futures.pop(request_id, None) + for request_id in finished_retrieves: + self.retrieve_futures.pop(request_id, None) # Update the internal states ret_stores = self._process_finished_stores( finished_stores, finished_req_ids_from_engine ) - if unhealthy: - # A request may have a pending retrieve AND appear in - # finished_req_ids_from_engine (it ran without loading KV after - # the server died). The scheduler processes finished_recving - # first and deletes the request, so we must not also report it - # in finished_sending. - ret_stores -= finished_retrieves - # the invocation of `get_finished` means that # these requests' KV caches are already fully stored. # or the requests normally ends without any store. diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index df84c788ce4..41d96200eba 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -154,12 +154,8 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: int, str, int, - EngineType, - LayoutHints, + bytes, int, - int, - int, - str, bool, ], response_class=None, diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 2f932490b59..d176d61d690 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -10,6 +10,7 @@ import time # Third Party +import torch import zmq # First Party @@ -289,12 +290,8 @@ def register_kv_cache_cpu_context( instance_id: int, model_name: str, world_size: int, - engine_type: EngineType, - layout_hints: LayoutHints, + layout_desc_bytes: bytes, block_size: int, - num_layers: int, - hidden_dim_size: int, - dtype_str: str, use_mla: bool, ) -> None: """Register non-CUDA KV layout metadata for CPU context mode. @@ -303,41 +300,11 @@ def register_kv_cache_cpu_context( instance_id: Worker instance identifier (typically PID). model_name: Model name associated with this worker. world_size: Worker world size used in cache keys. - engine_type: Serving engine type (kept for protocol compatibility). - layout_hints: Optional engine layout hints (protocol compatibility). + layout_desc_bytes: Pickled :class:`MemoryLayoutDesc`. block_size: Tokens per paged block. - num_layers: Number of model layers. - hidden_dim_size: Flattened hidden dimension per token. - dtype_str: Torch dtype name (for example ``"float16"``). use_mla: Whether the worker KV format is MLA. - MLA stores one latent vector per token with shape - ``[num_layers, chunk_size, hidden_dim_size]``; non-MLA stores - separate K/V planes with shape - ``[2, num_layers, chunk_size, hidden_dim_size]``. - - Raises: - ValueError: If ``dtype_str`` is not a valid torch dtype name. """ - # Third Party - import torch - - # Keep these for protocol compatibility with register_kv_cache(). - del engine_type, layout_hints - dtype = getattr(torch, dtype_str, None) - if dtype is None or not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid dtype_str '{dtype_str}': expected a torch.dtype name " - "(e.g. 'float16', 'bfloat16', 'float32')." - ) - - shape = ( - # MLA has one latent state per token (no separate K/V axis), - # while non-MLA stores separate K and V tensors at dim 0. - torch.Size([num_layers, self.chunk_size, hidden_dim_size]) - if use_mla - else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) - ) - layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + layout_desc = pickle.loads(layout_desc_bytes) self.cpu_contexts[instance_id] = CPUContextMetadata( layout_desc=layout_desc, block_size=block_size, @@ -345,6 +312,17 @@ def register_kv_cache_cpu_context( ) self.cpu_context_meta[instance_id] = (model_name, world_size) + def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: + """Resolve object keys from an IPC cache key.""" + session = self.session_manager.get_or_create(key.request_id) + session.set_tokens(list(key.token_ids)) + chunk_hashes = [ + TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) + ] + if key.worker_id is None: + raise ValueError("Must resolve keys with worker_id != None") + return ipc_key_to_object_keys(key, chunk_hashes) + @_lmcache_nvtx_annotate def store_cpu_chunks( self, @@ -365,17 +343,7 @@ def store_cpu_chunks( Raises: ValueError: If the instance has no registered cpu context. """ - # Third Party - import torch - - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - if key.worker_id is None: - raise ValueError("Must store with worker_id != None") - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) if instance_id not in self.cpu_contexts: raise ValueError( @@ -426,14 +394,7 @@ def retrieve_cpu_chunks( Raises: ValueError: If the instance has no registered cpu context. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - if key.worker_id is None: - raise ValueError("Must retrieve with worker_id != None") - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) if instance_id not in self.cpu_contexts: raise ValueError( @@ -479,16 +440,8 @@ def store( that signals the completion of the store operation. The second element indicates whether the store operation was successful. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - st = time.perf_counter() - - assert key.worker_id is not None, "Must store with worker_id != None" - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) assert instance_id in self.gpu_contexts, ( f"KV cache not registered for GPU ID {instance_id}" @@ -655,16 +608,8 @@ def retrieve( that signals the completion of the retrieve operation. The second element indicates whether the key was successfully retrieved. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - st = time.perf_counter() - - assert key.worker_id is not None, "Must retrieve with worker_id != None" - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) assert instance_id in self.gpu_contexts, ( f"KV cache not registered for GPU ID {instance_id}" diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 12362d35923..29fd71b581e 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -4,6 +4,7 @@ # Standard from abc import ABC, abstractmethod from typing import Any, Callable, Protocol +import pickle # Third Party import torch @@ -21,7 +22,8 @@ gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) -from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture +from lmcache.v1.multiprocess.futures import MessagingFuture +from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import RequestType logger = init_logger(__name__) @@ -34,18 +36,11 @@ def ipc_handle(self) -> object: """Return an IPC handle consumable by the multiprocess server.""" -SendRequest = Callable[ - [MessageQueueClient, RequestType, list[object]], MessagingFuture[object] -] +SendRequest = Callable[[MessageQueueClient, RequestType, list[object]], MessagingFuture] class TransferContext(ABC): - """Abstract transport layer for worker-side KV transfer. - - Concrete implementations encapsulate how worker-side store/retrieve - operations are transmitted to the multiprocess server (for example, - CUDA IPC futures or CPU-context gather/scatter flows). - """ + """Abstract transport layer for worker-side KV transfer.""" @abstractmethod def register( @@ -59,18 +54,7 @@ def register( mq_timeout: float, send_request: SendRequest, ) -> None: - """Register KV caches with the server and wait for ACK. - - Args: - instance_id: Worker process instance id. - kv_caches: Worker KV cache tensors keyed by layer name. - model_name: Model name used by cache keys. - world_size: KV world size. - blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. - mq_client: Message queue client used to communicate with server. - mq_timeout: Timeout in seconds for synchronous request wait. - send_request: Request sender callable used to issue MQ requests. - """ + """Register KV caches with the server and wait for ACK.""" @abstractmethod def submit_store( @@ -82,18 +66,8 @@ def submit_store( block_ids: list[int], event: IPCEvent, blocks_in_chunk: int, - ) -> None: - """Submit a store request. - - Args: - request_id: Request identifier. - key: LMCache key object. - instance_id: Worker process instance id. - kv_caches: Worker KV cache tensors keyed by layer name. - block_ids: vLLM block ids to store. - event: Synchronization event object. - blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. - """ + ) -> MessagingFuture: + """Submit a store request and return a future.""" @abstractmethod def submit_retrieve( @@ -106,35 +80,8 @@ def submit_retrieve( event: IPCEvent, blocks_in_chunk: int, skip_first_n_tokens: int = 0, - ) -> None: - """Submit a retrieve request. - - Args: - request_id: Request identifier. - key: LMCache key object. - instance_id: Worker process instance id. - kv_caches: Worker KV cache tensors keyed by layer name. - block_ids: vLLM block ids to retrieve. - event: Synchronization event object. - blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. - skip_first_n_tokens: Number of tokens to skip for partial scatter. - """ - - @abstractmethod - def poll_finished(self) -> tuple[set[str], set[str], set[int]]: - """Poll completed requests. - - Returns: - Tuple of ``(finished_store_ids, finished_retrieve_ids, error_block_ids)``. - """ - - @abstractmethod - def drain_all(self) -> tuple[set[str], set[str], set[int]]: - """Drain all pending requests. - - Returns: - Tuple of ``(finished_store_ids, finished_retrieve_ids, error_block_ids)``. - """ + ) -> MessagingFuture: + """Submit a retrieve request and return a future.""" @abstractmethod def close(self) -> None: @@ -145,10 +92,7 @@ class CudaTransferContext(TransferContext): """CUDA IPC + MQ future transport context.""" def __init__(self) -> None: - self._store_futures: dict[str, Any] = {} - self._retrieve_futures: dict[str, tuple[Any, list[int]]] = {} self._mq_client: MessageQueueClient | None = None - self._mq_timeout: float = 0.0 self._send_request: SendRequest | None = None def register( @@ -167,7 +111,6 @@ def register( from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches self._mq_client = mq_client - self._mq_timeout = mq_timeout self._send_request = send_request layout_hints = vllm_layout_hints() future = send_request( @@ -186,33 +129,28 @@ def register( def submit_store( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, _kv_caches: dict[str, torch.Tensor], block_ids: list[int], event: IPCEvent, _blocks_in_chunk: int, - ) -> None: - if ( - self._mq_client is None - or self._send_request is None - or self._mq_timeout < 0 - ): + ) -> MessagingFuture: + if self._mq_client is None or self._send_request is None: raise RuntimeError( "CUDA transfer context is not registered. " "Call register() before submit_store()." ) - future = self._send_request( + return self._send_request( self._mq_client, RequestType.STORE, [key, instance_id, block_ids, event.ipc_handle()], ).to_cuda_future() - self._store_futures[request_id] = future def submit_retrieve( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, _kv_caches: dict[str, torch.Tensor], @@ -220,73 +158,20 @@ def submit_retrieve( event: IPCEvent, _blocks_in_chunk: int, skip_first_n_tokens: int = 0, - ) -> None: - if ( - self._mq_client is None - or self._send_request is None - or self._mq_timeout < 0 - ): + ) -> MessagingFuture: + if self._mq_client is None or self._send_request is None: raise RuntimeError( "CUDA transfer context is not registered. " "Call register() before submit_retrieve()." ) - future = self._send_request( + return self._send_request( self._mq_client, RequestType.RETRIEVE, [key, instance_id, block_ids, event.ipc_handle(), skip_first_n_tokens], ).to_cuda_future() - self._retrieve_futures[request_id] = (future, list(block_ids)) - - def poll_finished(self) -> tuple[set[str], set[str], set[int]]: - finished_stores: set[str] = set() - finished_retrieves: set[str] = set() - error_block_ids: set[int] = set() - - for request_id, s_future in list(self._store_futures.items()): - if not s_future.query(): - continue - s_result = s_future.result() - finished_stores.add(request_id) - if not s_result: - logger.error( - "Something went wrong when processing the store request " - "for request_id=%s", - request_id, - ) - self._store_futures.pop(request_id, None) - - for request_id, (r_future, r_block_ids) in list(self._retrieve_futures.items()): - if not r_future.query(): - continue - r_result = r_future.result() - finished_retrieves.add(request_id) - if not r_result: - logger.error( - "Something went wrong when processing the retrieve request " - "for request_id=%s, result=%s", - request_id, - r_result, - ) - error_block_ids.update(r_block_ids) - self._retrieve_futures.pop(request_id, None) - - return finished_stores, finished_retrieves, error_block_ids - - def drain_all(self) -> tuple[set[str], set[str], set[int]]: - finished_stores = set(self._store_futures.keys()) - finished_retrieves = set(self._retrieve_futures.keys()) - error_block_ids: set[int] = set() - for _request_id, (_r_future, block_ids) in self._retrieve_futures.items(): - error_block_ids.update(block_ids) - self._store_futures.clear() - self._retrieve_futures.clear() - return finished_stores, finished_retrieves, error_block_ids def close(self) -> None: - self._store_futures.clear() - self._retrieve_futures.clear() self._mq_client = None - self._mq_timeout = 0.0 self._send_request = None @@ -297,11 +182,6 @@ def __init__(self) -> None: self._cpu_context: CPUContext | None = None self._layout_hints: Any = None self._gpu_kv_format: Any = None - self._store_done: dict[str, bool] = {} - self._retrieve_done: dict[str, tuple[bool, list[int]]] = {} - self._mq_client: MessageQueueClient | None = None - self._mq_timeout: float = 0.0 - self._send_request: SendRequest | None = None def register( self, @@ -317,9 +197,6 @@ def register( # First Party from lmcache.integration.vllm.utils import vllm_layout_hints - self._mq_client = mq_client - self._mq_timeout = mq_timeout - self._send_request = send_request layout_hints = vllm_layout_hints() ( block_size, @@ -331,6 +208,17 @@ def register( self._layout_hints = layout_hints self._gpu_kv_format = gpu_kv_format + use_mla_flag = is_mla(gpu_kv_format) + shape = ( + torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size]) + if use_mla_flag + else torch.Size( + [2, num_layers, blocks_in_chunk * block_size, hidden_dim_size] + ) + ) + dtype = getattr(torch, dtype_str) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + future = send_request( mq_client, RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, @@ -338,27 +226,14 @@ def register( instance_id, model_name, world_size, - EngineType.VLLM, - layout_hints, + pickle.dumps(layout_desc), block_size, - num_layers, - hidden_dim_size, - dtype_str, - is_mla(gpu_kv_format), + use_mla_flag, ], ) - use_mla_flag = is_mla(gpu_kv_format) - shape = ( - torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size]) - if use_mla_flag - else torch.Size( - [2, num_layers, blocks_in_chunk * block_size, hidden_dim_size] - ) - ) - dtype = getattr(torch, dtype_str) metadata = CPUContextMetadata( - layout_desc=MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]), + layout_desc=layout_desc, block_size=block_size, use_mla=use_mla_flag, ) @@ -367,19 +242,20 @@ def register( def submit_store( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, kv_caches: dict[str, torch.Tensor], block_ids: list[int], _event: IPCEvent, blocks_in_chunk: int, - ) -> None: + ) -> MessagingFuture: if self._cpu_context is None: raise RuntimeError( "CPU transfer context is not registered. " "Call register() before submit_store()." ) + torch_dev.synchronize() cpu_chunks = gather_paged_kv_to_cpu( kv_caches, @@ -390,11 +266,14 @@ def submit_store( ) handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) ok = self._cpu_context.commit_store(handle) - self._store_done[request_id] = ok + + future: MessagingFuture[bool] = MessagingFuture() + future.set_result(ok) + return future def submit_retrieve( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, kv_caches: dict[str, torch.Tensor], @@ -402,12 +281,13 @@ def submit_retrieve( _event: IPCEvent, blocks_in_chunk: int, skip_first_n_tokens: int = 0, - ) -> None: + ) -> MessagingFuture: if self._cpu_context is None: raise RuntimeError( "CPU transfer context is not registered. " "Call register() before submit_retrieve()." ) + handle, chunks = self._cpu_context.prepare_retrieve(key, instance_id) ok = chunks is not None if chunks is not None: @@ -425,31 +305,15 @@ def submit_retrieve( logger.exception("Failed to scatter retrieved CPU context chunks") ok = False self._cpu_context.commit_retrieve(handle) - self._retrieve_done[request_id] = (ok, list(block_ids)) - - def poll_finished(self) -> tuple[set[str], set[str], set[int]]: - finished_stores = set(self._store_done.keys()) - finished_retrieves = set(self._retrieve_done.keys()) - error_block_ids: set[int] = set() - for ok, block_ids in self._retrieve_done.values(): - if not ok: - error_block_ids.update(block_ids) - self._store_done.clear() - self._retrieve_done.clear() - return finished_stores, finished_retrieves, error_block_ids - - def drain_all(self) -> tuple[set[str], set[str], set[int]]: - return self.poll_finished() + + future: MessagingFuture[bool] = MessagingFuture() + future.set_result(ok) + return future def close(self) -> None: if self._cpu_context is not None: self._cpu_context.close() self._cpu_context = None - self._store_done.clear() - self._retrieve_done.clear() - self._mq_client = None - self._mq_timeout = 0.0 - self._send_request = None def create_transfer_context( diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py index 90b7d9d9b65..ff38859e3c5 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -10,6 +10,9 @@ import pytest import torch +# First Party +from lmcache.v1.distributed.api import MemoryLayoutDesc + def _make_kv_caches( num_layers: int = 2, @@ -83,12 +86,20 @@ def _make_hnd_flashinfer_kv_caches( return kv_caches -def test_wrap_kv_caches_cpu_context_returns_empty() -> None: - """Verify wrap_kv_caches returns no IPC wrappers in cpu context mode.""" +def test_wrap_kv_caches_wraps_all_tensors(monkeypatch: Any) -> None: + """Verify wrap_kv_caches wraps all provided KV tensors.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches + from lmcache.integration.vllm import vllm_multi_process_adapter as adapter_mod + + kv_caches = _make_kv_caches() + monkeypatch.setattr( + adapter_mod, + "CudaIPCWrapper", + lambda tensor: ("wrapped", tensor), + ) - assert wrap_kv_caches(_make_kv_caches(), use_cpu_context=True) == [] + wrapped = adapter_mod.wrap_kv_caches(kv_caches) + assert len(wrapped) == len(kv_caches) def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: @@ -345,16 +356,16 @@ def test_server_register_and_find_cpu_context_layout( patch("lmcache.v1.multiprocess.server.get_event_bus"), ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) + expected_layout_desc = MemoryLayoutDesc( + shapes=[torch.Size([2, 2, 16, 16])], + dtypes=[torch.float32], + ) engine.register_kv_cache_cpu_context( instance_id=1, model_name="m", world_size=1, - engine_type=MagicMock(), - layout_hints={}, + layout_desc_bytes=pickle.dumps(expected_layout_desc), block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", use_mla=False, ) @@ -398,16 +409,16 @@ def _read_prefetched_results(_keys: Any) -> Any: session_cls.return_value.get_or_create.return_value = mock_session engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + layout_desc = MemoryLayoutDesc( + shapes=[torch.Size([2, 2, 8, 16])], + dtypes=[torch.float32], + ) engine.register_kv_cache_cpu_context( instance_id=2, model_name="m", world_size=1, - engine_type=MagicMock(), - layout_hints={}, + layout_desc_bytes=pickle.dumps(layout_desc), block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", use_mla=False, ) payload = torch.ones(2, 2, 8, 16) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 84edb347bcb..ef7d7c0b93b 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -120,16 +120,20 @@ def test_register_kv_caches_cpu_submits_cpu_context_registration( assert send_mock.call_count == 1 args, _kwargs = send_mock.call_args assert args[1] == RequestType.REGISTER_KV_CACHE_CPU_CONTEXT + assert len(args[2]) == 6 + assert isinstance(args[2][3], bytes) -def test_submit_store_request_passes_no_transport_kwargs(fake_adapter, monkeypatch): - """submit_store_request should not pass mq/send kwargs after registration.""" +def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch): + """submit_store_request stores the returned future in store_futures.""" adapter, _send_mock, _ = fake_adapter monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) fake_tensor = MagicMock() fake_tensor.device.type = "cuda" adapter.kv_caches = {"layer.0": fake_tensor} transfer_ctx = MagicMock() + fake_future = MagicMock() + transfer_ctx.submit_store.return_value = fake_future adapter.transfer_ctx = transfer_ctx op = LoadStoreOp(token_ids=[1, 2, 3, 4], block_ids=[0], start=0, end=4) @@ -137,16 +141,19 @@ def test_submit_store_request_passes_no_transport_kwargs(fake_adapter, monkeypat assert transfer_ctx.submit_store.called assert transfer_ctx.submit_store.call_args.kwargs == {} + assert adapter.store_futures["req-1"] is fake_future -def test_submit_retrieve_request_passes_no_transport_kwargs(fake_adapter, monkeypatch): - """submit_retrieve_request should not pass mq/send kwargs after registration.""" +def test_submit_retrieve_request_tracks_returned_future(fake_adapter, monkeypatch): + """submit_retrieve_request stores returned future and block IDs.""" adapter, _send_mock, _ = fake_adapter monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) fake_tensor = MagicMock() fake_tensor.device.type = "cuda" adapter.kv_caches = {"layer.0": fake_tensor} transfer_ctx = MagicMock() + fake_future = MagicMock() + transfer_ctx.submit_retrieve.return_value = fake_future adapter.transfer_ctx = transfer_ctx op = LoadStoreOp( token_ids=[1, 2, 3, 4], @@ -160,3 +167,4 @@ def test_submit_retrieve_request_passes_no_transport_kwargs(fake_adapter, monkey assert transfer_ctx.submit_retrieve.called assert transfer_ctx.submit_retrieve.call_args.kwargs == {"skip_first_n_tokens": 1} + assert adapter.retrieve_futures["req-1"] == (fake_future, [0]) From 22f3dbc1cab93e48d5399cc6acee8c2b14354038 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 13 May 2026 06:34:11 +0000 Subject: [PATCH 3/6] Improve docstrings for transfer context and key resolver Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/b277cbef-7a95-47da-b1b5-93a69834a331 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/multiprocess/server.py | 12 +++- lmcache/v1/multiprocess/transfer_context.py | 63 +++++++++++++++++++-- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index d176d61d690..61431750979 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -313,7 +313,17 @@ def register_kv_cache_cpu_context( self.cpu_context_meta[instance_id] = (model_name, world_size) def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: - """Resolve object keys from an IPC cache key.""" + """Resolve object keys from an IPC cache key. + + Args: + key: IPC cache key describing model/session/token range. + + Returns: + Resolved object keys for the requested token range. + + Raises: + ValueError: If ``key.worker_id`` is ``None``. + """ session = self.session_manager.get_or_create(key.request_id) session.set_tokens(list(key.token_ids)) chunk_hashes = [ diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 29fd71b581e..94f42e338f0 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -40,7 +40,13 @@ def ipc_handle(self) -> object: class TransferContext(ABC): - """Abstract transport layer for worker-side KV transfer.""" + """Abstract transport layer for worker-side KV transfer. + + Concrete implementations encapsulate how worker-side store/retrieve + operations are transmitted to the multiprocess server. CUDA paths return + CUDA-aware futures backed by MQ requests, while CPU paths may perform + gather/scatter synchronously and return already-resolved futures. + """ @abstractmethod def register( @@ -54,7 +60,23 @@ def register( mq_timeout: float, send_request: SendRequest, ) -> None: - """Register KV caches with the server and wait for ACK.""" + """Register KV caches with the server and wait for ACK. + + Args: + instance_id: Worker process instance identifier. + kv_caches: Worker KV cache tensors keyed by layer name. + model_name: Model name used by cache keys. + world_size: KV world size. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + mq_client: Message queue client used to communicate with server. + mq_timeout: Timeout in seconds for synchronous request wait. + send_request: Request sender callable used to issue MQ requests. + + Raises: + TimeoutError: If server registration does not complete before + ``mq_timeout``. + RuntimeError: If a concrete context cannot initialize. + """ @abstractmethod def submit_store( @@ -67,7 +89,23 @@ def submit_store( event: IPCEvent, blocks_in_chunk: int, ) -> MessagingFuture: - """Submit a store request and return a future.""" + """Submit a store request and return a completion future. + + Args: + request_id: External request identifier. + key: LMCache key object for the store range. + instance_id: Worker process instance identifier. + kv_caches: Worker KV cache tensors keyed by layer name. + block_ids: vLLM block IDs to store. + event: Synchronization event object. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + + Returns: + A future compatible with adapter-side ``query()``/``result()`` flow. + + Raises: + RuntimeError: If register() was not called first. + """ @abstractmethod def submit_retrieve( @@ -81,7 +119,24 @@ def submit_retrieve( blocks_in_chunk: int, skip_first_n_tokens: int = 0, ) -> MessagingFuture: - """Submit a retrieve request and return a future.""" + """Submit a retrieve request and return a completion future. + + Args: + request_id: External request identifier. + key: LMCache key object for the retrieve range. + instance_id: Worker process instance identifier. + kv_caches: Worker KV cache tensors keyed by layer name. + block_ids: vLLM block IDs to retrieve into. + event: Synchronization event object. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + skip_first_n_tokens: Number of initial tokens to skip when writing. + + Returns: + A future compatible with adapter-side ``query()``/``result()`` flow. + + Raises: + RuntimeError: If register() was not called first. + """ @abstractmethod def close(self) -> None: From 8dd355871fa5bcd6d6cabfb4eeca3d8edcf5b15a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 13 May 2026 07:00:45 +0000 Subject: [PATCH 4/6] Revert CPU registration payload from pickle bytes to scalar fields Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/8efcb792-420d-460c-98c8-97775dcee47e Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/multiprocess/protocols/engine.py | 4 +++- lmcache/v1/multiprocess/server.py | 26 ++++++++++++++++++--- lmcache/v1/multiprocess/transfer_context.py | 5 ++-- tests/v1/multiprocess/test_cpu_context.py | 20 +++++----------- tests/v1/test_vllm_mp_adapter.py | 3 +-- 5 files changed, 36 insertions(+), 22 deletions(-) diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index 41d96200eba..3fdcdd9edbe 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -154,8 +154,10 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: int, str, int, - bytes, int, + int, + int, + str, bool, ], response_class=None, diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 61431750979..ee00604ddb2 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -290,8 +290,10 @@ def register_kv_cache_cpu_context( instance_id: int, model_name: str, world_size: int, - layout_desc_bytes: bytes, block_size: int, + num_layers: int, + hidden_dim_size: int, + dtype_str: str, use_mla: bool, ) -> None: """Register non-CUDA KV layout metadata for CPU context mode. @@ -300,11 +302,29 @@ def register_kv_cache_cpu_context( instance_id: Worker instance identifier (typically PID). model_name: Model name associated with this worker. world_size: Worker world size used in cache keys. - layout_desc_bytes: Pickled :class:`MemoryLayoutDesc`. block_size: Tokens per paged block. + num_layers: Number of model layers. + hidden_dim_size: Flattened hidden dimension per token. + dtype_str: Torch dtype name (for example ``"float16"``). use_mla: Whether the worker KV format is MLA. + + Raises: + ValueError: If ``dtype_str`` is not a valid torch dtype name. """ - layout_desc = pickle.loads(layout_desc_bytes) + dtype = getattr(torch, dtype_str, None) + if dtype is None or not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid dtype_str '{dtype_str}': must be a valid torch dtype " + "attribute name (e.g. 'float16' for torch.float16, " + "'bfloat16' for torch.bfloat16, 'float32' for torch.float32)." + ) + + shape = ( + torch.Size([num_layers, self.chunk_size, hidden_dim_size]) + if use_mla + else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) + ) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) self.cpu_contexts[instance_id] = CPUContextMetadata( layout_desc=layout_desc, block_size=block_size, diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 94f42e338f0..89047a3983d 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -4,7 +4,6 @@ # Standard from abc import ABC, abstractmethod from typing import Any, Callable, Protocol -import pickle # Third Party import torch @@ -281,8 +280,10 @@ def register( instance_id, model_name, world_size, - pickle.dumps(layout_desc), block_size, + num_layers, + hidden_dim_size, + dtype_str, use_mla_flag, ], ) diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py index ff38859e3c5..d48189e06c0 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -10,10 +10,6 @@ import pytest import torch -# First Party -from lmcache.v1.distributed.api import MemoryLayoutDesc - - def _make_kv_caches( num_layers: int = 2, num_blocks: int = 6, @@ -356,16 +352,14 @@ def test_server_register_and_find_cpu_context_layout( patch("lmcache.v1.multiprocess.server.get_event_bus"), ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) - expected_layout_desc = MemoryLayoutDesc( - shapes=[torch.Size([2, 2, 16, 16])], - dtypes=[torch.float32], - ) engine.register_kv_cache_cpu_context( instance_id=1, model_name="m", world_size=1, - layout_desc_bytes=pickle.dumps(expected_layout_desc), block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", use_mla=False, ) @@ -409,16 +403,14 @@ def _read_prefetched_results(_keys: Any) -> Any: session_cls.return_value.get_or_create.return_value = mock_session engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) - layout_desc = MemoryLayoutDesc( - shapes=[torch.Size([2, 2, 8, 16])], - dtypes=[torch.float32], - ) engine.register_kv_cache_cpu_context( instance_id=2, model_name="m", world_size=1, - layout_desc_bytes=pickle.dumps(layout_desc), block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", use_mla=False, ) payload = torch.ones(2, 2, 8, 16) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index ef7d7c0b93b..85d31f1d69b 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -120,8 +120,7 @@ def test_register_kv_caches_cpu_submits_cpu_context_registration( assert send_mock.call_count == 1 args, _kwargs = send_mock.call_args assert args[1] == RequestType.REGISTER_KV_CACHE_CPU_CONTEXT - assert len(args[2]) == 6 - assert isinstance(args[2][3], bytes) + assert len(args[2]) == 8 def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch): From fbba72eaabeb2639b769a86c1ea174674e963082 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 13 May 2026 07:29:34 +0000 Subject: [PATCH 5/6] Sync CPU context design doc with scalar registration and future flow Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/11d24763-3f7b-4dda-946b-e0c7b35c73d2 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- .../v1/multiprocess/cpu_context_design.md | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/cpu_context_design.md index b74d27185b2..128ce843eb7 100644 --- a/docs/design/v1/multiprocess/cpu_context_design.md +++ b/docs/design/v1/multiprocess/cpu_context_design.md @@ -36,6 +36,21 @@ cpu context design): These are registered in the MP server dispatch and have corresponding payload/response contracts in the multiprocess protocol definitions. +Current registration payload for `REGISTER_KV_CACHE_CPU_CONTEXT` is: + +```python +[ + instance_id, # int + model_name, # str + world_size, # int + block_size, # int + num_layers, # int + hidden_dim_size, # int + dtype_str, # str + use_mla, # bool +] +``` + ## File structure ``` @@ -138,7 +153,11 @@ and calls: - `transfer_ctx.register(...)` - `transfer_ctx.submit_store(...)` - `transfer_ctx.submit_retrieve(...)` -- `transfer_ctx.poll_finished()` (healthy) or `transfer_ctx.drain_all()` (unhealthy) +- `transfer_ctx.close()` + +The adapter owns request completion tracking via +`self.store_futures` / `self.retrieve_futures`. It polls each future through +`query()` / `result()` in `get_finished()`. ### Store path (non-CUDA) @@ -147,11 +166,11 @@ and calls: cpu_chunks = gather_paged_kv_to_cpu(kv_caches, block_ids, blocks_in_chunk, ...) handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) ok = self._cpu_context.commit_store(handle) # synchronous; blocks for server ack -self._store_done[request_id] = ok +future = MessagingFuture() +future.set_result(ok) +return future ``` -`CPUTransferContext.poll_finished()` drains `_store_done` on each call. - ### Retrieve path (non-CUDA) ```python @@ -165,15 +184,15 @@ if chunks is not None: except (RuntimeError, ValueError, TypeError, IndexError): ok = False self._cpu_context.commit_retrieve(handle) -self._retrieve_done[request_id] = (ok, block_ids) +future = MessagingFuture() +future.set_result(ok) +return future ``` - -`CPUTransferContext.poll_finished()` drains `_retrieve_done` on each call. The adapter passes `op.skip_first_n_tokens` into `transfer_ctx.submit_retrieve(..., skip_first_n_tokens=...)`. The retrieve is **synchronous inside `CPUTransferContext.submit_retrieve`**; -`poll_finished()` just drains request ids recorded by submit methods. +the returned future is already resolved by the time the method returns. ## Server integration @@ -211,7 +230,7 @@ Additional integration points: | | v v CudaTransferContext.register() CPUTransferContext.register() - REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_CPU_CONTEXT (CPU metadata) + REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_CPU_CONTEXT (scalar metadata) | + create_cpu_context() +----------------+----------------+ | @@ -224,10 +243,10 @@ Additional integration points: transfer_ctx.submit_store() transfer_ctx.submit_store() | | v v - STORE (GPU -> L1) gather_paged_kv_to_cpu() - | + _cpu_context.prepare_store() - v + _cpu_context.commit_store() [sync] - [READY] _store_done[id] = ok + STORE (GPU -> L1) gather_paged_kv_to_cpu() + | + _cpu_context.prepare_store() + v + _cpu_context.commit_store() [sync] + [READY] return resolved future | | +----------------+----------------+ | @@ -237,10 +256,10 @@ Additional integration points: +----------------+----------------+ | | v v - RETRIEVE (L1 -> GPU) _cpu_context.prepare_retrieve() [sync] - [async future] + scatter_cpu_to_paged_kv() - + _cpu_context.commit_retrieve() - _retrieve_done[id] = (ok, block_ids) + RETRIEVE (L1 -> GPU) _cpu_context.prepare_retrieve() [sync] + [async future] + scatter_cpu_to_paged_kv() + + _cpu_context.commit_retrieve() + return resolved future | | +----------------+----------------+ | @@ -273,7 +292,7 @@ back to pickle when SHM is unavailable. `tests/v1/multiprocess/test_cpu_context.py` covers: -- CPU wrapper behavior (`wrap_kv_caches` with cpu context mode) +- CPU wrapper behavior (`wrap_kv_caches`) - NHD and MLA gather/scatter round-trip - HND round-trip for both HND formats - `skip_first_n_tokens` behavior From 95a009cf17520733ad2b9d6a97abc236131fb447 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 13 May 2026 07:47:26 +0000 Subject: [PATCH 6/6] Streamline CPU context design doc to high-level guidance Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/c045a593-284a-42f6-b68f-83ebe9bfad8b Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- .../v1/multiprocess/cpu_context_design.md | 335 ++++-------------- 1 file changed, 63 insertions(+), 272 deletions(-) diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/cpu_context_design.md index 128ce843eb7..2e15342c7ae 100644 --- a/docs/design/v1/multiprocess/cpu_context_design.md +++ b/docs/design/v1/multiprocess/cpu_context_design.md @@ -2,307 +2,98 @@ ## Scope -This document describes the non-CUDA CPU-based KV transfer path for LMCache -multiprocess mode. +This document describes the high-level design of the non-CUDA KV transfer path +for LMCache multiprocess mode. -The goal is to support KV transfer for non-CUDA devices (for example CPU, -XPU, HPU) without changing the existing CUDA IPC path, while providing a -clean abstraction layer that makes it easy to add alternative transport -mechanisms (e.g. shared memory) in a future PR. +The purpose of this path is to support KV transfer on non-CUDA devices (for +example CPU, XPU, and HPU) while preserving existing CUDA IPC behavior. ## Why this path exists -The CUDA path uses IPC wrappers around GPU tensors and the existing -`REGISTER_KV_CACHE` / `STORE` / `RETRIEVE` request flow. +CUDA IPC is only available for CUDA tensors. For non-CUDA tensors, workers use +a CPU-context path that: -For non-CUDA tensors, CUDA IPC is not available. The CPU context path -provides a generic protocol where workers: +1. gathers KV blocks into CPU chunks, +2. transfers those chunks to/from the server through `CPUContext`, +3. scatters retrieved chunks back to worker KV tensors. -1. Gather KV blocks into CPU chunk(or memory obj) tensors. -2. Transport those CPU chunks to the server storage through a concrete - `CPUContext` implementation. -3. Retrieve CPU chunks(or memory obj) from the server and scatter them back into device KV - tensors. +## Protocol overview -## Protocol additions - -Three request types are used for non-CUDA mode (unchanged from the original -cpu context design): +Non-CUDA mode adds three request types: - `REGISTER_KV_CACHE_CPU_CONTEXT` - `STORE_CPU_CHUNKS` - `RETRIEVE_CPU_CHUNKS` -These are registered in the MP server dispatch and have corresponding -payload/response contracts in the multiprocess protocol definitions. - -Current registration payload for `REGISTER_KV_CACHE_CPU_CONTEXT` is: - -```python -[ - instance_id, # int - model_name, # str - world_size, # int - block_size, # int - num_layers, # int - hidden_dim_size, # int - dtype_str, # str - use_mla, # bool -] -``` - -## File structure - -``` -lmcache/v1/multiprocess/ -├── cpu_context.py # CPUContextMetadata, CPUContext(ABC), factory, gather/scatter utils -└── cpu_context_pickle.py # CPUContextPickle — pickle-based concrete implementation -``` - -### `cpu_context.py` - -Provides: - -- **`CPUContextMetadata`** dataclass — layout metadata: - - ```python - @dataclass - class CPUContextMetadata: - layout_desc: MemoryLayoutDesc - block_size: int - use_mla: bool - ``` - -- **`CPUContext(ABC)`** — abstract base class with `mq_client` as a common - dependency. All concrete implementations share the same two-phase - `prepare/commit` interface: - - ```python - class CPUContext(ABC): - def __init__(self, metadata: CPUContextMetadata, mq_client, mq_timeout: float): ... - - @abstractmethod - def prepare_store(self, key, instance_id, chunks: list[torch.Tensor]) -> Any: ... - @abstractmethod - def commit_store(self, handle: Any) -> bool: ... - @abstractmethod - def prepare_retrieve(self, key, instance_id) -> tuple[Any, list[torch.Tensor] | None]: ... - @abstractmethod - def commit_retrieve(self, handle: Any) -> None: ... - @abstractmethod - def close(self) -> None: ... - ``` - -- **`create_cpu_context()`** factory — currently always returns a - `CPUContextPickle` instance; a future SHM-capable PR can extend this to - probe for shared-memory availability and fall back to pickle. - -- **Shared utility functions** used by all concrete implementations: - - `compute_kv_layout` — extract block size, layer count, hidden dim and - dtype from live KV tensors. - - `gather_paged_kv_to_cpu` — gather paged KV blocks into a list of CPU - tensors (one per LMCache chunk). - - `scatter_cpu_to_paged_kv` — scatter CPU chunk tensors back into paged - KV tensors. - -### `cpu_context_pickle.py` - -Provides **`CPUContextPickle(CPUContext)`**: - -| Phase | What happens | -|---|---| -| `prepare_store` | `pickle.dumps(chunks)` → returns `(key, instance_id, bytes)` as opaque handle | -| `commit_store` | sends `STORE_CPU_CHUNKS` via `mq_client`, blocks for server ack, returns `bool` | -| `prepare_retrieve` | sends `RETRIEVE_CPU_CHUNKS` via `mq_client`, blocks for response, `pickle.loads` → returns `(None, chunks)` or `(None, None)` on miss | -| `commit_retrieve` | no-op (pickle path holds no server-side locks) | -| `close` | no-op | - -## Tensor/chunk contracts +CPU-context registration uses scalar metadata (for example: `instance_id`, +`model_name`, `world_size`, `block_size`, `num_layers`, `hidden_dim_size`, +`dtype_str`, and `use_mla`) so server-side layout can be reconstructed without +transmitting pickled layout objects, reducing serialization coupling and +allowing server-side validation from explicit fields. -Chunk formats are unchanged: +## Main components -- non-MLA: `[2, num_layers, chunk_tokens, hidden_dim]` -- MLA: `[num_layers, chunk_tokens, hidden_dim]` +- `cpu_context.py` + - defines `CPUContextMetadata`, the `CPUContext` abstraction, and shared + gather/scatter helpers. +- `cpu_context_pickle.py` + - current concrete `CPUContext` implementation. +- `transfer_context.py` + - dispatches between CUDA and CPU transfer paths. +- `vllm_multi_process_adapter.py` + - owns request lifecycle and future polling. +- `server.py` + - stores per-instance CPU metadata and handles CPU chunk store/retrieve + requests. -Internal gather/scatter uses block-level indexing to avoid token-level slot -expansion and token-wise select/copy operations. +## Worker-side behavior -## Layout handling +`create_transfer_context(kv_caches)` selects transport by device type: -Supported KV formats in CPU gather/scatter: +- CUDA tensors -> `CudaTransferContext` +- non-CUDA tensors -> `CPUTransferContext` -- `NL_X_TWO_NB_BS_NH_HS` (NHD) -- `NL_X_NB_TWO_BS_NH_HS` (NHD flashinfer) -- `NL_X_TWO_NB_NH_BS_HS` (HND) -- `NL_X_NB_TWO_NH_BS_HS` (HND flashinfer) -- `NL_X_NB_BS_HS` (MLA) +The adapter keeps ownership of request completion tracking via +`store_futures` and `retrieve_futures`. -## Worker adapter integration +For CPU mode, store/retrieve execution is synchronous inside +`CPUTransferContext` (gather/scatter plus MQ interaction), and the transfer +methods return resolved futures so adapter-side completion flow stays uniform +across CUDA and non-CUDA modes. Here, "resolved futures" means the futures are +already completed when returned (no background async work pending in the CPU +path). -The adapter now delegates transport behavior to -`lmcache/v1/multiprocess/transfer_context.py`. +## Server-side behavior -`create_transfer_context(kv_caches)` centralizes device dispatch: +`MPCacheEngine` maintains CPU-context metadata per worker instance and uses that +metadata to resolve layout for CPU chunk writes/reads. -- all CUDA → existing CUDA IPC registration and store/retrieve path -- all non-CUDA → `CPUTransferContext` with cpu context registration and CPU context store/retrieve path +Server handlers: -`LMCacheMPSchedulerAdapter` now holds `self.transfer_ctx: TransferContext | None` -and calls: +- register CPU-context metadata, +- store worker-provided CPU chunks into storage, +- retrieve CPU chunks from storage and return them to workers. -- `transfer_ctx.register(...)` -- `transfer_ctx.submit_store(...)` -- `transfer_ctx.submit_retrieve(...)` -- `transfer_ctx.close()` +Cleanup removes CPU-context state on unregister. -The adapter owns request completion tracking via -`self.store_futures` / `self.retrieve_futures`. It polls each future through -`query()` / `result()` in `get_finished()`. +## Format and compatibility notes -### Store path (non-CUDA) - -```python -# CPUTransferContext.submit_store -cpu_chunks = gather_paged_kv_to_cpu(kv_caches, block_ids, blocks_in_chunk, ...) -handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) -ok = self._cpu_context.commit_store(handle) # synchronous; blocks for server ack -future = MessagingFuture() -future.set_result(ok) -return future -``` - -### Retrieve path (non-CUDA) - -```python -# CPUTransferContext.submit_retrieve -handle, chunks = self._cpu_context.prepare_retrieve(key, instance_id) # synchronous -ok = chunks is not None -if chunks is not None: - try: - scatter_cpu_to_paged_kv(kv_caches, block_ids, chunks, blocks_in_chunk, - skip_first_n_tokens=skip_first_n_tokens, ...) - except (RuntimeError, ValueError, TypeError, IndexError): - ok = False -self._cpu_context.commit_retrieve(handle) -future = MessagingFuture() -future.set_result(ok) -return future -``` -The adapter passes `op.skip_first_n_tokens` into -`transfer_ctx.submit_retrieve(..., skip_first_n_tokens=...)`. - -The retrieve is **synchronous inside `CPUTransferContext.submit_retrieve`**; -the returned future is already resolved by the time the method returns. - -## Server integration - -`MPCacheEngine` holds: - -- `cpu_contexts: dict[int, CPUContextMetadata]` — per-instance metadata. -- `cpu_context_meta: dict[int, tuple[str, int]]` — per-instance - `(model_name, world_size)` for layout resolution. - -Server-side handler methods are unchanged: -- `register_kv_cache_cpu_context` — stores `CPUContextMetadata` in `cpu_contexts`. -- `store_cpu_chunks` — unpickles payload, copies tensors into storage. -- `retrieve_cpu_chunks` — reads from storage, pickles tensors, returns bytes. - -Additional integration points: - -- Unregister cleanup removes both `cpu_contexts` and `cpu_context_meta`. -- Layout lookup via `_find_layout_desc` resolves both GPU and CPU context - registrations. -- Status reporting (`report_status`) includes `registered_cpu_instance_ids` - and `cpu_context_meta`. - -## CUDA vs non-CUDA state machine - -```text - register_kv_caches() - | - v - create_transfer_context(kv_caches) - | - +----------------+----------------+ - | | - v v - [device == cuda] [device != cuda] - | | - v v - CudaTransferContext.register() CPUTransferContext.register() - REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_CPU_CONTEXT (scalar metadata) - | + create_cpu_context() - +----------------+----------------+ - | - v - [READY / SERVING] - | - +----------------+----------------+ - | | - v v - transfer_ctx.submit_store() transfer_ctx.submit_store() - | | - v v - STORE (GPU -> L1) gather_paged_kv_to_cpu() - | + _cpu_context.prepare_store() - v + _cpu_context.commit_store() [sync] - [READY] return resolved future - | | - +----------------+----------------+ - | - v - transfer_ctx.submit_retrieve() + get_finished() - | - +----------------+----------------+ - | | - v v - RETRIEVE (L1 -> GPU) _cpu_context.prepare_retrieve() [sync] - [async future] + scatter_cpu_to_paged_kv() - + _cpu_context.commit_retrieve() - return resolved future - | | - +----------------+----------------+ - | - v - [READY / SERVING] - | - v - unregister_kv_cache() - | - v - [TERMINATED] -``` - -## Future extension: CPUContextShm - -The `CPUContext` base class is designed to accommodate a shared-memory -implementation in a future PR with minimal changes: - -| Phase | Pickle | SHM (future) | -|---|---|---| -| `prepare_store` | `pickle.dumps` | MQ `PREPARE_STORE` → slot metadata → memcpy | -| `commit_store` | MQ `STORE_CPU_CHUNKS` | MQ `COMMIT_STORE` | -| `prepare_retrieve` | MQ `RETRIEVE_CPU_CHUNKS` + `pickle.loads` | MQ `PREPARE_RETRIEVE` → tensor views from SHM | -| `commit_retrieve` | no-op | MQ `FINISH_READ` (release read lock) | - -The `create_cpu_context()` factory will probe for SHM availability and fall -back to pickle when SHM is unavailable. +- Chunk tensor layout remains consistent with gather/scatter contracts: + non-MLA chunks are 4D (`[2, num_layers, chunk_tokens, hidden_dim]`) and MLA + chunks are 3D (`[num_layers, chunk_tokens, hidden_dim]`). +- Existing CUDA IPC semantics are unchanged. +- CPU-context logic remains isolated from shared GPU connector utilities. ## Validation coverage -`tests/v1/multiprocess/test_cpu_context.py` covers: - -- CPU wrapper behavior (`wrap_kv_caches`) -- NHD and MLA gather/scatter round-trip -- HND round-trip for both HND formats -- `skip_first_n_tokens` behavior -- Server-side register/store/retrieve flow +Tests cover: -`tests/v1/test_vllm_mp_adapter.py` covers transfer-context integration, -including CPU registration path (`REGISTER_KV_CACHE_CPU_CONTEXT`) and -store/retrieve submit delegation. +- CPU gather/scatter correctness across supported layouts, +- CPU registration and server store/retrieve flow, +- adapter integration with transfer-context submit/get-finished behavior. -## Non-goals +## Future extension -- No change to existing CUDA IPC path semantics. -- No CPU-specific logic added to shared `gpu_connector/utils.py`. +The `CPUContext` abstraction is designed to support additional transports +(e.g. shared-memory-based implementations) with minimal adapter/server flow +changes.