Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 78 additions & 12 deletions modelexpress_client/python/modelexpress/metadata/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@
import logging
import os
import threading
from typing import TYPE_CHECKING
import time
from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
from ..client import MxClient
from ..nixl_transfer import NixlTransferManager

logger = logging.getLogger("modelexpress.metadata.heartbeat")

PUBLISH_TIMEOUT_SECS_DEFAULT = 30 * 60 # 30 minutes


class HeartbeatThread:
"""Background thread that signals source liveness via UpdateStatus RPCs.
"""Background thread that publishes metadata and signals source liveness.

After PublishMetadata(status=INITIALIZING), the source spawns this thread.
Each tick: if NIXL agent is healthy, send UpdateStatus(READY) to refresh
updated_at. If not healthy, skip — the server-side reaper detects the
stale updated_at and marks the worker STALE.
A caller may pass either an existing ``mx_source_id`` or a ``publish_fn``.
With ``publish_fn``, the thread attempts initial metadata publication on
each tick until it succeeds or times out, then transitions to READY
heartbeats.

On clean shutdown (SIGTERM), atexit handler sends UpdateStatus(STALE)
for immediate detection without waiting for the reaper timeout.
Expand All @@ -42,21 +45,43 @@ class HeartbeatThread:
worker_rank: Model-shard rank used for metadata/status keying.
nixl_manager: Optional NIXL transfer manager for agent health checks.
Non-NIXL transports pass None and heartbeat unconditionally.
publish_fn: Optional callback for deferred initial metadata publish.
publish_timeout_secs: Seconds to keep retrying publish before giving up.
"""

def __init__(
self,
mx_client: MxClient,
mx_source_id: str,
worker_id: str,
worker_rank: int,
nixl_manager: NixlTransferManager | None,
mx_source_id: str | None = None,
worker_id: str | None = None,
worker_rank: int | None = None,
nixl_manager: NixlTransferManager | None = None,
publish_fn: Callable[[], str] | None = None,
publish_timeout_secs: int | None = None,
):
if mx_source_id is None and publish_fn is None:
raise ValueError("HeartbeatThread requires mx_source_id or publish_fn")
if worker_id is None or worker_rank is None:
raise ValueError("HeartbeatThread requires worker_id and worker_rank")

self._mx_client = mx_client
self._mx_source_id = mx_source_id
self._worker_id = worker_id
self._worker_rank = worker_rank
self._nixl_manager = nixl_manager
self._publish_fn = publish_fn

if publish_timeout_secs is not None:
self._publish_timeout = publish_timeout_secs
else:
self._publish_timeout = int(
os.environ.get(
"MX_PUBLISH_TIMEOUT_SECS",
str(PUBLISH_TIMEOUT_SECS_DEFAULT),
)
)
self._publish_started_at: float | None = None
self._publish_given_up = False

self._interval = int(
os.environ.get("MX_HEARTBEAT_INTERVAL_SECS", "30")
Expand All @@ -65,6 +90,10 @@ def __init__(
self._started = False
self._thread: threading.Thread | None = None

@property
def mx_source_id(self) -> str | None:
return self._mx_source_id

def start(self) -> None:
"""Start the heartbeat background thread."""
self._thread = threading.Thread(
Expand All @@ -76,7 +105,8 @@ def start(self) -> None:
atexit.register(self._on_exit)
logger.info(
f"[Worker {self._worker_rank}] Heartbeat started "
f"(interval={self._interval}s)"
f"(interval={self._interval}s, "
f"publish_timeout={self._publish_timeout}s)"
)

def stop(self) -> None:
Expand Down Expand Up @@ -117,10 +147,46 @@ def _update_status(self, status: int) -> None:
status=status,
)

def _try_publish(self) -> None:
"""Attempt metadata publish. On success, store mx_source_id."""
if self._publish_fn is None:
return
if self._publish_started_at is None:
self._publish_started_at = time.monotonic()

elapsed = time.monotonic() - self._publish_started_at
if elapsed > self._publish_timeout:
if not self._publish_given_up:
logger.warning(
f"[Worker {self._worker_rank}] Giving up on metadata publish "
f"after {elapsed:.0f}s (timeout={self._publish_timeout}s). "
f"Worker will continue without P2P serving."
)
self._publish_given_up = True
return

try:
self._mx_source_id = self._publish_fn()
self._started = True
logger.info(
f"[Worker {self._worker_rank}] Metadata published successfully "
f"(mx_source_id={self._mx_source_id})"
)
except Exception as e:
logger.warning(
f"[Worker {self._worker_rank}] Metadata publish attempt failed "
f"({elapsed:.0f}s elapsed, timeout={self._publish_timeout}s), "
f"will retry next tick: {e}"
)
Comment thread
zhengluo-nv marked this conversation as resolved.

def _tick(self) -> None:
"""Single heartbeat tick: check health and send READY if healthy."""
"""Single tick: publish if needed, otherwise heartbeat."""
from .. import p2p_pb2

if self._mx_source_id is None:
self._try_publish()
return

if self._nixl_manager is not None and not self._nixl_manager.is_healthy():
return

Expand Down
82 changes: 42 additions & 40 deletions modelexpress_client/python/modelexpress/metadata/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,38 +116,33 @@ def publish_metadata_and_ready(
identity: "p2p_pb2.SourceIdentity",
worker_id: str,
) -> None:
"""Publish tensor metadata and ready flag to the ModelExpress server."""
"""Prepare metadata publication and start the heartbeat thread.

