From 0cf4d717b58213df93472f2dbddaade300e14c6f Mon Sep 17 00:00:00 2001 From: "ziyi.wang" Date: Tue, 12 May 2026 07:01:53 +0800 Subject: [PATCH 1/5] feat: add prefix kv cache --- benchmarks/throughput.py | 53 ++- .../test_kv_cache_manager.py | 358 ++++++++++++++++++ tests/test_scheduler.py | 14 +- .../batch_invariant_ops/kv_cache_context.py | 291 +++++++++++--- vexact/core/request.py | 7 + vexact/core/scheduler.py | 104 ++++- vexact/engine.py | 11 + vexact/worker/driver_client.py | 9 + vexact/worker/driver_worker.py | 58 ++- vexact/worker/worker.py | 5 + vexact/worker/worker_proxy.py | 4 +- 11 files changed, 843 insertions(+), 71 deletions(-) create mode 100644 tests/batch_invariant_ops/test_kv_cache_manager.py diff --git a/benchmarks/throughput.py b/benchmarks/throughput.py index 12623a4..8302833 100644 --- a/benchmarks/throughput.py +++ b/benchmarks/throughput.py @@ -108,7 +108,7 @@ def _load_sharegpt_samples( return samples if wants_all else samples[:num_requests] -async def _run_test(vexact_engine, samples, timeout_s: float | None): +async def _run_test(vexact_engine, samples, timeout_s: float | None, system_prompt_ids: list[int] | None = None): total_prompt_tokens = 0 total_output_tokens = 0 latencies = [] @@ -120,6 +120,8 @@ async def submit_one(sample): prompt = sample.prompt try: input_ids = vexact_engine.tokenizer.encode(prompt, add_special_tokens=True) + if system_prompt_ids: + input_ids = list(system_prompt_ids) + input_ids gen_config = GenerationConfig( max_new_tokens=sample.expected_output_len, max_length=vexact_engine.config.model.max_model_len, @@ -150,6 +152,21 @@ async def submit_one(sample): return total_prompt_tokens, total_output_tokens, latencies, errors, completed, total_time +def _build_synthetic_system_prompt(tokenizer, target_len: int) -> list[int]: + """Build a deterministic token-id sequence of exactly target_len tokens. + + Used by --system-prompt-len to prepend a shared prefix to every request so + a prefix-cache implementation can be exercised. We tile a fixed instruction + string and truncate; exact text doesn't matter since cache hits are by + token-id match. + """ + base = "You are a helpful assistant. Follow the user's instructions carefully and answer concisely. " + ids: list[int] = [] + while len(ids) < target_len: + ids.extend(tokenizer.encode(base, add_special_tokens=False)) + return ids[:target_len] + + def run_throughput( vexact_engine, total_requests: int | None, @@ -157,6 +174,7 @@ def run_throughput( dataset_path: str, timeout_s: float | None, include_multimodal: bool = False, + system_prompt_len: int = 0, ): tokenizer = vexact_engine.tokenizer samples = _load_sharegpt_samples( @@ -167,10 +185,12 @@ def run_throughput( include_multimodal=include_multimodal, ) - total_prompt_tokens, total_output_tokens, latencies, errors, completed, total_time = asyncio.run( - _run_test(vexact_engine, samples, timeout_s) - ) - return total_prompt_tokens, total_output_tokens, latencies, errors, completed, total_time + system_prompt_ids: list[int] | None = None + if system_prompt_len > 0: + system_prompt_ids = _build_synthetic_system_prompt(tokenizer, system_prompt_len) + print(f"Prepending {system_prompt_len}-token shared prefix to every request") + + return asyncio.run(_run_test(vexact_engine, samples, timeout_s, system_prompt_ids=system_prompt_ids)) def _parse_args(): @@ -252,6 +272,13 @@ def _parse_args(): default=0, help="Number of steps to profile (0 = until manually stopped, default: 0).", ) + parser.add_argument( + "--system-prompt-len", + type=int, + default=0, + help="If > 0, prepend a deterministic shared prefix of this many tokens to " + "every request. Useful for measuring prefix-cache hit benefit.", + ) return parser.parse_args() @@ -281,7 +308,10 @@ def main(): dataset_path=args.dataset_path, timeout_s=args.timeout_s, include_multimodal=args.include_multimodal, + system_prompt_len=args.system_prompt_len, ) + # Snapshot before close() — stats live in the worker process. + prefix_cache_stats = engine.get_prefix_cache_stats() finally: engine.close() @@ -310,6 +340,19 @@ def main(): print(f"Avg latency: {avg_latency:.3f}s") print(f"P50 latency: {p50:.3f}s") print(f"P95 latency: {p95:.3f}s") + if args.system_prompt_len > 0: + print(f"Shared prefix: {args.system_prompt_len} tokens prepended to every request") + if prefix_cache_stats.get("prefix_cache_enabled"): + hit = prefix_cache_stats["hit_tokens"] + miss = prefix_cache_stats["miss_tokens"] + ratio = prefix_cache_stats["hit_ratio"] + print( + f"Prefix cache: {hit}/{hit + miss} tokens hit ({ratio * 100:.1f}%), " + f"cached_blocks={prefix_cache_stats['cached_blocks']}, " + f"free_blocks={prefix_cache_stats['free_blocks']}" + ) + else: + print("Prefix cache: disabled") if __name__ == "__main__": diff --git a/tests/batch_invariant_ops/test_kv_cache_manager.py b/tests/batch_invariant_ops/test_kv_cache_manager.py new file mode 100644 index 0000000..acbbdc9 --- /dev/null +++ b/tests/batch_invariant_ops/test_kv_cache_manager.py @@ -0,0 +1,358 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for KVCacheManager: refcounting, LRU eviction, and prefix cache plan/commit/mark.""" + +import pytest + +from vexact.batch_invariant_ops.kv_cache_context import KVCacheManager +from vexact.config import CacheConfig + + +@pytest.fixture +def mgr(): + def _build(page_size: int = 4, max_blocks: int = 8, enable_prefix_cache: bool = True): + return KVCacheManager( + CacheConfig(page_size=page_size, max_cache_blocks=max_blocks), + enable_prefix_cache=enable_prefix_cache, + ) + + return _build + + +# ---------- construction & public surface ---------- + + +def test_construction_exposes_public_attrs(mgr): + m = mgr(page_size=8, max_blocks=16) + assert m.prefix_cache_enabled is True + assert m.page_size == 8 + assert m.num_free_blocks() == 16 + assert m.num_allocated_blocks() == 0 + assert m.num_cached_blocks() == 0 + + +def test_construction_disabled(mgr): + m = mgr(enable_prefix_cache=False) + assert m.prefix_cache_enabled is False + + +# ---------- plan_prefix_cache ---------- + + +def test_plan_empty_tokens(mgr): + m = mgr() + assert m.plan_prefix_cache([]) == ([], 0) + + +def test_plan_only_partial_block(mgr): + # page_size=4, 3 tokens → 0 full blocks + m = mgr() + assert m.plan_prefix_cache([1, 2, 3]) == ([], 0) + + +def test_plan_when_disabled(mgr): + m = mgr(enable_prefix_cache=False) + assert m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) == ([], 0) + + +def test_plan_cold_cache_returns_hashes_but_zero_hits(mgr): + m = mgr() + hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + assert len(hashes) == 2 + assert n_cached == 0 + + +def test_chain_hash_diverges_on_content_difference(mgr): + m = mgr() + h_a, _ = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + h_b, _ = m.plan_prefix_cache([1, 2, 3, 4, 9, 9, 9, 9]) + assert h_a[0] == h_b[0] # same first block content + assert h_a[1] != h_b[1] # diverged from block 1 onward + + +# ---------- commit_prefix_plan ---------- + + +def test_commit_cold_allocates_fresh(mgr): + m = mgr() + hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids = m.commit_prefix_plan(hashes, n_cached, 8) + assert bids == [0, 1] + assert m.num_allocated_blocks() == 2 + + +def test_commit_partial_only_takes_one_block(mgr): + m = mgr() + hashes, n_cached = m.plan_prefix_cache([1, 2, 3]) + bids = m.commit_prefix_plan(hashes, n_cached, 3) + assert bids == [0] + assert m.num_allocated_blocks() == 1 + + +def test_commit_full_plus_partial(mgr): + # 1 full + 1 partial = 2 blocks total + m = mgr() + hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5]) + assert len(hashes) == 1 + bids = m.commit_prefix_plan(hashes, n_cached, 5) + assert len(bids) == 2 + + +def test_commit_oom_rollback_keeps_pool_intact(mgr): + m = mgr(max_blocks=2) + hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + assert m.commit_prefix_plan(hashes, n_cached, 12) is None + assert m.num_free_blocks() == 2 + assert m.num_allocated_blocks() == 0 + + +def test_commit_oom_releases_cached_increfs(mgr): + # Cached blocks incref'd during commit must be decref'd back on OOM, not stuck. + m = mgr(max_blocks=2) + # First request fills the cache index (2 full blocks). + h1, n1 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids1 = m.commit_prefix_plan(h1, n1, 8) + m.mark_blocks_filled(bids1, h1) + m.free_blocks(bids1) + + # Second request: same prefix (full hit) + a partial last block. OOM because + # the partial needs a fresh block but the only 2 blocks just got refcounted. + h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8, 99, 99, 99]) + assert n2 == 2 + assert m.commit_prefix_plan(h2, n2, 11) is None + assert m.num_free_blocks() == 2 + assert m.num_allocated_blocks() == 0 + # Cache index untouched by the failed commit + assert m.num_cached_blocks() == 2 + + +# ---------- mark_blocks_filled ---------- + + +def test_mark_blocks_filled_records_full_blocks_only(mgr): + m = mgr() + hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7]) # 1 full + partial + bids = m.commit_prefix_plan(hashes, n_cached, 7) + m.mark_blocks_filled(bids, hashes) + assert m.num_cached_blocks() == 1 # only the full block stamped + + +def test_mark_blocks_filled_idempotent(mgr): + m = mgr() + hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids = m.commit_prefix_plan(hashes, n_cached, 8) + + m.mark_blocks_filled(bids, hashes) + state1 = (dict(m._block_hash_to_id), dict(m._block_id_to_hash)) + m.mark_blocks_filled(bids, hashes) + m.mark_blocks_filled(bids, hashes) + state2 = (dict(m._block_hash_to_id), dict(m._block_id_to_hash)) + assert state1 == state2 + + +def test_mark_blocks_filled_stamps_misses_after_partial_hit(mgr): + # Half-hit: first block hits cache, second block is fresh. The fresh block + # must end up correctly stamped using the precomputed chain hash. + m = mgr() + h1, n1 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids1 = m.commit_prefix_plan(h1, n1, 8) + m.mark_blocks_filled(bids1, h1) + m.free_blocks(bids1) + + h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 9, 9, 9, 9]) + assert n2 == 1 # only first block hits + bids2 = m.commit_prefix_plan(h2, n2, 8) + m.mark_blocks_filled(bids2, h2) + + assert m._block_id_to_hash[bids2[1]] == h2[1] + assert m._block_hash_to_id[h2[1]] == bids2[1] + + +def test_mark_blocks_filled_noop_when_empty_hashes(mgr): + m = mgr(enable_prefix_cache=False) + bids = m.commit_prefix_plan([], 0, 8) + m.mark_blocks_filled(bids, []) # empty hashes path + assert m.num_cached_blocks() == 0 + + +# ---------- full cycle: miss → hit ---------- + + +def test_full_cycle_miss_then_hit(mgr): + m = mgr() + toks = [1, 2, 3, 4, 5, 6, 7, 8] + + h1, n1 = m.plan_prefix_cache(toks) + assert n1 == 0 + bids1 = m.commit_prefix_plan(h1, n1, len(toks)) + m.mark_blocks_filled(bids1, h1) + m.free_blocks(bids1) + assert m.num_free_blocks() == 8 + assert m.num_cached_blocks() == 2 # hashes survive free + + # Same content again — full hit on the same physical blocks. + h2, n2 = m.plan_prefix_cache(toks) + assert h2 == h1 + assert n2 == 2 + bids2 = m.commit_prefix_plan(h2, n2, len(toks)) + assert bids2 == bids1 + + +def test_partial_last_block_does_not_block_prefix_hits(mgr): + # Two requests share the same first full block but differ in the partial tail — + # the full block should still hit. + m = mgr() + h1, n1 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7]) + bids1 = m.commit_prefix_plan(h1, n1, 7) + m.mark_blocks_filled(bids1, h1) + m.free_blocks(bids1) + + h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 99, 99, 99]) + assert n2 == 1 + assert h2[0] == h1[0] + + +def test_contiguous_run_stops_at_first_miss(mgr): + # Fill cache: blocks [1..4][5..8][9..12]. + m = mgr() + h1, _ = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + bids1 = m.commit_prefix_plan(h1, 0, 12) + m.mark_blocks_filled(bids1, h1) + m.free_blocks(bids1) + + # New: [1..4] hits, [99..] misses, [9..12] hashes ALWAYS diverge because the + # chain depends on the previous block's hash. So even the "same" 3rd block + # is a different chain hash → reported miss. + h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 99, 99, 99, 99, 9, 10, 11, 12]) + assert n2 == 1 + assert h2[0] == h1[0] + assert h2[2] != h1[2] + + +def test_refcount_shared_between_concurrent_requests(mgr): + m = mgr() + h_a, n_a = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids_a = m.commit_prefix_plan(h_a, n_a, 8) + m.mark_blocks_filled(bids_a, h_a) + + # Second concurrent request hits both blocks. + h_b, n_b = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids_b = m.commit_prefix_plan(h_b, n_b, 8) + assert bids_b == bids_a + for bid in bids_a: + assert m._refcount[bid] == 2 + + m.free_blocks(bids_b) + for bid in bids_a: + assert m._refcount[bid] == 1 + assert m.num_free_blocks() == 6 # blocks still owned by A + + m.free_blocks(bids_a) + assert m.num_free_blocks() == 8 + + +# ---------- LRU & eviction ---------- + + +def test_free_lru_oldest_first(mgr): + m = mgr(max_blocks=4) + a = m.allocate_blocks(total_tokens=4, num_current_blocks=0) + b = m.allocate_blocks(total_tokens=4, num_current_blocks=0) + c = m.allocate_blocks(total_tokens=4, num_current_blocks=0) + assert a == [0] and b == [1] and c == [2] + + m.free_blocks(a) # free order: 0 + m.free_blocks(b) # then 1 + m.free_blocks(c) # then 2 + # Pool currently: [3 (never used), 0, 1, 2]. Oldest first → 3, then 0, 1, 2. + + assert m.allocate_blocks(total_tokens=4, num_current_blocks=0) == [3] + assert m.allocate_blocks(total_tokens=4, num_current_blocks=0) == [0] + assert m.allocate_blocks(total_tokens=4, num_current_blocks=0) == [1] + assert m.allocate_blocks(total_tokens=4, num_current_blocks=0) == [2] + + +def test_eviction_drops_hash_entry(mgr): + m = mgr(max_blocks=2) + # Fill cache with 2 hashed blocks. + h1, _ = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids1 = m.commit_prefix_plan(h1, 0, 8) + m.mark_blocks_filled(bids1, h1) + m.free_blocks(bids1) + assert m.num_cached_blocks() == 2 + + # Allocate 2 fresh blocks for unrelated content → both old hashes get evicted. + h2, _ = m.plan_prefix_cache([100, 101, 102, 103, 104, 105, 106, 107]) + bids2 = m.commit_prefix_plan(h2, 0, 8) + assert bids2 is not None + for old in h1: + assert old not in m._block_hash_to_id + + +# ---------- clear_cache_index ---------- + + +def test_clear_cache_index_drops_hashes_only(mgr): + m = mgr() + h, n = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + bids = m.commit_prefix_plan(h, n, 8) + m.mark_blocks_filled(bids, h) + assert m.num_cached_blocks() == 2 + assert m.num_allocated_blocks() == 2 + + m.clear_cache_index() + assert m.num_cached_blocks() == 0 + assert m.num_allocated_blocks() == 2 # refcounts untouched + + # New plan with same content sees a cold cache. + _, n2 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + assert n2 == 0 + + +# ---------- disabled prefix cache ---------- + + +def test_disabled_cache_behaves_like_plain_allocator(mgr): + m = mgr(enable_prefix_cache=False) + h, n = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + assert h == [] and n == 0 + bids = m.commit_prefix_plan(h, n, 8) + assert len(bids) == 2 + + m.mark_blocks_filled(bids, h) # no-op + assert m.num_cached_blocks() == 0 + + # Second identical request still misses. + m.free_blocks(bids) + _, n2 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + assert n2 == 0 + + +# ---------- allocate_blocks (decode-path) ---------- + + +def test_allocate_blocks_incremental(mgr): + m = mgr() + bids = m.commit_prefix_plan([], 0, 4) # 1 block via partial-only path + assert m.allocate_blocks(total_tokens=16, num_current_blocks=1) == [1, 2, 3] + # Already covers — no further allocation needed. + assert m.allocate_blocks(total_tokens=16, num_current_blocks=4) == [] + del bids + + +def test_allocate_blocks_oom_returns_none(mgr): + m = mgr(max_blocks=2) + assert m.allocate_blocks(total_tokens=12, num_current_blocks=0) is None + assert m.num_free_blocks() == 2 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 56bf1a2..0615ab7 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -266,20 +266,25 @@ def test_scheduler_update_processes_outputs_and_finishes_requests(kv_cache_manag generation_config=generation_config, generated_tokens=[], input_ids_list=[234, 1897], - num_computed_tokens=2, ), InferenceRequest( request_id="req02", generation_config=generation_config, generated_tokens=[], input_ids_list=[15287, 344, 765], - num_computed_tokens=3, ), ] # Prepare requests by allocating KV cache and setting up state for request in infer_requests: - scheduler._activate_request(request) + block_hashes, num_prefix_hit_blocks = scheduler._kv_cache_manager.plan_prefix_cache( + request.input_ids_list + ) + scheduler._activate_request(request, block_hashes, num_prefix_hit_blocks) + # Simulate that prefill is already complete so update() takes the decode path. + # _activate_request sets num_computed_tokens to the cache-derived value (0 + # in this test); we want it at input_len for this unit test of update(). + request.num_computed_tokens = len(request.input_ids_list) # Set generated tokens after prepare (which resets them) infer_requests[1].generated_tokens = [78, 22] @@ -458,7 +463,8 @@ def test_scheduler_preempted_request_folds_back_generated_tokens(kv_cache_manage req = InferenceRequest(request_id="req-preempt", generation_config=gen_config, input_ids_list=[10, 20, 30]) # Set up request as if it had been running and generated 2 tokens - scheduler._activate_request(req) + block_hashes, num_prefix_hit_blocks = scheduler._kv_cache_manager.plan_prefix_cache(req.input_ids_list) + scheduler._activate_request(req, block_hashes, num_prefix_hit_blocks) req.generated_tokens = [40, 50] req.generated_logprobs = [0.1, 0.2] req.num_computed_tokens = 5 diff --git a/vexact/batch_invariant_ops/kv_cache_context.py b/vexact/batch_invariant_ops/kv_cache_context.py index eeef68d..89c4653 100644 --- a/vexact/batch_invariant_ops/kv_cache_context.py +++ b/vexact/batch_invariant_ops/kv_cache_context.py @@ -20,6 +20,7 @@ """ import threading +from collections import OrderedDict from dataclasses import dataclass from typing import Iterable, Optional @@ -139,82 +140,278 @@ def has_kv_cache_context() -> bool: class KVCacheManager: + """KV cache manager with optional content-addressed prefix cache. + + Two responsibilities: + 1. Block lifetime via reference counting. Free blocks live in an LRU queue + (least-recently-released first) and are reclaimed for new content when + the pool runs out. + 2. Prefix cache via chained block-content hashing. When `prefix_cache_enabled` + is True, the scheduler calls `plan_prefix_cache` per request to compute + per-full-block chain hashes and the leading hit count. After prefill + completes, `mark_blocks_filled` stamps the just-computed hashes onto the + allocated blocks so future requests with the same prefix can hit them. + + Lifecycle of a cached block: + A) plan_prefix_cache(token_ids) → (block_hashes, num_prefix_hit_blocks) + B) commit_prefix_plan(block_hashes, num_prefix_hit_blocks, total_tokens) → + incref hit blocks; take fresh blocks (LRU-oldest, evicting their old + hash if any) for the remaining full blocks and the partial last block + C) prefill completes → mark_blocks_filled(block_ids, block_hashes) stamps + the full blocks (idempotent — safe to call every decode step) + D) request finishes / preempts → free_blocks → decref; if 0 → push to free_lru + (block KEEPS its hash association — still cache-eligible until evicted) + E) future allocation runs out of fresh-tagged blocks → pops oldest from free_lru, + drops its hash entry from the index → block becomes fresh again + + Invalidation: `clear_cache_index()` drops all hash entries (refcounts / free_lru + untouched). It is the caller's responsibility to also preempt every in-flight + request that owns blocks under the old KV state — those blocks' KV is stale + and decode would read garbage. See `Scheduler.reset_for_state_change()`. + + For pp_size > 1 set prefix_cache_enabled=False — KV is replicated across PP + ranks but only this driver-side manager hashes; coordinating cache state across + ranks is out of scope. """ - KV Cache Manager for continuous batching with flex_attention support - This class manages block-based KV cache allocation and provides utilities - for setting up the context needed by flex_attention. - """ + # Constant seed for the chain hash. Doesn't need cryptographic strength — + # the chain lives within a single process and Python's int hash is fine for + # 1024-block scale (collision probability ~3e-14). + _SEED: int = 0 - def __init__(self, cache_config): + def __init__(self, cache_config, enable_prefix_cache: bool = True): """ - Initialize KV cache manager. - Args: cache_config: CacheConfig from vexact.config + enable_prefix_cache: when True, full blocks are hashed for content-addressed + reuse across requests. When False, behaves like the original allocator + (always fresh blocks, no lookups). Set False for pp_size > 1. """ self.cache_config = cache_config + self.page_size: int = cache_config.page_size + self.prefix_cache_enabled: bool = enable_prefix_cache + self._max_blocks = cache_config.max_cache_blocks + + # refcount[bid] == 0 ⇔ bid in free_lru. Both kept in sync by _incref/_decref. + self._refcount: dict[int, int] = {bid: 0 for bid in range(self._max_blocks)} + self._free_lru: OrderedDict[int, None] = OrderedDict( + (bid, None) for bid in range(self._max_blocks) + ) + + # Content-addressed prefix cache: chain hash → block_id. Only populated for + # full blocks (covering exactly page_size tokens) after their request has + # finished prefill. Both maps are cleared together (a block is either in + # both or neither). + self._block_hash_to_id: dict[int, int] = {} + self._block_id_to_hash: dict[int, int] = {} + + # ---- diagnostics ---- - # Block allocation tracking - self.allocated_blocks = set() - self.free_blocks_set = set(range(cache_config.max_cache_blocks)) + def num_free_blocks(self) -> int: + return len(self._free_lru) + + def num_allocated_blocks(self) -> int: + return self._max_blocks - len(self._free_lru) + + def num_cached_blocks(self) -> int: + """Blocks currently registered in the prefix cache index (in-use OR free-but-cached).""" + return len(self._block_hash_to_id) def _num_blocks_needed(self, num_tokens: int) -> int: - """ - Calculate the number of KV cache blocks needed for a given number of tokens. + return (num_tokens + self.page_size - 1) // self.page_size - Args: - num_tokens: Total number of tokens (e.g., max_length from generation config) + # ---- low-level ref / pool management ---- - Returns: - Number of blocks required - """ - return (num_tokens + self.cache_config.page_size - 1) // self.cache_config.page_size + def _incref(self, block_id: int) -> None: + if self._refcount[block_id] == 0: + self._free_lru.pop(block_id, None) + self._refcount[block_id] += 1 + + def _decref(self, block_id: int) -> None: + self._refcount[block_id] -= 1 + if self._refcount[block_id] == 0: + # Push as MOST recently freed (last). _take_free_block pops from front (oldest). + self._free_lru[block_id] = None + + def _take_free_block(self) -> Optional[int]: + """Pop the LRU-oldest free block; if it carried a cache hash, evict it first.""" + if not self._free_lru: + return None + bid = next(iter(self._free_lru)) + del self._free_lru[bid] + old_hash = self._block_id_to_hash.pop(bid, None) + if old_hash is not None: + # Defensive: only delete if it still maps back to us (it should). + if self._block_hash_to_id.get(old_hash) == bid: + del self._block_hash_to_id[old_hash] + return bid + + def _rollback(self, block_ids: list[int]) -> None: + for bid in block_ids: + self._decref(bid) + + # ---- public allocation paths ---- def allocate_blocks(self, total_tokens: int, num_current_blocks: int = 0) -> list[int] | None: - """ - Ensure KV cache coverage for total_tokens given num_current_blocks already allocated. + """Incremental allocation for an already-active request. - Args: - total_tokens: Total number of tokens that need block coverage - num_current_blocks: Number of blocks already allocated for this request + Used by the scheduler during decode and chunked-prefill steps to extend + block coverage as num_computed_tokens grows. New blocks are taken fresh + from the free pool — no prefix lookup, no stamping (decode tokens are + unique to this request and would never hit the cache anyway). - Returns: - List of newly allocated block IDs (may be empty if already sufficient), - or None if not enough free blocks (OOM). + Returns newly allocated block IDs (possibly empty), or None on OOM. """ num_needed = self._num_blocks_needed(total_tokens) delta = num_needed - num_current_blocks - if delta <= 0: return [] - - if len(self.free_blocks_set) < delta: + if len(self._free_lru) < delta: return None - - allocated = [] + allocated: list[int] = [] for _ in range(delta): - block_id = min(self.free_blocks_set) - self.free_blocks_set.remove(block_id) - self.allocated_blocks.add(block_id) - allocated.append(block_id) - + bid = self._take_free_block() + assert bid is not None # checked above + self._incref(bid) + allocated.append(bid) return allocated - def num_free_blocks(self) -> int: - """Return the number of free KV cache blocks.""" - return len(self.free_blocks_set) + def plan_prefix_cache(self, token_ids: list[int]) -> tuple[list[int], int]: + """Compute the prefix-cache plan for a token sequence. Stateless — no allocation. - def num_allocated_blocks(self) -> int: - """Return the number of allocated KV cache blocks.""" - return len(self.allocated_blocks) + For each FULL block (last partial excluded) compute the chained hash and + check whether it's currently in the cache index. The "hit" run stops at + the first miss — a later "hit" would still require prefill to fill the + gap, which would overwrite the cached blocks, so we don't try to exploit it. + + Returns (block_hashes, num_prefix_hit_blocks): + - block_hashes: one chain hash per full block (`len(token_ids) // page_size` + entries). Empty when prefix cache is disabled. + - num_prefix_hit_blocks: length of the leading contiguous hit run. + + The scheduler reuses block_hashes verbatim in `commit_prefix_plan` and + `mark_blocks_filled` — neither recomputes hashes. + """ + if not self.prefix_cache_enabled: + return [], 0 + + page_size = self.page_size + num_full = len(token_ids) // page_size + block_hashes: list[int] = [] + num_prefix_hit_blocks = 0 + contiguous = True + prev_hash = self._SEED + + for i in range(num_full): + block_hash = hash((prev_hash, tuple(token_ids[i * page_size : (i + 1) * page_size]))) + block_hashes.append(block_hash) + if contiguous and block_hash in self._block_hash_to_id: + num_prefix_hit_blocks += 1 + else: + contiguous = False + prev_hash = block_hash + + return block_hashes, num_prefix_hit_blocks + + def commit_prefix_plan( + self, + block_hashes: list[int], + num_prefix_hit_blocks: int, + total_tokens: int, + ) -> list[int] | None: + """Commit a plan from `plan_prefix_cache`: incref hit blocks, take fresh for the rest. + + Single-threaded scheduler ⇒ cache state can't change between plan and commit, + so the leading `num_prefix_hit_blocks` lookups via `_block_hash_to_id` are + guaranteed to hit (the chain hash is unique to the content). + + Args: + block_hashes: hashes from plan_prefix_cache (one per full block). May be + empty when prefix cache is disabled or `total_tokens < page_size`. + num_prefix_hit_blocks: leading hits from plan_prefix_cache. + total_tokens: total tokens this request needs coverage for. Determines + whether a partial last block is needed (always fresh). + + Returns the full block_ids list on success, or None on OOM (with full rollback). + """ + num_blocks_needed = self._num_blocks_needed(total_tokens) + if num_blocks_needed == 0: + return [] + + block_ids: list[int] = [] + + # Cached portion: refcount existing blocks. Guaranteed to hit per the + # single-threaded plan/commit invariant. + for i in range(num_prefix_hit_blocks): + cached_bid = self._block_hash_to_id[block_hashes[i]] + self._incref(cached_bid) + block_ids.append(cached_bid) + + # Remaining full blocks (cache misses) plus the partial last block: fresh from pool. + for _ in range(num_prefix_hit_blocks, num_blocks_needed): + bid = self._take_free_block() + if bid is None: + self._rollback(block_ids) + return None + self._incref(bid) + block_ids.append(bid) + + return block_ids def free_blocks(self, block_ids: list[int]): - """Free allocated blocks""" - for block_id in block_ids: - if block_id in self.allocated_blocks: - self.allocated_blocks.remove(block_id) - self.free_blocks_set.add(block_id) + """Decref the given blocks; those reaching zero rejoin free_lru (still hashed).""" + for bid in block_ids: + if self._refcount.get(bid, 0) > 0: + self._decref(bid) + + # ---- prefix cache management ---- + + def mark_blocks_filled(self, block_ids: list[int], block_hashes: list[int]) -> None: + """Stamp full blocks with the chain hashes computed by `plan_prefix_cache`. + + Called every decode step (and once at prefill completion) — the fast-path + check on the last full block (which is the last to be stamped, and refcounted + by us so no one else can rewrite it) makes repeat calls O(1). + + Why the partial last block is never stamped (intentional, not just an omission): + 1. Chain hash is sensitive to tuple length: `hash((prev, tuple(partial)))` + differs from `hash((prev, tuple(full)))`, so a future longer request + couldn't hit a partial-stamped block anyway. + 2. The partial slot keeps being written by decode, so any stamp written + at prefill completion would go stale the next decode step. + + Hash collisions and re-stamping of already-stamped blocks are both safe: + - same content → same hash → no-op write + - true collision → last-writer-wins; the displaced entry's defensive + cleanup in `_take_free_block` keeps `_block_hash_to_id` consistent + on eviction. + """ + if not block_hashes: + return + # Fast path: stamping is atomic per request, so if the last full block + # is already stamped with our hash, all earlier ones are too. The + # request owns these blocks (refcount > 0), so eviction can't clobber. + last_bid = block_ids[len(block_hashes) - 1] + if self._block_id_to_hash.get(last_bid) == block_hashes[-1]: + return + for bid, block_hash in zip(block_ids, block_hashes): + self._block_hash_to_id[block_hash] = bid + self._block_id_to_hash[bid] = block_hash + + def clear_cache_index(self) -> None: + """Drop all prefix-cache hash entries. + + Called on weight update or memory-saver sleep — the KV in cached blocks is + no longer correct under the new weights / restored memory, so future requests + must not hit them. + + DOES NOT preempt in-flight requests. Active requests still hold refcounts + on blocks whose KV is now stale; their decode would read garbage if they + continued. Caller is responsible for preempting them first — see + `Scheduler.reset_for_state_change()`. + """ + self._block_hash_to_id.clear() + self._block_id_to_hash.clear() class KVCacheStore: diff --git a/vexact/core/request.py b/vexact/core/request.py index 6f15fc4..cd3c03e 100644 --- a/vexact/core/request.py +++ b/vexact/core/request.py @@ -55,6 +55,12 @@ class InferenceRequest: block_ids: list[int] = field(default_factory=list) + # Chain hashes for each FULL block of input_ids_list, populated by the + # scheduler at activation via plan_prefix_cache. Reused by mark_blocks_filled + # so no module recomputes hashes. Reset on preempt — input_ids_list grows to + # include generated tokens, so the chain must be replanned on reactivation. + prefix_block_hashes: list[int] = field(default_factory=list) + # State: Outputs generated_tokens: list[int] = field(default_factory=list) generated_logits: list[torch.Tensor] = field(default_factory=list) # Store logits for each token @@ -137,6 +143,7 @@ def preempt(self) -> None: self.block_ids = [] self.num_computed_tokens = 0 self.tokens_this_step = 0 + self.prefix_block_hashes = [] self.status = RequestStatus.PENDING def finish(self): diff --git a/vexact/core/scheduler.py b/vexact/core/scheduler.py index 17bfc81..097c81a 100644 --- a/vexact/core/scheduler.py +++ b/vexact/core/scheduler.py @@ -71,6 +71,9 @@ def __init__( self.max_num_batched_tokens = config.max_num_batched_tokens self.total_requests = 0 self._kv_cache_manager = kv_cache_manager + # Prefix-cache stats (per Scheduler lifetime). Useful to confirm hits. + self.cache_hit_tokens_total = 0 + self.cache_miss_tokens_total = 0 self._request_queue: queue.Queue[InferenceRequest] = queue.Queue(maxsize=config.max_queue_size) @@ -159,9 +162,21 @@ def schedule(self) -> SchedulerOutput: break try: - tokens, next_num_comp = self._plan_tokens_for_request(request, available_token_budget) + # Plan prefix-cache hits once. block_hashes flow through to + # _activate_request → commit_prefix_plan (no rehash) and later to + # mark_blocks_filled (also no rehash). Eliminates the triple-hash + # that the original peek/allocate/stamp paths had. + block_hashes, num_prefix_hit_blocks = self._kv_cache_manager.plan_prefix_cache( + request.input_ids_list + ) + cached_tokens = num_prefix_hit_blocks * self._kv_cache_manager.page_size + # Pass num_comp explicitly so we don't mutate the request before + # we know it'll be activated. + tokens, next_num_comp = self._plan_tokens_for_request( + request, available_token_budget, num_comp=cached_tokens + ) if available_token_budget >= tokens: - self._activate_request(request) + self._activate_request(request, block_hashes, num_prefix_hit_blocks) # Set scheduling fields and add to the active batch available_token_budget -= tokens @@ -173,6 +188,7 @@ def schedule(self) -> SchedulerOutput: prefill_seqs_budget -= 1 seqs_budget -= 1 else: + # No mutation happened; just put back and stop. self._request_queue.put(request) break except RuntimeError as e: @@ -195,8 +211,17 @@ def schedule(self) -> SchedulerOutput: return SchedulerOutput(batch_to_infer=batch_to_infer, batch_to_update=batch_to_update) - def _plan_tokens_for_request(self, request: InferenceRequest, available_token_budget: int): - num_comp = request.num_computed_tokens + def _plan_tokens_for_request( + self, + request: InferenceRequest, + available_token_budget: int, + num_comp: int | None = None, + ): + # Allow caller to pass a prospective num_computed_tokens (e.g. the + # prefix-cache hit count for an about-to-be-activated request) without + # mutating request state ahead of the commit. + if num_comp is None: + num_comp = request.num_computed_tokens input_len = len(request.input_ids_list) prefill_remaining = max(0, input_len - num_comp) @@ -228,6 +253,13 @@ def update(self, requests: list[InferenceRequest], infer_result: InferencerOutpu if request.num_computed_tokens < len(request.input_ids_list): continue + # Prefill done. Stamp full blocks so concurrent / future requests with + # the same prefix can hit the cache. `mark_blocks_filled` is idempotent + # (O(1) fast-path on repeat calls), so we don't track a flag here. + self._kv_cache_manager.mark_blocks_filled( + request.block_ids, request.prefix_block_hashes + ) + token_id = self._process_generated_token(request, token_tensor, logits, logprobs) if request.should_finish(token_id): @@ -239,21 +271,43 @@ def update(self, requests: list[InferenceRequest], infer_result: InferencerOutpu # cuz we don't need to do streaming the partial states for now self._result_queue.put(finished_requests) - def _activate_request(self, request: InferenceRequest) -> None: - """Activate a request for processing: allocate resources incrementally. - - Only allocates blocks for the current input_ids_list length (prompt for new requests, - prompt+generated for re-activated preempted requests), not the full max_length. + def _activate_request( + self, + request: InferenceRequest, + block_hashes: list[int], + num_prefix_hit_blocks: int, + ) -> None: + """Activate a request: commit the prefix-cache plan, attach hashes, bump + num_computed_tokens past any cached prefix so prefill skips it. + + The plan (block_hashes + num_prefix_hit_blocks) comes from `plan_prefix_cache`; + we never recompute hashes here. With prefix cache disabled (pp_size > 1), + block_hashes is empty and every block is a fresh allocation — same behaviour + as the original allocator. """ - new_blocks = self._kv_cache_manager.allocate_blocks(len(request.input_ids_list)) - if new_blocks is None: + block_ids = self._kv_cache_manager.commit_prefix_plan( + block_hashes, num_prefix_hit_blocks, len(request.input_ids_list) + ) + if block_ids is None: raise RuntimeError( f"Not enough free blocks to activate request {request.request_id}: " f"need coverage for {len(request.input_ids_list)} tokens" ) - request.block_ids.extend(new_blocks) + num_cached_tokens = num_prefix_hit_blocks * self._kv_cache_manager.page_size + request.block_ids = block_ids + request.prefix_block_hashes = block_hashes + request.num_computed_tokens = num_cached_tokens request.activate() self._active_requests[request.request_id] = request + self.cache_hit_tokens_total += num_cached_tokens + self.cache_miss_tokens_total += max(0, len(request.input_ids_list) - num_cached_tokens) + if num_cached_tokens > 0: + logger.debug( + "[Scheduler] %s: prefix cache hit %d/%d tokens", + request.request_id, + num_cached_tokens, + len(request.input_ids_list), + ) def _extend_or_preempt(self, request: InferenceRequest) -> bool: """Try to extend KV cache blocks for request's planned tokens. If OOM, preempt least-progress request. @@ -292,6 +346,32 @@ def _preempt_request(self, request: InferenceRequest) -> None: self._active_requests.pop(request.request_id, None) self._request_queue.put(request) + def reset_for_state_change(self) -> None: + """Drop all KV-dependent state ahead of a weight update or memory-saver sleep. + + After this returns, the cache index is empty and every previously active + request has been preempted back to the queue with reset state — fresh + prefill under the new weights / restored memory. Lifetime hit/miss stats + are also reset so the post-change hit-ratio reflects the new run. + + Why preempt: blocks held by active requests carry KV computed under the + OLD weights. Continuing decode against those blocks reads stale KV. The + only safe move is to free them and re-prefill from scratch. + """ + for batch in self._inflight_batches: + for request in list(batch.active_requests.values()): + self._kv_cache_manager.free_blocks(request.block_ids) + request.preempt() + self._request_queue.put(request) + batch.active_requests.clear() + + self._kv_cache_manager.clear_cache_index() + + # Reset lifetime stats so post-reset hit-ratio isn't polluted by hits + # against the previous model's KV. + self.cache_hit_tokens_total = 0 + self.cache_miss_tokens_total = 0 + def _finalize_request(self, request: InferenceRequest) -> None: """Finalize a completed request: mark finished, release KV blocks, and remove from active set.""" self._kv_cache_manager.free_blocks(request.block_ids) diff --git a/vexact/engine.py b/vexact/engine.py index 5d3c065..468e175 100644 --- a/vexact/engine.py +++ b/vexact/engine.py @@ -83,6 +83,17 @@ def wake_up(self, tag: str = None): """ return self.driver_client.wake_up(tag=tag) + def get_prefix_cache_stats(self) -> dict: + """Snapshot of prefix-cache counters since the driver started (or last + weight update / sleep, which resets the counters). + + Returns a dict with: prefix_cache_enabled, hit_tokens, miss_tokens, + hit_ratio, cached_blocks, free_blocks. Returns + {"prefix_cache_enabled": False} on PP>1 (the prefix cache is disabled + in that configuration). + """ + return self.driver_client.get_prefix_cache_stats() + async def generate( self, request: DriverRequest, diff --git a/vexact/worker/driver_client.py b/vexact/worker/driver_client.py index 2153ebb..4357274 100644 --- a/vexact/worker/driver_client.py +++ b/vexact/worker/driver_client.py @@ -99,3 +99,12 @@ def wake_up(self, tag: str = None) -> list[Any]: def receive_weights(self) -> list[Any]: """Receive model weights via IPC on all workers.""" return self._execute("receive_weights") + + def get_prefix_cache_stats(self) -> dict: + """Return prefix-cache stats from the driver (rank 0). + + The collective hits every rank but only the driver owns the scheduler, + so non-driver responses are disabled markers we discard. + """ + results = self._execute("get_prefix_cache_stats") + return results[0] if results else {"prefix_cache_enabled": False} diff --git a/vexact/worker/driver_worker.py b/vexact/worker/driver_worker.py index 7545d21..a1c332c 100644 --- a/vexact/worker/driver_worker.py +++ b/vexact/worker/driver_worker.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os import time from vexact.batch_invariant_ops.kv_cache_context import KVCacheManager @@ -30,7 +31,17 @@ class DriverWorker(Worker): def __init__(self, config: VeXactConfig): super().__init__(config, rank=0) # CUDA graph capture happens here - self.kv_cache_manager = KVCacheManager(config.cache) + # Prefix cache only on pp_size==1 — coordinating per-block hash state + # across PP ranks (and ensuring KV slots agree on every rank) is out of + # scope for now. With it disabled, KVCacheManager behaves like the + # original allocator (always-fresh blocks, no lookups). + # VEXACT_DISABLE_PREFIX_CACHE=1 is a debug escape hatch — useful for + # measuring the cache's contribution by running the same workload twice. + enable_prefix_cache = ( + config.parallel.pipeline_parallel_size == 1 + and os.environ.get("VEXACT_DISABLE_PREFIX_CACHE", "") != "1" + ) + self.kv_cache_manager = KVCacheManager(config.cache, enable_prefix_cache=enable_prefix_cache) self.scheduler = Scheduler( config=config.scheduler, kv_cache_manager=self.kv_cache_manager, @@ -45,6 +56,51 @@ def poll_results(self, timeout: float = None) -> list[InferenceRequest]: """Get next batch of finished requests (blocking with optional timeout).""" return self.scheduler.poll_results(timeout=timeout) + def receive_weights(self): + """Receive new model weights. + + Active requests are preempted (their KV is stale under the new weights + and decode would read garbage) and the prefix cache index is dropped. + Re-prefill happens automatically on the next schedule(). + """ + self.scheduler.reset_for_state_change() + super().receive_weights() + + def sleep(self, tag: str = None): + """Pause memory-saving regions. + + The KV cache region is allocated with enable_cpu_backup=False, so its + bytes are dropped on sleep — any active block IDs become invalid. Active + requests are preempted; the prefix cache index is dropped. + """ + self.scheduler.reset_for_state_change() + super().sleep(tag=tag) + + def load_state_dict(self, state_dict): + """In-process weight load (e.g. tests). Same staleness as receive_weights().""" + self.scheduler.reset_for_state_change() + super().load_state_dict(state_dict) + + def get_prefix_cache_stats(self) -> dict: + """Snapshot prefix-cache counters since process start (or last state reset). + + hit_tokens / miss_tokens are summed across every activated request; the + ratio is hit_tokens / (hit_tokens + miss_tokens), guarded against div0. + Counters are reset on weight update / sleep so the ratio always reflects + the current model. + """ + hit = self.scheduler.cache_hit_tokens_total + miss = self.scheduler.cache_miss_tokens_total + total = hit + miss + return { + "prefix_cache_enabled": self.kv_cache_manager.prefix_cache_enabled, + "hit_tokens": hit, + "miss_tokens": miss, + "hit_ratio": (hit / total) if total > 0 else 0.0, + "cached_blocks": self.kv_cache_manager.num_cached_blocks(), + "free_blocks": self.kv_cache_manager.num_free_blocks(), + } + def _generation_loop(self): """Main continuous generation loop.""" # Start profiler here (not in __init__) so that _enable_profiler() and diff --git a/vexact/worker/worker.py b/vexact/worker/worker.py index 678bae0..4bc0060 100644 --- a/vexact/worker/worker.py +++ b/vexact/worker/worker.py @@ -159,6 +159,11 @@ def receive_weights(self): receiver = BucketedWeightReceiver(zmq_handle=self.zmq_handle, device=self.device) receiver.receive_weights(on_bucket_received=lambda weights: self.load_state_dict(dict(weights))) + def get_prefix_cache_stats(self) -> dict: + # Non-driver ranks don't own a scheduler or KVCacheManager. Returning a + # disabled marker keeps the collective RPC well-formed across all ranks. + return {"prefix_cache_enabled": False} + def run_worker(config: VeXactConfig, rank: int): """Run a worker (blocking).""" diff --git a/vexact/worker/worker_proxy.py b/vexact/worker/worker_proxy.py index 3d0bac4..bedfd63 100644 --- a/vexact/worker/worker_proxy.py +++ b/vexact/worker/worker_proxy.py @@ -53,7 +53,7 @@ def __init__(self, config: VeXactConfig, rank: int): self.control_server = ControlChannelServer( address=config.driver.control_addresses[rank], target=self.worker, - allowed_methods=["sleep", "wake_up", "receive_weights"], + allowed_methods=["sleep", "wake_up", "receive_weights", "get_prefix_cache_stats"], ) def start(self): @@ -82,7 +82,7 @@ def __init__(self, config: VeXactConfig, rank: int): self.control_server = ControlChannelServer( address=config.driver.control_addresses[rank], target=self.worker, - allowed_methods=["sleep", "wake_up", "receive_weights"], + allowed_methods=["sleep", "wake_up", "receive_weights", "get_prefix_cache_stats"], ) def start(self): From ea5f3f41c18a4f21a914dce055bb9adfd80587c6 Mon Sep 17 00:00:00 2001 From: "ziyi.wang" Date: Tue, 12 May 2026 07:51:32 +0800 Subject: [PATCH 2/5] style: apply ruff-format to prefix cache changes --- tests/test_scheduler.py | 4 +--- vexact/batch_invariant_ops/kv_cache_context.py | 4 +--- vexact/core/scheduler.py | 8 ++------ vexact/worker/driver_worker.py | 3 +-- 4 files changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 0615ab7..8de6e2c 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -277,9 +277,7 @@ def test_scheduler_update_processes_outputs_and_finishes_requests(kv_cache_manag # Prepare requests by allocating KV cache and setting up state for request in infer_requests: - block_hashes, num_prefix_hit_blocks = scheduler._kv_cache_manager.plan_prefix_cache( - request.input_ids_list - ) + block_hashes, num_prefix_hit_blocks = scheduler._kv_cache_manager.plan_prefix_cache(request.input_ids_list) scheduler._activate_request(request, block_hashes, num_prefix_hit_blocks) # Simulate that prefill is already complete so update() takes the decode path. # _activate_request sets num_computed_tokens to the cache-derived value (0 diff --git a/vexact/batch_invariant_ops/kv_cache_context.py b/vexact/batch_invariant_ops/kv_cache_context.py index 89c4653..811a7c8 100644 --- a/vexact/batch_invariant_ops/kv_cache_context.py +++ b/vexact/batch_invariant_ops/kv_cache_context.py @@ -194,9 +194,7 @@ def __init__(self, cache_config, enable_prefix_cache: bool = True): # refcount[bid] == 0 ⇔ bid in free_lru. Both kept in sync by _incref/_decref. self._refcount: dict[int, int] = {bid: 0 for bid in range(self._max_blocks)} - self._free_lru: OrderedDict[int, None] = OrderedDict( - (bid, None) for bid in range(self._max_blocks) - ) + self._free_lru: OrderedDict[int, None] = OrderedDict((bid, None) for bid in range(self._max_blocks)) # Content-addressed prefix cache: chain hash → block_id. Only populated for # full blocks (covering exactly page_size tokens) after their request has diff --git a/vexact/core/scheduler.py b/vexact/core/scheduler.py index 097c81a..1c65729 100644 --- a/vexact/core/scheduler.py +++ b/vexact/core/scheduler.py @@ -166,9 +166,7 @@ def schedule(self) -> SchedulerOutput: # _activate_request → commit_prefix_plan (no rehash) and later to # mark_blocks_filled (also no rehash). Eliminates the triple-hash # that the original peek/allocate/stamp paths had. - block_hashes, num_prefix_hit_blocks = self._kv_cache_manager.plan_prefix_cache( - request.input_ids_list - ) + block_hashes, num_prefix_hit_blocks = self._kv_cache_manager.plan_prefix_cache(request.input_ids_list) cached_tokens = num_prefix_hit_blocks * self._kv_cache_manager.page_size # Pass num_comp explicitly so we don't mutate the request before # we know it'll be activated. @@ -256,9 +254,7 @@ def update(self, requests: list[InferenceRequest], infer_result: InferencerOutpu # Prefill done. Stamp full blocks so concurrent / future requests with # the same prefix can hit the cache. `mark_blocks_filled` is idempotent # (O(1) fast-path on repeat calls), so we don't track a flag here. - self._kv_cache_manager.mark_blocks_filled( - request.block_ids, request.prefix_block_hashes - ) + self._kv_cache_manager.mark_blocks_filled(request.block_ids, request.prefix_block_hashes) token_id = self._process_generated_token(request, token_tensor, logits, logprobs) diff --git a/vexact/worker/driver_worker.py b/vexact/worker/driver_worker.py index a1c332c..d00ed24 100644 --- a/vexact/worker/driver_worker.py +++ b/vexact/worker/driver_worker.py @@ -38,8 +38,7 @@ def __init__(self, config: VeXactConfig): # VEXACT_DISABLE_PREFIX_CACHE=1 is a debug escape hatch — useful for # measuring the cache's contribution by running the same workload twice. enable_prefix_cache = ( - config.parallel.pipeline_parallel_size == 1 - and os.environ.get("VEXACT_DISABLE_PREFIX_CACHE", "") != "1" + config.parallel.pipeline_parallel_size == 1 and os.environ.get("VEXACT_DISABLE_PREFIX_CACHE", "") != "1" ) self.kv_cache_manager = KVCacheManager(config.cache, enable_prefix_cache=enable_prefix_cache) self.scheduler = Scheduler( From f4d2c0c4d16e941a7779ce26a0b841d4209b373e Mon Sep 17 00:00:00 2001 From: "ziyi.wang" Date: Tue, 12 May 2026 09:30:53 +0800 Subject: [PATCH 3/5] simplify weight-update path: rely on framework drain receive_weights / sleep / load_state_dict now call clear_cache_index() directly instead of routing through scheduler.reset_for_state_change(), which preempted active requests and reset stats. Drop the preempt-all helper entirely. Rationale: the framework (verl) drains all in-flight requests before triggering a weight update or sleep, so there is nothing to preempt at that point. clear_cache_index alone is enough to prevent future requests from hitting KV computed under old weights. Stats become pure lifetime counters, giving a stable trend line across weight updates instead of a jittery per-step ratio. Docstrings updated to spell out the drain assumption. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../batch_invariant_ops/kv_cache_context.py | 23 +++++++------- vexact/core/scheduler.py | 30 ++----------------- vexact/worker/driver_worker.py | 22 +++++++------- 3 files changed, 27 insertions(+), 48 deletions(-) diff --git a/vexact/batch_invariant_ops/kv_cache_context.py b/vexact/batch_invariant_ops/kv_cache_context.py index 811a7c8..07f46a6 100644 --- a/vexact/batch_invariant_ops/kv_cache_context.py +++ b/vexact/batch_invariant_ops/kv_cache_context.py @@ -165,9 +165,10 @@ class KVCacheManager: drops its hash entry from the index → block becomes fresh again Invalidation: `clear_cache_index()` drops all hash entries (refcounts / free_lru - untouched). It is the caller's responsibility to also preempt every in-flight - request that owns blocks under the old KV state — those blocks' KV is stale - and decode would read garbage. See `Scheduler.reset_for_state_change()`. + untouched). Caller (the engine, on weight update / sleep) must ensure there + are no in-flight requests at that moment — those would own blocks whose KV + is now stale, and continuing decode against them would read garbage. The + expected protocol is "drain → clear → swap" handled at the framework layer. For pp_size > 1 set prefix_cache_enabled=False — KV is replicated across PP ranks but only this driver-side manager hashes; coordinating cache state across @@ -400,13 +401,15 @@ def clear_cache_index(self) -> None: """Drop all prefix-cache hash entries. Called on weight update or memory-saver sleep — the KV in cached blocks is - no longer correct under the new weights / restored memory, so future requests - must not hit them. - - DOES NOT preempt in-flight requests. Active requests still hold refcounts - on blocks whose KV is now stale; their decode would read garbage if they - continued. Caller is responsible for preempting them first — see - `Scheduler.reset_for_state_change()`. + no longer correct under the new weights / restored memory, so future + requests must not hit them. + + DOES NOT preempt in-flight requests. The contract is that the engine + framework (e.g. verl) drains all in-flight requests before triggering a + weight update / sleep, so at the moment this is called there's nothing + to preempt. If called with active requests present, their decode will + continue reading stale KV — that's a framework-level bug, not handled + here. """ self._block_hash_to_id.clear() self._block_id_to_hash.clear() diff --git a/vexact/core/scheduler.py b/vexact/core/scheduler.py index 1c65729..3daacf9 100644 --- a/vexact/core/scheduler.py +++ b/vexact/core/scheduler.py @@ -71,7 +71,9 @@ def __init__( self.max_num_batched_tokens = config.max_num_batched_tokens self.total_requests = 0 self._kv_cache_manager = kv_cache_manager - # Prefix-cache stats (per Scheduler lifetime). Useful to confirm hits. + # Lifetime prefix-cache stats. Never reset — get_prefix_cache_stats reports + # the cumulative hit ratio since engine start, which gives a stable trend + # line across weight updates rather than a jittery per-step ratio. self.cache_hit_tokens_total = 0 self.cache_miss_tokens_total = 0 @@ -342,32 +344,6 @@ def _preempt_request(self, request: InferenceRequest) -> None: self._active_requests.pop(request.request_id, None) self._request_queue.put(request) - def reset_for_state_change(self) -> None: - """Drop all KV-dependent state ahead of a weight update or memory-saver sleep. - - After this returns, the cache index is empty and every previously active - request has been preempted back to the queue with reset state — fresh - prefill under the new weights / restored memory. Lifetime hit/miss stats - are also reset so the post-change hit-ratio reflects the new run. - - Why preempt: blocks held by active requests carry KV computed under the - OLD weights. Continuing decode against those blocks reads stale KV. The - only safe move is to free them and re-prefill from scratch. - """ - for batch in self._inflight_batches: - for request in list(batch.active_requests.values()): - self._kv_cache_manager.free_blocks(request.block_ids) - request.preempt() - self._request_queue.put(request) - batch.active_requests.clear() - - self._kv_cache_manager.clear_cache_index() - - # Reset lifetime stats so post-reset hit-ratio isn't polluted by hits - # against the previous model's KV. - self.cache_hit_tokens_total = 0 - self.cache_miss_tokens_total = 0 - def _finalize_request(self, request: InferenceRequest) -> None: """Finalize a completed request: mark finished, release KV blocks, and remove from active set.""" self._kv_cache_manager.free_blocks(request.block_ids) diff --git a/vexact/worker/driver_worker.py b/vexact/worker/driver_worker.py index d00ed24..706cfad 100644 --- a/vexact/worker/driver_worker.py +++ b/vexact/worker/driver_worker.py @@ -58,35 +58,35 @@ def poll_results(self, timeout: float = None) -> list[InferenceRequest]: def receive_weights(self): """Receive new model weights. - Active requests are preempted (their KV is stale under the new weights - and decode would read garbage) and the prefix cache index is dropped. - Re-prefill happens automatically on the next schedule(). + Drops the prefix cache index so future requests can't hit KV computed + under the old weights. Assumes the caller (e.g. verl) has already + drained all in-flight requests — we don't preempt anything here. """ - self.scheduler.reset_for_state_change() + self.kv_cache_manager.clear_cache_index() super().receive_weights() def sleep(self, tag: str = None): """Pause memory-saving regions. The KV cache region is allocated with enable_cpu_backup=False, so its - bytes are dropped on sleep — any active block IDs become invalid. Active - requests are preempted; the prefix cache index is dropped. + bytes are dropped on sleep. The prefix cache index is dropped so the + post-wake_up engine can't hit blocks whose bytes are gone. Caller is + expected to have drained in-flight requests first. """ - self.scheduler.reset_for_state_change() + self.kv_cache_manager.clear_cache_index() super().sleep(tag=tag) def load_state_dict(self, state_dict): """In-process weight load (e.g. tests). Same staleness as receive_weights().""" - self.scheduler.reset_for_state_change() + self.kv_cache_manager.clear_cache_index() super().load_state_dict(state_dict) def get_prefix_cache_stats(self) -> dict: - """Snapshot prefix-cache counters since process start (or last state reset). + """Snapshot prefix-cache counters since process start (lifetime). hit_tokens / miss_tokens are summed across every activated request; the ratio is hit_tokens / (hit_tokens + miss_tokens), guarded against div0. - Counters are reset on weight update / sleep so the ratio always reflects - the current model. + Counters are never reset, so the ratio is a stable lifetime trend line. """ hit = self.scheduler.cache_hit_tokens_total miss = self.scheduler.cache_miss_tokens_total From 3a0296e9c1840015a9014a5c87e68711c45ae6a0 Mon Sep 17 00:00:00 2001 From: "ziyi.wang" Date: Tue, 12 May 2026 09:32:40 +0800 Subject: [PATCH 4/5] move test_kv_cache_manager.py to tests/ top level KVCacheManager is a scheduler-level resource manager (refcount + LRU + prefix cache index), not a batch-invariant kernel. Its tests belong next to test_scheduler.py, not in tests/batch_invariant_ops/ which is for flex_attention / matmul / FA4 invariance checks. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/{batch_invariant_ops => }/test_kv_cache_manager.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{batch_invariant_ops => }/test_kv_cache_manager.py (100%) diff --git a/tests/batch_invariant_ops/test_kv_cache_manager.py b/tests/test_kv_cache_manager.py similarity index 100% rename from tests/batch_invariant_ops/test_kv_cache_manager.py rename to tests/test_kv_cache_manager.py From 2ada2699d6c1592866e0b0e523495b73a5bd9eee Mon Sep 17 00:00:00 2001 From: "ziyi.wang" Date: Tue, 12 May 2026 09:47:53 +0800 Subject: [PATCH 5/5] split plan_prefix_cache into compute + count, cache hashes on request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Block hashes are a pure function of (token_ids, page_size, _SEED), so recomputing them every scheduling iteration for a requeued request is wasted work — gets noticeable at long contexts (page_size=16 + 32K prompts: ~1ms per replan × 100 retries = 100ms). Split the API: - compute_block_hashes(token_ids) → list[int] (stateless, deterministic) - count_prefix_hits(block_hashes) → int (cheap O(N) dict lookups) Scheduler caches the hashes on InferenceRequest.prefix_block_hashes after the first compute; subsequent rounds only re-run count_prefix_hits. The field already existed for mark_blocks_filled, so no new state is added. preempt() already clears it (input_ids_list grows on preempt, hashes need to be recomputed against the new sequence). plan_prefix_cache (the convenience wrapper) is removed entirely — no production caller after this change, so keeping it as a test-only wrapper would just hide that tests aren't exercising the real path. Addresses Gemini review comment on plan-per-iteration cost. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_kv_cache_manager.py | 136 ++++++++++-------- tests/test_scheduler.py | 6 +- .../batch_invariant_ops/kv_cache_context.py | 67 ++++----- vexact/core/request.py | 9 +- vexact/core/scheduler.py | 21 ++- 5 files changed, 134 insertions(+), 105 deletions(-) diff --git a/tests/test_kv_cache_manager.py b/tests/test_kv_cache_manager.py index acbbdc9..7e21ea8 100644 --- a/tests/test_kv_cache_manager.py +++ b/tests/test_kv_cache_manager.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for KVCacheManager: refcounting, LRU eviction, and prefix cache plan/commit/mark.""" +"""Unit tests for KVCacheManager: refcounting, LRU eviction, and prefix cache compute/commit/mark.""" import pytest @@ -48,36 +48,50 @@ def test_construction_disabled(mgr): assert m.prefix_cache_enabled is False -# ---------- plan_prefix_cache ---------- +# ---------- compute_block_hashes / count_prefix_hits ---------- -def test_plan_empty_tokens(mgr): +def test_compute_empty_tokens(mgr): m = mgr() - assert m.plan_prefix_cache([]) == ([], 0) + assert m.compute_block_hashes([]) == [] + assert m.count_prefix_hits([]) == 0 -def test_plan_only_partial_block(mgr): +def test_compute_only_partial_block(mgr): # page_size=4, 3 tokens → 0 full blocks m = mgr() - assert m.plan_prefix_cache([1, 2, 3]) == ([], 0) + assert m.compute_block_hashes([1, 2, 3]) == [] -def test_plan_when_disabled(mgr): +def test_compute_when_disabled(mgr): m = mgr(enable_prefix_cache=False) - assert m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) == ([], 0) + assert m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) == [] -def test_plan_cold_cache_returns_hashes_but_zero_hits(mgr): +def test_compute_cold_cache_returns_hashes_but_zero_hits(mgr): m = mgr() - hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + hashes = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) assert len(hashes) == 2 - assert n_cached == 0 + assert m.count_prefix_hits(hashes) == 0 + + +def test_compute_is_stateless(mgr): + # Contract: compute_block_hashes is a pure function of (token_ids, page_size, + # _SEED). The scheduler caches the result on the request — that's only safe if + # changing cache state doesn't change the hashes. + m = mgr() + h_first = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + bids = m.commit_prefix_plan(h_first, 0, 8) + m.mark_blocks_filled(bids, h_first) + m.free_blocks(bids) + # Same input → same hashes, even after cache state changed. + assert m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) == h_first def test_chain_hash_diverges_on_content_difference(mgr): m = mgr() - h_a, _ = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - h_b, _ = m.plan_prefix_cache([1, 2, 3, 4, 9, 9, 9, 9]) + h_a = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + h_b = m.compute_block_hashes([1, 2, 3, 4, 9, 9, 9, 9]) assert h_a[0] == h_b[0] # same first block content assert h_a[1] != h_b[1] # diverged from block 1 onward @@ -87,16 +101,16 @@ def test_chain_hash_diverges_on_content_difference(mgr): def test_commit_cold_allocates_fresh(mgr): m = mgr() - hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - bids = m.commit_prefix_plan(hashes, n_cached, 8) + hashes = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + bids = m.commit_prefix_plan(hashes, 0, 8) assert bids == [0, 1] assert m.num_allocated_blocks() == 2 def test_commit_partial_only_takes_one_block(mgr): m = mgr() - hashes, n_cached = m.plan_prefix_cache([1, 2, 3]) - bids = m.commit_prefix_plan(hashes, n_cached, 3) + hashes = m.compute_block_hashes([1, 2, 3]) + bids = m.commit_prefix_plan(hashes, 0, 3) assert bids == [0] assert m.num_allocated_blocks() == 1 @@ -104,16 +118,16 @@ def test_commit_partial_only_takes_one_block(mgr): def test_commit_full_plus_partial(mgr): # 1 full + 1 partial = 2 blocks total m = mgr() - hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5]) + hashes = m.compute_block_hashes([1, 2, 3, 4, 5]) assert len(hashes) == 1 - bids = m.commit_prefix_plan(hashes, n_cached, 5) + bids = m.commit_prefix_plan(hashes, 0, 5) assert len(bids) == 2 def test_commit_oom_rollback_keeps_pool_intact(mgr): m = mgr(max_blocks=2) - hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) - assert m.commit_prefix_plan(hashes, n_cached, 12) is None + hashes = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + assert m.commit_prefix_plan(hashes, 0, 12) is None assert m.num_free_blocks() == 2 assert m.num_allocated_blocks() == 0 @@ -122,14 +136,15 @@ def test_commit_oom_releases_cached_increfs(mgr): # Cached blocks incref'd during commit must be decref'd back on OOM, not stuck. m = mgr(max_blocks=2) # First request fills the cache index (2 full blocks). - h1, n1 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - bids1 = m.commit_prefix_plan(h1, n1, 8) + h1 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + bids1 = m.commit_prefix_plan(h1, 0, 8) m.mark_blocks_filled(bids1, h1) m.free_blocks(bids1) # Second request: same prefix (full hit) + a partial last block. OOM because # the partial needs a fresh block but the only 2 blocks just got refcounted. - h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8, 99, 99, 99]) + h2 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8, 99, 99, 99]) + n2 = m.count_prefix_hits(h2) assert n2 == 2 assert m.commit_prefix_plan(h2, n2, 11) is None assert m.num_free_blocks() == 2 @@ -143,16 +158,16 @@ def test_commit_oom_releases_cached_increfs(mgr): def test_mark_blocks_filled_records_full_blocks_only(mgr): m = mgr() - hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7]) # 1 full + partial - bids = m.commit_prefix_plan(hashes, n_cached, 7) + hashes = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7]) # 1 full + partial + bids = m.commit_prefix_plan(hashes, 0, 7) m.mark_blocks_filled(bids, hashes) assert m.num_cached_blocks() == 1 # only the full block stamped def test_mark_blocks_filled_idempotent(mgr): m = mgr() - hashes, n_cached = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - bids = m.commit_prefix_plan(hashes, n_cached, 8) + hashes = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + bids = m.commit_prefix_plan(hashes, 0, 8) m.mark_blocks_filled(bids, hashes) state1 = (dict(m._block_hash_to_id), dict(m._block_id_to_hash)) @@ -166,14 +181,14 @@ def test_mark_blocks_filled_stamps_misses_after_partial_hit(mgr): # Half-hit: first block hits cache, second block is fresh. The fresh block # must end up correctly stamped using the precomputed chain hash. m = mgr() - h1, n1 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - bids1 = m.commit_prefix_plan(h1, n1, 8) + h1 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + bids1 = m.commit_prefix_plan(h1, 0, 8) m.mark_blocks_filled(bids1, h1) m.free_blocks(bids1) - h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 9, 9, 9, 9]) - assert n2 == 1 # only first block hits - bids2 = m.commit_prefix_plan(h2, n2, 8) + h2 = m.compute_block_hashes([1, 2, 3, 4, 9, 9, 9, 9]) + assert m.count_prefix_hits(h2) == 1 # only first block hits + bids2 = m.commit_prefix_plan(h2, 1, 8) m.mark_blocks_filled(bids2, h2) assert m._block_id_to_hash[bids2[1]] == h2[1] @@ -194,16 +209,17 @@ def test_full_cycle_miss_then_hit(mgr): m = mgr() toks = [1, 2, 3, 4, 5, 6, 7, 8] - h1, n1 = m.plan_prefix_cache(toks) - assert n1 == 0 - bids1 = m.commit_prefix_plan(h1, n1, len(toks)) + h1 = m.compute_block_hashes(toks) + assert m.count_prefix_hits(h1) == 0 + bids1 = m.commit_prefix_plan(h1, 0, len(toks)) m.mark_blocks_filled(bids1, h1) m.free_blocks(bids1) assert m.num_free_blocks() == 8 assert m.num_cached_blocks() == 2 # hashes survive free # Same content again — full hit on the same physical blocks. - h2, n2 = m.plan_prefix_cache(toks) + h2 = m.compute_block_hashes(toks) + n2 = m.count_prefix_hits(h2) assert h2 == h1 assert n2 == 2 bids2 = m.commit_prefix_plan(h2, n2, len(toks)) @@ -214,20 +230,20 @@ def test_partial_last_block_does_not_block_prefix_hits(mgr): # Two requests share the same first full block but differ in the partial tail — # the full block should still hit. m = mgr() - h1, n1 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7]) - bids1 = m.commit_prefix_plan(h1, n1, 7) + h1 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7]) + bids1 = m.commit_prefix_plan(h1, 0, 7) m.mark_blocks_filled(bids1, h1) m.free_blocks(bids1) - h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 99, 99, 99]) - assert n2 == 1 + h2 = m.compute_block_hashes([1, 2, 3, 4, 99, 99, 99]) + assert m.count_prefix_hits(h2) == 1 assert h2[0] == h1[0] def test_contiguous_run_stops_at_first_miss(mgr): # Fill cache: blocks [1..4][5..8][9..12]. m = mgr() - h1, _ = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + h1 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) bids1 = m.commit_prefix_plan(h1, 0, 12) m.mark_blocks_filled(bids1, h1) m.free_blocks(bids1) @@ -235,20 +251,21 @@ def test_contiguous_run_stops_at_first_miss(mgr): # New: [1..4] hits, [99..] misses, [9..12] hashes ALWAYS diverge because the # chain depends on the previous block's hash. So even the "same" 3rd block # is a different chain hash → reported miss. - h2, n2 = m.plan_prefix_cache([1, 2, 3, 4, 99, 99, 99, 99, 9, 10, 11, 12]) - assert n2 == 1 + h2 = m.compute_block_hashes([1, 2, 3, 4, 99, 99, 99, 99, 9, 10, 11, 12]) + assert m.count_prefix_hits(h2) == 1 assert h2[0] == h1[0] assert h2[2] != h1[2] def test_refcount_shared_between_concurrent_requests(mgr): m = mgr() - h_a, n_a = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - bids_a = m.commit_prefix_plan(h_a, n_a, 8) + h_a = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + bids_a = m.commit_prefix_plan(h_a, 0, 8) m.mark_blocks_filled(bids_a, h_a) # Second concurrent request hits both blocks. - h_b, n_b = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + h_b = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + n_b = m.count_prefix_hits(h_b) bids_b = m.commit_prefix_plan(h_b, n_b, 8) assert bids_b == bids_a for bid in bids_a: @@ -287,14 +304,14 @@ def test_free_lru_oldest_first(mgr): def test_eviction_drops_hash_entry(mgr): m = mgr(max_blocks=2) # Fill cache with 2 hashed blocks. - h1, _ = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) + h1 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) bids1 = m.commit_prefix_plan(h1, 0, 8) m.mark_blocks_filled(bids1, h1) m.free_blocks(bids1) assert m.num_cached_blocks() == 2 # Allocate 2 fresh blocks for unrelated content → both old hashes get evicted. - h2, _ = m.plan_prefix_cache([100, 101, 102, 103, 104, 105, 106, 107]) + h2 = m.compute_block_hashes([100, 101, 102, 103, 104, 105, 106, 107]) bids2 = m.commit_prefix_plan(h2, 0, 8) assert bids2 is not None for old in h1: @@ -306,8 +323,8 @@ def test_eviction_drops_hash_entry(mgr): def test_clear_cache_index_drops_hashes_only(mgr): m = mgr() - h, n = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - bids = m.commit_prefix_plan(h, n, 8) + h = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + bids = m.commit_prefix_plan(h, 0, 8) m.mark_blocks_filled(bids, h) assert m.num_cached_blocks() == 2 assert m.num_allocated_blocks() == 2 @@ -316,9 +333,9 @@ def test_clear_cache_index_drops_hashes_only(mgr): assert m.num_cached_blocks() == 0 assert m.num_allocated_blocks() == 2 # refcounts untouched - # New plan with same content sees a cold cache. - _, n2 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - assert n2 == 0 + # New hashes for same content see a cold cache. + h2 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + assert m.count_prefix_hits(h2) == 0 # ---------- disabled prefix cache ---------- @@ -326,9 +343,10 @@ def test_clear_cache_index_drops_hashes_only(mgr): def test_disabled_cache_behaves_like_plain_allocator(mgr): m = mgr(enable_prefix_cache=False) - h, n = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - assert h == [] and n == 0 - bids = m.commit_prefix_plan(h, n, 8) + h = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + assert h == [] + assert m.count_prefix_hits(h) == 0 + bids = m.commit_prefix_plan(h, 0, 8) assert len(bids) == 2 m.mark_blocks_filled(bids, h) # no-op @@ -336,8 +354,8 @@ def test_disabled_cache_behaves_like_plain_allocator(mgr): # Second identical request still misses. m.free_blocks(bids) - _, n2 = m.plan_prefix_cache([1, 2, 3, 4, 5, 6, 7, 8]) - assert n2 == 0 + h2 = m.compute_block_hashes([1, 2, 3, 4, 5, 6, 7, 8]) + assert m.count_prefix_hits(h2) == 0 # ---------- allocate_blocks (decode-path) ---------- diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 8de6e2c..184b0ee 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -277,7 +277,8 @@ def test_scheduler_update_processes_outputs_and_finishes_requests(kv_cache_manag # Prepare requests by allocating KV cache and setting up state for request in infer_requests: - block_hashes, num_prefix_hit_blocks = scheduler._kv_cache_manager.plan_prefix_cache(request.input_ids_list) + block_hashes = scheduler._kv_cache_manager.compute_block_hashes(request.input_ids_list) + num_prefix_hit_blocks = scheduler._kv_cache_manager.count_prefix_hits(block_hashes) scheduler._activate_request(request, block_hashes, num_prefix_hit_blocks) # Simulate that prefill is already complete so update() takes the decode path. # _activate_request sets num_computed_tokens to the cache-derived value (0 @@ -461,7 +462,8 @@ def test_scheduler_preempted_request_folds_back_generated_tokens(kv_cache_manage req = InferenceRequest(request_id="req-preempt", generation_config=gen_config, input_ids_list=[10, 20, 30]) # Set up request as if it had been running and generated 2 tokens - block_hashes, num_prefix_hit_blocks = scheduler._kv_cache_manager.plan_prefix_cache(req.input_ids_list) + block_hashes = scheduler._kv_cache_manager.compute_block_hashes(req.input_ids_list) + num_prefix_hit_blocks = scheduler._kv_cache_manager.count_prefix_hits(block_hashes) scheduler._activate_request(req, block_hashes, num_prefix_hit_blocks) req.generated_tokens = [40, 50] req.generated_logprobs = [0.1, 0.2] diff --git a/vexact/batch_invariant_ops/kv_cache_context.py b/vexact/batch_invariant_ops/kv_cache_context.py index 07f46a6..684da26 100644 --- a/vexact/batch_invariant_ops/kv_cache_context.py +++ b/vexact/batch_invariant_ops/kv_cache_context.py @@ -147,13 +147,15 @@ class KVCacheManager: (least-recently-released first) and are reclaimed for new content when the pool runs out. 2. Prefix cache via chained block-content hashing. When `prefix_cache_enabled` - is True, the scheduler calls `plan_prefix_cache` per request to compute - per-full-block chain hashes and the leading hit count. After prefill - completes, `mark_blocks_filled` stamps the just-computed hashes onto the - allocated blocks so future requests with the same prefix can hit them. + is True, the scheduler calls `compute_block_hashes` once per request + (deterministic, cached on the request) and `count_prefix_hits` every + scheduling iteration (cheap, depends on current cache state). After + prefill completes, `mark_blocks_filled` stamps the just-computed hashes + onto the allocated blocks so future requests with the same prefix hit. Lifecycle of a cached block: - A) plan_prefix_cache(token_ids) → (block_hashes, num_prefix_hit_blocks) + A) compute_block_hashes(token_ids) → block_hashes + count_prefix_hits(block_hashes) → num_prefix_hit_blocks B) commit_prefix_plan(block_hashes, num_prefix_hit_blocks, total_tokens) → incref hit blocks; take fresh blocks (LRU-oldest, evicting their old hash if any) for the remaining full blocks and the partial last block @@ -275,42 +277,41 @@ def allocate_blocks(self, total_tokens: int, num_current_blocks: int = 0) -> lis allocated.append(bid) return allocated - def plan_prefix_cache(self, token_ids: list[int]) -> tuple[list[int], int]: - """Compute the prefix-cache plan for a token sequence. Stateless — no allocation. + def compute_block_hashes(self, token_ids: list[int]) -> list[int]: + """Chain hashes for all FULL blocks of `token_ids`. Stateless, deterministic. - For each FULL block (last partial excluded) compute the chained hash and - check whether it's currently in the cache index. The "hit" run stops at - the first miss — a later "hit" would still require prefill to fill the - gap, which would overwrite the cached blocks, so we don't try to exploit it. + Result is purely a function of (token_ids, page_size, _SEED) — does not + consult the cache index. The scheduler caches the result on the request + so requeued requests don't pay this cost every iteration. - Returns (block_hashes, num_prefix_hit_blocks): - - block_hashes: one chain hash per full block (`len(token_ids) // page_size` - entries). Empty when prefix cache is disabled. - - num_prefix_hit_blocks: length of the leading contiguous hit run. - - The scheduler reuses block_hashes verbatim in `commit_prefix_plan` and - `mark_blocks_filled` — neither recomputes hashes. + Returns empty when prefix cache is disabled or `len(token_ids) < page_size`. """ if not self.prefix_cache_enabled: - return [], 0 + return [] page_size = self.page_size num_full = len(token_ids) // page_size block_hashes: list[int] = [] - num_prefix_hit_blocks = 0 - contiguous = True prev_hash = self._SEED - for i in range(num_full): block_hash = hash((prev_hash, tuple(token_ids[i * page_size : (i + 1) * page_size]))) block_hashes.append(block_hash) - if contiguous and block_hash in self._block_hash_to_id: - num_prefix_hit_blocks += 1 - else: - contiguous = False prev_hash = block_hash + return block_hashes - return block_hashes, num_prefix_hit_blocks + def count_prefix_hits(self, block_hashes: list[int]) -> int: + """Number of leading contiguous hits in the current cache index. + + Cheap: O(N) dict-membership checks against `_block_hash_to_id`. Stops at the + first miss — a later "hit" would still require prefill to fill the gap (which + would overwrite the cached blocks), so we don't try to exploit scattered hits. + """ + n = 0 + for h in block_hashes: + if h not in self._block_hash_to_id: + break + n += 1 + return n def commit_prefix_plan( self, @@ -318,16 +319,16 @@ def commit_prefix_plan( num_prefix_hit_blocks: int, total_tokens: int, ) -> list[int] | None: - """Commit a plan from `plan_prefix_cache`: incref hit blocks, take fresh for the rest. + """Commit the plan: incref hit blocks, take fresh for the rest. - Single-threaded scheduler ⇒ cache state can't change between plan and commit, + Single-threaded scheduler ⇒ cache state can't change between count and commit, so the leading `num_prefix_hit_blocks` lookups via `_block_hash_to_id` are guaranteed to hit (the chain hash is unique to the content). Args: - block_hashes: hashes from plan_prefix_cache (one per full block). May be - empty when prefix cache is disabled or `total_tokens < page_size`. - num_prefix_hit_blocks: leading hits from plan_prefix_cache. + block_hashes: hashes from compute_block_hashes (one per full block). + May be empty when prefix cache is disabled or `total_tokens < page_size`. + num_prefix_hit_blocks: leading hits from count_prefix_hits. total_tokens: total tokens this request needs coverage for. Determines whether a partial last block is needed (always fresh). @@ -366,7 +367,7 @@ def free_blocks(self, block_ids: list[int]): # ---- prefix cache management ---- def mark_blocks_filled(self, block_ids: list[int], block_hashes: list[int]) -> None: - """Stamp full blocks with the chain hashes computed by `plan_prefix_cache`. + """Stamp full blocks with the chain hashes computed by `compute_block_hashes`. Called every decode step (and once at prefill completion) — the fast-path check on the last full block (which is the last to be stamped, and refcounted diff --git a/vexact/core/request.py b/vexact/core/request.py index cd3c03e..f040c66 100644 --- a/vexact/core/request.py +++ b/vexact/core/request.py @@ -55,10 +55,11 @@ class InferenceRequest: block_ids: list[int] = field(default_factory=list) - # Chain hashes for each FULL block of input_ids_list, populated by the - # scheduler at activation via plan_prefix_cache. Reused by mark_blocks_filled - # so no module recomputes hashes. Reset on preempt — input_ids_list grows to - # include generated tokens, so the chain must be replanned on reactivation. + # Chain hashes for each FULL block of input_ids_list. Populated lazily by the + # scheduler via compute_block_hashes (deterministic per token sequence, so we + # cache here to avoid rehashing if the request gets requeued). Consumed by + # commit_prefix_plan and mark_blocks_filled. Reset on preempt — input_ids_list + # grows to include generated tokens, so the chain must be recomputed. prefix_block_hashes: list[int] = field(default_factory=list) # State: Outputs diff --git a/vexact/core/scheduler.py b/vexact/core/scheduler.py index 3daacf9..d011ded 100644 --- a/vexact/core/scheduler.py +++ b/vexact/core/scheduler.py @@ -164,11 +164,17 @@ def schedule(self) -> SchedulerOutput: break try: - # Plan prefix-cache hits once. block_hashes flow through to - # _activate_request → commit_prefix_plan (no rehash) and later to - # mark_blocks_filled (also no rehash). Eliminates the triple-hash - # that the original peek/allocate/stamp paths had. - block_hashes, num_prefix_hit_blocks = self._kv_cache_manager.plan_prefix_cache(request.input_ids_list) + # Hashes are deterministic per token sequence — cache them on the + # request so a requeued request (budget-rejected this round) doesn't + # re-hash next round. Only `count_prefix_hits` (a few dict lookups) + # has to re-run, since the cache index may have changed. + if not request.prefix_block_hashes: + request.prefix_block_hashes = self._kv_cache_manager.compute_block_hashes( + request.input_ids_list + ) + num_prefix_hit_blocks = self._kv_cache_manager.count_prefix_hits( + request.prefix_block_hashes + ) cached_tokens = num_prefix_hit_blocks * self._kv_cache_manager.page_size # Pass num_comp explicitly so we don't mutate the request before # we know it'll be activated. @@ -176,7 +182,7 @@ def schedule(self) -> SchedulerOutput: request, available_token_budget, num_comp=cached_tokens ) if available_token_budget >= tokens: - self._activate_request(request, block_hashes, num_prefix_hit_blocks) + self._activate_request(request, request.prefix_block_hashes, num_prefix_hit_blocks) # Set scheduling fields and add to the active batch available_token_budget -= tokens @@ -278,7 +284,8 @@ def _activate_request( """Activate a request: commit the prefix-cache plan, attach hashes, bump num_computed_tokens past any cached prefix so prefill skips it. - The plan (block_hashes + num_prefix_hit_blocks) comes from `plan_prefix_cache`; + The plan (block_hashes + num_prefix_hit_blocks) comes from the two-step + `compute_block_hashes` + `count_prefix_hits` pattern in schedule(); we never recompute hashes here. With prefix cache disabled (pp_size > 1), block_hashes is empty and every block is a fresh allocation — same behaviour as the original allocator.