Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions recipe/simple_use_case/single_controller_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@

import transfer_queue as tq
from transfer_queue import KVBatchMeta
from transfer_queue.utils.logging_utils import get_logger

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

logger = get_logger(__name__)

os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
Expand Down
161 changes: 71 additions & 90 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,52 @@
TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1))
TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5))


class ZMQServerTransport:
"""
Unified management of ZMQ Router Sockets, port binding, daemon threads, and message I/O.
"""

def __init__(self, node_ip: str, ctx: zmq.Context | None = None):
self.node_ip = node_ip
self.zmq_ctx = ctx or zmq.Context()
self.sockets: dict[str, zmq.Socket] = {}
self.ports: dict[str, int] = {}
self.threads: list[Thread] = []

def create_router_socket(self, socket_name: str) -> None:
"""Create a ROUTER-type socket, automatically retrying port binding."""
while True:
try:
port = get_free_port(ip=self.node_ip)
sock = create_zmq_socket(
ctx=self.zmq_ctx,
socket_type=zmq.ROUTER,
ip=self.node_ip,
)
sock.bind(format_zmq_address(self.node_ip, port))
self.sockets[socket_name] = sock
self.ports[socket_name] = port
return
except zmq.ZMQError:
logger.warning(f"ZMQ bind {socket_name} failed, retrying...")

def get_socket(self, socket_name: str) -> zmq.Socket:
return self.sockets[socket_name]

def start_daemon_thread(self, target, name: str) -> None:
t = Thread(target=target, name=name, daemon=True)
t.start()
self.threads.append(t)

def build_server_info(self, role: TransferQueueRole, server_id: str) -> ZMQServerInfo:
return ZMQServerInfo(
role=role,
id=server_id,
ip=self.node_ip,
ports=self.ports.copy(),
)

# Sample pre-allocation for StreamingDataLoader compatibility.
# By pre-allocating sample indices (typically global_batch_size), consumers can accurately
# determine consumption status even before producers have generated the samples.
Expand Down Expand Up @@ -384,7 +430,6 @@ def allocated_samples_num(self) -> int:
return self.production_status.shape[0]

# ==================== Index Pre-Allocation Methods ====================

def register_pre_allocated_indexes(self, allocated_indexes: list[int]):
"""
Register pre-allocated sample indexes to this partition.
Expand Down Expand Up @@ -442,7 +487,6 @@ def activate_pre_allocated_indexes(self, sample_num: int) -> list[int]:
return global_index_to_allocate

# ==================== Dynamic Expansion Methods ====================

def ensure_samples_capacity(self, required_samples: int) -> None:
"""
Ensure the production status tensor has enough rows for the required samples.
Expand Down Expand Up @@ -492,7 +536,6 @@ def ensure_fields_capacity(self, required_fields: int) -> None:
logger.debug(f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields")

# ==================== Production Status Interface ====================

def update_production_status(
self,
global_indices: list[int],
Expand Down Expand Up @@ -611,7 +654,6 @@ def mark_consumed(self, task_name: str, global_indices: list[int]):
)

# ==================== Consumption Status Interface ====================

def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]:
"""
Get or create consumption status for a specific task.
Expand Down Expand Up @@ -713,7 +755,6 @@ def get_production_status_for_fields(
return partition_global_index, production_status

# ==================== Data Scanning and Query Methods ====================

def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
"""
Scan data status to find samples ready for consumption.
Expand Down Expand Up @@ -761,7 +802,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
return ready_sample_indices

# ==================== Metadata Methods ====================

def get_field_schema(
self, field_names: list[str], batch_global_indexes: list[int] | None = None
) -> dict[str, dict[str, Any]]:
Expand Down Expand Up @@ -830,7 +870,6 @@ def set_custom_meta(self, custom_meta: dict[int, dict]) -> None:
self.custom_meta.update(custom_meta)

# ==================== Statistics and Monitoring ====================

def get_statistics(self) -> dict[str, Any]:
"""Get detailed statistics for this partition."""
stats = {
Expand Down Expand Up @@ -877,7 +916,6 @@ def get_statistics(self) -> dict[str, Any]:
return stats

# ==================== Serialization ====================

def to_snapshot(self):
"""
Get a snapshot of partition status information.
Expand Down Expand Up @@ -1008,8 +1046,9 @@ def __init__(
self.polling_mode = polling_mode
self.tq_config = None # global config for TransferQueue system

# Initialize ZMQ sockets for communication
self._init_zmq_socket()
# Initialize ZMQ transport layer
self._transport = ZMQServerTransport(node_ip=get_node_ip_address_raw())
self._init_zmq_sockets()

# Partition management
self.partitions: dict[str, DataPartitionStatus] = {} # partition_id -> DataPartitionStatus
Expand All @@ -1021,14 +1060,11 @@ def __init__(
self._connected_storage_managers: set[str] = set()

# Start background processing threads
self._start_process_handshake()
self._start_process_update_data_status()
self._start_process_request()
self._start_daemon_threads()

logger.info(f"TransferQueue Controller {self.controller_id} initialized")

# ==================== Partition Management API ====================

def create_partition(self, partition_id: str) -> bool:
"""
Create a new data partition with pre-allocated sample indexes.
Expand Down Expand Up @@ -1098,7 +1134,6 @@ def list_partitions(self) -> list[str]:
return list(self.partitions.keys())

# ==================== Partition Index Management API ====================

def get_partition_index_range(self, partition_id: str) -> list[int]:
"""
Get all indexes for a specific partition.
Expand All @@ -1114,7 +1149,6 @@ def get_partition_index_range(self, partition_id: str) -> list[int]:
return self.index_manager.get_indexes_for_partition(partition_id)

# ==================== Data Production API ====================

def update_production_status(
self,
partition_id: str,
Expand Down Expand Up @@ -1150,7 +1184,6 @@ def update_production_status(
return success

# ==================== Data Consumption API ====================

def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]:
"""
Get or create consumption status for a specific task and partition.
Expand Down Expand Up @@ -1398,7 +1431,6 @@ def scan_data_status(
return ready_sample_indices

# ==================== Metadata Generation API ====================

def generate_batch_meta(
self,
partition_id: str,
Expand Down Expand Up @@ -1686,69 +1718,29 @@ def kv_retrieve_keys(

return keys

def _init_zmq_socket(self):
"""Initialize ZMQ sockets for communication."""
self.zmq_context = zmq.Context()
self._node_ip = get_node_ip_address_raw()

while True:
try:
self._handshake_socket_port = get_free_port(ip=self._node_ip)
self._request_handle_socket_port = get_free_port(ip=self._node_ip)
self._data_status_update_socket_port = get_free_port(ip=self._node_ip)

self.handshake_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port))

self.request_handle_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port))

