diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 1d94bd805..7067b3a55 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -3,6 +3,7 @@ import contextlib import functools import json +import os import time from collections.abc import Sequence from dataclasses import dataclass @@ -338,6 +339,7 @@ class GenerationResult: prompt_tokens: int = 0 generation_tokens: int = 0 total_tokens: int = 0 + cached_tokens: int = 0 prompt_tps: float = 0.0 generation_tps: float = 0.0 peak_memory: float = 0.0 @@ -354,6 +356,13 @@ class PromptCacheState: def __init__(self): self.cache: Optional[List[Any]] = None self.token_ids: Optional[List[int]] = None + self.last_used: float = time.time() + self.created_at: float = time.time() + + @property + def token_count(self) -> int: + """Number of tokens stored in the cache.""" + return len(self.token_ids) if self.token_ids else 0 def find_prefix_length(self, new_ids: list) -> int: """Return the number of leading tokens that match the cached ids.""" @@ -365,10 +374,20 @@ def find_prefix_length(self, new_ids: list) -> int: return i return max_len + def touch(self): + """Update last_used timestamp.""" + self.last_used = time.time() + def update(self, token_ids: list, kv_cache: list): """Store the full token sequence and corresponding KV cache.""" self.token_ids = list(token_ids) self.cache = kv_cache + self.last_used = time.time() + + def invalidate(self): + """Discard cached state, forcing a full prefill on next turn.""" + self.cache = None + self.token_ids = None def generate_step( @@ -668,33 +687,65 @@ def stream_generate( reused_prefix_len = 0 full_input_ids_list = input_ids.flatten().tolist() + # Save originals for fallback if cache reuse fails + _original_input_ids = input_ids + _original_pixel_values = pixel_values + _original_kwargs = {k: v for k, v in kwargs.items() if k == "cached_image_features"} + if prompt_cache_state is not None and prompt_cache_state.cache is not None: - prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list) - if prefix_len > 0 and prefix_len < input_ids.shape[1]: - reused_prefix_len = prefix_len - # Trim to only new tokens - input_ids = input_ids[:, prefix_len:] - # Only skip vision if no image tokens in the new (trimmed) tokens - image_token_id = getattr(model.config, "image_token_id", None) or getattr( - model.config, "image_token_index", None + try: + prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list) + cached_total = ( + len(prompt_cache_state.token_ids) if prompt_cache_state.token_ids else 0 ) - new_ids = input_ids.flatten().tolist() - has_image_in_new = image_token_id is not None and image_token_id in new_ids - if not has_image_in_new: - pixel_values = None - kwargs.pop("cached_image_features", None) - # Reuse the saved KV cache (trimmed to prefix length) - kv_cache = prompt_cache_state.cache - # Trim cache to prefix_len in case it includes generated tokens - for c in kv_cache: - if hasattr(c, "keys") and c.keys is not None: - cached_len = c.keys.shape[2] - if cached_len > prefix_len: - c.keys = c.keys[:, :, :prefix_len, :] - c.values = c.values[:, :, :prefix_len, :] - if hasattr(c, "offset"): - c.offset = prefix_len - kwargs["prompt_cache"] = kv_cache + # Only reuse if a substantial prefix matches (>= 50% of cached tokens). + # Short matches on quantized KV caches (TurboQuant) can produce + # corrupted output because trim() only adjusts the offset without + # clearing stale quantized data. + min_reuse = max(512, cached_total // 2) + if prefix_len >= min_reuse and prefix_len < input_ids.shape[1]: + reused_prefix_len = prefix_len + # Trim to only new tokens + input_ids = input_ids[:, prefix_len:] + # Only skip vision if no image tokens in the new (trimmed) tokens + image_token_id = getattr( + model.config, "image_token_id", None + ) or getattr(model.config, "image_token_index", None) + new_ids = input_ids.flatten().tolist() + has_image_in_new = ( + image_token_id is not None and image_token_id in new_ids + ) + if not has_image_in_new: + pixel_values = None + kwargs.pop("cached_image_features", None) + # Reuse the saved KV cache (trimmed to prefix length). + # Works with both standard KVCache (mx.array keys) and + # quantized caches (TurboQuant) via their trim() method. + kv_cache = prompt_cache_state.cache + for c in kv_cache: + if hasattr(c, "offset") and c.offset > prefix_len: + trim_amount = c.offset - prefix_len + if hasattr(c, "trim") and callable(c.trim): + c.trim(trim_amount) + elif hasattr(c, "keys") and c.keys is not None: + keys = c.keys + if hasattr(keys, "shape") and len(keys.shape) >= 3: + c.keys = keys[:, :, :prefix_len, :] + c.values = c.values[:, :, :prefix_len, :] + c.offset = prefix_len + kwargs["prompt_cache"] = kv_cache + except Exception as e: + # Cache reuse failed (e.g., shape mismatch, stale KV state). + # Invalidate and fall back to fresh generation. + if os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes"): + print(f"[prompt_cache] Cache reuse failed, invalidating: {e}") + prompt_cache_state.invalidate() + reused_prefix_len = 0 + input_ids = _original_input_ids + pixel_values = _original_pixel_values + if "cached_image_features" in _original_kwargs: + kwargs["cached_image_features"] = _original_kwargs["cached_image_features"] + kwargs.pop("prompt_cache", None) if thinking_budget is not None: thinking_start_token_id = tokenizer.encode( @@ -720,6 +771,35 @@ def stream_generate( model.language_model, max_kv_size=kwargs.get("max_kv_size", None), ) + + # Validate cache shapes before generation. If the cached KV state has + # inconsistent shapes (e.g., stale after model reload), discard it and + # build a fresh cache to avoid broadcast_shapes errors during generation. + if reused_prefix_len > 0: + try: + for c in kwargs["prompt_cache"]: + if hasattr(c, "offset"): + # Use offset for all cache types (works for both standard + # KVCache and quantized TurboQuant caches). + if c.offset != reused_prefix_len: + raise ValueError( + f"Cache offset mismatch: expected {reused_prefix_len}, got {c.offset}" + ) + except (ValueError, IndexError, AttributeError) as e: + if os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes"): + print(f"[prompt_cache] Cache validation failed, rebuilding: {e}") + if prompt_cache_state is not None: + prompt_cache_state.invalidate() + reused_prefix_len = 0 + input_ids = _original_input_ids + pixel_values = _original_pixel_values + if "cached_image_features" in _original_kwargs: + kwargs["cached_image_features"] = _original_kwargs["cached_image_features"] + kwargs["prompt_cache"] = cache.make_prompt_cache( + model.language_model, + max_kv_size=kwargs.get("max_kv_size", None), + ) + tracked_cache = kwargs["prompt_cache"] total_prompt_tokens = reused_prefix_len + input_ids.size @@ -758,6 +838,7 @@ def stream_generate( prompt_tokens=total_prompt_tokens, generation_tokens=n + 1, total_tokens=total_prompt_tokens + n + 1, + cached_tokens=reused_prefix_len, prompt_tps=prompt_tps, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.get_peak_memory() / 1e9, @@ -771,13 +852,21 @@ def stream_generate( prompt_tokens=total_prompt_tokens, generation_tokens=n + 1, total_tokens=total_prompt_tokens + n + 1, + cached_tokens=reused_prefix_len, prompt_tps=prompt_tps, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.get_peak_memory() / 1e9, ) - # Save cache state for potential reuse on next turn - if prompt_cache_state is not None: + # Save cache state for potential reuse on next turn. + # Only save if the prompt was substantial (>= 1024 tokens) to avoid + # polluting the cache with short probe/capability-check requests that + # some agent frameworks send before the real request. + _MIN_CACHE_TOKENS = 1024 + if ( + prompt_cache_state is not None + and len(full_input_ids_list) >= _MIN_CACHE_TOKENS + ): all_ids = full_input_ids_list + [ t.item() if hasattr(t, "item") else t for t in generated_tokens ] diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index a4db38c47..52b68b3d8 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -460,6 +460,39 @@ def get_message_json( ) +def _normalize_tool(tool): + """Ensure a tool dict uses the Chat Completions nested format. + + The OpenAI Responses API uses a flat format:: + + {"type": "function", "name": "...", "description": "...", "parameters": {...}} + + While Chat Completions (and most Jinja chat templates) expect:: + + {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}} + + This helper converts the flat format to nested so that Jinja templates + referencing ``tool['function']`` (e.g. Gemma 4) work correctly. + """ + if not isinstance(tool, dict): + return tool + # Already in nested format + if "function" in tool and isinstance(tool["function"], dict): + return tool + # Flat Responses-API format → wrap in 'function' key + if tool.get("type") == "function" and "name" in tool: + fn = {k: v for k, v in tool.items() if k != "type"} + return {"type": "function", "function": fn} + return tool + + +def _normalize_tools(tools): + """Normalize a list of tool dicts to the Chat Completions nested format.""" + if not isinstance(tools, list): + return tools + return [_normalize_tool(t) for t in tools] + + def get_chat_template( processor, messages: List[Dict[str, Any]], @@ -615,6 +648,11 @@ def _missing_template_error(error: Exception) -> bool: if template_processor is None: return _messages_to_plain_prompt() + # Normalize tool dicts from flat Responses-API format to the nested + # Chat Completions format expected by Jinja chat templates. + if "tools" in kwargs and kwargs["tools"] is not None: + kwargs["tools"] = _normalize_tools(kwargs["tools"]) + try: return template_processor.apply_chat_template( messages, diff --git a/mlx_vlm/responses_models.py b/mlx_vlm/responses_models.py new file mode 100644 index 000000000..507e09578 --- /dev/null +++ b/mlx_vlm/responses_models.py @@ -0,0 +1,515 @@ +"""Pydantic models for the OpenAI Responses API (/v1/responses). + +This module defines all request, response, and streaming event models +for the OpenAI-compatible Responses endpoint. Models are self-contained +to avoid circular imports with server.py. + +Reference: https://developers.openai.com/api/reference/resources/responses +""" + +import uuid +from typing import Any, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Required, TypeAlias, TypedDict + +# --------------------------------------------------------------------------- +# Constants (mirrored from mlx_vlm.generate to avoid heavy imports) +# --------------------------------------------------------------------------- + +DEFAULT_MAX_TOKENS = 4096 +DEFAULT_TEMPERATURE = 0.0 +DEFAULT_TOP_P = 1.0 +DEFAULT_THINKING_START_TOKEN = "" +DEFAULT_THINKING_END_TOKEN = "" + + +# --------------------------------------------------------------------------- +# Base classes (duplicated from server.py for import independence) +# --------------------------------------------------------------------------- + + +class FlexibleBaseModel(BaseModel): + """Base model that silently accepts unknown fields for forward compatibility.""" + + model_config = ConfigDict(extra="allow") + + def dump_kwargs(self, *fields: str) -> dict[str, Any]: + """Return a dict of the requested fields, omitting ``None`` values.""" + return { + key: getattr(self, key) + for key in fields + if hasattr(self, key) and getattr(self, key) is not None + } + + +class GenerationParams(FlexibleBaseModel): + """Sampling parameters shared across endpoints.""" + + temperature: float = Field( + DEFAULT_TEMPERATURE, description="Temperature for sampling." + ) + top_p: float = Field(DEFAULT_TOP_P, description="Top-p sampling.") + top_k: Optional[int] = Field(None, description="Top-k sampling cutoff.") + min_p: Optional[float] = Field(None, description="Min-p sampling threshold.") + repetition_penalty: Optional[float] = Field( + None, description="Penalty applied to repeated tokens." + ) + logit_bias: Optional[dict[int, float]] = Field( + None, description="Additive logit bias keyed by token id." + ) + + def shared_generation_kwargs(self) -> dict[str, Any]: + return self.dump_kwargs( + "temperature", + "top_p", + "top_k", + "min_p", + "repetition_penalty", + "logit_bias", + ) + + +class TemplateParams(FlexibleBaseModel): + """Chat template parameters (thinking mode, etc.).""" + + enable_thinking: Optional[bool] = Field( + None, description="Enable thinking mode in the chat template." + ) + thinking_budget: Optional[int] = Field( + None, + description="Maximum number of thinking tokens before forcing the end token.", + ) + thinking_start_token: Optional[str] = Field( + DEFAULT_THINKING_START_TOKEN, + description="Token that marks the start of a thinking block.", + ) + thinking_end_token: Optional[str] = Field( + DEFAULT_THINKING_END_TOKEN, + description="Token that marks the end of a thinking block.", + ) + + def template_kwargs(self) -> dict[str, Any]: + kwargs = self.dump_kwargs( + "enable_thinking", + "thinking_budget", + "thinking_start_token", + "thinking_end_token", + ) + kwargs.setdefault("enable_thinking", False) + return kwargs + + +# --------------------------------------------------------------------------- +# Input content types (TypedDicts matching OpenAI SDK) +# --------------------------------------------------------------------------- + + +class ResponseInputTextParam(TypedDict, total=False): + """Text content item — accepts both ``input_text`` and ``text`` types.""" + + text: Required[str] + type: Required[Literal["input_text", "text"]] + + +class ResponseInputImageParam(TypedDict, total=False): + """Image content item with a direct image URL.""" + + detail: Literal["high", "low", "auto"] + type: Required[Literal["input_image"]] + image_url: Required[str] + file_id: Optional[str] + + +class InputAudio(TypedDict, total=False): + data: Required[str] + format: Required[str] + + +class ResponseInputAudioParam(TypedDict, total=False): + """Audio content item.""" + + type: Required[Literal["input_audio"]] + input_audio: Required[InputAudio] + + +class ImageUrl(TypedDict, total=False): + url: Required[str] + + +class ResponseImageUrlParam(TypedDict, total=False): + """Image content item with nested ``image_url.url`` (chat/completions format).""" + + type: Required[Literal["image_url"]] + image_url: Required[ImageUrl] + + +class ResponseOutputText(TypedDict, total=False): + """Output text item used in multi-turn assistant messages.""" + + text: Required[str] + type: Required[Literal["output_text"]] + + +ResponseInputContentParam: TypeAlias = Union[ + ResponseInputTextParam, + ResponseInputImageParam, + ResponseImageUrlParam, + ResponseInputAudioParam, +] + +ResponseInputMessageContentListParam: TypeAlias = List[ResponseInputContentParam] +ResponseOutputMessageContentList: TypeAlias = List[ResponseOutputText] + + +# --------------------------------------------------------------------------- +# Chat message model +# --------------------------------------------------------------------------- + + +class ChatMessage(FlexibleBaseModel): + """A single message in the conversation input.""" + + role: Literal["user", "assistant", "system", "developer", "tool"] = Field( + ..., description="Role of the message sender." + ) + content: Optional[ + Union[ + str, + ResponseInputMessageContentListParam, + ResponseOutputMessageContentList, + ] + ] = Field(None, description="Content of the message.") + tool_calls: List = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Function tool definition +# --------------------------------------------------------------------------- + + +class ResponseFunctionTool(BaseModel): + """A function tool the model may call.""" + + type: Literal["function"] = "function" + name: str = Field(..., description="The name of the function.") + description: Optional[str] = Field( + None, description="A description of what the function does." + ) + parameters: Optional[dict] = Field( + None, description="JSON Schema object describing the function parameters." + ) + strict: Optional[bool] = Field( + None, description="Whether to enforce strict schema adherence." + ) + + +# --------------------------------------------------------------------------- +# Function call input items (for multi-turn tool use) +# --------------------------------------------------------------------------- + + +class ResponseFunctionCallInputItem(BaseModel): + """A function call from a previous assistant turn, included in input.""" + + type: Literal["function_call"] = "function_call" + call_id: str = Field(..., description="Unique ID for this tool call.") + name: str = Field(..., description="The function name that was called.") + arguments: str = Field(..., description="JSON string of the function arguments.") + status: Optional[str] = "completed" + + +class ResponseFunctionCallOutputInputItem(BaseModel): + """The output/result of a function call, sent back by the client.""" + + type: Literal["function_call_output"] = "function_call_output" + call_id: str = Field( + ..., description="The call_id of the function call this is a result for." + ) + output: str = Field(..., description="The function output as a string.") + + +# --------------------------------------------------------------------------- +# Request model +# --------------------------------------------------------------------------- + + +class ResponsesRequest(GenerationParams, TemplateParams): + """OpenAI Responses API request body. + + Reference: https://developers.openai.com/api/reference/resources/responses/create + """ + + input: Union[str, List[Any]] = Field( + ..., description="Input text or list of input items (messages, tool outputs)." + ) + model: str = Field(..., description="The model to use for generation.") + max_output_tokens: Optional[int] = Field( + None, description="Maximum number of tokens to generate. Uses server default if not specified." + ) + stream: bool = Field( + False, description="Whether to stream the response chunk by chunk." + ) + tools: Optional[List[dict]] = Field( + None, description="Tool definitions the model may call." + ) + tool_choice: Optional[Any] = Field( + "auto", description='Tool choice: "none", "auto", "required", or specific tool.' + ) + parallel_tool_calls: bool = Field(True, description="Allow parallel tool calls.") + previous_response_id: Optional[str] = Field( + None, + description="ID of a previous response for multi-turn context replay.", + ) + instructions: Optional[str] = Field( + None, + description="System/developer message inserted into context.", + ) + metadata: Optional[dict] = Field( + None, description="Up to 16 key-value pairs of metadata." + ) + stop: Optional[Union[str, List[str]]] = Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens.", + ) + response_format: Optional[dict] = Field( + None, + description='Output format: {"type": "text"} or {"type": "json_object"}.', + ) + prompt_cache_key: Optional[str] = Field( + None, + description="Stable key for prompt cache routing across turns.", + ) + + def generation_kwargs(self) -> dict[str, Any]: + kwargs = self.dump_kwargs("max_output_tokens") + if "max_output_tokens" in kwargs: + kwargs["max_tokens"] = kwargs.pop("max_output_tokens") + return {**kwargs, **self.shared_generation_kwargs()} + + +# --------------------------------------------------------------------------- +# Output item models +# --------------------------------------------------------------------------- + + +class ContentPartOutputText(BaseModel): + """A text content part in an output message.""" + + type: Literal["output_text"] = "output_text" + text: str = "" + annotations: List[str] = Field(default_factory=list) + + +class ResponseMessageItem(BaseModel): + """An assistant message output item.""" + + id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:24]}") + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + status: Literal["in_progress", "completed"] = "completed" + content: List[ContentPartOutputText] = Field(default_factory=list) + + +class ResponseFunctionCallItem(BaseModel): + """A function call output item.""" + + type: Literal["function_call"] = "function_call" + id: str = Field(default_factory=lambda: f"fc_{uuid.uuid4().hex[:24]}") + call_id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex[:24]}") + name: str = Field(..., description="The function name being called.") + arguments: str = Field(..., description="JSON string of the function arguments.") + status: Literal["completed"] = "completed" + + +class ResponseIncompleteDetails(BaseModel): + """Details about why a response is incomplete.""" + + reason: Literal["max_output_tokens", "content_filter"] + + +# --------------------------------------------------------------------------- +# Usage and error models +# --------------------------------------------------------------------------- + + +class InputTokensDetails(BaseModel): + """Breakdown of input token usage.""" + + cached_tokens: int = 0 + + +class ResponseUsage(BaseModel): + """Token usage details with cache-awareness for OpenClaw/Hermes.""" + + input_tokens: int + output_tokens: int + total_tokens: int + input_tokens_details: Optional[InputTokensDetails] = None + + +class ResponseErrorObject(BaseModel): + """Error object returned when the model fails to generate a Response.""" + + code: Optional[str] = None + message: Optional[str] = None + param: Optional[str] = None + type: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Response object +# --------------------------------------------------------------------------- + + +class ResponseObject(BaseModel): + """The top-level Response object returned by /v1/responses. + + Reference: https://developers.openai.com/api/reference/resources/responses/object + """ + + id: str = Field( + default_factory=lambda: f"resp_{uuid.uuid4().hex[:24]}", + description="Unique identifier for this Response.", + ) + object: Literal["response"] = Field( + "response", description="The object type — always ``response``." + ) + created_at: int = Field(..., description="Unix timestamp of creation.") + status: Literal["completed", "failed", "in_progress", "incomplete"] = Field( + "completed", description="The status of the response generation." + ) + error: Optional[ResponseErrorObject] = Field(None) + incomplete_details: Optional[ResponseIncompleteDetails] = Field(None) + instructions: Optional[str] = Field(None) + max_output_tokens: Optional[int] = Field(None) + model: str = Field(..., description="Model ID used to generate the response.") + output: List[Union[ResponseMessageItem, ResponseFunctionCallItem]] = Field( + default_factory=list, + description="An array of content items generated by the model.", + ) + parallel_tool_calls: bool = Field(True) + previous_response_id: Optional[str] = Field(None) + temperature: Optional[float] = Field(None, ge=0, le=2) + top_p: Optional[float] = Field(None, ge=0, le=1) + tools: List = Field(default_factory=list) + tool_choice: Optional[Any] = Field("auto") + truncation: Literal["auto", "disabled"] = Field("disabled") + metadata: Optional[dict] = Field(None) + usage: ResponseUsage = Field(..., description="Token usage details.") + user: Optional[str] = Field(None) + + @property + def output_text(self) -> str: + """Aggregate text from all output_text content parts.""" + parts = [] + for item in self.output: + if isinstance(item, ResponseMessageItem): + for part in item.content: + if part.type == "output_text" and part.text: + parts.append(part.text) + return "".join(parts) or "" + + +# --------------------------------------------------------------------------- +# Streaming event models +# --------------------------------------------------------------------------- + + +class BaseStreamEvent(BaseModel): + """Base class for all SSE streaming events.""" + + type: str + sequence_number: int = 0 + + +class ResponseCreatedEvent(BaseStreamEvent): + type: Literal["response.created"] = "response.created" + response: ResponseObject + + +class ResponseInProgressEvent(BaseStreamEvent): + type: Literal["response.in_progress"] = "response.in_progress" + response: ResponseObject + + +class ResponseOutputItemAddedEvent(BaseStreamEvent): + type: Literal["response.output_item.added"] = "response.output_item.added" + output_index: int + item: Union[ResponseMessageItem, ResponseFunctionCallItem] + + +class ResponseContentPartAddedEvent(BaseStreamEvent): + type: Literal["response.content_part.added"] = "response.content_part.added" + item_id: str + output_index: int + content_index: int + part: ContentPartOutputText + + +class ResponseOutputTextDeltaEvent(BaseStreamEvent): + type: Literal["response.output_text.delta"] = "response.output_text.delta" + item_id: str + output_index: int + content_index: int + delta: str + + +class ResponseOutputTextDoneEvent(BaseStreamEvent): + type: Literal["response.output_text.done"] = "response.output_text.done" + item_id: str + output_index: int + content_index: int + text: str + + +class ResponseContentPartDoneEvent(BaseStreamEvent): + type: Literal["response.content_part.done"] = "response.content_part.done" + item_id: str + output_index: int + content_index: int + part: ContentPartOutputText + + +class ResponseOutputItemDoneEvent(BaseStreamEvent): + type: Literal["response.output_item.done"] = "response.output_item.done" + output_index: int + item: Union[ResponseMessageItem, ResponseFunctionCallItem] + + +class ResponseFunctionCallArgumentsDeltaEvent(BaseStreamEvent): + type: Literal["response.function_call_arguments.delta"] = ( + "response.function_call_arguments.delta" + ) + item_id: str + output_index: int + delta: str + + +class ResponseFunctionCallArgumentsDoneEvent(BaseStreamEvent): + type: Literal["response.function_call_arguments.done"] = ( + "response.function_call_arguments.done" + ) + item_id: str + output_index: int + arguments: str + + +class ResponseCompletedEvent(BaseStreamEvent): + type: Literal["response.completed"] = "response.completed" + response: ResponseObject + + +StreamEvent = Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseContentPartAddedEvent, + ResponseOutputTextDeltaEvent, + ResponseOutputTextDoneEvent, + ResponseContentPartDoneEvent, + ResponseOutputItemDoneEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseCompletedEvent, +] diff --git a/mlx_vlm/responses_store.py b/mlx_vlm/responses_store.py new file mode 100644 index 000000000..a543cebe7 --- /dev/null +++ b/mlx_vlm/responses_store.py @@ -0,0 +1,127 @@ +"""LRU response store for OpenAI Responses API previous_response_id support.""" + +import threading +from collections import OrderedDict +from typing import Any, Optional + + +class ResponseStore: + """Bounded LRU store mapping response IDs to (input_items, response_object) pairs. + + Used to support the ``previous_response_id`` parameter in the Responses API, + which allows clients to chain responses without resending full conversation + history. + + Args: + maxsize: Maximum number of responses to store. When exceeded, the oldest + entry is evicted. Defaults to 256. + """ + + def __init__(self, maxsize: int = 256): + self._store: OrderedDict[str, dict] = OrderedDict() + self._maxsize = maxsize + self._lock = threading.Lock() + + def save( + self, + response_id: str, + input_items: Any, + response_output: list, + ) -> None: + """Save a response for later replay. + + Args: + response_id: The unique response ID (e.g., ``"resp_abc123"``). + input_items: The original request input (string or list of input items). + response_output: The response output items list (dicts or model instances). + """ + with self._lock: + if response_id in self._store: + self._store.move_to_end(response_id) + self._store[response_id] = { + "input": input_items, + "output": response_output, + } + while len(self._store) > self._maxsize: + self._store.popitem(last=False) + + def get(self, response_id: str) -> Optional[dict]: + """Retrieve a stored response by ID. + + Args: + response_id: The response ID to look up. + + Returns: + Dict with ``"input"`` and ``"output"`` keys, or ``None`` if not found. + """ + with self._lock: + entry = self._store.get(response_id) + if entry is not None: + self._store.move_to_end(response_id) + return entry + + def replay_input(self, response_id: str) -> Optional[list]: + """Build conversation input by replaying a previous response. + + Reconstructs input items from the stored response: the original input + items followed by the output items converted to input format. + + Args: + response_id: The previous response ID to replay. + + Returns: + List of input items suitable for prepending to the current request, + or ``None`` if the response ID is not found. + """ + entry = self.get(response_id) + if entry is None: + return None + + items = [] + + # Add original input items + original_input = entry["input"] + if isinstance(original_input, str): + items.append({"role": "user", "content": original_input}) + elif isinstance(original_input, list): + items.extend(original_input) + + # Convert output items to input format + for output_item in entry.get("output", []): + if isinstance(output_item, dict): + item_type = output_item.get("type", "") + if item_type == "message": + content = output_item.get("content", []) + for part in content: + if isinstance(part, dict) and part.get("type") == "output_text": + items.append( + { + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": part.get("text", ""), + } + ], + } + ) + elif item_type == "function_call": + items.append( + { + "type": "function_call", + "call_id": output_item.get("call_id", ""), + "name": output_item.get("name", ""), + "arguments": output_item.get("arguments", ""), + } + ) + + return items + + def __len__(self) -> int: + with self._lock: + return len(self._store) + + def clear(self) -> None: + """Remove all stored responses.""" + with self._lock: + self._store.clear() diff --git a/mlx_vlm/server.py b/mlx_vlm/server.py index 04c011f6a..6dcf55a99 100644 --- a/mlx_vlm/server.py +++ b/mlx_vlm/server.py @@ -1,4 +1,5 @@ import argparse +import asyncio import gc import json import os @@ -7,7 +8,6 @@ import traceback import uuid from contextlib import asynccontextmanager -from datetime import datetime from typing import Any, List, Literal, Optional, Union import mlx.core as mx @@ -31,11 +31,45 @@ DEFAULT_THINKING_END_TOKEN, DEFAULT_THINKING_START_TOKEN, DEFAULT_TOP_P, + PromptCacheState, generate, normalize_resize_shape, stream_generate, ) from .prompt_utils import apply_chat_template +from .responses_models import ContentPartOutputText as ResponseContentPartOutputText +from .responses_models import InputTokensDetails +from .responses_models import ResponseCompletedEvent as ResponsesCompletedEvent +from .responses_models import ( + ResponseContentPartAddedEvent as ResponsesContentPartAddedEvent, +) +from .responses_models import ( + ResponseContentPartDoneEvent as ResponsesContentPartDoneEvent, +) +from .responses_models import ResponseCreatedEvent as ResponsesCreatedEvent +from .responses_models import ( + ResponseFunctionCallArgumentsDeltaEvent as ResponsesFunctionCallArgumentsDeltaEvent, +) +from .responses_models import ( + ResponseFunctionCallArgumentsDoneEvent as ResponsesFunctionCallArgumentsDoneEvent, +) +from .responses_models import ResponseFunctionCallItem, ResponseIncompleteDetails +from .responses_models import ResponseInProgressEvent as ResponsesInProgressEvent +from .responses_models import ResponseMessageItem, ResponseObject +from .responses_models import ( + ResponseOutputItemAddedEvent as ResponsesOutputItemAddedEvent, +) +from .responses_models import ( + ResponseOutputItemDoneEvent as ResponsesOutputItemDoneEvent, +) +from .responses_models import ( + ResponseOutputTextDeltaEvent as ResponsesOutputTextDeltaEvent, +) +from .responses_models import ( + ResponseOutputTextDoneEvent as ResponsesOutputTextDoneEvent, +) +from .responses_models import ResponsesRequest, ResponseUsage +from .responses_store import ResponseStore from .tool_parsers import _infer_tool_parser, load_tool_module from .utils import load from .version import __version__ @@ -44,11 +78,56 @@ DEFAULT_SERVER_HOST = "0.0.0.0" DEFAULT_SERVER_PORT = 8080 +def _is_verbose() -> bool: + return os.environ.get("VERBOSE", "").lower() in ("1", "true", "yes") + + +class _VerboseFlag: + """Lazy flag that checks env var on each access, so --verbose works after import.""" + def __bool__(self) -> bool: + return _is_verbose() + + +_verbose = _VerboseFlag() + + +def get_default_max_tokens() -> int: + """Server-side default max tokens for API responses. + The upstream generate.py default (256) is too low for agentic use. + Configurable via --default-max-tokens CLI flag or DEFAULT_MAX_TOKENS env var.""" + return int(os.environ.get("DEFAULT_MAX_TOKENS", DEFAULT_MAX_TOKENS)) + +_responses_store = ResponseStore() + def get_prefill_step_size(): return int(os.environ.get("PREFILL_STEP_SIZE", DEFAULT_PREFILL_STEP_SIZE)) +def get_max_context_tokens() -> int: + """Maximum prompt tokens before rejecting a request. 0 = no limit.""" + return int(os.environ.get("MAX_CONTEXT_TOKENS", 0)) + + +def get_request_timeout() -> int: + """Maximum seconds for a generation request. 0 = no timeout.""" + return int(os.environ.get("REQUEST_TIMEOUT", 300)) + + +def check_context_length(prompt: str, processor, max_context: int) -> None: + """Raise HTTP 400 if the tokenized prompt exceeds *max_context* tokens.""" + if max_context <= 0: + return + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor + token_count = len(tokenizer.encode(prompt, add_special_tokens=False)) + if token_count > max_context: + raise HTTPException( + status_code=400, + detail=f"Prompt length ({token_count} tokens) exceeds maximum context " + f"window ({max_context} tokens).", + ) + + def get_quantized_kv_bits(model: str): kv_bits = float(os.environ.get("KV_BITS", 0)) if kv_bits == 0: @@ -85,6 +164,16 @@ def get_quantized_kv_start(): return int(os.environ.get("QUANTIZED_KV_START", DEFAULT_QUANTIZED_KV_START)) +async def _prompt_cache_cleanup_loop(): + """Background task that periodically evicts stale prompt caches.""" + while True: + await asyncio.sleep(60) + try: + evict_stale_prompt_caches() + except Exception as e: + print(f"[prompt_cache] Cleanup error: {e}") + + @asynccontextmanager async def lifespan(app): # Startup @@ -97,7 +186,24 @@ async def lifespan(app): except Exception as e: print(f"Failed to preload model: {e}") print("Server will continue without a preloaded model.") + + # Start prompt cache cleanup task + ttl = get_prompt_cache_ttl() + cleanup_task = None + if ttl > 0: + cleanup_task = asyncio.create_task(_prompt_cache_cleanup_loop()) + if _verbose: + print(f"[prompt_cache] Cleanup task started (TTL={ttl}s, check every 60s)") + yield + + # Shutdown + if cleanup_task is not None: + cleanup_task.cancel() + try: + await cleanup_task + except asyncio.CancelledError: + pass unload_model_sync() @@ -122,6 +228,112 @@ async def lifespan(app): model_cache = {} +# Prompt cache: reuse KV state across requests with the same prompt prefix. +# Keyed by (model_name, cache_key) — supports both: +# - OpenClaw: sends `prompt_cache_key` for per-session routing +# - Hermes: relies on stable system prompt prefix for automatic matching +# When no cache_key is provided, falls back to model name only. +DEFAULT_PROMPT_CACHE_TTL = 300 # seconds + + +def get_prompt_cache_ttl() -> int: + """Prompt cache TTL in seconds. 0 = no expiry.""" + return int(os.environ.get("PROMPT_CACHE_TTL", DEFAULT_PROMPT_CACHE_TTL)) + + +_PROMPT_CACHE_MAX_ENTRIES = 64 +_prompt_cache_states: dict[str, PromptCacheState] = {} + +# Concurrency guard: MLX generation is single-threaded on Metal. +# Concurrent requests would corrupt shared GPU state. The semaphore +# serializes access to the generation pipeline. +_generation_semaphore: Optional[asyncio.Semaphore] = None + + +def get_max_concurrent_requests() -> int: + return int(os.environ.get("MAX_CONCURRENT_REQUESTS", 1)) + + +def get_generation_semaphore() -> asyncio.Semaphore: + """Get or create the generation semaphore.""" + global _generation_semaphore + if _generation_semaphore is None: + _generation_semaphore = asyncio.Semaphore(get_max_concurrent_requests()) + return _generation_semaphore + + +async def acquire_semaphore() -> asyncio.Semaphore: + """Acquire the generation semaphore with timeout. Returns the semaphore for release.""" + sem = get_generation_semaphore() + timeout = get_request_timeout() + try: + if timeout > 0: + await asyncio.wait_for(sem.acquire(), timeout=timeout) + else: + await sem.acquire() + except asyncio.TimeoutError: + raise HTTPException( + status_code=503, + detail="Server busy: generation request timed out waiting for GPU.", + ) + return sem + + +def get_prompt_cache_state( + model_name: str, + cache_key: Optional[str] = None, +) -> PromptCacheState: + """Get or create a PromptCacheState for the given model and cache key. + + Supports two caching patterns: + + **OpenClaw**: Sends ``prompt_cache_key`` per session so that requests + from the same conversation share a KV cache. The system prompt prefix + is stable across turns, so prefix matching works. + + **Hermes**: Relies on a stable system prompt and ``cache_control`` + breakpoints. No ``prompt_cache_key`` is sent, so we fall back to + a single cache per model. The ``PromptCacheState.find_prefix_length`` + in generate.py will still match the common system-prompt prefix. + + Args: + model_name: The model identifier. + cache_key: Optional routing key (e.g., ``prompt_cache_key`` from the + request). When provided, each key gets its own cache state. + """ + key = f"{model_name}::{cache_key}" if cache_key else model_name + # Evict LRU entry if at capacity and key is new + if key not in _prompt_cache_states and len(_prompt_cache_states) >= _PROMPT_CACHE_MAX_ENTRIES: + lru_key = min(_prompt_cache_states, key=lambda k: _prompt_cache_states[k].last_used) + evicted = _prompt_cache_states.pop(lru_key) + evicted.invalidate() + state = _prompt_cache_states.setdefault(key, PromptCacheState()) + state.touch() + return state + + +def evict_stale_prompt_caches() -> int: + """Remove prompt cache entries that exceed the TTL. Returns count evicted.""" + ttl = get_prompt_cache_ttl() + if ttl <= 0: + return 0 + now = time.time() + stale_keys = [ + k for k, v in _prompt_cache_states.items() + if (now - v.last_used) > ttl + ] + for k in stale_keys: + entry = _prompt_cache_states.pop(k) + tokens = entry.token_count + idle = now - entry.last_used + entry.invalidate() + if _verbose: + print(f"[prompt_cache] Evicted '{k}' ({tokens} tokens, idle {idle:.0f}s)") + if stale_keys: + gc.collect() + mx.clear_cache() + return len(stale_keys) + class FlexibleBaseModel(BaseModel): """Base model that ignores/accepts any unknown OpenAI SDK fields.""" @@ -204,6 +416,8 @@ def unload_model_sync(): if "vision_cache" in model_cache: model_cache["vision_cache"].clear() model_cache = {} + # Clear prompt cache states + _prompt_cache_states.clear() # Force garbage collection gc.collect() mx.clear_cache() @@ -376,12 +590,16 @@ class OpenAIRequest(GenerationParams, TemplateParams): ) model: str = Field(..., description="The model to use for generation.") max_output_tokens: int = Field( - DEFAULT_MAX_TOKENS, + None, description="Maximum number of tokens to generate.", ) stream: bool = Field( False, description="Whether to stream the response chunk by chunk." ) + stop: Optional[Union[str, List[str]]] = Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens.", + ) def generation_kwargs(self) -> dict[str, Any]: kwargs = self.dump_kwargs("max_output_tokens") @@ -554,11 +772,15 @@ class VLMRequest(GenerationParams, TemplateParams): adapter_path: Optional[str] = Field( None, description="The path to the adapter weights." ) - max_tokens: int = Field( - DEFAULT_MAX_TOKENS, - description="Maximum number of tokens to generate.", + max_tokens: Optional[int] = Field( + None, + description="Maximum number of tokens to generate. Uses server default if not specified.", ) seed: int = Field(DEFAULT_SEED, description="Seed for random generation.") + stop: Optional[Union[str, List[str]]] = Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens.", + ) resize_shape: Optional[ResizeShapeInput] = Field( None, description="Resize shape for the image. Provide one integer for a square resize or two integers for (height, width).", @@ -602,11 +824,40 @@ class UsageStats(OpenAIUsage): class ChatRequest(GenerationRequest): messages: List[ChatMessage] + logprobs: Optional[bool] = Field( + None, description="Whether to return log probabilities." + ) + top_logprobs: Optional[int] = Field( + None, + ge=0, + le=20, + description="Number of most likely tokens to return at each position.", + ) + + @field_validator("top_logprobs") + @classmethod + def validate_top_logprobs_supported(cls, value): + if value is not None: + raise ValueError( + "`top_logprobs` is not supported by this server and must be omitted." + ) + return value + + +class TokenLogprob(BaseModel): + token: str + logprob: float + bytes: Optional[List[int]] = None + + +class ChoiceLogprobs(BaseModel): + content: Optional[List[TokenLogprob]] = None class ChatChoice(BaseModel): finish_reason: str message: ChatMessage + logprobs: Optional[ChoiceLogprobs] = None class ChatResponse(BaseModel): @@ -619,6 +870,7 @@ class ChatStreamChoice(BaseModel): index: int = 0 finish_reason: Optional[str] = None delta: ChatMessage + logprobs: Optional[ChoiceLogprobs] = None class ChatStreamChunk(BaseModel): @@ -630,10 +882,106 @@ class ChatStreamChunk(BaseModel): usage: Optional[UsageStats] +def resolve_stop_sequences( + stop: Optional[Union[str, list]], +) -> Optional[list]: + """Normalize stop sequences for the generation stopping criteria. + + The generation pipeline's ``add_eos_token_ids`` accepts strings + and handles tokenization internally. + + Args: + stop: A single stop string or list of stop strings, or None. + + Returns: + A list of stop strings (max 4), or None. + """ + if not stop: + return None + if isinstance(stop, str): + stop = [stop] + sequences = [s for s in stop[:4] if isinstance(s, str) and s] + return sequences if sequences else None + + +def resolve_tool_choice( + tools: Optional[list], + tool_choice: Optional[Any], +) -> tuple[Optional[list], Optional[str]]: + """Apply tool_choice policy to the tools list. + + Args: + tools: The original tools list from the request. + tool_choice: ``"none"``, ``"auto"``, ``"required"``, or a dict + specifying a particular tool. + + Returns: + Tuple of ``(filtered_tools, system_instruction)``. + """ + if not tools or tool_choice is None or tool_choice == "auto": + return tools, None + + if tool_choice == "none": + return None, None + + if tool_choice == "required": + return tools, "You must call one of the available tools to answer this request." + + if isinstance(tool_choice, dict): + func = tool_choice.get("function", {}) + name = func.get("name") if isinstance(func, dict) else None + if name: + filtered = [ + t + for t in tools + if (t.get("function", {}) or {}).get("name") == name + or t.get("name") == name + ] + return ( + filtered or tools, + f'You must call the "{name}" tool to answer this request.', + ) + + return tools, None + + +def resolve_response_format( + messages: list, + response_format: Optional[dict], +) -> list: + """Inject JSON instruction if json_object format is requested.""" + if not response_format: + return messages + fmt_type = response_format.get("type", "text") + if fmt_type == "json_object": + messages.insert( + 0, + { + "role": "system", + "content": "You must respond with valid JSON only. Do not include any text outside the JSON object.", + }, + ) + return messages + + +DEFAULT_REPETITION_PENALTY = 1.1 + + def build_generation_kwargs( request: Any, template_kwargs: dict[str, Any], ) -> dict[str, Any]: + gen_kwargs = request.generation_kwargs() + # Apply server-side default max_tokens if not specified in request. + default_max = get_default_max_tokens() + if "max_tokens" not in gen_kwargs or gen_kwargs["max_tokens"] is None: + gen_kwargs["max_tokens"] = default_max + # Apply server-side default repetition penalty if not specified in request. + # Prevents MoE models from degenerating into repetition loops. + if "repetition_penalty" not in gen_kwargs or gen_kwargs["repetition_penalty"] is None: + default_rp = float(os.environ.get("DEFAULT_REPETITION_PENALTY", DEFAULT_REPETITION_PENALTY)) + if default_rp > 0: + gen_kwargs["repetition_penalty"] = default_rp return { "prefill_step_size": get_prefill_step_size(), "kv_bits": get_quantized_kv_bits(request.model), @@ -641,7 +989,7 @@ def build_generation_kwargs( "kv_quant_scheme": get_kv_quant_scheme(), "max_kv_size": get_max_kv_size(request.model), "quantized_kv_start": get_quantized_kv_start(), - **request.generation_kwargs(), + **gen_kwargs, **template_kwargs, } @@ -704,273 +1052,628 @@ class ModelsResponse(BaseModel): data: List[ModelInfo] -# OpenAI compatile endpoints +# --------------------------------------------------------------------------- +# Responses API helpers +# --------------------------------------------------------------------------- -@app.post("/responses") -@app.post("/v1/responses", include_in_schema=False) -async def responses_endpoint(openai_request: OpenAIRequest): +def responses_input_to_messages( + input_items: Union[str, list], + instructions: Optional[str] = None, + previous_response_id: Optional[str] = None, +) -> tuple[list[dict], list[str]]: + """Convert Responses API input items to chat messages and images. + + Args: + input_items: String input or list of input items. + instructions: Optional system instructions to prepend. + previous_response_id: Optional previous response ID for context replay. + + Returns: + Tuple of (chat_messages, image_urls). """ - OpenAI-compatible endpoint for generating text based on a prompt and optional images. + chat_messages: list[dict] = [] + images: list[str] = [] + + # Replay previous response context + if previous_response_id: + replayed = _responses_store.replay_input(previous_response_id) + if replayed is None: + raise HTTPException( + status_code=404, + detail=f"Previous response not found: {previous_response_id}", + ) + # Recursively process replayed items + prev_messages, prev_images = responses_input_to_messages(replayed) + chat_messages.extend(prev_messages) + images.extend(prev_images) + + # Prepend instructions as system message + if instructions: + chat_messages.insert(0, {"role": "system", "content": instructions}) + + # Handle string input + if isinstance(input_items, str): + chat_messages.append({"role": "user", "content": input_items}) + return chat_messages, images + + # Handle list of input items + for item in input_items: + if isinstance(item, dict): + item_type = item.get("type", "") + role = item.get("role", "") + + # Function call output item + if item_type == "function_call_output": + call_id = item.get("call_id", "unknown") + output = item.get("output", "") + chat_messages.append( + { + "role": "tool", + "content": output, + "tool_call_id": call_id, + } + ) + continue - using client.responses.create method. + # Function call item (from previous assistant turn) + if item_type == "function_call": + chat_messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": item.get("call_id", ""), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", ""), + }, + } + ], + } + ) + continue + + # Regular message with role and content + if role: + content = item.get("content", "") + + # Normalize developer role to system + msg_role = "system" if role == "developer" else role + + if isinstance(content, str): + chat_messages.append({"role": msg_role, "content": content}) + elif isinstance(content, list): + # Process content items + text_parts = [] + for ci in content: + if isinstance(ci, dict): + ci_type = ci.get("type", "") + if ci_type in ("input_text", "text"): + text_parts.append(ci.get("text", "")) + elif ci_type == "input_image": + images.append(ci.get("image_url", "")) + elif ci_type == "image_url": + img = ci.get("image_url", {}) + if isinstance(img, dict): + images.append(img.get("url", "")) + elif isinstance(img, str): + images.append(img) + elif ci_type == "output_text": + # Multi-turn: previous assistant output + chat_messages.append( + { + "role": "assistant", + "content": ci.get("text", ""), + } + ) + elif ci_type == "input_audio": + pass # Audio not yet supported in responses + else: + pass # Skip unsupported content types gracefully + + if text_parts: + chat_messages.append( + { + "role": msg_role, + "content": "\n".join(text_parts), + } + ) + else: + chat_messages.append( + { + "role": msg_role, + "content": str(content) if content else "", + } + ) + continue + + # Handle Pydantic ChatMessage objects + elif hasattr(item, "role"): + role = item.role + msg_role = "system" if role == "developer" else role + content = item.content + + if content is None: + chat_messages.append({"role": msg_role, "content": ""}) + elif isinstance(content, str): + chat_messages.append({"role": msg_role, "content": content}) + elif isinstance(content, list): + text_parts = [] + for ci in content: + if isinstance(ci, dict): + ci_type = ci.get("type", "") + if ci_type in ("input_text", "text"): + text_parts.append(ci.get("text", "")) + elif ci_type == "input_image": + images.append(ci.get("image_url", "")) + elif ci_type == "image_url": + img = ci.get("image_url", {}) + if isinstance(img, dict): + images.append(img.get("url", "")) + elif isinstance(img, str): + images.append(img) + elif ci_type == "output_text": + chat_messages.append( + { + "role": "assistant", + "content": ci.get("text", ""), + } + ) - example: + if text_parts: + chat_messages.append( + { + "role": msg_role, + "content": "\n".join(text_parts), + } + ) - from openai import OpenAI + return chat_messages, images - API_URL = "http://0.0.0.0:8000" - API_KEY = 'any' - def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, model="mlx-community/Qwen2.5-VL-3B-Instruct-8bit"): - ''' Calls the OpenAI API - ''' +def build_responses_output( + raw_text: str, + tool_parser_type: Optional[str], + tool_module: Optional[Any], + tools: Optional[list], +) -> list[Union[ResponseMessageItem, ResponseFunctionCallItem]]: + """Build structured Responses API output items from raw model text. - client = OpenAI(base_url=f"{API_URL}", api_key=API_KEY) + Parses tool calls from the raw text if a tool parser is available, + creating ResponseFunctionCallItem for each detected call and a + ResponseMessageItem for any remaining text. - try : - response = client.responses.create( - model=model, - input=[ - {"role":"system", - "content": f"{system}" - }, - { - "role": "user", - "content": [ - {"type": "input_text", "text": prompt}, - {"type": "input_image", "image_url": f"{img_url}"}, - ], - } - ], - max_output_tokens=max_output_tokens, - stream=stream - ) - if not stream: - print(response.output[0].content[0].text) - print(response.usage) - else: - for event in response: - # Process different event types if needed - if hasattr(event, 'delta') and event.delta: - print(event.delta, end="", flush=True) - elif event.type == 'response.completed': - print("\n--- Usage ---") - print(event.response.usage) + Args: + raw_text: The raw text output from the model. + tool_parser_type: The detected tool parser type (e.g., "gemma4"), or None. + tool_module: The loaded tool parser module, or None. + tools: The tool definitions from the request, or None. - except Exception as e: - # building a response object to match the one returned when request is successful so that it can be processed in the same way - return {"model - error":str(e),"content":{}, "model":model} + Returns: + List of output items (message items and/or function call items). + """ + output_items: list[Union[ResponseMessageItem, ResponseFunctionCallItem]] = [] + remaining_text = raw_text + # Try to parse tool calls + if tool_parser_type and tool_module and tools: + try: + result = process_tool_calls(raw_text, tool_module, tools) + if result["calls"]: + for call in result["calls"]: + func_info = call.get("function", {}) + output_items.append( + ResponseFunctionCallItem( + name=func_info.get("name", ""), + arguments=func_info.get("arguments", "{}"), + call_id=call.get("id", f"call_{uuid.uuid4().hex[:24]}"), + ) + ) + remaining_text = result.get("remaining_text", "").strip() + except Exception: + # If tool parsing fails, fall through to plain text + remaining_text = raw_text + + # Create message item for any remaining text + if remaining_text or not output_items: + msg_item = ResponseMessageItem( + content=( + [ResponseContentPartOutputText(text=remaining_text)] + if remaining_text + else [] + ), + ) + # Insert message before function calls (matching OpenAI ordering) + output_items.insert(0, msg_item) + + return output_items + + +# OpenAI compatible endpoints + + +@app.post("/responses") +@app.post("/v1/responses", include_in_schema=False) +async def responses_endpoint(request: ResponsesRequest): + """OpenAI-compatible Responses API endpoint. + + Supports tool calling, multi-turn via previous_response_id, and streaming + with proper SSE event sequences including function_call argument events. """ + # Resolve default max tokens if not specified in request + if request.max_output_tokens is None: + request.max_output_tokens = get_default_max_tokens() try: # Get model, processor, config - loading if necessary - model, processor, config = get_cached_model(openai_request.model) + model, processor, config = get_cached_model(request.model) + + # Debug: log incoming request details + _tools_count = len(request.tools) if request.tools else 0 + _tool_names = [t.get("name", t.get("function", {}).get("name", "?")) if isinstance(t, dict) else "?" for t in (request.tools or [])] + _input_len = len(str(request.input)) + _instructions_len = len(request.instructions) if request.instructions else 0 + if _verbose: + print(f"[responses] tools={_tools_count} names={_tool_names} stream={request.stream} input_chars={_input_len} instructions_chars={_instructions_len}") + + # Convert input to chat messages + chat_messages, images = responses_input_to_messages( + request.input, + instructions=request.instructions, + previous_response_id=request.previous_response_id, + ) - chat_messages = [] - images = [] - instructions = None - if openai_request.input: - if isinstance(openai_request.input, str): - # If input is a string, treat it as a single text message - chat_messages.append({"role": "user", "content": openai_request.input}) - elif isinstance(openai_request.input, list): - # If input is a list, treat it as a series of chat messages - for message in openai_request.input: - if isinstance(message, ChatMessage): - if message.content is None: - chat_messages.append({"role": message.role, "content": ""}) - elif isinstance(message.content, str): - chat_messages.append( - {"role": message.role, "content": message.content} - ) - if message.role == "system": - instructions = message.content - elif isinstance(message.content, list): - # Handle list of content items - for item in message.content: - if isinstance(item, dict): - if item["type"] == "input_text": - chat_messages.append( - { - "role": message.role, - "content": item["text"], - } - ) - if message.role == "system": - instructions = item["text"] - # examples for multiple images (https://platform.openai.com/docs/guides/images?api-mode=responses) - elif item["type"] == "input_image": - images.append(item["image_url"]) - else: - print( - f"invalid input item type: {item['type']}" - ) - raise HTTPException( - status_code=400, - detail="Invalid input item type.", - ) - else: - print( - f"Invalid message content item format: {item}" - ) - raise HTTPException( - status_code=400, - detail="Missing type in input item.", - ) - else: - print("Invalid message content format.") - raise HTTPException( - status_code=400, detail="Invalid input format." - ) - else: - print("not a ChatMessage") - raise HTTPException( - status_code=400, detail="Invalid input format." - ) - else: - print("neither string not list") - raise HTTPException(status_code=400, detail="Invalid input format.") + # Apply JSON mode if requested + response_format = getattr(request, "response_format", None) + chat_messages = resolve_response_format(chat_messages, response_format) - else: - print("no input") - raise HTTPException(status_code=400, detail="Missing input.") + # Set up tool parser (apply tool_choice policy) + tools = request.tools + tool_choice_val = getattr(request, "tool_choice", "auto") + tools, tool_instruction = resolve_tool_choice(tools, tool_choice_val) + if tool_instruction: + chat_messages.insert(0, {"role": "system", "content": tool_instruction}) + + tool_parser_type = None + tool_module = None + tokenizer = ( + processor.tokenizer if hasattr(processor, "tokenizer") else processor + ) + if hasattr(tokenizer, "chat_template") and tools: + tool_parser_type = _infer_tool_parser(tokenizer.chat_template) + if tool_parser_type is not None: + tool_module = load_tool_module(tool_parser_type) + if _verbose: + print(f"[responses] tool_parser={tool_parser_type} tool_module={'yes' if tool_module else 'no'} tools_after_choice={len(tools) if tools else 0}") + + # Build template kwargs + template_kwargs = request.template_kwargs() - template_kwargs = openai_request.template_kwargs() + # Apply chat template (pass tools so the template can include tool defs) formatted_prompt = apply_chat_template( processor, config, chat_messages, num_images=len(images), + tools=tools, **template_kwargs, ) - generation_kwargs = build_generation_kwargs(openai_request, template_kwargs) + generation_kwargs = build_generation_kwargs(request, template_kwargs) - generated_at = datetime.now().timestamp() - response_id = f"resp_{uuid.uuid4().hex}" - message_id = f"msg_{uuid.uuid4().hex}" + if _verbose: + _prompt_str = formatted_prompt if isinstance(formatted_prompt, str) else str(formatted_prompt) + print(f"[responses] prompt_chars={len(_prompt_str)} last_300=...{_prompt_str[-300:]!r}") - if openai_request.stream: + check_context_length(formatted_prompt, processor, get_max_context_tokens()) + + # Resolve stop sequences to token IDs + stop_seqs = resolve_stop_sequences(getattr(request, "stop", None)) + if stop_seqs: + generation_kwargs["eos_tokens"] = stop_seqs + + generated_at = int(time.time()) + response_id = f"resp_{uuid.uuid4().hex[:24]}" + message_id = f"msg_{uuid.uuid4().hex[:24]}" + + if request.stream: + # ---------------------------------------------------------- # Streaming response - async def stream_generator(): - token_iterator = None + # ---------------------------------------------------------- + async def stream_responses_generator(): + seq = 0 # sequence_number counter + + def _evt(event_type: str, event_obj) -> str: + nonlocal seq + event_obj.sequence_number = seq + seq += 1 + return ( + f"event: {event_type}\ndata: {event_obj.model_dump_json()}\n\n" + ) + + sem = None try: - # Create base response object (to match the openai pipeline) - base_response = OpenAIResponse( + sem = await acquire_semaphore() + # Build base ResponseObject (in_progress, empty output) + base_response = ResponseObject( id=response_id, - object="response", - created_at=int(generated_at), + created_at=generated_at, status="in_progress", - instructions=instructions, - max_output_tokens=openai_request.max_output_tokens, - model=openai_request.model, + model=request.model, output=[], - output_text="", - temperature=openai_request.temperature, - top_p=openai_request.top_p, - usage={ - "input_tokens": 0, # get prompt tokens - "output_tokens": 0, - "total_tokens": 0, - }, + instructions=request.instructions, + max_output_tokens=request.max_output_tokens, + temperature=request.temperature, + top_p=request.top_p, + tools=tools or [], + tool_choice=request.tool_choice, + parallel_tool_calls=request.parallel_tool_calls, + previous_response_id=request.previous_response_id, + metadata=request.metadata, + usage=ResponseUsage( + input_tokens=0, output_tokens=0, total_tokens=0 + ), ) - # Send response.created event (to match the openai pipeline) - yield f"event: response.created\ndata: {ResponseCreatedEvent(type='response.created', response=base_response).model_dump_json()}\n\n" - - # Send response.in_progress event (to match the openai pipeline) - yield f"event: response.in_progress\ndata: {ResponseInProgressEvent(type='response.in_progress', response=base_response).model_dump_json()}\n\n" - - # Send response.output_item.added event (to match the openai pipeline) - message_item = MessageItem( - id=message_id, - type="message", - status="in_progress", - role="assistant", - content=[], + # response.created + yield _evt( + "response.created", + ResponsesCreatedEvent(response=base_response), + ) + # response.in_progress + yield _evt( + "response.in_progress", + ResponsesInProgressEvent(response=base_response), ) - yield f"event: response.output_item.added\ndata: {ResponseOutputItemAddedEvent(type='response.output_item.added', output_index=0, item=message_item).model_dump_json()}\n\n" - # Send response.content_part.added event - content_part = ContentPartOutputText( - type="output_text", text="", annotations=[] + # output_item.added (message) + msg_item = ResponseMessageItem( + id=message_id, status="in_progress", content=[] + ) + yield _evt( + "response.output_item.added", + ResponsesOutputItemAddedEvent(output_index=0, item=msg_item), ) - yield f"event: response.content_part.added\ndata: {ResponseContentPartAddedEvent(type='response.content_part.added', item_id=message_id, output_index=0, content_index=0, part=content_part).model_dump_json()}\n\n" - # Stream text deltas + # content_part.added + empty_part = ResponseContentPartOutputText(text="") + yield _evt( + "response.content_part.added", + ResponsesContentPartAddedEvent( + item_id=message_id, + output_index=0, + content_index=0, + part=empty_part, + ), + ) + cache_state = get_prompt_cache_state( + request.model, getattr(request, "prompt_cache_key", None) + ) token_iterator = stream_generate( model=model, processor=processor, prompt=formatted_prompt, image=images, vision_cache=model_cache.get("vision_cache"), + prompt_cache_state=cache_state, **generation_kwargs, ) full_text = "" + visible_text = "" + usage_stats = {"input_tokens": 0, "output_tokens": 0} + in_tool_call = False + tool_call_start_tag = ( + tool_module.tool_call_start if tool_module else "" + ) + for chunk in token_iterator: if chunk is None or not hasattr(chunk, "text"): continue delta = chunk.text full_text += delta - usage_stats = { "input_tokens": chunk.prompt_tokens, "output_tokens": chunk.generation_tokens, } - # Send response.output_text.delta event - yield f"event: response.output_text.delta\ndata: {ResponseOutputTextDeltaEvent(type='response.output_text.delta', item_id=message_id, output_index=0, content_index=0, delta=delta).model_dump_json()}\n\n" + # Suppress tool call tokens from being streamed as text + if not in_tool_call and tool_call_start_tag in full_text: + in_tool_call = True + if in_tool_call: + continue - # Send response.output_text.done event (to match the openai pipeline) - yield f"event: response.output_text.done\ndata: {ResponseOutputTextDoneEvent(type='response.output_text.done', item_id=message_id, output_index=0, content_index=0, text=full_text).model_dump_json()}\n\n" + # Check if this delta starts a tool call tag + # (partial match: buffer might end with "") + if tools and tool_call_start_tag[:1] in delta: + pending = full_text[ + -(len(delta) + len(tool_call_start_tag)) : + ] + if any( + tool_call_start_tag[:i] == pending[-i:] + for i in range(2, len(tool_call_start_tag) + 1) + ): + continue + + visible_text += delta + yield _evt( + "response.output_text.delta", + ResponsesOutputTextDeltaEvent( + item_id=message_id, + output_index=0, + content_index=0, + delta=delta, + ), + ) - # Send response.content_part.done event (to match the openai pipeline) - final_content_part = ContentPartOutputText( - type="output_text", text=full_text, annotations=[] + # Determine finish reason + max_tok = request.max_output_tokens + is_length = usage_stats["output_tokens"] >= max_tok + status = "incomplete" if is_length else "completed" + + # Use visible_text (sans tool call markup) for text events + display_text = visible_text + + # output_text.done + yield _evt( + "response.output_text.done", + ResponsesOutputTextDoneEvent( + item_id=message_id, + output_index=0, + content_index=0, + text=display_text, + ), ) - yield f"event: response.content_part.done\ndata: {ResponseContentPartDoneEvent(type='response.content_part.done', item_id=message_id, output_index=0, content_index=0, part=final_content_part).model_dump_json()}\n\n" - # Send response.output_item.done event (to match the openai pipeline) - final_message_item = MessageItem( + # content_part.done + final_part = ResponseContentPartOutputText(text=display_text) + yield _evt( + "response.content_part.done", + ResponsesContentPartDoneEvent( + item_id=message_id, + output_index=0, + content_index=0, + part=final_part, + ), + ) + + # output_item.done (message) + final_msg = ResponseMessageItem( id=message_id, - type="message", status="completed", - role="assistant", - content=[final_content_part], + content=[final_part], ) - yield f"event: response.output_item.done\ndata: {ResponseOutputItemDoneEvent(type='response.output_item.done', output_index=0, item=final_message_item).model_dump_json()}\n\n" + yield _evt( + "response.output_item.done", + ResponsesOutputItemDoneEvent(output_index=0, item=final_msg), + ) + + # Collect all output items for final response + all_output_items: list = [final_msg] - # Send response.completed event (to match the openai pipeline) + # Parse tool calls from accumulated text + if tool_parser_type and tool_module and tools: + try: + tc_result = process_tool_calls( + full_text, tool_module, tools + ) + if tc_result["calls"]: + for idx, call in enumerate(tc_result["calls"]): + func_info = call.get("function", {}) + fc_item = ResponseFunctionCallItem( + name=func_info.get("name", ""), + arguments=func_info.get("arguments", "{}"), + call_id=call.get( + "id", f"call_{uuid.uuid4().hex[:24]}" + ), + ) + out_idx = len(all_output_items) + + # output_item.added (function_call) + yield _evt( + "response.output_item.added", + ResponsesOutputItemAddedEvent( + output_index=out_idx, item=fc_item + ), + ) + + # function_call_arguments.delta (full arguments in one shot) + yield _evt( + "response.function_call_arguments.delta", + ResponsesFunctionCallArgumentsDeltaEvent( + item_id=fc_item.id, + output_index=out_idx, + delta=fc_item.arguments, + ), + ) + + # function_call_arguments.done + yield _evt( + "response.function_call_arguments.done", + ResponsesFunctionCallArgumentsDoneEvent( + item_id=fc_item.id, + output_index=out_idx, + arguments=fc_item.arguments, + ), + ) + + # output_item.done (function_call) + yield _evt( + "response.output_item.done", + ResponsesOutputItemDoneEvent( + output_index=out_idx, item=fc_item + ), + ) + + all_output_items.append(fc_item) + except Exception: + pass # Tool parsing failure is non-fatal in streaming + + # response.completed + total_tokens = ( + usage_stats["input_tokens"] + usage_stats["output_tokens"] + ) completed_response = base_response.model_copy( update={ - "status": "completed", - "output": [final_message_item], - "usage": { - "input_tokens": usage_stats["input_tokens"], - "output_tokens": usage_stats["output_tokens"], - "total_tokens": usage_stats["input_tokens"] - + usage_stats["output_tokens"], - }, + "status": status, + "output": all_output_items, + "incomplete_details": ( + ResponseIncompleteDetails(reason="max_output_tokens") + if status == "incomplete" + else None + ), + "usage": ResponseUsage( + input_tokens=usage_stats["input_tokens"], + output_tokens=usage_stats["output_tokens"], + total_tokens=total_tokens, + ), } ) - yield f"event: response.completed\ndata: {ResponseCompletedEvent(type='response.completed', response=completed_response).model_dump_json()}\n\n" + yield _evt( + "response.completed", + ResponsesCompletedEvent(response=completed_response), + ) + + # Save to store for previous_response_id + _responses_store.save( + response_id, + ( + request.input + if isinstance(request.input, str) + else [ + ( + item.model_dump() + if hasattr(item, "model_dump") + else item + ) + for item in request.input + ] + ), + [item.model_dump() for item in all_output_items], + ) + + # Final sentinel + yield "data: [DONE]\n\n" except Exception as e: print(f"Error during stream generation: {e}") traceback.print_exc() - error_data = json.dumps({"error": str(e)}) + error_data = json.dumps({"error": "Internal generation error"}) yield f"data: {error_data}\n\n" finally: mx.clear_cache() gc.collect() - print("Stream finished, cleared cache.") + if sem is not None: + sem.release() + if _verbose: + print("Stream finished, cleared cache.") return StreamingResponse( - stream_generator(), + stream_responses_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", @@ -980,64 +1683,102 @@ async def stream_generator(): ) else: + # ---------------------------------------------------------- # Non-streaming response + # ---------------------------------------------------------- + sem = None try: - # Use generate from generate.py + sem = await acquire_semaphore() + cache_state = get_prompt_cache_state( + request.model, getattr(request, "prompt_cache_key", None) + ) result = generate( model=model, processor=processor, prompt=formatted_prompt, image=images, - verbose=False, # stats are passed in the response + verbose=False, + vision_cache=model_cache.get("vision_cache"), + prompt_cache_state=cache_state, **generation_kwargs, ) - # Clean up resources mx.clear_cache() gc.collect() - print("Generation finished, cleared cache.") + if _verbose: + print("Generation finished, cleared cache.") + + # Build output items (with tool call parsing) + output_items = build_responses_output( + result.text, + tool_parser_type, + tool_module, + tools, + ) + + # Determine status + is_length = result.generation_tokens >= request.max_output_tokens + status = "incomplete" if is_length else "completed" + incomplete_details = ( + ResponseIncompleteDetails(reason="max_output_tokens") + if status == "incomplete" + else None + ) - response = OpenAIResponse( + response_obj = ResponseObject( id=response_id, - object="response", - created_at=int(generated_at), - status="completed", - instructions=instructions, - max_output_tokens=openai_request.max_output_tokens, - model=openai_request.model, - output=[ - { - "role": "assistant", - "content": [ - { - "type": "output_text", - "text": result.text, - } - ], - } - ], - output_text=result.text, - temperature=openai_request.temperature, - top_p=openai_request.top_p, - usage={ - "input_tokens": result.prompt_tokens, - "output_tokens": result.generation_tokens, - "total_tokens": result.total_tokens, - }, + created_at=generated_at, + model=request.model, + output=output_items, + status=status, + incomplete_details=incomplete_details, + instructions=request.instructions, + max_output_tokens=request.max_output_tokens, + temperature=request.temperature, + top_p=request.top_p, + tools=tools or [], + tool_choice=request.tool_choice, + parallel_tool_calls=request.parallel_tool_calls, + previous_response_id=request.previous_response_id, + metadata=request.metadata, + usage=ResponseUsage( + input_tokens=result.prompt_tokens, + output_tokens=result.generation_tokens, + total_tokens=result.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr(result, "cached_tokens", 0), + ), + ), + ) + + # Save to store for previous_response_id support + _responses_store.save( + response_obj.id, + ( + request.input + if isinstance(request.input, str) + else [ + item.model_dump() if hasattr(item, "model_dump") else item + for item in request.input + ] + ), + [item.model_dump() for item in output_items], ) - return response + + return response_obj.model_dump() except Exception as e: print(f"Error during generation: {e}") traceback.print_exc() mx.clear_cache() gc.collect() - raise HTTPException(status_code=500, detail=f"Generation failed: {e}") + raise HTTPException(status_code=500, detail="Generation failed. Check server logs for details.") + finally: + if sem is not None: + sem.release() - except HTTPException as http_exc: - # Re-raise HTTP exceptions (like model loading failure) - raise http_exc + except HTTPException: + raise except Exception as e: - # Catch unexpected errors print(f"Unexpected error in /responses endpoint: {e}") traceback.print_exc() mx.clear_cache() @@ -1058,6 +1799,9 @@ async def chat_completions_endpoint(request: ChatRequest): System message will be ignored if not already in the prompt. Can operate in streaming or non-streaming mode. """ + # Resolve default max tokens if not specified in request + if request.max_tokens is None: + request.max_tokens = get_default_max_tokens() try: # Get model, processor, config - loading if necessary @@ -1095,6 +1839,14 @@ async def chat_completions_endpoint(request: ChatRequest): if hasattr(request, "tools"): tools = request.tools + # Apply tool_choice policy + tool_choice = getattr(request, "tool_choice", None) + tools, tool_instruction = resolve_tool_choice(tools, tool_choice) + if tool_instruction: + processed_messages.insert( + 0, {"role": "system", "content": tool_instruction} + ) + tool_parser_type = None tokenizer = ( processor.tokenizer if hasattr(processor, "tokenizer") else processor @@ -1115,12 +1867,24 @@ async def chat_completions_endpoint(request: ChatRequest): ) generation_kwargs = build_generation_kwargs(request, template_kwargs) + check_context_length(formatted_prompt, processor, get_max_context_tokens()) + + # Resolve stop sequences to token IDs + stop_seqs = resolve_stop_sequences(getattr(request, "stop", None)) + if stop_seqs: + generation_kwargs["eos_tokens"] = stop_seqs + if request.stream: # Streaming response async def stream_generator(): + sem = None token_iterator = None try: - # Use stream_generate from utils + sem = await acquire_semaphore() + # Use stream_generate with prompt cache reuse + cache_state = get_prompt_cache_state( + request.model, getattr(request, "prompt_cache_key", None) + ) token_iterator = stream_generate( model=model, processor=processor, @@ -1128,14 +1892,17 @@ async def stream_generator(): image=images, audio=audio, vision_cache=model_cache.get("vision_cache"), + prompt_cache_state=cache_state, **generation_kwargs, ) output_text = "" request_id = f"chatcmpl-{uuid.uuid4()}" + want_logprobs = getattr(request, "logprobs", None) for chunk in token_iterator: if chunk is None or not hasattr(chunk, "text"): - print("Warning: Received unexpected chunk format:", chunk) + if _verbose: + print("Warning: Received unexpected chunk format:", chunk) continue output_text += chunk.text @@ -1151,9 +1918,24 @@ async def stream_generator(): "peak_memory": chunk.peak_memory, } + chunk_logprobs = None + if want_logprobs and chunk.token is not None and chunk.logprobs is not None: + token_text = tokenizer.decode([chunk.token]) + chosen_logprob = float(chunk.logprobs[chunk.token]) + chunk_logprobs = ChoiceLogprobs( + content=[ + TokenLogprob( + token=token_text, + logprob=chosen_logprob, + bytes=list(token_text.encode("utf-8")), + ) + ] + ) + choices = [ ChatStreamChoice( - delta=ChatMessage(role="assistant", content=chunk.text) + delta=ChatMessage(role="assistant", content=chunk.text), + logprobs=chunk_logprobs, ) ] chunk_data = ChatStreamChunk( @@ -1176,10 +1958,11 @@ async def stream_generator(): tool_calls = {} tool_calls["calls"] = [] - # Signal stream end + # Signal stream end with correct finish_reason + stream_finish = "tool_calls" if tool_calls.get("calls") else "stop" choices = [ ChatStreamChoice( - finish_reason="stop", + finish_reason=stream_finish, delta=ChatMessage( role="assistant", content="", @@ -1202,13 +1985,16 @@ async def stream_generator(): except Exception as e: print(f"Error during stream generation: {e}") traceback.print_exc() - error_data = json.dumps({"error": str(e)}) + error_data = json.dumps({"error": "Internal generation error"}) yield f"data: {error_data}\n\n" finally: mx.clear_cache() gc.collect() - print("Stream finished, cleared cache.") + if sem is not None: + sem.release() + if _verbose: + print("Stream finished, cleared cache.") return StreamingResponse( stream_generator(), @@ -1222,22 +2008,60 @@ async def stream_generator(): else: # Non-streaming response + sem = None try: - # Use generate from generate.py - gen_result = generate( - model=model, - processor=processor, - prompt=formatted_prompt, - image=images, - audio=audio, - verbose=False, # Keep API output clean - vision_cache=model_cache.get("vision_cache"), - **generation_kwargs, + sem = await acquire_semaphore() + want_logprobs = getattr(request, "logprobs", None) + cache_state = get_prompt_cache_state( + request.model, getattr(request, "prompt_cache_key", None) ) + token_logprobs = [] + + if want_logprobs: + # Use stream_generate to collect per-token logprobs + full_text = "" + gen_result = None + for chunk in stream_generate( + model=model, + processor=processor, + prompt=formatted_prompt, + image=images, + audio=audio, + vision_cache=model_cache.get("vision_cache"), + prompt_cache_state=cache_state, + **generation_kwargs, + ): + if chunk is None or not hasattr(chunk, "text"): + continue + full_text += chunk.text + if chunk.token is not None and chunk.logprobs is not None: + token_text = tokenizer.decode([chunk.token]) + chosen_logprob = float(chunk.logprobs[chunk.token]) + token_logprobs.append( + TokenLogprob( + token=token_text, + logprob=chosen_logprob, + bytes=list(token_text.encode("utf-8")), + ) + ) + gen_result = chunk + gen_result.text = full_text + else: + gen_result = generate( + model=model, + processor=processor, + prompt=formatted_prompt, + image=images, + audio=audio, + verbose=False, + vision_cache=model_cache.get("vision_cache"), + prompt_cache_state=cache_state, + **generation_kwargs, + ) + # Clean up resources mx.clear_cache() gc.collect() - print("Generation finished, cleared cache.") usage_stats = UsageStats( input_tokens=gen_result.prompt_tokens, @@ -1259,14 +2083,20 @@ async def stream_generator(): tool_calls["calls"] = [] tool_calls["remaining_text"] = gen_result.text + choice_logprobs = None + if want_logprobs and token_logprobs: + choice_logprobs = ChoiceLogprobs(content=token_logprobs) + + finish = "tool_calls" if tool_calls.get("calls") else "stop" choices = [ ChatChoice( - finish_reason="stop", + finish_reason=finish, message=ChatMessage( role="assistant", content=tool_calls["remaining_text"], tool_calls=tool_calls["calls"], ), + logprobs=choice_logprobs, ) ] @@ -1281,7 +2111,10 @@ async def stream_generator(): traceback.print_exc() mx.clear_cache() gc.collect() - raise HTTPException(status_code=500, detail=f"Generation failed: {e}") + raise HTTPException(status_code=500, detail="Generation failed. Check server logs for details.") + finally: + if sem is not None: + sem.release() except HTTPException as http_exc: # Re-raise HTTP exceptions (like model loading failure) @@ -1433,6 +2266,48 @@ def main(): default=DEFAULT_QUANTIZED_KV_START, help="Start index (of token) for the quantized KV cache.", ) + parser.add_argument( + "--max-concurrent-requests", + type=int, + default=1, + help="Maximum number of concurrent generation requests. " + "MLX runs single-threaded on Metal; values > 1 may cause GPU errors. " + "(default: %(default)s)", + ) + parser.add_argument( + "--max-context-tokens", + type=int, + default=0, + help="Maximum context window in tokens. 0 = no limit. (default: %(default)s)", + ) + parser.add_argument( + "--request-timeout", + type=int, + default=300, + help="Maximum seconds per generation request. (default: %(default)s)", + ) + parser.add_argument( + "--default-max-tokens", + type=int, + default=DEFAULT_MAX_TOKENS, + help="Default max tokens for API responses when not specified in the request. " + "The upstream default (256) is too low for agentic use. " + "(default: %(default)s)", + ) + parser.add_argument( + "--prompt-cache-ttl", + type=int, + default=DEFAULT_PROMPT_CACHE_TTL, + help="Seconds of idle time before a prompt cache entry is evicted. " + "Frees GPU memory from stale KV caches. 0 = no expiry. " + "(default: %(default)s)", + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Enable verbose debug logging (prompt content, cache state, tool detection).", + ) parser.add_argument( "--reload", action="store_true", @@ -1454,6 +2329,13 @@ def main(): os.environ["KV_QUANT_SCHEME"] = args.kv_quant_scheme os.environ["MAX_KV_SIZE"] = str(args.max_kv_size) os.environ["QUANTIZED_KV_START"] = str(args.quantized_kv_start) + os.environ["MAX_CONCURRENT_REQUESTS"] = str(args.max_concurrent_requests) + os.environ["MAX_CONTEXT_TOKENS"] = str(args.max_context_tokens) + os.environ["REQUEST_TIMEOUT"] = str(args.request_timeout) + os.environ["PROMPT_CACHE_TTL"] = str(args.prompt_cache_ttl) + os.environ["DEFAULT_MAX_TOKENS"] = str(args.default_max_tokens) + if args.verbose: + os.environ["VERBOSE"] = "1" uvicorn.run( "mlx_vlm.server:app", diff --git a/mlx_vlm/tests/test_responses_api.py b/mlx_vlm/tests/test_responses_api.py new file mode 100644 index 000000000..7ef21d588 --- /dev/null +++ b/mlx_vlm/tests/test_responses_api.py @@ -0,0 +1,1073 @@ +"""Tests for the OpenAI Responses API (/v1/responses) compliance. + +Covers: + A. Model validation (pure unit tests, no server/mlx needed) + B. Response store (pure unit tests) + C. Functional endpoint tests (TestClient, mocked model) + D. Streaming endpoint tests (TestClient, mocked model) +""" + +import importlib.util +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers: load modules without triggering mlx_vlm.__init__ (no mlx needed) +# --------------------------------------------------------------------------- + + +def _load_module(name: str, filename: str): + """Load a sibling module by file path, bypassing package __init__.""" + mod_path = Path(__file__).parent.parent / filename + spec = importlib.util.spec_from_file_location(name, str(mod_path)) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +responses_models = _load_module("responses_models", "responses_models.py") +responses_store = _load_module("responses_store", "responses_store.py") + +ResponsesRequest = responses_models.ResponsesRequest +ResponseObject = responses_models.ResponseObject +ResponseMessageItem = responses_models.ResponseMessageItem +ResponseFunctionCallItem = responses_models.ResponseFunctionCallItem +ContentPartOutputText = responses_models.ContentPartOutputText +ResponseUsage = responses_models.ResponseUsage +FlexibleBaseModel = responses_models.FlexibleBaseModel +BaseStreamEvent = responses_models.BaseStreamEvent +ResponseStore = responses_store.ResponseStore + + +# ========================================================================= +# A. Model Validation Tests +# ========================================================================= + + +class TestResponsesModels: + """Pure unit tests for Pydantic models in responses_models.py.""" + + def test_responses_request_accepts_string_input(self): + req = ResponsesRequest(input="Hello", model="test-model") + assert req.input == "Hello" + + def test_responses_request_accepts_message_list(self): + msgs = [{"role": "user", "content": "hello"}] + req = ResponsesRequest(input=msgs, model="test-model") + assert isinstance(req.input, list) + assert len(req.input) == 1 + + def test_responses_request_accepts_tools(self): + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the weather", + "parameters": {"type": "object", "properties": {}}, + } + ] + req = ResponsesRequest(input="hi", model="m", tools=tools) + assert req.tools is not None + assert len(req.tools) == 1 + + def test_responses_request_default_tool_choice(self): + req = ResponsesRequest(input="hi", model="m") + assert req.tool_choice == "auto" + + def test_responses_request_generation_kwargs(self): + req = ResponsesRequest(input="hi", model="m", max_output_tokens=128) + kwargs = req.generation_kwargs() + assert "max_tokens" in kwargs + assert kwargs["max_tokens"] == 128 + assert "max_output_tokens" not in kwargs + + def test_response_object_output_text_computed(self): + msg = ResponseMessageItem( + content=[ + ContentPartOutputText(text="Hello "), + ContentPartOutputText(text="world!"), + ] + ) + resp = ResponseObject( + created_at=0, + model="m", + output=[msg], + usage=ResponseUsage(input_tokens=1, output_tokens=2, total_tokens=3), + ) + assert resp.output_text == "Hello world!" + + def test_response_object_output_text_empty_when_only_function_calls(self): + fc = ResponseFunctionCallItem(name="fn", arguments='{"a":1}') + resp = ResponseObject( + created_at=0, + model="m", + output=[fc], + usage=ResponseUsage(input_tokens=1, output_tokens=2, total_tokens=3), + ) + assert resp.output_text == "" + + def test_function_call_item_auto_ids(self): + fc = ResponseFunctionCallItem(name="fn", arguments="{}") + assert fc.id.startswith("fc_") + assert fc.call_id.startswith("call_") + # IDs should be unique per instance + fc2 = ResponseFunctionCallItem(name="fn", arguments="{}") + assert fc.id != fc2.id + + def test_function_call_item_schema(self): + fc = ResponseFunctionCallItem(name="get_weather", arguments='{"city":"NYC"}') + assert fc.name == "get_weather" + assert fc.arguments == '{"city":"NYC"}' + assert fc.type == "function_call" + + def test_content_part_output_text_defaults(self): + part = ContentPartOutputText() + assert part.type == "output_text" + assert part.text == "" + assert part.annotations == [] + + def test_streaming_event_sequence_number(self): + evt = BaseStreamEvent(type="test.event", sequence_number=42) + assert evt.sequence_number == 42 + evt_default = BaseStreamEvent(type="test.event") + assert evt_default.sequence_number == 0 + + def test_flexible_base_model_accepts_unknown_fields(self): + req = ResponsesRequest(input="hi", model="m", some_unknown_field="surprise") + # Should not raise; extra field accessible via model_extra + assert req.model_extra.get("some_unknown_field") == "surprise" + + +# ========================================================================= +# B. Response Store Tests +# ========================================================================= + + +class TestResponseStore: + """Pure unit tests for the LRU ResponseStore.""" + + def test_store_save_and_get(self): + store = ResponseStore() + store.save("resp_1", "hello", [{"type": "message"}]) + entry = store.get("resp_1") + assert entry is not None + assert entry["input"] == "hello" + assert entry["output"] == [{"type": "message"}] + + def test_store_get_missing_returns_none(self): + store = ResponseStore() + assert store.get("resp_nonexistent") is None + + def test_store_lru_eviction(self): + store = ResponseStore(maxsize=2) + store.save("resp_a", "a", []) + store.save("resp_b", "b", []) + store.save("resp_c", "c", []) # should evict resp_a + assert store.get("resp_a") is None + assert store.get("resp_b") is not None + assert store.get("resp_c") is not None + + def test_store_replay_string_input(self): + store = ResponseStore() + store.save("resp_1", "hello", []) + items = store.replay_input("resp_1") + assert items is not None + assert len(items) == 1 + assert items[0]["role"] == "user" + assert items[0]["content"] == "hello" + + def test_store_replay_message_list_input(self): + original = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "system", "content": "You are helpful."}, + ] + store = ResponseStore() + store.save("resp_1", original, []) + items = store.replay_input("resp_1") + assert items is not None + assert len(items) == 2 + assert items[0]["role"] == "user" + assert items[1]["role"] == "system" + + def test_store_replay_function_call_output(self): + output = [ + { + "type": "function_call", + "call_id": "call_123", + "name": "get_weather", + "arguments": '{"city":"NYC"}', + } + ] + store = ResponseStore() + store.save("resp_1", "hello", output) + items = store.replay_input("resp_1") + assert items is not None + # First item is the original user input, second is the function call + fc_items = [i for i in items if i.get("type") == "function_call"] + assert len(fc_items) == 1 + assert fc_items[0]["name"] == "get_weather" + assert fc_items[0]["call_id"] == "call_123" + + def test_store_replay_missing_returns_none(self): + store = ResponseStore() + assert store.replay_input("resp_nope") is None + + def test_store_clear(self): + store = ResponseStore() + store.save("resp_1", "a", []) + store.save("resp_2", "b", []) + assert len(store) == 2 + store.clear() + assert len(store) == 0 + assert store.get("resp_1") is None + + +# ========================================================================= +# C. Functional Endpoint Tests (require mlx for server import) +# ========================================================================= + +# Guard: skip functional/streaming tests if mlx is unavailable, but let +# the pure-unit tests above run on any platform. +_has_mlx = importlib.util.find_spec("mlx") is not None + +if _has_mlx: + from fastapi.testclient import TestClient # noqa: E402 + + import mlx_vlm.server as server # noqa: E402 + +_skip_no_mlx = pytest.mark.skipif(not _has_mlx, reason="mlx not installed") + + +# Shared mock objects (safe to create even without mlx) +mock_model = MagicMock() +mock_processor = MagicMock() +mock_processor.tokenizer = MagicMock() +mock_processor.tokenizer.chat_template = "" +mock_config = SimpleNamespace(model_type="test") + + +def _mock_result(text="Hello world!", prompt_tokens=10, gen_tokens=5): + """Build a SimpleNamespace matching generate() return shape.""" + return SimpleNamespace( + text=text, + prompt_tokens=prompt_tokens, + generation_tokens=gen_tokens, + total_tokens=prompt_tokens + gen_tokens, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ) + + +@pytest.fixture +def client(): + with TestClient(server.app) as c: + yield c + + +def _patch_model(): + return patch.object( + server, + "get_cached_model", + return_value=(mock_model, mock_processor, mock_config), + ) + + +def _patch_template(): + return patch.object(server, "apply_chat_template", return_value="prompt") + + +def _patch_generate(result=None): + if result is None: + result = _mock_result() + return patch.object(server, "generate", return_value=result) + + +@_skip_no_mlx +class TestResponsesEndpoint: + """Functional tests for POST /responses.""" + + def test_basic_text_response(self, client): + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={"model": "demo", "input": "Hello"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["object"] == "response" + assert data["status"] == "completed" + assert "id" in data + assert "output" in data + assert "usage" in data + + def test_response_with_message_list(self, client): + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": [{"role": "user", "content": "hello"}], + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "completed" + + def test_instructions_field_echoed(self, client): + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": [ + {"role": "system", "content": "Be brief."}, + {"role": "user", "content": "hi"}, + ], + "instructions": "Be brief.", + }, + ) + assert resp.status_code == 200 + data = resp.json() + # The instructions field should be present in the response + assert data.get("instructions") is not None + + def test_tools_field_echoed(self, client): + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object", "properties": {}}, + } + ] + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={"model": "demo", "input": "hi", "tools": tools}, + ) + assert resp.status_code == 200 + + def test_previous_response_id_not_found(self, client): + """Referencing a non-existent previous_response_id should return an error.""" + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": "follow-up", + "previous_response_id": "resp_nonexistent999", + }, + ) + # The server should either 404 or 200 (ignoring unknown ID). + # We just verify it doesn't crash with a 500. + assert resp.status_code in (200, 404) + + def test_developer_role_mapped_to_system(self, client): + """developer role should be accepted (mapped to system internally).""" + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": [ + {"role": "developer", "content": "You are helpful."}, + {"role": "user", "content": "hi"}, + ], + }, + ) + # Should not crash; accept 200 or 422 if server rejects developer role + assert resp.status_code in (200, 422) + + def test_text_type_alias(self, client): + """'text' type should be accepted alongside 'input_text'.""" + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "hi"}], + } + ], + }, + ) + assert resp.status_code == 200 + + def test_function_call_output_input(self, client): + """function_call_output items in input should not crash the server.""" + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": [ + {"role": "user", "content": "call a tool"}, + { + "type": "function_call_output", + "call_id": "call_abc", + "output": '{"result": 42}', + }, + ], + }, + ) + # May fail if server doesn't handle function_call_output yet; + # accept anything except unhandled 500 + assert resp.status_code in (200, 400, 422) + + def test_max_output_tokens_incomplete(self, client): + """When finish_reason is 'length', the response status should ideally be 'incomplete'.""" + result = _mock_result(text="truncated...") + with _patch_model(), _patch_template(), _patch_generate(result): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": "Write a very long essay", + "max_output_tokens": 5, + }, + ) + assert resp.status_code == 200 + # Just verify the response is well-formed + data = resp.json() + assert data["status"] in ("completed", "incomplete") + + +@_skip_no_mlx +class TestResponsesStreaming: + """Streaming SSE tests for POST /responses with stream=true.""" + + def _stream_events(self, client, payload): + """Helper: POST with stream=True and collect SSE events.""" + with _patch_model(), _patch_template(): + # Mock stream_generate to yield chunks + chunks = [ + SimpleNamespace(text="Hello", prompt_tokens=10, generation_tokens=1), + SimpleNamespace(text=" world", prompt_tokens=10, generation_tokens=2), + ] + + def mock_stream_gen(**kwargs): + return iter(chunks) + + with patch.object(server, "stream_generate", side_effect=mock_stream_gen): + resp = client.post("/responses", json=payload) + return resp + + def test_streaming_sse_events(self, client): + payload = {"model": "demo", "input": "Hello", "stream": True} + resp = self._stream_events(client, payload) + assert resp.status_code == 200 + body = resp.text + # Should contain key event types + assert "event: response.created" in body + assert "event: response.output_text.delta" in body + assert "event: response.completed" in body + + def test_streaming_done_sentinel(self, client): + """The stream should end properly (response.completed is the last real event).""" + payload = {"model": "demo", "input": "Hello", "stream": True} + resp = self._stream_events(client, payload) + assert resp.status_code == 200 + body = resp.text + # The last meaningful event should be response.completed + lines = [l for l in body.strip().split("\n") if l.startswith("event:")] + assert lines[-1] == "event: response.completed" + + +# ========================================================================= +# E. Prompt Cache Tests +# ========================================================================= + + +@_skip_no_mlx +class TestPromptCache: + """Verify prompt_cache_state is wired into all generation entry points.""" + + def test_responses_non_streaming_passes_cache_state(self, client): + """Non-streaming /responses should pass prompt_cache_state to generate.""" + captured = {} + + def capture_generate(**kwargs): + captured["prompt_cache_state"] = kwargs.get("prompt_cache_state") + return _mock_result() + + with ( + _patch_model(), + _patch_template(), + patch.object(server, "generate", side_effect=capture_generate), + ): + resp = client.post("/responses", json={"model": "demo", "input": "hi"}) + assert resp.status_code == 200 + assert captured.get("prompt_cache_state") is not None + assert hasattr(captured["prompt_cache_state"], "find_prefix_length") + + def test_responses_streaming_passes_cache_state(self, client): + """Streaming /responses should pass prompt_cache_state to stream_generate.""" + captured = {} + + def capture_stream(**kwargs): + captured["prompt_cache_state"] = kwargs.get("prompt_cache_state") + return iter( + [ + SimpleNamespace(text="Hi", prompt_tokens=5, generation_tokens=1), + ] + ) + + with ( + _patch_model(), + _patch_template(), + patch.object(server, "stream_generate", side_effect=capture_stream), + ): + resp = client.post( + "/responses", + json={"model": "demo", "input": "hi", "stream": True}, + ) + assert resp.status_code == 200 + assert captured.get("prompt_cache_state") is not None + + def test_chat_completions_non_streaming_passes_cache_state(self, client): + """Non-streaming /chat/completions should pass prompt_cache_state.""" + captured = {} + + def capture_generate(**kwargs): + captured["prompt_cache_state"] = kwargs.get("prompt_cache_state") + return _mock_result() + + with ( + _patch_model(), + _patch_template(), + patch.object(server, "generate", side_effect=capture_generate), + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + assert resp.status_code == 200 + assert captured.get("prompt_cache_state") is not None + + def test_chat_completions_streaming_passes_cache_state(self, client): + """Streaming /chat/completions should pass prompt_cache_state.""" + captured = {} + + def capture_stream(**kwargs): + captured["prompt_cache_state"] = kwargs.get("prompt_cache_state") + return iter( + [ + SimpleNamespace(text="Hi", prompt_tokens=5, generation_tokens=1), + ] + ) + + with ( + _patch_model(), + _patch_template(), + patch.object(server, "stream_generate", side_effect=capture_stream), + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }, + ) + assert resp.status_code == 200 + assert captured.get("prompt_cache_state") is not None + + def test_cache_state_persists_across_requests(self, client): + """The same PromptCacheState should be reused for the same model.""" + states = [] + + def capture_generate(**kwargs): + states.append(kwargs.get("prompt_cache_state")) + return _mock_result() + + with ( + _patch_model(), + _patch_template(), + patch.object(server, "generate", side_effect=capture_generate), + ): + client.post("/responses", json={"model": "demo", "input": "first"}) + client.post("/responses", json={"model": "demo", "input": "second"}) + + assert len(states) == 2 + assert states[0] is states[1], "Same model should reuse the same cache state" + + def test_cache_state_isolated_per_model(self, client): + """Different models should get different PromptCacheState instances.""" + states = {} + + def capture_generate(**kwargs): + return _mock_result() + + # We need to capture from the store directly + with ( + _patch_model(), + _patch_template(), + patch.object(server, "generate", side_effect=capture_generate), + ): + client.post("/responses", json={"model": "model-a", "input": "hi"}) + state_a = server.get_prompt_cache_state("model-a") + client.post("/responses", json={"model": "model-b", "input": "hi"}) + state_b = server.get_prompt_cache_state("model-b") + + assert ( + state_a is not state_b + ), "Different models must have separate cache states" + + def test_cache_state_has_correct_interface(self, client): + """PromptCacheState should expose find_prefix_length and update methods.""" + state = server.get_prompt_cache_state("test-model") + assert hasattr(state, "find_prefix_length") + assert hasattr(state, "update") + assert hasattr(state, "cache") + assert hasattr(state, "token_ids") + # Initially empty + assert state.cache is None + assert state.token_ids is None + assert state.find_prefix_length([1, 2, 3]) == 0 + + def test_cache_state_cleared_on_unload(self, client): + """Unloading a model should clear all prompt cache states.""" + server._prompt_cache_states["some-model"] = server.PromptCacheState() + assert "some-model" in server._prompt_cache_states + # Simulate unload + with patch.object(server, "model_cache", {"model_path": "x"}): + server.unload_model_sync() + assert len(server._prompt_cache_states) == 0 + + def test_cache_state_prefix_matching(self): + """PromptCacheState.find_prefix_length should find common prefix.""" + state = server.PromptCacheState() + state.token_ids = [10, 20, 30, 40, 50] + assert state.find_prefix_length([10, 20, 30, 40, 50]) == 5 + assert state.find_prefix_length([10, 20, 30, 99, 50]) == 3 + assert state.find_prefix_length([99, 20, 30]) == 0 + assert state.find_prefix_length([10, 20]) == 2 + assert state.find_prefix_length([]) == 0 + + +# ========================================================================= +# F. Concurrency Guard Tests +# ========================================================================= + + +@_skip_no_mlx +class TestConcurrencyGuard: + """Verify concurrency guard serializes Metal GPU access.""" + + def test_semaphore_exists_and_is_semaphore(self): + """get_generation_semaphore should return an asyncio.Semaphore.""" + import asyncio + + sem = server.get_generation_semaphore() + assert isinstance(sem, asyncio.Semaphore) + + def test_semaphore_default_value_is_one(self): + """Default semaphore should allow exactly 1 concurrent request.""" + import os + + # Reset to force re-creation with default + server._generation_semaphore = None + os.environ.pop("MAX_CONCURRENT_REQUESTS", None) + sem = server.get_generation_semaphore() + assert sem._value == 1 + # Reset for other tests + server._generation_semaphore = None + + def test_semaphore_respects_env_var(self): + """MAX_CONCURRENT_REQUESTS env var should configure semaphore value.""" + import os + + server._generation_semaphore = None + os.environ["MAX_CONCURRENT_REQUESTS"] = "3" + sem = server.get_generation_semaphore() + assert sem._value == 3 + # Cleanup + os.environ["MAX_CONCURRENT_REQUESTS"] = "1" + server._generation_semaphore = None + + def test_semaphore_singleton(self): + """Repeated calls should return the same semaphore instance.""" + server._generation_semaphore = None + sem1 = server.get_generation_semaphore() + sem2 = server.get_generation_semaphore() + assert sem1 is sem2 + server._generation_semaphore = None + + def test_responses_non_streaming_acquires_semaphore(self, client): + """Non-streaming /responses should acquire and release the semaphore.""" + + acquired = [] + released = [] + real_sem = server.get_generation_semaphore() + + original_acquire = real_sem.acquire + original_release = real_sem.release + + async def mock_acquire(): + acquired.append(True) + return await original_acquire() + + def mock_release(): + released.append(True) + return original_release() + + with ( + _patch_model(), + _patch_template(), + _patch_generate(), + patch.object(real_sem, "acquire", side_effect=mock_acquire), + patch.object(real_sem, "release", side_effect=mock_release), + ): + resp = client.post("/responses", json={"model": "demo", "input": "hi"}) + + assert resp.status_code == 200 + assert len(acquired) >= 1, "Semaphore should be acquired" + assert len(released) >= 1, "Semaphore should be released" + + def test_concurrent_requests_both_succeed(self, client): + """Two sequential requests should both succeed (semaphore serializes).""" + with _patch_model(), _patch_template(), _patch_generate(): + r1 = client.post("/responses", json={"model": "demo", "input": "first"}) + r2 = client.post("/responses", json={"model": "demo", "input": "second"}) + assert r1.status_code == 200 + assert r2.status_code == 200 + + +# ========================================================================= +# G. finish_reason Tests +# ========================================================================= + + +@_skip_no_mlx +class TestFinishReason: + """Verify finish_reason is set correctly based on tool call detection.""" + + def test_chat_completions_finish_reason_stop_no_tools(self, client): + """finish_reason='stop' when no tools provided.""" + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/chat/completions", + json={"model": "demo", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["finish_reason"] == "stop" + + def test_chat_completions_finish_reason_tool_calls(self, client): + """finish_reason='tool_calls' when tool calls detected.""" + fake_calls = { + "calls": [ + { + "type": "function", + "id": "c1", + "function": {"name": "search", "arguments": "{}"}, + } + ], + "remaining_text": "", + } + with ( + _patch_model(), + _patch_template(), + _patch_generate(), + patch.object(server, "_infer_tool_parser", return_value="qwen3_coder"), + patch.object(server, "load_tool_module", return_value=SimpleNamespace()), + patch.object(server, "process_tool_calls", return_value=fake_calls), + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "search"}], + "tools": [ + { + "type": "function", + "function": {"name": "search", "parameters": {}}, + } + ], + }, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["finish_reason"] == "tool_calls" + + def test_chat_completions_finish_reason_stop_tools_no_calls(self, client): + """finish_reason='stop' when tools defined but model doesn't call any.""" + no_calls = {"calls": [], "remaining_text": "Just text, no tools."} + with ( + _patch_model(), + _patch_template(), + _patch_generate(), + patch.object(server, "_infer_tool_parser", return_value="qwen3_coder"), + patch.object(server, "load_tool_module", return_value=SimpleNamespace()), + patch.object(server, "process_tool_calls", return_value=no_calls), + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "hello"}], + "tools": [ + { + "type": "function", + "function": {"name": "search", "parameters": {}}, + } + ], + }, + ) + assert resp.status_code == 200 + assert resp.json()["choices"][0]["finish_reason"] == "stop" + + def test_chat_completions_streaming_finish_reason_tool_calls(self, client): + """Streaming finish_reason should be 'tool_calls' when tools detected.""" + fake_calls = { + "calls": [ + { + "type": "function", + "id": "c1", + "function": {"name": "search", "arguments": "{}"}, + } + ], + "remaining_text": "", + } + chunks = [ + SimpleNamespace( + text="calling", + prompt_tokens=10, + generation_tokens=1, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ), + ] + + def mock_stream(**kwargs): + return iter(chunks) + + with ( + _patch_model(), + _patch_template(), + patch.object(server, "stream_generate", side_effect=mock_stream), + patch.object(server, "_infer_tool_parser", return_value="qwen3_coder"), + patch.object(server, "load_tool_module", return_value=SimpleNamespace()), + patch.object(server, "process_tool_calls", return_value=fake_calls), + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "search"}], + "tools": [ + { + "type": "function", + "function": {"name": "search", "parameters": {}}, + } + ], + "stream": True, + }, + ) + assert resp.status_code == 200 + import json as json_mod + + lines = [ + l + for l in resp.text.strip().split("\n") + if l.startswith("data:") and "[DONE]" not in l + ] + last_data = json_mod.loads(lines[-1].replace("data: ", "")) + assert last_data["choices"][0]["finish_reason"] == "tool_calls" + + def test_responses_status_completed_no_tools(self, client): + """Responses endpoint status='completed' for normal text.""" + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post("/responses", json={"model": "demo", "input": "hi"}) + assert resp.status_code == 200 + assert resp.json()["status"] == "completed" + + +# ========================================================================= +# H. JSON Mode Tests +# ========================================================================= + + +@_skip_no_mlx +class TestJsonMode: + """Verify response_format parameter handling.""" + + def test_resolve_response_format_json_adds_instruction(self): + msgs = [{"role": "user", "content": "hi"}] + result = server.resolve_response_format(msgs, {"type": "json_object"}) + assert result[0]["role"] == "system" + assert "json" in result[0]["content"].lower() + + def test_resolve_response_format_text_no_change(self): + msgs = [{"role": "user", "content": "hi"}] + result = server.resolve_response_format(msgs, {"type": "text"}) + assert len(result) == 1 + + def test_resolve_response_format_none_no_change(self): + msgs = [{"role": "user", "content": "hi"}] + result = server.resolve_response_format(msgs, None) + assert len(result) == 1 + + def test_responses_json_mode_accepted(self, client): + with _patch_model(), _patch_template(), _patch_generate(): + resp = client.post( + "/responses", + json={ + "model": "demo", + "input": "Give me JSON", + "response_format": {"type": "json_object"}, + }, + ) + assert resp.status_code == 200 + + +# ========================================================================= +# I. Context Tracking Tests +# ========================================================================= + + +@_skip_no_mlx +class TestContextTracking: + """Verify context window tracking and OOM prevention.""" + + def test_check_context_length_within_limit(self): + fake_proc = SimpleNamespace( + tokenizer=SimpleNamespace( + encode=lambda s, add_special_tokens=False: list(range(10)) + ), + ) + server.check_context_length("short", fake_proc, 100) + + def test_check_context_length_exceeds_limit(self): + from fastapi import HTTPException as _Exc + + fake_proc = SimpleNamespace( + tokenizer=SimpleNamespace( + encode=lambda s, add_special_tokens=False: list(range(200)) + ), + ) + with pytest.raises(_Exc) as exc_info: + server.check_context_length("long", fake_proc, 100) + assert exc_info.value.status_code == 400 + + def test_check_context_length_zero_unlimited(self): + server.check_context_length("anything", None, 0) + + def test_get_max_context_tokens_default(self): + import os + + os.environ.pop("MAX_CONTEXT_TOKENS", None) + assert server.get_max_context_tokens() == 0 + + def test_get_max_context_tokens_from_env(self): + import os + + os.environ["MAX_CONTEXT_TOKENS"] = "16384" + assert server.get_max_context_tokens() == 16384 + os.environ.pop("MAX_CONTEXT_TOKENS") + + +# ========================================================================= +# J. Request Cancellation Tests +# ========================================================================= + + +@_skip_no_mlx +class TestRequestCancellation: + """Verify request timeout and cancellation support.""" + + def test_get_request_timeout_default(self): + import os + + os.environ.pop("REQUEST_TIMEOUT", None) + assert server.get_request_timeout() == 300 + + def test_get_request_timeout_from_env(self): + import os + + os.environ["REQUEST_TIMEOUT"] = "60" + assert server.get_request_timeout() == 60 + os.environ.pop("REQUEST_TIMEOUT") + + def test_streaming_cleanup_on_normal_completion(self, client): + """Streaming should complete normally and clean up.""" + chunks = [ + SimpleNamespace( + text="Hi", + prompt_tokens=5, + generation_tokens=1, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ), + ] + with ( + _patch_model(), + _patch_template(), + patch.object(server, "stream_generate", return_value=iter(chunks)), + ): + resp = client.post( + "/responses", + json={"model": "demo", "input": "hi", "stream": True}, + ) + assert resp.status_code == 200 + assert "response.completed" in resp.text + + +# ========================================================================= +# K. Prompt Cache Key Routing Tests +# ========================================================================= + + +@_skip_no_mlx +class TestPromptCacheKeyRouting: + """Verify prompt_cache_key routes to separate cache states.""" + + def test_same_cache_key_same_state(self): + """Same model + cache_key should return the same PromptCacheState.""" + s1 = server.get_prompt_cache_state("model-a", "session-1") + s2 = server.get_prompt_cache_state("model-a", "session-1") + assert s1 is s2 + + def test_different_cache_key_different_state(self): + """Different cache_keys should get isolated cache states.""" + s1 = server.get_prompt_cache_state("model-a", "session-1") + s2 = server.get_prompt_cache_state("model-a", "session-2") + assert s1 is not s2 + + def test_no_cache_key_falls_back_to_model(self): + """No cache_key should fall back to model-only keying.""" + s1 = server.get_prompt_cache_state("model-b") + s2 = server.get_prompt_cache_state("model-b", None) + assert s1 is s2 + + def test_cache_key_passed_from_request(self, client): + """prompt_cache_key from request should route to correct cache.""" + captured = {} + + def capture_gen(**kwargs): + captured["state"] = kwargs.get("prompt_cache_state") + return _mock_result() + + with ( + _patch_model(), + _patch_template(), + patch.object(server, "generate", side_effect=capture_gen), + ): + client.post( + "/responses", + json={ + "model": "demo", + "input": "hi", + "prompt_cache_key": "my-session", + }, + ) + assert captured["state"] is server.get_prompt_cache_state("demo", "my-session") diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 270a82d77..0ed3c20a5 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -130,3 +130,127 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client): assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15 assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5} assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512) + + +# --------------------------------------------------------------------------- +# Stop sequences tests +# --------------------------------------------------------------------------- + + +def test_chat_completions_stop_passed_as_eos_tokens(client): + """stop parameter should be passed as eos_tokens strings to generate.""" + model = SimpleNamespace() + processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template="")) + config = SimpleNamespace(model_type="test") + result = SimpleNamespace( + text="Hello", + prompt_tokens=5, + generation_tokens=1, + total_tokens=6, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ) + + with ( + patch.object( + server, "get_cached_model", return_value=(model, processor, config) + ), + patch.object(server, "apply_chat_template", return_value="prompt"), + patch.object(server, "generate", return_value=result) as mock_gen, + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "hello"}], + "stop": ["\n\n", ""], + }, + ) + assert resp.status_code == 200 + assert "eos_tokens" in mock_gen.call_args.kwargs + assert mock_gen.call_args.kwargs["eos_tokens"] == ["\n\n", ""] + + +def test_chat_completions_no_stop_no_eos_tokens(client): + """Without stop parameter, eos_tokens should not be in kwargs.""" + model = SimpleNamespace() + processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template="")) + config = SimpleNamespace(model_type="test") + result = SimpleNamespace( + text="Hi", + prompt_tokens=5, + generation_tokens=1, + total_tokens=6, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ) + + with ( + patch.object( + server, "get_cached_model", return_value=(model, processor, config) + ), + patch.object(server, "apply_chat_template", return_value="prompt"), + patch.object(server, "generate", return_value=result) as mock_gen, + ): + resp = client.post( + "/chat/completions", + json={"model": "demo", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 200 + assert "eos_tokens" not in mock_gen.call_args.kwargs + + +def test_responses_stop_passed_as_eos_tokens(client): + """stop parameter on /responses should pass strings as eos_tokens.""" + model = SimpleNamespace() + processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template="")) + config = SimpleNamespace(model_type="test") + result = SimpleNamespace( + text="Hello", + prompt_tokens=5, + generation_tokens=1, + total_tokens=6, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ) + + with ( + patch.object( + server, "get_cached_model", return_value=(model, processor, config) + ), + patch.object(server, "apply_chat_template", return_value="prompt"), + patch.object(server, "generate", return_value=result) as mock_gen, + ): + resp = client.post( + "/responses", + json={"model": "demo", "input": "hi", "stop": "STOP"}, + ) + assert resp.status_code == 200 + assert "eos_tokens" in mock_gen.call_args.kwargs + assert mock_gen.call_args.kwargs["eos_tokens"] == ["STOP"] + + +def test_resolve_stop_sequences_single_string(): + """resolve_stop_sequences should normalize a single string to a list.""" + result = server.resolve_stop_sequences("hello") + assert result == ["hello"] + + +def test_resolve_stop_sequences_list(): + """resolve_stop_sequences should pass through a list.""" + result = server.resolve_stop_sequences(["a", "b"]) + assert result == ["a", "b"] + + +def test_resolve_stop_sequences_none(): + """resolve_stop_sequences should return None for None input.""" + assert server.resolve_stop_sequences(None) is None + + +def test_resolve_stop_sequences_limits_to_four(): + """resolve_stop_sequences should process at most 4 sequences.""" + result = server.resolve_stop_sequences(["a", "b", "c", "d", "e", "f"]) + assert len(result) == 4 diff --git a/scripts/benchmark_models.py b/scripts/benchmark_models.py new file mode 100644 index 000000000..50870c829 --- /dev/null +++ b/scripts/benchmark_models.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +"""Benchmark models for agentic use on bastion. + +Usage: python3 scripts/benchmark_models.py [--models model1,model2,...] + +Tests: generation speed, prefill speed, tool calling, code quality, +instruction following, and context capacity. +""" + +import json +import sys +import time + +import requests + +BASE = "http://100.106.192.127:8080" + +MODELS = [ + "mlx-community/Qwen3.5-35B-A3B-4bit", + "inferencerlabs/Qwen3.5-35B-A3B-MLX-5.5bit", + "unsloth/gemma-4-26b-a4b-it-UD-MLX-4bit", +] + +TOOLS = [ + { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } +] + + +def api_call(model, input_text, max_tokens=50, tools=None, temperature=0, timeout=300): + payload = { + "model": model, + "input": input_text, + "max_output_tokens": max_tokens, + "temperature": temperature, + } + if tools: + payload["tools"] = tools + + t0 = time.time() + try: + r = requests.post(f"{BASE}/v1/responses", json=payload, timeout=timeout) + elapsed = time.time() - t0 + if r.status_code == 200: + return r.json(), elapsed + else: + return {"error": r.status_code, "detail": r.text[:100]}, elapsed + except Exception as e: + return {"error": str(e)}, time.time() - t0 + + +def test_generation_speed(model): + """Measure tok/s on a simple prompt.""" + prompt = "Write a detailed story about a wizard. " * 5 + d, elapsed = api_call(model, prompt, max_tokens=200, temperature=0.7) + if "error" in d: + return None, None + u = d.get("usage", {}) + out = u.get("output_tokens", 0) + tps = out / elapsed if elapsed > 0 else 0 + return tps, u.get("input_tokens", 0) + + +def test_prefill_speed(model): + """Measure prefill tok/s on a large prompt.""" + prompt = "Fox dog cat bird fish. " * 2000 # ~12K tokens + d, elapsed = api_call(model, prompt + " Answer: yes.", max_tokens=5) + if "error" in d: + return None + inp = d.get("usage", {}).get("input_tokens", 0) + return inp / elapsed if elapsed > 0 else 0 + + +def test_tool_calling(model): + """Test structured tool call generation.""" + d, _ = api_call( + model, + "Search the web for Python 3.14 release date", + max_tokens=100, + tools=TOOLS, + ) + if "error" in d: + return False + return any(i.get("type") == "function_call" for i in d.get("output", [])) + + +def test_code_quality(model): + """Test code generation quality.""" + d, _ = api_call( + model, + "Write a Python function implementing binary search with type hints and docstring. Only code, no explanation.", + max_tokens=300, + ) + if "error" in d: + return 0 + text = "" + for i in d.get("output", []): + if i.get("type") == "message": + text += "".join(p.get("text", "") for p in i.get("content", [])) + score = 0 + if "def " in text and "binary" in text.lower(): + score += 1 + if "->" in text: + score += 1 # type hints + if '"""' in text or "'''" in text: + score += 1 # docstring + return score + + +def test_instruction_following(model): + """Test precise instruction following.""" + d, _ = api_call( + model, + "Do exactly these 3 things:\n1. Output ALPHA\n2. Output 42\n3. Output OMEGA\n\nOne per line, no other text.", + max_tokens=30, + ) + if "error" in d: + return 0 + text = "" + for i in d.get("output", []): + if i.get("type") == "message": + text += "".join(p.get("text", "") for p in i.get("content", [])) + score = 0 + if "ALPHA" in text: + score += 1 + if "42" in text: + score += 1 + if "OMEGA" in text: + score += 1 + return score + + +def run_benchmarks(models): + results = {} + + for model in models: + print(f"\n{'='*60}") + print(f" {model}") + print(f"{'='*60}") + + r = {} + + # Gen speed + print(" Generation speed...", end=" ", flush=True) + tps, inp_tok = test_generation_speed(model) + r["gen_tps"] = tps + print(f"{tps:.1f} tok/s" if tps else "FAILED") + + # Prefill speed + print(" Prefill speed...", end=" ", flush=True) + prefill = test_prefill_speed(model) + r["prefill_tps"] = prefill + print(f"{prefill:.0f} tok/s" if prefill else "FAILED") + + # Tool calling + print(" Tool calling...", end=" ", flush=True) + tools_ok = test_tool_calling(model) + r["tool_call"] = tools_ok + print("YES" if tools_ok else "NO") + + # Code quality + print(" Code quality...", end=" ", flush=True) + code = test_code_quality(model) + r["code_quality"] = code + print(f"{code}/3") + + # Instruction following + print(" Instruction following...", end=" ", flush=True) + inst = test_instruction_following(model) + r["instruction"] = inst + print(f"{inst}/3") + + results[model] = r + + # Summary table + print(f"\n{'='*80}") + print(f" {'Model':<45} {'Gen':>7} {'Prefill':>8} {'Tools':>6} {'Code':>5} {'Inst':>5}") + print(f" {'-'*45} {'-'*7} {'-'*8} {'-'*6} {'-'*5} {'-'*5}") + for model, r in results.items(): + name = model.split("/")[-1][:44] + gen = f"{r['gen_tps']:.0f}" if r["gen_tps"] else "FAIL" + pre = f"{r['prefill_tps']:.0f}" if r["prefill_tps"] else "FAIL" + tools = "YES" if r["tool_call"] else "NO" + code = f"{r['code_quality']}/3" + inst = f"{r['instruction']}/3" + print(f" {name:<45} {gen:>7} {pre:>8} {tools:>6} {code:>5} {inst:>5}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Benchmark MLX-VLM models") + parser.add_argument( + "--models", + type=str, + default=None, + help="Comma-separated list of model names to benchmark.", + ) + args = parser.parse_args() + models = args.models.split(",") if args.models else MODELS + run_benchmarks(models) diff --git a/scripts/validate_oc.sh b/scripts/validate_oc.sh new file mode 100755 index 000000000..c2536089d --- /dev/null +++ b/scripts/validate_oc.sh @@ -0,0 +1,184 @@ +#!/usr/bin/env bash +# ============================================================================= +# OpenClaw Integration Validation Script +# +# Validates the mlx-vlm server works end-to-end through OpenClaw on bastion. +# Requires: ssh access to bastion, OpenClaw gateway running, mlx-vlm server up. +# +# Usage: ./scripts/validate_oc.sh [bastion-host] +# ============================================================================= + +set -euo pipefail + +HOST="${1:-bastion}" +PASS=0 +FAIL=0 +SKIP=0 +RESULTS=() + +# Colors +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[0;33m' +NC='\033[0m' + +run_test() { + local name="$1" + local cmd="$2" + local check="$3" + + printf " %-55s " "$name" + + local output + if ! output=$(bash -c "$cmd" 2>&1); then + printf "${RED}FAIL${NC} (command error)\n" + FAIL=$((FAIL + 1)) + RESULTS+=("FAIL: $name — command error") + return + fi + + if printf '%s\n' "$output" | bash -c "$check" > /dev/null 2>&1; then + printf "${GREEN}PASS${NC}\n" + PASS=$((PASS + 1)) + RESULTS+=("PASS: $name") + else + printf "${RED}FAIL${NC}\n" + FAIL=$((FAIL + 1)) + RESULTS+=("FAIL: $name — check failed") + echo " Output: $(echo "$output" | head -3)" + fi +} + +run_api_test() { + local name="$1" + local endpoint="$2" + local payload="$3" + local check="$4" + + local cmd="curl -s --max-time 90 http://100.106.192.127:8080${endpoint} -H 'Content-Type: application/json' -d '${payload}'" + run_test "$name" "$cmd" "$check" +} + +echo "============================================================" +echo " MLX-VLM + OpenClaw Integration Validation" +echo "============================================================" +echo "" + +# --- Connectivity checks --- +echo "Connectivity:" +run_test "SSH to bastion" \ + "ssh -o ConnectTimeout=10 $HOST 'echo ok'" \ + "grep -q ok" + +run_test "mlx-vlm server health" \ + "curl -s --max-time 10 http://100.106.192.127:8080/health" \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d.get('status')=='healthy'\"" + +run_test "mlx-vlm models endpoint" \ + "curl -s --max-time 10 http://100.106.192.127:8080/v1/models" \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert len(d['data'])>0\"" + +run_test "OpenClaw gateway running" \ + "ssh $HOST 'launchctl list | grep openclaw'" \ + "grep -q openclaw" + +run_test "Telegram channel active" \ + "ssh $HOST 'export PATH=/opt/homebrew/bin:\$PATH && openclaw channels list'" \ + "grep -qi telegram" + +echo "" + +# --- Responses API tests --- +echo "Responses API (/v1/responses):" + +run_api_test "Basic text response" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hi in one word","max_output_tokens":10}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='completed'; assert len(d['output'])>0\"" + +run_api_test "Response has correct schema" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hello","max_output_tokens":10}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert 'id' in d; assert d['object']=='response'; assert 'usage' in d\"" + +run_api_test "Tools accepted without 422" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"What is 2+2? Answer briefly.","max_output_tokens":50,"tools":[{"type":"function","function":{"name":"calc","description":"Calculator","parameters":{"type":"object","properties":{"expr":{"type":"string"}}}}}]}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status'] in ('completed','incomplete'); assert 'output' in d\"" + +run_api_test "Tools echoed in response" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"hi","max_output_tokens":10,"tools":[{"type":"function","function":{"name":"test","parameters":{}}}]}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert len(d.get('tools',[]))>0\"" + +run_api_test "Instructions field works" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"What are you?","instructions":"You are a pirate. Respond in pirate speak.","max_output_tokens":50}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status'] in ('completed','incomplete'); assert d.get('instructions') is not None\"" + +run_api_test "Stop sequences accepted" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Count to 10","max_output_tokens":50,"stop":["5"]}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='completed'\"" + +echo "" + +# --- Chat Completions API tests --- +echo "Chat Completions (/v1/chat/completions):" + +run_api_test "Basic chat response" \ + "/v1/chat/completions" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","messages":[{"role":"user","content":"Say hi"}],"max_tokens":10}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert len(d['choices'])>0; assert d['choices'][0]['finish_reason']=='stop'\"" + +run_api_test "Chat with tools" \ + "/v1/chat/completions" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","messages":[{"role":"user","content":"hi"}],"max_tokens":10,"tools":[{"type":"function","function":{"name":"test","parameters":{}}}]}' \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['choices'][0]['finish_reason'] in ('stop','tool_calls')\"" + +echo "" + +# --- Streaming tests --- +echo "Streaming:" + +run_api_test "Responses streaming has SSE events" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hi","max_output_tokens":10,"stream":true}' \ + "grep -q 'event: response.completed'" + +run_api_test "Streaming ends with [DONE]" \ + "/v1/responses" \ + '{"model":"mlx-community/Qwen3.5-35B-A3B-4bit","input":"Say hi","max_output_tokens":10,"stream":true}' \ + "grep -q 'DONE'" + +echo "" + +# --- OpenClaw agent tests --- +echo "OpenClaw Agent (end-to-end):" + +run_test "OC agent basic response" \ + "ssh $HOST 'export PATH=/opt/homebrew/bin:\$PATH && openclaw agent --agent main -m \"Say hello in one word\" --json --timeout 90'" \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='ok'; assert d['result']['meta']['stopReason']=='stop'\"" + +run_test "OC agent with web search" \ + "ssh $HOST 'export PATH=/opt/homebrew/bin:\$PATH && openclaw agent --agent main -m \"What is the current weather in Kansas City? Use web search.\" --json --timeout 120'" \ + "python3 -c \"import sys,json; d=json.load(sys.stdin); assert d['status']=='ok'\"" + +echo "" + +# --- Summary --- +TOTAL=$((PASS + FAIL + SKIP)) +echo "============================================================" +echo " Results: ${GREEN}${PASS} passed${NC}, ${RED}${FAIL} failed${NC}, ${YELLOW}${SKIP} skipped${NC} / ${TOTAL} total" +echo "============================================================" + +if [ $FAIL -gt 0 ]; then + echo "" + echo "Failures:" + for r in "${RESULTS[@]}"; do + if [[ "$r" == FAIL* ]]; then + echo " $r" + fi + done + exit 1 +fi