Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 55 additions & 119 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import numpy as np
import ray
import torch
import zmq
from omegaconf import DictConfig
from torch import Tensor

Expand All @@ -42,9 +41,7 @@
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
get_free_port,
ZMQServerTransport,
get_node_ip_address,
)

Expand Down Expand Up @@ -1003,8 +1000,7 @@ def __init__(
self.polling_mode = polling_mode
self.tq_config = None # global config for TransferQueue system

# Initialize ZMQ sockets for communication
self._init_zmq_socket()
self._init_zmq_transport()

# Partition management
self.partitions: dict[str, DataPartitionStatus] = {} # partition_id -> DataPartitionStatus
Expand All @@ -1020,9 +1016,7 @@ def __init__(
self._metrics_endpoint: str = ""

# Start background processing threads
self._start_process_handshake()
self._start_process_update_data_status()
self._start_process_request()
self._start_daemon_threads()

logger.info(f"TransferQueue Controller {self.controller_id} initialized")

Expand Down Expand Up @@ -1355,7 +1349,6 @@ def get_metadata(

elif mode == "force_fetch":
batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
consumed_indexes = []

# Package into metadata
metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
Expand Down Expand Up @@ -1676,140 +1669,82 @@ def kv_retrieve_keys(

return keys

def _init_zmq_socket(self):
"""Initialize ZMQ sockets for communication."""
self.zmq_context = zmq.Context()
self._node_ip = get_node_ip_address()

while True:
try:
self._handshake_socket_port = get_free_port(ip=self._node_ip)
self._request_handle_socket_port = get_free_port(ip=self._node_ip)
self._data_status_update_socket_port = get_free_port(ip=self._node_ip)

self.handshake_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port))

self.request_handle_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port))

self.data_status_update_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.data_status_update_socket.bind(
format_zmq_address(self._node_ip, self._data_status_update_socket_port)
)

break
except zmq.ZMQError:
logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...")
continue

self.zmq_server_info = ZMQServerInfo(
def _init_zmq_transport(self):
"""Initialize ZMQ transport layer."""
self._transport = ZMQServerTransport(node_ip=get_node_ip_address())
for socket_name in ("handshake_socket", "request_handle_socket", "data_status_update_socket"):
self._transport.create_router_socket(socket_name)
self.zmq_server_info = self._transport.build_server_info(
role=Role.CONTROLLER,
id=self.controller_id,
ip=self._node_ip,
ports={
"handshake_socket": self._handshake_socket_port,
"request_handle_socket": self._request_handle_socket_port,
"data_status_update_socket": self._data_status_update_socket_port,
},
)

def _wait_connection(self):
"""Wait for storage instances to complete handshake with retransmission support."""
poller = zmq.Poller()
poller.register(self.handshake_socket, zmq.POLLIN)
"""Wait for storage instances to complete handshake."""
handshake_socket = self._transport.get_socket("handshake_socket")

logger.debug(f"Controller {self.controller_id} started waiting for storage connections...")

while True:
socks = dict(poller.poll(1000))

if self.handshake_socket in socks:
try:
messages = self.handshake_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
try:
messages = handshake_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)

if request_msg.request_type == ZMQRequestType.HANDSHAKE:
storage_manager_id = request_msg.sender_id
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
storage_manager_id = request_msg.sender_id

# Always send ACK for HANDSHAKE
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.HANDSHAKE_ACK,
sender_id=self.controller_id,
body={},
).serialize()
self.handshake_socket.send_multipart([identity, *response_msg])

# Track new connections
if storage_manager_id not in self._connected_storage_managers:
self._connected_storage_managers.add(storage_manager_id)
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
logger.debug(
f"[{self.controller_id}]: received handshake from "
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
f"Total connected: {len(self._connected_storage_managers)}"
)
else:
logger.debug(
f"[{self.controller_id}]: received duplicate handshake from "
f"storage manager {storage_manager_id}. Resending ACK."
)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.HANDSHAKE_ACK,
sender_id=self.controller_id,
body={},
).serialize()
handshake_socket.send_multipart([identity, *response_msg])

if storage_manager_id not in self._connected_storage_managers:
self._connected_storage_managers.add(storage_manager_id)
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
logger.debug(
f"[{self.controller_id}]: received handshake from "
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
f"Total connected: {len(self._connected_storage_managers)}"
)
else:
logger.debug(
f"[{self.controller_id}]: received duplicate handshake from "
f"storage manager {storage_manager_id}. Resending ACK."
)

except Exception as e:
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")
except Exception as e:
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")

