Skip to content
4 changes: 4 additions & 0 deletions lmcache/v1/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass, field
from typing import Literal
import argparse
import os

# First Party
from lmcache import torch_dev
Expand Down Expand Up @@ -39,6 +40,9 @@ class L1MemoryManagerConfig:
align_bytes: int = field(default=0x1000)
""" The alignment size in bytes. Default is 4KB. """

shm_name: str = field(default_factory=lambda: f"lmcache_l1_pool_{os.getpid()}")
""" POSIX shared-memory segment name for L1 pool. Empty disables SHM. """

def __post_init__(self):
self.init_size_in_bytes = min(self.init_size_in_bytes, self.size_in_bytes)

Expand Down
4 changes: 4 additions & 0 deletions lmcache/v1/distributed/l1_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,10 @@ def get_l1_memory_desc(self):
"""Return an L1MemoryDesc describing the underlying L1 memory buffer."""
return self._memory_manager.get_l1_memory_desc()

def get_shm_pool_info(self) -> dict:
"""Return SHM pool metadata for non-GPU SHM transport."""
return self._memory_manager.get_shm_pool_info()

def close(self) -> None:
"""Close the L1Manager and free all resources."""
with self._lock:
Expand Down
54 changes: 53 additions & 1 deletion lmcache/v1/distributed/memory_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import os
import shutil

# First Party
from lmcache.logging import init_logger
from lmcache.v1.distributed.api import MemoryLayoutDesc
Expand All @@ -16,7 +20,23 @@
logger = init_logger(__name__)


# HELPER FUNCTIONS
def _unlink_stale_shm(shm_name: str) -> None:
"""Remove a stale LMCache shm segment if it exists."""
normalized = shm_name.lstrip("/")
if "/" in normalized or "\\" in normalized:
logger.warning("Refusing to unlink invalid shm name %s", shm_name)
return
if not normalized.startswith("lmcache_l1_pool_"):
return
shm_path = os.path.join("/dev/shm", normalized)
try:
os.unlink(shm_path)
except FileNotFoundError:
return
except OSError:
logger.warning("Failed to remove stale shm segment %s", shm_path, exc_info=True)


def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInterface:
"""
Create a memory allocator based on the provided configuration.
Expand Down Expand Up @@ -45,6 +65,27 @@ def create_memory_allocator(config: L1MemoryManagerConfig) -> MemoryAllocatorInt
config.size_in_bytes,
config.align_bytes,
)
shm_name = config.shm_name
if shm_name:
try:
free_bytes = shutil.disk_usage("/dev/shm").free
if free_bytes < config.size_in_bytes:
raise RuntimeError(
"insufficient /dev/shm capacity: "
f"need {config.size_in_bytes} bytes, have {free_bytes} bytes"
)
_unlink_stale_shm(shm_name)
return MixedMemoryAllocator(
config.size_in_bytes,
align_bytes=config.align_bytes,
shm_name=shm_name,
)
except (RuntimeError, OSError, ValueError):
logger.warning(
"Failed to initialize SHM pool (%s), falling back to pickle path",
shm_name,
exc_info=True,
)
return MixedMemoryAllocator(
config.size_in_bytes,
align_bytes=config.align_bytes,
Expand All @@ -65,6 +106,13 @@ def __init__(self, config: L1MemoryManagerConfig):
self._allocator = create_memory_allocator(config)
self._size_in_bytes = config.size_in_bytes
self._align_bytes = config.align_bytes
self._shm_pool_info = {"shm_name": "", "pool_size": 0}
if isinstance(self._allocator, MixedMemoryAllocator):
if self._allocator.shm_name:
self._shm_pool_info = {
"shm_name": self._allocator.shm_name,
"pool_size": self._size_in_bytes,
}

def allocate(
self, layout_desc: MemoryLayoutDesc, count: int
Expand Down Expand Up @@ -174,6 +222,10 @@ def close(self) -> None:
"""
self._allocator.close()

def get_shm_pool_info(self) -> dict:
"""Return SHM pool metadata for non-GPU SHM transport."""
return dict(self._shm_pool_info)