self.data_status_update_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.data_status_update_socket.bind(
format_zmq_address(self._node_ip, self._data_status_update_socket_port)
)

break
except zmq.ZMQError:
logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...")
continue

self.zmq_server_info = ZMQServerInfo(
def _init_zmq_sockets(self):
self._transport.create_router_socket("handshake_socket")
self._transport.create_router_socket("request_handle_socket")
self._transport.create_router_socket("data_status_update_socket")
self.zmq_server_info = self._transport.build_server_info(
role=TransferQueueRole.CONTROLLER,
id=self.controller_id,
ip=self._node_ip,
ports={
"handshake_socket": self._handshake_socket_port,
"request_handle_socket": self._request_handle_socket_port,
"data_status_update_socket": self._data_status_update_socket_port,
},
server_id=self.controller_id,
)

def _wait_connection(self):
"""Wait for storage instances to complete handshake with retransmission support."""
handshake_socket = self._transport.get_socket("handshake_socket")
poller = zmq.Poller()
poller.register(self.handshake_socket, zmq.POLLIN)
poller.register(handshake_socket, zmq.POLLIN)

logger.debug(f"Controller {self.controller_id} started waiting for storage connections...")

while True:
socks = dict(poller.poll(1000))

if self.handshake_socket in socks:
if handshake_socket in socks:
try:
messages = self.handshake_socket.recv_multipart(copy=False)
messages = handshake_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
Expand All @@ -1762,7 +1754,7 @@ def _wait_connection(self):
sender_id=self.controller_id,
body={},
).serialize()
self.handshake_socket.send_multipart([identity, *response_msg])
handshake_socket.send_multipart([identity, *response_msg])

# Track new connections
if storage_manager_id not in self._connected_storage_managers:
Expand All @@ -1782,42 +1774,30 @@ def _wait_connection(self):
except Exception as e:
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")

def _start_process_handshake(self):
"""Start the handshake process thread."""
self.wait_connection_thread = Thread(
def _start_daemon_threads(self):
self._transport.start_daemon_thread(
target=self._wait_connection,
name="TransferQueueControllerWaitConnectionThread",
daemon=True,
)
self.wait_connection_thread.start()

def _start_process_update_data_status(self):
"""Start the data status update processing thread."""
self.process_update_data_status_thread = Thread(
self._transport.start_daemon_thread(
target=self._update_data_status,
name="TransferQueueControllerProcessUpdateDataStatusThread",
daemon=True,
)
self.process_update_data_status_thread.start()

def _start_process_request(self):
"""Start the request processing thread."""
self.process_request_thread = Thread(
self._transport.start_daemon_thread(
target=self._process_request,
name="TransferQueueControllerProcessRequestThread",
daemon=True,
)
self.process_request_thread.start()

def _process_request(self):
"""Main request processing loop - adapted for partition-based operations."""

logger.info(f"[{self.controller_id}]: start processing requests...")

request_handle_socket = self._transport.get_socket("request_handle_socket")
perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)

while True:
messages = self.request_handle_socket.recv_multipart(copy=False)
messages = request_handle_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
Expand Down Expand Up @@ -2051,16 +2031,17 @@ def _process_request(self):
body={"partition_info": partition_info, "message": message},
)

self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
request_handle_socket.send_multipart([identity, *response_msg.serialize()])

def _update_data_status(self):
"""Process data status update messages from storage units - adapted for partitions."""
logger.debug(f"[{self.controller_id}]: start receiving update_data_status requests...")

data_status_update_socket = self._transport.get_socket("data_status_update_socket")
perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)

while True:
messages = self.data_status_update_socket.recv_multipart(copy=False)
messages = data_status_update_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
Expand Down Expand Up @@ -2091,7 +2072,7 @@ def _update_data_status(self):
"success": success,
},
)
self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()])
data_status_update_socket.send_multipart([identity, *response_msg.serialize()])

def get_zmq_server_info(self) -> ZMQServerInfo:
"""Get ZMQ server connection information."""
Expand Down
Loading
Loading