feat: concurrency guard for Metal GPU serialization#13
Conversation
MLX generation runs single-threaded on Metal. Concurrent HTTP requests to the generate pipeline can corrupt shared GPU state or crash with Metal command buffer errors. Add an asyncio.Semaphore that serializes access to the generation pipeline across both /v1/responses and /v1/chat/completions endpoints, for both streaming and non-streaming paths. New server flag: --max-concurrent-requests N (default: 1) env: MAX_CONCURRENT_REQUESTS The semaphore is acquired before generation begins and released in a finally block (non-streaming) or at the end of the stream generator (streaming), ensuring cleanup on errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces an application-level concurrency guard to serialize MLX generation on Metal, and exposes configuration for maximum concurrent generation requests via a CLI flag/environment variable.
Changes:
- Add a global
asyncio.Semaphoreto serialize access to the GPU generation pipeline. - Wrap streaming and non-streaming generation paths with semaphore acquire/release.
- Add
--max-concurrent-requestsCLI flag and plumb it throughMAX_CONCURRENT_REQUESTS.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_max_concurrent_requests() -> int: | ||
| return int(os.environ.get("MAX_CONCURRENT_REQUESTS", 1)) | ||
|
|
||
|
|
||
| def get_generation_semaphore() -> asyncio.Semaphore: | ||
| """Get or create the generation semaphore.""" | ||
| global _generation_semaphore | ||
| if _generation_semaphore is None: | ||
| _generation_semaphore = asyncio.Semaphore(get_max_concurrent_requests()) | ||
| return _generation_semaphore |
There was a problem hiding this comment.
get_max_concurrent_requests() blindly casts the env var to int; values like 0/negative will either deadlock requests (0 permits) or raise ValueError when creating the semaphore. Consider validating/clamping to >= 1 and raising a clear configuration error (ideally at startup / argument parsing) when the value is invalid.
|
|
||
| def get_generation_semaphore() -> asyncio.Semaphore: | ||
| """Get or create the generation semaphore.""" | ||
| global _generation_semaphore |
There was a problem hiding this comment.
The semaphore only serializes within a single Python process. If the app is launched with multiple Uvicorn/Gunicorn workers (e.g., uvicorn ... --workers N), concurrent requests can still hit Metal in parallel across processes. Consider documenting/enforcing single-worker operation or adding a cross-process guard if multi-worker deployments are supported.
| def get_generation_semaphore() -> asyncio.Semaphore: | |
| """Get or create the generation semaphore.""" | |
| global _generation_semaphore | |
| def _get_configured_worker_count() -> int: | |
| """ | |
| Best-effort detection of a multi-worker deployment. | |
| The in-process asyncio semaphore only serializes requests within a | |
| single Python process, so running multiple Uvicorn/Gunicorn workers | |
| would still allow concurrent generation across processes. | |
| """ | |
| worker_counts = [] | |
| for env_name in ("WEB_CONCURRENCY", "UVICORN_WORKERS"): | |
| value = os.environ.get(env_name) | |
| if value and value.isdigit(): | |
| worker_counts.append(int(value)) | |
| gunicorn_args = os.environ.get("GUNICORN_CMD_ARGS", "") | |
| match = re.search(r"(?:^|\s)(?:-w|--workers)(?:\s+|=)(\d+)(?:\s|$)", gunicorn_args) | |
| if match: | |
| worker_counts.append(int(match.group(1))) | |
| return max(worker_counts, default=1) | |
| def get_generation_semaphore() -> asyncio.Semaphore: | |
| """Get or create the generation semaphore.""" | |
| global _generation_semaphore | |
| worker_count = _get_configured_worker_count() | |
| if worker_count > 1: | |
| raise RuntimeError( | |
| "MLX generation must run with a single Uvicorn/Gunicorn worker. " | |
| f"Detected worker count: {worker_count}. " | |
| "The generation semaphore is process-local and cannot prevent " | |
| "concurrent Metal access across multiple worker processes." | |
| ) |
| sem = get_generation_semaphore() | ||
| await sem.acquire() | ||
| try: | ||
| # Use generate from generate.py | ||
| result = generate( |
There was a problem hiding this comment.
New concurrency behavior isn’t covered by tests. Since mlx_vlm/tests/test_server.py exists, it would be good to add a test that concurrent requests are serialized (e.g., two overlapping requests where the patched generate blocks until an event is set, asserting the second request does not enter generation until the first completes).
| sem = get_generation_semaphore() | ||
| await sem.acquire() | ||
| token_iterator = None |
There was a problem hiding this comment.
To reduce the risk of mismatched acquire/release and simplify the control flow, consider using an async context manager (async with get_generation_semaphore():) instead of manual acquire()/release() in the streaming and non-streaming paths.
- Validate max_concurrent_requests >= 1 (prevent deadlock on 0) - Add multi-worker uvicorn caveat comment - Use async with for non-streaming semaphore paths - Add 5 concurrency guard tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Validate max_concurrent_requests >= 1 (prevent deadlock on 0) - Add multi-worker uvicorn caveat comment - Use async with for non-streaming semaphore paths - Add 5 concurrency guard tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary\nasyncio.Semaphore serializes Metal GPU access. --max-concurrent-requests flag (default 1).