Skip to content

feat: prompt prefix caching for server endpoints#12

Open
eloe wants to merge 7 commits into
mainfrom
feature/prompt-cache
Open

feat: prompt prefix caching for server endpoints#12
eloe wants to merge 7 commits into
mainfrom
feature/prompt-cache

Conversation

@eloe
Copy link
Copy Markdown
Owner

@eloe eloe commented Apr 6, 2026

Summary\nWire PromptCacheState into server endpoints. TurboQuant-compatible prefix reuse via trim(). Probe filter, 50% prefix match threshold, cache_key routing, cached_tokens reporting. 3.9x speedup on warm turns.

eloe and others added 4 commits April 5, 2026 12:52
Wire the existing PromptCacheState from generate.py into both
/v1/responses and /chat/completions endpoints. On repeated requests,
the KV cache from the previous generation is reused for matching
prefix tokens, skipping redundant prefill computation.

This is especially impactful for agentic workflows where the system
prompt (~15K tokens) is the same across requests — only new user
messages need prefilling, reducing latency from ~35s to ~2-3s on
follow-up turns.

Changes:
- Import PromptCacheState from generate.py
- Add get_prompt_cache_state() keyed by model name
- Pass prompt_cache_state to all 4 generate/stream_generate call sites
- Clear prompt cache on model unload

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The prompt cache prefix reuse code assumed all cache layers have
mx.array keys with .shape. TurboQuantKVCache stores keys as
TurboQuantMSEState objects which don't support slicing.

Now checks for .shape before attempting to trim, and falls back
to updating just the offset for quantized cache layers.

Fixes: 'TurboQuantMSEState' object has no attribute 'shape' error
when prompt caching is used with --kv-bits 3.5 --kv-quant-scheme turboquant.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…_tokens

Improve prompt prefix caching to work with all KV cache types:

- Use trim(n) method for prefix reuse instead of manual array slicing.
  Works with both standard KVCache and TurboQuantKVCache.
- Accept optional cache_key parameter for per-session cache routing
  (supports OpenClaw prompt_cache_key and Hermes patterns).
- Add cached_tokens field to GenerationResult, populated from
  reused_prefix_len for cache hit reporting.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…tches

Two fixes for prompt cache stability:

1. Require >= 50% prefix match (min 512 tokens) before reusing KV cache.
   Short matches on quantized caches (TurboQuant) produce corrupted
   repetitive output because trim() only adjusts offset without clearing
   stale quantized data.

2. Skip cache save for requests < 1024 tokens. Agent frameworks send
   short probe/capability-check requests that would evict the valuable
   cached system prompt KV state.

Also uses trim() method for cache trimming (TurboQuant compatible)
instead of manual array slicing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR wires prompt-prefix KV-cache reuse (via PromptCacheState) into the HTTP server endpoints and updates generation to support safer cache trimming/reuse for TurboQuant-style quantized caches, aiming to reduce prefill work on warm turns.

Changes:

  • Add server-side prompt cache state management and pass prompt_cache_state into generate() / stream_generate() calls.
  • Implement prefix-reuse gating (50% / min-512) and TurboQuant-compatible cache trimming via trim()/offset.
  • Add cached_tokens to GenerationResult and introduce a minimum prompt length threshold before saving cache state.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
mlx_vlm/server.py Introduces global prompt-cache state and enables reuse in /responses and /chat/completions paths.
mlx_vlm/generate.py Adds cached_tokens, implements prefix matching threshold + TurboQuant trim handling, and filters when cache state is saved.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlx_vlm/server.py Outdated
Comment on lines +126 to +144
# Prompt cache: reuse KV state across requests with the same prompt prefix.
# Keyed by model name — one PromptCacheState per loaded model.
_prompt_cache_states: dict[str, PromptCacheState] = {}


