Skip to content

feat: concurrency guard for Metal GPU serialization#13

Open
eloe wants to merge 2 commits into
mainfrom
feature/concurrency-guard
Open

feat: concurrency guard for Metal GPU serialization#13
eloe wants to merge 2 commits into
mainfrom
feature/concurrency-guard

Conversation

@eloe
Copy link
Copy Markdown
Owner

@eloe eloe commented Apr 6, 2026

Summary\nasyncio.Semaphore serializes Metal GPU access. --max-concurrent-requests flag (default 1).

MLX generation runs single-threaded on Metal. Concurrent HTTP requests
to the generate pipeline can corrupt shared GPU state or crash with
Metal command buffer errors.

Add an asyncio.Semaphore that serializes access to the generation
pipeline across both /v1/responses and /v1/chat/completions endpoints,
for both streaming and non-streaming paths.

New server flag:
  --max-concurrent-requests N  (default: 1)
  env: MAX_CONCURRENT_REQUESTS

The semaphore is acquired before generation begins and released in a
finally block (non-streaming) or at the end of the stream generator
(streaming), ensuring cleanup on errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an application-level concurrency guard to serialize MLX generation on Metal, and exposes configuration for maximum concurrent generation requests via a CLI flag/environment variable.

Changes:

  • Add a global asyncio.Semaphore to serialize access to the GPU generation pipeline.
  • Wrap streaming and non-streaming generation paths with semaphore acquire/release.
  • Add --max-concurrent-requests CLI flag and plumb it through MAX_CONCURRENT_REQUESTS.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlx_vlm/server.py
Comment on lines +132 to +141
def get_max_concurrent_requests() -> int:
return int(os.environ.get("MAX_CONCURRENT_REQUESTS", 1))


def get_generation_semaphore() -> asyncio.Semaphore:
"""Get or create the generation semaphore."""
global _generation_semaphore
if _generation_semaphore is None:
_generation_semaphore = asyncio.Semaphore(get_max_concurrent_requests())
return _generation_semaphore
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_max_concurrent_requests() blindly casts the env var to int; values like 0/negative will either deadlock requests (0 permits) or raise ValueError when creating the semaphore. Consider validating/clamping to >= 1 and raising a clear configuration error (ideally at startup / argument parsing) when the value is invalid.

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/server.py
Comment on lines +135 to +138

def get_generation_semaphore() -> asyncio.Semaphore:
"""Get or create the generation semaphore."""
global _generation_semaphore
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semaphore only serializes within a single Python process. If the app is launched with multiple Uvicorn/Gunicorn workers (e.g., uvicorn ... --workers N), concurrent requests can still hit Metal in parallel across processes. Consider documenting/enforcing single-worker operation or adding a cross-process guard if multi-worker deployments are supported.

Suggested change
def get_generation_semaphore() -> asyncio.Semaphore:
"""Get or create the generation semaphore."""
global _generation_semaphore
def _get_configured_worker_count() -> int:
"""
Best-effort detection of a multi-worker deployment.
The in-process asyncio semaphore only serializes requests within a
single Python process, so running multiple Uvicorn/Gunicorn workers
would still allow concurrent generation across processes.
"""
worker_counts = []
for env_name in ("WEB_CONCURRENCY", "UVICORN_WORKERS"):
value = os.environ.get(env_name)
if value and value.isdigit():
worker_counts.append(int(value))
gunicorn_args = os.environ.get("GUNICORN_CMD_ARGS", "")
match = re.search(r"(?:^|\s)(?:-w|--workers)(?:\s+|=)(\d+)(?:\s|$)", gunicorn_args)
if match:
worker_counts.append(int(match.group(1)))
return max(worker_counts, default=1)
def get_generation_semaphore() -> asyncio.Semaphore:
"""Get or create the generation semaphore."""
global _generation_semaphore
worker_count = _get_configured_worker_count()
if worker_count > 1:
raise RuntimeError(
"MLX generation must run with a single Uvicorn/Gunicorn worker. "
f"Detected worker count: {worker_count}. "
"The generation semaphore is process-local and cannot prevent "
"concurrent Metal access across multiple worker processes."
)

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/server.py Outdated
Comment on lines 1005 to 1009
sem = get_generation_semaphore()
await sem.acquire()
try:
# Use generate from generate.py
result = generate(
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New concurrency behavior isn’t covered by tests. Since mlx_vlm/tests/test_server.py exists, it would be good to add a test that concurrent requests are serialized (e.g., two overlapping requests where the patched generate blocks until an event is set, asserting the second request does not enter generation until the first completes).

Copilot uses AI. Check for mistakes.
Comment thread mlx_vlm/server.py
Comment on lines +875 to 877
sem = get_generation_semaphore()
await sem.acquire()
token_iterator = None
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce the risk of mismatched acquire/release and simplify the control flow, consider using an async context manager (async with get_generation_semaphore():) instead of manual acquire()/release() in the streaming and non-streaming paths.

Copilot uses AI. Check for mistakes.
- Validate max_concurrent_requests >= 1 (prevent deadlock on 0)
- Add multi-worker uvicorn caveat comment
- Use async with for non-streaming semaphore paths
- Add 5 concurrency guard tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
eloe added a commit that referenced this pull request Apr 9, 2026
- Validate max_concurrent_requests >= 1 (prevent deadlock on 0)
- Add multi-worker uvicorn caveat comment
- Use async with for non-streaming semaphore paths
- Add 5 concurrency guard tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants