diff --git a/modelexpress_client/python/modelexpress/trtllm_live_transfer.py b/modelexpress_client/python/modelexpress/trtllm_live_transfer.py index e84d3f0b..20d4254e 100644 --- a/modelexpress_client/python/modelexpress/trtllm_live_transfer.py +++ b/modelexpress_client/python/modelexpress/trtllm_live_transfer.py @@ -306,13 +306,32 @@ def load_weights( mpi_rank = device_id # MPI workers' stdout is swallowed by TRT-LLM — write to per-rank file + # Use line-buffered mode and explicit handler config so messages are + # visible even if the worker exits before normal logging shutdown. log_dir = os.environ.get("MX_TRANSFER_LOG_DIR", "/tmp/mx_logs") os.makedirs(log_dir, exist_ok=True) rank_log = os.path.join(log_dir, f"rank{mpi_rank}.log") - fh = logging.FileHandler(rank_log, mode="w") + + mx_logger = logging.getLogger("modelexpress") + mx_logger.setLevel(logging.INFO) + # Avoid duplicate handlers across multiple load_weights() calls. + for h in list(mx_logger.handlers): + if isinstance(h, logging.FileHandler) and getattr( + h, "baseFilename", "") == os.path.abspath(rank_log): + mx_logger.removeHandler(h) + h.close() + + # Open with line buffering so each log line hits disk immediately. + rank_log_stream = open(rank_log, "w", buffering=1) + fh = logging.StreamHandler(rank_log_stream) fh.setLevel(logging.INFO) - fh.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) - logging.getLogger("modelexpress").addHandler(fh) + fh.setFormatter(logging.Formatter( + "%(asctime)s %(name)s %(levelname)s %(message)s")) + mx_logger.addHandler(fh) + # Track for explicit flush before return (safety net in case Python's + # exit handlers don't run, e.g. when MPI rank gets killed by mpirun). + self._rank_log_handler = fh + self._rank_log_stream = rank_log_stream logger.info( "Live transfer: loading '%s' rank %d (GPU %d)", model_name, mpi_rank, device_id @@ -485,11 +504,26 @@ def load_weights( except Exception as e: logger.warning("PVC fallback failed: %s", e) + # Flush rank log so transfer metrics are visible even if MPI worker + # gets killed before normal logging shutdown runs. + if hasattr(self, "_rank_log_handler"): + try: + self._rank_log_handler.flush() + self._rank_log_stream.flush() + os.fsync(self._rank_log_stream.fileno()) + except Exception: + pass + # Return fallback weights for TRT-LLM to apply; P2P weights are already in model params return fallback_weights def cleanup(self): - pass + if hasattr(self, "_rank_log_handler"): + try: + self._rank_log_handler.flush() + self._rank_log_stream.flush() + except Exception: + pass def _query_source(self, mx_server, model_name, timeout=600): import grpc