Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions pymllm/orchestrator/scheduler_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import logging
import os
import queue as stdlib_queue
import time
from collections import deque
Expand All @@ -55,6 +56,27 @@
_DEFAULT_MAX_TOTAL_TOKENS = 131072
_DEFAULT_MAX_NEW_TOKENS = 32768

# Brief poll timeout (ms) used between decode batches to avoid 100% CPU spin.
# 1 ms is enough to yield the CPU core to the OS scheduler while adding
# negligible latency (decode steps typically take >1 ms on the GPU anyway).
# Override via MLLM_DECODE_POLL_TIMEOUT_MS env var for testing.
def _read_decode_poll_timeout_ms() -> int:
raw = os.environ.get("MLLM_DECODE_POLL_TIMEOUT_MS", "1")
try:
val = int(raw)
except ValueError:
raise ValueError(
f"MLLM_DECODE_POLL_TIMEOUT_MS must be a non-negative integer, got {raw!r}"
)
Comment on lines +67 to +70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's verify the file exists and read the relevant section
wc -l pymllm/orchestrator/scheduler_process.py

Repository: UbiquitousLearning/mllm

Length of output: 112


🏁 Script executed:

# Read lines 60-80 to see the context around lines 67-70
sed -n '60,80p' pymllm/orchestrator/scheduler_process.py

Repository: UbiquitousLearning/mllm

Length of output: 755


🏁 Script executed:

# Let's also check the full function to understand the complete context
sed -n '60,80p' pymllm/orchestrator/scheduler_process.py | cat -n

Repository: UbiquitousLearning/mllm

Length of output: 902


🏁 Script executed:

# Verify the exact line numbers by counting
head -70 pymllm/orchestrator/scheduler_process.py | tail -15 | cat -n

Repository: UbiquitousLearning/mllm

Length of output: 811


Preserve exception causality by chaining the original parse error.

The ValueError caught at line 68 should be chained when re-raised to preserve the original exception context for debugging. This aligns with Python best practices (PEP 3134) and helps maintain diagnostic information when the environment variable fails to parse.

🔧 Suggested fix
    except ValueError:
-        raise ValueError(
+        raise ValueError(
             f"MLLM_DECODE_POLL_TIMEOUT_MS must be a non-negative integer, got {raw!r}"
-        )
+        ) from exc

(Also add as exc to the except clause: except ValueError as exc:)

🧰 Tools
🪛 Ruff (0.15.6)

