diff --git a/docs/design/v1/multiprocess/non_gpu_context_design.md b/docs/design/v1/multiprocess/non_gpu_context_design.md new file mode 100644 index 00000000000..5ed541d30ed --- /dev/null +++ b/docs/design/v1/multiprocess/non_gpu_context_design.md @@ -0,0 +1,217 @@ +# Non-GPU Context Design (Multiprocess Mode) + +## 1. Motivation + +LMCache multiprocess mode originally depended on CUDA IPC: workers send IPC handles, +and the server reads/writes worker GPU memory directly. That path works well on +CUDA, but the required primitives are CUDA-specific (IPC memory handles, +interprocess CUDA events, CUDA stream semantics). + +For **CPU, XPU, HPU, and other non-CUDA devices**, those primitives do not exist. +The non-GPU context design introduces a device-agnostic path where workers move KV +data through CPU chunks instead of CUDA IPC handles. + +Goal: keep the existing CUDA path unchanged while adding a second path that works +across non-CUDA backends. + +## 2. Design + +### 2.1 Architecture Overview + +```text +Worker adapter (vLLM MP adapter) + └─ TransferContext + ├─ HandleTransferContext (CUDA IPC path) + └─ DataTransferContext (non-CUDA data path) + └─ NonGpuContext + ├─ NonGpuContextPickle + └─ NonGpuContextShm (TODO) +``` + +State machine overview (worker-side): + +```text + create_transfer_context() + | + +---------------+---------------+ + | | + v v + HandleTransferContext DataTransferContext + (device == CUDA) (device != CUDA) + | | + v v + register() register() + | | + +---------------+---------------+ + | + v + READY + | + +---------------+-------------------------------+ + | | + v v + submit_store (handle path) submit_store (data path) + -> STORE request (async) -> prepare_store -> gather -> commit_store + | | + +---------------+-------------------------------+ + | + v + READY + | + +---------------+-------------------------------+ + | | + v v + submit_retrieve (handle path) submit_retrieve (data path) + -> RETRIEVE request (async) -> prepare_retrieve -> scatter -> commit_retrieve + | | + +---------------+-------------------------------+ + | + v + READY + | + v + close() +``` + +Overall data flow: +- **CUDA path**: worker sends a handle, server pulls/pushes data directly. +- **Non-CUDA path**: worker gathers/scatters paged KV and exchanges CPU-side data + via a transport-specific `NonGpuContext` implementation. + +### 2.2 Worker Side: TransferContext + +`TransferContext` is the worker-side transport abstraction with four methods: +`register`, `submit_store`, `submit_retrieve`, and `close`. +The contract is intentionally minimal so worker adapters only depend on these +four lifecycle and transfer operations. + +- **HandleTransferContext** keeps the original CUDA IPC behavior: + worker sends a handle and server performs direct GPU-side transfer. +- **DataTransferContext** is the non-CUDA path: + worker transfers actual data chunks through `NonGpuContext`. + +`DataTransferContext` flows: +- **submit_store**: `prepare_store` → `gather_paged_kv_to_cpu` → `commit_store` +- **submit_retrieve**: `prepare_retrieve` → `scatter_cpu_to_paged_kv` → `commit_retrieve` + +Why `prepare → data operation → commit`: +- `prepare_*`: set up transport state (for SHM this allocates/returns shared buffers; + for pickle it is a protocol RPC that does not allocate transfer buffers). +- gather/scatter: worker-local data movement between paged KV and contiguous + CPU chunks, performed between protocol phases. +- `commit_*`: finalize and notify server to consume or release transfer state. + +`create_transfer_context()` selects the implementation once based on device type +(CUDA → `HandleTransferContext`, otherwise → `DataTransferContext`). +It also validates that all KV cache tensors share one device type and rejects +mixed-device configurations by raising an error. + +| Context | What is transferred | Who performs copy work | Completion style | +|---|---|---|---| +| HandleTransferContext | Device handle/reference | Server pulls/pushes via IPC | Async MQ future | +| DataTransferContext | Actual CPU chunk data | Worker gather/scatter + transport commit | Synchronous worker-side flow | + +### 2.3 Server Side: GPU Context vs Non-GPU Context + +- **GPU Context (existing path):** server uses CUDA IPC handles to access worker + device memory directly. +- **Non-GPU Context:** server participates in two separate two-phase protocols + exposed by `NonGpuContext`: `prepare_store/commit_store` for store, and + `prepare_retrieve/commit_retrieve` for retrieve, plus lifecycle cleanup via + `close`. + +`NonGpuContext` implementations: +- **NonGpuContextPickle**: serialize/deserialize chunk payloads with pickle. +- **NonGpuContextShm**: shared-memory transport (planned/TODO). + +This split keeps server protocol stable while allowing transport-specific behavior +behind one interface contract. + +### 2.4 Transport Comparison + +**Store (worker → server storage):** + +| Transport | Copies | Data flow | +|---|---|---| +| Handle (CUDA IPC) | 2 | GPU KV → GPU staging buffer → CPU memory object | +| Pickle | 4 | GPU KV → CPU chunk → serialize → deserialize → CPU memory object | +| SHM (TODO) | 1 | GPU KV → CPU memory object (SHM mapped) | + +**Retrieve (server storage → worker):** + +| Transport | Copies | Data flow | +|---|---|---| +| Handle (CUDA IPC) | 2 | CPU memory object → GPU staging buffer → GPU KV | +| Pickle | 4 | CPU memory object → serialize → deserialize → CPU chunk → GPU KV | +| SHM (TODO) | 1 | CPU memory object (SHM mapped) → GPU KV | + +| Transport | Pros | Cons | Best fit | +|---|---|---|---| +| Handle (CUDA IPC) | Mature path, good async overlap | CUDA-only | NVIDIA CUDA deployments | +| Pickle | Works everywhere, no SHM setup | Extra serialization + copy overhead | Universal fallback | +| SHM (TODO) | Lowest copy count, no serialization | Requires enough `/dev/shm` and synchronization | High-throughput non-CUDA setups | + +## 3. Protocol & Data Flow + +### 3.1 MQ Request Types Used by Non-GPU Path + +The non-GPU path uses five request types: + +1. `REGISTER_KV_CACHE_NON_GPU_CONTEXT` + Worker registers non-CUDA KV layout metadata so the server can reconstruct + the worker KV memory layout for store/retrieve operations. + +2. `PREPARE_STORE` + Worker asks server/transport to prepare store-side transfer state. + +3. `COMMIT_STORE` + Worker commits store data so server can persist it into storage. + +4. `PREPARE_RETRIEVE` + Worker asks server to prepare retrieval payload/state for a key. + +5. `COMMIT_RETRIEVE` + Worker acknowledges retrieval completion so transport state can be finalized. + +### 3.2 Data Flow: Pickle Path + +Store: +1. Worker `prepare_store` RPC. +2. Worker gathers paged KV into CPU chunks. +3. Worker `commit_store` sends serialized bytes. +4. Server deserializes and writes to storage. + +Retrieve: +1. Worker `prepare_retrieve` RPC. +2. Server reads from storage and returns serialized bytes. +3. Worker deserializes to CPU chunks. +4. Worker scatters chunks back to paged KV. +5. Worker `commit_retrieve` finalizes protocol state. + +```text +Store (pickle) +Worker: prepare_store --> Server +Worker: gather paged KV -> CPU chunks +Worker: commit_store(serialized bytes) --> Server +Server: deserialize -> storage write + +Retrieve (pickle) +Worker: prepare_retrieve --> Server +Server: read storage -> serialize bytes +Server: serialized bytes --> Worker +Worker: deserialize -> scatter to paged KV +Worker: commit_retrieve --> Server +``` + +### 3.3 Data Flow: SHM Path (TODO) + +Store: +1. Worker `prepare_store` obtains SHM slot/offset. +2. Worker gathers directly into SHM-backed buffers. +3. Worker `commit_store` notifies server to consume SHM data. + +Retrieve: +1. Worker `prepare_retrieve` asks server to populate SHM. +2. Server writes retrieved chunks into SHM. +3. Worker scatters from SHM-backed buffers into paged KV. +4. Worker `commit_retrieve` releases/read-completes SHM state. diff --git a/lmcache/integration/vllm/vllm_multi_process_adapter.py b/lmcache/integration/vllm/vllm_multi_process_adapter.py index b10ed5a4e56..5335cd259d1 100644 --- a/lmcache/integration/vllm/vllm_multi_process_adapter.py +++ b/lmcache/integration/vllm/vllm_multi_process_adapter.py @@ -13,7 +13,8 @@ # First Party from lmcache.integration.request_telemetry.factory import RequestTelemetryFactory -from lmcache.utils import EngineType, _lmcache_nvtx_annotate, init_logger +from lmcache.integration.vllm.utils import vllm_layout_hints +from lmcache.utils import _lmcache_nvtx_annotate, init_logger from lmcache.v1.multiprocess.custom_types import ( BlockAllocationRecord, CudaIPCWrapper, @@ -22,6 +23,10 @@ ) from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture from lmcache.v1.multiprocess.protocol import RequestType, get_response_class +from lmcache.v1.multiprocess.transfer_context import ( + TransferContext, + create_transfer_context, +) from lmcache.v1.periodic_thread import PeriodicThread, ThreadLevel, ThreadRunSummary logger = init_logger(__name__) @@ -803,6 +808,9 @@ def __init__( # Registered kv caches from vLLM self.kv_caches: dict[str, torch.Tensor] = {} + # Transport context for transfer operations. + self.transfer_ctx: TransferContext | None = None + # Request futures self.store_futures: dict[str, MessagingFuture[StoreResult]] = {} # request_id -> (future, block_ids) @@ -939,27 +947,24 @@ def _send_register_kv_caches_request( ConnectionError: if the server does not respond within mq_timeout. """ - # First Party - from lmcache.integration.vllm.utils import vllm_layout_hints - + self.kv_caches = kv_caches + self.transfer_ctx = create_transfer_context(kv_caches) layout_hints = vllm_layout_hints() layout_hints["inference_engine_logical_block_size"] = ( self.vllm_logical_block_size ) - future = send_lmcache_request( - self.mq_client, - RequestType.REGISTER_KV_CACHE, - [ + try: + self.transfer_ctx.register( self.instance_id, - wrap_kv_caches(kv_caches), + kv_caches, self.model_name, self.world_size, - EngineType.VLLM, - layout_hints, - ], - ) - try: - future.result(timeout=self._mq_timeout) + self.blocks_in_chunk, + self.mq_client, + self._mq_timeout, + send_request=send_lmcache_request, + layout_hints=layout_hints, + ) except TimeoutError: raise ConnectionError( "LMCache server did not respond to " @@ -1049,11 +1054,20 @@ def submit_store_request( request_id=request_id, cache_salt=cache_salt, ) - future = send_lmcache_request( - self.mq_client, - RequestType.STORE, - [key, self.instance_id, op.block_ids, event.ipc_handle()], - ).to_cuda_future() + if self.transfer_ctx is None: + raise RuntimeError( + "Transfer context is not initialized. " + "Call register_kv_caches() before submitting store requests." + ) + future = self.transfer_ctx.submit_store( + request_id, + key, + self.instance_id, + self.kv_caches, + op.block_ids, + event, + self.blocks_in_chunk, + ) self.store_futures[request_id] = future @_lmcache_nvtx_annotate @@ -1088,17 +1102,21 @@ def submit_retrieve_request( request_id=request_id, cache_salt=cache_salt, ) - future = send_lmcache_request( - self.mq_client, - RequestType.RETRIEVE, - [ - key, - self.instance_id, - op.block_ids, - event.ipc_handle(), - op.skip_first_n_tokens, - ], - ).to_cuda_future() + if self.transfer_ctx is None: + raise RuntimeError( + "Transfer context is not initialized. " + "Call register_kv_caches() before submitting retrieve requests." + ) + future = self.transfer_ctx.submit_retrieve( + request_id, + key, + self.instance_id, + self.kv_caches, + op.block_ids, + event, + self.blocks_in_chunk, + skip_first_n_tokens=op.skip_first_n_tokens, + ) self.retrieve_futures[request_id] = (future, list(op.block_ids)) @_lmcache_nvtx_annotate @@ -1309,6 +1327,10 @@ def shutdown(self): self._mq_timeout, ) + if self.transfer_ctx is not None: + self.transfer_ctx.close() + self.transfer_ctx = None + self.mq_client.close() self.request_telemetry.close() diff --git a/lmcache/python_ops_fallback.py b/lmcache/python_ops_fallback.py index 2c92df067ff..10fce1d4afa 100644 --- a/lmcache/python_ops_fallback.py +++ b/lmcache/python_ops_fallback.py @@ -17,7 +17,7 @@ import torch # First Party -from lmcache import torch_dev, torch_device_type +from lmcache import torch_dev # Store the tensor objects in memory so that they can be accessed # outside the scope of this file @@ -327,10 +327,7 @@ def alloc_pinned_numa_ptr(size: int, numa_id: int = 0) -> int: Note: NUMA node selection is not supported on non-CUDA.""" # Create a 1D uint8 CPU tensor, as uint8 == 1 byte - # On XPU (Intel GPU), PyTorch 2.4+ supports pin_memory=True via SYCL USM - # host allocation, enabling fast DMA for XPU<->CPU transfers. - pin_memory = torch_device_type == "xpu" - tensor = torch.empty(size, dtype=torch.uint8, pin_memory=pin_memory) + tensor = torch.empty(size, dtype=torch.uint8, pin_memory=False) # First-touch initialization (forces physical allocation) tensor.fill_(0) @@ -358,10 +355,7 @@ def alloc_pinned_ptr(size: int, device_id: int = 0) -> int: fast DMA transfers. On other non-CUDA platforms, pinning is not supported.""" # Create a 1D uint8 CPU tensor, as uint8 == 1 byte - # On XPU (Intel GPU), PyTorch 2.4+ supports pin_memory=True via SYCL USM - # host allocation, enabling fast DMA for XPU<->CPU transfers. - pin_memory = torch_device_type == "xpu" - tensor = torch.empty(size, dtype=torch.uint8, pin_memory=pin_memory) + tensor = torch.empty(size, dtype=torch.uint8, pin_memory=False) # First-touch initialization (forces physical allocation) tensor.fill_(0) diff --git a/lmcache/v1/distributed/config.py b/lmcache/v1/distributed/config.py index 5bdf503ef4d..2690b043970 100644 --- a/lmcache/v1/distributed/config.py +++ b/lmcache/v1/distributed/config.py @@ -10,12 +10,16 @@ import argparse # First Party +from lmcache import torch_dev +from lmcache.logging import init_logger from lmcache.v1.distributed.l2_adapters.config import ( L2AdaptersConfig, add_l2_adapters_args, parse_args_to_l2_adapters_config, ) +logger = init_logger(__name__) + @dataclass class L1MemoryManagerConfig: @@ -38,6 +42,15 @@ class L1MemoryManagerConfig: def __post_init__(self): self.init_size_in_bytes = min(self.init_size_in_bytes, self.size_in_bytes) + # LazyMemoryAllocator requires cudart (CUDA host-pinned memory). + # Auto-disable on non-CUDA backends to avoid a RuntimeError. + if self.use_lazy and not hasattr(torch_dev, "cudart"): + logger.warning( + "LazyMemoryAllocator requires cudart which is not available " + "on the current backend. Disabling l1-use-lazy." + ) + self.use_lazy = False + @dataclass class L1ManagerConfig: diff --git a/lmcache/v1/multiprocess/custom_types.py b/lmcache/v1/multiprocess/custom_types.py index ec82b8bbd47..28cc2e3a85a 100644 --- a/lmcache/v1/multiprocess/custom_types.py +++ b/lmcache/v1/multiprocess/custom_types.py @@ -315,6 +315,30 @@ def no_worker_id_version(self) -> "IPCCacheEngineKey": KVCache = list[CudaIPCWrapper] +class RegisterNonGpuContextPayload(msgspec.Struct): + """Payload for the REGISTER_KV_CACHE_NON_GPU_CONTEXT protocol message. + + Attributes: + instance_id: Worker instance identifier (typically PID). + model_name: Model name associated with this worker. + world_size: Worker world size used in cache keys. + block_size: Tokens per paged block. + num_layers: Number of model layers. + hidden_dim_size: Flattened hidden dimension per token. + dtype_str: Torch dtype name (e.g. ``"float16"``). + use_mla: Whether the worker KV format is MLA. + """ + + instance_id: int + model_name: str + world_size: int + block_size: int + num_layers: int + hidden_dim_size: int + dtype_str: str + use_mla: bool + + @dataclass class CustomizedSerdeConfig: serializer: Callable[[Any], bytes] diff --git a/lmcache/v1/multiprocess/non_gpu_context.py b/lmcache/v1/multiprocess/non_gpu_context.py new file mode 100644 index 00000000000..e782c76d0c3 --- /dev/null +++ b/lmcache/v1/multiprocess/non_gpu_context.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Non-GPU context abstractions and utilities for multiprocess mode. + +This module provides: +- ``NonGpuContextMetadata``: layout metadata dataclass for non-CUDA workers. +- ``NonGpuContext``: abstract base class with a two-phase prepare/commit + interface for CPU-side KV data transfer. Concrete implementations (e.g. + ``NonGpuContextPickle``) each decide *how* data is serialised and transported. +- ``create_non_gpu_context()``: factory that returns the appropriate + ``NonGpuContext`` subclass (currently always ``NonGpuContextPickle``). +- ``compute_kv_layout``, ``gather_paged_kv_to_cpu``, ``scatter_cpu_to_paged_kv``: + shared gather/scatter utilities used by all concrete implementations. +""" + +# Standard +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, cast + +# Third Party +import torch + +# First Party +from lmcache.utils import EngineType +from lmcache.v1.distributed.api import MemoryLayoutDesc + + +@dataclass +class NonGpuContextMetadata: + """Non-GPU context layout metadata for non-CUDA workers. + + Attributes: + layout_desc: Memory layout descriptor used to interpret chunk payloads. + block_size: Number of tokens per paged block. + use_mla: Whether the worker KV format is MLA. + """ + + layout_desc: MemoryLayoutDesc + block_size: int + use_mla: bool + + +class NonGpuContext(ABC): + """Abstract base class for CPU-side KV data transfer contexts. + + All concrete implementations share a common message-queue client and + expose a uniform two-phase ``prepare/commit`` interface so that the + worker adapter is implementation-agnostic. + + Args: + metadata: Layout metadata describing the chunk format. + mq_client: Message-queue client used for server communication. + mq_timeout: Timeout in seconds for blocking MQ requests. + """ + + def __init__( + self, + metadata: NonGpuContextMetadata, + mq_client: Any, + mq_timeout: float, + ) -> None: + self.metadata = metadata + self.mq_client = mq_client + self.mq_timeout = mq_timeout + + @property + def layout_desc(self) -> MemoryLayoutDesc: + """The memory layout descriptor for this context.""" + return self.metadata.layout_desc + + @abstractmethod + def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Prepare store. Returns pre-allocated out buffers (shm) or None (pickle).""" + ... + + @abstractmethod + def commit_store( + self, key: Any, instance_id: int, chunks: list[torch.Tensor] + ) -> bool: + """Commit store. Pickle: serialize and send. Shm: notify server.""" + ... + + @abstractmethod + def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Prepare retrieve. Returns chunks or shm views, or None on miss.""" + ... + + @abstractmethod + def commit_retrieve(self, key: Any, instance_id: int) -> bool: + """Commit retrieve. Pickle: no-op. Shm: release read locks.""" + ... + + @abstractmethod + def close(self) -> None: + """Release any resources held by this context.""" + ... + + +def create_non_gpu_context( + metadata: NonGpuContextMetadata, + mq_client: Any, + mq_timeout: float, +) -> NonGpuContext: + """Factory that returns the appropriate :class:`NonGpuContext` implementation. + + Currently always returns a pickle-based implementation + (``NonGpuContextPickle``). A future SHM-capable PR + may probe for shared-memory availability and fall back to pickle. + + Args: + metadata: Layout metadata for the non-GPU context. + mq_client: Message-queue client for server communication. + mq_timeout: Timeout in seconds for blocking MQ requests. + + Returns: + A concrete :class:`NonGpuContext` instance. + """ + # Local + from .non_gpu_context_pickle import NonGpuContextPickle + + return NonGpuContextPickle(metadata, mq_client, mq_timeout) + + +# --------------------------------------------------------------------------- +# Shared gather / scatter utilities +# --------------------------------------------------------------------------- + + +def compute_kv_layout( + kv_caches: dict[str, torch.Tensor], + layout_hints: Any | None = None, +) -> tuple[int, int, int, str, Any]: + """Compute KV layout metadata from KV tensors. + + Args: + kv_caches: Per-layer KV tensor mapping. + layout_hints: Optional engine layout hints. + + Returns: + Tuple of ``(block_size, num_layers, hidden_dim_size, dtype_str,`` + ``gpu_kv_format)``. + + Raises: + ValueError: If ``kv_caches`` is empty. + """ + # First Party + from lmcache.v1.gpu_connector.utils import ( + get_block_size, + get_hidden_dim_size, + get_num_layers, + normalize_kv_and_discover_format, + ) + + tensors = list(kv_caches.values()) + if not tensors: + raise ValueError("kv_caches is empty. Cannot compute KV layout.") + + gpu_kv_format, normalized = normalize_kv_and_discover_format( + tensors, EngineType.VLLM, layout_hints=layout_hints + ) + block_size = get_block_size(normalized, gpu_kv_format) + num_layers = get_num_layers(normalized, gpu_kv_format) + hidden_dim_size = get_hidden_dim_size(normalized, gpu_kv_format) + dtype_str = str(tensors[0].dtype).replace("torch.", "") + return block_size, num_layers, hidden_dim_size, dtype_str, gpu_kv_format + + +def gather_paged_kv_to_cpu( + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + blocks_per_chunk: int, + layout_hints: Any | None = None, + gpu_kv_format: Any | None = None, + out: list[torch.Tensor] | None = None, +) -> list[torch.Tensor]: + """Gather paged KV blocks into CPU chunk tensors. + + Args: + kv_caches: Per-layer KV tensor mapping. + block_ids: Flattened block IDs for all chunks. + blocks_per_chunk: Number of paged blocks in one LMCache chunk. + layout_hints: Optional engine layout hints. + gpu_kv_format: Optional pre-detected KV format. + + Returns: + List of CPU tensors, one per chunk. For non-MLA each chunk has shape + ``[2, num_layers, chunk_tokens, hidden_dim]`` where dimension ``0`` + stores ``(K, V)``. For MLA (multi-head latent attention) each chunk + has shape ``[num_layers, chunk_tokens, hidden_dim]``. + """ + # First Party + from lmcache.v1.gpu_connector.utils import ( + get_block_size, + is_mla, + normalize_kv_and_discover_format, + ) + import lmcache.c_ops as lmc_ops + + tensors = list(kv_caches.values()) + fmt, normalized = normalize_kv_and_discover_format( + tensors, EngineType.VLLM, layout_hints=layout_hints + ) + if gpu_kv_format is None: + gpu_kv_format = fmt + use_mla = is_mla(gpu_kv_format) + is_hnd = gpu_kv_format in ( + lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + ) + + block_size = get_block_size(normalized, gpu_kv_format) + num_chunks = len(block_ids) // blocks_per_chunk + + # After normalization the structure is always a list of per-layer + # tensors. Cast once so all downstream indexing is typed correctly. + layer_tensors = cast(list[torch.Tensor], normalized) + + chunks: list[torch.Tensor] = [] if out is None else out + for chunk_idx in range(num_chunks): + chunk_block_ids = block_ids[ + chunk_idx * blocks_per_chunk : (chunk_idx + 1) * blocks_per_chunk + ] + if use_mla: + mla_layers: list[torch.Tensor] = [] + idx = torch.tensor(chunk_block_ids, dtype=torch.long) + for layer in layer_tensors: + layer_blocks = layer[idx] + mla_layers.append( + layer_blocks.reshape( + len(chunk_block_ids) * block_size, layer_blocks.shape[-1] + ) + ) + chunk_tensor = torch.stack(mla_layers, dim=0) + if out is not None: + out[chunk_idx].copy_(chunk_tensor, non_blocking=True) + else: + chunks.append(chunk_tensor.cpu()) + else: + k_layers: list[torch.Tensor] = [] + v_layers: list[torch.Tensor] = [] + for layer in layer_tensors: + if is_hnd: + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _num_blocks, num_heads, _block_size, head_size = k_t.shape + k_blocks = k_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + v_blocks = v_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + # HND blocks are [NB, NH, BS, HS]; convert to token-major + # [NB, BS, NH, HS] before flattening to [tokens, NH*HS]. + k_layers.append( + k_blocks.permute(0, 2, 1, 3).reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + v_layers.append( + v_blocks.permute(0, 2, 1, 3).reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + else: + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _num_blocks, _block_size, num_heads, head_size = k_t.shape + k_blocks = k_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + v_blocks = v_t[torch.tensor(chunk_block_ids, dtype=torch.long)] + k_layers.append( + k_blocks.reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + v_layers.append( + v_blocks.reshape( + len(chunk_block_ids) * block_size, num_heads * head_size + ) + ) + k_stacked = torch.stack(k_layers, dim=0) + v_stacked = torch.stack(v_layers, dim=0) + chunk_tensor = torch.stack([k_stacked, v_stacked], dim=0) + if out is not None: + out[chunk_idx].copy_(chunk_tensor, non_blocking=True) + else: + chunks.append(chunk_tensor.cpu()) + return chunks + + +def scatter_cpu_to_paged_kv( + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + chunks: list[torch.Tensor], + blocks_per_chunk: int, + skip_first_n_tokens: int = 0, + layout_hints: Any | None = None, + gpu_kv_format: Any | None = None, +) -> None: + """Scatter CPU chunk tensors back into paged KV tensors. + + Args: + kv_caches: Per-layer KV tensor mapping to write into. + block_ids: Flattened destination block IDs for all chunks. + chunks: List of CPU chunk tensors (as returned by + :func:`gather_paged_kv_to_cpu`). + blocks_per_chunk: Number of paged blocks in one LMCache chunk. + skip_first_n_tokens: Token prefix to skip when scattering. + layout_hints: Optional engine layout hints. + gpu_kv_format: Optional pre-detected KV format. + """ + # First Party + from lmcache.v1.gpu_connector.utils import ( + get_block_size, + is_mla, + normalize_kv_and_discover_format, + ) + import lmcache.c_ops as lmc_ops + + if not chunks: + return + + tensors = list(kv_caches.values()) + fmt, normalized = normalize_kv_and_discover_format( + tensors, EngineType.VLLM, layout_hints=layout_hints + ) + if gpu_kv_format is None: + gpu_kv_format = fmt + use_mla = is_mla(gpu_kv_format) + + block_size = get_block_size(normalized, gpu_kv_format) + device = tensors[0].device + is_hnd = gpu_kv_format in ( + lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS, + lmc_ops.GPUKVFormat.NL_X_NB_TWO_NH_BS_HS, + ) + + # After normalization the structure is always a list of per-layer + # tensors. Cast once so all downstream indexing is typed correctly. + layer_tensors = cast(list[torch.Tensor], normalized) + + for chunk_idx, chunk_cpu in enumerate(chunks): + chunk_block_ids = block_ids[ + chunk_idx * blocks_per_chunk : (chunk_idx + 1) * blocks_per_chunk + ] + if not chunk_block_ids: + continue + + chunk_start_token = chunk_idx * blocks_per_chunk * block_size + chunk_end_token = chunk_start_token + len(chunk_block_ids) * block_size + effective_start = max(chunk_start_token, skip_first_n_tokens) + if effective_start >= chunk_end_token: + continue + + skip_blocks_in_chunk = (effective_start - chunk_start_token) // block_size + effective_block_ids = chunk_block_ids[skip_blocks_in_chunk:] + if not effective_block_ids: + continue + + skip_tokens = skip_blocks_in_chunk * block_size + chunk_device = chunk_cpu.to(device) + + if use_mla: + eff_idx = torch.tensor(effective_block_ids, dtype=torch.long) + for layer_idx, layer in enumerate(layer_tensors): + mla_src = chunk_device[layer_idx, skip_tokens:] + hidden_size = layer.shape[-1] + mla_src_3d = mla_src.reshape( + len(effective_block_ids), block_size, hidden_size + ) + layer[eff_idx] = mla_src_3d + elif is_hnd: + for layer_idx, layer in enumerate(layer_tensors): + k_src = chunk_device[0, layer_idx, skip_tokens:] + v_src = chunk_device[1, layer_idx, skip_tokens:] + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _nb, nh, _bs, hs = k_t.shape + k_blocks = k_src.reshape( + len(effective_block_ids), block_size, nh, hs + ).permute(0, 2, 1, 3) + v_blocks = v_src.reshape( + len(effective_block_ids), block_size, nh, hs + ).permute(0, 2, 1, 3) + k_t[effective_block_ids] = k_blocks + v_t[effective_block_ids] = v_blocks + else: + for layer_idx, layer in enumerate(layer_tensors): + k_src = chunk_device[0, layer_idx, skip_tokens:] + v_src = chunk_device[1, layer_idx, skip_tokens:] + if gpu_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_BS_NH_HS: + k_t = layer[0] + v_t = layer[1] + else: + k_t = layer[:, 0] + v_t = layer[:, 1] + _num_blocks, _block_size, num_heads, head_size = k_t.shape + k_src_4d = k_src.reshape( + len(effective_block_ids), block_size, num_heads, head_size + ) + v_src_4d = v_src.reshape( + len(effective_block_ids), block_size, num_heads, head_size + ) + k_t[effective_block_ids] = k_src_4d + v_t[effective_block_ids] = v_src_4d diff --git a/lmcache/v1/multiprocess/non_gpu_context_pickle.py b/lmcache/v1/multiprocess/non_gpu_context_pickle.py new file mode 100644 index 00000000000..d310b9c65d1 --- /dev/null +++ b/lmcache/v1/multiprocess/non_gpu_context_pickle.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Pickle-based NonGpuContext implementation for multiprocess mode.""" + +# Standard +from typing import Any +import pickle + +# Third Party +import torch + +# First Party +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContext, + NonGpuContextMetadata, +) +from lmcache.v1.multiprocess.protocol import RequestType, get_response_class + + +class NonGpuContextPickle(NonGpuContext): + """Pickle-based implementation of :class:`NonGpuContext`. + + Transport mechanism: + - **Store**: ``prepare_store`` sends ``PREPARE_STORE`` (returns empty slots + for pickle mode); ``commit_store`` serialises chunks and sends + ``COMMIT_STORE``. + - **Retrieve**: ``prepare_retrieve`` sends ``PREPARE_RETRIEVE`` and + deserialises the returned bytes; ``commit_retrieve`` sends + ``COMMIT_RETRIEVE`` (no-op for pickle). + """ + + def __init__( + self, + metadata: NonGpuContextMetadata, + mq_client: Any, + mq_timeout: float, + ) -> None: + super().__init__(metadata, mq_client, mq_timeout) + + def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Send PREPARE_STORE RPC. For pickle, returns no pre-allocated buffers.""" + future = self.mq_client.submit_request( + RequestType.PREPARE_STORE, + [key, instance_id], + get_response_class(RequestType.PREPARE_STORE), + ) + try: + future.result(timeout=self.mq_timeout) + except TimeoutError: + pass + return None + + def commit_store( + self, key: Any, instance_id: int, chunks: list[torch.Tensor] + ) -> bool: + """Serialize chunks and send via COMMIT_STORE. + + Returns: + ``True`` on success, ``False`` on failure or timeout. + """ + serialised = pickle.dumps(chunks) + future = self.mq_client.submit_request( + RequestType.COMMIT_STORE, + [key, instance_id, serialised], + get_response_class(RequestType.COMMIT_STORE), + ) + try: + return bool(future.result(timeout=self.mq_timeout)) + except TimeoutError: + return False + + def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + """Send PREPARE_RETRIEVE and deserialize the response data. + + Returns: + Chunks on hit, or None on miss/timeout. + """ + future = self.mq_client.submit_request( + RequestType.PREPARE_RETRIEVE, + [key, instance_id], + get_response_class(RequestType.PREPARE_RETRIEVE), + ) + try: + response = future.result(timeout=self.mq_timeout) + except TimeoutError: + return None + if not response.success or not response.data: + return None + chunks: list[torch.Tensor] = pickle.loads(response.data) + return chunks + + def commit_retrieve(self, key: Any, instance_id: int) -> bool: + """Send COMMIT_RETRIEVE (no-op for pickle path).""" + future = self.mq_client.submit_request( + RequestType.COMMIT_RETRIEVE, + [key, instance_id], + get_response_class(RequestType.COMMIT_RETRIEVE), + ) + try: + future.result(timeout=self.mq_timeout) + except TimeoutError: + pass + return True + + def close(self) -> None: + """No-op: the pickle path holds no persistent resources.""" diff --git a/lmcache/v1/multiprocess/protocols/base.py b/lmcache/v1/multiprocess/protocols/base.py index 383a41ff8c3..777ce029f29 100644 --- a/lmcache/v1/multiprocess/protocols/base.py +++ b/lmcache/v1/multiprocess/protocols/base.py @@ -48,6 +48,11 @@ class RequestType(enum.Enum): QUERY_PREFETCH_LOOKUP_HITS = enum.auto() FREE_LOOKUP_LOCKS = enum.auto() END_SESSION = enum.auto() + REGISTER_KV_CACHE_NON_GPU_CONTEXT = enum.auto() + PREPARE_STORE = enum.auto() + COMMIT_STORE = enum.auto() + PREPARE_RETRIEVE = enum.auto() + COMMIT_RETRIEVE = enum.auto() # Controller operations CLEAR = enum.auto() diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index e9f37fd422f..62ec4926cd8 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -12,15 +12,40 @@ - END_SESSION: End a session and clean up associated resources """ +# Standard +from dataclasses import dataclass, field + # First Party from lmcache.utils import EngineType from lmcache.v1.gpu_connector.utils import LayoutHints from lmcache.v1.multiprocess.custom_types import ( IPCCacheEngineKey, KVCache, + RegisterNonGpuContextPayload, ) from lmcache.v1.multiprocess.protocols.base import HandlerType, ProtocolDefinition + +@dataclass +class PrepareStoreResponse: + """Response for PREPARE_STORE.""" + + context: dict = field( + default_factory=dict + ) # pickle: {}, shm will put slot info here + + +@dataclass +class PrepareRetrieveResponse: + """Response for PREPARE_RETRIEVE.""" + + success: bool + data: bytes = b"" + context: dict = field( + default_factory=dict + ) # pickle: {}, shm will put slot info here + + # Define request names for this protocol group REQUEST_NAMES = [ "REGISTER_KV_CACHE", @@ -32,6 +57,11 @@ "QUERY_PREFETCH_LOOKUP_HITS", "FREE_LOOKUP_LOCKS", "END_SESSION", + "REGISTER_KV_CACHE_NON_GPU_CONTEXT", + "PREPARE_STORE", + "COMMIT_STORE", + "PREPARE_RETRIEVE", + "COMMIT_RETRIEVE", ] # Type alias for cache keys @@ -146,4 +176,33 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: response_class=None, handler_type=HandlerType.BLOCKING, ), + # Register non-GPU KV cache context + # Payload: + # - RegisterNonGpuContextPayload - all metadata fields in one struct + # Returns: None + "REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition( + payload_classes=[RegisterNonGpuContextPayload], + response_class=None, + handler_type=HandlerType.SYNC, + ), + "PREPARE_STORE": ProtocolDefinition( + payload_classes=[KeyType, int], + response_class=PrepareStoreResponse, + handler_type=HandlerType.BLOCKING, + ), + "COMMIT_STORE": ProtocolDefinition( + payload_classes=[KeyType, int, bytes], + response_class=bool, + handler_type=HandlerType.BLOCKING, + ), + "PREPARE_RETRIEVE": ProtocolDefinition( + payload_classes=[KeyType, int], + response_class=PrepareRetrieveResponse, + handler_type=HandlerType.BLOCKING, + ), + "COMMIT_RETRIEVE": ProtocolDefinition( + payload_classes=[KeyType, int], + response_class=bool, + handler_type=HandlerType.BLOCKING, + ), } diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 1ff8f6ca173..bb67d07bf54 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -5,10 +5,12 @@ from itertools import islice from typing import Generator import argparse +import pickle import threading import time # Third Party +import torch import zmq # First Party @@ -55,16 +57,22 @@ BlockAllocationRecord, IPCCacheEngineKey, KVCache, + RegisterNonGpuContextPayload, ) from lmcache.v1.multiprocess.gpu_context import ( GPUCacheContext, ) from lmcache.v1.multiprocess.mq import MessageQueueServer +from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata from lmcache.v1.multiprocess.protocol import ( RequestType, get_handler_type, get_payload_classes, ) +from lmcache.v1.multiprocess.protocols.engine import ( + PrepareRetrieveResponse, + PrepareStoreResponse, +) from lmcache.v1.multiprocess.session import SessionManager from lmcache.v1.multiprocess.token_hasher import TokenHasher import lmcache.c_ops as lmc_ops @@ -173,6 +181,45 @@ class _PrefetchJob: cache_salt: str = "" +@dataclass +class RegisteredContext: + """Registered context metadata for a single worker instance. + + At least one of ``gpu_context`` or ``non_cuda_metadata`` is expected to be + populated for valid registrations. + """ + + model_name: str + world_size: int + gpu_context: GPUCacheContext | None = None + non_cuda_metadata: NonGpuContextMetadata | None = None + + @property + def is_gpu(self) -> bool: + """Return whether this registration uses a GPU transfer context.""" + return self.gpu_context is not None + + def get_layout_desc(self, chunk_size: int) -> MemoryLayoutDesc: + """Return the layout descriptor for this registration. + + Args: + chunk_size: Chunk size in tokens used for GPU layout derivation. + + Returns: + The resolved memory layout descriptor. + + Raises: + ValueError: If no GPU context or non-CUDA metadata is configured. + """ + if self.gpu_context is not None: + return get_layout_desc(self.gpu_context, chunk_size) + if self.non_cuda_metadata is None: + raise ValueError( + "Invalid RegisteredContext: no GPU or non-CUDA metadata configured" + ) + return self.non_cuda_metadata.layout_desc + + # Main class for the mp cache engine class MPCacheEngine: def __init__( @@ -181,14 +228,8 @@ def __init__( chunk_size: int = 256, hash_algorithm: str = "blake3", ): - # GPU ID -> KV cache tensors - self.gpu_contexts: dict[int, GPUCacheContext] = {} - - # GPU ID -> (model name, world size) as metadata - # NOTE: This is mainly for determining the layout desc during prefetch - # We assume that if the (model name, world size) is the same, then - # the layout desc returned by the gpu context is the same. - self.gpu_context_meta: dict[int, tuple[str, int]] = {} + # Worker instance ID -> registered context metadata + self.contexts: dict[int, RegisteredContext] = {} # chunk size self.chunk_size = chunk_size @@ -216,6 +257,15 @@ def __init__( self._setup_metrics() + @property + def gpu_contexts(self) -> dict[int, GPUCacheContext]: + """Return GPU-only context mapping for backward compatibility.""" + return { + instance_id: ctx.gpu_context + for instance_id, ctx in self.contexts.items() + if ctx.gpu_context is not None + } + def register_kv_cache( self, instance_id: int, @@ -239,7 +289,7 @@ def register_kv_cache( layout_hints: See :class:`LayoutHints`. Forwarded to :class:`GPUCacheContext` for GPU KV format detection. """ - if instance_id in self.gpu_contexts: + if instance_id in self.contexts: logger.warning( "Instance %s's KV cache is already registered, " "skipping the new registration", @@ -253,8 +303,11 @@ def register_kv_cache( layout_hints=layout_hints or None, engine_type=engine_type, ) - self.gpu_contexts[instance_id] = gpu_context - self.gpu_context_meta[instance_id] = (model_name, world_size) + self.contexts[instance_id] = RegisteredContext( + model_name=model_name, + world_size=world_size, + gpu_context=gpu_context, + ) logger.info( "Registered KV cache for GPU ID %d with %d layers", instance_id, @@ -268,13 +321,216 @@ def unregister_kv_cache(self, instance_id: int) -> None: Args: instance_id (int): The GPU instance ID (such as PID). """ - if instance_id in self.gpu_contexts: - del self.gpu_contexts[instance_id] - del self.gpu_context_meta[instance_id] + context = self.contexts.pop(instance_id, None) + if context is None: + logger.warning( + "No registered context found for instance ID %d", instance_id + ) + return + + if context.is_gpu: logger.info("Unregistered KV cache for GPU ID %d", instance_id) torch_dev.empty_cache() else: - logger.warning("No KV cache found for GPU ID %d to unregister", instance_id) + logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) + + def register_kv_cache_non_gpu_context( + self, + payload: RegisterNonGpuContextPayload, + ) -> None: + """Register non-CUDA KV layout metadata for non-GPU context mode. + + Args: + payload: Struct containing all registration fields + (instance_id, model_name, world_size, block_size, + num_layers, hidden_dim_size, dtype_str, use_mla). + + Raises: + ValueError: If ``payload.dtype_str`` is not a valid torch dtype name. + """ + if payload.instance_id in self.contexts: + logger.warning( + "Instance %s's KV cache is already registered, " + "skipping the new registration", + payload.instance_id, + ) + return + + dtype = getattr(torch, payload.dtype_str, None) + if dtype is None or not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid dtype_str '{payload.dtype_str}': must be a valid torch dtype " + "attribute name (e.g. 'float16' for torch.float16, " + "'bfloat16' for torch.bfloat16, 'float32' for torch.float32)." + ) + + shape = ( + torch.Size([payload.num_layers, self.chunk_size, payload.hidden_dim_size]) + if payload.use_mla + else torch.Size( + [2, payload.num_layers, self.chunk_size, payload.hidden_dim_size] + ) + ) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + self.contexts[payload.instance_id] = RegisteredContext( + model_name=payload.model_name, + world_size=payload.world_size, + non_cuda_metadata=NonGpuContextMetadata( + layout_desc=layout_desc, + block_size=payload.block_size, + use_mla=payload.use_mla, + ), + ) + + def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: + """Resolve object keys from an IPC cache key. + + Args: + key: IPC cache key describing model/session/token range. + + Returns: + Resolved object keys for the requested token range. + + Raises: + ValueError: If ``key.worker_id`` is ``None``. + """ + session = self.session_manager.get_or_create(key.request_id) + session.set_tokens(list(key.token_ids)) + chunk_hashes = [ + TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) + ] + if key.worker_id is None: + raise ValueError("Must resolve keys with worker_id != None") + return ipc_key_to_object_keys(key, chunk_hashes) + + @_lmcache_nvtx_annotate + def prepare_store( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> PrepareStoreResponse: + """Prepare a store operation. For pickle mode, returns empty slots. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + + Returns: + PrepareStoreResponse with empty slots for pickle mode. + """ + + return PrepareStoreResponse(context={}) + + @_lmcache_nvtx_annotate + def commit_store( + self, + key: IPCCacheEngineKey, + instance_id: int, + cpu_data: bytes, + ) -> bool: + """Commit serialized CPU chunks to storage. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + cpu_data: Pickled list of CPU tensors produced by the worker. + + Returns: + ``True`` when all reserved objects are written, otherwise ``False``. + """ + obj_keys = self._resolve_obj_keys(key) + + context = self.contexts.get(instance_id) + if context is None or context.non_cuda_metadata is None: + raise ValueError( + f"non-CUDA context not registered for instance ID {instance_id}" + ) + ctx = context.non_cuda_metadata + chunks: list[torch.Tensor] = pickle.loads(cpu_data) + reserved_dict = self.storage_manager.reserve_write( + obj_keys, ctx.layout_desc, "new" + ) + written_keys: list[ObjectKey] = [] + try: + for idx, obj_key in enumerate(obj_keys): + if obj_key not in reserved_dict: + continue + if idx >= len(chunks): + continue + memory_obj = reserved_dict[obj_key] + if memory_obj.tensor is None: + continue + chunk_cpu = chunks[idx] + if chunk_cpu.shape != memory_obj.tensor.shape: + continue + memory_obj.tensor.copy_(chunk_cpu) + written_keys.append(obj_key) + finally: + if written_keys: + self.storage_manager.finish_write(written_keys) + + return len(written_keys) == len(reserved_dict) + + @_lmcache_nvtx_annotate + def prepare_retrieve( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> PrepareRetrieveResponse: + """Retrieve prefetched chunks and return serialized CPU tensors. + + Args: + key: Cache key for the token range to retrieve. + instance_id: Worker instance identifier. + + Returns: + PrepareRetrieveResponse with serialized data on hit. + """ + + obj_keys = self._resolve_obj_keys(key) + + context = self.contexts.get(instance_id) + if context is None or context.non_cuda_metadata is None: + raise ValueError( + f"non-CUDA context not registered for instance ID {instance_id}" + ) + + prefetched_keys: list[ObjectKey] = [] + try: + with self.storage_manager.read_prefetched_results(obj_keys) as memory_objs: + if not memory_objs or len(memory_objs) != len(obj_keys): + return PrepareRetrieveResponse(success=False, data=b"", context={}) + prefetched_keys = obj_keys[: len(memory_objs)] + chunks = [] + for memory_obj in memory_objs: + if memory_obj.tensor is None: + return PrepareRetrieveResponse( + success=False, data=b"", context={} + ) + chunks.append(memory_obj.tensor.cpu().clone()) + return PrepareRetrieveResponse( + success=True, data=pickle.dumps(chunks), context={} + ) + finally: + if prefetched_keys: + self.storage_manager.finish_read_prefetched(prefetched_keys) + + @_lmcache_nvtx_annotate + def commit_retrieve( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> bool: + """Finalize a retrieve operation. No-op for pickle mode. + + Args: + key: Cache key (unused for pickle). + instance_id: Worker instance identifier (unused for pickle). + + Returns: + Always ``True``. + """ + return True @_lmcache_nvtx_annotate def store( @@ -299,22 +555,18 @@ def store( that signals the completion of the store operation. The second element indicates whether the store operation was successful. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - st = time.perf_counter() + obj_keys = self._resolve_obj_keys(key) - assert key.worker_id is not None, "Must store with worker_id != None" - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) - - assert instance_id in self.gpu_contexts, ( - f"KV cache not registered for GPU ID {instance_id}" + context = self.contexts.get(instance_id) + assert context is not None, ( + f"No context registered for instance ID {instance_id}" + ) + assert context.gpu_context is not None, ( + f"GPU context not registered for instance ID {instance_id}" ) - gpu_context = self.gpu_contexts[instance_id] - model_name = self.gpu_context_meta[instance_id][0] + gpu_context = context.gpu_context + model_name = context.model_name # ``blocks_per_chunk`` is counted in inference-engine-side # blocks (each block addresses @@ -490,22 +742,18 @@ def retrieve( that signals the completion of the retrieve operation. The second element indicates whether the key was successfully retrieved. """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - st = time.perf_counter() + obj_keys = self._resolve_obj_keys(key) - assert key.worker_id is not None, "Must retrieve with worker_id != None" - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) - - assert instance_id in self.gpu_contexts, ( - f"KV cache not registered for GPU ID {instance_id}" + context = self.contexts.get(instance_id) + assert context is not None, ( + f"No context registered for instance ID {instance_id}" + ) + assert context.gpu_context is not None, ( + f"GPU context not registered for instance ID {instance_id}" ) - gpu_context = self.gpu_contexts[instance_id] - model_name = self.gpu_context_meta[instance_id][0] + gpu_context = context.gpu_context + model_name = context.model_name # CPU-synchronous sentinel: a GPU retrieve is about to be enqueued. # Must be published via publish() (not publish_on_stream) so the @@ -675,18 +923,16 @@ def _find_layout_desc( model_name: str, world_size: int, ) -> MemoryLayoutDesc | None: - """Find layout desc from a matching GPU context. + """Find layout desc from a matching GPU or CPU context. Returns: - The layout descriptor, or None if no context - matches (model_name, world_size). + The layout descriptor, or None if no context matches + ``(model_name, world_size)``. GPU contexts are checked first, + then CPU contexts. """ - for gpu_id, (m, w) in self.gpu_context_meta.items(): - if m == model_name and w == world_size: - return get_layout_desc( - self.gpu_contexts[gpu_id], - self.chunk_size, - ) + for context in self.contexts.values(): + if context.model_name == model_name and context.world_size == world_size: + return context.get_layout_desc(self.chunk_size) return None def lookup( @@ -1002,13 +1248,18 @@ def report_status(self) -> dict: sm = self.storage_manager.report_status() gpu_context_meta: dict[str, dict] = {} - for gpu_id, meta in self.gpu_context_meta.items(): + non_cuda_context_meta: dict[str, dict] = {} + registered_gpu_ids: list[int] = [] + registered_non_cuda_ids: list[int] = [] + + for instance_id, context in self.contexts.items(): entry: dict = { - "model_name": meta[0], - "world_size": meta[1], + "model_name": context.model_name, + "world_size": context.world_size, } - ctx = self.gpu_contexts.get(gpu_id) - if ctx is not None: + if context.gpu_context is not None: + registered_gpu_ids.append(instance_id) + ctx = context.gpu_context entry["kv_cache_layout"] = { "num_layers": ctx.num_layers, "inference_engine_logical_block_size": ( @@ -1026,15 +1277,26 @@ def report_status(self) -> dict: "attention_backend": ctx.attention_backend, "cache_size_per_token": ctx.cache_size_per_token(), } - gpu_context_meta[str(gpu_id)] = entry + gpu_context_meta[str(instance_id)] = entry + continue + + if context.non_cuda_metadata is not None: + registered_non_cuda_ids.append(instance_id) + non_cuda_context_meta[str(instance_id)] = { + **entry, + "block_size": context.non_cuda_metadata.block_size, + "use_mla": context.non_cuda_metadata.use_mla, + } return { "is_healthy": sm["is_healthy"], "engine_type": self.__class__.__name__, "chunk_size": self.chunk_size, "hash_algorithm": self.token_hasher.hash_algorithm_name, - "registered_gpu_ids": list(self.gpu_contexts.keys()), + "registered_gpu_ids": registered_gpu_ids, "gpu_context_meta": gpu_context_meta, + "registered_non_cuda_instance_ids": registered_non_cuda_ids, + "non_cuda_context_meta": non_cuda_context_meta, "active_sessions": self.session_manager.active_count(), "active_prefetch_jobs": self._active_prefetch_count(), "storage_manager": sm, @@ -1086,7 +1348,7 @@ def close(self) -> None: logger.info("MPCacheEngine closed") # Release GPU contexts - self.gpu_contexts.clear() + self.contexts.clear() def _active_prefetch_count(self) -> int: """Return the number of active prefetch jobs (thread-safe).""" @@ -1170,6 +1432,12 @@ def run_cache_server( server, RequestType.UNREGISTER_KV_CACHE, engine.unregister_kv_cache ) add_handler_helper(server, RequestType.STORE, engine.store) + add_handler_helper( + server, + RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, + engine.register_kv_cache_non_gpu_context, + ) + add_handler_helper(server, RequestType.PREPARE_STORE, engine.prepare_store) add_handler_helper(server, RequestType.LOOKUP, engine.lookup) add_handler_helper( server, RequestType.QUERY_PREFETCH_STATUS, engine.query_prefetch_status @@ -1181,6 +1449,9 @@ def run_cache_server( ) add_handler_helper(server, RequestType.FREE_LOOKUP_LOCKS, engine.free_lookup_locks) add_handler_helper(server, RequestType.RETRIEVE, engine.retrieve) + add_handler_helper(server, RequestType.COMMIT_STORE, engine.commit_store) + add_handler_helper(server, RequestType.PREPARE_RETRIEVE, engine.prepare_retrieve) + add_handler_helper(server, RequestType.COMMIT_RETRIEVE, engine.commit_retrieve) add_handler_helper(server, RequestType.CLEAR, engine.clear) add_handler_helper(server, RequestType.GET_CHUNK_SIZE, engine.get_chunk_size) add_handler_helper(server, RequestType.PING, engine.ping) @@ -1194,7 +1465,14 @@ def run_cache_server( # Assign thread pools server.add_affinity_thread_pool( - [RequestType.STORE, RequestType.RETRIEVE], + [ + RequestType.STORE, + RequestType.RETRIEVE, + RequestType.PREPARE_STORE, + RequestType.COMMIT_STORE, + RequestType.PREPARE_RETRIEVE, + RequestType.COMMIT_RETRIEVE, + ], max_workers=mp_config.max_gpu_workers, ) server.add_normal_thread_pool( diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py new file mode 100644 index 00000000000..2a598791bb6 --- /dev/null +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -0,0 +1,409 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Transfer context abstractions for LMCache multiprocess worker adapters.""" + +# Standard +from abc import ABC, abstractmethod +from typing import Any, Callable, Protocol + +# Third Party +import torch + +# First Party +from lmcache import torch_dev +from lmcache.utils import EngineType, init_logger +from lmcache.v1.distributed.api import MemoryLayoutDesc +from lmcache.v1.gpu_connector.utils import LayoutHints, is_mla +from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload +from lmcache.v1.multiprocess.futures import MessagingFuture +from lmcache.v1.multiprocess.mq import MessageQueueClient +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContext, + NonGpuContextMetadata, + compute_kv_layout, + create_non_gpu_context, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, +) +from lmcache.v1.multiprocess.protocol import RequestType + +logger = init_logger(__name__) + + +class IPCEvent(Protocol): + """Protocol for IPC-capable CUDA events used by transport operations.""" + + def ipc_handle(self) -> object: + """Return an IPC handle consumable by the multiprocess server.""" + + +SendRequest = Callable[[MessageQueueClient, RequestType, list[object]], MessagingFuture] + + +class TransferContext(ABC): + """Abstract transport layer for worker-side KV transfer. + + Concrete implementations encapsulate how worker-side store/retrieve + operations are transmitted to the multiprocess server. CUDA paths return + CUDA-aware futures backed by MQ requests, while CPU paths may perform + gather/scatter synchronously and return already-resolved futures. + """ + + @abstractmethod + def register( + self, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + model_name: str, + world_size: int, + blocks_in_chunk: int, + mq_client: MessageQueueClient, + mq_timeout: float, + send_request: SendRequest, + layout_hints: LayoutHints | None = None, + ) -> None: + """Register KV caches with the server and wait for ACK. + + Args: + instance_id: Worker process instance identifier. + kv_caches: Worker KV cache tensors keyed by layer name. + model_name: Model name used by cache keys. + world_size: KV world size. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + mq_client: Message queue client used to communicate with server. + mq_timeout: Timeout in seconds for synchronous request wait. + send_request: Request sender callable used to issue MQ requests. + layout_hints: Optional inference-engine-provided layout hints. + + Raises: + TimeoutError: If server registration does not complete before + ``mq_timeout``. + RuntimeError: If a concrete context cannot initialize. + """ + + @abstractmethod + def submit_store( + self, + request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + blocks_in_chunk: int, + ) -> MessagingFuture: + """Submit a store request and return a completion future. + + Args: + request_id: External request identifier. + key: LMCache key object for the store range. + instance_id: Worker process instance identifier. + kv_caches: Worker KV cache tensors keyed by layer name. + block_ids: vLLM block IDs to store. + event: Synchronization event object. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + + Returns: + A future compatible with adapter-side ``query()``/``result()`` flow. + + Raises: + RuntimeError: If register() was not called first. + """ + + @abstractmethod + def submit_retrieve( + self, + request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + blocks_in_chunk: int, + skip_first_n_tokens: int = 0, + ) -> MessagingFuture: + """Submit a retrieve request and return a completion future. + + Args: + request_id: External request identifier. + key: LMCache key object for the retrieve range. + instance_id: Worker process instance identifier. + kv_caches: Worker KV cache tensors keyed by layer name. + block_ids: vLLM block IDs to retrieve into. + event: Synchronization event object. + blocks_in_chunk: Number of vLLM blocks per LMCache chunk. + skip_first_n_tokens: Number of initial tokens to skip when writing. + + Returns: + A future compatible with adapter-side ``query()``/``result()`` flow. + + Raises: + RuntimeError: If register() was not called first. + """ + + @abstractmethod + def close(self) -> None: + """Release resources held by this context.""" + + +class HandleTransferContext(TransferContext): + """Handle-based IPC + MQ future transport context.""" + + def __init__(self) -> None: + self._mq_client: MessageQueueClient | None = None + self._send_request: SendRequest | None = None + + def register( + self, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + model_name: str, + world_size: int, + _blocks_in_chunk: int, + mq_client: MessageQueueClient, + mq_timeout: float, + send_request: SendRequest, + layout_hints: LayoutHints | None = None, + ) -> None: + # First Party + from lmcache.integration.vllm.vllm_multi_process_adapter import wrap_kv_caches + + self._mq_client = mq_client + self._send_request = send_request + future = send_request( + mq_client, + RequestType.REGISTER_KV_CACHE, + [ + instance_id, + wrap_kv_caches(kv_caches), + model_name, + world_size, + EngineType.VLLM, + layout_hints, + ], + ) + future.result(timeout=mq_timeout) + + def submit_store( + self, + _request_id: str, + key: Any, + instance_id: int, + _kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + _blocks_in_chunk: int, + ) -> MessagingFuture: + if self._mq_client is None or self._send_request is None: + raise RuntimeError( + "Handle transfer context is not registered. " + "Call register() before submit_store()." + ) + return self._send_request( + self._mq_client, + RequestType.STORE, + [key, instance_id, block_ids, event.ipc_handle()], + ).to_cuda_future() + + def submit_retrieve( + self, + _request_id: str, + key: Any, + instance_id: int, + _kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + event: IPCEvent, + _blocks_in_chunk: int, + skip_first_n_tokens: int = 0, + ) -> MessagingFuture: + if self._mq_client is None or self._send_request is None: + raise RuntimeError( + "Handle transfer context is not registered. " + "Call register() before submit_retrieve()." + ) + return self._send_request( + self._mq_client, + RequestType.RETRIEVE, + [key, instance_id, block_ids, event.ipc_handle(), skip_first_n_tokens], + ).to_cuda_future() + + def close(self) -> None: + self._mq_client = None + self._send_request = None + + +class DataTransferContext(TransferContext): + """Data transfer context for non-CUDA workers.""" + + def __init__(self) -> None: + self._non_gpu_context: NonGpuContext | None = None + self._layout_hints: LayoutHints | None = None + self._gpu_kv_format: Any = None + + def register( + self, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + model_name: str, + world_size: int, + blocks_in_chunk: int, + mq_client: MessageQueueClient, + mq_timeout: float, + send_request: SendRequest, + layout_hints: LayoutHints | None = None, + ) -> None: + # TODO: inference_engine_logical_block_size is currently used by + # DeepSeek V4 on the CUDA path. The non-CUDA path is yet to be + # implemented. + ( + block_size, + num_layers, + hidden_dim_size, + dtype_str, + gpu_kv_format, + ) = compute_kv_layout(kv_caches, layout_hints=layout_hints) + self._layout_hints = layout_hints + self._gpu_kv_format = gpu_kv_format + + use_mla_flag = is_mla(gpu_kv_format) + shape = ( + torch.Size([num_layers, blocks_in_chunk * block_size, hidden_dim_size]) + if use_mla_flag + else torch.Size( + [2, num_layers, blocks_in_chunk * block_size, hidden_dim_size] + ) + ) + dtype = getattr(torch, dtype_str) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + + future = send_request( + mq_client, + RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, + [ + RegisterNonGpuContextPayload( + instance_id=instance_id, + model_name=model_name, + world_size=world_size, + block_size=block_size, + num_layers=num_layers, + hidden_dim_size=hidden_dim_size, + dtype_str=dtype_str, + use_mla=use_mla_flag, + ) + ], + ) + + metadata = NonGpuContextMetadata( + layout_desc=layout_desc, + block_size=block_size, + use_mla=use_mla_flag, + ) + self._non_gpu_context = create_non_gpu_context(metadata, mq_client, mq_timeout) + future.result(timeout=mq_timeout) + + def submit_store( + self, + _request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + _event: IPCEvent, + blocks_in_chunk: int, + ) -> MessagingFuture: + if self._non_gpu_context is None: + raise RuntimeError( + "Data transfer context is not registered. " + "Call register() before submit_store()." + ) + + torch_dev.synchronize() + out_buffers = self._non_gpu_context.prepare_store(key, instance_id) + cpu_chunks = gather_paged_kv_to_cpu( + kv_caches, + block_ids, + blocks_in_chunk, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + out=out_buffers, + ) + ok = self._non_gpu_context.commit_store(key, instance_id, cpu_chunks) + + future: MessagingFuture[bool] = MessagingFuture() + future.set_result(ok) + return future + + def submit_retrieve( + self, + _request_id: str, + key: Any, + instance_id: int, + kv_caches: dict[str, torch.Tensor], + block_ids: list[int], + _event: IPCEvent, + blocks_in_chunk: int, + skip_first_n_tokens: int = 0, + ) -> MessagingFuture: + if self._non_gpu_context is None: + raise RuntimeError( + "Data transfer context is not registered. " + "Call register() before submit_retrieve()." + ) + + src_buffers = self._non_gpu_context.prepare_retrieve(key, instance_id) + ok = src_buffers is not None + if src_buffers is not None: + try: + scatter_cpu_to_paged_kv( + kv_caches, + block_ids, + src_buffers, + blocks_in_chunk, + skip_first_n_tokens=skip_first_n_tokens, + layout_hints=self._layout_hints, + gpu_kv_format=self._gpu_kv_format, + ) + except (RuntimeError, ValueError, TypeError, IndexError): + logger.exception("Failed to scatter retrieved CPU context chunks") + ok = False + self._non_gpu_context.commit_retrieve(key, instance_id) + + future: MessagingFuture[bool] = MessagingFuture() + future.set_result(ok) + return future + + def close(self) -> None: + if self._non_gpu_context is not None: + self._non_gpu_context.close() + self._non_gpu_context = None + + +def create_transfer_context( + kv_caches: dict[str, torch.Tensor], + **_kwargs: Any, +) -> TransferContext: + """Create a transfer context from KV cache device type. + + The device check is intentionally centralized here. + + Args: + kv_caches: Worker KV cache tensors keyed by layer name. + **kwargs: Unused placeholder for forward-compatible factory extension. + + Returns: + A concrete :class:`TransferContext` implementation. + + Raises: + ValueError: If ``kv_caches`` is empty or has mixed device types. + """ + if not kv_caches: + raise ValueError("kv_caches is empty") + device_types = {tensor.device.type for tensor in kv_caches.values()} + if len(device_types) != 1: + raise ValueError( + f"All KV cache tensors must share one device type, got {device_types}" + ) + device_type = next(iter(device_types)) + logger.info("Creating transfer context (device_type=%s)", device_type) + if device_type == "cuda": + return HandleTransferContext() + return DataTransferContext() diff --git a/tests/v1/multiprocess/test_non_cuda_data_transfer.py b/tests/v1/multiprocess/test_non_cuda_data_transfer.py new file mode 100644 index 00000000000..f8a281a8785 --- /dev/null +++ b/tests/v1/multiprocess/test_non_cuda_data_transfer.py @@ -0,0 +1,461 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from contextlib import contextmanager +from typing import Any, Callable +from unittest.mock import MagicMock, patch +import pickle +import sys + +# Third Party +import pytest +import torch + + +def _make_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + num_heads: int = 2, + head_size: int = 8, +) -> dict[str, torch.Tensor]: + """Build per-layer NHD KV tensors for non-CUDA data transfer tests.""" + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn( + 2, num_blocks, block_size, num_heads, head_size + ) + return kv_caches + + +def _make_mla_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + hidden_size: int = 16, +) -> dict[str, torch.Tensor]: + """Build per-layer MLA KV tensors for non-CUDA data transfer tests. + + Args: + num_layers: Number of KV layers to generate. + num_blocks: Number of paged blocks per layer. + block_size: Number of tokens per block. + hidden_size: Hidden size per token. + + Returns: + Mapping from layer name to MLA KV tensor with shape + ``[num_blocks, block_size, hidden_size]``. + """ + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn(num_blocks, block_size, hidden_size) + return kv_caches + + +def _make_hnd_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + num_heads: int = 2, + head_size: int = 8, +) -> dict[str, torch.Tensor]: + """Build per-layer HND KV tensors for non-CUDA data transfer tests.""" + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn( + 2, num_blocks, num_heads, block_size, head_size + ) + return kv_caches + + +def _make_hnd_flashinfer_kv_caches( + num_layers: int = 2, + num_blocks: int = 6, + block_size: int = 4, + num_heads: int = 2, + head_size: int = 8, +) -> dict[str, torch.Tensor]: + """Build per-layer HND flash-infer KV tensors for non-CUDA data transfer tests.""" + kv_caches = {} + for i in range(num_layers): + kv_caches[f"layer_{i}"] = torch.randn( + num_blocks, 2, num_heads, block_size, head_size + ) + return kv_caches + + +def test_wrap_kv_caches_wraps_all_tensors(monkeypatch: Any) -> None: + """Verify wrap_kv_caches wraps all provided KV tensors.""" + # First Party + from lmcache.integration.vllm import vllm_multi_process_adapter as adapter_mod + + kv_caches = _make_kv_caches() + monkeypatch.setattr( + adapter_mod, + "CudaIPCWrapper", + lambda tensor: ("wrapped", tensor), + ) + + wrapped = adapter_mod.wrap_kv_caches(kv_caches) + assert len(wrapped) == len(kv_caches) + + +def test_create_transfer_context_uses_non_cuda_context_on_cpu() -> None: + """Ensure transfer context factory returns DataTransferContext for CPU KV.""" + # First Party + from lmcache.v1.multiprocess.transfer_context import ( + DataTransferContext, + create_transfer_context, + ) + + context = create_transfer_context({"layer_0": torch.randn(2, 2)}) + assert isinstance(context, DataTransferContext) + + +def test_compute_kv_layout_and_gather_scatter_roundtrip() -> None: + """Validate layout extraction and gather/scatter round-trip on CPU tensors.""" + # First Party + from lmcache.v1.multiprocess.non_gpu_context import ( + compute_kv_layout, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, + ) + + source = _make_kv_caches(num_layers=2, num_blocks=8, block_size=4) + ( + block_size, + num_layers, + hidden_dim, + dtype_str, + detected_kv_format, + ) = compute_kv_layout(source) + assert block_size == 4 + assert num_layers == 2 + assert hidden_dim == 16 + assert dtype_str == "float32" + assert detected_kv_format is not None + + blocks_per_chunk = 2 + gathered = gather_paged_kv_to_cpu(source, [0, 1], blocks_per_chunk) + destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} + scatter_cpu_to_paged_kv(destination, [4, 5], gathered, blocks_per_chunk) + + for name in source: + assert torch.allclose(source[name][:, 0], destination[name][:, 4]) + assert torch.allclose(source[name][:, 1], destination[name][:, 5]) + + +@pytest.mark.parametrize( + ("hnd_builder", "expected_format"), + [ + (_make_hnd_kv_caches, "NL_X_TWO_NB_NH_BS_HS"), + (_make_hnd_flashinfer_kv_caches, "NL_X_NB_TWO_NH_BS_HS"), + ], +) +def test_gather_scatter_roundtrip_hnd_layout( + hnd_builder: Callable[[int, int, int, int, int], dict[str, torch.Tensor]], + expected_format: str, +) -> None: + """Validate gather/scatter round-trip for HND vLLM KV layout.""" + # First Party + from lmcache.v1.multiprocess.non_gpu_context import ( + compute_kv_layout, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, + ) + import lmcache.c_ops as lmc_ops + + source = hnd_builder(2, 8, 4, 2, 8) + layout_hints = {"kv_layout": "HND"} + ( + block_size, + num_layers, + hidden_dim, + dtype_str, + detected_kv_format, + ) = compute_kv_layout(source, layout_hints=layout_hints) + assert block_size == 4 + assert num_layers == 2 + assert hidden_dim == 16 + assert dtype_str == "float32" + assert detected_kv_format == getattr(lmc_ops.GPUKVFormat, expected_format) + + blocks_per_chunk = 2 + gathered = gather_paged_kv_to_cpu( + source, + [0, 1], + blocks_per_chunk, + layout_hints=layout_hints, + gpu_kv_format=detected_kv_format, + ) + destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} + scatter_cpu_to_paged_kv( + destination, + [4, 5], + gathered, + blocks_per_chunk, + layout_hints=layout_hints, + gpu_kv_format=detected_kv_format, + ) + + for name in source: + if detected_kv_format == lmc_ops.GPUKVFormat.NL_X_TWO_NB_NH_BS_HS: + assert torch.allclose(source[name][:, 0], destination[name][:, 4]) + assert torch.allclose(source[name][:, 1], destination[name][:, 5]) + else: + assert torch.allclose(source[name][0], destination[name][4]) + assert torch.allclose(source[name][1], destination[name][5]) + + +def test_scatter_respects_skip_first_n_tokens() -> None: + """Ensure scatter honors skip_first_n_tokens and preserves skipped blocks.""" + # First Party + from lmcache.v1.multiprocess.non_gpu_context import ( + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, + ) + + source = _make_kv_caches(num_layers=2, num_blocks=8, block_size=4) + destination = { + name: torch.full_like(tensor, 999.0) for name, tensor in source.items() + } + gathered = gather_paged_kv_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_to_paged_kv( + destination, + [0, 1, 2, 3], + gathered, + blocks_per_chunk=4, + skip_first_n_tokens=8, + ) + + for name in destination: + assert torch.all(destination[name][:, 0] == 999.0) + assert torch.all(destination[name][:, 1] == 999.0) + assert torch.allclose(destination[name][:, 2], source[name][:, 2]) + assert torch.allclose(destination[name][:, 3], source[name][:, 3]) + + +def test_compute_kv_layout_and_gather_scatter_roundtrip_mla() -> None: + """Validate gather/scatter round-trip for MLA KV tensors.""" + # First Party + from lmcache.v1.multiprocess.non_gpu_context import ( + compute_kv_layout, + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, + ) + + source = _make_mla_kv_caches( + num_layers=2, num_blocks=8, block_size=4, hidden_size=16 + ) + ( + block_size, + num_layers, + hidden_dim, + dtype_str, + detected_kv_format, + ) = compute_kv_layout(source) + assert block_size == 4 + assert num_layers == 2 + assert hidden_dim == 16 + assert dtype_str == "float32" + assert detected_kv_format is not None + + blocks_per_chunk = 2 + gathered = gather_paged_kv_to_cpu(source, [0, 1], blocks_per_chunk) + destination = {name: torch.zeros_like(tensor) for name, tensor in source.items()} + scatter_cpu_to_paged_kv(destination, [4, 5], gathered, blocks_per_chunk) + + for name in source: + assert torch.allclose(source[name][0], destination[name][4]) + assert torch.allclose(source[name][1], destination[name][5]) + + +def test_compute_kv_layout_empty_raises_value_error() -> None: + """Ensure compute_kv_layout rejects empty KV cache input.""" + # First Party + from lmcache.v1.multiprocess.non_gpu_context import compute_kv_layout + + with pytest.raises(ValueError, match="kv_caches is empty"): + compute_kv_layout({}) + + +def test_scatter_mla_respects_skip_first_n_tokens() -> None: + """Ensure MLA scatter honors skip_first_n_tokens and preserves skipped blocks.""" + # First Party + from lmcache.v1.multiprocess.non_gpu_context import ( + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, + ) + + source = _make_mla_kv_caches( + num_layers=2, num_blocks=8, block_size=4, hidden_size=16 + ) + destination = { + name: torch.full_like(tensor, 999.0) for name, tensor in source.items() + } + gathered = gather_paged_kv_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_to_paged_kv( + destination, + [0, 1, 2, 3], + gathered, + blocks_per_chunk=4, + skip_first_n_tokens=8, + ) + + for name in destination: + assert torch.all(destination[name][0] == 999.0) + assert torch.all(destination[name][1] == 999.0) + assert torch.allclose(destination[name][2], source[name][2]) + assert torch.allclose(destination[name][3], source[name][3]) + + +def test_scatter_mla_skip_past_chunk_keeps_destination_unchanged() -> None: + """Ensure MLA scatter is a no-op when skip_first_n_tokens exceeds chunk tokens.""" + # First Party + from lmcache.v1.multiprocess.non_gpu_context import ( + gather_paged_kv_to_cpu, + scatter_cpu_to_paged_kv, + ) + + source = _make_mla_kv_caches( + num_layers=2, num_blocks=8, block_size=4, hidden_size=16 + ) + destination = { + name: torch.full_like(tensor, 123.0) for name, tensor in source.items() + } + gathered = gather_paged_kv_to_cpu(source, [0, 1, 2, 3], blocks_per_chunk=4) + scatter_cpu_to_paged_kv( + destination, + [0, 1, 2, 3], + gathered, + blocks_per_chunk=4, + skip_first_n_tokens=40, + ) + + for name in destination: + assert torch.all(destination[name] == 123.0) + + +@pytest.fixture +def stub_native_storage_ops() -> Any: + """Stub native modules so server imports work in source-only test runs.""" + module = type(sys)("lmcache.native_storage_ops") + module.TTLLock = type("TTLLock", (), {}) # type: ignore[attr-defined] + module.Bitmap = type("Bitmap", (), {}) # type: ignore[attr-defined] + with patch.dict( + sys.modules, + { + "lmcache.native_storage_ops": module, + "cupy": MagicMock(), + }, + ): + yield + + +def test_server_register_and_find_non_cuda_context_layout( + stub_native_storage_ops: Any, +) -> None: + """Ensure non-CUDA registration stores metadata and lookup finds layout.""" + # First Party + from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload + from lmcache.v1.multiprocess.server import MPCacheEngine + + with ( + patch("lmcache.v1.multiprocess.server.StorageManager"), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager"), + patch("lmcache.v1.multiprocess.server.get_event_bus"), + ): + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) + engine.register_kv_cache_non_gpu_context( + RegisterNonGpuContextPayload( + instance_id=1, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + ) + + layout = engine._find_layout_desc("m", 1) + assert layout is not None + assert layout.shapes[0] == torch.Size([2, 2, 16, 16]) + + +def test_server_store_and_retrieve_cpu_chunks(stub_native_storage_ops: Any) -> None: + """Validate mocked server-side CPU chunk store and retrieve behavior.""" + # First Party + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) + from lmcache.v1.multiprocess.server import MPCacheEngine + + mock_storage = MagicMock() + target_tensor = torch.zeros(2, 2, 8, 16) + mock_memory_obj = MagicMock() + mock_memory_obj.tensor = target_tensor + mock_storage.reserve_write.return_value = {"obj": mock_memory_obj} + + @contextmanager + def _read_prefetched_results(_keys: Any) -> Any: + yield [mock_memory_obj] + + mock_storage.read_prefetched_results.side_effect = _read_prefetched_results + mock_session = MagicMock() + mock_session.get_hashes.return_value = [b"h"] + with ( + patch( + "lmcache.v1.multiprocess.server.StorageManager", + return_value=mock_storage, + ), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ), + ): + session_cls.return_value.get_or_create.return_value = mock_session + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + + engine.register_kv_cache_non_gpu_context( + RegisterNonGpuContextPayload( + instance_id=2, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + ) + payload = torch.ones(2, 2, 8, 16) + key = IPCCacheEngineKey.from_token_ids( + "m", + 1, + 0, + [1] * 8, + start=0, + end=8, + request_id="req", + ) + with patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ): + store_ok = engine.commit_store(key, 2, pickle.dumps([payload])) + response = engine.prepare_retrieve(key, 2) + success = response.success + cpu_data = response.data + assert isinstance(store_ok, bool) + assert torch.allclose(mock_memory_obj.tensor, payload) + + assert success is True + recovered_chunks: list[torch.Tensor] = pickle.loads(cpu_data) + assert len(recovered_chunks) == 1 + assert torch.allclose(recovered_chunks[0], payload) diff --git a/tests/v1/test_vllm_mp_adapter.py b/tests/v1/test_vllm_mp_adapter.py index 91d0b1d120a..ef30d25327b 100644 --- a/tests/v1/test_vllm_mp_adapter.py +++ b/tests/v1/test_vllm_mp_adapter.py @@ -14,11 +14,13 @@ # Third Party import pytest +import torch # First Party from lmcache.integration.vllm import vllm_multi_process_adapter as adapter_mod from lmcache.integration.vllm.vllm_multi_process_adapter import ( LMCacheMPWorkerAdapter, + LoadStoreOp, ParallelStrategy, ) from lmcache.v1.multiprocess.protocol import RequestType @@ -81,7 +83,9 @@ def fake_adapter(monkeypatch): def test_register_kv_caches_updates_kv_caches_and_submits(fake_adapter): """Public register_kv_caches stores the dict and submits one request.""" adapter, send_mock, _ = fake_adapter - new_caches = {"layer.0": object(), "layer.1": object()} + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + new_caches = {"layer.0": fake_tensor, "layer.1": fake_tensor} adapter.register_kv_caches(new_caches) @@ -97,4 +101,73 @@ def test_register_kv_caches_raises_connection_error_on_timeout(fake_adapter): future.result.side_effect = TimeoutError("server down") with pytest.raises(ConnectionError, match="did not respond"): - adapter.register_kv_caches({"layer.0": object()}) + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + adapter.register_kv_caches({"layer.0": fake_tensor}) + + +def test_register_kv_caches_cpu_submits_non_gpu_context_registration( + fake_adapter, monkeypatch +): + """CPU KV cache registration routes to REGISTER_KV_CACHE_NON_GPU_CONTEXT.""" + adapter, send_mock, _ = fake_adapter + monkeypatch.setattr( + "lmcache.integration.vllm.utils.vllm_layout_hints", + lambda: {}, + raising=False, + ) + cpu_kv = {"layer.0": torch.randn(2, 8, 4, 2, 8)} + + adapter.register_kv_caches(cpu_kv) + + assert adapter.kv_caches is cpu_kv + assert send_mock.call_count == 1 + args, _kwargs = send_mock.call_args + assert args[1] == RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT + assert len(args[2]) == 1 + + +def test_submit_store_request_tracks_returned_future(fake_adapter, monkeypatch): + """submit_store_request stores the returned future in store_futures.""" + adapter, _send_mock, _ = fake_adapter + monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + adapter.kv_caches = {"layer.0": fake_tensor} + transfer_ctx = MagicMock() + fake_future = MagicMock() + transfer_ctx.submit_store.return_value = fake_future + adapter.transfer_ctx = transfer_ctx + op = LoadStoreOp(token_ids=[1, 2, 3, 4], block_ids=[0], start=0, end=4) + + adapter.submit_store_request("req-1", op, event=MagicMock()) + + assert transfer_ctx.submit_store.called + assert transfer_ctx.submit_store.call_args.kwargs == {} + assert adapter.store_futures["req-1"] is fake_future + + +def test_submit_retrieve_request_tracks_returned_future(fake_adapter, monkeypatch): + """submit_retrieve_request stores returned future and block IDs.""" + adapter, _send_mock, _ = fake_adapter + monkeypatch.setattr(adapter, "_ensure_heartbeat_started", lambda: None) + fake_tensor = MagicMock() + fake_tensor.device.type = "cuda" + adapter.kv_caches = {"layer.0": fake_tensor} + transfer_ctx = MagicMock() + fake_future = MagicMock() + transfer_ctx.submit_retrieve.return_value = fake_future + adapter.transfer_ctx = transfer_ctx + op = LoadStoreOp( + token_ids=[1, 2, 3, 4], + block_ids=[0], + start=0, + end=4, + skip_first_n_tokens=1, + ) + + adapter.submit_retrieve_request("req-1", op, event=MagicMock()) + + assert transfer_ctx.submit_retrieve.called + assert transfer_ctx.submit_retrieve.call_args.kwargs == {"skip_first_n_tokens": 1} + assert adapter.retrieve_futures["req-1"] == (fake_future, [0])