From 942391adaa0133fc15b566b58935a3df4157782a Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Sun, 29 Mar 2026 20:57:49 +0800 Subject: [PATCH 1/5] unify dynamic ZMQ socket decorator between simple_backend_manager and client Signed-off-by: ji-huazhong --- transfer_queue/client.py | 93 ++++-------------- .../managers/simple_backend_manager.py | 94 +++--------------- transfer_queue/utils/zmq_utils.py | 95 ++++++++++++++++++- 3 files changed, 129 insertions(+), 153 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 212c52f1..cee7f2e9 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -17,9 +17,7 @@ import logging import os import threading -from functools import wraps -from typing import Any, Callable, Optional -from uuid import uuid4 +from typing import Any, Optional import torch import zmq @@ -38,8 +36,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - create_zmq_socket, - format_zmq_address, + dynamic_zmq_socket, ) logger = logging.getLogger(__name__) @@ -53,6 +50,13 @@ TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8)) +# Pre-bound decorator for controller socket operations. +_controller_socket = dynamic_zmq_socket( + "request_handle_socket", + owner_id_attr="client_id", + server_attr="_controller", +) + class AsyncTransferQueueClient: """Asynchronous client for interacting with TransferQueue controller and storage systems. @@ -99,63 +103,8 @@ def initialize_storage_manager( manager_type, controller_info=self._controller, config=config ) - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. - @staticmethod - def dynamic_socket(socket_name: str): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers. - - Handles socket lifecycle: create -> connect -> inject -> close. - - Args: - socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port") - - Decorated Function Requirements: - 1. Must be an async class method (needs `self`) - 2. `self` must have: - - `_controller`: Server registry - - `client_id`: Unique client ID for socket identity - 3. Receives ZMQ socket via `socket` keyword argument (injected by decorator) - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - server_info = self._controller - if not server_info: - raise RuntimeError("No controller registered") - - context = zmq.asyncio.Context() - address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) - identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) - - try: - sock.connect(address) - logger.debug( - f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}") - raise - finally: - try: - if not sock.closed: - sock.close(linger=-1) - except Exception as e: - logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}") - - context.term() - - return wrapper - - return decorator - # ==================== Basic API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_meta( self, data_fields: list[str], @@ -245,7 +194,7 @@ async def async_get_meta( f"{response_msg.body.get('message', 'Unknown error')}" ) - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_set_custom_meta( self, metadata: BatchMeta, @@ -545,7 +494,7 @@ async def async_clear_samples(self, metadata: BatchMeta): except Exception as e: raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): """Clear metadata in the controller. @@ -571,7 +520,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: raise RuntimeError("Failed to clear samples metadata in controller.") - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta: """Get metadata required for the whole partition from controller. @@ -601,7 +550,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta return response_msg.body["metadata"] - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def _clear_partition_in_controller(self, partition_id, socket=None): """Clear the whole partition in the controller. @@ -628,7 +577,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") # ==================== Status Query API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_consumption_status( self, task_name: str, @@ -691,7 +640,7 @@ async def async_get_consumption_status( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_production_status( self, data_fields: list[str], @@ -823,7 +772,7 @@ async def async_check_production_status( return False return torch.all(production_status == 1).item() - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_reset_consumption( self, partition_id: str, @@ -885,7 +834,7 @@ async def async_reset_consumption( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_partition_list( self, socket: Optional[zmq.asyncio.Socket] = None, @@ -931,7 +880,7 @@ async def async_get_partition_list( raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e # ==================== KV Interface API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_kv_retrieve_meta( self, keys: list[str] | str, @@ -997,7 +946,7 @@ async def async_kv_retrieve_meta( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_kv_retrieve_keys( self, global_indexes: list[int] | int, @@ -1060,7 +1009,7 @@ async def async_kv_retrieve_keys( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_kv_list( self, partition_id: Optional[str] = None, diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 99c6c7d2..3e7cf852 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -19,10 +19,8 @@ import warnings from collections import defaultdict from collections.abc import Mapping -from functools import wraps from operator import itemgetter -from typing import Any, Callable, NamedTuple, Optional -from uuid import uuid4 +from typing import Any, NamedTuple import torch import zmq @@ -36,8 +34,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - create_zmq_socket, - format_zmq_address, + dynamic_zmq_socket, ) logger = logging.getLogger(__name__) @@ -51,6 +48,15 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds +# Pre-bound decorator for storage-unit socket operations. +_storage_unit_socket = dynamic_zmq_socket( + "put_get_socket", + owner_id_attr="storage_manager_id", + server_attr="storage_unit_infos", + target_kwarg="target_storage_unit", + timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT, +) + class RoutingGroup(NamedTuple): """Routing result for a single storage unit.""" @@ -114,78 +120,6 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn return server_infos_transform - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. - @staticmethod - def dynamic_storage_manager_socket(socket_name: str, timeout: int): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). - - Args: - socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). - timeout (float): Timeout in seconds for ZMQ connection (in seconds). - - Decorated Function Rules: - 1. Must be an async class method (needs `self`). - 2. `self` requires: - - `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]). - 3. Specify target server via: - - `target_storage_unit` arg. - 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator). - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - server_key = kwargs.get("target_storage_unit") - if server_key is None: - for arg in args: - if isinstance(arg, str) and arg in self.storage_unit_infos.keys(): - server_key = arg - break - - server_info = self.storage_unit_infos.get(server_key) - - if not server_info: - raise RuntimeError(f"Server {server_key} not found in registered servers") - - context = zmq.asyncio.Context() - address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) - identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity) - - try: - sock.connect(address) - # Timeouts to avoid indefinite await on recv/send - sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) - sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) - logger.debug( - f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error( - f"[{self.storage_manager_id}]: Error in socket operation with " - f"StorageUnit {server_info.id} at {address}: " - f"{type(e).__name__}: {e}" - ) - raise - finally: - try: - if not sock.closed: - sock.close(linger=-1) - except Exception as e: - logger.warning( - f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}" - ) - - context.term() - - return wrapper - - return decorator - def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]: """Group samples by global_idx % num_su, return {storage_id: RoutingGroup}. @@ -347,7 +281,7 @@ async def put_data( field_schema, ) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @_storage_unit_socket async def _put_to_single_storage_unit( self, global_indexes: list[int], @@ -483,7 +417,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: return TensorDict(tensor_data, batch_size=len(metadata)) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @_storage_unit_socket async def _get_from_single_storage_unit( self, global_indexes: list[int], @@ -555,7 +489,7 @@ async def clear_data(self, metadata: BatchMeta) -> None: if isinstance(result, Exception): logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @_storage_unit_socket async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create( diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 8afbb480..3b276fef 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -18,12 +18,14 @@ import socket import time from dataclasses import dataclass -from typing import Any, Optional, TypeAlias +from functools import wraps +from typing import Any, Callable, Mapping, Optional, TypeAlias from uuid import uuid4 import psutil import ray import zmq +import zmq.asyncio from ray.util import get_node_ip_address from transfer_queue.utils.enum_utils import ExplicitEnum, TransferQueueRole @@ -301,6 +303,97 @@ def create_zmq_socket( return socket +def dynamic_zmq_socket( + socket_name: str, + *, + owner_id_attr: str, + server_attr: str, + target_kwarg: Optional[str] = None, + timeout: Optional[int] = None, +): + """Create a reusable async decorator for request sockets. + + This decorator encapsulates the common socket lifecycle used by both + client-side and storage-manager-side request paths: + create context/socket -> connect -> inject socket -> close/term. + + Args: + socket_name: Socket port key in ``ZMQServerInfo.ports``. + owner_id_attr: Attribute name on ``self`` used in identity/log prefix + (e.g., ``client_id`` or ``storage_manager_id``). + server_attr: Attribute name on ``self`` that stores server info. + - ``ZMQServerInfo`` for single-target calls. + - ``Mapping[str, ZMQServerInfo]`` for multi-target calls. + target_kwarg: Optional kwarg name that provides target server id when + ``server_attr`` is a mapping. + timeout: Optional timeout (seconds) for both send/recv operations. + """ + + def decorator(func: Callable): + @wraps(func) + async def wrapper(self, *args, **kwargs): + owner_id = getattr(self, owner_id_attr, None) + if owner_id is None: + raise RuntimeError(f"Missing owner id attribute: {owner_id_attr}") + + server_obj = getattr(self, server_attr, None) + if server_obj is None: + raise RuntimeError(f"Missing server registry attribute: {server_attr}") + + target_name: Optional[str] = None + if target_kwarg is not None: + target_name = kwargs.get(target_kwarg) + if target_name is None: + for arg in args: + if isinstance(arg, str): + target_name = arg + break + + if isinstance(server_obj, ZMQServerInfo): + if target_name is not None and target_name != server_obj.id: + raise RuntimeError( + f"Target mismatch: target '{target_name}' does not match registered server '{server_obj.id}'" + ) + server_info = server_obj + elif isinstance(server_obj, Mapping): + if target_name is None: + raise RuntimeError(f"Missing target server identifier via '{target_kwarg}'") + server_info = server_obj.get(target_name) + if server_info is None: + raise RuntimeError(f"Server '{target_name}' not found in registered servers") + else: + raise RuntimeError( + f"Unsupported server registry type for '{server_attr}': {type(server_obj).__name__}" + ) + + port = server_info.ports.get(socket_name) + if port is None: + raise RuntimeError(f"Socket '{socket_name}' not configured for server '{server_info.id}'") + + context = zmq.asyncio.Context() + address = format_zmq_address(server_info.ip, port) + identity = f"{owner_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() + sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity=identity) + + try: + sock.connect(address) + if timeout is not None: + sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) + sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) + kwargs["socket"] = sock + return await func(self, *args, **kwargs) + finally: + try: + if not sock.closed: + sock.close(linger=-1) + finally: + context.term() + + return wrapper + + return decorator + + def process_zmq_server_info( handlers: dict[Any, Any] | Any, ): # noqa: UP007 From 736680f39b8b84e6a4f181562a1d71362a391ac8 Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Sun, 29 Mar 2026 21:05:17 +0800 Subject: [PATCH 2/5] fix Signed-off-by: ji-huazhong --- transfer_queue/utils/zmq_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 3b276fef..f7d5ad16 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -362,9 +362,7 @@ async def wrapper(self, *args, **kwargs): if server_info is None: raise RuntimeError(f"Server '{target_name}' not found in registered servers") else: - raise RuntimeError( - f"Unsupported server registry type for '{server_attr}': {type(server_obj).__name__}" - ) + raise RuntimeError(f"Unsupported server registry type for '{server_attr}': {type(server_obj).__name__}") port = server_info.ports.get(socket_name) if port is None: From 20a5ece31e81358c4bf5b55f2b1b61a6afcffe2d Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Sun, 29 Mar 2026 21:46:37 +0800 Subject: [PATCH 3/5] apply review suggestion from Copilot Signed-off-by: ji-huazhong --- transfer_queue/utils/zmq_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index f7d5ad16..17d749c6 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -17,9 +17,10 @@ import os import socket import time +from collections.abc import Mapping from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Mapping, Optional, TypeAlias +from typing import Any, Callable, Optional, TypeAlias from uuid import uuid4 import psutil From 759aa40f95825f658980abababa507d70da944f2 Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Mon, 27 Apr 2026 11:32:30 +0800 Subject: [PATCH 4/5] update Signed-off-by: ji-huazhong --- transfer_queue/client.py | 34 +++++------ .../managers/simple_backend_manager.py | 22 +++---- transfer_queue/utils/zmq_utils.py | 61 +++++++------------ 3 files changed, 49 insertions(+), 68 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index cee7f2e9..9b064c84 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -17,7 +17,7 @@ import logging import os import threading -from typing import Any, Optional +from typing import Any, Callable, Optional import torch import zmq @@ -36,7 +36,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - dynamic_zmq_socket, + with_zmq_socket, ) logger = logging.getLogger(__name__) @@ -51,10 +51,10 @@ TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8)) # Pre-bound decorator for controller socket operations. -_controller_socket = dynamic_zmq_socket( +with_controller_socket = with_zmq_socket( "request_handle_socket", - owner_id_attr="client_id", - server_attr="_controller", + get_identity=lambda self: self.client_id, + get_peer=lambda self, target: self._controller, ) @@ -104,7 +104,7 @@ def initialize_storage_manager( ) # ==================== Basic API ==================== - @_controller_socket + @with_controller_socket async def async_get_meta( self, data_fields: list[str], @@ -194,7 +194,7 @@ async def async_get_meta( f"{response_msg.body.get('message', 'Unknown error')}" ) - @_controller_socket + @with_controller_socket async def async_set_custom_meta( self, metadata: BatchMeta, @@ -494,7 +494,7 @@ async def async_clear_samples(self, metadata: BatchMeta): except Exception as e: raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e - @_controller_socket + @with_controller_socket async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): """Clear metadata in the controller. @@ -520,7 +520,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: raise RuntimeError("Failed to clear samples metadata in controller.") - @_controller_socket + @with_controller_socket async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta: """Get metadata required for the whole partition from controller. @@ -550,7 +550,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta return response_msg.body["metadata"] - @_controller_socket + @with_controller_socket async def _clear_partition_in_controller(self, partition_id, socket=None): """Clear the whole partition in the controller. @@ -577,7 +577,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") # ==================== Status Query API ==================== - @_controller_socket + @with_controller_socket async def async_get_consumption_status( self, task_name: str, @@ -640,7 +640,7 @@ async def async_get_consumption_status( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e - @_controller_socket + @with_controller_socket async def async_get_production_status( self, data_fields: list[str], @@ -772,7 +772,7 @@ async def async_check_production_status( return False return torch.all(production_status == 1).item() - @_controller_socket + @with_controller_socket async def async_reset_consumption( self, partition_id: str, @@ -834,7 +834,7 @@ async def async_reset_consumption( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e - @_controller_socket + @with_controller_socket async def async_get_partition_list( self, socket: Optional[zmq.asyncio.Socket] = None, @@ -880,7 +880,7 @@ async def async_get_partition_list( raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e # ==================== KV Interface API ==================== - @_controller_socket + @with_controller_socket async def async_kv_retrieve_meta( self, keys: list[str] | str, @@ -946,7 +946,7 @@ async def async_kv_retrieve_meta( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e - @_controller_socket + @with_controller_socket async def async_kv_retrieve_keys( self, global_indexes: list[int] | int, @@ -1009,7 +1009,7 @@ async def async_kv_retrieve_keys( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e - @_controller_socket + @with_controller_socket async def async_kv_list( self, partition_id: Optional[str] = None, diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 3e7cf852..005ca9a8 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -20,7 +20,7 @@ from collections import defaultdict from collections.abc import Mapping from operator import itemgetter -from typing import Any, NamedTuple +from typing import Any, Callable, NamedTuple import torch import zmq @@ -34,7 +34,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - dynamic_zmq_socket, + with_zmq_socket, ) logger = logging.getLogger(__name__) @@ -49,11 +49,11 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds # Pre-bound decorator for storage-unit socket operations. -_storage_unit_socket = dynamic_zmq_socket( +with_storage_unit_socket = with_zmq_socket( "put_get_socket", - owner_id_attr="storage_manager_id", - server_attr="storage_unit_infos", - target_kwarg="target_storage_unit", + get_identity=lambda self: self.storage_manager_id, + get_peer=lambda self, target: self.storage_unit_infos[target], + resolve_target=lambda args, kwargs: kwargs.get("target_storage_unit"), timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT, ) @@ -220,7 +220,7 @@ def _select_by_positions(field_data, positions: list[int]): return field_data[positions] async def put_data( - self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None ) -> None: """ Send data to remote StorageUnit based on metadata. @@ -281,13 +281,13 @@ async def put_data( field_schema, ) - @_storage_unit_socket + @with_storage_unit_socket async def _put_to_single_storage_unit( self, global_indexes: list[int], storage_data: dict[str, Any], target_storage_unit: str, - data_parser: Optional[Callable[[Any], Any]] = None, + data_parser: Callable[[Any], Any] | None = None, socket: zmq.Socket = None, ): """ @@ -417,7 +417,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: return TensorDict(tensor_data, batch_size=len(metadata)) - @_storage_unit_socket + @with_storage_unit_socket async def _get_from_single_storage_unit( self, global_indexes: list[int], @@ -489,7 +489,7 @@ async def clear_data(self, metadata: BatchMeta) -> None: if isinstance(result, Exception): logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") - @_storage_unit_socket + @with_storage_unit_socket async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create( diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 17d749c6..62cfb9c3 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -17,7 +17,6 @@ import os import socket import time -from collections.abc import Mapping from dataclasses import dataclass from functools import wraps from typing import Any, Callable, Optional, TypeAlias @@ -304,12 +303,12 @@ def create_zmq_socket( return socket -def dynamic_zmq_socket( +def with_zmq_socket( socket_name: str, *, - owner_id_attr: str, - server_attr: str, - target_kwarg: Optional[str] = None, + get_identity: Callable[[Any], str], + get_peer: Callable[[Any, Optional[str]], ZMQServerInfo], + resolve_target: Optional[Callable[[tuple, dict], Optional[str]]] = None, timeout: Optional[int] = None, ): """Create a reusable async decorator for request sockets. @@ -320,50 +319,32 @@ def dynamic_zmq_socket( Args: socket_name: Socket port key in ``ZMQServerInfo.ports``. - owner_id_attr: Attribute name on ``self`` used in identity/log prefix - (e.g., ``client_id`` or ``storage_manager_id``). - server_attr: Attribute name on ``self`` that stores server info. - - ``ZMQServerInfo`` for single-target calls. - - ``Mapping[str, ZMQServerInfo]`` for multi-target calls. - target_kwarg: Optional kwarg name that provides target server id when - ``server_attr`` is a mapping. + get_identity: Callable that extracts owner identity from ``self``. + Example: ``lambda self: self.client_id`` + get_peer: Callable that returns ``ZMQServerInfo`` for the target. + For single-target scenarios, ignore the target parameter. + Example: ``lambda self, target: self.server_info`` + Example: ``lambda self, target: self.storage_unit_infos[target]`` + resolve_target: Optional callable that extracts target identifier from + function arguments. Receives (args, kwargs) and returns target name. + Example: ``lambda args, kwargs: kwargs.get("target_storage_unit")`` timeout: Optional timeout (seconds) for both send/recv operations. """ def decorator(func: Callable): @wraps(func) async def wrapper(self, *args, **kwargs): - owner_id = getattr(self, owner_id_attr, None) + owner_id = get_identity(self) if owner_id is None: - raise RuntimeError(f"Missing owner id attribute: {owner_id_attr}") - - server_obj = getattr(self, server_attr, None) - if server_obj is None: - raise RuntimeError(f"Missing server registry attribute: {server_attr}") + raise RuntimeError("get_identity returned None") target_name: Optional[str] = None - if target_kwarg is not None: - target_name = kwargs.get(target_kwarg) - if target_name is None: - for arg in args: - if isinstance(arg, str): - target_name = arg - break - - if isinstance(server_obj, ZMQServerInfo): - if target_name is not None and target_name != server_obj.id: - raise RuntimeError( - f"Target mismatch: target '{target_name}' does not match registered server '{server_obj.id}'" - ) - server_info = server_obj - elif isinstance(server_obj, Mapping): - if target_name is None: - raise RuntimeError(f"Missing target server identifier via '{target_kwarg}'") - server_info = server_obj.get(target_name) - if server_info is None: - raise RuntimeError(f"Server '{target_name}' not found in registered servers") - else: - raise RuntimeError(f"Unsupported server registry type for '{server_attr}': {type(server_obj).__name__}") + if resolve_target is not None: + target_name = resolve_target(args, kwargs) + + server_info = get_peer(self, target_name) + if server_info is None: + raise RuntimeError(f"get_peer returned None for target '{target_name}'") port = server_info.ports.get(socket_name) if port is None: From 01a5a4aa6513a34001c2d73662dec42520cc1ef0 Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Mon, 27 Apr 2026 14:59:05 +0800 Subject: [PATCH 5/5] apply review suggestion from copilot Signed-off-by: ji-huazhong --- transfer_queue/utils/zmq_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 62cfb9c3..2858e285 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -351,11 +351,11 @@ async def wrapper(self, *args, **kwargs): raise RuntimeError(f"Socket '{socket_name}' not configured for server '{server_info.id}'") context = zmq.asyncio.Context() - address = format_zmq_address(server_info.ip, port) - identity = f"{owner_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity=identity) - + sock = None try: + address = format_zmq_address(server_info.ip, port) + identity = f"{owner_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() + sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity=identity) sock.connect(address) if timeout is not None: sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) @@ -363,10 +363,13 @@ async def wrapper(self, *args, **kwargs): kwargs["socket"] = sock return await func(self, *args, **kwargs) finally: - try: - if not sock.closed: - sock.close(linger=-1) - finally: + if sock is not None: + try: + if not sock.closed: + sock.close(linger=-1) + finally: + context.term() + else: context.term() return wrapper