From c276b334e5c0d5631ed29781856f105b4cae6917 Mon Sep 17 00:00:00 2001 From: sh <100367948+yeoshuheng@users.noreply.github.com> Date: Tue, 26 May 2026 21:48:38 +0800 Subject: [PATCH 1/5] [Doc] Add recipes for Mistral, Phi & Llama (#3270) * docs: add recipes for phi, mistral, llama Signed-off-by: yeoshuheng <100367948+yeoshuheng@users.noreply.github.com> * docs: update tool calling link Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: sh <100367948+yeoshuheng@users.noreply.github.com> * docs: update jinja template fp Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: sh <100367948+yeoshuheng@users.noreply.github.com> * docs: update jinja template fp Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: sh <100367948+yeoshuheng@users.noreply.github.com> --------- Signed-off-by: yeoshuheng <100367948+yeoshuheng@users.noreply.github.com> Signed-off-by: sh <100367948+yeoshuheng@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/source/recipes/index.rst | 29 +++++++ docs/source/recipes/llama.rst | 132 ++++++++++++++++++++++++++++++++ docs/source/recipes/mixtral.rst | 92 ++++++++++++++++++++++ docs/source/recipes/phi3.rst | 90 ++++++++++++++++++++++ 4 files changed, 343 insertions(+) create mode 100644 docs/source/recipes/llama.rst create mode 100644 docs/source/recipes/mixtral.rst create mode 100644 docs/source/recipes/phi3.rst diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 7d3a05eaa62..2491f188fee 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -42,30 +42,35 @@ Supported architectures - SGLang - TRT-LLM - Recipe + * - ``MiniMaxM2ForCausalLM`` - ``MiniMaxAI/MiniMax-M2`` - ✓ - — - — - :doc:`minimax_m2` + * - ``Gemma4ForConditionalGeneration`` - ``google/gemma-4-31B-it`` - ✓ - — - — - :doc:`gemma4` + * - ``MistralForCausalLM`` - ``mistralai/Devstral-2-123B-Instruct-2512`` - ✓ - — - — - :doc:`devstral` + * - ``GptOssForCausalLM`` - ``openai/gpt-oss-120b`` - ✓ - — - — - :doc:`gpt_oss` + * - ``Qwen3MoeForCausalLM`` - ``Qwen/Qwen3-235B-A22B`` - ✓ @@ -73,6 +78,27 @@ Supported architectures - — - :doc:`qwen3` + * - ``LlamaForCausalLM`` + - ``meta-llama/Meta-Llama-3.1-70B-Instruct`` + - ✓ + - — + - — + - :doc:`llama` + + * - ``Phi3ForCausalLM`` + - ``microsoft/Phi-4-mini-instruct`` + - ✓ + - — + - — + - :doc:`phi3` + + * - ``MixtralForCausalLM`` + - ``mistralai/Mixtral-8x7B-Instruct-v0.1`` + - ✓ + - — + - — + - :doc:`mixtral` + Legend: ``✓`` validated, ``—`` not validated. Contributing a recipe @@ -96,3 +122,6 @@ To add a new architecture: devstral gpt_oss qwen3 + llama + phi3 + mixtral \ No newline at end of file diff --git a/docs/source/recipes/llama.rst b/docs/source/recipes/llama.rst new file mode 100644 index 00000000000..d7fafb5575b --- /dev/null +++ b/docs/source/recipes/llama.rst @@ -0,0 +1,132 @@ +.. _recipe_llama: + +LlamaForCausalLM +==================== + +Validated models +---------------- + +- `meta-llama/Meta-Llama-3.1-8B `_ +- `meta-llama/Meta-Llama-3.1-8B-Instruct `_ +- `meta-llama/Meta-Llama-3.1-70B `_ +- `meta-llama/Meta-Llama-3.1-70B-Instruct `_ + +.. tab-set:: + :sync-group: engine + + .. tab-item:: vLLM + + **Engine documentation:** + `LlamaForCausalLM in vLLM supported models + `_ + (architecture `LlamaForCausalLM`). + + **Status:** Validated with LMCache. + + Apply for access on the model card page and add your + `huggingface token `_ + as an environment variable: + + .. code-block:: bash + + export HUGGING_FACE_HUB_TOKEN=hf_xxxxxxxxxxxxxxxxx + + | + + Start the LMCache MP server: + + .. code-block:: bash + + lmcache server --l1-size-gb 100 --eviction-policy LRU + + | + + Get the chat templates for tool calling by following the `Llama tool calling guide `_ from vLLM. + + Start vLLM with the LMCache MP connector: + + **Meta-Llama-3.1-8B** (1 GPU): + + .. code-block:: bash + + vllm serve meta-llama/Meta-Llama-3.1-8B \ + --trust-remote-code \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + **Meta-Llama-3.1-8B-Instruct** (1 GPU): + + .. code-block:: bash + + vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --trust-remote-code \ + --enable-auto-tool-choice \ + --tool-call-parser llama3_json \ + --chat-template \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + **Meta-Llama-3.1-70B** (4 GPUs): + + .. code-block:: bash + + vllm serve meta-llama/Meta-Llama-3.1-70B \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + **Meta-Llama-3.1-70B-Instruct** (4 GPUs): + + .. code-block:: bash + + vllm serve meta-llama/Meta-Llama-3.1-70B-Instruct \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --enable-auto-tool-choice \ + --tool-call-parser llama3_json \ + --chat-template \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + Adjust ``--tensor-parallel-size`` to match your hardware. For the + generic LMCache + vLLM wiring (ports, remote hosts, in-process mode), + see :doc:`../mp/quickstart`. + + .. tab-item:: SGLang + + **Status:** Not validated with LMCache. + + .. tab-item:: TRT-LLM + + **Status:** Not supported. LMCache TRT-LLM integration is in progress. + +CacheBlend support +------------------ + +Compression support +------------------- + +.. list-table:: + :header-rows: 1 + :widths: 25 20 55 + + * - Method + - Status + - Notes + * - :doc:`CacheGen <../kv_cache_optimizations/compression/cachegen>` + - Not validated + - + +Caveats +------- + +None known. \ No newline at end of file diff --git a/docs/source/recipes/mixtral.rst b/docs/source/recipes/mixtral.rst new file mode 100644 index 00000000000..7e94c6feebe --- /dev/null +++ b/docs/source/recipes/mixtral.rst @@ -0,0 +1,92 @@ +.. _recipe_mixtral: + +MixtralForCausalLM +==================== + +Validated models +---------------- + +- `mistralai/Mixtral-8x7B-v0.1 `_ +- `mistralai/Mixtral-8x7B-Instruct-v0.1 `_ + +.. tab-set:: + :sync-group: engine + + .. tab-item:: vLLM + + **Engine documentation:** + `MixtralForCausalLM in vLLM supported models + `_ + (architecture `MixtralForCausalLM`). + + **Status:** Validated with LMCache. + + Start the LMCache MP server: + + .. code-block:: bash + + lmcache server --l1-size-gb 100 --eviction-policy LRU + + | + + Start vLLM with the LMCache MP connector: + + **Mixtral-8x7B-v0.1** (4 GPUs): + + .. code-block:: bash + + vllm serve mistralai/Mixtral-8x7B-v0.1 \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + **Mixtral-8x7B-Instruct-v0.1** (4 GPUs): + + .. code-block:: bash + + vllm serve mistralai/Mixtral-8x7B-Instruct-v0.1 \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --enable-auto-tool-choice \ + --tool-call-parser mistral \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + Adjust ``--tensor-parallel-size`` to match your hardware. For the + generic LMCache + vLLM wiring (ports, remote hosts, in-process mode), + see :doc:`../mp/quickstart`. + + .. tab-item:: SGLang + + **Status:** Not validated with LMCache. + + .. tab-item:: TRT-LLM + + **Status:** Not supported. LMCache TRT-LLM integration is in progress. + +CacheBlend support +------------------ + +Compression support +------------------- + +.. list-table:: + :header-rows: 1 + :widths: 25 20 55 + + * - Method + - Status + - Notes + * - :doc:`CacheGen <../kv_cache_optimizations/compression/cachegen>` + - Not validated + - + +Caveats +------- + +None known. \ No newline at end of file diff --git a/docs/source/recipes/phi3.rst b/docs/source/recipes/phi3.rst new file mode 100644 index 00000000000..743bf359060 --- /dev/null +++ b/docs/source/recipes/phi3.rst @@ -0,0 +1,90 @@ +.. _recipe_phi3: + +Phi3ForCausalLM +==================== + +Validated models +---------------- + +- `microsoft/Phi-4-mini-instruct `_ +- `microsoft/Phi-3-medium-128k-instruct `_ + +.. tab-set:: + :sync-group: engine + + .. tab-item:: vLLM + + **Engine documentation:** + `Phi3ForCausalLM in vLLM supported models + `_ + (architecture `Phi3ForCausalLM`). + + **Status:** Validated with LMCache. + + Start the LMCache MP server: + + .. code-block:: bash + + lmcache server --l1-size-gb 100 --eviction-policy LRU + + | + + Start vLLM with the LMCache MP connector: + + **Phi-4-mini-instruct** (1 GPU): + + .. code-block:: bash + + vllm serve microsoft/Phi-4-mini-instruct \ + --trust-remote-code \ + --enable-auto-tool-choice \ + --tool-call-parser phi4_mini_json \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + **Phi-3-medium-128k-instruct** (1 GPU): + + .. code-block:: bash + + vllm serve microsoft/Phi-3-medium-128k-instruct \ + --trust-remote-code \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheMPConnector", "kv_role":"kv_both"}' + + | + + Adjust ``--tensor-parallel-size`` to match your hardware. For the + generic LMCache + vLLM wiring (ports, remote hosts, in-process mode), + see :doc:`../mp/quickstart`. + + .. tab-item:: SGLang + + **Status:** Not validated with LMCache. + + .. tab-item:: TRT-LLM + + **Status:** Not supported. LMCache TRT-LLM integration is in progress. + +CacheBlend support +------------------ + +Compression support +------------------- + +.. list-table:: + :header-rows: 1 + :widths: 25 20 55 + + * - Method + - Status + - Notes + * - :doc:`CacheGen <../kv_cache_optimizations/compression/cachegen>` + - Not validated + - + +Caveats +------- + +None known. \ No newline at end of file From 4fb03710be6453e5050419172e8ec2ef8d1b2187 Mon Sep 17 00:00:00 2001 From: Zhengfei He <157287166+zhengfeihe@users.noreply.github.com> Date: Wed, 27 May 2026 00:59:21 +0900 Subject: [PATCH 2/5] [Perf] Replace Condvar polling with eventfd + epoll in iouring worker thread (#3271) * Replace Condvar polling with eventfd+epoll for io_uring worker Signed-off-by: zhengfeihe * Fix rust format check Signed-off-by: zhengfeihe * Add RAII guard for file descriptor in rust raw block Signed-off-by: zhengfeihe --------- Signed-off-by: zhengfeihe --- rust/raw_block/src/lib.rs | 985 ++++++++++++++++++++++++-------------- 1 file changed, 617 insertions(+), 368 deletions(-) diff --git a/rust/raw_block/src/lib.rs b/rust/raw_block/src/lib.rs index 62551e2efab..f170de532aa 100644 --- a/rust/raw_block/src/lib.rs +++ b/rust/raw_block/src/lib.rs @@ -20,6 +20,7 @@ use pyo3::prelude::*; use pyo3::types::PyAny; use std::collections::HashMap; use std::ffi::CString; +use std::io; use std::os::unix::io::RawFd; use std::slice; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; @@ -297,6 +298,174 @@ impl IoCompletion { } } +/// Manages io_uring worker thread notification, using one `epoll` instance +/// over two eventfds to wait on two event sources at once: a producer-side +/// eventfd signalled when Python pushes a new submission, and a CQ-side +/// eventfd signalled by the kernel when a completion is posted. +struct UringNotify { + /// Epoll instance watching both `producer_efd` and `cq_efd`. A single + /// `epoll_wait` on this fd blocks until either eventfd becomes readable, + /// so the worker can react to user-space queue pushes and kernel CQE + /// posts from the same call site. + epoll_fd: RawFd, + /// Eventfd written by Python producer threads after pushing into the + /// submission queue. Replaces the `Condvar::notify_one` used in the + /// pre-eventfd design. Also written by `do_close` so the worker can + /// break out of `epoll_wait` and observe the shutdown flag. + producer_efd: RawFd, + /// Eventfd registered with the io_uring instance via + /// `Submitter::register_eventfd`. The kernel writes to it whenever a + /// CQE is posted, so the worker is woken without having to drain the + /// completion queue speculatively. + cq_efd: RawFd, +} + +impl UringNotify { + /// Builds the three fds (two eventfds + one epoll) and wires the + /// eventfds into the epoll instance. Cleans up partially-built state + /// on any error path so no fd leaks if construction fails midway. + fn new() -> io::Result { + // Producer-side eventfd. Counter starts at 0 (no pending events). + // EFD_CLOEXEC prevents leaking the fd into a child process via + // execve. EFD_NONBLOCK makes read() return EAGAIN (instead of + // blocking) when the counter is already 0; wait() relies on this + // non-blocking behaviour to drain safely without hanging. + let producer_efd = unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) }; + if producer_efd < 0 { + // Nothing else has been allocated yet -- just bubble the error. + return Err(io::Error::last_os_error()); + } + + // CQ-side eventfd. The kernel writes to this one after + // register_eventfd() is called on the io_uring instance. + let cq_efd = unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) }; + if cq_efd < 0 { + // producer_efd is live -- close it before bubbling the error. + let e = io::Error::last_os_error(); + unsafe { libc::close(producer_efd) }; + return Err(e); + } + + // Epoll instance fd. EPOLL_CLOEXEC for the same hygiene reason as + // EFD_CLOEXEC above. + let epoll_fd = unsafe { libc::epoll_create1(libc::EPOLL_CLOEXEC) }; + if epoll_fd < 0 { + // Both eventfds are live; close them before returning. + let e = io::Error::last_os_error(); + unsafe { + libc::close(producer_efd); + libc::close(cq_efd); + } + return Err(e); + } + + // Register both eventfds with the epoll instance so epoll_wait + // can block on either source. + // + // We use EPOLLIN: wake when the counter becomes non-zero + // (someone signalled). Alternatives we deliberately don't use: + // - EPOLLOUT (writable): pointless -- eventfd is effectively + // always writable, so it would just busy-loop epoll_wait. + // - EPOLLET (edge-triggered): would require fully draining + // every wake-up in one go to avoid missing edges. Default + // level-triggered fits our drain pattern in wait(). + // - EPOLLONESHOT: would auto-disarm after each fire and force + // us to re-register; not worth the complexity. + // u64 stores the fd value itself so wait() can identify which + // fd fired without keeping a side map. + for fd in [producer_efd, cq_efd] { + let mut ev = libc::epoll_event { + events: libc::EPOLLIN as u32, + u64: fd as u64, + }; + let rc = unsafe { libc::epoll_ctl(epoll_fd, libc::EPOLL_CTL_ADD, fd, &mut ev) }; + if rc < 0 { + // All three fds are live; clean up all of them. + let e = io::Error::last_os_error(); + unsafe { + libc::close(epoll_fd); + libc::close(producer_efd); + libc::close(cq_efd); + } + return Err(e); + } + } + + // Ownership of all three fds moves into Self; Drop handles the + // happy-path close. + Ok(Self { + epoll_fd, + producer_efd, + cq_efd, + }) + } + + /// Wakes the worker thread by writing 1 to `producer_efd`. Called by + /// producers after pushing a submission into the queue and by `do_close` + /// to break the worker out of `epoll_wait` for shutdown. + fn signal_producer(&self) { + let v: u64 = 1; + // Given EFD_NONBLOCK and counter < u64::MAX - 1, the 8-byte write + // always succeeds, so the return value is intentionally ignored. + unsafe { + libc::write( + self.producer_efd, + &v as *const u64 as *const libc::c_void, + 8, + ); + } + } + + /// Blocks the worker until either eventfd is readable, then drains + /// each fired fd. Drain is required because epoll is level-triggered: + /// without consuming the counter, the next epoll_wait would return + /// immediately on the same already-handled signal. + fn wait(&self) { + // A capacity of 2 is enough: only two fds are registered with this + // epoll instance, so at most two events can come back per call. + let mut events = [libc::epoll_event { events: 0, u64: 0 }; 2]; + + // Timeout = -1 means "block indefinitely". Shutdown wakes us by + // writing producer_efd from do_close, so we never need a timeout. + let n = unsafe { libc::epoll_wait(self.epoll_fd, events.as_mut_ptr(), 2, -1) }; + + // n < 0 is usually EINTR (signal interruption); we just return and + // the worker's outer loop will call wait() again. n == 0 should + // not happen with timeout=-1 but is handled defensively. + if n <= 0 { + return; + } + + // For each event reported, read 8 bytes from the corresponding + // eventfd to reset its counter to 0. The fd value was stashed in + // ev.u64 during epoll_ctl registration. We discard the read value + // (we only care that a signal arrived, not how many). + let mut buf = [0u8; 8]; + for ev in &events[..n as usize] { + let fd = ev.u64 as RawFd; + // Discard the result. The wake-up was already delivered by + // epoll_wait; this read only exists to reset the eventfd + // counter so epoll stops reporting the fd as readable. If it + // fails, the worst case is one spurious wake-up next + // iteration. No work is lost, because the real submissions + // live in the queue and CQ, not in the bytes we just read. + unsafe { + libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, 8); + } + } + } +} + +impl Drop for UringNotify { + fn drop(&mut self) { + unsafe { + libc::close(self.epoll_fd); + libc::close(self.producer_efd); + libc::close(self.cq_efd); + } + } +} + /// Represents a single I/O submission to io_uring. /// /// This struct is sent from Python threads to the worker thread via a queue. @@ -381,8 +550,10 @@ struct RawBlockDevice { // Per-batch in-flight count tracking // Maps batch_id -> (in_flight_count, condition_variable) batch_in_flight: Arc>>, - // Signal to wake up worker when new requests are available - batch_ready: Option>, + // Worker wake-up: producer eventfd + ring CQ eventfd, both polled via + // a single epoll_fd. Replaces the previous `Arc` which couldn't + // be signaled from the kernel-side completion queue. + batch_ready: Option>, // Store Python buffer objects for writes, reads to keep them alive until they complete // This prevents premature garbage collection while io_uring is using the buffers // Keyed by batch_id to isolate concurrent batches @@ -394,6 +565,42 @@ struct RawBlockDevice { next_batch_id: Arc, } +/// RAII guard for a raw file descriptor +struct FdGuard { + fd: RawFd, +} + +impl FdGuard { + /// Takes ownership of an open file descriptor. + /// + /// Args: + /// - `fd`: a descriptor returned by a successful `open()`. + fn new(fd: RawFd) -> Self { + FdGuard { fd } + } + + /// Releases ownership of the fd to the caller, disarming the guard so the + /// descriptor is not closed on drop. + /// + /// Returns the raw descriptor, now owned by the caller. + fn disarm(self) -> RawFd { + let fd = self.fd; + std::mem::forget(self); + fd + } +} + +impl Drop for FdGuard { + fn drop(&mut self) { + // SAFETY: `fd` was returned by a successful `open()` and ownership has + // not been released via `disarm()`, so this is the only close of this + // descriptor. + unsafe { + libc::close(self.fd); + } + } +} + impl RawBlockDevice { /// Internal constructor performs all low level setup. fn new_internal( @@ -421,6 +628,10 @@ impl RawBlockDevice { if fd < 0 { return Err(os_err("open failed")); } + // Take ownership of the fd so it is closed if any fallible setup step + // below returns early before the fd is moved into the RawBlockDevice. + // Disarmed once the struct is successfully constructed. + let fd_guard = FdGuard::new(fd); let size = fd_size_bytes(fd)?; let ( @@ -438,10 +649,18 @@ impl RawBlockDevice { ) = if use_iouring { let ring = IoUring::new(iouring_queue_depth as u32) .map_err(|e| PyRuntimeError::new_err(format!("io_uring init failed: {}", e)))?; + let notify = UringNotify::new() + .map_err(|e| PyRuntimeError::new_err(format!("UringNotify init failed: {}", e)))?; + // Register the CQ eventfd with the ring so the kernel writes to it + // whenever a CQE is posted. Must happen before the ring is wrapped + // in a Mutex / handed to the worker. + ring.submitter() + .register_eventfd(notify.cq_efd) + .map_err(|e| PyRuntimeError::new_err(format!("register_eventfd failed: {}", e)))?; let ring = Arc::new(Mutex::new(ring)); let queue = Arc::new(Mutex::new(Vec::::new())); let shutdown = Arc::new(AtomicBool::new(false)); - let batch_ready = Arc::new(Condvar::new()); + let batch_ready = Arc::new(notify); let in_flight_count = Arc::new(AtomicU64::new(0)); let in_flight_cvar = Arc::new(Condvar::new()); let batched_buffer_objs = Arc::new(Mutex::new(HashMap::>>::new())); @@ -476,48 +695,127 @@ impl RawBlockDevice { // - Reads from the submission queue // - Submits to io_uring // - Processes completions - let worker = thread::spawn(move || { - let mut in_flight: HashMap = HashMap::new(); - let mut next_user_data: u64 = 1; - - while !shutdown_clone.load(Ordering::Relaxed) { - // This drains all completed I/O operations from the completion queue (CQ). - // For each completion: - // - Remove the request from our in_flight tracking HashMap - // - Signal the waiting Python thread via IoCompletion - // - Decrement the in_flight_count atomic - // - Wake up any threads waiting for all I/O to complete - { - let mut ring = ring_clone.lock().unwrap(); - let completions: Vec<_> = ring.completion().collect(); - for cqe in completions { - let user_data = cqe.user_data(); - if let Some(mut sub) = in_flight.remove(&user_data) { - let batch_id = sub.batch_id; - if cqe.result() < 0 { - let code = -cqe.result(); - // Drop any bounce buffer associated with this submission. - let _ = sub.bounce.take(); - sub.completion - .set(Err(PyOSError::new_err((code, "io_uring I/O error")))); - } else { - let bytes_transferred = cqe.result() as usize; - if bytes_transferred < sub.len { - // Short read/write: update offset and length, then resubmit - sub.offset += bytes_transferred as u64; - sub.len -= bytes_transferred; - // Update buffer pointer for writes and direct reads - if sub.is_write || sub.bounce.is_none() { - sub.ptr_addr += bytes_transferred; + let worker = thread::Builder::new() + .name("rust-rawblock-uring".into()) + .spawn(move || { + let mut in_flight: HashMap = HashMap::new(); + let mut next_user_data: u64 = 1; + + while !shutdown_clone.load(Ordering::Relaxed) { + // This drains all completed I/O operations from the completion queue (CQ). + // For each completion: + // - Remove the request from our in_flight tracking HashMap + // - Signal the waiting Python thread via IoCompletion + // - Decrement the in_flight_count atomic + // - Wake up any threads waiting for all I/O to complete + { + let mut ring = ring_clone.lock().unwrap(); + let completions: Vec<_> = ring.completion().collect(); + for cqe in completions { + let user_data = cqe.user_data(); + if let Some(mut sub) = in_flight.remove(&user_data) { + let batch_id = sub.batch_id; + if cqe.result() < 0 { + let code = -cqe.result(); + // Drop any bounce buffer associated with this submission. + let _ = sub.bounce.take(); + sub.completion.set(Err(PyOSError::new_err(( + code, + "io_uring I/O error", + )))); + } else { + let bytes_transferred = cqe.result() as usize; + if bytes_transferred < sub.len { + // Short read/write: update offset and length, then resubmit + sub.offset += bytes_transferred as u64; + sub.len -= bytes_transferred; + // Update buffer pointer for writes and direct reads + if sub.is_write || sub.bounce.is_none() { + sub.ptr_addr += bytes_transferred; + } + // For read with bounce buffer, copy partial data back + if !sub.is_write { + if let ( + Some(bounce), + Some(orig_ptr), + Some(payload_len), + ) = ( + sub.bounce.as_ref(), + sub.original_ptr, + sub.payload_len, + ) { + unsafe { + libc::memcpy( + orig_ptr as *mut libc::c_void, + bounce.as_ptr() as *const libc::c_void, + bytes_transferred.min(payload_len), + ); + } + sub.original_ptr = + Some(orig_ptr + bytes_transferred); + sub.payload_len = Some( + payload_len + .saturating_sub(bytes_transferred), + ); + } + } + // Re-insert into in_flight with updated values + // Don't decrement in_flight_count since we're resubmitting + in_flight.insert(user_data, sub.clone()); + // Push a new SQE for the remaining data + let ptr = sub.ptr_addr as *mut u8; + let sqe = if sub.is_write { + if let Some(idx) = sub.fixed_buffer_idx { + opcode::WriteFixed::new( + Fd(sub.fd), + ptr as *const u8, + sub.len as u32, + idx, + ) + .offset(sub.offset) + .build() + } else { + opcode::Write::new( + Fd(sub.fd), + ptr as *const u8, + sub.len as u32, + ) + .offset(sub.offset) + .build() + } + } else if let Some(idx) = sub.fixed_buffer_idx { + opcode::ReadFixed::new( + Fd(sub.fd), + ptr, + sub.len as u32, + idx, + ) + .offset(sub.offset) + .build() + } else { + opcode::Read::new(Fd(sub.fd), ptr, sub.len as u32) + .offset(sub.offset) + .build() + }; + let sqe = sqe.user_data(user_data); + unsafe { + ring.submission().push(&sqe).expect( + "failed to push sqe for short read/write", + ); + } + // Submit the new SQE to the kernel + let _ = ring.submitter().submit(); + continue; } - // For read with bounce buffer, copy partial data back + // Full completion + // For reads with bounce buffer, copy data back to original buffer if !sub.is_write { if let ( Some(bounce), Some(orig_ptr), Some(payload_len), ) = ( - sub.bounce.as_ref(), + sub.bounce.take(), sub.original_ptr, sub.payload_len, ) { @@ -525,262 +823,199 @@ impl RawBlockDevice { libc::memcpy( orig_ptr as *mut libc::c_void, bounce.as_ptr() as *const libc::c_void, - bytes_transferred.min(payload_len), + payload_len, ); } - sub.original_ptr = - Some(orig_ptr + bytes_transferred); - sub.payload_len = Some( - payload_len.saturating_sub(bytes_transferred), - ); } - } - // Re-insert into in_flight with updated values - // Don't decrement in_flight_count since we're resubmitting - in_flight.insert(user_data, sub.clone()); - // Push a new SQE for the remaining data - let ptr = sub.ptr_addr as *mut u8; - let sqe = if sub.is_write { - if let Some(idx) = sub.fixed_buffer_idx { - opcode::WriteFixed::new( - Fd(sub.fd), - ptr as *const u8, - sub.len as u32, - idx, - ) - .offset(sub.offset) - .build() - } else { - opcode::Write::new( - Fd(sub.fd), - ptr as *const u8, - sub.len as u32, - ) - .offset(sub.offset) - .build() - } - } else if let Some(idx) = sub.fixed_buffer_idx { - opcode::ReadFixed::new( - Fd(sub.fd), - ptr, - sub.len as u32, - idx, - ) - .offset(sub.offset) - .build() } else { - opcode::Read::new(Fd(sub.fd), ptr, sub.len as u32) - .offset(sub.offset) - .build() - }; - let sqe = sqe.user_data(user_data); - unsafe { - ring.submission() - .push(&sqe) - .expect("failed to push sqe for short read/write"); + // Drop any bounce buffer associated with this submission. + let _ = sub.bounce.take(); } - // Submit the new SQE to the kernel - let _ = ring.submitter().submit(); - continue; + sub.completion.set(Ok(())); + } + let prev = + in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); + if prev == 1 { + in_flight_cvar_clone.notify_all(); } - // Full completion - // For reads with bounce buffer, copy data back to original buffer - if !sub.is_write { - if let (Some(bounce), Some(orig_ptr), Some(payload_len)) = - (sub.bounce.take(), sub.original_ptr, sub.payload_len) + // Decrement per-batch in-flight count and notify if batch is complete + if batch_id != 0 { + let batch_map = batch_in_flight_clone.lock().unwrap(); + if let Some((batch_count, batch_cvar)) = + batch_map.get(&batch_id) { - unsafe { - libc::memcpy( - orig_ptr as *mut libc::c_void, - bounce.as_ptr() as *const libc::c_void, - payload_len, - ); + let prev_batch = + batch_count.fetch_sub(1, Ordering::Relaxed); + if prev_batch == 1 { + batch_cvar.notify_all(); } } - } else { - // Drop any bounce buffer associated with this submission. - let _ = sub.bounce.take(); - } - sub.completion.set(Ok(())); - } - let prev = in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); - if prev == 1 { - in_flight_cvar_clone.notify_all(); - } - // Decrement per-batch in-flight count and notify if batch is complete - if batch_id != 0 { - let batch_map = batch_in_flight_clone.lock().unwrap(); - if let Some((batch_count, batch_cvar)) = - batch_map.get(&batch_id) - { - let prev_batch = - batch_count.fetch_sub(1, Ordering::Relaxed); - if prev_batch == 1 { - batch_cvar.notify_all(); - } } } } + ring.submission().sync(); } - ring.submission().sync(); - } - // We use a condition variable with a short timeout (10 microseconds). - // This allows us to: - // - Quickly respond to new requests (batched from Python) - // - Periodically check for shutdown signal - // - Not spin aggressively (which would waste CPU) - let timeout = Duration::from_micros(10); - let q = queue_clone.lock().unwrap(); - let (mut q, _) = batch_ready_clone - .wait_timeout_while(q, timeout, |q| { - q.is_empty() && !shutdown_clone.load(Ordering::Relaxed) - }) - .unwrap(); - - if !q.is_empty() { - // Take all pending requests from our queue and submit them to io_uring. - // - // - Remove all pending requests from queue - // - Check how much space is available in the ring (max 256 entries) - // - If batch is larger than available space, put excess back in queue - // - Increment in_flight_count for each request we're about to submit - // - Build SQE (Submission Queue Entry) for each request - // - Push SQEs to the ring - // - Call submit() to send them to the kernel - // - // Fixed Buffer Support: - // - If the buffer was pre-registered with register_fixed_buffers(), - // we use ReadFixed/WriteFixed for true zero-copy I/O - // - Otherwise we use regular Read/Write with user-space pointers - let mut batch: Vec = std::mem::take(&mut *q); - let batch_len = batch.len(); - - let mut ring = ring_clone.lock().unwrap(); - - let available = ring_size - ring.submission().len(); - let to_submit_count = std::cmp::min(available, batch_len); - - if to_submit_count < batch_len { - let remaining: Vec<_> = batch[to_submit_count..].to_vec(); - if !remaining.is_empty() { - q.extend(remaining); - } + // Block on epoll only if there's truly nothing pending. The empty + + // shutdown checks short-circuit so we don't sleep when a producer or + // do_close() already left work for us. Race-free against a late + // signal_producer(): eventfd is a counter, so a wake-up between the + // check and wait() is buffered, not lost. + if !shutdown_clone.load(Ordering::Relaxed) + && queue_clone.lock().unwrap().is_empty() + { + batch_ready_clone.wait(); } - drop(q); - - // Track user_data values for each submission to clean up in_flight entries - // if submit() fails or returns partial count - let mut user_data_list: Vec = Vec::with_capacity(to_submit_count); - for sub in batch.iter().take(to_submit_count) { - let user_data = next_user_data; - next_user_data = next_user_data.wrapping_add(1); - user_data_list.push(user_data); - in_flight.insert(user_data, sub.clone()); - - let ptr = sub.ptr_addr as *mut u8; - let sqe = if sub.is_write { - if let Some(idx) = sub.fixed_buffer_idx { - opcode::WriteFixed::new( - Fd(sub.fd), - ptr as *const u8, - sub.len as u32, - idx, - ) - .offset(sub.offset) - .build() - } else { - opcode::Write::new(Fd(sub.fd), ptr as *const u8, sub.len as u32) - .offset(sub.offset) - .build() + let mut q = queue_clone.lock().unwrap(); + if !q.is_empty() { + // Take all pending requests from our queue and submit them to io_uring. + // + // - Remove all pending requests from queue + // - Check how much space is available in the ring (max 256 entries) + // - If batch is larger than available space, put excess back in queue + // - Increment in_flight_count for each request we're about to submit + // - Build SQE (Submission Queue Entry) for each request + // - Push SQEs to the ring + // - Call submit() to send them to the kernel + // + // Fixed Buffer Support: + // - If the buffer was pre-registered with register_fixed_buffers(), + // we use ReadFixed/WriteFixed for true zero-copy I/O + // - Otherwise we use regular Read/Write with user-space pointers + let mut batch: Vec = std::mem::take(&mut *q); + let batch_len = batch.len(); + + let mut ring = ring_clone.lock().unwrap(); + + let available = ring_size - ring.submission().len(); + let to_submit_count = std::cmp::min(available, batch_len); + + if to_submit_count < batch_len { + let remaining: Vec<_> = batch[to_submit_count..].to_vec(); + if !remaining.is_empty() { + q.extend(remaining); } - } else if let Some(idx) = sub.fixed_buffer_idx { - opcode::ReadFixed::new(Fd(sub.fd), ptr, sub.len as u32, idx) - .offset(sub.offset) - .build() - } else { - opcode::Read::new(Fd(sub.fd), ptr, sub.len as u32) - .offset(sub.offset) - .build() - }; - let sqe = sqe.user_data(user_data); - unsafe { - ring.submission().push(&sqe).expect("failed to push sqe"); } - } - let submit_result = ring.submitter().submit(); - // Handle EAGAIN (ring full) and EINTR (interrupted syscall) - match submit_result { - Ok(submitted) => { - // Any remaining requests in batch that weren't submitted - // will be retried in the next iteration of the loop - if submitted < to_submit_count { - // Remove in_flight entries for unsubmitted requests - for user_data in user_data_list[submitted..].iter() { - in_flight.remove(user_data); - } - // Put unsubmitted requests back in the queue for retry - let unsubmitted: Vec<_> = - batch[submitted..to_submit_count].to_vec(); - if !unsubmitted.is_empty() { - drop(ring); - let mut q = queue_clone.lock().unwrap(); - // Insert unsubmitted requests back at the front preserving order - q.splice(0..0, unsubmitted); + drop(q); + + // Track user_data values for each submission to clean up in_flight entries + // if submit() fails or returns partial count + let mut user_data_list: Vec = Vec::with_capacity(to_submit_count); + for sub in batch.iter().take(to_submit_count) { + let user_data = next_user_data; + next_user_data = next_user_data.wrapping_add(1); + user_data_list.push(user_data); + in_flight.insert(user_data, sub.clone()); + + let ptr = sub.ptr_addr as *mut u8; + let sqe = if sub.is_write { + if let Some(idx) = sub.fixed_buffer_idx { + opcode::WriteFixed::new( + Fd(sub.fd), + ptr as *const u8, + sub.len as u32, + idx, + ) + .offset(sub.offset) + .build() + } else { + opcode::Write::new( + Fd(sub.fd), + ptr as *const u8, + sub.len as u32, + ) + .offset(sub.offset) + .build() } + } else if let Some(idx) = sub.fixed_buffer_idx { + opcode::ReadFixed::new(Fd(sub.fd), ptr, sub.len as u32, idx) + .offset(sub.offset) + .build() + } else { + opcode::Read::new(Fd(sub.fd), ptr, sub.len as u32) + .offset(sub.offset) + .build() + }; + let sqe = sqe.user_data(user_data); + unsafe { + ring.submission().push(&sqe).expect("failed to push sqe"); } } - Err(e) => { - // Handle submission errors - let error_code = e.raw_os_error(); - match error_code { - Some(libc::EAGAIN) | Some(libc::EINTR) => { - // Ring is full, or the operation was interrupted due - // to signal. We need to wait for completions and then retry - // Remove in_flight entries for all submissions in this batch - for user_data in user_data_list.iter() { + + let submit_result = ring.submitter().submit(); + // Handle EAGAIN (ring full) and EINTR (interrupted syscall) + match submit_result { + Ok(submitted) => { + // Any remaining requests in batch that weren't submitted + // will be retried in the next iteration of the loop + if submitted < to_submit_count { + // Remove in_flight entries for unsubmitted requests + for user_data in user_data_list[submitted..].iter() { in_flight.remove(user_data); } - // Put unsubmitted requests back in queue for next iteration - if to_submit_count > 0 { - let unsubmitted: Vec<_> = - batch[..to_submit_count].to_vec(); + // Put unsubmitted requests back in the queue for retry + let unsubmitted: Vec<_> = + batch[submitted..to_submit_count].to_vec(); + if !unsubmitted.is_empty() { drop(ring); let mut q = queue_clone.lock().unwrap(); // Insert unsubmitted requests back at the front preserving order q.splice(0..0, unsubmitted); } } - _ => { - // Error: fail all pending submissions in this batch. - // Remove in_flight entries since these won't generate completions - for user_data in user_data_list.iter() { - in_flight.remove(user_data); + } + Err(e) => { + // Handle submission errors + let error_code = e.raw_os_error(); + match error_code { + Some(libc::EAGAIN) | Some(libc::EINTR) => { + // Ring is full, or the operation was interrupted due + // to signal. We need to wait for completions and then retry + // Remove in_flight entries for all submissions in this batch + for user_data in user_data_list.iter() { + in_flight.remove(user_data); + } + // Put unsubmitted requests back in queue for next iteration + if to_submit_count > 0 { + let unsubmitted: Vec<_> = + batch[..to_submit_count].to_vec(); + drop(ring); + let mut q = queue_clone.lock().unwrap(); + // Insert unsubmitted requests back at the front preserving order + q.splice(0..0, unsubmitted); + } } - for sub in batch.iter_mut().take(to_submit_count) { - let batch_id = sub.batch_id; - sub.completion.set(Err(PyRuntimeError::new_err( - format!("io_uring submit error: {:?}", e), - ))); - let _ = sub.bounce.take(); - let prev = in_flight_count_clone - .fetch_sub(1, Ordering::Relaxed); - if prev == 1 { - in_flight_cvar_clone.notify_all(); + _ => { + // Error: fail all pending submissions in this batch. + // Remove in_flight entries since these won't generate completions + for user_data in user_data_list.iter() { + in_flight.remove(user_data); } - // Decrement per-batch in-flight count and notify if batch is complete - if batch_id != 0 { - let batch_map = - batch_in_flight_clone.lock().unwrap(); - if let Some((batch_count, batch_cvar)) = - batch_map.get(&batch_id) - { - let prev_batch = - batch_count.fetch_sub(1, Ordering::Relaxed); - if prev_batch == 1 { - batch_cvar.notify_all(); + for sub in batch.iter_mut().take(to_submit_count) { + let batch_id = sub.batch_id; + sub.completion.set(Err(PyRuntimeError::new_err( + format!("io_uring submit error: {:?}", e), + ))); + let _ = sub.bounce.take(); + let prev = in_flight_count_clone + .fetch_sub(1, Ordering::Relaxed); + if prev == 1 { + in_flight_cvar_clone.notify_all(); + } + // Decrement per-batch in-flight count and notify if batch is complete + if batch_id != 0 { + let batch_map = + batch_in_flight_clone.lock().unwrap(); + if let Some((batch_count, batch_cvar)) = + batch_map.get(&batch_id) + { + let prev_batch = batch_count + .fetch_sub(1, Ordering::Relaxed); + if prev_batch == 1 { + batch_cvar.notify_all(); + } } } } @@ -790,131 +1025,140 @@ impl RawBlockDevice { } } } - } - // SHUTDOWN: Wake up all waiting Python threads - // Drain the queue and wake up all waiting threads with error - { - let mut q = queue_clone - .lock() - .expect("Worker: queue mutex poisoned during shutdown"); - while let Some(mut sub) = q.pop() { - let batch_id = sub.batch_id; - // Drop any bounce buffer associated with this submission. - let _ = sub.bounce.take(); - in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); - sub.completion.set(Err(PyRuntimeError::new_err( - "io_uring worker shutting down", - ))); - // Decrement per-batch in-flight count and notify if batch is complete - if batch_id != 0 { - let batch_map = batch_in_flight_clone.lock().unwrap(); - if let Some((batch_count, batch_cvar)) = batch_map.get(&batch_id) { - let prev_batch = batch_count.fetch_sub(1, Ordering::Relaxed); - if prev_batch == 1 { - batch_cvar.notify_all(); + // SHUTDOWN: Wake up all waiting Python threads + // Drain the queue and wake up all waiting threads with error + { + let mut q = queue_clone + .lock() + .expect("Worker: queue mutex poisoned during shutdown"); + while let Some(mut sub) = q.pop() { + let batch_id = sub.batch_id; + // Drop any bounce buffer associated with this submission. + let _ = sub.bounce.take(); + in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); + sub.completion.set(Err(PyRuntimeError::new_err( + "io_uring worker shutting down", + ))); + // Decrement per-batch in-flight count and notify if batch is complete + if batch_id != 0 { + let batch_map = batch_in_flight_clone.lock().unwrap(); + if let Some((batch_count, batch_cvar)) = batch_map.get(&batch_id) { + let prev_batch = batch_count.fetch_sub(1, Ordering::Relaxed); + if prev_batch == 1 { + batch_cvar.notify_all(); + } } } } } - } - // Process any remaining in-flight requests - // Wait for kernel to complete the requests or force-cancel them - // Note: This 1000 milliseconds is a rough estimate - let graceful_shutdown = Duration::from_millis(1000); - thread::sleep(graceful_shutdown); - { - let mut ring = ring_clone - .lock() - .expect("Worker: ring mutex poisoned during shutdown"); - for cqe in ring.completion() { - let user_data = cqe.user_data(); - if let Some(mut sub) = in_flight.remove(&user_data) { - let batch_id = sub.batch_id; - if cqe.result() < 0 { - let code = -cqe.result(); - // Drop any bounce buffer associated with this submission. - let _ = sub.bounce.take(); - sub.completion - .set(Err(PyOSError::new_err((code, "io_uring I/O error")))); - } else { - let bytes_transferred = cqe.result() as usize; - if bytes_transferred < sub.len { - // Short read/write during shutdown: fail the request - // We cannot resubmit because the worker is about to exit + // Process any remaining in-flight requests + // Wait for kernel to complete the requests or force-cancel them + // Note: This 1000 milliseconds is a rough estimate + let graceful_shutdown = Duration::from_millis(1000); + thread::sleep(graceful_shutdown); + { + let mut ring = ring_clone + .lock() + .expect("Worker: ring mutex poisoned during shutdown"); + for cqe in ring.completion() { + let user_data = cqe.user_data(); + if let Some(mut sub) = in_flight.remove(&user_data) { + let batch_id = sub.batch_id; + if cqe.result() < 0 { + let code = -cqe.result(); // Drop any bounce buffer associated with this submission. let _ = sub.bounce.take(); - sub.completion.set(Err(PyRuntimeError::new_err( + sub.completion + .set(Err(PyOSError::new_err((code, "io_uring I/O error")))); + } else { + let bytes_transferred = cqe.result() as usize; + if bytes_transferred < sub.len { + // Short read/write during shutdown: fail the request + // We cannot resubmit because the worker is about to exit + // Drop any bounce buffer associated with this submission. + let _ = sub.bounce.take(); + sub.completion.set(Err(PyRuntimeError::new_err( "io_uring worker shutting down - short I/O during shutdown", ))); - // Continue to decrement in_flight_count below - } else { - // Full completion - // For reads with bounce buffer, copy data back to original buffer - if !sub.is_write { - if let (Some(bounce), Some(orig_ptr), Some(payload_len)) = - (sub.bounce.take(), sub.original_ptr, sub.payload_len) - { - unsafe { - libc::memcpy( - orig_ptr as *mut libc::c_void, - bounce.as_ptr() as *const libc::c_void, - payload_len, - ); + // Continue to decrement in_flight_count below + } else { + // Full completion + // For reads with bounce buffer, copy data back to original buffer + if !sub.is_write { + if let ( + Some(bounce), + Some(orig_ptr), + Some(payload_len), + ) = ( + sub.bounce.take(), + sub.original_ptr, + sub.payload_len, + ) { + unsafe { + libc::memcpy( + orig_ptr as *mut libc::c_void, + bounce.as_ptr() as *const libc::c_void, + payload_len, + ); + } } + } else { + // Drop any bounce buffer associated with this submission. + let _ = sub.bounce.take(); } - } else { - // Drop any bounce buffer associated with this submission. - let _ = sub.bounce.take(); + sub.completion.set(Ok(())); } - sub.completion.set(Ok(())); } - } - let prev = in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); - if prev == 1 { - in_flight_cvar_clone.notify_all(); - } - // Decrement per-batch in-flight count and notify if batch is complete - if batch_id != 0 { - let batch_map = batch_in_flight_clone.lock().unwrap(); - if let Some((batch_count, batch_cvar)) = batch_map.get(&batch_id) { - let prev_batch = batch_count.fetch_sub(1, Ordering::Relaxed); - if prev_batch == 1 { - batch_cvar.notify_all(); + let prev = in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); + if prev == 1 { + in_flight_cvar_clone.notify_all(); + } + // Decrement per-batch in-flight count and notify if batch is complete + if batch_id != 0 { + let batch_map = batch_in_flight_clone.lock().unwrap(); + if let Some((batch_count, batch_cvar)) = + batch_map.get(&batch_id) + { + let prev_batch = + batch_count.fetch_sub(1, Ordering::Relaxed); + if prev_batch == 1 { + batch_cvar.notify_all(); + } } } } } + ring.submission().sync(); } - ring.submission().sync(); - } - // Any remaining in_flight requests, force wake with error - // (these were submitted to kernel but won't get completions) - for (_user_data, mut sub) in in_flight.drain() { - let batch_id = sub.batch_id; - // Drop any bounce buffer associated with this submission. - let _ = sub.bounce.take(); - in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); - sub.completion.set(Err(PyRuntimeError::new_err( - "io_uring worker shutting down - request cancelled", - ))); - // Decrement per-batch in-flight count and notify if batch is complete - if batch_id != 0 { - let batch_map = batch_in_flight_clone.lock().unwrap(); - if let Some((batch_count, batch_cvar)) = batch_map.get(&batch_id) { - let prev_batch = batch_count.fetch_sub(1, Ordering::Relaxed); - if prev_batch == 1 { - batch_cvar.notify_all(); + // Any remaining in_flight requests, force wake with error + // (these were submitted to kernel but won't get completions) + for (_user_data, mut sub) in in_flight.drain() { + let batch_id = sub.batch_id; + // Drop any bounce buffer associated with this submission. + let _ = sub.bounce.take(); + in_flight_count_clone.fetch_sub(1, Ordering::Relaxed); + sub.completion.set(Err(PyRuntimeError::new_err( + "io_uring worker shutting down - request cancelled", + ))); + // Decrement per-batch in-flight count and notify if batch is complete + if batch_id != 0 { + let batch_map = batch_in_flight_clone.lock().unwrap(); + if let Some((batch_count, batch_cvar)) = batch_map.get(&batch_id) { + let prev_batch = batch_count.fetch_sub(1, Ordering::Relaxed); + if prev_batch == 1 { + batch_cvar.notify_all(); + } } } } - } - // Final notification in case any thread is waiting on in_flight_count - in_flight_cvar_clone.notify_all(); - }); + // Final notification in case any thread is waiting on in_flight_count + in_flight_cvar_clone.notify_all(); + }) + .expect("spawn rust-rawblock-uring worker"); ( Some(ring), @@ -935,6 +1179,11 @@ impl RawBlockDevice { ) }; + // All fallible setup succeeded: hand the fd over to the struct, whose + // Drop is now responsible for closing it. Disarm so the guard does not + // close the same descriptor a second time. + let fd = fd_guard.disarm(); + Ok(Self { fd, size, @@ -1235,7 +1484,7 @@ impl RawBlockDevice { let mut q = queue.lock().unwrap(); q.push(sub); } - batch_ready.notify_one(); + batch_ready.signal_producer(); // Store completion for error checking in wait_iouring { @@ -1455,7 +1704,7 @@ impl RawBlockDevice { q.push(sub); } if let Some(batch_ready) = &self.batch_ready { - batch_ready.notify_one(); + batch_ready.signal_producer(); } py.allow_threads(move || comp.wait()) } else { @@ -1483,7 +1732,7 @@ impl RawBlockDevice { q.push(sub); } if let Some(batch_ready) = &self.batch_ready { - batch_ready.notify_one(); + batch_ready.signal_producer(); } py.allow_threads(move || comp.wait()) }; @@ -1671,7 +1920,7 @@ impl RawBlockDevice { let mut q = queue.lock().unwrap(); q.push(sub); } - batch_ready.notify_one(); + batch_ready.signal_producer(); // Store completion for error checking in wait_iouring { @@ -1971,7 +2220,7 @@ impl RawBlockDevice { shutdown.store(true, Ordering::Relaxed); } if let Some(batch_ready) = &self.batch_ready { - batch_ready.notify_all(); + batch_ready.signal_producer(); } let mutex = Mutex::new(()); From 3e83d5c2020dc5984269d5646c7be5a906cca595 Mon Sep 17 00:00:00 2001 From: Guy Ealey Morag Date: Tue, 26 May 2026 18:57:01 +0200 Subject: [PATCH 3/5] nixl_storage: Fix tests and remove from ignore list (#3200) * nixl_storage: naive support for files + dynamic static is not actually very usable with files: we bump into the OS limits on the open files very quickly, limiting the size of the cache. With tons of VRAM, having G3 of comparable size is not helping much, we really need it to be much bigger. This commit adds support for files in dynamic mode of nixl storage. There are some naive things done: 1. No support for sharing the cache storage. We assume the worker has exclusive access to the files and it is always safe to overwrite existing files. Note that with TP!=1 we will have several workers with the same target directory, that's why it's important to have a worker id as part of the key, so they don't affect each other. 2. No eviction support. As before, dynamic mode has no evicting support. This kinda made sense when it was only working with OBJ storages, but now that is something to be aware of. It's not hard to add eviction though. 3. Flat directory with all the cache files. There are going to be a lot of them, especially provided we don't have eviction. Most filesystems don't optimize for the case of a directory with millions of files, and we constantly open/close files there, that might be a source of additional latency. - There is an existing PR to support sharding the directory, we can merge it as a band aid. - As a more advanced solution, we can do multi-layer subdirectory structure, like gds backend does. 4. We switched directly from having _all_ files open at once, to open/close _every_ single file on access. We might explore having an LRU cache of open files instead. Signed-off-by: Ilya Yanok * Fix formatting Signed-off-by: Guy Ealey Morag * Fix PR Comments Signed-off-by: Guy Ealey Morag * Extract _build_descs Signed-off-by: Guy Ealey Morag * Fix test Signed-off-by: Guy Ealey Morag * Refactor to increase readability Signed-off-by: Guy Ealey Morag * Fix another leak issue Signed-off-by: Guy Ealey Morag * Fix release order Signed-off-by: Guy Ealey Morag * Treat file not found as a miss Signed-off-by: Guy Ealey Morag * Add docstring about file create mode 0o644 Signed-off-by: Guy Ealey Morag * Unlink files on failure Signed-off-by: Guy Ealey Morag * Add missing docs Signed-off-by: Guy Ealey Morag * Fix pre-commit issues Signed-off-by: Guy Ealey Morag * Fix mypy error Signed-off-by: Guy Ealey Morag * Fix test Signed-off-by: Guy Ealey Morag * Make tests pass Signed-off-by: Guy Ealey Morag * Fix formatting Signed-off-by: Guy Ealey Morag * Fix nixl tests and stop ignoring them in CI Signed-off-by: Guy Ealey Morag * Fix path to support GDS Signed-off-by: Guy Ealey Morag * Revert path creation in NixlDynamicStorageBackend Signed-off-by: Guy Ealey Morag * Fix code quality issue Signed-off-by: Guy Ealey Morag --------- Signed-off-by: Ilya Yanok Signed-off-by: Guy Ealey Morag Co-authored-by: Ilya Yanok --- .buildkite/k3_tests/unit/run.sh | 1 - .buildkite/pipeline.yml | 1 - .github/workflows/test.yml | 2 +- AGENTS.md | 9 ++- CONTRIBUTING.md | 1 - tests/v1/test_nixl_storage.py | 104 ++++++++++++++++++++++++++++---- 6 files changed, 96 insertions(+), 22 deletions(-) diff --git a/.buildkite/k3_tests/unit/run.sh b/.buildkite/k3_tests/unit/run.sh index 314e50d1ed3..052ed5931ce 100755 --- a/.buildkite/k3_tests/unit/run.sh +++ b/.buildkite/k3_tests/unit/run.sh @@ -35,7 +35,6 @@ pytest --maxfail=1 --cov=lmcache \ --cov-report term --cov-report=html:coverage-test \ --cov-report=xml:coverage-test.xml --html=durations/test.html \ --ignore=tests/disagg --ignore=tests/v1/test_pos_kernels.py \ - --ignore=tests/v1/test_nixl_storage.py \ --ignore=tests/skipped \ --ignore=tests/v1/storage_backend/test_eic.py diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6d024a7d721..e4517632cee 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -37,7 +37,6 @@ steps: --cov-report term --cov-report=html:coverage-test \ --cov-report=xml:coverage-test.xml --html=durations/test.html \ --ignore=tests/disagg --ignore=tests/v1/test_pos_kernels.py \ - --ignore=tests/v1/test_nixl_storage.py \ --ignore=tests/v1/test_nixl_batched_contains.py \ --ignore=tests/v1/test_device_id_race.py \ --ignore=tests/skipped \ diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 62e36f2ff0c..87ab1c34335 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -99,7 +99,7 @@ jobs: - name: "Run non-CUDA unit tests" run: | - pytest --ignore=tests/disagg --ignore=tests/v1/test_nixl_storage.py \ + pytest --ignore=tests/disagg \ --ignore=tests/v1/multiprocess/ \ --ignore=tests/v1/distributed/ \ --ignore=tests/v1/mp_observability/ \ diff --git a/AGENTS.md b/AGENTS.md index aa1673d0449..aa239e53877 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -47,11 +47,10 @@ BUILD_WITH_HIP=1 pip install -e . ```bash # Run standard test suite (mirrors CI) pytest -xvs --ignore=tests/disagg \ - --ignore=tests/v1/test_nixl_storage.py \ - --ignore=tests/v1/multiprocess/ \ - --ignore=tests/v1/distributed/ \ - --ignore=tests/skipped \ - --ignore=tests/v1/storage_backend/test_eic.py + --ignore=tests/v1/multiprocess/ \ + --ignore=tests/v1/distributed/ \ + --ignore=tests/skipped \ + --ignore=tests/v1/storage_backend/test_eic.py # Run a single test file pytest -xvs tests/v1/test_cache_engine.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e37e0270b27..506a19bda2c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -47,7 +47,6 @@ BUILD_WITH_HIP=1 pip install -e . ```bash # Run standard test suite (mirrors CI) pytest -xvs --ignore=tests/disagg \ - --ignore=tests/v1/test_nixl_storage.py \ --ignore=tests/v1/multiprocess/ \ --ignore=tests/v1/distributed/ \ --ignore=tests/skipped \ diff --git a/tests/v1/test_nixl_storage.py b/tests/v1/test_nixl_storage.py index 634c6484cb7..47499d82948 100644 --- a/tests/v1/test_nixl_storage.py +++ b/tests/v1/test_nixl_storage.py @@ -2,9 +2,14 @@ # Standard from pathlib import Path import asyncio +import contextlib +import functools import os +import shutil import sys +import tempfile import threading +import uuid # Third Party import pytest @@ -12,6 +17,10 @@ pytest.importorskip("nixl", reason="nixl package is required for nixl tests") +# Third Party +from nixl._api import nixl_agent as NixlAgent +from nixl._api import nixl_agent_config as NixlAgentConfig + # First Party from lmcache.utils import CacheEngineKey from lmcache.v1.config import LMCacheEngineConfig @@ -22,6 +31,52 @@ NixlStorageBackend, NixlStorageConfig, ) +from lmcache.v1.transfer_channel.transfer_utils import get_correct_device + +# cuFile-based backends (GDS, GDS_MT) need a GDS-capable filesystem +_TEST_TMPDIR = os.environ.get("LMCACHE_TEST_TMPDIR") or None + + +@functools.lru_cache(maxsize=None) +def _can_register_file_with_nixl_backend(backend: str) -> bool: + """Probe ``cuFileHandleRegister`` via NIXL on the test scratch dir.""" + + probe_dir = tempfile.mkdtemp(prefix="nixl_gds_probe_", dir=_TEST_TMPDIR) + probe_path = os.path.join(probe_dir, "probe.bin") + fd = -1 + try: + agent = NixlAgent( + f"NixlGdsProbe_{uuid.uuid4().hex}", + NixlAgentConfig(backends=[]), + ) + agent.create_backend(backend, {}) + fd = os.open(probe_path, os.O_CREAT | os.O_RDWR, 0o600) + os.write(fd, b"\x00" * 4096) + agent.register_memory([(0, 4096, fd, "")], mem_type="FILE") + return True + except Exception: + return False + finally: + if fd >= 0: + with contextlib.suppress(OSError): + os.close(fd) + shutil.rmtree(probe_dir, ignore_errors=True) + + +_GDS_SKIP_REASON = ( + "NIXL {backend} cannot register file handles in this environment; " + "set LMCACHE_TEST_TMPDIR to a GDS-capable mount (ext4/xfs) to enable." +) + + +@pytest.fixture +def nixl_tmp_path(): + """Per-test scratch dir, honoring ``LMCACHE_TEST_TMPDIR``.""" + path = tempfile.mkdtemp(prefix="nixl_test_", dir=_TEST_TMPDIR) + try: + yield path + finally: + shutil.rmtree(path, ignore_errors=True) def create_key(chunk_hash: str): @@ -70,7 +125,9 @@ def run(config: LMCacheEngineConfig, shape, dtype): config, metadata, thread_loop, - dst_device=config.nixl_buffer_device, # Pass the device directly + dst_device=get_correct_device( + config.nixl_buffer_device, metadata.worker_id + ), ) assert len(backends) == 2 # NixlStorageBackend + LocalCPUBackend assert BACKEND_NAME in backends @@ -201,62 +258,82 @@ def test_eviction(new_idx, old_idx): @pytest.mark.no_shared_allocator @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") -def test_nixl_gds_mt_cuda_backend(): +@pytest.mark.skipif( + not _can_register_file_with_nixl_backend("GDS_MT"), + reason=_GDS_SKIP_REASON.format(backend="GDS_MT"), +) +def test_nixl_gds_mt_cuda_backend(nixl_tmp_path): BASE_DIR = Path(__file__).parent config = LMCacheEngineConfig.from_file(BASE_DIR / "data/nixl.yaml") dtype = torch.bfloat16 - shape = [2048, 2048] + shape = torch.Size([2048, 2048]) - config.nixl_buffer_device = "cuda:0" # Use explicit device + config.nixl_buffer_device = "cuda" config.extra_config["nixl_backend"] = "GDS_MT" config.extra_config["enable_cuda"] = True + config.extra_config["nixl_path"] = nixl_tmp_path run(config, shape, dtype) @pytest.mark.no_shared_allocator -def test_nixl_gds_mt_cpu_backend(): +@pytest.mark.skipif( + not _can_register_file_with_nixl_backend("GDS_MT"), + reason=_GDS_SKIP_REASON.format(backend="GDS_MT"), +) +def test_nixl_gds_mt_cpu_backend(nixl_tmp_path): BASE_DIR = Path(__file__).parent config = LMCacheEngineConfig.from_file(BASE_DIR / "data/nixl.yaml") dtype = torch.bfloat16 - shape = [2048, 2048] + shape = torch.Size([2048, 2048]) config.nixl_buffer_device = "cpu" config.extra_config["nixl_backend"] = "GDS_MT" config.extra_config["enable_cuda"] = False + config.extra_config["nixl_path"] = nixl_tmp_path run(config, shape, dtype) @pytest.mark.no_shared_allocator @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") -def test_nixl_gds_cuda_backend(): +@pytest.mark.skipif( + not _can_register_file_with_nixl_backend("GDS"), + reason=_GDS_SKIP_REASON.format(backend="GDS"), +) +def test_nixl_gds_cuda_backend(nixl_tmp_path): BASE_DIR = Path(__file__).parent config = LMCacheEngineConfig.from_file(BASE_DIR / "data/nixl.yaml") dtype = torch.bfloat16 - shape = [2048, 2048] + shape = torch.Size([2048, 2048]) - config.nixl_buffer_device = "cuda:0" # Use explicit device + config.nixl_buffer_device = "cuda" config.extra_config["nixl_backend"] = "GDS" config.extra_config["enable_cuda"] = True + config.extra_config["nixl_path"] = nixl_tmp_path run(config, shape, dtype) @pytest.mark.no_shared_allocator -def test_nixl_gds_cpu_backend(): +@pytest.mark.skipif( + not _can_register_file_with_nixl_backend("GDS"), + reason=_GDS_SKIP_REASON.format(backend="GDS"), +) +def test_nixl_gds_cpu_backend(nixl_tmp_path): BASE_DIR = Path(__file__).parent config = LMCacheEngineConfig.from_file(BASE_DIR / "data/nixl.yaml") dtype = torch.bfloat16 - shape = [2048, 2048] + shape = torch.Size([2048, 2048]) config.nixl_buffer_device = "cpu" config.extra_config["nixl_backend"] = "GDS" config.extra_config["enable_cuda"] = False + config.extra_config["nixl_path"] = nixl_tmp_path run(config, shape, dtype) @@ -305,16 +382,17 @@ def test_nixl_endpoint_list_malformed_url_raises(): @pytest.mark.no_shared_allocator -def test_nixl_posix_backend(): +def test_nixl_posix_backend(nixl_tmp_path): BASE_DIR = Path(__file__).parent config = LMCacheEngineConfig.from_file(BASE_DIR / "data/nixl.yaml") dtype = torch.bfloat16 - shape = [2048, 2048] + shape = torch.Size([2048, 2048]) config.nixl_buffer_device = "cpu" config.extra_config["nixl_backend"] = "POSIX" config.extra_config["enable_cuda"] = False + config.extra_config["nixl_path"] = nixl_tmp_path run(config, shape, dtype) From a68bd0a0d1465f5a1a6b69d1ac8ca2e0be54932b Mon Sep 17 00:00:00 2001 From: Ilya Yanok Date: Tue, 26 May 2026 19:10:29 +0200 Subject: [PATCH 4/5] Iyanok/huge pages (#2643) * memory: huge page support (C bits) Add 3 pairs of alloc/free functions for huge pages. The 4th option is for SHM and it's missing, since normal SHM use cases don't support huge pages. Old API is left untouched, except for the small refactor: common mmap code is factored out. This doesn't affect the behavior. Signed-off-by: Ilya Yanok * memory_management: use new functions for huge pages Change two allocators to support us_huge_pages flag. Signed-off-by: Ilya Yanok * memory_management: log if allocating huge pages fail Signed-off-by: Ilya Yanok * non_cuda_equivalents: change memory functions to match cuda Adding use_hugepages to alloc and size to free. There is no actual huge page support though. It should be pretty easy to add to non-shm cases (mmap and give torch a buffer). For shm we need to look into shared_memory module implementation. Signed-off-by: Ilya Yanok * local_cpu: support for huge pages Signed-off-by: Ilya Yanok * nixl_storage: add config option to alloc huge pages We want to be able to alloc huge pages for the NIXL buffer in DRAM. Signed-off-by: Ilya Yanok * PR comments fixes, Code cleanup Signed-off-by: Guy Ealey Morag * Skip test if not enough free huge pages Signed-off-by: Guy Ealey Morag * Fix non-cuda fallbacks Signed-off-by: Guy Ealey Morag * Add docs Signed-off-by: Guy Ealey Morag * Remove unused func Signed-off-by: Guy Ealey Morag * Fix hugepage info to check for 2 MiB only Signed-off-by: Guy Ealey Morag * Fix pre-commit issues Signed-off-by: Guy Ealey Morag * Merge branch 'dev' into iyanok/huge-pages Signed-off-by: Guy Ealey Morag --------- Signed-off-by: Ilya Yanok Signed-off-by: Guy Ealey Morag Co-authored-by: Guy Ealey Morag Co-authored-by: Yihua Cheng --- csrc/mem_alloc.cpp | 95 ++++++++-- csrc/mem_alloc.h | 10 ++ csrc/pybind.cpp | 6 + docs/source/api_reference/configurations.rst | 5 + .../kv_cache/storage_backends/cpu_ram.rst | 50 ++++++ .../source/kv_cache/storage_backends/nixl.rst | 5 +- lmcache/python_ops_fallback.py | 37 ++++ lmcache/v1/config.py | 5 + lmcache/v1/memory_management.py | 166 ++++++++++++++---- .../v1/storage_backend/local_cpu_backend.py | 7 + .../storage_backend/nixl_storage_backend.py | 22 ++- tests/v1/test_memory_management.py | 58 ++++++ 12 files changed, 417 insertions(+), 49 deletions(-) diff --git a/csrc/mem_alloc.cpp b/csrc/mem_alloc.cpp index b659b9cae52..9ac21c1af76 100644 --- a/csrc/mem_alloc.cpp +++ b/csrc/mem_alloc.cpp @@ -1,8 +1,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -10,6 +12,24 @@ #include // for MPOL_BIND, MPOL_MF_MOVE, MPOL_MF_STRICT #include "mem_alloc.h" +static constexpr size_t HUGEPAGE_SIZE = 2UL * 1024 * 1024; // MAP_HUGE_2MB + +static inline size_t _align_hugepage(size_t size) { + return (size + HUGEPAGE_SIZE - 1) & ~(HUGEPAGE_SIZE - 1); +} + +static void* _mmap_anon(size_t size, bool hugepages) { + int flags = MAP_PRIVATE | MAP_ANONYMOUS; + if (hugepages) { + flags |= MAP_HUGETLB | MAP_HUGE_2MB; + } + void* ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE, flags, -1, 0); + if (ptr == MAP_FAILED) { + throw std::runtime_error(std::string("mmap failed: ") + strerror(errno)); + } + return ptr; +} + uintptr_t alloc_pinned_ptr(size_t size, unsigned int flags) { void* ptr = nullptr; cudaError_t err = cudaHostAlloc(&ptr, size, flags); @@ -26,6 +46,36 @@ void free_pinned_ptr(uintptr_t ptr) { } } +uintptr_t alloc_hugepage_pinned_ptr(size_t size, unsigned int flags) { + size = _align_hugepage(size); + void* ptr = _mmap_anon(size, true); + + cudaError_t st = cudaHostRegister(ptr, size, flags); + if (st != cudaSuccess) { + munmap(ptr, size); + throw std::runtime_error(std::string("cudaHostRegister failed: ") + + cudaGetErrorString(st)); + } + + return reinterpret_cast(ptr); +} + +void free_hugepage_pinned_ptr(uintptr_t ptr, size_t size) { + size = _align_hugepage(size); + void* p = reinterpret_cast(ptr); + + // Unpin first, then unmap. + cudaError_t st = cudaHostUnregister(p); + if (st != cudaSuccess) { + munmap(p, size); + throw std::runtime_error(std::string("cudaHostUnregister failed: ") + + cudaGetErrorString(st)); + } + if (munmap(p, size) != 0) { + throw std::runtime_error(std::string("munmap failed: ") + strerror(errno)); + } +} + void batched_memcpy(const std::vector& src_ptrs, const std::vector& dst_ptrs, const std::vector& sizes) { @@ -43,10 +93,11 @@ void batched_memcpy(const std::vector& src_ptrs, } } -static void first_touch(void* p, size_t size) { - const long ps = sysconf(_SC_PAGESIZE); +static void first_touch(void* p, size_t size, bool hugepages) { + const size_t ps = + hugepages ? HUGEPAGE_SIZE : static_cast(sysconf(_SC_PAGESIZE)); for (size_t off = 0; off < size; off += ps) { - volatile char* c = (volatile char*)p + off; + volatile char* c = static_cast(p) + off; *c = 0; } } @@ -58,11 +109,12 @@ static inline int mbind_sys(void* addr, unsigned long len, int mode, return (rc == -1) ? -errno : 0; } -uintptr_t alloc_numa_ptr(size_t size, int node) { - void* ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - if (ptr == MAP_FAILED) - throw std::runtime_error(std::string("mmap failed: ") + strerror(errno)); +static uintptr_t _alloc_numa_impl(size_t size, int node, bool hugepages) { + if (hugepages) { + assert(size % HUGEPAGE_SIZE == 0); + } + + void* ptr = _mmap_anon(size, hugepages); // Maximum of 64 numa nodes unsigned long mask = 1UL << node; @@ -74,11 +126,15 @@ uintptr_t alloc_numa_ptr(size_t size, int node) { throw std::runtime_error(std::string("mbind failed: ") + strerror(err)); } - first_touch(ptr, size); + first_touch(ptr, size, hugepages); return reinterpret_cast(ptr); } +uintptr_t alloc_numa_ptr(size_t size, int node) { + return _alloc_numa_impl(size, node, false); +} + void free_numa_ptr(uintptr_t ptr, size_t size) { void* p = reinterpret_cast(ptr); if (munmap(p, size) != 0) { @@ -86,8 +142,9 @@ void free_numa_ptr(uintptr_t ptr, size_t size) { } } -uintptr_t alloc_pinned_numa_ptr(size_t size, int node) { - void* ptr = reinterpret_cast(alloc_numa_ptr(size, node)); +static uintptr_t _alloc_pinned_numa_impl(size_t size, int node, + bool hugepages) { + void* ptr = reinterpret_cast(_alloc_numa_impl(size, node, hugepages)); cudaError_t st = cudaHostRegister(ptr, size, 0); if (st != cudaSuccess) { @@ -99,6 +156,15 @@ uintptr_t alloc_pinned_numa_ptr(size_t size, int node) { return reinterpret_cast(ptr); } +uintptr_t alloc_pinned_numa_ptr(size_t size, int node) { + return _alloc_pinned_numa_impl(size, node, false); +} + +uintptr_t alloc_hugepage_pinned_numa_ptr(size_t size, int node) { + size = _align_hugepage(size); + return _alloc_pinned_numa_impl(size, node, true); +} + void free_pinned_numa_ptr(uintptr_t ptr, size_t size) { void* p = reinterpret_cast(ptr); // Unpin first, then unmap. @@ -113,6 +179,11 @@ void free_pinned_numa_ptr(uintptr_t ptr, size_t size) { } } +void free_hugepage_pinned_numa_ptr(uintptr_t ptr, size_t size) { + size = _align_hugepage(size); + free_pinned_numa_ptr(ptr, size); +} + uintptr_t alloc_shm_pinned_ptr(size_t size, const std::string& shm_name) { int fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0600); if (fd < 0) @@ -133,7 +204,7 @@ uintptr_t alloc_shm_pinned_ptr(size_t size, const std::string& shm_name) { throw std::runtime_error(std::string("mmap failed: ") + strerror(errno)); } - first_touch(ptr, size); + first_touch(ptr, size, false); cudaError_t st = cudaHostRegister(ptr, size, 0); if (st != cudaSuccess) { diff --git a/csrc/mem_alloc.h b/csrc/mem_alloc.h index 3e189ef2c2f..58c83b7fa8e 100644 --- a/csrc/mem_alloc.h +++ b/csrc/mem_alloc.h @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + #include #include #include @@ -15,3 +17,11 @@ void free_numa_ptr(uintptr_t ptr, size_t size); void free_pinned_numa_ptr(uintptr_t ptr, size_t size); void free_shm_pinned_ptr(uintptr_t ptr, size_t size, const std::string& shm_name); + +// Hugepage variants (MAP_HUGETLB). Not available for shm: /dev/shm usually +// uses tmpfs, and tmpfs does not support MAP_HUGETLB. +uintptr_t alloc_hugepage_pinned_ptr(size_t size, unsigned int flags); +uintptr_t alloc_hugepage_pinned_numa_ptr(size_t size, int node); + +void free_hugepage_pinned_ptr(uintptr_t ptr, size_t size); +void free_hugepage_pinned_numa_ptr(uintptr_t ptr, size_t size); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 2228795d59a..817295bef5b 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -61,9 +61,15 @@ PYBIND11_MODULE(c_ops, m) { m.def("alloc_pinned_ptr", &alloc_pinned_ptr, py::call_guard()); m.def("free_pinned_ptr", &free_pinned_ptr); + m.def("alloc_hugepage_pinned_ptr", &alloc_hugepage_pinned_ptr, + py::call_guard()); + m.def("free_hugepage_pinned_ptr", &free_hugepage_pinned_ptr); m.def("alloc_pinned_numa_ptr", &alloc_pinned_numa_ptr, py::call_guard()); m.def("free_pinned_numa_ptr", &free_pinned_numa_ptr); + m.def("alloc_hugepage_pinned_numa_ptr", &alloc_hugepage_pinned_numa_ptr, + py::call_guard()); + m.def("free_hugepage_pinned_numa_ptr", &free_hugepage_pinned_numa_ptr); m.def("alloc_numa_ptr", &alloc_numa_ptr, py::call_guard()); m.def("free_numa_ptr", &free_numa_ptr); diff --git a/docs/source/api_reference/configurations.rst b/docs/source/api_reference/configurations.rst index 85b16f1bb9c..8c3506d7350 100644 --- a/docs/source/api_reference/configurations.rst +++ b/docs/source/api_reference/configurations.rst @@ -34,6 +34,9 @@ Basic cache settings that control the core functionality of LMCache. * - max_local_cpu_size - LMCACHE_MAX_LOCAL_CPU_SIZE - Maximum CPU cache size in GB. Default: 5.0 + * - local_cpu_use_hugepages + - LMCACHE_LOCAL_CPU_USE_HUGEPAGES + - Whether to use Linux hugepages (2 MB) for CPU-pinned KV cache memory. Not compatible with P2P mode or shared memory (multiprocess). Requires pre-allocated hugepages (``sysctl vm.nr_hugepages``). Values: true/false. Default: false * - local_disk - LMCACHE_LOCAL_DISK - Path (or comma-separated paths) to local disk cache directories. Format: ``"file:///path/to/cache"`` or ``"/path/a,/path/b"`` for multi-device I/O. See ``local_disk_path_sharding`` for how paths are assigned to GPUs. @@ -354,6 +357,8 @@ Settings for using Nixl as a storage backend instead of disaggregated prefill. T - Number of files or objects in the storage pool * - nixl_endpoint_list - List of object-storage endpoint URLs for per-worker distribution. Each TP worker selects an entry round-robin by ``local_worker_id``, overriding ``nixl_backend_params.endpoint_override``. Only applied when ``nixl_backend`` is ``"OBJ"`` (silently ignored otherwise). Each entry must start with ``http://`` or ``https://``; an empty list raises ``ValueError`` at engine init. + * - nixl_use_hugepages + - Whether to use Linux hugepages (2 MiB) for the NIXL CPU buffer. Requires pre-allocated hugepages (``sysctl vm.nr_hugepages``). Values: true/false. Default: false Additional Storage Configurations diff --git a/docs/source/kv_cache/storage_backends/cpu_ram.rst b/docs/source/kv_cache/storage_backends/cpu_ram.rst index ce1253615cd..c29d5c482c7 100644 --- a/docs/source/kv_cache/storage_backends/cpu_ram.rst +++ b/docs/source/kv_cache/storage_backends/cpu_ram.rst @@ -63,6 +63,56 @@ tokens into the pinned CPU RAM from the disk or remote storage (*if* the KV cach tokens are already stored there). This can preemptively avoid the latency of the disk and remote KV transfer if we predict these tokens will be requested soon (e.g. structured or agentic workflows). +.. _cpu_ram-hugepage-support: + +Hugepage Support +----------------- + +By default LMCache allocates CPU-pinned memory using regular 4 KiB pages. +For large KV cache buffers (multiple gigabytes), enabling **Linux hugepages** +(2 MiB pages) can reduce TLB (Translation Lookaside Buffer) pressure and +improve memory access performance. + +**System prerequisite** + +Hugepages must be pre-allocated at the OS level before LMCache starts. +TO find the number of pages needed, divide the desired buffer size by 2 MiB and round up. +For example, 5 GB requires at least 2560 pages: + +.. code-block:: bash + + # Allocate 2560 hugepages (5 GB) + sudo sysctl -w vm.nr_hugepages=2560 + + # Make persistent across reboots + echo 'vm.nr_hugepages=2560' | sudo tee -a /etc/sysctl.conf + +Verify that pages are available: + +.. code-block:: bash + + grep HugePages /proc/meminfo + # HugePages_Total: 2560 + # HugePages_Free: 2560 + +**Configuration** + +.. code-block:: yaml + + local_cpu_use_hugepages: true + +Or via environment variable: + +.. code-block:: bash + + export LMCACHE_LOCAL_CPU_USE_HUGEPAGES=true + +**Restrictions** + +- Hugepages are **not compatible with P2P mode** (``enable_p2p: true``). +- Hugepages are **not compatible with shared memory** (``shm_name`` is set). +- On non-CUDA platforms, hugepages are not supported. Regular allocation will be used as fallback. + .. _cpu_ram-online-inference-example: Online Inference Example diff --git a/docs/source/kv_cache/storage_backends/nixl.rst b/docs/source/kv_cache/storage_backends/nixl.rst index 9f599ab1ef6..bc689b78b31 100644 --- a/docs/source/kv_cache/storage_backends/nixl.rst +++ b/docs/source/kv_cache/storage_backends/nixl.rst @@ -37,7 +37,8 @@ Example ``lmcache-config.yaml`` for POSIX backend: nixl_backend: POSIX nixl_pool_size: 64 nixl_path: /mnt/nixl/cache/ - use_direct_io: True + use_direct_io: true + nixl_use_hugepages: true # optional, requires pre-allocated hugepages Key settings: @@ -51,6 +52,8 @@ Key settings: - ``nixl_backend``: configuration of which nixl backend to use for storage. +- ``nixl_use_hugepages``: whether to use Linux hugepages (2 MiB) for the NIXL CPU buffer. Not supported for GPU buffers. Requires pre-allocated hugepages (``sysctl vm.nr_hugepages``). Default: ``false``. + .. note:: Supported backends are: ["GDS", "GDS_MT", "POSIX", "HF3FS", "OBJ", "AZURE_BLOB"]. diff --git a/lmcache/python_ops_fallback.py b/lmcache/python_ops_fallback.py index 10fce1d4afa..5c19d96f8b4 100644 --- a/lmcache/python_ops_fallback.py +++ b/lmcache/python_ops_fallback.py @@ -10,6 +10,7 @@ from typing import Optional, Tuple import ctypes import ctypes.util +import warnings # Third Party from numba import njit @@ -438,6 +439,42 @@ def free_shm_pinned_ptr(ptr: int, size: int = 0, shm_name: str = "") -> None: shm.unlink() +# Hugepage variants: non-CUDA platforms do not support hugepages, so these +# fall back to the same regular pinned allocation. + + +def alloc_hugepage_pinned_ptr(size: int, device_id: int = 0) -> int: + """Non-CUDA fallback for alloc_hugepage_pinned_ptr (no hugepage support).""" + warnings.warn( + "Hugepages requested but not available on non-CUDA platforms; " + "falling back to regular allocation.", + RuntimeWarning, + stacklevel=2, + ) + return alloc_pinned_ptr(size, device_id) + + +def free_hugepage_pinned_ptr(ptr: int, size: int = 0) -> None: + """Non-CUDA fallback for free_hugepage_pinned_ptr (no hugepage support).""" + free_pinned_ptr(ptr) + + +def alloc_hugepage_pinned_numa_ptr(size: int, numa_id: int = 0) -> int: + """Non-CUDA fallback for alloc_hugepage_pinned_numa_ptr (no hugepage support).""" + warnings.warn( + "Hugepages requested but not available on non-CUDA platforms; " + "falling back to regular allocation.", + RuntimeWarning, + stacklevel=2, + ) + return alloc_pinned_numa_ptr(size, numa_id) + + +def free_hugepage_pinned_numa_ptr(ptr: int, size: int = 0) -> None: + """Non-CUDA fallback for free_hugepage_pinned_numa_ptr (no hugepage support).""" + free_pinned_numa_ptr(ptr, size) + + def alloc_numa_ptr(size: int, numa_id: int = 0) -> int: """Non-CUDA equivalent of allocating numa memory and returning pointer to it. Note: Numa memory is not supported on non-CUDA.""" diff --git a/lmcache/v1/config.py b/lmcache/v1/config.py index c7b317d99ca..d52693cafed 100644 --- a/lmcache/v1/config.py +++ b/lmcache/v1/config.py @@ -75,6 +75,11 @@ "env_converter": _to_bool, }, "max_local_cpu_size": {"type": float, "default": 5.0, "env_converter": float}, + "local_cpu_use_hugepages": { + "type": bool, + "default": False, + "env_converter": _to_bool, + }, "reserve_local_cpu_size": {"type": float, "default": 0.0, "env_converter": float}, "local_disk": { "type": Optional[str], diff --git a/lmcache/v1/memory_management.py b/lmcache/v1/memory_management.py index 9e1a5182b3c..18509941c8e 100644 --- a/lmcache/v1/memory_management.py +++ b/lmcache/v1/memory_management.py @@ -372,25 +372,44 @@ def parent(self) -> Optional["MemoryAllocatorInterface"]: raise NotImplementedError +@dataclass +class PinnedAllocFree: + """Resolved alloc/free function pair for pinned CPU memory.""" + + alloc_fn: Any + alloc_args: tuple + free_fn: Any + free_args: tuple + + def alloc(self) -> int: + """Allocate pinned memory and return the raw pointer.""" + return self.alloc_fn(*self.alloc_args) + + def free(self, ptr: int) -> None: + """Free a previously allocated pinned-memory pointer.""" + self.free_fn(ptr, *self.free_args) + + def _resolve_pinned_alloc_free( numa_mapping: Optional[NUMAMapping] = None, shm_name: Optional[str] = None, size: Optional[int] = None, -) -> Tuple[ - tuple, # (alloc_fn, *alloc_args) - tuple, # (free_fn, *free_args_after_ptr) -]: + use_hugepages: bool = False, +) -> PinnedAllocFree: """Resolve the alloc/free function pair based on memory type. Returns: - A tuple of (alloc_info, free_info) where: - - alloc_info: (alloc_fn, *args) to call as alloc_fn(size, *args) - - free_info: (free_fn, *args) to call as free_fn(ptr, *args) + A PinnedAllocFree with the resolved functions and their extra + arguments. Call ``ptr = resolved.alloc()`` and ``resolved.free(ptr)``. """ if shm_name: - return ( - (lmc_ops.alloc_shm_pinned_ptr, shm_name), - (lmc_ops.free_shm_pinned_ptr, size, shm_name), + if use_hugepages: + raise ValueError("Hugepages are not supported with shared memory (shm)") + return PinnedAllocFree( + alloc_fn=lmc_ops.alloc_shm_pinned_ptr, + alloc_args=(size, shm_name), + free_fn=lmc_ops.free_shm_pinned_ptr, + free_args=(size, shm_name), ) elif numa_mapping: if torch_dev.is_available(): @@ -402,31 +421,103 @@ def _resolve_pinned_alloc_free( f"Current device {current_device_id} is not in the GPU NUMA mapping." ) numa_id = gpu_to_numa_mapping[current_device_id] - return ( - (lmc_ops.alloc_pinned_numa_ptr, numa_id), - (lmc_ops.free_pinned_numa_ptr, size), - ) + if use_hugepages: + return PinnedAllocFree( + alloc_fn=lmc_ops.alloc_hugepage_pinned_numa_ptr, + alloc_args=(size, numa_id), + free_fn=lmc_ops.free_hugepage_pinned_numa_ptr, + free_args=(size,), + ) + else: + return PinnedAllocFree( + alloc_fn=lmc_ops.alloc_pinned_numa_ptr, + alloc_args=(size, numa_id), + free_fn=lmc_ops.free_pinned_numa_ptr, + free_args=(size,), + ) else: - return ( - (lmc_ops.alloc_pinned_ptr, 0), - (lmc_ops.free_pinned_ptr,), - ) + flags = 0 + if use_hugepages: + return PinnedAllocFree( + alloc_fn=lmc_ops.alloc_hugepage_pinned_ptr, + alloc_args=(size, flags), + free_fn=lmc_ops.free_hugepage_pinned_ptr, + free_args=(size,), + ) + else: + return PinnedAllocFree( + alloc_fn=lmc_ops.alloc_pinned_ptr, + alloc_args=(size, flags), + free_fn=lmc_ops.free_pinned_ptr, + free_args=(), + ) + + +def _read_hugepage_info() -> Optional[Tuple[int, int, int]]: + """Read hugepage pool stats from sysfs. + + NOTE: We only use 2 MiB hugepages, so the pool stats are taken from + the 2 MiB pool directly rather than the system default pool reported in + ``/proc/meminfo`` (which can be 1 GiB on some hosts). + + Returns: + ``(nr_hugepages, free_hugepages, page_size_mb)`` for the hugepage + pool, or ``None`` if the sysfs entries are unavailable. + """ + base = "/sys/kernel/mm/hugepages/hugepages-2048kB" + try: + with open(f"{base}/nr_hugepages") as f: + total = int(f.read().strip()) + with open(f"{base}/free_hugepages") as f: + free = int(f.read().strip()) + return total, free, 2 + except (OSError, ValueError): + return None def _allocate_cpu_memory( size: int, numa_mapping: Optional[NUMAMapping] = None, shm_name: Optional[str] = None, + use_hugepages: bool = False, ) -> torch.Tensor: if size == 0: return torch.empty(0, dtype=torch.uint8) - alloc_info, _ = _resolve_pinned_alloc_free( + resolved = _resolve_pinned_alloc_free( numa_mapping, shm_name, + size, + use_hugepages, ) - alloc_fn, *alloc_args = alloc_info - ptr = alloc_fn(size, *alloc_args) + + try: + ptr = resolved.alloc() + except RuntimeError as e: + if use_hugepages and "mmap failed" in str(e): + diag = _read_hugepage_info() + if diag is not None: + total, free, page_mb = diag + page_bytes = page_mb * 1024 * 1024 + needed = (size + page_bytes - 1) // page_bytes + logger.error( + "Failed to allocate huge pages. " + "Pool has %d pages (%d free, each %d MiB). " + "Requested %d bytes (%d pages). " + "Please grow the %d MiB hugepage pool.", + total, + free, + page_mb, + size, + needed, + page_mb, + ) + else: + logger.error( + "Failed to allocate huge pages. " + "Please grow the 2 MiB hugepage pool." + ) + raise array_type = ctypes.c_uint8 * size buf = array_type.from_address(ptr) @@ -440,17 +531,18 @@ def _free_cpu_memory( size: int | None = None, numa_mapping: Optional[NUMAMapping] = None, shm_name: Optional[str] = None, + use_hugepages: bool = False, ) -> None: if torch_dev.is_available(): torch_dev.synchronize() - _, free_info = _resolve_pinned_alloc_free( + resolved = _resolve_pinned_alloc_free( numa_mapping, shm_name, - size=size, + size, + use_hugepages, ) - free_fn, *free_args = free_info - free_fn(buffer.data_ptr(), *free_args) + resolved.free(buffer.data_ptr()) def _allocate_gpu_memory( @@ -534,6 +626,7 @@ def get_shape(self) -> torch.Size: return self.meta.shape def get_dtype(self) -> torch.dtype: + assert self.meta.dtype is not None return self.meta.dtype def get_shapes(self) -> list[torch.Size]: @@ -1819,12 +1912,12 @@ def memcheck(self): def __str__(self): return "PagedTensorMemoryAllocator" - def get_paged_buffers(self) -> list[torch.Tensor]: + def get_paged_buffers(self) -> tuple[torch.Tensor, ...]: """ - Get the list of paged buffers for fixed buffer registration. + Get the paged buffers for fixed buffer registration. Returns: - List of paged buffer tensors that can be registered with io_uring + Tuple of paged buffer tensors that can be registered with io_uring for true zero copy operations. """ return self.paged_buffers @@ -2062,12 +2155,16 @@ class MixedMemoryAllocator(MemoryAllocatorInterface): (2) byte_array buffer memory. """ - def __init__(self, size: int, use_paging: bool = False, **kwargs): + def __init__( + self, size: int, use_paging: bool = False, use_hugepages: bool = False, **kwargs + ): """ :param int size: The size of the pinned memory in bytes. + :param bool use_hugepages: Whether to use hugepages. """ self.numa_mapping = kwargs.get("numa_mapping", None) + self.use_hugepages = use_hugepages self.align_bytes = kwargs.get("align_bytes", AddressManager.ALIGN_BYTES) if self.align_bytes <= 0 or self.align_bytes & (self.align_bytes - 1) != 0: raise ValueError("align_bytes must be a positive power of two") @@ -2083,7 +2180,9 @@ def __init__(self, size: int, use_paging: bool = False, **kwargs): self.size = size - self.buffer = _allocate_cpu_memory(size, self.numa_mapping, self.shm_name) + self.buffer = _allocate_cpu_memory( + size, self.numa_mapping, self.shm_name, use_hugepages=use_hugepages + ) self._unregistered = False @@ -2218,15 +2317,16 @@ def close(self): self.size, self.numa_mapping, self.shm_name, + use_hugepages=self.use_hugepages, ) self._unregistered = True - def get_paged_buffers(self) -> Optional[list[torch.Tensor]]: + def get_paged_buffers(self) -> Optional[tuple[torch.Tensor, ...]]: """ - Get the list of paged buffers for fixed buffer registration. + Get the paged buffers for fixed buffer registration. Returns: - List of paged buffer tensors if using paged allocator, None otherwise. + Tuple of paged buffer tensors if using paged allocator, None otherwise. These buffers can be registered with io_uring for true zero copy operations. """ if isinstance(self.pin_allocator, PagedTensorMemoryAllocator): diff --git a/lmcache/v1/storage_backend/local_cpu_backend.py b/lmcache/v1/storage_backend/local_cpu_backend.py index c072465db92..fbc2e461e98 100644 --- a/lmcache/v1/storage_backend/local_cpu_backend.py +++ b/lmcache/v1/storage_backend/local_cpu_backend.py @@ -350,6 +350,7 @@ def initialize_allocator( metadata: Optional[LMCacheMetadata] = None, ) -> MemoryAllocatorInterface: cpu_size = config.max_local_cpu_size + use_hugepages = config.local_cpu_use_hugepages if metadata is not None: # save_only_first_rank only works when use mla @@ -381,6 +382,9 @@ def initialize_allocator( ) if config.enable_p2p: + if use_hugepages: + raise ValueError("Hugepages are not supported with P2P mode") + # TODO(baoloongmao): Add lazy memory allocator support for P2P mode # For now, keep the original P2P implementation assert metadata is not None @@ -483,6 +487,7 @@ def initialize_allocator( return MixedMemoryAllocator( align_cpu_size_bytes, use_paging=True, + use_hugepages=False, **kwargs, ) @@ -492,11 +497,13 @@ def initialize_allocator( cpu_size_bytes, numa_mapping=numa_mapping, align_bytes=allocator_align_bytes, + use_hugepages=use_hugepages, ) return MixedMemoryAllocator( cpu_size_bytes, numa_mapping=numa_mapping, config=config, + use_hugepages=use_hugepages, ) @staticmethod diff --git a/lmcache/v1/storage_backend/nixl_storage_backend.py b/lmcache/v1/storage_backend/nixl_storage_backend.py index 856f351d8dc..a6895b31f8d 100644 --- a/lmcache/v1/storage_backend/nixl_storage_backend.py +++ b/lmcache/v1/storage_backend/nixl_storage_backend.py @@ -90,6 +90,7 @@ class NixlStorageConfig: enable_async_put: bool use_direct_io: bool path: str + use_hugepages: bool enable_prog_thread: bool sync_mode: Optional[Any] # nixl_thread_sync_t, None if unsupported @@ -149,6 +150,7 @@ def from_cache_engine_config( len(endpoint_list), ) path = extra_config.get("nixl_path") + use_hugepages = extra_config.get("nixl_use_hugepages", False) enable_prog_thread = extra_config.get("nixl_enable_prog_thread", True) sync_mode_str = extra_config.get("nixl_sync_mode", None) if sync_mode_str is not None and not _NIXL_SYNC_MODE_SUPPORTED: @@ -215,6 +217,7 @@ def from_cache_engine_config( enable_async_put=enable_async_put, use_direct_io=use_direct_io, path=path, + use_hugepages=use_hugepages, enable_prog_thread=enable_prog_thread, sync_mode=sync_mode, ) @@ -706,6 +709,7 @@ def __init__( self.progress_lock = threading.RLock() self.progress_set: Set[CacheEngineKey] = set() + self.nixl_config = nixl_config self.memory_allocator = self.initialize_allocator(config, metadata) def initialize_allocator( @@ -718,15 +722,23 @@ def initialize_allocator( "enable_nixl_storage" ) assert enable_nixl_storage + corrected_device = get_correct_device( config.nixl_buffer_device, metadata.worker_id, ) + self.use_hugepages = self.nixl_config.use_hugepages + self.buffer_size = config.nixl_buffer_size if corrected_device == "cpu": - self.buffer = _allocate_cpu_memory(config.nixl_buffer_size) + self.buffer = _allocate_cpu_memory( + config.nixl_buffer_size, use_hugepages=self.use_hugepages + ) self.free_pinned_buffer = True else: + if self.use_hugepages: + logger.warning("Hugepages are not supported for GPU memory allocation") + self.use_hugepages = False base_buffer, self.buffer = _allocate_gpu_memory( config.nixl_buffer_size, corrected_device ) @@ -1141,7 +1153,9 @@ def close(self) -> None: self.memory_allocator.close() if self.free_pinned_buffer: - _free_cpu_memory(self.buffer) + _free_cpu_memory( + self.buffer, self.buffer_size, use_hugepages=self.use_hugepages + ) class NixlDynamicStorageBackend(NixlStorageBackend): @@ -1862,4 +1876,6 @@ def close(self) -> None: self.memory_allocator.close() if self.free_pinned_buffer: - _free_cpu_memory(self.buffer) + _free_cpu_memory( + self.buffer, self.buffer_size, use_hugepages=self.use_hugepages + ) diff --git a/tests/v1/test_memory_management.py b/tests/v1/test_memory_management.py index aad1a0feb47..9f6429ce9ed 100644 --- a/tests/v1/test_memory_management.py +++ b/tests/v1/test_memory_management.py @@ -21,9 +21,14 @@ PinMemoryAllocator, TensorMemoryAllocator, TensorMemoryObj, + _allocate_cpu_memory, + _free_cpu_memory, + _read_hugepage_info, ) from lmcache.v1.pin_monitor import PinMonitor +HUGEPAGE_SIZE = 2 * 1024 * 1024 # MAP_HUGE_2MB + def check_allocator(allocator, max_size): # 512 * 512 * 4 = 1MB @@ -885,3 +890,56 @@ def test_allocation_and_free_interleaved(self, lazy_allocator_cls): assert allocator.memcheck() allocator.close() + + +def _get_num_free_hugepages() -> int: + """Return the number of free huge pages, or 0 if unknown.""" + info = _read_hugepage_info() + if info is None: + return 0 + _, free, _ = info + return free + + +@pytest.mark.skipif( + _get_num_free_hugepages() < 1, + reason="Requires at least 1 free huge page (sysctl vm.nr_hugepages)", +) +class TestHugepageAllocation: + """Tests for hugepage-backed CPU memory allocation. + + Skipped unless the system has pre-allocated huge pages. + """ + + def test_allocate_and_free(self): + """Allocate one huge page worth of memory and free it.""" + buf = _allocate_cpu_memory(HUGEPAGE_SIZE, use_hugepages=True) + assert buf.numel() == HUGEPAGE_SIZE + assert buf.dtype == torch.uint8 + buf[0] = 42 + buf[-1] = 99 + assert buf[0].item() == 42 + assert buf[-1].item() == 99 + _free_cpu_memory(buf, size=HUGEPAGE_SIZE, use_hugepages=True) + + @pytest.mark.skipif( + _get_num_free_hugepages() < 4, + reason="Requires at least 4 free huge pages (sysctl vm.nr_hugepages)", + ) + def test_allocate_multiple_pages(self): + """Allocate several huge pages and verify the buffer is usable.""" + size = 4 * HUGEPAGE_SIZE + buf = _allocate_cpu_memory(size, use_hugepages=True) + assert buf.numel() == size + buf.fill_(7) + assert buf[size // 2].item() == 7 + _free_cpu_memory(buf, size=size, use_hugepages=True) + + def test_read_hugepage_info(self): + """_read_hugepage_info returns valid data on Linux.""" + info = _read_hugepage_info() + assert info is not None + total, free, page_mb = info + assert total > 0 + assert free >= 0 + assert page_mb == 2 From bda3af18db83a9e1a2da9d8ed8ee7ca064b890ef Mon Sep 17 00:00:00 2001 From: Yihua Cheng Date: Tue, 26 May 2026 12:43:04 -0500 Subject: [PATCH 5/5] [MP][Core] Refactor MPCacheEngine for better extendability (#3391) * [Add] refactoring for the LMCache MP cache engine Signed-off-by: Yihua Cheng --- lmcache/v1/multiprocess/config.py | 21 +- lmcache/v1/multiprocess/engine_context.py | 157 ++ lmcache/v1/multiprocess/engine_module.py | 65 + lmcache/v1/multiprocess/http_server.py | 7 +- lmcache/v1/multiprocess/modules/__init__.py | 1 + .../{blend_server_v2.py => modules/blend.py} | 699 +++----- .../v1/multiprocess/modules/gpu_transfer.py | 670 +++++++ lmcache/v1/multiprocess/modules/lookup.py | 467 +++++ lmcache/v1/multiprocess/modules/management.py | 129 ++ .../multiprocess/modules/non_gpu_transfer.py | 347 ++++ lmcache/v1/multiprocess/server.py | 1543 ++--------------- tests/v1/multiprocess/test_blend_server_v2.py | 75 +- tests/v1/multiprocess/test_cache_server.py | 2 +- tests/v1/multiprocess/test_free_locks.py | 47 +- .../test_non_cuda_data_transfer.py | 42 +- .../v1/multiprocess/test_query_lookup_hits.py | 63 +- 16 files changed, 2375 insertions(+), 1960 deletions(-) create mode 100644 lmcache/v1/multiprocess/engine_context.py create mode 100644 lmcache/v1/multiprocess/engine_module.py create mode 100644 lmcache/v1/multiprocess/modules/__init__.py rename lmcache/v1/multiprocess/{blend_server_v2.py => modules/blend.py} (68%) create mode 100644 lmcache/v1/multiprocess/modules/gpu_transfer.py create mode 100644 lmcache/v1/multiprocess/modules/lookup.py create mode 100644 lmcache/v1/multiprocess/modules/management.py create mode 100644 lmcache/v1/multiprocess/modules/non_gpu_transfer.py diff --git a/lmcache/v1/multiprocess/config.py b/lmcache/v1/multiprocess/config.py index 8873bd24d08..71670947a80 100644 --- a/lmcache/v1/multiprocess/config.py +++ b/lmcache/v1/multiprocess/config.py @@ -39,9 +39,13 @@ class MPServerConfig: engine_type: str = "default" """Cache engine backend type - ('default' for MPCacheEngine, 'blend' for BlendEngineV2). + ('default' for standard prefix caching, 'blend' when cacheblend is enabled). """ + transfer_mode: str = "gpu" + """Transfer mode: 'gpu' for GPU-based IPC transfer (STORE/RETRIEVE), + 'non_gpu' for non-GPU-based transfer (PREPARE/COMMIT).""" + runtime_plugin_config: "RuntimePluginConfig" = field( default_factory=lambda: RuntimePluginConfig() ) @@ -146,9 +150,17 @@ def add_mp_server_args( type=str, default="default", choices=["default", "blend"], - help="Cache engine backend type. 'default' uses MPCacheEngine, " - "'blend' uses BlendEngineV2 for cross-request KV reuse. " - "Default is 'default'.", + help="Cache engine backend type. 'default' uses standard prefix caching, " + "'blend' when cacheblend is enabled. Default is 'default'.", + ) + mp_group.add_argument( + "--transfer-mode", + type=str, + default="gpu", + choices=["gpu", "non_gpu"], + help="Transfer mode: 'gpu' for GPU-based IPC transfer " + "(STORE/RETRIEVE), 'non_gpu' for non-GPU-based transfer " + "(PREPARE/COMMIT). Default is 'gpu'.", ) mp_group.add_argument( "--runtime-plugin-locations", @@ -198,6 +210,7 @@ def parse_args_to_mp_server_config( max_cpu_workers=max_cpu, hash_algorithm=args.hash_algorithm, engine_type=args.engine_type, + transfer_mode=args.transfer_mode, runtime_plugin_config=RuntimePluginConfig( locations=(args.runtime_plugin_locations or []), extra_config=plugin_extra, diff --git a/lmcache/v1/multiprocess/engine_context.py b/lmcache/v1/multiprocess/engine_context.py new file mode 100644 index 00000000000..890a56ca8c4 --- /dev/null +++ b/lmcache/v1/multiprocess/engine_context.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared context and layout descriptor registry for engine modules.""" + +# Standard +import threading + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.distributed.api import ( + MemoryLayoutDesc, + ObjectKey, + ipc_key_to_object_keys, +) +from lmcache.v1.distributed.config import StorageManagerConfig +from lmcache.v1.distributed.storage_manager import StorageManager +from lmcache.v1.mp_observability.event_bus import EventBus, get_event_bus +from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey +from lmcache.v1.multiprocess.session import SessionManager +from lmcache.v1.multiprocess.token_hasher import TokenHasher + +logger = init_logger(__name__) + + +class LayoutDescRegistry: + """Thread-safe registry mapping (model_name, world_size) to MemoryLayoutDesc. + + Modules write to this registry when KV caches are registered. + Consumers (e.g. LookupModule) read from it to find layout descriptors + for prefetch tasks. + """ + + def __init__(self) -> None: + # Key: (model_name, world_size) -> MemoryLayoutDesc + self._registry: dict[tuple[str, int], MemoryLayoutDesc] = {} + self._lock = threading.Lock() + + def register( + self, + model_name: str, + world_size: int, + layout_desc: MemoryLayoutDesc, + ) -> None: + """Register a layout descriptor for a (model_name, world_size) pair. + + Args: + model_name: The model name. + world_size: The world size. + layout_desc: The memory layout descriptor. + """ + with self._lock: + self._registry[(model_name, world_size)] = layout_desc + + def unregister(self, model_name: str, world_size: int) -> None: + """Remove a layout descriptor for a (model_name, world_size) pair. + + Args: + model_name: The model name. + world_size: The world size. + """ + with self._lock: + self._registry.pop((model_name, world_size), None) + + def find(self, model_name: str, world_size: int) -> MemoryLayoutDesc | None: + """Look up a layout descriptor by (model_name, world_size). + + Args: + model_name: The model name. + world_size: The world size. + + Returns: + The layout descriptor if found, otherwise None. + """ + with self._lock: + return self._registry.get((model_name, world_size)) + + +class MPCacheEngineContext: + """Shared infrastructure for all engine modules. + + Holds the storage manager, token hasher, session manager, event bus, + and layout descriptor registry. Modules receive this context at init + and use it for shared operations. + + Args: + storage_manager_config: Configuration for the storage manager. + chunk_size: Chunk size for KV cache operations. + hash_algorithm: Hash algorithm for token hashing. + """ + + def __init__( + self, + storage_manager_config: StorageManagerConfig, + chunk_size: int = 256, + hash_algorithm: str = "blake3", + ) -> None: + self._chunk_size = chunk_size + self._storage_manager = StorageManager(storage_manager_config) + self._token_hasher = TokenHasher( + chunk_size=chunk_size, hash_algorithm=hash_algorithm + ) + self._session_manager = SessionManager(self._token_hasher) + self._event_bus = get_event_bus() + self._layout_desc_registry = LayoutDescRegistry() + + @property + def chunk_size(self) -> int: + """Chunk size for KV cache operations.""" + return self._chunk_size + + @property + def storage_manager(self) -> StorageManager: + """The storage manager instance.""" + return self._storage_manager + + @property + def token_hasher(self) -> TokenHasher: + """The token hasher for computing chunk hashes.""" + return self._token_hasher + + @property + def session_manager(self) -> SessionManager: + """The session manager for request lifecycle tracking.""" + return self._session_manager + + @property + def event_bus(self) -> EventBus: + """The event bus for observability events.""" + return self._event_bus + + @property + def layout_desc_registry(self) -> LayoutDescRegistry: + """Registry mapping (model_name, world_size) to MemoryLayoutDesc.""" + return self._layout_desc_registry + + def resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: + """Resolve object keys from an IPC cache key. + + Uses the session manager to track token state and the token hasher + to compute chunk hashes for the requested range. + + Args: + key: IPC cache key describing model/session/token range. + + Returns: + Resolved object keys for the requested token range. + + Raises: + ValueError: If ``key.worker_id`` is ``None``. + """ + session = self.session_manager.get_or_create(key.request_id) + session.set_tokens(list(key.token_ids)) + chunk_hashes = [ + TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) + ] + if key.worker_id is None: + raise ValueError("Must resolve keys with worker_id != None") + return ipc_key_to_object_keys(key, chunk_hashes) diff --git a/lmcache/v1/multiprocess/engine_module.py b/lmcache/v1/multiprocess/engine_module.py new file mode 100644 index 00000000000..592d124adae --- /dev/null +++ b/lmcache/v1/multiprocess/engine_module.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Protocol and types for pluggable engine modules.""" + +# Future +from __future__ import annotations + +# Standard +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, Callable, Protocol + +# First Party +from lmcache.v1.multiprocess.protocol import RequestType + +if TYPE_CHECKING: + # First Party + from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext + + +class ThreadPoolType(Enum): + """Declares which thread pool a handler should run in.""" + + SYNC = auto() + AFFINITY = auto() + NORMAL = auto() + + +@dataclass +class HandlerSpec: + """Specification for a single message queue handler. + + Args: + request_type: The ZMQ request type this handler serves. + handler: The callable that processes the request. + pool: Which thread pool the handler runs in. + """ + + request_type: RequestType + handler: Callable + pool: ThreadPoolType + + +class EngineModule(Protocol): + """Protocol for pluggable engine modules. + + Each module owns its internal state and exposes handlers + that the compositor registers with the message queue server. + """ + + @property + def context(self) -> MPCacheEngineContext: + """Return the shared engine context. Exposed for testing only.""" + ... + + def get_handlers(self) -> list[HandlerSpec]: + """Return handler specs for all request types this module serves.""" + ... + + def report_status(self) -> dict: + """Return module-specific status information.""" + ... + + def close(self) -> None: + """Release resources owned by this module.""" + ... diff --git a/lmcache/v1/multiprocess/http_server.py b/lmcache/v1/multiprocess/http_server.py index c83540314ca..2e1f6d45f0d 100644 --- a/lmcache/v1/multiprocess/http_server.py +++ b/lmcache/v1/multiprocess/http_server.py @@ -35,6 +35,7 @@ from lmcache.v1.multiprocess.mp_runtime_plugin_launcher import ( MPRuntimePluginLauncher, ) +from lmcache.v1.multiprocess.server import run_cache_server logger = init_logger(__name__) @@ -61,12 +62,6 @@ async def lifespan(app: FastAPI): torch_dev.is_available(), ) mp_config = _configs["mp"] - if mp_config.engine_type == "blend": - # First Party - from lmcache.v1.multiprocess.blend_server_v2 import run_cache_server - else: - # First Party - from lmcache.v1.multiprocess.server import run_cache_server result = run_cache_server( mp_config=mp_config, diff --git a/lmcache/v1/multiprocess/modules/__init__.py b/lmcache/v1/multiprocess/modules/__init__.py new file mode 100644 index 00000000000..9881313609a --- /dev/null +++ b/lmcache/v1/multiprocess/modules/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/lmcache/v1/multiprocess/blend_server_v2.py b/lmcache/v1/multiprocess/modules/blend.py similarity index 68% rename from lmcache/v1/multiprocess/blend_server_v2.py rename to lmcache/v1/multiprocess/modules/blend.py index 1d355179a63..22b82419fac 100644 --- a/lmcache/v1/multiprocess/blend_server_v2.py +++ b/lmcache/v1/multiprocess/modules/blend.py @@ -1,41 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -""" -Overview --------- -This server enables KV cache reuse across requests that share token -sub-sequences at *arbitrary positions*, not only at a common prefix. - -Workflow (example: chunk_size = 3) ------------------------------------ -1. cb_store_pre_computed([1,2,3,4,5,6]) - Tokens are split into full chunks ([1,2,3] and [4,5,6]). Each chunk - is stored in the underlying storage under its normal rolling prefix - hash, and the chunk fingerprints are registered in - BlendTokenRangeMatcher for fast sub-sequence lookup. Because normal - hashes are used, these chunks are also accessible via the standard - lookup/retrieve path. - -2. cb_lookup_pre_computed([x,y,z, a,b,c, 4,5,6, m,n,p]) - BlendTokenRangeMatcher slides a rolling polynomial hash over the new - request's tokens and detects that the window at positions [6, 9) - matches the stored chunk [4,5,6]. A prefetch task is submitted for - that chunk using its stored hash as the storage key. Only chunks - confirmed present in storage are returned as CBMatchResult objects - (with cur_st/cur_ed pointing to their location in the new request). - -3. cb_retrieve_pre_computed(...) - The (prefetched) KV cache for each matched chunk is copied (CPU→GPU) - into the correct slot of the new request's KV cache buffer (at - cur_st + offset), so the LLM can skip recomputing those tokens. - -4. cb_store_final([x,y,z, a,b,c, 4,5,6, m,n,p]) - After inference completes on the new request, all its chunks are - stored under normal prefix hashes. Future requests sharing - any prefix of the new request will get standard prefix-cache hits. - Future requests sharing any prefix of the first request will also - get hits because cb_store_pre_computed already stored those chunks - under normal hashes. -""" +"""Blend (context-blend / cross-request KV reuse) module for MPCacheEngine.""" # Standard from typing import Any @@ -44,7 +8,6 @@ # Third Party import numpy as np -import zmq # First Party from lmcache import torch_dev, torch_device_type @@ -53,44 +16,26 @@ from lmcache.v1.distributed.api import ( MemoryLayoutDesc, ObjectKey, + PrefetchHandle, ipc_key_to_object_keys, ) -from lmcache.v1.distributed.config import ( - StorageManagerConfig, - parse_args_to_config, -) -from lmcache.v1.distributed.storage_manager import PrefetchHandle from lmcache.v1.gpu_connector.gpu_ops import ( lmcache_memcpy_async_d2h, lmcache_memcpy_async_h2d, ) -from lmcache.v1.mp_observability.config import ( - ObservabilityConfig, - init_observability, - parse_args_to_observability_config, -) from lmcache.v1.mp_observability.event import Event, EventType -from lmcache.v1.mp_observability.event_bus import get_event_bus -from lmcache.v1.mp_observability.trace import maybe_initialize_trace_recorder -from lmcache.v1.multiprocess.config import ( - MPServerConfig, - parse_args_to_mp_server_config, -) from lmcache.v1.multiprocess.custom_types import ( CBMatchResult, IPCCacheEngineKey, KVCache, ) -from lmcache.v1.multiprocess.gpu_context import ( - PlainGPUCacheContext, -) -from lmcache.v1.multiprocess.mq import MessageQueueServer -from lmcache.v1.multiprocess.protocol import ( - RequestType, - get_handler_type, - get_payload_classes, +from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext +from lmcache.v1.multiprocess.engine_module import ( + HandlerSpec, + ThreadPoolType, ) -from lmcache.v1.multiprocess.server import MPCacheEngine, parse_args +from lmcache.v1.multiprocess.gpu_context import PlainGPUCacheContext +from lmcache.v1.multiprocess.protocols.base import RequestType from lmcache.v1.multiprocess.token_hasher import ( chunk_hash_windows_numba, rolling_hash_windows_numba, @@ -105,43 +50,46 @@ class BlendTokenRangeMatcher: """Fast token-range matcher using polynomial rolling/chunk hashes and a direct-address lookup table. - Table layout: poly_chunk_hash (u64) → compact_chunk_id (i64, sequential 0…N-1). + Table layout: poly_chunk_hash (u64) -> compact_chunk_id (i64, sequential 0...N-1). Because compact IDs are bounded by _TABLE_SIZE, unique_hits_direct_id_numba - can use a fixed `seen` array of _TABLE_SIZE bytes (~1 MB) rather than one - sized by an arbitrary max hash — no memory explosion. + can use a fixed ``seen`` array of _TABLE_SIZE bytes (~1 MB) rather than one + sized by an arbitrary max hash -- no memory explosion. Auxiliary storage: _chunk_token_hash[i] : token_hash for chunk i (None if evicted) - _token_hash_to_start : token_hash → start position in seq + _token_hash_to_start : token_hash -> start position in seq _compact_id_to_slot[i] : table slot for compact_id i - _token_hash_to_compact_id : token_hash → compact_chunk_id + _token_hash_to_compact_id : token_hash -> compact_chunk_id Methods: - on_new_token_hashes – register a sequence; builds fingerprints - and writes compact IDs. - match_sub_sequence – sliding-window probe → compact IDs → - token_hash → start. Skips evicted entries. - remove_chunks – lazily evict stale entries. Clears the - table slot and auxiliary maps. + on_new_token_hashes -- register a sequence; builds fingerprints + and writes compact IDs. + match_sub_sequence -- sliding-window probe -> compact IDs -> + token_hash -> start. Skips evicted entries. + remove_chunks -- lazily evict stale entries. Clears the + table slot and auxiliary maps. + + Args: + chunk_size: Number of tokens per chunk for fingerprint computation. """ - _TABLE_BITS: int = 20 # 2^20 ≈ 1 M entries + _TABLE_BITS: int = 20 # 2^20 ~ 1 M entries _TABLE_SIZE: int = 1 << _TABLE_BITS _BASE: np.uint64 = np.uint64(0x9E3779B97F4A7C15) # Fibonacci-hashing constant def __init__(self, chunk_size: int = 256): self.chunk_size = chunk_size - # poly_chunk_hash → compact_chunk_id; -1 = empty + # poly_chunk_hash -> compact_chunk_id; -1 = empty self._table_id = np.full(self._TABLE_SIZE, -1, dtype=np.int64) self._mask = np.uint64(self._TABLE_SIZE - 1) - # compact_chunk_id → caller-supplied token_hash (full bytes) + # compact_chunk_id -> caller-supplied token_hash (full bytes) self._chunk_token_hash: list[bytes | None] = [] - # token_hash → start position in its registered sequence + # token_hash -> start position in its registered sequence self._token_hash_to_start: dict[bytes, int] = {} - # compact_chunk_id → table slot index (for reverse lookup during eviction) + # compact_chunk_id -> table slot index (for reverse lookup during eviction) self._compact_id_to_slot = np.full(self._TABLE_SIZE, -1, dtype=np.int64) - # token_hash → compact_chunk_id (for eviction lookup) + # token_hash -> compact_chunk_id (for eviction lookup) self._token_hash_to_compact_id: dict[bytes, int] = {} self._lock = threading.Lock() @@ -149,7 +97,7 @@ def on_new_token_hashes( self, token_ids: list[int], token_hashes: list[bytes], - ): + ) -> None: """Register a new token sequence and index its non-overlapping chunks. Args: @@ -207,10 +155,10 @@ def on_new_token_hashes( ) compact_ids = np.arange(base_id, base_id + n_new, dtype=np.int64) - # Write table: poly_chunk_hash → compact_chunk_id + # Write table: poly_chunk_hash -> compact_chunk_id update_table_id_numba(new_chunk_hashes, self._table_id, compact_ids) - # Persist compact_id → token_hash, token_hash → start, and reverse maps + # Persist compact_id -> token_hash, token_hash -> start, and reverse maps for k, orig_i in enumerate(new_idxs): th = token_hashes[orig_i] cid = int(compact_ids[k]) @@ -363,28 +311,108 @@ def _unique_token_coverage(results: list[CBMatchResult]) -> int: return coverage -# Main class and main functions -class BlendEngineV2(MPCacheEngine): - def __init__( - self, - storage_manager_config: StorageManagerConfig, - chunk_size: int = 256, - hash_algorithm: str = "blake3", - ): - super().__init__( - storage_manager_config, chunk_size, hash_algorithm=hash_algorithm - ) +class BlendModule: + """Handles blend (context-blend / cross-request KV reuse) operations. - self._cb_gpu_contexts: dict[int, PlainGPUCacheContext] = {} + Owns CB-specific GPU context registrations and the token range matcher. + Provides handlers for CB register, unregister, store, retrieve, and lookup. + + Args: + ctx: The shared engine context. + """ - # CB GPU ID -> (model name, world size) as metadata - # NOTE: This is mainly for determining the layout desc during prefetch + def __init__(self, ctx: MPCacheEngineContext) -> None: + self._ctx = ctx + self._cb_gpu_contexts: dict[int, PlainGPUCacheContext] = {} self._cb_gpu_context_meta: dict[int, tuple[str, int]] = {} + self._token_range_matcher = BlendTokenRangeMatcher(ctx.chunk_size) + self._gpu_copy_lock = threading.Lock() + + @property + def context(self) -> MPCacheEngineContext: + """Return the shared engine context. Exposed for testing only.""" + return self._ctx + + def get_handlers(self) -> list[HandlerSpec]: + """Return handler specs for all request types this module serves. + + Returns: + A list of HandlerSpec entries mapping request types to + their handler callables and thread pool assignments. + """ + return [ + HandlerSpec( + RequestType.CB_REGISTER_KV_CACHE, + self.cb_register_kv_cache, + ThreadPoolType.SYNC, + ), + HandlerSpec( + RequestType.CB_UNREGISTER_KV_CACHE, + self.cb_unregister_kv_cache, + ThreadPoolType.SYNC, + ), + HandlerSpec( + RequestType.CB_STORE_PRE_COMPUTED, + self.cb_store_pre_computed, + ThreadPoolType.AFFINITY, + ), + HandlerSpec( + RequestType.CB_RETRIEVE_PRE_COMPUTED_V2, + self.cb_retrieve_pre_computed, + ThreadPoolType.AFFINITY, + ), + HandlerSpec( + RequestType.CB_STORE_FINAL, + self.cb_store_final, + ThreadPoolType.AFFINITY, + ), + HandlerSpec( + RequestType.CB_LOOKUP_PRE_COMPUTED_V2, + self.cb_lookup_pre_computed, + ThreadPoolType.NORMAL, + ), + ] + + def report_status(self) -> dict: + """Return blend module status information. + + Returns: + A dict containing registered CB GPU instance IDs and + per-instance KV cache layout metadata. + """ + cb_gpu_context_meta: dict[str, dict] = {} + for gpu_id, meta in self._cb_gpu_context_meta.items(): + model_name, world_size = meta + entry: dict = { + "model_name": model_name, + "world_size": world_size, + } + ctx = self._cb_gpu_contexts.get(gpu_id) + if ctx is not None: + # bytes per token = 2 (K+V) * num_layers * hidden_dim_size * + # itemsize; num_tokens is the cache capacity, not a per-token + # cost. + cache_size_per_token = ( + 2 * ctx.num_layers * ctx.hidden_dim_size * ctx.dtype.itemsize + ) + entry["kv_cache_layout"] = { + "num_layers": ctx.num_layers, + "num_tokens": ctx.num_tokens, + "hidden_dim_size": ctx.hidden_dim_size, + "dtype": str(ctx.dtype), + "cache_size_per_token": cache_size_per_token, + } + cb_gpu_context_meta[str(gpu_id)] = entry - # Fast local matcher: indexes pre-computed chunk hashes for sub-sequence lookup - self._token_range_matcher = BlendTokenRangeMatcher(chunk_size) + return { + "registered_cb_gpu_ids": list(self._cb_gpu_contexts.keys()), + "cb_gpu_context_meta": cb_gpu_context_meta, + } - self._event_bus = get_event_bus() + def close(self) -> None: + """Release resources owned by this module.""" + self._cb_gpu_contexts.clear() + self._cb_gpu_context_meta.clear() def cb_register_kv_cache( self, @@ -393,18 +421,24 @@ def cb_register_kv_cache( model_name: str, world_size: int, ) -> None: - """ - Register the KV cache buffer from the blend engine + """Register the KV cache buffer from the blend engine. Args: - instance_id: Unique identifier for the blend engine instance - kv_caches: KVCache object containing the GPU buffer pointers + instance_id: Unique identifier for the blend engine instance. + kv_caches: KVCache object containing the GPU buffer pointers. model_name: The name of the model associated with this KV cache. world_size: The world size associated with this KV cache. """ - gpu_context = PlainGPUCacheContext(kv_caches, self.chunk_size) + gpu_context = PlainGPUCacheContext(kv_caches, self._ctx.chunk_size) self._cb_gpu_contexts[instance_id] = gpu_context self._cb_gpu_context_meta[instance_id] = (model_name, world_size) + + layout_desc = MemoryLayoutDesc( + shapes=[gpu_context.get_kv_buffer_shape(self._ctx.chunk_size)], + dtypes=[gpu_context.dtype], + ) + self._ctx.layout_desc_registry.register(model_name, world_size, layout_desc) + logger.info( "Registered CB KV cache for instance_id %d with %d layers", instance_id, @@ -412,11 +446,11 @@ def cb_register_kv_cache( ) def cb_unregister_kv_cache(self, instance_id: int) -> None: - """ - Unregister the KV cache buffer for the given instance_id + """Unregister the KV cache buffer for the given instance_id. Args: - instance_id: Unique identifier for the blend engine instance to unregister + instance_id: Unique identifier for the blend engine instance + to unregister. """ if instance_id in self._cb_gpu_contexts: del self._cb_gpu_contexts[instance_id] @@ -428,46 +462,8 @@ def cb_unregister_kv_cache(self, instance_id: int) -> None: instance_id, ) - def report_status(self) -> dict: - """Return a status dict for the entire cache engine. - - Extends the base dict with ``registered_cb_gpu_ids`` and - ``cb_gpu_context_meta`` so CB-only deployments are distinguishable - from "no engine connected" to ``/api/status`` clients. - """ - status = super().report_status() - - cb_gpu_context_meta: dict[str, dict] = {} - for gpu_id, meta in self._cb_gpu_context_meta.items(): - model_name, world_size = meta - entry: dict = { - "model_name": model_name, - "world_size": world_size, - } - ctx = self._cb_gpu_contexts.get(gpu_id) - if ctx is not None: - # bytes per token = 2 (K+V) * num_layers * hidden_dim_size * - # itemsize; num_tokens is the cache capacity, not a per-token - # cost. - cache_size_per_token = ( - 2 * ctx.num_layers * ctx.hidden_dim_size * ctx.dtype.itemsize - ) - entry["kv_cache_layout"] = { - "num_layers": ctx.num_layers, - "num_tokens": ctx.num_tokens, - "hidden_dim_size": ctx.hidden_dim_size, - "dtype": str(ctx.dtype), - "cache_size_per_token": cache_size_per_token, - } - cb_gpu_context_meta[str(gpu_id)] = entry - - status["registered_cb_gpu_ids"] = list(self._cb_gpu_contexts.keys()) - status["cb_gpu_context_meta"] = cb_gpu_context_meta - return status - def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: - """ - Lookup the pre-computed chunks in the underlying storage. + """Lookup the pre-computed chunks in the underlying storage. Uses BlendTokenRangeMatcher for a fast local pre-filter, then submits prefetch tasks for matched chunks using their stored hashes directly. @@ -475,20 +471,20 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: storage are lazily evicted from the matcher via remove_chunks. Args: - key: IPCCacheEngineKey containing the token ids to lookup + key: IPCCacheEngineKey containing the token ids to lookup. Returns: List of CBMatchResult for chunks that were actually found in storage, ready to be passed to cb_retrieve_pre_computed. """ num_tokens = len(key.token_ids) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_REQUEST_START, session_id=key.request_id, ) ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_LOOKUP_START, session_id=key.request_id, @@ -496,18 +492,11 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: ) ) - # Sub-sequence fingerprint match: find CB-stored chunks anywhere in the - # query using polynomial rolling hashes (context-independent, so a chunk - # stored at position 0 is found even when it appears at CHUNK_SIZE in the - # query). Only chunks registered via cb_store_pre_computed / cb_store_final - # are in the fingerprint table, which gives CB isolation for free — chunks - # stored via the normal STORE path are never registered and therefore - # never returned here. cb_match_result = self._token_range_matcher.match_sub_sequence( list(key.token_ids) ) if not cb_match_result: - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_LOOKUP_END, session_id=key.request_id, @@ -519,12 +508,12 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: "stale_chunks": 0, "no_gpu_context": False, "hit_tokens": 0, - "requested_tokens": (num_tokens // self.chunk_size) - * self.chunk_size, + "requested_tokens": (num_tokens // self._ctx.chunk_size) + * self._ctx.chunk_size, }, ) ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_REQUEST_END, session_id=key.request_id, @@ -564,7 +553,7 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: if m_name == model_name and w_size == world_size: cb_ctx = self._cb_gpu_contexts[gpu_id] layout_desc = MemoryLayoutDesc( - shapes=[cb_ctx.get_kv_buffer_shape(self.chunk_size)], + shapes=[cb_ctx.get_kv_buffer_shape(self._ctx.chunk_size)], dtypes=[cb_ctx.dtype], ) break @@ -576,7 +565,7 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: model_name, world_size, ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_LOOKUP_END, session_id=key.request_id, @@ -588,12 +577,12 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: "stale_chunks": 0, "no_gpu_context": True, "hit_tokens": 0, - "requested_tokens": (num_tokens // self.chunk_size) - * self.chunk_size, + "requested_tokens": (num_tokens // self._ctx.chunk_size) + * self._ctx.chunk_size, }, ) ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_REQUEST_END, session_id=key.request_id, @@ -607,7 +596,7 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: for group in groups: chunk_hashes = [r.hash for r in group] obj_keys = ipc_key_to_object_keys(key, chunk_hashes) - handle = self.storage_manager.submit_prefetch_task( + handle = self._ctx.storage_manager.submit_prefetch_task( obj_keys, layout_desc, external_request_id=key.request_id, @@ -620,17 +609,14 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: group[0].cur_st, ) - # TODO(Jiayi): We need to follow how lookup is handled in server.py - # to optimize performance. # Collect only the CBMatchResults for chunks actually found in storage stale_hashes: list[bytes] = [] for handle, group in zip(prefetch_handles, groups, strict=False): found_count = None while True: - found_count = self.storage_manager.query_prefetch_status(handle) + found_count = self._ctx.storage_manager.query_prefetch_status(handle) if found_count is not None: break - time.sleep(0.001) # Real found count after dedup the TP @@ -664,14 +650,14 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: "Evicted %d stale chunks from fingerprint table", len(stale_hashes), ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_CHUNKS_EVICTED, metadata={"num_chunks": len(stale_hashes)}, ) ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_LOOKUP_END, session_id=key.request_id, @@ -683,8 +669,8 @@ def cb_lookup_pre_computed(self, key: IPCCacheEngineKey) -> list[CBMatchResult]: "stale_chunks": len(stale_hashes), "no_gpu_context": False, "hit_tokens": _unique_token_coverage(found_cb_match_result), - "requested_tokens": (num_tokens // self.chunk_size) - * self.chunk_size, + "requested_tokens": (num_tokens // self._ctx.chunk_size) + * self._ctx.chunk_size, }, ) ) @@ -698,8 +684,7 @@ def _cb_store_gpu_copy( event_ipc_handle: bytes, start_event: Event | None = None, ) -> tuple[Any, dict]: - """ - Helper function to perform GPU-to-CPU copy operations for storing chunks. + """Helper function to perform GPU-to-CPU copy operations for storing chunks. Args: obj_keys: List of object keys to store. @@ -736,16 +721,18 @@ def _cb_store_gpu_copy( vllm_event.wait(stream=gpu_context.stream) if start_event is not None: - self._event_bus.publish_on_stream(gpu_context.cupy_stream, start_event) + self._ctx.event_bus.publish_on_stream( + gpu_context.cupy_stream, start_event + ) # Prepare for the copy - num_tokens = self.chunk_size + num_tokens = self._ctx.chunk_size cpu_shape = gpu_context.get_kv_buffer_shape(num_tokens) layout_desc = MemoryLayoutDesc( shapes=[cpu_shape], dtypes=[gpu_context.dtype] ) - reserved_dict = self.storage_manager.reserve_write( + reserved_dict = self._ctx.storage_manager.reserve_write( obj_keys, layout_desc, "new" ) @@ -755,22 +742,22 @@ def _cb_store_gpu_copy( else: continue - offset_start = idx * self.chunk_size + offset - offset_end = offset_start + self.chunk_size + offset_start = idx * self._ctx.chunk_size + offset + offset_end = offset_start + self._ctx.chunk_size # Copy from GPU to CPU tmp_buffer = gpu_context.get_tmp_gpu_buffer(offset_end - offset_start) gpu_kv_slice = gpu_context.slice_kv_cache_on_tokens( offset_start, offset_end ) - with self.lock: + with self._gpu_copy_lock: tmp_buffer.copy_(gpu_kv_slice, non_blocking=True) lmcache_memcpy_async_d2h(tmp_buffer, memory_obj) event.record() # Call finish_write after the copy is done gpu_context.cupy_stream.launch_host_func( - self.storage_manager.finish_write, + self._ctx.storage_manager.finish_write, list(reserved_dict.keys()), ) @@ -783,46 +770,49 @@ def cb_store_pre_computed( instance_id: int, event_ipc_handle: bytes, ) -> tuple[bytes, bool]: - """ - Store the pre-computed chunks in the underlying storage for later retrieval. + """Store the pre-computed chunks in the underlying storage for later retrieval. Args: - key: IPCCacheEngineKey containing the token ids for which the pre-computed - chunks are stored. + key: IPCCacheEngineKey containing the token ids for which the + pre-computed chunks are stored. offset: The starting offset in the CB KV cache buffer where the - pre-computed - instance_id: The instance_id of the blend engine instance to store the - pre-computed chunks for. + pre-computed chunks begin. + instance_id: The instance_id of the blend engine instance to store + the pre-computed chunks for. event_ipc_handle: The IPC handle for the CUDA event that signals the completion of LLM inference. Returns: - IPC handle bytes for the event that signals the completion of storing the - pre-computed chunks, and a boolean flag indicating if the store is - successful. + IPC handle bytes for the event that signals the completion of storing + the pre-computed chunks, and a boolean flag indicating if the store + is successful. + + Raises: + ValueError: If instance_id is not registered for CB KV cache. Note: - The input tokens should not have any separator in it. It should just be - one "paragraph". - This function will discard the last partial chunk and only store the full - chunks + The input tokens should not have any separator in it. It should just + be one "paragraph". + This function will discard the last partial chunk and only store the + full chunks. """ num_tokens = key.end - key.start - assert instance_id in self._cb_gpu_contexts, ( - f"Instance ID {instance_id} not registered for CB KV cache" - ) + if instance_id not in self._cb_gpu_contexts: + raise ValueError( + f"Instance ID {instance_id} not registered for CB KV cache" + ) gpu_context = self._cb_gpu_contexts[instance_id] # CPU-synchronous sentinel: GPU store is about to be enqueued. - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_STORE_PRE_COMPUTED_SUBMITTED, session_id=key.request_id, metadata={"instance_id": instance_id}, ) ) - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_STORE_PRE_COMPUTED_START, @@ -833,7 +823,7 @@ def cb_store_pre_computed( # Compute normal prefix hashes so these chunks are accessible both via # the CB lookup path and via the standard lookup/retrieve path. - chunk_hashes = self.token_hasher.compute_chunk_hashes(list(key.token_ids)) + chunk_hashes = self._ctx.token_hasher.compute_chunk_hashes(list(key.token_ids)) # convert to object key obj_keys = ipc_key_to_object_keys(key, chunk_hashes) @@ -852,7 +842,7 @@ def cb_store_pre_computed( self._token_range_matcher.on_new_token_hashes( list(key.token_ids), token_hashes ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_FINGERPRINTS_REGISTERED, session_id=key.request_id, @@ -870,7 +860,7 @@ def cb_store_pre_computed( ) except Exception: logger.exception("Cannot store pre-computed chunks due to exception") - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_STORE_PRE_COMPUTED_END, @@ -885,7 +875,7 @@ def cb_store_pre_computed( ) raise - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_STORE_PRE_COMPUTED_END, @@ -908,33 +898,37 @@ def cb_retrieve_pre_computed( instance_id: int, event_ipc_handle: bytes, ) -> tuple[bytes, bool]: - """ - Retrieve the pre-computed chunks from the underlying storage and copy them to - the CB KV cache buffer. + """Retrieve pre-computed chunks from storage and copy them to the CB KV buffer. Args: - key: IPCCacheEngineKey containing the token ids for which the pre-computed - chunks are retrieved. - cb_match_result: List of CBMatchResult returned by cb_lookup_pre_computed, - containing the per-chunk hashes and query positions. - offset: The starting offset in the CB KV cache buffer to copy the retrieved - chunks to. - instance_id: The instance_id of the blend engine instance to retrieve the - pre-computed chunks for. - event_ipc_handle: The IPC handle for the CUDA event that signals the - completion of LLM inference. + key: IPCCacheEngineKey containing the token ids for which the + pre-computed chunks are retrieved. + cb_match_result: List of CBMatchResult returned by + cb_lookup_pre_computed, containing the per-chunk hashes and + query positions. + offset: The starting offset in the CB KV cache buffer to copy the + retrieved chunks to. + instance_id: The instance_id of the blend engine instance to + retrieve the pre-computed chunks for. + event_ipc_handle: The IPC handle for the CUDA event that signals + the completion of LLM inference. Returns: - IPC handle bytes for the event that signals the completion of retrieving the - pre-computed chunks, and a boolean flag indicating if the retrieval is - successful. + IPC handle bytes for the event that signals the completion of + retrieving the pre-computed chunks, and a boolean flag indicating + if the retrieval is successful. + + Raises: + ValueError: If instance_id is not registered for CB KV cache. Note: - We must call `cb_lookup_pre_computed` first before calling this function + cb_lookup_pre_computed must be called first before calling this + function. """ - assert instance_id in self._cb_gpu_contexts, ( - f"Instance ID {instance_id} not registered for CB KV cache" - ) + if instance_id not in self._cb_gpu_contexts: + raise ValueError( + f"Instance ID {instance_id} not registered for CB KV cache" + ) gpu_context = self._cb_gpu_contexts[instance_id] # One obj_key per match_result, in cur_st order @@ -944,7 +938,7 @@ def cb_retrieve_pre_computed( all_obj_keys = ipc_key_to_object_keys(key, chunk_hashes) # CPU-synchronous sentinel: GPU retrieve is about to be enqueued. - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_RETRIEVE_SUBMITTED, session_id=key.request_id, @@ -962,22 +956,25 @@ def cb_retrieve_pre_computed( check_interprocess_event_support() event = torch_dev.Event(interprocess=True) - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_RETRIEVE_START, session_id=key.request_id, - metadata={"instance_id": instance_id, "num_chunks": num_chunks}, + metadata={ + "instance_id": instance_id, + "num_chunks": num_chunks, + }, ), ) try: - with self.storage_manager.read_prefetched_results( + with self._ctx.storage_manager.read_prefetched_results( all_obj_keys ) as memory_objs: if memory_objs is None: logger.error("Some keys not found during CB retrieve!") - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_RETRIEVE_END, @@ -995,18 +992,20 @@ def cb_retrieve_pre_computed( cb_match_result, memory_objs, strict=False ): gpu_st = r.cur_st + offset - gpu_ed = gpu_st + self.chunk_size - tmp_buffer = gpu_context.get_tmp_gpu_buffer(self.chunk_size) + gpu_ed = gpu_st + self._ctx.chunk_size + tmp_buffer = gpu_context.get_tmp_gpu_buffer( + self._ctx.chunk_size + ) target_buffer = gpu_context.slice_kv_cache_on_tokens( gpu_st, gpu_ed ) - with self.lock: + with self._gpu_copy_lock: lmcache_memcpy_async_h2d(memory_obj, tmp_buffer) target_buffer.copy_(tmp_buffer, non_blocking=True) except Exception: logger.exception("Error during retrieving prefetched results") - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_RETRIEVE_END, @@ -1027,7 +1026,8 @@ def cb_retrieve_pre_computed( # We should consider not unlocking objects in read_prefetched_results # if error happens. gpu_context.cupy_stream.launch_host_func( - self.storage_manager.finish_read_prefetched, all_obj_keys + self._ctx.storage_manager.finish_read_prefetched, + all_obj_keys, ) logger.info( @@ -1035,7 +1035,7 @@ def cb_retrieve_pre_computed( len(cb_match_result), offset, ) - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_RETRIEVE_END, @@ -1056,43 +1056,48 @@ def cb_store_final( instance_id: int, event_ipc_handle: bytes, ) -> tuple[bytes, bool]: - """ - Store the final chunks in the underlying storage after processing. The stored - chunk should be accessible for normal mode LLMs. + """Store the final chunks in the underlying storage after processing. + + The stored chunks should be accessible for normal mode LLMs. Args: - key: IPCCacheEngineKey containing the token ids for which the final chunks - are stored. - offset: The starting offset in the CB KV cache buffer where the final - chunks are stored. - instance_id: The instance_id of the blend engine instance to store the final - chunks for. - event_ipc_handle: The IPC handle for the CUDA event that signals the - completion of LLM inference. + key: IPCCacheEngineKey containing the token ids for which the + final chunks are stored. + offset: The starting offset in the CB KV cache buffer where the + final chunks are stored. + instance_id: The instance_id of the blend engine instance to + store the final chunks for. + event_ipc_handle: The IPC handle for the CUDA event that signals + the completion of LLM inference. Returns: - IPC handle bytes for the event that signals the completion of storing the - final chunks, and a boolean flag indicating if the store is successful. + IPC handle bytes for the event that signals the completion of + storing the final chunks, and a boolean flag indicating if the + store is successful. + + Raises: + ValueError: If instance_id is not registered for CB KV cache. """ num_tokens = key.end - key.start # Get GPU context - assert instance_id in self._cb_gpu_contexts, ( - f"Instance ID {instance_id} not registered for CB KV cache" - ) + if instance_id not in self._cb_gpu_contexts: + raise ValueError( + f"Instance ID {instance_id} not registered for CB KV cache" + ) gpu_context = self._cb_gpu_contexts[instance_id] # CPU-synchronous sentinels: SUBMITTED before SESSION_END so the # tracing subscriber's in-flight counter is non-zero when SESSION_END # arrives and correctly defers root span closure. - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_STORE_FINAL_SUBMITTED, session_id=key.request_id, metadata={"instance_id": instance_id}, ) ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_REQUEST_END, session_id=key.request_id, @@ -1100,7 +1105,7 @@ def cb_store_final( ) # Compute normal hash for the keys - chunk_hashes = self.token_hasher.compute_chunk_hashes(list(key.token_ids)) + chunk_hashes = self._ctx.token_hasher.compute_chunk_hashes(list(key.token_ids)) # convert to object key obj_keys = ipc_key_to_object_keys(key, chunk_hashes) @@ -1115,7 +1120,10 @@ def cb_store_final( start_event=Event( event_type=EventType.CB_STORE_FINAL_START, session_id=key.request_id, - metadata={"instance_id": instance_id, "num_tokens": num_tokens}, + metadata={ + "instance_id": instance_id, + "num_tokens": num_tokens, + }, ), ) @@ -1126,7 +1134,7 @@ def cb_store_final( self._token_range_matcher.on_new_token_hashes( list(key.token_ids), list(chunk_hashes) ) - self._event_bus.publish( + self._ctx.event_bus.publish( Event( event_type=EventType.CB_FINGERPRINTS_REGISTERED, session_id=key.request_id, @@ -1144,7 +1152,7 @@ def cb_store_final( ) except Exception: logger.exception("Cannot store final chunks due to exception") - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_STORE_FINAL_END, @@ -1159,7 +1167,7 @@ def cb_store_final( ) raise - self._event_bus.publish_on_stream( + self._ctx.event_bus.publish_on_stream( gpu_context.cupy_stream, Event( event_type=EventType.CB_STORE_FINAL_END, @@ -1173,172 +1181,3 @@ def cb_store_final( ), ) return event.ipc_handle(), True - - -def add_handler_helper( - server: MessageQueueServer, request_type: RequestType, handler_function -): - payload_classes = get_payload_classes(request_type) - handler_type = get_handler_type(request_type) - server.add_handler( - request_type, - payload_classes, - handler_type, - handler_function, - ) - - -def run_cache_server( - mp_config: MPServerConfig, - storage_manager_config: StorageManagerConfig, - obs_config: ObservabilityConfig, - return_engine: bool = False, - start_prometheus_http_server: bool = True, -): - """ - Run the LMCache cache server with ZMQ message queue. - - Args: - mp_config: Configuration for the ZMQ multiprocess server - storage_manager_config: Configuration for the storage manager - obs_config: Configuration for the observability stack - return_engine: If True, return (server, engine) after starting; - if False, run blocking loop to keep server alive - start_prometheus_http_server: Whether to start a standalone - Prometheus HTTP server in a background thread. Set to - ``False`` when an external HTTP framework already serves - ``/metrics`` to avoid port conflicts or redundant servers. - - Returns: - If return_engine is True: tuple of (MessageQueueServer, BlendEngineV2) - If return_engine is False: None (blocks until interrupted) - """ - event_bus = init_observability( - obs_config, start_prometheus_http_server=start_prometheus_http_server - ) - - # Wire up the trace recorder (no-op when --trace-level is unset). - maybe_initialize_trace_recorder(event_bus, obs_config, storage_manager_config) - - # Initialize the engine (loggers self-register with the global controller) - engine = BlendEngineV2( - storage_manager_config=storage_manager_config, - chunk_size=mp_config.chunk_size, - hash_algorithm=mp_config.hash_algorithm, - ) - - # Initialize the message queue server - context = zmq.Context.instance() - server = MessageQueueServer( - bind_url=f"tcp://{mp_config.host}:{mp_config.port}", - context=context, - ) - - # Add handlers for original server - add_handler_helper(server, RequestType.REGISTER_KV_CACHE, engine.register_kv_cache) - add_handler_helper( - server, RequestType.UNREGISTER_KV_CACHE, engine.unregister_kv_cache - ) - add_handler_helper(server, RequestType.STORE, engine.store) - add_handler_helper(server, RequestType.LOOKUP, engine.lookup) - add_handler_helper( - server, RequestType.QUERY_PREFETCH_STATUS, engine.query_prefetch_status - ) - add_handler_helper(server, RequestType.FREE_LOOKUP_LOCKS, engine.free_lookup_locks) - add_handler_helper(server, RequestType.RETRIEVE, engine.retrieve) - add_handler_helper(server, RequestType.CLEAR, engine.clear) - add_handler_helper(server, RequestType.GET_CHUNK_SIZE, engine.get_chunk_size) - add_handler_helper(server, RequestType.END_SESSION, engine.end_session) - add_handler_helper(server, RequestType.NOOP, engine.debug) - add_handler_helper( - server, - RequestType.REPORT_BLOCK_ALLOCATION, - engine.report_block_allocations, - ) - # Add handler for blend operations - add_handler_helper( - server, RequestType.CB_REGISTER_KV_CACHE, engine.cb_register_kv_cache - ) - add_handler_helper( - server, RequestType.CB_UNREGISTER_KV_CACHE, engine.cb_unregister_kv_cache - ) - add_handler_helper( - server, RequestType.CB_LOOKUP_PRE_COMPUTED_V2, engine.cb_lookup_pre_computed - ) - add_handler_helper( - server, RequestType.CB_STORE_PRE_COMPUTED, engine.cb_store_pre_computed - ) - add_handler_helper( - server, RequestType.CB_RETRIEVE_PRE_COMPUTED_V2, engine.cb_retrieve_pre_computed - ) - add_handler_helper(server, RequestType.CB_STORE_FINAL, engine.cb_store_final) - add_handler_helper(server, RequestType.PING, engine.ping) - - # Assign thread pools - server.add_affinity_thread_pool( - [ - RequestType.STORE, - RequestType.RETRIEVE, - RequestType.CB_STORE_PRE_COMPUTED, - RequestType.CB_RETRIEVE_PRE_COMPUTED_V2, - RequestType.CB_STORE_FINAL, - ], - max_workers=mp_config.max_gpu_workers, - ) - server.add_normal_thread_pool( - [ - RequestType.LOOKUP, - RequestType.QUERY_PREFETCH_STATUS, - RequestType.FREE_LOOKUP_LOCKS, - RequestType.END_SESSION, - RequestType.CLEAR, - RequestType.CB_LOOKUP_PRE_COMPUTED_V2, - RequestType.PING, - RequestType.REPORT_BLOCK_ALLOCATION, - ], - max_workers=mp_config.max_cpu_workers, - ) - - logger.info( - "LMCache ZMQ cache server is running on tcp://%s:%d", - mp_config.host, - mp_config.port, - ) - # Start the ZMQ server - # Not all backends expose init(); some auto-initialize on first use - if not hasattr(torch_dev, "init"): - logger.warning( - "Backend '%s' does not support init(), skipping device init", - torch_device_type, - ) - else: - torch_dev.init() - server.start() - - logger.info("LMCache cache blend v2 server is running...") - - # Return server and engine if requested (for HTTP server integration) - if return_engine: - return server, engine - - # Dummy loop to keep the server running - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("Shutting down server...") - event_bus.stop() - server.close() - engine.close() - - -if __name__ == "__main__": - args = parse_args() - mp_config = parse_args_to_mp_server_config(args) - storage_manager_config = parse_args_to_config(args) - obs_config = parse_args_to_observability_config(args) - run_cache_server( - mp_config=mp_config, - storage_manager_config=storage_manager_config, - obs_config=obs_config, - ) diff --git a/lmcache/v1/multiprocess/modules/gpu_transfer.py b/lmcache/v1/multiprocess/modules/gpu_transfer.py new file mode 100644 index 00000000000..536cafb20c5 --- /dev/null +++ b/lmcache/v1/multiprocess/modules/gpu_transfer.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: Apache-2.0 +"""GPU-based KV cache transfer operations for the MPCacheEngine.""" + +# Standard +from dataclasses import dataclass +from itertools import islice +from typing import Generator +import time + +# First Party +from lmcache import torch_dev, torch_device_type +from lmcache.logging import init_logger +from lmcache.utils import ( + EngineType, + _lmcache_nvtx_annotate, + check_interprocess_event_support, +) +from lmcache.v1.distributed.api import ( + MemoryLayoutDesc, + ObjectKey, +) +from lmcache.v1.gpu_connector.gpu_ops import ( + lmcache_memcpy_async_d2h, + lmcache_memcpy_async_h2d, +) +from lmcache.v1.gpu_connector.utils import LayoutHints +from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.mp_observability.event import Event, EventType +from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + KVCache, +) +from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext +from lmcache.v1.multiprocess.engine_module import ( + HandlerSpec, + ThreadPoolType, +) +from lmcache.v1.multiprocess.gpu_context import GPUCacheContext +from lmcache.v1.multiprocess.native_completion import ( + DeviceHostFuncDispatcher, + submit_callback_to_stream, +) +from lmcache.v1.multiprocess.protocols.base import RequestType +import lmcache.c_ops as lmc_ops + +logger = init_logger(__name__) + + +def get_layout_desc(gpu_context: GPUCacheContext, num_tokens: int) -> MemoryLayoutDesc: + """Get the memory layout description for a given GPU context and number of tokens. + + Supports multiple KV layer groups with different shapes and dtypes. + + Args: + gpu_context: The GPU cache context containing the KV cache information. + num_tokens: The number of tokens to determine the layout for. + + Returns: + MemoryLayoutDesc: The memory layout description containing shapes and dtypes. + """ + num_groups = gpu_context.kv_layer_groups_manager.num_groups + shapes = [ + gpu_context.get_kv_buffer_shape(num_tokens, group_idx) + for group_idx in range(num_groups) + ] + dtypes = [ + gpu_context.kv_layer_groups_manager.kv_layer_groups[group_idx].dtype + for group_idx in range(num_groups) + ] + return MemoryLayoutDesc(shapes=shapes, dtypes=dtypes) + + +def batched_iteration(lst: list, batch_size: int) -> Generator[tuple, None, None]: + """Utility function to iterate over a list in batches. + + Args: + lst: The list to iterate over. + batch_size: The size of each batch. + + Yields: + Batches of the list as tuples. + + Raises: + ValueError: If batch_size is less than 1. + """ + if batch_size < 1: + raise ValueError("batch size must be at least one") + it = iter(lst) + while batch := tuple(islice(it, batch_size)): + yield batch + + +@dataclass +class GPUContextEntry: + """Registered GPU context metadata for a single worker instance. + + Args: + gpu_context: The GPU cache context managing shape and pointers + to vLLM GPU KV cache tensors. + model_name: The name of the model associated with this KV cache. + world_size: The world size associated with this KV cache. + """ + + gpu_context: GPUCacheContext + model_name: str + world_size: int + + +class GPUTransferModule: + """Handles GPU-based KV cache transfer operations. + + Owns GPU context registrations and provides handlers for + register, unregister, store, and retrieve of GPU KV caches. + + Args: + ctx: The shared engine context. + """ + + def __init__(self, ctx: MPCacheEngineContext) -> None: + self._ctx = ctx + self._gpu_contexts: dict[int, GPUContextEntry] = {} + + # Route finish_write / finish_read_prefetched through a C++ host + # callback so the driver thread doesn't acquire the GIL. + self._device_host_func_dispatcher = DeviceHostFuncDispatcher() + self._device_host_func_dispatcher.register( + "finish_write", + self._ctx.storage_manager.finish_write, + payload_type=list[ObjectKey], + ) + self._device_host_func_dispatcher.register( + "finish_read_prefetched", + self._ctx.storage_manager.finish_read_prefetched, + payload_type=list[ObjectKey], + ) + self._device_host_func_dispatcher.start() + + @property + def context(self) -> MPCacheEngineContext: + """Return the shared engine context. Exposed for testing only.""" + return self._ctx + + def get_handlers(self) -> list[HandlerSpec]: + """Return handler specs for all request types this module serves. + + Returns: + A list of HandlerSpec entries mapping request types to + their handler callables and thread pool assignments. + """ + return [ + HandlerSpec( + RequestType.REGISTER_KV_CACHE, + self.register_kv_cache, + ThreadPoolType.SYNC, + ), + HandlerSpec( + RequestType.UNREGISTER_KV_CACHE, + self.unregister_kv_cache, + ThreadPoolType.SYNC, + ), + HandlerSpec( + RequestType.STORE, + self.store, + ThreadPoolType.AFFINITY, + ), + HandlerSpec( + RequestType.RETRIEVE, + self.retrieve, + ThreadPoolType.AFFINITY, + ), + ] + + def report_status(self) -> dict: + """Return GPU transfer module status information. + + Returns: + A dict containing registered GPU instance IDs and + per-instance KV cache layout metadata. + """ + registered_gpu_ids: list[int] = [] + gpu_context_meta: dict[str, dict] = {} + + for instance_id, entry in self._gpu_contexts.items(): + registered_gpu_ids.append(instance_id) + ctx = entry.gpu_context + gpu_context_meta[str(instance_id)] = { + "model_name": entry.model_name, + "world_size": entry.world_size, + "kv_cache_layout": { + "num_layers": ctx.num_layers, + "inference_engine_logical_block_size": ( + ctx.kv_layer_groups_manager.inference_engine_logical_block_size + ), + "group_physical_block_sizes": ctx.group_physical_block_sizes, + "group_compress_ratios": ctx.group_compress_ratios, + "hidden_dim_sizes": str(ctx.hidden_dim_sizes), + "dtype": str(ctx.dtype), + "is_mla": ctx.is_mla, + "num_blocks": ctx.num_blocks, + "gpu_kv_format": ctx.gpu_kv_format_name, + "gpu_kv_shape": ctx.gpu_kv_shape, + "gpu_kv_concrete_shape": ctx.concrete_gpu_kv_shape, + "attention_backend": ctx.attention_backend, + "cache_size_per_token": ctx.cache_size_per_token(), + }, + } + + return { + "registered_gpu_ids": registered_gpu_ids, + "gpu_context_meta": gpu_context_meta, + } + + def close(self) -> None: + """Release GPU resources owned by this module.""" + # Stop the drain thread before storage_manager.close() so any + # in-flight completions reach a live storage manager. + self._device_host_func_dispatcher.stop() + + had_contexts = len(self._gpu_contexts) > 0 + self._gpu_contexts.clear() + if had_contexts: + torch_dev.empty_cache() + + def register_kv_cache( + self, + instance_id: int, + kv_caches: KVCache, + model_name: str, + world_size: int, + engine_type: EngineType, + layout_hints: LayoutHints, + ) -> None: + """Register the KV cache tensors for a given GPU instance ID. + + Args: + instance_id: The GPU instance ID (such as PID). + kv_caches: The KV cache tensor wrappers from the + serving engine. + model_name: The name of the model associated with this KV cache. + world_size: The world size associated with this KV cache. + engine_type: Which serving engine produced the caches. + Forwarded to GPUCacheContext for format detection. + layout_hints: See LayoutHints. Forwarded to + GPUCacheContext for GPU KV format detection. + """ + if instance_id in self._gpu_contexts: + logger.warning( + "Instance %s's KV cache is already registered, " + "skipping the new registration", + instance_id, + ) + return + + gpu_context = GPUCacheContext( + kv_caches, + self._ctx.chunk_size, + layout_hints=layout_hints or None, + engine_type=engine_type, + ) + self._gpu_contexts[instance_id] = GPUContextEntry( + gpu_context=gpu_context, + model_name=model_name, + world_size=world_size, + ) + + layout_desc = get_layout_desc(gpu_context, self._ctx.chunk_size) + self._ctx.layout_desc_registry.register(model_name, world_size, layout_desc) + + logger.info( + "Registered KV cache for GPU ID %d with %d layers", + instance_id, + gpu_context.num_layers, + ) + + def unregister_kv_cache(self, instance_id: int) -> None: + """Unregister the KV cache tensors for a given GPU instance ID. + + Args: + instance_id: The GPU instance ID (such as PID). + """ + entry = self._gpu_contexts.pop(instance_id, None) + if entry is None: + logger.warning( + "No registered GPU context found for instance ID %d", instance_id + ) + return + + self._ctx.layout_desc_registry.unregister(entry.model_name, entry.world_size) + logger.info("Unregistered KV cache for GPU ID %d", instance_id) + torch_dev.empty_cache() + + @_lmcache_nvtx_annotate + def store( + self, + key: IPCCacheEngineKey, + instance_id: int, + gpu_block_ids: list[int], + event_ipc_handle: bytes, + ) -> tuple[bytes, bool]: + """Store the GPU KV cache blocks to CPU. + + Args: + key: The IPC key for the KV cache blocks. + Must have worker_id != None (worker store operation). + instance_id: The GPU instance ID (such as PID). + gpu_block_ids: The GPU block IDs to store. + event_ipc_handle: The IPC handle of the event to wait on. + + Returns: + A tuple where the first element is the IPC handle of the event + that signals the completion of the store operation, and the second + element indicates whether the store operation was successful. + + Raises: + ValueError: If no GPU context is registered for the given instance ID. + RuntimeError: If the backend does not support IPC event handles. + """ + st = time.perf_counter() + obj_keys = self._ctx.resolve_obj_keys(key) + + entry = self._gpu_contexts.get(instance_id) + if entry is None: + raise ValueError(f"No GPU context registered for instance ID {instance_id}") + gpu_context = entry.gpu_context + model_name = entry.model_name + + # ``blocks_per_chunk`` is counted in inference-engine-side + # blocks (each block addresses + # ``inference_engine_logical_block_size`` *logical* tokens). + # For compressed groups the per-group physical slot count + # differs, but the block-id indexing is shared with the engine + # and therefore uses the engine logical block size here. + blocks_per_chunk = ( + self._ctx.chunk_size + // gpu_context.kv_layer_groups_manager.inference_engine_logical_block_size + ) + + with ( + torch_dev.device(gpu_context.device), + torch_dev.stream(gpu_context.stream), + ): + check_interprocess_event_support() + event = torch_dev.Event(interprocess=True) + + all_block_ids_gpu = gpu_context.stage_block_ids(gpu_block_ids) + + if not hasattr(torch_dev.Event, "from_ipc_handle"): + raise RuntimeError( + f"Backend '{torch_device_type}' does not support IPC event " + "handles (Event.from_ipc_handle not available). " + "Multiprocess IPC requires CUDA." + ) + vllm_event = torch_dev.Event.from_ipc_handle( + gpu_context.device, event_ipc_handle + ) + vllm_event.wait(stream=gpu_context.stream) + + # CPU-synchronous sentinel: a GPU store is about to be enqueued. + # Must be published via publish() (not publish_on_stream) so the + # drain thread sees it before MP_REQUEST_END can race MP_STORE_END. + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_STORE_SUBMITTED, + session_id=key.request_id, + metadata={"device": str(gpu_context.device)}, + ) + ) + + self._ctx.event_bus.publish_on_stream( + gpu_context.cupy_stream, + Event( + event_type=EventType.MP_STORE_START, + session_id=key.request_id, + metadata={ + "device": str(gpu_context.device), + "engine_id": instance_id, + "model_name": model_name, + }, + ), + ) + + reserved_dict: dict[ObjectKey, MemoryObj] = {} + try: + layout_desc = get_layout_desc(gpu_context, self._ctx.chunk_size) + reserved_dict = self._ctx.storage_manager.reserve_write( + obj_keys, layout_desc, "new" + ) + + # NOTE: Store is not batched because some obj_keys may be + # skipped (not in reserved_dict), making block_ids + # non-contiguous. Batching would require torch.cat to + # reassemble block_ids, negating the benefit. + num_groups = gpu_context.kv_layer_groups_manager.num_groups + for idx, obj_key in enumerate(obj_keys): + if obj_key in reserved_dict: + memory_obj = reserved_dict[obj_key] + else: + continue + + chunk_block_ids_gpu = all_block_ids_gpu[ + idx * blocks_per_chunk : (idx + 1) * blocks_per_chunk + ] + + # Copy from GPU paged buffer to tmp buffer, then to CPU — per group + for group_idx in range(num_groups): + tmp_buffer = gpu_context.get_tmp_chunk_gpu_buffer(group_idx) + group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) + # Kernel contract: ``group_lmcache_chunk_size`` here is the + # number of *physical* slots per chunk for this group + # (= logical chunk_size // compress_ratio). + group_lmcache_chunk_size = gpu_context.get_physical_chunk_size( + group_idx + ) + lmc_ops.multi_layer_block_kv_transfer( + group_kv_pointers, + [tmp_buffer.data_ptr()], + chunk_block_ids_gpu, + gpu_context.device, + lmc_ops.TransferDirection.D2H, + gpu_context.get_shape_desc(group_idx), + group_lmcache_chunk_size, + gpu_context.gpu_kv_format_, + 0, + ) + # Store is not batched, so we always use chunk_idx=0 (single slot) + lmcache_memcpy_async_d2h( + gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj + ) + except Exception: + logger.exception("Cannot store keys due to exception") + finally: + event.record() + if reserved_dict: + submit_callback_to_stream( + gpu_context.cupy_stream, + "finish_write", + list(reserved_dict.keys()), + ) + # All reserved MemoryObjs share one layout_desc, so per-object + # size is identical — avoid summing N identical values. + total_bytes = ( + next(iter(reserved_dict.values())).get_size() * len(reserved_dict) + if reserved_dict + else 0 + ) + self._ctx.event_bus.publish_on_stream( + gpu_context.cupy_stream, + Event( + event_type=EventType.MP_STORE_END, + session_id=key.request_id, + metadata={ + "stored_count": len(reserved_dict), + "device": str(gpu_context.device), + "engine_id": instance_id, + "model_name": model_name, + "total_bytes": total_bytes, + }, + ), + ) + + ed = time.perf_counter() + if length := len(reserved_dict): + logger.info( + "Stored %d tokens in %.3f seconds", + length * self._ctx.chunk_size, + ed - st, + ) + return event.ipc_handle(), True + + @_lmcache_nvtx_annotate + def retrieve( + self, + key: IPCCacheEngineKey, + instance_id: int, + gpu_block_ids: list[int], + event_ipc_handle: bytes, + skip_first_n_tokens: int = 0, + ) -> tuple[bytes, bool]: + """Retrieve the CPU KV cache and put into GPU blocks. + + Args: + key: The IPC key for the KV cache blocks. + Must have worker_id != None (worker retrieve operation). + instance_id: The GPU instance ID (such as PID). + gpu_block_ids: The GPU block IDs to retrieve into. + event_ipc_handle: The IPC handle of the event to wait on. + skip_first_n_tokens: Number of tokens to skip writing at + the start of the retrieve range. This avoids overwriting + APC-shared GPU blocks that may be read concurrently by other + requests. + + Returns: + A tuple where the first element is the IPC handle of the event + that signals the completion of the retrieve operation, and the + second element indicates whether the key was successfully retrieved. + + Raises: + ValueError: If no GPU context is registered for the given instance ID. + """ + st = time.perf_counter() + obj_keys = self._ctx.resolve_obj_keys(key) + + entry = self._gpu_contexts.get(instance_id) + if entry is None: + raise ValueError(f"No GPU context registered for instance ID {instance_id}") + gpu_context = entry.gpu_context + model_name = entry.model_name + + # CPU-synchronous sentinel: a GPU retrieve is about to be enqueued. + # Must be published via publish() (not publish_on_stream) so the + # drain thread sees it before MP_REQUEST_END can race MP_RETRIEVE_END. + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_RETRIEVE_SUBMITTED, + session_id=key.request_id, + metadata={"device": str(gpu_context.device)}, + ) + ) + + self._ctx.event_bus.publish_on_stream( + gpu_context.cupy_stream, + Event( + event_type=EventType.MP_RETRIEVE_START, + session_id=key.request_id, + metadata={ + "device": str(gpu_context.device), + "engine_id": instance_id, + "model_name": model_name, + }, + ), + ) + + # ``skip_*_in_chunk`` is expressed in engine-block units + # (logical tokens), which is what the kernel's + # ``skip_blocks_in_chunk`` argument expects regardless + # of per-group compression. + ie_logical_block_size = ( + gpu_context.kv_layer_groups_manager.inference_engine_logical_block_size + ) + blocks_per_chunk = self._ctx.chunk_size // ie_logical_block_size + + def _retrieve_loop(keys: list[ObjectKey], memory_objs: list[MemoryObj]) -> None: + _BATCH_SIZE = gpu_context.max_batch_size + num_groups = gpu_context.kv_layer_groups_manager.num_groups + for batch_idx, memory_obj_batch in enumerate( + batched_iteration(memory_objs, batch_size=_BATCH_SIZE) + ): + batch_len = len(memory_obj_batch) + chunk_start = batch_idx * self._ctx.chunk_size * _BATCH_SIZE + chunk_end = chunk_start + self._ctx.chunk_size * batch_len + + effective_start = max(chunk_start, skip_first_n_tokens) + if effective_start >= chunk_end: + # Entire batch is within APC range, skip it + continue + + skip_tokens_in_chunk = max( + 0, + min( + effective_start - chunk_start, + self._ctx.chunk_size * batch_len - 1, + ), + ) + if skip_tokens_in_chunk % ie_logical_block_size != 0: + logger.error( + "skip_first_n_tokens (%d) is not aligned to " + "inference_engine_logical_block_size (%d), " + "rounding down from %d tokens to %d blocks", + skip_first_n_tokens, + ie_logical_block_size, + skip_tokens_in_chunk, + skip_tokens_in_chunk // ie_logical_block_size, + ) + skip_blocks_in_chunk = skip_tokens_in_chunk // ie_logical_block_size + + start_chunk_id = batch_idx * _BATCH_SIZE + end_chunk_id = start_chunk_id + batch_len + chunk_block_ids_gpu = all_block_ids_gpu[ + start_chunk_id * blocks_per_chunk : end_chunk_id * blocks_per_chunk + ] + + # Copy from CPU to GPU tmp buffers, then scatter to paged KV — per group + # H2D copy: each memory_obj maps to its own batch slot + for chunk_idx, memory_obj in enumerate(memory_obj_batch): + lmcache_memcpy_async_h2d( + memory_obj, + gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=chunk_idx), + ) + for group_idx in range(num_groups): + tmp_buffers = gpu_context.get_tmp_chunk_gpu_buffer_batched( + batch_len, group_idx + ) + group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) + group_lmcache_chunk_size = gpu_context.get_physical_chunk_size( + group_idx + ) + + lmc_ops.multi_layer_block_kv_transfer( + group_kv_pointers, + [tb.data_ptr() for tb in tmp_buffers], + chunk_block_ids_gpu, + gpu_context.device, + lmc_ops.TransferDirection.H2D, + gpu_context.get_shape_desc(group_idx), + group_lmcache_chunk_size, + gpu_context.gpu_kv_format_, + skip_blocks_in_chunk, + ) + + with ( + torch_dev.device(gpu_context.device), + torch_dev.stream(gpu_context.stream), + ): + # Stage all block_ids to GPU once before the loop + all_block_ids_gpu = gpu_context.stage_block_ids(gpu_block_ids) + + check_interprocess_event_support() + event = torch_dev.Event(interprocess=True) + + prefetched_keys: list[ObjectKey] = [] + retrieve_succeeded = False + total_bytes = 0 + try: + with self._ctx.storage_manager.read_prefetched_results( + obj_keys + ) as memory_objs: + if not memory_objs or len(memory_objs) != len(obj_keys): + logger.error("Some keys not found during retrieve!") + return event.ipc_handle(), False + + prefetched_keys = obj_keys[: len(memory_objs)] + total_bytes = sum(mo.get_size() for mo in memory_objs) + _retrieve_loop(obj_keys, memory_objs) + # Only set True when with-block exits normally + retrieve_succeeded = True + except Exception: + logger.exception("Cannot retrieve keys due to exception") + return event.ipc_handle(), False + finally: + event.record() + if retrieve_succeeded: + submit_callback_to_stream( + gpu_context.cupy_stream, + "finish_read_prefetched", + prefetched_keys, + ) + self._ctx.event_bus.publish_on_stream( + gpu_context.cupy_stream, + Event( + event_type=EventType.MP_RETRIEVE_END, + session_id=key.request_id, + metadata={ + "retrieved_count": len(prefetched_keys), + "device": str(gpu_context.device), + "engine_id": instance_id, + "model_name": model_name, + "cache_salt": key.cache_salt, + "total_bytes": total_bytes, + }, + ), + ) + tokens_retrieved = len(obj_keys) * self._ctx.chunk_size + ed = time.perf_counter() + logger.info( + "Retrieved %d tokens in %.3f seconds", + tokens_retrieved, + ed - st, + ) + + return event.ipc_handle(), True diff --git a/lmcache/v1/multiprocess/modules/lookup.py b/lmcache/v1/multiprocess/modules/lookup.py new file mode 100644 index 00000000000..d7df4e296f6 --- /dev/null +++ b/lmcache/v1/multiprocess/modules/lookup.py @@ -0,0 +1,467 @@ +# SPDX-License-Identifier: Apache-2.0 +"""LookupModule: lookup, prefetch polling, and session lifecycle.""" + +# Standard +from dataclasses import dataclass +from functools import partial +import threading +import time + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.distributed.api import ( + PrefetchHandle, + ipc_key_to_object_keys, +) +from lmcache.v1.mp_observability.event import Event, EventType +from lmcache.v1.mp_observability.otel_init import register_gauge +from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey +from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext +from lmcache.v1.multiprocess.engine_module import ( + HandlerSpec, + ThreadPoolType, +) +from lmcache.v1.multiprocess.protocol import RequestType +from lmcache.v1.multiprocess.token_hasher import TokenHasher + +logger = init_logger(__name__) + + +def compute_extra_count( + tp_size: int, + world_size: int, +) -> int: + """Compute extra count for MLA multi-reader locking. + + Non-MLA: each TP worker owns a distinct KV shard, + so each ObjectKey is retrieved by exactly 1 + worker -> extra_count = 0. + MLA: TP does not split KV caches, all TP workers + share the same object. vLLM passes world_size + already divided by tp_size (e.g. world_size=1 + for TP=4 PP=1), so ipc_keys_to_object_keys + only produces 1 ObjectKey per chunk. All TP + workers retrieve that same ObjectKey, hence + extra_count = tp_size - 1. + + Detection: tp > world_size means MLA (world_size + was divided by tp on the vLLM side). + + Fallback: old vLLM (<= 0.8.5) does not send + tp_size (defaults to 1); we fall back to + world_size which gives extra_count = 0 + (safe but may under-lock for MLA). + + TODO: world_size currently carries an overloaded + meaning (total ranks for non-MLA vs total/tp for + MLA). Consider a dedicated field in the future. + + Args: + tp_size: Tensor-parallel size from the client. + world_size: World size from the cache key. + + Returns: + Number of extra count (0 for non-MLA). + """ + tp = tp_size if tp_size > 1 else world_size + return tp - 1 if tp > world_size else 0 + + +@dataclass +class _PrefetchJob: + handle: PrefetchHandle + world_size: int + request_id: str + # Number of tokens submitted for lookup (denominator for the L1+L2 + # token-level hit-rate metric). Equals ``len(chunk_hashes) * chunk_size`` + # on the happy path; 0 for early-exit paths (no GPU context matches + # or chunk_hashes is empty). Consumed at ``MP_LOOKUP_PREFETCH_END`` + # emission time in ``query_prefetch_status``. + requested_tokens: int + # Captured at lookup time so the ``MP_LOOKUP_PREFETCH_END`` event can + # carry them as labels. ``model_name`` lets dashboards slice hit rate + # per model in multi-model deployments; ``cache_salt`` slices per + # tenant / isolation domain (an empty string means no salt set). + model_name: str = "" + cache_salt: str = "" + + +class LookupModule: + """Handles lookup, prefetch polling, lock release, and session lifecycle. + + Owns the prefetch-job bookkeeping (``_prefetch_jobs``) and exposes + handlers for the LOOKUP, QUERY_PREFETCH_STATUS, + QUERY_PREFETCH_LOOKUP_HITS, FREE_LOOKUP_LOCKS, and END_SESSION + request types. + + Args: + ctx: Shared engine context providing storage manager, token hasher, + session manager, event bus, layout descriptor registry, and + chunk size. + """ + + def __init__(self, ctx: MPCacheEngineContext) -> None: + self._ctx = ctx + self._prefetch_jobs: dict[str, _PrefetchJob] = {} + self._prefetch_job_lock = threading.Lock() + self._setup_metrics() + + @property + def context(self) -> MPCacheEngineContext: + """Return the shared engine context. Exposed for testing only.""" + return self._ctx + + def get_handlers(self) -> list[HandlerSpec]: + """Return handler specs for all request types this module serves. + + Returns: + List of handler specs for lookup-related request types. + """ + return [ + HandlerSpec(RequestType.LOOKUP, self.lookup, ThreadPoolType.NORMAL), + HandlerSpec( + RequestType.QUERY_PREFETCH_STATUS, + self.query_prefetch_status, + ThreadPoolType.NORMAL, + ), + HandlerSpec( + RequestType.QUERY_PREFETCH_LOOKUP_HITS, + self.query_prefetch_lookup_hits, + ThreadPoolType.NORMAL, + ), + HandlerSpec( + RequestType.FREE_LOOKUP_LOCKS, + self.free_lookup_locks, + ThreadPoolType.NORMAL, + ), + HandlerSpec( + RequestType.END_SESSION, + self.end_session, + ThreadPoolType.NORMAL, + ), + ] + + def report_status(self) -> dict[str, int]: + """Return module-specific status information. + + Returns: + Dictionary with the count of active prefetch jobs. + """ + return { + "active_prefetch_jobs": self._active_prefetch_count(), + } + + def close(self) -> None: + """Release resources owned by this module (no-op).""" + pass + + # ----------------------------------------------------------------- + # Handlers + # ----------------------------------------------------------------- + + def lookup( + self, + key: IPCCacheEngineKey, + tp_size: int, + ) -> None: + """Submit a prefix lookup. + + Hashes the key, submits a prefetch task to the storage manager, + and registers the job under ``key.request_id`` for later polling + via query_prefetch_status. + + Args: + key: Cache key with request_id embedded. + tp_size: Tensor-parallel size for MLA multi-reader locking. + """ + model_name, world_size = key.model_name, key.world_size + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_REQUEST_START, + session_id=key.request_id, + ) + ) + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_LOOKUP_PREFETCH_START, + session_id=key.request_id, + ) + ) + + layout_desc = self._ctx.layout_desc_registry.find(model_name, world_size) + if layout_desc is None: + logger.error( + "No GPU context found for model %s with world size %d during lookup!", + model_name, + world_size, + ) + self._register_prefetch_job( + _PrefetchJob( + handle=PrefetchHandle( + prefetch_request_id=-1, + external_request_id=key.request_id, + l1_prefix_hit_count=0, + total_requested_keys=0, + submit_time=time.monotonic(), + ), + world_size=1, + request_id=key.request_id, + requested_tokens=0, + model_name=model_name, + cache_salt=key.cache_salt, + ) + ) + return + + extra_count = compute_extra_count(tp_size, world_size) + + chunk_hashes = self._ctx.token_hasher.compute_chunk_hashes(list(key.token_ids)) + if not chunk_hashes: + self._register_prefetch_job( + _PrefetchJob( + handle=PrefetchHandle( + prefetch_request_id=-1, + external_request_id=key.request_id, + l1_prefix_hit_count=0, + total_requested_keys=0, + submit_time=time.monotonic(), + ), + world_size=1, + request_id=key.request_id, + requested_tokens=0, + model_name=model_name, + cache_salt=key.cache_salt, + ) + ) + return + + # Total chunk-aligned tokens submitted for lookup; surfaces as the + # denominator of the L1+L2 token-level hit-rate via the + # ``requested_tokens`` field on ``MP_LOOKUP_PREFETCH_END``. Sub-chunk + # trailing tokens are intentionally excluded — they cannot hit at + # chunk granularity. + requested_tokens = len(chunk_hashes) * self._ctx.chunk_size + + # Guard with has_subscribers() to avoid allocating the metadata dict + # (including dtype/shape list comprehensions) when no subscriber is + # listening (e.g. lookup hash logger is disabled). + if self._ctx.event_bus.has_subscribers(EventType.MP_LOOKUP): + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_LOOKUP, + session_id=key.request_id, + metadata={ + "request_id": key.request_id, + "chunk_hashes": chunk_hashes, + "model_name": model_name, + "chunk_size": self._ctx.chunk_size, + "seq_len": len(key.token_ids), + "dtypes": [str(d) for d in layout_desc.dtypes], + "shapes": [list(s) for s in layout_desc.shapes], + }, + ) + ) + + session = self._ctx.session_manager.get_or_create(key.request_id) + session.set_tokens(list(key.token_ids)) + session.lookup_ipc_key = key + + obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + + handle = self._ctx.storage_manager.submit_prefetch_task( + obj_keys, + layout_desc, + extra_count=extra_count, + external_request_id=key.request_id, + ) + self._register_prefetch_job( + _PrefetchJob( + handle=handle, + world_size=key.world_size, + request_id=key.request_id, + requested_tokens=requested_tokens, + model_name=model_name, + cache_salt=key.cache_salt, + ) + ) + + def query_prefetch_lookup_hits( + self, + request_id: str, + ) -> int | None: + """Query the number of hits for a prefetch request before it's finished. + + Args: + request_id: The external request ID passed in the lookup key. + + Returns: + The number of hits for the prefetched keys if the lookup phase is + done. None if the lookup phase is still in progress. 0 if the + request_id is unknown (already completed and consumed, or invalid). + """ + with self._prefetch_job_lock: + job = self._prefetch_jobs.get(request_id) + + if job is None: + logger.warning( + "Prefetch job for request %s not found (already completed or invalid)", + request_id, + ) + return 0 + + found_count = self._ctx.storage_manager.query_prefetch_lookup_hits(job.handle) + if found_count is None: + return None + + found_count = found_count // job.world_size + return found_count + + def query_prefetch_status( + self, + request_id: str, + ) -> int | None: + """Poll the status of a prefetch job by request_id. + + Returns the chunk count when the prefetch is complete, or None + if it is still in progress. The job entry is automatically + removed once a non-None result is returned (exactly-once + semantics). + + Args: + request_id: The external request ID passed in the lookup key. + + Returns: + Chunk count (int) when done, None if still in progress, + 0 if the request_id is unknown (already completed and consumed, + or invalid). + """ + with self._prefetch_job_lock: + job = self._prefetch_jobs.get(request_id) + if job is None: + logger.warning( + "Prefetch job for request %s not found (already completed or invalid)", + request_id, + ) + return 0 + + found_count = self._ctx.storage_manager.query_prefetch_status(job.handle) + if found_count is None: + return None + + # NOTE(Kuntai): this assumes two things: + # 1. the world size is the same between keys + # 2. the lookup sort the keys in prefix order and breaks at the + # first failure + found_count = found_count // job.world_size + + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_LOOKUP_PREFETCH_END, + session_id=job.request_id, + metadata={ + "found_count": found_count, + "requested_tokens": job.requested_tokens, + "hit_tokens": found_count * self._ctx.chunk_size, + "model_name": job.model_name, + "cache_salt": job.cache_salt, + }, + ) + ) + + with self._prefetch_job_lock: + self._prefetch_jobs.pop(request_id, None) + + return found_count + + def free_lookup_locks( + self, + key: IPCCacheEngineKey, + tp_size: int, + ) -> None: + """Release read locks acquired during lookup. + + Hashes are computed only for chunks in ``[start, end)`` to avoid + unnecessary work on tokens outside that range. + ``start`` and ``end`` must be aligned to ``chunk_size``; it is the + caller's responsibility to align the boundaries as desired. + + Computes the extra reader count from ``tp_size`` and + ``world_size`` the same way :meth:`lookup` does, so + the correct number of locks is released. + + Args: + key: Cache key whose read locks should be released. + tp_size: Tensor-parallel size for MLA + multi-reader locking. + """ + chunk_hashes = self._ctx.token_hasher.compute_chunk_hashes( + list(key.token_ids), start=key.start, end=key.end + ) + if not chunk_hashes: + return + obj_keys = ipc_key_to_object_keys(key, chunk_hashes) + + extra_count = compute_extra_count(tp_size, key.world_size) + + self._ctx.storage_manager.finish_read_prefetched( + obj_keys, extra_count=extra_count + ) + + def end_session(self, request_id: str) -> None: + """Remove the session for a finished request. + + Args: + request_id: The request ID whose session should be removed. + """ + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_VLLM_END_SESSION, + metadata={"request_id": request_id}, + ) + ) + session = self._ctx.session_manager.remove(request_id) + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_REQUEST_END, + session_id=request_id, + ) + ) + if session is None: + logger.warning("Session %s not found, skipping touch", request_id) + return + if session.lookup_ipc_key is None: + logger.warning( + "Session %s has no lookup ipc key, skipping touch", + request_id, + ) + return + + chunk_hashes = [TokenHasher.hash_to_bytes(h) for h in session.get_hashes(0)] + obj_keys = ipc_key_to_object_keys(session.lookup_ipc_key, chunk_hashes) + # unified touch of all keys, which include retrieved and stored keys + # TODO(chunxiaozheng): when l2 is enabled, the prefetched keys from l2 are temp + # and will be deleted after finish_read_prefetched, when we touch all keys, + # these keys has been deleted and will not be touched. + self._ctx.storage_manager.touch_l1_keys(obj_keys) + + # ----------------------------------------------------------------- + # Internal helpers + # ----------------------------------------------------------------- + + def _register_prefetch_job(self, job: _PrefetchJob) -> None: + with self._prefetch_job_lock: + self._prefetch_jobs[job.request_id] = job + + def _active_prefetch_count(self) -> int: + """Return the number of active prefetch jobs (thread-safe).""" + with self._prefetch_job_lock: + return len(self._prefetch_jobs) + + def _setup_metrics(self) -> None: + """Register OTel observable gauges for lookup module metrics.""" + _gauge = partial(register_gauge, "lmcache.mp_engine") + _gauge( + "lmcache_mp.active_prefetch_jobs", + "Number of active prefetch jobs", + self._active_prefetch_count, + ) diff --git a/lmcache/v1/multiprocess/modules/management.py b/lmcache/v1/multiprocess/modules/management.py new file mode 100644 index 00000000000..4cdfcb2d164 --- /dev/null +++ b/lmcache/v1/multiprocess/modules/management.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Management and utility operations for the MPCacheEngine.""" + +# Standard +import threading + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.mp_observability.event import Event, EventType +from lmcache.v1.multiprocess.custom_types import BlockAllocationRecord +from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext +from lmcache.v1.multiprocess.engine_module import ( + HandlerSpec, + ThreadPoolType, +) +from lmcache.v1.multiprocess.protocols.base import RequestType + +logger = init_logger(__name__) + + +class ManagementModule: + """Handles management and utility operations for the cache engine. + + Owns the lock used during cache clearing and provides handlers for + ping, chunk-size queries, clear, debug, and block-allocation reporting. + + Args: + ctx: The shared engine context. + """ + + def __init__(self, ctx: MPCacheEngineContext) -> None: + self._ctx = ctx + self._clear_lock = threading.Lock() + + @property + def context(self) -> MPCacheEngineContext: + """Return the shared engine context. Exposed for testing only.""" + return self._ctx + + def get_handlers(self) -> list[HandlerSpec]: + """Return handler specs for all request types this module serves. + + Returns: + A list of HandlerSpec entries mapping request types to + their handler callables and thread pool assignments. + """ + return [ + HandlerSpec(RequestType.CLEAR, self.clear, ThreadPoolType.NORMAL), + HandlerSpec( + RequestType.GET_CHUNK_SIZE, + self.get_chunk_size, + ThreadPoolType.SYNC, + ), + HandlerSpec(RequestType.PING, self.ping, ThreadPoolType.NORMAL), + HandlerSpec(RequestType.NOOP, self.debug, ThreadPoolType.SYNC), + HandlerSpec( + RequestType.REPORT_BLOCK_ALLOCATION, + self.report_block_allocations, + ThreadPoolType.NORMAL, + ), + ] + + def report_status(self) -> dict: + """Return module-specific status information. + + Returns: + An empty dict; management has no module-level metrics. + """ + return {} + + def close(self) -> None: + """Release resources owned by this module.""" + pass + + def ping(self) -> bool: + """Respond to a ping request. + + Returns: + Always True. + """ + return True + + def get_chunk_size(self) -> int: + """Return the chunk size used for KV cache operations. + + Returns: + The chunk size. + """ + return self._ctx.chunk_size + + def clear(self) -> None: + """Clear all stored KV cache data from the storage manager.""" + with self._clear_lock: + self._ctx.storage_manager.memcheck() + self._ctx.storage_manager.clear(force=True) + self._ctx.storage_manager.memcheck() + + def debug(self) -> str: + """Return a simple health-check string. + + Returns: + The literal string ``"OK"``. + """ + return "OK" + + def report_block_allocations( + self, + instance_id: int, + model_name: str, + records: list[BlockAllocationRecord], + ) -> None: + """Publish vLLM block allocation records to the EventBus. + + Args: + instance_id: The scheduler instance ID. + model_name: The model name from the adapter. + records: List of BlockAllocationRecord with per-request + block and token allocation deltas. + """ + self._ctx.event_bus.publish( + Event( + event_type=EventType.MP_VLLM_BLOCK_ALLOCATION, + metadata={ + "instance_id": instance_id, + "model_name": model_name, + "records": records, + }, + ) + ) diff --git a/lmcache/v1/multiprocess/modules/non_gpu_transfer.py b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py new file mode 100644 index 00000000000..f92291b230f --- /dev/null +++ b/lmcache/v1/multiprocess/modules/non_gpu_transfer.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Non-GPU (pickle-based) KV cache transfer operations for the MPCacheEngine.""" + +# Standard +from dataclasses import dataclass +import pickle + +# Third Party +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.utils import _lmcache_nvtx_annotate +from lmcache.v1.distributed.api import ( + MemoryLayoutDesc, + ObjectKey, +) +from lmcache.v1.multiprocess.custom_types import ( + IPCCacheEngineKey, + RegisterNonGpuContextPayload, +) +from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext +from lmcache.v1.multiprocess.engine_module import ( + HandlerSpec, + ThreadPoolType, +) +from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata +from lmcache.v1.multiprocess.protocols.base import RequestType +from lmcache.v1.multiprocess.protocols.engine import ( + PrepareRetrieveResponse, + PrepareStoreResponse, +) + +logger = init_logger(__name__) + + +@dataclass +class NonGPUContextEntry: + """Registered non-GPU context metadata for a single worker instance. + + Attributes: + metadata: Layout metadata describing the non-CUDA chunk format. + model_name: The name of the model associated with this context. + world_size: The world size associated with this context. + """ + + metadata: NonGpuContextMetadata + model_name: str + world_size: int + + +class NonGPUTransferModule: + """Handles non-GPU (pickle-based) KV cache transfer operations. + + Owns non-GPU context registrations and provides handlers for + register, unregister, prepare/commit store, and prepare/commit retrieve + of CPU-serialized KV caches. + + Args: + ctx: The shared engine context. + """ + + def __init__(self, ctx: MPCacheEngineContext) -> None: + self._ctx = ctx + self._non_gpu_contexts: dict[int, NonGPUContextEntry] = {} + + @property + def context(self) -> MPCacheEngineContext: + """Return the shared engine context. Exposed for testing only.""" + return self._ctx + + def get_handlers(self) -> list[HandlerSpec]: + """Return handler specs for all request types this module serves. + + Returns: + A list of HandlerSpec entries mapping request types to + their handler callables and thread pool assignments. + """ + return [ + HandlerSpec( + RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, + self.register_kv_cache_non_gpu_context, + ThreadPoolType.SYNC, + ), + HandlerSpec( + RequestType.UNREGISTER_KV_CACHE, + self.unregister_kv_cache, + ThreadPoolType.SYNC, + ), + HandlerSpec( + RequestType.PREPARE_STORE, + self.prepare_store, + ThreadPoolType.AFFINITY, + ), + HandlerSpec( + RequestType.COMMIT_STORE, + self.commit_store, + ThreadPoolType.AFFINITY, + ), + HandlerSpec( + RequestType.PREPARE_RETRIEVE, + self.prepare_retrieve, + ThreadPoolType.AFFINITY, + ), + HandlerSpec( + RequestType.COMMIT_RETRIEVE, + self.commit_retrieve, + ThreadPoolType.AFFINITY, + ), + ] + + def report_status(self) -> dict: + """Return non-GPU transfer module status information. + + Returns: + A dict containing registered non-CUDA instance IDs and + per-instance context metadata. + """ + registered_non_cuda_ids: list[int] = [] + non_cuda_context_meta: dict[str, dict] = {} + + for instance_id, entry in self._non_gpu_contexts.items(): + registered_non_cuda_ids.append(instance_id) + non_cuda_context_meta[str(instance_id)] = { + "model_name": entry.model_name, + "world_size": entry.world_size, + "block_size": entry.metadata.block_size, + "use_mla": entry.metadata.use_mla, + } + + return { + "registered_non_cuda_instance_ids": registered_non_cuda_ids, + "non_cuda_context_meta": non_cuda_context_meta, + } + + def close(self) -> None: + """Release resources owned by this module.""" + self._non_gpu_contexts.clear() + + def register_kv_cache_non_gpu_context( + self, + payload: RegisterNonGpuContextPayload, + ) -> None: + """Register non-CUDA KV layout metadata for non-GPU context mode. + + Args: + payload: Struct containing all registration fields + (instance_id, model_name, world_size, block_size, + num_layers, hidden_dim_size, dtype_str, use_mla). + + Raises: + ValueError: If ``payload.dtype_str`` is not a valid torch dtype name. + """ + if payload.instance_id in self._non_gpu_contexts: + logger.warning( + "Instance %s's KV cache is already registered, " + "skipping the new registration", + payload.instance_id, + ) + return + + dtype = getattr(torch, payload.dtype_str, None) + if dtype is None or not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid dtype_str '{payload.dtype_str}': must be a valid torch dtype " + "attribute name (e.g. 'float16' for torch.float16, " + "'bfloat16' for torch.bfloat16, 'float32' for torch.float32)." + ) + + shape = ( + torch.Size( + [payload.num_layers, self._ctx.chunk_size, payload.hidden_dim_size] + ) + if payload.use_mla + else torch.Size( + [2, payload.num_layers, self._ctx.chunk_size, payload.hidden_dim_size] + ) + ) + layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) + metadata = NonGpuContextMetadata( + layout_desc=layout_desc, + block_size=payload.block_size, + use_mla=payload.use_mla, + ) + self._non_gpu_contexts[payload.instance_id] = NonGPUContextEntry( + metadata=metadata, + model_name=payload.model_name, + world_size=payload.world_size, + ) + + self._ctx.layout_desc_registry.register( + payload.model_name, payload.world_size, layout_desc + ) + + def unregister_kv_cache(self, instance_id: int) -> None: + """Unregister a non-GPU KV cache context for the given instance ID. + + Args: + instance_id: The worker instance identifier. + """ + entry = self._non_gpu_contexts.pop(instance_id, None) + if entry is None: + logger.warning( + "No registered non-GPU context found for instance ID %d", + instance_id, + ) + return + + self._ctx.layout_desc_registry.unregister(entry.model_name, entry.world_size) + logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) + + @_lmcache_nvtx_annotate + def prepare_store( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> PrepareStoreResponse: + """Prepare a store operation. For pickle mode, returns empty slots. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + + Returns: + PrepareStoreResponse with empty slots for pickle mode. + """ + return PrepareStoreResponse(context={}) + + @_lmcache_nvtx_annotate + def commit_store( + self, + key: IPCCacheEngineKey, + instance_id: int, + cpu_data: bytes, + ) -> bool: + """Commit serialized CPU chunks to storage. + + Args: + key: Cache key for the token range to store. + instance_id: Worker instance identifier. + cpu_data: Pickled list of CPU tensors produced by the worker. + + Returns: + ``True`` when all reserved objects are written, otherwise ``False``. + + Raises: + ValueError: If no non-GPU context is registered for the given + instance ID. + """ + obj_keys = self._ctx.resolve_obj_keys(key) + + entry = self._non_gpu_contexts.get(instance_id) + if entry is None: + raise ValueError( + f"non-CUDA context not registered for instance ID {instance_id}" + ) + ctx = entry.metadata + chunks: list[torch.Tensor] = pickle.loads(cpu_data) + reserved_dict = self._ctx.storage_manager.reserve_write( + obj_keys, ctx.layout_desc, "new" + ) + written_keys: list[ObjectKey] = [] + try: + for idx, obj_key in enumerate(obj_keys): + if obj_key not in reserved_dict: + continue + if idx >= len(chunks): + continue + memory_obj = reserved_dict[obj_key] + if memory_obj.tensor is None: + continue + chunk_cpu = chunks[idx] + if chunk_cpu.shape != memory_obj.tensor.shape: + continue + memory_obj.tensor.copy_(chunk_cpu) + written_keys.append(obj_key) + finally: + if written_keys: + self._ctx.storage_manager.finish_write(written_keys) + + return len(written_keys) == len(reserved_dict) + + @_lmcache_nvtx_annotate + def prepare_retrieve( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> PrepareRetrieveResponse: + """Retrieve prefetched chunks and return serialized CPU tensors. + + Args: + key: Cache key for the token range to retrieve. + instance_id: Worker instance identifier. + + Returns: + PrepareRetrieveResponse with serialized data on hit. + + Raises: + ValueError: If no non-GPU context is registered for the given + instance ID. + """ + obj_keys = self._ctx.resolve_obj_keys(key) + + entry = self._non_gpu_contexts.get(instance_id) + if entry is None: + raise ValueError( + f"non-CUDA context not registered for instance ID {instance_id}" + ) + + prefetched_keys: list[ObjectKey] = [] + try: + with self._ctx.storage_manager.read_prefetched_results( + obj_keys + ) as memory_objs: + if not memory_objs or len(memory_objs) != len(obj_keys): + return PrepareRetrieveResponse(success=False, data=b"", context={}) + prefetched_keys = obj_keys[: len(memory_objs)] + chunks = [] + for memory_obj in memory_objs: + if memory_obj.tensor is None: + return PrepareRetrieveResponse( + success=False, data=b"", context={} + ) + chunks.append(memory_obj.tensor.cpu().clone()) + return PrepareRetrieveResponse( + success=True, data=pickle.dumps(chunks), context={} + ) + finally: + if prefetched_keys: + self._ctx.storage_manager.finish_read_prefetched(prefetched_keys) + + @_lmcache_nvtx_annotate + def commit_retrieve( + self, + key: IPCCacheEngineKey, + instance_id: int, + ) -> bool: + """Finalize a retrieve operation. No-op for pickle mode. + + Args: + key: Cache key (unused for pickle). + instance_id: Worker instance identifier (unused for pickle). + + Returns: + Always ``True``. + """ + return True diff --git a/lmcache/v1/multiprocess/server.py b/lmcache/v1/multiprocess/server.py index b8cedd73256..01be515df74 100644 --- a/lmcache/v1/multiprocess/server.py +++ b/lmcache/v1/multiprocess/server.py @@ -1,1398 +1,116 @@ # SPDX-License-Identifier: Apache-2.0 +"""MPCacheEngine compositor and unified cache server entry point.""" + # Standard -from dataclasses import dataclass -from functools import partial -from itertools import islice -from typing import Generator import argparse -import pickle -import threading import time # Third Party -import torch import zmq # First Party from lmcache import torch_dev, torch_device_type from lmcache.logging import init_logger -from lmcache.utils import ( - EngineType, - _lmcache_nvtx_annotate, - check_interprocess_event_support, -) -from lmcache.v1.distributed.api import ( - MemoryLayoutDesc, - ObjectKey, - ipc_key_to_object_keys, -) from lmcache.v1.distributed.config import ( StorageManagerConfig, add_storage_manager_args, parse_args_to_config, ) -from lmcache.v1.distributed.storage_manager import PrefetchHandle, StorageManager -from lmcache.v1.gpu_connector.gpu_ops import ( - lmcache_memcpy_async_d2h, - lmcache_memcpy_async_h2d, -) -from lmcache.v1.gpu_connector.utils import LayoutHints -from lmcache.v1.memory_management import MemoryObj from lmcache.v1.mp_observability.config import ( ObservabilityConfig, add_observability_args, init_observability, parse_args_to_observability_config, ) -from lmcache.v1.mp_observability.event import Event, EventType -from lmcache.v1.mp_observability.event_bus import get_event_bus -from lmcache.v1.mp_observability.otel_init import register_gauge from lmcache.v1.mp_observability.trace import maybe_initialize_trace_recorder from lmcache.v1.multiprocess.config import ( MPServerConfig, add_mp_server_args, parse_args_to_mp_server_config, ) -from lmcache.v1.multiprocess.custom_types import ( - BlockAllocationRecord, - IPCCacheEngineKey, - KVCache, - RegisterNonGpuContextPayload, -) -from lmcache.v1.multiprocess.gpu_context import ( - GPUCacheContext, +from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext +from lmcache.v1.multiprocess.engine_module import ( + EngineModule, + HandlerSpec, + ThreadPoolType, ) +from lmcache.v1.multiprocess.modules.gpu_transfer import GPUTransferModule +from lmcache.v1.multiprocess.modules.lookup import LookupModule +from lmcache.v1.multiprocess.modules.management import ManagementModule +from lmcache.v1.multiprocess.modules.non_gpu_transfer import NonGPUTransferModule from lmcache.v1.multiprocess.mq import MessageQueueServer -from lmcache.v1.multiprocess.native_completion import ( - DeviceHostFuncDispatcher, - submit_callback_to_stream, -) -from lmcache.v1.multiprocess.non_gpu_context import NonGpuContextMetadata from lmcache.v1.multiprocess.protocol import ( RequestType, get_handler_type, get_payload_classes, ) -from lmcache.v1.multiprocess.protocols.engine import ( - PrepareRetrieveResponse, - PrepareStoreResponse, -) -from lmcache.v1.multiprocess.session import SessionManager -from lmcache.v1.multiprocess.token_hasher import TokenHasher -import lmcache.c_ops as lmc_ops logger = init_logger(__name__) -# Helper functions -def compute_extra_count( - tp_size: int, - world_size: int, -) -> int: - """Compute extra count for MLA multi-reader locking. - - Non-MLA: each TP worker owns a distinct KV shard, - so each ObjectKey is retrieved by exactly 1 - worker -> extra_count = 0. - MLA: TP does not split KV caches, all TP workers - share the same object. vLLM passes world_size - already divided by tp_size (e.g. world_size=1 - for TP=4 PP=1), so ipc_keys_to_object_keys - only produces 1 ObjectKey per chunk. All TP - workers retrieve that same ObjectKey, hence - extra_count = tp_size - 1. - - Detection: tp > world_size means MLA (world_size - was divided by tp on the vLLM side). - - Fallback: old vLLM (<= 0.8.5) does not send - tp_size (defaults to 1); we fall back to - world_size which gives extra_count = 0 - (safe but may under-lock for MLA). - - TODO: world_size currently carries an overloaded - meaning (total ranks for non-MLA vs total/tp for - MLA). Consider a dedicated field in the future. - - Args: - tp_size: Tensor-parallel size from the client. - world_size: World size from the cache key. - - Returns: - Number of extra count (0 for non-MLA). - """ - tp = tp_size if tp_size > 1 else world_size - return tp - 1 if tp > world_size else 0 - - -def get_layout_desc(gpu_context: GPUCacheContext, num_tokens: int) -> MemoryLayoutDesc: - """Get the memory layout description for a given GPU context and number of tokens. - - Supports multiple KV layer groups with different shapes and dtypes. - - Args: - gpu_context: The GPU cache context containing the KV cache information. - num_tokens: The number of tokens to determine the layout for. - - Returns: - MemoryLayoutDesc: The memory layout description containing shapes and dtypes. - """ - num_groups = gpu_context.kv_layer_groups_manager.num_groups - shapes = [ - gpu_context.get_kv_buffer_shape(num_tokens, group_idx) - for group_idx in range(num_groups) - ] - dtypes = [ - gpu_context.kv_layer_groups_manager.kv_layer_groups[group_idx].dtype - for group_idx in range(num_groups) - ] - return MemoryLayoutDesc(shapes=shapes, dtypes=dtypes) - +class MPCacheEngine: + """Compositor that assembles pluggable engine modules. -def batched_iteration(lst: list, batch_size: int) -> Generator[tuple, None, None]: - """Utility function to iterate over a list in batches. + Holds the shared :class:`MPCacheEngineContext` and a list of + :class:`EngineModule` instances. Provides aggregated + ``report_status()`` and ``close()`` across all modules. Args: - lst: The list to iterate over. - batch_size: The size of each batch. - - Yields: - Batches of the list as tuples. + context: The shared engine context. + modules: List of engine modules to compose. """ - if batch_size < 1: - raise ValueError("batch size must be at least one") - it = iter(lst) - while batch := tuple(islice(it, batch_size)): - yield batch - - -@dataclass -class _PrefetchJob: - handle: PrefetchHandle - world_size: int - request_id: str - # Number of tokens submitted for lookup (denominator for the L1+L2 - # token-level hit-rate metric). Equals ``len(chunk_hashes) * chunk_size`` - # on the happy path; 0 for early-exit paths (no GPU context matches - # or chunk_hashes is empty). Consumed at ``MP_LOOKUP_PREFETCH_END`` - # emission time in ``query_prefetch_status``. - requested_tokens: int - # Captured at lookup time so the ``MP_LOOKUP_PREFETCH_END`` event can - # carry them as labels. ``model_name`` lets dashboards slice hit rate - # per model in multi-model deployments; ``cache_salt`` slices per - # tenant / isolation domain (an empty string means no salt set). - model_name: str = "" - cache_salt: str = "" - - -@dataclass -class RegisteredContext: - """Registered context metadata for a single worker instance. - - At least one of ``gpu_context`` or ``non_cuda_metadata`` is expected to be - populated for valid registrations. - """ - - model_name: str - world_size: int - gpu_context: GPUCacheContext | None = None - non_cuda_metadata: NonGpuContextMetadata | None = None - - @property - def is_gpu(self) -> bool: - """Return whether this registration uses a GPU transfer context.""" - return self.gpu_context is not None - - def get_layout_desc(self, chunk_size: int) -> MemoryLayoutDesc: - """Return the layout descriptor for this registration. - - Args: - chunk_size: Chunk size in tokens used for GPU layout derivation. - - Returns: - The resolved memory layout descriptor. - Raises: - ValueError: If no GPU context or non-CUDA metadata is configured. - """ - if self.gpu_context is not None: - return get_layout_desc(self.gpu_context, chunk_size) - if self.non_cuda_metadata is None: - raise ValueError( - "Invalid RegisteredContext: no GPU or non-CUDA metadata configured" - ) - return self.non_cuda_metadata.layout_desc - - -# Main class for the mp cache engine -class MPCacheEngine: def __init__( self, - storage_manager_config: StorageManagerConfig, - chunk_size: int = 256, - hash_algorithm: str = "blake3", - ): - # Worker instance ID -> registered context metadata - self.contexts: dict[int, RegisteredContext] = {} - - # chunk size - self.chunk_size = chunk_size - - # Lock for clear() to avoid concurrent storage manager mutations - self.lock = threading.Lock() - - # storage manager - self.storage_manager = StorageManager(storage_manager_config) - - # Token hasher and session manager for token-based operations - self.token_hasher = TokenHasher( - chunk_size=chunk_size, hash_algorithm=hash_algorithm - ) - self.session_manager = SessionManager(self.token_hasher) - - # EventBus for observability - self._event_bus = get_event_bus() - - # Route finish_write / finish_read_prefetched through a C++ host - # callback so the driver thread doesn't acquire the GIL. - self._device_host_func_dispatcher = DeviceHostFuncDispatcher() - self._device_host_func_dispatcher.register( - "finish_write", - self.storage_manager.finish_write, - payload_type=list[ObjectKey], - ) - self._device_host_func_dispatcher.register( - "finish_read_prefetched", - self.storage_manager.finish_read_prefetched, - payload_type=list[ObjectKey], - ) - self._device_host_func_dispatcher.start() - - # Prefetch job tracking for two-phase lookup, keyed by request_id. - # TODO: implement periodic cleanup of stale _prefetch_jobs entries - # for crash resilience (e.g., client calls lookup but never queries) - self._prefetch_jobs: dict[str, _PrefetchJob] = {} - self._prefetch_job_lock = threading.Lock() - - self._setup_metrics() - - @property - def gpu_contexts(self) -> dict[int, GPUCacheContext]: - """Return GPU-only context mapping for backward compatibility.""" - return { - instance_id: ctx.gpu_context - for instance_id, ctx in self.contexts.items() - if ctx.gpu_context is not None - } - - def register_kv_cache( - self, - instance_id: int, - kv_caches: KVCache, - model_name: str, - world_size: int, - engine_type: EngineType, - layout_hints: LayoutHints, + context: MPCacheEngineContext, + modules: list[EngineModule], ) -> None: - """ - Registers the KV cache tensors for a given GPU instance ID. - - Args: - instance_id (int): The GPU instance ID (such as PID). - kv_caches (KVCache): The KV cache tensor wrappers from the - serving engine. - model_name (str): The name of the model associated with this KV cache. - world_size (int): The world size associated with this KV cache. - engine_type: Which serving engine produced the caches. - Forwarded to :class:`GPUCacheContext` for format detection. - layout_hints: See :class:`LayoutHints`. Forwarded to - :class:`GPUCacheContext` for GPU KV format detection. - """ - if instance_id in self.contexts: - logger.warning( - "Instance %s's KV cache is already registered, " - "skipping the new registration", - instance_id, - ) - return - - gpu_context = GPUCacheContext( - kv_caches, - self.chunk_size, - layout_hints=layout_hints or None, - engine_type=engine_type, - ) - self.contexts[instance_id] = RegisteredContext( - model_name=model_name, - world_size=world_size, - gpu_context=gpu_context, - ) - logger.info( - "Registered KV cache for GPU ID %d with %d layers", - instance_id, - gpu_context.num_layers, - ) - - def unregister_kv_cache(self, instance_id: int) -> None: - """ - Unregisters the KV cache tensors for a given GPU instance ID. + self._context = context + self._modules = modules - Args: - instance_id (int): The GPU instance ID (such as PID). - """ - context = self.contexts.pop(instance_id, None) - if context is None: - logger.warning( - "No registered context found for instance ID %d", instance_id - ) - return - - if context.is_gpu: - logger.info("Unregistered KV cache for GPU ID %d", instance_id) - torch_dev.empty_cache() - else: - logger.info("Unregistered non-CUDA context for instance ID %d", instance_id) - - def register_kv_cache_non_gpu_context( - self, - payload: RegisterNonGpuContextPayload, - ) -> None: - """Register non-CUDA KV layout metadata for non-GPU context mode. - - Args: - payload: Struct containing all registration fields - (instance_id, model_name, world_size, block_size, - num_layers, hidden_dim_size, dtype_str, use_mla). - - Raises: - ValueError: If ``payload.dtype_str`` is not a valid torch dtype name. - """ - if payload.instance_id in self.contexts: - logger.warning( - "Instance %s's KV cache is already registered, " - "skipping the new registration", - payload.instance_id, - ) - return - - dtype = getattr(torch, payload.dtype_str, None) - if dtype is None or not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid dtype_str '{payload.dtype_str}': must be a valid torch dtype " - "attribute name (e.g. 'float16' for torch.float16, " - "'bfloat16' for torch.bfloat16, 'float32' for torch.float32)." - ) - - shape = ( - torch.Size([payload.num_layers, self.chunk_size, payload.hidden_dim_size]) - if payload.use_mla - else torch.Size( - [2, payload.num_layers, self.chunk_size, payload.hidden_dim_size] - ) - ) - layout_desc = MemoryLayoutDesc(shapes=[shape], dtypes=[dtype]) - self.contexts[payload.instance_id] = RegisteredContext( - model_name=payload.model_name, - world_size=payload.world_size, - non_cuda_metadata=NonGpuContextMetadata( - layout_desc=layout_desc, - block_size=payload.block_size, - use_mla=payload.use_mla, - ), - ) - - def _resolve_obj_keys(self, key: IPCCacheEngineKey) -> list[ObjectKey]: - """Resolve object keys from an IPC cache key. - - Args: - key: IPC cache key describing model/session/token range. - - Returns: - Resolved object keys for the requested token range. - - Raises: - ValueError: If ``key.worker_id`` is ``None``. - """ - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - chunk_hashes = [ - TokenHasher.hash_to_bytes(h) for h in session.get_hashes(key.start, key.end) - ] - if key.worker_id is None: - raise ValueError("Must resolve keys with worker_id != None") - return ipc_key_to_object_keys(key, chunk_hashes) - - @_lmcache_nvtx_annotate - def prepare_store( - self, - key: IPCCacheEngineKey, - instance_id: int, - ) -> PrepareStoreResponse: - """Prepare a store operation. For pickle mode, returns empty slots. - - Args: - key: Cache key for the token range to store. - instance_id: Worker instance identifier. - - Returns: - PrepareStoreResponse with empty slots for pickle mode. - """ - - return PrepareStoreResponse(context={}) - - @_lmcache_nvtx_annotate - def commit_store( - self, - key: IPCCacheEngineKey, - instance_id: int, - cpu_data: bytes, - ) -> bool: - """Commit serialized CPU chunks to storage. - - Args: - key: Cache key for the token range to store. - instance_id: Worker instance identifier. - cpu_data: Pickled list of CPU tensors produced by the worker. - - Returns: - ``True`` when all reserved objects are written, otherwise ``False``. - """ - obj_keys = self._resolve_obj_keys(key) - - context = self.contexts.get(instance_id) - if context is None or context.non_cuda_metadata is None: - raise ValueError( - f"non-CUDA context not registered for instance ID {instance_id}" - ) - ctx = context.non_cuda_metadata - chunks: list[torch.Tensor] = pickle.loads(cpu_data) - reserved_dict = self.storage_manager.reserve_write( - obj_keys, ctx.layout_desc, "new" - ) - written_keys: list[ObjectKey] = [] - try: - for idx, obj_key in enumerate(obj_keys): - if obj_key not in reserved_dict: - continue - if idx >= len(chunks): - continue - memory_obj = reserved_dict[obj_key] - if memory_obj.tensor is None: - continue - chunk_cpu = chunks[idx] - if chunk_cpu.shape != memory_obj.tensor.shape: - continue - memory_obj.tensor.copy_(chunk_cpu) - written_keys.append(obj_key) - finally: - if written_keys: - self.storage_manager.finish_write(written_keys) - - return len(written_keys) == len(reserved_dict) - - @_lmcache_nvtx_annotate - def prepare_retrieve( - self, - key: IPCCacheEngineKey, - instance_id: int, - ) -> PrepareRetrieveResponse: - """Retrieve prefetched chunks and return serialized CPU tensors. - - Args: - key: Cache key for the token range to retrieve. - instance_id: Worker instance identifier. - - Returns: - PrepareRetrieveResponse with serialized data on hit. - """ - - obj_keys = self._resolve_obj_keys(key) - - context = self.contexts.get(instance_id) - if context is None or context.non_cuda_metadata is None: - raise ValueError( - f"non-CUDA context not registered for instance ID {instance_id}" - ) - - prefetched_keys: list[ObjectKey] = [] - try: - with self.storage_manager.read_prefetched_results(obj_keys) as memory_objs: - if not memory_objs or len(memory_objs) != len(obj_keys): - return PrepareRetrieveResponse(success=False, data=b"", context={}) - prefetched_keys = obj_keys[: len(memory_objs)] - chunks = [] - for memory_obj in memory_objs: - if memory_obj.tensor is None: - return PrepareRetrieveResponse( - success=False, data=b"", context={} - ) - chunks.append(memory_obj.tensor.cpu().clone()) - return PrepareRetrieveResponse( - success=True, data=pickle.dumps(chunks), context={} - ) - finally: - if prefetched_keys: - self.storage_manager.finish_read_prefetched(prefetched_keys) - - @_lmcache_nvtx_annotate - def commit_retrieve( - self, - key: IPCCacheEngineKey, - instance_id: int, - ) -> bool: - """Finalize a retrieve operation. No-op for pickle mode. - - Args: - key: Cache key (unused for pickle). - instance_id: Worker instance identifier (unused for pickle). - - Returns: - Always ``True``. - """ - return True - - @_lmcache_nvtx_annotate - def store( - self, - key: IPCCacheEngineKey, - instance_id: int, - gpu_block_ids: list[int], - event_ipc_handle: bytes, - ) -> tuple[bytes, bool]: - """ - Stores the GPU KV cache blocks to CPU. - - Args: - key (IPCCacheEngineKey): The IPC key for the KV cache blocks. - Must have worker_id != None (worker store operation). - instance_id (int): The GPU instance ID (such as PID). - gpu_block_ids (list[int]): The GPU block IDs to store. - event_ipc_handle (bytes): The IPC handle of the event to wait on. - - Returns: - tuple[bytes, bool]: The first element is the IPC handle of the event - that signals the completion of the store operation. The second - element indicates whether the store operation was successful. - """ - st = time.perf_counter() - obj_keys = self._resolve_obj_keys(key) - - context = self.contexts.get(instance_id) - assert context is not None, ( - f"No context registered for instance ID {instance_id}" - ) - assert context.gpu_context is not None, ( - f"GPU context not registered for instance ID {instance_id}" - ) - gpu_context = context.gpu_context - model_name = context.model_name - - # ``blocks_per_chunk`` is counted in inference-engine-side - # blocks (each block addresses - # ``inference_engine_logical_block_size`` *logical* tokens). - # For compressed groups the per-group physical slot count - # differs, but the block-id indexing is shared with the engine - # and therefore uses the engine logical block size here. - blocks_per_chunk = ( - self.chunk_size - // gpu_context.kv_layer_groups_manager.inference_engine_logical_block_size - ) - - with ( - torch_dev.device(gpu_context.device), - torch_dev.stream(gpu_context.stream), - ): - # Not all backends support interprocess Events (CUDA IPC specific) - check_interprocess_event_support() - event = torch_dev.Event(interprocess=True) - - # Stage all block_ids to GPU once before the loop - all_block_ids_gpu = gpu_context.stage_block_ids(gpu_block_ids) - - # Wait for vLLM to finish - # Not all backends support IPC event handles (CUDA IPC specific) - if not hasattr(torch_dev.Event, "from_ipc_handle"): - raise RuntimeError( - f"Backend '{torch_device_type}' does not support IPC event " - "handles (Event.from_ipc_handle not available). " - "Multiprocess IPC requires CUDA." - ) - vllm_event = torch_dev.Event.from_ipc_handle( - gpu_context.device, event_ipc_handle - ) - vllm_event.wait(stream=gpu_context.stream) - - # CPU-synchronous sentinel: a GPU store is about to be enqueued. - # Must be published via publish() (not publish_on_stream) so the - # drain thread sees it before MP_REQUEST_END can race MP_STORE_END. - self._event_bus.publish( - Event( - event_type=EventType.MP_STORE_SUBMITTED, - session_id=key.request_id, - metadata={"device": str(gpu_context.device)}, - ) - ) - - self._event_bus.publish_on_stream( - gpu_context.cupy_stream, - Event( - event_type=EventType.MP_STORE_START, - session_id=key.request_id, - metadata={ - "device": str(gpu_context.device), - "engine_id": instance_id, - "model_name": model_name, - }, - ), - ) - - reserved_dict: dict = {} - try: - layout_desc = get_layout_desc(gpu_context, self.chunk_size) - reserved_dict = self.storage_manager.reserve_write( - obj_keys, layout_desc, "new" - ) - - # NOTE: Store is not batched because some obj_keys may be - # skipped (not in reserved_dict), making block_ids - # non-contiguous. Batching would require torch.cat to - # reassemble block_ids, negating the benefit. - num_groups = gpu_context.kv_layer_groups_manager.num_groups - for idx, obj_key in enumerate(obj_keys): - if obj_key in reserved_dict: - memory_obj = reserved_dict[obj_key] - else: - continue - - chunk_block_ids_gpu = all_block_ids_gpu[ - idx * blocks_per_chunk : (idx + 1) * blocks_per_chunk - ] - - # Copy from GPU paged buffer to tmp buffer, then to CPU — per group - for group_idx in range(num_groups): - tmp_buffer = gpu_context.get_tmp_chunk_gpu_buffer(group_idx) - group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) - # Kernel contract: ``group_lmcache_chunk_size`` here is the - # number of *physical* slots per chunk for this group - # (= logical chunk_size // compress_ratio). - group_lmcache_chunk_size = gpu_context.get_physical_chunk_size( - group_idx - ) - lmc_ops.multi_layer_block_kv_transfer( - group_kv_pointers, - [tmp_buffer.data_ptr()], - chunk_block_ids_gpu, - gpu_context.device, - lmc_ops.TransferDirection.D2H, - gpu_context.get_shape_desc(group_idx), - group_lmcache_chunk_size, - gpu_context.gpu_kv_format_, - 0, - ) - # Store is not batched, so we always use chunk_idx=0 (single slot) - lmcache_memcpy_async_d2h( - gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=0), memory_obj - ) - except Exception: - logger.exception("Cannot store keys due to exception") - finally: - event.record() - if reserved_dict: - submit_callback_to_stream( - gpu_context.cupy_stream, - "finish_write", - list(reserved_dict.keys()), - ) - # All reserved MemoryObjs share one layout_desc, so per-object - # size is identical — avoid summing N identical values. - total_bytes = ( - next(iter(reserved_dict.values())).get_size() * len(reserved_dict) - if reserved_dict - else 0 - ) - self._event_bus.publish_on_stream( - gpu_context.cupy_stream, - Event( - event_type=EventType.MP_STORE_END, - session_id=key.request_id, - metadata={ - "stored_count": len(reserved_dict), - "device": str(gpu_context.device), - "engine_id": instance_id, - "model_name": model_name, - "total_bytes": total_bytes, - }, - ), - ) - - ed = time.perf_counter() - if length := len(reserved_dict): - logger.info( - "Stored %d tokens in %.3f seconds", - length * self.chunk_size, - ed - st, - ) - return event.ipc_handle(), True - - @_lmcache_nvtx_annotate - def retrieve( - self, - key: IPCCacheEngineKey, - instance_id: int, - gpu_block_ids: list[int], - event_ipc_handle: bytes, - skip_first_n_tokens: int = 0, - ) -> tuple[bytes, bool]: - """ - Retrieves the CPU KV cache and put into GPU blocks. - - Args: - key (IPCCacheEngineKey): The IPC key for the KV cache blocks. - Must have worker_id != None (worker retrieve operation). - instance_id (int): The GPU instance ID (such as PID). - gpu_block_ids (list[int]): The GPU block IDs to retrieve into. - event_ipc_handle (bytes): The IPC handle of the event to wait on. - skip_first_n_tokens (int): Number of tokens to skip writing at - the start of the retrieve range. This avoids overwriting - APC-shared GPU blocks that may be read concurrently by other - requests. - - Returns: - tuple[bytes, bool]: The first element is the IPC handle of the event - that signals the completion of the retrieve operation. The second - element indicates whether the key was successfully retrieved. - """ - st = time.perf_counter() - obj_keys = self._resolve_obj_keys(key) - - context = self.contexts.get(instance_id) - assert context is not None, ( - f"No context registered for instance ID {instance_id}" - ) - assert context.gpu_context is not None, ( - f"GPU context not registered for instance ID {instance_id}" - ) - gpu_context = context.gpu_context - model_name = context.model_name - - # CPU-synchronous sentinel: a GPU retrieve is about to be enqueued. - # Must be published via publish() (not publish_on_stream) so the - # drain thread sees it before MP_REQUEST_END can race MP_RETRIEVE_END. - self._event_bus.publish( - Event( - event_type=EventType.MP_RETRIEVE_SUBMITTED, - session_id=key.request_id, - metadata={"device": str(gpu_context.device)}, - ) - ) - - self._event_bus.publish_on_stream( - gpu_context.cupy_stream, - Event( - event_type=EventType.MP_RETRIEVE_START, - session_id=key.request_id, - metadata={ - "device": str(gpu_context.device), - "engine_id": instance_id, - "model_name": model_name, - }, - ), - ) - - # ``skip_*_in_chunk`` is expressed in engine-block units - # (logical tokens), which is what the kernel's - # ``skip_blocks_in_chunk`` argument expects regardless - # of per-group compression. - ie_logical_block_size = ( - gpu_context.kv_layer_groups_manager.inference_engine_logical_block_size - ) - blocks_per_chunk = self.chunk_size // ie_logical_block_size - - def _retrieve_loop(keys: list[ObjectKey], memory_objs: list[MemoryObj]) -> None: - _BATCH_SIZE = gpu_context.max_batch_size - num_groups = gpu_context.kv_layer_groups_manager.num_groups - for batch_idx, memory_obj_batch in enumerate( - batched_iteration(memory_objs, batch_size=_BATCH_SIZE) - ): - batch_len = len(memory_obj_batch) - chunk_start = batch_idx * self.chunk_size * _BATCH_SIZE - chunk_end = chunk_start + self.chunk_size * batch_len - - effective_start = max(chunk_start, skip_first_n_tokens) - if effective_start >= chunk_end: - # Entire batch is within APC range, skip it - continue - - skip_tokens_in_chunk = max( - 0, - min( - effective_start - chunk_start, - self.chunk_size * batch_len - 1, - ), - ) - if skip_tokens_in_chunk % ie_logical_block_size != 0: - logger.error( - "skip_first_n_tokens (%d) is not aligned to " - "inference_engine_logical_block_size (%d), " - "rounding down from %d tokens to %d blocks", - skip_first_n_tokens, - ie_logical_block_size, - skip_tokens_in_chunk, - skip_tokens_in_chunk // ie_logical_block_size, - ) - skip_blocks_in_chunk = skip_tokens_in_chunk // ie_logical_block_size - - start_chunk_id = batch_idx * _BATCH_SIZE - end_chunk_id = start_chunk_id + batch_len - chunk_block_ids_gpu = all_block_ids_gpu[ - start_chunk_id * blocks_per_chunk : end_chunk_id * blocks_per_chunk - ] - - # Copy from CPU to GPU tmp buffers, then scatter to paged KV — per group - # H2D copy: each memory_obj maps to its own batch slot - for chunk_idx, memory_obj in enumerate(memory_obj_batch): - lmcache_memcpy_async_h2d( - memory_obj, - gpu_context.get_tmp_gpu_buffer_flat(chunk_idx=chunk_idx), - ) - for group_idx in range(num_groups): - tmp_buffers = gpu_context.get_tmp_chunk_gpu_buffer_batched( - batch_len, group_idx - ) - group_kv_pointers = gpu_context.get_group_kv_pointers(group_idx) - group_lmcache_chunk_size = gpu_context.get_physical_chunk_size( - group_idx - ) - - lmc_ops.multi_layer_block_kv_transfer( - group_kv_pointers, - [tb.data_ptr() for tb in tmp_buffers], - chunk_block_ids_gpu, - gpu_context.device, - lmc_ops.TransferDirection.H2D, - gpu_context.get_shape_desc(group_idx), - group_lmcache_chunk_size, - gpu_context.gpu_kv_format_, - skip_blocks_in_chunk, - ) - - with ( - torch_dev.device(gpu_context.device), - torch_dev.stream(gpu_context.stream), - ): - # Stage all block_ids to GPU once before the loop - all_block_ids_gpu = gpu_context.stage_block_ids(gpu_block_ids) - - # Not all backends support interprocess Events (CUDA IPC specific) - check_interprocess_event_support() - event = torch_dev.Event(interprocess=True) - - prefetched_keys: list[ObjectKey] = [] - retrieve_succeeded = False - total_bytes = 0 - try: - with self.storage_manager.read_prefetched_results( - obj_keys - ) as memory_objs: - if not memory_objs or len(memory_objs) != len(obj_keys): - logger.error("Some keys not found during retrieve!") - return event.ipc_handle(), False - - prefetched_keys = obj_keys[: len(memory_objs)] - total_bytes = sum(mo.get_size() for mo in memory_objs) - _retrieve_loop(obj_keys, memory_objs) - # Only set True when with-block exits normally - retrieve_succeeded = True - except Exception: - logger.exception("Cannot retrieve keys due to exception") - return event.ipc_handle(), False - finally: - event.record() - if retrieve_succeeded: - submit_callback_to_stream( - gpu_context.cupy_stream, - "finish_read_prefetched", - prefetched_keys, - ) - self._event_bus.publish_on_stream( - gpu_context.cupy_stream, - Event( - event_type=EventType.MP_RETRIEVE_END, - session_id=key.request_id, - metadata={ - "retrieved_count": len(prefetched_keys), - "device": str(gpu_context.device), - "engine_id": instance_id, - "model_name": model_name, - "cache_salt": key.cache_salt, - "total_bytes": total_bytes, - }, - ), - ) - tokens_retrieved = len(obj_keys) * self.chunk_size - ed = time.perf_counter() - logger.info( - "Retrieved %d tokens in %.3f seconds", - tokens_retrieved, - ed - st, - ) - - return event.ipc_handle(), True - - def _find_layout_desc( - self, - model_name: str, - world_size: int, - ) -> MemoryLayoutDesc | None: - """Find layout desc from a matching GPU or CPU context. - - Returns: - The layout descriptor, or None if no context matches - ``(model_name, world_size)``. GPU contexts are checked first, - then CPU contexts. - """ - for context in self.contexts.values(): - if context.model_name == model_name and context.world_size == world_size: - return context.get_layout_desc(self.chunk_size) - return None - - def lookup( - self, - key: IPCCacheEngineKey, - tp_size: int, - ) -> None: - """Submit a prefix lookup. - - Hashes the key, submits a prefetch task to the storage manager, - and registers the job under ``key.request_id`` for later polling - via query_prefetch_status. - - Args: - key: Cache key with request_id embedded. - tp_size: Tensor-parallel size for MLA multi-reader locking. - """ - model_name, world_size = key.model_name, key.world_size - self._event_bus.publish( - Event( - event_type=EventType.MP_REQUEST_START, - session_id=key.request_id, - ) - ) - self._event_bus.publish( - Event( - event_type=EventType.MP_LOOKUP_PREFETCH_START, - session_id=key.request_id, - ) - ) - - layout_desc = self._find_layout_desc(model_name, world_size) - if layout_desc is None: - logger.error( - "No GPU context found for model %s with world size %d during lookup!", - model_name, - world_size, - ) - self._register_prefetch_job( - _PrefetchJob( - handle=PrefetchHandle( - prefetch_request_id=-1, - external_request_id=key.request_id, - l1_prefix_hit_count=0, - total_requested_keys=0, - submit_time=time.monotonic(), - ), - world_size=1, - request_id=key.request_id, - requested_tokens=0, - model_name=model_name, - cache_salt=key.cache_salt, - ) - ) - return - - extra_count = compute_extra_count(tp_size, world_size) - - # Compute chunk hashes for all full chunks - chunk_hashes = self.token_hasher.compute_chunk_hashes(list(key.token_ids)) - if not chunk_hashes: - self._register_prefetch_job( - _PrefetchJob( - handle=PrefetchHandle( - prefetch_request_id=-1, - external_request_id=key.request_id, - l1_prefix_hit_count=0, - total_requested_keys=0, - submit_time=time.monotonic(), - ), - world_size=1, - request_id=key.request_id, - requested_tokens=0, - model_name=model_name, - cache_salt=key.cache_salt, - ) - ) - return - - # Total chunk-aligned tokens submitted for lookup; surfaces as the - # denominator of the L1+L2 token-level hit-rate via the - # ``requested_tokens`` field on ``MP_LOOKUP_PREFETCH_END``. Sub-chunk - # trailing tokens are intentionally excluded — they cannot hit at - # chunk granularity. - requested_tokens = len(chunk_hashes) * self.chunk_size - - # Publish lookup event via EventBus for observability subscribers. - # Guard with has_subscribers() to avoid allocating the metadata dict - # (including dtype/shape list comprehensions) when no subscriber is - # listening (e.g. lookup hash logger is disabled). - if self._event_bus.has_subscribers(EventType.MP_LOOKUP): - self._event_bus.publish( - Event( - event_type=EventType.MP_LOOKUP, - session_id=key.request_id, - metadata={ - "request_id": key.request_id, - "chunk_hashes": chunk_hashes, - "model_name": model_name, - "chunk_size": self.chunk_size, - "seq_len": len(key.token_ids), - "dtypes": [str(d) for d in layout_desc.dtypes], - "shapes": [list(s) for s in layout_desc.shapes], - }, - ) - ) - - # set lookup ipc key, for session manager to use and generate object keys - session = self.session_manager.get_or_create(key.request_id) - session.set_tokens(list(key.token_ids)) - session.lookup_ipc_key = key - - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) - - handle = self.storage_manager.submit_prefetch_task( - obj_keys, - layout_desc, - extra_count=extra_count, - external_request_id=key.request_id, - ) - self._register_prefetch_job( - _PrefetchJob( - handle=handle, - world_size=key.world_size, - request_id=key.request_id, - requested_tokens=requested_tokens, - model_name=model_name, - cache_salt=key.cache_salt, - ) - ) - - def _register_prefetch_job(self, job: _PrefetchJob) -> None: - with self._prefetch_job_lock: - self._prefetch_jobs[job.request_id] = job - - def query_prefetch_lookup_hits( - self, - request_id: str, - ) -> int | None: - """Query the number of hits for a prefetch request before it's finished. - - Returns: - The number of hits for the prefetched keys if the lookup phase is - done. None if the lookup phase is still in progress. 0 if the - request_id is unknown (already completed and consumed, or invalid). - """ - with self._prefetch_job_lock: - job = self._prefetch_jobs.get(request_id) - - if job is None: - logger.warning( - "Prefetch job for request %s not found (already completed or invalid)", - request_id, - ) - return 0 - - found_count = self.storage_manager.query_prefetch_lookup_hits(job.handle) - if found_count is None: - return None - - found_count = found_count // job.world_size - return found_count - - def query_prefetch_status( - self, - request_id: str, - ) -> int | None: - """Poll the status of a prefetch job by request_id. - - Returns the chunk count when the prefetch is complete, or None - if it is still in progress. The job entry is automatically - removed once a non-None result is returned (exactly-once - semantics). - - Args: - request_id: The external request ID passed in the lookup key. - - Returns: - Chunk count (int) when done, None if still in progress, - 0 if the request_id is unknown (already completed and consumed, - or invalid). - """ - with self._prefetch_job_lock: - job = self._prefetch_jobs.get(request_id) - if job is None: - logger.warning( - "Prefetch job for request %s not found (already completed or invalid)", - request_id, - ) - return 0 - - found_count = self.storage_manager.query_prefetch_status(job.handle) - if found_count is None: - return None - - # NOTE(Kuntai): this assumes two things: - # 1. the world size is the same between keys - # 2. the lookup sort the keys in prefix order and breaks at the - # first failure - found_count = found_count // job.world_size - - self._event_bus.publish( - Event( - event_type=EventType.MP_LOOKUP_PREFETCH_END, - session_id=job.request_id, - metadata={ - "found_count": found_count, - "requested_tokens": job.requested_tokens, - "hit_tokens": found_count * self.chunk_size, - "model_name": job.model_name, - "cache_salt": job.cache_salt, - }, - ) - ) - - with self._prefetch_job_lock: - self._prefetch_jobs.pop(request_id, None) - - return found_count - - def free_lookup_locks( - self, - key: IPCCacheEngineKey, - tp_size: int, - ) -> None: - """Release read locks acquired during lookup. - - Hashes are computed only for chunks in ``[start, end)`` to avoid - unnecessary work on tokens outside that range. - ``start`` and ``end`` must be aligned to ``chunk_size``; it is the - caller's responsibility to align the boundaries as desired. - - Computes the extra reader count from ``tp_size`` and - ``world_size`` the same way :meth:`lookup` does, so - the correct number of locks is released. - - Args: - key: Cache key whose read locks should be released. - tp_size: Tensor-parallel size for MLA - multi-reader locking. - """ - chunk_hashes = self.token_hasher.compute_chunk_hashes( - list(key.token_ids), start=key.start, end=key.end - ) - if not chunk_hashes: - return - obj_keys = ipc_key_to_object_keys(key, chunk_hashes) - - extra_count = compute_extra_count(tp_size, key.world_size) - - self.storage_manager.finish_read_prefetched(obj_keys, extra_count=extra_count) - - # ========================================================================= - # Utility methods - # ========================================================================= - - def ping(self) -> bool: - """ - Respond to a ping request. - - Returns: - bool: Always True. - """ - return True + @property + def context(self) -> MPCacheEngineContext: + """Return the shared engine context.""" + return self._context - def get_chunk_size(self) -> int: - """ - Returns the chunk size used for KV cache operations. + def report_status(self) -> dict: + """Return an aggregated status dict from all modules. Returns: - int: The chunk size. + Combined status from the storage manager, engine metadata, + and each module's ``report_status()`` output. """ - return self.chunk_size - - def end_session(self, request_id: str) -> None: - """Remove the session for a finished request. - - Args: - request_id: The request ID whose session should be removed. - """ - self._event_bus.publish( - Event( - event_type=EventType.MP_VLLM_END_SESSION, - metadata={"request_id": request_id}, - ) - ) - session = self.session_manager.remove(request_id) - self._event_bus.publish( - Event( - event_type=EventType.MP_REQUEST_END, - session_id=request_id, - ) - ) - if session is None: - logger.warning("Session %s not found, skipping touch", request_id) - return - if session.lookup_ipc_key is None: - logger.warning( - "Session %s has no lookup ipc key, skipping touch", request_id - ) - return - - chunk_hashes = [TokenHasher.hash_to_bytes(h) for h in session.get_hashes(0)] - obj_keys = ipc_key_to_object_keys(session.lookup_ipc_key, chunk_hashes) - # unified touch of all keys, which include retrieved and stored keys - # TODO(chunxiaozheng): when l2 is enabled, the prefetched keys from l2 are temp - # and will be deleted after finish_read_prefetched, when we touch all keys, - # these keys has been deleted and will not be touched. - self.storage_manager.touch_l1_keys(obj_keys) - - def report_status(self) -> dict: - """Return a status dict for the entire cache engine.""" - sm = self.storage_manager.report_status() - - gpu_context_meta: dict[str, dict] = {} - non_cuda_context_meta: dict[str, dict] = {} - registered_gpu_ids: list[int] = [] - registered_non_cuda_ids: list[int] = [] - - for instance_id, context in self.contexts.items(): - entry: dict = { - "model_name": context.model_name, - "world_size": context.world_size, - } - if context.gpu_context is not None: - registered_gpu_ids.append(instance_id) - ctx = context.gpu_context - entry["kv_cache_layout"] = { - "num_layers": ctx.num_layers, - "inference_engine_logical_block_size": ( - ctx.kv_layer_groups_manager.inference_engine_logical_block_size - ), - "group_physical_block_sizes": ctx.group_physical_block_sizes, - "group_compress_ratios": ctx.group_compress_ratios, - "hidden_dim_sizes": str(ctx.hidden_dim_sizes), - "dtype": str(ctx.dtype), - "is_mla": ctx.is_mla, - "num_blocks": ctx.num_blocks, - "gpu_kv_format": ctx.gpu_kv_format_name, - "gpu_kv_shape": ctx.gpu_kv_shape, - "gpu_kv_concrete_shape": ctx.concrete_gpu_kv_shape, - "attention_backend": ctx.attention_backend, - "cache_size_per_token": ctx.cache_size_per_token(), - } - gpu_context_meta[str(instance_id)] = entry - continue - - if context.non_cuda_metadata is not None: - registered_non_cuda_ids.append(instance_id) - non_cuda_context_meta[str(instance_id)] = { - **entry, - "block_size": context.non_cuda_metadata.block_size, - "use_mla": context.non_cuda_metadata.use_mla, - } - - return { + sm = self._context.storage_manager.report_status() + status: dict = { "is_healthy": sm["is_healthy"], "engine_type": self.__class__.__name__, - "chunk_size": self.chunk_size, - "hash_algorithm": self.token_hasher.hash_algorithm_name, - "registered_gpu_ids": registered_gpu_ids, - "gpu_context_meta": gpu_context_meta, - "registered_non_cuda_instance_ids": registered_non_cuda_ids, - "non_cuda_context_meta": non_cuda_context_meta, - "active_sessions": self.session_manager.active_count(), - "active_prefetch_jobs": self._active_prefetch_count(), + "chunk_size": self._context.chunk_size, + "hash_algorithm": self._context.token_hasher.hash_algorithm_name, + "active_sessions": self._context.session_manager.active_count(), "storage_manager": sm, } - - def report_block_allocations( - self, - instance_id: int, - model_name: str, - records: list[BlockAllocationRecord], - ) -> None: - """Publish vLLM block allocation records to the EventBus. - - Args: - instance_id: The scheduler instance ID. - model_name: The model name from the adapter. - records: List of BlockAllocationRecord with per-request - block and token allocation deltas. - """ - self._event_bus.publish( - Event( - event_type=EventType.MP_VLLM_BLOCK_ALLOCATION, - metadata={ - "instance_id": instance_id, - "model_name": model_name, - "records": records, - }, - ) - ) - - def debug(self) -> str: - return "OK" - - def clear(self) -> None: - """ - Clears all stored KV cache data from the storage manager. - """ - with self.lock: - self.storage_manager.memcheck() - self.storage_manager.clear(force=True) - self.storage_manager.memcheck() + for module in self._modules: + status.update(module.report_status()) + return status def close(self) -> None: - """ - Closes the MPCacheEngine and releases all resources. - """ - # Stop the drain thread before storage_manager.close() so any - # in-flight completions reach a live storage manager. - self._device_host_func_dispatcher.stop() - - # Close storage manager - self.storage_manager.close() + """Close all modules and release shared resources.""" + for module in self._modules: + module.close() + self._context.storage_manager.close() logger.info("MPCacheEngine closed") - # Release GPU contexts - self.contexts.clear() - - def _active_prefetch_count(self) -> int: - """Return the number of active prefetch jobs (thread-safe).""" - with self._prefetch_job_lock: - return len(self._prefetch_jobs) - - def _setup_metrics(self) -> None: - """Register OTel observable gauges for MP engine metrics.""" - _gauge = partial(register_gauge, "lmcache.mp_engine") - _gauge( - "lmcache_mp.active_prefetch_jobs", - "Number of active prefetch jobs", - self._active_prefetch_count, - ) - def add_handler_helper( server: MessageQueueServer, request_type: RequestType, handler_function ): + """Register a handler with the message queue server. + + Args: + server: The message queue server. + request_type: The request type to handle. + handler_function: The handler callable. + """ payload_classes = get_payload_classes(request_type) handler_type = get_handler_type(request_type) server.add_handler( @@ -1403,124 +121,119 @@ def add_handler_helper( ) +def _build_modules( + ctx: MPCacheEngineContext, + mp_config: MPServerConfig, +) -> list[EngineModule]: + """Assemble the list of engine modules based on configuration. + + Args: + ctx: The shared engine context. + mp_config: Server configuration determining which modules to load. + + Returns: + List of initialized engine modules. + + Raises: + ValueError: If blend engine is requested with non-GPU transfer mode. + """ + modules: list[EngineModule] = [ + LookupModule(ctx), + ManagementModule(ctx), + ] + + if mp_config.transfer_mode == "gpu": + modules.append(GPUTransferModule(ctx)) + else: + modules.append(NonGPUTransferModule(ctx)) + + if mp_config.engine_type == "blend": + if mp_config.transfer_mode != "gpu": + raise ValueError( + "Blend engine requires transfer_mode='gpu', " + f"got '{mp_config.transfer_mode}'" + ) + # First Party + from lmcache.v1.multiprocess.modules.blend import BlendModule + + modules.append(BlendModule(ctx)) + + return modules + + def run_cache_server( mp_config: MPServerConfig, storage_manager_config: StorageManagerConfig, obs_config: ObservabilityConfig, return_engine: bool = False, start_prometheus_http_server: bool = True, -): - """ - Run the LMCache cache server with ZMQ message queue. +) -> tuple[MessageQueueServer, MPCacheEngine] | None: + """Run the LMCache cache server with ZMQ message queue. Args: - mp_config: Configuration for the ZMQ multiprocess server - storage_manager_config: Configuration for the storage manager - obs_config: Configuration for the observability stack + mp_config: Configuration for the ZMQ multiprocess server. + storage_manager_config: Configuration for the storage manager. + obs_config: Configuration for the observability stack. return_engine: If True, return (server, engine) after starting; - if False, run blocking loop to keep server alive + if False, run blocking loop to keep server alive. start_prometheus_http_server: Whether to start a standalone Prometheus HTTP server in a background thread. Set to ``False`` when an external HTTP framework already serves ``/metrics`` to avoid port conflicts or redundant servers. Returns: - If return_engine is True: tuple of (MessageQueueServer, MPCacheEngine) - If return_engine is False: None (blocks until interrupted) + If return_engine is True: tuple of (MessageQueueServer, MPCacheEngine). + If return_engine is False: None (blocks until interrupted). """ event_bus = init_observability( obs_config, start_prometheus_http_server=start_prometheus_http_server ) - # Wire up the trace recorder (no-op when --trace-level is unset). - # Registered before the engine handlers are added so any - # storage-manager calls during engine init are captured too. maybe_initialize_trace_recorder(event_bus, obs_config, storage_manager_config) - # Initialize the engine (loggers self-register with the global controller) - engine = MPCacheEngine( + ctx = MPCacheEngineContext( storage_manager_config=storage_manager_config, chunk_size=mp_config.chunk_size, hash_algorithm=mp_config.hash_algorithm, ) - # Initialize the message queue server - context = zmq.Context.instance() + modules = _build_modules(ctx, mp_config) + engine = MPCacheEngine(ctx, modules) + + zmq_context = zmq.Context.instance() server = MessageQueueServer( bind_url=f"tcp://{mp_config.host}:{mp_config.port}", - context=context, + context=zmq_context, ) - # Add handlers - add_handler_helper(server, RequestType.REGISTER_KV_CACHE, engine.register_kv_cache) - add_handler_helper( - server, RequestType.UNREGISTER_KV_CACHE, engine.unregister_kv_cache - ) - add_handler_helper(server, RequestType.STORE, engine.store) - add_handler_helper( - server, - RequestType.REGISTER_KV_CACHE_NON_GPU_CONTEXT, - engine.register_kv_cache_non_gpu_context, - ) - add_handler_helper(server, RequestType.PREPARE_STORE, engine.prepare_store) - add_handler_helper(server, RequestType.LOOKUP, engine.lookup) - add_handler_helper( - server, RequestType.QUERY_PREFETCH_STATUS, engine.query_prefetch_status - ) - add_handler_helper( - server, - RequestType.QUERY_PREFETCH_LOOKUP_HITS, - engine.query_prefetch_lookup_hits, - ) - add_handler_helper(server, RequestType.FREE_LOOKUP_LOCKS, engine.free_lookup_locks) - add_handler_helper(server, RequestType.RETRIEVE, engine.retrieve) - add_handler_helper(server, RequestType.COMMIT_STORE, engine.commit_store) - add_handler_helper(server, RequestType.PREPARE_RETRIEVE, engine.prepare_retrieve) - add_handler_helper(server, RequestType.COMMIT_RETRIEVE, engine.commit_retrieve) - add_handler_helper(server, RequestType.CLEAR, engine.clear) - add_handler_helper(server, RequestType.GET_CHUNK_SIZE, engine.get_chunk_size) - add_handler_helper(server, RequestType.PING, engine.ping) - add_handler_helper(server, RequestType.END_SESSION, engine.end_session) - add_handler_helper(server, RequestType.NOOP, engine.debug) - add_handler_helper( - server, - RequestType.REPORT_BLOCK_ALLOCATION, - engine.report_block_allocations, - ) + all_specs: list[HandlerSpec] = [] + for module in modules: + all_specs.extend(module.get_handlers()) - # Assign thread pools - server.add_affinity_thread_pool( - [ - RequestType.STORE, - RequestType.RETRIEVE, - RequestType.PREPARE_STORE, - RequestType.COMMIT_STORE, - RequestType.PREPARE_RETRIEVE, - RequestType.COMMIT_RETRIEVE, - ], - max_workers=mp_config.max_gpu_workers, - ) - server.add_normal_thread_pool( - [ - RequestType.LOOKUP, - RequestType.QUERY_PREFETCH_STATUS, - RequestType.QUERY_PREFETCH_LOOKUP_HITS, - RequestType.FREE_LOOKUP_LOCKS, - RequestType.END_SESSION, - RequestType.CLEAR, - RequestType.PING, - RequestType.REPORT_BLOCK_ALLOCATION, - ], - max_workers=mp_config.max_cpu_workers, - ) + for spec in all_specs: + add_handler_helper(server, spec.request_type, spec.handler) + + affinity_types = [ + s.request_type for s in all_specs if s.pool == ThreadPoolType.AFFINITY + ] + normal_types = [ + s.request_type for s in all_specs if s.pool == ThreadPoolType.NORMAL + ] + if affinity_types: + server.add_affinity_thread_pool( + affinity_types, max_workers=mp_config.max_gpu_workers + ) + if normal_types: + server.add_normal_thread_pool( + normal_types, max_workers=mp_config.max_cpu_workers + ) logger.info( "LMCache ZMQ cache server is running on tcp://%s:%d", mp_config.host, mp_config.port, ) - # Start the ZMQ server - # Not all backends expose init(); some auto-initialize on first use + if not hasattr(torch_dev, "init"): logger.warning( "Backend '%s' does not support init(), skipping device init", @@ -1532,11 +245,9 @@ def run_cache_server( logger.info("LMCache cache server is running...") - # Return server and engine if requested (for HTTP server integration) if return_engine: return server, engine - # Dummy loop to keep the server running try: while True: time.sleep(1) @@ -1545,9 +256,15 @@ def run_cache_server( event_bus.stop() server.close() engine.close() + return None def parse_args(): + """Parse command line arguments for the cache server. + + Returns: + Parsed arguments namespace. + """ parser = argparse.ArgumentParser( description="LMCache ZMQ Cache Server (without HTTP)" ) diff --git a/tests/v1/multiprocess/test_blend_server_v2.py b/tests/v1/multiprocess/test_blend_server_v2.py index bdc91e68e17..715711df3f8 100644 --- a/tests/v1/multiprocess/test_blend_server_v2.py +++ b/tests/v1/multiprocess/test_blend_server_v2.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 """ -Unit and integration tests for BlendTokenRangeMatcher and BlendEngineV2. +Unit and integration tests for BlendTokenRangeMatcher and BlendModule. Structure --------- Part 1 – BlendTokenRangeMatcher (pure unit tests, no GPU/server needed) Tests the rolling-hash sub-sequence matching logic in isolation. -Part 2 – BlendEngineV2 integration tests (two-process ZMQ architecture) +Part 2 – Blend integration tests (two-process ZMQ architecture) Uses CB_LOOKUP_PRE_COMPUTED_V2 / CB_RETRIEVE_PRE_COMPUTED_V2, which return/accept list[CBMatchResult] instead of list[tuple[int, int]]. @@ -42,17 +42,17 @@ StorageManagerConfig, ) from lmcache.v1.mp_observability.config import DEFAULT_OBSERVABILITY_CONFIG -from lmcache.v1.multiprocess.blend_server_v2 import ( - BlendEngineV2, - BlendTokenRangeMatcher, - _unique_token_coverage, -) from lmcache.v1.multiprocess.custom_types import ( CBMatchResult, CudaIPCWrapper, IPCCacheEngineKey, KVCache, ) +from lmcache.v1.multiprocess.modules.blend import ( + BlendModule, + BlendTokenRangeMatcher, + _unique_token_coverage, +) from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import ( RequestType, @@ -692,17 +692,17 @@ def expected_full_chunks(num_tokens: int, chunk_size: int = CHUNK_SIZE) -> int: # ============================================================================= -# Server Process Runner (BlendEngineV2) +# Server Process Runner (Blend Mode) # ============================================================================= def server_process_runner_v2( host: str, port: int, chunk_size: int, cpu_buffer_size: float ): - """Entry point for the server process running BlendEngineV2.""" + """Entry point for the server process running blend mode.""" # First Party - from lmcache.v1.multiprocess.blend_server_v2 import run_cache_server from lmcache.v1.multiprocess.config import MPServerConfig + from lmcache.v1.multiprocess.server import run_cache_server mp_config = MPServerConfig( host=host, @@ -733,7 +733,7 @@ def server_process_runner_v2( @pytest.fixture(scope="module") def server_process() -> Generator[mp.Process, None, None]: - """Start the BlendEngineV2 server in a separate process for the module.""" + """Start the blend mode server in a separate process for the module.""" mp.set_start_method("spawn", force=True) process = mp.Process( target=server_process_runner_v2, @@ -855,7 +855,7 @@ def registered_instance( # ============================================================================= -# Part 2: BlendEngineV2 Integration Tests +# Part 2: Blend Integration Tests # ============================================================================= # --------------------------------------------------------------------------- @@ -865,7 +865,7 @@ def registered_instance( def test_server_running_v2(server_process: mp.Process): """Server process should be alive.""" - assert server_process.is_alive(), "BlendEngineV2 server process should be running" + assert server_process.is_alive(), "Blend server process should be running" @pytest.mark.skipif( @@ -2001,19 +2001,28 @@ def test_cb_store_final_v2_visible_to_cb_lookup_v2( # 8. report_status() – CB-extended fields # --------------------------------------------------------------------------- # -# These tests construct a BlendEngineV2 directly in the test process because -# report_status() is invoked in-process by the FastAPI /api/status handler, -# not over ZMQ. CUDA IPC unwrapping is bypassed via a monkeypatched -# unwrap_kv_cache_tensors so the public cb_register_kv_cache path can run -# without a cross-process producer. +# These tests construct an MPCacheEngine with a BlendModule directly in the +# test process because report_status() is invoked in-process by the FastAPI +# /api/status handler, not over ZMQ. CUDA IPC unwrapping is bypassed via a +# monkeypatched unwrap_kv_cache_tensors so the public cb_register_kv_cache +# path can run without a cross-process producer. @pytest.fixture(scope="function") -def in_process_blend_engine() -> Generator[BlendEngineV2, None, None]: - """Build a BlendEngineV2 in-process for direct method-call tests.""" +def in_process_blend_engine() -> Generator[tuple, None, None]: + """Build an MPCacheEngine with BlendModule in-process for direct tests. + + Yields: + (engine, blend_module) tuple. + """ if not torch.cuda.is_available(): pytest.skip("CUDA is not available") + # First Party + from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext + from lmcache.v1.multiprocess.modules.gpu_transfer import GPUTransferModule + from lmcache.v1.multiprocess.server import MPCacheEngine + config = StorageManagerConfig( l1_manager_config=L1ManagerConfig( memory_config=L1MemoryManagerConfig( @@ -2023,12 +2032,15 @@ def in_process_blend_engine() -> Generator[BlendEngineV2, None, None]: ), eviction_config=EvictionConfig(eviction_policy="LRU"), ) - engine = BlendEngineV2( + ctx = MPCacheEngineContext( storage_manager_config=config, chunk_size=CHUNK_SIZE, ) + gpu_module = GPUTransferModule(ctx) + blend_module = BlendModule(ctx) + engine = MPCacheEngine(ctx, [gpu_module, blend_module]) try: - yield engine + yield engine, blend_module finally: engine.close() @@ -2055,10 +2067,11 @@ def _local_unwrap(_kv_caches): not torch.cuda.is_available(), reason="report_status CB tests require CUDA" ) def test_report_status_no_cb_registrations( - in_process_blend_engine: BlendEngineV2, + in_process_blend_engine: tuple, ): """Without any CB registration, the new fields are present and empty.""" - status = in_process_blend_engine.report_status() + engine, _blend_module = in_process_blend_engine + status = engine.report_status() assert "registered_cb_gpu_ids" in status assert "cb_gpu_context_meta" in status @@ -2070,15 +2083,15 @@ def test_report_status_no_cb_registrations( not torch.cuda.is_available(), reason="report_status CB tests require CUDA" ) def test_report_status_surfaces_cb_registration( - in_process_blend_engine: BlendEngineV2, + in_process_blend_engine: tuple, cb_client_context: CBClientContext, local_kv_cache_unwrap, ): """A CB-registered instance shows up in both new fields with correct shape.""" - engine = in_process_blend_engine + engine, blend_module = in_process_blend_engine instance_id = 4242 - engine.cb_register_kv_cache( + blend_module.cb_register_kv_cache( instance_id, cb_client_context.get_kv_cache(), "testmodel", @@ -2110,21 +2123,21 @@ def test_report_status_surfaces_cb_registration( not torch.cuda.is_available(), reason="report_status CB tests require CUDA" ) def test_report_status_unregister_clears_cb_fields( - in_process_blend_engine: BlendEngineV2, + in_process_blend_engine: tuple, cb_client_context: CBClientContext, local_kv_cache_unwrap, ): """Unregistering removes the entry from both new fields.""" - engine = in_process_blend_engine + engine, blend_module = in_process_blend_engine instance_id = 4243 - engine.cb_register_kv_cache( + blend_module.cb_register_kv_cache( instance_id, cb_client_context.get_kv_cache(), "testmodel", 1, ) - engine.cb_unregister_kv_cache(instance_id) + blend_module.cb_unregister_kv_cache(instance_id) status = engine.report_status() assert status["registered_cb_gpu_ids"] == [] diff --git a/tests/v1/multiprocess/test_cache_server.py b/tests/v1/multiprocess/test_cache_server.py index 61818b89536..62c84b94d92 100644 --- a/tests/v1/multiprocess/test_cache_server.py +++ b/tests/v1/multiprocess/test_cache_server.py @@ -38,7 +38,7 @@ SERVER_URL = f"tcp://{SERVER_HOST}:{SERVER_PORT}" CHUNK_SIZE = 256 CPU_BUFFER_SIZE = 5.0 -DEFAULT_TIMEOUT = 5.0 +DEFAULT_TIMEOUT = 10.0 def _has_working_new_shared_cuda() -> bool: diff --git a/tests/v1/multiprocess/test_free_locks.py b/tests/v1/multiprocess/test_free_locks.py index 4ba1936df45..e08a7911513 100644 --- a/tests/v1/multiprocess/test_free_locks.py +++ b/tests/v1/multiprocess/test_free_locks.py @@ -88,42 +88,42 @@ def test_mq_free_locks(): def test_server_free_lookup_locks_calls_finish_read_prefetched(): - """MPCacheEngine.free_lookup_locks should resolve hash keys and call + """LookupModule.free_lookup_locks should resolve hash keys and call finish_read_prefetched on the storage manager.""" # First Party - from lmcache.v1.multiprocess.server import MPCacheEngine + from lmcache.v1.multiprocess.modules.lookup import LookupModule - engine = MagicMock() - engine.token_hasher = MagicMock() - engine.token_hasher.chunk_size = 256 - engine.token_hasher.compute_chunk_hashes.return_value = [b"hash0"] + ctx = MagicMock() + ctx.token_hasher.chunk_size = 256 + ctx.token_hasher.compute_chunk_hashes.return_value = [b"hash0"] + + module = LookupModule(ctx) # Build a key key = create_cache_key(0).no_worker_id_version() sentinel_obj_keys = [MagicMock()] with patch( - "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + "lmcache.v1.multiprocess.modules.lookup.ipc_key_to_object_keys", return_value=sentinel_obj_keys, ): - # Call the real method on the mock - MPCacheEngine.free_lookup_locks(engine, key, 1) + module.free_lookup_locks(key, 1) - engine.storage_manager.finish_read_prefetched.assert_called_once_with( + module.context.storage_manager.finish_read_prefetched.assert_called_once_with( sentinel_obj_keys, extra_count=0 ) def test_server_free_lookup_locks_no_matching_chunks(): - """MPCacheEngine.free_lookup_locks with no chunks in range should be a no-op.""" + """LookupModule.free_lookup_locks with no chunks in range should be a no-op.""" # First Party - from lmcache.v1.multiprocess.server import MPCacheEngine + from lmcache.v1.multiprocess.modules.lookup import LookupModule + + ctx = MagicMock() + ctx.token_hasher.chunk_size = 256 + ctx.token_hasher.compute_chunk_hashes.return_value = [] - engine = MagicMock() - engine.token_hasher = MagicMock() - engine.token_hasher.chunk_size = 256 - # start=end=0 is passed to compute_chunk_hashes, which returns no hashes - engine.token_hasher.compute_chunk_hashes.return_value = [] + module = LookupModule(ctx) # Key with start == end means no chunks to free key = IPCCacheEngineKey( @@ -136,19 +136,18 @@ def test_server_free_lookup_locks_no_matching_chunks(): request_id="req-empty", ) - MPCacheEngine.free_lookup_locks(engine, key, 1) + module.free_lookup_locks(key, 1) - engine.storage_manager.finish_read_prefetched.assert_not_called() + module.context.storage_manager.finish_read_prefetched.assert_not_called() def test_server_handler_registered(): - """run_cache_server should register a FREE_LOOKUP_LOCKS handler.""" + """LookupModule should have a free_lookup_locks method.""" # First Party - from lmcache.v1.multiprocess.server import MPCacheEngine + from lmcache.v1.multiprocess.modules.lookup import LookupModule - engine = MPCacheEngine.__new__(MPCacheEngine) - assert hasattr(engine, "free_lookup_locks") - assert callable(engine.free_lookup_locks) + assert hasattr(LookupModule, "free_lookup_locks") + assert callable(LookupModule.free_lookup_locks) # ============================================================================ diff --git a/tests/v1/multiprocess/test_non_cuda_data_transfer.py b/tests/v1/multiprocess/test_non_cuda_data_transfer.py index f8a281a8785..a9b246a1dc6 100644 --- a/tests/v1/multiprocess/test_non_cuda_data_transfer.py +++ b/tests/v1/multiprocess/test_non_cuda_data_transfer.py @@ -357,16 +357,18 @@ def test_server_register_and_find_non_cuda_context_layout( """Ensure non-CUDA registration stores metadata and lookup finds layout.""" # First Party from lmcache.v1.multiprocess.custom_types import RegisterNonGpuContextPayload - from lmcache.v1.multiprocess.server import MPCacheEngine + from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext + from lmcache.v1.multiprocess.modules.non_gpu_transfer import NonGPUTransferModule with ( - patch("lmcache.v1.multiprocess.server.StorageManager"), - patch("lmcache.v1.multiprocess.server.TokenHasher"), - patch("lmcache.v1.multiprocess.server.SessionManager"), - patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch("lmcache.v1.multiprocess.engine_context.StorageManager"), + patch("lmcache.v1.multiprocess.engine_context.TokenHasher"), + patch("lmcache.v1.multiprocess.engine_context.SessionManager"), + patch("lmcache.v1.multiprocess.engine_context.get_event_bus"), ): - engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=16) - engine.register_kv_cache_non_gpu_context( + ctx = MPCacheEngineContext(storage_manager_config=MagicMock(), chunk_size=16) + module = NonGPUTransferModule(ctx) + module.register_kv_cache_non_gpu_context( RegisterNonGpuContextPayload( instance_id=1, model_name="m", @@ -379,7 +381,7 @@ def test_server_register_and_find_non_cuda_context_layout( ) ) - layout = engine._find_layout_desc("m", 1) + layout = ctx.layout_desc_registry.find("m", 1) assert layout is not None assert layout.shapes[0] == torch.Size([2, 2, 16, 16]) @@ -391,7 +393,8 @@ def test_server_store_and_retrieve_cpu_chunks(stub_native_storage_ops: Any) -> N IPCCacheEngineKey, RegisterNonGpuContextPayload, ) - from lmcache.v1.multiprocess.server import MPCacheEngine + from lmcache.v1.multiprocess.engine_context import MPCacheEngineContext + from lmcache.v1.multiprocess.modules.non_gpu_transfer import NonGPUTransferModule mock_storage = MagicMock() target_tensor = torch.zeros(2, 2, 8, 16) @@ -408,21 +411,22 @@ def _read_prefetched_results(_keys: Any) -> Any: mock_session.get_hashes.return_value = [b"h"] with ( patch( - "lmcache.v1.multiprocess.server.StorageManager", + "lmcache.v1.multiprocess.engine_context.StorageManager", return_value=mock_storage, ), - patch("lmcache.v1.multiprocess.server.TokenHasher"), - patch("lmcache.v1.multiprocess.server.SessionManager") as session_cls, - patch("lmcache.v1.multiprocess.server.get_event_bus"), + patch("lmcache.v1.multiprocess.engine_context.TokenHasher"), + patch("lmcache.v1.multiprocess.engine_context.SessionManager") as session_cls, + patch("lmcache.v1.multiprocess.engine_context.get_event_bus"), patch( - "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + "lmcache.v1.multiprocess.engine_context.ipc_key_to_object_keys", return_value=["obj"], ), ): session_cls.return_value.get_or_create.return_value = mock_session - engine = MPCacheEngine(storage_manager_config=MagicMock(), chunk_size=8) + ctx = MPCacheEngineContext(storage_manager_config=MagicMock(), chunk_size=8) - engine.register_kv_cache_non_gpu_context( + module = NonGPUTransferModule(ctx) + module.register_kv_cache_non_gpu_context( RegisterNonGpuContextPayload( instance_id=2, model_name="m", @@ -445,11 +449,11 @@ def _read_prefetched_results(_keys: Any) -> Any: request_id="req", ) with patch( - "lmcache.v1.multiprocess.server.ipc_key_to_object_keys", + "lmcache.v1.multiprocess.engine_context.ipc_key_to_object_keys", return_value=["obj"], ): - store_ok = engine.commit_store(key, 2, pickle.dumps([payload])) - response = engine.prepare_retrieve(key, 2) + store_ok = module.commit_store(key, 2, pickle.dumps([payload])) + response = module.prepare_retrieve(key, 2) success = response.success cpu_data = response.data assert isinstance(store_ok, bool) diff --git a/tests/v1/multiprocess/test_query_lookup_hits.py b/tests/v1/multiprocess/test_query_lookup_hits.py index 53a1d85cbc2..387f1b5061b 100644 --- a/tests/v1/multiprocess/test_query_lookup_hits.py +++ b/tests/v1/multiprocess/test_query_lookup_hits.py @@ -6,11 +6,11 @@ # Standard from unittest.mock import MagicMock -import threading import time # First Party from lmcache.v1.distributed.storage_manager import PrefetchHandle +from lmcache.v1.multiprocess.modules.lookup import LookupModule, _PrefetchJob from lmcache.v1.multiprocess.protocol import ( RequestType, get_handler_type, @@ -18,7 +18,6 @@ get_response_class, ) from lmcache.v1.multiprocess.protocols.base import HandlerType -from lmcache.v1.multiprocess.server import MPCacheEngine, _PrefetchJob # Test helpers from tests.v1.multiprocess.test_mq import ( @@ -107,16 +106,17 @@ def test_mq_query_prefetch_lookup_hits_none_response(): # ============================================================================ -def _make_engine_with_job( +def _make_module_with_job( world_size: int, storage_return: int | None -) -> tuple[MagicMock, str]: - """Create a mock MPCacheEngine with a single prefetch job. +) -> tuple[LookupModule, str]: + """Create a LookupModule with a mock context and a single prefetch job. Returns: - (engine_mock, request_id) + (module, request_id) """ - engine = MagicMock() - engine._prefetch_job_lock = threading.Lock() + ctx = MagicMock() + ctx.token_hasher.chunk_size = 256 + module = LookupModule(ctx) handle = PrefetchHandle( prefetch_request_id=0, @@ -132,50 +132,50 @@ def _make_engine_with_job( request_id=request_id, requested_tokens=0, ) - engine._prefetch_jobs = {request_id: job} - engine.storage_manager.query_prefetch_lookup_hits.return_value = storage_return + module._prefetch_jobs[request_id] = job + ctx.storage_manager.query_prefetch_lookup_hits.return_value = storage_return - return engine, request_id + return module, request_id def test_server_lookup_hits_returns_count(): """query_prefetch_lookup_hits returns chunk count when lookup is done.""" - engine, request_id = _make_engine_with_job(world_size=1, storage_return=5) + module, request_id = _make_module_with_job(world_size=1, storage_return=5) - result = MPCacheEngine.query_prefetch_lookup_hits(engine, request_id) + result = module.query_prefetch_lookup_hits(request_id) assert result == 5 - engine.storage_manager.query_prefetch_lookup_hits.assert_called_once() + module.context.storage_manager.query_prefetch_lookup_hits.assert_called_once() def test_server_lookup_hits_divides_by_world_size(): """Result should be divided by world_size for tensor parallelism.""" - engine, request_id = _make_engine_with_job(world_size=2, storage_return=10) + module, request_id = _make_module_with_job(world_size=2, storage_return=10) - result = MPCacheEngine.query_prefetch_lookup_hits(engine, request_id) + result = module.query_prefetch_lookup_hits(request_id) assert result == 5 # 10 // 2 def test_server_lookup_hits_returns_none_when_in_progress(): """Returns None when storage manager lookup is still in progress.""" - engine, request_id = _make_engine_with_job(world_size=1, storage_return=None) + module, request_id = _make_module_with_job(world_size=1, storage_return=None) - result = MPCacheEngine.query_prefetch_lookup_hits(engine, request_id) + result = module.query_prefetch_lookup_hits(request_id) assert result is None def test_server_lookup_hits_returns_zero_for_invalid_request(): """Returns 0 for a request_id that doesn't exist (prevents infinite spin).""" - engine = MagicMock() - engine._prefetch_job_lock = threading.Lock() - engine._prefetch_jobs = {} + ctx = MagicMock() + ctx.token_hasher.chunk_size = 256 + module = LookupModule(ctx) - result = MPCacheEngine.query_prefetch_lookup_hits(engine, "nonexistent-req") + result = module.query_prefetch_lookup_hits("nonexistent-req") assert result == 0 - engine.storage_manager.query_prefetch_lookup_hits.assert_not_called() + ctx.storage_manager.query_prefetch_lookup_hits.assert_not_called() def test_server_lookup_hits_returns_zero_after_prefetch_consumed(): @@ -183,28 +183,27 @@ def test_server_lookup_hits_returns_zero_after_prefetch_consumed(): This prevents the caller from spinning forever on a completed request. """ - engine, request_id = _make_engine_with_job(world_size=1, storage_return=5) + module, request_id = _make_module_with_job(world_size=1, storage_return=5) # Simulate query_prefetch_status consuming the job - del engine._prefetch_jobs[request_id] + module._prefetch_jobs.pop(request_id) - result = MPCacheEngine.query_prefetch_lookup_hits(engine, request_id) + result = module.query_prefetch_lookup_hits(request_id) assert result == 0 def test_server_lookup_hits_zero_count(): """Returns 0 when no keys matched (not None).""" - engine, request_id = _make_engine_with_job(world_size=1, storage_return=0) + module, request_id = _make_module_with_job(world_size=1, storage_return=0) - result = MPCacheEngine.query_prefetch_lookup_hits(engine, request_id) + result = module.query_prefetch_lookup_hits(request_id) assert result == 0 assert result is not None def test_server_handler_registered(): - """MPCacheEngine should have a query_prefetch_lookup_hits method.""" - engine = MPCacheEngine.__new__(MPCacheEngine) - assert hasattr(engine, "query_prefetch_lookup_hits") - assert callable(engine.query_prefetch_lookup_hits) + """LookupModule should have a query_prefetch_lookup_hits method.""" + assert hasattr(LookupModule, "query_prefetch_lookup_hits") + assert callable(LookupModule.query_prefetch_lookup_hits)