The heartbeat thread performs the initial PublishMetadata RPC so a
transient MX server failure can be retried after the model has loaded.
"""
logger.info(
f"[Worker {worker_rank}] Publishing {len(tensors)} tensors for model '{identity.model_name}'"
f"[Worker {worker_rank}] Preparing {len(tensors)} tensors for model '{identity.model_name}'"
)

tensor_protos = build_tensor_protos(tensors, device_id, worker_rank)

if _is_p2p_metadata_enabled(mx_client):
from .worker_server import WorkerGrpcServer

if nixl_manager._listen_port is None:
raise RuntimeError(
"P2P metadata exchange requires a NIXL listen port, "
"but the NIXL manager was initialized without one."
)

host = _get_worker_host()

grpc_base = int(os.environ.get("MX_WORKER_GRPC_PORT", "6555"))
worker_grpc_port = grpc_base + device_id

worker = p2p_pb2.WorkerMetadata(
worker_rank=worker_rank,
metadata_endpoint=f"{host}:{nixl_manager._listen_port}",
agent_name=nixl_manager.agent_name,
worker_grpc_endpoint="",
)
mx_source_id = _publish_metadata_to_server(
mx_client=mx_client,
identity=identity,
worker=worker,
worker_id=worker_id,
worker_rank=worker_rank,
)

grpc_server = WorkerGrpcServer(
tensor_protos=tensor_protos,
mx_source_id=mx_source_id,
port=worker_grpc_port,
metadata_endpoint=f"{host}:{nixl_manager._listen_port}",
agent_name=nixl_manager.agent_name,
Expand All @@ -162,41 +157,48 @@ def publish_metadata_and_ready(
agent_name=nixl_manager.agent_name,
worker_grpc_endpoint=f"{host}:{actual_port}",
)
mx_source_id = _publish_metadata_to_server(
mx_client=mx_client,
identity=identity,
worker=worker,
worker_id=worker_id,
worker_rank=worker_rank,
)
logger.info(
f"[Worker {worker_rank}] Published P2P metadata to MX server "
f"(mx_source_id={mx_source_id}, worker_grpc={host}:{actual_port})"
)

def publish_fn() -> str:
mx_source_id = _publish_metadata_to_server(
mx_client=mx_client,
identity=identity,
worker=worker,
worker_id=worker_id,
worker_rank=worker_rank,
)
grpc_server.set_mx_source_id(mx_source_id)
logger.info(
f"[Worker {worker_rank}] Published P2P metadata to MX server "
f"(mx_source_id={mx_source_id}, worker_grpc={host}:{actual_port})"
)
return mx_source_id
else:
worker = p2p_pb2.WorkerMetadata(
worker_rank=worker_rank,
nixl_metadata=nixl_manager.nixl_metadata,
tensors=tensor_protos,
)
mx_source_id = _publish_metadata_to_server(
mx_client=mx_client,
identity=identity,
worker=worker,
worker_id=worker_id,
worker_rank=worker_rank,
)
logger.info(
f"[Worker {worker_rank}] Published metadata to MX server "
f"(mx_source_id={mx_source_id}, worker_id={worker_id})"
)

def publish_fn() -> str:
mx_source_id = _publish_metadata_to_server(
mx_client=mx_client,
identity=identity,
worker=worker,
worker_id=worker_id,
worker_rank=worker_rank,
)
logger.info(
f"[Worker {worker_rank}] Published metadata to MX server "
f"(mx_source_id={mx_source_id}, worker_id={worker_id})"
)
return mx_source_id

heartbeat = HeartbeatThread(
mx_client=mx_client,
mx_source_id=mx_source_id,
worker_id=worker_id,
worker_rank=worker_rank,
nixl_manager=nixl_manager,
publish_fn=publish_fn,
)
heartbeat.start()
_heartbeat_threads[worker_rank] = heartbeat
Expand Down
20 changes: 15 additions & 5 deletions modelexpress_client/python/modelexpress/metadata/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class WorkerServiceServicer(p2p_pb2_grpc.WorkerServiceServicer):
def __init__(
self,
tensor_protos: list[p2p_pb2.TensorDescriptor],
mx_source_id: str,
mx_source_id: str | None = None,
metadata_endpoint: str = "",
agent_name: str = "",
worker_rank: int = 0,
Expand All @@ -40,6 +40,9 @@ def __init__(
self._agent_name = agent_name
self._worker_rank = worker_rank

def set_mx_source_id(self, mx_source_id: str) -> None:
self._mx_source_id = mx_source_id

def GetTensorManifest(self, request, context):
if request.mx_source_id and request.mx_source_id != self._mx_source_id:
context.abort(
Expand All @@ -49,7 +52,7 @@ def GetTensorManifest(self, request, context):
)
return p2p_pb2.GetTensorManifestResponse(
tensors=self._tensor_protos,
mx_source_id=self._mx_source_id,
mx_source_id=self._mx_source_id or "",
metadata_endpoint=self._metadata_endpoint,
agent_name=self._agent_name,
worker_rank=self._worker_rank,
Expand All @@ -62,7 +65,7 @@ class WorkerGrpcServer:
def __init__(
self,
tensor_protos: list[p2p_pb2.TensorDescriptor],
mx_source_id: str,
mx_source_id: str | None = None,
port: int = 0,
metadata_endpoint: str = "",
agent_name: str = "",
Expand All @@ -75,23 +78,30 @@ def __init__(
self._agent_name = agent_name
self._worker_rank = worker_rank
self._server: grpc.Server | None = None
self._servicer: WorkerServiceServicer | None = None
self._port: int | None = None

@property
def port(self) -> int | None:
return self._port

def set_mx_source_id(self, mx_source_id: str) -> None:
if self._servicer is None:
raise RuntimeError("Server must be started before setting mx_source_id")
self._mx_source_id = mx_source_id
self._servicer.set_mx_source_id(mx_source_id)

def start(self) -> int:
"""Start the gRPC server. Returns the actual bound port."""
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
servicer = WorkerServiceServicer(
self._servicer = WorkerServiceServicer(
tensor_protos=self._tensor_protos,
mx_source_id=self._mx_source_id,
metadata_endpoint=self._metadata_endpoint,
agent_name=self._agent_name,
worker_rank=self._worker_rank,
)
p2p_pb2_grpc.add_WorkerServiceServicer_to_server(servicer, self._server)
p2p_pb2_grpc.add_WorkerServiceServicer_to_server(self._servicer, self._server)

if self._requested_port:
self._port = self._server.add_insecure_port(f"[::]:{self._requested_port}")
Expand Down
Loading
Loading