From 99f4ad04a1beb38a3b503020f7d6522dc7dbefb7 Mon Sep 17 00:00:00 2001 From: Zheng Luo Date: Tue, 12 May 2026 12:23:35 -0700 Subject: [PATCH] feat: retry metadata publish from heartbeat Signed-off-by: Zheng Luo --- .../python/modelexpress/metadata/heartbeat.py | 90 +++++++++++-- .../python/modelexpress/metadata/publish.py | 82 ++++++------ .../modelexpress/metadata/worker_server.py | 20 ++- .../python/tests/test_heartbeat.py | 53 ++++++++ .../python/tests/test_vllm_loader.py | 122 ++++++++++++------ 5 files changed, 269 insertions(+), 98 deletions(-) diff --git a/modelexpress_client/python/modelexpress/metadata/heartbeat.py b/modelexpress_client/python/modelexpress/metadata/heartbeat.py index 3f39ee82..11fc7bd2 100644 --- a/modelexpress_client/python/modelexpress/metadata/heartbeat.py +++ b/modelexpress_client/python/modelexpress/metadata/heartbeat.py @@ -15,7 +15,8 @@ import logging import os import threading -from typing import TYPE_CHECKING +import time +from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: from ..client import MxClient @@ -23,14 +24,16 @@ logger = logging.getLogger("modelexpress.metadata.heartbeat") +PUBLISH_TIMEOUT_SECS_DEFAULT = 30 * 60 # 30 minutes + class HeartbeatThread: - """Background thread that signals source liveness via UpdateStatus RPCs. + """Background thread that publishes metadata and signals source liveness. - After PublishMetadata(status=INITIALIZING), the source spawns this thread. - Each tick: if NIXL agent is healthy, send UpdateStatus(READY) to refresh - updated_at. If not healthy, skip — the server-side reaper detects the - stale updated_at and marks the worker STALE. + A caller may pass either an existing ``mx_source_id`` or a ``publish_fn``. + With ``publish_fn``, the thread attempts initial metadata publication on + each tick until it succeeds or times out, then transitions to READY + heartbeats. On clean shutdown (SIGTERM), atexit handler sends UpdateStatus(STALE) for immediate detection without waiting for the reaper timeout. @@ -42,21 +45,43 @@ class HeartbeatThread: worker_rank: Model-shard rank used for metadata/status keying. nixl_manager: Optional NIXL transfer manager for agent health checks. Non-NIXL transports pass None and heartbeat unconditionally. + publish_fn: Optional callback for deferred initial metadata publish. + publish_timeout_secs: Seconds to keep retrying publish before giving up. """ def __init__( self, mx_client: MxClient, - mx_source_id: str, - worker_id: str, - worker_rank: int, - nixl_manager: NixlTransferManager | None, + mx_source_id: str | None = None, + worker_id: str | None = None, + worker_rank: int | None = None, + nixl_manager: NixlTransferManager | None = None, + publish_fn: Callable[[], str] | None = None, + publish_timeout_secs: int | None = None, ): + if mx_source_id is None and publish_fn is None: + raise ValueError("HeartbeatThread requires mx_source_id or publish_fn") + if worker_id is None or worker_rank is None: + raise ValueError("HeartbeatThread requires worker_id and worker_rank") + self._mx_client = mx_client self._mx_source_id = mx_source_id self._worker_id = worker_id self._worker_rank = worker_rank self._nixl_manager = nixl_manager + self._publish_fn = publish_fn + + if publish_timeout_secs is not None: + self._publish_timeout = publish_timeout_secs + else: + self._publish_timeout = int( + os.environ.get( + "MX_PUBLISH_TIMEOUT_SECS", + str(PUBLISH_TIMEOUT_SECS_DEFAULT), + ) + ) + self._publish_started_at: float | None = None + self._publish_given_up = False self._interval = int( os.environ.get("MX_HEARTBEAT_INTERVAL_SECS", "30") @@ -65,6 +90,10 @@ def __init__( self._started = False self._thread: threading.Thread | None = None + @property + def mx_source_id(self) -> str | None: + return self._mx_source_id + def start(self) -> None: """Start the heartbeat background thread.""" self._thread = threading.Thread( @@ -76,7 +105,8 @@ def start(self) -> None: atexit.register(self._on_exit) logger.info( f"[Worker {self._worker_rank}] Heartbeat started " - f"(interval={self._interval}s)" + f"(interval={self._interval}s, " + f"publish_timeout={self._publish_timeout}s)" ) def stop(self) -> None: @@ -117,10 +147,46 @@ def _update_status(self, status: int) -> None: status=status, ) + def _try_publish(self) -> None: + """Attempt metadata publish. On success, store mx_source_id.""" + if self._publish_fn is None: + return + if self._publish_started_at is None: + self._publish_started_at = time.monotonic() + + elapsed = time.monotonic() - self._publish_started_at + if elapsed > self._publish_timeout: + if not self._publish_given_up: + logger.warning( + f"[Worker {self._worker_rank}] Giving up on metadata publish " + f"after {elapsed:.0f}s (timeout={self._publish_timeout}s). " + f"Worker will continue without P2P serving." + ) + self._publish_given_up = True + return + + try: + self._mx_source_id = self._publish_fn() + self._started = True + logger.info( + f"[Worker {self._worker_rank}] Metadata published successfully " + f"(mx_source_id={self._mx_source_id})" + ) + except Exception as e: + logger.warning( + f"[Worker {self._worker_rank}] Metadata publish attempt failed " + f"({elapsed:.0f}s elapsed, timeout={self._publish_timeout}s), " + f"will retry next tick: {e}" + ) + def _tick(self) -> None: - """Single heartbeat tick: check health and send READY if healthy.""" + """Single tick: publish if needed, otherwise heartbeat.""" from .. import p2p_pb2 + if self._mx_source_id is None: + self._try_publish() + return + if self._nixl_manager is not None and not self._nixl_manager.is_healthy(): return diff --git a/modelexpress_client/python/modelexpress/metadata/publish.py b/modelexpress_client/python/modelexpress/metadata/publish.py index 16fe4655..54a39302 100644 --- a/modelexpress_client/python/modelexpress/metadata/publish.py +++ b/modelexpress_client/python/modelexpress/metadata/publish.py @@ -116,9 +116,13 @@ def publish_metadata_and_ready( identity: "p2p_pb2.SourceIdentity", worker_id: str, ) -> None: - """Publish tensor metadata and ready flag to the ModelExpress server.""" + """Prepare metadata publication and start the heartbeat thread. + + The heartbeat thread performs the initial PublishMetadata RPC so a + transient MX server failure can be retried after the model has loaded. + """ logger.info( - f"[Worker {worker_rank}] Publishing {len(tensors)} tensors for model '{identity.model_name}'" + f"[Worker {worker_rank}] Preparing {len(tensors)} tensors for model '{identity.model_name}'" ) tensor_protos = build_tensor_protos(tensors, device_id, worker_rank) @@ -126,28 +130,19 @@ def publish_metadata_and_ready( if _is_p2p_metadata_enabled(mx_client): from .worker_server import WorkerGrpcServer + if nixl_manager._listen_port is None: + raise RuntimeError( + "P2P metadata exchange requires a NIXL listen port, " + "but the NIXL manager was initialized without one." + ) + host = _get_worker_host() grpc_base = int(os.environ.get("MX_WORKER_GRPC_PORT", "6555")) worker_grpc_port = grpc_base + device_id - worker = p2p_pb2.WorkerMetadata( - worker_rank=worker_rank, - metadata_endpoint=f"{host}:{nixl_manager._listen_port}", - agent_name=nixl_manager.agent_name, - worker_grpc_endpoint="", - ) - mx_source_id = _publish_metadata_to_server( - mx_client=mx_client, - identity=identity, - worker=worker, - worker_id=worker_id, - worker_rank=worker_rank, - ) - grpc_server = WorkerGrpcServer( tensor_protos=tensor_protos, - mx_source_id=mx_source_id, port=worker_grpc_port, metadata_endpoint=f"{host}:{nixl_manager._listen_port}", agent_name=nixl_manager.agent_name, @@ -162,41 +157,48 @@ def publish_metadata_and_ready( agent_name=nixl_manager.agent_name, worker_grpc_endpoint=f"{host}:{actual_port}", ) - mx_source_id = _publish_metadata_to_server( - mx_client=mx_client, - identity=identity, - worker=worker, - worker_id=worker_id, - worker_rank=worker_rank, - ) - logger.info( - f"[Worker {worker_rank}] Published P2P metadata to MX server " - f"(mx_source_id={mx_source_id}, worker_grpc={host}:{actual_port})" - ) + + def publish_fn() -> str: + mx_source_id = _publish_metadata_to_server( + mx_client=mx_client, + identity=identity, + worker=worker, + worker_id=worker_id, + worker_rank=worker_rank, + ) + grpc_server.set_mx_source_id(mx_source_id) + logger.info( + f"[Worker {worker_rank}] Published P2P metadata to MX server " + f"(mx_source_id={mx_source_id}, worker_grpc={host}:{actual_port})" + ) + return mx_source_id else: worker = p2p_pb2.WorkerMetadata( worker_rank=worker_rank, nixl_metadata=nixl_manager.nixl_metadata, tensors=tensor_protos, ) - mx_source_id = _publish_metadata_to_server( - mx_client=mx_client, - identity=identity, - worker=worker, - worker_id=worker_id, - worker_rank=worker_rank, - ) - logger.info( - f"[Worker {worker_rank}] Published metadata to MX server " - f"(mx_source_id={mx_source_id}, worker_id={worker_id})" - ) + + def publish_fn() -> str: + mx_source_id = _publish_metadata_to_server( + mx_client=mx_client, + identity=identity, + worker=worker, + worker_id=worker_id, + worker_rank=worker_rank, + ) + logger.info( + f"[Worker {worker_rank}] Published metadata to MX server " + f"(mx_source_id={mx_source_id}, worker_id={worker_id})" + ) + return mx_source_id heartbeat = HeartbeatThread( mx_client=mx_client, - mx_source_id=mx_source_id, worker_id=worker_id, worker_rank=worker_rank, nixl_manager=nixl_manager, + publish_fn=publish_fn, ) heartbeat.start() _heartbeat_threads[worker_rank] = heartbeat diff --git a/modelexpress_client/python/modelexpress/metadata/worker_server.py b/modelexpress_client/python/modelexpress/metadata/worker_server.py index 9652a98e..6890d016 100644 --- a/modelexpress_client/python/modelexpress/metadata/worker_server.py +++ b/modelexpress_client/python/modelexpress/metadata/worker_server.py @@ -29,7 +29,7 @@ class WorkerServiceServicer(p2p_pb2_grpc.WorkerServiceServicer): def __init__( self, tensor_protos: list[p2p_pb2.TensorDescriptor], - mx_source_id: str, + mx_source_id: str | None = None, metadata_endpoint: str = "", agent_name: str = "", worker_rank: int = 0, @@ -40,6 +40,9 @@ def __init__( self._agent_name = agent_name self._worker_rank = worker_rank + def set_mx_source_id(self, mx_source_id: str) -> None: + self._mx_source_id = mx_source_id + def GetTensorManifest(self, request, context): if request.mx_source_id and request.mx_source_id != self._mx_source_id: context.abort( @@ -49,7 +52,7 @@ def GetTensorManifest(self, request, context): ) return p2p_pb2.GetTensorManifestResponse( tensors=self._tensor_protos, - mx_source_id=self._mx_source_id, + mx_source_id=self._mx_source_id or "", metadata_endpoint=self._metadata_endpoint, agent_name=self._agent_name, worker_rank=self._worker_rank, @@ -62,7 +65,7 @@ class WorkerGrpcServer: def __init__( self, tensor_protos: list[p2p_pb2.TensorDescriptor], - mx_source_id: str, + mx_source_id: str | None = None, port: int = 0, metadata_endpoint: str = "", agent_name: str = "", @@ -75,23 +78,30 @@ def __init__( self._agent_name = agent_name self._worker_rank = worker_rank self._server: grpc.Server | None = None + self._servicer: WorkerServiceServicer | None = None self._port: int | None = None @property def port(self) -> int | None: return self._port + def set_mx_source_id(self, mx_source_id: str) -> None: + if self._servicer is None: + raise RuntimeError("Server must be started before setting mx_source_id") + self._mx_source_id = mx_source_id + self._servicer.set_mx_source_id(mx_source_id) + def start(self) -> int: """Start the gRPC server. Returns the actual bound port.""" self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) - servicer = WorkerServiceServicer( + self._servicer = WorkerServiceServicer( tensor_protos=self._tensor_protos, mx_source_id=self._mx_source_id, metadata_endpoint=self._metadata_endpoint, agent_name=self._agent_name, worker_rank=self._worker_rank, ) - p2p_pb2_grpc.add_WorkerServiceServicer_to_server(servicer, self._server) + p2p_pb2_grpc.add_WorkerServiceServicer_to_server(self._servicer, self._server) if self._requested_port: self._port = self._server.add_insecure_port(f"[::]:{self._requested_port}") diff --git a/modelexpress_client/python/tests/test_heartbeat.py b/modelexpress_client/python/tests/test_heartbeat.py index 45951be0..9dbc96c4 100644 --- a/modelexpress_client/python/tests/test_heartbeat.py +++ b/modelexpress_client/python/tests/test_heartbeat.py @@ -91,6 +91,59 @@ def test_multiple_ticks_refresh_updated_at(self, heartbeat, mx_client): assert len(ready_calls) >= 2 +class TestHeartbeatPublishAndReady: + def test_first_tick_publishes_then_ready(self, mx_client, nixl_manager): + with patch.dict("os.environ", {"MX_HEARTBEAT_INTERVAL_SECS": "1"}): + hb = HeartbeatThread( + mx_client=mx_client, + worker_id="w1", + worker_rank=0, + nixl_manager=nixl_manager, + publish_fn=lambda: "abc123", + ) + + hb.start() + time.sleep(2.5) + hb.stop() + + assert hb.mx_source_id == "abc123" + ready_calls = [ + c for c in mx_client.update_status.call_args_list + if c == call( + mx_source_id="abc123", + worker_id="w1", + worker_rank=0, + status=2, + ) + ] + assert len(ready_calls) >= 1 + + def test_publish_failure_retries(self, mx_client, nixl_manager): + attempt = {"count": 0} + + def failing_then_ok(): + attempt["count"] += 1 + if attempt["count"] < 3: + raise RuntimeError("server down") + return "abc123" + + with patch.dict("os.environ", {"MX_HEARTBEAT_INTERVAL_SECS": "1"}): + hb = HeartbeatThread( + mx_client=mx_client, + worker_id="w1", + worker_rank=0, + nixl_manager=nixl_manager, + publish_fn=failing_then_ok, + ) + + hb.start() + time.sleep(3.5) + hb.stop() + + assert hb.mx_source_id == "abc123" + assert attempt["count"] == 3 + + class TestHeartbeatStop: def test_stop_marks_stale(self, heartbeat, mx_client): heartbeat.start() diff --git a/modelexpress_client/python/tests/test_vllm_loader.py b/modelexpress_client/python/tests/test_vllm_loader.py index c29b7dc7..a7b610fd 100644 --- a/modelexpress_client/python/tests/test_vllm_loader.py +++ b/modelexpress_client/python/tests/test_vllm_loader.py @@ -887,7 +887,7 @@ def test_rollback_shuts_down_nixl_manager(self): class TestPublishMetadataAndReady: - def test_calls_publish_and_starts_heartbeat(self): + def test_starts_heartbeat_with_publish_fn(self): from modelexpress.metadata.publish import publish_metadata_and_ready mx_client = MagicMock() @@ -910,21 +910,24 @@ def test_calls_publish_and_starts_heartbeat(self): with patch("modelexpress.metadata.publish.HeartbeatThread", return_value=mock_hb) as hb_cls: publish_metadata_and_ready(mx_client, nixl_manager, tensors, worker_rank=2, device_id=0, identity=identity, worker_id="inst-uuid") + mx_client.publish_metadata.assert_not_called() + hb_cls.assert_called_once() + hb_kwargs = hb_cls.call_args.kwargs + assert hb_kwargs["mx_client"] is mx_client + assert hb_kwargs["worker_id"] == "inst-uuid" + assert hb_kwargs["worker_rank"] == 2 + assert hb_kwargs["nixl_manager"] is nixl_manager + assert callable(hb_kwargs["publish_fn"]) + mock_hb.start.assert_called_once() + + result = hb_kwargs["publish_fn"]() + assert result == "abc123def456abcd" mx_client.publish_metadata.assert_called_once() call_args = mx_client.publish_metadata.call_args assert call_args.args[0] is identity assert call_args.args[2] == "inst-uuid" - hb_cls.assert_called_once_with( - mx_client=mx_client, - mx_source_id="abc123def456abcd", - worker_id="inst-uuid", - worker_rank=2, - nixl_manager=nixl_manager, - ) - mock_hb.start.assert_called_once() - - def test_retries_publish_before_starting_heartbeat(self): + def test_publish_fn_retries_publish(self): from modelexpress.metadata.publish import publish_metadata_and_ready mx_client = MagicMock() @@ -951,19 +954,15 @@ def test_retries_publish_before_starting_heartbeat(self): worker_id="w-1", ) + publish_fn = hb_cls.call_args.kwargs["publish_fn"] + result = publish_fn() + + assert result == "abc123def456abcd" assert mx_client.publish_metadata.call_count == 3 assert sleep_mock.call_args_list == [call(1.0), call(2.0)] - hb_cls.assert_called_once_with( - mx_client=mx_client, - mx_source_id="abc123def456abcd", - worker_id="w-1", - worker_rank=0, - nixl_manager=nixl_manager, - ) mock_hb.start.assert_called_once() - def test_publish_failure_after_retries_raises_runtime_error(self): - """If publish_metadata keeps failing, heartbeat should not be started.""" + def test_publish_fn_failure_after_retries_raises_runtime_error(self): from modelexpress.metadata.publish import publish_metadata_and_ready mx_client = MagicMock() @@ -980,22 +979,24 @@ def test_publish_failure_after_retries_raises_runtime_error(self): mock_hb = MagicMock() with patch("modelexpress.metadata.publish.time.sleep") as sleep_mock, \ patch("modelexpress.metadata.publish.HeartbeatThread", return_value=mock_hb) as hb_cls: + publish_metadata_and_ready( + mx_client, + nixl_manager, + {}, + worker_rank=0, + device_id=0, + identity=identity, + worker_id="w-1", + ) + publish_fn = hb_cls.call_args.kwargs["publish_fn"] with pytest.raises(RuntimeError, match="Failed to publish metadata after 3 attempts"): - publish_metadata_and_ready( - mx_client, - nixl_manager, - {}, - worker_rank=0, - device_id=0, - identity=identity, - worker_id="w-1", - ) + publish_fn() assert mx_client.publish_metadata.call_count == 3 assert sleep_mock.call_args_list == [call(1.0), call(2.0)] - hb_cls.assert_not_called() + mock_hb.start.assert_called_once() - def test_non_retryable_grpc_failure_fails_immediately(self): + def test_publish_fn_non_retryable_grpc_failure_fails_immediately(self): from modelexpress.metadata.publish import publish_metadata_and_ready mx_client = MagicMock() @@ -1011,20 +1012,59 @@ def test_non_retryable_grpc_failure_fails_immediately(self): mock_hb = MagicMock() with patch("modelexpress.metadata.publish.time.sleep") as sleep_mock, \ patch("modelexpress.metadata.publish.HeartbeatThread", return_value=mock_hb) as hb_cls: + publish_metadata_and_ready( + mx_client, + nixl_manager, + {}, + worker_rank=0, + device_id=0, + identity=identity, + worker_id="w-1", + ) + publish_fn = hb_cls.call_args.kwargs["publish_fn"] with pytest.raises(_FakeRpcError, match="permission denied"): - publish_metadata_and_ready( - mx_client, - nixl_manager, - {}, - worker_rank=0, - device_id=0, - identity=identity, - worker_id="w-1", - ) + publish_fn() assert mx_client.publish_metadata.call_count == 1 assert sleep_mock.call_args_list == [] - hb_cls.assert_not_called() + mock_hb.start.assert_called_once() + + def test_p2p_mode_starts_grpc_server_before_publish(self): + from modelexpress.metadata.publish import publish_metadata_and_ready + + mx_client = MagicMock() + mx_client.publish_metadata.return_value = "abc123def456abcd" + + nixl_manager = MagicMock() + nixl_manager._listen_port = 5555 + nixl_manager.agent_name = "test-agent" + + mock_grpc_server = MagicMock() + mock_grpc_server.start.return_value = 6555 + mock_hb = MagicMock() + + with patch.dict("os.environ", {"MX_P2P_METADATA": "1", "MX_WORKER_HOST": "10.0.0.1"}), \ + patch("modelexpress.metadata.worker_server.WorkerGrpcServer", return_value=mock_grpc_server) as grpc_cls, \ + patch("modelexpress.metadata.publish.HeartbeatThread", return_value=mock_hb) as hb_cls: + publish_metadata_and_ready( + mx_client, + nixl_manager, + {}, + worker_rank=0, + device_id=0, + identity=_make_identity(), + worker_id="w-1", + ) + + mx_client.publish_metadata.assert_not_called() + grpc_cls.assert_called_once() + mock_grpc_server.start.assert_called_once() + + publish_fn = hb_cls.call_args.kwargs["publish_fn"] + publish_fn() + + mx_client.publish_metadata.assert_called_once() + mock_grpc_server.set_mx_source_id.assert_called_once_with("abc123def456abcd") # ---------------------------------------------------------------------------