[warning] 68-70: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/orchestrator/scheduler_process.py` around lines 67 - 70, Change the
bare exception re-raise to preserve causality: update the except clause from
"except ValueError:" to "except ValueError as exc:" and re-raise the new
ValueError with "from exc" so the original parse error is chained (keep the
existing message that uses raw and the MLLM_DECODE_POLL_TIMEOUT_MS context).

if val < 0:
raise ValueError(
f"MLLM_DECODE_POLL_TIMEOUT_MS must be >= 0, got {val}"
)
return val


_DECODE_POLL_TIMEOUT_MS = _read_decode_poll_timeout_ms()


# ======================================================================
# IdleSleeper -- avoid busy-looping when no work is available
Expand Down Expand Up @@ -482,20 +504,30 @@ def init_model(self) -> None:
logger.info("In-process model runner initialised on GPU %d", self._gpu_id)

def event_loop(self) -> None:
"""Infinite scheduling loop."""
"""Infinite scheduling loop.

When decode batches are active the loop would otherwise spin at
100 % CPU doing non-blocking ZMQ polls between GPU forward passes.
We track whether the previous iteration ran a decode batch and, if
so, use a brief poll timeout (default 1 ms) in ``recv_requests``
so the OS can schedule other work on this core.
"""
logger.info(
"SchedulerProcess event loop started (shared_queue=%s, transport=%s)",
self._enable_shared_queue,
self._tensor_transport_mode,
)
_in_decode = False
while True:
self.recv_requests()
self.recv_requests(brief_poll=_in_decode)
self.process_input_requests()
batch = self.get_next_batch_to_run()
if batch is not None:
_in_decode = not batch.forward_mode.is_extend()
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
_in_decode = False
# No work available -- sleep until a new request arrives
# on the ZMQ socket (or timeout). Avoids busy-looping.
self._idle_sleeper.sleep()
Expand All @@ -505,30 +537,40 @@ def event_loop(self) -> None:
# Step 1: receive tokenized requests (non-blocking)
# ------------------------------------------------------------------

def recv_requests(self) -> None:
def recv_requests(self, brief_poll: bool = False) -> None:
"""Non-blocking receive of tokenized requests from TokenizerProcess.

Supports two modes:
1. Legacy ZMQ: Uses ``zmq.Poller`` with a short timeout
2. Shared queue: Non-blocking get from multiprocessing.Queue

When *brief_poll* is ``True`` (typically during active decode), the
first poll uses a small timeout (``_DECODE_POLL_TIMEOUT_MS``) instead
of zero. This yields the CPU core to the OS scheduler between decode
batches while adding negligible latency.

Messages are either:
* A :class:`~pymllm.engine.io_struct.TokenizedGenerateReqInput`
dataclass appended to ``_waiting_queue``.
* A plain abort sentinel dict ``{"rid": ..., "abort": True}`` handled
dataclass - appended to ``_waiting_queue``.
* A plain abort sentinel dict ``{"rid": ..., "abort": True}`` - handled
inline by removing the matching rid from the waiting queue.
"""
if self._enable_shared_queue and self._shared_queue is not None:
self._recv_from_shared_queue()
self._recv_from_shared_queue(brief_poll=brief_poll)
else:
self._recv_from_zmq()
self._recv_from_zmq(brief_poll=brief_poll)

def _recv_from_zmq(self) -> None:
def _recv_from_zmq(self, brief_poll: bool = False) -> None:
"""Receive requests via legacy ZMQ path."""
# On the first poll, use a brief timeout if requested (decode path)
# to yield the CPU. After draining the first message, switch to
# non-blocking for any remaining queued messages.
poll_timeout = _DECODE_POLL_TIMEOUT_MS if brief_poll else 0
while True:
events = dict(self._poller.poll(timeout=0)) # non-blocking
events = dict(self._poller.poll(timeout=poll_timeout))
if self._recv_from_tokenizer not in events:
break
poll_timeout = 0 # drain remaining messages without blocking
msg = self._recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
# Abort sentinel: plain dict with "abort" key.
if isinstance(msg, dict) and msg.get("abort"):
Expand All @@ -542,7 +584,7 @@ def _recv_from_zmq(self) -> None:
else:
self._waiting_queue.append(msg)

def _recv_from_shared_queue(self) -> None:
def _recv_from_shared_queue(self, brief_poll: bool = False) -> None:
"""Receive requests via shared memory + shared queue fast path.

After reading a ``(rid, shm_name, mm_inputs)`` tuple from the queue:
Expand All @@ -556,9 +598,13 @@ def _recv_from_shared_queue(self) -> None:
3. A full ``TokenizedGenerateReqInput`` is assembled and appended to
``_waiting_queue``.
"""
# Use a slightly longer timeout on the first get when in decode mode
# to yield CPU; subsequent gets use a short timeout to drain the queue.
get_timeout = _DECODE_POLL_TIMEOUT_MS / 1000.0 if brief_poll else 0.002
while True:
try:
rid, shm_name, mm_inputs = self._shared_queue.get(timeout=0.002)
rid, shm_name, mm_inputs = self._shared_queue.get(timeout=get_timeout)
get_timeout = 0.002 # drain remaining without extra delay

# Read metadata from shared memory (and unlink immediately)
metadata: TokenizedGenerateReqInput = SharedMemoryManager.read_metadata(
Expand Down