From 99119a283cdeabb1a501624810d3bdb77d60910a Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 12 May 2026 06:11:08 +0000 Subject: [PATCH 01/23] feat(mp): CPU Context by pickle Signed-off-by: Tony Lin --- .../v1/multiprocess/cpu_context_design.md | 273 +++++++++++ .../vllm/vllm_multi_process_adapter.py | 226 +++++++-- lmcache/non_cuda_equivalents.py | 12 +- lmcache/v1/multiprocess/cpu_context.py | 433 ++++++++++++++++++ lmcache/v1/multiprocess/cpu_context_pickle.py | 121 +++++ lmcache/v1/multiprocess/protocols/base.py | 3 + lmcache/v1/multiprocess/protocols/engine.py | 29 ++ lmcache/v1/multiprocess/server.py | 216 ++++++++- tests/v1/multiprocess/test_cpu_context.py | 433 ++++++++++++++++++ 9 files changed, 1701 insertions(+), 45 deletions(-) create mode 100644 docs/design/v1/multiprocess/cpu_context_design.md create mode 100644 lmcache/v1/multiprocess/cpu_context.py create mode 100644 lmcache/v1/multiprocess/cpu_context_pickle.py create mode 100644 tests/v1/multiprocess/test_cpu_context.py diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/cpu_context_design.md new file mode 100644 index 00000000000..0f83205b411 --- /dev/null +++ b/docs/design/v1/multiprocess/cpu_context_design.md @@ -0,0 +1,273 @@ +# CPU Context Design (MP mode, non-CUDA) + +## Scope + +This document describes the non-CUDA CPU-based KV transfer path for LMCache +multiprocess mode. + +The goal is to support KV transfer for non-CUDA devices (for example CPU, +XPU, HPU) without changing the existing CUDA IPC path, while providing a +clean abstraction layer that makes it easy to add alternative transport +mechanisms (e.g. shared memory) in a future PR. + +## Why this path exists + +The CUDA path uses IPC wrappers around GPU tensors and the existing +`REGISTER_KV_CACHE` / `STORE` / `RETRIEVE` request flow. + +For non-CUDA tensors, CUDA IPC is not available. The CPU context path +provides a generic protocol where workers: + +1. Gather KV blocks into CPU chunk tensors. +2. Transport those CPU chunks to the server storage through a concrete + `CPUContext` implementation. +3. Retrieve CPU chunks from the server and scatter them back into device KV + tensors. + +## Protocol additions + +Three request types are used for non-CUDA mode (unchanged from the original +bounce-buffer design): + +- `REGISTER_KV_CACHE_BOUNCE` +- `STORE_CPU_CHUNKS` +- `RETRIEVE_CPU_CHUNKS` + +These are registered in the MP server dispatch and have corresponding +payload/response contracts in the multiprocess protocol definitions. + +## File structure + +``` +lmcache/v1/multiprocess/ +├── cpu_context.py # CPUContextMetadata, CPUContext(ABC), factory, gather/scatter utils +└── cpu_context_pickle.py # CPUContextPickle — pickle-based concrete implementation +``` + +### `cpu_context.py` + +Provides: + +- **`CPUContextMetadata`** dataclass — layout metadata (replaces the old + `CPUBounceContext` dataclass): + + ```python + @dataclass + class CPUContextMetadata: + layout_desc: MemoryLayoutDesc + block_size: int + use_mla: bool + ``` + +- **`CPUContext(ABC)`** — abstract base class with `mq_client` as a common + dependency. All concrete implementations share the same two-phase + `prepare/commit` interface: + + ```python + class CPUContext(ABC): + def __init__(self, metadata: CPUContextMetadata, mq_client, mq_timeout: float): ... + + @abstractmethod + def prepare_store(self, key, instance_id, chunks: list[torch.Tensor]) -> Any: ... + @abstractmethod + def commit_store(self, handle: Any) -> bool: ... + @abstractmethod + def prepare_retrieve(self, key, instance_id) -> tuple[Any, list[torch.Tensor] | None]: ... + @abstractmethod + def commit_retrieve(self, handle: Any) -> None: ... + @abstractmethod + def close(self) -> None: ... + ``` + +- **`create_cpu_context()`** factory — currently always returns a + `CPUContextPickle` instance; a future SHM-capable PR can extend this to + probe for shared-memory availability and fall back to pickle. + +- **Shared utility functions** used by all concrete implementations: + - `compute_kv_layout` — extract block size, layer count, hidden dim and + dtype from live KV tensors. + - `gather_chunks_to_cpu` — gather paged KV blocks into a list of CPU + tensors (one per LMCache chunk). + - `scatter_cpu_chunks_to_kv` — scatter CPU chunk tensors back into paged + KV tensors. + +### `cpu_context_pickle.py` + +Provides **`CPUContextPickle(CPUContext)`**: + +| Phase | What happens | +|---|---| +| `prepare_store` | `pickle.dumps(chunks)` → returns `(key, instance_id, bytes)` as opaque handle | +| `commit_store` | sends `STORE_CPU_CHUNKS` via `mq_client`, blocks for server ack, returns `bool` | +| `prepare_retrieve` | sends `RETRIEVE_CPU_CHUNKS` via `mq_client`, blocks for response, `pickle.loads` → returns `(None, chunks)` or `(None, None)` on miss | +| `commit_retrieve` | no-op (pickle path holds no server-side locks) | +| `close` | no-op | + +## Tensor/chunk contracts + +Chunk formats are unchanged: + +- non-MLA: `[2, num_layers, chunk_tokens, hidden_dim]` +- MLA: `[num_layers, chunk_tokens, hidden_dim]` + +Internal gather/scatter uses block-level indexing to avoid token-level slot +expansion and token-wise select/copy operations. + +## Layout handling + +Supported KV formats in CPU gather/scatter: + +- `NL_X_TWO_NB_BS_NH_HS` (NHD) +- `NL_X_NB_TWO_BS_NH_HS` (NHD flashinfer) +- `NL_X_TWO_NB_NH_BS_HS` (HND) +- `NL_X_NB_TWO_NH_BS_HS` (HND flashinfer) +- `NL_X_NB_BS_HS` (MLA) + +## Worker adapter integration + +`lmcache/integration/vllm/vllm_multi_process_adapter.py` chooses the path +by tensor `device.type`: + +- all CUDA → existing CUDA IPC registration and store/retrieve path +- all non-CUDA → bounce registration and CPU context store/retrieve path + +The adapter holds a `cpu_context: CPUContext` instance and uses the uniform +`prepare/commit` interface for both store and retrieve. + +### Store path (non-CUDA) + +```python +# submit_store_request +cpu_chunks = gather_chunks_to_cpu(kv_caches, block_ids, blocks_in_chunk, ...) +handle = self.cpu_context.prepare_store(key, instance_id, cpu_chunks) +ok = self.cpu_context.commit_store(handle) # synchronous; blocks for server ack +self._cpu_store_done[request_id] = ok +``` + +`get_finished` drains `_cpu_store_done` on each call. + +### Retrieve path (non-CUDA) + +```python +# submit_retrieve_request +handle, chunks = self.cpu_context.prepare_retrieve(key, instance_id) # synchronous +if chunks is not None: + scatter_cpu_chunks_to_kv(kv_caches, block_ids, chunks, blocks_in_chunk, + skip_first_n_tokens=op.skip_first_n_tokens, ...) +self.cpu_context.commit_retrieve(handle) +self._cpu_retrieve_done[request_id] = (chunks is not None, block_ids) +``` + +`get_finished` drains `_cpu_retrieve_done` on each call. + +The retrieve is now **synchronous in `submit_retrieve_request`**; there is no +separate future to poll. This simplifies `get_finished` which no longer +needs a `if self._use_cpu_context:` branch for retrieve futures. + +## Server integration + +`MPCacheEngine` holds: + +- `cpu_contexts: dict[int, CPUContextMetadata]` — per-instance metadata. +- `cpu_context_meta: dict[int, tuple[str, int]]` — per-instance + `(model_name, world_size)` for layout resolution. + +Server-side handler methods are unchanged: +- `register_kv_cache_bounce` — stores `CPUContextMetadata` in `cpu_contexts`. +- `store_cpu_chunks` — unpickles payload, copies tensors into storage. +- `retrieve_cpu_chunks` — reads from storage, pickles tensors, returns bytes. + +Additional integration points: + +- Unregister cleanup removes both `cpu_contexts` and `cpu_context_meta`. +- Layout lookup via `_find_layout_desc` resolves both GPU and CPU context + registrations. +- Status reporting (`report_status`) includes `registered_cpu_instance_ids` + and `cpu_context_meta`. + +## CUDA vs non-CUDA state machine + +```text + register_kv_caches() + | + v + [Inspect device.type] + | + +----------------+----------------+ + | | + v v + [device == cuda] [device != cuda] + | | + v v + REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_BOUNCE (CPU metadata) + | + create_cpu_context() + +----------------+----------------+ + | + v + [READY / SERVING] + | + +----------------+----------------+ + | | + v v + submit_store() submit_store() + | | + v v + STORE (GPU -> L1) gather_chunks_to_cpu() + | + cpu_context.prepare_store() + v + cpu_context.commit_store() [sync] + [READY] _cpu_store_done[id] = ok + | | + +----------------+----------------+ + | + v + submit_retrieve() + get_finished() + | + +----------------+----------------+ + | | + v v + RETRIEVE (L1 -> GPU) cpu_context.prepare_retrieve() [sync] + [async future] + scatter_cpu_chunks_to_kv() + + cpu_context.commit_retrieve() + _cpu_retrieve_done[id] = (ok, block_ids) + | | + +----------------+----------------+ + | + v + [READY / SERVING] + | + v + unregister_kv_cache() + | + v + [TERMINATED] +``` + +## Future extension: CPUContextShm + +The `CPUContext` base class is designed to accommodate a shared-memory +implementation in a future PR with minimal changes: + +| Phase | Pickle | SHM (future) | +|---|---|---| +| `prepare_store` | `pickle.dumps` | MQ `PREPARE_STORE` → slot metadata → memcpy | +| `commit_store` | MQ `STORE_CPU_CHUNKS` | MQ `COMMIT_STORE` | +| `prepare_retrieve` | MQ `RETRIEVE_CPU_CHUNKS` + `pickle.loads` | MQ `PREPARE_RETRIEVE` → tensor views from SHM | +| `commit_retrieve` | no-op | MQ `FINISH_READ` (release read lock) | + +The `create_cpu_context()` factory will probe for SHM availability and fall +back to pickle when SHM is unavailable. + +## Validation coverage + +`tests/v1/multiprocess/test_cpu_context.py` covers: + +- CPU wrapper behavior (`wrap_kv_caches` with bounce mode) +- NHD and MLA gather/scatter round-trip +- HND round-trip for both HND formats +- `skip_first_n_tokens` behavior +- Server-side register/store/retrieve flow + +## Non-goals + +- No change to existing CUDA IPC path semantics. +- No CPU-specific logic added to shared `gpu_connector/utils.py`. diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 790adabd431..f0ed611175e 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -11,8 +11,15 @@ import zmq # First Party +from lmcache import torch_dev from lmcache.integration.request_telemetry.factory import RequestTelemetryFactory from lmcache.utils import EngineType, _lmcache_nvtx_annotate, init_logger +from lmcache.v1.multiprocess.cpu_context import ( + CPUContext, + compute_kv_layout, + gather_chunks_to_cpu, + scatter_cpu_chunks_to_kv, +) from lmcache.v1.multiprocess.custom_types import ( BlockAllocationRecord, CudaIPCWrapper, @@ -32,7 +39,11 @@ DEFAULT_HEARTBEAT_INTERVAL: float = 10.0 -def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache: +def wrap_kv_caches( + kv_caches: dict[str, torch.Tensor], use_cpu_context: bool = False +) -> KVCache: + if use_cpu_context: + return [] logger.info("KV caches keys are %s", list(kv_caches.keys())) return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()] @@ -701,6 +712,17 @@ def __init__( str, tuple[MessagingFuture[RetrieveResult], list[int]] ] = {} + # Non-CUDA (bounce-buffer) mode state + self._use_cpu_context: bool = False + self._device_type: str = "cuda" + # CPU context for non-CUDA (bounce-buffer) mode + self.cpu_context: CPUContext | None = None + self._bounce_layout_hints: Any = None + self._bounce_gpu_kv_format: Any = None + # Completed synchronous CPU store/retrieve results, keyed by request_id + self._cpu_store_done: dict[str, bool] = {} + self._cpu_retrieve_done: dict[str, tuple[bool, list[int]]] = {} + # Block IDs that failed due to retrieve timeout self.error_block_ids: set[int] = set() @@ -824,18 +846,90 @@ def _send_register_kv_caches_request( from lmcache.integration.vllm.utils import vllm_layout_hints layout_hints = vllm_layout_hints() - future = send_lmcache_request( - self.mq_client, - RequestType.REGISTER_KV_CACHE, - [ - self.instance_id, - wrap_kv_caches(kv_caches), - self.model_name, - self.world_size, - EngineType.VLLM, - layout_hints, - ], + self.kv_caches = kv_caches + + if not kv_caches: + raise ValueError("kv_caches is empty") + device_types = {tensor.device.type for tensor in kv_caches.values()} + if len(device_types) != 1: + raise ValueError( + f"All KV cache tensors must share one device type, got {device_types}" + ) + self._device_type = next(iter(device_types)) + self._use_cpu_context = self._device_type != "cuda" + logger.info( + "Registering kv caches (device_type=%s, bounce=%s)", + self._device_type, + self._use_cpu_context, ) + + if self._use_cpu_context: + # First Party + from lmcache.v1.distributed.api import MemoryLayoutDesc + from lmcache.v1.gpu_connector.utils import is_mla + from lmcache.v1.multiprocess.cpu_context import ( + CPUContextMetadata, + create_cpu_context, + ) + + ( + block_size, + num_layers, + hidden_dim_size, + dtype_str, + gpu_kv_format, + ) = compute_kv_layout(kv_caches, layout_hints=layout_hints) + self._bounce_layout_hints = layout_hints + self._bounce_gpu_kv_format = gpu_kv_format + future = send_lmcache_request( + self.mq_client, + RequestType.REGISTER_KV_CACHE_BOUNCE, + [ + self.instance_id, + self.model_name, + self.world_size, + EngineType.VLLM, + layout_hints, + block_size, + num_layers, + hidden_dim_size, + dtype_str, + is_mla(gpu_kv_format), + ], + ) + # Build the layout descriptor so we can construct the CPUContext. + use_mla_flag = is_mla(gpu_kv_format) + shape = ( + torch.Size( + [num_layers, self.blocks_in_chunk * block_size, hidden_dim_size] + ) + if use_mla_flag + else torch.Size( + [2, num_layers, self.blocks_in_chunk * block_size, hidden_dim_size] + ) + ) + dtype = getattr(torch, dtype_str) + metadata = CPUContextMetadata( + layout_desc=MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]), + block_size=block_size, + use_mla=use_mla_flag, + ) + self.cpu_context = create_cpu_context( + metadata, self.mq_client, self._mq_timeout + ) + else: + future = send_lmcache_request( + self.mq_client, + RequestType.REGISTER_KV_CACHE, + [ + self.instance_id, + wrap_kv_caches(kv_caches), + self.model_name, + self.world_size, + EngineType.VLLM, + layout_hints, + ], + ) try: future.result(timeout=self._mq_timeout) except TimeoutError: @@ -927,12 +1021,26 @@ def submit_store_request( request_id=request_id, cache_salt=cache_salt, ) - future = send_lmcache_request( - self.mq_client, - RequestType.STORE, - [key, self.instance_id, op.block_ids, event.ipc_handle()], - ).to_cuda_future() - self.store_futures[request_id] = future + if self._use_cpu_context: + assert self.cpu_context is not None + torch_dev.synchronize() + cpu_chunks = gather_chunks_to_cpu( + self.kv_caches, + op.block_ids, + self.blocks_in_chunk, + layout_hints=self._bounce_layout_hints, + gpu_kv_format=self._bounce_gpu_kv_format, + ) + handle = self.cpu_context.prepare_store(key, self.instance_id, cpu_chunks) + ok = self.cpu_context.commit_store(handle) + self._cpu_store_done[request_id] = ok + else: + future = send_lmcache_request( + self.mq_client, + RequestType.STORE, + [key, self.instance_id, op.block_ids, event.ipc_handle()], + ).to_cuda_future() + self.store_futures[request_id] = future @_lmcache_nvtx_annotate def submit_retrieve_request( @@ -966,18 +1074,39 @@ def submit_retrieve_request( request_id=request_id, cache_salt=cache_salt, ) - future = send_lmcache_request( - self.mq_client, - RequestType.RETRIEVE, - [ - key, - self.instance_id, - op.block_ids, - event.ipc_handle(), - op.skip_first_n_tokens, - ], - ).to_cuda_future() - self.retrieve_futures[request_id] = (future, list(op.block_ids)) + if self._use_cpu_context: + assert self.cpu_context is not None + handle, chunks = self.cpu_context.prepare_retrieve(key, self.instance_id) + ok = chunks is not None + if chunks is not None: + try: + scatter_cpu_chunks_to_kv( + self.kv_caches, + op.block_ids, + chunks, + self.blocks_in_chunk, + skip_first_n_tokens=op.skip_first_n_tokens, + layout_hints=self._bounce_layout_hints, + gpu_kv_format=self._bounce_gpu_kv_format, + ) + except Exception: + logger.exception("Failed to scatter retrieved CPU context chunks") + ok = False + self.cpu_context.commit_retrieve(handle) + self._cpu_retrieve_done[request_id] = (ok, list(op.block_ids)) + else: + future = send_lmcache_request( + self.mq_client, + RequestType.RETRIEVE, + [ + key, + self.instance_id, + op.block_ids, + event.ipc_handle(), + op.skip_first_n_tokens, + ], + ).to_cuda_future() + self.retrieve_futures[request_id] = (future, list(op.block_ids)) @_lmcache_nvtx_annotate def batched_submit_store_requests( @@ -1076,7 +1205,9 @@ def get_finished( """ # If unhealthy, drain all pending futures immediately if not self.is_healthy: - finished_stores = set(self.store_futures.keys()) + finished_stores = set(self.store_futures.keys()) | set( + self._cpu_store_done.keys() + ) finished_retrieves = set() for request_id, ( _r_future, @@ -1084,8 +1215,14 @@ def get_finished( ) in self.retrieve_futures.items(): finished_retrieves.add(request_id) self.error_block_ids.update(r_block_ids) + for request_id, (ok, r_block_ids) in self._cpu_retrieve_done.items(): + finished_retrieves.add(request_id) + if not ok: + self.error_block_ids.update(r_block_ids) self.store_futures.clear() self.retrieve_futures.clear() + self._cpu_store_done.clear() + self._cpu_retrieve_done.clear() ret_stores = self._process_finished_stores( finished_stores, finished_req_ids_from_engine @@ -1100,6 +1237,31 @@ def get_finished( finished_stores = set() finished_retrieves = set() + + # Drain completed synchronous CPU store results + for request_id, ok in list(self._cpu_store_done.items()): + finished_stores.add(request_id) + if not ok: + logger.error( + "Something went wrong when processing the " + "store request for request_id=%s", + request_id, + ) + self._cpu_store_done.clear() + + # Drain completed synchronous CPU retrieve results + for request_id, (ok, r_block_ids) in list(self._cpu_retrieve_done.items()): + finished_retrieves.add(request_id) + if not ok: + logger.error( + "Something went wrong when processing the " + "retrieve request for request_id=%s, result=%s", + request_id, + ok, + ) + self.error_block_ids.update(r_block_ids) + self._cpu_retrieve_done.clear() + for request_id, s_future in self.store_futures.items(): if not s_future.query(): continue @@ -1114,7 +1276,7 @@ def get_finished( request_id, ) - for request_id, (r_future, _) in self.retrieve_futures.items(): + for request_id, (r_future, r_block_ids) in self.retrieve_futures.items(): if not r_future.query(): continue diff --git a/lmcache/non_cuda_equivalents.py b/lmcache/non_cuda_equivalents.py index 68146cdfc54..25be5d98e91 100644 --- a/lmcache/non_cuda_equivalents.py +++ b/lmcache/non_cuda_equivalents.py @@ -17,7 +17,7 @@ import torch # First Party -from lmcache import torch_dev, torch_device_type +from lmcache import torch_dev # Store the tensor objects in memory so that they can be accessed # outside the scope of this file @@ -305,10 +305,7 @@ def alloc_pinned_numa_ptr(size: int, numa_id: int = 0) -> int: Note: NUMA node selection is not supported on non-CUDA.""" # Create a 1D uint8 CPU tensor, as uint8 == 1 byte - # On XPU (Intel GPU), PyTorch 2.4+ supports pin_memory=True via SYCL USM - # host allocation, enabling fast DMA for XPU<->CPU transfers. - pin_memory = torch_device_type == "xpu" - tensor = torch.empty(size, dtype=torch.uint8, pin_memory=pin_memory) + tensor = torch.empty(size, dtype=torch.uint8, pin_memory=False) # First-touch initialization (forces physical allocation) tensor.fill_(0) @@ -336,10 +333,7 @@ def alloc_pinned_ptr(size: int, device_id: int = 0) -> int: fast DMA transfers. On other non-CUDA platforms, pinning is not supported.""" # Create a 1D uint8 CPU tensor, as uint8 == 1 byte - # On XPU (Intel GPU), PyTorch 2.4+ supports pin_memory=True via SYCL USM - # host allocation, enabling fast DMA for XPU<->CPU transfers. - pin_memory = torch_device_type == "xpu" - tensor = torch.empty(size, dtype=torch.uint8, pin_memory=pin_memory) + tensor = torch.empty(size, dtype=torch.uint8, pin_memory=False) # First-touch initialization (forces physical allocation) tensor.fill_(0) diff --git a/lmcache/v1/multiprocess/cpu_context.py b/lmcache/v1/multiprocess/cpu_context.py new file mode 100644 index 00000000000..126b8778986 --- /dev/null +++ b/lmcache/v1/multiprocess/cpu_context.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 +"""CPU context abstractions and utilities for multiprocess mode. + +This module provides: +- ``CPUContextMetadata``: layout metadata dataclass for non-CUDA workers. +- ``CPUContext``: abstract base class with a two-phase prepare/commit interface + for CPU-side KV data transfer. Concrete implementations (e.g. + ``CPUContextPickle``) each decide *how* data is serialised and transported. +- ``create_cpu_context()``: factory that returns the appropriate + ``CPUContext`` subclass (currently always ``CPUContextPickle``). +- ``compute_kv_layout``, ``gather_chunks_to_cpu``, ``scatter_cpu_chunks_to_kv``: + shared gather/scatter utilities used by all concrete implementations. +""" + +# Standard +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, cast + +# Third Party +import torch + +# First Party +from lmcache.utils import EngineType +from lmcache.v1.distributed.api import MemoryLayoutDesc + + +@dataclass +class CPUContextMetadata: + """CPU context layout metadata for non-CUDA workers. + + Attributes: + layout_desc: Memory layout descriptor used to interpret chunk payloads. + block_size: Number of tokens per paged block. + use_mla: Whether the worker KV format is MLA. + """ + + layout_desc: MemoryLayoutDesc + block_size: int + use_mla: bool + + +class CPUContext(ABC): + """Abstract base class for CPU-side KV data transfer contexts. + + All concrete implementations share a common message-queue client and + expose a uniform two-phase ``prepare/commit`` interface so that the + worker adapter is implementation-agnostic. + + Args: + metadata: Layout metadata describing the chunk format. + mq_client: Message-queue client used for server communication. + mq_timeout: Timeout in seconds for blocking MQ requests. + """ + + def __init__( + self, + metadata: CPUContextMetadata, + mq_client: Any, + mq_timeout: float, + ) -> None: + self.metadata = metadata + self.mq_client = mq_client + self.mq_timeout = mq_timeout + + @property + def layout_desc(self) -> MemoryLayoutDesc: + """The memory layout descriptor for this context.""" + return self.metadata.layout_desc + + @abstractmethod + def prepare_store( + self, key: Any, instance_id: int, chunks: list[torch.Tensor] + ) -> Any: + """Prepare a store operation. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + chunks: CPU chunk tensors to store. + + Returns: + An opaque handle to be passed to :meth:`commit_store`. + """ + ... + + @abstractmethod + def commit_store(self, handle: Any) -> bool: + """Commit a prepared store operation. + + Args: + handle: The opaque handle returned by :meth:`prepare_store`. + + Returns: + ``True`` on success, ``False`` otherwise. + """ + ... + + @abstractmethod + def prepare_retrieve( + self, key: Any, instance_id: int + ) -> tuple[Any, list[torch.Tensor] | None]: + """Prepare a retrieve operation. + + Args: + key: Cache key for the token range to retrieve. + instance_id: Worker instance identifier. + + Returns: + A ``(handle, chunks)`` pair. ``chunks`` is a list of CPU tensors + on cache hit, or ``None`` on cache miss. The handle must be + passed to :meth:`commit_retrieve`. + """ + ... + + @abstractmethod + def commit_retrieve(self, handle: Any) -> None: + """Finalise a retrieve operation (release locks, cleanup, etc.). + + Args: + handle: The opaque handle returned by :meth:`prepare_retrieve`. + """ + ... + + @abstractmethod + def close(self) -> None: + """Release any resources held by this context.""" + ... + + +def create_cpu_context( + metadata: CPUContextMetadata, + mq_client: Any, + mq_timeout: float, +) -> CPUContext: + """Factory that returns the appropriate :class:`CPUContext` implementation. + + Currently always returns a :class:`~lmcache.v1.multiprocess.\ +cpu_context_pickle.CPUContextPickle` instance. A future SHM-capable PR + may probe for shared-memory availability and fall back to pickle. + + Args: + metadata: Layout metadata for the CPU context. + mq_client: Message-queue client for server communication. + mq_timeout: Timeout in seconds for blocking MQ requests. + + Returns: + A concrete :class:`CPUContext` instance. + """ + # Local + from .cpu_context_pickle import CPUContextPickle + + return CPUContextPickle(metadata, mq_client, mq_timeout) + + +# --------------------------------------------------------------------------- +# Shared gather / scatter utilities +# --------------------------------------------------------------------------- + + +def compute_kv_layout( + kv_caches: dict[str, torch.Tensor], + layout_hints: Any | None = None, +) -> tuple[int, int, int, str, Any]: + """Compute KV layout metadata from KV tensors. + + Args: + kv_caches: Per-layer KV tensor mapping. + layout_hints: Optional engine layout hints. + + Returns: + Tuple of ``(block_size, num_layers, hidden_dim_size, dtype_str,`` + ``gpu_kv_format)``. + + Raises: + ValueError: If ``kv_caches`` is empty. + """ + # First Party + from lmcache.v1.gpu_connector.utils import ( + get_block_size, + get_hidden_dim_size, + get_num_layers, + normalize_kv_and_discover_format, + ) + + tensors = list(kv_caches.values()) + if not tensors: + raise ValueError("kv_caches is empty. Cannot compute KV layout.") + + gpu_kv_format, normalized = normalize_kv_and_discover_format( + tensors, EngineType.VLLM, layout_hints=layout_hints + ) + block_size = get_block_size(normalized, gpu_kv_format) + num_layers = get_num_layers(normalized, gpu_kv_format) + hidden_dim_size = get_hidden_dim_size(normalized, gpu_kv_format) + dtype_str = str(tensors[0].dtype).replace("torch.", "") + return block_size, num_layers, hidden_dim_size, dtype_str, gpu_kv_format + + +def gather_chunks_to_cpu( + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + blocks_per_chunk: int, + layout_hints: Any | None = None, + gpu_kv_format: Any | None = None, +) -> list[torch.Tensor]: + """Gather paged KV blocks into CPU chunk tensors. + + Args: + kv_caches: Per-layer KV tensor mapping. + block_ids: Flattened block IDs for all chunks. + blocks_per_chunk: Number of paged blocks in one LMCache chunk. + layout_hints: Optional engine layout hints. + gpu_kv_format: Optional pre-detected KV format. + + Returns: + List of CPU tensors, one per chunk. For non-MLA each chunk has shape + ``[2, num_layers, chunk_tokens, hidden_dim]`` where dimension ``0`` + stores ``(K, V)``. For MLA (multi-head latent attention) each chunk + has shape ``[num_layers, chunk_tokens, hidden_dim]``. + """ + # First Party + from lmcache.v1.gpu_connector.utils import ( + get_block_size, + is_mla, + normalize_kv_and_discover_format, + ) + import lmcache.c_ops as lmc_ops + + tensors = list(kv_caches.values()) + fmt, normalized = normalize_kv_and_discover_format( + tensors, EngineType.VLLM, layout_hints=layout_hints + ) + if gpu_kv_format is None: + gpu_kv_format = fmt + use_mla = is_mla(gpu_kv_format) + is_hnd = gpu_kv_format in ( + lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + ) + + block_size = get_block_size(normalized, gpu_kv_format) + num_chunks = len(block_ids) // blocks_per_chunk + + # After normalization the structure is always a list of per-layer + # tensors. Cast once so all downstream indexing is typed correctly. + layer_tensors = cast(list[torch.Tensor], normalized) + + chunks: list[torch.Tensor] = [] + for chunk_idx in range(num_chunks): + chunk_block_ids = block_ids[ + chunk_idx * blocks_per_chunk : (chunk_idx + 1) * blocks_per_chunk + ] + if use_mla: + mla_layers: list[torch.Tensor] = [] + for layer in layer_tensors: + layer_blocks = layer[torch.tensor(chunk_block_ids, dtype=torch.long)] + mla_layers.append( + layer_blocks.reshape( + len(chunk_block_ids) * block_size, layer_blocks.shape[-1] + ) + ) + chunks.append(torch.stack(mla_layers, dim=0).cpu()) + else: + k_layers: list[torch.Tensor] = [] + v_layers: list[torch.Tensor] = [] + for layer in layer_tensors: + if is_hnd: + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _num_blocks, num_heads, _block_size, head_size = k_t.shape + k_blocks = k_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + v_blocks = v_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + # HND blocks are [NB, NH, BS, HS]; convert to token-major + # [NB, BS, NH, HS] before flattening to [tokens, NH*HS]. + k_layers.append( + k_blocks.permute(0, 2, 1, 3).reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + v_layers.append( + v_blocks.permute(0, 2, 1, 3).reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + else: + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _num_blocks, _block_size, num_heads, head_size = k_t.shape + k_blocks = k_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + v_blocks = v_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + k_layers.append( + k_blocks.reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + v_layers.append( + v_blocks.reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + k_stacked = torch.stack(k_layers, dim=0) + v_stacked = torch.stack(v_layers, dim=0) + chunks.append(torch.stack([k_stacked, v_stacked], dim=0).cpu()) + return chunks + + +def scatter_cpu_chunks_to_kv( + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + chunks: list[torch.Tensor], + blocks_per_chunk: int, + skip_first_n_tokens: int = 0, + layout_hints: Any | None = None, + gpu_kv_format: Any | None = None, +) -> None: + """Scatter CPU chunk tensors back into paged KV tensors. + + Args: + kv_caches: Per-layer KV tensor mapping to write into. + block_ids: Flattened destination block IDs for all chunks. + chunks: List of CPU chunk tensors (as returned by + :func:`gather_chunks_to_cpu`). + blocks_per_chunk: Number of paged blocks in one LMCache chunk. + skip_first_n_tokens: Token prefix to skip when scattering. + layout_hints: Optional engine layout hints. + gpu_kv_format: Optional pre-detected KV format. + """ + # First Party + from lmcache.v1.gpu_connector.utils import ( + get_block_size, + is_mla, + normalize_kv_and_discover_format, + ) + import lmcache.c_ops as lmc_ops + + if not chunks: + return + + tensors = list(kv_caches.values()) + fmt, normalized = normalize_kv_and_discover_format( + tensors, EngineType.VLLM, layout_hints=layout_hints + ) + if gpu_kv_format is None: + gpu_kv_format = fmt + use_mla = is_mla(gpu_kv_format) + + block_size = get_block_size(normalized, gpu_kv_format) + device = tensors[0].device + is_hnd = gpu_kv_format in ( + lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + ) + + # After normalization the structure is always a list of per-layer + # tensors. Cast once so all downstream indexing is typed correctly. + layer_tensors = cast(list[torch.Tensor], normalized) + + for chunk_idx, chunk_cpu in enumerate(chunks): + chunk_block_ids = block_ids[ + chunk_idx * blocks_per_chunk : (chunk_idx + 1) * blocks_per_chunk + ] + if not chunk_block_ids: + continue + + chunk_start_token = chunk_idx * blocks_per_chunk * block_size + chunk_end_token = chunk_start_token + len(chunk_block_ids) * block_size + effective_start = max(chunk_start_token, skip_first_n_tokens) + if effective_start >= chunk_end_token: + continue + + skip_blocks_in_chunk = (effective_start - chunk_start_token) // block_size + effective_block_ids = chunk_block_ids[skip_blocks_in_chunk:] + if not effective_block_ids: + continue + + skip_tokens = skip_blocks_in_chunk * block_size + chunk_device = chunk_cpu.to(device) + + if use_mla: + for layer_idx, layer in enumerate(layer_tensors): + mla_src = chunk_device[layer_idx, skip_tokens:] + hidden_size = layer.shape[-1] + mla_src_3d = mla_src.reshape( + len(effective_block_ids), block_size, hidden_size + ) + layer[effective_block_ids] = mla_src_3d + elif is_hnd: + for layer_idx, layer in enumerate(layer_tensors): + k_src = chunk_device[0, layer_idx, skip_tokens:] + v_src = chunk_device[1, layer_idx, skip_tokens:] + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _nb, nh, _bs, hs = k_t.shape + k_blocks = k_src.reshape( + len(effective_block_ids), block_size, nh, hs + ).permute(0, 2, 1, 3) + v_blocks = v_src.reshape( + len(effective_block_ids), block_size, nh, hs + ).permute(0, 2, 1, 3) + k_t[effective_block_ids] = k_blocks + v_t[effective_block_ids] = v_blocks + else: + for layer_idx, layer in enumerate(layer_tensors): + k_src = chunk_device[0, layer_idx, skip_tokens:] + v_src = chunk_device[1, layer_idx, skip_tokens:] + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _num_blocks, _block_size, num_heads, head_size = k_t.shape + k_src_4d = k_src.reshape( + len(effective_block_ids), block_size, num_heads, head_size + ) + v_src_4d = v_src.reshape( + len(effective_block_ids), block_size, num_heads, head_size + ) + k_t[effective_block_ids] = k_src_4d + v_t[effective_block_ids] = v_src_4d diff --git a/lmcache/v1/multiprocess/cpu_context_pickle.py b/lmcache/v1/multiprocess/cpu_context_pickle.py new file mode 100644 index 00000000000..95bf7f6127c --- /dev/null +++ b/lmcache/v1/multiprocess/cpu_context_pickle.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Pickle-based CPUContext implementation for multiprocess mode.""" + +# Standard +from typing import Any +import pickle + +# Third Party +import torch + +# First Party +from lmcache.v1.multiprocess.cpu_context import CPUContext, CPUContextMetadata +from lmcache.v1.multiprocess.protocol import RequestType, get_response_class + + +class CPUContextPickle(CPUContext): + """Pickle-based implementation of :class:`CPUContext`. + + Transport mechanism: + - **Store**: ``prepare_store`` serialises chunks with ``pickle.dumps``; \ +``commit_store`` sends a ``STORE_CPU_CHUNKS`` message and waits for the \ +server acknowledgment. + - **Retrieve**: ``prepare_retrieve`` sends a ``RETRIEVE_CPU_CHUNKS`` \ +message, waits for the response, and deserialises the returned bytes with \ +``pickle.loads``; ``commit_retrieve`` is a no-op (no locks to release). + + Args: + metadata: Layout metadata for the CPU context. + mq_client: Message-queue client for server communication. + mq_timeout: Timeout in seconds for blocking MQ requests. + """ + + def __init__( + self, + metadata: CPUContextMetadata, + mq_client: Any, + mq_timeout: float, + ) -> None: + super().__init__(metadata, mq_client, mq_timeout) + + def prepare_store( + self, key: Any, instance_id: int, chunks: list[torch.Tensor] + ) -> Any: + """Serialise *chunks* with ``pickle.dumps``. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + chunks: CPU chunk tensors to serialise. + + Returns: + Opaque handle ``(key, instance_id, serialised_bytes)`` to be + passed to :meth:`commit_store`. + """ + serialised = pickle.dumps(chunks) + return (key, instance_id, serialised) + + def commit_store(self, handle: Any) -> bool: + """Send pickled chunks to the server via ``STORE_CPU_CHUNKS``. + + Blocks until the server acknowledges the write. + + Args: + handle: The ``(key, instance_id, bytes)`` tuple returned by + :meth:`prepare_store`. + + Returns: + ``True`` on success, ``False`` on failure or timeout. + """ + key, instance_id, serialised = handle + future = self.mq_client.submit_request( + RequestType.STORE_CPU_CHUNKS, + [key, instance_id, serialised], + get_response_class(RequestType.STORE_CPU_CHUNKS), + ) + try: + return bool(future.result(timeout=self.mq_timeout)) + except TimeoutError: + return False + + def prepare_retrieve( + self, key: Any, instance_id: int + ) -> tuple[Any, list[torch.Tensor] | None]: + """Fetch serialised chunks from the server via ``RETRIEVE_CPU_CHUNKS``. + + Blocks until the server responds with the cached data (or reports a + miss). + + Args: + key: Cache key for the token range to retrieve. + instance_id: Worker instance identifier. + + Returns: + ``(None, chunks)`` on cache hit where *chunks* is the + deserialised list of CPU tensors, or ``(None, None)`` on cache + miss or timeout. The handle is ``None`` because the pickle path + has no resources to release in :meth:`commit_retrieve`. + """ + future = self.mq_client.submit_request( + RequestType.RETRIEVE_CPU_CHUNKS, + [key, instance_id], + get_response_class(RequestType.RETRIEVE_CPU_CHUNKS), + ) + try: + success, cpu_data_bytes = future.result(timeout=self.mq_timeout) + except TimeoutError: + return (None, None) + if not success or not cpu_data_bytes: + return (None, None) + chunks: list[torch.Tensor] = pickle.loads(cpu_data_bytes) + return (None, chunks) + + def commit_retrieve(self, handle: Any) -> None: + """No-op: the pickle path holds no server-side locks. + + Args: + handle: Ignored. + """ + + def close(self) -> None: + """No-op: the pickle path holds no persistent resources.""" diff --git a/lmcache/v1/multiprocess/protocols/base.py b/lmcache/v1/multiprocess/protocols/base.py index 383a41ff8c3..cd00322c1fd 100644 --- a/lmcache/v1/multiprocess/protocols/base.py +++ b/lmcache/v1/multiprocess/protocols/base.py @@ -48,6 +48,9 @@ class RequestType(enum.Enum): QUERY_PREFETCH_LOOKUP_HITS = enum.auto() FREE_LOOKUP_LOCKS = enum.auto() END_SESSION = enum.auto() + REGISTER_KV_CACHE_BOUNCE = enum.auto() + STORE_CPU_CHUNKS = enum.auto() + RETRIEVE_CPU_CHUNKS = enum.auto() # Controller operations CLEAR = enum.auto() diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index e9f37fd422f..63df2ac656f 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -32,6 +32,9 @@ "QUERY_PREFETCH_LOOKUP_HITS", "FREE_LOOKUP_LOCKS", "END_SESSION", + "REGISTER_KV_CACHE_BOUNCE", + "STORE_CPU_CHUNKS", + "RETRIEVE_CPU_CHUNKS", ] # Type alias for cache keys @@ -146,4 +149,30 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: response_class=None, handler_type=HandlerType.BLOCKING, ), + "REGISTER_KV_CACHE_BOUNCE": ProtocolDefinition( + payload_classes=[ + int, + str, + int, + EngineType, + LayoutHints, + int, + int, + int, + str, + bool, + ], + response_class=None, + handler_type=HandlerType.SYNC, + ), + "STORE_CPU_CHUNKS": ProtocolDefinition( + payload_classes=[KeyType, int, bytes], + response_class=bool, + handler_type=HandlerType.BLOCKING, + ), + "RETRIEVE_CPU_CHUNKS": ProtocolDefinition( + payload_classes=[KeyType, int], + response_class=tuple[bool, bytes], + handler_type=HandlerType.BLOCKING, + ), } diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 4748d6c5d66..b8dd0d7dbbf 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -5,6 +5,7 @@ from itertools import islice from typing import Generator import argparse +import pickle import threading import time @@ -51,6 +52,7 @@ add_mp_server_args, parse_args_to_mp_server_config, ) +from lmcache.v1.multiprocess.cpu_context import CPUContextMetadata from lmcache.v1.multiprocess.custom_types import ( BlockAllocationRecord, IPCCacheEngineKey, @@ -189,6 +191,8 @@ def __init__( # We assume that if the (model name, world size) is the same, then # the layout desc returned by the gpu context is the same. self.gpu_context_meta: dict[int, tuple[str, int]] = {} + self.cpu_contexts: dict[int, CPUContextMetadata] = {} + self.cpu_context_meta: dict[int, tuple[str, int]] = {} # chunk size self.chunk_size = chunk_size @@ -273,9 +277,185 @@ def unregister_kv_cache(self, instance_id: int) -> None: del self.gpu_context_meta[instance_id] logger.info("Unregistered KV cache for GPU ID %d", instance_id) torch_dev.empty_cache() + elif instance_id in self.cpu_contexts: + del self.cpu_contexts[instance_id] + del self.cpu_context_meta[instance_id] + logger.info("Unregistered CPU context for instance ID %d", instance_id) else: logger.warning("No KV cache found for GPU ID %d to unregister", instance_id) + def register_kv_cache_bounce( + self, + instance_id: int, + model_name: str, + world_size: int, + engine_type: EngineType, + layout_hints: LayoutHints, + block_size: int, + num_layers: int, + hidden_dim_size: int, + dtype_str: str, + use_mla: bool, + ) -> None: + """Register non-CUDA KV layout metadata for CPU bounce-buffer mode. + + Args: + instance_id: Worker instance identifier (typically PID). + model_name: Model name associated with this worker. + world_size: Worker world size used in cache keys. + engine_type: Serving engine type (kept for protocol compatibility). + layout_hints: Optional engine layout hints (protocol compatibility). + block_size: Tokens per paged block. + num_layers: Number of model layers. + hidden_dim_size: Flattened hidden dimension per token. + dtype_str: Torch dtype name (for example ``"float16"``). + use_mla: Whether the worker KV format is MLA. + MLA stores one latent vector per token with shape + ``[num_layers, chunk_size, hidden_dim_size]``; non-MLA stores + separate K/V planes with shape + ``[2, num_layers, chunk_size, hidden_dim_size]``. + + Raises: + ValueError: If ``dtype_str`` is not a valid torch dtype name. + """ + # Third Party + import torch + + # Keep these for protocol compatibility with register_kv_cache(). + del engine_type, layout_hints + dtype = getattr(torch, dtype_str, None) + if dtype is None or not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid dtype_str '{dtype_str}': expected a torch.dtype name " + "(e.g. 'float16', 'bfloat16', 'float32')." + ) + + shape = ( + # MLA has one latent state per token (no separate K/V axis), + # while non-MLA stores separate K and V tensors at dim 0. + torch.Size([num_layers, self.chunk_size, hidden_dim_size]) + if use_mla + else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) + ) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + self.cpu_contexts[instance_id] = CPUContextMetadata( + layout_desc=layout_desc, + block_size=block_size, + use_mla=use_mla, + ) + self.cpu_context_meta[instance_id] = (model_name, world_size) + + @_lmcache_nvtx_annotate + def store_cpu_chunks( + self, + key: IPCCacheEngineKey, + instance_id: int, + cpu_data: bytes, + ) -> bool: + """Store worker-provided CPU chunks for non-CUDA bounce-buffer mode. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + cpu_data: Pickled list of CPU tensors produced by the worker. + + Returns: + ``True`` when all reserved objects are written, otherwise ``False``. + + Raises: + ValueError: If the instance has no registered bounce context. + """ + # Third Party + import torch + + session = self.session_manager.get_or_create(key.request_id) + session.set_tokens(list(key.token_ids)) + chunk_hashes = [ + TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) + ] + if key.worker_id is None: + raise ValueError("Must store with worker_id != None") + obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + + if instance_id not in self.cpu_contexts: + raise ValueError( + f"CPU context not registered for instance ID {instance_id}" + ) + ctx = self.cpu_contexts[instance_id] + chunks: list[torch.Tensor] = pickle.loads(cpu_data) + reserved_dict = self.storage_manager.reserve_write( + obj_keys, ctx.layout_desc, "new" + ) + written_keys: list[ObjectKey] = [] + try: + for idx, obj_key in enumerate(obj_keys): + if obj_key not in reserved_dict: + continue + if idx >= len(chunks): + continue + memory_obj = reserved_dict[obj_key] + if memory_obj.tensor is None: + continue + chunk_cpu = chunks[idx] + if chunk_cpu.shape != memory_obj.tensor.shape: + continue + memory_obj.tensor.copy_(chunk_cpu) + written_keys.append(obj_key) + finally: + if written_keys: + self.storage_manager.finish_write(written_keys) + + return len(written_keys) == len(reserved_dict) + + @_lmcache_nvtx_annotate + def retrieve_cpu_chunks( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> tuple[bool, bytes]: + """Retrieve prefetched chunks and return serialized CPU tensors. + + Args: + key: Cache key for the token range to retrieve. + instance_id: Worker instance identifier. + + Returns: + Tuple ``(success, payload)`` where ``payload`` is a pickled + list of CPU chunk tensors. + + Raises: + ValueError: If the instance has no registered bounce context. + """ + session = self.session_manager.get_or_create(key.request_id) + session.set_tokens(list(key.token_ids)) + chunk_hashes = [ + TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) + ] + if key.worker_id is None: + raise ValueError("Must retrieve with worker_id != None") + obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + + if instance_id not in self.cpu_contexts: + raise ValueError( + f"CPU context not registered for instance ID {instance_id}" + ) + + prefetched_keys: list[ObjectKey] = [] + try: + with self.storage_manager.read_prefetched_results(obj_keys) as memory_objs: + if not memory_objs or len(memory_objs) != len(obj_keys): + return False, b"" + prefetched_keys = obj_keys[: len(memory_objs)] + chunks = [] + for memory_obj in memory_objs: + if memory_obj.tensor is None: + return False, b"" + chunks.append(memory_obj.tensor.cpu().clone()) + return True, pickle.dumps(chunks) + finally: + if prefetched_keys: + self.storage_manager.finish_read_prefetched(prefetched_keys) + @_lmcache_nvtx_annotate def store( self, @@ -649,11 +829,12 @@ def _find_layout_desc( model_name: str, world_size: int, ) -> MemoryLayoutDesc | None: - """Find layout desc from a matching GPU context. + """Find layout desc from a matching GPU or CPU context. Returns: - The layout descriptor, or None if no context - matches (model_name, world_size). + The layout descriptor, or None if no context matches + ``(model_name, world_size)``. GPU contexts are checked first, + then CPU contexts. """ for gpu_id, (m, w) in self.gpu_context_meta.items(): if m == model_name and w == world_size: @@ -661,6 +842,9 @@ def _find_layout_desc( self.gpu_contexts[gpu_id], self.chunk_size, ) + for instance_id, (m, w) in self.cpu_context_meta.items(): + if m == model_name and w == world_size: + return self.cpu_contexts[instance_id].layout_desc return None def lookup( @@ -1005,6 +1189,18 @@ def report_status(self) -> dict: "hash_algorithm": self.token_hasher.hash_algorithm_name, "registered_gpu_ids": list(self.gpu_contexts.keys()), "gpu_context_meta": gpu_context_meta, + "registered_cpu_instance_ids": list(self.cpu_contexts.keys()), + "cpu_context_meta": { + str(instance_id): { + "model_name": model_name, + "world_size": world_size, + "block_size": self.cpu_contexts[instance_id].block_size, + "use_mla": self.cpu_contexts[instance_id].use_mla, + } + for instance_id, (model_name, world_size) in ( + self.cpu_context_meta.items() + ) + }, "active_sessions": self.session_manager.active_count(), "active_prefetch_jobs": self._active_prefetch_count(), "storage_manager": sm, @@ -1140,6 +1336,10 @@ def run_cache_server( server, RequestType.UNREGISTER_KV_CACHE, engine.unregister_kv_cache ) add_handler_helper(server, RequestType.STORE, engine.store) + add_handler_helper( + server, RequestType.REGISTER_KV_CACHE_BOUNCE, engine.register_kv_cache_bounce + ) + add_handler_helper(server, RequestType.STORE_CPU_CHUNKS, engine.store_cpu_chunks) add_handler_helper(server, RequestType.LOOKUP, engine.lookup) add_handler_helper( server, RequestType.QUERY_PREFETCH_STATUS, engine.query_prefetch_status @@ -1151,6 +1351,9 @@ def run_cache_server( ) add_handler_helper(server, RequestType.FREE_LOOKUP_LOCKS, engine.free_lookup_locks) add_handler_helper(server, RequestType.RETRIEVE, engine.retrieve) + add_handler_helper( + server, RequestType.RETRIEVE_CPU_CHUNKS, engine.retrieve_cpu_chunks + ) add_handler_helper(server, RequestType.CLEAR, engine.clear) add_handler_helper(server, RequestType.GET_CHUNK_SIZE, engine.get_chunk_size) add_handler_helper(server, RequestType.PING, engine.ping) @@ -1164,7 +1367,12 @@ def run_cache_server( # Assign thread pools server.add_affinity_thread_pool( - [RequestType.STORE, RequestType.RETRIEVE], + [ + RequestType.STORE, + RequestType.RETRIEVE, + RequestType.STORE_CPU_CHUNKS, + RequestType.RETRIEVE_CPU_CHUNKS, + ], max_workers=mp_config.max_gpu_workers, ) server.add_normal_thread_pool( diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py new file mode 100644 index 00000000000..baa983ee5d7 --- /dev/null +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from contextlib import contextmanager +from typing import Any, Callable +from unittest.mock import MagicMock, patch +import pickle +import sys + +# Third Party +import pytest +import torch + + +def _make_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + num_heads: int = 2, + head_size: int = 8, +) -> dict[str, torch.Tensor]: + """Build per-layer NHD KV tensors for CPU bounce-buffer tests.""" + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn( + 2, num_blocks, block_size, num_heads, head_size + ) + return kv_caches + + +def _make_mla_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + hidden_size: int = 16, +) -> dict[str, torch.Tensor]: + """Build per-layer MLA KV tensors for CPU bounce-buffer tests. + + Args: + num_layers: Number of KV layers to generate. + num_blocks: Number of paged blocks per layer. + block_size: Number of tokens per block. + hidden_size: Hidden size per token. + + Returns: + Mapping from layer name to MLA KV tensor with shape + ``[num_blocks, block_size, hidden_size]``. + """ + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn(num_blocks, block_size, hidden_size) + return kv_caches + + +def _make_hnd_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + num_heads: int = 2, + head_size: int = 8, +) -> dict[str, torch.Tensor]: + """Build per-layer HND KV tensors for CPU bounce-buffer tests.""" + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn( + 2, num_blocks, num_heads, block_size, head_size + ) + return kv_caches + + +def _make_hnd_flashinfer_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + num_heads: int = 2, + head_size: int = 8, +) -> dict[str, torch.Tensor]: + """Build per-layer HND flash-infer KV tensors for CPU bounce-buffer tests.""" + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn( + num_blocks, 2, num_heads, block_size, head_size + ) + return kv_caches + + +def test_wrap_kv_caches_bounce_returns_empty() -> None: + """Verify wrap_kv_caches returns no IPC wrappers in bounce-buffer mode.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches + + assert wrap_kv_caches(_make_kv_caches(), use_cpu_context=True) == [] + + +def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: + """Validate layout extraction and gather/scatter round-trip on CPU tensors.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import ( + compute_kv_layout, + gather_chunks_to_cpu, + scatter_cpu_chunks_to_kv, + ) + + source = _make_kv_caches(num_layers=2, num_blocks=8, block_size=4) + ( + block_size, + num_layers, + hidden_dim, + dtype_str, + detected_kv_format, + ) = compute_kv_layout(source) + assert block_size == 4 + assert num_layers == 2 + assert hidden_dim == 16 + assert dtype_str == "float32" + assert detected_kv_format is not None + + blocks_per_chunk = 2 + gathered = gather_chunks_to_cpu(source, [0, 1], blocks_per_chunk) + destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} + scatter_cpu_chunks_to_kv(destination, [4, 5], gathered, blocks_per_chunk) + + for name in source: + assert torch.allclose(source[name][:, 0], destination[name][:, 4]) + assert torch.allclose(source[name][:, 1], destination[name][:, 5]) + + +@pytest.mark.parametrize( + ("hnd_builder", "expected_format"), + [ + (_make_hnd_kv_caches, "NL_X_TWO_NB_NH_BS_HS"), + (_make_hnd_flashinfer_kv_caches, "NL_X_NB_TWO_NH_BS_HS"), + ], +) +def test_gather_scatter_roundtrip_hnd_layout( + hnd_builder: Callable[[int, int, int, int, int], dict[str, torch.Tensor]], + expected_format: str, +) -> None: + """Validate gather/scatter round-trip for HND vLLM KV layout.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import ( + compute_kv_layout, + gather_chunks_to_cpu, + scatter_cpu_chunks_to_kv, + ) + import lmcache.c_ops as lmc_ops + + source = hnd_builder(2, 8, 4, 2, 8) + layout_hints = {"kv_layout": "HND"} + ( + block_size, + num_layers, + hidden_dim, + dtype_str, + detected_kv_format, + ) = compute_kv_layout(source, layout_hints=layout_hints) + assert block_size == 4 + assert num_layers == 2 + assert hidden_dim == 16 + assert dtype_str == "float32" + assert detected_kv_format == getattr(lmc_ops.GPUKVFormat, expected_format) + + blocks_per_chunk = 2 + gathered = gather_chunks_to_cpu( + source, + [0, 1], + blocks_per_chunk, + layout_hints=layout_hints, + gpu_kv_format=detected_kv_format, + ) + destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} + scatter_cpu_chunks_to_kv( + destination, + [4, 5], + gathered, + blocks_per_chunk, + layout_hints=layout_hints, + gpu_kv_format=detected_kv_format, + ) + + for name in source: + if detected_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: + assert torch.allclose(source[name][:, 0], destination[name][:, 4]) + assert torch.allclose(source[name][:, 1], destination[name][:, 5]) + else: + assert torch.allclose(source[name][0], destination[name][4]) + assert torch.allclose(source[name][1], destination[name][5]) + + +def test_scatter_respects_skip_first_n_tokens() -> None: + """Ensure scatter honors skip_first_n_tokens and preserves skipped blocks.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import ( + gather_chunks_to_cpu, + scatter_cpu_chunks_to_kv, + ) + + source = _make_kv_caches(num_layers=2, num_blocks=8, block_size=4) + destination = { + name: torch.full_like(tensor, 999.0) for name, tensor in source.items() + } + gathered = gather_chunks_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_chunks_to_kv( + destination, + [0, 1, 2, 3], + gathered, + blocks_per_chunk=4, + skip_first_n_tokens=8, + ) + + for name in destination: + assert torch.all(destination[name][:, 0] == 999.0) + assert torch.all(destination[name][:, 1] == 999.0) + assert torch.allclose(destination[name][:, 2], source[name][:, 2]) + assert torch.allclose(destination[name][:, 3], source[name][:, 3]) + + +def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: + """Validate gather/scatter round-trip for MLA KV tensors.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import ( + compute_kv_layout, + gather_chunks_to_cpu, + scatter_cpu_chunks_to_kv, + ) + + source = _make_mla_kv_caches( + num_layers=2, num_blocks=8, block_size=4, hidden_size=16 + ) + ( + block_size, + num_layers, + hidden_dim, + dtype_str, + detected_kv_format, + ) = compute_kv_layout(source) + assert block_size == 4 + assert num_layers == 2 + assert hidden_dim == 16 + assert dtype_str == "float32" + assert detected_kv_format is not None + + blocks_per_chunk = 2 + gathered = gather_chunks_to_cpu(source, [0, 1], blocks_per_chunk) + destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} + scatter_cpu_chunks_to_kv(destination, [4, 5], gathered, blocks_per_chunk) + + for name in source: + assert torch.allclose(source[name][0], destination[name][4]) + assert torch.allclose(source[name][1], destination[name][5]) + + +def test_compute_kv_layout_empty_raises_value_error() -> None: + """Ensure compute_kv_layout rejects empty KV cache input.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import compute_kv_layout + + with pytest.raises(ValueError, match="kv_caches is empty"): + compute_kv_layout({}) + + +def test_scatter_mla_respects_skip_first_n_tokens() -> None: + """Ensure MLA scatter honors skip_first_n_tokens and preserves skipped blocks.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import ( + gather_chunks_to_cpu, + scatter_cpu_chunks_to_kv, + ) + + source = _make_mla_kv_caches( + num_layers=2, num_blocks=8, block_size=4, hidden_size=16 + ) + destination = { + name: torch.full_like(tensor, 999.0) for name, tensor in source.items() + } + gathered = gather_chunks_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_chunks_to_kv( + destination, + [0, 1, 2, 3], + gathered, + blocks_per_chunk=4, + skip_first_n_tokens=8, + ) + + for name in destination: + assert torch.all(destination[name][0] == 999.0) + assert torch.all(destination[name][1] == 999.0) + assert torch.allclose(destination[name][2], source[name][2]) + assert torch.allclose(destination[name][3], source[name][3]) + + +def test_scatter_mla_skip_past_chunk_keeps_destination_unchanged() -> None: + """Ensure MLA scatter is a no-op when skip_first_n_tokens exceeds chunk tokens.""" + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import ( + gather_chunks_to_cpu, + scatter_cpu_chunks_to_kv, + ) + + source = _make_mla_kv_caches( + num_layers=2, num_blocks=8, block_size=4, hidden_size=16 + ) + destination = { + name: torch.full_like(tensor, 123.0) for name, tensor in source.items() + } + gathered = gather_chunks_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_chunks_to_kv( + destination, + [0, 1, 2, 3], + gathered, + blocks_per_chunk=4, + skip_first_n_tokens=40, + ) + + for name in destination: + assert torch.all(destination[name] == 123.0) + + +@pytest.fixture +def stub_native_storage_ops() -> Any: + """Stub native modules so server imports work in source-only test runs.""" + module = type(sys)("lmcache.native_storage_ops") + module.TTLLock = type("TTLLock", (), {}) # type: ignore[attr-defined] + module.Bitmap = type("Bitmap", (), {}) # type: ignore[attr-defined] + with patch.dict( + sys.modules, + { + "lmcache.native_storage_ops": module, + "cupy": MagicMock(), + }, + ): + yield + + +def test_server_register_and_find_bounce_layout(stub_native_storage_ops: Any) -> None: + """Ensure bounce registration stores metadata and lookup finds its layout.""" + # First Party + from lmcache.v1.multiprocess.server import MPCacheEngine + + with ( + patch("lmcache.v1.multiprocess.server.StorageManager"), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager"), + patch("lmcache.v1.multiprocess.server.get_event_bus"), + ): + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) + engine.register_kv_cache_bounce( + instance_id=1, + model_name="m", + world_size=1, + engine_type=MagicMock(), + layout_hints={}, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + + layout = engine._find_layout_desc("m", 1) + assert layout is not None + assert layout.shapes[0] == torch.Size([2, 2, 16, 16]) + + +def test_server_store_and_retrieve_cpu_chunks(stub_native_storage_ops: Any) -> None: + """Validate mocked server-side CPU chunk store and retrieve behavior.""" + # First Party + from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey + from lmcache.v1.multiprocess.server import MPCacheEngine + + mock_storage = MagicMock() + target_tensor = torch.zeros(2, 2, 8, 16) + mock_memory_obj = MagicMock() + mock_memory_obj.tensor = target_tensor + mock_storage.reserve_write.return_value = {"obj": mock_memory_obj} + + @contextmanager + def _read_prefetched_results(_keys: Any) -> Any: + yield [mock_memory_obj] + + mock_storage.read_prefetched_results.side_effect = _read_prefetched_results + mock_session = MagicMock() + mock_session.get_hashes.return_value = [b"h"] + with ( + patch( + "lmcache.v1.multiprocess.server.StorageManager", + return_value=mock_storage, + ), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ), + ): + session_cls.return_value.get_or_create.return_value = mock_session + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + + engine.register_kv_cache_bounce( + instance_id=2, + model_name="m", + world_size=1, + engine_type=MagicMock(), + layout_hints={}, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + payload = torch.ones(2, 2, 8, 16) + key = IPCCacheEngineKey.from_token_ids( + "m", + 1, + 0, + [1] * 8, + start=0, + end=8, + request_id="req", + ) + with patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ): + store_ok = engine.store_cpu_chunks(key, 2, pickle.dumps([payload])) + success, cpu_data = engine.retrieve_cpu_chunks(key, 2) + assert isinstance(store_ok, bool) + assert torch.allclose(mock_memory_obj.tensor, payload) + + assert success is True + recovered_chunks: list[torch.Tensor] = pickle.loads(cpu_data) + assert len(recovered_chunks) == 1 + assert torch.allclose(recovered_chunks[0], payload) From 9a12d7ae85624ca99b2bee97fad83bdf58f9635d Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 12 May 2026 08:18:37 +0000 Subject: [PATCH 02/23] renaming bounce keyword to cpu context Signed-off-by: Tony Lin --- .../v1/multiprocess/cpu_context_design.md | 8 +++---- .../vllm/vllm_multi_process_adapter.py | 24 +++++++++---------- lmcache/v1/multiprocess/protocols/base.py | 2 +- lmcache/v1/multiprocess/protocols/engine.py | 4 ++-- lmcache/v1/multiprocess/server.py | 14 ++++++----- tests/v1/multiprocess/test_cpu_context.py | 22 +++++++++-------- 6 files changed, 39 insertions(+), 35 deletions(-) diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/cpu_context_design.md index 0f83205b411..ce48d5b94fe 100644 --- a/docs/design/v1/multiprocess/cpu_context_design.md +++ b/docs/design/v1/multiprocess/cpu_context_design.md @@ -27,7 +27,7 @@ provides a generic protocol where workers: ## Protocol additions Three request types are used for non-CUDA mode (unchanged from the original -bounce-buffer design): +cpu context design): - `REGISTER_KV_CACHE_BOUNCE` - `STORE_CPU_CHUNKS` @@ -129,7 +129,7 @@ Supported KV formats in CPU gather/scatter: by tensor `device.type`: - all CUDA → existing CUDA IPC registration and store/retrieve path -- all non-CUDA → bounce registration and CPU context store/retrieve path +- all non-CUDA → cpu context registration and CPU context store/retrieve path The adapter holds a `cpu_context: CPUContext` instance and uses the uniform `prepare/commit` interface for both store and retrieve. @@ -173,7 +173,7 @@ needs a `if self._use_cpu_context:` branch for retrieve futures. `(model_name, world_size)` for layout resolution. Server-side handler methods are unchanged: -- `register_kv_cache_bounce` — stores `CPUContextMetadata` in `cpu_contexts`. +- `register_kv_cache_cpu_context` — stores `CPUContextMetadata` in `cpu_contexts`. - `store_cpu_chunks` — unpickles payload, copies tensors into storage. - `retrieve_cpu_chunks` — reads from storage, pickles tensors, returns bytes. @@ -261,7 +261,7 @@ back to pickle when SHM is unavailable. `tests/v1/multiprocess/test_cpu_context.py` covers: -- CPU wrapper behavior (`wrap_kv_caches` with bounce mode) +- CPU wrapper behavior (`wrap_kv_caches` with cpu context mode) - NHD and MLA gather/scatter round-trip - HND round-trip for both HND formats - `skip_first_n_tokens` behavior diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index f0ed611175e..1c6f3e8334f 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -712,13 +712,13 @@ def __init__( str, tuple[MessagingFuture[RetrieveResult], list[int]] ] = {} - # Non-CUDA (bounce-buffer) mode state + # Non-CUDA (cpu context) mode state self._use_cpu_context: bool = False self._device_type: str = "cuda" - # CPU context for non-CUDA (bounce-buffer) mode + # CPU context for non-CUDA (cpu context) mode self.cpu_context: CPUContext | None = None - self._bounce_layout_hints: Any = None - self._bounce_gpu_kv_format: Any = None + self._cpu_layout_hints: Any = None + self._cpu_gpu_kv_format: Any = None # Completed synchronous CPU store/retrieve results, keyed by request_id self._cpu_store_done: dict[str, bool] = {} self._cpu_retrieve_done: dict[str, tuple[bool, list[int]]] = {} @@ -858,7 +858,7 @@ def _send_register_kv_caches_request( self._device_type = next(iter(device_types)) self._use_cpu_context = self._device_type != "cuda" logger.info( - "Registering kv caches (device_type=%s, bounce=%s)", + "Registering kv caches (device_type=%s, use_cpu_context=%s)", self._device_type, self._use_cpu_context, ) @@ -879,11 +879,11 @@ def _send_register_kv_caches_request( dtype_str, gpu_kv_format, ) = compute_kv_layout(kv_caches, layout_hints=layout_hints) - self._bounce_layout_hints = layout_hints - self._bounce_gpu_kv_format = gpu_kv_format + self._cpu_layout_hints = layout_hints + self._cpu_gpu_kv_format = gpu_kv_format future = send_lmcache_request( self.mq_client, - RequestType.REGISTER_KV_CACHE_BOUNCE, + RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, [ self.instance_id, self.model_name, @@ -1028,8 +1028,8 @@ def submit_store_request( self.kv_caches, op.block_ids, self.blocks_in_chunk, - layout_hints=self._bounce_layout_hints, - gpu_kv_format=self._bounce_gpu_kv_format, + layout_hints=self._cpu_layout_hints, + gpu_kv_format=self._cpu_gpu_kv_format, ) handle = self.cpu_context.prepare_store(key, self.instance_id, cpu_chunks) ok = self.cpu_context.commit_store(handle) @@ -1086,8 +1086,8 @@ def submit_retrieve_request( chunks, self.blocks_in_chunk, skip_first_n_tokens=op.skip_first_n_tokens, - layout_hints=self._bounce_layout_hints, - gpu_kv_format=self._bounce_gpu_kv_format, + layout_hints=self._cpu_layout_hints, + gpu_kv_format=self._cpu_gpu_kv_format, ) except Exception: logger.exception("Failed to scatter retrieved CPU context chunks") diff --git a/lmcache/v1/multiprocess/protocols/base.py b/lmcache/v1/multiprocess/protocols/base.py index cd00322c1fd..85d033cec60 100644 --- a/lmcache/v1/multiprocess/protocols/base.py +++ b/lmcache/v1/multiprocess/protocols/base.py @@ -48,7 +48,7 @@ class RequestType(enum.Enum): QUERY_PREFETCH_LOOKUP_HITS = enum.auto() FREE_LOOKUP_LOCKS = enum.auto() END_SESSION = enum.auto() - REGISTER_KV_CACHE_BOUNCE = enum.auto() + REGISTER_KV_CACHE_CPU_CONTEXT = enum.auto() STORE_CPU_CHUNKS = enum.auto() RETRIEVE_CPU_CHUNKS = enum.auto() diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index 63df2ac656f..df84c788ce4 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -32,7 +32,7 @@ "QUERY_PREFETCH_LOOKUP_HITS", "FREE_LOOKUP_LOCKS", "END_SESSION", - "REGISTER_KV_CACHE_BOUNCE", + "REGISTER_KV_CACHE_CPU_CONTEXT", "STORE_CPU_CHUNKS", "RETRIEVE_CPU_CHUNKS", ] @@ -149,7 +149,7 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: response_class=None, handler_type=HandlerType.BLOCKING, ), - "REGISTER_KV_CACHE_BOUNCE": ProtocolDefinition( + "REGISTER_KV_CACHE_CPU_CONTEXT": ProtocolDefinition( payload_classes=[ int, str, diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index b8dd0d7dbbf..2f932490b59 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -284,7 +284,7 @@ def unregister_kv_cache(self, instance_id: int) -> None: else: logger.warning("No KV cache found for GPU ID %d to unregister", instance_id) - def register_kv_cache_bounce( + def register_kv_cache_cpu_context( self, instance_id: int, model_name: str, @@ -297,7 +297,7 @@ def register_kv_cache_bounce( dtype_str: str, use_mla: bool, ) -> None: - """Register non-CUDA KV layout metadata for CPU bounce-buffer mode. + """Register non-CUDA KV layout metadata for CPU context mode. Args: instance_id: Worker instance identifier (typically PID). @@ -352,7 +352,7 @@ def store_cpu_chunks( instance_id: int, cpu_data: bytes, ) -> bool: - """Store worker-provided CPU chunks for non-CUDA bounce-buffer mode. + """Store worker-provided CPU chunks for non-CUDA cpu context mode. Args: key: Cache key for the token range to store. @@ -363,7 +363,7 @@ def store_cpu_chunks( ``True`` when all reserved objects are written, otherwise ``False``. Raises: - ValueError: If the instance has no registered bounce context. + ValueError: If the instance has no registered cpu context. """ # Third Party import torch @@ -424,7 +424,7 @@ def retrieve_cpu_chunks( list of CPU chunk tensors. Raises: - ValueError: If the instance has no registered bounce context. + ValueError: If the instance has no registered cpu context. """ session = self.session_manager.get_or_create(key.request_id) session.set_tokens(list(key.token_ids)) @@ -1337,7 +1337,9 @@ def run_cache_server( ) add_handler_helper(server, RequestType.STORE, engine.store) add_handler_helper( - server, RequestType.REGISTER_KV_CACHE_BOUNCE, engine.register_kv_cache_bounce + server, + RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, + engine.register_kv_cache_cpu_context, ) add_handler_helper(server, RequestType.STORE_CPU_CHUNKS, engine.store_cpu_chunks) add_handler_helper(server, RequestType.LOOKUP, engine.lookup) diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py index baa983ee5d7..6fb03e4181b 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -18,7 +18,7 @@ def _make_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer NHD KV tensors for CPU bounce-buffer tests.""" + """Build per-layer NHD KV tensors for CPU cpu context tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -33,7 +33,7 @@ def _make_mla_kv_caches( block_size: int = 4, hidden_size: int = 16, ) -> dict[str, torch.Tensor]: - """Build per-layer MLA KV tensors for CPU bounce-buffer tests. + """Build per-layer MLA KV tensors for CPU cpu context tests. Args: num_layers: Number of KV layers to generate. @@ -58,7 +58,7 @@ def _make_hnd_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer HND KV tensors for CPU bounce-buffer tests.""" + """Build per-layer HND KV tensors for CPU cpu context tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -74,7 +74,7 @@ def _make_hnd_flashinfer_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer HND flash-infer KV tensors for CPU bounce-buffer tests.""" + """Build per-layer HND flash-infer KV tensors for CPU cpu context tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -83,8 +83,8 @@ def _make_hnd_flashinfer_kv_caches( return kv_caches -def test_wrap_kv_caches_bounce_returns_empty() -> None: - """Verify wrap_kv_caches returns no IPC wrappers in bounce-buffer mode.""" +def test_wrap_kv_caches_cpu_context_returns_empty() -> None: + """Verify wrap_kv_caches returns no IPC wrappers in cpu context mode.""" # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches @@ -331,8 +331,10 @@ def stub_native_storage_ops() -> Any: yield -def test_server_register_and_find_bounce_layout(stub_native_storage_ops: Any) -> None: - """Ensure bounce registration stores metadata and lookup finds its layout.""" +def test_server_register_and_find_cpu_context_layout( + stub_native_storage_ops: Any, +) -> None: + """Ensure cpu context registration stores metadata and lookup finds its layout.""" # First Party from lmcache.v1.multiprocess.server import MPCacheEngine @@ -343,7 +345,7 @@ def test_server_register_and_find_bounce_layout(stub_native_storage_ops: Any) -> patch("lmcache.v1.multiprocess.server.get_event_bus"), ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) - engine.register_kv_cache_bounce( + engine.register_kv_cache_cpu_context( instance_id=1, model_name="m", world_size=1, @@ -396,7 +398,7 @@ def _read_prefetched_results(_keys: Any) -> Any: session_cls.return_value.get_or_create.return_value = mock_session engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) - engine.register_kv_cache_bounce( + engine.register_kv_cache_cpu_context( instance_id=2, model_name="m", world_size=1, From 08a4d10766dd7bed89d3cac3339460a4dc971c73 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 12 May 2026 21:01:29 +0800 Subject: [PATCH 03/23] fix unit test failures Signed-off-by: Tony Lin --- tests/v1/test_vllm_mp_adapter.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 08e2dc77de6..854dc6137ab 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -77,7 +77,9 @@ def fake_adapter(monkeypatch): def test_register_kv_caches_updates_kv_caches_and_submits(fake_adapter): """Public register_kv_caches stores the dict and submits one request.""" adapter, send_mock, _ = fake_adapter - new_caches = {"layer.0": object(), "layer.1": object()} + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + new_caches = {"layer.0": fake_tensor, "layer.1": fake_tensor} adapter.register_kv_caches(new_caches) @@ -93,4 +95,6 @@ def test_register_kv_caches_raises_connection_error_on_timeout(fake_adapter): future.result.side_effect = TimeoutError("server down") with pytest.raises(ConnectionError, match="did not respond"): - adapter.register_kv_caches({"layer.0": object()}) + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + adapter.register_kv_caches({"layer.0": fake_tensor}) From 715ac57dbc9cae8cd80eec46a8740f546e386fff Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 12 May 2026 21:19:18 +0800 Subject: [PATCH 04/23] address bot review comment Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/cpu_context.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lmcache/v1/multiprocess/cpu_context.py b/lmcache/v1/multiprocess/cpu_context.py index 126b8778986..853a89f9c18 100644 --- a/lmcache/v1/multiprocess/cpu_context.py +++ b/lmcache/v1/multiprocess/cpu_context.py @@ -253,8 +253,9 @@ def gather_chunks_to_cpu( ] if use_mla: mla_layers: list[torch.Tensor] = [] + idx = torch.tensor(chunk_block_ids, dtype=torch.long) for layer in layer_tensors: - layer_blocks = layer[torch.tensor(chunk_block_ids, dtype=torch.long)] + layer_blocks = layer[idx] mla_layers.append( layer_blocks.reshape( len(chunk_block_ids) * block_size, layer_blocks.shape[-1] @@ -386,13 +387,14 @@ def scatter_cpu_chunks_to_kv( chunk_device = chunk_cpu.to(device) if use_mla: + eff_idx = torch.tensor(effective_block_ids, dtype=torch.long) for layer_idx, layer in enumerate(layer_tensors): mla_src = chunk_device[layer_idx, skip_tokens:] hidden_size = layer.shape[-1] mla_src_3d = mla_src.reshape( len(effective_block_ids), block_size, hidden_size ) - layer[effective_block_ids] = mla_src_3d + layer[eff_idx] = mla_src_3d elif is_hnd: for layer_idx, layer in enumerate(layer_tensors): k_src = chunk_device[0, layer_idx, skip_tokens:] From 96adc9440cf14e78cbc1a83248c8a7d704bd5460 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 13 May 2026 08:51:45 +0800 Subject: [PATCH 05/23] refactor: standardize cpu context naming conventions Signed-off-by: Tony Lin --- .../v1/multiprocess/cpu_context_design.md | 23 +++++---- .../vllm/vllm_multi_process_adapter.py | 8 ++-- lmcache/v1/multiprocess/cpu_context.py | 8 ++-- tests/v1/multiprocess/test_cpu_context.py | 48 +++++++++---------- 4 files changed, 43 insertions(+), 44 deletions(-) diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/cpu_context_design.md index ce48d5b94fe..651339063ff 100644 --- a/docs/design/v1/multiprocess/cpu_context_design.md +++ b/docs/design/v1/multiprocess/cpu_context_design.md @@ -18,10 +18,10 @@ The CUDA path uses IPC wrappers around GPU tensors and the existing For non-CUDA tensors, CUDA IPC is not available. The CPU context path provides a generic protocol where workers: -1. Gather KV blocks into CPU chunk tensors. +1. Gather KV blocks into CPU chunk(or memory obj) tensors. 2. Transport those CPU chunks to the server storage through a concrete `CPUContext` implementation. -3. Retrieve CPU chunks from the server and scatter them back into device KV +3. Retrieve CPU chunks(or memory obj) from the server and scatter them back into device KV tensors. ## Protocol additions @@ -29,7 +29,7 @@ provides a generic protocol where workers: Three request types are used for non-CUDA mode (unchanged from the original cpu context design): -- `REGISTER_KV_CACHE_BOUNCE` +- `REGISTER_KV_CACHE_CPU_CONTEXT` - `STORE_CPU_CHUNKS` - `RETRIEVE_CPU_CHUNKS` @@ -48,8 +48,7 @@ lmcache/v1/multiprocess/ Provides: -- **`CPUContextMetadata`** dataclass — layout metadata (replaces the old - `CPUBounceContext` dataclass): +- **`CPUContextMetadata`** dataclass — layout metadata: ```python @dataclass @@ -86,9 +85,9 @@ Provides: - **Shared utility functions** used by all concrete implementations: - `compute_kv_layout` — extract block size, layer count, hidden dim and dtype from live KV tensors. - - `gather_chunks_to_cpu` — gather paged KV blocks into a list of CPU + - `gather_paged_kv_to_cpu` — gather paged KV blocks into a list of CPU tensors (one per LMCache chunk). - - `scatter_cpu_chunks_to_kv` — scatter CPU chunk tensors back into paged + - `scatter_cpu_to_paged_kv` — scatter CPU chunk tensors back into paged KV tensors. ### `cpu_context_pickle.py` @@ -138,7 +137,7 @@ The adapter holds a `cpu_context: CPUContext` instance and uses the uniform ```python # submit_store_request -cpu_chunks = gather_chunks_to_cpu(kv_caches, block_ids, blocks_in_chunk, ...) +cpu_chunks = gather_paged_kv_to_cpu(kv_caches, block_ids, blocks_in_chunk, ...) handle = self.cpu_context.prepare_store(key, instance_id, cpu_chunks) ok = self.cpu_context.commit_store(handle) # synchronous; blocks for server ack self._cpu_store_done[request_id] = ok @@ -152,7 +151,7 @@ self._cpu_store_done[request_id] = ok # submit_retrieve_request handle, chunks = self.cpu_context.prepare_retrieve(key, instance_id) # synchronous if chunks is not None: - scatter_cpu_chunks_to_kv(kv_caches, block_ids, chunks, blocks_in_chunk, + scatter_cpu_to_paged_kv(kv_caches, block_ids, chunks, blocks_in_chunk, skip_first_n_tokens=op.skip_first_n_tokens, ...) self.cpu_context.commit_retrieve(handle) self._cpu_retrieve_done[request_id] = (chunks is not None, block_ids) @@ -199,7 +198,7 @@ Additional integration points: [device == cuda] [device != cuda] | | v v - REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_BOUNCE (CPU metadata) + REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_CPU_CONTEXT (CPU metadata) | + create_cpu_context() +----------------+----------------+ | @@ -212,7 +211,7 @@ Additional integration points: submit_store() submit_store() | | v v - STORE (GPU -> L1) gather_chunks_to_cpu() + STORE (GPU -> L1) gather_paged_kv_to_cpu() | + cpu_context.prepare_store() v + cpu_context.commit_store() [sync] [READY] _cpu_store_done[id] = ok @@ -226,7 +225,7 @@ Additional integration points: | | v v RETRIEVE (L1 -> GPU) cpu_context.prepare_retrieve() [sync] - [async future] + scatter_cpu_chunks_to_kv() + [async future] + scatter_cpu_to_paged_kv() + cpu_context.commit_retrieve() _cpu_retrieve_done[id] = (ok, block_ids) | | diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 1c6f3e8334f..1c6f050548a 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -17,8 +17,8 @@ from lmcache.v1.multiprocess.cpu_context import ( CPUContext, compute_kv_layout, - gather_chunks_to_cpu, - scatter_cpu_chunks_to_kv, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, ) from lmcache.v1.multiprocess.custom_types import ( BlockAllocationRecord, @@ -1024,7 +1024,7 @@ def submit_store_request( if self._use_cpu_context: assert self.cpu_context is not None torch_dev.synchronize() - cpu_chunks = gather_chunks_to_cpu( + cpu_chunks = gather_paged_kv_to_cpu( self.kv_caches, op.block_ids, self.blocks_in_chunk, @@ -1080,7 +1080,7 @@ def submit_retrieve_request( ok = chunks is not None if chunks is not None: try: - scatter_cpu_chunks_to_kv( + scatter_cpu_to_paged_kv( self.kv_caches, op.block_ids, chunks, diff --git a/lmcache/v1/multiprocess/cpu_context.py b/lmcache/v1/multiprocess/cpu_context.py index 853a89f9c18..f627f86920d 100644 --- a/lmcache/v1/multiprocess/cpu_context.py +++ b/lmcache/v1/multiprocess/cpu_context.py @@ -8,7 +8,7 @@ ``CPUContextPickle``) each decide *how* data is serialised and transported. - ``create_cpu_context()``: factory that returns the appropriate ``CPUContext`` subclass (currently always ``CPUContextPickle``). -- ``compute_kv_layout``, ``gather_chunks_to_cpu``, ``scatter_cpu_chunks_to_kv``: +- ``compute_kv_layout``, ``gather_paged_kv_to_cpu``, ``scatter_cpu_to_paged_kv``: shared gather/scatter utilities used by all concrete implementations. """ @@ -197,7 +197,7 @@ def compute_kv_layout( return block_size, num_layers, hidden_dim_size, dtype_str, gpu_kv_format -def gather_chunks_to_cpu( +def gather_paged_kv_to_cpu( kv_caches: dict[str, torch.Tensor], block_ids: list[int], blocks_per_chunk: int, @@ -314,7 +314,7 @@ def gather_chunks_to_cpu( return chunks -def scatter_cpu_chunks_to_kv( +def scatter_cpu_to_paged_kv( kv_caches: dict[str, torch.Tensor], block_ids: list[int], chunks: list[torch.Tensor], @@ -329,7 +329,7 @@ def scatter_cpu_chunks_to_kv( kv_caches: Per-layer KV tensor mapping to write into. block_ids: Flattened destination block IDs for all chunks. chunks: List of CPU chunk tensors (as returned by - :func:`gather_chunks_to_cpu`). + :func:`gather_paged_kv_to_cpu`). blocks_per_chunk: Number of paged blocks in one LMCache chunk. skip_first_n_tokens: Token prefix to skip when scattering. layout_hints: Optional engine layout hints. diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py index 6fb03e4181b..ea39a45c9c7 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -96,8 +96,8 @@ def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import ( compute_kv_layout, - gather_chunks_to_cpu, - scatter_cpu_chunks_to_kv, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, ) source = _make_kv_caches(num_layers=2, num_blocks=8, block_size=4) @@ -115,9 +115,9 @@ def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: assert detected_kv_format is not None blocks_per_chunk = 2 - gathered = gather_chunks_to_cpu(source, [0, 1], blocks_per_chunk) + gathered = gather_paged_kv_to_cpu(source, [0, 1], blocks_per_chunk) destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} - scatter_cpu_chunks_to_kv(destination, [4, 5], gathered, blocks_per_chunk) + scatter_cpu_to_paged_kv(destination, [4, 5], gathered, blocks_per_chunk) for name in source: assert torch.allclose(source[name][:, 0], destination[name][:, 4]) @@ -139,8 +139,8 @@ def test_gather_scatter_roundtrip_hnd_layout( # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import ( compute_kv_layout, - gather_chunks_to_cpu, - scatter_cpu_chunks_to_kv, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, ) import lmcache.c_ops as lmc_ops @@ -160,7 +160,7 @@ def test_gather_scatter_roundtrip_hnd_layout( assert detected_kv_format == getattr(lmc_ops.GPUKVFormat, expected_format) blocks_per_chunk = 2 - gathered = gather_chunks_to_cpu( + gathered = gather_paged_kv_to_cpu( source, [0, 1], blocks_per_chunk, @@ -168,7 +168,7 @@ def test_gather_scatter_roundtrip_hnd_layout( gpu_kv_format=detected_kv_format, ) destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} - scatter_cpu_chunks_to_kv( + scatter_cpu_to_paged_kv( destination, [4, 5], gathered, @@ -190,16 +190,16 @@ def test_scatter_respects_skip_first_n_tokens() -> None: """Ensure scatter honors skip_first_n_tokens and preserves skipped blocks.""" # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import ( - gather_chunks_to_cpu, - scatter_cpu_chunks_to_kv, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, ) source = _make_kv_caches(num_layers=2, num_blocks=8, block_size=4) destination = { name: torch.full_like(tensor, 999.0) for name, tensor in source.items() } - gathered = gather_chunks_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) - scatter_cpu_chunks_to_kv( + gathered = gather_paged_kv_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_to_paged_kv( destination, [0, 1, 2, 3], gathered, @@ -219,8 +219,8 @@ def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import ( compute_kv_layout, - gather_chunks_to_cpu, - scatter_cpu_chunks_to_kv, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, ) source = _make_mla_kv_caches( @@ -240,9 +240,9 @@ def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: assert detected_kv_format is not None blocks_per_chunk = 2 - gathered = gather_chunks_to_cpu(source, [0, 1], blocks_per_chunk) + gathered = gather_paged_kv_to_cpu(source, [0, 1], blocks_per_chunk) destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} - scatter_cpu_chunks_to_kv(destination, [4, 5], gathered, blocks_per_chunk) + scatter_cpu_to_paged_kv(destination, [4, 5], gathered, blocks_per_chunk) for name in source: assert torch.allclose(source[name][0], destination[name][4]) @@ -262,8 +262,8 @@ def test_scatter_mla_respects_skip_first_n_tokens() -> None: """Ensure MLA scatter honors skip_first_n_tokens and preserves skipped blocks.""" # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import ( - gather_chunks_to_cpu, - scatter_cpu_chunks_to_kv, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, ) source = _make_mla_kv_caches( @@ -272,8 +272,8 @@ def test_scatter_mla_respects_skip_first_n_tokens() -> None: destination = { name: torch.full_like(tensor, 999.0) for name, tensor in source.items() } - gathered = gather_chunks_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) - scatter_cpu_chunks_to_kv( + gathered = gather_paged_kv_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_to_paged_kv( destination, [0, 1, 2, 3], gathered, @@ -292,8 +292,8 @@ def test_scatter_mla_skip_past_chunk_keeps_destination_unchanged() -> None: """Ensure MLA scatter is a no-op when skip_first_n_tokens exceeds chunk tokens.""" # First Party from lmcache.integration.vllm.vllm_multi_process_adapter import ( - gather_chunks_to_cpu, - scatter_cpu_chunks_to_kv, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, ) source = _make_mla_kv_caches( @@ -302,8 +302,8 @@ def test_scatter_mla_skip_past_chunk_keeps_destination_unchanged() -> None: destination = { name: torch.full_like(tensor, 123.0) for name, tensor in source.items() } - gathered = gather_chunks_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) - scatter_cpu_chunks_to_kv( + gathered = gather_paged_kv_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_to_paged_kv( destination, [0, 1, 2, 3], gathered, From e88862001b61646e5b33901d5ff306ac7cdd29d4 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 13 May 2026 09:56:35 +0800 Subject: [PATCH 06/23] small fix on error handling Signed-off-by: Tony Lin --- lmcache/integration/vllm/vllm_multi_process_adapter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 1c6f050548a..65faa96f19c 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -1290,6 +1290,7 @@ def get_finished( request_id, r_result, ) + self.error_block_ids.update(r_block_ids) # Remove the finished requests from the tracking dicts for request_id in finished_stores: From cf70122e0e8a25921577d26440300ec4f6761109 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 13 May 2026 05:20:02 +0000 Subject: [PATCH 07/23] refactor: polymorphic TransferContext for MP adapter transport layer Replace scattered if/else CPU/CUDA branches in vllm_multi_process_adapter with a TransferContext abstraction (ABC + CudaTransferContext + CPUTransferContext). - Add transfer_context.py with unified register/store/retrieve/poll interface - Device-type dispatch centralized in create_transfer_context() factory - Adapter delegates all transport logic via polymorphism, no branching - Future transports (e.g. SHM) only need a new subclass, zero adapter changes Signed-off-by: Tony Lin --- .../v1/multiprocess/cpu_context_design.md | 81 +-- .../vllm/vllm_multi_process_adapter.py | 332 +++--------- lmcache/v1/multiprocess/transfer_context.py | 484 ++++++++++++++++++ tests/v1/multiprocess/test_cpu_context.py | 14 +- tests/v1/test_vllm_mp_adapter.py | 62 +++ 5 files changed, 678 insertions(+), 295 deletions(-) create mode 100644 lmcache/v1/multiprocess/transfer_context.py diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/cpu_context_design.md index 651339063ff..b74d27185b2 100644 --- a/docs/design/v1/multiprocess/cpu_context_design.md +++ b/docs/design/v1/multiprocess/cpu_context_design.md @@ -124,44 +124,56 @@ Supported KV formats in CPU gather/scatter: ## Worker adapter integration -`lmcache/integration/vllm/vllm_multi_process_adapter.py` chooses the path -by tensor `device.type`: +The adapter now delegates transport behavior to +`lmcache/v1/multiprocess/transfer_context.py`. + +`create_transfer_context(kv_caches)` centralizes device dispatch: - all CUDA → existing CUDA IPC registration and store/retrieve path -- all non-CUDA → cpu context registration and CPU context store/retrieve path +- all non-CUDA → `CPUTransferContext` with cpu context registration and CPU context store/retrieve path + +`LMCacheMPSchedulerAdapter` now holds `self.transfer_ctx: TransferContext | None` +and calls: -The adapter holds a `cpu_context: CPUContext` instance and uses the uniform -`prepare/commit` interface for both store and retrieve. +- `transfer_ctx.register(...)` +- `transfer_ctx.submit_store(...)` +- `transfer_ctx.submit_retrieve(...)` +- `transfer_ctx.poll_finished()` (healthy) or `transfer_ctx.drain_all()` (unhealthy) ### Store path (non-CUDA) ```python -# submit_store_request +# CPUTransferContext.submit_store cpu_chunks = gather_paged_kv_to_cpu(kv_caches, block_ids, blocks_in_chunk, ...) -handle = self.cpu_context.prepare_store(key, instance_id, cpu_chunks) -ok = self.cpu_context.commit_store(handle) # synchronous; blocks for server ack -self._cpu_store_done[request_id] = ok +handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) +ok = self._cpu_context.commit_store(handle) # synchronous; blocks for server ack +self._store_done[request_id] = ok ``` -`get_finished` drains `_cpu_store_done` on each call. +`CPUTransferContext.poll_finished()` drains `_store_done` on each call. ### Retrieve path (non-CUDA) ```python -# submit_retrieve_request -handle, chunks = self.cpu_context.prepare_retrieve(key, instance_id) # synchronous +# CPUTransferContext.submit_retrieve +handle, chunks = self._cpu_context.prepare_retrieve(key, instance_id) # synchronous +ok = chunks is not None if chunks is not None: - scatter_cpu_to_paged_kv(kv_caches, block_ids, chunks, blocks_in_chunk, - skip_first_n_tokens=op.skip_first_n_tokens, ...) -self.cpu_context.commit_retrieve(handle) -self._cpu_retrieve_done[request_id] = (chunks is not None, block_ids) + try: + scatter_cpu_to_paged_kv(kv_caches, block_ids, chunks, blocks_in_chunk, + skip_first_n_tokens=skip_first_n_tokens, ...) + except (RuntimeError, ValueError, TypeError, IndexError): + ok = False +self._cpu_context.commit_retrieve(handle) +self._retrieve_done[request_id] = (ok, block_ids) ``` -`get_finished` drains `_cpu_retrieve_done` on each call. +`CPUTransferContext.poll_finished()` drains `_retrieve_done` on each call. +The adapter passes `op.skip_first_n_tokens` into +`transfer_ctx.submit_retrieve(..., skip_first_n_tokens=...)`. -The retrieve is now **synchronous in `submit_retrieve_request`**; there is no -separate future to poll. This simplifies `get_finished` which no longer -needs a `if self._use_cpu_context:` branch for retrieve futures. +The retrieve is **synchronous inside `CPUTransferContext.submit_retrieve`**; +`poll_finished()` just drains request ids recorded by submit methods. ## Server integration @@ -190,7 +202,7 @@ Additional integration points: register_kv_caches() | v - [Inspect device.type] + create_transfer_context(kv_caches) | +----------------+----------------+ | | @@ -198,7 +210,8 @@ Additional integration points: [device == cuda] [device != cuda] | | v v - REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_CPU_CONTEXT (CPU metadata) + CudaTransferContext.register() CPUTransferContext.register() + REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_CPU_CONTEXT (CPU metadata) | + create_cpu_context() +----------------+----------------+ | @@ -208,26 +221,26 @@ Additional integration points: +----------------+----------------+ | | v v - submit_store() submit_store() + transfer_ctx.submit_store() transfer_ctx.submit_store() | | v v - STORE (GPU -> L1) gather_paged_kv_to_cpu() - | + cpu_context.prepare_store() - v + cpu_context.commit_store() [sync] - [READY] _cpu_store_done[id] = ok + STORE (GPU -> L1) gather_paged_kv_to_cpu() + | + _cpu_context.prepare_store() + v + _cpu_context.commit_store() [sync] + [READY] _store_done[id] = ok | | +----------------+----------------+ | v - submit_retrieve() + get_finished() + transfer_ctx.submit_retrieve() + get_finished() | +----------------+----------------+ | | v v - RETRIEVE (L1 -> GPU) cpu_context.prepare_retrieve() [sync] - [async future] + scatter_cpu_to_paged_kv() - + cpu_context.commit_retrieve() - _cpu_retrieve_done[id] = (ok, block_ids) + RETRIEVE (L1 -> GPU) _cpu_context.prepare_retrieve() [sync] + [async future] + scatter_cpu_to_paged_kv() + + _cpu_context.commit_retrieve() + _retrieve_done[id] = (ok, block_ids) | | +----------------+----------------+ | @@ -266,6 +279,10 @@ back to pickle when SHM is unavailable. - `skip_first_n_tokens` behavior - Server-side register/store/retrieve flow +`tests/v1/test_vllm_mp_adapter.py` covers transfer-context integration, +including CPU registration path (`REGISTER_KV_CACHE_CPU_CONTEXT`) and +store/retrieve submit delegation. + ## Non-goals - No change to existing CUDA IPC path semantics. diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 65faa96f19c..10fec34cab7 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -11,15 +11,8 @@ import zmq # First Party -from lmcache import torch_dev from lmcache.integration.request_telemetry.factory import RequestTelemetryFactory -from lmcache.utils import EngineType, _lmcache_nvtx_annotate, init_logger -from lmcache.v1.multiprocess.cpu_context import ( - CPUContext, - compute_kv_layout, - gather_paged_kv_to_cpu, - scatter_cpu_to_paged_kv, -) +from lmcache.utils import _lmcache_nvtx_annotate, init_logger from lmcache.v1.multiprocess.custom_types import ( BlockAllocationRecord, CudaIPCWrapper, @@ -28,6 +21,10 @@ ) from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture from lmcache.v1.multiprocess.protocol import RequestType, get_response_class +from lmcache.v1.multiprocess.transfer_context import ( + TransferContext, + create_transfer_context, +) from lmcache.v1.periodic_thread import PeriodicThread, ThreadLevel, ThreadRunSummary logger = init_logger(__name__) @@ -705,23 +702,10 @@ def __init__( # Registered kv caches from vLLM self.kv_caches: dict[str, torch.Tensor] = {} - # Request futures - self.store_futures: dict[str, MessagingFuture[StoreResult]] = {} - # request_id -> (future, block_ids) - self.retrieve_futures: dict[ - str, tuple[MessagingFuture[RetrieveResult], list[int]] - ] = {} - - # Non-CUDA (cpu context) mode state - self._use_cpu_context: bool = False - self._device_type: str = "cuda" - # CPU context for non-CUDA (cpu context) mode - self.cpu_context: CPUContext | None = None - self._cpu_layout_hints: Any = None - self._cpu_gpu_kv_format: Any = None - # Completed synchronous CPU store/retrieve results, keyed by request_id - self._cpu_store_done: dict[str, bool] = {} - self._cpu_retrieve_done: dict[str, tuple[bool, list[int]]] = {} + # Transport context for transfer operations. + self.transfer_ctx: TransferContext | None = None + # Store requests submitted but not yet finished by transfer context. + self._pending_store_request_ids: set[str] = set() # Block IDs that failed due to retrieve timeout self.error_block_ids: set[int] = set() @@ -842,96 +826,19 @@ def _send_register_kv_caches_request( ConnectionError: if the server does not respond within mq_timeout. """ - # First Party - from lmcache.integration.vllm.utils import vllm_layout_hints - - layout_hints = vllm_layout_hints() self.kv_caches = kv_caches - - if not kv_caches: - raise ValueError("kv_caches is empty") - device_types = {tensor.device.type for tensor in kv_caches.values()} - if len(device_types) != 1: - raise ValueError( - f"All KV cache tensors must share one device type, got {device_types}" - ) - self._device_type = next(iter(device_types)) - self._use_cpu_context = self._device_type != "cuda" - logger.info( - "Registering kv caches (device_type=%s, use_cpu_context=%s)", - self._device_type, - self._use_cpu_context, - ) - - if self._use_cpu_context: - # First Party - from lmcache.v1.distributed.api import MemoryLayoutDesc - from lmcache.v1.gpu_connector.utils import is_mla - from lmcache.v1.multiprocess.cpu_context import ( - CPUContextMetadata, - create_cpu_context, - ) - - ( - block_size, - num_layers, - hidden_dim_size, - dtype_str, - gpu_kv_format, - ) = compute_kv_layout(kv_caches, layout_hints=layout_hints) - self._cpu_layout_hints = layout_hints - self._cpu_gpu_kv_format = gpu_kv_format - future = send_lmcache_request( - self.mq_client, - RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, - [ - self.instance_id, - self.model_name, - self.world_size, - EngineType.VLLM, - layout_hints, - block_size, - num_layers, - hidden_dim_size, - dtype_str, - is_mla(gpu_kv_format), - ], - ) - # Build the layout descriptor so we can construct the CPUContext. - use_mla_flag = is_mla(gpu_kv_format) - shape = ( - torch.Size( - [num_layers, self.blocks_in_chunk * block_size, hidden_dim_size] - ) - if use_mla_flag - else torch.Size( - [2, num_layers, self.blocks_in_chunk * block_size, hidden_dim_size] - ) - ) - dtype = getattr(torch, dtype_str) - metadata = CPUContextMetadata( - layout_desc=MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]), - block_size=block_size, - use_mla=use_mla_flag, - ) - self.cpu_context = create_cpu_context( - metadata, self.mq_client, self._mq_timeout - ) - else: - future = send_lmcache_request( + self.transfer_ctx = create_transfer_context(kv_caches) + try: + self.transfer_ctx.register( + self.instance_id, + kv_caches, + self.model_name, + self.world_size, + self.blocks_in_chunk, self.mq_client, - RequestType.REGISTER_KV_CACHE, - [ - self.instance_id, - wrap_kv_caches(kv_caches), - self.model_name, - self.world_size, - EngineType.VLLM, - layout_hints, - ], + self._mq_timeout, + send_request=send_lmcache_request, ) - try: - future.result(timeout=self._mq_timeout) except TimeoutError: raise ConnectionError( "LMCache server did not respond to " @@ -1021,26 +928,21 @@ def submit_store_request( request_id=request_id, cache_salt=cache_salt, ) - if self._use_cpu_context: - assert self.cpu_context is not None - torch_dev.synchronize() - cpu_chunks = gather_paged_kv_to_cpu( - self.kv_caches, - op.block_ids, - self.blocks_in_chunk, - layout_hints=self._cpu_layout_hints, - gpu_kv_format=self._cpu_gpu_kv_format, + if self.transfer_ctx is None: + raise RuntimeError( + "Transfer context is not initialized. " + "Call register_kv_caches() before submitting store requests." ) - handle = self.cpu_context.prepare_store(key, self.instance_id, cpu_chunks) - ok = self.cpu_context.commit_store(handle) - self._cpu_store_done[request_id] = ok - else: - future = send_lmcache_request( - self.mq_client, - RequestType.STORE, - [key, self.instance_id, op.block_ids, event.ipc_handle()], - ).to_cuda_future() - self.store_futures[request_id] = future + self.transfer_ctx.submit_store( + request_id, + key, + self.instance_id, + self.kv_caches, + op.block_ids, + event, + self.blocks_in_chunk, + ) + self._pending_store_request_ids.add(request_id) @_lmcache_nvtx_annotate def submit_retrieve_request( @@ -1074,39 +976,21 @@ def submit_retrieve_request( request_id=request_id, cache_salt=cache_salt, ) - if self._use_cpu_context: - assert self.cpu_context is not None - handle, chunks = self.cpu_context.prepare_retrieve(key, self.instance_id) - ok = chunks is not None - if chunks is not None: - try: - scatter_cpu_to_paged_kv( - self.kv_caches, - op.block_ids, - chunks, - self.blocks_in_chunk, - skip_first_n_tokens=op.skip_first_n_tokens, - layout_hints=self._cpu_layout_hints, - gpu_kv_format=self._cpu_gpu_kv_format, - ) - except Exception: - logger.exception("Failed to scatter retrieved CPU context chunks") - ok = False - self.cpu_context.commit_retrieve(handle) - self._cpu_retrieve_done[request_id] = (ok, list(op.block_ids)) - else: - future = send_lmcache_request( - self.mq_client, - RequestType.RETRIEVE, - [ - key, - self.instance_id, - op.block_ids, - event.ipc_handle(), - op.skip_first_n_tokens, - ], - ).to_cuda_future() - self.retrieve_futures[request_id] = (future, list(op.block_ids)) + if self.transfer_ctx is None: + raise RuntimeError( + "Transfer context is not initialized. " + "Call register_kv_caches() before submitting retrieve requests." + ) + self.transfer_ctx.submit_retrieve( + request_id, + key, + self.instance_id, + self.kv_caches, + op.block_ids, + event, + self.blocks_in_chunk, + skip_first_n_tokens=op.skip_first_n_tokens, + ) @_lmcache_nvtx_annotate def batched_submit_store_requests( @@ -1171,7 +1055,10 @@ def _process_finished_stores( for req_id in finished_req_ids_from_engine: if req_id in self._returned_finished: continue - if req_id in self.finished_stores or req_id in self.store_futures: + if ( + req_id in self.finished_stores + or req_id in self._pending_store_request_ids + ): self.previously_finished.add(req_id) else: ret_stores.add(req_id) @@ -1203,106 +1090,35 @@ def get_finished( take care of deduplicating the request IDs and only return the request IDs that have not been returned before. """ - # If unhealthy, drain all pending futures immediately - if not self.is_healthy: - finished_stores = set(self.store_futures.keys()) | set( - self._cpu_store_done.keys() + if self.transfer_ctx is None: + return set(), set() + + unhealthy = not self.is_healthy + if unhealthy: + finished_stores, finished_retrieves, error_block_ids = ( + self.transfer_ctx.drain_all() ) - finished_retrieves = set() - for request_id, ( - _r_future, - r_block_ids, - ) in self.retrieve_futures.items(): - finished_retrieves.add(request_id) - self.error_block_ids.update(r_block_ids) - for request_id, (ok, r_block_ids) in self._cpu_retrieve_done.items(): - finished_retrieves.add(request_id) - if not ok: - self.error_block_ids.update(r_block_ids) - self.store_futures.clear() - self.retrieve_futures.clear() - self._cpu_store_done.clear() - self._cpu_retrieve_done.clear() - - ret_stores = self._process_finished_stores( - finished_stores, finished_req_ids_from_engine + else: + finished_stores, finished_retrieves, error_block_ids = ( + self.transfer_ctx.poll_finished() ) - # A request may have a pending retrieve AND appear in - # finished_req_ids_from_engine (it ran without loading KV after - # the server died). The scheduler processes finished_recving - # first and deletes the request, so we must not also report it - # in finished_sending. - ret_stores -= finished_retrieves - return ret_stores, finished_retrieves - - finished_stores = set() - finished_retrieves = set() - - # Drain completed synchronous CPU store results - for request_id, ok in list(self._cpu_store_done.items()): - finished_stores.add(request_id) - if not ok: - logger.error( - "Something went wrong when processing the " - "store request for request_id=%s", - request_id, - ) - self._cpu_store_done.clear() - - # Drain completed synchronous CPU retrieve results - for request_id, (ok, r_block_ids) in list(self._cpu_retrieve_done.items()): - finished_retrieves.add(request_id) - if not ok: - logger.error( - "Something went wrong when processing the " - "retrieve request for request_id=%s, result=%s", - request_id, - ok, - ) - self.error_block_ids.update(r_block_ids) - self._cpu_retrieve_done.clear() - - for request_id, s_future in self.store_futures.items(): - if not s_future.query(): - continue - - s_result = s_future.result() - finished_stores.add(request_id) - if not s_result: - logger.error( - "Something went wrong when processing the " - "store request for request_id=%s", - request_id, - ) - - for request_id, (r_future, r_block_ids) in self.retrieve_futures.items(): - if not r_future.query(): - continue - - r_result = r_future.result() - finished_retrieves.add(request_id) - - if not r_result: - logger.error( - "Something went wrong when processing the " - "retrieve request for request_id=%s, result=%s", - request_id, - r_result, - ) - self.error_block_ids.update(r_block_ids) - - # Remove the finished requests from the tracking dicts - for request_id in finished_stores: - self.store_futures.pop(request_id, None) - for request_id in finished_retrieves: - self.retrieve_futures.pop(request_id, None) + self.error_block_ids.update(error_block_ids) + self._pending_store_request_ids.difference_update(finished_stores) # Update the internal states ret_stores = self._process_finished_stores( finished_stores, finished_req_ids_from_engine ) + if unhealthy: + # A request may have a pending retrieve AND appear in + # finished_req_ids_from_engine (it ran without loading KV after + # the server died). The scheduler processes finished_recving + # first and deletes the request, so we must not also report it + # in finished_sending. + ret_stores -= finished_retrieves + # the invocation of `get_finished` means that # these requests' KV caches are already fully stored. # or the requests normally ends without any store. @@ -1350,6 +1166,10 @@ def shutdown(self): self._mq_timeout, ) + if self.transfer_ctx is not None: + self.transfer_ctx.close() + self.transfer_ctx = None + self.mq_client.close() self.request_telemetry.close() diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py new file mode 100644 index 00000000000..12362d35923 --- /dev/null +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -0,0 +1,484 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Transfer context abstractions for LMCache multiprocess worker adapters.""" + +# Standard +from abc import ABC, abstractmethod +from typing import Any, Callable, Protocol + +# Third Party +import torch + +# First Party +from lmcache import torch_dev +from lmcache.utils import EngineType, init_logger +from lmcache.v1.distributed.api import MemoryLayoutDesc +from lmcache.v1.gpu_connector.utils import is_mla +from lmcache.v1.multiprocess.cpu_context import ( + CPUContext, + CPUContextMetadata, + compute_kv_layout, + create_cpu_context, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, +) +from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture +from lmcache.v1.multiprocess.protocol import RequestType + +logger = init_logger(__name__) + + +class IPCEvent(Protocol): + """Protocol for IPC-capable CUDA events used by transport operations.""" + + def ipc_handle(self) -> object: + """Return an IPC handle consumable by the multiprocess server.""" + + +SendRequest = Callable[ + [MessageQueueClient, RequestType, list[object]], MessagingFuture[object] +] + + +class TransferContext(ABC): + """Abstract transport layer for worker-side KV transfer. + + Concrete implementations encapsulate how worker-side store/retrieve + operations are transmitted to the multiprocess server (for example, + CUDA IPC futures or CPU-context gather/scatter flows). + """ + + @abstractmethod + def register( + self, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + model_name: str, + world_size: int, + blocks_in_chunk: int, + mq_client: MessageQueueClient, + mq_timeout: float, + send_request: SendRequest, + ) -> None: + """Register KV caches with the server and wait for ACK. + + Args: + instance_id: Worker process instance id. + kv_caches: Worker KV cache tensors keyed by layer name. + model_name: Model name used by cache keys. + world_size: KV world size. + blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. + mq_client: Message queue client used to communicate with server. + mq_timeout: Timeout in seconds for synchronous request wait. + send_request: Request sender callable used to issue MQ requests. + """ + + @abstractmethod + def submit_store( + self, + request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + blocks_in_chunk: int, + ) -> None: + """Submit a store request. + + Args: + request_id: Request identifier. + key: LMCache key object. + instance_id: Worker process instance id. + kv_caches: Worker KV cache tensors keyed by layer name. + block_ids: vLLM block ids to store. + event: Synchronization event object. + blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. + """ + + @abstractmethod + def submit_retrieve( + self, + request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + blocks_in_chunk: int, + skip_first_n_tokens: int = 0, + ) -> None: + """Submit a retrieve request. + + Args: + request_id: Request identifier. + key: LMCache key object. + instance_id: Worker process instance id. + kv_caches: Worker KV cache tensors keyed by layer name. + block_ids: vLLM block ids to retrieve. + event: Synchronization event object. + blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. + skip_first_n_tokens: Number of tokens to skip for partial scatter. + """ + + @abstractmethod + def poll_finished(self) -> tuple[set[str], set[str], set[int]]: + """Poll completed requests. + + Returns: + Tuple of ``(finished_store_ids, finished_retrieve_ids, error_block_ids)``. + """ + + @abstractmethod + def drain_all(self) -> tuple[set[str], set[str], set[int]]: + """Drain all pending requests. + + Returns: + Tuple of ``(finished_store_ids, finished_retrieve_ids, error_block_ids)``. + """ + + @abstractmethod + def close(self) -> None: + """Release resources held by this context.""" + + +class CudaTransferContext(TransferContext): + """CUDA IPC + MQ future transport context.""" + + def __init__(self) -> None: + self._store_futures: dict[str, Any] = {} + self._retrieve_futures: dict[str, tuple[Any, list[int]]] = {} + self._mq_client: MessageQueueClient | None = None + self._mq_timeout: float = 0.0 + self._send_request: SendRequest | None = None + + def register( + self, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + model_name: str, + world_size: int, + _blocks_in_chunk: int, + mq_client: MessageQueueClient, + mq_timeout: float, + send_request: SendRequest, + ) -> None: + # First Party + from lmcache.integration.vllm.utils import vllm_layout_hints + from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches + + self._mq_client = mq_client + self._mq_timeout = mq_timeout + self._send_request = send_request + layout_hints = vllm_layout_hints() + future = send_request( + mq_client, + RequestType.REGISTER_KV_CACHE, + [ + instance_id, + wrap_kv_caches(kv_caches), + model_name, + world_size, + EngineType.VLLM, + layout_hints, + ], + ) + future.result(timeout=mq_timeout) + + def submit_store( + self, + request_id: str, + key: Any, + instance_id: int, + _kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + _blocks_in_chunk: int, + ) -> None: + if ( + self._mq_client is None + or self._send_request is None + or self._mq_timeout < 0 + ): + raise RuntimeError( + "CUDA transfer context is not registered. " + "Call register() before submit_store()." + ) + future = self._send_request( + self._mq_client, + RequestType.STORE, + [key, instance_id, block_ids, event.ipc_handle()], + ).to_cuda_future() + self._store_futures[request_id] = future + + def submit_retrieve( + self, + request_id: str, + key: Any, + instance_id: int, + _kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + _blocks_in_chunk: int, + skip_first_n_tokens: int = 0, + ) -> None: + if ( + self._mq_client is None + or self._send_request is None + or self._mq_timeout < 0 + ): + raise RuntimeError( + "CUDA transfer context is not registered. " + "Call register() before submit_retrieve()." + ) + future = self._send_request( + self._mq_client, + RequestType.RETRIEVE, + [key, instance_id, block_ids, event.ipc_handle(), skip_first_n_tokens], + ).to_cuda_future() + self._retrieve_futures[request_id] = (future, list(block_ids)) + + def poll_finished(self) -> tuple[set[str], set[str], set[int]]: + finished_stores: set[str] = set() + finished_retrieves: set[str] = set() + error_block_ids: set[int] = set() + + for request_id, s_future in list(self._store_futures.items()): + if not s_future.query(): + continue + s_result = s_future.result() + finished_stores.add(request_id) + if not s_result: + logger.error( + "Something went wrong when processing the store request " + "for request_id=%s", + request_id, + ) + self._store_futures.pop(request_id, None) + + for request_id, (r_future, r_block_ids) in list(self._retrieve_futures.items()): + if not r_future.query(): + continue + r_result = r_future.result() + finished_retrieves.add(request_id) + if not r_result: + logger.error( + "Something went wrong when processing the retrieve request " + "for request_id=%s, result=%s", + request_id, + r_result, + ) + error_block_ids.update(r_block_ids) + self._retrieve_futures.pop(request_id, None) + + return finished_stores, finished_retrieves, error_block_ids + + def drain_all(self) -> tuple[set[str], set[str], set[int]]: + finished_stores = set(self._store_futures.keys()) + finished_retrieves = set(self._retrieve_futures.keys()) + error_block_ids: set[int] = set() + for _request_id, (_r_future, block_ids) in self._retrieve_futures.items(): + error_block_ids.update(block_ids) + self._store_futures.clear() + self._retrieve_futures.clear() + return finished_stores, finished_retrieves, error_block_ids + + def close(self) -> None: + self._store_futures.clear() + self._retrieve_futures.clear() + self._mq_client = None + self._mq_timeout = 0.0 + self._send_request = None + + +class CPUTransferContext(TransferContext): + """CPU context transport for non-CUDA workers.""" + + def __init__(self) -> None: + self._cpu_context: CPUContext | None = None + self._layout_hints: Any = None + self._gpu_kv_format: Any = None + self._store_done: dict[str, bool] = {} + self._retrieve_done: dict[str, tuple[bool, list[int]]] = {} + self._mq_client: MessageQueueClient | None = None + self._mq_timeout: float = 0.0 + self._send_request: SendRequest | None = None + + def register( + self, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + model_name: str, + world_size: int, + blocks_in_chunk: int, + mq_client: MessageQueueClient, + mq_timeout: float, + send_request: SendRequest, + ) -> None: + # First Party + from lmcache.integration.vllm.utils import vllm_layout_hints + + self._mq_client = mq_client + self._mq_timeout = mq_timeout + self._send_request = send_request + layout_hints = vllm_layout_hints() + ( + block_size, + num_layers, + hidden_dim_size, + dtype_str, + gpu_kv_format, + ) = compute_kv_layout(kv_caches, layout_hints=layout_hints) + self._layout_hints = layout_hints + self._gpu_kv_format = gpu_kv_format + + future = send_request( + mq_client, + RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, + [ + instance_id, + model_name, + world_size, + EngineType.VLLM, + layout_hints, + block_size, + num_layers, + hidden_dim_size, + dtype_str, + is_mla(gpu_kv_format), + ], + ) + + use_mla_flag = is_mla(gpu_kv_format) + shape = ( + torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size]) + if use_mla_flag + else torch.Size( + [2, num_layers, blocks_in_chunk * block_size, hidden_dim_size] + ) + ) + dtype = getattr(torch, dtype_str) + metadata = CPUContextMetadata( + layout_desc=MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]), + block_size=block_size, + use_mla=use_mla_flag, + ) + self._cpu_context = create_cpu_context(metadata, mq_client, mq_timeout) + future.result(timeout=mq_timeout) + + def submit_store( + self, + request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + _event: IPCEvent, + blocks_in_chunk: int, + ) -> None: + if self._cpu_context is None: + raise RuntimeError( + "CPU transfer context is not registered. " + "Call register() before submit_store()." + ) + torch_dev.synchronize() + cpu_chunks = gather_paged_kv_to_cpu( + kv_caches, + block_ids, + blocks_in_chunk, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + ) + handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) + ok = self._cpu_context.commit_store(handle) + self._store_done[request_id] = ok + + def submit_retrieve( + self, + request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + _event: IPCEvent, + blocks_in_chunk: int, + skip_first_n_tokens: int = 0, + ) -> None: + if self._cpu_context is None: + raise RuntimeError( + "CPU transfer context is not registered. " + "Call register() before submit_retrieve()." + ) + handle, chunks = self._cpu_context.prepare_retrieve(key, instance_id) + ok = chunks is not None + if chunks is not None: + try: + scatter_cpu_to_paged_kv( + kv_caches, + block_ids, + chunks, + blocks_in_chunk, + skip_first_n_tokens=skip_first_n_tokens, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + ) + except (RuntimeError, ValueError, TypeError, IndexError): + logger.exception("Failed to scatter retrieved CPU context chunks") + ok = False + self._cpu_context.commit_retrieve(handle) + self._retrieve_done[request_id] = (ok, list(block_ids)) + + def poll_finished(self) -> tuple[set[str], set[str], set[int]]: + finished_stores = set(self._store_done.keys()) + finished_retrieves = set(self._retrieve_done.keys()) + error_block_ids: set[int] = set() + for ok, block_ids in self._retrieve_done.values(): + if not ok: + error_block_ids.update(block_ids) + self._store_done.clear() + self._retrieve_done.clear() + return finished_stores, finished_retrieves, error_block_ids + + def drain_all(self) -> tuple[set[str], set[str], set[int]]: + return self.poll_finished() + + def close(self) -> None: + if self._cpu_context is not None: + self._cpu_context.close() + self._cpu_context = None + self._store_done.clear() + self._retrieve_done.clear() + self._mq_client = None + self._mq_timeout = 0.0 + self._send_request = None + + +def create_transfer_context( + kv_caches: dict[str, torch.Tensor], + **_kwargs: Any, +) -> TransferContext: + """Create a transfer context from KV cache device type. + + The device check is intentionally centralized here. + + Args: + kv_caches: Worker KV cache tensors keyed by layer name. + **kwargs: Unused placeholder for forward-compatible factory extension. + + Returns: + A concrete :class:`TransferContext` implementation. + + Raises: + ValueError: If ``kv_caches`` is empty or has mixed device types. + """ + if not kv_caches: + raise ValueError("kv_caches is empty") + device_types = {tensor.device.type for tensor in kv_caches.values()} + if len(device_types) != 1: + raise ValueError( + f"All KV cache tensors must share one device type, got {device_types}" + ) + device_type = next(iter(device_types)) + logger.info("Creating transfer context (device_type=%s)", device_type) + if device_type == "cuda": + return CudaTransferContext() + return CPUTransferContext() diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py index ea39a45c9c7..90b7d9d9b65 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -94,7 +94,7 @@ def test_wrap_kv_caches_cpu_context_returns_empty() -> None: def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: """Validate layout extraction and gather/scatter round-trip on CPU tensors.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import ( + from lmcache.v1.multiprocess.cpu_context import ( compute_kv_layout, gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, @@ -137,7 +137,7 @@ def test_gather_scatter_roundtrip_hnd_layout( ) -> None: """Validate gather/scatter round-trip for HND vLLM KV layout.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import ( + from lmcache.v1.multiprocess.cpu_context import ( compute_kv_layout, gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, @@ -189,7 +189,7 @@ def test_gather_scatter_roundtrip_hnd_layout( def test_scatter_respects_skip_first_n_tokens() -> None: """Ensure scatter honors skip_first_n_tokens and preserves skipped blocks.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import ( + from lmcache.v1.multiprocess.cpu_context import ( gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) @@ -217,7 +217,7 @@ def test_scatter_respects_skip_first_n_tokens() -> None: def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: """Validate gather/scatter round-trip for MLA KV tensors.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import ( + from lmcache.v1.multiprocess.cpu_context import ( compute_kv_layout, gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, @@ -252,7 +252,7 @@ def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: def test_compute_kv_layout_empty_raises_value_error() -> None: """Ensure compute_kv_layout rejects empty KV cache input.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import compute_kv_layout + from lmcache.v1.multiprocess.cpu_context import compute_kv_layout with pytest.raises(ValueError, match="kv_caches is empty"): compute_kv_layout({}) @@ -261,7 +261,7 @@ def test_compute_kv_layout_empty_raises_value_error() -> None: def test_scatter_mla_respects_skip_first_n_tokens() -> None: """Ensure MLA scatter honors skip_first_n_tokens and preserves skipped blocks.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import ( + from lmcache.v1.multiprocess.cpu_context import ( gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) @@ -291,7 +291,7 @@ def test_scatter_mla_respects_skip_first_n_tokens() -> None: def test_scatter_mla_skip_past_chunk_keeps_destination_unchanged() -> None: """Ensure MLA scatter is a no-op when skip_first_n_tokens exceeds chunk tokens.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import ( + from lmcache.v1.multiprocess.cpu_context import ( gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 854dc6137ab..84edb347bcb 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -14,11 +14,13 @@ # Third Party import pytest +import torch # First Party from lmcache.integration.vllm import vllm_multi_process_adapter as adapter_mod from lmcache.integration.vllm.vllm_multi_process_adapter import ( LMCacheMPWorkerAdapter, + LoadStoreOp, ParallelStrategy, ) from lmcache.v1.multiprocess.protocol import RequestType @@ -98,3 +100,63 @@ def test_register_kv_caches_raises_connection_error_on_timeout(fake_adapter): fake_tensor = MagicMock() fake_tensor.device.type = "cuda" adapter.register_kv_caches({"layer.0": fake_tensor}) + + +def test_register_kv_caches_cpu_submits_cpu_context_registration( + fake_adapter, monkeypatch +): + """CPU KV cache registration routes to REGISTER_KV_CACHE_CPU_CONTEXT.""" + adapter, send_mock, _ = fake_adapter + monkeypatch.setattr( + "lmcache.integration.vllm.utils.vllm_layout_hints", + lambda: {}, + raising=False, + ) + cpu_kv = {"layer.0": torch.randn(2, 8, 4, 2, 8)} + + adapter.register_kv_caches(cpu_kv) + + assert adapter.kv_caches is cpu_kv + assert send_mock.call_count == 1 + args, _kwargs = send_mock.call_args + assert args[1] == RequestType.REGISTER_KV_CACHE_CPU_CONTEXT + + +def test_submit_store_request_passes_no_transport_kwargs(fake_adapter, monkeypatch): + """submit_store_request should not pass mq/send kwargs after registration.""" + adapter, _send_mock, _ = fake_adapter + monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + adapter.kv_caches = {"layer.0": fake_tensor} + transfer_ctx = MagicMock() + adapter.transfer_ctx = transfer_ctx + op = LoadStoreOp(token_ids=[1, 2, 3, 4], block_ids=[0], start=0, end=4) + + adapter.submit_store_request("req-1", op, event=MagicMock()) + + assert transfer_ctx.submit_store.called + assert transfer_ctx.submit_store.call_args.kwargs == {} + + +def test_submit_retrieve_request_passes_no_transport_kwargs(fake_adapter, monkeypatch): + """submit_retrieve_request should not pass mq/send kwargs after registration.""" + adapter, _send_mock, _ = fake_adapter + monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + adapter.kv_caches = {"layer.0": fake_tensor} + transfer_ctx = MagicMock() + adapter.transfer_ctx = transfer_ctx + op = LoadStoreOp( + token_ids=[1, 2, 3, 4], + block_ids=[0], + start=0, + end=4, + skip_first_n_tokens=1, + ) + + adapter.submit_retrieve_request("req-1", op, event=MagicMock()) + + assert transfer_ctx.submit_retrieve.called + assert transfer_ctx.submit_retrieve.call_args.kwargs == {"skip_first_n_tokens": 1} From b665d41a7a9dbb361d76e9aed33d1b243a7f4013 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 13 May 2026 06:44:45 +0000 Subject: [PATCH 08/23] restore unnecessary changes Signed-off-by: Tony Lin --- .../vllm/vllm_multi_process_adapter.py | 106 +++++--- lmcache/v1/multiprocess/protocols/engine.py | 6 +- lmcache/v1/multiprocess/server.py | 103 +++----- lmcache/v1/multiprocess/transfer_context.py | 229 ++++++------------ tests/v1/multiprocess/test_cpu_context.py | 39 +-- tests/v1/test_vllm_mp_adapter.py | 16 +- 6 files changed, 212 insertions(+), 287 deletions(-) diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 10fec34cab7..02fcb031485 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -36,11 +36,7 @@ DEFAULT_HEARTBEAT_INTERVAL: float = 10.0 -def wrap_kv_caches( - kv_caches: dict[str, torch.Tensor], use_cpu_context: bool = False -) -> KVCache: - if use_cpu_context: - return [] +def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache: logger.info("KV caches keys are %s", list(kv_caches.keys())) return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()] @@ -704,8 +700,13 @@ def __init__( # Transport context for transfer operations. self.transfer_ctx: TransferContext | None = None - # Store requests submitted but not yet finished by transfer context. - self._pending_store_request_ids: set[str] = set() + + # Request futures + self.store_futures: dict[str, MessagingFuture[StoreResult]] = {} + # request_id -> (future, block_ids) + self.retrieve_futures: dict[ + str, tuple[MessagingFuture[RetrieveResult], list[int]] + ] = {} # Block IDs that failed due to retrieve timeout self.error_block_ids: set[int] = set() @@ -933,7 +934,7 @@ def submit_store_request( "Transfer context is not initialized. " "Call register_kv_caches() before submitting store requests." ) - self.transfer_ctx.submit_store( + future = self.transfer_ctx.submit_store( request_id, key, self.instance_id, @@ -942,7 +943,7 @@ def submit_store_request( event, self.blocks_in_chunk, ) - self._pending_store_request_ids.add(request_id) + self.store_futures[request_id] = future @_lmcache_nvtx_annotate def submit_retrieve_request( @@ -981,7 +982,7 @@ def submit_retrieve_request( "Transfer context is not initialized. " "Call register_kv_caches() before submitting retrieve requests." ) - self.transfer_ctx.submit_retrieve( + future = self.transfer_ctx.submit_retrieve( request_id, key, self.instance_id, @@ -991,6 +992,7 @@ def submit_retrieve_request( self.blocks_in_chunk, skip_first_n_tokens=op.skip_first_n_tokens, ) + self.retrieve_futures[request_id] = (future, list(op.block_ids)) @_lmcache_nvtx_annotate def batched_submit_store_requests( @@ -1055,10 +1057,7 @@ def _process_finished_stores( for req_id in finished_req_ids_from_engine: if req_id in self._returned_finished: continue - if ( - req_id in self.finished_stores - or req_id in self._pending_store_request_ids - ): + if req_id in self.finished_stores or req_id in self.store_futures: self.previously_finished.add(req_id) else: ret_stores.add(req_id) @@ -1090,35 +1089,72 @@ def get_finished( take care of deduplicating the request IDs and only return the request IDs that have not been returned before. """ - if self.transfer_ctx is None: - return set(), set() - - unhealthy = not self.is_healthy - if unhealthy: - finished_stores, finished_retrieves, error_block_ids = ( - self.transfer_ctx.drain_all() - ) - else: - finished_stores, finished_retrieves, error_block_ids = ( - self.transfer_ctx.poll_finished() + # If unhealthy, drain all pending futures immediately + if not self.is_healthy: + finished_stores = set(self.store_futures.keys()) + finished_retrieves = set() + for request_id, ( + _r_future, + r_block_ids, + ) in self.retrieve_futures.items(): + finished_retrieves.add(request_id) + self.error_block_ids.update(r_block_ids) + self.store_futures.clear() + self.retrieve_futures.clear() + + ret_stores = self._process_finished_stores( + finished_stores, finished_req_ids_from_engine ) + # A request may have a pending retrieve AND appear in + # finished_req_ids_from_engine (it ran without loading KV after + # the server died). The scheduler processes finished_recving + # first and deletes the request, so we must not also report it + # in finished_sending. + ret_stores -= finished_retrieves + return ret_stores, finished_retrieves - self.error_block_ids.update(error_block_ids) - self._pending_store_request_ids.difference_update(finished_stores) + finished_stores = set() + finished_retrieves = set() + for request_id, s_future in self.store_futures.items(): + if not s_future.query(): + continue + + s_result = s_future.result() + finished_stores.add(request_id) + + if not s_result: + logger.error( + "Something went wrong when processing the " + "store request for request_id=%s", + request_id, + ) + + for request_id, (r_future, _) in self.retrieve_futures.items(): + if not r_future.query(): + continue + + r_result = r_future.result() + finished_retrieves.add(request_id) + + if not r_result: + logger.error( + "Something went wrong when processing the " + "retrieve request for request_id=%s, result=%s", + request_id, + r_result, + ) + + # Remove the finished requests from the tracking dicts + for request_id in finished_stores: + self.store_futures.pop(request_id, None) + for request_id in finished_retrieves: + self.retrieve_futures.pop(request_id, None) # Update the internal states ret_stores = self._process_finished_stores( finished_stores, finished_req_ids_from_engine ) - if unhealthy: - # A request may have a pending retrieve AND appear in - # finished_req_ids_from_engine (it ran without loading KV after - # the server died). The scheduler processes finished_recving - # first and deletes the request, so we must not also report it - # in finished_sending. - ret_stores -= finished_retrieves - # the invocation of `get_finished` means that # these requests' KV caches are already fully stored. # or the requests normally ends without any store. diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index df84c788ce4..41d96200eba 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -154,12 +154,8 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: int, str, int, - EngineType, - LayoutHints, + bytes, int, - int, - int, - str, bool, ], response_class=None, diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 2f932490b59..61431750979 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -10,6 +10,7 @@ import time # Third Party +import torch import zmq # First Party @@ -289,12 +290,8 @@ def register_kv_cache_cpu_context( instance_id: int, model_name: str, world_size: int, - engine_type: EngineType, - layout_hints: LayoutHints, + layout_desc_bytes: bytes, block_size: int, - num_layers: int, - hidden_dim_size: int, - dtype_str: str, use_mla: bool, ) -> None: """Register non-CUDA KV layout metadata for CPU context mode. @@ -303,41 +300,11 @@ def register_kv_cache_cpu_context( instance_id: Worker instance identifier (typically PID). model_name: Model name associated with this worker. world_size: Worker world size used in cache keys. - engine_type: Serving engine type (kept for protocol compatibility). - layout_hints: Optional engine layout hints (protocol compatibility). + layout_desc_bytes: Pickled :class:`MemoryLayoutDesc`. block_size: Tokens per paged block. - num_layers: Number of model layers. - hidden_dim_size: Flattened hidden dimension per token. - dtype_str: Torch dtype name (for example ``"float16"``). use_mla: Whether the worker KV format is MLA. - MLA stores one latent vector per token with shape - ``[num_layers, chunk_size, hidden_dim_size]``; non-MLA stores - separate K/V planes with shape - ``[2, num_layers, chunk_size, hidden_dim_size]``. - - Raises: - ValueError: If ``dtype_str`` is not a valid torch dtype name. """ - # Third Party - import torch - - # Keep these for protocol compatibility with register_kv_cache(). - del engine_type, layout_hints - dtype = getattr(torch, dtype_str, None) - if dtype is None or not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid dtype_str '{dtype_str}': expected a torch.dtype name " - "(e.g. 'float16', 'bfloat16', 'float32')." - ) - - shape = ( - # MLA has one latent state per token (no separate K/V axis), - # while non-MLA stores separate K and V tensors at dim 0. - torch.Size([num_layers, self.chunk_size, hidden_dim_size]) - if use_mla - else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) - ) - layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + layout_desc = pickle.loads(layout_desc_bytes) self.cpu_contexts[instance_id] = CPUContextMetadata( layout_desc=layout_desc, block_size=block_size, @@ -345,6 +312,27 @@ def register_kv_cache_cpu_context( ) self.cpu_context_meta[instance_id] = (model_name, world_size) + def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: + """Resolve object keys from an IPC cache key. + + Args: + key: IPC cache key describing model/session/token range. + + Returns: + Resolved object keys for the requested token range. + + Raises: + ValueError: If ``key.worker_id`` is ``None``. + """ + session = self.session_manager.get_or_create(key.request_id) + session.set_tokens(list(key.token_ids)) + chunk_hashes = [ + TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) + ] + if key.worker_id is None: + raise ValueError("Must resolve keys with worker_id != None") + return ipc_key_to_object_keys(key, chunk_hashes) + @_lmcache_nvtx_annotate def store_cpu_chunks( self, @@ -365,17 +353,7 @@ def store_cpu_chunks( Raises: ValueError: If the instance has no registered cpu context. """ - # Third Party - import torch - - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - if key.worker_id is None: - raise ValueError("Must store with worker_id != None") - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) if instance_id not in self.cpu_contexts: raise ValueError( @@ -426,14 +404,7 @@ def retrieve_cpu_chunks( Raises: ValueError: If the instance has no registered cpu context. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - if key.worker_id is None: - raise ValueError("Must retrieve with worker_id != None") - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) if instance_id not in self.cpu_contexts: raise ValueError( @@ -479,16 +450,8 @@ def store( that signals the completion of the store operation. The second element indicates whether the store operation was successful. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - st = time.perf_counter() - - assert key.worker_id is not None, "Must store with worker_id != None" - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) assert instance_id in self.gpu_contexts, ( f"KV cache not registered for GPU ID {instance_id}" @@ -655,16 +618,8 @@ def retrieve( that signals the completion of the retrieve operation. The second element indicates whether the key was successfully retrieved. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - st = time.perf_counter() - - assert key.worker_id is not None, "Must retrieve with worker_id != None" - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + obj_keys = self._resolve_obj_keys(key) assert instance_id in self.gpu_contexts, ( f"KV cache not registered for GPU ID {instance_id}" diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 12362d35923..94f42e338f0 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -4,6 +4,7 @@ # Standard from abc import ABC, abstractmethod from typing import Any, Callable, Protocol +import pickle # Third Party import torch @@ -21,7 +22,8 @@ gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) -from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture +from lmcache.v1.multiprocess.futures import MessagingFuture +from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import RequestType logger = init_logger(__name__) @@ -34,17 +36,16 @@ def ipc_handle(self) -> object: """Return an IPC handle consumable by the multiprocess server.""" -SendRequest = Callable[ - [MessageQueueClient, RequestType, list[object]], MessagingFuture[object] -] +SendRequest = Callable[[MessageQueueClient, RequestType, list[object]], MessagingFuture] class TransferContext(ABC): """Abstract transport layer for worker-side KV transfer. Concrete implementations encapsulate how worker-side store/retrieve - operations are transmitted to the multiprocess server (for example, - CUDA IPC futures or CPU-context gather/scatter flows). + operations are transmitted to the multiprocess server. CUDA paths return + CUDA-aware futures backed by MQ requests, while CPU paths may perform + gather/scatter synchronously and return already-resolved futures. """ @abstractmethod @@ -62,14 +63,19 @@ def register( """Register KV caches with the server and wait for ACK. Args: - instance_id: Worker process instance id. + instance_id: Worker process instance identifier. kv_caches: Worker KV cache tensors keyed by layer name. model_name: Model name used by cache keys. world_size: KV world size. - blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. mq_client: Message queue client used to communicate with server. mq_timeout: Timeout in seconds for synchronous request wait. send_request: Request sender callable used to issue MQ requests. + + Raises: + TimeoutError: If server registration does not complete before + ``mq_timeout``. + RuntimeError: If a concrete context cannot initialize. """ @abstractmethod @@ -82,17 +88,23 @@ def submit_store( block_ids: list[int], event: IPCEvent, blocks_in_chunk: int, - ) -> None: - """Submit a store request. + ) -> MessagingFuture: + """Submit a store request and return a completion future. Args: - request_id: Request identifier. - key: LMCache key object. - instance_id: Worker process instance id. + request_id: External request identifier. + key: LMCache key object for the store range. + instance_id: Worker process instance identifier. kv_caches: Worker KV cache tensors keyed by layer name. - block_ids: vLLM block ids to store. + block_ids: vLLM block IDs to store. event: Synchronization event object. - blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + + Returns: + A future compatible with adapter-side ``query()``/``result()`` flow. + + Raises: + RuntimeError: If register() was not called first. """ @abstractmethod @@ -106,34 +118,24 @@ def submit_retrieve( event: IPCEvent, blocks_in_chunk: int, skip_first_n_tokens: int = 0, - ) -> None: - """Submit a retrieve request. + ) -> MessagingFuture: + """Submit a retrieve request and return a completion future. Args: - request_id: Request identifier. - key: LMCache key object. - instance_id: Worker process instance id. + request_id: External request identifier. + key: LMCache key object for the retrieve range. + instance_id: Worker process instance identifier. kv_caches: Worker KV cache tensors keyed by layer name. - block_ids: vLLM block ids to retrieve. + block_ids: vLLM block IDs to retrieve into. event: Synchronization event object. - blocks_in_chunk: Number of vLLM blocks in one LMCache chunk. - skip_first_n_tokens: Number of tokens to skip for partial scatter. - """ - - @abstractmethod - def poll_finished(self) -> tuple[set[str], set[str], set[int]]: - """Poll completed requests. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + skip_first_n_tokens: Number of initial tokens to skip when writing. Returns: - Tuple of ``(finished_store_ids, finished_retrieve_ids, error_block_ids)``. - """ + A future compatible with adapter-side ``query()``/``result()`` flow. - @abstractmethod - def drain_all(self) -> tuple[set[str], set[str], set[int]]: - """Drain all pending requests. - - Returns: - Tuple of ``(finished_store_ids, finished_retrieve_ids, error_block_ids)``. + Raises: + RuntimeError: If register() was not called first. """ @abstractmethod @@ -145,10 +147,7 @@ class CudaTransferContext(TransferContext): """CUDA IPC + MQ future transport context.""" def __init__(self) -> None: - self._store_futures: dict[str, Any] = {} - self._retrieve_futures: dict[str, tuple[Any, list[int]]] = {} self._mq_client: MessageQueueClient | None = None - self._mq_timeout: float = 0.0 self._send_request: SendRequest | None = None def register( @@ -167,7 +166,6 @@ def register( from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches self._mq_client = mq_client - self._mq_timeout = mq_timeout self._send_request = send_request layout_hints = vllm_layout_hints() future = send_request( @@ -186,33 +184,28 @@ def register( def submit_store( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, _kv_caches: dict[str, torch.Tensor], block_ids: list[int], event: IPCEvent, _blocks_in_chunk: int, - ) -> None: - if ( - self._mq_client is None - or self._send_request is None - or self._mq_timeout < 0 - ): + ) -> MessagingFuture: + if self._mq_client is None or self._send_request is None: raise RuntimeError( "CUDA transfer context is not registered. " "Call register() before submit_store()." ) - future = self._send_request( + return self._send_request( self._mq_client, RequestType.STORE, [key, instance_id, block_ids, event.ipc_handle()], ).to_cuda_future() - self._store_futures[request_id] = future def submit_retrieve( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, _kv_caches: dict[str, torch.Tensor], @@ -220,73 +213,20 @@ def submit_retrieve( event: IPCEvent, _blocks_in_chunk: int, skip_first_n_tokens: int = 0, - ) -> None: - if ( - self._mq_client is None - or self._send_request is None - or self._mq_timeout < 0 - ): + ) -> MessagingFuture: + if self._mq_client is None or self._send_request is None: raise RuntimeError( "CUDA transfer context is not registered. " "Call register() before submit_retrieve()." ) - future = self._send_request( + return self._send_request( self._mq_client, RequestType.RETRIEVE, [key, instance_id, block_ids, event.ipc_handle(), skip_first_n_tokens], ).to_cuda_future() - self._retrieve_futures[request_id] = (future, list(block_ids)) - - def poll_finished(self) -> tuple[set[str], set[str], set[int]]: - finished_stores: set[str] = set() - finished_retrieves: set[str] = set() - error_block_ids: set[int] = set() - - for request_id, s_future in list(self._store_futures.items()): - if not s_future.query(): - continue - s_result = s_future.result() - finished_stores.add(request_id) - if not s_result: - logger.error( - "Something went wrong when processing the store request " - "for request_id=%s", - request_id, - ) - self._store_futures.pop(request_id, None) - - for request_id, (r_future, r_block_ids) in list(self._retrieve_futures.items()): - if not r_future.query(): - continue - r_result = r_future.result() - finished_retrieves.add(request_id) - if not r_result: - logger.error( - "Something went wrong when processing the retrieve request " - "for request_id=%s, result=%s", - request_id, - r_result, - ) - error_block_ids.update(r_block_ids) - self._retrieve_futures.pop(request_id, None) - - return finished_stores, finished_retrieves, error_block_ids - - def drain_all(self) -> tuple[set[str], set[str], set[int]]: - finished_stores = set(self._store_futures.keys()) - finished_retrieves = set(self._retrieve_futures.keys()) - error_block_ids: set[int] = set() - for _request_id, (_r_future, block_ids) in self._retrieve_futures.items(): - error_block_ids.update(block_ids) - self._store_futures.clear() - self._retrieve_futures.clear() - return finished_stores, finished_retrieves, error_block_ids def close(self) -> None: - self._store_futures.clear() - self._retrieve_futures.clear() self._mq_client = None - self._mq_timeout = 0.0 self._send_request = None @@ -297,11 +237,6 @@ def __init__(self) -> None: self._cpu_context: CPUContext | None = None self._layout_hints: Any = None self._gpu_kv_format: Any = None - self._store_done: dict[str, bool] = {} - self._retrieve_done: dict[str, tuple[bool, list[int]]] = {} - self._mq_client: MessageQueueClient | None = None - self._mq_timeout: float = 0.0 - self._send_request: SendRequest | None = None def register( self, @@ -317,9 +252,6 @@ def register( # First Party from lmcache.integration.vllm.utils import vllm_layout_hints - self._mq_client = mq_client - self._mq_timeout = mq_timeout - self._send_request = send_request layout_hints = vllm_layout_hints() ( block_size, @@ -331,6 +263,17 @@ def register( self._layout_hints = layout_hints self._gpu_kv_format = gpu_kv_format + use_mla_flag = is_mla(gpu_kv_format) + shape = ( + torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size]) + if use_mla_flag + else torch.Size( + [2, num_layers, blocks_in_chunk * block_size, hidden_dim_size] + ) + ) + dtype = getattr(torch, dtype_str) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + future = send_request( mq_client, RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, @@ -338,27 +281,14 @@ def register( instance_id, model_name, world_size, - EngineType.VLLM, - layout_hints, + pickle.dumps(layout_desc), block_size, - num_layers, - hidden_dim_size, - dtype_str, - is_mla(gpu_kv_format), + use_mla_flag, ], ) - use_mla_flag = is_mla(gpu_kv_format) - shape = ( - torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size]) - if use_mla_flag - else torch.Size( - [2, num_layers, blocks_in_chunk * block_size, hidden_dim_size] - ) - ) - dtype = getattr(torch, dtype_str) metadata = CPUContextMetadata( - layout_desc=MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]), + layout_desc=layout_desc, block_size=block_size, use_mla=use_mla_flag, ) @@ -367,19 +297,20 @@ def register( def submit_store( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, kv_caches: dict[str, torch.Tensor], block_ids: list[int], _event: IPCEvent, blocks_in_chunk: int, - ) -> None: + ) -> MessagingFuture: if self._cpu_context is None: raise RuntimeError( "CPU transfer context is not registered. " "Call register() before submit_store()." ) + torch_dev.synchronize() cpu_chunks = gather_paged_kv_to_cpu( kv_caches, @@ -390,11 +321,14 @@ def submit_store( ) handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) ok = self._cpu_context.commit_store(handle) - self._store_done[request_id] = ok + + future: MessagingFuture[bool] = MessagingFuture() + future.set_result(ok) + return future def submit_retrieve( self, - request_id: str, + _request_id: str, key: Any, instance_id: int, kv_caches: dict[str, torch.Tensor], @@ -402,12 +336,13 @@ def submit_retrieve( _event: IPCEvent, blocks_in_chunk: int, skip_first_n_tokens: int = 0, - ) -> None: + ) -> MessagingFuture: if self._cpu_context is None: raise RuntimeError( "CPU transfer context is not registered. " "Call register() before submit_retrieve()." ) + handle, chunks = self._cpu_context.prepare_retrieve(key, instance_id) ok = chunks is not None if chunks is not None: @@ -425,31 +360,15 @@ def submit_retrieve( logger.exception("Failed to scatter retrieved CPU context chunks") ok = False self._cpu_context.commit_retrieve(handle) - self._retrieve_done[request_id] = (ok, list(block_ids)) - - def poll_finished(self) -> tuple[set[str], set[str], set[int]]: - finished_stores = set(self._store_done.keys()) - finished_retrieves = set(self._retrieve_done.keys()) - error_block_ids: set[int] = set() - for ok, block_ids in self._retrieve_done.values(): - if not ok: - error_block_ids.update(block_ids) - self._store_done.clear() - self._retrieve_done.clear() - return finished_stores, finished_retrieves, error_block_ids - - def drain_all(self) -> tuple[set[str], set[str], set[int]]: - return self.poll_finished() + + future: MessagingFuture[bool] = MessagingFuture() + future.set_result(ok) + return future def close(self) -> None: if self._cpu_context is not None: self._cpu_context.close() self._cpu_context = None - self._store_done.clear() - self._retrieve_done.clear() - self._mq_client = None - self._mq_timeout = 0.0 - self._send_request = None def create_transfer_context( diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py index 90b7d9d9b65..ff38859e3c5 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -10,6 +10,9 @@ import pytest import torch +# First Party +from lmcache.v1.distributed.api import MemoryLayoutDesc + def _make_kv_caches( num_layers: int = 2, @@ -83,12 +86,20 @@ def _make_hnd_flashinfer_kv_caches( return kv_caches -def test_wrap_kv_caches_cpu_context_returns_empty() -> None: - """Verify wrap_kv_caches returns no IPC wrappers in cpu context mode.""" +def test_wrap_kv_caches_wraps_all_tensors(monkeypatch: Any) -> None: + """Verify wrap_kv_caches wraps all provided KV tensors.""" # First Party - from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches + from lmcache.integration.vllm import vllm_multi_process_adapter as adapter_mod + + kv_caches = _make_kv_caches() + monkeypatch.setattr( + adapter_mod, + "CudaIPCWrapper", + lambda tensor: ("wrapped", tensor), + ) - assert wrap_kv_caches(_make_kv_caches(), use_cpu_context=True) == [] + wrapped = adapter_mod.wrap_kv_caches(kv_caches) + assert len(wrapped) == len(kv_caches) def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: @@ -345,16 +356,16 @@ def test_server_register_and_find_cpu_context_layout( patch("lmcache.v1.multiprocess.server.get_event_bus"), ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) + expected_layout_desc = MemoryLayoutDesc( + shapes=[torch.Size([2, 2, 16, 16])], + dtypes=[torch.float32], + ) engine.register_kv_cache_cpu_context( instance_id=1, model_name="m", world_size=1, - engine_type=MagicMock(), - layout_hints={}, + layout_desc_bytes=pickle.dumps(expected_layout_desc), block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", use_mla=False, ) @@ -398,16 +409,16 @@ def _read_prefetched_results(_keys: Any) -> Any: session_cls.return_value.get_or_create.return_value = mock_session engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + layout_desc = MemoryLayoutDesc( + shapes=[torch.Size([2, 2, 8, 16])], + dtypes=[torch.float32], + ) engine.register_kv_cache_cpu_context( instance_id=2, model_name="m", world_size=1, - engine_type=MagicMock(), - layout_hints={}, + layout_desc_bytes=pickle.dumps(layout_desc), block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", use_mla=False, ) payload = torch.ones(2, 2, 8, 16) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 84edb347bcb..ef7d7c0b93b 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -120,16 +120,20 @@ def test_register_kv_caches_cpu_submits_cpu_context_registration( assert send_mock.call_count == 1 args, _kwargs = send_mock.call_args assert args[1] == RequestType.REGISTER_KV_CACHE_CPU_CONTEXT + assert len(args[2]) == 6 + assert isinstance(args[2][3], bytes) -def test_submit_store_request_passes_no_transport_kwargs(fake_adapter, monkeypatch): - """submit_store_request should not pass mq/send kwargs after registration.""" +def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch): + """submit_store_request stores the returned future in store_futures.""" adapter, _send_mock, _ = fake_adapter monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) fake_tensor = MagicMock() fake_tensor.device.type = "cuda" adapter.kv_caches = {"layer.0": fake_tensor} transfer_ctx = MagicMock() + fake_future = MagicMock() + transfer_ctx.submit_store.return_value = fake_future adapter.transfer_ctx = transfer_ctx op = LoadStoreOp(token_ids=[1, 2, 3, 4], block_ids=[0], start=0, end=4) @@ -137,16 +141,19 @@ def test_submit_store_request_passes_no_transport_kwargs(fake_adapter, monkeypat assert transfer_ctx.submit_store.called assert transfer_ctx.submit_store.call_args.kwargs == {} + assert adapter.store_futures["req-1"] is fake_future -def test_submit_retrieve_request_passes_no_transport_kwargs(fake_adapter, monkeypatch): - """submit_retrieve_request should not pass mq/send kwargs after registration.""" +def test_submit_retrieve_request_tracks_returned_future(fake_adapter, monkeypatch): + """submit_retrieve_request stores returned future and block IDs.""" adapter, _send_mock, _ = fake_adapter monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) fake_tensor = MagicMock() fake_tensor.device.type = "cuda" adapter.kv_caches = {"layer.0": fake_tensor} transfer_ctx = MagicMock() + fake_future = MagicMock() + transfer_ctx.submit_retrieve.return_value = fake_future adapter.transfer_ctx = transfer_ctx op = LoadStoreOp( token_ids=[1, 2, 3, 4], @@ -160,3 +167,4 @@ def test_submit_retrieve_request_passes_no_transport_kwargs(fake_adapter, monkey assert transfer_ctx.submit_retrieve.called assert transfer_ctx.submit_retrieve.call_args.kwargs == {"skip_first_n_tokens": 1} + assert adapter.retrieve_futures["req-1"] == (fake_future, [0]) From 1b5fa31bac99b4e23dfc06a218c873406e26baa4 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 13 May 2026 07:19:31 +0000 Subject: [PATCH 09/23] Revert CPU registration payload from pickle bytes to scalar fields Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/protocols/engine.py | 4 +++- lmcache/v1/multiprocess/server.py | 26 ++++++++++++++++++--- lmcache/v1/multiprocess/transfer_context.py | 5 ++-- tests/v1/multiprocess/test_cpu_context.py | 19 +++++---------- tests/v1/test_vllm_mp_adapter.py | 3 +-- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index 41d96200eba..3fdcdd9edbe 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -154,8 +154,10 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: int, str, int, - bytes, int, + int, + int, + str, bool, ], response_class=None, diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 61431750979..ee00604ddb2 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -290,8 +290,10 @@ def register_kv_cache_cpu_context( instance_id: int, model_name: str, world_size: int, - layout_desc_bytes: bytes, block_size: int, + num_layers: int, + hidden_dim_size: int, + dtype_str: str, use_mla: bool, ) -> None: """Register non-CUDA KV layout metadata for CPU context mode. @@ -300,11 +302,29 @@ def register_kv_cache_cpu_context( instance_id: Worker instance identifier (typically PID). model_name: Model name associated with this worker. world_size: Worker world size used in cache keys. - layout_desc_bytes: Pickled :class:`MemoryLayoutDesc`. block_size: Tokens per paged block. + num_layers: Number of model layers. + hidden_dim_size: Flattened hidden dimension per token. + dtype_str: Torch dtype name (for example ``"float16"``). use_mla: Whether the worker KV format is MLA. + + Raises: + ValueError: If ``dtype_str`` is not a valid torch dtype name. """ - layout_desc = pickle.loads(layout_desc_bytes) + dtype = getattr(torch, dtype_str, None) + if dtype is None or not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid dtype_str '{dtype_str}': must be a valid torch dtype " + "attribute name (e.g. 'float16' for torch.float16, " + "'bfloat16' for torch.bfloat16, 'float32' for torch.float32)." + ) + + shape = ( + torch.Size([num_layers, self.chunk_size, hidden_dim_size]) + if use_mla + else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) + ) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) self.cpu_contexts[instance_id] = CPUContextMetadata( layout_desc=layout_desc, block_size=block_size, diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 94f42e338f0..89047a3983d 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -4,7 +4,6 @@ # Standard from abc import ABC, abstractmethod from typing import Any, Callable, Protocol -import pickle # Third Party import torch @@ -281,8 +280,10 @@ def register( instance_id, model_name, world_size, - pickle.dumps(layout_desc), block_size, + num_layers, + hidden_dim_size, + dtype_str, use_mla_flag, ], ) diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_cpu_context.py index ff38859e3c5..5a6a4c0ddc4 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_cpu_context.py @@ -10,9 +10,6 @@ import pytest import torch -# First Party -from lmcache.v1.distributed.api import MemoryLayoutDesc - def _make_kv_caches( num_layers: int = 2, @@ -356,16 +353,14 @@ def test_server_register_and_find_cpu_context_layout( patch("lmcache.v1.multiprocess.server.get_event_bus"), ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) - expected_layout_desc = MemoryLayoutDesc( - shapes=[torch.Size([2, 2, 16, 16])], - dtypes=[torch.float32], - ) engine.register_kv_cache_cpu_context( instance_id=1, model_name="m", world_size=1, - layout_desc_bytes=pickle.dumps(expected_layout_desc), block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", use_mla=False, ) @@ -409,16 +404,14 @@ def _read_prefetched_results(_keys: Any) -> Any: session_cls.return_value.get_or_create.return_value = mock_session engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) - layout_desc = MemoryLayoutDesc( - shapes=[torch.Size([2, 2, 8, 16])], - dtypes=[torch.float32], - ) engine.register_kv_cache_cpu_context( instance_id=2, model_name="m", world_size=1, - layout_desc_bytes=pickle.dumps(layout_desc), block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", use_mla=False, ) payload = torch.ones(2, 2, 8, 16) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index ef7d7c0b93b..85d31f1d69b 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -120,8 +120,7 @@ def test_register_kv_caches_cpu_submits_cpu_context_registration( assert send_mock.call_count == 1 args, _kwargs = send_mock.call_args assert args[1] == RequestType.REGISTER_KV_CACHE_CPU_CONTEXT - assert len(args[2]) == 6 - assert isinstance(args[2][3], bytes) + assert len(args[2]) == 8 def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch): From 8008ed3a40600bf62a1355cd53fa4fafc54c1482 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Thu, 14 May 2026 00:51:42 +0000 Subject: [PATCH 10/23] update design doc Signed-off-by: Tony Lin --- .../v1/multiprocess/cpu_context_design.md | 379 ++++++++---------- 1 file changed, 170 insertions(+), 209 deletions(-) diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/cpu_context_design.md index b74d27185b2..c2345e40958 100644 --- a/docs/design/v1/multiprocess/cpu_context_design.md +++ b/docs/design/v1/multiprocess/cpu_context_design.md @@ -1,202 +1,83 @@ # CPU Context Design (MP mode, non-CUDA) -## Scope +## 1. Motivation -This document describes the non-CUDA CPU-based KV transfer path for LMCache -multiprocess mode. +LMCache multiprocess mode relies on **CUDA IPC** to transfer KV cache data +between vLLM worker processes and the LMCache cache server. The existing +path wraps GPU tensors in `CudaIPCWrapper`, exchanges IPC handles via ZMQ +messages, and uses CUDA events for cross-process synchronisation. -The goal is to support KV transfer for non-CUDA devices (for example CPU, -XPU, HPU) without changing the existing CUDA IPC path, while providing a -clean abstraction layer that makes it easy to add alternative transport -mechanisms (e.g. shared memory) in a future PR. +This design is fundamentally tied to the CUDA programming model: -## Why this path exists - -The CUDA path uses IPC wrappers around GPU tensors and the existing -`REGISTER_KV_CACHE` / `STORE` / `RETRIEVE` request flow. - -For non-CUDA tensors, CUDA IPC is not available. The CPU context path -provides a generic protocol where workers: - -1. Gather KV blocks into CPU chunk(or memory obj) tensors. -2. Transport those CPU chunks to the server storage through a concrete - `CPUContext` implementation. -3. Retrieve CPU chunks(or memory obj) from the server and scatter them back into device KV - tensors. - -## Protocol additions - -Three request types are used for non-CUDA mode (unchanged from the original -cpu context design): - -- `REGISTER_KV_CACHE_CPU_CONTEXT` -- `STORE_CPU_CHUNKS` -- `RETRIEVE_CPU_CHUNKS` - -These are registered in the MP server dispatch and have corresponding -payload/response contracts in the multiprocess protocol definitions. - -## File structure - -``` -lmcache/v1/multiprocess/ -├── cpu_context.py # CPUContextMetadata, CPUContext(ABC), factory, gather/scatter utils -└── cpu_context_pickle.py # CPUContextPickle — pickle-based concrete implementation -``` - -### `cpu_context.py` - -Provides: - -- **`CPUContextMetadata`** dataclass — layout metadata: - - ```python - @dataclass - class CPUContextMetadata: - layout_desc: MemoryLayoutDesc - block_size: int - use_mla: bool - ``` - -- **`CPUContext(ABC)`** — abstract base class with `mq_client` as a common - dependency. All concrete implementations share the same two-phase - `prepare/commit` interface: - - ```python - class CPUContext(ABC): - def __init__(self, metadata: CPUContextMetadata, mq_client, mq_timeout: float): ... - - @abstractmethod - def prepare_store(self, key, instance_id, chunks: list[torch.Tensor]) -> Any: ... - @abstractmethod - def commit_store(self, handle: Any) -> bool: ... - @abstractmethod - def prepare_retrieve(self, key, instance_id) -> tuple[Any, list[torch.Tensor] | None]: ... - @abstractmethod - def commit_retrieve(self, handle: Any) -> None: ... - @abstractmethod - def close(self) -> None: ... - ``` - -- **`create_cpu_context()`** factory — currently always returns a - `CPUContextPickle` instance; a future SHM-capable PR can extend this to - probe for shared-memory availability and fall back to pickle. - -- **Shared utility functions** used by all concrete implementations: - - `compute_kv_layout` — extract block size, layer count, hidden dim and - dtype from live KV tensors. - - `gather_paged_kv_to_cpu` — gather paged KV blocks into a list of CPU - tensors (one per LMCache chunk). - - `scatter_cpu_to_paged_kv` — scatter CPU chunk tensors back into paged - KV tensors. - -### `cpu_context_pickle.py` - -Provides **`CPUContextPickle(CPUContext)`**: - -| Phase | What happens | +| CUDA IPC dependency | Why it blocks non-CUDA devices | |---|---| -| `prepare_store` | `pickle.dumps(chunks)` → returns `(key, instance_id, bytes)` as opaque handle | -| `commit_store` | sends `STORE_CPU_CHUNKS` via `mq_client`, blocks for server ack, returns `bool` | -| `prepare_retrieve` | sends `RETRIEVE_CPU_CHUNKS` via `mq_client`, blocks for response, `pickle.loads` → returns `(None, chunks)` or `(None, None)` on miss | -| `commit_retrieve` | no-op (pickle path holds no server-side locks) | -| `close` | no-op | - -## Tensor/chunk contracts +| `CudaIPCWrapper` / `cudaIpcGetMemHandle` | Only works on NVIDIA CUDA tensors | +| `torch.cuda.Event(interprocess=True)` | CUDA-specific IPC event API | +| `cupy.cuda.ExternalStream` | CUDA stream wrapper | +| GPU pointer arithmetic in C++ kernels | Assumes CUDA device pointers | -Chunk formats are unchanged: +For non-CUDA accelerators — **CPU, Intel XPU, Habana HPU**, or any future +device — none of these primitives are available. -- non-MLA: `[2, num_layers, chunk_tokens, hidden_dim]` -- MLA: `[num_layers, chunk_tokens, hidden_dim]` +The **CPU context** path introduces a device-agnostic KV transfer mechanism: -Internal gather/scatter uses block-level indexing to avoid token-level slot -expansion and token-wise select/copy operations. +1. Workers **gather** paged KV blocks into contiguous CPU chunk tensors. +2. CPU chunks are **transported** to the server through a pluggable + serialisation layer (pickle today, shared memory in the future). +3. On retrieve, the server returns CPU chunks and workers **scatter** them + back into device-local paged KV tensors. -## Layout handling +The existing CUDA IPC path is **untouched** — the two paths coexist behind a +polymorphic `TransferContext` abstraction. -Supported KV formats in CPU gather/scatter: +### Transport comparison -- `NL_X_TWO_NB_BS_NH_HS` (NHD) -- `NL_X_NB_TWO_BS_NH_HS` (NHD flashinfer) -- `NL_X_TWO_NB_NH_BS_HS` (HND) -- `NL_X_NB_TWO_NH_BS_HS` (HND flashinfer) -- `NL_X_NB_BS_HS` (MLA) +**Store (worker → server storage):** -## Worker adapter integration +| Transport | Copies | Data flow | +|---|---|---| +| CUDA IPC | 2 | GPU KV → GPU staging buffer → CPU memory obj | +| Pickle | 4 | GPU KV → CPU chunk → pickle.dumps → pickle.loads → CPU memory obj | +| SHM (TODO) | 1 | GPU KV → CPU memory obj (SHM mapped) | -The adapter now delegates transport behavior to -`lmcache/v1/multiprocess/transfer_context.py`. +**Retrieve (server storage → worker):** -`create_transfer_context(kv_caches)` centralizes device dispatch: +| Transport | Copies | Data flow | +|---|---|---| +| CUDA IPC | 2 | CPU memory obj → GPU staging buffer → GPU KV | +| Pickle | 4 | CPU memory obj → pickle.dumps → pickle.loads → CPU chunk → GPU KV | +| SHM (TODO) | 1 | CPU memory obj (SHM mapped) → GPU KV | -- all CUDA → existing CUDA IPC registration and store/retrieve path -- all non-CUDA → `CPUTransferContext` with cpu context registration and CPU context store/retrieve path +**Applicability:** -`LMCacheMPSchedulerAdapter` now holds `self.transfer_ctx: TransferContext | None` -and calls: +| Transport | Platform requirement | Pros | Cons | +|---|---|---|---| +| CUDA IPC | NVIDIA CUDA devices only | Async GPU streams, mature path | CUDA-only | +| Pickle | Any device, no dependencies | Generally available, zero setup | 4 copies + serialisation overhead | +| SHM (TODO) | `/dev/shm` capacity ≥ L1 cache size | Fewest copies (1), no serialisation | Requires sufficient shared memory | -- `transfer_ctx.register(...)` -- `transfer_ctx.submit_store(...)` -- `transfer_ctx.submit_retrieve(...)` -- `transfer_ctx.poll_finished()` (healthy) or `transfer_ctx.drain_all()` (unhealthy) +## 2. Architecture Overview -### Store path (non-CUDA) +### 2.1 Layered architecture -```python -# CPUTransferContext.submit_store -cpu_chunks = gather_paged_kv_to_cpu(kv_caches, block_ids, blocks_in_chunk, ...) -handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) -ok = self._cpu_context.commit_store(handle) # synchronous; blocks for server ack -self._store_done[request_id] = ok ``` - -`CPUTransferContext.poll_finished()` drains `_store_done` on each call. - -### Retrieve path (non-CUDA) - -```python -# CPUTransferContext.submit_retrieve -handle, chunks = self._cpu_context.prepare_retrieve(key, instance_id) # synchronous -ok = chunks is not None -if chunks is not None: - try: - scatter_cpu_to_paged_kv(kv_caches, block_ids, chunks, blocks_in_chunk, - skip_first_n_tokens=skip_first_n_tokens, ...) - except (RuntimeError, ValueError, TypeError, IndexError): - ok = False -self._cpu_context.commit_retrieve(handle) -self._retrieve_done[request_id] = (ok, block_ids) +vllm_multi_process_adapter.py ← Engine adapter, device-agnostic + └── TransferContext ← Worker-side transport abstraction (§3) + ├── CudaTransferContext ← CUDA IPC + MQ future path + └── CPUTransferContext ← Synchronous gather/scatter path + └── CPUContext ← Serialisation abstraction (§4.2) + ├── CPUContextPickle ← pickle.dumps/loads (§4.3) + └── CPUContextShm ← shared memory (§4.4, TODO) ``` -`CPUTransferContext.poll_finished()` drains `_retrieve_done` on each call. -The adapter passes `op.skip_first_n_tokens` into -`transfer_ctx.submit_retrieve(..., skip_first_n_tokens=...)`. - -The retrieve is **synchronous inside `CPUTransferContext.submit_retrieve`**; -`poll_finished()` just drains request ids recorded by submit methods. - -## Server integration - -`MPCacheEngine` holds: +Two layers of abstraction serve different purposes: -- `cpu_contexts: dict[int, CPUContextMetadata]` — per-instance metadata. -- `cpu_context_meta: dict[int, tuple[str, int]]` — per-instance - `(model_name, world_size)` for layout resolution. +- **TransferContext** (§3) — decides **CUDA vs non-CUDA** routing at the + worker adapter level. +- **CPUContext** (§4.2) — decides **how** CPU chunk data is serialised and + transported (pickle vs SHM). Only used inside `CPUTransferContext`. -Server-side handler methods are unchanged: -- `register_kv_cache_cpu_context` — stores `CPUContextMetadata` in `cpu_contexts`. -- `store_cpu_chunks` — unpickles payload, copies tensors into storage. -- `retrieve_cpu_chunks` — reads from storage, pickles tensors, returns bytes. - -Additional integration points: - -- Unregister cleanup removes both `cpu_contexts` and `cpu_context_meta`. -- Layout lookup via `_find_layout_desc` resolves both GPU and CPU context - registrations. -- Status reporting (`report_status`) includes `registered_cpu_instance_ids` - and `cpu_context_meta`. - -## CUDA vs non-CUDA state machine +### 2.2 State machine (worker ↔ server) ```text register_kv_caches() @@ -210,8 +91,9 @@ Additional integration points: [device == cuda] [device != cuda] | | v v - CudaTransferContext.register() CPUTransferContext.register() - REGISTER_KV_CACHE (CUDA IPC) REGISTER_KV_CACHE_CPU_CONTEXT (CPU metadata) + CudaTransferContext.register() CPUTransferContext.register() + → REGISTER_KV_CACHE → REGISTER_KV_CACHE_CPU_CONTEXT + (CUDA IPC handles) (scalar metadata fields) | + create_cpu_context() +----------------+----------------+ | @@ -221,31 +103,30 @@ Additional integration points: +----------------+----------------+ | | v v - transfer_ctx.submit_store() transfer_ctx.submit_store() + transfer_ctx.submit_store() transfer_ctx.submit_store() | | v v - STORE (GPU -> L1) gather_paged_kv_to_cpu() - | + _cpu_context.prepare_store() - v + _cpu_context.commit_store() [sync] - [READY] _store_done[id] = ok - | | + STORE (GPU → L1) gather_paged_kv_to_cpu() + [async MQ future] + _cpu_context.prepare_store() + | + _cpu_context.commit_store() [sync] + v _store_done[id] = ok + [READY] | +----------------+----------------+ | v - transfer_ctx.submit_retrieve() + get_finished() + transfer_ctx.submit_retrieve() + poll_finished() | +----------------+----------------+ | | v v - RETRIEVE (L1 -> GPU) _cpu_context.prepare_retrieve() [sync] - [async future] + scatter_cpu_to_paged_kv() - + _cpu_context.commit_retrieve() - _retrieve_done[id] = (ok, block_ids) - | | + RETRIEVE (L1 → GPU) _cpu_context.prepare_retrieve() [sync] + [async MQ future] + scatter_cpu_to_paged_kv() + | + _cpu_context.commit_retrieve() + v _retrieve_done[id] = (ok, block_ids) +----------------+----------------+ | v - [READY / SERVING] + [READY / SERVING] | v unregister_kv_cache() @@ -254,36 +135,116 @@ Additional integration points: [TERMINATED] ``` -## Future extension: CPUContextShm +## 3. Worker-side: TransferContext Abstraction -The `CPUContext` base class is designed to accommodate a shared-memory -implementation in a future PR with minimal changes: +### 3.1 Problem -| Phase | Pickle | SHM (future) | +Before this refactoring, `vllm_multi_process_adapter.py` contained +`if self._use_cpu_context:` branches in every method — `register_kv_caches`, +`submit_store_request`, `submit_retrieve_request`, `get_finished`, and the +unhealthy drain path. Adding a third transport would require touching every +branch. + +### 3.2 Solution + +`transfer_context.py` defines the `TransferContext` ABC with six methods: +`register`, `submit_store`, `submit_retrieve`, `poll_finished`, `drain_all`, +and `close`. The adapter holds a single `TransferContext` and delegates — +no `if/else` anywhere. + +### 3.3 `create_transfer_context()` factory + +Inspects device types of all KV cache tensors **exactly once**. CUDA → +`CudaTransferContext`; otherwise → `CPUTransferContext`. Mixed device types +are rejected. + +### 3.4 `CudaTransferContext` + +Wraps the original CUDA IPC path. Sends `REGISTER_KV_CACHE` / `STORE` / +`RETRIEVE` messages with IPC handles, tracks async MQ futures. +`poll_finished` queries futures; `drain_all` marks all pending as finished +for unhealthy shutdown. Semantics identical to pre-refactoring. + +### 3.5 `CPUTransferContext` + +Holds a `CPUContext` instance internally. Sends +`REGISTER_KV_CACHE_CPU_CONTEXT` with scalar metadata. Store and retrieve +are **synchronous**: gather → prepare/commit, then record result in +`_store_done` / `_retrieve_done`. `poll_finished` simply drains these dicts. + +## 4. Server-side: CPU Context Protocol + +### 4.1 Why GPU context and CPU context need different protocols + +| | GPU context | CPU context | |---|---|---| -| `prepare_store` | `pickle.dumps` | MQ `PREPARE_STORE` → slot metadata → memcpy | -| `commit_store` | MQ `STORE_CPU_CHUNKS` | MQ `COMMIT_STORE` | -| `prepare_retrieve` | MQ `RETRIEVE_CPU_CHUNKS` + `pickle.loads` | MQ `PREPARE_RETRIEVE` → tensor views from SHM | -| `commit_retrieve` | no-op | MQ `FINISH_READ` (release read lock) | +| Registration | `REGISTER_KV_CACHE` — IPC handles | `REGISTER_KV_CACHE_CPU_CONTEXT` — scalar fields | +| Store | `STORE` — event handle + block IDs, server reads GPU directly | `STORE_CPU_CHUNKS` — serialised CPU tensors | +| Retrieve | `RETRIEVE` — event handle + block IDs, server writes GPU directly | `RETRIEVE_CPU_CHUNKS` — key lookup, returns CPU tensors | + +Registration uses **scalar fields** (`block_size`, `num_layers`, +`hidden_dim_size`, `dtype_str`, `use_mla`) instead of pickled objects +to avoid cross-process pickle security and compatibility concerns. The +server reconstructs `MemoryLayoutDesc` from the scalars internally. + +### 4.2 `CPUContext` ABC: two-phase prepare/commit -The `create_cpu_context()` factory will probe for SHM availability and fall -back to pickle when SHM is unavailable. +The serialisation layer is abstracted behind `CPUContext` so that pickle +and SHM can be swapped without touching `CPUTransferContext` or the server. + +The ABC defines: `prepare_store`, `commit_store`, `prepare_retrieve`, +`commit_retrieve`, `close`. + +Why two phases? Pickle can do everything in one step (prepare serialises, +commit sends). SHM needs prepare to allocate a slot, then the worker writes +into mapped memory, then commit tells the server "ready". The split +accommodates both without forcing unnecessary round-trips on pickle. + +| Phase | Pickle | SHM (TODO) | +|---|---|---| +| `prepare_store` | `pickle.dumps(chunks)` → opaque handle | MQ `PREPARE_STORE` → get SHM offset → `memcpy` into SHM | +| `commit_store` | MQ `STORE_CPU_CHUNKS`, block for ack | MQ `COMMIT_STORE` → server reads from SHM | +| `prepare_retrieve` | MQ `RETRIEVE_CPU_CHUNKS` → `pickle.loads` | MQ `PREPARE_RETRIEVE` → server writes to SHM → map tensor views | +| `commit_retrieve` | no-op | MQ `FINISH_READ` → release SHM read lock | + +`create_cpu_context()` factory currently always returns `CPUContextPickle`. +Future: probe `/dev/shm` availability and capacity, fall back to pickle if +insufficient. + +## 5. Data Path: Gather / Scatter + +### 5.1 Chunk format + +- **Non-MLA**: `[2, num_layers, chunk_tokens, hidden_dim]` — dim 0 = `(K, V)`. +- **MLA**: `[num_layers, chunk_tokens, hidden_dim]` — single latent vector. + +Where `chunk_tokens = blocks_per_chunk × block_size`. + +### 5.2 Supported KV layouts + +| Format enum | Layout | Shape per layer | +|---|---|---| +| `NL_X_TWO_NB_BS_NH_HS` | NHD | `[2, NB, BS, NH, HS]` | +| `NL_X_NB_TWO_BS_NH_HS` | NHD (flashinfer) | `[NB, 2, BS, NH, HS]` | +| `NL_X_TWO_NB_NH_BS_HS` | HND | `[2, NB, NH, BS, HS]` | +| `NL_X_NB_TWO_NH_BS_HS` | HND (flashinfer) | `[NB, 2, NH, BS, HS]` | +| `NL_X_NB_BS_HS` | MLA | `[NB, BS, HS]` | -## Validation coverage +### 5.3 Block-level indexing -`tests/v1/multiprocess/test_cpu_context.py` covers: +Gather and scatter operate at **block granularity** (`tensor[block_ids]`) +rather than per-token `index_select` / `index_copy_`. For HND layouts, a +`permute(0, 2, 1, 3)` converts between head-major and token-major order. -- CPU wrapper behavior (`wrap_kv_caches` with cpu context mode) -- NHD and MLA gather/scatter round-trip -- HND round-trip for both HND formats -- `skip_first_n_tokens` behavior -- Server-side register/store/retrieve flow +### 5.4 Utility functions -`tests/v1/test_vllm_mp_adapter.py` covers transfer-context integration, -including CPU registration path (`REGISTER_KV_CACHE_CPU_CONTEXT`) and -store/retrieve submit delegation. +- **`compute_kv_layout`** — extracts `(block_size, num_layers, hidden_dim_size, dtype_str, gpu_kv_format)` from live KV tensors. +- **`gather_paged_kv_to_cpu`** — gathers paged blocks into CPU chunk tensors. +- **`scatter_cpu_to_paged_kv`** — scatters CPU chunks back into device paged KV tensors. Respects `skip_first_n_tokens` for partial-prefix retrieval. ## Non-goals - No change to existing CUDA IPC path semantics. - No CPU-specific logic added to shared `gpu_connector/utils.py`. +- No wire-protocol incompatibility between CUDA and CPU context workers in + the same cluster. From f1f18240eea1f391b8bde6221aa9d7230bd380c6 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 15 May 2026 03:35:24 +0000 Subject: [PATCH 11/23] rebase dsv4: propagate vllm_logical_block_size through TransferContext.register() to restore DeepSeek V4 compress_ratio Signed-off-by: Tony Lin --- lmcache/integration/vllm/vllm_multi_process_adapter.py | 1 + lmcache/v1/multiprocess/transfer_context.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 8b60a71ceb2..6b06f814838 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -862,6 +862,7 @@ def _send_register_kv_caches_request( self.mq_client, self._mq_timeout, send_request=send_lmcache_request, + vllm_logical_block_size=self.vllm_logical_block_size, ) except TimeoutError: raise ConnectionError( diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 89047a3983d..d16dc157cdd 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -58,6 +58,7 @@ def register( mq_client: MessageQueueClient, mq_timeout: float, send_request: SendRequest, + vllm_logical_block_size: int = 0, ) -> None: """Register KV caches with the server and wait for ACK. @@ -70,6 +71,8 @@ def register( mq_client: Message queue client used to communicate with server. mq_timeout: Timeout in seconds for synchronous request wait. send_request: Request sender callable used to issue MQ requests. + vllm_logical_block_size: vLLM logical block size used to derive + per-layer-group compression ratios on the server side. Raises: TimeoutError: If server registration does not complete before @@ -159,6 +162,7 @@ def register( mq_client: MessageQueueClient, mq_timeout: float, send_request: SendRequest, + vllm_logical_block_size: int = 0, ) -> None: # First Party from lmcache.integration.vllm.utils import vllm_layout_hints @@ -167,6 +171,7 @@ def register( self._mq_client = mq_client self._send_request = send_request layout_hints = vllm_layout_hints() + layout_hints["inference_engine_logical_block_size"] = vllm_logical_block_size future = send_request( mq_client, RequestType.REGISTER_KV_CACHE, @@ -247,11 +252,13 @@ def register( mq_client: MessageQueueClient, mq_timeout: float, send_request: SendRequest, + vllm_logical_block_size: int = 0, ) -> None: # First Party from lmcache.integration.vllm.utils import vllm_layout_hints layout_hints = vllm_layout_hints() + layout_hints["inference_engine_logical_block_size"] = vllm_logical_block_size ( block_size, num_layers, From f8d93b9e456a863a4ec87ff786c5578c80028c9f Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 15 May 2026 05:29:13 +0000 Subject: [PATCH 12/23] rename to more general names: non_gpu_context & non_cuda_transfer_context Signed-off-by: Tony Lin --- ...xt_design.md => non_gpu_context_design.md} | 58 +++++++++---------- .../{cpu_context.py => non_gpu_context.py} | 54 ++++++++--------- ...xt_pickle.py => non_gpu_context_pickle.py} | 15 +++-- lmcache/v1/multiprocess/protocols/base.py | 2 +- lmcache/v1/multiprocess/protocols/engine.py | 4 +- lmcache/v1/multiprocess/server.py | 50 ++++++++-------- lmcache/v1/multiprocess/transfer_context.py | 48 +++++++-------- ...pu_context.py => test_non_cuda_context.py} | 42 +++++++++----- tests/v1/test_vllm_mp_adapter.py | 6 +- 9 files changed, 147 insertions(+), 132 deletions(-) rename docs/design/v1/multiprocess/{cpu_context_design.md => non_gpu_context_design.md} (82%) rename lmcache/v1/multiprocess/{cpu_context.py => non_gpu_context.py} (89%) rename lmcache/v1/multiprocess/{cpu_context_pickle.py => non_gpu_context_pickle.py} (91%) rename tests/v1/multiprocess/{test_cpu_context.py => test_non_cuda_context.py} (90%) diff --git a/docs/design/v1/multiprocess/cpu_context_design.md b/docs/design/v1/multiprocess/non_gpu_context_design.md similarity index 82% rename from docs/design/v1/multiprocess/cpu_context_design.md rename to docs/design/v1/multiprocess/non_gpu_context_design.md index c2345e40958..24f9d19fc2e 100644 --- a/docs/design/v1/multiprocess/cpu_context_design.md +++ b/docs/design/v1/multiprocess/non_gpu_context_design.md @@ -1,4 +1,4 @@ -# CPU Context Design (MP mode, non-CUDA) +# Non-GPU Context Design (MP mode, non-CUDA) ## 1. Motivation @@ -19,7 +19,7 @@ This design is fundamentally tied to the CUDA programming model: For non-CUDA accelerators — **CPU, Intel XPU, Habana HPU**, or any future device — none of these primitives are available. -The **CPU context** path introduces a device-agnostic KV transfer mechanism: +The **non-GPU context** path introduces a device-agnostic KV transfer mechanism: 1. Workers **gather** paged KV blocks into contiguous CPU chunk tensors. 2. CPU chunks are **transported** to the server through a pluggable @@ -64,18 +64,18 @@ polymorphic `TransferContext` abstraction. vllm_multi_process_adapter.py ← Engine adapter, device-agnostic └── TransferContext ← Worker-side transport abstraction (§3) ├── CudaTransferContext ← CUDA IPC + MQ future path - └── CPUTransferContext ← Synchronous gather/scatter path - └── CPUContext ← Serialisation abstraction (§4.2) - ├── CPUContextPickle ← pickle.dumps/loads (§4.3) - └── CPUContextShm ← shared memory (§4.4, TODO) + └── NonCudaTransferContext ← Synchronous gather/scatter path + └── NonGpuContext ← Serialisation abstraction (§4.2) + ├── NonGpuContextPickle ← pickle.dumps/loads (§4.3) + └── NonGpuContextShm ← shared memory (§4.4, TODO) ``` Two layers of abstraction serve different purposes: - **TransferContext** (§3) — decides **CUDA vs non-CUDA** routing at the worker adapter level. -- **CPUContext** (§4.2) — decides **how** CPU chunk data is serialised and - transported (pickle vs SHM). Only used inside `CPUTransferContext`. +- **NonGpuContext** (§4.2) — decides **how** CPU chunk data is serialised and + transported (pickle vs SHM). Only used inside `NonCudaTransferContext`. ### 2.2 State machine (worker ↔ server) @@ -91,10 +91,10 @@ Two layers of abstraction serve different purposes: [device == cuda] [device != cuda] | | v v - CudaTransferContext.register() CPUTransferContext.register() - → REGISTER_KV_CACHE → REGISTER_KV_CACHE_CPU_CONTEXT + CudaTransferContext.register() NonCudaTransferContext.register() + → REGISTER_KV_CACHE → REGISTER_KV_CACHE_NON_GPU_CONTEXT (CUDA IPC handles) (scalar metadata fields) - | + create_cpu_context() + | + create_non_gpu_context() +----------------+----------------+ | v @@ -107,8 +107,8 @@ Two layers of abstraction serve different purposes: | | v v STORE (GPU → L1) gather_paged_kv_to_cpu() - [async MQ future] + _cpu_context.prepare_store() - | + _cpu_context.commit_store() [sync] + [async MQ future] + _non_gpu_context.prepare_store() + | + _non_gpu_context.commit_store() [sync] v _store_done[id] = ok [READY] | +----------------+----------------+ @@ -119,9 +119,9 @@ Two layers of abstraction serve different purposes: +----------------+----------------+ | | v v - RETRIEVE (L1 → GPU) _cpu_context.prepare_retrieve() [sync] + RETRIEVE (L1 → GPU) _non_gpu_context.prepare_retrieve() [sync] [async MQ future] + scatter_cpu_to_paged_kv() - | + _cpu_context.commit_retrieve() + | + _non_gpu_context.commit_retrieve() v _retrieve_done[id] = (ok, block_ids) +----------------+----------------+ | @@ -140,7 +140,7 @@ Two layers of abstraction serve different purposes: ### 3.1 Problem Before this refactoring, `vllm_multi_process_adapter.py` contained -`if self._use_cpu_context:` branches in every method — `register_kv_caches`, +non-CUDA-specific branching in every method — `register_kv_caches`, `submit_store_request`, `submit_retrieve_request`, `get_finished`, and the unhealthy drain path. Adding a third transport would require touching every branch. @@ -155,7 +155,7 @@ no `if/else` anywhere. ### 3.3 `create_transfer_context()` factory Inspects device types of all KV cache tensors **exactly once**. CUDA → -`CudaTransferContext`; otherwise → `CPUTransferContext`. Mixed device types +`CudaTransferContext`; otherwise → `NonCudaTransferContext`. Mixed device types are rejected. ### 3.4 `CudaTransferContext` @@ -165,20 +165,20 @@ Wraps the original CUDA IPC path. Sends `REGISTER_KV_CACHE` / `STORE` / `poll_finished` queries futures; `drain_all` marks all pending as finished for unhealthy shutdown. Semantics identical to pre-refactoring. -### 3.5 `CPUTransferContext` +### 3.5 `NonCudaTransferContext` -Holds a `CPUContext` instance internally. Sends -`REGISTER_KV_CACHE_CPU_CONTEXT` with scalar metadata. Store and retrieve +Holds a `NonGpuContext` instance internally. Sends +`REGISTER_KV_CACHE_NON_GPU_CONTEXT` with scalar metadata. Store and retrieve are **synchronous**: gather → prepare/commit, then record result in `_store_done` / `_retrieve_done`. `poll_finished` simply drains these dicts. -## 4. Server-side: CPU Context Protocol +## 4. Server-side: Non-GPU Context Protocol -### 4.1 Why GPU context and CPU context need different protocols +### 4.1 Why GPU context and non-GPU context need different protocols -| | GPU context | CPU context | +| | GPU context | non-GPU context | |---|---|---| -| Registration | `REGISTER_KV_CACHE` — IPC handles | `REGISTER_KV_CACHE_CPU_CONTEXT` — scalar fields | +| Registration | `REGISTER_KV_CACHE` — IPC handles | `REGISTER_KV_CACHE_NON_GPU_CONTEXT` — scalar fields | | Store | `STORE` — event handle + block IDs, server reads GPU directly | `STORE_CPU_CHUNKS` — serialised CPU tensors | | Retrieve | `RETRIEVE` — event handle + block IDs, server writes GPU directly | `RETRIEVE_CPU_CHUNKS` — key lookup, returns CPU tensors | @@ -187,10 +187,10 @@ Registration uses **scalar fields** (`block_size`, `num_layers`, to avoid cross-process pickle security and compatibility concerns. The server reconstructs `MemoryLayoutDesc` from the scalars internally. -### 4.2 `CPUContext` ABC: two-phase prepare/commit +### 4.2 `NonGpuContext` ABC: two-phase prepare/commit -The serialisation layer is abstracted behind `CPUContext` so that pickle -and SHM can be swapped without touching `CPUTransferContext` or the server. +The serialisation layer is abstracted behind `NonGpuContext` so that pickle +and SHM can be swapped without touching `NonCudaTransferContext` or the server. The ABC defines: `prepare_store`, `commit_store`, `prepare_retrieve`, `commit_retrieve`, `close`. @@ -207,7 +207,7 @@ accommodates both without forcing unnecessary round-trips on pickle. | `prepare_retrieve` | MQ `RETRIEVE_CPU_CHUNKS` → `pickle.loads` | MQ `PREPARE_RETRIEVE` → server writes to SHM → map tensor views | | `commit_retrieve` | no-op | MQ `FINISH_READ` → release SHM read lock | -`create_cpu_context()` factory currently always returns `CPUContextPickle`. +`create_non_gpu_context()` factory currently always returns `NonGpuContextPickle`. Future: probe `/dev/shm` availability and capacity, fall back to pickle if insufficient. @@ -246,5 +246,5 @@ rather than per-token `index_select` / `index_copy_`. For HND layouts, a - No change to existing CUDA IPC path semantics. - No CPU-specific logic added to shared `gpu_connector/utils.py`. -- No wire-protocol incompatibility between CUDA and CPU context workers in +- No wire-protocol incompatibility between CUDA and non-GPU context workers in the same cluster. diff --git a/lmcache/v1/multiprocess/cpu_context.py b/lmcache/v1/multiprocess/non_gpu_context.py similarity index 89% rename from lmcache/v1/multiprocess/cpu_context.py rename to lmcache/v1/multiprocess/non_gpu_context.py index f627f86920d..32f2d50e381 100644 --- a/lmcache/v1/multiprocess/cpu_context.py +++ b/lmcache/v1/multiprocess/non_gpu_context.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -"""CPU context abstractions and utilities for multiprocess mode. +"""Non-GPU context abstractions and utilities for multiprocess mode. This module provides: -- ``CPUContextMetadata``: layout metadata dataclass for non-CUDA workers. -- ``CPUContext``: abstract base class with a two-phase prepare/commit interface - for CPU-side KV data transfer. Concrete implementations (e.g. - ``CPUContextPickle``) each decide *how* data is serialised and transported. -- ``create_cpu_context()``: factory that returns the appropriate - ``CPUContext`` subclass (currently always ``CPUContextPickle``). +- ``NonGpuContextMetadata``: layout metadata dataclass for non-CUDA workers. +- ``NonGpuContext``: abstract base class with a two-phase prepare/commit + interface for CPU-side KV data transfer. Concrete implementations (e.g. + ``NonGpuContextPickle``) each decide *how* data is serialised and transported. +- ``create_non_gpu_context()``: factory that returns the appropriate + ``NonGpuContext`` subclass (currently always ``NonGpuContextPickle``). - ``compute_kv_layout``, ``gather_paged_kv_to_cpu``, ``scatter_cpu_to_paged_kv``: shared gather/scatter utilities used by all concrete implementations. """ @@ -26,8 +26,8 @@ @dataclass -class CPUContextMetadata: - """CPU context layout metadata for non-CUDA workers. +class NonGpuContextMetadata: + """Non-GPU context layout metadata for non-CUDA workers. Attributes: layout_desc: Memory layout descriptor used to interpret chunk payloads. @@ -40,7 +40,7 @@ class CPUContextMetadata: use_mla: bool -class CPUContext(ABC): +class NonGpuContext(ABC): """Abstract base class for CPU-side KV data transfer contexts. All concrete implementations share a common message-queue client and @@ -55,7 +55,7 @@ class CPUContext(ABC): def __init__( self, - metadata: CPUContextMetadata, + metadata: NonGpuContextMetadata, mq_client: Any, mq_timeout: float, ) -> None: @@ -107,8 +107,8 @@ def prepare_retrieve( instance_id: Worker instance identifier. Returns: - A ``(handle, chunks)`` pair. ``chunks`` is a list of CPU tensors - on cache hit, or ``None`` on cache miss. The handle must be + A ``(handle, chunks)`` pair. ``chunks`` is a list of CPU tensors + on cache hit, or ``None`` on cache miss. The handle must be passed to :meth:`commit_retrieve`. """ ... @@ -128,29 +128,29 @@ def close(self) -> None: ... -def create_cpu_context( - metadata: CPUContextMetadata, +def create_non_gpu_context( + metadata: NonGpuContextMetadata, mq_client: Any, mq_timeout: float, -) -> CPUContext: - """Factory that returns the appropriate :class:`CPUContext` implementation. +) -> NonGpuContext: + """Factory that returns the appropriate :class:`NonGpuContext` implementation. - Currently always returns a :class:`~lmcache.v1.multiprocess.\ -cpu_context_pickle.CPUContextPickle` instance. A future SHM-capable PR + Currently always returns a pickle-based implementation + (``NonGpuContextPickle``). A future SHM-capable PR may probe for shared-memory availability and fall back to pickle. Args: - metadata: Layout metadata for the CPU context. + metadata: Layout metadata for the non-GPU context. mq_client: Message-queue client for server communication. mq_timeout: Timeout in seconds for blocking MQ requests. Returns: - A concrete :class:`CPUContext` instance. + A concrete :class:`NonGpuContext` instance. """ # Local - from .cpu_context_pickle import CPUContextPickle + from .non_gpu_context_pickle import NonGpuContextPickle - return CPUContextPickle(metadata, mq_client, mq_timeout) + return NonGpuContextPickle(metadata, mq_client, mq_timeout) # --------------------------------------------------------------------------- @@ -214,9 +214,9 @@ def gather_paged_kv_to_cpu( gpu_kv_format: Optional pre-detected KV format. Returns: - List of CPU tensors, one per chunk. For non-MLA each chunk has shape + List of CPU tensors, one per chunk. For non-MLA each chunk has shape ``[2, num_layers, chunk_tokens, hidden_dim]`` where dimension ``0`` - stores ``(K, V)``. For MLA (multi-head latent attention) each chunk + stores ``(K, V)``. For MLA (multi-head latent attention) each chunk has shape ``[num_layers, chunk_tokens, hidden_dim]``. """ # First Party @@ -243,7 +243,7 @@ def gather_paged_kv_to_cpu( num_chunks = len(block_ids) // blocks_per_chunk # After normalization the structure is always a list of per-layer - # tensors. Cast once so all downstream indexing is typed correctly. + # tensors. Cast once so all downstream indexing is typed correctly. layer_tensors = cast(list[torch.Tensor], normalized) chunks: list[torch.Tensor] = [] @@ -362,7 +362,7 @@ def scatter_cpu_to_paged_kv( ) # After normalization the structure is always a list of per-layer - # tensors. Cast once so all downstream indexing is typed correctly. + # tensors. Cast once so all downstream indexing is typed correctly. layer_tensors = cast(list[torch.Tensor], normalized) for chunk_idx, chunk_cpu in enumerate(chunks): diff --git a/lmcache/v1/multiprocess/cpu_context_pickle.py b/lmcache/v1/multiprocess/non_gpu_context_pickle.py similarity index 91% rename from lmcache/v1/multiprocess/cpu_context_pickle.py rename to lmcache/v1/multiprocess/non_gpu_context_pickle.py index 95bf7f6127c..a78f27b4ebd 100644 --- a/lmcache/v1/multiprocess/cpu_context_pickle.py +++ b/lmcache/v1/multiprocess/non_gpu_context_pickle.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""Pickle-based CPUContext implementation for multiprocess mode.""" +"""Pickle-based NonGpuContext implementation for multiprocess mode.""" # Standard from typing import Any @@ -9,12 +9,15 @@ import torch # First Party -from lmcache.v1.multiprocess.cpu_context import CPUContext, CPUContextMetadata +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContext, + NonGpuContextMetadata, +) from lmcache.v1.multiprocess.protocol import RequestType, get_response_class -class CPUContextPickle(CPUContext): - """Pickle-based implementation of :class:`CPUContext`. +class NonGpuContextPickle(NonGpuContext): + """Pickle-based implementation of :class:`NonGpuContext`. Transport mechanism: - **Store**: ``prepare_store`` serialises chunks with ``pickle.dumps``; \ @@ -25,14 +28,14 @@ class CPUContextPickle(CPUContext): ``pickle.loads``; ``commit_retrieve`` is a no-op (no locks to release). Args: - metadata: Layout metadata for the CPU context. + metadata: Layout metadata for the non-GPU context. mq_client: Message-queue client for server communication. mq_timeout: Timeout in seconds for blocking MQ requests. """ def __init__( self, - metadata: CPUContextMetadata, + metadata: NonGpuContextMetadata, mq_client: Any, mq_timeout: float, ) -> None: diff --git a/lmcache/v1/multiprocess/protocols/base.py b/lmcache/v1/multiprocess/protocols/base.py index 85d033cec60..0d82b2754c3 100644 --- a/lmcache/v1/multiprocess/protocols/base.py +++ b/lmcache/v1/multiprocess/protocols/base.py @@ -48,7 +48,7 @@ class RequestType(enum.Enum): QUERY_PREFETCH_LOOKUP_HITS = enum.auto() FREE_LOOKUP_LOCKS = enum.auto() END_SESSION = enum.auto() - REGISTER_KV_CACHE_CPU_CONTEXT = enum.auto() + REGISTER_KV_CACHE_NON_GPU_CONTEXT = enum.auto() STORE_CPU_CHUNKS = enum.auto() RETRIEVE_CPU_CHUNKS = enum.auto() diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index 3fdcdd9edbe..e3f80b6a6c7 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -32,7 +32,7 @@ "QUERY_PREFETCH_LOOKUP_HITS", "FREE_LOOKUP_LOCKS", "END_SESSION", - "REGISTER_KV_CACHE_CPU_CONTEXT", + "REGISTER_KV_CACHE_NON_GPU_CONTEXT", "STORE_CPU_CHUNKS", "RETRIEVE_CPU_CHUNKS", ] @@ -149,7 +149,7 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: response_class=None, handler_type=HandlerType.BLOCKING, ), - "REGISTER_KV_CACHE_CPU_CONTEXT": ProtocolDefinition( + "REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition( payload_classes=[ int, str, diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 4b8f5afa226..07cb129bba5 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -53,7 +53,6 @@ add_mp_server_args, parse_args_to_mp_server_config, ) -from lmcache.v1.multiprocess.cpu_context import CPUContextMetadata from lmcache.v1.multiprocess.custom_types import ( BlockAllocationRecord, IPCCacheEngineKey, @@ -63,6 +62,7 @@ GPUCacheContext, ) from lmcache.v1.multiprocess.mq import MessageQueueServer +from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata from lmcache.v1.multiprocess.protocol import ( RequestType, get_handler_type, @@ -192,8 +192,8 @@ def __init__( # We assume that if the (model name, world size) is the same, then # the layout desc returned by the gpu context is the same. self.gpu_context_meta: dict[int, tuple[str, int]] = {} - self.cpu_contexts: dict[int, CPUContextMetadata] = {} - self.cpu_context_meta: dict[int, tuple[str, int]] = {} + self.non_cuda_contexts: dict[int, NonGpuContextMetadata] = {} + self.non_cuda_context_meta: dict[int, tuple[str, int]] = {} # chunk size self.chunk_size = chunk_size @@ -278,14 +278,14 @@ def unregister_kv_cache(self, instance_id: int) -> None: del self.gpu_context_meta[instance_id] logger.info("Unregistered KV cache for GPU ID %d", instance_id) torch_dev.empty_cache() - elif instance_id in self.cpu_contexts: - del self.cpu_contexts[instance_id] - del self.cpu_context_meta[instance_id] - logger.info("Unregistered CPU context for instance ID %d", instance_id) + elif instance_id in self.non_cuda_contexts: + del self.non_cuda_contexts[instance_id] + del self.non_cuda_context_meta[instance_id] + logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) else: logger.warning("No KV cache found for GPU ID %d to unregister", instance_id) - def register_kv_cache_cpu_context( + def register_kv_cache_non_gpu_context( self, instance_id: int, model_name: str, @@ -296,7 +296,7 @@ def register_kv_cache_cpu_context( dtype_str: str, use_mla: bool, ) -> None: - """Register non-CUDA KV layout metadata for CPU context mode. + """Register non-CUDA KV layout metadata for non-GPU context mode. Args: instance_id: Worker instance identifier (typically PID). @@ -325,12 +325,12 @@ def register_kv_cache_cpu_context( else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) ) layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) - self.cpu_contexts[instance_id] = CPUContextMetadata( + self.non_cuda_contexts[instance_id] = NonGpuContextMetadata( layout_desc=layout_desc, block_size=block_size, use_mla=use_mla, ) - self.cpu_context_meta[instance_id] = (model_name, world_size) + self.non_cuda_context_meta[instance_id] = (model_name, world_size) def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: """Resolve object keys from an IPC cache key. @@ -375,11 +375,11 @@ def store_cpu_chunks( """ obj_keys = self._resolve_obj_keys(key) - if instance_id not in self.cpu_contexts: + if instance_id not in self.non_cuda_contexts: raise ValueError( - f"CPU context not registered for instance ID {instance_id}" + f"non-CUDA context not registered for instance ID {instance_id}" ) - ctx = self.cpu_contexts[instance_id] + ctx = self.non_cuda_contexts[instance_id] chunks: list[torch.Tensor] = pickle.loads(cpu_data) reserved_dict = self.storage_manager.reserve_write( obj_keys, ctx.layout_desc, "new" @@ -426,9 +426,9 @@ def retrieve_cpu_chunks( """ obj_keys = self._resolve_obj_keys(key) - if instance_id not in self.cpu_contexts: + if instance_id not in self.non_cuda_contexts: raise ValueError( - f"CPU context not registered for instance ID {instance_id}" + f"non-CUDA context not registered for instance ID {instance_id}" ) prefetched_keys: list[ObjectKey] = [] @@ -843,9 +843,9 @@ def _find_layout_desc( self.gpu_contexts[gpu_id], self.chunk_size, ) - for instance_id, (m, w) in self.cpu_context_meta.items(): + for instance_id, (m, w) in self.non_cuda_context_meta.items(): if m == model_name and w == world_size: - return self.cpu_contexts[instance_id].layout_desc + return self.non_cuda_contexts[instance_id].layout_desc return None def lookup( @@ -1194,16 +1194,16 @@ def report_status(self) -> dict: "hash_algorithm": self.token_hasher.hash_algorithm_name, "registered_gpu_ids": list(self.gpu_contexts.keys()), "gpu_context_meta": gpu_context_meta, - "registered_cpu_instance_ids": list(self.cpu_contexts.keys()), - "cpu_context_meta": { + "registered_non_cuda_instance_ids": list(self.non_cuda_contexts.keys()), + "non_cuda_context_meta": { str(instance_id): { "model_name": model_name, "world_size": world_size, - "block_size": self.cpu_contexts[instance_id].block_size, - "use_mla": self.cpu_contexts[instance_id].use_mla, + "block_size": self.non_cuda_contexts[instance_id].block_size, + "use_mla": self.non_cuda_contexts[instance_id].use_mla, } for instance_id, (model_name, world_size) in ( - self.cpu_context_meta.items() + self.non_cuda_context_meta.items() ) }, "active_sessions": self.session_manager.active_count(), @@ -1343,8 +1343,8 @@ def run_cache_server( add_handler_helper(server, RequestType.STORE, engine.store) add_handler_helper( server, - RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, - engine.register_kv_cache_cpu_context, + RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, + engine.register_kv_cache_non_gpu_context, ) add_handler_helper(server, RequestType.STORE_CPU_CHUNKS, engine.store_cpu_chunks) add_handler_helper(server, RequestType.LOOKUP, engine.lookup) diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index d16dc157cdd..24593ec80cf 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -13,16 +13,16 @@ from lmcache.utils import EngineType, init_logger from lmcache.v1.distributed.api import MemoryLayoutDesc from lmcache.v1.gpu_connector.utils import is_mla -from lmcache.v1.multiprocess.cpu_context import ( - CPUContext, - CPUContextMetadata, +from lmcache.v1.multiprocess.futures import MessagingFuture +from lmcache.v1.multiprocess.mq import MessageQueueClient +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContext, + NonGpuContextMetadata, compute_kv_layout, - create_cpu_context, + create_non_gpu_context, gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) -from lmcache.v1.multiprocess.futures import MessagingFuture -from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import RequestType logger = init_logger(__name__) @@ -234,11 +234,11 @@ def close(self) -> None: self._send_request = None -class CPUTransferContext(TransferContext): - """CPU context transport for non-CUDA workers.""" +class NonCudaTransferContext(TransferContext): + """Non-CUDA context transport for non-CUDA workers.""" def __init__(self) -> None: - self._cpu_context: CPUContext | None = None + self._non_gpu_context: NonGpuContext | None = None self._layout_hints: Any = None self._gpu_kv_format: Any = None @@ -282,7 +282,7 @@ def register( future = send_request( mq_client, - RequestType.REGISTER_KV_CACHE_CPU_CONTEXT, + RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, [ instance_id, model_name, @@ -295,12 +295,12 @@ def register( ], ) - metadata = CPUContextMetadata( + metadata = NonGpuContextMetadata( layout_desc=layout_desc, block_size=block_size, use_mla=use_mla_flag, ) - self._cpu_context = create_cpu_context(metadata, mq_client, mq_timeout) + self._non_gpu_context = create_non_gpu_context(metadata, mq_client, mq_timeout) future.result(timeout=mq_timeout) def submit_store( @@ -313,9 +313,9 @@ def submit_store( _event: IPCEvent, blocks_in_chunk: int, ) -> MessagingFuture: - if self._cpu_context is None: + if self._non_gpu_context is None: raise RuntimeError( - "CPU transfer context is not registered. " + "Non-CUDA transfer context is not registered. " "Call register() before submit_store()." ) @@ -327,8 +327,8 @@ def submit_store( layout_hints=self._layout_hints, gpu_kv_format=self._gpu_kv_format, ) - handle = self._cpu_context.prepare_store(key, instance_id, cpu_chunks) - ok = self._cpu_context.commit_store(handle) + handle = self._non_gpu_context.prepare_store(key, instance_id, cpu_chunks) + ok = self._non_gpu_context.commit_store(handle) future: MessagingFuture[bool] = MessagingFuture() future.set_result(ok) @@ -345,13 +345,13 @@ def submit_retrieve( blocks_in_chunk: int, skip_first_n_tokens: int = 0, ) -> MessagingFuture: - if self._cpu_context is None: + if self._non_gpu_context is None: raise RuntimeError( - "CPU transfer context is not registered. " + "Non-CUDA transfer context is not registered. " "Call register() before submit_retrieve()." ) - handle, chunks = self._cpu_context.prepare_retrieve(key, instance_id) + handle, chunks = self._non_gpu_context.prepare_retrieve(key, instance_id) ok = chunks is not None if chunks is not None: try: @@ -367,16 +367,16 @@ def submit_retrieve( except (RuntimeError, ValueError, TypeError, IndexError): logger.exception("Failed to scatter retrieved CPU context chunks") ok = False - self._cpu_context.commit_retrieve(handle) + self._non_gpu_context.commit_retrieve(handle) future: MessagingFuture[bool] = MessagingFuture() future.set_result(ok) return future def close(self) -> None: - if self._cpu_context is not None: - self._cpu_context.close() - self._cpu_context = None + if self._non_gpu_context is not None: + self._non_gpu_context.close() + self._non_gpu_context = None def create_transfer_context( @@ -408,4 +408,4 @@ def create_transfer_context( logger.info("Creating transfer context (device_type=%s)", device_type) if device_type == "cuda": return CudaTransferContext() - return CPUTransferContext() + return NonCudaTransferContext() diff --git a/tests/v1/multiprocess/test_cpu_context.py b/tests/v1/multiprocess/test_non_cuda_context.py similarity index 90% rename from tests/v1/multiprocess/test_cpu_context.py rename to tests/v1/multiprocess/test_non_cuda_context.py index 5a6a4c0ddc4..b7cb30dcd32 100644 --- a/tests/v1/multiprocess/test_cpu_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -18,7 +18,7 @@ def _make_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer NHD KV tensors for CPU cpu context tests.""" + """Build per-layer NHD KV tensors for non-CUDA context tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -33,7 +33,7 @@ def _make_mla_kv_caches( block_size: int = 4, hidden_size: int = 16, ) -> dict[str, torch.Tensor]: - """Build per-layer MLA KV tensors for CPU cpu context tests. + """Build per-layer MLA KV tensors for non-CUDA context tests. Args: num_layers: Number of KV layers to generate. @@ -58,7 +58,7 @@ def _make_hnd_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer HND KV tensors for CPU cpu context tests.""" + """Build per-layer HND KV tensors for non-CUDA context tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -74,7 +74,7 @@ def _make_hnd_flashinfer_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer HND flash-infer KV tensors for CPU cpu context tests.""" + """Build per-layer HND flash-infer KV tensors for non-CUDA context tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -99,10 +99,22 @@ def test_wrap_kv_caches_wraps_all_tensors(monkeypatch: Any) -> None: assert len(wrapped) == len(kv_caches) +def test_create_transfer_context_uses_non_cuda_context_on_cpu() -> None: + """Ensure transfer context factory returns NonCudaTransferContext for CPU KV.""" + # First Party + from lmcache.v1.multiprocess.transfer_context import ( + NonCudaTransferContext, + create_transfer_context, + ) + + context = create_transfer_context({"layer_0": torch.randn(2, 2)}) + assert isinstance(context, NonCudaTransferContext) + + def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: """Validate layout extraction and gather/scatter round-trip on CPU tensors.""" # First Party - from lmcache.v1.multiprocess.cpu_context import ( + from lmcache.v1.multiprocess.non_gpu_context import ( compute_kv_layout, gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, @@ -145,7 +157,7 @@ def test_gather_scatter_roundtrip_hnd_layout( ) -> None: """Validate gather/scatter round-trip for HND vLLM KV layout.""" # First Party - from lmcache.v1.multiprocess.cpu_context import ( + from lmcache.v1.multiprocess.non_gpu_context import ( compute_kv_layout, gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, @@ -197,7 +209,7 @@ def test_gather_scatter_roundtrip_hnd_layout( def test_scatter_respects_skip_first_n_tokens() -> None: """Ensure scatter honors skip_first_n_tokens and preserves skipped blocks.""" # First Party - from lmcache.v1.multiprocess.cpu_context import ( + from lmcache.v1.multiprocess.non_gpu_context import ( gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) @@ -225,7 +237,7 @@ def test_scatter_respects_skip_first_n_tokens() -> None: def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: """Validate gather/scatter round-trip for MLA KV tensors.""" # First Party - from lmcache.v1.multiprocess.cpu_context import ( + from lmcache.v1.multiprocess.non_gpu_context import ( compute_kv_layout, gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, @@ -260,7 +272,7 @@ def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: def test_compute_kv_layout_empty_raises_value_error() -> None: """Ensure compute_kv_layout rejects empty KV cache input.""" # First Party - from lmcache.v1.multiprocess.cpu_context import compute_kv_layout + from lmcache.v1.multiprocess.non_gpu_context import compute_kv_layout with pytest.raises(ValueError, match="kv_caches is empty"): compute_kv_layout({}) @@ -269,7 +281,7 @@ def test_compute_kv_layout_empty_raises_value_error() -> None: def test_scatter_mla_respects_skip_first_n_tokens() -> None: """Ensure MLA scatter honors skip_first_n_tokens and preserves skipped blocks.""" # First Party - from lmcache.v1.multiprocess.cpu_context import ( + from lmcache.v1.multiprocess.non_gpu_context import ( gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) @@ -299,7 +311,7 @@ def test_scatter_mla_respects_skip_first_n_tokens() -> None: def test_scatter_mla_skip_past_chunk_keeps_destination_unchanged() -> None: """Ensure MLA scatter is a no-op when skip_first_n_tokens exceeds chunk tokens.""" # First Party - from lmcache.v1.multiprocess.cpu_context import ( + from lmcache.v1.multiprocess.non_gpu_context import ( gather_paged_kv_to_cpu, scatter_cpu_to_paged_kv, ) @@ -339,10 +351,10 @@ def stub_native_storage_ops() -> Any: yield -def test_server_register_and_find_cpu_context_layout( +def test_server_register_and_find_non_cuda_context_layout( stub_native_storage_ops: Any, ) -> None: - """Ensure cpu context registration stores metadata and lookup finds its layout.""" + """Ensure non-CUDA registration stores metadata and lookup finds layout.""" # First Party from lmcache.v1.multiprocess.server import MPCacheEngine @@ -353,7 +365,7 @@ def test_server_register_and_find_cpu_context_layout( patch("lmcache.v1.multiprocess.server.get_event_bus"), ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) - engine.register_kv_cache_cpu_context( + engine.register_kv_cache_non_gpu_context( instance_id=1, model_name="m", world_size=1, @@ -404,7 +416,7 @@ def _read_prefetched_results(_keys: Any) -> Any: session_cls.return_value.get_or_create.return_value = mock_session engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) - engine.register_kv_cache_cpu_context( + engine.register_kv_cache_non_gpu_context( instance_id=2, model_name="m", world_size=1, diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 944a19077fe..f18404053e6 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -106,10 +106,10 @@ def test_register_kv_caches_raises_connection_error_on_timeout(fake_adapter): adapter.register_kv_caches({"layer.0": fake_tensor}) -def test_register_kv_caches_cpu_submits_cpu_context_registration( +def test_register_kv_caches_cpu_submits_non_gpu_context_registration( fake_adapter, monkeypatch ): - """CPU KV cache registration routes to REGISTER_KV_CACHE_CPU_CONTEXT.""" + """CPU KV cache registration routes to REGISTER_KV_CACHE_NON_GPU_CONTEXT.""" adapter, send_mock, _ = fake_adapter monkeypatch.setattr( "lmcache.integration.vllm.utils.vllm_layout_hints", @@ -123,7 +123,7 @@ def test_register_kv_caches_cpu_submits_cpu_context_registration( assert adapter.kv_caches is cpu_kv assert send_mock.call_count == 1 args, _kwargs = send_mock.call_args - assert args[1] == RequestType.REGISTER_KV_CACHE_CPU_CONTEXT + assert args[1] == RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT assert len(args[2]) == 8 From 598c30fe51b48f6f9c37678a224e054859875010 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 15 May 2026 06:26:30 +0000 Subject: [PATCH 13/23] add todo note for deepseek v4 on non-cuda path Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/transfer_context.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 24593ec80cf..7d59ab323ed 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -258,6 +258,9 @@ def register( from lmcache.integration.vllm.utils import vllm_layout_hints layout_hints = vllm_layout_hints() + + # TODO: inference_engine_logical_block_size is used by deepseek v4 + # which is implemented in cuda path, non cuda path is to be implemented layout_hints["inference_engine_logical_block_size"] = vllm_logical_block_size ( block_size, From b4371e3bed9586596b2ed2c424eaad2bf5340853 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 15 May 2026 07:13:03 +0000 Subject: [PATCH 14/23] Consolidate MPCacheEngine context state into unified registry Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/server.py | 195 ++++++++++++++++++++---------- 1 file changed, 129 insertions(+), 66 deletions(-) diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 07cb129bba5..8b467e6c4c1 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -176,6 +176,45 @@ class _PrefetchJob: cache_salt: str = "" +@dataclass +class RegisteredContext: + """Registered context metadata for a single worker instance. + + At least one of ``gpu_context`` or ``non_cuda_metadata`` is expected to be + populated for valid registrations. + """ + + model_name: str + world_size: int + gpu_context: GPUCacheContext | None = None + non_cuda_metadata: NonGpuContextMetadata | None = None + + @property + def is_gpu(self) -> bool: + """Return whether this registration uses a GPU transfer context.""" + return self.gpu_context is not None + + def get_layout_desc(self, chunk_size: int) -> MemoryLayoutDesc: + """Return the layout descriptor for this registration. + + Args: + chunk_size: Chunk size in tokens used for GPU layout derivation. + + Returns: + The resolved memory layout descriptor. + + Raises: + ValueError: If no GPU context or non-CUDA metadata is configured. + """ + if self.gpu_context is not None: + return get_layout_desc(self.gpu_context, chunk_size) + if self.non_cuda_metadata is None: + raise ValueError( + "Invalid RegisteredContext: no GPU or non-CUDA metadata configured" + ) + return self.non_cuda_metadata.layout_desc + + # Main class for the mp cache engine class MPCacheEngine: def __init__( @@ -184,16 +223,8 @@ def __init__( chunk_size: int = 256, hash_algorithm: str = "blake3", ): - # GPU ID -> KV cache tensors - self.gpu_contexts: dict[int, GPUCacheContext] = {} - - # GPU ID -> (model name, world size) as metadata - # NOTE: This is mainly for determining the layout desc during prefetch - # We assume that if the (model name, world size) is the same, then - # the layout desc returned by the gpu context is the same. - self.gpu_context_meta: dict[int, tuple[str, int]] = {} - self.non_cuda_contexts: dict[int, NonGpuContextMetadata] = {} - self.non_cuda_context_meta: dict[int, tuple[str, int]] = {} + # Worker instance ID -> registered context metadata + self.contexts: dict[int, RegisteredContext] = {} # chunk size self.chunk_size = chunk_size @@ -221,6 +252,15 @@ def __init__( self._setup_metrics() + @property + def gpu_contexts(self) -> dict[int, GPUCacheContext]: + """Return GPU-only context mapping for backward compatibility.""" + return { + instance_id: ctx.gpu_context + for instance_id, ctx in self.contexts.items() + if ctx.gpu_context is not None + } + def register_kv_cache( self, instance_id: int, @@ -244,7 +284,7 @@ def register_kv_cache( layout_hints: See :class:`LayoutHints`. Forwarded to :class:`GPUCacheContext` for GPU KV format detection. """ - if instance_id in self.gpu_contexts: + if instance_id in self.contexts: logger.warning( "Instance %s's KV cache is already registered, " "skipping the new registration", @@ -258,8 +298,11 @@ def register_kv_cache( layout_hints=layout_hints or None, engine_type=engine_type, ) - self.gpu_contexts[instance_id] = gpu_context - self.gpu_context_meta[instance_id] = (model_name, world_size) + self.contexts[instance_id] = RegisteredContext( + model_name=model_name, + world_size=world_size, + gpu_context=gpu_context, + ) logger.info( "Registered KV cache for GPU ID %d with %d layers", instance_id, @@ -273,17 +316,18 @@ def unregister_kv_cache(self, instance_id: int) -> None: Args: instance_id (int): The GPU instance ID (such as PID). """ - if instance_id in self.gpu_contexts: - del self.gpu_contexts[instance_id] - del self.gpu_context_meta[instance_id] + context = self.contexts.pop(instance_id, None) + if context is None: + logger.warning( + "No registered context found for instance ID %d", instance_id + ) + return + + if context.is_gpu: logger.info("Unregistered KV cache for GPU ID %d", instance_id) torch_dev.empty_cache() - elif instance_id in self.non_cuda_contexts: - del self.non_cuda_contexts[instance_id] - del self.non_cuda_context_meta[instance_id] - logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) else: - logger.warning("No KV cache found for GPU ID %d to unregister", instance_id) + logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) def register_kv_cache_non_gpu_context( self, @@ -311,6 +355,14 @@ def register_kv_cache_non_gpu_context( Raises: ValueError: If ``dtype_str`` is not a valid torch dtype name. """ + if instance_id in self.contexts: + logger.warning( + "Instance %s's KV cache is already registered, " + "skipping the new registration", + instance_id, + ) + return + dtype = getattr(torch, dtype_str, None) if dtype is None or not isinstance(dtype, torch.dtype): raise ValueError( @@ -325,12 +377,15 @@ def register_kv_cache_non_gpu_context( else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) ) layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) - self.non_cuda_contexts[instance_id] = NonGpuContextMetadata( - layout_desc=layout_desc, - block_size=block_size, - use_mla=use_mla, + self.contexts[instance_id] = RegisteredContext( + model_name=model_name, + world_size=world_size, + non_cuda_metadata=NonGpuContextMetadata( + layout_desc=layout_desc, + block_size=block_size, + use_mla=use_mla, + ), ) - self.non_cuda_context_meta[instance_id] = (model_name, world_size) def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: """Resolve object keys from an IPC cache key. @@ -375,11 +430,12 @@ def store_cpu_chunks( """ obj_keys = self._resolve_obj_keys(key) - if instance_id not in self.non_cuda_contexts: + context = self.contexts.get(instance_id) + if context is None or context.non_cuda_metadata is None: raise ValueError( f"non-CUDA context not registered for instance ID {instance_id}" ) - ctx = self.non_cuda_contexts[instance_id] + ctx = context.non_cuda_metadata chunks: list[torch.Tensor] = pickle.loads(cpu_data) reserved_dict = self.storage_manager.reserve_write( obj_keys, ctx.layout_desc, "new" @@ -426,7 +482,8 @@ def retrieve_cpu_chunks( """ obj_keys = self._resolve_obj_keys(key) - if instance_id not in self.non_cuda_contexts: + context = self.contexts.get(instance_id) + if context is None or context.non_cuda_metadata is None: raise ValueError( f"non-CUDA context not registered for instance ID {instance_id}" ) @@ -473,11 +530,15 @@ def store( st = time.perf_counter() obj_keys = self._resolve_obj_keys(key) - assert instance_id in self.gpu_contexts, ( - f"KV cache not registered for GPU ID {instance_id}" + context = self.contexts.get(instance_id) + assert context is not None, ( + f"No context registered for instance ID {instance_id}" ) - gpu_context = self.gpu_contexts[instance_id] - model_name = self.gpu_context_meta[instance_id][0] + assert context.gpu_context is not None, ( + f"GPU context not registered for instance ID {instance_id}" + ) + gpu_context = context.gpu_context + model_name = context.model_name # ``blocks_per_chunk`` is counted in inference-engine-side # blocks (each block addresses @@ -656,11 +717,15 @@ def retrieve( st = time.perf_counter() obj_keys = self._resolve_obj_keys(key) - assert instance_id in self.gpu_contexts, ( - f"KV cache not registered for GPU ID {instance_id}" + context = self.contexts.get(instance_id) + assert context is not None, ( + f"No context registered for instance ID {instance_id}" + ) + assert context.gpu_context is not None, ( + f"GPU context not registered for instance ID {instance_id}" ) - gpu_context = self.gpu_contexts[instance_id] - model_name = self.gpu_context_meta[instance_id][0] + gpu_context = context.gpu_context + model_name = context.model_name # CPU-synchronous sentinel: a GPU retrieve is about to be enqueued. # Must be published via publish() (not publish_on_stream) so the @@ -837,15 +902,9 @@ def _find_layout_desc( ``(model_name, world_size)``. GPU contexts are checked first, then CPU contexts. """ - for gpu_id, (m, w) in self.gpu_context_meta.items(): - if m == model_name and w == world_size: - return get_layout_desc( - self.gpu_contexts[gpu_id], - self.chunk_size, - ) - for instance_id, (m, w) in self.non_cuda_context_meta.items(): - if m == model_name and w == world_size: - return self.non_cuda_contexts[instance_id].layout_desc + for context in self.contexts.values(): + if context.model_name == model_name and context.world_size == world_size: + return context.get_layout_desc(self.chunk_size) return None def lookup( @@ -1161,13 +1220,18 @@ def report_status(self) -> dict: sm = self.storage_manager.report_status() gpu_context_meta: dict[str, dict] = {} - for gpu_id, meta in self.gpu_context_meta.items(): + non_cuda_context_meta: dict[str, dict] = {} + registered_gpu_ids: list[int] = [] + registered_non_cuda_ids: list[int] = [] + + for instance_id, context in self.contexts.items(): entry: dict = { - "model_name": meta[0], - "world_size": meta[1], + "model_name": context.model_name, + "world_size": context.world_size, } - ctx = self.gpu_contexts.get(gpu_id) - if ctx is not None: + if context.gpu_context is not None: + registered_gpu_ids.append(instance_id) + ctx = context.gpu_context entry["kv_cache_layout"] = { "num_layers": ctx.num_layers, "inference_engine_logical_block_size": ( @@ -1185,27 +1249,26 @@ def report_status(self) -> dict: "attention_backend": ctx.attention_backend, "cache_size_per_token": ctx.cache_size_per_token(), } - gpu_context_meta[str(gpu_id)] = entry + gpu_context_meta[str(instance_id)] = entry + continue + + if context.non_cuda_metadata is not None: + registered_non_cuda_ids.append(instance_id) + non_cuda_context_meta[str(instance_id)] = { + **entry, + "block_size": context.non_cuda_metadata.block_size, + "use_mla": context.non_cuda_metadata.use_mla, + } return { "is_healthy": sm["is_healthy"], "engine_type": self.__class__.__name__, "chunk_size": self.chunk_size, "hash_algorithm": self.token_hasher.hash_algorithm_name, - "registered_gpu_ids": list(self.gpu_contexts.keys()), + "registered_gpu_ids": registered_gpu_ids, "gpu_context_meta": gpu_context_meta, - "registered_non_cuda_instance_ids": list(self.non_cuda_contexts.keys()), - "non_cuda_context_meta": { - str(instance_id): { - "model_name": model_name, - "world_size": world_size, - "block_size": self.non_cuda_contexts[instance_id].block_size, - "use_mla": self.non_cuda_contexts[instance_id].use_mla, - } - for instance_id, (model_name, world_size) in ( - self.non_cuda_context_meta.items() - ) - }, + "registered_non_cuda_instance_ids": registered_non_cuda_ids, + "non_cuda_context_meta": non_cuda_context_meta, "active_sessions": self.session_manager.active_count(), "active_prefetch_jobs": self._active_prefetch_count(), "storage_manager": sm, @@ -1257,7 +1320,7 @@ def close(self) -> None: logger.info("MPCacheEngine closed") # Release GPU contexts - self.gpu_contexts.clear() + self.contexts.clear() def _active_prefetch_count(self) -> int: """Return the number of active prefetch jobs (thread-safe).""" From a129e9dad7da409bf29c6c95ea6cbeec83a14bec Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 15 May 2026 07:33:19 +0000 Subject: [PATCH 15/23] use dataclass for payload Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/custom_types.py | 24 +++++++++ lmcache/v1/multiprocess/protocols/engine.py | 16 +++--- lmcache/v1/multiprocess/server.py | 49 ++++++++----------- lmcache/v1/multiprocess/transfer_context.py | 19 ++++--- .../v1/multiprocess/test_non_cuda_context.py | 42 +++++++++------- tests/v1/test_vllm_mp_adapter.py | 2 +- 6 files changed, 87 insertions(+), 65 deletions(-) diff --git a/lmcache/v1/multiprocess/custom_types.py b/lmcache/v1/multiprocess/custom_types.py index ec82b8bbd47..28cc2e3a85a 100644 --- a/lmcache/v1/multiprocess/custom_types.py +++ b/lmcache/v1/multiprocess/custom_types.py @@ -315,6 +315,30 @@ def no_worker_id_version(self) -> "IPCCacheEngineKey": KVCache = list[CudaIPCWrapper] +class RegisterNonGpuContextPayload(msgspec.Struct): + """Payload for the REGISTER_KV_CACHE_NON_GPU_CONTEXT protocol message. + + Attributes: + instance_id: Worker instance identifier (typically PID). + model_name: Model name associated with this worker. + world_size: Worker world size used in cache keys. + block_size: Tokens per paged block. + num_layers: Number of model layers. + hidden_dim_size: Flattened hidden dimension per token. + dtype_str: Torch dtype name (e.g. ``"float16"``). + use_mla: Whether the worker KV format is MLA. + """ + + instance_id: int + model_name: str + world_size: int + block_size: int + num_layers: int + hidden_dim_size: int + dtype_str: str + use_mla: bool + + @dataclass class CustomizedSerdeConfig: serializer: Callable[[Any], bytes] diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index e3f80b6a6c7..cb3126c2ad9 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -18,6 +18,7 @@ from lmcache.v1.multiprocess.custom_types import ( IPCCacheEngineKey, KVCache, + RegisterNonGpuContextPayload, ) from lmcache.v1.multiprocess.protocols.base import HandlerType, ProtocolDefinition @@ -149,17 +150,12 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: response_class=None, handler_type=HandlerType.BLOCKING, ), + # Register non-GPU KV cache context + # Payload: + # - RegisterNonGpuContextPayload - all metadata fields in one struct + # Returns: None "REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition( - payload_classes=[ - int, - str, - int, - int, - int, - int, - str, - bool, - ], + payload_classes=[RegisterNonGpuContextPayload], response_class=None, handler_type=HandlerType.SYNC, ), diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 8b467e6c4c1..5dad97e4f2a 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -57,6 +57,7 @@ BlockAllocationRecord, IPCCacheEngineKey, KVCache, + RegisterNonGpuContextPayload, ) from lmcache.v1.multiprocess.gpu_context import ( GPUCacheContext, @@ -331,59 +332,49 @@ def unregister_kv_cache(self, instance_id: int) -> None: def register_kv_cache_non_gpu_context( self, - instance_id: int, - model_name: str, - world_size: int, - block_size: int, - num_layers: int, - hidden_dim_size: int, - dtype_str: str, - use_mla: bool, + payload: RegisterNonGpuContextPayload, ) -> None: """Register non-CUDA KV layout metadata for non-GPU context mode. Args: - instance_id: Worker instance identifier (typically PID). - model_name: Model name associated with this worker. - world_size: Worker world size used in cache keys. - block_size: Tokens per paged block. - num_layers: Number of model layers. - hidden_dim_size: Flattened hidden dimension per token. - dtype_str: Torch dtype name (for example ``"float16"``). - use_mla: Whether the worker KV format is MLA. + payload: Struct containing all registration fields + (instance_id, model_name, world_size, block_size, + num_layers, hidden_dim_size, dtype_str, use_mla). Raises: - ValueError: If ``dtype_str`` is not a valid torch dtype name. + ValueError: If ``payload.dtype_str`` is not a valid torch dtype name. """ - if instance_id in self.contexts: + if payload.instance_id in self.contexts: logger.warning( "Instance %s's KV cache is already registered, " "skipping the new registration", - instance_id, + payload.instance_id, ) return - dtype = getattr(torch, dtype_str, None) + dtype = getattr(torch, payload.dtype_str, None) if dtype is None or not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid dtype_str '{dtype_str}': must be a valid torch dtype " + f"Invalid dtype_str '{payload.dtype_str}': must be a valid torch dtype " "attribute name (e.g. 'float16' for torch.float16, " "'bfloat16' for torch.bfloat16, 'float32' for torch.float32)." ) shape = ( - torch.Size([num_layers, self.chunk_size, hidden_dim_size]) - if use_mla - else torch.Size([2, num_layers, self.chunk_size, hidden_dim_size]) + torch.Size([payload.num_layers, self.chunk_size, payload.hidden_dim_size]) + if payload.use_mla + else torch.Size( + [2, payload.num_layers, self.chunk_size, payload.hidden_dim_size] + ) ) layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) - self.contexts[instance_id] = RegisteredContext( - model_name=model_name, - world_size=world_size, + self.contexts[payload.instance_id] = RegisteredContext( + model_name=payload.model_name, + world_size=payload.world_size, non_cuda_metadata=NonGpuContextMetadata( layout_desc=layout_desc, - block_size=block_size, - use_mla=use_mla, + block_size=payload.block_size, + use_mla=payload.use_mla, ), ) diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 7d59ab323ed..a2b1f8ba7a2 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -13,6 +13,7 @@ from lmcache.utils import EngineType, init_logger from lmcache.v1.distributed.api import MemoryLayoutDesc from lmcache.v1.gpu_connector.utils import is_mla +from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload from lmcache.v1.multiprocess.futures import MessagingFuture from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.non_gpu_context import ( @@ -287,14 +288,16 @@ def register( mq_client, RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, [ - instance_id, - model_name, - world_size, - block_size, - num_layers, - hidden_dim_size, - dtype_str, - use_mla_flag, + RegisterNonGpuContextPayload( + instance_id=instance_id, + model_name=model_name, + world_size=world_size, + block_size=block_size, + num_layers=num_layers, + hidden_dim_size=hidden_dim_size, + dtype_str=dtype_str, + use_mla=use_mla_flag, + ) ], ) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index b7cb30dcd32..31783e46903 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -356,6 +356,7 @@ def test_server_register_and_find_non_cuda_context_layout( ) -> None: """Ensure non-CUDA registration stores metadata and lookup finds layout.""" # First Party + from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload from lmcache.v1.multiprocess.server import MPCacheEngine with ( @@ -366,14 +367,16 @@ def test_server_register_and_find_non_cuda_context_layout( ): engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) engine.register_kv_cache_non_gpu_context( - instance_id=1, - model_name="m", - world_size=1, - block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", - use_mla=False, + RegisterNonGpuContextPayload( + instance_id=1, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) ) layout = engine._find_layout_desc("m", 1) @@ -384,7 +387,10 @@ def test_server_register_and_find_non_cuda_context_layout( def test_server_store_and_retrieve_cpu_chunks(stub_native_storage_ops: Any) -> None: """Validate mocked server-side CPU chunk store and retrieve behavior.""" # First Party - from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) from lmcache.v1.multiprocess.server import MPCacheEngine mock_storage = MagicMock() @@ -417,14 +423,16 @@ def _read_prefetched_results(_keys: Any) -> Any: engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) engine.register_kv_cache_non_gpu_context( - instance_id=2, - model_name="m", - world_size=1, - block_size=4, - num_layers=2, - hidden_dim_size=16, - dtype_str="float32", - use_mla=False, + RegisterNonGpuContextPayload( + instance_id=2, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) ) payload = torch.ones(2, 2, 8, 16) key = IPCCacheEngineKey.from_token_ids( diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index f18404053e6..e4e9c64b33f 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -124,7 +124,7 @@ def test_register_kv_caches_cpu_submits_non_gpu_context_registration( assert send_mock.call_count == 1 args, _kwargs = send_mock.call_args assert args[1] == RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT - assert len(args[2]) == 8 + assert len(args[2]) == 1 def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch): From f030c2b8f5f215754887982d7514bdf6582ff863 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 15 May 2026 07:35:56 +0000 Subject: [PATCH 16/23] Auto-disable l1-use-lazy on non-CUDA backends Signed-off-by: Tony Lin --- lmcache/v1/distributed/config.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lmcache/v1/distributed/config.py b/lmcache/v1/distributed/config.py index 5bdf503ef4d..2690b043970 100644 --- a/lmcache/v1/distributed/config.py +++ b/lmcache/v1/distributed/config.py @@ -10,12 +10,16 @@ import argparse # First Party +from lmcache import torch_dev +from lmcache.logging import init_logger from lmcache.v1.distributed.l2_adapters.config import ( L2AdaptersConfig, add_l2_adapters_args, parse_args_to_l2_adapters_config, ) +logger = init_logger(__name__) + @dataclass class L1MemoryManagerConfig: @@ -38,6 +42,15 @@ class L1MemoryManagerConfig: def __post_init__(self): self.init_size_in_bytes = min(self.init_size_in_bytes, self.size_in_bytes) + # LazyMemoryAllocator requires cudart (CUDA host-pinned memory). + # Auto-disable on non-CUDA backends to avoid a RuntimeError. + if self.use_lazy and not hasattr(torch_dev, "cudart"): + logger.warning( + "LazyMemoryAllocator requires cudart which is not available " + "on the current backend. Disabling l1-use-lazy." + ) + self.use_lazy = False + @dataclass class L1ManagerConfig: From 2009b4120e12a6e50b87ca06b79bc49cb0b04f30 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Fri, 15 May 2026 11:45:51 +0000 Subject: [PATCH 17/23] Lift layout hints to the caller layer to avoid redundant computation in transfer-context registration Signed-off-by: Tony Lin --- .../vllm/vllm_multi_process_adapter.py | 7 ++++- lmcache/v1/multiprocess/transfer_context.py | 27 +++++++------------ 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index 6b06f814838..e8b1389612b 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -12,6 +12,7 @@ # First Party from lmcache.integration.request_telemetry.factory import RequestTelemetryFactory +from lmcache.integration.vllm.utils import vllm_layout_hints from lmcache.utils import _lmcache_nvtx_annotate, init_logger from lmcache.v1.multiprocess.custom_types import ( BlockAllocationRecord, @@ -852,6 +853,10 @@ def _send_register_kv_caches_request( """ self.kv_caches = kv_caches self.transfer_ctx = create_transfer_context(kv_caches) + layout_hints = vllm_layout_hints() + layout_hints["inference_engine_logical_block_size"] = ( + self.vllm_logical_block_size + ) try: self.transfer_ctx.register( self.instance_id, @@ -862,7 +867,7 @@ def _send_register_kv_caches_request( self.mq_client, self._mq_timeout, send_request=send_lmcache_request, - vllm_logical_block_size=self.vllm_logical_block_size, + layout_hints=layout_hints, ) except TimeoutError: raise ConnectionError( diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index a2b1f8ba7a2..e9a97cfeff4 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -12,7 +12,7 @@ from lmcache import torch_dev from lmcache.utils import EngineType, init_logger from lmcache.v1.distributed.api import MemoryLayoutDesc -from lmcache.v1.gpu_connector.utils import is_mla +from lmcache.v1.gpu_connector.utils import LayoutHints, is_mla from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload from lmcache.v1.multiprocess.futures import MessagingFuture from lmcache.v1.multiprocess.mq import MessageQueueClient @@ -59,7 +59,7 @@ def register( mq_client: MessageQueueClient, mq_timeout: float, send_request: SendRequest, - vllm_logical_block_size: int = 0, + layout_hints: LayoutHints | None = None, ) -> None: """Register KV caches with the server and wait for ACK. @@ -72,8 +72,7 @@ def register( mq_client: Message queue client used to communicate with server. mq_timeout: Timeout in seconds for synchronous request wait. send_request: Request sender callable used to issue MQ requests. - vllm_logical_block_size: vLLM logical block size used to derive - per-layer-group compression ratios on the server side. + layout_hints: Optional inference-engine-provided layout hints. Raises: TimeoutError: If server registration does not complete before @@ -163,16 +162,13 @@ def register( mq_client: MessageQueueClient, mq_timeout: float, send_request: SendRequest, - vllm_logical_block_size: int = 0, + layout_hints: LayoutHints | None = None, ) -> None: # First Party - from lmcache.integration.vllm.utils import vllm_layout_hints from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches self._mq_client = mq_client self._send_request = send_request - layout_hints = vllm_layout_hints() - layout_hints["inference_engine_logical_block_size"] = vllm_logical_block_size future = send_request( mq_client, RequestType.REGISTER_KV_CACHE, @@ -240,7 +236,7 @@ class NonCudaTransferContext(TransferContext): def __init__(self) -> None: self._non_gpu_context: NonGpuContext | None = None - self._layout_hints: Any = None + self._layout_hints: LayoutHints | None = None self._gpu_kv_format: Any = None def register( @@ -253,16 +249,11 @@ def register( mq_client: MessageQueueClient, mq_timeout: float, send_request: SendRequest, - vllm_logical_block_size: int = 0, + layout_hints: LayoutHints | None = None, ) -> None: - # First Party - from lmcache.integration.vllm.utils import vllm_layout_hints - - layout_hints = vllm_layout_hints() - - # TODO: inference_engine_logical_block_size is used by deepseek v4 - # which is implemented in cuda path, non cuda path is to be implemented - layout_hints["inference_engine_logical_block_size"] = vllm_logical_block_size + # TODO: inference_engine_logical_block_size is currently used by + # DeepSeek V4 on the CUDA path. The non-CUDA path is yet to be + # implemented. ( block_size, num_layers, From ec68ee33d7c2753b32b4c2abd32d86c947ac3f19 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 19 May 2026 22:47:32 +0800 Subject: [PATCH 18/23] [refactor] reserve zero-copy buffer allocation interface to NonGpuContext for shm solution Signed-off-by: Tony Lin --- lmcache/v1/multiprocess/non_gpu_context.py | 67 ++++------ .../v1/multiprocess/non_gpu_context_pickle.py | 117 ++++++++---------- lmcache/v1/multiprocess/protocols/base.py | 6 +- lmcache/v1/multiprocess/protocols/engine.py | 46 ++++++- lmcache/v1/multiprocess/server.py | 81 ++++++++---- lmcache/v1/multiprocess/transfer_context.py | 15 +-- 6 files changed, 184 insertions(+), 148 deletions(-) diff --git a/lmcache/v1/multiprocess/non_gpu_context.py b/lmcache/v1/multiprocess/non_gpu_context.py index 32f2d50e381..e782c76d0c3 100644 --- a/lmcache/v1/multiprocess/non_gpu_context.py +++ b/lmcache/v1/multiprocess/non_gpu_context.py @@ -69,57 +69,25 @@ def layout_desc(self) -> MemoryLayoutDesc: return self.metadata.layout_desc @abstractmethod - def prepare_store( - self, key: Any, instance_id: int, chunks: list[torch.Tensor] - ) -> Any: - """Prepare a store operation. - - Args: - key: Cache key for the token range to store. - instance_id: Worker instance identifier. - chunks: CPU chunk tensors to store. - - Returns: - An opaque handle to be passed to :meth:`commit_store`. - """ + def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Prepare store. Returns pre-allocated out buffers (shm) or None (pickle).""" ... @abstractmethod - def commit_store(self, handle: Any) -> bool: - """Commit a prepared store operation. - - Args: - handle: The opaque handle returned by :meth:`prepare_store`. - - Returns: - ``True`` on success, ``False`` otherwise. - """ + def commit_store( + self, key: Any, instance_id: int, chunks: list[torch.Tensor] + ) -> bool: + """Commit store. Pickle: serialize and send. Shm: notify server.""" ... @abstractmethod - def prepare_retrieve( - self, key: Any, instance_id: int - ) -> tuple[Any, list[torch.Tensor] | None]: - """Prepare a retrieve operation. - - Args: - key: Cache key for the token range to retrieve. - instance_id: Worker instance identifier. - - Returns: - A ``(handle, chunks)`` pair. ``chunks`` is a list of CPU tensors - on cache hit, or ``None`` on cache miss. The handle must be - passed to :meth:`commit_retrieve`. - """ + def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Prepare retrieve. Returns chunks or shm views, or None on miss.""" ... @abstractmethod - def commit_retrieve(self, handle: Any) -> None: - """Finalise a retrieve operation (release locks, cleanup, etc.). - - Args: - handle: The opaque handle returned by :meth:`prepare_retrieve`. - """ + def commit_retrieve(self, key: Any, instance_id: int) -> bool: + """Commit retrieve. Pickle: no-op. Shm: release read locks.""" ... @abstractmethod @@ -203,6 +171,7 @@ def gather_paged_kv_to_cpu( blocks_per_chunk: int, layout_hints: Any | None = None, gpu_kv_format: Any | None = None, + out: list[torch.Tensor] | None = None, ) -> list[torch.Tensor]: """Gather paged KV blocks into CPU chunk tensors. @@ -246,7 +215,7 @@ def gather_paged_kv_to_cpu( # tensors. Cast once so all downstream indexing is typed correctly. layer_tensors = cast(list[torch.Tensor], normalized) - chunks: list[torch.Tensor] = [] + chunks: list[torch.Tensor] = [] if out is None else out for chunk_idx in range(num_chunks): chunk_block_ids = block_ids[ chunk_idx * blocks_per_chunk : (chunk_idx + 1) * blocks_per_chunk @@ -261,7 +230,11 @@ def gather_paged_kv_to_cpu( len(chunk_block_ids) * block_size, layer_blocks.shape[-1] ) ) - chunks.append(torch.stack(mla_layers, dim=0).cpu()) + chunk_tensor = torch.stack(mla_layers, dim=0) + if out is not None: + out[chunk_idx].copy_(chunk_tensor, non_blocking=True) + else: + chunks.append(chunk_tensor.cpu()) else: k_layers: list[torch.Tensor] = [] v_layers: list[torch.Tensor] = [] @@ -310,7 +283,11 @@ def gather_paged_kv_to_cpu( ) k_stacked = torch.stack(k_layers, dim=0) v_stacked = torch.stack(v_layers, dim=0) - chunks.append(torch.stack([k_stacked, v_stacked], dim=0).cpu()) + chunk_tensor = torch.stack([k_stacked, v_stacked], dim=0) + if out is not None: + out[chunk_idx].copy_(chunk_tensor, non_blocking=True) + else: + chunks.append(chunk_tensor.cpu()) return chunks diff --git a/lmcache/v1/multiprocess/non_gpu_context_pickle.py b/lmcache/v1/multiprocess/non_gpu_context_pickle.py index a78f27b4ebd..d310b9c65d1 100644 --- a/lmcache/v1/multiprocess/non_gpu_context_pickle.py +++ b/lmcache/v1/multiprocess/non_gpu_context_pickle.py @@ -20,17 +20,12 @@ class NonGpuContextPickle(NonGpuContext): """Pickle-based implementation of :class:`NonGpuContext`. Transport mechanism: - - **Store**: ``prepare_store`` serialises chunks with ``pickle.dumps``; \ -``commit_store`` sends a ``STORE_CPU_CHUNKS`` message and waits for the \ -server acknowledgment. - - **Retrieve**: ``prepare_retrieve`` sends a ``RETRIEVE_CPU_CHUNKS`` \ -message, waits for the response, and deserialises the returned bytes with \ -``pickle.loads``; ``commit_retrieve`` is a no-op (no locks to release). - - Args: - metadata: Layout metadata for the non-GPU context. - mq_client: Message-queue client for server communication. - mq_timeout: Timeout in seconds for blocking MQ requests. + - **Store**: ``prepare_store`` sends ``PREPARE_STORE`` (returns empty slots + for pickle mode); ``commit_store`` serialises chunks and sends + ``COMMIT_STORE``. + - **Retrieve**: ``prepare_retrieve`` sends ``PREPARE_RETRIEVE`` and + deserialises the returned bytes; ``commit_retrieve`` sends + ``COMMIT_RETRIEVE`` (no-op for pickle). """ def __init__( @@ -41,84 +36,70 @@ def __init__( ) -> None: super().__init__(metadata, mq_client, mq_timeout) - def prepare_store( - self, key: Any, instance_id: int, chunks: list[torch.Tensor] - ) -> Any: - """Serialise *chunks* with ``pickle.dumps``. - - Args: - key: Cache key for the token range to store. - instance_id: Worker instance identifier. - chunks: CPU chunk tensors to serialise. - - Returns: - Opaque handle ``(key, instance_id, serialised_bytes)`` to be - passed to :meth:`commit_store`. - """ - serialised = pickle.dumps(chunks) - return (key, instance_id, serialised) - - def commit_store(self, handle: Any) -> bool: - """Send pickled chunks to the server via ``STORE_CPU_CHUNKS``. - - Blocks until the server acknowledges the write. + def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Send PREPARE_STORE RPC. For pickle, returns no pre-allocated buffers.""" + future = self.mq_client.submit_request( + RequestType.PREPARE_STORE, + [key, instance_id], + get_response_class(RequestType.PREPARE_STORE), + ) + try: + future.result(timeout=self.mq_timeout) + except TimeoutError: + pass + return None - Args: - handle: The ``(key, instance_id, bytes)`` tuple returned by - :meth:`prepare_store`. + def commit_store( + self, key: Any, instance_id: int, chunks: list[torch.Tensor] + ) -> bool: + """Serialize chunks and send via COMMIT_STORE. Returns: ``True`` on success, ``False`` on failure or timeout. """ - key, instance_id, serialised = handle + serialised = pickle.dumps(chunks) future = self.mq_client.submit_request( - RequestType.STORE_CPU_CHUNKS, + RequestType.COMMIT_STORE, [key, instance_id, serialised], - get_response_class(RequestType.STORE_CPU_CHUNKS), + get_response_class(RequestType.COMMIT_STORE), ) try: return bool(future.result(timeout=self.mq_timeout)) except TimeoutError: return False - def prepare_retrieve( - self, key: Any, instance_id: int - ) -> tuple[Any, list[torch.Tensor] | None]: - """Fetch serialised chunks from the server via ``RETRIEVE_CPU_CHUNKS``. - - Blocks until the server responds with the cached data (or reports a - miss). - - Args: - key: Cache key for the token range to retrieve. - instance_id: Worker instance identifier. + def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Send PREPARE_RETRIEVE and deserialize the response data. Returns: - ``(None, chunks)`` on cache hit where *chunks* is the - deserialised list of CPU tensors, or ``(None, None)`` on cache - miss or timeout. The handle is ``None`` because the pickle path - has no resources to release in :meth:`commit_retrieve`. + Chunks on hit, or None on miss/timeout. """ future = self.mq_client.submit_request( - RequestType.RETRIEVE_CPU_CHUNKS, + RequestType.PREPARE_RETRIEVE, [key, instance_id], - get_response_class(RequestType.RETRIEVE_CPU_CHUNKS), + get_response_class(RequestType.PREPARE_RETRIEVE), ) try: - success, cpu_data_bytes = future.result(timeout=self.mq_timeout) + response = future.result(timeout=self.mq_timeout) except TimeoutError: - return (None, None) - if not success or not cpu_data_bytes: - return (None, None) - chunks: list[torch.Tensor] = pickle.loads(cpu_data_bytes) - return (None, chunks) - - def commit_retrieve(self, handle: Any) -> None: - """No-op: the pickle path holds no server-side locks. - - Args: - handle: Ignored. - """ + return None + if not response.success or not response.data: + return None + chunks: list[torch.Tensor] = pickle.loads(response.data) + return chunks + + def commit_retrieve(self, key: Any, instance_id: int) -> bool: + """Send COMMIT_RETRIEVE (no-op for pickle path).""" + future = self.mq_client.submit_request( + RequestType.COMMIT_RETRIEVE, + [key, instance_id], + get_response_class(RequestType.COMMIT_RETRIEVE), + ) + try: + future.result(timeout=self.mq_timeout) + except TimeoutError: + pass + return True def close(self) -> None: """No-op: the pickle path holds no persistent resources.""" diff --git a/lmcache/v1/multiprocess/protocols/base.py b/lmcache/v1/multiprocess/protocols/base.py index 0d82b2754c3..777ce029f29 100644 --- a/lmcache/v1/multiprocess/protocols/base.py +++ b/lmcache/v1/multiprocess/protocols/base.py @@ -49,8 +49,10 @@ class RequestType(enum.Enum): FREE_LOOKUP_LOCKS = enum.auto() END_SESSION = enum.auto() REGISTER_KV_CACHE_NON_GPU_CONTEXT = enum.auto() - STORE_CPU_CHUNKS = enum.auto() - RETRIEVE_CPU_CHUNKS = enum.auto() + PREPARE_STORE = enum.auto() + COMMIT_STORE = enum.auto() + PREPARE_RETRIEVE = enum.auto() + COMMIT_RETRIEVE = enum.auto() # Controller operations CLEAR = enum.auto() diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index cb3126c2ad9..62ec4926cd8 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -12,6 +12,9 @@ - END_SESSION: End a session and clean up associated resources """ +# Standard +from dataclasses import dataclass, field + # First Party from lmcache.utils import EngineType from lmcache.v1.gpu_connector.utils import LayoutHints @@ -22,6 +25,27 @@ ) from lmcache.v1.multiprocess.protocols.base import HandlerType, ProtocolDefinition + +@dataclass +class PrepareStoreResponse: + """Response for PREPARE_STORE.""" + + context: dict = field( + default_factory=dict + ) # pickle: {}, shm will put slot info here + + +@dataclass +class PrepareRetrieveResponse: + """Response for PREPARE_RETRIEVE.""" + + success: bool + data: bytes = b"" + context: dict = field( + default_factory=dict + ) # pickle: {}, shm will put slot info here + + # Define request names for this protocol group REQUEST_NAMES = [ "REGISTER_KV_CACHE", @@ -34,8 +58,10 @@ "FREE_LOOKUP_LOCKS", "END_SESSION", "REGISTER_KV_CACHE_NON_GPU_CONTEXT", - "STORE_CPU_CHUNKS", - "RETRIEVE_CPU_CHUNKS", + "PREPARE_STORE", + "COMMIT_STORE", + "PREPARE_RETRIEVE", + "COMMIT_RETRIEVE", ] # Type alias for cache keys @@ -159,14 +185,24 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: response_class=None, handler_type=HandlerType.SYNC, ), - "STORE_CPU_CHUNKS": ProtocolDefinition( + "PREPARE_STORE": ProtocolDefinition( + payload_classes=[KeyType, int], + response_class=PrepareStoreResponse, + handler_type=HandlerType.BLOCKING, + ), + "COMMIT_STORE": ProtocolDefinition( payload_classes=[KeyType, int, bytes], response_class=bool, handler_type=HandlerType.BLOCKING, ), - "RETRIEVE_CPU_CHUNKS": ProtocolDefinition( + "PREPARE_RETRIEVE": ProtocolDefinition( payload_classes=[KeyType, int], - response_class=tuple[bool, bytes], + response_class=PrepareRetrieveResponse, + handler_type=HandlerType.BLOCKING, + ), + "COMMIT_RETRIEVE": ProtocolDefinition( + payload_classes=[KeyType, int], + response_class=bool, handler_type=HandlerType.BLOCKING, ), } diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 5dad97e4f2a..bb67d07bf54 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -69,6 +69,10 @@ get_handler_type, get_payload_classes, ) +from lmcache.v1.multiprocess.protocols.engine import ( + PrepareRetrieveResponse, + PrepareStoreResponse, +) from lmcache.v1.multiprocess.session import SessionManager from lmcache.v1.multiprocess.token_hasher import TokenHasher import lmcache.c_ops as lmc_ops @@ -400,13 +404,31 @@ def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: return ipc_key_to_object_keys(key, chunk_hashes) @_lmcache_nvtx_annotate - def store_cpu_chunks( + def prepare_store( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> PrepareStoreResponse: + """Prepare a store operation. For pickle mode, returns empty slots. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + + Returns: + PrepareStoreResponse with empty slots for pickle mode. + """ + + return PrepareStoreResponse(context={}) + + @_lmcache_nvtx_annotate + def commit_store( self, key: IPCCacheEngineKey, instance_id: int, cpu_data: bytes, ) -> bool: - """Store worker-provided CPU chunks for non-CUDA cpu context mode. + """Commit serialized CPU chunks to storage. Args: key: Cache key for the token range to store. @@ -415,9 +437,6 @@ def store_cpu_chunks( Returns: ``True`` when all reserved objects are written, otherwise ``False``. - - Raises: - ValueError: If the instance has no registered cpu context. """ obj_keys = self._resolve_obj_keys(key) @@ -453,11 +472,11 @@ def store_cpu_chunks( return len(written_keys) == len(reserved_dict) @_lmcache_nvtx_annotate - def retrieve_cpu_chunks( + def prepare_retrieve( self, key: IPCCacheEngineKey, instance_id: int, - ) -> tuple[bool, bytes]: + ) -> PrepareRetrieveResponse: """Retrieve prefetched chunks and return serialized CPU tensors. Args: @@ -465,12 +484,9 @@ def retrieve_cpu_chunks( instance_id: Worker instance identifier. Returns: - Tuple ``(success, payload)`` where ``payload`` is a pickled - list of CPU chunk tensors. - - Raises: - ValueError: If the instance has no registered cpu context. + PrepareRetrieveResponse with serialized data on hit. """ + obj_keys = self._resolve_obj_keys(key) context = self.contexts.get(instance_id) @@ -483,18 +499,39 @@ def retrieve_cpu_chunks( try: with self.storage_manager.read_prefetched_results(obj_keys) as memory_objs: if not memory_objs or len(memory_objs) != len(obj_keys): - return False, b"" + return PrepareRetrieveResponse(success=False, data=b"", context={}) prefetched_keys = obj_keys[: len(memory_objs)] chunks = [] for memory_obj in memory_objs: if memory_obj.tensor is None: - return False, b"" + return PrepareRetrieveResponse( + success=False, data=b"", context={} + ) chunks.append(memory_obj.tensor.cpu().clone()) - return True, pickle.dumps(chunks) + return PrepareRetrieveResponse( + success=True, data=pickle.dumps(chunks), context={} + ) finally: if prefetched_keys: self.storage_manager.finish_read_prefetched(prefetched_keys) + @_lmcache_nvtx_annotate + def commit_retrieve( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> bool: + """Finalize a retrieve operation. No-op for pickle mode. + + Args: + key: Cache key (unused for pickle). + instance_id: Worker instance identifier (unused for pickle). + + Returns: + Always ``True``. + """ + return True + @_lmcache_nvtx_annotate def store( self, @@ -1400,7 +1437,7 @@ def run_cache_server( RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, engine.register_kv_cache_non_gpu_context, ) - add_handler_helper(server, RequestType.STORE_CPU_CHUNKS, engine.store_cpu_chunks) + add_handler_helper(server, RequestType.PREPARE_STORE, engine.prepare_store) add_handler_helper(server, RequestType.LOOKUP, engine.lookup) add_handler_helper( server, RequestType.QUERY_PREFETCH_STATUS, engine.query_prefetch_status @@ -1412,9 +1449,9 @@ def run_cache_server( ) add_handler_helper(server, RequestType.FREE_LOOKUP_LOCKS, engine.free_lookup_locks) add_handler_helper(server, RequestType.RETRIEVE, engine.retrieve) - add_handler_helper( - server, RequestType.RETRIEVE_CPU_CHUNKS, engine.retrieve_cpu_chunks - ) + add_handler_helper(server, RequestType.COMMIT_STORE, engine.commit_store) + add_handler_helper(server, RequestType.PREPARE_RETRIEVE, engine.prepare_retrieve) + add_handler_helper(server, RequestType.COMMIT_RETRIEVE, engine.commit_retrieve) add_handler_helper(server, RequestType.CLEAR, engine.clear) add_handler_helper(server, RequestType.GET_CHUNK_SIZE, engine.get_chunk_size) add_handler_helper(server, RequestType.PING, engine.ping) @@ -1431,8 +1468,10 @@ def run_cache_server( [ RequestType.STORE, RequestType.RETRIEVE, - RequestType.STORE_CPU_CHUNKS, - RequestType.RETRIEVE_CPU_CHUNKS, + RequestType.PREPARE_STORE, + RequestType.COMMIT_STORE, + RequestType.PREPARE_RETRIEVE, + RequestType.COMMIT_RETRIEVE, ], max_workers=mp_config.max_gpu_workers, ) diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index e9a97cfeff4..225019836b1 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -317,15 +317,16 @@ def submit_store( ) torch_dev.synchronize() + out_buffers = self._non_gpu_context.prepare_store(key, instance_id) cpu_chunks = gather_paged_kv_to_cpu( kv_caches, block_ids, blocks_in_chunk, layout_hints=self._layout_hints, gpu_kv_format=self._gpu_kv_format, + out=out_buffers, ) - handle = self._non_gpu_context.prepare_store(key, instance_id, cpu_chunks) - ok = self._non_gpu_context.commit_store(handle) + ok = self._non_gpu_context.commit_store(key, instance_id, cpu_chunks) future: MessagingFuture[bool] = MessagingFuture() future.set_result(ok) @@ -348,14 +349,14 @@ def submit_retrieve( "Call register() before submit_retrieve()." ) - handle, chunks = self._non_gpu_context.prepare_retrieve(key, instance_id) - ok = chunks is not None - if chunks is not None: + src_buffers = self._non_gpu_context.prepare_retrieve(key, instance_id) + ok = src_buffers is not None + if src_buffers is not None: try: scatter_cpu_to_paged_kv( kv_caches, block_ids, - chunks, + src_buffers, blocks_in_chunk, skip_first_n_tokens=skip_first_n_tokens, layout_hints=self._layout_hints, @@ -364,7 +365,7 @@ def submit_retrieve( except (RuntimeError, ValueError, TypeError, IndexError): logger.exception("Failed to scatter retrieved CPU context chunks") ok = False - self._non_gpu_context.commit_retrieve(handle) + self._non_gpu_context.commit_retrieve(key, instance_id) future: MessagingFuture[bool] = MessagingFuture() future.set_result(ok) From bb60eefc535c641314c0113ab035554e8f71d4fc Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Tue, 19 May 2026 22:55:57 +0800 Subject: [PATCH 19/23] fix: update test to use new unified protocol methods Signed-off-by: Tony Lin --- tests/v1/multiprocess/test_non_cuda_context.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index 31783e46903..f5feadeaa4c 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -448,8 +448,10 @@ def _read_prefetched_results(_keys: Any) -> Any: "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", return_value=["obj"], ): - store_ok = engine.store_cpu_chunks(key, 2, pickle.dumps([payload])) - success, cpu_data = engine.retrieve_cpu_chunks(key, 2) + store_ok = engine.commit_store(key, 2, pickle.dumps([payload])) + response = engine.prepare_retrieve(key, 2) + success = response.success + cpu_data = response.data assert isinstance(store_ok, bool) assert torch.allclose(mock_memory_obj.tensor, payload) From 91ed22b115fd05f00198f4fc153fb7f60ee1f9e7 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 20 May 2026 04:44:06 +0000 Subject: [PATCH 20/23] refactor: rename transfer context classes to handle/data semantics Signed-off-by: Tony Lin --- .../v1/multiprocess/non_gpu_context_design.md | 16 +++++++-------- lmcache/v1/multiprocess/transfer_context.py | 20 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/design/v1/multiprocess/non_gpu_context_design.md b/docs/design/v1/multiprocess/non_gpu_context_design.md index 24f9d19fc2e..dcb4a08a470 100644 --- a/docs/design/v1/multiprocess/non_gpu_context_design.md +++ b/docs/design/v1/multiprocess/non_gpu_context_design.md @@ -63,8 +63,8 @@ polymorphic `TransferContext` abstraction. ``` vllm_multi_process_adapter.py ← Engine adapter, device-agnostic └── TransferContext ← Worker-side transport abstraction (§3) - ├── CudaTransferContext ← CUDA IPC + MQ future path - └── NonCudaTransferContext ← Synchronous gather/scatter path + ├── HandleTransferContext ← CUDA IPC + MQ future path + └── DataTransferContext ← Synchronous gather/scatter path └── NonGpuContext ← Serialisation abstraction (§4.2) ├── NonGpuContextPickle ← pickle.dumps/loads (§4.3) └── NonGpuContextShm ← shared memory (§4.4, TODO) @@ -75,7 +75,7 @@ Two layers of abstraction serve different purposes: - **TransferContext** (§3) — decides **CUDA vs non-CUDA** routing at the worker adapter level. - **NonGpuContext** (§4.2) — decides **how** CPU chunk data is serialised and - transported (pickle vs SHM). Only used inside `NonCudaTransferContext`. + transported (pickle vs SHM). Only used inside `DataTransferContext`. ### 2.2 State machine (worker ↔ server) @@ -91,7 +91,7 @@ Two layers of abstraction serve different purposes: [device == cuda] [device != cuda] | | v v - CudaTransferContext.register() NonCudaTransferContext.register() + HandleTransferContext.register() DataTransferContext.register() → REGISTER_KV_CACHE → REGISTER_KV_CACHE_NON_GPU_CONTEXT (CUDA IPC handles) (scalar metadata fields) | + create_non_gpu_context() @@ -155,17 +155,17 @@ no `if/else` anywhere. ### 3.3 `create_transfer_context()` factory Inspects device types of all KV cache tensors **exactly once**. CUDA → -`CudaTransferContext`; otherwise → `NonCudaTransferContext`. Mixed device types +`HandleTransferContext`; otherwise → `DataTransferContext`. Mixed device types are rejected. -### 3.4 `CudaTransferContext` +### 3.4 `HandleTransferContext` Wraps the original CUDA IPC path. Sends `REGISTER_KV_CACHE` / `STORE` / `RETRIEVE` messages with IPC handles, tracks async MQ futures. `poll_finished` queries futures; `drain_all` marks all pending as finished for unhealthy shutdown. Semantics identical to pre-refactoring. -### 3.5 `NonCudaTransferContext` +### 3.5 `DataTransferContext` Holds a `NonGpuContext` instance internally. Sends `REGISTER_KV_CACHE_NON_GPU_CONTEXT` with scalar metadata. Store and retrieve @@ -190,7 +190,7 @@ server reconstructs `MemoryLayoutDesc` from the scalars internally. ### 4.2 `NonGpuContext` ABC: two-phase prepare/commit The serialisation layer is abstracted behind `NonGpuContext` so that pickle -and SHM can be swapped without touching `NonCudaTransferContext` or the server. +and SHM can be swapped without touching `DataTransferContext` or the server. The ABC defines: `prepare_store`, `commit_store`, `prepare_retrieve`, `commit_retrieve`, `close`. diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 225019836b1..2a598791bb6 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -145,8 +145,8 @@ def close(self) -> None: """Release resources held by this context.""" -class CudaTransferContext(TransferContext): - """CUDA IPC + MQ future transport context.""" +class HandleTransferContext(TransferContext): + """Handle-based IPC + MQ future transport context.""" def __init__(self) -> None: self._mq_client: MessageQueueClient | None = None @@ -195,7 +195,7 @@ def submit_store( ) -> MessagingFuture: if self._mq_client is None or self._send_request is None: raise RuntimeError( - "CUDA transfer context is not registered. " + "Handle transfer context is not registered. " "Call register() before submit_store()." ) return self._send_request( @@ -217,7 +217,7 @@ def submit_retrieve( ) -> MessagingFuture: if self._mq_client is None or self._send_request is None: raise RuntimeError( - "CUDA transfer context is not registered. " + "Handle transfer context is not registered. " "Call register() before submit_retrieve()." ) return self._send_request( @@ -231,8 +231,8 @@ def close(self) -> None: self._send_request = None -class NonCudaTransferContext(TransferContext): - """Non-CUDA context transport for non-CUDA workers.""" +class DataTransferContext(TransferContext): + """Data transfer context for non-CUDA workers.""" def __init__(self) -> None: self._non_gpu_context: NonGpuContext | None = None @@ -312,7 +312,7 @@ def submit_store( ) -> MessagingFuture: if self._non_gpu_context is None: raise RuntimeError( - "Non-CUDA transfer context is not registered. " + "Data transfer context is not registered. " "Call register() before submit_store()." ) @@ -345,7 +345,7 @@ def submit_retrieve( ) -> MessagingFuture: if self._non_gpu_context is None: raise RuntimeError( - "Non-CUDA transfer context is not registered. " + "Data transfer context is not registered. " "Call register() before submit_retrieve()." ) @@ -405,5 +405,5 @@ def create_transfer_context( device_type = next(iter(device_types)) logger.info("Creating transfer context (device_type=%s)", device_type) if device_type == "cuda": - return CudaTransferContext() - return NonCudaTransferContext() + return HandleTransferContext() + return DataTransferContext() From 6d8ff15bae743f43e813aac0bff7758c1fba6b90 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 20 May 2026 05:25:40 +0000 Subject: [PATCH 21/23] update test Signed-off-by: Tony Lin --- tests/v1/multiprocess/test_non_cuda_context.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index f5feadeaa4c..5da7dc47aca 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -100,15 +100,15 @@ def test_wrap_kv_caches_wraps_all_tensors(monkeypatch: Any) -> None: def test_create_transfer_context_uses_non_cuda_context_on_cpu() -> None: - """Ensure transfer context factory returns NonCudaTransferContext for CPU KV.""" + """Ensure transfer context factory returns DataTransferContext for CPU KV.""" # First Party from lmcache.v1.multiprocess.transfer_context import ( - NonCudaTransferContext, + DataTransferContext, create_transfer_context, ) context = create_transfer_context({"layer_0": torch.randn(2, 2)}) - assert isinstance(context, NonCudaTransferContext) + assert isinstance(context, DataTransferContext) def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: From b07b7089796dabc00c40be16c5618ebf3cd5145d Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 20 May 2026 05:30:33 +0000 Subject: [PATCH 22/23] update docs Signed-off-by: Tony Lin --- .../v1/multiprocess/non_gpu_context_design.md | 375 ++++++++---------- 1 file changed, 171 insertions(+), 204 deletions(-) diff --git a/docs/design/v1/multiprocess/non_gpu_context_design.md b/docs/design/v1/multiprocess/non_gpu_context_design.md index dcb4a08a470..5ed541d30ed 100644 --- a/docs/design/v1/multiprocess/non_gpu_context_design.md +++ b/docs/design/v1/multiprocess/non_gpu_context_design.md @@ -1,250 +1,217 @@ -# Non-GPU Context Design (MP mode, non-CUDA) +# Non-GPU Context Design (Multiprocess Mode) ## 1. Motivation -LMCache multiprocess mode relies on **CUDA IPC** to transfer KV cache data -between vLLM worker processes and the LMCache cache server. The existing -path wraps GPU tensors in `CudaIPCWrapper`, exchanges IPC handles via ZMQ -messages, and uses CUDA events for cross-process synchronisation. +LMCache multiprocess mode originally depended on CUDA IPC: workers send IPC handles, +and the server reads/writes worker GPU memory directly. That path works well on +CUDA, but the required primitives are CUDA-specific (IPC memory handles, +interprocess CUDA events, CUDA stream semantics). -This design is fundamentally tied to the CUDA programming model: +For **CPU, XPU, HPU, and other non-CUDA devices**, those primitives do not exist. +The non-GPU context design introduces a device-agnostic path where workers move KV +data through CPU chunks instead of CUDA IPC handles. -| CUDA IPC dependency | Why it blocks non-CUDA devices | -|---|---| -| `CudaIPCWrapper` / `cudaIpcGetMemHandle` | Only works on NVIDIA CUDA tensors | -| `torch.cuda.Event(interprocess=True)` | CUDA-specific IPC event API | -| `cupy.cuda.ExternalStream` | CUDA stream wrapper | -| GPU pointer arithmetic in C++ kernels | Assumes CUDA device pointers | +Goal: keep the existing CUDA path unchanged while adding a second path that works +across non-CUDA backends. -For non-CUDA accelerators — **CPU, Intel XPU, Habana HPU**, or any future -device — none of these primitives are available. +## 2. Design -The **non-GPU context** path introduces a device-agnostic KV transfer mechanism: +### 2.1 Architecture Overview -1. Workers **gather** paged KV blocks into contiguous CPU chunk tensors. -2. CPU chunks are **transported** to the server through a pluggable - serialisation layer (pickle today, shared memory in the future). -3. On retrieve, the server returns CPU chunks and workers **scatter** them - back into device-local paged KV tensors. - -The existing CUDA IPC path is **untouched** — the two paths coexist behind a -polymorphic `TransferContext` abstraction. - -### Transport comparison - -**Store (worker → server storage):** - -| Transport | Copies | Data flow | -|---|---|---| -| CUDA IPC | 2 | GPU KV → GPU staging buffer → CPU memory obj | -| Pickle | 4 | GPU KV → CPU chunk → pickle.dumps → pickle.loads → CPU memory obj | -| SHM (TODO) | 1 | GPU KV → CPU memory obj (SHM mapped) | - -**Retrieve (server storage → worker):** - -| Transport | Copies | Data flow | -|---|---|---| -| CUDA IPC | 2 | CPU memory obj → GPU staging buffer → GPU KV | -| Pickle | 4 | CPU memory obj → pickle.dumps → pickle.loads → CPU chunk → GPU KV | -| SHM (TODO) | 1 | CPU memory obj (SHM mapped) → GPU KV | - -**Applicability:** - -| Transport | Platform requirement | Pros | Cons | -|---|---|---|---| -| CUDA IPC | NVIDIA CUDA devices only | Async GPU streams, mature path | CUDA-only | -| Pickle | Any device, no dependencies | Generally available, zero setup | 4 copies + serialisation overhead | -| SHM (TODO) | `/dev/shm` capacity ≥ L1 cache size | Fewest copies (1), no serialisation | Requires sufficient shared memory | - -## 2. Architecture Overview - -### 2.1 Layered architecture - -``` -vllm_multi_process_adapter.py ← Engine adapter, device-agnostic - └── TransferContext ← Worker-side transport abstraction (§3) - ├── HandleTransferContext ← CUDA IPC + MQ future path - └── DataTransferContext ← Synchronous gather/scatter path - └── NonGpuContext ← Serialisation abstraction (§4.2) - ├── NonGpuContextPickle ← pickle.dumps/loads (§4.3) - └── NonGpuContextShm ← shared memory (§4.4, TODO) +```text +Worker adapter (vLLM MP adapter) + └─ TransferContext + ├─ HandleTransferContext (CUDA IPC path) + └─ DataTransferContext (non-CUDA data path) + └─ NonGpuContext + ├─ NonGpuContextPickle + └─ NonGpuContextShm (TODO) ``` -Two layers of abstraction serve different purposes: - -- **TransferContext** (§3) — decides **CUDA vs non-CUDA** routing at the - worker adapter level. -- **NonGpuContext** (§4.2) — decides **how** CPU chunk data is serialised and - transported (pickle vs SHM). Only used inside `DataTransferContext`. - -### 2.2 State machine (worker ↔ server) +State machine overview (worker-side): ```text - register_kv_caches() - | - v - create_transfer_context(kv_caches) - | - +----------------+----------------+ - | | - v v - [device == cuda] [device != cuda] - | | - v v - HandleTransferContext.register() DataTransferContext.register() - → REGISTER_KV_CACHE → REGISTER_KV_CACHE_NON_GPU_CONTEXT - (CUDA IPC handles) (scalar metadata fields) - | + create_non_gpu_context() - +----------------+----------------+ - | - v - [READY / SERVING] - | - +----------------+----------------+ - | | - v v - transfer_ctx.submit_store() transfer_ctx.submit_store() - | | - v v - STORE (GPU → L1) gather_paged_kv_to_cpu() - [async MQ future] + _non_gpu_context.prepare_store() - | + _non_gpu_context.commit_store() [sync] - v _store_done[id] = ok - [READY] | - +----------------+----------------+ - | - v - transfer_ctx.submit_retrieve() + poll_finished() - | - +----------------+----------------+ - | | - v v - RETRIEVE (L1 → GPU) _non_gpu_context.prepare_retrieve() [sync] - [async MQ future] + scatter_cpu_to_paged_kv() - | + _non_gpu_context.commit_retrieve() - v _retrieve_done[id] = (ok, block_ids) - +----------------+----------------+ - | - v - [READY / SERVING] - | - v - unregister_kv_cache() - | - v - [TERMINATED] + create_transfer_context() + | + +---------------+---------------+ + | | + v v + HandleTransferContext DataTransferContext + (device == CUDA) (device != CUDA) + | | + v v + register() register() + | | + +---------------+---------------+ + | + v + READY + | + +---------------+-------------------------------+ + | | + v v + submit_store (handle path) submit_store (data path) + -> STORE request (async) -> prepare_store -> gather -> commit_store + | | + +---------------+-------------------------------+ + | + v + READY + | + +---------------+-------------------------------+ + | | + v v + submit_retrieve (handle path) submit_retrieve (data path) + -> RETRIEVE request (async) -> prepare_retrieve -> scatter -> commit_retrieve + | | + +---------------+-------------------------------+ + | + v + READY + | + v + close() ``` -## 3. Worker-side: TransferContext Abstraction +Overall data flow: +- **CUDA path**: worker sends a handle, server pulls/pushes data directly. +- **Non-CUDA path**: worker gathers/scatters paged KV and exchanges CPU-side data + via a transport-specific `NonGpuContext` implementation. -### 3.1 Problem +### 2.2 Worker Side: TransferContext -Before this refactoring, `vllm_multi_process_adapter.py` contained -non-CUDA-specific branching in every method — `register_kv_caches`, -`submit_store_request`, `submit_retrieve_request`, `get_finished`, and the -unhealthy drain path. Adding a third transport would require touching every -branch. +`TransferContext` is the worker-side transport abstraction with four methods: +`register`, `submit_store`, `submit_retrieve`, and `close`. +The contract is intentionally minimal so worker adapters only depend on these +four lifecycle and transfer operations. -### 3.2 Solution +- **HandleTransferContext** keeps the original CUDA IPC behavior: + worker sends a handle and server performs direct GPU-side transfer. +- **DataTransferContext** is the non-CUDA path: + worker transfers actual data chunks through `NonGpuContext`. -`transfer_context.py` defines the `TransferContext` ABC with six methods: -`register`, `submit_store`, `submit_retrieve`, `poll_finished`, `drain_all`, -and `close`. The adapter holds a single `TransferContext` and delegates — -no `if/else` anywhere. +`DataTransferContext` flows: +- **submit_store**: `prepare_store` → `gather_paged_kv_to_cpu` → `commit_store` +- **submit_retrieve**: `prepare_retrieve` → `scatter_cpu_to_paged_kv` → `commit_retrieve` -### 3.3 `create_transfer_context()` factory +Why `prepare → data operation → commit`: +- `prepare_*`: set up transport state (for SHM this allocates/returns shared buffers; + for pickle it is a protocol RPC that does not allocate transfer buffers). +- gather/scatter: worker-local data movement between paged KV and contiguous + CPU chunks, performed between protocol phases. +- `commit_*`: finalize and notify server to consume or release transfer state. -Inspects device types of all KV cache tensors **exactly once**. CUDA → -`HandleTransferContext`; otherwise → `DataTransferContext`. Mixed device types -are rejected. +`create_transfer_context()` selects the implementation once based on device type +(CUDA → `HandleTransferContext`, otherwise → `DataTransferContext`). +It also validates that all KV cache tensors share one device type and rejects +mixed-device configurations by raising an error. -### 3.4 `HandleTransferContext` +| Context | What is transferred | Who performs copy work | Completion style | +|---|---|---|---| +| HandleTransferContext | Device handle/reference | Server pulls/pushes via IPC | Async MQ future | +| DataTransferContext | Actual CPU chunk data | Worker gather/scatter + transport commit | Synchronous worker-side flow | -Wraps the original CUDA IPC path. Sends `REGISTER_KV_CACHE` / `STORE` / -`RETRIEVE` messages with IPC handles, tracks async MQ futures. -`poll_finished` queries futures; `drain_all` marks all pending as finished -for unhealthy shutdown. Semantics identical to pre-refactoring. +### 2.3 Server Side: GPU Context vs Non-GPU Context -### 3.5 `DataTransferContext` +- **GPU Context (existing path):** server uses CUDA IPC handles to access worker + device memory directly. +- **Non-GPU Context:** server participates in two separate two-phase protocols + exposed by `NonGpuContext`: `prepare_store/commit_store` for store, and + `prepare_retrieve/commit_retrieve` for retrieve, plus lifecycle cleanup via + `close`. -Holds a `NonGpuContext` instance internally. Sends -`REGISTER_KV_CACHE_NON_GPU_CONTEXT` with scalar metadata. Store and retrieve -are **synchronous**: gather → prepare/commit, then record result in -`_store_done` / `_retrieve_done`. `poll_finished` simply drains these dicts. +`NonGpuContext` implementations: +- **NonGpuContextPickle**: serialize/deserialize chunk payloads with pickle. +- **NonGpuContextShm**: shared-memory transport (planned/TODO). -## 4. Server-side: Non-GPU Context Protocol +This split keeps server protocol stable while allowing transport-specific behavior +behind one interface contract. -### 4.1 Why GPU context and non-GPU context need different protocols +### 2.4 Transport Comparison -| | GPU context | non-GPU context | -|---|---|---| -| Registration | `REGISTER_KV_CACHE` — IPC handles | `REGISTER_KV_CACHE_NON_GPU_CONTEXT` — scalar fields | -| Store | `STORE` — event handle + block IDs, server reads GPU directly | `STORE_CPU_CHUNKS` — serialised CPU tensors | -| Retrieve | `RETRIEVE` — event handle + block IDs, server writes GPU directly | `RETRIEVE_CPU_CHUNKS` — key lookup, returns CPU tensors | +**Store (worker → server storage):** -Registration uses **scalar fields** (`block_size`, `num_layers`, -`hidden_dim_size`, `dtype_str`, `use_mla`) instead of pickled objects -to avoid cross-process pickle security and compatibility concerns. The -server reconstructs `MemoryLayoutDesc` from the scalars internally. +| Transport | Copies | Data flow | +|---|---|---| +| Handle (CUDA IPC) | 2 | GPU KV → GPU staging buffer → CPU memory object | +| Pickle | 4 | GPU KV → CPU chunk → serialize → deserialize → CPU memory object | +| SHM (TODO) | 1 | GPU KV → CPU memory object (SHM mapped) | -### 4.2 `NonGpuContext` ABC: two-phase prepare/commit +**Retrieve (server storage → worker):** -The serialisation layer is abstracted behind `NonGpuContext` so that pickle -and SHM can be swapped without touching `DataTransferContext` or the server. +| Transport | Copies | Data flow | +|---|---|---| +| Handle (CUDA IPC) | 2 | CPU memory object → GPU staging buffer → GPU KV | +| Pickle | 4 | CPU memory object → serialize → deserialize → CPU chunk → GPU KV | +| SHM (TODO) | 1 | CPU memory object (SHM mapped) → GPU KV | -The ABC defines: `prepare_store`, `commit_store`, `prepare_retrieve`, -`commit_retrieve`, `close`. +| Transport | Pros | Cons | Best fit | +|---|---|---|---| +| Handle (CUDA IPC) | Mature path, good async overlap | CUDA-only | NVIDIA CUDA deployments | +| Pickle | Works everywhere, no SHM setup | Extra serialization + copy overhead | Universal fallback | +| SHM (TODO) | Lowest copy count, no serialization | Requires enough `/dev/shm` and synchronization | High-throughput non-CUDA setups | -Why two phases? Pickle can do everything in one step (prepare serialises, -commit sends). SHM needs prepare to allocate a slot, then the worker writes -into mapped memory, then commit tells the server "ready". The split -accommodates both without forcing unnecessary round-trips on pickle. +## 3. Protocol & Data Flow -| Phase | Pickle | SHM (TODO) | -|---|---|---| -| `prepare_store` | `pickle.dumps(chunks)` → opaque handle | MQ `PREPARE_STORE` → get SHM offset → `memcpy` into SHM | -| `commit_store` | MQ `STORE_CPU_CHUNKS`, block for ack | MQ `COMMIT_STORE` → server reads from SHM | -| `prepare_retrieve` | MQ `RETRIEVE_CPU_CHUNKS` → `pickle.loads` | MQ `PREPARE_RETRIEVE` → server writes to SHM → map tensor views | -| `commit_retrieve` | no-op | MQ `FINISH_READ` → release SHM read lock | +### 3.1 MQ Request Types Used by Non-GPU Path -`create_non_gpu_context()` factory currently always returns `NonGpuContextPickle`. -Future: probe `/dev/shm` availability and capacity, fall back to pickle if -insufficient. +The non-GPU path uses five request types: -## 5. Data Path: Gather / Scatter +1. `REGISTER_KV_CACHE_NON_GPU_CONTEXT` + Worker registers non-CUDA KV layout metadata so the server can reconstruct + the worker KV memory layout for store/retrieve operations. -### 5.1 Chunk format +2. `PREPARE_STORE` + Worker asks server/transport to prepare store-side transfer state. -- **Non-MLA**: `[2, num_layers, chunk_tokens, hidden_dim]` — dim 0 = `(K, V)`. -- **MLA**: `[num_layers, chunk_tokens, hidden_dim]` — single latent vector. +3. `COMMIT_STORE` + Worker commits store data so server can persist it into storage. -Where `chunk_tokens = blocks_per_chunk × block_size`. +4. `PREPARE_RETRIEVE` + Worker asks server to prepare retrieval payload/state for a key. -### 5.2 Supported KV layouts +5. `COMMIT_RETRIEVE` + Worker acknowledges retrieval completion so transport state can be finalized. -| Format enum | Layout | Shape per layer | -|---|---|---| -| `NL_X_TWO_NB_BS_NH_HS` | NHD | `[2, NB, BS, NH, HS]` | -| `NL_X_NB_TWO_BS_NH_HS` | NHD (flashinfer) | `[NB, 2, BS, NH, HS]` | -| `NL_X_TWO_NB_NH_BS_HS` | HND | `[2, NB, NH, BS, HS]` | -| `NL_X_NB_TWO_NH_BS_HS` | HND (flashinfer) | `[NB, 2, NH, BS, HS]` | -| `NL_X_NB_BS_HS` | MLA | `[NB, BS, HS]` | +### 3.2 Data Flow: Pickle Path -### 5.3 Block-level indexing +Store: +1. Worker `prepare_store` RPC. +2. Worker gathers paged KV into CPU chunks. +3. Worker `commit_store` sends serialized bytes. +4. Server deserializes and writes to storage. -Gather and scatter operate at **block granularity** (`tensor[block_ids]`) -rather than per-token `index_select` / `index_copy_`. For HND layouts, a -`permute(0, 2, 1, 3)` converts between head-major and token-major order. +Retrieve: +1. Worker `prepare_retrieve` RPC. +2. Server reads from storage and returns serialized bytes. +3. Worker deserializes to CPU chunks. +4. Worker scatters chunks back to paged KV. +5. Worker `commit_retrieve` finalizes protocol state. -### 5.4 Utility functions +```text +Store (pickle) +Worker: prepare_store --> Server +Worker: gather paged KV -> CPU chunks +Worker: commit_store(serialized bytes) --> Server +Server: deserialize -> storage write + +Retrieve (pickle) +Worker: prepare_retrieve --> Server +Server: read storage -> serialize bytes +Server: serialized bytes --> Worker +Worker: deserialize -> scatter to paged KV +Worker: commit_retrieve --> Server +``` -- **`compute_kv_layout`** — extracts `(block_size, num_layers, hidden_dim_size, dtype_str, gpu_kv_format)` from live KV tensors. -- **`gather_paged_kv_to_cpu`** — gathers paged blocks into CPU chunk tensors. -- **`scatter_cpu_to_paged_kv`** — scatters CPU chunks back into device paged KV tensors. Respects `skip_first_n_tokens` for partial-prefix retrieval. +### 3.3 Data Flow: SHM Path (TODO) -## Non-goals +Store: +1. Worker `prepare_store` obtains SHM slot/offset. +2. Worker gathers directly into SHM-backed buffers. +3. Worker `commit_store` notifies server to consume SHM data. -- No change to existing CUDA IPC path semantics. -- No CPU-specific logic added to shared `gpu_connector/utils.py`. -- No wire-protocol incompatibility between CUDA and non-GPU context workers in - the same cluster. +Retrieve: +1. Worker `prepare_retrieve` asks server to populate SHM. +2. Server writes retrieved chunks into SHM. +3. Worker scatters from SHM-backed buffers into paged KV. +4. Worker `commit_retrieve` releases/read-completes SHM state. From 09415fd16cb3ffe978f2990f36c52a83210ca711 Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Wed, 20 May 2026 20:03:21 +0800 Subject: [PATCH 23/23] rename test file Signed-off-by: Tony Lin --- ...non_cuda_context.py => test_non_cuda_data_transfer.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename tests/v1/multiprocess/{test_non_cuda_context.py => test_non_cuda_data_transfer.py} (97%) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_data_transfer.py similarity index 97% rename from tests/v1/multiprocess/test_non_cuda_context.py rename to tests/v1/multiprocess/test_non_cuda_data_transfer.py index 5da7dc47aca..f8a281a8785 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_data_transfer.py @@ -18,7 +18,7 @@ def _make_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer NHD KV tensors for non-CUDA context tests.""" + """Build per-layer NHD KV tensors for non-CUDA data transfer tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -33,7 +33,7 @@ def _make_mla_kv_caches( block_size: int = 4, hidden_size: int = 16, ) -> dict[str, torch.Tensor]: - """Build per-layer MLA KV tensors for non-CUDA context tests. + """Build per-layer MLA KV tensors for non-CUDA data transfer tests. Args: num_layers: Number of KV layers to generate. @@ -58,7 +58,7 @@ def _make_hnd_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer HND KV tensors for non-CUDA context tests.""" + """Build per-layer HND KV tensors for non-CUDA data transfer tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn( @@ -74,7 +74,7 @@ def _make_hnd_flashinfer_kv_caches( num_heads: int = 2, head_size: int = 8, ) -> dict[str, torch.Tensor]: - """Build per-layer HND flash-infer KV tensors for non-CUDA context tests.""" + """Build per-layer HND flash-infer KV tensors for non-CUDA data transfer tests.""" kv_caches = {} for i in range(num_layers): kv_caches[f"layer_{i}"] = torch.randn(