Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion lmcache/v1/multiprocess/modules/non_gpu_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MemoryLayoutDesc,
ObjectKey,
)
from lmcache.v1.multiprocess.config import MPServerConfig
from lmcache.v1.multiprocess.custom_types import (
IPCCacheEngineKey,
RegisterNonGpuContextPayload,
Expand Down Expand Up @@ -72,10 +73,18 @@ class NonGPUTransferModule:

Args:
ctx: The shared engine context.
mp_config: Optional MP server config carrying non-GPU SHM overrides.
"""

def __init__(self, ctx: MPCacheEngineContext) -> None:
def __init__(
self,
ctx: MPCacheEngineContext,
mp_config: MPServerConfig | None = None,
) -> None:
self._ctx = ctx
self._shm_name_override = (
mp_config.shm_name if mp_config is not None else None
)
self._non_gpu_contexts: dict[int, NonGPUContextEntry] = {}
self._strategies: dict[int, TransferStrategy] = {}
self._pending_shm_writes: dict[
Expand Down Expand Up @@ -162,6 +171,17 @@ def close(self) -> None:

def _compute_shm_pool_info(self) -> ShmPoolInfo:
"""Compute SHM pool info from storage manager config."""
if self._shm_name_override is not None:
shm_name = self._shm_name_override
if not shm_name:
return {"shm_name": "", "pool_size": 0}
sm_config = self._ctx.storage_manager_config
mem_cfg = sm_config.l1_manager_config.memory_config
bare = shm_name.lstrip("/")
if not bare.startswith("lmcache_l1_pool_"):
shm_name = f"lmcache_l1_pool_{bare}"
return {"shm_name": shm_name, "pool_size": mem_cfg.size_in_bytes}

sm_config = self._ctx.storage_manager_config
mem_cfg = sm_config.l1_manager_config.memory_config
shm_name = mem_cfg.shm_name or ""
Expand Down
10 changes: 9 additions & 1 deletion lmcache/v1/multiprocess/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _build_modules(
if mp_config.transfer_mode == "gpu":
modules.append(GPUTransferModule(ctx))
else:
modules.append(NonGPUTransferModule(ctx))
modules.append(NonGPUTransferModule(ctx, mp_config))

if mp_config.engine_type == "blend":
if mp_config.transfer_mode != "gpu":
Expand Down Expand Up @@ -191,6 +191,14 @@ def run_cache_server(

maybe_initialize_trace_recorder(event_bus, obs_config, storage_manager_config)

# Apply shm_name override from MP config so the allocator creates the
# segment with the user-specified name (must happen before StorageManager
# is instantiated inside MPCacheEngineContext).
if mp_config.shm_name is not None:
storage_manager_config.l1_manager_config.memory_config.shm_name = (
mp_config.shm_name
)

ctx = MPCacheEngineContext(
storage_manager_config=storage_manager_config,
chunk_size=mp_config.chunk_size,
Expand Down
88 changes: 88 additions & 0 deletions tests/v1/multiprocess/test_non_cuda_data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,94 @@ def test_server_register_and_find_non_cuda_context_layout(
assert layout.shapes[0] == torch.Size([2, 2, 16, 16])


def test_build_modules_passes_empty_shm_override_to_non_gpu_module(
stub_native_storage_ops: Any,
) -> None:
"""Ensure MP-level empty shm override disables SHM in non-GPU mode."""
# First Party
from lmcache.v1.multiprocess.config import MPServerConfig
from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload
from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext
from lmcache.v1.multiprocess.modules.non_gpu_transfer import NonGPUTransferModule
from lmcache.v1.multiprocess.server import _build_modules

storage_manager_config = MagicMock()
storage_manager_config.l1_manager_config.memory_config.shm_name = "storage_default"
storage_manager_config.l1_manager_config.memory_config.use_lazy = False
storage_manager_config.l1_manager_config.memory_config.size_in_bytes = 4096

with (
patch("lmcache.v1.multiprocess.engine_context.StorageManager"),
patch("lmcache.v1.multiprocess.engine_context.TokenHasher"),
patch("lmcache.v1.multiprocess.engine_context.SessionManager"),
patch("lmcache.v1.multiprocess.engine_context.get_event_bus"),
):
ctx = MPCacheEngineContext(
storage_manager_config=storage_manager_config, chunk_size=16
)

modules = _build_modules(ctx, MPServerConfig(transfer_mode="non_gpu", shm_name=""))
module = next(m for m in modules if isinstance(m, NonGPUTransferModule))
response = module.register_kv_cache_non_gpu_context(
RegisterNonGpuContextPayload(
instance_id=3,
model_name="m",
world_size=1,
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
)
)

assert response.shm_name == ""
assert response.pool_size == 0


def test_non_gpu_transfer_module_uses_mp_config_shm_name_override(
stub_native_storage_ops: Any,
) -> None:
"""Ensure MP-level shm override is normalized and returned to workers."""
# First Party
from lmcache.v1.multiprocess.config import MPServerConfig
from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload
from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext
from lmcache.v1.multiprocess.modules.non_gpu_transfer import NonGPUTransferModule

storage_manager_config = MagicMock()
storage_manager_config.l1_manager_config.memory_config.shm_name = None
storage_manager_config.l1_manager_config.memory_config.use_lazy = True
storage_manager_config.l1_manager_config.memory_config.size_in_bytes = 8192

with (
patch("lmcache.v1.multiprocess.engine_context.StorageManager"),
patch("lmcache.v1.multiprocess.engine_context.TokenHasher"),
patch("lmcache.v1.multiprocess.engine_context.SessionManager"),
patch("lmcache.v1.multiprocess.engine_context.get_event_bus"),
):
ctx = MPCacheEngineContext(
storage_manager_config=storage_manager_config, chunk_size=16
)

module = NonGPUTransferModule(ctx, MPServerConfig(shm_name="worker_pool"))
response = module.register_kv_cache_non_gpu_context(
RegisterNonGpuContextPayload(
instance_id=4,
model_name="m",
world_size=1,
block_size=4,
num_layers=2,
hidden_dim_size=16,
dtype_str="float32",
use_mla=False,
)
)

assert response.shm_name == "lmcache_l1_pool_worker_pool"
assert response.pool_size == 8192


def test_server_store_and_retrieve_cpu_chunks(stub_native_storage_ops: Any) -> None:
"""Validate mocked server-side CPU chunk store and retrieve behavior."""
# First Party
Expand Down
Loading