# Debugging APIs
def memcheck(self):
return self._allocator.memcheck()
19 changes: 19 additions & 0 deletions lmcache/v1/distributed/storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,25 @@ def touch_l1_keys(self, keys: list[ObjectKey]):
"""
self._l1_manager.touch_keys(keys)

def get_shm_pool_info(self) -> dict:
"""Return SHM pool metadata from the L1 memory manager."""
return self._l1_manager.get_shm_pool_info()

def unsafe_read(
self, keys: list[ObjectKey]
) -> tuple[list[ObjectKey], list[MemoryObj]]:
"""Read already read-locked objects without acquiring new read locks."""
read_results = self._l1_manager.unsafe_read(keys)
good_keys: list[ObjectKey] = []
good_objs: list[MemoryObj] = []
for key in keys:
err, obj = read_results.get(key, (L1Error.KEY_NOT_EXIST, None))
if err != L1Error.SUCCESS or obj is None:
continue
good_keys.append(key)
good_objs.append(obj)
return good_keys, good_objs

@property
def quota_manager(self) -> QuotaManager:
"""Per-cache_salt quota registry.
Expand Down
10 changes: 10 additions & 0 deletions lmcache/v1/memory_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,16 @@ def get_num_tokens(self) -> int:
"""
raise NotImplementedError

@property
def shm_offset(self) -> int:
"""Return the byte offset of this object inside the SHM pool."""
return self.meta.address

@property
def shm_byte_length(self) -> int:
"""Return the byte length of this object inside the SHM pool."""
return self.get_size()

@property
@abc.abstractmethod
def metadata(self) -> MemoryObjMetadata:
Expand Down
24 changes: 21 additions & 3 deletions lmcache/v1/multiprocess/non_gpu_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
import torch

# First Party
from lmcache.logging import init_logger
from lmcache.utils import EngineType
from lmcache.v1.distributed.api import MemoryLayoutDesc

logger = init_logger(__name__)


@dataclass
class NonGpuContextMetadata:
Expand Down Expand Up @@ -100,24 +103,39 @@ def create_non_gpu_context(
metadata: NonGpuContextMetadata,
mq_client: Any,
mq_timeout: float,
shm_name: str = "",
pool_size: int = 0,
) -> NonGpuContext:
"""Factory that returns the appropriate :class:`NonGpuContext` implementation.

Currently always returns a pickle-based implementation
(``NonGpuContextPickle``). A future SHM-capable PR
may probe for shared-memory availability and fall back to pickle.
Returns SHM-based implementation when shared-memory pool information is
available; otherwise falls back to the pickle-based implementation.

Args:
metadata: Layout metadata for the non-GPU context.
mq_client: Message-queue client for server communication.
mq_timeout: Timeout in seconds for blocking MQ requests.
shm_name: Shared-memory segment name. Empty means pickle mode.
pool_size: Shared-memory pool size in bytes. Non-positive means pickle mode.