def _start_process_handshake(self):
"""Start the handshake process thread."""
self.wait_connection_thread = Thread(
def _start_daemon_threads(self):
self._transport.start_daemon_thread(
target=self._wait_connection,
name="TransferQueueControllerWaitConnectionThread",
daemon=True,
name="TQControllerWaitConnectionThread",
)
self.wait_connection_thread.start()

def _start_process_update_data_status(self):
"""Start the data status update processing thread."""
self.process_update_data_status_thread = Thread(
self._transport.start_daemon_thread(
target=self._update_data_status,
name="TransferQueueControllerProcessUpdateDataStatusThread",
daemon=True,
name="TQControllerProcessUpdateDataStatusThread",
)
self.process_update_data_status_thread.start()

def _start_process_request(self):
"""Start the request processing thread."""
self.process_request_thread = Thread(
self._transport.start_daemon_thread(
target=self._process_request,
name="TransferQueueControllerProcessRequestThread",
daemon=True,
name="TQControllerProcessRequestThread",
)
self.process_request_thread.start()

def _process_request(self):
"""Main request processing loop - adapted for partition-based operations."""

logger.info(f"[{self.controller_id}]: start processing requests...")

request_handle_socket = self._transport.get_socket("request_handle_socket")
perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)

while True:
monitor = self._metrics if self._metrics is not None else perf_monitor

messages = self.request_handle_socket.recv_multipart(copy=False)
messages = request_handle_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
Expand Down Expand Up @@ -2045,18 +1980,18 @@ def _process_request(self):
body={"partition_info": partition_info, "message": message},
)

self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
request_handle_socket.send_multipart([identity, *response_msg.serialize()])

def _update_data_status(self):
"""Process data status update messages from storage units - adapted for partitions."""
logger.debug(f"[{self.controller_id}]: start receiving update_data_status requests...")

data_status_update_socket = self._transport.get_socket("data_status_update_socket")
perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)

while True:
monitor = self._metrics if self._metrics is not None else perf_monitor

messages = self.data_status_update_socket.recv_multipart(copy=False)
messages = data_status_update_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
Expand All @@ -2074,6 +2009,7 @@ def _update_data_status(self):
field_schema=message_data.get("field_schema", {}),
custom_backend_meta=message_data.get("custom_backend_meta", {}),
)

if success:
if self._metrics is not None:
self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes))
Expand All @@ -2089,7 +2025,7 @@ def _update_data_status(self):
"success": success,
},
)
self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()])
data_status_update_socket.send_multipart([identity, *response_msg.serialize()])

def get_zmq_server_info(self) -> ZMQServerInfo:
"""Get ZMQ server connection information."""
Expand Down Expand Up @@ -2196,7 +2132,7 @@ def start_metrics(self, port: int = 0) -> str:
from transfer_queue.metrics import TQMetricsExporter

self._metrics = TQMetricsExporter()
self._metrics_endpoint = self._metrics.start(node_ip=self._node_ip, port=port)
self._metrics_endpoint = self._metrics.start(node_ip=get_node_ip_address(), port=port)
# Launch a daemon thread that periodically pushes controller state
# snapshots to the exporter, keeping them process-isolated.
self._metrics_snapshot_thread = Thread(
Expand Down
56 changes: 50 additions & 6 deletions transfer_queue/utils/zmq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import socket
import threading
import time
from dataclasses import dataclass
from functools import wraps
Expand Down Expand Up @@ -103,9 +104,7 @@ class ZMQRequestType(ExplicitEnum):


class ZMQServerInfo:
"""
TransferQueue server info class.
"""
"""TransferQueue server info class."""

def __init__(self, role: Role, id: str, ip: str, ports: dict[str, int]):
self.role = role
Expand All @@ -132,9 +131,7 @@ def __str__(self) -> str:

@dataclass
class ZMQMessage:
"""
ZMQMessage class for TransferQueue communication.
"""
"""ZMQMessage class for TransferQueue communication."""

request_type: ZMQRequestType
sender_id: str
Expand Down Expand Up @@ -190,6 +187,53 @@ def deserialize(cls, frames: list) -> "ZMQMessage":
)


class ZMQServerTransport:
"""Unified management of ZMQ Router Sockets, port binding, daemon threads, and message I/O."""

def __init__(self, node_ip: str, ctx: zmq.Context | None = None):
self.node_ip = node_ip
self.zmq_ctx = ctx or zmq.Context()
self.sockets: dict[str, zmq.Socket] = {}
self.ports: dict[str, int] = {}
self.threads: list[threading.Thread] = []

def create_router_socket(self, name: str) -> None:
"""Create a ROUTER-type socket, automatically retrying port binding."""
while True:
try:
port = get_free_port(ip=self.node_ip)
sock = create_zmq_socket(
ctx=self.zmq_ctx,
socket_type=zmq.ROUTER,
ip=self.node_ip,
)
sock.bind(format_zmq_address(self.node_ip, port))
self.sockets[name] = sock
self.ports[name] = port
return
except zmq.ZMQError:
logger.warning(f"ZMQ bind {name} failed, retrying...")

def get_socket(self, name: str) -> zmq.Socket:
"""Get ZMQ socket by name."""
return self.sockets[name]

def start_daemon_thread(self, target, name: str) -> None:
"""Start a daemon thread with the given target functions."""
t = threading.Thread(target=target, name=name, daemon=True)
t.start()
self.threads.append(t)

def build_server_info(self, role: Role, id: str) -> ZMQServerInfo:
"""Build ZMQServerInfo."""
return ZMQServerInfo(
role=role,
id=id,
ip=self.node_ip,
ports=self.ports,
)


def is_ipv6_address(ip: str) -> bool:
"""Check if the given IP address is an IPv6 address."""
try:
Expand Down
Loading