Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 82 additions & 42 deletions lmcache/v1/storage_backend/pd_backend_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,17 +853,49 @@ async def _async_transfer_task(
alloc_response = await self._async_remote_allocate(
receiver_id, alloc_request
)
already_sent_indexes = set(alloc_response.already_sent_indexes)
remote_indexes = alloc_response.remote_indexes

num_keys = len(keys)
if already_sent_indexes:
if (
min(already_sent_indexes) < 0
or max(already_sent_indexes) >= num_keys
):
raise RuntimeError(
f"Invalid already_sent_indexes from receiver: "
f"{alloc_response.already_sent_indexes}, "
f"valid range [0, {num_keys})"
)

expected_send_count = num_keys - len(already_sent_indexes)
if len(remote_indexes) != expected_send_count:
raise RuntimeError(
f"AllocResponse inconsistency: total_keys={num_keys}, "
f"already_sent={len(already_sent_indexes)}, "
f"remote_indexes={len(remote_indexes)}, "
f"expected={expected_send_count}"
)

mem_objs_to_send: list[MemoryObj] = []
keys_to_send: list[CacheEngineKey] = []
for idx, (key, mem_obj) in enumerate(zip(keys, memory_objs, strict=True)):
if idx in already_sent_indexes:
mem_obj.ref_count_down()
completed_indexes.add(idx)
else:
mem_objs_to_send.append(mem_obj)
keys_to_send.append(key)

# Abort if any remote slot failed to allocate.
for idx, (mem_obj, remote_addr) in enumerate(
zip(memory_objs, remote_indexes, strict=True)
zip(mem_objs_to_send, remote_indexes, strict=True)
):
if remote_addr == -1:
logger.warning(
"Receiver allocation failed for key %s (idx=%d), "
"aborting entire request.",
keys[idx],
keys_to_send[idx],
idx,
)
for j, mo in enumerate(memory_objs):
Expand All @@ -874,32 +906,33 @@ async def _async_transfer_task(
await self._abort_request(req_id)
return

if memory_objs:
if mem_objs_to_send:
channel_transfer_spec = {
"receiver_id": receiver_id,
"remote_indexes": remote_indexes,
}

# Track sent keys for abort cleanup.
# Track all keys (including deduped) for abort cleanup.
if req_id:
sent = self._sent_keys.setdefault(req_id, [])
sent.extend(k.to_string() for k in keys)

await self.transfer_channel.async_batched_write(
objects=memory_objs,
objects=mem_objs_to_send,
transfer_spec=channel_transfer_spec,
)
for idx, mem_obj in enumerate(memory_objs):
if idx not in completed_indexes:
before = mem_obj.get_ref_count()
mem_obj.ref_count_down()
logger.debug(
"[SENDER] chunk %d ref_count: %d -> %d",
idx,
before,
before - 1,
)
completed_indexes.add(idx)
if idx in completed_indexes:
continue
before = mem_obj.get_ref_count()
mem_obj.ref_count_down()
logger.debug(
"[SENDER] chunk %d ref_count: %d -> %d",
idx,
before,
before - 1,
)
completed_indexes.add(idx)
logger.debug(
"[SENDER] req=%s batch done, freed %d chunks, free_chunks=%d",
req_id,
Expand Down Expand Up @@ -1314,11 +1347,19 @@ async def _async_allocate_and_put(
shape = list(alloc_request.shape)

alloc_indexes: list[int] = []
already_sent_indexes: list[int] = []
current_batch_keys: list[str] = []

try:
for idx, key_str in enumerate(alloc_request.keys):
key = CacheEngineKey.from_string(key_str)
with self.data_lock:
if key in self.data:
# Pin existing object so concurrent remove() cannot
# delete it before the deduped consumer retrieves it.
self.data[key].ref_count_up()
already_sent_indexes.append(idx)
continue

if idx == total_allocs - 1:
token_dim = fmt.token_dim()
Expand Down Expand Up @@ -1412,31 +1453,25 @@ async def _async_allocate_and_put(
self._req_allocated_keys.pop(req_id, None)
await self._recv_reservation_mgr.async_release_reservation(req_id)

return AllocResponse(remote_indexes=alloc_indexes)
return AllocResponse(
remote_indexes=alloc_indexes, already_sent_indexes=already_sent_indexes
)

def put(
self,
key: CacheEngineKey,
mem_obj: MemoryObj,
) -> None:
"""Store a memory object in the local data dictionary.

If a memory object already exists for the given key, the old object is
released (ref_count_down) to prevent memory leaks.

:param key: The cache engine key to associate with the memory object.
:param mem_obj: The memory object to store.
"""
"""Store a memory object in the local data dictionary."""
with self.data_lock:
old = self.data.pop(key, None)
if old is not None:
logger.debug(
"Overwriting existing MemoryObj for key %s in "
"PDBackendAsync.put(). "
"Releasing old object to prevent memory leak.",
if key in self.data:
logger.info(
"Duplicate put for key %s in PDBackendAsync.put(); "
"dropping new object.",
key,
)
old.ref_count_down()
mem_obj.ref_count_down()
return
self.data[key] = mem_obj

def get_blocking(self, key: CacheEngineKey) -> Optional[MemoryObj]:
Expand Down Expand Up @@ -1468,25 +1503,30 @@ def remove(
:param key: The key to remove.
"""
with self.data_lock:
mem_obj = self.data.pop(key, None)
mem_obj = self.data.get(key, None)
if mem_obj is not None:
logger.debug(
"[PD-FREE] remove key=%s, addr=%d, ref_count=%d, "
"data_size=%d, free_chunks_before=%d",
key,
mem_obj.meta.address,
mem_obj.get_ref_count(),
len(self.data),
self._get_free_chunks(),
)
before_rc = mem_obj.get_ref_count()
mem_obj.ref_count_down()
deleted = False
if mem_obj.get_ref_count() == 0:
del self.data[key]
logger.debug(
"[PD-FREE] remove key=%s, addr=%d, ref_count_before=%d, "
"data_size=%d, free_chunks_before=%d",
key,
mem_obj.meta.address,
before_rc,
len(self.data),
self._get_free_chunks(),
)
deleted = True
# Notify any coroutines blocked waiting for free memory.
# _alloc_freed_condition and _recv_loop only exist on the
# receiver; remove() is also called on the sender, so the
# hasattr guards are intentional. run_coroutine_threadsafe is
# used because remove() may be called from any OS thread while
# the receiver event loop runs on a dedicated thread.
if hasattr(self, "_alloc_freed_condition") and hasattr(
if deleted and hasattr(self, "_alloc_freed_condition") and hasattr(
self, "_recv_loop"
):
loop = self._recv_loop
Expand Down
Loading
Loading