diff --git a/lmcache/v1/distributed/config.py b/lmcache/v1/distributed/config.py index 2690b043970..bc39392a54b 100644 --- a/lmcache/v1/distributed/config.py +++ b/lmcache/v1/distributed/config.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from typing import Literal import argparse +import os # First Party from lmcache import torch_dev @@ -39,6 +40,9 @@ class L1MemoryManagerConfig: align_bytes: int = field(default=0x1000) """ The alignment size in bytes. Default is 4KB. """ + shm_name: str = field(default_factory=lambda: f"lmcache_l1_pool_{os.getpid()}") + """ POSIX shared-memory segment name for L1 pool. Empty disables SHM. """ + def __post_init__(self): self.init_size_in_bytes = min(self.init_size_in_bytes, self.size_in_bytes) diff --git a/lmcache/v1/distributed/l1_manager.py b/lmcache/v1/distributed/l1_manager.py index e4e4379f30f..f85eaabfa41 100644 --- a/lmcache/v1/distributed/l1_manager.py +++ b/lmcache/v1/distributed/l1_manager.py @@ -803,6 +803,10 @@ def get_l1_memory_desc(self): """Return an L1MemoryDesc describing the underlying L1 memory buffer.""" return self._memory_manager.get_l1_memory_desc() + def get_shm_pool_info(self) -> dict: + """Return SHM pool metadata for non-GPU SHM transport.""" + return self._memory_manager.get_shm_pool_info() + def close(self) -> None: """Close the L1Manager and free all resources.""" with self._lock: diff --git a/lmcache/v1/distributed/memory_manager.py b/lmcache/v1/distributed/memory_manager.py index e2bfff45c4e..5cbede5c057 100644 --- a/lmcache/v1/distributed/memory_manager.py +++ b/lmcache/v1/distributed/memory_manager.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +# Standard +import os +import shutil + # First Party from lmcache.logging import init_logger from lmcache.v1.distributed.api import MemoryLayoutDesc @@ -16,7 +20,23 @@ logger = init_logger(__name__) -# HELPER FUNCTIONS +def _unlink_stale_shm(shm_name: str) -> None: + """Remove a stale LMCache shm segment if it exists.""" + normalized = shm_name.lstrip("/") + if "/" in normalized or "\\" in normalized: + logger.warning("Refusing to unlink invalid shm name %s", shm_name) + return + if not normalized.startswith("lmcache_l1_pool_"): + return + shm_path = os.path.join("/dev/shm", normalized) + try: + os.unlink(shm_path) + except FileNotFoundError: + return + except OSError: + logger.warning("Failed to remove stale shm segment %s", shm_path, exc_info=True) + + def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInterface: """ Create a memory allocator based on the provided configuration. @@ -45,6 +65,27 @@ def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInt config.size_in_bytes, config.align_bytes, ) + shm_name = config.shm_name + if shm_name: + try: + free_bytes = shutil.disk_usage("/dev/shm").free + if free_bytes < config.size_in_bytes: + raise RuntimeError( + "insufficient /dev/shm capacity: " + f"need {config.size_in_bytes} bytes, have {free_bytes} bytes" + ) + _unlink_stale_shm(shm_name) + return MixedMemoryAllocator( + config.size_in_bytes, + align_bytes=config.align_bytes, + shm_name=shm_name, + ) + except (RuntimeError, OSError, ValueError): + logger.warning( + "Failed to initialize SHM pool (%s), falling back to pickle path", + shm_name, + exc_info=True, + ) return MixedMemoryAllocator( config.size_in_bytes, align_bytes=config.align_bytes, @@ -65,6 +106,13 @@ def __init__(self, config: L1MemoryManagerConfig): self._allocator = create_memory_allocator(config) self._size_in_bytes = config.size_in_bytes self._align_bytes = config.align_bytes + self._shm_pool_info = {"shm_name": "", "pool_size": 0} + if isinstance(self._allocator, MixedMemoryAllocator): + if self._allocator.shm_name: + self._shm_pool_info = { + "shm_name": self._allocator.shm_name, + "pool_size": self._size_in_bytes, + } def allocate( self, layout_desc: MemoryLayoutDesc, count: int @@ -174,6 +222,10 @@ def close(self) -> None: """ self._allocator.close() + def get_shm_pool_info(self) -> dict: + """Return SHM pool metadata for non-GPU SHM transport.""" + return dict(self._shm_pool_info) + # Debugging APIs def memcheck(self): return self._allocator.memcheck() diff --git a/lmcache/v1/distributed/storage_manager.py b/lmcache/v1/distributed/storage_manager.py index 6b4f95bd996..82df94b4fa1 100644 --- a/lmcache/v1/distributed/storage_manager.py +++ b/lmcache/v1/distributed/storage_manager.py @@ -559,6 +559,25 @@ def touch_l1_keys(self, keys: list[ObjectKey]): """ self._l1_manager.touch_keys(keys) + def get_shm_pool_info(self) -> dict: + """Return SHM pool metadata from the L1 memory manager.""" + return self._l1_manager.get_shm_pool_info() + + def unsafe_read( + self, keys: list[ObjectKey] + ) -> tuple[list[ObjectKey], list[MemoryObj]]: + """Read already read-locked objects without acquiring new read locks.""" + read_results = self._l1_manager.unsafe_read(keys) + good_keys: list[ObjectKey] = [] + good_objs: list[MemoryObj] = [] + for key in keys: + err, obj = read_results.get(key, (L1Error.KEY_NOT_EXIST, None)) + if err != L1Error.SUCCESS or obj is None: + continue + good_keys.append(key) + good_objs.append(obj) + return good_keys, good_objs + @property def quota_manager(self) -> QuotaManager: """Per-cache_salt quota registry. diff --git a/lmcache/v1/memory_management.py b/lmcache/v1/memory_management.py index 164cb058f6f..c3a6018ed41 100644 --- a/lmcache/v1/memory_management.py +++ b/lmcache/v1/memory_management.py @@ -299,6 +299,16 @@ def get_num_tokens(self) -> int: """ raise NotImplementedError + @property + def shm_offset(self) -> int: + """Return the byte offset of this object inside the SHM pool.""" + return self.meta.address + + @property + def shm_byte_length(self) -> int: + """Return the byte length of this object inside the SHM pool.""" + return self.get_size() + @property @abc.abstractmethod def metadata(self) -> MemoryObjMetadata: diff --git a/lmcache/v1/multiprocess/non_gpu_context.py b/lmcache/v1/multiprocess/non_gpu_context.py index e782c76d0c3..39ababd7ccb 100644 --- a/lmcache/v1/multiprocess/non_gpu_context.py +++ b/lmcache/v1/multiprocess/non_gpu_context.py @@ -21,9 +21,12 @@ import torch # First Party +from lmcache.logging import init_logger from lmcache.utils import EngineType from lmcache.v1.distributed.api import MemoryLayoutDesc +logger = init_logger(__name__) + @dataclass class NonGpuContextMetadata: @@ -100,24 +103,39 @@ def create_non_gpu_context( metadata: NonGpuContextMetadata, mq_client: Any, mq_timeout: float, + shm_name: str = "", + pool_size: int = 0, ) -> NonGpuContext: """Factory that returns the appropriate :class:`NonGpuContext` implementation. - Currently always returns a pickle-based implementation - (``NonGpuContextPickle``). A future SHM-capable PR - may probe for shared-memory availability and fall back to pickle. + Returns SHM-based implementation when shared-memory pool information is + available; otherwise falls back to the pickle-based implementation. Args: metadata: Layout metadata for the non-GPU context. mq_client: Message-queue client for server communication. mq_timeout: Timeout in seconds for blocking MQ requests. + shm_name: Shared-memory segment name. Empty means pickle mode. + pool_size: Shared-memory pool size in bytes. Non-positive means pickle mode. Returns: A concrete :class:`NonGpuContext` instance. """ + if shm_name and pool_size > 0: + # Local + from .non_gpu_context_shm import NonGpuContextShm + + logger.info( + "Creating NonGpuContextShm (shm_name=%s, pool_size=%d)", + shm_name, + pool_size, + ) + return NonGpuContextShm(metadata, mq_client, mq_timeout, shm_name, pool_size) + # Local from .non_gpu_context_pickle import NonGpuContextPickle + logger.info("Creating NonGpuContextPickle (pickle transport)") return NonGpuContextPickle(metadata, mq_client, mq_timeout) diff --git a/lmcache/v1/multiprocess/non_gpu_context_shm.py b/lmcache/v1/multiprocess/non_gpu_context_shm.py new file mode 100644 index 00000000000..cfa7c58748c --- /dev/null +++ b/lmcache/v1/multiprocess/non_gpu_context_shm.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared-memory NonGpuContext implementation for multiprocess mode.""" + +# Standard +from typing import Any +import mmap +import os + +# Third Party +import torch + +# First Party +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContext, + NonGpuContextMetadata, +) +from lmcache.v1.multiprocess.protocol import RequestType, get_response_class + +INVALID_SHM_FD = -1 + + +class NonGpuContextShm(NonGpuContext): + """Shared-memory implementation of :class:`NonGpuContext`.""" + + def __init__( + self, + metadata: NonGpuContextMetadata, + mq_client: Any, + mq_timeout: float, + shm_name: str, + pool_size: int, + ) -> None: + super().__init__(metadata, mq_client, mq_timeout) + if not shm_name or pool_size <= 0: + raise ValueError("shm_name must be non-empty and pool_size must be > 0") + + self._shm_name = shm_name + self._pool_size = pool_size + self._shm_fd = INVALID_SHM_FD + shm_path = os.path.join("/dev/shm", shm_name.lstrip("/")) + self._shm_fd = os.open(shm_path, os.O_RDWR) + try: + self._mmap_obj = mmap.mmap( + self._shm_fd, self._pool_size, access=mmap.ACCESS_WRITE + ) + except Exception: + os.close(self._shm_fd) + self._shm_fd = INVALID_SHM_FD + raise + + def _make_tensor_view( + self, + offset: int, + length: int, + shape: list[int], + dtype_str: str, + ) -> torch.Tensor: + """Create a tensor view over a SHM slot via ``torch.frombuffer``.""" + dtype = getattr(torch, dtype_str, None) + if dtype is None or not isinstance(dtype, torch.dtype): + raise ValueError(f"Invalid torch dtype string: {dtype_str}") + itemsize = torch.empty((), dtype=dtype).element_size() + if itemsize <= 0: + raise ValueError(f"Invalid dtype size for {dtype_str}") + count = length // itemsize + tensor_1d = torch.frombuffer( + self._mmap_obj, dtype=dtype, count=count, offset=offset + ) + return tensor_1d.view(torch.Size(shape)) + + def _build_slot_tensors(self, slots: list[dict[str, Any]]) -> list[torch.Tensor]: + return [ + self._make_tensor_view( + offset=int(slot["offset"]), + length=int(slot["length"]), + shape=list(slot["shape"]), + dtype_str=str(slot["dtype"]), + ) + for slot in slots + ] + + def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + future = self.mq_client.submit_request( + RequestType.PREPARE_STORE, + [key, instance_id], + get_response_class(RequestType.PREPARE_STORE), + ) + try: + response = future.result(timeout=self.mq_timeout) + except TimeoutError: + return None + context = response.context if isinstance(response.context, dict) else {} + slots = context.get("slots") + if not isinstance(slots, list) or not slots: + return None + return self._build_slot_tensors(slots) + + def commit_store( + self, key: Any, instance_id: int, _chunks: list[torch.Tensor] + ) -> bool: + future = self.mq_client.submit_request( + RequestType.COMMIT_STORE, + [key, instance_id, b""], + get_response_class(RequestType.COMMIT_STORE), + ) + try: + return bool(future.result(timeout=self.mq_timeout)) + except TimeoutError: + return False + + def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None: + future = self.mq_client.submit_request( + RequestType.PREPARE_RETRIEVE, + [key, instance_id], + get_response_class(RequestType.PREPARE_RETRIEVE), + ) + try: + response = future.result(timeout=self.mq_timeout) + except TimeoutError: + return None + if not response.success: + return None + slots = response.context.get("slots", []) + return self._build_slot_tensors(slots) if slots else None + + def commit_retrieve(self, key: Any, instance_id: int) -> bool: + future = self.mq_client.submit_request( + RequestType.COMMIT_RETRIEVE, + [key, instance_id], + get_response_class(RequestType.COMMIT_RETRIEVE), + ) + try: + return bool(future.result(timeout=self.mq_timeout)) + except TimeoutError: + return False + + def close(self) -> None: + if self._shm_fd == INVALID_SHM_FD: + return + try: + self._mmap_obj.close() + finally: + fd = self._shm_fd + self._shm_fd = INVALID_SHM_FD + os.close(fd) diff --git a/lmcache/v1/multiprocess/protocols/engine.py b/lmcache/v1/multiprocess/protocols/engine.py index 62ec4926cd8..bea9a1723c2 100644 --- a/lmcache/v1/multiprocess/protocols/engine.py +++ b/lmcache/v1/multiprocess/protocols/engine.py @@ -46,6 +46,14 @@ class PrepareRetrieveResponse: ) # pickle: {}, shm will put slot info here +@dataclass +class RegisterNonGpuContextResponse: + """Response for REGISTER_KV_CACHE_NON_GPU_CONTEXT.""" + + shm_name: str = "" + pool_size: int = 0 + + # Define request names for this protocol group REQUEST_NAMES = [ "REGISTER_KV_CACHE", @@ -179,10 +187,10 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]: # Register non-GPU KV cache context # Payload: # - RegisterNonGpuContextPayload - all metadata fields in one struct - # Returns: None + # Returns: RegisterNonGpuContextResponse "REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition( payload_classes=[RegisterNonGpuContextPayload], - response_class=None, + response_class=RegisterNonGpuContextResponse, handler_type=HandlerType.SYNC, ), "PREPARE_STORE": ProtocolDefinition( diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index bb67d07bf54..794fa7d79b5 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -72,6 +72,7 @@ from lmcache.v1.multiprocess.protocols.engine import ( PrepareRetrieveResponse, PrepareStoreResponse, + RegisterNonGpuContextResponse, ) from lmcache.v1.multiprocess.session import SessionManager from lmcache.v1.multiprocess.token_hasher import TokenHasher @@ -162,6 +163,11 @@ def batched_iteration(lst: list, batch_size: int) -> Generator[tuple, None, None yield batch +def _dtype_to_name(dtype: torch.dtype) -> str: + """Return a stable torch dtype name without module prefix.""" + return str(dtype).split(".")[-1] + + @dataclass class _PrefetchJob: handle: PrefetchHandle @@ -254,6 +260,17 @@ def __init__( # for crash resilience (e.g., client calls lookup but never queries) self._prefetch_jobs: dict[str, _PrefetchJob] = {} self._prefetch_job_lock = threading.Lock() + # Pending SHM transfer tracking, keyed by (instance_id, IPC key). + # IPCCacheEngineKey is a frozen dataclass and hashable, so it is safe + # and efficient for dict lookups across pending transfer tracking. + self._pending_shm_writes: dict[ + tuple[int, IPCCacheEngineKey], list[ObjectKey] + ] = {} + self._pending_shm_reads: dict[ + tuple[int, IPCCacheEngineKey], list[ObjectKey] + ] = {} + self._pending_shm_lock = threading.Lock() + self._shm_active = False self._setup_metrics() @@ -333,11 +350,28 @@ def unregister_kv_cache(self, instance_id: int) -> None: torch_dev.empty_cache() else: logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) + with self._pending_shm_lock: + stale_writes = { + k: v for k, v in self._pending_shm_writes.items() if k[0] == instance_id + } + for transfer_key in stale_writes: + del self._pending_shm_writes[transfer_key] + stale_reads = { + k: v for k, v in self._pending_shm_reads.items() if k[0] == instance_id + } + for transfer_key in stale_reads: + del self._pending_shm_reads[transfer_key] + for reserved_keys in stale_writes.values(): + if reserved_keys: + self.storage_manager.finish_write(reserved_keys) + for prefetched_keys in stale_reads.values(): + if prefetched_keys: + self.storage_manager.finish_read_prefetched(prefetched_keys) def register_kv_cache_non_gpu_context( self, payload: RegisterNonGpuContextPayload, - ) -> None: + ) -> RegisterNonGpuContextResponse: """Register non-CUDA KV layout metadata for non-GPU context mode. Args: @@ -354,7 +388,7 @@ def register_kv_cache_non_gpu_context( "skipping the new registration", payload.instance_id, ) - return + return RegisterNonGpuContextResponse() dtype = getattr(torch, payload.dtype_str, None) if dtype is None or not isinstance(dtype, torch.dtype): @@ -381,6 +415,46 @@ def register_kv_cache_non_gpu_context( use_mla=payload.use_mla, ), ) + shm_pool_info = self.storage_manager.get_shm_pool_info() + if not isinstance(shm_pool_info, dict): + self._shm_active = False + logger.info( + "Instance %s non-GPU context using pickle transport " + "(no SHM pool info returned)", + payload.instance_id, + ) + return RegisterNonGpuContextResponse() + shm_name = str(shm_pool_info.get("shm_name", "")) + pool_size = int(shm_pool_info.get("pool_size", 0)) + self._shm_active = bool(shm_name) and pool_size > 0 + response = RegisterNonGpuContextResponse( + shm_name=shm_name, + pool_size=pool_size, + ) + if self._shm_active: + logger.info( + "Instance %s non-GPU context using SHM transport " + "(shm_name=%s, pool_size=%d)", + payload.instance_id, + response.shm_name, + response.pool_size, + ) + else: + logger.info( + "Instance %s non-GPU context using pickle transport", + payload.instance_id, + ) + return response + + @staticmethod + def _make_non_gpu_transfer_key( + key: IPCCacheEngineKey, instance_id: int + ) -> tuple[int, IPCCacheEngineKey]: + """Build a unique key for pending SHM write/read transfer tracking.""" + return (instance_id, key) + + def _is_shm_active(self) -> bool: + return self._shm_active def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: """Resolve object keys from an IPC cache key. @@ -419,7 +493,39 @@ def prepare_store( PrepareStoreResponse with empty slots for pickle mode. """ - return PrepareStoreResponse(context={}) + if not self._is_shm_active(): + return PrepareStoreResponse(context={}) + + obj_keys = self._resolve_obj_keys(key) + context = self.contexts.get(instance_id) + if context is None or context.non_cuda_metadata is None: + raise ValueError( + f"non-CUDA context not registered for instance ID {instance_id}" + ) + reserved = self.storage_manager.reserve_write( + obj_keys, context.non_cuda_metadata.layout_desc, "new" + ) + slots: list[dict] = [] + reserved_keys: list[ObjectKey] = [] + for obj_key in obj_keys: + memory_obj = reserved.get(obj_key) + if memory_obj is None or memory_obj.tensor is None: + continue + slots.append( + { + "offset": memory_obj.shm_offset, + "length": memory_obj.shm_byte_length, + "shape": list(memory_obj.tensor.shape), + "dtype": _dtype_to_name(memory_obj.tensor.dtype), + } + ) + reserved_keys.append(obj_key) + if not reserved_keys: + return PrepareStoreResponse(context={}) + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + with self._pending_shm_lock: + self._pending_shm_writes[transfer_key] = reserved_keys + return PrepareStoreResponse(context={"slots": slots}) @_lmcache_nvtx_annotate def commit_store( @@ -434,10 +540,24 @@ def commit_store( key: Cache key for the token range to store. instance_id: Worker instance identifier. cpu_data: Pickled list of CPU tensors produced by the worker. + In SHM mode, empty bytes (``b""``) indicate data is already + written to SHM and only lock finalization is required. Returns: ``True`` when all reserved objects are written, otherwise ``False``. """ + if cpu_data == b"" and self._is_shm_active(): + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + with self._pending_shm_lock: + reserved_keys = self._pending_shm_writes.pop(transfer_key, None) + # Missing transfer key means COMMIT arrived without matching PREPARE. + if reserved_keys is None: + return False + # Empty reservation is a valid no-op when all objects already exist. + if reserved_keys: + self.storage_manager.finish_write(reserved_keys) + return True + obj_keys = self._resolve_obj_keys(key) context = self.contexts.get(instance_id) @@ -495,14 +615,49 @@ def prepare_retrieve( f"non-CUDA context not registered for instance ID {instance_id}" ) - prefetched_keys: list[ObjectKey] = [] + if self._is_shm_active(): + # Precondition: read locks for these keys were acquired during the + # lookup/prefetch phase before this retrieve call. + shm_prefetched_keys, shm_memory_objs = self.storage_manager.unsafe_read( + obj_keys + ) + if ( + not shm_memory_objs + or len(shm_prefetched_keys) != len(obj_keys) + or len(shm_memory_objs) != len(obj_keys) + ): + if shm_prefetched_keys: + self.storage_manager.finish_read_prefetched(shm_prefetched_keys) + return PrepareRetrieveResponse(success=False, data=b"", context={}) + slots: list[dict] = [] + for memory_obj in shm_memory_objs: + if memory_obj.tensor is None: + self.storage_manager.finish_read_prefetched(shm_prefetched_keys) + return PrepareRetrieveResponse(success=False, data=b"", context={}) + slots.append( + { + "offset": memory_obj.shm_offset, + "length": memory_obj.shm_byte_length, + "shape": list(memory_obj.tensor.shape), + "dtype": _dtype_to_name(memory_obj.tensor.dtype), + } + ) + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + with self._pending_shm_lock: + self._pending_shm_reads[transfer_key] = shm_prefetched_keys + return PrepareRetrieveResponse( + success=True, data=b"", context={"slots": slots} + ) + + prefetched_keys_pickle: list[ObjectKey] = [] try: - with self.storage_manager.read_prefetched_results(obj_keys) as memory_objs: - if not memory_objs or len(memory_objs) != len(obj_keys): + read_ctx = self.storage_manager.read_prefetched_results(obj_keys) + with read_ctx as maybe_memory_objs: + if not maybe_memory_objs or len(maybe_memory_objs) != len(obj_keys): return PrepareRetrieveResponse(success=False, data=b"", context={}) - prefetched_keys = obj_keys[: len(memory_objs)] + prefetched_keys_pickle = obj_keys[: len(maybe_memory_objs)] chunks = [] - for memory_obj in memory_objs: + for memory_obj in maybe_memory_objs: if memory_obj.tensor is None: return PrepareRetrieveResponse( success=False, data=b"", context={} @@ -512,8 +667,8 @@ def prepare_retrieve( success=True, data=pickle.dumps(chunks), context={} ) finally: - if prefetched_keys: - self.storage_manager.finish_read_prefetched(prefetched_keys) + if prefetched_keys_pickle: + self.storage_manager.finish_read_prefetched(prefetched_keys_pickle) @_lmcache_nvtx_annotate def commit_retrieve( @@ -530,6 +685,13 @@ def commit_retrieve( Returns: Always ``True``. """ + if self._is_shm_active(): + transfer_key = self._make_non_gpu_transfer_key(key, instance_id) + with self._pending_shm_lock: + prefetched_keys = self._pending_shm_reads.pop(transfer_key, []) + if prefetched_keys: + self.storage_manager.finish_read_prefetched(prefetched_keys) + return True return True @_lmcache_nvtx_annotate diff --git a/lmcache/v1/multiprocess/transfer_context.py b/lmcache/v1/multiprocess/transfer_context.py index 2a598791bb6..610f867d242 100644 --- a/lmcache/v1/multiprocess/transfer_context.py +++ b/lmcache/v1/multiprocess/transfer_context.py @@ -25,6 +25,7 @@ scatter_cpu_to_paged_kv, ) from lmcache.v1.multiprocess.protocol import RequestType +from lmcache.v1.multiprocess.protocols.engine import RegisterNonGpuContextResponse logger = init_logger(__name__) @@ -291,14 +292,25 @@ def register( ) ], ) + response = future.result(timeout=mq_timeout) + shm_name = "" + pool_size = 0 + if isinstance(response, RegisterNonGpuContextResponse): + shm_name = response.shm_name + pool_size = response.pool_size metadata = NonGpuContextMetadata( layout_desc=layout_desc, block_size=block_size, use_mla=use_mla_flag, ) - self._non_gpu_context = create_non_gpu_context(metadata, mq_client, mq_timeout) - future.result(timeout=mq_timeout) + self._non_gpu_context = create_non_gpu_context( + metadata, + mq_client, + mq_timeout, + shm_name=shm_name, + pool_size=pool_size, + ) def submit_store( self, diff --git a/tests/v1/distributed/test_shm_l1_pool.py b/tests/v1/distributed/test_shm_l1_pool.py new file mode 100644 index 00000000000..7919214f85d --- /dev/null +++ b/tests/v1/distributed/test_shm_l1_pool.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import MagicMock +import mmap +import os + +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.distributed.api import MemoryLayoutDesc +from lmcache.v1.distributed.config import L1MemoryManagerConfig +from lmcache.v1.distributed.memory_manager import create_memory_allocator +from lmcache.v1.memory_management import MixedMemoryAllocator +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContextMetadata, + create_non_gpu_context, +) +from lmcache.v1.multiprocess.non_gpu_context_pickle import NonGpuContextPickle +from lmcache.v1.multiprocess.non_gpu_context_shm import NonGpuContextShm +from lmcache.v1.multiprocess.protocol import RequestType +from lmcache.v1.multiprocess.protocols.engine import ( + PrepareRetrieveResponse, + PrepareStoreResponse, +) + + +class _CompletedFuture: + def __init__(self, value): + self._value = value + + def result(self, timeout=None): # noqa: ARG002 + return self._value + + +def _create_shm_file(shm_name: str, size: int) -> str: + path = os.path.join("/dev/shm", shm_name.lstrip("/")) + fd = os.open(path, os.O_CREAT | os.O_RDWR, 0o600) + os.ftruncate(fd, size) + os.close(fd) + return path + + +def test_shm_segment_creation_and_cleanup() -> None: + shm_name = f"lmcache_l1_pool_test_{os.getpid()}" + cfg = L1MemoryManagerConfig( + size_in_bytes=1024 * 1024, + use_lazy=False, + shm_name=shm_name, + ) + allocator = create_memory_allocator(cfg) + assert isinstance(allocator, MixedMemoryAllocator) + assert allocator.shm_name == shm_name + shm_path = os.path.join("/dev/shm", shm_name) + assert os.path.exists(shm_path) + allocator.close() + assert not os.path.exists(shm_path) + + +def test_non_gpu_context_shm_tensor_view_from_buffer() -> None: + shm_name = f"lmcache_test_view_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + try: + with open(shm_path, "r+b") as f: + mm = mmap.mmap(f.fileno(), 4096, access=mmap.ACCESS_WRITE) + src = torch.arange(8, dtype=torch.float32).reshape(2, 4) + mm[: src.numel() * src.element_size()] = src.numpy().tobytes() + mm.close() + + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 4])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + view = context._make_tensor_view( + offset=0, + length=src.numel() * src.element_size(), + shape=[2, 4], + dtype_str="float32", + ) + assert torch.equal(view, src) + finally: + context.close() + finally: + if os.path.exists(shm_path): + os.unlink(shm_path) + + +def test_non_gpu_context_shm_store_retrieve_flow_with_mocked_mq() -> None: + shm_name = f"lmcache_test_flow_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + slots = [ + { + "offset": 0, + "length": 16, + "shape": [2, 2], + "dtype": "float32", + } + ] + + mq_client = MagicMock() + + def _submit_request(req_type, payload, response_cls): # noqa: ARG001 + if req_type == RequestType.PREPARE_STORE: + return _CompletedFuture(PrepareStoreResponse(context={"slots": slots})) + if req_type == RequestType.COMMIT_STORE: + _, _, commit_cpu_data = payload + assert commit_cpu_data == b"" + return _CompletedFuture(True) + if req_type == RequestType.PREPARE_RETRIEVE: + return _CompletedFuture( + PrepareRetrieveResponse( + success=True, data=b"", context={"slots": slots} + ) + ) + if req_type == RequestType.COMMIT_RETRIEVE: + return _CompletedFuture(True) + raise AssertionError(f"Unexpected request type: {req_type}") + + mq_client.submit_request.side_effect = _submit_request + + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=mq_client, + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + store_views = context.prepare_store(key="k", instance_id=1) + assert store_views is not None + store_views[0].copy_( + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + ) + assert context.commit_store("k", 1, store_views) + + retrieve_views = context.prepare_retrieve(key="k", instance_id=1) + assert retrieve_views is not None + assert torch.equal( + retrieve_views[0], + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32), + ) + assert context.commit_retrieve("k", 1) + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) + + +def test_non_gpu_context_shm_init_raises_when_segment_missing() -> None: + with pytest.raises(FileNotFoundError, match="No such file or directory"): + NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name="lmcache_missing_shm_segment", + pool_size=4096, + ) + + +def test_create_non_gpu_context_falls_back_to_pickle_without_shm_info() -> None: + context = create_non_gpu_context( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name="", + pool_size=0, + ) + assert isinstance(context, NonGpuContextPickle) + + +def test_non_gpu_context_shm_close_is_idempotent() -> None: + shm_name = f"lmcache_test_close_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + try: + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + context.close() + context.close() + finally: + if os.path.exists(shm_path): + os.unlink(shm_path) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index 5da7dc47aca..05fb83def6a 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -459,3 +459,152 @@ def _read_prefetched_results(_keys: Any) -> Any: recovered_chunks: list[torch.Tensor] = pickle.loads(cpu_data) assert len(recovered_chunks) == 1 assert torch.allclose(recovered_chunks[0], payload) + + +def test_server_shm_commit_store_allows_noop_when_all_keys_exist( + stub_native_storage_ops: Any, +) -> None: + """Regression: repeated prompt after worker restart should no-op-store cleanly. + + When all object keys already exist in cache, SHM ``prepare_store`` reserves + no new objects and returns empty context (no "slots" key). The worker sees + no slots and does not call ``commit_store``, so no entry leaks in + ``_pending_shm_writes``. + """ + # First Party + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) + from lmcache.v1.multiprocess.server import MPCacheEngine + + mock_storage = MagicMock() + mock_storage.get_shm_pool_info.return_value = { + "shm_name": "lmcache_test_pool", + "pool_size": 1024, + } + # Empty reserve_write indicates all object keys already exist in cache. + mock_storage.reserve_write.return_value = {} + mock_session = MagicMock() + mock_session.get_hashes.return_value = [b"h"] + + with ( + patch( + "lmcache.v1.multiprocess.server.StorageManager", + return_value=mock_storage, + ), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ), + ): + session_cls.return_value.get_or_create.return_value = mock_session + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + + engine.register_kv_cache_non_gpu_context( + RegisterNonGpuContextPayload( + instance_id=3, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + ) + key = IPCCacheEngineKey.from_token_ids( + "m", + 1, + 0, + [1] * 8, + start=0, + end=8, + request_id="req", + ) + prepare_response = engine.prepare_store(key, 3) + # Empty context means no slots reserved — worker won't call commit_store. + assert prepare_response.context == {} + + # commit_store without a matching prepare must fail (no entry leaked). + assert engine.commit_store(key, 3, b"") is False + + +def test_server_unregister_non_gpu_context_releases_pending_shm_locks( + stub_native_storage_ops: Any, +) -> None: + """Ensure unregister releases pending SHM read/write reservations.""" + # First Party + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) + from lmcache.v1.multiprocess.server import MPCacheEngine + + mock_storage = MagicMock() + mock_storage.get_shm_pool_info.return_value = { + "shm_name": "lmcache_l1_pool_test", + "pool_size": 4096, + } + mock_memory_obj = MagicMock() + mock_memory_obj.tensor = torch.zeros(2, 2, 8, 16) + mock_memory_obj.shm_offset = 0 + mock_memory_obj.shm_byte_length = 2048 + mock_storage.reserve_write.side_effect = ( + lambda obj_keys, *_args, **_kwargs: { + obj_key: mock_memory_obj for obj_key in obj_keys + } + ) + mock_storage.unsafe_read.side_effect = ( + lambda obj_keys: (obj_keys, [mock_memory_obj for _ in obj_keys]) + ) + mock_session = MagicMock() + mock_session.get_hashes.return_value = [b"h"] + + with ( + patch( + "lmcache.v1.multiprocess.server.StorageManager", + return_value=mock_storage, + ), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ), + ): + session_cls.return_value.get_or_create.return_value = mock_session + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + + engine.register_kv_cache_non_gpu_context( + RegisterNonGpuContextPayload( + instance_id=4, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + ) + key = IPCCacheEngineKey.from_token_ids( + "m", + 1, + 0, + [1] * 8, + start=0, + end=8, + request_id="req", + ) + assert engine.prepare_store(key, 4).context.get("slots") + assert engine.prepare_retrieve(key, 4).success is True + + engine.unregister_kv_cache(4) + + mock_storage.finish_write.assert_called_once() + mock_storage.finish_read_prefetched.assert_called_once()