From 2e7f60ecd199beebac814f3c41deaa18b3ce46ac Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 28 Apr 2026 20:28:38 -0700 Subject: [PATCH 1/2] fix(trtllm): flush rank log handler so MPI workers' transfer metrics persist MPI worker processes get killed via SIGKILL when the engine init completes, which skips Python's normal logging shutdown. The buffered FileHandler loses all logged Gbps/timing metrics. Switch to line-buffered stream and explicit flush+fsync at end of load_weights(). Signed-off-by: Kavin Krishnan Made-with: Cursor --- .../modelexpress/trtllm_live_transfer.py | 42 +++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/modelexpress_client/python/modelexpress/trtllm_live_transfer.py b/modelexpress_client/python/modelexpress/trtllm_live_transfer.py index e84d3f0b..20d4254e 100644 --- a/modelexpress_client/python/modelexpress/trtllm_live_transfer.py +++ b/modelexpress_client/python/modelexpress/trtllm_live_transfer.py @@ -306,13 +306,32 @@ def load_weights( mpi_rank = device_id # MPI workers' stdout is swallowed by TRT-LLM — write to per-rank file + # Use line-buffered mode and explicit handler config so messages are + # visible even if the worker exits before normal logging shutdown. log_dir = os.environ.get("MX_TRANSFER_LOG_DIR", "/tmp/mx_logs") os.makedirs(log_dir, exist_ok=True) rank_log = os.path.join(log_dir, f"rank{mpi_rank}.log") - fh = logging.FileHandler(rank_log, mode="w") + + mx_logger = logging.getLogger("modelexpress") + mx_logger.setLevel(logging.INFO) + # Avoid duplicate handlers across multiple load_weights() calls. + for h in list(mx_logger.handlers): + if isinstance(h, logging.FileHandler) and getattr( + h, "baseFilename", "") == os.path.abspath(rank_log): + mx_logger.removeHandler(h) + h.close() + + # Open with line buffering so each log line hits disk immediately. + rank_log_stream = open(rank_log, "w", buffering=1) + fh = logging.StreamHandler(rank_log_stream) fh.setLevel(logging.INFO) - fh.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) - logging.getLogger("modelexpress").addHandler(fh) + fh.setFormatter(logging.Formatter( + "%(asctime)s %(name)s %(levelname)s %(message)s")) + mx_logger.addHandler(fh) + # Track for explicit flush before return (safety net in case Python's + # exit handlers don't run, e.g. when MPI rank gets killed by mpirun). + self._rank_log_handler = fh + self._rank_log_stream = rank_log_stream logger.info( "Live transfer: loading '%s' rank %d (GPU %d)", model_name, mpi_rank, device_id @@ -485,11 +504,26 @@ def load_weights( except Exception as e: logger.warning("PVC fallback failed: %s", e) + # Flush rank log so transfer metrics are visible even if MPI worker + # gets killed before normal logging shutdown runs. + if hasattr(self, "_rank_log_handler"): + try: + self._rank_log_handler.flush() + self._rank_log_stream.flush() + os.fsync(self._rank_log_stream.fileno()) + except Exception: + pass + # Return fallback weights for TRT-LLM to apply; P2P weights are already in model params return fallback_weights def cleanup(self): - pass + if hasattr(self, "_rank_log_handler"): + try: + self._rank_log_handler.flush() + self._rank_log_stream.flush() + except Exception: + pass def _query_source(self, mx_server, model_name, timeout=600): import grpc From 4f642cb1a63a6058cb1453b9afb500d5154c8d91 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 7 May 2026 19:11:23 -0700 Subject: [PATCH 2/2] chore: re-trigger CodeRabbit review Signed-off-by: Kavin Krishnan