diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 212c52f..9b064c8 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -17,9 +17,7 @@ import logging import os import threading -from functools import wraps from typing import Any, Callable, Optional -from uuid import uuid4 import torch import zmq @@ -38,8 +36,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - create_zmq_socket, - format_zmq_address, + with_zmq_socket, ) logger = logging.getLogger(__name__) @@ -53,6 +50,13 @@ TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8)) +# Pre-bound decorator for controller socket operations. +with_controller_socket = with_zmq_socket( + "request_handle_socket", + get_identity=lambda self: self.client_id, + get_peer=lambda self, target: self._controller, +) + class AsyncTransferQueueClient: """Asynchronous client for interacting with TransferQueue controller and storage systems. @@ -99,63 +103,8 @@ def initialize_storage_manager( manager_type, controller_info=self._controller, config=config ) - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. - @staticmethod - def dynamic_socket(socket_name: str): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers. - - Handles socket lifecycle: create -> connect -> inject -> close. - - Args: - socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port") - - Decorated Function Requirements: - 1. Must be an async class method (needs `self`) - 2. `self` must have: - - `_controller`: Server registry - - `client_id`: Unique client ID for socket identity - 3. Receives ZMQ socket via `socket` keyword argument (injected by decorator) - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - server_info = self._controller - if not server_info: - raise RuntimeError("No controller registered") - - context = zmq.asyncio.Context() - address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) - identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) - - try: - sock.connect(address) - logger.debug( - f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}") - raise - finally: - try: - if not sock.closed: - sock.close(linger=-1) - except Exception as e: - logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}") - - context.term() - - return wrapper - - return decorator - # ==================== Basic API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_get_meta( self, data_fields: list[str], @@ -245,7 +194,7 @@ async def async_get_meta( f"{response_msg.body.get('message', 'Unknown error')}" ) - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_set_custom_meta( self, metadata: BatchMeta, @@ -545,7 +494,7 @@ async def async_clear_samples(self, metadata: BatchMeta): except Exception as e: raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): """Clear metadata in the controller. @@ -571,7 +520,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: raise RuntimeError("Failed to clear samples metadata in controller.") - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta: """Get metadata required for the whole partition from controller. @@ -601,7 +550,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta return response_msg.body["metadata"] - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def _clear_partition_in_controller(self, partition_id, socket=None): """Clear the whole partition in the controller. @@ -628,7 +577,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") # ==================== Status Query API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_get_consumption_status( self, task_name: str, @@ -691,7 +640,7 @@ async def async_get_consumption_status( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_get_production_status( self, data_fields: list[str], @@ -823,7 +772,7 @@ async def async_check_production_status( return False return torch.all(production_status == 1).item() - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_reset_consumption( self, partition_id: str, @@ -885,7 +834,7 @@ async def async_reset_consumption( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_get_partition_list( self, socket: Optional[zmq.asyncio.Socket] = None, @@ -931,7 +880,7 @@ async def async_get_partition_list( raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e # ==================== KV Interface API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_kv_retrieve_meta( self, keys: list[str] | str, @@ -997,7 +946,7 @@ async def async_kv_retrieve_meta( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_kv_retrieve_keys( self, global_indexes: list[int] | int, @@ -1060,7 +1009,7 @@ async def async_kv_retrieve_keys( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @with_controller_socket async def async_kv_list( self, partition_id: Optional[str] = None, diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 99c6c7d..005ca9a 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -19,10 +19,8 @@ import warnings from collections import defaultdict from collections.abc import Mapping -from functools import wraps from operator import itemgetter -from typing import Any, Callable, NamedTuple, Optional -from uuid import uuid4 +from typing import Any, Callable, NamedTuple import torch import zmq @@ -36,8 +34,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - create_zmq_socket, - format_zmq_address, + with_zmq_socket, ) logger = logging.getLogger(__name__) @@ -51,6 +48,15 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds +# Pre-bound decorator for storage-unit socket operations. +with_storage_unit_socket = with_zmq_socket( + "put_get_socket", + get_identity=lambda self: self.storage_manager_id, + get_peer=lambda self, target: self.storage_unit_infos[target], + resolve_target=lambda args, kwargs: kwargs.get("target_storage_unit"), + timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT, +) + class RoutingGroup(NamedTuple): """Routing result for a single storage unit.""" @@ -114,78 +120,6 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn return server_infos_transform - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. - @staticmethod - def dynamic_storage_manager_socket(socket_name: str, timeout: int): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). - - Args: - socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). - timeout (float): Timeout in seconds for ZMQ connection (in seconds). - - Decorated Function Rules: - 1. Must be an async class method (needs `self`). - 2. `self` requires: - - `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]). - 3. Specify target server via: - - `target_storage_unit` arg. - 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator). - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - server_key = kwargs.get("target_storage_unit") - if server_key is None: - for arg in args: - if isinstance(arg, str) and arg in self.storage_unit_infos.keys(): - server_key = arg - break - - server_info = self.storage_unit_infos.get(server_key) - - if not server_info: - raise RuntimeError(f"Server {server_key} not found in registered servers") - - context = zmq.asyncio.Context() - address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) - identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity) - - try: - sock.connect(address) - # Timeouts to avoid indefinite await on recv/send - sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) - sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) - logger.debug( - f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error( - f"[{self.storage_manager_id}]: Error in socket operation with " - f"StorageUnit {server_info.id} at {address}: " - f"{type(e).__name__}: {e}" - ) - raise - finally: - try: - if not sock.closed: - sock.close(linger=-1) - except Exception as e: - logger.warning( - f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}" - ) - - context.term() - - return wrapper - - return decorator - def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]: """Group samples by global_idx % num_su, return {storage_id: RoutingGroup}. @@ -286,7 +220,7 @@ def _select_by_positions(field_data, positions: list[int]): return field_data[positions] async def put_data( - self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None ) -> None: """ Send data to remote StorageUnit based on metadata. @@ -347,13 +281,13 @@ async def put_data( field_schema, ) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @with_storage_unit_socket async def _put_to_single_storage_unit( self, global_indexes: list[int], storage_data: dict[str, Any], target_storage_unit: str, - data_parser: Optional[Callable[[Any], Any]] = None, + data_parser: Callable[[Any], Any] | None = None, socket: zmq.Socket = None, ): """ @@ -483,7 +417,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: return TensorDict(tensor_data, batch_size=len(metadata)) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @with_storage_unit_socket async def _get_from_single_storage_unit( self, global_indexes: list[int], @@ -555,7 +489,7 @@ async def clear_data(self, metadata: BatchMeta) -> None: if isinstance(result, Exception): logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @with_storage_unit_socket async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create( diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 8afbb48..2858e28 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -18,12 +18,14 @@ import socket import time from dataclasses import dataclass -from typing import Any, Optional, TypeAlias +from functools import wraps +from typing import Any, Callable, Optional, TypeAlias from uuid import uuid4 import psutil import ray import zmq +import zmq.asyncio from ray.util import get_node_ip_address from transfer_queue.utils.enum_utils import ExplicitEnum, TransferQueueRole @@ -301,6 +303,80 @@ def create_zmq_socket( return socket +def with_zmq_socket( + socket_name: str, + *, + get_identity: Callable[[Any], str], + get_peer: Callable[[Any, Optional[str]], ZMQServerInfo], + resolve_target: Optional[Callable[[tuple, dict], Optional[str]]] = None, + timeout: Optional[int] = None, +): + """Create a reusable async decorator for request sockets. + + This decorator encapsulates the common socket lifecycle used by both + client-side and storage-manager-side request paths: + create context/socket -> connect -> inject socket -> close/term. + + Args: + socket_name: Socket port key in ``ZMQServerInfo.ports``. + get_identity: Callable that extracts owner identity from ``self``. + Example: ``lambda self: self.client_id`` + get_peer: Callable that returns ``ZMQServerInfo`` for the target. + For single-target scenarios, ignore the target parameter. + Example: ``lambda self, target: self.server_info`` + Example: ``lambda self, target: self.storage_unit_infos[target]`` + resolve_target: Optional callable that extracts target identifier from + function arguments. Receives (args, kwargs) and returns target name. + Example: ``lambda args, kwargs: kwargs.get("target_storage_unit")`` + timeout: Optional timeout (seconds) for both send/recv operations. + """ + + def decorator(func: Callable): + @wraps(func) + async def wrapper(self, *args, **kwargs): + owner_id = get_identity(self) + if owner_id is None: + raise RuntimeError("get_identity returned None") + + target_name: Optional[str] = None + if resolve_target is not None: + target_name = resolve_target(args, kwargs) + + server_info = get_peer(self, target_name) + if server_info is None: + raise RuntimeError(f"get_peer returned None for target '{target_name}'") + + port = server_info.ports.get(socket_name) + if port is None: + raise RuntimeError(f"Socket '{socket_name}' not configured for server '{server_info.id}'") + + context = zmq.asyncio.Context() + sock = None + try: + address = format_zmq_address(server_info.ip, port) + identity = f"{owner_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() + sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity=identity) + sock.connect(address) + if timeout is not None: + sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) + sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) + kwargs["socket"] = sock + return await func(self, *args, **kwargs) + finally: + if sock is not None: + try: + if not sock.closed: + sock.close(linger=-1) + finally: + context.term() + else: + context.term() + + return wrapper + + return decorator + + def process_zmq_server_info( handlers: dict[Any, Any] | Any, ): # noqa: UP007