Returns:
A concrete :class:`NonGpuContext` instance.
"""
if shm_name and pool_size > 0:
# Local
from .non_gpu_context_shm import NonGpuContextShm

logger.info(
"Creating NonGpuContextShm (shm_name=%s, pool_size=%d)",
shm_name,
pool_size,
)
return NonGpuContextShm(metadata, mq_client, mq_timeout, shm_name, pool_size)

# Local
from .non_gpu_context_pickle import NonGpuContextPickle

logger.info("Creating NonGpuContextPickle (pickle transport)")
return NonGpuContextPickle(metadata, mq_client, mq_timeout)


Expand Down
145 changes: 145 additions & 0 deletions lmcache/v1/multiprocess/non_gpu_context_shm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
"""Shared-memory NonGpuContext implementation for multiprocess mode."""

# Standard
from typing import Any
import mmap
import os

# Third Party
import torch

# First Party
from lmcache.v1.multiprocess.non_gpu_context import (
NonGpuContext,
NonGpuContextMetadata,
)
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class

INVALID_SHM_FD = -1


class NonGpuContextShm(NonGpuContext):
"""Shared-memory implementation of :class:`NonGpuContext`."""

def __init__(
self,
metadata: NonGpuContextMetadata,
mq_client: Any,
mq_timeout: float,
shm_name: str,
pool_size: int,
) -> None:
super().__init__(metadata, mq_client, mq_timeout)
if not shm_name or pool_size <= 0:
raise ValueError("shm_name must be non-empty and pool_size must be > 0")

self._shm_name = shm_name
self._pool_size = pool_size
self._shm_fd = INVALID_SHM_FD
shm_path = os.path.join("/dev/shm", shm_name.lstrip("/"))
self._shm_fd = os.open(shm_path, os.O_RDWR)
try:
self._mmap_obj = mmap.mmap(
self._shm_fd, self._pool_size, access=mmap.ACCESS_WRITE
)
except Exception:
os.close(self._shm_fd)
self._shm_fd = INVALID_SHM_FD
raise

def _make_tensor_view(
self,
offset: int,
length: int,
shape: list[int],
dtype_str: str,
) -> torch.Tensor:
"""Create a tensor view over a SHM slot via ``torch.frombuffer``."""
dtype = getattr(torch, dtype_str, None)
if dtype is None or not isinstance(dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype string: {dtype_str}")
itemsize = torch.empty((), dtype=dtype).element_size()
if itemsize <= 0:
raise ValueError(f"Invalid dtype size for {dtype_str}")
count = length // itemsize
tensor_1d = torch.frombuffer(
self._mmap_obj, dtype=dtype, count=count, offset=offset
)
return tensor_1d.view(torch.Size(shape))

def _build_slot_tensors(self, slots: list[dict[str, Any]]) -> list[torch.Tensor]:
return [
self._make_tensor_view(
offset=int(slot["offset"]),
length=int(slot["length"]),
shape=list(slot["shape"]),
dtype_str=str(slot["dtype"]),
)
for slot in slots
]

def prepare_store(self, key: Any, instance_id: int) -> list[torch.Tensor] | None:
future = self.mq_client.submit_request(
RequestType.PREPARE_STORE,
[key, instance_id],
get_response_class(RequestType.PREPARE_STORE),
)
try:
response = future.result(timeout=self.mq_timeout)
except TimeoutError:
return None
context = response.context if isinstance(response.context, dict) else {}
slots = context.get("slots")
if not isinstance(slots, list) or not slots:
return None
return self._build_slot_tensors(slots)

def commit_store(
self, key: Any, instance_id: int, _chunks: list[torch.Tensor]
) -> bool:
future = self.mq_client.submit_request(
RequestType.COMMIT_STORE,
[key, instance_id, b""],
get_response_class(RequestType.COMMIT_STORE),
)
try:
return bool(future.result(timeout=self.mq_timeout))
except TimeoutError:
return False

def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None:
future = self.mq_client.submit_request(
RequestType.PREPARE_RETRIEVE,
[key, instance_id],
get_response_class(RequestType.PREPARE_RETRIEVE),
)
try:
response = future.result(timeout=self.mq_timeout)
except TimeoutError:
return None
if not response.success:
return None
slots = response.context.get("slots", [])
return self._build_slot_tensors(slots) if slots else None

def commit_retrieve(self, key: Any, instance_id: int) -> bool:
future = self.mq_client.submit_request(
RequestType.COMMIT_RETRIEVE,
[key, instance_id],
get_response_class(RequestType.COMMIT_RETRIEVE),
)
try:
return bool(future.result(timeout=self.mq_timeout))
except TimeoutError:
return False

def close(self) -> None:
if self._shm_fd == INVALID_SHM_FD:
return
try:
self._mmap_obj.close()
finally:
fd = self._shm_fd
self._shm_fd = INVALID_SHM_FD
os.close(fd)
12 changes: 10 additions & 2 deletions lmcache/v1/multiprocess/protocols/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class PrepareRetrieveResponse:
) # pickle: {}, shm will put slot info here


@dataclass
class RegisterNonGpuContextResponse:
"""Response for REGISTER_KV_CACHE_NON_GPU_CONTEXT."""

shm_name: str = ""
pool_size: int = 0


# Define request names for this protocol group
REQUEST_NAMES = [
"REGISTER_KV_CACHE",
Expand Down Expand Up @@ -179,10 +187,10 @@ def get_protocol_definitions() -> dict[str, ProtocolDefinition]:
# Register non-GPU KV cache context
# Payload:
# - RegisterNonGpuContextPayload - all metadata fields in one struct
# Returns: None
# Returns: RegisterNonGpuContextResponse
"REGISTER_KV_CACHE_NON_GPU_CONTEXT": ProtocolDefinition(
payload_classes=[RegisterNonGpuContextPayload],
response_class=None,
response_class=RegisterNonGpuContextResponse,
handler_type=HandlerType.SYNC,
),
"PREPARE_STORE": ProtocolDefinition(
Expand Down
Loading
Loading