diff --git a/lmcache/v1/distributed/memory_manager.py b/lmcache/v1/distributed/memory_manager.py index b571113e316..5cbede5c057 100644 --- a/lmcache/v1/distributed/memory_manager.py +++ b/lmcache/v1/distributed/memory_manager.py @@ -20,18 +20,6 @@ logger = init_logger(__name__) -# HELPER FUNCTIONS -def _check_shm_capacity(required_bytes: int) -> bool: - """Return whether ``/dev/shm`` has enough free space.""" - if required_bytes <= 0: - return True - try: - free_bytes = shutil.disk_usage("/dev/shm").free - except OSError: - return False - return free_bytes >= required_bytes - - def _unlink_stale_shm(shm_name: str) -> None: """Remove a stale LMCache shm segment if it exists.""" normalized = shm_name.lstrip("/") diff --git a/lmcache/v1/multiprocess/non_gpu_context_shm.py b/lmcache/v1/multiprocess/non_gpu_context_shm.py index acbe1a5a0a6..cfa7c58748c 100644 --- a/lmcache/v1/multiprocess/non_gpu_context_shm.py +++ b/lmcache/v1/multiprocess/non_gpu_context_shm.py @@ -16,6 +16,8 @@ ) from lmcache.v1.multiprocess.protocol import RequestType, get_response_class +INVALID_SHM_FD = -1 + class NonGpuContextShm(NonGpuContext): """Shared-memory implementation of :class:`NonGpuContext`.""" @@ -34,14 +36,17 @@ def __init__( self._shm_name = shm_name self._pool_size = pool_size + self._shm_fd = INVALID_SHM_FD shm_path = os.path.join("/dev/shm", shm_name.lstrip("/")) - shm_fd = os.open(shm_path, os.O_RDWR) + self._shm_fd = os.open(shm_path, os.O_RDWR) try: self._mmap_obj = mmap.mmap( - shm_fd, self._pool_size, access=mmap.ACCESS_WRITE + self._shm_fd, self._pool_size, access=mmap.ACCESS_WRITE ) - finally: - os.close(shm_fd) + except Exception: + os.close(self._shm_fd) + self._shm_fd = INVALID_SHM_FD + raise def _make_tensor_view( self, @@ -84,8 +89,11 @@ def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None response = future.result(timeout=self.mq_timeout) except TimeoutError: return None - slots = response.context.get("slots", []) - return self._build_slot_tensors(slots) if slots else None + context = response.context if isinstance(response.context, dict) else {} + slots = context.get("slots") + if not isinstance(slots, list) or not slots: + return None + return self._build_slot_tensors(slots) def commit_store( self, key: Any, instance_id: int, _chunks: list[torch.Tensor] @@ -127,4 +135,11 @@ def commit_retrieve(self, key: Any, instance_id: int) -> bool: return False def close(self) -> None: - self._mmap_obj.close() + if self._shm_fd == INVALID_SHM_FD: + return + try: + self._mmap_obj.close() + finally: + fd = self._shm_fd + self._shm_fd = INVALID_SHM_FD + os.close(fd) diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index 922b287dd3f..794fa7d79b5 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -260,8 +260,17 @@ def __init__( # for crash resilience (e.g., client calls lookup but never queries) self._prefetch_jobs: dict[str, _PrefetchJob] = {} self._prefetch_job_lock = threading.Lock() - self._pending_shm_writes: dict[tuple[object, ...], list[ObjectKey]] = {} - self._pending_shm_reads: dict[tuple[object, ...], list[ObjectKey]] = {} + # Pending SHM transfer tracking, keyed by (instance_id, IPC key). + # IPCCacheEngineKey is a frozen dataclass and hashable, so it is safe + # and efficient for dict lookups across pending transfer tracking. + self._pending_shm_writes: dict[ + tuple[int, IPCCacheEngineKey], list[ObjectKey] + ] = {} + self._pending_shm_reads: dict[ + tuple[int, IPCCacheEngineKey], list[ObjectKey] + ] = {} + self._pending_shm_lock = threading.Lock() + self._shm_active = False self._setup_metrics() @@ -341,12 +350,23 @@ def unregister_kv_cache(self, instance_id: int) -> None: torch_dev.empty_cache() else: logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) - self._pending_shm_writes = { - k: v for k, v in self._pending_shm_writes.items() if k[0] != instance_id - } - self._pending_shm_reads = { - k: v for k, v in self._pending_shm_reads.items() if k[0] != instance_id - } + with self._pending_shm_lock: + stale_writes = { + k: v for k, v in self._pending_shm_writes.items() if k[0] == instance_id + } + for transfer_key in stale_writes: + del self._pending_shm_writes[transfer_key] + stale_reads = { + k: v for k, v in self._pending_shm_reads.items() if k[0] == instance_id + } + for transfer_key in stale_reads: + del self._pending_shm_reads[transfer_key] + for reserved_keys in stale_writes.values(): + if reserved_keys: + self.storage_manager.finish_write(reserved_keys) + for prefetched_keys in stale_reads.values(): + if prefetched_keys: + self.storage_manager.finish_read_prefetched(prefetched_keys) def register_kv_cache_non_gpu_context( self, @@ -397,17 +417,21 @@ def register_kv_cache_non_gpu_context( ) shm_pool_info = self.storage_manager.get_shm_pool_info() if not isinstance(shm_pool_info, dict): + self._shm_active = False logger.info( "Instance %s non-GPU context using pickle transport " "(no SHM pool info returned)", payload.instance_id, ) return RegisterNonGpuContextResponse() + shm_name = str(shm_pool_info.get("shm_name", "")) + pool_size = int(shm_pool_info.get("pool_size", 0)) + self._shm_active = bool(shm_name) and pool_size > 0 response = RegisterNonGpuContextResponse( - shm_name=str(shm_pool_info.get("shm_name", "")), - pool_size=int(shm_pool_info.get("pool_size", 0)), + shm_name=shm_name, + pool_size=pool_size, ) - if response.shm_name and response.pool_size > 0: + if self._shm_active: logger.info( "Instance %s non-GPU context using SHM transport " "(shm_name=%s, pool_size=%d)", @@ -425,28 +449,12 @@ def register_kv_cache_non_gpu_context( @staticmethod def _make_non_gpu_transfer_key( key: IPCCacheEngineKey, instance_id: int - ) -> tuple[object, ...]: + ) -> tuple[int, IPCCacheEngineKey]: """Build a unique key for pending SHM write/read transfer tracking.""" - return ( - instance_id, - key.model_name, - key.world_size, - key.worker_id, - key.token_ids, - key.start, - key.end, - key.request_id, - key.cache_salt, - ) + return (instance_id, key) def _is_shm_active(self) -> bool: - shm_pool_info = self.storage_manager.get_shm_pool_info() - if not isinstance(shm_pool_info, dict): - return False - return ( - bool(shm_pool_info.get("shm_name")) - and int(shm_pool_info.get("pool_size", 0)) > 0 - ) + return self._shm_active def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: """Resolve object keys from an IPC cache key. @@ -512,8 +520,11 @@ def prepare_store( } ) reserved_keys.append(obj_key) + if not reserved_keys: + return PrepareStoreResponse(context={}) transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - self._pending_shm_writes[transfer_key] = reserved_keys + with self._pending_shm_lock: + self._pending_shm_writes[transfer_key] = reserved_keys return PrepareStoreResponse(context={"slots": slots}) @_lmcache_nvtx_annotate @@ -537,7 +548,8 @@ def commit_store( """ if cpu_data == b"" and self._is_shm_active(): transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - reserved_keys = self._pending_shm_writes.pop(transfer_key, None) + with self._pending_shm_lock: + reserved_keys = self._pending_shm_writes.pop(transfer_key, None) # Missing transfer key means COMMIT arrived without matching PREPARE. if reserved_keys is None: return False @@ -604,10 +616,16 @@ def prepare_retrieve( ) if self._is_shm_active(): + # Precondition: read locks for these keys were acquired during the + # lookup/prefetch phase before this retrieve call. shm_prefetched_keys, shm_memory_objs = self.storage_manager.unsafe_read( obj_keys ) - if not shm_memory_objs or len(shm_prefetched_keys) != len(obj_keys): + if ( + not shm_memory_objs + or len(shm_prefetched_keys) != len(obj_keys) + or len(shm_memory_objs) != len(obj_keys) + ): if shm_prefetched_keys: self.storage_manager.finish_read_prefetched(shm_prefetched_keys) return PrepareRetrieveResponse(success=False, data=b"", context={}) @@ -625,7 +643,8 @@ def prepare_retrieve( } ) transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - self._pending_shm_reads[transfer_key] = shm_prefetched_keys + with self._pending_shm_lock: + self._pending_shm_reads[transfer_key] = shm_prefetched_keys return PrepareRetrieveResponse( success=True, data=b"", context={"slots": slots} ) @@ -668,7 +687,8 @@ def commit_retrieve( """ if self._is_shm_active(): transfer_key = self._make_non_gpu_transfer_key(key, instance_id) - prefetched_keys = self._pending_shm_reads.pop(transfer_key, []) + with self._pending_shm_lock: + prefetched_keys = self._pending_shm_reads.pop(transfer_key, []) if prefetched_keys: self.storage_manager.finish_read_prefetched(prefetched_keys) return True diff --git a/tests/v1/distributed/test_shm_l1_pool.py b/tests/v1/distributed/test_shm_l1_pool.py index 3ac3102cf0d..7919214f85d 100644 --- a/tests/v1/distributed/test_shm_l1_pool.py +++ b/tests/v1/distributed/test_shm_l1_pool.py @@ -6,6 +6,7 @@ import os # Third Party +import pytest import torch # First Party @@ -13,7 +14,11 @@ from lmcache.v1.distributed.config import L1MemoryManagerConfig from lmcache.v1.distributed.memory_manager import create_memory_allocator from lmcache.v1.memory_management import MixedMemoryAllocator -from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata +from lmcache.v1.multiprocess.non_gpu_context import ( + NonGpuContextMetadata, + create_non_gpu_context, +) +from lmcache.v1.multiprocess.non_gpu_context_pickle import NonGpuContextPickle from lmcache.v1.multiprocess.non_gpu_context_shm import NonGpuContextShm from lmcache.v1.multiprocess.protocol import RequestType from lmcache.v1.multiprocess.protocols.engine import ( @@ -159,3 +164,64 @@ def _submit_request(req_type, payload, response_cls): # noqa: ARG001 context.close() if os.path.exists(shm_path): os.unlink(shm_path) + + +def test_non_gpu_context_shm_init_raises_when_segment_missing() -> None: + with pytest.raises(FileNotFoundError, match="No such file or directory"): + NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name="lmcache_missing_shm_segment", + pool_size=4096, + ) + + +def test_create_non_gpu_context_falls_back_to_pickle_without_shm_info() -> None: + context = create_non_gpu_context( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name="", + pool_size=0, + ) + assert isinstance(context, NonGpuContextPickle) + + +def test_non_gpu_context_shm_close_is_idempotent() -> None: + shm_name = f"lmcache_test_close_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + try: + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + context.close() + context.close() + finally: + if os.path.exists(shm_path): + os.unlink(shm_path) diff --git a/tests/v1/multiprocess/test_non_cuda_context.py b/tests/v1/multiprocess/test_non_cuda_context.py index fe81826646d..05fb83def6a 100644 --- a/tests/v1/multiprocess/test_non_cuda_context.py +++ b/tests/v1/multiprocess/test_non_cuda_context.py @@ -467,9 +467,9 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( """Regression: repeated prompt after worker restart should no-op-store cleanly. When all object keys already exist in cache, SHM ``prepare_store`` reserves - no new objects and returns empty slots. ``commit_store`` must still succeed - as a valid no-op for that prepared transfer, but fail without a matching - prepare state. + no new objects and returns empty context (no "slots" key). The worker sees + no slots and does not call ``commit_store``, so no entry leaks in + ``_pending_shm_writes``. """ # First Party from lmcache.v1.multiprocess.custom_types import ( @@ -487,6 +487,7 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( mock_storage.reserve_write.return_value = {} mock_session = MagicMock() mock_session.get_hashes.return_value = [b"h"] + with ( patch( "lmcache.v1.multiprocess.server.StorageManager", @@ -525,11 +526,85 @@ def test_server_shm_commit_store_allows_noop_when_all_keys_exist( request_id="req", ) prepare_response = engine.prepare_store(key, 3) - assert prepare_response.context["slots"] == [] - - store_ok = engine.commit_store(key, 3, b"") - assert store_ok is True - mock_storage.finish_write.assert_not_called() + # Empty context means no slots reserved — worker won't call commit_store. + assert prepare_response.context == {} - # A second commit without a matching prepare must fail. + # commit_store without a matching prepare must fail (no entry leaked). assert engine.commit_store(key, 3, b"") is False + + +def test_server_unregister_non_gpu_context_releases_pending_shm_locks( + stub_native_storage_ops: Any, +) -> None: + """Ensure unregister releases pending SHM read/write reservations.""" + # First Party + from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, + ) + from lmcache.v1.multiprocess.server import MPCacheEngine + + mock_storage = MagicMock() + mock_storage.get_shm_pool_info.return_value = { + "shm_name": "lmcache_l1_pool_test", + "pool_size": 4096, + } + mock_memory_obj = MagicMock() + mock_memory_obj.tensor = torch.zeros(2, 2, 8, 16) + mock_memory_obj.shm_offset = 0 + mock_memory_obj.shm_byte_length = 2048 + mock_storage.reserve_write.side_effect = ( + lambda obj_keys, *_args, **_kwargs: { + obj_key: mock_memory_obj for obj_key in obj_keys + } + ) + mock_storage.unsafe_read.side_effect = ( + lambda obj_keys: (obj_keys, [mock_memory_obj for _ in obj_keys]) + ) + mock_session = MagicMock() + mock_session.get_hashes.return_value = [b"h"] + + with ( + patch( + "lmcache.v1.multiprocess.server.StorageManager", + return_value=mock_storage, + ), + patch("lmcache.v1.multiprocess.server.TokenHasher"), + patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch( + "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + return_value=["obj"], + ), + ): + session_cls.return_value.get_or_create.return_value = mock_session + engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + + engine.register_kv_cache_non_gpu_context( + RegisterNonGpuContextPayload( + instance_id=4, + model_name="m", + world_size=1, + block_size=4, + num_layers=2, + hidden_dim_size=16, + dtype_str="float32", + use_mla=False, + ) + ) + key = IPCCacheEngineKey.from_token_ids( + "m", + 1, + 0, + [1] * 8, + start=0, + end=8, + request_id="req", + ) + assert engine.prepare_store(key, 4).context.get("slots") + assert engine.prepare_retrieve(key, 4).success is True + + engine.unregister_kv_cache(4) + + mock_storage.finish_write.assert_called_once() + mock_storage.finish_read_prefetched.assert_called_once()