def get_prompt_cache_state(
model_name: str,
cache_key: Optional[str] = None,
) -> PromptCacheState:
"""Get or create a PromptCacheState for the given model and cache key.

Args:
model_name: The model identifier.
cache_key: Optional routing key (e.g., prompt_cache_key from request).
"""
key = f"{model_name}::{cache_key}" if cache_key else model_name
if key not in _prompt_cache_states:
_prompt_cache_states[key] = PromptCacheState()
return _prompt_cache_states[key]
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_prompt_cache_states is a process-global mutable cache, and the default key is just the model name. That means concurrent requests (or different users) will reuse/overwrite the same PromptCacheState, which can corrupt generations and can leak prior users’ prompt/response context via KV reuse. Scope the cache state per conversation/user (e.g., require a request-provided cache key or derive one from an authenticated user/session), and guard access with a per-key lock or avoid sharing mutable cache objects across overlapping requests.

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/server.py
Comment on lines +924 to +932
# Stream text deltas (with prompt cache reuse)
cache_state = get_prompt_cache_state(openai_request.model)
token_iterator = stream_generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
vision_cache=model_cache.get("vision_cache"),
prompt_cache_state=cache_state,
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_prompt_cache_state() supports a cache_key, but the endpoints always call it with only the model name. This effectively enables cross-conversation cache sharing and also doesn’t implement the PR’s stated “cache_key routing”. Plumb a request field (e.g., prompt_cache_key) through OpenAIRequest/ChatRequest and pass it here, or drop the cache_key parameter if routing isn’t intended.

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/server.py
Comment on lines +1150 to 1160
# Use stream_generate with prompt cache reuse
cache_state = get_prompt_cache_state(request.model)
token_iterator = stream_generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
audio=audio,
vision_cache=model_cache.get("vision_cache"),
prompt_cache_state=cache_state,
**generation_kwargs,
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompt cache reuse is enabled here, but the API usage payload doesn’t expose the new cached_tokens metric (even though GenerationResult now tracks it). If clients are expected to observe cache effectiveness (“cached_tokens reporting” per PR description), extend UsageStats/OpenAIUsage (or add an optional field) and populate it from chunk.cached_tokens / gen_result.cached_tokens.

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/generate.py Outdated
Comment on lines +667 to +680
# corrupted output because trim() only adjusts the offset without
# clearing stale quantized data.
min_reuse = max(512, cached_total // 2)
if prefix_len >= min_reuse and prefix_len < input_ids.shape[1]:
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new prompt-cache behaviors (>=50% prefix reuse threshold, TurboQuant trim path, and the >=1024-token probe filter for saving cache state) aren’t covered by unit tests. Add tests that (1) verify reuse is skipped below the threshold, (2) verify reuse occurs above it and cached_tokens is reported, and (3) verify cache_state.update is only called when the prompt length meets the minimum.

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/generate.py
Comment on lines 772 to 776
prompt_tokens=total_prompt_tokens,
generation_tokens=n + 1,
total_tokens=total_prompt_tokens + n + 1,
cached_tokens=reused_prefix_len,
prompt_tps=prompt_tps,
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cached_tokens is set on the streamed GenerationResult here, but the terminal (post-finalize) GenerationResult later in this function doesn’t include cached_tokens, so callers that rely on the final result (including generate(), which returns the last response) will observe cached_tokens=0 even when reuse happened. Add cached_tokens=reused_prefix_len to the final yield as well to keep stats consistent.

Copilot uses AI. Check for mistakes.
eloe and others added 2 commits April 6, 2026 20:33
Prompt cache fixes:
- Wrap cache reuse in try/except — if trim() fails (e.g.,
  broadcast_shapes from stale KV state), invalidate cache and
  fall back to fresh generation instead of crashing
- Add cache shape validation before generate_step — detect
  seq length mismatches early and rebuild cache
- Add PromptCacheState.invalidate() method to clear stale state

Error handling:
- Sanitize all error messages sent to clients — no more raw
  MLX errors like "[broadcast_shapes] Shapes (12716) and (1840)
  cannot be broadcast" leaking through to Telegram/API users
- Streaming and non-streaming paths both sanitized
- Full errors still logged server-side for debugging

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
PromptCacheState now tracks last_used and created_at timestamps.
A background asyncio task runs every 60s and evicts entries idle
longer than --prompt-cache-ttl (default 300s, configurable via
CLI or PROMPT_CACHE_TTL env var). TTL=0 disables expiry.

This prevents stale KV caches from holding GPU memory indefinitely
(e.g., a 45K-context conversation cache sitting unused for hours).
Eviction triggers gc.collect() + mx.clear_cache() to reclaim VRAM.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings April 7, 2026 03:38
Simulates actual failure scenarios:
- 13-hour idle Telegram conversation (the broadcast_shapes bug)
- Active conversation (30s between messages, never evicted)
- Multiple users with different cache keys (only stale evicted)
- TTL boundary (299s idle on 300s TTL — not evicted)
- Invalidation verification (cache/token_ids set to None)
- Short TTL for dev/testing (5s evicts after 10s idle)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

mlx_vlm/generate.py:345

  • GenerationResult adds cached_tokens, but nothing in the server layer consumes or exposes it (no API field, no usage stats). If “cached_tokens reporting” is a goal of this PR, thread this through to the response/usage models; otherwise consider removing the field to avoid dead/unused stats.
@dataclass
class GenerationResult:
    text: str = ""
    token: Optional[int] = None
    logprobs: Optional[List[float]] = None
    prompt_tokens: int = 0
    generation_tokens: int = 0
    total_tokens: int = 0
    cached_tokens: int = 0
    prompt_tps: float = 0.0
    generation_tps: float = 0.0
    peak_memory: float = 0.0


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlx_vlm/server.py
Comment on lines +90 to +98
async def _prompt_cache_cleanup_loop():
"""Background task that periodically evicts stale prompt caches."""
while True:
await asyncio.sleep(60)
try:
evict_stale_prompt_caches()
except Exception as e:
print(f"[prompt_cache] Cleanup error: {e}")

Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_prompt_cache_cleanup_loop() evicts and calls gc.collect()/mx.clear_cache() concurrently with request handling and generation. Since _prompt_cache_states and PromptCacheState objects are shared across requests, this can race with in-flight generations (e.g., invalidating a cache while it’s being reused, or clearing MLX cache during active compute). Consider protecting prompt-cache access/eviction with a lock and/or running eviction only at safe points (e.g., after requests finish), and avoid calling mx.clear_cache() from a background task while generation may be running.

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/server.py
Comment on lines +162 to +177
def get_prompt_cache_state(
model_name: str,
cache_key: Optional[str] = None,
) -> PromptCacheState:
"""Get or create a PromptCacheState for the given model and cache key.

Args:
model_name: The model identifier.
cache_key: Optional routing key (e.g., prompt_cache_key from request).
"""
key = f"{model_name}::{cache_key}" if cache_key else model_name
if key not in _prompt_cache_states:
_prompt_cache_states[key] = PromptCacheState()
state = _prompt_cache_states[key]
state.touch()
return state
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_prompt_cache_state() supports a cache_key parameter, but the request schemas don’t expose any prompt_cache_key/user/session field and all call sites pass only model name. That means all clients share a single PromptCacheState per model, which can cause cross-request interference and contradicts the PR description’s “cache_key routing”. Add a request field for a cache key (or use an existing per-user/session identifier) and pass it through, or drop cache_key from the API to avoid implying isolation that isn’t implemented.

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/generate.py
Comment on lines 694 to +734
if prompt_cache_state is not None and prompt_cache_state.cache is not None:
prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
if prefix_len > 0 and prefix_len < input_ids.shape[1]:
reused_prefix_len = prefix_len
# Trim to only new tokens
input_ids = input_ids[:, prefix_len:]
# Only skip vision if no image tokens in the new (trimmed) tokens
image_token_id = getattr(model.config, "image_token_id", None) or getattr(
model.config, "image_token_index", None
)
new_ids = input_ids.flatten().tolist()
has_image_in_new = image_token_id is not None and image_token_id in new_ids
if not has_image_in_new:
pixel_values = None
kwargs.pop("cached_image_features", None)
# Reuse the saved KV cache (trimmed to prefix length)
kv_cache = prompt_cache_state.cache
# Trim cache to prefix_len in case it includes generated tokens
for c in kv_cache:
if hasattr(c, "keys") and c.keys is not None:
cached_len = c.keys.shape[2]
if cached_len > prefix_len:
c.keys = c.keys[:, :, :prefix_len, :]
c.values = c.values[:, :, :prefix_len, :]
if hasattr(c, "offset"):
try:
prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
cached_total = len(prompt_cache_state.token_ids) if prompt_cache_state.token_ids else 0
# Only reuse if a substantial prefix matches (>= 50% of cached tokens).
# Short matches on quantized KV caches (TurboQuant) can produce
# corrupted output because trim() only adjusts the offset without
# clearing stale quantized data.
min_reuse = max(512, cached_total // 2)
if prefix_len >= min_reuse and prefix_len < input_ids.shape[1]:
reused_prefix_len = prefix_len
# Trim to only new tokens
input_ids = input_ids[:, prefix_len:]
# Only skip vision if no image tokens in the new (trimmed) tokens
image_token_id = getattr(model.config, "image_token_id", None) or getattr(
model.config, "image_token_index", None
)
new_ids = input_ids.flatten().tolist()
has_image_in_new = image_token_id is not None and image_token_id in new_ids
if not has_image_in_new:
pixel_values = None
kwargs.pop("cached_image_features", None)
# Reuse the saved KV cache (trimmed to prefix length).
# Works with both standard KVCache (mx.array keys) and
# quantized caches (TurboQuant) via their trim() method.
kv_cache = prompt_cache_state.cache
for c in kv_cache:
if hasattr(c, "offset") and c.offset > prefix_len:
trim_amount = c.offset - prefix_len
if hasattr(c, "trim") and callable(c.trim):
c.trim(trim_amount)
elif hasattr(c, "keys") and c.keys is not None:
keys = c.keys
if hasattr(keys, "shape") and len(keys.shape) >= 3:
c.keys = keys[:, :, :prefix_len, :]
c.values = c.values[:, :, :prefix_len, :]
c.offset = prefix_len
elif hasattr(c, "offset") and c.offset > prefix_len:
# Quantized cache: just update offset if possible
c.offset = prefix_len
kwargs["prompt_cache"] = kv_cache
kwargs["prompt_cache"] = kv_cache
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompt-cache reuse trims KV state only when cache entries expose offset/trim/keys, but some models return cache structures without these (e.g., Florence2 make_cache() returns tuples of SimpleKVCache). In those cases, if prefix_len < cached_total, reuse can proceed without actually trimming the underlying cache, risking incorrect context or shape errors. Consider adding a guard: only allow partial-prefix reuse when all cache entries are trimmable to prefix_len; otherwise require prefix_len == cached_total (full match) or skip reuse/invalidate.

Copilot uses AI. Check for mistakes.
Comment on lines +154 to +159
import time
time.sleep(0.01)
state.touch()
assert state.last_used > old_time


Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test uses time.sleep() to assert touch() advances last_used, which can be flaky and slows the suite. Prefer monkeypatching time.time() (or setting last_used to a known older value) and asserting touch() sets it to the patched “now”.

Suggested change
import time
time.sleep(0.01)
state.touch()
assert state.last_used > old_time
mocked_now = old_time + 10
with patch("mlx_vlm.generate.time.time", return_value=mocked_now):
state.touch()
assert state.last_used == mocked_now

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants