Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions lmcache/v1/multiprocess/adapter_connector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from lmcache.logging import init_logger
from lmcache.utils import EngineType
from lmcache.v1.distributed.api import MemoryLayoutDesc
from lmcache.v1.gpu_connector.utils import LayoutHints
from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey
from lmcache.v1.multiprocess.mq import MessageQueueClient

logger = init_logger(__name__)

Expand Down Expand Up @@ -59,7 +62,7 @@ class NonGpuContext(ABC):
def __init__(
self,
metadata: NonGpuContextMetadata,
mq_client: Any,
mq_client: MessageQueueClient,
mq_timeout: float,
) -> None:
self.metadata = metadata
Expand All @@ -73,7 +76,7 @@ def layout_desc(self) -> MemoryLayoutDesc:

@abstractmethod
def prepare_store(
self, key: Any, instance_id: int
self, key: IPCCacheEngineKey, instance_id: int
) -> tuple[list[torch.Tensor], list[int]] | None:
"""Prepare SHM buffers for a store operation.

Expand All @@ -95,18 +98,20 @@ def prepare_store(

@abstractmethod
def commit_store(
self, key: Any, instance_id: int, chunks: list[torch.Tensor]
self, key: IPCCacheEngineKey, instance_id: int, chunks: list[torch.Tensor]
) -> bool:
"""Commit store. Pickle: serialize and send. Shm: notify server."""
...

@abstractmethod
def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None:
def prepare_retrieve(
self, key: IPCCacheEngineKey, instance_id: int
) -> list[torch.Tensor] | None:
"""Prepare retrieve. Returns chunks or shm views, or None on miss."""
...

@abstractmethod
def commit_retrieve(self, key: Any, instance_id: int) -> bool:
def commit_retrieve(self, key: IPCCacheEngineKey, instance_id: int) -> bool:
"""Commit retrieve. Pickle: no-op. Shm: release read locks."""
...

Expand All @@ -118,7 +123,7 @@ def close(self) -> None:

def create_non_gpu_context(
metadata: NonGpuContextMetadata,
mq_client: Any,
mq_client: MessageQueueClient,
mq_timeout: float,
shm_name: str = "",
pool_size: int = 0,
Expand Down Expand Up @@ -175,7 +180,7 @@ def create_non_gpu_context(

def compute_kv_layout(
kv_caches: dict[str, torch.Tensor],
layout_hints: Any | None = None,
layout_hints: LayoutHints | None = None,
) -> tuple[int, int, int, str, Any]:
"""Compute KV layout metadata from KV tensors.

Expand Down Expand Up @@ -216,7 +221,7 @@ def gather_paged_kv_to_cpu(
kv_caches: dict[str, torch.Tensor],
block_ids: list[int],
blocks_per_chunk: int,
layout_hints: Any | None = None,
layout_hints: LayoutHints | None = None,
gpu_kv_format: Any | None = None,
out: list[torch.Tensor] | None = None,
chunk_indices: list[int] | None = None,
Expand Down Expand Up @@ -360,7 +365,7 @@ def scatter_cpu_to_paged_kv(
chunks: list[torch.Tensor],
blocks_per_chunk: int,
skip_first_n_tokens: int = 0,
layout_hints: Any | None = None,
layout_hints: LayoutHints | None = None,
gpu_kv_format: Any | None = None,
) -> None:
"""Scatter CPU chunk tensors back into paged KV tensors.
Expand Down
15 changes: 9 additions & 6 deletions lmcache/v1/multiprocess/adapter_connector/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Pickle-based NonGpuContext implementation for multiprocess mode."""

# Standard
from typing import Any
import pickle

# Third Party
Expand All @@ -13,6 +12,8 @@
NonGpuContext,
NonGpuContextMetadata,
)
from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey
from lmcache.v1.multiprocess.mq import MessageQueueClient
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class


Expand All @@ -31,13 +32,13 @@ class NonGpuContextPickle(NonGpuContext):
def __init__(
self,
metadata: NonGpuContextMetadata,
mq_client: Any,
mq_client: MessageQueueClient,
mq_timeout: float,
) -> None:
super().__init__(metadata, mq_client, mq_timeout)

