Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions modelexpress_client/python/modelexpress/trtllm_live_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,32 @@ def load_weights(
mpi_rank = device_id

# MPI workers' stdout is swallowed by TRT-LLM — write to per-rank file
# Use line-buffered mode and explicit handler config so messages are
# visible even if the worker exits before normal logging shutdown.
log_dir = os.environ.get("MX_TRANSFER_LOG_DIR", "/tmp/mx_logs")
os.makedirs(log_dir, exist_ok=True)
rank_log = os.path.join(log_dir, f"rank{mpi_rank}.log")
fh = logging.FileHandler(rank_log, mode="w")

mx_logger = logging.getLogger("modelexpress")
mx_logger.setLevel(logging.INFO)
# Avoid duplicate handlers across multiple load_weights() calls.
for h in list(mx_logger.handlers):
if isinstance(h, logging.FileHandler) and getattr(
h, "baseFilename", "") == os.path.abspath(rank_log):
mx_logger.removeHandler(h)
h.close()

# Open with line buffering so each log line hits disk immediately.
rank_log_stream = open(rank_log, "w", buffering=1)
fh = logging.StreamHandler(rank_log_stream)
fh.setLevel(logging.INFO)
fh.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s"))
logging.getLogger("modelexpress").addHandler(fh)
fh.setFormatter(logging.Formatter(
"%(asctime)s %(name)s %(levelname)s %(message)s"))
mx_logger.addHandler(fh)
# Track for explicit flush before return (safety net in case Python's
# exit handlers don't run, e.g. when MPI rank gets killed by mpirun).
self._rank_log_handler = fh
self._rank_log_stream = rank_log_stream
Comment on lines +318 to +334
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Duplicate-handler dedup no longer matches; rank log stream is leaked.

Two related concerns on the new handler setup:

  1. The dedup loop at lines 318–322 still checks isinstance(h, logging.FileHandler), but the handler being added on line 326 is a plain logging.StreamHandler (not a FileHandler). On a second call to load_weights() the previously attached StreamHandler won't be matched/removed, so handlers (and underlying open file streams) accumulate on mx_logger and every log line gets written N times.
  2. rank_log_stream opened on line 325 is never closed. cleanup() only flushes; the previous handler/stream reference on self is overwritten on each call, leaking the file descriptor. Closing it in the dedup branch and in cleanup() would fix both.
Proposed fix
-        mx_logger = logging.getLogger("modelexpress")
-        mx_logger.setLevel(logging.INFO)
-        # Avoid duplicate handlers across multiple load_weights() calls.
-        for h in list(mx_logger.handlers):
-            if isinstance(h, logging.FileHandler) and getattr(
-                    h, "baseFilename", "") == os.path.abspath(rank_log):
-                mx_logger.removeHandler(h)
-                h.close()
-
-        # Open with line buffering so each log line hits disk immediately.
-        rank_log_stream = open(rank_log, "w", buffering=1)
-        fh = logging.StreamHandler(rank_log_stream)
+        mx_logger = logging.getLogger("modelexpress")
+        mx_logger.setLevel(logging.INFO)
+        # Avoid duplicate handlers across multiple load_weights() calls.
+        rank_log_abspath = os.path.abspath(rank_log)
+        for h in list(mx_logger.handlers):
+            stream = getattr(h, "stream", None)
+            if stream is not None and getattr(stream, "name", None) == rank_log_abspath:
+                mx_logger.removeHandler(h)
+                try:
+                    stream.close()
+                except Exception:
+                    pass
+
+        # Open with line buffering so each log line hits disk immediately.
+        rank_log_stream = open(rank_log, "w", buffering=1)
+        fh = logging.StreamHandler(rank_log_stream)

And mirror the close in cleanup() after flushing.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelexpress_client/python/modelexpress/trtllm_live_transfer.py` around lines
318 - 334, The dedup loop in load_weights() currently only removes
logging.FileHandler instances so previously added StreamHandler and its open
file (rank_log_stream) are not removed or closed, causing duplicated handlers
and leaked file descriptors; update the dedup logic that iterates
mx_logger.handlers to remove both logging.StreamHandler and logging.FileHandler
whose base stream/file matches os.path.abspath(rank_log), and when removing a
matching handler close it and if it owns an underlying file-like object (the
rank_log_stream) close that stream as well; also ensure cleanup() not only
flushes self._rank_log_stream and handler but explicitly closes
self._rank_log_stream and removes/closes self._rank_log_handler (and set them to
None) to avoid overwriting references and leaking FDs on subsequent
load_weights() calls.


logger.info(
"Live transfer: loading '%s' rank %d (GPU %d)", model_name, mpi_rank, device_id
Expand Down Expand Up @@ -485,11 +504,26 @@ def load_weights(
except Exception as e:
logger.warning("PVC fallback failed: %s", e)

# Flush rank log so transfer metrics are visible even if MPI worker
# gets killed before normal logging shutdown runs.
if hasattr(self, "_rank_log_handler"):
try:
self._rank_log_handler.flush()
self._rank_log_stream.flush()
os.fsync(self._rank_log_stream.fileno())
except Exception:
pass

# Return fallback weights for TRT-LLM to apply; P2P weights are already in model params
return fallback_weights

def cleanup(self):
pass
if hasattr(self, "_rank_log_handler"):
try:
self._rank_log_handler.flush()
self._rank_log_stream.flush()
except Exception:
pass

def _query_source(self, mx_server, model_name, timeout=600):
import grpc
Expand Down
Loading