Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1407088
feat: OpenAI Responses API compliance with tool calling support
eloe Apr 5, 2026
0739cb7
fix: suppress tool call tokens from streaming text deltas
eloe Apr 5, 2026
f3fa553
feat: add prompt prefix caching to server endpoints
eloe Apr 5, 2026
a4048eb
feat: combined OpenAI Responses + prompt prefix caching
eloe Apr 5, 2026
19a2860
feat: add concurrency guard and finish_reason=tool_calls to combined …
eloe Apr 5, 2026
5e66a84
test: comprehensive tests for prompt cache, concurrency guard, finish…
eloe Apr 5, 2026
0b954c9
feat: add stop sequences support for both endpoints
eloe Apr 5, 2026
7227505
fix: handle TurboQuant KV cache in prompt cache trimming
eloe Apr 5, 2026
f9e008e
Merge feature/stop-sequences into combined-bastion
eloe Apr 5, 2026
a9bdcf1
feat: add stop sequences, tool_choice, and TurboQuant cache fix
eloe Apr 5, 2026
415da71
fix: stop sequences pass strings not token IDs to stopping criteria
eloe Apr 5, 2026
4ee6bf0
feat: add JSON mode, context tracking, and request cancellation (#7, …
eloe Apr 6, 2026
949bdf8
feat: prompt cache key routing for OpenClaw and Hermes compatibility
eloe Apr 6, 2026
318c962
feat: report cached_tokens in usage for OC/Hermes prompt caching
eloe Apr 6, 2026
12116f2
fix: use trim() for KV cache prefix reuse (TurboQuant compatible)
eloe Apr 6, 2026
ef8e0c3
debug: add prompt cache logging
eloe Apr 6, 2026
5712dd0
debug: log token mismatch details
eloe Apr 6, 2026
2853726
fix: skip cache save for short probe requests (<1024 tokens)
eloe Apr 6, 2026
873ad47
feat: production-ready prompt caching with probe request filter
eloe Apr 6, 2026
792bb8a
fix: require substantial prefix match for KV cache reuse
eloe Apr 6, 2026
d49397f
style: apply black, isort, autoflake formatting
eloe Apr 6, 2026
5250deb
fix: stale KV cache recovery + TTL eviction + sanitized errors
eloe Apr 7, 2026
0396983
fix: use offset instead of keys.shape for TurboQuant cache validation
eloe Apr 7, 2026
617720c
fix: add default repetition penalty to prevent MoE degeneration loops
eloe Apr 7, 2026
2391fca
fix: normalize Responses API tool format for Jinja chat templates
eloe Apr 7, 2026
a398f44
debug: add request logging to responses endpoint for tool call invest…
eloe Apr 7, 2026
0a5bd0d
debug: log formatted prompt tail for tool call investigation
eloe Apr 7, 2026
c471efc
debug: log tool call detection in streaming responses
eloe Apr 7, 2026
cc2e09a
feat: add --default-max-tokens CLI flag for server-side token limit
eloe Apr 8, 2026
c89ce53
fix: address code review and security findings
eloe Apr 8, 2026
a2befaf
fix: address Copilot review findings
eloe Apr 8, 2026
eacb558
fix: address Copilot review round 3
eloe Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 116 additions & 27 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import functools
import json
import os
import time
from collections.abc import Sequence
from dataclasses import dataclass
Expand Down Expand Up @@ -338,6 +339,7 @@ class GenerationResult:
prompt_tokens: int = 0
generation_tokens: int = 0
total_tokens: int = 0
cached_tokens: int = 0
prompt_tps: float = 0.0
generation_tps: float = 0.0
peak_memory: float = 0.0
Expand All @@ -354,6 +356,13 @@ class PromptCacheState:
def __init__(self):
self.cache: Optional[List[Any]] = None
self.token_ids: Optional[List[int]] = None
self.last_used: float = time.time()
self.created_at: float = time.time()

@property
def token_count(self) -> int:
"""Number of tokens stored in the cache."""
return len(self.token_ids) if self.token_ids else 0

def find_prefix_length(self, new_ids: list) -> int:
"""Return the number of leading tokens that match the cached ids."""
Expand All @@ -365,10 +374,20 @@ def find_prefix_length(self, new_ids: list) -> int:
return i
return max_len

def touch(self):
"""Update last_used timestamp."""
self.last_used = time.time()

def update(self, token_ids: list, kv_cache: list):
"""Store the full token sequence and corresponding KV cache."""
self.token_ids = list(token_ids)
self.cache = kv_cache
self.last_used = time.time()

def invalidate(self):
"""Discard cached state, forcing a full prefill on next turn."""
self.cache = None
self.token_ids = None


def generate_step(
Expand Down Expand Up @@ -668,33 +687,65 @@ def stream_generate(
reused_prefix_len = 0
full_input_ids_list = input_ids.flatten().tolist()

# Save originals for fallback if cache reuse fails
_original_input_ids = input_ids
_original_pixel_values = pixel_values
_original_kwargs = {k: v for k, v in kwargs.items() if k == "cached_image_features"}

if prompt_cache_state is not None and prompt_cache_state.cache is not None:
prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
if prefix_len > 0 and prefix_len < input_ids.shape[1]:
reused_prefix_len = prefix_len
# Trim to only new tokens
input_ids = input_ids[:, prefix_len:]
# Only skip vision if no image tokens in the new (trimmed) tokens
image_token_id = getattr(model.config, "image_token_id", None) or getattr(
model.config, "image_token_index", None
try:
prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
cached_total = (
len(prompt_cache_state.token_ids) if prompt_cache_state.token_ids else 0
)
new_ids = input_ids.flatten().tolist()
has_image_in_new = image_token_id is not None and image_token_id in new_ids
if not has_image_in_new:
pixel_values = None
kwargs.pop("cached_image_features", None)
# Reuse the saved KV cache (trimmed to prefix length)
kv_cache = prompt_cache_state.cache
# Trim cache to prefix_len in case it includes generated tokens
for c in kv_cache:
if hasattr(c, "keys") and c.keys is not None:
cached_len = c.keys.shape[2]
if cached_len > prefix_len:
c.keys = c.keys[:, :, :prefix_len, :]
c.values = c.values[:, :, :prefix_len, :]
if hasattr(c, "offset"):
c.offset = prefix_len
kwargs["prompt_cache"] = kv_cache
# Only reuse if a substantial prefix matches (>= 50% of cached tokens).
# Short matches on quantized KV caches (TurboQuant) can produce
# corrupted output because trim() only adjusts the offset without
# clearing stale quantized data.
min_reuse = max(512, cached_total // 2)
if prefix_len >= min_reuse and prefix_len < input_ids.shape[1]:
reused_prefix_len = prefix_len
# Trim to only new tokens
input_ids = input_ids[:, prefix_len:]
# Only skip vision if no image tokens in the new (trimmed) tokens
image_token_id = getattr(
model.config, "image_token_id", None
) or getattr(model.config, "image_token_index", None)
new_ids = input_ids.flatten().tolist()
has_image_in_new = (
image_token_id is not None and image_token_id in new_ids
)
if not has_image_in_new:
pixel_values = None
kwargs.pop("cached_image_features", None)
# Reuse the saved KV cache (trimmed to prefix length).
# Works with both standard KVCache (mx.array keys) and
# quantized caches (TurboQuant) via their trim() method.
kv_cache = prompt_cache_state.cache
for c in kv_cache:
if hasattr(c, "offset") and c.offset > prefix_len:
trim_amount = c.offset - prefix_len
if hasattr(c, "trim") and callable(c.trim):
c.trim(trim_amount)
elif hasattr(c, "keys") and c.keys is not None:
keys = c.keys
if hasattr(keys, "shape") and len(keys.shape) >= 3:
c.keys = keys[:, :, :prefix_len, :]
c.values = c.values[:, :, :prefix_len, :]
c.offset = prefix_len
kwargs["prompt_cache"] = kv_cache
except Exception as e:
# Cache reuse failed (e.g., shape mismatch, stale KV state).
# Invalidate and fall back to fresh generation.
if os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes"):
print(f"[prompt_cache] Cache reuse failed, invalidating: {e}")
prompt_cache_state.invalidate()
reused_prefix_len = 0
Comment on lines +737 to +743
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stream_generate() prints cache reuse/validation failures unconditionally. In production these paths can be hit during normal operation (TTL eviction, model reloads, cache mismatches), potentially spamming stdout and making logs hard to use. Consider routing these through the server’s verbosity/logging mechanism (e.g., logger.debug) or making them conditional on an explicit debug flag.

Copilot uses AI. Check for mistakes.
input_ids = _original_input_ids
pixel_values = _original_pixel_values
if "cached_image_features" in _original_kwargs:
kwargs["cached_image_features"] = _original_kwargs["cached_image_features"]
kwargs.pop("prompt_cache", None)

if thinking_budget is not None:
thinking_start_token_id = tokenizer.encode(
Expand All @@ -720,6 +771,35 @@ def stream_generate(
model.language_model,
max_kv_size=kwargs.get("max_kv_size", None),
)

# Validate cache shapes before generation. If the cached KV state has
# inconsistent shapes (e.g., stale after model reload), discard it and
# build a fresh cache to avoid broadcast_shapes errors during generation.
if reused_prefix_len > 0:
try:
for c in kwargs["prompt_cache"]:
if hasattr(c, "offset"):
# Use offset for all cache types (works for both standard
# KVCache and quantized TurboQuant caches).
if c.offset != reused_prefix_len:
raise ValueError(
f"Cache offset mismatch: expected {reused_prefix_len}, got {c.offset}"
)
except (ValueError, IndexError, AttributeError) as e:
if os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes"):
print(f"[prompt_cache] Cache validation failed, rebuilding: {e}")
if prompt_cache_state is not None:
prompt_cache_state.invalidate()
reused_prefix_len = 0
input_ids = _original_input_ids
pixel_values = _original_pixel_values
if "cached_image_features" in _original_kwargs:
kwargs["cached_image_features"] = _original_kwargs["cached_image_features"]
kwargs["prompt_cache"] = cache.make_prompt_cache(
model.language_model,
max_kv_size=kwargs.get("max_kv_size", None),
)

tracked_cache = kwargs["prompt_cache"]

total_prompt_tokens = reused_prefix_len + input_ids.size
Expand Down Expand Up @@ -758,6 +838,7 @@ def stream_generate(
prompt_tokens=total_prompt_tokens,
generation_tokens=n + 1,
total_tokens=total_prompt_tokens + n + 1,
cached_tokens=reused_prefix_len,
prompt_tps=prompt_tps,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
Expand All @@ -771,13 +852,21 @@ def stream_generate(
prompt_tokens=total_prompt_tokens,
generation_tokens=n + 1,
total_tokens=total_prompt_tokens + n + 1,
cached_tokens=reused_prefix_len,
prompt_tps=prompt_tps,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
)

# Save cache state for potential reuse on next turn
if prompt_cache_state is not None:
# Save cache state for potential reuse on next turn.
# Only save if the prompt was substantial (>= 1024 tokens) to avoid
# polluting the cache with short probe/capability-check requests that
# some agent frameworks send before the real request.
_MIN_CACHE_TOKENS = 1024
if (
prompt_cache_state is not None
and len(full_input_ids_list) >= _MIN_CACHE_TOKENS
):
all_ids = full_input_ids_list + [
t.item() if hasattr(t, "item") else t for t in generated_tokens
]
Expand Down
38 changes: 38 additions & 0 deletions mlx_vlm/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,39 @@ def get_message_json(
)


def _normalize_tool(tool):
"""Ensure a tool dict uses the Chat Completions nested format.

The OpenAI Responses API uses a flat format::

{"type": "function", "name": "...", "description": "...", "parameters": {...}}

While Chat Completions (and most Jinja chat templates) expect::

{"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}}

This helper converts the flat format to nested so that Jinja templates
referencing ``tool['function']`` (e.g. Gemma 4) work correctly.
"""
if not isinstance(tool, dict):
return tool
# Already in nested format
if "function" in tool and isinstance(tool["function"], dict):
return tool
# Flat Responses-API format → wrap in 'function' key
if tool.get("type") == "function" and "name" in tool:
fn = {k: v for k, v in tool.items() if k != "type"}
return {"type": "function", "function": fn}
return tool


def _normalize_tools(tools):
"""Normalize a list of tool dicts to the Chat Completions nested format."""
if not isinstance(tools, list):
return tools
return [_normalize_tool(t) for t in tools]


def get_chat_template(
processor,
messages: List[Dict[str, Any]],
Expand Down Expand Up @@ -615,6 +648,11 @@ def _missing_template_error(error: Exception) -> bool:
if template_processor is None:
return _messages_to_plain_prompt()

# Normalize tool dicts from flat Responses-API format to the nested
# Chat Completions format expected by Jinja chat templates.
if "tools" in kwargs and kwargs["tools"] is not None:
kwargs["tools"] = _normalize_tools(kwargs["tools"])

try:
return template_processor.apply_chat_template(
messages,
Expand Down
Loading