From cf82347acaafd922befe8dbefca6645a99408ddf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 05:43:57 +0000 Subject: [PATCH 1/9] Initial plan From 0403c5be5d32db41eaa8bb698d3d3fcb1a6822b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 05:57:48 +0000 Subject: [PATCH 2/9] Add SHM NonGpuContext and server/storage plumbing Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/07c7d0ab-d21a-4245-9109-006f91352b6c Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/distributed/config.py | 4 + lmcache/v1/distributed/l1_manager.py | 4 + lmcache/v1/distributed/memory_manager.py | 57 +++++++ lmcache/v1/distributed/storage_manager.py | 19 +++ lmcache/v1/memory_management.py | 10 ++ lmcache/v1/multiprocess/non_gpu_context.py | 15 +- .../v1/multiprocess/non_gpu_context_shm.py | 129 ++++++++++++++ lmcache/v1/multiprocess/protocols/engine.py | 12 +- lmcache/v1/multiprocess/server.py | 132 +++++++++++++-- lmcache/v1/multiprocess/transfer_context.py | 16 +- tests/v1/distributed/test_shm_l1_pool.py | 160 ++++++++++++++++++ 11 files changed, 541 insertions(+), 17 deletions(-) create mode 100644 lmcache/v1/multiprocess/non_gpu_context_shm.py create mode 100644 tests/v1/distributed/test_shm_l1_pool.py diff --git a/lmcache/v1/distributed/config.py b/lmcache/v1/distributed/config.py index 2690b043970..bc39392a54b 100644 --- a/lmcache/v1/distributed/config.py +++ b/lmcache/v1/distributed/config.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from typing import Literal import argparse +import os # First Party from lmcache import torch_dev @@ -39,6 +40,9 @@ class L1MemoryManagerConfig: align_bytes: int = field(default=0x1000) """ The alignment size in bytes. Default is 4KB. """ + shm_name: str = field(default_factory=lambda: f"lmcache_l1_pool_{os.getpid()}") + """ POSIX shared-memory segment name for L1 pool. Empty disables SHM. """ + def __post_init__(self): self.init_size_in_bytes = min(self.init_size_in_bytes, self.size_in_bytes) diff --git a/lmcache/v1/distributed/l1_manager.py b/lmcache/v1/distributed/l1_manager.py index e4e4379f30f..f85eaabfa41 100644 --- a/lmcache/v1/distributed/l1_manager.py +++ b/lmcache/v1/distributed/l1_manager.py @@ -803,6 +803,10 @@ def get_l1_memory_desc(self): """Return an L1MemoryDesc describing the underlying L1 memory buffer.""" return self._memory_manager.get_l1_memory_desc() + def get_shm_pool_info(self) -> dict: + """Return SHM pool metadata for non-GPU SHM transport.""" + return self._memory_manager.get_shm_pool_info() + def close(self) -> None: """Close the L1Manager and free all resources.""" with self._lock: diff --git a/lmcache/v1/distributed/memory_manager.py b/lmcache/v1/distributed/memory_manager.py index e2bfff45c4e..ea8f0244da8 100644 --- a/lmcache/v1/distributed/memory_manager.py +++ b/lmcache/v1/distributed/memory_manager.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +# Standard +import os +import shutil + # First Party from lmcache.logging import init_logger from lmcache.v1.distributed.api import MemoryLayoutDesc @@ -17,6 +21,31 @@ # HELPER FUNCTIONS +def _check_shm_capacity(required_bytes: int) -> bool: + """Return whether ``/dev/shm`` has enough free space.""" + if required_bytes <= 0: + return True + try: + free_bytes = shutil.disk_usage("/dev/shm").free + except OSError: + return False + return free_bytes >= required_bytes + + +def _unlink_stale_shm(shm_name: str) -> None: + """Remove a stale LMCache shm segment if it exists.""" + normalized = shm_name.lstrip("/") + if not normalized.startswith("lmcache_l1_pool_"): + return + shm_path = os.path.join("/dev/shm", normalized) + try: + os.unlink(shm_path) + except FileNotFoundError: + return + except OSError: + logger.warning("Failed to remove stale shm segment %s", shm_path, exc_info=True) + + def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInterface: """ Create a memory allocator based on the provided configuration. @@ -45,6 +74,23 @@ def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInt config.size_in_bytes, config.align_bytes, ) + shm_name = config.shm_name + if shm_name: + try: + if not _check_shm_capacity(config.size_in_bytes): + raise RuntimeError("insufficient /dev/shm capacity") + _unlink_stale_shm(shm_name) + return MixedMemoryAllocator( + config.size_in_bytes, + align_bytes=config.align_bytes, + shm_name=shm_name, + ) + except (RuntimeError, OSError, ValueError): + logger.warning( + "Failed to initialize SHM pool (%s), falling back to pickle path", + shm_name, + exc_info=True, + ) return MixedMemoryAllocator( config.size_in_bytes, align_bytes=config.align_bytes, @@ -65,6 +111,13 @@ def __init__(self, config: L1MemoryManagerConfig): self._allocator = create_memory_allocator(config) self._size_in_bytes = config.size_in_bytes self._align_bytes = config.align_bytes + self._shm_pool_info = {"shm_name": "", "pool_size": 0} + if isinstance(self._allocator, MixedMemoryAllocator): + if self._allocator.shm_name: + self._shm_pool_info = { + "shm_name": self._allocator.shm_name, + "pool_size": self._size_in_bytes, + } def allocate( self, layout_desc: MemoryLayoutDesc, count: int @@ -174,6 +227,10 @@ def close(self) -> None: """ self._allocator.close() + def get_shm_pool_info(self) -> dict: + """Return SHM pool metadata for non-GPU SHM transport.""" + return dict(self._shm_pool_info) + # Debugging APIs def memcheck(self): return self._allocator.memcheck() diff --git a/lmcache/v1/distributed/storage_manager.py b/lmcache/v1/distributed/storage_manager.py index 6b4f95bd996..82df94b4fa1 100644 --- a/lmcache/v1/distributed/storage_manager.py +++ b/lmcache/v1/distributed/storage_manager.py @@ -559,6 +559,25 @@ def touch_l1_keys(self, keys: list[ObjectKey]): """ self._l1_manager.touch_keys(keys) + def get_shm_pool_info(self) -> dict: + """Return SHM pool metadata from the L1 memory manager.""" + return self._l1_manager.get_shm_pool_info() + + def unsafe_read( + self, keys: list[ObjectKey] + ) -> tuple[list[ObjectKey], list[MemoryObj]]: + """Read already read-locked objects without acquiring new read locks.""" + read_results = self._l1_manager.unsafe_read(keys) + good_keys: list[ObjectKey] = [] + good_objs: list[MemoryObj] = [] + for key in keys: + err, obj = read_results.get(key, (L1Error.KEY_NOT_EXIST, None)) + if err != L1Error.SUCCESS or obj is None: + continue + good_keys.append(key) + good_objs.append(obj) + return good_keys, good_objs + @property def quota_manager(self) -> QuotaManager: """Per-cache_salt quota registry. diff --git a/lmcache/v1/memory_management.py b/lmcache/v1/memory_management.py index 164cb058f6f..c3a6018ed41 100644 --- a/lmcache/v1/memory_management.py +++ b/lmcache/v1/memory_management.py @@ -299,6 +299,16 @@ def get_num_tokens(self) -> int: """ raise NotImplementedError + @property + def shm_offset(self) -> int: + """Return the byte offset of this object inside the SHM pool.""" + return self.meta.address + + @property + def shm_byte_length(self) -> int: + """Return the byte length of this object inside the SHM pool.""" + return self.get_size() + @property @abc.abstractmethod def metadata(self) -> MemoryObjMetadata: diff --git a/lmcache/v1/multiprocess/non_gpu_context.py b/lmcache/v1/multiprocess/non_gpu_context.py index e782c76d0c3..bc3e36deb71 100644 --- a/lmcache/v1/multiprocess/non_gpu_context.py +++ b/lmcache/v1/multiprocess/non_gpu_context.py @@ -100,21 +100,30 @@ def create_non_gpu_context( metadata: NonGpuContextMetadata, mq_client: Any, mq_timeout: float, + shm_name: str = "", + pool_size: int = 0, ) -> NonGpuContext: """Factory that returns the appropriate :class:`NonGpuContext` implementation. - Currently always returns a pickle-based implementation - (``NonGpuContextPickle``). A future SHM-capable PR - may probe for shared-memory availability and fall back to pickle. + Returns SHM-based implementation when shared-memory pool information is + available; otherwise falls back to the pickle-based implementation. Args: metadata: Layout metadata for the non-GPU context. mq_client: Message-queue client for server communication. mq_timeout: Timeout in seconds for blocking MQ requests. + shm_name: Shared-memory segment name. Empty means pickle mode. + pool_size: Shared-memory pool size in bytes. Non-positive means pickle mode. Returns: A concrete :class:`NonGpuContext` instance. """ + if shm_name and pool_size > 0: + # Local + from .non_gpu_context_shm import NonGpuContextShm + + return NonGpuContextShm(metadata, mq_client, mq_timeout, shm_name, pool_size) + # Local from .non_gpu_context_pickle import NonGpuContextPickle diff --git a/lmcache/v1/multiprocess/non_gpu_context_shm.py b/lmcache/v1/multiprocess/non_gpu_context_shm.py new file mode 100644 index 00000000000..8f41cb8970f --- /dev/null +++ b/lmcache/v1/multiprocess/non_gpu_context_shm.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared-memory NonGpuContext implementation for multiprocess mode.""" + +# Standard +from typing import Any +import mmap +import os + +# Third Party +import torch + +# First Party +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContext, + NonGpuContextMetadata, +) +from lmcache.v1.multiprocess.protocol import RequestType, get_response_class + + +class NonGpuContextShm(NonGpuContext): + """Shared-memory implementation of :class:`NonGpuContext`.""" + + def __init__( + self, + metadata: NonGpuContextMetadata, + mq_client: Any, + mq_timeout: float, + shm_name: str, + pool_size: int, + ) -> None: + super().__init__(metadata, mq_client, mq_timeout) + if not shm_name or pool_size <= 0: + raise ValueError("shm_name must be non-empty and pool_size must be > 0") + + self._shm_name = shm_name + self._pool_size = pool_size + shm_path = os.path.join("/dev/shm", shm_name.lstrip("/")) + self._shm_fd = os.open(shm_path, os.O_RDWR) + self._mmap_obj = mmap.mmap( + self._shm_fd, self._pool_size, access=mmap.ACCESS_WRITE + ) + + def _make_tensor_view( + self, + offset: int, + length: int, + shape: list[int], + dtype_str: str, + ) -> torch.Tensor: + """Create a tensor view over a SHM slot via ``torch.frombuffer``.""" + dtype = getattr(torch, dtype_str) + itemsize = torch.empty((), dtype=dtype).element_size() + if itemsize <= 0: + raise ValueError(f"Invalid dtype size for {dtype_str}") + count = length // itemsize + tensor_1d = torch.frombuffer( + self._mmap_obj, dtype=dtype, count=count, offset=offset + ) + return tensor_1d.view(torch.Size(shape)) + + def _build_slot_tensors(self, slots: list[dict[str, Any]]) -> list[torch.Tensor]: + return [ + self._make_tensor_view( + offset=int(slot["offset"]), + length=int(slot["length"]), + shape=list(slot["shape"]), + dtype_str=str(slot["dtype"]), + ) + for slot in slots + ] + + def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + future = self.mq_client.submit_request( + RequestType.PREPARE_STORE, + [key, instance_id], + get_response_class(RequestType.PREPARE_STORE), + ) + try: + response = future.result(timeout=self.mq_timeout) + except TimeoutError: + return None + slots = response.context.get("slots", []) + return self._build_slot_tensors(slots) if slots else None + + def commit_store( + self, key: Any, instance_id: int, chunks: list[torch.Tensor] + ) -> bool: + del chunks + future = self.mq_client.submit_request( + RequestType.COMMIT_STORE, + [key, instance_id, b""], + get_response_class(RequestType.COMMIT_STORE), + ) + try: + return bool(future.result(timeout=self.mq_timeout)) + except TimeoutError: + return False + + def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + future = self.mq_client.submit_request( + RequestType.PREPARE_RETRIEVE, + [key, instance_id], + get_response_class(RequestType.PREPARE_RETRIEVE), + ) + try: + response = future.result(timeout=self.mq_timeout) + except TimeoutError: + return None + if not response.success: + return None + slots = response.context.get("slots", []) + return self._build_slot_tensors(slots) if slots else None + + def commit_retrieve(self, key: Any, instance_id: int) -> bool: + future = self.mq_client.submit_request( + RequestType.COMMIT_RETRIEVE, + [key, instance_id], + get_response_class(RequestType.COMMIT_RETRIEVE), + ) + try: + return bool(future.result(timeout=self.mq_timeout)) + except TimeoutError: + return False + + def close(self) -> None: + try: + self._mmap_obj.close() + finally: + os.close(self._shm_fd) diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index 62ec4926cd8..bea9a1723c2 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -46,6 +46,14 @@ class PrepareRetrieveResponse: ) # pickle: {}, shm will put slot info here +@dataclass +class RegisterNonGpuContextResponse: + """Response for REGISTER_KV_CACHE_NON_GPU_CONTEXT.""" + + shm_name: str = "" + pool_size: int = 0 + + # Define request names for this protocol group REQUEST_NAMES = [ "REGISTER_KV_CACHE", @@ -179,10 +187,10 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: # Register non-GPU KV cache context # Payload: # - RegisterNonGpuContextPayload - all metadata fields in one struct - # Returns: None + # Returns: RegisterNonGpuContextResponse "REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition( payload_classes=[RegisterNonGpuContextPayload], - response_class=None, + response_class=RegisterNonGpuContextResponse, handler_type=HandlerType.SYNC, ), "PREPARE_STORE": ProtocolDefinition( diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index bb67d07bf54..88c271f27bc 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -72,6 +72,7 @@ from lmcache.v1.multiprocess.protocols.engine import ( PrepareRetrieveResponse, PrepareStoreResponse, + RegisterNonGpuContextResponse, ) from lmcache.v1.multiprocess.session import SessionManager from lmcache.v1.multiprocess.token_hasher import TokenHasher @@ -254,6 +255,8 @@ def __init__( # for crash resilience (e.g., client calls lookup but never queries) self._prefetch_jobs: dict[str, _PrefetchJob] = {} self._prefetch_job_lock = threading.Lock() + self._pending_shm_writes: dict[tuple[object, ...], list[ObjectKey]] = {} + self._pending_shm_reads: dict[tuple[object, ...], list[ObjectKey]] = {} self._setup_metrics() @@ -333,11 +336,17 @@ def unregister_kv_cache(self, instance_id: int) -> None: torch_dev.empty_cache() else: logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) + self._pending_shm_writes = { + k: v for k, v in self._pending_shm_writes.items() if k[0] != instance_id + } + self._pending_shm_reads = { + k: v for k, v in self._pending_shm_reads.items() if k[0] != instance_id + } def register_kv_cache_non_gpu_context( self, payload: RegisterNonGpuContextPayload, - ) -> None: + ) -> RegisterNonGpuContextResponse: """Register non-CUDA KV layout metadata for non-GPU context mode. Args: @@ -354,7 +363,7 @@ def register_kv_cache_non_gpu_context( "skipping the new registration", payload.instance_id, ) - return + return RegisterNonGpuContextResponse() dtype = getattr(torch, payload.dtype_str, None) if dtype is None or not isinstance(dtype, torch.dtype): @@ -381,6 +390,38 @@ def register_kv_cache_non_gpu_context( use_mla=payload.use_mla, ), ) + shm_pool_info = self.storage_manager.get_shm_pool_info() + if not isinstance(shm_pool_info, dict): + return RegisterNonGpuContextResponse() + return RegisterNonGpuContextResponse( + shm_name=str(shm_pool_info.get("shm_name", "")), + pool_size=int(shm_pool_info.get("pool_size", 0)), + ) + + @staticmethod + def _make_non_gpu_transfer_key( + key: IPCCacheEngineKey, instance_id: int + ) -> tuple[object, ...]: + return ( + instance_id, + key.model_name, + key.world_size, + key.worker_id, + key.token_ids, + key.start, + key.end, + key.request_id, + key.cache_salt, + ) + + def _is_shm_active(self) -> bool: + shm_pool_info = self.storage_manager.get_shm_pool_info() + if not isinstance(shm_pool_info, dict): + return False + return ( + bool(shm_pool_info.get("shm_name")) + and int(shm_pool_info.get("pool_size", 0)) > 0 + ) def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: """Resolve object keys from an IPC cache key. @@ -419,7 +460,36 @@ def prepare_store( PrepareStoreResponse with empty slots for pickle mode. """ - return PrepareStoreResponse(context={}) + if not self._is_shm_active(): + return PrepareStoreResponse(context={}) + + obj_keys = self._resolve_obj_keys(key) + context = self.contexts.get(instance_id) + if context is None or context.non_cuda_metadata is None: + raise ValueError( + f"non-CUDA context not registered for instance ID {instance_id}" + ) + reserved = self.storage_manager.reserve_write( + obj_keys, context.non_cuda_metadata.layout_desc, "new" + ) + slots: list[dict] = [] + reserved_keys: list[ObjectKey] = [] + for obj_key in obj_keys: + memory_obj = reserved.get(obj_key) + if memory_obj is None or memory_obj.tensor is None: + continue + slots.append( + { + "offset": memory_obj.shm_offset, + "length": memory_obj.shm_byte_length, + "shape": list(memory_obj.tensor.shape), + "dtype": str(memory_obj.tensor.dtype).replace("torch.", ""), + } + ) + reserved_keys.append(obj_key) + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + self._pending_shm_writes[transfer_key] = reserved_keys + return PrepareStoreResponse(context={"slots": slots}) @_lmcache_nvtx_annotate def commit_store( @@ -438,6 +508,14 @@ def commit_store( Returns: ``True`` when all reserved objects are written, otherwise ``False``. """ + if cpu_data == b"" and self._is_shm_active(): + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + reserved_keys = self._pending_shm_writes.pop(transfer_key, []) + if not reserved_keys: + return False + self.storage_manager.finish_write(reserved_keys) + return True + obj_keys = self._resolve_obj_keys(key) context = self.contexts.get(instance_id) @@ -495,14 +573,42 @@ def prepare_retrieve( f"non-CUDA context not registered for instance ID {instance_id}" ) - prefetched_keys: list[ObjectKey] = [] + if self._is_shm_active(): + shm_prefetched_keys, shm_memory_objs = self.storage_manager.unsafe_read( + obj_keys + ) + if not shm_memory_objs or len(shm_prefetched_keys) != len(obj_keys): + if shm_prefetched_keys: + self.storage_manager.finish_read_prefetched(shm_prefetched_keys) + return PrepareRetrieveResponse(success=False, data=b"", context={}) + slots: list[dict] = [] + for memory_obj in shm_memory_objs: + if memory_obj.tensor is None: + self.storage_manager.finish_read_prefetched(shm_prefetched_keys) + return PrepareRetrieveResponse(success=False, data=b"", context={}) + slots.append( + { + "offset": memory_obj.shm_offset, + "length": memory_obj.shm_byte_length, + "shape": list(memory_obj.tensor.shape), + "dtype": str(memory_obj.tensor.dtype).replace("torch.", ""), + } + ) + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + self._pending_shm_reads[transfer_key] = shm_prefetched_keys + return PrepareRetrieveResponse( + success=True, data=b"", context={"slots": slots} + ) + + prefetched_keys_pickle: list[ObjectKey] = [] try: - with self.storage_manager.read_prefetched_results(obj_keys) as memory_objs: - if not memory_objs or len(memory_objs) != len(obj_keys): + read_ctx = self.storage_manager.read_prefetched_results(obj_keys) + with read_ctx as maybe_memory_objs: + if not maybe_memory_objs or len(maybe_memory_objs) != len(obj_keys): return PrepareRetrieveResponse(success=False, data=b"", context={}) - prefetched_keys = obj_keys[: len(memory_objs)] + prefetched_keys_pickle = obj_keys[: len(maybe_memory_objs)] chunks = [] - for memory_obj in memory_objs: + for memory_obj in maybe_memory_objs: if memory_obj.tensor is None: return PrepareRetrieveResponse( success=False, data=b"", context={} @@ -512,8 +618,8 @@ def prepare_retrieve( success=True, data=pickle.dumps(chunks), context={} ) finally: - if prefetched_keys: - self.storage_manager.finish_read_prefetched(prefetched_keys) + if prefetched_keys_pickle: + self.storage_manager.finish_read_prefetched(prefetched_keys_pickle) @_lmcache_nvtx_annotate def commit_retrieve( @@ -530,6 +636,12 @@ def commit_retrieve( Returns: Always ``True``. """ + if self._is_shm_active(): + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + prefetched_keys = self._pending_shm_reads.pop(transfer_key, []) + if prefetched_keys: + self.storage_manager.finish_read_prefetched(prefetched_keys) + return True return True @_lmcache_nvtx_annotate diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 2a598791bb6..610f867d242 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -25,6 +25,7 @@ scatter_cpu_to_paged_kv, ) from lmcache.v1.multiprocess.protocol import RequestType +from lmcache.v1.multiprocess.protocols.engine import RegisterNonGpuContextResponse logger = init_logger(__name__) @@ -291,14 +292,25 @@ def register( ) ], ) + response = future.result(timeout=mq_timeout) + shm_name = "" + pool_size = 0 + if isinstance(response, RegisterNonGpuContextResponse): + shm_name = response.shm_name + pool_size = response.pool_size metadata = NonGpuContextMetadata( layout_desc=layout_desc, block_size=block_size, use_mla=use_mla_flag, ) - self._non_gpu_context = create_non_gpu_context(metadata, mq_client, mq_timeout) - future.result(timeout=mq_timeout) + self._non_gpu_context = create_non_gpu_context( + metadata, + mq_client, + mq_timeout, + shm_name=shm_name, + pool_size=pool_size, + ) def submit_store( self, diff --git a/tests/v1/distributed/test_shm_l1_pool.py b/tests/v1/distributed/test_shm_l1_pool.py new file mode 100644 index 00000000000..51152dadf64 --- /dev/null +++ b/tests/v1/distributed/test_shm_l1_pool.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import MagicMock +import mmap +import os + +# Third Party +import torch + +# First Party +from lmcache.v1.distributed.api import MemoryLayoutDesc +from lmcache.v1.distributed.config import L1MemoryManagerConfig +from lmcache.v1.distributed.memory_manager import create_memory_allocator +from lmcache.v1.memory_management import MixedMemoryAllocator +from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata +from lmcache.v1.multiprocess.non_gpu_context_shm import NonGpuContextShm +from lmcache.v1.multiprocess.protocol import RequestType +from lmcache.v1.multiprocess.protocols.engine import ( + PrepareRetrieveResponse, + PrepareStoreResponse, +) + + +class _CompletedFuture: + def __init__(self, value): + self._value = value + + def result(self, timeout=None): # noqa: ARG002 + return self._value + + +def _create_shm_file(shm_name: str, size: int) -> str: + path = os.path.join("/dev/shm", shm_name.lstrip("/")) + fd = os.open(path, os.O_CREAT | os.O_RDWR, 0o600) + os.ftruncate(fd, size) + os.close(fd) + return path + + +def test_shm_segment_creation_and_cleanup() -> None: + shm_name = f"lmcache_l1_pool_test_{os.getpid()}" + cfg = L1MemoryManagerConfig( + size_in_bytes=1024 * 1024, + use_lazy=False, + shm_name=shm_name, + ) + allocator = create_memory_allocator(cfg) + assert isinstance(allocator, MixedMemoryAllocator) + assert allocator.shm_name == shm_name + shm_path = os.path.join("/dev/shm", shm_name) + assert os.path.exists(shm_path) + allocator.close() + assert not os.path.exists(shm_path) + + +def test_non_gpu_context_shm_tensor_view_from_buffer() -> None: + shm_name = f"lmcache_test_view_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + try: + with open(shm_path, "r+b") as f: + mm = mmap.mmap(f.fileno(), 4096, access=mmap.ACCESS_WRITE) + src = torch.arange(8, dtype=torch.float32).reshape(2, 4) + mm[: src.numel() * src.element_size()] = src.numpy().tobytes() + mm.close() + + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 4])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + view = context._make_tensor_view( + offset=0, + length=src.numel() * src.element_size(), + shape=[2, 4], + dtype_str="float32", + ) + assert torch.equal(view, src) + finally: + context.close() + finally: + if os.path.exists(shm_path): + os.unlink(shm_path) + + +def test_non_gpu_context_shm_store_retrieve_flow_with_mocked_mq() -> None: + shm_name = f"lmcache_test_flow_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + slots = [ + { + "offset": 0, + "length": 16, + "shape": [2, 2], + "dtype": "float32", + } + ] + + mq_client = MagicMock() + + def _submit_request(req_type, payload, response_cls): # noqa: ARG001 + if req_type == RequestType.PREPARE_STORE: + return _CompletedFuture(PrepareStoreResponse(context={"slots": slots})) + if req_type == RequestType.COMMIT_STORE: + assert payload[2] == b"" + return _CompletedFuture(True) + if req_type == RequestType.PREPARE_RETRIEVE: + return _CompletedFuture( + PrepareRetrieveResponse( + success=True, data=b"", context={"slots": slots} + ) + ) + if req_type == RequestType.COMMIT_RETRIEVE: + return _CompletedFuture(True) + raise AssertionError(f"Unexpected request type: {req_type}") + + mq_client.submit_request.side_effect = _submit_request + + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=mq_client, + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + store_views = context.prepare_store(key="k", instance_id=1) + assert store_views is not None + store_views[0].copy_( + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + ) + assert context.commit_store("k", 1, store_views) + + retrieve_views = context.prepare_retrieve(key="k", instance_id=1) + assert retrieve_views is not None + assert torch.equal( + retrieve_views[0], + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32), + ) + assert context.commit_retrieve("k", 1) + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) From 18eeae1f96874e1d07dc54db5d0f846cbda5cffc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 06:02:08 +0000 Subject: [PATCH 3/9] Harden SHM validation and address review feedback Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/07c7d0ab-d21a-4245-9109-006f91352b6c Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/distributed/memory_manager.py | 11 +++++++++-- lmcache/v1/multiprocess/non_gpu_context_shm.py | 7 ++++--- lmcache/v1/multiprocess/server.py | 12 ++++++++++-- tests/v1/distributed/test_shm_l1_pool.py | 3 ++- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/lmcache/v1/distributed/memory_manager.py b/lmcache/v1/distributed/memory_manager.py index ea8f0244da8..b571113e316 100644 --- a/lmcache/v1/distributed/memory_manager.py +++ b/lmcache/v1/distributed/memory_manager.py @@ -35,6 +35,9 @@ def _check_shm_capacity(required_bytes: int) -> bool: def _unlink_stale_shm(shm_name: str) -> None: """Remove a stale LMCache shm segment if it exists.""" normalized = shm_name.lstrip("/") + if "/" in normalized or "\\" in normalized: + logger.warning("Refusing to unlink invalid shm name %s", shm_name) + return if not normalized.startswith("lmcache_l1_pool_"): return shm_path = os.path.join("/dev/shm", normalized) @@ -77,8 +80,12 @@ def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInt shm_name = config.shm_name if shm_name: try: - if not _check_shm_capacity(config.size_in_bytes): - raise RuntimeError("insufficient /dev/shm capacity") + free_bytes = shutil.disk_usage("/dev/shm").free + if free_bytes < config.size_in_bytes: + raise RuntimeError( + "insufficient /dev/shm capacity: " + f"need {config.size_in_bytes} bytes, have {free_bytes} bytes" + ) _unlink_stale_shm(shm_name) return MixedMemoryAllocator( config.size_in_bytes, diff --git a/lmcache/v1/multiprocess/non_gpu_context_shm.py b/lmcache/v1/multiprocess/non_gpu_context_shm.py index 8f41cb8970f..8f28effdd35 100644 --- a/lmcache/v1/multiprocess/non_gpu_context_shm.py +++ b/lmcache/v1/multiprocess/non_gpu_context_shm.py @@ -48,7 +48,9 @@ def _make_tensor_view( dtype_str: str, ) -> torch.Tensor: """Create a tensor view over a SHM slot via ``torch.frombuffer``.""" - dtype = getattr(torch, dtype_str) + dtype = getattr(torch, dtype_str, None) + if dtype is None or not isinstance(dtype, torch.dtype): + raise ValueError(f"Invalid torch dtype string: {dtype_str}") itemsize = torch.empty((), dtype=dtype).element_size() if itemsize <= 0: raise ValueError(f"Invalid dtype size for {dtype_str}") @@ -83,9 +85,8 @@ def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None return self._build_slot_tensors(slots) if slots else None def commit_store( - self, key: Any, instance_id: int, chunks: list[torch.Tensor] + self, key: Any, instance_id: int, _chunks: list[torch.Tensor] ) -> bool: - del chunks future = self.mq_client.submit_request( RequestType.COMMIT_STORE, [key, instance_id, b""], diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 88c271f27bc..0c29f3cac6d 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -163,6 +163,11 @@ def batched_iteration(lst: list, batch_size: int) -> Generator[tuple, None, None yield batch +def _dtype_to_name(dtype: torch.dtype) -> str: + """Return a stable torch dtype name without module prefix.""" + return str(dtype).split(".")[-1] + + @dataclass class _PrefetchJob: handle: PrefetchHandle @@ -402,6 +407,7 @@ def register_kv_cache_non_gpu_context( def _make_non_gpu_transfer_key( key: IPCCacheEngineKey, instance_id: int ) -> tuple[object, ...]: + """Build a unique key for pending SHM write/read transfer tracking.""" return ( instance_id, key.model_name, @@ -483,7 +489,7 @@ def prepare_store( "offset": memory_obj.shm_offset, "length": memory_obj.shm_byte_length, "shape": list(memory_obj.tensor.shape), - "dtype": str(memory_obj.tensor.dtype).replace("torch.", ""), + "dtype": _dtype_to_name(memory_obj.tensor.dtype), } ) reserved_keys.append(obj_key) @@ -504,6 +510,8 @@ def commit_store( key: Cache key for the token range to store. instance_id: Worker instance identifier. cpu_data: Pickled list of CPU tensors produced by the worker. + In SHM mode, empty bytes (``b""``) indicate data is already + written to SHM and only lock finalization is required. Returns: ``True`` when all reserved objects are written, otherwise ``False``. @@ -591,7 +599,7 @@ def prepare_retrieve( "offset": memory_obj.shm_offset, "length": memory_obj.shm_byte_length, "shape": list(memory_obj.tensor.shape), - "dtype": str(memory_obj.tensor.dtype).replace("torch.", ""), + "dtype": _dtype_to_name(memory_obj.tensor.dtype), } ) transfer_key = self._make_non_gpu_transfer_key(key, instance_id) diff --git a/tests/v1/distributed/test_shm_l1_pool.py b/tests/v1/distributed/test_shm_l1_pool.py index 51152dadf64..3ac3102cf0d 100644 --- a/tests/v1/distributed/test_shm_l1_pool.py +++ b/tests/v1/distributed/test_shm_l1_pool.py @@ -111,7 +111,8 @@ def _submit_request(req_type, payload, response_cls): # noqa: ARG001 if req_type == RequestType.PREPARE_STORE: return _CompletedFuture(PrepareStoreResponse(context={"slots": slots})) if req_type == RequestType.COMMIT_STORE: - assert payload[2] == b"" + _, _, commit_cpu_data = payload + assert commit_cpu_data == b"" return _CompletedFuture(True) if req_type == RequestType.PREPARE_RETRIEVE: return _CompletedFuture( From 2b1f1d96c21b41f00ee4d1a01c0e095a64f61c77 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 06:19:09 +0000 Subject: [PATCH 4/9] Align SHM mmap usage and add pickle/shm transport logs Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/6bb8fb82-c368-43a5-a4b1-83bfedecc1a6 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/multiprocess/non_gpu_context.py | 9 ++++++++ .../v1/multiprocess/non_gpu_context_shm.py | 20 ++++++++++-------- lmcache/v1/multiprocess/server.py | 21 ++++++++++++++++++- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/lmcache/v1/multiprocess/non_gpu_context.py b/lmcache/v1/multiprocess/non_gpu_context.py index bc3e36deb71..39ababd7ccb 100644 --- a/lmcache/v1/multiprocess/non_gpu_context.py +++ b/lmcache/v1/multiprocess/non_gpu_context.py @@ -21,9 +21,12 @@ import torch # First Party +from lmcache.logging import init_logger from lmcache.utils import EngineType from lmcache.v1.distributed.api import MemoryLayoutDesc +logger = init_logger(__name__) + @dataclass class NonGpuContextMetadata: @@ -122,11 +125,17 @@ def create_non_gpu_context( # Local from .non_gpu_context_shm import NonGpuContextShm + logger.info( + "Creating NonGpuContextShm (shm_name=%s, pool_size=%d)", + shm_name, + pool_size, + ) return NonGpuContextShm(metadata, mq_client, mq_timeout, shm_name, pool_size) # Local from .non_gpu_context_pickle import NonGpuContextPickle + logger.info("Creating NonGpuContextPickle (pickle transport)") return NonGpuContextPickle(metadata, mq_client, mq_timeout) diff --git a/lmcache/v1/multiprocess/non_gpu_context_shm.py b/lmcache/v1/multiprocess/non_gpu_context_shm.py index 8f28effdd35..ebf9d3fd0f1 100644 --- a/lmcache/v1/multiprocess/non_gpu_context_shm.py +++ b/lmcache/v1/multiprocess/non_gpu_context_shm.py @@ -35,10 +35,14 @@ def __init__( self._shm_name = shm_name self._pool_size = pool_size shm_path = os.path.join("/dev/shm", shm_name.lstrip("/")) - self._shm_fd = os.open(shm_path, os.O_RDWR) - self._mmap_obj = mmap.mmap( - self._shm_fd, self._pool_size, access=mmap.ACCESS_WRITE - ) + shm_fd = os.open(shm_path, os.O_RDWR) + try: + self._mmap_obj = mmap.mmap( + shm_fd, self._pool_size, access=mmap.ACCESS_WRITE + ) + finally: + os.close(shm_fd) + self._buffer: Any = self._mmap_obj def _make_tensor_view( self, @@ -56,7 +60,7 @@ def _make_tensor_view( raise ValueError(f"Invalid dtype size for {dtype_str}") count = length // itemsize tensor_1d = torch.frombuffer( - self._mmap_obj, dtype=dtype, count=count, offset=offset + self._buffer, dtype=dtype, count=count, offset=offset ) return tensor_1d.view(torch.Size(shape)) @@ -124,7 +128,5 @@ def commit_retrieve(self, key: Any, instance_id: int) -> bool: return False def close(self) -> None: - try: - self._mmap_obj.close() - finally: - os.close(self._shm_fd) + self._buffer = None + self._mmap_obj.close() diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 0c29f3cac6d..e306da208dc 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -397,11 +397,30 @@ def register_kv_cache_non_gpu_context( ) shm_pool_info = self.storage_manager.get_shm_pool_info() if not isinstance(shm_pool_info, dict): + logger.info( + "Instance %s non-GPU context using pickle transport " + "(no SHM pool info returned)", + payload.instance_id, + ) return RegisterNonGpuContextResponse() - return RegisterNonGpuContextResponse( + response = RegisterNonGpuContextResponse( shm_name=str(shm_pool_info.get("shm_name", "")), pool_size=int(shm_pool_info.get("pool_size", 0)), ) + if response.shm_name and response.pool_size > 0: + logger.info( + "Instance %s non-GPU context using SHM transport " + "(shm_name=%s, pool_size=%d)", + payload.instance_id, + response.shm_name, + response.pool_size, + ) + else: + logger.info( + "Instance %s non-GPU context using pickle transport", + payload.instance_id, + ) + return response @staticmethod def _make_non_gpu_transfer_key( From 19727677b430357b329952337ba80b6c3b31d562 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 06:20:55 +0000 Subject: [PATCH 5/9] Simplify SHM mmap buffer handling Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/6bb8fb82-c368-43a5-a4b1-83bfedecc1a6 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/multiprocess/non_gpu_context_shm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lmcache/v1/multiprocess/non_gpu_context_shm.py b/lmcache/v1/multiprocess/non_gpu_context_shm.py index ebf9d3fd0f1..acbe1a5a0a6 100644 --- a/lmcache/v1/multiprocess/non_gpu_context_shm.py +++ b/lmcache/v1/multiprocess/non_gpu_context_shm.py @@ -42,7 +42,6 @@ def __init__( ) finally: os.close(shm_fd) - self._buffer: Any = self._mmap_obj def _make_tensor_view( self, @@ -60,7 +59,7 @@ def _make_tensor_view( raise ValueError(f"Invalid dtype size for {dtype_str}") count = length // itemsize tensor_1d = torch.frombuffer( - self._buffer, dtype=dtype, count=count, offset=offset + self._mmap_obj, dtype=dtype, count=count, offset=offset ) return tensor_1d.view(torch.Size(shape)) @@ -128,5 +127,4 @@ def commit_retrieve(self, key: Any, instance_id: int) -> bool: return False def close(self) -> None: - self._buffer = None self._mmap_obj.close() From ee5eed62c28d449f7eb9cbb5823fbad7750be899 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 06:33:56 +0000 Subject: [PATCH 6/9] Handle SHM no-op commit on repeated prompt store Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/0b3178a9-b4e2-411e-9e1d-1c8617b05893 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/multiprocess/server.py | 7 +- .../v1/multiprocess/test_non_cuda_context.py | 64 +++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index e306da208dc..5c8842b490b 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -537,10 +537,11 @@ def commit_store( """ if cpu_data == b"" and self._is_shm_active(): transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - reserved_keys = self._pending_shm_writes.pop(transfer_key, []) - if not reserved_keys: + reserved_keys = self._pending_shm_writes.pop(transfer_key, None) + if reserved_keys is None: return False - self.storage_manager.finish_write(reserved_keys) + if reserved_keys: + self.storage_manager.finish_write(reserved_keys) return True obj_keys = self._resolve_obj_keys(key) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index 5da7dc47aca..a84ed479063 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -459,3 +459,67 @@ def _read_prefetched_results(_keys: Any) -> Any: recovered_chunks: list[torch.Tensor] = pickle.loads(cpu_data) assert len(recovered_chunks) == 1 assert torch.allclose(recovered_chunks[0], payload) + + +def test_server_shm_commit_store_allows_noop_when_all_keys_exist( + stub_native_storage_ops: Any, +) -> None: + """Ensure SHM commit succeeds when prepare_store reserves zero new objects.""" + # First Party + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) + from lmcache.v1.multiprocess.server import MPCacheEngine + + mock_storage = MagicMock() + mock_storage.get_shm_pool_info.return_value = { + "shm_name": "lmcache_test_pool", + "pool_size": 1024, + } + mock_storage.reserve_write.return_value = {} + mock_session = MagicMock() + mock_session.get_hashes.return_value = [b"h"] + with ( + patch( + "lmcache.v1.multiprocess.server.StorageManager", + return_value=mock_storage, + ), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ), + ): + session_cls.return_value.get_or_create.return_value = mock_session + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + + engine.register_kv_cache_non_gpu_context( + RegisterNonGpuContextPayload( + instance_id=3, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + ) + key = IPCCacheEngineKey.from_token_ids( + "m", + 1, + 0, + [1] * 8, + start=0, + end=8, + request_id="req", + ) + prepare_response = engine.prepare_store(key, 3) + assert prepare_response.context["slots"] == [] + + store_ok = engine.commit_store(key, 3, b"") + assert store_ok is True + mock_storage.finish_write.assert_not_called() From 26a0cc7ffdde29a9e2cc192d9787988aa0982b03 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 06:36:02 +0000 Subject: [PATCH 7/9] Clarify SHM commit semantics for no-op writes Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/0b3178a9-b4e2-411e-9e1d-1c8617b05893 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/multiprocess/server.py | 2 ++ tests/v1/multiprocess/test_non_cuda_context.py | 1 + 2 files changed, 3 insertions(+) diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 5c8842b490b..922b287dd3f 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -538,8 +538,10 @@ def commit_store( if cpu_data == b"" and self._is_shm_active(): transfer_key = self._make_non_gpu_transfer_key(key, instance_id) reserved_keys = self._pending_shm_writes.pop(transfer_key, None) + # Missing transfer key means COMMIT arrived without matching PREPARE. if reserved_keys is None: return False + # Empty reservation is a valid no-op when all objects already exist. if reserved_keys: self.storage_manager.finish_write(reserved_keys) return True diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index a84ed479063..8e25d202a1e 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -477,6 +477,7 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( "shm_name": "lmcache_test_pool", "pool_size": 1024, } + # Empty reserve_write indicates all object keys already exist in cache. mock_storage.reserve_write.return_value = {} mock_session = MagicMock() mock_session.get_hashes.return_value = [b"h"] From 067f3e4ff8b617a0392ed51f61010b3c8640786b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 06:38:07 +0000 Subject: [PATCH 8/9] Expand SHM no-op regression test coverage Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/0b3178a9-b4e2-411e-9e1d-1c8617b05893 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- tests/v1/multiprocess/test_non_cuda_context.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index 8e25d202a1e..fe81826646d 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -464,7 +464,13 @@ def _read_prefetched_results(_keys: Any) -> Any: def test_server_shm_commit_store_allows_noop_when_all_keys_exist( stub_native_storage_ops: Any, ) -> None: - """Ensure SHM commit succeeds when prepare_store reserves zero new objects.""" + """Regression: repeated prompt after worker restart should no-op-store cleanly. + + When all object keys already exist in cache, SHM ``prepare_store`` reserves + no new objects and returns empty slots. ``commit_store`` must still succeed + as a valid no-op for that prepared transfer, but fail without a matching + prepare state. + """ # First Party from lmcache.v1.multiprocess.custom_types import ( IPCCacheEngineKey, @@ -524,3 +530,6 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( store_ok = engine.commit_store(key, 3, b"") assert store_ok is True mock_storage.finish_write.assert_not_called() + + # A second commit without a matching prepare must fail. + assert engine.commit_store(key, 3, b"") is False From 2a02acb471324c795329edd08a91a9ded1a8bb0e Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 15:13:26 +0800 Subject: [PATCH 9/9] Port SHM non-GPU transport to cpu_context_pickle branch and fix correctness gaps from #278 review (#280) * Initial plan * Fix SHM non-GPU transport idempotency, locking, and cleanup issues Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/661cbeee-d0d4-40ef-9312-4044e4696a51 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> * Polish SHM feedback fixes and align validation comments Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/661cbeee-d0d4-40ef-9312-4044e4696a51 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> * Fix memory leak: early-return from prepare_store when all keys exist When reserve_write returns empty (all object keys already cached), return PrepareStoreResponse(context={}) immediately without storing an entry in _pending_shm_writes. This prevents leaked entries that would never be popped since the worker won't call commit_store. Agent-Logs-Url: https://github.com/hlin99/LMCache/sessions/182111d5-1737-49c0-be65-0287d5b9d6c5 Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> --- lmcache/v1/distributed/memory_manager.py | 12 --- .../v1/multiprocess/non_gpu_context_shm.py | 29 ++++-- lmcache/v1/multiprocess/server.py | 90 +++++++++++------- tests/v1/distributed/test_shm_l1_pool.py | 68 +++++++++++++- .../v1/multiprocess/test_non_cuda_context.py | 93 +++++++++++++++++-- 5 files changed, 228 insertions(+), 64 deletions(-) diff --git a/lmcache/v1/distributed/memory_manager.py b/lmcache/v1/distributed/memory_manager.py index b571113e316..5cbede5c057 100644 --- a/lmcache/v1/distributed/memory_manager.py +++ b/lmcache/v1/distributed/memory_manager.py @@ -20,18 +20,6 @@ logger = init_logger(__name__) -# HELPER FUNCTIONS -def _check_shm_capacity(required_bytes: int) -> bool: - """Return whether ``/dev/shm`` has enough free space.""" - if required_bytes <= 0: - return True - try: - free_bytes = shutil.disk_usage("/dev/shm").free - except OSError: - return False - return free_bytes >= required_bytes - - def _unlink_stale_shm(shm_name: str) -> None: """Remove a stale LMCache shm segment if it exists.""" normalized = shm_name.lstrip("/") diff --git a/lmcache/v1/multiprocess/non_gpu_context_shm.py b/lmcache/v1/multiprocess/non_gpu_context_shm.py index acbe1a5a0a6..cfa7c58748c 100644 --- a/lmcache/v1/multiprocess/non_gpu_context_shm.py +++ b/lmcache/v1/multiprocess/non_gpu_context_shm.py @@ -16,6 +16,8 @@ ) from lmcache.v1.multiprocess.protocol import RequestType, get_response_class +INVALID_SHM_FD = -1 + class NonGpuContextShm(NonGpuContext): """Shared-memory implementation of :class:`NonGpuContext`.""" @@ -34,14 +36,17 @@ def __init__( self._shm_name = shm_name self._pool_size = pool_size + self._shm_fd = INVALID_SHM_FD shm_path = os.path.join("/dev/shm", shm_name.lstrip("/")) - shm_fd = os.open(shm_path, os.O_RDWR) + self._shm_fd = os.open(shm_path, os.O_RDWR) try: self._mmap_obj = mmap.mmap( - shm_fd, self._pool_size, access=mmap.ACCESS_WRITE + self._shm_fd, self._pool_size, access=mmap.ACCESS_WRITE ) - finally: - os.close(shm_fd) + except Exception: + os.close(self._shm_fd) + self._shm_fd = INVALID_SHM_FD + raise def _make_tensor_view( self, @@ -84,8 +89,11 @@ def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None response = future.result(timeout=self.mq_timeout) except TimeoutError: return None - slots = response.context.get("slots", []) - return self._build_slot_tensors(slots) if slots else None + context = response.context if isinstance(response.context, dict) else {} + slots = context.get("slots") + if not isinstance(slots, list) or not slots: + return None + return self._build_slot_tensors(slots) def commit_store( self, key: Any, instance_id: int, _chunks: list[torch.Tensor] @@ -127,4 +135,11 @@ def commit_retrieve(self, key: Any, instance_id: int) -> bool: return False def close(self) -> None: - self._mmap_obj.close() + if self._shm_fd == INVALID_SHM_FD: + return + try: + self._mmap_obj.close() + finally: + fd = self._shm_fd + self._shm_fd = INVALID_SHM_FD + os.close(fd) diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 922b287dd3f..794fa7d79b5 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -260,8 +260,17 @@ def __init__( # for crash resilience (e.g., client calls lookup but never queries) self._prefetch_jobs: dict[str, _PrefetchJob] = {} self._prefetch_job_lock = threading.Lock() - self._pending_shm_writes: dict[tuple[object, ...], list[ObjectKey]] = {} - self._pending_shm_reads: dict[tuple[object, ...], list[ObjectKey]] = {} + # Pending SHM transfer tracking, keyed by (instance_id, IPC key). + # IPCCacheEngineKey is a frozen dataclass and hashable, so it is safe + # and efficient for dict lookups across pending transfer tracking. + self._pending_shm_writes: dict[ + tuple[int, IPCCacheEngineKey], list[ObjectKey] + ] = {} + self._pending_shm_reads: dict[ + tuple[int, IPCCacheEngineKey], list[ObjectKey] + ] = {} + self._pending_shm_lock = threading.Lock() + self._shm_active = False self._setup_metrics() @@ -341,12 +350,23 @@ def unregister_kv_cache(self, instance_id: int) -> None: torch_dev.empty_cache() else: logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) - self._pending_shm_writes = { - k: v for k, v in self._pending_shm_writes.items() if k[0] != instance_id - } - self._pending_shm_reads = { - k: v for k, v in self._pending_shm_reads.items() if k[0] != instance_id - } + with self._pending_shm_lock: + stale_writes = { + k: v for k, v in self._pending_shm_writes.items() if k[0] == instance_id + } + for transfer_key in stale_writes: + del self._pending_shm_writes[transfer_key] + stale_reads = { + k: v for k, v in self._pending_shm_reads.items() if k[0] == instance_id + } + for transfer_key in stale_reads: + del self._pending_shm_reads[transfer_key] + for reserved_keys in stale_writes.values(): + if reserved_keys: + self.storage_manager.finish_write(reserved_keys) + for prefetched_keys in stale_reads.values(): + if prefetched_keys: + self.storage_manager.finish_read_prefetched(prefetched_keys) def register_kv_cache_non_gpu_context( self, @@ -397,17 +417,21 @@ def register_kv_cache_non_gpu_context( ) shm_pool_info = self.storage_manager.get_shm_pool_info() if not isinstance(shm_pool_info, dict): + self._shm_active = False logger.info( "Instance %s non-GPU context using pickle transport " "(no SHM pool info returned)", payload.instance_id, ) return RegisterNonGpuContextResponse() + shm_name = str(shm_pool_info.get("shm_name", "")) + pool_size = int(shm_pool_info.get("pool_size", 0)) + self._shm_active = bool(shm_name) and pool_size > 0 response = RegisterNonGpuContextResponse( - shm_name=str(shm_pool_info.get("shm_name", "")), - pool_size=int(shm_pool_info.get("pool_size", 0)), + shm_name=shm_name, + pool_size=pool_size, ) - if response.shm_name and response.pool_size > 0: + if self._shm_active: logger.info( "Instance %s non-GPU context using SHM transport " "(shm_name=%s, pool_size=%d)", @@ -425,28 +449,12 @@ def register_kv_cache_non_gpu_context( @staticmethod def _make_non_gpu_transfer_key( key: IPCCacheEngineKey, instance_id: int - ) -> tuple[object, ...]: + ) -> tuple[int, IPCCacheEngineKey]: """Build a unique key for pending SHM write/read transfer tracking.""" - return ( - instance_id, - key.model_name, - key.world_size, - key.worker_id, - key.token_ids, - key.start, - key.end, - key.request_id, - key.cache_salt, - ) + return (instance_id, key) def _is_shm_active(self) -> bool: - shm_pool_info = self.storage_manager.get_shm_pool_info() - if not isinstance(shm_pool_info, dict): - return False - return ( - bool(shm_pool_info.get("shm_name")) - and int(shm_pool_info.get("pool_size", 0)) > 0 - ) + return self._shm_active def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: """Resolve object keys from an IPC cache key. @@ -512,8 +520,11 @@ def prepare_store( } ) reserved_keys.append(obj_key) + if not reserved_keys: + return PrepareStoreResponse(context={}) transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - self._pending_shm_writes[transfer_key] = reserved_keys + with self._pending_shm_lock: + self._pending_shm_writes[transfer_key] = reserved_keys return PrepareStoreResponse(context={"slots": slots}) @_lmcache_nvtx_annotate @@ -537,7 +548,8 @@ def commit_store( """ if cpu_data == b"" and self._is_shm_active(): transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - reserved_keys = self._pending_shm_writes.pop(transfer_key, None) + with self._pending_shm_lock: + reserved_keys = self._pending_shm_writes.pop(transfer_key, None) # Missing transfer key means COMMIT arrived without matching PREPARE. if reserved_keys is None: return False @@ -604,10 +616,16 @@ def prepare_retrieve( ) if self._is_shm_active(): + # Precondition: read locks for these keys were acquired during the + # lookup/prefetch phase before this retrieve call. shm_prefetched_keys, shm_memory_objs = self.storage_manager.unsafe_read( obj_keys ) - if not shm_memory_objs or len(shm_prefetched_keys) != len(obj_keys): + if ( + not shm_memory_objs + or len(shm_prefetched_keys) != len(obj_keys) + or len(shm_memory_objs) != len(obj_keys) + ): if shm_prefetched_keys: self.storage_manager.finish_read_prefetched(shm_prefetched_keys) return PrepareRetrieveResponse(success=False, data=b"", context={}) @@ -625,7 +643,8 @@ def prepare_retrieve( } ) transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - self._pending_shm_reads[transfer_key] = shm_prefetched_keys + with self._pending_shm_lock: + self._pending_shm_reads[transfer_key] = shm_prefetched_keys return PrepareRetrieveResponse( success=True, data=b"", context={"slots": slots} ) @@ -668,7 +687,8 @@ def commit_retrieve( """ if self._is_shm_active(): transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - prefetched_keys = self._pending_shm_reads.pop(transfer_key, []) + with self._pending_shm_lock: + prefetched_keys = self._pending_shm_reads.pop(transfer_key, []) if prefetched_keys: self.storage_manager.finish_read_prefetched(prefetched_keys) return True diff --git a/tests/v1/distributed/test_shm_l1_pool.py b/tests/v1/distributed/test_shm_l1_pool.py index 3ac3102cf0d..7919214f85d 100644 --- a/tests/v1/distributed/test_shm_l1_pool.py +++ b/tests/v1/distributed/test_shm_l1_pool.py @@ -6,6 +6,7 @@ import os # Third Party +import pytest import torch # First Party @@ -13,7 +14,11 @@ from lmcache.v1.distributed.config import L1MemoryManagerConfig from lmcache.v1.distributed.memory_manager import create_memory_allocator from lmcache.v1.memory_management import MixedMemoryAllocator -from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContextMetadata, + create_non_gpu_context, +) +from lmcache.v1.multiprocess.non_gpu_context_pickle import NonGpuContextPickle from lmcache.v1.multiprocess.non_gpu_context_shm import NonGpuContextShm from lmcache.v1.multiprocess.protocol import RequestType from lmcache.v1.multiprocess.protocols.engine import ( @@ -159,3 +164,64 @@ def _submit_request(req_type, payload, response_cls): # noqa: ARG001 context.close() if os.path.exists(shm_path): os.unlink(shm_path) + + +def test_non_gpu_context_shm_init_raises_when_segment_missing() -> None: + with pytest.raises(FileNotFoundError, match="No such file or directory"): + NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name="lmcache_missing_shm_segment", + pool_size=4096, + ) + + +def test_create_non_gpu_context_falls_back_to_pickle_without_shm_info() -> None: + context = create_non_gpu_context( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name="", + pool_size=0, + ) + assert isinstance(context, NonGpuContextPickle) + + +def test_non_gpu_context_shm_close_is_idempotent() -> None: + shm_name = f"lmcache_test_close_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + try: + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + context.close() + context.close() + finally: + if os.path.exists(shm_path): + os.unlink(shm_path) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index fe81826646d..05fb83def6a 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -467,9 +467,9 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( """Regression: repeated prompt after worker restart should no-op-store cleanly. When all object keys already exist in cache, SHM ``prepare_store`` reserves - no new objects and returns empty slots. ``commit_store`` must still succeed - as a valid no-op for that prepared transfer, but fail without a matching - prepare state. + no new objects and returns empty context (no "slots" key). The worker sees + no slots and does not call ``commit_store``, so no entry leaks in + ``_pending_shm_writes``. """ # First Party from lmcache.v1.multiprocess.custom_types import ( @@ -487,6 +487,7 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( mock_storage.reserve_write.return_value = {} mock_session = MagicMock() mock_session.get_hashes.return_value = [b"h"] + with ( patch( "lmcache.v1.multiprocess.server.StorageManager", @@ -525,11 +526,85 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( request_id="req", ) prepare_response = engine.prepare_store(key, 3) - assert prepare_response.context["slots"] == [] - - store_ok = engine.commit_store(key, 3, b"") - assert store_ok is True - mock_storage.finish_write.assert_not_called() + # Empty context means no slots reserved — worker won't call commit_store. + assert prepare_response.context == {} - # A second commit without a matching prepare must fail. + # commit_store without a matching prepare must fail (no entry leaked). assert engine.commit_store(key, 3, b"") is False + + +def test_server_unregister_non_gpu_context_releases_pending_shm_locks( + stub_native_storage_ops: Any, +) -> None: + """Ensure unregister releases pending SHM read/write reservations.""" + # First Party + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) + from lmcache.v1.multiprocess.server import MPCacheEngine + + mock_storage = MagicMock() + mock_storage.get_shm_pool_info.return_value = { + "shm_name": "lmcache_l1_pool_test", + "pool_size": 4096, + } + mock_memory_obj = MagicMock() + mock_memory_obj.tensor = torch.zeros(2, 2, 8, 16) + mock_memory_obj.shm_offset = 0 + mock_memory_obj.shm_byte_length = 2048 + mock_storage.reserve_write.side_effect = ( + lambda obj_keys, *_args, **_kwargs: { + obj_key: mock_memory_obj for obj_key in obj_keys + } + ) + mock_storage.unsafe_read.side_effect = ( + lambda obj_keys: (obj_keys, [mock_memory_obj for _ in obj_keys]) + ) + mock_session = MagicMock() + mock_session.get_hashes.return_value = [b"h"] + + with ( + patch( + "lmcache.v1.multiprocess.server.StorageManager", + return_value=mock_storage, + ), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ), + ): + session_cls.return_value.get_or_create.return_value = mock_session + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + + engine.register_kv_cache_non_gpu_context( + RegisterNonGpuContextPayload( + instance_id=4, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + ) + key = IPCCacheEngineKey.from_token_ids( + "m", + 1, + 0, + [1] * 8, + start=0, + end=8, + request_id="req", + ) + assert engine.prepare_store(key, 4).context.get("slots") + assert engine.prepare_retrieve(key, 4).success is True + + engine.unregister_kv_cache(4) + + mock_storage.finish_write.assert_called_once() + mock_storage.finish_read_prefetched.assert_called_once()