Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions lmcache/v1/distributed/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@
logger = init_logger(__name__)


# HELPER FUNCTIONS
def _check_shm_capacity(required_bytes: int) -> bool:
"""Return whether ``/dev/shm`` has enough free space."""
if required_bytes <= 0:
return True
try:
free_bytes = shutil.disk_usage("/dev/shm").free
except OSError:
return False
return free_bytes >= required_bytes


def _unlink_stale_shm(shm_name: str) -> None:
"""Remove a stale LMCache shm segment if it exists."""
normalized = shm_name.lstrip("/")
Expand Down
29 changes: 22 additions & 7 deletions lmcache/v1/multiprocess/non_gpu_context_shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class

INVALID_SHM_FD = -1


class NonGpuContextShm(NonGpuContext):
"""Shared-memory implementation of :class:`NonGpuContext`."""
Expand All @@ -34,14 +36,17 @@ def __init__(

self._shm_name = shm_name
self._pool_size = pool_size
self._shm_fd = INVALID_SHM_FD
shm_path = os.path.join("/dev/shm", shm_name.lstrip("/"))
shm_fd = os.open(shm_path, os.O_RDWR)
self._shm_fd = os.open(shm_path, os.O_RDWR)
try:
self._mmap_obj = mmap.mmap(
shm_fd, self._pool_size, access=mmap.ACCESS_WRITE
self._shm_fd, self._pool_size, access=mmap.ACCESS_WRITE
)
finally:
os.close(shm_fd)
except Exception:
os.close(self._shm_fd)
self._shm_fd = INVALID_SHM_FD
raise

def _make_tensor_view(
self,
Expand Down Expand Up @@ -84,8 +89,11 @@ def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None
response = future.result(timeout=self.mq_timeout)
except TimeoutError:
return None
slots = response.context.get("slots", [])
return self._build_slot_tensors(slots) if slots else None
context = response.context if isinstance(response.context, dict) else {}
slots = context.get("slots")
if not isinstance(slots, list) or not slots:
return None
return self._build_slot_tensors(slots)

def commit_store(
self, key: Any, instance_id: int, _chunks: list[torch.Tensor]
Expand Down Expand Up @@ -127,4 +135,11 @@ def commit_retrieve(self, key: Any, instance_id: int) -> bool:
return False

def close(self) -> None:
self._mmap_obj.close()
if self._shm_fd == INVALID_SHM_FD:
return
try:
self._mmap_obj.close()
finally:
fd = self._shm_fd
self._shm_fd = INVALID_SHM_FD
os.close(fd)
90 changes: 55 additions & 35 deletions lmcache/v1/multiprocess/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,17 @@ def __init__(
# for crash resilience (e.g., client calls lookup but never queries)
self._prefetch_jobs: dict[str, _PrefetchJob] = {}
self._prefetch_job_lock = threading.Lock()
self._pending_shm_writes: dict[tuple[object, ...], list[ObjectKey]] = {}
self._pending_shm_reads: dict[tuple[object, ...], list[ObjectKey]] = {}
# Pending SHM transfer tracking, keyed by (instance_id, IPC key).
# IPCCacheEngineKey is a frozen dataclass and hashable, so it is safe
# and efficient for dict lookups across pending transfer tracking.
self._pending_shm_writes: dict[
tuple[int, IPCCacheEngineKey], list[ObjectKey]
] = {}
self._pending_shm_reads: dict[
tuple[int, IPCCacheEngineKey], list[ObjectKey]
] = {}
self._pending_shm_lock = threading.Lock()
self._shm_active = False

self._setup_metrics()

Expand Down Expand Up @@ -341,12 +350,23 @@ def unregister_kv_cache(self, instance_id: int) -> None:
torch_dev.empty_cache()
else:
logger.info("Unregistered non-CUDA context for instance ID %d", instance_id)
self._pending_shm_writes = {
k: v for k, v in self._pending_shm_writes.items() if k[0] != instance_id
}
self._pending_shm_reads = {
k: v for k, v in self._pending_shm_reads.items() if k[0] != instance_id
}
with self._pending_shm_lock:
stale_writes = {
k: v for k, v in self._pending_shm_writes.items() if k[0] == instance_id
}
for transfer_key in stale_writes:
del self._pending_shm_writes[transfer_key]
stale_reads = {
k: v for k, v in self._pending_shm_reads.items() if k[0] == instance_id
}
for transfer_key in stale_reads:
del self._pending_shm_reads[transfer_key]
for reserved_keys in stale_writes.values():
if reserved_keys:
self.storage_manager.finish_write(reserved_keys)
for prefetched_keys in stale_reads.values():
if prefetched_keys:
self.storage_manager.finish_read_prefetched(prefetched_keys)

def register_kv_cache_non_gpu_context(
self,
Expand Down Expand Up @@ -397,17 +417,21 @@ def register_kv_cache_non_gpu_context(
)
shm_pool_info = self.storage_manager.get_shm_pool_info()
if not isinstance(shm_pool_info, dict):
self._shm_active = False
logger.info(
"Instance %s non-GPU context using pickle transport "
"(no SHM pool info returned)",
payload.instance_id,
)
return RegisterNonGpuContextResponse()
shm_name = str(shm_pool_info.get("shm_name", ""))
pool_size = int(shm_pool_info.get("pool_size", 0))
self._shm_active = bool(shm_name) and pool_size > 0
response = RegisterNonGpuContextResponse(
shm_name=str(shm_pool_info.get("shm_name", "")),
pool_size=int(shm_pool_info.get("pool_size", 0)),
shm_name=shm_name,
pool_size=pool_size,
)
if response.shm_name and response.pool_size > 0:
if self._shm_active:
logger.info(
"Instance %s non-GPU context using SHM transport "
"(shm_name=%s, pool_size=%d)",
Expand All @@ -425,28 +449,12 @@ def register_kv_cache_non_gpu_context(
@staticmethod
def _make_non_gpu_transfer_key(
key: IPCCacheEngineKey, instance_id: int
) -> tuple[object, ...]:
) -> tuple[int, IPCCacheEngineKey]:
"""Build a unique key for pending SHM write/read transfer tracking."""
return (
instance_id,
key.model_name,
key.world_size,
key.worker_id,
key.token_ids,
key.start,
key.end,
key.request_id,
key.cache_salt,
)
return (instance_id, key)

def _is_shm_active(self) -> bool:
shm_pool_info = self.storage_manager.get_shm_pool_info()
if not isinstance(shm_pool_info, dict):
return False
return (
bool(shm_pool_info.get("shm_name"))
and int(shm_pool_info.get("pool_size", 0)) > 0
)
return self._shm_active

def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]:
"""Resolve object keys from an IPC cache key.
Expand Down Expand Up @@ -512,8 +520,11 @@ def prepare_store(
}
)
reserved_keys.append(obj_key)
if not reserved_keys:
return PrepareStoreResponse(context={})
transfer_key = self._make_non_gpu_transfer_key(key, instance_id)
self._pending_shm_writes[transfer_key] = reserved_keys
with self._pending_shm_lock:
self._pending_shm_writes[transfer_key] = reserved_keys
return PrepareStoreResponse(context={"slots": slots})

@_lmcache_nvtx_annotate
Expand All @@ -537,7 +548,8 @@ def commit_store(
"""
if cpu_data == b"" and self._is_shm_active():
transfer_key = self._make_non_gpu_transfer_key(key, instance_id)
reserved_keys = self._pending_shm_writes.pop(transfer_key, None)
with self._pending_shm_lock:
reserved_keys = self._pending_shm_writes.pop(transfer_key, None)
# Missing transfer key means COMMIT arrived without matching PREPARE.
if reserved_keys is None:
return False
Expand Down Expand Up @@ -604,10 +616,16 @@ def prepare_retrieve(
)

if self._is_shm_active():
# Precondition: read locks for these keys were acquired during the
# lookup/prefetch phase before this retrieve call.
shm_prefetched_keys, shm_memory_objs = self.storage_manager.unsafe_read(
obj_keys
)
if not shm_memory_objs or len(shm_prefetched_keys) != len(obj_keys):
if (
not shm_memory_objs
or len(shm_prefetched_keys) != len(obj_keys)
or len(shm_memory_objs) != len(obj_keys)
):
if shm_prefetched_keys:
self.storage_manager.finish_read_prefetched(shm_prefetched_keys)
return PrepareRetrieveResponse(success=False, data=b"", context={})
Expand All @@ -625,7 +643,8 @@ def prepare_retrieve(
}
)
transfer_key = self._make_non_gpu_transfer_key(key, instance_id)
self._pending_shm_reads[transfer_key] = shm_prefetched_keys
with self._pending_shm_lock:
self._pending_shm_reads[transfer_key] = shm_prefetched_keys
return PrepareRetrieveResponse(
success=True, data=b"", context={"slots": slots}
)
Expand Down Expand Up @@ -668,7 +687,8 @@ def commit_retrieve(
"""
if self._is_shm_active():
transfer_key = self._make_non_gpu_transfer_key(key, instance_id)
prefetched_keys = self._pending_shm_reads.pop(transfer_key, [])
with self._pending_shm_lock:
prefetched_keys = self._pending_shm_reads.pop(transfer_key, [])
if prefetched_keys:
self.storage_manager.finish_read_prefetched(prefetched_keys)
return True
Expand Down
68 changes: 67 additions & 1 deletion tests/v1/distributed/test_shm_l1_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
import os

# Third Party
import pytest
import torch

# First Party
from lmcache.v1.distributed.api import MemoryLayoutDesc
from lmcache.v1.distributed.config import L1MemoryManagerConfig
from lmcache.v1.distributed.memory_manager import create_memory_allocator
from lmcache.v1.memory_management import MixedMemoryAllocator
from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata
from lmcache.v1.multiprocess.non_gpu_context import (
NonGpuContextMetadata,
create_non_gpu_context,
)
from lmcache.v1.multiprocess.non_gpu_context_pickle import NonGpuContextPickle
from lmcache.v1.multiprocess.non_gpu_context_shm import NonGpuContextShm
from lmcache.v1.multiprocess.protocol import RequestType
from lmcache.v1.multiprocess.protocols.engine import (
Expand Down Expand Up @@ -159,3 +164,64 @@ def _submit_request(req_type, payload, response_cls): # noqa: ARG001
context.close()
if os.path.exists(shm_path):
os.unlink(shm_path)


def test_non_gpu_context_shm_init_raises_when_segment_missing() -> None:
with pytest.raises(FileNotFoundError, match="No such file or directory"):
NonGpuContextShm(
metadata=NonGpuContextMetadata(
layout_desc=MemoryLayoutDesc(
shapes=[torch.Size([2, 2])],
dtypes=[torch.float32],
),
block_size=1,
use_mla=False,
),
mq_client=MagicMock(),
mq_timeout=1.0,
shm_name="lmcache_missing_shm_segment",
pool_size=4096,
)


def test_create_non_gpu_context_falls_back_to_pickle_without_shm_info() -> None:
context = create_non_gpu_context(
metadata=NonGpuContextMetadata(
layout_desc=MemoryLayoutDesc(
shapes=[torch.Size([2, 2])],
dtypes=[torch.float32],
),
block_size=1,
use_mla=False,
),
mq_client=MagicMock(),
mq_timeout=1.0,
shm_name="",
pool_size=0,
)
assert isinstance(context, NonGpuContextPickle)


def test_non_gpu_context_shm_close_is_idempotent() -> None:
shm_name = f"lmcache_test_close_{os.getpid()}"
shm_path = _create_shm_file(shm_name, 4096)
try:
context = NonGpuContextShm(
metadata=NonGpuContextMetadata(
layout_desc=MemoryLayoutDesc(
shapes=[torch.Size([2, 2])],
dtypes=[torch.float32],
),
block_size=1,
use_mla=False,
),
mq_client=MagicMock(),
mq_timeout=1.0,
shm_name=shm_name,
pool_size=4096,
)
context.close()
context.close()
finally:
if os.path.exists(shm_path):
os.unlink(shm_path)
Loading