feat: add request cancellation and timeout support#20
Conversation
…oints Detect client disconnects during SSE streaming via request.is_disconnected() and asyncio.CancelledError, closing the token iterator to stop generation. For non-streaming requests, wrap generation in asyncio.wait_for with a configurable --request-timeout (default 300s) returning HTTP 504 on timeout. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds request cancellation handling for streaming responses and introduces a configurable timeout for non-streaming generation requests (plus a couple of new request/CLI behaviors around response formatting and prompt length limits).
Changes:
- Detects client disconnects during SSE streaming for
/responsesand/chat/completionsand aborts token iteration. - Adds non-streaming request timeout support (
--request-timeout, default 300s) and tests for 504 responses on timeout. - Introduces
response_formathandling (JSON-only instruction injection) and a--max-context-tokensprompt-length enforcement.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 11 comments.
| File | Description |
|---|---|
mlx_vlm/server.py |
Adds disconnect checks for streaming, timeout logic for non-streaming via executor + wait_for, and introduces response_format + --max-context-tokens enforcement. |
mlx_vlm/tests/test_server.py |
Adds tests for timeout/env-default behavior (but includes a couple of test-isolation/lint issues). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
|
|
||
| def get_request_timeout(): | ||
| return int(os.environ.get("REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)) |
There was a problem hiding this comment.
get_request_timeout() directly casts the REQUEST_TIMEOUT env var to int; if the env var is set to a non-integer value (e.g. "300s"), this will raise ValueError and crash the request handler. Consider parsing defensively (try/except) and falling back to DEFAULT_REQUEST_TIMEOUT or raising a clear HTTPException/config error.
| return int(os.environ.get("REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)) | |
| request_timeout = os.environ.get("REQUEST_TIMEOUT") | |
| if request_timeout is None: | |
| return DEFAULT_REQUEST_TIMEOUT | |
| try: | |
| return int(request_timeout) | |
| except (TypeError, ValueError): | |
| return DEFAULT_REQUEST_TIMEOUT |
| # Use generate from generate.py, with request timeout | ||
| timeout = get_request_timeout() | ||
| loop = asyncio.get_event_loop() | ||
| try: | ||
| result = await asyncio.wait_for( | ||
| loop.run_in_executor( | ||
| None, | ||
| lambda: generate( | ||
| model=model, | ||
| processor=processor, | ||
| prompt=formatted_prompt, | ||
| image=images, | ||
| verbose=False, | ||
| **generation_kwargs, | ||
| ), | ||
| ), | ||
| timeout=timeout, | ||
| ) | ||
| except asyncio.TimeoutError: | ||
| print(f"[cancellation] /responses generation timed out after {timeout}s.") | ||
| mx.clear_cache() | ||
| gc.collect() | ||
| raise HTTPException( | ||
| status_code=504, | ||
| detail=f"Generation timed out after {timeout} seconds.", | ||
| ) |
There was a problem hiding this comment.
Using asyncio.wait_for() around loop.run_in_executor() will time out the awaiting task, but it will not stop the underlying generate() call running in the thread. That means timed-out requests may continue consuming compute, and mx.clear_cache()/gc.collect() here can run concurrently with generation, risking undefined behavior or crashes. Consider implementing cooperative cancellation/timeout inside generate(), or running generation in a separate process that can be terminated, and avoid clearing MLX caches until the generation task has actually stopped.
| ) | ||
| # Use generate from generate.py, with request timeout | ||
| timeout = get_request_timeout() | ||
| loop = asyncio.get_event_loop() |
There was a problem hiding this comment.
asyncio.get_event_loop() inside an async endpoint is deprecated-style usage on modern Python; prefer asyncio.get_running_loop() to reliably retrieve the active loop and avoid warnings/behavior changes across Python versions.
| loop = asyncio.get_event_loop() | |
| loop = asyncio.get_running_loop() |
| # Use generate from generate.py, with request timeout | ||
| timeout = get_request_timeout() | ||
| loop = asyncio.get_event_loop() | ||
| try: | ||
| gen_result = await asyncio.wait_for( | ||
| loop.run_in_executor( | ||
| None, | ||
| lambda: generate( | ||
| model=model, | ||
| processor=processor, | ||
| prompt=formatted_prompt, | ||
| image=images, | ||
| audio=audio, | ||
| verbose=False, | ||
| vision_cache=model_cache.get("vision_cache"), | ||
| **generation_kwargs, | ||
| ), | ||
| ), | ||
| timeout=timeout, | ||
| ) | ||
| except asyncio.TimeoutError: | ||
| print(f"[cancellation] /chat/completions generation timed out after {timeout}s.") | ||
| mx.clear_cache() | ||
| gc.collect() | ||
| raise HTTPException( | ||
| status_code=504, | ||
| detail=f"Generation timed out after {timeout} seconds.", | ||
| ) |
There was a problem hiding this comment.
Same issue as /responses: wait_for() around run_in_executor() won't stop the background generate() thread on timeout, and clearing MX caches while that thread may still be running is unsafe. Consider cooperative cancellation in generate() or isolating generation in a killable worker (e.g., separate process), and only clear caches after the generation has actually ended.
| ) | ||
| # Use generate from generate.py, with request timeout | ||
| timeout = get_request_timeout() | ||
| loop = asyncio.get_event_loop() |
There was a problem hiding this comment.
asyncio.get_event_loop() inside this async endpoint should be replaced with asyncio.get_running_loop() to avoid deprecated behavior and ensure the currently running loop is used.
| loop = asyncio.get_event_loop() | |
| loop = asyncio.get_running_loop() |
| def test_get_request_timeout_default(): | ||
| """Default timeout should be 300 seconds when env var is unset.""" | ||
| os.environ.pop("REQUEST_TIMEOUT", None) | ||
| assert server.get_request_timeout() == 300 | ||
|
|
||
|
|
||
| def test_get_request_timeout_from_env(): | ||
| """REQUEST_TIMEOUT env var should override the default.""" | ||
| os.environ["REQUEST_TIMEOUT"] = "60" | ||
| try: | ||
| assert server.get_request_timeout() == 60 | ||
| finally: | ||
| os.environ.pop("REQUEST_TIMEOUT", None) |
There was a problem hiding this comment.
This test forces REQUEST_TIMEOUT to "60" and then unsets it, which does not restore a pre-existing value if one was set before the test run. Use monkeypatch (auto-restores) or capture/restore the previous env var value to keep tests isolated.
| def test_get_request_timeout_default(): | |
| """Default timeout should be 300 seconds when env var is unset.""" | |
| os.environ.pop("REQUEST_TIMEOUT", None) | |
| assert server.get_request_timeout() == 300 | |
| def test_get_request_timeout_from_env(): | |
| """REQUEST_TIMEOUT env var should override the default.""" | |
| os.environ["REQUEST_TIMEOUT"] = "60" | |
| try: | |
| assert server.get_request_timeout() == 60 | |
| finally: | |
| os.environ.pop("REQUEST_TIMEOUT", None) | |
| def test_get_request_timeout_default(monkeypatch): | |
| """Default timeout should be 300 seconds when env var is unset.""" | |
| monkeypatch.delenv("REQUEST_TIMEOUT", raising=False) | |
| assert server.get_request_timeout() == 300 | |
| def test_get_request_timeout_from_env(monkeypatch): | |
| """REQUEST_TIMEOUT env var should override the default.""" | |
| monkeypatch.setenv("REQUEST_TIMEOUT", "60") | |
| assert server.get_request_timeout() == 60 |
| import asyncio | ||
|
|
There was a problem hiding this comment.
Unused import: asyncio is imported here but never used. This will be flagged by the repo's pre-commit autoflake hook (remove-all-unused-imports). Remove the import to keep the test file clean.
| import asyncio |
| parser.add_argument( | ||
| "--max-context-tokens", | ||
| type=int, | ||
| default=0, | ||
| help="Maximum context window in tokens. Requests exceeding this are rejected. " | ||
| "0 means no limit. (default: %(default)s)", | ||
| ) | ||
| parser.add_argument( | ||
| "--request-timeout", | ||
| type=int, | ||
| default=DEFAULT_REQUEST_TIMEOUT, | ||
| help="Maximum seconds for a single non-streaming generation request. " | ||
| "Streaming requests use client disconnect detection instead. " | ||
| "(default: %(default)s)", | ||
| ) |
There was a problem hiding this comment.
PR description mentions request cancellation + --request-timeout, but this PR also introduces response_format handling and a new --max-context-tokens flag/context-length enforcement. Please update the PR description to include these additional user-facing features (or split them into a separate PR if they’re out of scope).
| def check_context_length(prompt: str, processor, max_context: int) -> None: | ||
| """Raise HTTP 400 if the tokenized prompt exceeds *max_context* tokens.""" | ||
| if max_context <= 0: | ||
| return | ||
| tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor | ||
| token_count = len(tokenizer.encode(prompt, add_special_tokens=False)) | ||
| if token_count > max_context: | ||
| raise HTTPException( | ||
| status_code=400, | ||
| detail=f"Prompt length ({token_count} tokens) exceeds maximum context " | ||
| f"window ({max_context} tokens). Reduce your prompt or increase --max-context-tokens.", | ||
| ) |
There was a problem hiding this comment.
check_context_length() introduces new request-rejection behavior (HTTP 400 when prompt exceeds MAX_CONTEXT_TOKENS), but there are no tests covering this path. Add tests that set MAX_CONTEXT_TOKENS/--max-context-tokens and verify both /responses and /chat/completions reject over-limit prompts with the expected error detail.
| "You must respond with valid JSON only. " | ||
| "Do not include any text outside the JSON object." | ||
| ) | ||
| messages.insert(0, {"role": "system", "content": json_instruction}) |
There was a problem hiding this comment.
resolve_response_format() changes request semantics by injecting a system JSON-only instruction when response_format.type == "json_object", but there are no tests asserting the template receives the injected system message (and that it’s only injected once). Add unit/integration tests for both endpoints to lock in this behavior.
| messages.insert(0, {"role": "system", "content": json_instruction}) | |
| has_json_instruction = any( | |
| message.get("role") == "system" | |
| and message.get("content") == json_instruction | |
| for message in messages | |
| ) | |
| if not has_json_instruction: | |
| messages.insert(0, {"role": "system", "content": json_instruction}) |
- Validate REQUEST_TIMEOUT env var (catch ValueError, reject <= 0) - Validate MAX_CONTEXT_TOKENS >= 0 - Use asyncio.get_running_loop() instead of get_event_loop() - Add comments noting wait_for cannot cancel sync thread - Use monkeypatch for env var tests - Remove unused os import from tests - Add context length rejection tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary\nHandle client disconnects during streaming. --request-timeout flag (default 300s). 5 tests.