diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py
index 1d94bd805..7067b3a55 100644
--- a/mlx_vlm/generate.py
+++ b/mlx_vlm/generate.py
@@ -3,6 +3,7 @@
import contextlib
import functools
import json
+import os
import time
from collections.abc import Sequence
from dataclasses import dataclass
@@ -338,6 +339,7 @@ class GenerationResult:
prompt_tokens: int = 0
generation_tokens: int = 0
total_tokens: int = 0
+ cached_tokens: int = 0
prompt_tps: float = 0.0
generation_tps: float = 0.0
peak_memory: float = 0.0
@@ -354,6 +356,13 @@ class PromptCacheState:
def __init__(self):
self.cache: Optional[List[Any]] = None
self.token_ids: Optional[List[int]] = None
+ self.last_used: float = time.time()
+ self.created_at: float = time.time()
+
+ @property
+ def token_count(self) -> int:
+ """Number of tokens stored in the cache."""
+ return len(self.token_ids) if self.token_ids else 0
def find_prefix_length(self, new_ids: list) -> int:
"""Return the number of leading tokens that match the cached ids."""
@@ -365,10 +374,20 @@ def find_prefix_length(self, new_ids: list) -> int:
return i
return max_len
+ def touch(self):
+ """Update last_used timestamp."""
+ self.last_used = time.time()
+
def update(self, token_ids: list, kv_cache: list):
"""Store the full token sequence and corresponding KV cache."""
self.token_ids = list(token_ids)
self.cache = kv_cache
+ self.last_used = time.time()
+
+ def invalidate(self):
+ """Discard cached state, forcing a full prefill on next turn."""
+ self.cache = None
+ self.token_ids = None
def generate_step(
@@ -668,33 +687,65 @@ def stream_generate(
reused_prefix_len = 0
full_input_ids_list = input_ids.flatten().tolist()
+ # Save originals for fallback if cache reuse fails
+ _original_input_ids = input_ids
+ _original_pixel_values = pixel_values
+ _original_kwargs = {k: v for k, v in kwargs.items() if k == "cached_image_features"}
+
if prompt_cache_state is not None and prompt_cache_state.cache is not None:
- prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
- if prefix_len > 0 and prefix_len < input_ids.shape[1]:
- reused_prefix_len = prefix_len
- # Trim to only new tokens
- input_ids = input_ids[:, prefix_len:]
- # Only skip vision if no image tokens in the new (trimmed) tokens
- image_token_id = getattr(model.config, "image_token_id", None) or getattr(
- model.config, "image_token_index", None
+ try:
+ prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
+ cached_total = (
+ len(prompt_cache_state.token_ids) if prompt_cache_state.token_ids else 0
)
- new_ids = input_ids.flatten().tolist()
- has_image_in_new = image_token_id is not None and image_token_id in new_ids
- if not has_image_in_new:
- pixel_values = None
- kwargs.pop("cached_image_features", None)
- # Reuse the saved KV cache (trimmed to prefix length)
- kv_cache = prompt_cache_state.cache
- # Trim cache to prefix_len in case it includes generated tokens
- for c in kv_cache:
- if hasattr(c, "keys") and c.keys is not None:
- cached_len = c.keys.shape[2]
- if cached_len > prefix_len:
- c.keys = c.keys[:, :, :prefix_len, :]
- c.values = c.values[:, :, :prefix_len, :]
- if hasattr(c, "offset"):
- c.offset = prefix_len
- kwargs["prompt_cache"] = kv_cache
+ # Only reuse if a substantial prefix matches (>= 50% of cached tokens).
+ # Short matches on quantized KV caches (TurboQuant) can produce
+ # corrupted output because trim() only adjusts the offset without
+ # clearing stale quantized data.
+ min_reuse = max(512, cached_total // 2)
+ if prefix_len >= min_reuse and prefix_len < input_ids.shape[1]:
+ reused_prefix_len = prefix_len
+ # Trim to only new tokens
+ input_ids = input_ids[:, prefix_len:]
+ # Only skip vision if no image tokens in the new (trimmed) tokens
+ image_token_id = getattr(
+ model.config, "image_token_id", None
+ ) or getattr(model.config, "image_token_index", None)
+ new_ids = input_ids.flatten().tolist()
+ has_image_in_new = (
+ image_token_id is not None and image_token_id in new_ids
+ )
+ if not has_image_in_new:
+ pixel_values = None
+ kwargs.pop("cached_image_features", None)
+ # Reuse the saved KV cache (trimmed to prefix length).
+ # Works with both standard KVCache (mx.array keys) and
+ # quantized caches (TurboQuant) via their trim() method.
+ kv_cache = prompt_cache_state.cache
+ for c in kv_cache:
+ if hasattr(c, "offset") and c.offset > prefix_len:
+ trim_amount = c.offset - prefix_len
+ if hasattr(c, "trim") and callable(c.trim):
+ c.trim(trim_amount)
+ elif hasattr(c, "keys") and c.keys is not None:
+ keys = c.keys
+ if hasattr(keys, "shape") and len(keys.shape) >= 3:
+ c.keys = keys[:, :, :prefix_len, :]
+ c.values = c.values[:, :, :prefix_len, :]
+ c.offset = prefix_len
+ kwargs["prompt_cache"] = kv_cache
+ except Exception as e:
+ # Cache reuse failed (e.g., shape mismatch, stale KV state).
+ # Invalidate and fall back to fresh generation.
+ if os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes"):
+ print(f"[prompt_cache] Cache reuse failed, invalidating: {e}")
+ prompt_cache_state.invalidate()
+ reused_prefix_len = 0
+ input_ids = _original_input_ids
+ pixel_values = _original_pixel_values
+ if "cached_image_features" in _original_kwargs:
+ kwargs["cached_image_features"] = _original_kwargs["cached_image_features"]
+ kwargs.pop("prompt_cache", None)
if thinking_budget is not None:
thinking_start_token_id = tokenizer.encode(
@@ -720,6 +771,35 @@ def stream_generate(
model.language_model,
max_kv_size=kwargs.get("max_kv_size", None),
)
+
+ # Validate cache shapes before generation. If the cached KV state has
+ # inconsistent shapes (e.g., stale after model reload), discard it and
+ # build a fresh cache to avoid broadcast_shapes errors during generation.
+ if reused_prefix_len > 0:
+ try:
+ for c in kwargs["prompt_cache"]:
+ if hasattr(c, "offset"):
+ # Use offset for all cache types (works for both standard
+ # KVCache and quantized TurboQuant caches).
+ if c.offset != reused_prefix_len:
+ raise ValueError(
+ f"Cache offset mismatch: expected {reused_prefix_len}, got {c.offset}"
+ )
+ except (ValueError, IndexError, AttributeError) as e:
+ if os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes"):
+ print(f"[prompt_cache] Cache validation failed, rebuilding: {e}")
+ if prompt_cache_state is not None:
+ prompt_cache_state.invalidate()
+ reused_prefix_len = 0
+ input_ids = _original_input_ids
+ pixel_values = _original_pixel_values
+ if "cached_image_features" in _original_kwargs:
+ kwargs["cached_image_features"] = _original_kwargs["cached_image_features"]
+ kwargs["prompt_cache"] = cache.make_prompt_cache(
+ model.language_model,
+ max_kv_size=kwargs.get("max_kv_size", None),
+ )
+
tracked_cache = kwargs["prompt_cache"]
total_prompt_tokens = reused_prefix_len + input_ids.size
@@ -758,6 +838,7 @@ def stream_generate(
prompt_tokens=total_prompt_tokens,
generation_tokens=n + 1,
total_tokens=total_prompt_tokens + n + 1,
+ cached_tokens=reused_prefix_len,
prompt_tps=prompt_tps,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
@@ -771,13 +852,21 @@ def stream_generate(
prompt_tokens=total_prompt_tokens,
generation_tokens=n + 1,
total_tokens=total_prompt_tokens + n + 1,
+ cached_tokens=reused_prefix_len,
prompt_tps=prompt_tps,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
)
- # Save cache state for potential reuse on next turn
- if prompt_cache_state is not None:
+ # Save cache state for potential reuse on next turn.
+ # Only save if the prompt was substantial (>= 1024 tokens) to avoid
+ # polluting the cache with short probe/capability-check requests that
+ # some agent frameworks send before the real request.
+ _MIN_CACHE_TOKENS = 1024
+ if (
+ prompt_cache_state is not None
+ and len(full_input_ids_list) >= _MIN_CACHE_TOKENS
+ ):
all_ids = full_input_ids_list + [
t.item() if hasattr(t, "item") else t for t in generated_tokens
]
diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py
index a4db38c47..52b68b3d8 100644
--- a/mlx_vlm/prompt_utils.py
+++ b/mlx_vlm/prompt_utils.py
@@ -460,6 +460,39 @@ def get_message_json(
)
+def _normalize_tool(tool):
+ """Ensure a tool dict uses the Chat Completions nested format.
+
+ The OpenAI Responses API uses a flat format::
+
+ {"type": "function", "name": "...", "description": "...", "parameters": {...}}
+
+ While Chat Completions (and most Jinja chat templates) expect::
+
+ {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}}
+
+ This helper converts the flat format to nested so that Jinja templates
+ referencing ``tool['function']`` (e.g. Gemma 4) work correctly.
+ """
+ if not isinstance(tool, dict):
+ return tool
+ # Already in nested format
+ if "function" in tool and isinstance(tool["function"], dict):
+ return tool
+ # Flat Responses-API format → wrap in 'function' key
+ if tool.get("type") == "function" and "name" in tool:
+ fn = {k: v for k, v in tool.items() if k != "type"}
+ return {"type": "function", "function": fn}
+ return tool
+
+
+def _normalize_tools(tools):
+ """Normalize a list of tool dicts to the Chat Completions nested format."""
+ if not isinstance(tools, list):
+ return tools
+ return [_normalize_tool(t) for t in tools]
+
+
def get_chat_template(
processor,
messages: List[Dict[str, Any]],
@@ -615,6 +648,11 @@ def _missing_template_error(error: Exception) -> bool:
if template_processor is None:
return _messages_to_plain_prompt()
+ # Normalize tool dicts from flat Responses-API format to the nested
+ # Chat Completions format expected by Jinja chat templates.
+ if "tools" in kwargs and kwargs["tools"] is not None:
+ kwargs["tools"] = _normalize_tools(kwargs["tools"])
+
try:
return template_processor.apply_chat_template(
messages,
diff --git a/mlx_vlm/responses_models.py b/mlx_vlm/responses_models.py
new file mode 100644
index 000000000..507e09578
--- /dev/null
+++ b/mlx_vlm/responses_models.py
@@ -0,0 +1,515 @@
+"""Pydantic models for the OpenAI Responses API (/v1/responses).
+
+This module defines all request, response, and streaming event models
+for the OpenAI-compatible Responses endpoint. Models are self-contained
+to avoid circular imports with server.py.
+
+Reference: https://developers.openai.com/api/reference/resources/responses
+"""
+
+import uuid
+from typing import Any, List, Literal, Optional, Union
+
+from pydantic import BaseModel, ConfigDict, Field
+from typing_extensions import Required, TypeAlias, TypedDict
+
+# ---------------------------------------------------------------------------
+# Constants (mirrored from mlx_vlm.generate to avoid heavy imports)
+# ---------------------------------------------------------------------------
+
+DEFAULT_MAX_TOKENS = 4096
+DEFAULT_TEMPERATURE = 0.0
+DEFAULT_TOP_P = 1.0
+DEFAULT_THINKING_START_TOKEN = ""
+DEFAULT_THINKING_END_TOKEN = ""
+
+
+# ---------------------------------------------------------------------------
+# Base classes (duplicated from server.py for import independence)
+# ---------------------------------------------------------------------------
+
+
+class FlexibleBaseModel(BaseModel):
+ """Base model that silently accepts unknown fields for forward compatibility."""
+
+ model_config = ConfigDict(extra="allow")
+
+ def dump_kwargs(self, *fields: str) -> dict[str, Any]:
+ """Return a dict of the requested fields, omitting ``None`` values."""
+ return {
+ key: getattr(self, key)
+ for key in fields
+ if hasattr(self, key) and getattr(self, key) is not None
+ }
+
+
+class GenerationParams(FlexibleBaseModel):
+ """Sampling parameters shared across endpoints."""
+
+ temperature: float = Field(
+ DEFAULT_TEMPERATURE, description="Temperature for sampling."
+ )
+ top_p: float = Field(DEFAULT_TOP_P, description="Top-p sampling.")
+ top_k: Optional[int] = Field(None, description="Top-k sampling cutoff.")
+ min_p: Optional[float] = Field(None, description="Min-p sampling threshold.")
+ repetition_penalty: Optional[float] = Field(
+ None, description="Penalty applied to repeated tokens."
+ )
+ logit_bias: Optional[dict[int, float]] = Field(
+ None, description="Additive logit bias keyed by token id."
+ )
+
+ def shared_generation_kwargs(self) -> dict[str, Any]:
+ return self.dump_kwargs(
+ "temperature",
+ "top_p",
+ "top_k",
+ "min_p",
+ "repetition_penalty",
+ "logit_bias",
+ )
+
+
+class TemplateParams(FlexibleBaseModel):
+ """Chat template parameters (thinking mode, etc.)."""
+
+ enable_thinking: Optional[bool] = Field(
+ None, description="Enable thinking mode in the chat template."
+ )
+ thinking_budget: Optional[int] = Field(
+ None,
+ description="Maximum number of thinking tokens before forcing the end token.",
+ )
+ thinking_start_token: Optional[str] = Field(
+ DEFAULT_THINKING_START_TOKEN,
+ description="Token that marks the start of a thinking block.",
+ )
+ thinking_end_token: Optional[str] = Field(
+ DEFAULT_THINKING_END_TOKEN,
+ description="Token that marks the end of a thinking block.",
+ )
+
+ def template_kwargs(self) -> dict[str, Any]:
+ kwargs = self.dump_kwargs(
+ "enable_thinking",
+ "thinking_budget",
+ "thinking_start_token",
+ "thinking_end_token",
+ )
+ kwargs.setdefault("enable_thinking", False)
+ return kwargs
+
+
+# ---------------------------------------------------------------------------
+# Input content types (TypedDicts matching OpenAI SDK)
+# ---------------------------------------------------------------------------
+
+
+class ResponseInputTextParam(TypedDict, total=False):
+ """Text content item — accepts both ``input_text`` and ``text`` types."""
+
+ text: Required[str]
+ type: Required[Literal["input_text", "text"]]
+
+
+class ResponseInputImageParam(TypedDict, total=False):
+ """Image content item with a direct image URL."""
+
+ detail: Literal["high", "low", "auto"]
+ type: Required[Literal["input_image"]]
+ image_url: Required[str]
+ file_id: Optional[str]
+
+
+class InputAudio(TypedDict, total=False):
+ data: Required[str]
+ format: Required[str]
+
+
+class ResponseInputAudioParam(TypedDict, total=False):
+ """Audio content item."""
+
+ type: Required[Literal["input_audio"]]
+ input_audio: Required[InputAudio]
+
+
+class ImageUrl(TypedDict, total=False):
+ url: Required[str]
+
+
+class ResponseImageUrlParam(TypedDict, total=False):
+ """Image content item with nested ``image_url.url`` (chat/completions format)."""
+
+ type: Required[Literal["image_url"]]
+ image_url: Required[ImageUrl]
+
+
+class ResponseOutputText(TypedDict, total=False):
+ """Output text item used in multi-turn assistant messages."""
+
+ text: Required[str]
+ type: Required[Literal["output_text"]]
+
+
+ResponseInputContentParam: TypeAlias = Union[
+ ResponseInputTextParam,
+ ResponseInputImageParam,
+ ResponseImageUrlParam,
+ ResponseInputAudioParam,
+]
+
+ResponseInputMessageContentListParam: TypeAlias = List[ResponseInputContentParam]
+ResponseOutputMessageContentList: TypeAlias = List[ResponseOutputText]
+
+
+# ---------------------------------------------------------------------------
+# Chat message model
+# ---------------------------------------------------------------------------
+
+
+class ChatMessage(FlexibleBaseModel):
+ """A single message in the conversation input."""
+
+ role: Literal["user", "assistant", "system", "developer", "tool"] = Field(
+ ..., description="Role of the message sender."
+ )
+ content: Optional[
+ Union[
+ str,
+ ResponseInputMessageContentListParam,
+ ResponseOutputMessageContentList,
+ ]
+ ] = Field(None, description="Content of the message.")
+ tool_calls: List = Field(default_factory=list)
+
+
+# ---------------------------------------------------------------------------
+# Function tool definition
+# ---------------------------------------------------------------------------
+
+
+class ResponseFunctionTool(BaseModel):
+ """A function tool the model may call."""
+
+ type: Literal["function"] = "function"
+ name: str = Field(..., description="The name of the function.")
+ description: Optional[str] = Field(
+ None, description="A description of what the function does."
+ )
+ parameters: Optional[dict] = Field(
+ None, description="JSON Schema object describing the function parameters."
+ )
+ strict: Optional[bool] = Field(
+ None, description="Whether to enforce strict schema adherence."
+ )
+
+
+# ---------------------------------------------------------------------------
+# Function call input items (for multi-turn tool use)
+# ---------------------------------------------------------------------------
+
+
+class ResponseFunctionCallInputItem(BaseModel):
+ """A function call from a previous assistant turn, included in input."""
+
+ type: Literal["function_call"] = "function_call"
+ call_id: str = Field(..., description="Unique ID for this tool call.")
+ name: str = Field(..., description="The function name that was called.")
+ arguments: str = Field(..., description="JSON string of the function arguments.")
+ status: Optional[str] = "completed"
+
+
+class ResponseFunctionCallOutputInputItem(BaseModel):
+ """The output/result of a function call, sent back by the client."""
+
+ type: Literal["function_call_output"] = "function_call_output"
+ call_id: str = Field(
+ ..., description="The call_id of the function call this is a result for."
+ )
+ output: str = Field(..., description="The function output as a string.")
+
+
+# ---------------------------------------------------------------------------
+# Request model
+# ---------------------------------------------------------------------------
+
+
+class ResponsesRequest(GenerationParams, TemplateParams):
+ """OpenAI Responses API request body.
+
+ Reference: https://developers.openai.com/api/reference/resources/responses/create
+ """
+
+ input: Union[str, List[Any]] = Field(
+ ..., description="Input text or list of input items (messages, tool outputs)."
+ )
+ model: str = Field(..., description="The model to use for generation.")
+ max_output_tokens: Optional[int] = Field(
+ None, description="Maximum number of tokens to generate. Uses server default if not specified."
+ )
+ stream: bool = Field(
+ False, description="Whether to stream the response chunk by chunk."
+ )
+ tools: Optional[List[dict]] = Field(
+ None, description="Tool definitions the model may call."
+ )
+ tool_choice: Optional[Any] = Field(
+ "auto", description='Tool choice: "none", "auto", "required", or specific tool.'
+ )
+ parallel_tool_calls: bool = Field(True, description="Allow parallel tool calls.")
+ previous_response_id: Optional[str] = Field(
+ None,
+ description="ID of a previous response for multi-turn context replay.",
+ )
+ instructions: Optional[str] = Field(
+ None,
+ description="System/developer message inserted into context.",
+ )
+ metadata: Optional[dict] = Field(
+ None, description="Up to 16 key-value pairs of metadata."
+ )
+ stop: Optional[Union[str, List[str]]] = Field(
+ None,
+ description="Up to 4 sequences where the API will stop generating further tokens.",
+ )
+ response_format: Optional[dict] = Field(
+ None,
+ description='Output format: {"type": "text"} or {"type": "json_object"}.',
+ )
+ prompt_cache_key: Optional[str] = Field(
+ None,
+ description="Stable key for prompt cache routing across turns.",
+ )
+
+ def generation_kwargs(self) -> dict[str, Any]:
+ kwargs = self.dump_kwargs("max_output_tokens")
+ if "max_output_tokens" in kwargs:
+ kwargs["max_tokens"] = kwargs.pop("max_output_tokens")
+ return {**kwargs, **self.shared_generation_kwargs()}
+
+
+# ---------------------------------------------------------------------------
+# Output item models
+# ---------------------------------------------------------------------------
+
+
+class ContentPartOutputText(BaseModel):
+ """A text content part in an output message."""
+
+ type: Literal["output_text"] = "output_text"
+ text: str = ""
+ annotations: List[str] = Field(default_factory=list)
+
+
+class ResponseMessageItem(BaseModel):
+ """An assistant message output item."""
+
+ id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:24]}")
+ type: Literal["message"] = "message"
+ role: Literal["assistant"] = "assistant"
+ status: Literal["in_progress", "completed"] = "completed"
+ content: List[ContentPartOutputText] = Field(default_factory=list)
+
+
+class ResponseFunctionCallItem(BaseModel):
+ """A function call output item."""
+
+ type: Literal["function_call"] = "function_call"
+ id: str = Field(default_factory=lambda: f"fc_{uuid.uuid4().hex[:24]}")
+ call_id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex[:24]}")
+ name: str = Field(..., description="The function name being called.")
+ arguments: str = Field(..., description="JSON string of the function arguments.")
+ status: Literal["completed"] = "completed"
+
+
+class ResponseIncompleteDetails(BaseModel):
+ """Details about why a response is incomplete."""
+
+ reason: Literal["max_output_tokens", "content_filter"]
+
+
+# ---------------------------------------------------------------------------
+# Usage and error models
+# ---------------------------------------------------------------------------
+
+
+class InputTokensDetails(BaseModel):
+ """Breakdown of input token usage."""
+
+ cached_tokens: int = 0
+
+
+class ResponseUsage(BaseModel):
+ """Token usage details with cache-awareness for OpenClaw/Hermes."""
+
+ input_tokens: int
+ output_tokens: int
+ total_tokens: int
+ input_tokens_details: Optional[InputTokensDetails] = None
+
+
+class ResponseErrorObject(BaseModel):
+ """Error object returned when the model fails to generate a Response."""
+
+ code: Optional[str] = None
+ message: Optional[str] = None
+ param: Optional[str] = None
+ type: Optional[str] = None
+
+
+# ---------------------------------------------------------------------------
+# Response object
+# ---------------------------------------------------------------------------
+
+
+class ResponseObject(BaseModel):
+ """The top-level Response object returned by /v1/responses.
+
+ Reference: https://developers.openai.com/api/reference/resources/responses/object
+ """
+
+ id: str = Field(
+ default_factory=lambda: f"resp_{uuid.uuid4().hex[:24]}",
+ description="Unique identifier for this Response.",
+ )
+ object: Literal["response"] = Field(
+ "response", description="The object type — always ``response``."
+ )
+ created_at: int = Field(..., description="Unix timestamp of creation.")
+ status: Literal["completed", "failed", "in_progress", "incomplete"] = Field(
+ "completed", description="The status of the response generation."
+ )
+ error: Optional[ResponseErrorObject] = Field(None)
+ incomplete_details: Optional[ResponseIncompleteDetails] = Field(None)
+ instructions: Optional[str] = Field(None)
+ max_output_tokens: Optional[int] = Field(None)
+ model: str = Field(..., description="Model ID used to generate the response.")
+ output: List[Union[ResponseMessageItem, ResponseFunctionCallItem]] = Field(
+ default_factory=list,
+ description="An array of content items generated by the model.",
+ )
+ parallel_tool_calls: bool = Field(True)
+ previous_response_id: Optional[str] = Field(None)
+ temperature: Optional[float] = Field(None, ge=0, le=2)
+ top_p: Optional[float] = Field(None, ge=0, le=1)
+ tools: List = Field(default_factory=list)
+ tool_choice: Optional[Any] = Field("auto")
+ truncation: Literal["auto", "disabled"] = Field("disabled")
+ metadata: Optional[dict] = Field(None)
+ usage: ResponseUsage = Field(..., description="Token usage details.")
+ user: Optional[str] = Field(None)
+
+ @property
+ def output_text(self) -> str:
+ """Aggregate text from all output_text content parts."""
+ parts = []
+ for item in self.output:
+ if isinstance(item, ResponseMessageItem):
+ for part in item.content:
+ if part.type == "output_text" and part.text:
+ parts.append(part.text)
+ return "".join(parts) or ""
+
+
+# ---------------------------------------------------------------------------
+# Streaming event models
+# ---------------------------------------------------------------------------
+
+
+class BaseStreamEvent(BaseModel):
+ """Base class for all SSE streaming events."""
+
+ type: str
+ sequence_number: int = 0
+
+
+class ResponseCreatedEvent(BaseStreamEvent):
+ type: Literal["response.created"] = "response.created"
+ response: ResponseObject
+
+
+class ResponseInProgressEvent(BaseStreamEvent):
+ type: Literal["response.in_progress"] = "response.in_progress"
+ response: ResponseObject
+
+
+class ResponseOutputItemAddedEvent(BaseStreamEvent):
+ type: Literal["response.output_item.added"] = "response.output_item.added"
+ output_index: int
+ item: Union[ResponseMessageItem, ResponseFunctionCallItem]
+
+
+class ResponseContentPartAddedEvent(BaseStreamEvent):
+ type: Literal["response.content_part.added"] = "response.content_part.added"
+ item_id: str
+ output_index: int
+ content_index: int
+ part: ContentPartOutputText
+
+
+class ResponseOutputTextDeltaEvent(BaseStreamEvent):
+ type: Literal["response.output_text.delta"] = "response.output_text.delta"
+ item_id: str
+ output_index: int
+ content_index: int
+ delta: str
+
+
+class ResponseOutputTextDoneEvent(BaseStreamEvent):
+ type: Literal["response.output_text.done"] = "response.output_text.done"
+ item_id: str
+ output_index: int
+ content_index: int
+ text: str
+
+
+class ResponseContentPartDoneEvent(BaseStreamEvent):
+ type: Literal["response.content_part.done"] = "response.content_part.done"
+ item_id: str
+ output_index: int
+ content_index: int
+ part: ContentPartOutputText
+
+
+class ResponseOutputItemDoneEvent(BaseStreamEvent):
+ type: Literal["response.output_item.done"] = "response.output_item.done"
+ output_index: int
+ item: Union[ResponseMessageItem, ResponseFunctionCallItem]
+
+
+class ResponseFunctionCallArgumentsDeltaEvent(BaseStreamEvent):
+ type: Literal["response.function_call_arguments.delta"] = (
+ "response.function_call_arguments.delta"
+ )
+ item_id: str
+ output_index: int
+ delta: str
+
+
+class ResponseFunctionCallArgumentsDoneEvent(BaseStreamEvent):
+ type: Literal["response.function_call_arguments.done"] = (
+ "response.function_call_arguments.done"
+ )
+ item_id: str
+ output_index: int
+ arguments: str
+
+
+class ResponseCompletedEvent(BaseStreamEvent):
+ type: Literal["response.completed"] = "response.completed"
+ response: ResponseObject
+
+
+StreamEvent = Union[
+ ResponseCreatedEvent,
+ ResponseInProgressEvent,
+ ResponseOutputItemAddedEvent,
+ ResponseContentPartAddedEvent,
+ ResponseOutputTextDeltaEvent,
+ ResponseOutputTextDoneEvent,
+ ResponseContentPartDoneEvent,
+ ResponseOutputItemDoneEvent,
+ ResponseFunctionCallArgumentsDeltaEvent,
+ ResponseFunctionCallArgumentsDoneEvent,
+ ResponseCompletedEvent,
+]
diff --git a/mlx_vlm/responses_store.py b/mlx_vlm/responses_store.py
new file mode 100644
index 000000000..a543cebe7
--- /dev/null
+++ b/mlx_vlm/responses_store.py
@@ -0,0 +1,127 @@
+"""LRU response store for OpenAI Responses API previous_response_id support."""
+
+import threading
+from collections import OrderedDict
+from typing import Any, Optional
+
+
+class ResponseStore:
+ """Bounded LRU store mapping response IDs to (input_items, response_object) pairs.
+
+ Used to support the ``previous_response_id`` parameter in the Responses API,
+ which allows clients to chain responses without resending full conversation
+ history.
+
+ Args:
+ maxsize: Maximum number of responses to store. When exceeded, the oldest
+ entry is evicted. Defaults to 256.
+ """
+
+ def __init__(self, maxsize: int = 256):
+ self._store: OrderedDict[str, dict] = OrderedDict()
+ self._maxsize = maxsize
+ self._lock = threading.Lock()
+
+ def save(
+ self,
+ response_id: str,
+ input_items: Any,
+ response_output: list,
+ ) -> None:
+ """Save a response for later replay.
+
+ Args:
+ response_id: The unique response ID (e.g., ``"resp_abc123"``).
+ input_items: The original request input (string or list of input items).
+ response_output: The response output items list (dicts or model instances).
+ """
+ with self._lock:
+ if response_id in self._store:
+ self._store.move_to_end(response_id)
+ self._store[response_id] = {
+ "input": input_items,
+ "output": response_output,
+ }
+ while len(self._store) > self._maxsize:
+ self._store.popitem(last=False)
+
+ def get(self, response_id: str) -> Optional[dict]:
+ """Retrieve a stored response by ID.
+
+ Args:
+ response_id: The response ID to look up.
+
+ Returns:
+ Dict with ``"input"`` and ``"output"`` keys, or ``None`` if not found.
+ """
+ with self._lock:
+ entry = self._store.get(response_id)
+ if entry is not None:
+ self._store.move_to_end(response_id)
+ return entry
+
+ def replay_input(self, response_id: str) -> Optional[list]:
+ """Build conversation input by replaying a previous response.
+
+ Reconstructs input items from the stored response: the original input
+ items followed by the output items converted to input format.
+
+ Args:
+ response_id: The previous response ID to replay.
+
+ Returns:
+ List of input items suitable for prepending to the current request,
+ or ``None`` if the response ID is not found.
+ """
+ entry = self.get(response_id)
+ if entry is None:
+ return None
+
+ items = []
+
+ # Add original input items
+ original_input = entry["input"]
+ if isinstance(original_input, str):
+ items.append({"role": "user", "content": original_input})
+ elif isinstance(original_input, list):
+ items.extend(original_input)
+
+ # Convert output items to input format
+ for output_item in entry.get("output", []):
+ if isinstance(output_item, dict):
+ item_type = output_item.get("type", "")
+ if item_type == "message":
+ content = output_item.get("content", [])
+ for part in content:
+ if isinstance(part, dict) and part.get("type") == "output_text":
+ items.append(
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "output_text",
+ "text": part.get("text", ""),
+ }
+ ],
+ }
+ )
+ elif item_type == "function_call":
+ items.append(
+ {
+ "type": "function_call",
+ "call_id": output_item.get("call_id", ""),
+ "name": output_item.get("name", ""),
+ "arguments": output_item.get("arguments", ""),
+ }
+ )
+
+ return items
+
+ def __len__(self) -> int:
+ with self._lock:
+ return len(self._store)
+
+ def clear(self) -> None:
+ """Remove all stored responses."""
+ with self._lock:
+ self._store.clear()
diff --git a/mlx_vlm/server.py b/mlx_vlm/server.py
index 04c011f6a..6dcf55a99 100644
--- a/mlx_vlm/server.py
+++ b/mlx_vlm/server.py
@@ -1,4 +1,5 @@
import argparse
+import asyncio
import gc
import json
import os
@@ -7,7 +8,6 @@
import traceback
import uuid
from contextlib import asynccontextmanager
-from datetime import datetime
from typing import Any, List, Literal, Optional, Union
import mlx.core as mx
@@ -31,11 +31,45 @@
DEFAULT_THINKING_END_TOKEN,
DEFAULT_THINKING_START_TOKEN,
DEFAULT_TOP_P,
+ PromptCacheState,
generate,
normalize_resize_shape,
stream_generate,
)
from .prompt_utils import apply_chat_template
+from .responses_models import ContentPartOutputText as ResponseContentPartOutputText
+from .responses_models import InputTokensDetails
+from .responses_models import ResponseCompletedEvent as ResponsesCompletedEvent
+from .responses_models import (
+ ResponseContentPartAddedEvent as ResponsesContentPartAddedEvent,
+)
+from .responses_models import (
+ ResponseContentPartDoneEvent as ResponsesContentPartDoneEvent,
+)
+from .responses_models import ResponseCreatedEvent as ResponsesCreatedEvent
+from .responses_models import (
+ ResponseFunctionCallArgumentsDeltaEvent as ResponsesFunctionCallArgumentsDeltaEvent,
+)
+from .responses_models import (
+ ResponseFunctionCallArgumentsDoneEvent as ResponsesFunctionCallArgumentsDoneEvent,
+)
+from .responses_models import ResponseFunctionCallItem, ResponseIncompleteDetails
+from .responses_models import ResponseInProgressEvent as ResponsesInProgressEvent
+from .responses_models import ResponseMessageItem, ResponseObject
+from .responses_models import (
+ ResponseOutputItemAddedEvent as ResponsesOutputItemAddedEvent,
+)
+from .responses_models import (
+ ResponseOutputItemDoneEvent as ResponsesOutputItemDoneEvent,
+)
+from .responses_models import (
+ ResponseOutputTextDeltaEvent as ResponsesOutputTextDeltaEvent,
+)
+from .responses_models import (
+ ResponseOutputTextDoneEvent as ResponsesOutputTextDoneEvent,
+)
+from .responses_models import ResponsesRequest, ResponseUsage
+from .responses_store import ResponseStore
from .tool_parsers import _infer_tool_parser, load_tool_module
from .utils import load
from .version import __version__
@@ -44,11 +78,56 @@
DEFAULT_SERVER_HOST = "0.0.0.0"
DEFAULT_SERVER_PORT = 8080
+def _is_verbose() -> bool:
+ return os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes")
+
+
+class _VerboseFlag:
+ """Lazy flag that checks env var on each access, so --verbose works after import."""
+ def __bool__(self) -> bool:
+ return _is_verbose()
+
+
+_verbose = _VerboseFlag()
+
+
+def get_default_max_tokens() -> int:
+ """Server-side default max tokens for API responses.
+ The upstream generate.py default (256) is too low for agentic use.
+ Configurable via --default-max-tokens CLI flag or DEFAULT_MAX_TOKENS env var."""
+ return int(os.environ.get("DEFAULT_MAX_TOKENS", DEFAULT_MAX_TOKENS))
+
+_responses_store = ResponseStore()
+
def get_prefill_step_size():
return int(os.environ.get("PREFILL_STEP_SIZE", DEFAULT_PREFILL_STEP_SIZE))
+def get_max_context_tokens() -> int:
+ """Maximum prompt tokens before rejecting a request. 0 = no limit."""
+ return int(os.environ.get("MAX_CONTEXT_TOKENS", 0))
+
+
+def get_request_timeout() -> int:
+ """Maximum seconds for a generation request. 0 = no timeout."""
+ return int(os.environ.get("REQUEST_TIMEOUT", 300))
+
+
+def check_context_length(prompt: str, processor, max_context: int) -> None:
+ """Raise HTTP 400 if the tokenized prompt exceeds *max_context* tokens."""
+ if max_context <= 0:
+ return
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
+ token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
+ if token_count > max_context:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Prompt length ({token_count} tokens) exceeds maximum context "
+ f"window ({max_context} tokens).",
+ )
+
+
def get_quantized_kv_bits(model: str):
kv_bits = float(os.environ.get("KV_BITS", 0))
if kv_bits == 0:
@@ -85,6 +164,16 @@ def get_quantized_kv_start():
return int(os.environ.get("QUANTIZED_KV_START", DEFAULT_QUANTIZED_KV_START))
+async def _prompt_cache_cleanup_loop():
+ """Background task that periodically evicts stale prompt caches."""
+ while True:
+ await asyncio.sleep(60)
+ try:
+ evict_stale_prompt_caches()
+ except Exception as e:
+ print(f"[prompt_cache] Cleanup error: {e}")
+
+
@asynccontextmanager
async def lifespan(app):
# Startup
@@ -97,7 +186,24 @@ async def lifespan(app):
except Exception as e:
print(f"Failed to preload model: {e}")
print("Server will continue without a preloaded model.")
+
+ # Start prompt cache cleanup task
+ ttl = get_prompt_cache_ttl()
+ cleanup_task = None
+ if ttl > 0:
+ cleanup_task = asyncio.create_task(_prompt_cache_cleanup_loop())
+ if _verbose:
+ print(f"[prompt_cache] Cleanup task started (TTL={ttl}s, check every 60s)")
+
yield
+
+ # Shutdown
+ if cleanup_task is not None:
+ cleanup_task.cancel()
+ try:
+ await cleanup_task
+ except asyncio.CancelledError:
+ pass
unload_model_sync()
@@ -122,6 +228,112 @@ async def lifespan(app):
model_cache = {}
+# Prompt cache: reuse KV state across requests with the same prompt prefix.
+# Keyed by (model_name, cache_key) — supports both:
+# - OpenClaw: sends `prompt_cache_key` for per-session routing
+# - Hermes: relies on stable system prompt prefix for automatic matching
+# When no cache_key is provided, falls back to model name only.
+DEFAULT_PROMPT_CACHE_TTL = 300 # seconds
+
+
+def get_prompt_cache_ttl() -> int:
+ """Prompt cache TTL in seconds. 0 = no expiry."""
+ return int(os.environ.get("PROMPT_CACHE_TTL", DEFAULT_PROMPT_CACHE_TTL))
+
+
+_PROMPT_CACHE_MAX_ENTRIES = 64
+_prompt_cache_states: dict[str, PromptCacheState] = {}
+
+# Concurrency guard: MLX generation is single-threaded on Metal.
+# Concurrent requests would corrupt shared GPU state. The semaphore
+# serializes access to the generation pipeline.
+_generation_semaphore: Optional[asyncio.Semaphore] = None
+
+
+def get_max_concurrent_requests() -> int:
+ return int(os.environ.get("MAX_CONCURRENT_REQUESTS", 1))
+
+
+def get_generation_semaphore() -> asyncio.Semaphore:
+ """Get or create the generation semaphore."""
+ global _generation_semaphore
+ if _generation_semaphore is None:
+ _generation_semaphore = asyncio.Semaphore(get_max_concurrent_requests())
+ return _generation_semaphore
+
+
+async def acquire_semaphore() -> asyncio.Semaphore:
+ """Acquire the generation semaphore with timeout. Returns the semaphore for release."""
+ sem = get_generation_semaphore()
+ timeout = get_request_timeout()
+ try:
+ if timeout > 0:
+ await asyncio.wait_for(sem.acquire(), timeout=timeout)
+ else:
+ await sem.acquire()
+ except asyncio.TimeoutError:
+ raise HTTPException(
+ status_code=503,
+ detail="Server busy: generation request timed out waiting for GPU.",
+ )
+ return sem
+
+
+def get_prompt_cache_state(
+ model_name: str,
+ cache_key: Optional[str] = None,
+) -> PromptCacheState:
+ """Get or create a PromptCacheState for the given model and cache key.
+
+ Supports two caching patterns:
+
+ **OpenClaw**: Sends ``prompt_cache_key`` per session so that requests
+ from the same conversation share a KV cache. The system prompt prefix
+ is stable across turns, so prefix matching works.
+
+ **Hermes**: Relies on a stable system prompt and ``cache_control``
+ breakpoints. No ``prompt_cache_key`` is sent, so we fall back to
+ a single cache per model. The ``PromptCacheState.find_prefix_length``
+ in generate.py will still match the common system-prompt prefix.
+
+ Args:
+ model_name: The model identifier.
+ cache_key: Optional routing key (e.g., ``prompt_cache_key`` from the
+ request). When provided, each key gets its own cache state.
+ """
+ key = f"{model_name}::{cache_key}" if cache_key else model_name
+ # Evict LRU entry if at capacity and key is new
+ if key not in _prompt_cache_states and len(_prompt_cache_states) >= _PROMPT_CACHE_MAX_ENTRIES:
+ lru_key = min(_prompt_cache_states, key=lambda k: _prompt_cache_states[k].last_used)
+ evicted = _prompt_cache_states.pop(lru_key)
+ evicted.invalidate()
+ state = _prompt_cache_states.setdefault(key, PromptCacheState())
+ state.touch()
+ return state
+
+
+def evict_stale_prompt_caches() -> int:
+ """Remove prompt cache entries that exceed the TTL. Returns count evicted."""
+ ttl = get_prompt_cache_ttl()
+ if ttl <= 0:
+ return 0
+ now = time.time()
+ stale_keys = [
+ k for k, v in _prompt_cache_states.items()
+ if (now - v.last_used) > ttl
+ ]
+ for k in stale_keys:
+ entry = _prompt_cache_states.pop(k)
+ tokens = entry.token_count
+ idle = now - entry.last_used
+ entry.invalidate()
+ if _verbose:
+ print(f"[prompt_cache] Evicted '{k}' ({tokens} tokens, idle {idle:.0f}s)")
+ if stale_keys:
+ gc.collect()
+ mx.clear_cache()
+ return len(stale_keys)
+
class FlexibleBaseModel(BaseModel):
"""Base model that ignores/accepts any unknown OpenAI SDK fields."""
@@ -204,6 +416,8 @@ def unload_model_sync():
if "vision_cache" in model_cache:
model_cache["vision_cache"].clear()
model_cache = {}
+ # Clear prompt cache states
+ _prompt_cache_states.clear()
# Force garbage collection
gc.collect()
mx.clear_cache()
@@ -376,12 +590,16 @@ class OpenAIRequest(GenerationParams, TemplateParams):
)
model: str = Field(..., description="The model to use for generation.")
max_output_tokens: int = Field(
- DEFAULT_MAX_TOKENS,
+ None,
description="Maximum number of tokens to generate.",
)
stream: bool = Field(
False, description="Whether to stream the response chunk by chunk."
)
+ stop: Optional[Union[str, List[str]]] = Field(
+ None,
+ description="Up to 4 sequences where the API will stop generating further tokens.",
+ )
def generation_kwargs(self) -> dict[str, Any]:
kwargs = self.dump_kwargs("max_output_tokens")
@@ -554,11 +772,15 @@ class VLMRequest(GenerationParams, TemplateParams):
adapter_path: Optional[str] = Field(
None, description="The path to the adapter weights."
)
- max_tokens: int = Field(
- DEFAULT_MAX_TOKENS,
- description="Maximum number of tokens to generate.",
+ max_tokens: Optional[int] = Field(
+ None,
+ description="Maximum number of tokens to generate. Uses server default if not specified.",
)
seed: int = Field(DEFAULT_SEED, description="Seed for random generation.")
+ stop: Optional[Union[str, List[str]]] = Field(
+ None,
+ description="Up to 4 sequences where the API will stop generating further tokens.",
+ )
resize_shape: Optional[ResizeShapeInput] = Field(
None,
description="Resize shape for the image. Provide one integer for a square resize or two integers for (height, width).",
@@ -602,11 +824,40 @@ class UsageStats(OpenAIUsage):
class ChatRequest(GenerationRequest):
messages: List[ChatMessage]
+ logprobs: Optional[bool] = Field(
+ None, description="Whether to return log probabilities."
+ )
+ top_logprobs: Optional[int] = Field(
+ None,
+ ge=0,
+ le=20,
+ description="Number of most likely tokens to return at each position.",
+ )
+
+ @field_validator("top_logprobs")
+ @classmethod
+ def validate_top_logprobs_supported(cls, value):
+ if value is not None:
+ raise ValueError(
+ "`top_logprobs` is not supported by this server and must be omitted."
+ )
+ return value
+
+
+class TokenLogprob(BaseModel):
+ token: str
+ logprob: float
+ bytes: Optional[List[int]] = None
+
+
+class ChoiceLogprobs(BaseModel):
+ content: Optional[List[TokenLogprob]] = None
class ChatChoice(BaseModel):
finish_reason: str
message: ChatMessage
+ logprobs: Optional[ChoiceLogprobs] = None
class ChatResponse(BaseModel):
@@ -619,6 +870,7 @@ class ChatStreamChoice(BaseModel):
index: int = 0
finish_reason: Optional[str] = None
delta: ChatMessage
+ logprobs: Optional[ChoiceLogprobs] = None
class ChatStreamChunk(BaseModel):
@@ -630,10 +882,106 @@ class ChatStreamChunk(BaseModel):
usage: Optional[UsageStats]
+def resolve_stop_sequences(
+ stop: Optional[Union[str, list]],
+) -> Optional[list]:
+ """Normalize stop sequences for the generation stopping criteria.
+
+ The generation pipeline's ``add_eos_token_ids`` accepts strings
+ and handles tokenization internally.
+
+ Args:
+ stop: A single stop string or list of stop strings, or None.
+
+ Returns:
+ A list of stop strings (max 4), or None.
+ """
+ if not stop:
+ return None
+ if isinstance(stop, str):
+ stop = [stop]
+ sequences = [s for s in stop[:4] if isinstance(s, str) and s]
+ return sequences if sequences else None
+
+
+def resolve_tool_choice(
+ tools: Optional[list],
+ tool_choice: Optional[Any],
+) -> tuple[Optional[list], Optional[str]]:
+ """Apply tool_choice policy to the tools list.
+
+ Args:
+ tools: The original tools list from the request.
+ tool_choice: ``"none"``, ``"auto"``, ``"required"``, or a dict
+ specifying a particular tool.
+
+ Returns:
+ Tuple of ``(filtered_tools, system_instruction)``.
+ """
+ if not tools or tool_choice is None or tool_choice == "auto":
+ return tools, None
+
+ if tool_choice == "none":
+ return None, None
+
+ if tool_choice == "required":
+ return tools, "You must call one of the available tools to answer this request."
+
+ if isinstance(tool_choice, dict):
+ func = tool_choice.get("function", {})
+ name = func.get("name") if isinstance(func, dict) else None
+ if name:
+ filtered = [
+ t
+ for t in tools
+ if (t.get("function", {}) or {}).get("name") == name
+ or t.get("name") == name
+ ]
+ return (
+ filtered or tools,
+ f'You must call the "{name}" tool to answer this request.',
+ )
+
+ return tools, None
+
+
+def resolve_response_format(
+ messages: list,
+ response_format: Optional[dict],
+) -> list:
+ """Inject JSON instruction if json_object format is requested."""
+ if not response_format:
+ return messages
+ fmt_type = response_format.get("type", "text")
+ if fmt_type == "json_object":
+ messages.insert(
+ 0,
+ {
+ "role": "system",
+ "content": "You must respond with valid JSON only. Do not include any text outside the JSON object.",
+ },
+ )
+ return messages
+
+
+DEFAULT_REPETITION_PENALTY = 1.1
+
+
def build_generation_kwargs(
request: Any,
template_kwargs: dict[str, Any],
) -> dict[str, Any]:
+ gen_kwargs = request.generation_kwargs()
+ # Apply server-side default max_tokens if not specified in request.
+ default_max = get_default_max_tokens()
+ if "max_tokens" not in gen_kwargs or gen_kwargs["max_tokens"] is None:
+ gen_kwargs["max_tokens"] = default_max
+ # Apply server-side default repetition penalty if not specified in request.
+ # Prevents MoE models from degenerating into repetition loops.
+ if "repetition_penalty" not in gen_kwargs or gen_kwargs["repetition_penalty"] is None:
+ default_rp = float(os.environ.get("DEFAULT_REPETITION_PENALTY", DEFAULT_REPETITION_PENALTY))
+ if default_rp > 0:
+ gen_kwargs["repetition_penalty"] = default_rp
return {
"prefill_step_size": get_prefill_step_size(),
"kv_bits": get_quantized_kv_bits(request.model),
@@ -641,7 +989,7 @@ def build_generation_kwargs(
"kv_quant_scheme": get_kv_quant_scheme(),
"max_kv_size": get_max_kv_size(request.model),
"quantized_kv_start": get_quantized_kv_start(),
- **request.generation_kwargs(),
+ **gen_kwargs,
**template_kwargs,
}
@@ -704,273 +1052,628 @@ class ModelsResponse(BaseModel):
data: List[ModelInfo]
-# OpenAI compatile endpoints
+# ---------------------------------------------------------------------------
+# Responses API helpers
+# ---------------------------------------------------------------------------
-@app.post("/responses")
-@app.post("/v1/responses", include_in_schema=False)
-async def responses_endpoint(openai_request: OpenAIRequest):
+def responses_input_to_messages(
+ input_items: Union[str, list],
+ instructions: Optional[str] = None,
+ previous_response_id: Optional[str] = None,
+) -> tuple[list[dict], list[str]]:
+ """Convert Responses API input items to chat messages and images.
+
+ Args:
+ input_items: String input or list of input items.
+ instructions: Optional system instructions to prepend.
+ previous_response_id: Optional previous response ID for context replay.
+
+ Returns:
+ Tuple of (chat_messages, image_urls).
"""
- OpenAI-compatible endpoint for generating text based on a prompt and optional images.
+ chat_messages: list[dict] = []
+ images: list[str] = []
+
+ # Replay previous response context
+ if previous_response_id:
+ replayed = _responses_store.replay_input(previous_response_id)
+ if replayed is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Previous response not found: {previous_response_id}",
+ )
+ # Recursively process replayed items
+ prev_messages, prev_images = responses_input_to_messages(replayed)
+ chat_messages.extend(prev_messages)
+ images.extend(prev_images)
+
+ # Prepend instructions as system message
+ if instructions:
+ chat_messages.insert(0, {"role": "system", "content": instructions})
+
+ # Handle string input
+ if isinstance(input_items, str):
+ chat_messages.append({"role": "user", "content": input_items})
+ return chat_messages, images
+
+ # Handle list of input items
+ for item in input_items:
+ if isinstance(item, dict):
+ item_type = item.get("type", "")
+ role = item.get("role", "")
+
+ # Function call output item
+ if item_type == "function_call_output":
+ call_id = item.get("call_id", "unknown")
+ output = item.get("output", "")
+ chat_messages.append(
+ {
+ "role": "tool",
+ "content": output,
+ "tool_call_id": call_id,
+ }
+ )
+ continue
- using client.responses.create method.
+ # Function call item (from previous assistant turn)
+ if item_type == "function_call":
+ chat_messages.append(
+ {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [
+ {
+ "id": item.get("call_id", ""),
+ "type": "function",
+ "function": {
+ "name": item.get("name", ""),
+ "arguments": item.get("arguments", ""),
+ },
+ }
+ ],
+ }
+ )
+ continue
+
+ # Regular message with role and content
+ if role:
+ content = item.get("content", "")
+
+ # Normalize developer role to system
+ msg_role = "system" if role == "developer" else role
+
+ if isinstance(content, str):
+ chat_messages.append({"role": msg_role, "content": content})
+ elif isinstance(content, list):
+ # Process content items
+ text_parts = []
+ for ci in content:
+ if isinstance(ci, dict):
+ ci_type = ci.get("type", "")
+ if ci_type in ("input_text", "text"):
+ text_parts.append(ci.get("text", ""))
+ elif ci_type == "input_image":
+ images.append(ci.get("image_url", ""))
+ elif ci_type == "image_url":
+ img = ci.get("image_url", {})
+ if isinstance(img, dict):
+ images.append(img.get("url", ""))
+ elif isinstance(img, str):
+ images.append(img)
+ elif ci_type == "output_text":
+ # Multi-turn: previous assistant output
+ chat_messages.append(
+ {
+ "role": "assistant",
+ "content": ci.get("text", ""),
+ }
+ )
+ elif ci_type == "input_audio":
+ pass # Audio not yet supported in responses
+ else:
+ pass # Skip unsupported content types gracefully
+
+ if text_parts:
+ chat_messages.append(
+ {
+ "role": msg_role,
+ "content": "\n".join(text_parts),
+ }
+ )
+ else:
+ chat_messages.append(
+ {
+ "role": msg_role,
+ "content": str(content) if content else "",
+ }
+ )
+ continue
+
+ # Handle Pydantic ChatMessage objects
+ elif hasattr(item, "role"):
+ role = item.role
+ msg_role = "system" if role == "developer" else role
+ content = item.content
+
+ if content is None:
+ chat_messages.append({"role": msg_role, "content": ""})
+ elif isinstance(content, str):
+ chat_messages.append({"role": msg_role, "content": content})
+ elif isinstance(content, list):
+ text_parts = []
+ for ci in content:
+ if isinstance(ci, dict):
+ ci_type = ci.get("type", "")
+ if ci_type in ("input_text", "text"):
+ text_parts.append(ci.get("text", ""))
+ elif ci_type == "input_image":
+ images.append(ci.get("image_url", ""))
+ elif ci_type == "image_url":
+ img = ci.get("image_url", {})
+ if isinstance(img, dict):
+ images.append(img.get("url", ""))
+ elif isinstance(img, str):
+ images.append(img)
+ elif ci_type == "output_text":
+ chat_messages.append(
+ {
+ "role": "assistant",
+ "content": ci.get("text", ""),
+ }
+ )
- example:
+ if text_parts:
+ chat_messages.append(
+ {
+ "role": msg_role,
+ "content": "\n".join(text_parts),
+ }
+ )
- from openai import OpenAI
+ return chat_messages, images
- API_URL = "http://0.0.0.0:8000"
- API_KEY = 'any'
- def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, model="mlx-community/Qwen2.5-VL-3B-Instruct-8bit"):
- ''' Calls the OpenAI API
- '''
+def build_responses_output(
+ raw_text: str,
+ tool_parser_type: Optional[str],
+ tool_module: Optional[Any],
+ tools: Optional[list],
+) -> list[Union[ResponseMessageItem, ResponseFunctionCallItem]]:
+ """Build structured Responses API output items from raw model text.
- client = OpenAI(base_url=f"{API_URL}", api_key=API_KEY)
+ Parses tool calls from the raw text if a tool parser is available,
+ creating ResponseFunctionCallItem for each detected call and a
+ ResponseMessageItem for any remaining text.
- try :
- response = client.responses.create(
- model=model,
- input=[
- {"role":"system",
- "content": f"{system}"
- },
- {
- "role": "user",
- "content": [
- {"type": "input_text", "text": prompt},
- {"type": "input_image", "image_url": f"{img_url}"},
- ],
- }
- ],
- max_output_tokens=max_output_tokens,
- stream=stream
- )
- if not stream:
- print(response.output[0].content[0].text)
- print(response.usage)
- else:
- for event in response:
- # Process different event types if needed
- if hasattr(event, 'delta') and event.delta:
- print(event.delta, end="", flush=True)
- elif event.type == 'response.completed':
- print("\n--- Usage ---")
- print(event.response.usage)
+ Args:
+ raw_text: The raw text output from the model.
+ tool_parser_type: The detected tool parser type (e.g., "gemma4"), or None.
+ tool_module: The loaded tool parser module, or None.
+ tools: The tool definitions from the request, or None.
- except Exception as e:
- # building a response object to match the one returned when request is successful so that it can be processed in the same way
- return {"model - error":str(e),"content":{}, "model":model}
+ Returns:
+ List of output items (message items and/or function call items).
+ """
+ output_items: list[Union[ResponseMessageItem, ResponseFunctionCallItem]] = []
+ remaining_text = raw_text
+ # Try to parse tool calls
+ if tool_parser_type and tool_module and tools:
+ try:
+ result = process_tool_calls(raw_text, tool_module, tools)
+ if result["calls"]:
+ for call in result["calls"]:
+ func_info = call.get("function", {})
+ output_items.append(
+ ResponseFunctionCallItem(
+ name=func_info.get("name", ""),
+ arguments=func_info.get("arguments", "{}"),
+ call_id=call.get("id", f"call_{uuid.uuid4().hex[:24]}"),
+ )
+ )
+ remaining_text = result.get("remaining_text", "").strip()
+ except Exception:
+ # If tool parsing fails, fall through to plain text
+ remaining_text = raw_text
+
+ # Create message item for any remaining text
+ if remaining_text or not output_items:
+ msg_item = ResponseMessageItem(
+ content=(
+ [ResponseContentPartOutputText(text=remaining_text)]
+ if remaining_text
+ else []
+ ),
+ )
+ # Insert message before function calls (matching OpenAI ordering)
+ output_items.insert(0, msg_item)
+
+ return output_items
+
+
+# OpenAI compatible endpoints
+
+
+@app.post("/responses")
+@app.post("/v1/responses", include_in_schema=False)
+async def responses_endpoint(request: ResponsesRequest):
+ """OpenAI-compatible Responses API endpoint.
+
+ Supports tool calling, multi-turn via previous_response_id, and streaming
+ with proper SSE event sequences including function_call argument events.
"""
+ # Resolve default max tokens if not specified in request
+ if request.max_output_tokens is None:
+ request.max_output_tokens = get_default_max_tokens()
try:
# Get model, processor, config - loading if necessary
- model, processor, config = get_cached_model(openai_request.model)
+ model, processor, config = get_cached_model(request.model)
+
+ # Debug: log incoming request details
+ _tools_count = len(request.tools) if request.tools else 0
+ _tool_names = [t.get("name", t.get("function", {}).get("name", "?")) if isinstance(t, dict) else "?" for t in (request.tools or [])]
+ _input_len = len(str(request.input))
+ _instructions_len = len(request.instructions) if request.instructions else 0
+ if _verbose:
+ print(f"[responses] tools={_tools_count} names={_tool_names} stream={request.stream} input_chars={_input_len} instructions_chars={_instructions_len}")
+
+ # Convert input to chat messages
+ chat_messages, images = responses_input_to_messages(
+ request.input,
+ instructions=request.instructions,
+ previous_response_id=request.previous_response_id,
+ )
- chat_messages = []
- images = []
- instructions = None
- if openai_request.input:
- if isinstance(openai_request.input, str):
- # If input is a string, treat it as a single text message
- chat_messages.append({"role": "user", "content": openai_request.input})
- elif isinstance(openai_request.input, list):
- # If input is a list, treat it as a series of chat messages
- for message in openai_request.input:
- if isinstance(message, ChatMessage):
- if message.content is None:
- chat_messages.append({"role": message.role, "content": ""})
- elif isinstance(message.content, str):
- chat_messages.append(
- {"role": message.role, "content": message.content}
- )
- if message.role == "system":
- instructions = message.content
- elif isinstance(message.content, list):
- # Handle list of content items
- for item in message.content:
- if isinstance(item, dict):
- if item["type"] == "input_text":
- chat_messages.append(
- {
- "role": message.role,
- "content": item["text"],
- }
- )
- if message.role == "system":
- instructions = item["text"]
- # examples for multiple images (https://platform.openai.com/docs/guides/images?api-mode=responses)
- elif item["type"] == "input_image":
- images.append(item["image_url"])
- else:
- print(
- f"invalid input item type: {item['type']}"
- )
- raise HTTPException(
- status_code=400,
- detail="Invalid input item type.",
- )
- else:
- print(
- f"Invalid message content item format: {item}"
- )
- raise HTTPException(
- status_code=400,
- detail="Missing type in input item.",
- )
- else:
- print("Invalid message content format.")
- raise HTTPException(
- status_code=400, detail="Invalid input format."
- )
- else:
- print("not a ChatMessage")
- raise HTTPException(
- status_code=400, detail="Invalid input format."
- )
- else:
- print("neither string not list")
- raise HTTPException(status_code=400, detail="Invalid input format.")
+ # Apply JSON mode if requested
+ response_format = getattr(request, "response_format", None)
+ chat_messages = resolve_response_format(chat_messages, response_format)
- else:
- print("no input")
- raise HTTPException(status_code=400, detail="Missing input.")
+ # Set up tool parser (apply tool_choice policy)
+ tools = request.tools
+ tool_choice_val = getattr(request, "tool_choice", "auto")
+ tools, tool_instruction = resolve_tool_choice(tools, tool_choice_val)
+ if tool_instruction:
+ chat_messages.insert(0, {"role": "system", "content": tool_instruction})
+
+ tool_parser_type = None
+ tool_module = None
+ tokenizer = (
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor
+ )
+ if hasattr(tokenizer, "chat_template") and tools:
+ tool_parser_type = _infer_tool_parser(tokenizer.chat_template)
+ if tool_parser_type is not None:
+ tool_module = load_tool_module(tool_parser_type)
+ if _verbose:
+ print(f"[responses] tool_parser={tool_parser_type} tool_module={'yes' if tool_module else 'no'} tools_after_choice={len(tools) if tools else 0}")
+
+ # Build template kwargs
+ template_kwargs = request.template_kwargs()
- template_kwargs = openai_request.template_kwargs()
+ # Apply chat template (pass tools so the template can include tool defs)
formatted_prompt = apply_chat_template(
processor,
config,
chat_messages,
num_images=len(images),
+ tools=tools,
**template_kwargs,
)
- generation_kwargs = build_generation_kwargs(openai_request, template_kwargs)
+ generation_kwargs = build_generation_kwargs(request, template_kwargs)
- generated_at = datetime.now().timestamp()
- response_id = f"resp_{uuid.uuid4().hex}"
- message_id = f"msg_{uuid.uuid4().hex}"
+ if _verbose:
+ _prompt_str = formatted_prompt if isinstance(formatted_prompt, str) else str(formatted_prompt)
+ print(f"[responses] prompt_chars={len(_prompt_str)} last_300=...{_prompt_str[-300:]!r}")
- if openai_request.stream:
+ check_context_length(formatted_prompt, processor, get_max_context_tokens())
+
+ # Resolve stop sequences to token IDs
+ stop_seqs = resolve_stop_sequences(getattr(request, "stop", None))
+ if stop_seqs:
+ generation_kwargs["eos_tokens"] = stop_seqs
+
+ generated_at = int(time.time())
+ response_id = f"resp_{uuid.uuid4().hex[:24]}"
+ message_id = f"msg_{uuid.uuid4().hex[:24]}"
+
+ if request.stream:
+ # ----------------------------------------------------------
# Streaming response
- async def stream_generator():
- token_iterator = None
+ # ----------------------------------------------------------
+ async def stream_responses_generator():
+ seq = 0 # sequence_number counter
+
+ def _evt(event_type: str, event_obj) -> str:
+ nonlocal seq
+ event_obj.sequence_number = seq
+ seq += 1
+ return (
+ f"event: {event_type}\ndata: {event_obj.model_dump_json()}\n\n"
+ )
+
+ sem = None
try:
- # Create base response object (to match the openai pipeline)
- base_response = OpenAIResponse(
+ sem = await acquire_semaphore()
+ # Build base ResponseObject (in_progress, empty output)
+ base_response = ResponseObject(
id=response_id,
- object="response",
- created_at=int(generated_at),
+ created_at=generated_at,
status="in_progress",
- instructions=instructions,
- max_output_tokens=openai_request.max_output_tokens,
- model=openai_request.model,
+ model=request.model,
output=[],
- output_text="",
- temperature=openai_request.temperature,
- top_p=openai_request.top_p,
- usage={
- "input_tokens": 0, # get prompt tokens
- "output_tokens": 0,
- "total_tokens": 0,
- },
+ instructions=request.instructions,
+ max_output_tokens=request.max_output_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ tools=tools or [],
+ tool_choice=request.tool_choice,
+ parallel_tool_calls=request.parallel_tool_calls,
+ previous_response_id=request.previous_response_id,
+ metadata=request.metadata,
+ usage=ResponseUsage(
+ input_tokens=0, output_tokens=0, total_tokens=0
+ ),
)
- # Send response.created event (to match the openai pipeline)
- yield f"event: response.created\ndata: {ResponseCreatedEvent(type='response.created', response=base_response).model_dump_json()}\n\n"
-
- # Send response.in_progress event (to match the openai pipeline)
- yield f"event: response.in_progress\ndata: {ResponseInProgressEvent(type='response.in_progress', response=base_response).model_dump_json()}\n\n"
-
- # Send response.output_item.added event (to match the openai pipeline)
- message_item = MessageItem(
- id=message_id,
- type="message",
- status="in_progress",
- role="assistant",
- content=[],
+ # response.created
+ yield _evt(
+ "response.created",
+ ResponsesCreatedEvent(response=base_response),
+ )
+ # response.in_progress
+ yield _evt(
+ "response.in_progress",
+ ResponsesInProgressEvent(response=base_response),
)
- yield f"event: response.output_item.added\ndata: {ResponseOutputItemAddedEvent(type='response.output_item.added', output_index=0, item=message_item).model_dump_json()}\n\n"
- # Send response.content_part.added event
- content_part = ContentPartOutputText(
- type="output_text", text="", annotations=[]
+ # output_item.added (message)
+ msg_item = ResponseMessageItem(
+ id=message_id, status="in_progress", content=[]
+ )
+ yield _evt(
+ "response.output_item.added",
+ ResponsesOutputItemAddedEvent(output_index=0, item=msg_item),
)
- yield f"event: response.content_part.added\ndata: {ResponseContentPartAddedEvent(type='response.content_part.added', item_id=message_id, output_index=0, content_index=0, part=content_part).model_dump_json()}\n\n"
- # Stream text deltas
+ # content_part.added
+ empty_part = ResponseContentPartOutputText(text="")
+ yield _evt(
+ "response.content_part.added",
+ ResponsesContentPartAddedEvent(
+ item_id=message_id,
+ output_index=0,
+ content_index=0,
+ part=empty_part,
+ ),
+ )
+ cache_state = get_prompt_cache_state(
+ request.model, getattr(request, "prompt_cache_key", None)
+ )
token_iterator = stream_generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
vision_cache=model_cache.get("vision_cache"),
+ prompt_cache_state=cache_state,
**generation_kwargs,
)
full_text = ""
+ visible_text = ""
+ usage_stats = {"input_tokens": 0, "output_tokens": 0}
+ in_tool_call = False
+ tool_call_start_tag = (
+ tool_module.tool_call_start if tool_module else ""
+ )
+
for chunk in token_iterator:
if chunk is None or not hasattr(chunk, "text"):
continue
delta = chunk.text
full_text += delta
-
usage_stats = {
"input_tokens": chunk.prompt_tokens,
"output_tokens": chunk.generation_tokens,
}
- # Send response.output_text.delta event
- yield f"event: response.output_text.delta\ndata: {ResponseOutputTextDeltaEvent(type='response.output_text.delta', item_id=message_id, output_index=0, content_index=0, delta=delta).model_dump_json()}\n\n"
+ # Suppress tool call tokens from being streamed as text
+ if not in_tool_call and tool_call_start_tag in full_text:
+ in_tool_call = True
+ if in_tool_call:
+ continue
- # Send response.output_text.done event (to match the openai pipeline)
- yield f"event: response.output_text.done\ndata: {ResponseOutputTextDoneEvent(type='response.output_text.done', item_id=message_id, output_index=0, content_index=0, text=full_text).model_dump_json()}\n\n"
+ # Check if this delta starts a tool call tag
+ # (partial match: buffer might end with "")
+ if tools and tool_call_start_tag[:1] in delta:
+ pending = full_text[
+ -(len(delta) + len(tool_call_start_tag)) :
+ ]
+ if any(
+ tool_call_start_tag[:i] == pending[-i:]
+ for i in range(2, len(tool_call_start_tag) + 1)
+ ):
+ continue
+
+ visible_text += delta
+ yield _evt(
+ "response.output_text.delta",
+ ResponsesOutputTextDeltaEvent(
+ item_id=message_id,
+ output_index=0,
+ content_index=0,
+ delta=delta,
+ ),
+ )
- # Send response.content_part.done event (to match the openai pipeline)
- final_content_part = ContentPartOutputText(
- type="output_text", text=full_text, annotations=[]
+ # Determine finish reason
+ max_tok = request.max_output_tokens
+ is_length = usage_stats["output_tokens"] >= max_tok
+ status = "incomplete" if is_length else "completed"
+
+ # Use visible_text (sans tool call markup) for text events
+ display_text = visible_text
+
+ # output_text.done
+ yield _evt(
+ "response.output_text.done",
+ ResponsesOutputTextDoneEvent(
+ item_id=message_id,
+ output_index=0,
+ content_index=0,
+ text=display_text,
+ ),
)
- yield f"event: response.content_part.done\ndata: {ResponseContentPartDoneEvent(type='response.content_part.done', item_id=message_id, output_index=0, content_index=0, part=final_content_part).model_dump_json()}\n\n"
- # Send response.output_item.done event (to match the openai pipeline)
- final_message_item = MessageItem(
+ # content_part.done
+ final_part = ResponseContentPartOutputText(text=display_text)
+ yield _evt(
+ "response.content_part.done",
+ ResponsesContentPartDoneEvent(
+ item_id=message_id,
+ output_index=0,
+ content_index=0,
+ part=final_part,
+ ),
+ )
+
+ # output_item.done (message)
+ final_msg = ResponseMessageItem(
id=message_id,
- type="message",
status="completed",
- role="assistant",
- content=[final_content_part],
+ content=[final_part],
)
- yield f"event: response.output_item.done\ndata: {ResponseOutputItemDoneEvent(type='response.output_item.done', output_index=0, item=final_message_item).model_dump_json()}\n\n"
+ yield _evt(
+ "response.output_item.done",
+ ResponsesOutputItemDoneEvent(output_index=0, item=final_msg),
+ )
+
+ # Collect all output items for final response
+ all_output_items: list = [final_msg]
- # Send response.completed event (to match the openai pipeline)
+ # Parse tool calls from accumulated text
+ if tool_parser_type and tool_module and tools:
+ try:
+ tc_result = process_tool_calls(
+ full_text, tool_module, tools
+ )
+ if tc_result["calls"]:
+ for idx, call in enumerate(tc_result["calls"]):
+ func_info = call.get("function", {})
+ fc_item = ResponseFunctionCallItem(
+ name=func_info.get("name", ""),
+ arguments=func_info.get("arguments", "{}"),
+ call_id=call.get(
+ "id", f"call_{uuid.uuid4().hex[:24]}"
+ ),
+ )
+ out_idx = len(all_output_items)
+
+ # output_item.added (function_call)
+ yield _evt(
+ "response.output_item.added",
+ ResponsesOutputItemAddedEvent(
+ output_index=out_idx, item=fc_item
+ ),
+ )
+
+ # function_call_arguments.delta (full arguments in one shot)
+ yield _evt(
+ "response.function_call_arguments.delta",
+ ResponsesFunctionCallArgumentsDeltaEvent(
+ item_id=fc_item.id,
+ output_index=out_idx,
+ delta=fc_item.arguments,
+ ),
+ )
+
+ # function_call_arguments.done
+ yield _evt(
+ "response.function_call_arguments.done",
+ ResponsesFunctionCallArgumentsDoneEvent(
+ item_id=fc_item.id,
+ output_index=out_idx,
+ arguments=fc_item.arguments,
+ ),
+ )
+
+ # output_item.done (function_call)
+ yield _evt(
+ "response.output_item.done",
+ ResponsesOutputItemDoneEvent(
+ output_index=out_idx, item=fc_item
+ ),
+ )
+
+ all_output_items.append(fc_item)
+ except Exception:
+ pass # Tool parsing failure is non-fatal in streaming
+
+ # response.completed
+ total_tokens = (
+ usage_stats["input_tokens"] + usage_stats["output_tokens"]
+ )
completed_response = base_response.model_copy(
update={
- "status": "completed",
- "output": [final_message_item],
- "usage": {
- "input_tokens": usage_stats["input_tokens"],
- "output_tokens": usage_stats["output_tokens"],
- "total_tokens": usage_stats["input_tokens"]
- + usage_stats["output_tokens"],
- },
+ "status": status,
+ "output": all_output_items,
+ "incomplete_details": (
+ ResponseIncompleteDetails(reason="max_output_tokens")
+ if status == "incomplete"
+ else None
+ ),
+ "usage": ResponseUsage(
+ input_tokens=usage_stats["input_tokens"],
+ output_tokens=usage_stats["output_tokens"],
+ total_tokens=total_tokens,
+ ),
}
)
- yield f"event: response.completed\ndata: {ResponseCompletedEvent(type='response.completed', response=completed_response).model_dump_json()}\n\n"
+ yield _evt(
+ "response.completed",
+ ResponsesCompletedEvent(response=completed_response),
+ )
+
+ # Save to store for previous_response_id
+ _responses_store.save(
+ response_id,
+ (
+ request.input
+ if isinstance(request.input, str)
+ else [
+ (
+ item.model_dump()
+ if hasattr(item, "model_dump")
+ else item
+ )
+ for item in request.input
+ ]
+ ),
+ [item.model_dump() for item in all_output_items],
+ )
+
+ # Final sentinel
+ yield "data: [DONE]\n\n"
except Exception as e:
print(f"Error during stream generation: {e}")
traceback.print_exc()
- error_data = json.dumps({"error": str(e)})
+ error_data = json.dumps({"error": "Internal generation error"})
yield f"data: {error_data}\n\n"
finally:
mx.clear_cache()
gc.collect()
- print("Stream finished, cleared cache.")
+ if sem is not None:
+ sem.release()
+ if _verbose:
+ print("Stream finished, cleared cache.")
return StreamingResponse(
- stream_generator(),
+ stream_responses_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -980,64 +1683,102 @@ async def stream_generator():
)
else:
+ # ----------------------------------------------------------
# Non-streaming response
+ # ----------------------------------------------------------
+ sem = None
try:
- # Use generate from generate.py
+ sem = await acquire_semaphore()
+ cache_state = get_prompt_cache_state(
+ request.model, getattr(request, "prompt_cache_key", None)
+ )
result = generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
- verbose=False, # stats are passed in the response
+ verbose=False,
+ vision_cache=model_cache.get("vision_cache"),
+ prompt_cache_state=cache_state,
**generation_kwargs,
)
- # Clean up resources
mx.clear_cache()
gc.collect()
- print("Generation finished, cleared cache.")
+ if _verbose:
+ print("Generation finished, cleared cache.")
+
+ # Build output items (with tool call parsing)
+ output_items = build_responses_output(
+ result.text,
+ tool_parser_type,
+ tool_module,
+ tools,
+ )
+
+ # Determine status
+ is_length = result.generation_tokens >= request.max_output_tokens
+ status = "incomplete" if is_length else "completed"
+ incomplete_details = (
+ ResponseIncompleteDetails(reason="max_output_tokens")
+ if status == "incomplete"
+ else None
+ )
- response = OpenAIResponse(
+ response_obj = ResponseObject(
id=response_id,
- object="response",
- created_at=int(generated_at),
- status="completed",
- instructions=instructions,
- max_output_tokens=openai_request.max_output_tokens,
- model=openai_request.model,
- output=[
- {
- "role": "assistant",
- "content": [
- {
- "type": "output_text",
- "text": result.text,
- }
- ],
- }
- ],
- output_text=result.text,
- temperature=openai_request.temperature,
- top_p=openai_request.top_p,
- usage={
- "input_tokens": result.prompt_tokens,
- "output_tokens": result.generation_tokens,
- "total_tokens": result.total_tokens,
- },
+ created_at=generated_at,
+ model=request.model,
+ output=output_items,
+ status=status,
+ incomplete_details=incomplete_details,
+ instructions=request.instructions,
+ max_output_tokens=request.max_output_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ tools=tools or [],
+ tool_choice=request.tool_choice,
+ parallel_tool_calls=request.parallel_tool_calls,
+ previous_response_id=request.previous_response_id,
+ metadata=request.metadata,
+ usage=ResponseUsage(
+ input_tokens=result.prompt_tokens,
+ output_tokens=result.generation_tokens,
+ total_tokens=result.total_tokens,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=getattr(result, "cached_tokens", 0),
+ ),
+ ),
+ )
+
+ # Save to store for previous_response_id support
+ _responses_store.save(
+ response_obj.id,
+ (
+ request.input
+ if isinstance(request.input, str)
+ else [
+ item.model_dump() if hasattr(item, "model_dump") else item
+ for item in request.input
+ ]
+ ),
+ [item.model_dump() for item in output_items],
)
- return response
+
+ return response_obj.model_dump()
except Exception as e:
print(f"Error during generation: {e}")
traceback.print_exc()
mx.clear_cache()
gc.collect()
- raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
+ raise HTTPException(status_code=500, detail="Generation failed. Check server logs for details.")
+ finally:
+ if sem is not None:
+ sem.release()
- except HTTPException as http_exc:
- # Re-raise HTTP exceptions (like model loading failure)
- raise http_exc
+ except HTTPException:
+ raise
except Exception as e:
- # Catch unexpected errors
print(f"Unexpected error in /responses endpoint: {e}")
traceback.print_exc()
mx.clear_cache()
@@ -1058,6 +1799,9 @@ async def chat_completions_endpoint(request: ChatRequest):
System message will be ignored if not already in the prompt.
Can operate in streaming or non-streaming mode.
"""
+ # Resolve default max tokens if not specified in request
+ if request.max_tokens is None:
+ request.max_tokens = get_default_max_tokens()
try:
# Get model, processor, config - loading if necessary
@@ -1095,6 +1839,14 @@ async def chat_completions_endpoint(request: ChatRequest):
if hasattr(request, "tools"):
tools = request.tools
+ # Apply tool_choice policy
+ tool_choice = getattr(request, "tool_choice", None)
+ tools, tool_instruction = resolve_tool_choice(tools, tool_choice)
+ if tool_instruction:
+ processed_messages.insert(
+ 0, {"role": "system", "content": tool_instruction}
+ )
+
tool_parser_type = None
tokenizer = (
processor.tokenizer if hasattr(processor, "tokenizer") else processor
@@ -1115,12 +1867,24 @@ async def chat_completions_endpoint(request: ChatRequest):
)
generation_kwargs = build_generation_kwargs(request, template_kwargs)
+ check_context_length(formatted_prompt, processor, get_max_context_tokens())
+
+ # Resolve stop sequences to token IDs
+ stop_seqs = resolve_stop_sequences(getattr(request, "stop", None))
+ if stop_seqs:
+ generation_kwargs["eos_tokens"] = stop_seqs
+
if request.stream:
# Streaming response
async def stream_generator():
+ sem = None
token_iterator = None
try:
- # Use stream_generate from utils
+ sem = await acquire_semaphore()
+ # Use stream_generate with prompt cache reuse
+ cache_state = get_prompt_cache_state(
+ request.model, getattr(request, "prompt_cache_key", None)
+ )
token_iterator = stream_generate(
model=model,
processor=processor,
@@ -1128,14 +1892,17 @@ async def stream_generator():
image=images,
audio=audio,
vision_cache=model_cache.get("vision_cache"),
+ prompt_cache_state=cache_state,
**generation_kwargs,
)
output_text = ""
request_id = f"chatcmpl-{uuid.uuid4()}"
+ want_logprobs = getattr(request, "logprobs", None)
for chunk in token_iterator:
if chunk is None or not hasattr(chunk, "text"):
- print("Warning: Received unexpected chunk format:", chunk)
+ if _verbose:
+ print("Warning: Received unexpected chunk format:", chunk)
continue
output_text += chunk.text
@@ -1151,9 +1918,24 @@ async def stream_generator():
"peak_memory": chunk.peak_memory,
}
+ chunk_logprobs = None
+ if want_logprobs and chunk.token is not None and chunk.logprobs is not None:
+ token_text = tokenizer.decode([chunk.token])
+ chosen_logprob = float(chunk.logprobs[chunk.token])
+ chunk_logprobs = ChoiceLogprobs(
+ content=[
+ TokenLogprob(
+ token=token_text,
+ logprob=chosen_logprob,
+ bytes=list(token_text.encode("utf-8")),
+ )
+ ]
+ )
+
choices = [
ChatStreamChoice(
- delta=ChatMessage(role="assistant", content=chunk.text)
+ delta=ChatMessage(role="assistant", content=chunk.text),
+ logprobs=chunk_logprobs,
)
]
chunk_data = ChatStreamChunk(
@@ -1176,10 +1958,11 @@ async def stream_generator():
tool_calls = {}
tool_calls["calls"] = []
- # Signal stream end
+ # Signal stream end with correct finish_reason
+ stream_finish = "tool_calls" if tool_calls.get("calls") else "stop"
choices = [
ChatStreamChoice(
- finish_reason="stop",
+ finish_reason=stream_finish,
delta=ChatMessage(
role="assistant",
content="",
@@ -1202,13 +1985,16 @@ async def stream_generator():
except Exception as e:
print(f"Error during stream generation: {e}")
traceback.print_exc()
- error_data = json.dumps({"error": str(e)})
+ error_data = json.dumps({"error": "Internal generation error"})
yield f"data: {error_data}\n\n"
finally:
mx.clear_cache()
gc.collect()
- print("Stream finished, cleared cache.")
+ if sem is not None:
+ sem.release()
+ if _verbose:
+ print("Stream finished, cleared cache.")
return StreamingResponse(
stream_generator(),
@@ -1222,22 +2008,60 @@ async def stream_generator():
else:
# Non-streaming response
+ sem = None
try:
- # Use generate from generate.py
- gen_result = generate(
- model=model,
- processor=processor,
- prompt=formatted_prompt,
- image=images,
- audio=audio,
- verbose=False, # Keep API output clean
- vision_cache=model_cache.get("vision_cache"),
- **generation_kwargs,
+ sem = await acquire_semaphore()
+ want_logprobs = getattr(request, "logprobs", None)
+ cache_state = get_prompt_cache_state(
+ request.model, getattr(request, "prompt_cache_key", None)
)
+ token_logprobs = []
+
+ if want_logprobs:
+ # Use stream_generate to collect per-token logprobs
+ full_text = ""
+ gen_result = None
+ for chunk in stream_generate(
+ model=model,
+ processor=processor,
+ prompt=formatted_prompt,
+ image=images,
+ audio=audio,
+ vision_cache=model_cache.get("vision_cache"),
+ prompt_cache_state=cache_state,
+ **generation_kwargs,
+ ):
+ if chunk is None or not hasattr(chunk, "text"):
+ continue
+ full_text += chunk.text
+ if chunk.token is not None and chunk.logprobs is not None:
+ token_text = tokenizer.decode([chunk.token])
+ chosen_logprob = float(chunk.logprobs[chunk.token])
+ token_logprobs.append(
+ TokenLogprob(
+ token=token_text,
+ logprob=chosen_logprob,
+ bytes=list(token_text.encode("utf-8")),
+ )
+ )
+ gen_result = chunk
+ gen_result.text = full_text
+ else:
+ gen_result = generate(
+ model=model,
+ processor=processor,
+ prompt=formatted_prompt,
+ image=images,
+ audio=audio,
+ verbose=False,
+ vision_cache=model_cache.get("vision_cache"),
+ prompt_cache_state=cache_state,
+ **generation_kwargs,
+ )
+
# Clean up resources
mx.clear_cache()
gc.collect()
- print("Generation finished, cleared cache.")
usage_stats = UsageStats(
input_tokens=gen_result.prompt_tokens,
@@ -1259,14 +2083,20 @@ async def stream_generator():
tool_calls["calls"] = []
tool_calls["remaining_text"] = gen_result.text
+ choice_logprobs = None
+ if want_logprobs and token_logprobs:
+ choice_logprobs = ChoiceLogprobs(content=token_logprobs)
+
+ finish = "tool_calls" if tool_calls.get("calls") else "stop"
choices = [
ChatChoice(
- finish_reason="stop",
+ finish_reason=finish,
message=ChatMessage(
role="assistant",
content=tool_calls["remaining_text"],
tool_calls=tool_calls["calls"],
),
+ logprobs=choice_logprobs,
)
]
@@ -1281,7 +2111,10 @@ async def stream_generator():
traceback.print_exc()
mx.clear_cache()
gc.collect()
- raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
+ raise HTTPException(status_code=500, detail="Generation failed. Check server logs for details.")
+ finally:
+ if sem is not None:
+ sem.release()
except HTTPException as http_exc:
# Re-raise HTTP exceptions (like model loading failure)
@@ -1433,6 +2266,48 @@ def main():
default=DEFAULT_QUANTIZED_KV_START,
help="Start index (of token) for the quantized KV cache.",
)
+ parser.add_argument(
+ "--max-concurrent-requests",
+ type=int,
+ default=1,
+ help="Maximum number of concurrent generation requests. "
+ "MLX runs single-threaded on Metal; values > 1 may cause GPU errors. "
+ "(default: %(default)s)",
+ )
+ parser.add_argument(
+ "--max-context-tokens",
+ type=int,
+ default=0,
+ help="Maximum context window in tokens. 0 = no limit. (default: %(default)s)",
+ )
+ parser.add_argument(
+ "--request-timeout",
+ type=int,
+ default=300,
+ help="Maximum seconds per generation request. (default: %(default)s)",
+ )
+ parser.add_argument(
+ "--default-max-tokens",
+ type=int,
+ default=DEFAULT_MAX_TOKENS,
+ help="Default max tokens for API responses when not specified in the request. "
+ "The upstream default (256) is too low for agentic use. "
+ "(default: %(default)s)",
+ )
+ parser.add_argument(
+ "--prompt-cache-ttl",
+ type=int,
+ default=DEFAULT_PROMPT_CACHE_TTL,
+ help="Seconds of idle time before a prompt cache entry is evicted. "
+ "Frees GPU memory from stale KV caches. 0 = no expiry. "
+ "(default: %(default)s)",
+ )
+ parser.add_argument(
+ "--verbose",
+ action="store_true",
+ default=False,
+ help="Enable verbose debug logging (prompt content, cache state, tool detection).",
+ )
parser.add_argument(
"--reload",
action="store_true",
@@ -1454,6 +2329,13 @@ def main():
os.environ["KV_QUANT_SCHEME"] = args.kv_quant_scheme
os.environ["MAX_KV_SIZE"] = str(args.max_kv_size)
os.environ["QUANTIZED_KV_START"] = str(args.quantized_kv_start)
+ os.environ["MAX_CONCURRENT_REQUESTS"] = str(args.max_concurrent_requests)
+ os.environ["MAX_CONTEXT_TOKENS"] = str(args.max_context_tokens)
+ os.environ["REQUEST_TIMEOUT"] = str(args.request_timeout)
+ os.environ["PROMPT_CACHE_TTL"] = str(args.prompt_cache_ttl)
+ os.environ["DEFAULT_MAX_TOKENS"] = str(args.default_max_tokens)
+ if args.verbose:
+ os.environ["VERBOSE"] = "1"
uvicorn.run(
"mlx_vlm.server:app",
diff --git a/mlx_vlm/tests/test_responses_api.py b/mlx_vlm/tests/test_responses_api.py
new file mode 100644
index 000000000..7ef21d588
--- /dev/null
+++ b/mlx_vlm/tests/test_responses_api.py
@@ -0,0 +1,1073 @@
+"""Tests for the OpenAI Responses API (/v1/responses) compliance.
+
+Covers:
+ A. Model validation (pure unit tests, no server/mlx needed)
+ B. Response store (pure unit tests)
+ C. Functional endpoint tests (TestClient, mocked model)
+ D. Streaming endpoint tests (TestClient, mocked model)
+"""
+
+import importlib.util
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+# ---------------------------------------------------------------------------
+# Helpers: load modules without triggering mlx_vlm.__init__ (no mlx needed)
+# ---------------------------------------------------------------------------
+
+
+def _load_module(name: str, filename: str):
+ """Load a sibling module by file path, bypassing package __init__."""
+ mod_path = Path(__file__).parent.parent / filename
+ spec = importlib.util.spec_from_file_location(name, str(mod_path))
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ return mod
+
+
+responses_models = _load_module("responses_models", "responses_models.py")
+responses_store = _load_module("responses_store", "responses_store.py")
+
+ResponsesRequest = responses_models.ResponsesRequest
+ResponseObject = responses_models.ResponseObject
+ResponseMessageItem = responses_models.ResponseMessageItem
+ResponseFunctionCallItem = responses_models.ResponseFunctionCallItem
+ContentPartOutputText = responses_models.ContentPartOutputText
+ResponseUsage = responses_models.ResponseUsage
+FlexibleBaseModel = responses_models.FlexibleBaseModel
+BaseStreamEvent = responses_models.BaseStreamEvent
+ResponseStore = responses_store.ResponseStore
+
+
+# =========================================================================
+# A. Model Validation Tests
+# =========================================================================
+
+
+class TestResponsesModels:
+ """Pure unit tests for Pydantic models in responses_models.py."""
+
+ def test_responses_request_accepts_string_input(self):
+ req = ResponsesRequest(input="Hello", model="test-model")
+ assert req.input == "Hello"
+
+ def test_responses_request_accepts_message_list(self):
+ msgs = [{"role": "user", "content": "hello"}]
+ req = ResponsesRequest(input=msgs, model="test-model")
+ assert isinstance(req.input, list)
+ assert len(req.input) == 1
+
+ def test_responses_request_accepts_tools(self):
+ tools = [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get the weather",
+ "parameters": {"type": "object", "properties": {}},
+ }
+ ]
+ req = ResponsesRequest(input="hi", model="m", tools=tools)
+ assert req.tools is not None
+ assert len(req.tools) == 1
+
+ def test_responses_request_default_tool_choice(self):
+ req = ResponsesRequest(input="hi", model="m")
+ assert req.tool_choice == "auto"
+
+ def test_responses_request_generation_kwargs(self):
+ req = ResponsesRequest(input="hi", model="m", max_output_tokens=128)
+ kwargs = req.generation_kwargs()
+ assert "max_tokens" in kwargs
+ assert kwargs["max_tokens"] == 128
+ assert "max_output_tokens" not in kwargs
+
+ def test_response_object_output_text_computed(self):
+ msg = ResponseMessageItem(
+ content=[
+ ContentPartOutputText(text="Hello "),
+ ContentPartOutputText(text="world!"),
+ ]
+ )
+ resp = ResponseObject(
+ created_at=0,
+ model="m",
+ output=[msg],
+ usage=ResponseUsage(input_tokens=1, output_tokens=2, total_tokens=3),
+ )
+ assert resp.output_text == "Hello world!"
+
+ def test_response_object_output_text_empty_when_only_function_calls(self):
+ fc = ResponseFunctionCallItem(name="fn", arguments='{"a":1}')
+ resp = ResponseObject(
+ created_at=0,
+ model="m",
+ output=[fc],
+ usage=ResponseUsage(input_tokens=1, output_tokens=2, total_tokens=3),
+ )
+ assert resp.output_text == ""
+
+ def test_function_call_item_auto_ids(self):
+ fc = ResponseFunctionCallItem(name="fn", arguments="{}")
+ assert fc.id.startswith("fc_")
+ assert fc.call_id.startswith("call_")
+ # IDs should be unique per instance
+ fc2 = ResponseFunctionCallItem(name="fn", arguments="{}")
+ assert fc.id != fc2.id
+
+ def test_function_call_item_schema(self):
+ fc = ResponseFunctionCallItem(name="get_weather", arguments='{"city":"NYC"}')
+ assert fc.name == "get_weather"
+ assert fc.arguments == '{"city":"NYC"}'
+ assert fc.type == "function_call"
+
+ def test_content_part_output_text_defaults(self):
+ part = ContentPartOutputText()
+ assert part.type == "output_text"
+ assert part.text == ""
+ assert part.annotations == []
+
+ def test_streaming_event_sequence_number(self):
+ evt = BaseStreamEvent(type="test.event", sequence_number=42)
+ assert evt.sequence_number == 42
+ evt_default = BaseStreamEvent(type="test.event")
+ assert evt_default.sequence_number == 0
+
+ def test_flexible_base_model_accepts_unknown_fields(self):
+ req = ResponsesRequest(input="hi", model="m", some_unknown_field="surprise")
+ # Should not raise; extra field accessible via model_extra
+ assert req.model_extra.get("some_unknown_field") == "surprise"
+
+
+# =========================================================================
+# B. Response Store Tests
+# =========================================================================
+
+
+class TestResponseStore:
+ """Pure unit tests for the LRU ResponseStore."""
+
+ def test_store_save_and_get(self):
+ store = ResponseStore()
+ store.save("resp_1", "hello", [{"type": "message"}])
+ entry = store.get("resp_1")
+ assert entry is not None
+ assert entry["input"] == "hello"
+ assert entry["output"] == [{"type": "message"}]
+
+ def test_store_get_missing_returns_none(self):
+ store = ResponseStore()
+ assert store.get("resp_nonexistent") is None
+
+ def test_store_lru_eviction(self):
+ store = ResponseStore(maxsize=2)
+ store.save("resp_a", "a", [])
+ store.save("resp_b", "b", [])
+ store.save("resp_c", "c", []) # should evict resp_a
+ assert store.get("resp_a") is None
+ assert store.get("resp_b") is not None
+ assert store.get("resp_c") is not None
+
+ def test_store_replay_string_input(self):
+ store = ResponseStore()
+ store.save("resp_1", "hello", [])
+ items = store.replay_input("resp_1")
+ assert items is not None
+ assert len(items) == 1
+ assert items[0]["role"] == "user"
+ assert items[0]["content"] == "hello"
+
+ def test_store_replay_message_list_input(self):
+ original = [
+ {"role": "user", "content": "What is 2+2?"},
+ {"role": "system", "content": "You are helpful."},
+ ]
+ store = ResponseStore()
+ store.save("resp_1", original, [])
+ items = store.replay_input("resp_1")
+ assert items is not None
+ assert len(items) == 2
+ assert items[0]["role"] == "user"
+ assert items[1]["role"] == "system"
+
+ def test_store_replay_function_call_output(self):
+ output = [
+ {
+ "type": "function_call",
+ "call_id": "call_123",
+ "name": "get_weather",
+ "arguments": '{"city":"NYC"}',
+ }
+ ]
+ store = ResponseStore()
+ store.save("resp_1", "hello", output)
+ items = store.replay_input("resp_1")
+ assert items is not None
+ # First item is the original user input, second is the function call
+ fc_items = [i for i in items if i.get("type") == "function_call"]
+ assert len(fc_items) == 1
+ assert fc_items[0]["name"] == "get_weather"
+ assert fc_items[0]["call_id"] == "call_123"
+
+ def test_store_replay_missing_returns_none(self):
+ store = ResponseStore()
+ assert store.replay_input("resp_nope") is None
+
+ def test_store_clear(self):
+ store = ResponseStore()
+ store.save("resp_1", "a", [])
+ store.save("resp_2", "b", [])
+ assert len(store) == 2
+ store.clear()
+ assert len(store) == 0
+ assert store.get("resp_1") is None
+
+
+# =========================================================================
+# C. Functional Endpoint Tests (require mlx for server import)
+# =========================================================================
+
+# Guard: skip functional/streaming tests if mlx is unavailable, but let
+# the pure-unit tests above run on any platform.
+_has_mlx = importlib.util.find_spec("mlx") is not None
+
+if _has_mlx:
+ from fastapi.testclient import TestClient # noqa: E402
+
+ import mlx_vlm.server as server # noqa: E402
+
+_skip_no_mlx = pytest.mark.skipif(not _has_mlx, reason="mlx not installed")
+
+
+# Shared mock objects (safe to create even without mlx)
+mock_model = MagicMock()
+mock_processor = MagicMock()
+mock_processor.tokenizer = MagicMock()
+mock_processor.tokenizer.chat_template = ""
+mock_config = SimpleNamespace(model_type="test")
+
+
+def _mock_result(text="Hello world!", prompt_tokens=10, gen_tokens=5):
+ """Build a SimpleNamespace matching generate() return shape."""
+ return SimpleNamespace(
+ text=text,
+ prompt_tokens=prompt_tokens,
+ generation_tokens=gen_tokens,
+ total_tokens=prompt_tokens + gen_tokens,
+ prompt_tps=100.0,
+ generation_tps=50.0,
+ peak_memory=1.0,
+ )
+
+
+@pytest.fixture
+def client():
+ with TestClient(server.app) as c:
+ yield c
+
+
+def _patch_model():
+ return patch.object(
+ server,
+ "get_cached_model",
+ return_value=(mock_model, mock_processor, mock_config),
+ )
+
+
+def _patch_template():
+ return patch.object(server, "apply_chat_template", return_value="prompt")
+
+
+def _patch_generate(result=None):
+ if result is None:
+ result = _mock_result()
+ return patch.object(server, "generate", return_value=result)
+
+
+@_skip_no_mlx
+class TestResponsesEndpoint:
+ """Functional tests for POST /responses."""
+
+ def test_basic_text_response(self, client):
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={"model": "demo", "input": "Hello"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["object"] == "response"
+ assert data["status"] == "completed"
+ assert "id" in data
+ assert "output" in data
+ assert "usage" in data
+
+ def test_response_with_message_list(self, client):
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [{"role": "user", "content": "hello"}],
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["status"] == "completed"
+
+ def test_instructions_field_echoed(self, client):
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {"role": "system", "content": "Be brief."},
+ {"role": "user", "content": "hi"},
+ ],
+ "instructions": "Be brief.",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ # The instructions field should be present in the response
+ assert data.get("instructions") is not None
+
+ def test_tools_field_echoed(self, client):
+ tools = [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get weather",
+ "parameters": {"type": "object", "properties": {}},
+ }
+ ]
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={"model": "demo", "input": "hi", "tools": tools},
+ )
+ assert resp.status_code == 200
+
+ def test_previous_response_id_not_found(self, client):
+ """Referencing a non-existent previous_response_id should return an error."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": "follow-up",
+ "previous_response_id": "resp_nonexistent999",
+ },
+ )
+ # The server should either 404 or 200 (ignoring unknown ID).
+ # We just verify it doesn't crash with a 500.
+ assert resp.status_code in (200, 404)
+
+ def test_developer_role_mapped_to_system(self, client):
+ """developer role should be accepted (mapped to system internally)."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {"role": "developer", "content": "You are helpful."},
+ {"role": "user", "content": "hi"},
+ ],
+ },
+ )
+ # Should not crash; accept 200 or 422 if server rejects developer role
+ assert resp.status_code in (200, 422)
+
+ def test_text_type_alias(self, client):
+ """'text' type should be accepted alongside 'input_text'."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {
+ "role": "user",
+ "content": [{"type": "input_text", "text": "hi"}],
+ }
+ ],
+ },
+ )
+ assert resp.status_code == 200
+
+ def test_function_call_output_input(self, client):
+ """function_call_output items in input should not crash the server."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {"role": "user", "content": "call a tool"},
+ {
+ "type": "function_call_output",
+ "call_id": "call_abc",
+ "output": '{"result": 42}',
+ },
+ ],
+ },
+ )
+ # May fail if server doesn't handle function_call_output yet;
+ # accept anything except unhandled 500
+ assert resp.status_code in (200, 400, 422)
+
+ def test_max_output_tokens_incomplete(self, client):
+ """When finish_reason is 'length', the response status should ideally be 'incomplete'."""
+ result = _mock_result(text="truncated...")
+ with _patch_model(), _patch_template(), _patch_generate(result):
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": "Write a very long essay",
+ "max_output_tokens": 5,
+ },
+ )
+ assert resp.status_code == 200
+ # Just verify the response is well-formed
+ data = resp.json()
+ assert data["status"] in ("completed", "incomplete")
+
+
+@_skip_no_mlx
+class TestResponsesStreaming:
+ """Streaming SSE tests for POST /responses with stream=true."""
+
+ def _stream_events(self, client, payload):
+ """Helper: POST with stream=True and collect SSE events."""
+ with _patch_model(), _patch_template():
+ # Mock stream_generate to yield chunks
+ chunks = [
+ SimpleNamespace(text="Hello", prompt_tokens=10, generation_tokens=1),
+ SimpleNamespace(text=" world", prompt_tokens=10, generation_tokens=2),
+ ]
+
+ def mock_stream_gen(**kwargs):
+ return iter(chunks)
+
+ with patch.object(server, "stream_generate", side_effect=mock_stream_gen):
+ resp = client.post("/responses", json=payload)
+ return resp
+
+ def test_streaming_sse_events(self, client):
+ payload = {"model": "demo", "input": "Hello", "stream": True}
+ resp = self._stream_events(client, payload)
+ assert resp.status_code == 200
+ body = resp.text
+ # Should contain key event types
+ assert "event: response.created" in body
+ assert "event: response.output_text.delta" in body
+ assert "event: response.completed" in body
+
+ def test_streaming_done_sentinel(self, client):
+ """The stream should end properly (response.completed is the last real event)."""
+ payload = {"model": "demo", "input": "Hello", "stream": True}
+ resp = self._stream_events(client, payload)
+ assert resp.status_code == 200
+ body = resp.text
+ # The last meaningful event should be response.completed
+ lines = [l for l in body.strip().split("\n") if l.startswith("event:")]
+ assert lines[-1] == "event: response.completed"
+
+
+# =========================================================================
+# E. Prompt Cache Tests
+# =========================================================================
+
+
+@_skip_no_mlx
+class TestPromptCache:
+ """Verify prompt_cache_state is wired into all generation entry points."""
+
+ def test_responses_non_streaming_passes_cache_state(self, client):
+ """Non-streaming /responses should pass prompt_cache_state to generate."""
+ captured = {}
+
+ def capture_generate(**kwargs):
+ captured["prompt_cache_state"] = kwargs.get("prompt_cache_state")
+ return _mock_result()
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "generate", side_effect=capture_generate),
+ ):
+ resp = client.post("/responses", json={"model": "demo", "input": "hi"})
+ assert resp.status_code == 200
+ assert captured.get("prompt_cache_state") is not None
+ assert hasattr(captured["prompt_cache_state"], "find_prefix_length")
+
+ def test_responses_streaming_passes_cache_state(self, client):
+ """Streaming /responses should pass prompt_cache_state to stream_generate."""
+ captured = {}
+
+ def capture_stream(**kwargs):
+ captured["prompt_cache_state"] = kwargs.get("prompt_cache_state")
+ return iter(
+ [
+ SimpleNamespace(text="Hi", prompt_tokens=5, generation_tokens=1),
+ ]
+ )
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "stream_generate", side_effect=capture_stream),
+ ):
+ resp = client.post(
+ "/responses",
+ json={"model": "demo", "input": "hi", "stream": True},
+ )
+ assert resp.status_code == 200
+ assert captured.get("prompt_cache_state") is not None
+
+ def test_chat_completions_non_streaming_passes_cache_state(self, client):
+ """Non-streaming /chat/completions should pass prompt_cache_state."""
+ captured = {}
+
+ def capture_generate(**kwargs):
+ captured["prompt_cache_state"] = kwargs.get("prompt_cache_state")
+ return _mock_result()
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "generate", side_effect=capture_generate),
+ ):
+ resp = client.post(
+ "/chat/completions",
+ json={
+ "model": "demo",
+ "messages": [{"role": "user", "content": "hello"}],
+ },
+ )
+ assert resp.status_code == 200
+ assert captured.get("prompt_cache_state") is not None
+
+ def test_chat_completions_streaming_passes_cache_state(self, client):
+ """Streaming /chat/completions should pass prompt_cache_state."""
+ captured = {}
+
+ def capture_stream(**kwargs):
+ captured["prompt_cache_state"] = kwargs.get("prompt_cache_state")
+ return iter(
+ [
+ SimpleNamespace(text="Hi", prompt_tokens=5, generation_tokens=1),
+ ]
+ )
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "stream_generate", side_effect=capture_stream),
+ ):
+ resp = client.post(
+ "/chat/completions",
+ json={
+ "model": "demo",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": True,
+ },
+ )
+ assert resp.status_code == 200
+ assert captured.get("prompt_cache_state") is not None
+
+ def test_cache_state_persists_across_requests(self, client):
+ """The same PromptCacheState should be reused for the same model."""
+ states = []
+
+ def capture_generate(**kwargs):
+ states.append(kwargs.get("prompt_cache_state"))
+ return _mock_result()
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "generate", side_effect=capture_generate),
+ ):
+ client.post("/responses", json={"model": "demo", "input": "first"})
+ client.post("/responses", json={"model": "demo", "input": "second"})
+
+ assert len(states) == 2
+ assert states[0] is states[1], "Same model should reuse the same cache state"
+
+ def test_cache_state_isolated_per_model(self, client):
+ """Different models should get different PromptCacheState instances."""
+ states = {}
+
+ def capture_generate(**kwargs):
+ return _mock_result()
+
+ # We need to capture from the store directly
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "generate", side_effect=capture_generate),
+ ):
+ client.post("/responses", json={"model": "model-a", "input": "hi"})
+ state_a = server.get_prompt_cache_state("model-a")
+ client.post("/responses", json={"model": "model-b", "input": "hi"})
+ state_b = server.get_prompt_cache_state("model-b")
+
+ assert (
+ state_a is not state_b
+ ), "Different models must have separate cache states"
+
+ def test_cache_state_has_correct_interface(self, client):
+ """PromptCacheState should expose find_prefix_length and update methods."""
+ state = server.get_prompt_cache_state("test-model")
+ assert hasattr(state, "find_prefix_length")
+ assert hasattr(state, "update")
+ assert hasattr(state, "cache")
+ assert hasattr(state, "token_ids")
+ # Initially empty
+ assert state.cache is None
+ assert state.token_ids is None
+ assert state.find_prefix_length([1, 2, 3]) == 0
+
+ def test_cache_state_cleared_on_unload(self, client):
+ """Unloading a model should clear all prompt cache states."""
+ server._prompt_cache_states["some-model"] = server.PromptCacheState()
+ assert "some-model" in server._prompt_cache_states
+ # Simulate unload
+ with patch.object(server, "model_cache", {"model_path": "x"}):
+ server.unload_model_sync()
+ assert len(server._prompt_cache_states) == 0
+
+ def test_cache_state_prefix_matching(self):
+ """PromptCacheState.find_prefix_length should find common prefix."""
+ state = server.PromptCacheState()
+ state.token_ids = [10, 20, 30, 40, 50]
+ assert state.find_prefix_length([10, 20, 30, 40, 50]) == 5
+ assert state.find_prefix_length([10, 20, 30, 99, 50]) == 3
+ assert state.find_prefix_length([99, 20, 30]) == 0
+ assert state.find_prefix_length([10, 20]) == 2
+ assert state.find_prefix_length([]) == 0
+
+
+# =========================================================================
+# F. Concurrency Guard Tests
+# =========================================================================
+
+
+@_skip_no_mlx
+class TestConcurrencyGuard:
+ """Verify concurrency guard serializes Metal GPU access."""
+
+ def test_semaphore_exists_and_is_semaphore(self):
+ """get_generation_semaphore should return an asyncio.Semaphore."""
+ import asyncio
+
+ sem = server.get_generation_semaphore()
+ assert isinstance(sem, asyncio.Semaphore)
+
+ def test_semaphore_default_value_is_one(self):
+ """Default semaphore should allow exactly 1 concurrent request."""
+ import os
+
+ # Reset to force re-creation with default
+ server._generation_semaphore = None
+ os.environ.pop("MAX_CONCURRENT_REQUESTS", None)
+ sem = server.get_generation_semaphore()
+ assert sem._value == 1
+ # Reset for other tests
+ server._generation_semaphore = None
+
+ def test_semaphore_respects_env_var(self):
+ """MAX_CONCURRENT_REQUESTS env var should configure semaphore value."""
+ import os
+
+ server._generation_semaphore = None
+ os.environ["MAX_CONCURRENT_REQUESTS"] = "3"
+ sem = server.get_generation_semaphore()
+ assert sem._value == 3
+ # Cleanup
+ os.environ["MAX_CONCURRENT_REQUESTS"] = "1"
+ server._generation_semaphore = None
+
+ def test_semaphore_singleton(self):
+ """Repeated calls should return the same semaphore instance."""
+ server._generation_semaphore = None
+ sem1 = server.get_generation_semaphore()
+ sem2 = server.get_generation_semaphore()
+ assert sem1 is sem2
+ server._generation_semaphore = None
+
+ def test_responses_non_streaming_acquires_semaphore(self, client):
+ """Non-streaming /responses should acquire and release the semaphore."""
+
+ acquired = []
+ released = []
+ real_sem = server.get_generation_semaphore()
+
+ original_acquire = real_sem.acquire
+ original_release = real_sem.release
+
+ async def mock_acquire():
+ acquired.append(True)
+ return await original_acquire()
+
+ def mock_release():
+ released.append(True)
+ return original_release()
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ _patch_generate(),
+ patch.object(real_sem, "acquire", side_effect=mock_acquire),
+ patch.object(real_sem, "release", side_effect=mock_release),
+ ):
+ resp = client.post("/responses", json={"model": "demo", "input": "hi"})
+
+ assert resp.status_code == 200
+ assert len(acquired) >= 1, "Semaphore should be acquired"
+ assert len(released) >= 1, "Semaphore should be released"
+
+ def test_concurrent_requests_both_succeed(self, client):
+ """Two sequential requests should both succeed (semaphore serializes)."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ r1 = client.post("/responses", json={"model": "demo", "input": "first"})
+ r2 = client.post("/responses", json={"model": "demo", "input": "second"})
+ assert r1.status_code == 200
+ assert r2.status_code == 200
+
+
+# =========================================================================
+# G. finish_reason Tests
+# =========================================================================
+
+
+@_skip_no_mlx
+class TestFinishReason:
+ """Verify finish_reason is set correctly based on tool call detection."""
+
+ def test_chat_completions_finish_reason_stop_no_tools(self, client):
+ """finish_reason='stop' when no tools provided."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/chat/completions",
+ json={"model": "demo", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.status_code == 200
+ assert resp.json()["choices"][0]["finish_reason"] == "stop"
+
+ def test_chat_completions_finish_reason_tool_calls(self, client):
+ """finish_reason='tool_calls' when tool calls detected."""
+ fake_calls = {
+ "calls": [
+ {
+ "type": "function",
+ "id": "c1",
+ "function": {"name": "search", "arguments": "{}"},
+ }
+ ],
+ "remaining_text": "",
+ }
+ with (
+ _patch_model(),
+ _patch_template(),
+ _patch_generate(),
+ patch.object(server, "_infer_tool_parser", return_value="qwen3_coder"),
+ patch.object(server, "load_tool_module", return_value=SimpleNamespace()),
+ patch.object(server, "process_tool_calls", return_value=fake_calls),
+ ):
+ resp = client.post(
+ "/chat/completions",
+ json={
+ "model": "demo",
+ "messages": [{"role": "user", "content": "search"}],
+ "tools": [
+ {
+ "type": "function",
+ "function": {"name": "search", "parameters": {}},
+ }
+ ],
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.json()["choices"][0]["finish_reason"] == "tool_calls"
+
+ def test_chat_completions_finish_reason_stop_tools_no_calls(self, client):
+ """finish_reason='stop' when tools defined but model doesn't call any."""
+ no_calls = {"calls": [], "remaining_text": "Just text, no tools."}
+ with (
+ _patch_model(),
+ _patch_template(),
+ _patch_generate(),
+ patch.object(server, "_infer_tool_parser", return_value="qwen3_coder"),
+ patch.object(server, "load_tool_module", return_value=SimpleNamespace()),
+ patch.object(server, "process_tool_calls", return_value=no_calls),
+ ):
+ resp = client.post(
+ "/chat/completions",
+ json={
+ "model": "demo",
+ "messages": [{"role": "user", "content": "hello"}],
+ "tools": [
+ {
+ "type": "function",
+ "function": {"name": "search", "parameters": {}},
+ }
+ ],
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.json()["choices"][0]["finish_reason"] == "stop"
+
+ def test_chat_completions_streaming_finish_reason_tool_calls(self, client):
+ """Streaming finish_reason should be 'tool_calls' when tools detected."""
+ fake_calls = {
+ "calls": [
+ {
+ "type": "function",
+ "id": "c1",
+ "function": {"name": "search", "arguments": "{}"},
+ }
+ ],
+ "remaining_text": "",
+ }
+ chunks = [
+ SimpleNamespace(
+ text="calling",
+ prompt_tokens=10,
+ generation_tokens=1,
+ prompt_tps=100.0,
+ generation_tps=50.0,
+ peak_memory=1.0,
+ ),
+ ]
+
+ def mock_stream(**kwargs):
+ return iter(chunks)
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "stream_generate", side_effect=mock_stream),
+ patch.object(server, "_infer_tool_parser", return_value="qwen3_coder"),
+ patch.object(server, "load_tool_module", return_value=SimpleNamespace()),
+ patch.object(server, "process_tool_calls", return_value=fake_calls),
+ ):
+ resp = client.post(
+ "/chat/completions",
+ json={
+ "model": "demo",
+ "messages": [{"role": "user", "content": "search"}],
+ "tools": [
+ {
+ "type": "function",
+ "function": {"name": "search", "parameters": {}},
+ }
+ ],
+ "stream": True,
+ },
+ )
+ assert resp.status_code == 200
+ import json as json_mod
+
+ lines = [
+ l
+ for l in resp.text.strip().split("\n")
+ if l.startswith("data:") and "[DONE]" not in l
+ ]
+ last_data = json_mod.loads(lines[-1].replace("data: ", ""))
+ assert last_data["choices"][0]["finish_reason"] == "tool_calls"
+
+ def test_responses_status_completed_no_tools(self, client):
+ """Responses endpoint status='completed' for normal text."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post("/responses", json={"model": "demo", "input": "hi"})
+ assert resp.status_code == 200
+ assert resp.json()["status"] == "completed"
+
+
+# =========================================================================
+# H. JSON Mode Tests
+# =========================================================================
+
+
+@_skip_no_mlx
+class TestJsonMode:
+ """Verify response_format parameter handling."""
+
+ def test_resolve_response_format_json_adds_instruction(self):
+ msgs = [{"role": "user", "content": "hi"}]
+ result = server.resolve_response_format(msgs, {"type": "json_object"})
+ assert result[0]["role"] == "system"
+ assert "json" in result[0]["content"].lower()
+
+ def test_resolve_response_format_text_no_change(self):
+ msgs = [{"role": "user", "content": "hi"}]
+ result = server.resolve_response_format(msgs, {"type": "text"})
+ assert len(result) == 1
+
+ def test_resolve_response_format_none_no_change(self):
+ msgs = [{"role": "user", "content": "hi"}]
+ result = server.resolve_response_format(msgs, None)
+ assert len(result) == 1
+
+ def test_responses_json_mode_accepted(self, client):
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": "Give me JSON",
+ "response_format": {"type": "json_object"},
+ },
+ )
+ assert resp.status_code == 200
+
+
+# =========================================================================
+# I. Context Tracking Tests
+# =========================================================================
+
+
+@_skip_no_mlx
+class TestContextTracking:
+ """Verify context window tracking and OOM prevention."""
+
+ def test_check_context_length_within_limit(self):
+ fake_proc = SimpleNamespace(
+ tokenizer=SimpleNamespace(
+ encode=lambda s, add_special_tokens=False: list(range(10))
+ ),
+ )
+ server.check_context_length("short", fake_proc, 100)
+
+ def test_check_context_length_exceeds_limit(self):
+ from fastapi import HTTPException as _Exc
+
+ fake_proc = SimpleNamespace(
+ tokenizer=SimpleNamespace(
+ encode=lambda s, add_special_tokens=False: list(range(200))
+ ),
+ )
+ with pytest.raises(_Exc) as exc_info:
+ server.check_context_length("long", fake_proc, 100)
+ assert exc_info.value.status_code == 400
+
+ def test_check_context_length_zero_unlimited(self):
+ server.check_context_length("anything", None, 0)
+
+ def test_get_max_context_tokens_default(self):
+ import os
+
+ os.environ.pop("MAX_CONTEXT_TOKENS", None)
+ assert server.get_max_context_tokens() == 0
+
+ def test_get_max_context_tokens_from_env(self):
+ import os
+
+ os.environ["MAX_CONTEXT_TOKENS"] = "16384"
+ assert server.get_max_context_tokens() == 16384
+ os.environ.pop("MAX_CONTEXT_TOKENS")
+
+
+# =========================================================================
+# J. Request Cancellation Tests
+# =========================================================================
+
+
+@_skip_no_mlx
+class TestRequestCancellation:
+ """Verify request timeout and cancellation support."""
+
+ def test_get_request_timeout_default(self):
+ import os
+
+ os.environ.pop("REQUEST_TIMEOUT", None)
+ assert server.get_request_timeout() == 300
+
+ def test_get_request_timeout_from_env(self):
+ import os
+
+ os.environ["REQUEST_TIMEOUT"] = "60"
+ assert server.get_request_timeout() == 60
+ os.environ.pop("REQUEST_TIMEOUT")
+
+ def test_streaming_cleanup_on_normal_completion(self, client):
+ """Streaming should complete normally and clean up."""
+ chunks = [
+ SimpleNamespace(
+ text="Hi",
+ prompt_tokens=5,
+ generation_tokens=1,
+ prompt_tps=100.0,
+ generation_tps=50.0,
+ peak_memory=1.0,
+ ),
+ ]
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "stream_generate", return_value=iter(chunks)),
+ ):
+ resp = client.post(
+ "/responses",
+ json={"model": "demo", "input": "hi", "stream": True},
+ )
+ assert resp.status_code == 200
+ assert "response.completed" in resp.text
+
+
+# =========================================================================
+# K. Prompt Cache Key Routing Tests
+# =========================================================================
+
+
+@_skip_no_mlx
+class TestPromptCacheKeyRouting:
+ """Verify prompt_cache_key routes to separate cache states."""
+
+ def test_same_cache_key_same_state(self):
+ """Same model + cache_key should return the same PromptCacheState."""
+ s1 = server.get_prompt_cache_state("model-a", "session-1")
+ s2 = server.get_prompt_cache_state("model-a", "session-1")
+ assert s1 is s2
+
+ def test_different_cache_key_different_state(self):
+ """Different cache_keys should get isolated cache states."""
+ s1 = server.get_prompt_cache_state("model-a", "session-1")
+ s2 = server.get_prompt_cache_state("model-a", "session-2")
+ assert s1 is not s2
+
+ def test_no_cache_key_falls_back_to_model(self):
+ """No cache_key should fall back to model-only keying."""
+ s1 = server.get_prompt_cache_state("model-b")
+ s2 = server.get_prompt_cache_state("model-b", None)
+ assert s1 is s2
+
+ def test_cache_key_passed_from_request(self, client):
+ """prompt_cache_key from request should route to correct cache."""
+ captured = {}
+
+ def capture_gen(**kwargs):
+ captured["state"] = kwargs.get("prompt_cache_state")
+ return _mock_result()
+
+ with (
+ _patch_model(),
+ _patch_template(),
+ patch.object(server, "generate", side_effect=capture_gen),
+ ):
+ client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": "hi",
+ "prompt_cache_key": "my-session",
+ },
+ )
+ assert captured["state"] is server.get_prompt_cache_state("demo", "my-session")
diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py
index 270a82d77..0ed3c20a5 100644
--- a/mlx_vlm/tests/test_server.py
+++ b/mlx_vlm/tests/test_server.py
@@ -130,3 +130,127 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client):
assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15
assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5}
assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512)
+
+
+# ---------------------------------------------------------------------------
+# Stop sequences tests
+# ---------------------------------------------------------------------------
+
+
+def test_chat_completions_stop_passed_as_eos_tokens(client):
+ """stop parameter should be passed as eos_tokens strings to generate."""
+ model = SimpleNamespace()
+ processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template=""))
+ config = SimpleNamespace(model_type="test")
+ result = SimpleNamespace(
+ text="Hello",
+ prompt_tokens=5,
+ generation_tokens=1,
+ total_tokens=6,
+ prompt_tps=100.0,
+ generation_tps=50.0,
+ peak_memory=1.0,
+ )
+
+ with (
+ patch.object(
+ server, "get_cached_model", return_value=(model, processor, config)
+ ),
+ patch.object(server, "apply_chat_template", return_value="prompt"),
+ patch.object(server, "generate", return_value=result) as mock_gen,
+ ):
+ resp = client.post(
+ "/chat/completions",
+ json={
+ "model": "demo",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stop": ["\n\n", ""],
+ },
+ )
+ assert resp.status_code == 200
+ assert "eos_tokens" in mock_gen.call_args.kwargs
+ assert mock_gen.call_args.kwargs["eos_tokens"] == ["\n\n", ""]
+
+
+def test_chat_completions_no_stop_no_eos_tokens(client):
+ """Without stop parameter, eos_tokens should not be in kwargs."""
+ model = SimpleNamespace()
+ processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template=""))
+ config = SimpleNamespace(model_type="test")
+ result = SimpleNamespace(
+ text="Hi",
+ prompt_tokens=5,
+ generation_tokens=1,
+ total_tokens=6,
+ prompt_tps=100.0,
+ generation_tps=50.0,
+ peak_memory=1.0,
+ )
+
+ with (
+ patch.object(
+ server, "get_cached_model", return_value=(model, processor, config)
+ ),
+ patch.object(server, "apply_chat_template", return_value="prompt"),
+ patch.object(server, "generate", return_value=result) as mock_gen,
+ ):
+ resp = client.post(
+ "/chat/completions",
+ json={"model": "demo", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.status_code == 200
+ assert "eos_tokens" not in mock_gen.call_args.kwargs
+
+
+def test_responses_stop_passed_as_eos_tokens(client):
+ """stop parameter on /responses should pass strings as eos_tokens."""
+ model = SimpleNamespace()
+ processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template=""))
+ config = SimpleNamespace(model_type="test")
+ result = SimpleNamespace(
+ text="Hello",
+ prompt_tokens=5,
+ generation_tokens=1,
+ total_tokens=6,
+ prompt_tps=100.0,
+ generation_tps=50.0,
+ peak_memory=1.0,
+ )
+
+ with (
+ patch.object(
+ server, "get_cached_model", return_value=(model, processor, config)
+ ),
+ patch.object(server, "apply_chat_template", return_value="prompt"),
+ patch.object(server, "generate", return_value=result) as mock_gen,
+ ):
+ resp = client.post(
+ "/responses",
+ json={"model": "demo", "input": "hi", "stop": "STOP"},
+ )
+ assert resp.status_code == 200
+ assert "eos_tokens" in mock_gen.call_args.kwargs
+ assert mock_gen.call_args.kwargs["eos_tokens"] == ["STOP"]
+
+
+def test_resolve_stop_sequences_single_string():
+ """resolve_stop_sequences should normalize a single string to a list."""
+ result = server.resolve_stop_sequences("hello")
+ assert result == ["hello"]
+
+
+def test_resolve_stop_sequences_list():
+ """resolve_stop_sequences should pass through a list."""
+ result = server.resolve_stop_sequences(["a", "b"])
+ assert result == ["a", "b"]
+
+
+def test_resolve_stop_sequences_none():
+ """resolve_stop_sequences should return None for None input."""
+ assert server.resolve_stop_sequences(None) is None
+
+
+def test_resolve_stop_sequences_limits_to_four():
+ """resolve_stop_sequences should process at most 4 sequences."""
+ result = server.resolve_stop_sequences(["a", "b", "c", "d", "e", "f"])
+ assert len(result) == 4
diff --git a/scripts/benchmark_models.py b/scripts/benchmark_models.py
new file mode 100644
index 000000000..50870c829
--- /dev/null
+++ b/scripts/benchmark_models.py
@@ -0,0 +1,211 @@
+#!/usr/bin/env python3
+"""Benchmark models for agentic use on bastion.
+
+Usage: python3 scripts/benchmark_models.py [--models model1,model2,...]
+
+Tests: generation speed, prefill speed, tool calling, code quality,
+instruction following, and context capacity.
+"""
+
+import json
+import sys
+import time
+
+import requests
+
+BASE = "http://100.106.192.127:8080"
+
+MODELS = [
+ "mlx-community/Qwen3.5-35B-A3B-4bit",
+ "inferencerlabs/Qwen3.5-35B-A3B-MLX-5.5bit",
+ "unsloth/gemma-4-26b-a4b-it-UD-MLX-4bit",
+]
+
+TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "web_search",
+ "description": "Search the web",
+ "parameters": {
+ "type": "object",
+ "properties": {"query": {"type": "string"}},
+ "required": ["query"],
+ },
+ },
+ }
+]
+
+
+def api_call(model, input_text, max_tokens=50, tools=None, temperature=0, timeout=300):
+ payload = {
+ "model": model,
+ "input": input_text,
+ "max_output_tokens": max_tokens,
+ "temperature": temperature,
+ }
+ if tools:
+ payload["tools"] = tools
+
+ t0 = time.time()
+ try:
+ r = requests.post(f"{BASE}/v1/responses", json=payload, timeout=timeout)
+ elapsed = time.time() - t0
+ if r.status_code == 200:
+ return r.json(), elapsed
+ else:
+ return {"error": r.status_code, "detail": r.text[:100]}, elapsed
+ except Exception as e:
+ return {"error": str(e)}, time.time() - t0
+
+
+def test_generation_speed(model):
+ """Measure tok/s on a simple prompt."""
+ prompt = "Write a detailed story about a wizard. " * 5
+ d, elapsed = api_call(model, prompt, max_tokens=200, temperature=0.7)
+ if "error" in d:
+ return None, None
+ u = d.get("usage", {})
+ out = u.get("output_tokens", 0)
+ tps = out / elapsed if elapsed > 0 else 0
+ return tps, u.get("input_tokens", 0)
+
+
+def test_prefill_speed(model):
+ """Measure prefill tok/s on a large prompt."""
+ prompt = "Fox dog cat bird fish. " * 2000 # ~12K tokens
+ d, elapsed = api_call(model, prompt + " Answer: yes.", max_tokens=5)
+ if "error" in d:
+ return None
+ inp = d.get("usage", {}).get("input_tokens", 0)
+ return inp / elapsed if elapsed > 0 else 0
+
+
+def test_tool_calling(model):
+ """Test structured tool call generation."""
+ d, _ = api_call(
+ model,
+ "Search the web for Python 3.14 release date",
+ max_tokens=100,
+ tools=TOOLS,
+ )
+ if "error" in d:
+ return False
+ return any(i.get("type") == "function_call" for i in d.get("output", []))
+
+
+def test_code_quality(model):
+ """Test code generation quality."""
+ d, _ = api_call(
+ model,
+ "Write a Python function implementing binary search with type hints and docstring. Only code, no explanation.",
+ max_tokens=300,
+ )
+ if "error" in d:
+ return 0
+ text = ""
+ for i in d.get("output", []):
+ if i.get("type") == "message":
+ text += "".join(p.get("text", "") for p in i.get("content", []))
+ score = 0
+ if "def " in text and "binary" in text.lower():
+ score += 1
+ if "->" in text:
+ score += 1 # type hints
+ if '"""' in text or "'''" in text:
+ score += 1 # docstring
+ return score
+
+
+def test_instruction_following(model):
+ """Test precise instruction following."""
+ d, _ = api_call(
+ model,
+ "Do exactly these 3 things:\n1. Output ALPHA\n2. Output 42\n3. Output OMEGA\n\nOne per line, no other text.",
+ max_tokens=30,
+ )
+ if "error" in d:
+ return 0
+ text = ""
+ for i in d.get("output", []):
+ if i.get("type") == "message":
+ text += "".join(p.get("text", "") for p in i.get("content", []))
+ score = 0
+ if "ALPHA" in text:
+ score += 1
+ if "42" in text:
+ score += 1
+ if "OMEGA" in text:
+ score += 1
+ return score
+
+
+def run_benchmarks(models):
+ results = {}
+
+ for model in models:
+ print(f"\n{'='*60}")
+ print(f" {model}")
+ print(f"{'='*60}")
+
+ r = {}
+
+ # Gen speed
+ print(" Generation speed...", end=" ", flush=True)
+ tps, inp_tok = test_generation_speed(model)
+ r["gen_tps"] = tps
+ print(f"{tps:.1f} tok/s" if tps else "FAILED")
+
+ # Prefill speed
+ print(" Prefill speed...", end=" ", flush=True)
+ prefill = test_prefill_speed(model)
+ r["prefill_tps"] = prefill
+ print(f"{prefill:.0f} tok/s" if prefill else "FAILED")
+
+ # Tool calling
+ print(" Tool calling...", end=" ", flush=True)
+ tools_ok = test_tool_calling(model)
+ r["tool_call"] = tools_ok
+ print("YES" if tools_ok else "NO")
+
+ # Code quality
+ print(" Code quality...", end=" ", flush=True)
+ code = test_code_quality(model)
+ r["code_quality"] = code
+ print(f"{code}/3")
+
+ # Instruction following
+ print(" Instruction following...", end=" ", flush=True)
+ inst = test_instruction_following(model)
+ r["instruction"] = inst
+ print(f"{inst}/3")
+
+ results[model] = r
+
+ # Summary table
+ print(f"\n{'='*80}")
+ print(f" {'Model':<45} {'Gen':>7} {'Prefill':>8} {'Tools':>6} {'Code':>5} {'Inst':>5}")
+ print(f" {'-'*45} {'-'*7} {'-'*8} {'-'*6} {'-'*5} {'-'*5}")
+ for model, r in results.items():
+ name = model.split("/")[-1][:44]
+ gen = f"{r['gen_tps']:.0f}" if r["gen_tps"] else "FAIL"
+ pre = f"{r['prefill_tps']:.0f}" if r["prefill_tps"] else "FAIL"
+ tools = "YES" if r["tool_call"] else "NO"
+ code = f"{r['code_quality']}/3"
+ inst = f"{r['instruction']}/3"
+ print(f" {name:<45} {gen:>7} {pre:>8} {tools:>6} {code:>5} {inst:>5}")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Benchmark MLX-VLM models")
+ parser.add_argument(
+ "--models",
+ type=str,
+ default=None,
+ help="Comma-separated list of model names to benchmark.",
+ )
+ args = parser.parse_args()
+ models = args.models.split(",") if args.models else MODELS
+ run_benchmarks(models)
diff --git a/scripts/validate_oc.sh b/scripts/validate_oc.sh
new file mode 100755
index 000000000..c2536089d
--- /dev/null
+++ b/scripts/validate_oc.sh
@@ -0,0 +1,184 @@
+#!/usr/bin/env bash
+# =============================================================================
+# OpenClaw Integration Validation Script
+#
+# Validates the mlx-vlm server works end-to-end through OpenClaw on bastion.
+# Requires: ssh access to bastion, OpenClaw gateway running, mlx-vlm server up.
+#
+# Usage: ./scripts/validate_oc.sh [bastion-host]
+# =============================================================================
+
+set -euo pipefail
+
+HOST="${1:-bastion}"
+PASS=0
+FAIL=0
+SKIP=0
+RESULTS=()
+
+# Colors
+GREEN='\033[0;32m'
+RED='\033[0;31m'
+YELLOW='\033[0;33m'
+NC='\033[0m'
+
+run_test() {
+ local name="$1"
+ local cmd="$2"
+ local check="$3"
+
+ printf " %-55s " "$name"
+
+ local output
+ if ! output=$(bash -c "$cmd" 2>&1); then
+ printf "${RED}FAIL${NC} (command error)\n"
+ FAIL=$((FAIL + 1))
+ RESULTS+=("FAIL: $name — command error")
+ return
+ fi
+
+ if printf '%s\n' "$output" | bash -c "$check" > /dev/null 2>&1; then
+ printf "${GREEN}PASS${NC}\n"
+ PASS=$((PASS + 1))
+ RESULTS+=("PASS: $name")
+ else
+ printf "${RED}FAIL${NC}\n"
+ FAIL=$((FAIL + 1))
+ RESULTS+=("FAIL: $name — check failed")
+ echo " Output: $(echo "$output" | head -3)"
+ fi
+}
+
+run_api_test() {
+ local name="$1"
+ local endpoint="$2"
+ local payload="$3"
+ local check="$4"
+
+ local cmd="curl -s --max-time 90 http://100.106.192.127:8080${endpoint} -H 'Content-Type: application/json' -d '${payload}'"
+ run_test "$name" "$cmd" "$check"
+}
+
+echo "============================================================"
+echo " MLX-VLM + OpenClaw Integration Validation"
+echo "============================================================"
+echo ""
+
+# --- Connectivity checks ---
+echo "Connectivity:"
+run_test "SSH to bastion" \
+ "ssh -o ConnectTimeout=10 $HOST 'echo ok'" \
+ "grep -q ok"
+
+run_test "mlx-vlm server health" \
+ "curl -s --max-time 10 http://100.106.192.127:8080/health" \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d.get('status')=='healthy'\""
+
+run_test "mlx-vlm models endpoint" \
+ "curl -s --max-time 10 http://100.106.192.127:8080/v1/models" \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert len(d['data'])>0\""
+
+run_test "OpenClaw gateway running" \
+ "ssh $HOST 'launchctl list | grep openclaw'" \
+ "grep -q openclaw"
+
+run_test "Telegram channel active" \
+ "ssh $HOST 'export PATH=/opt/homebrew/bin:\$PATH && openclaw channels list'" \
+ "grep -qi telegram"
+
+echo ""
+
+# --- Responses API tests ---
+echo "Responses API (/v1/responses):"
+
+run_api_test "Basic text response" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hi in one word","max_output_tokens":10}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='completed'; assert len(d['output'])>0\""
+
+run_api_test "Response has correct schema" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hello","max_output_tokens":10}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert 'id' in d; assert d['object']=='response'; assert 'usage' in d\""
+
+run_api_test "Tools accepted without 422" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"What is 2+2? Answer briefly.","max_output_tokens":50,"tools":[{"type":"function","function":{"name":"calc","description":"Calculator","parameters":{"type":"object","properties":{"expr":{"type":"string"}}}}}]}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status'] in ('completed','incomplete'); assert 'output' in d\""
+
+run_api_test "Tools echoed in response" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"hi","max_output_tokens":10,"tools":[{"type":"function","function":{"name":"test","parameters":{}}}]}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert len(d.get('tools',[]))>0\""
+
+run_api_test "Instructions field works" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"What are you?","instructions":"You are a pirate. Respond in pirate speak.","max_output_tokens":50}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status'] in ('completed','incomplete'); assert d.get('instructions') is not None\""
+
+run_api_test "Stop sequences accepted" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Count to 10","max_output_tokens":50,"stop":["5"]}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='completed'\""
+
+echo ""
+
+# --- Chat Completions API tests ---
+echo "Chat Completions (/v1/chat/completions):"
+
+run_api_test "Basic chat response" \
+ "/v1/chat/completions" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","messages":[{"role":"user","content":"Say hi"}],"max_tokens":10}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert len(d['choices'])>0; assert d['choices'][0]['finish_reason']=='stop'\""
+
+run_api_test "Chat with tools" \
+ "/v1/chat/completions" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","messages":[{"role":"user","content":"hi"}],"max_tokens":10,"tools":[{"type":"function","function":{"name":"test","parameters":{}}}]}' \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['choices'][0]['finish_reason'] in ('stop','tool_calls')\""
+
+echo ""
+
+# --- Streaming tests ---
+echo "Streaming:"
+
+run_api_test "Responses streaming has SSE events" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hi","max_output_tokens":10,"stream":true}' \
+ "grep -q 'event: response.completed'"
+
+run_api_test "Streaming ends with [DONE]" \
+ "/v1/responses" \
+ '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hi","max_output_tokens":10,"stream":true}' \
+ "grep -q 'DONE'"
+
+echo ""
+
+# --- OpenClaw agent tests ---
+echo "OpenClaw Agent (end-to-end):"
+
+run_test "OC agent basic response" \
+ "ssh $HOST 'export PATH=/opt/homebrew/bin:\$PATH && openclaw agent --agent main -m \"Say hello in one word\" --json --timeout 90'" \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='ok'; assert d['result']['meta']['stopReason']=='stop'\""
+
+run_test "OC agent with web search" \
+ "ssh $HOST 'export PATH=/opt/homebrew/bin:\$PATH && openclaw agent --agent main -m \"What is the current weather in Kansas City? Use web search.\" --json --timeout 120'" \
+ "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='ok'\""
+
+echo ""
+
+# --- Summary ---
+TOTAL=$((PASS + FAIL + SKIP))
+echo "============================================================"
+echo " Results: ${GREEN}${PASS} passed${NC}, ${RED}${FAIL} failed${NC}, ${YELLOW}${SKIP} skipped${NC} / ${TOTAL} total"
+echo "============================================================"
+
+if [ $FAIL -gt 0 ]; then
+ echo ""
+ echo "Failures:"
+ for r in "${RESULTS[@]}"; do
+ if [[ "$r" == FAIL* ]]; then
+ echo " $r"
+ fi
+ done
+ exit 1
+fi