def prepare_store(
self, key: Any, instance_id: int
self, key: IPCCacheEngineKey, instance_id: int
) -> tuple[list[torch.Tensor], list[int]] | None:
"""Send PREPARE_STORE RPC. For pickle, returns no pre-allocated buffers."""
future = self.mq_client.submit_request(
Expand All @@ -52,7 +53,7 @@ def prepare_store(
return None

def commit_store(
self, key: Any, instance_id: int, chunks: list[torch.Tensor]
self, key: IPCCacheEngineKey, instance_id: int, chunks: list[torch.Tensor]
) -> bool:
"""Serialize chunks and send via COMMIT_STORE.

Expand All @@ -70,7 +71,9 @@ def commit_store(
except TimeoutError:
return False

def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None:
def prepare_retrieve(
self, key: IPCCacheEngineKey, instance_id: int
) -> list[torch.Tensor] | None:
"""Send PREPARE_RETRIEVE and deserialize the response data.

Returns:
Expand All @@ -90,7 +93,7 @@ def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | N
chunks: list[torch.Tensor] = pickle.loads(response.data)
return chunks

def commit_retrieve(self, key: Any, instance_id: int) -> bool:
def commit_retrieve(self, key: IPCCacheEngineKey, instance_id: int) -> bool:
"""Send COMMIT_RETRIEVE (no-op for pickle path)."""
future = self.mq_client.submit_request(
RequestType.COMMIT_RETRIEVE,
Expand Down
14 changes: 9 additions & 5 deletions lmcache/v1/multiprocess/adapter_connector/shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
NonGpuContext,
NonGpuContextMetadata,
)
from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey
from lmcache.v1.multiprocess.mq import MessageQueueClient
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class


Expand Down Expand Up @@ -77,7 +79,7 @@ class NonGpuContextShm(NonGpuContext):
def __init__(
self,
metadata: NonGpuContextMetadata,
mq_client: Any,
mq_client: MessageQueueClient,
mq_timeout: float,
shm_name: str,
pool_size: int,
Expand Down Expand Up @@ -141,7 +143,7 @@ def _build_slot_tensors(self, slots: list[dict[str, Any]]) -> list[torch.Tensor]
]

def prepare_store(
self, key: Any, instance_id: int
self, key: IPCCacheEngineKey, instance_id: int
) -> tuple[list[torch.Tensor], list[int]] | None:
future = self.mq_client.submit_request(
RequestType.PREPARE_STORE,
Expand All @@ -166,7 +168,7 @@ def prepare_store(
return self._build_slot_tensors(slots), chunk_indices

def commit_store(
self, key: Any, instance_id: int, _chunks: list[torch.Tensor]
self, key: IPCCacheEngineKey, instance_id: int, _chunks: list[torch.Tensor]
) -> bool:
future = self.mq_client.submit_request(
RequestType.COMMIT_STORE,
Expand All @@ -178,7 +180,9 @@ def commit_store(
except TimeoutError:
return False

def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | None:
def prepare_retrieve(
self, key: IPCCacheEngineKey, instance_id: int
) -> list[torch.Tensor] | None:
future = self.mq_client.submit_request(
RequestType.PREPARE_RETRIEVE,
[key, instance_id],
Expand All @@ -193,7 +197,7 @@ def prepare_retrieve(self, key: Any, instance_id: int) -> list[torch.Tensor] | N
slots = response.context.get("slots", [])
return self._build_slot_tensors(slots) if slots else None

def commit_retrieve(self, key: Any, instance_id: int) -> bool:
def commit_retrieve(self, key: IPCCacheEngineKey, instance_id: int) -> bool:
future = self.mq_client.submit_request(
RequestType.COMMIT_RETRIEVE,
[key, instance_id],
Expand Down
9 changes: 5 additions & 4 deletions tests/v1/multiprocess/test_non_cuda_data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,21 +982,22 @@ def _submit_request(req_type, payload, response_cls): # noqa: ARG001
pool_size=4096,
)
try:
store_result = context.prepare_store(key="k", instance_id=1)
key = _default_key()
store_result = context.prepare_store(key=key, instance_id=1)
assert store_result is not None
store_views, _ = store_result
store_views[0].copy_(
torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
)
assert context.commit_store("k", 1, store_views)
assert context.commit_store(key, 1, store_views)

retrieve_views = context.prepare_retrieve(key="k", instance_id=1)
retrieve_views = context.prepare_retrieve(key=key, instance_id=1)
assert retrieve_views is not None
assert torch.equal(
retrieve_views[0],
torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32),
)
assert context.commit_retrieve("k", 1)
assert context.commit_retrieve(key, 1)
finally:
context.close()
if os.path.exists(shm_path):
Expand Down