From cb2dc5f27a286f5c93d4dfdedae27041f3ea599c Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Sat, 2 May 2026 13:08:43 +0800 Subject: [PATCH 1/2] update --- .../simple_use_case/single_controller_demo.py | 5 +- transfer_queue/controller.py | 14 +- transfer_queue/interface.py | 161 ++++++++---------- transfer_queue/storage/simple_backend.py | 2 +- transfer_queue/utils/zmq_utils.py | 4 +- 5 files changed, 74 insertions(+), 112 deletions(-) diff --git a/recipe/simple_use_case/single_controller_demo.py b/recipe/simple_use_case/single_controller_demo.py index af8689ab..3fe09aa3 100644 --- a/recipe/simple_use_case/single_controller_demo.py +++ b/recipe/simple_use_case/single_controller_demo.py @@ -32,9 +32,10 @@ import transfer_queue as tq from transfer_queue import KVBatchMeta +from transfer_queue.utils.logging_utils import get_logger -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) + +logger = get_logger(__name__) os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 49aa2410..dd6d60ff 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -384,7 +384,6 @@ def allocated_samples_num(self) -> int: return self.production_status.shape[0] # ==================== Index Pre-Allocation Methods ==================== - def register_pre_allocated_indexes(self, allocated_indexes: list[int]): """ Register pre-allocated sample indexes to this partition. @@ -442,7 +441,6 @@ def activate_pre_allocated_indexes(self, sample_num: int) -> list[int]: return global_index_to_allocate # ==================== Dynamic Expansion Methods ==================== - def ensure_samples_capacity(self, required_samples: int) -> None: """ Ensure the production status tensor has enough rows for the required samples. @@ -492,7 +490,6 @@ def ensure_fields_capacity(self, required_fields: int) -> None: logger.debug(f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields") # ==================== Production Status Interface ==================== - def update_production_status( self, global_indices: list[int], @@ -611,7 +608,6 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): ) # ==================== Consumption Status Interface ==================== - def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]: """ Get or create consumption status for a specific task. @@ -713,7 +709,6 @@ def get_production_status_for_fields( return partition_global_index, production_status # ==================== Data Scanning and Query Methods ==================== - def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: """ Scan data status to find samples ready for consumption. @@ -761,7 +756,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: return ready_sample_indices # ==================== Metadata Methods ==================== - def get_field_schema( self, field_names: list[str], batch_global_indexes: list[int] | None = None ) -> dict[str, dict[str, Any]]: @@ -830,7 +824,6 @@ def set_custom_meta(self, custom_meta: dict[int, dict]) -> None: self.custom_meta.update(custom_meta) # ==================== Statistics and Monitoring ==================== - def get_statistics(self) -> dict[str, Any]: """Get detailed statistics for this partition.""" stats = { @@ -877,7 +870,6 @@ def get_statistics(self) -> dict[str, Any]: return stats # ==================== Serialization ==================== - def to_snapshot(self): """ Get a snapshot of partition status information. @@ -995,6 +987,7 @@ def __init__( - If True, the controller will return an empty BatchMeta when no enough data is available. The user side is responsible for handling this empty case (retrying later). """ + breakpoint() if isinstance(sampler, BaseSampler): self.sampler = sampler elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): @@ -1028,7 +1021,6 @@ def __init__( logger.info(f"TransferQueue Controller {self.controller_id} initialized") # ==================== Partition Management API ==================== - def create_partition(self, partition_id: str) -> bool: """ Create a new data partition with pre-allocated sample indexes. @@ -1098,7 +1090,6 @@ def list_partitions(self) -> list[str]: return list(self.partitions.keys()) # ==================== Partition Index Management API ==================== - def get_partition_index_range(self, partition_id: str) -> list[int]: """ Get all indexes for a specific partition. @@ -1114,7 +1105,6 @@ def get_partition_index_range(self, partition_id: str) -> list[int]: return self.index_manager.get_indexes_for_partition(partition_id) # ==================== Data Production API ==================== - def update_production_status( self, partition_id: str, @@ -1150,7 +1140,6 @@ def update_production_status( return success # ==================== Data Consumption API ==================== - def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Get or create consumption status for a specific task and partition. @@ -1398,7 +1387,6 @@ def scan_data_status( return ready_sample_indices # ==================== Metadata Generation API ==================== - def generate_batch_meta( self, partition_id: str, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 66141b8f..58ff5078 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -43,40 +43,38 @@ logger = get_logger(__name__) -_TRANSFER_QUEUE_CLIENT: Any = None -_TRANSFER_QUEUE_STORAGE: Any = None -_TRANSFER_QUEUE_CONTROLLER: Any = None +_TQ_CLIENT: Any = None +_TQ_STORAGE: Any = None +_TQ_CONTROLLER: Any = None -def _maybe_create_transferqueue_client( - conf: Optional[DictConfig] = None, -) -> TransferQueueClient: - global _TRANSFER_QUEUE_CLIENT - if _TRANSFER_QUEUE_CLIENT is None: +def _maybe_create_tq_client(conf: Optional[DictConfig] = None) -> TransferQueueClient: + global _TQ_CLIENT + if _TQ_CLIENT is None: if conf is None: _init_from_existing() - assert _TRANSFER_QUEUE_CLIENT is not None, ( + assert _TQ_CLIENT is not None, ( "TransferQueueController has not been initialized yet. Please call init() first." ) - return _TRANSFER_QUEUE_CLIENT + return _TQ_CLIENT pid = os.getpid() - _TRANSFER_QUEUE_CLIENT = TransferQueueClient( + _TQ_CLIENT = TransferQueueClient( client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info ) - backend_name = conf.backend.storage_backend + # backend = conf.storage.backend + name = conf.backend.storage_backend + _TQ_CLIENT.initialize_storage_manager(manager_type=name, config=conf.backend[name]) - _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) + return _TQ_CLIENT - return _TRANSFER_QUEUE_CLIENT - -def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: - global _TRANSFER_QUEUE_STORAGE - - if _TRANSFER_QUEUE_STORAGE is None: - _TRANSFER_QUEUE_STORAGE = {} +# TODO(hz): Adopt registry pattern to manage storage backends for better scalability. +def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig: + global _TQ_STORAGE + if _TQ_STORAGE is None: + _TQ_STORAGE = {} if conf.backend.storage_backend == "SimpleStorage": # initialize SimpleStorageUnit simple_storage_handles = {} @@ -98,7 +96,7 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: storage_zmq_info = process_zmq_server_info(simple_storage_handles) backend_name = conf.backend.storage_backend conf.backend[backend_name].zmq_info = storage_zmq_info - _TRANSFER_QUEUE_STORAGE["SimpleStorage"] = simple_storage_handles + _TQ_STORAGE["SimpleStorage"] = simple_storage_handles if conf.backend.storage_backend == "MooncakeStore": if conf.backend.MooncakeStore.auto_init: # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts @@ -186,9 +184,9 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: f"mooncake_master exited with error. Check {log_file_path} for detailed logs. " f"Output:\n{error_msg}" ) - _TRANSFER_QUEUE_STORAGE["MooncakeStore"] = process + _TQ_STORAGE["MooncakeStore"] = process if conf.backend.storage_backend == "Yuanrong" and conf.backend.Yuanrong.auto_init: - _TRANSFER_QUEUE_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf) + _TQ_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf) return conf @@ -198,10 +196,10 @@ def _init_from_existing() -> bool: Returns: True if successfully initialized from existing controller, False otherwise. """ - global _TRANSFER_QUEUE_CONTROLLER + global _TQ_CONTROLLER try: - if _TRANSFER_QUEUE_CONTROLLER is None: - _TRANSFER_QUEUE_CONTROLLER = ray.get_actor("TransferQueueController") + if _TQ_CONTROLLER is None: + _TQ_CONTROLLER = ray.get_actor("TransferQueueController") except ValueError: logger.info("Called _init_from_existing() but TransferQueueController has not been initialized yet.") @@ -211,10 +209,9 @@ def _init_from_existing() -> bool: conf = None while conf is None: - conf = ray.get(_TRANSFER_QUEUE_CONTROLLER.get_config.remote()) + conf = ray.get(_TQ_CONTROLLER.get_config.remote()) if conf is not None: - _maybe_create_transferqueue_client(conf) - + _maybe_create_tq_client(conf) logger.info("TransferQueueClient initialized.") return True @@ -231,21 +228,16 @@ def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: This function sets up the TransferQueue controller, distributed storage, and client. It should be called once at the beginning of the program before any data operations. - If a controller already exists (e.g., initialized by another process), this function - will retrieve the config from existing controller and initialize the TransferQueueClient. - In this case, the `conf` parameter will be ignored. + If a controller already exists, reuse it and only initialize the client; + the provided `conf` will be ignored in this case. Args: - conf: Optional configuration dictionary. If provided, it will be merged with - the default config from 'config.yaml'. This is only used for first-time - initializing. When connecting to an existing controller, this parameter - is ignored. + conf: Optional custom config merged with default `config.yaml`. + Only takes effect on first-time initialization, ignored when attaching + to an existing controller. Returns: The merged configuration dictionary. - Raises: - ValueError: If config is not valid or required configuration keys are missing. - Example: >>> # In process 0, node A >>> import transfer_queue as tq @@ -261,36 +253,29 @@ def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: if _init_from_existing(): return conf - # First-time initialize TransferQueue logger.info("No TransferQueueController found. Starting first-time initialization...") - # create config final_conf = OmegaConf.create({}, flags={"allow_objects": True}) default_conf = OmegaConf.load(resources.files("transfer_queue") / "config.yaml") final_conf = OmegaConf.merge(final_conf, default_conf) if conf: final_conf = OmegaConf.merge(final_conf, conf) - # create controller + # TODO(hz): support load custom sampler class from external modules. try: sampler = final_conf.controller.sampler if isinstance(sampler, BaseSampler): - # user pass a pre-initialized sampler instance sampler = sampler elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): - # user pass a sampler class sampler = sampler() elif isinstance(sampler, str): - # user pass a sampler name str - # try to convert as sampler class sampler = globals()[final_conf.controller.sampler] except KeyError: raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None try: - # Ray will make sure actor with same name can only be created once - global _TRANSFER_QUEUE_CONTROLLER - _TRANSFER_QUEUE_CONTROLLER = TransferQueueController.options(name="TransferQueueController").remote( # type: ignore[attr-defined] + global _TQ_CONTROLLER + _TQ_CONTROLLER = TransferQueueController.options(name="TransferQueueController").remote( # type: ignore[attr-defined] sampler=sampler, polling_mode=final_conf.controller.polling_mode ) logger.info("TransferQueueController has been created.") @@ -299,18 +284,15 @@ def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: _init_from_existing() return final_conf - controller_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_CONTROLLER) + controller_zmq_info = process_zmq_server_info(_TQ_CONTROLLER) final_conf.controller.zmq_info = controller_zmq_info - # create distributed storage backends - final_conf = _maybe_create_transferqueue_storage(final_conf) + final_conf = _maybe_create_tq_storage(final_conf) - # store the config into controller - ray.get(_TRANSFER_QUEUE_CONTROLLER.store_config.remote(final_conf)) + ray.get(_TQ_CONTROLLER.store_config.remote(final_conf)) logger.info(f"TransferQueue config: {final_conf}") - # create client - _maybe_create_transferqueue_client(final_conf) + _maybe_create_tq_client(final_conf) return final_conf @@ -321,17 +303,14 @@ def close(): - Closing the client and its associated resources - Cleaning up distributed storage (only for the process that initialized it) - Killing the controller actor - - Note: - This function should be called when the TransferQueue system is no longer needed. """ - global _TRANSFER_QUEUE_CLIENT - global _TRANSFER_QUEUE_STORAGE - global _TRANSFER_QUEUE_CONTROLLER + global _TQ_CLIENT + global _TQ_STORAGE + global _TQ_CONTROLLER try: - if _TRANSFER_QUEUE_STORAGE: - for key, value in _TRANSFER_QUEUE_STORAGE.items(): + if _TQ_STORAGE: + for key, value in _TQ_STORAGE.items(): if key == "SimpleStorage": # only the process that do first-time init can clean the distributed storage for storage in value.values(): @@ -345,9 +324,9 @@ def close(): f"Consider manually killing the mooncake_master." ) - if _TRANSFER_QUEUE_CLIENT: + if _TQ_CLIENT: try: - ret = _TRANSFER_QUEUE_CLIENT.storage_manager.storage_client._store.remove_all() + ret = _TQ_CLIENT.storage_manager.storage_client._store.remove_all() if ret < 0: logger.error("Failed to remove existing keys in mooncake_master.") else: @@ -357,28 +336,25 @@ def close(): elif key == "Yuanrong": cleanup_yuanrong_resources(value) else: - logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") + logger.warning(f"close for _TQ_STORAGE with key {key} is not supported for now.") - _TRANSFER_QUEUE_STORAGE = None + _TQ_STORAGE = None except Exception: pass - if _TRANSFER_QUEUE_CLIENT: - _TRANSFER_QUEUE_CLIENT.close() - _TRANSFER_QUEUE_CLIENT = None - - if _TRANSFER_QUEUE_CONTROLLER: + if _TQ_CLIENT: try: - ray.kill(_TRANSFER_QUEUE_CONTROLLER) + _TQ_CLIENT.close() except Exception: pass - _TRANSFER_QUEUE_CONTROLLER = None + _TQ_CLIENT = None - try: - controller = ray.get_actor("TransferQueueController") - ray.kill(controller) - except Exception: - pass + if _TQ_CONTROLLER: + try: + ray.kill(_TQ_CONTROLLER) + except Exception: + pass + _TQ_CONTROLLER = None # ==================== High-Level KV Interface API ==================== @@ -440,7 +416,7 @@ def kv_put( if fields is None and tag is None: raise ValueError("Please provide at least one parameter of `fields` or `tag`.") - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() # 1. translate user-specified key to BatchMeta batch_meta = tq_client.kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True) @@ -547,7 +523,7 @@ def kv_batch_put( f"batch_size {fields.batch_size[0]}" ) - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() # 1. translate user-specified key to BatchMeta batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True) @@ -668,7 +644,7 @@ def kv_batch_get( ... select_fields="input_ids" ... ) """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) @@ -724,7 +700,7 @@ def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: >>> for pid, keys in all_partitions.items(): >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() partition_info = tq_client.kv_list(partition_id) @@ -753,7 +729,7 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None: if isinstance(keys, str): keys = [keys] - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size > 0: @@ -820,7 +796,7 @@ async def async_kv_put( if fields is None and tag is None: raise ValueError("Please provide at least one parameter of fields or tag.") - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() # 1. translate user-specified key to BatchMeta batch_meta = await tq_client.async_kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True) @@ -926,7 +902,7 @@ async def async_kv_batch_put( f"batch_size {fields.batch_size[0]}" ) - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() # 1. translate user-specified key to BatchMeta batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True) @@ -1049,7 +1025,7 @@ async def async_kv_batch_get( ... select_fields="input_ids" ... ) """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) @@ -1105,7 +1081,7 @@ async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[st >>> for pid, keys in all_partitions.items(): >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() partition_info = await tq_client.async_kv_list(partition_id) @@ -1134,7 +1110,7 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: if isinstance(keys, str): keys = [keys] - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size > 0: @@ -1145,6 +1121,5 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: # For low-level API support, please refer to transfer_queue/client.py for details. def get_client(): """Get a TransferQueueClient for using low-level API""" - if _TRANSFER_QUEUE_CLIENT is None: - raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!") - return _TRANSFER_QUEUE_CLIENT + assert _TQ_CLIENT is not None, ("Please initialize the TransferQueue first by calling `tq.init()`!") + return _TQ_CLIENT diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index bf334efc..303e5882 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -129,7 +129,7 @@ class SimpleStorageUnit: """A storage unit that provides distributed data storage functionality. This class represents a storage unit that can store data in a 2D structure - (samples × data fields) and provides ZMQ-based communication for put/get/clear operations. + (samples, data_fields) and provides ZMQ-based communication for put/get/clear operations. Note: We use Ray decorator (@ray.remote) only for initialization purposes. We do NOT use Ray's .remote() call capabilities - the storage unit runs diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 8af6aec3..42f79692 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -369,9 +369,7 @@ async def wrapper(self, *args, **kwargs): return decorator -def process_zmq_server_info( - handlers: dict[Any, Any] | Any, -): # noqa: UP007 +def process_zmq_server_info(handlers: dict[Any, Any] | Any): """Extract ZMQ server information from handler objects. Args: From 22ec2d57454043e56187a9d67017afa681110b5f Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Sat, 2 May 2026 14:00:56 +0800 Subject: [PATCH 2/2] =?UTF-8?q?[1/n]=20controller=E9=87=8D=E6=9E=84?= =?UTF-8?q?=EF=BC=8C=E6=8A=BD=E5=87=BAtransport=E5=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- transfer_queue/controller.py | 149 +++++++++++++++++------------------ 1 file changed, 71 insertions(+), 78 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index dd6d60ff..cc6805f6 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -53,6 +53,52 @@ TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1)) TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5)) + +class ZMQServerTransport: + """ + Unified management of ZMQ Router Sockets, port binding, daemon threads, and message I/O. + """ + + def __init__(self, node_ip: str, ctx: zmq.Context | None = None): + self.node_ip = node_ip + self.zmq_ctx = ctx or zmq.Context() + self.sockets: dict[str, zmq.Socket] = {} + self.ports: dict[str, int] = {} + self.threads: list[Thread] = [] + + def create_router_socket(self, socket_name: str) -> None: + """Create a ROUTER-type socket, automatically retrying port binding.""" + while True: + try: + port = get_free_port(ip=self.node_ip) + sock = create_zmq_socket( + ctx=self.zmq_ctx, + socket_type=zmq.ROUTER, + ip=self.node_ip, + ) + sock.bind(format_zmq_address(self.node_ip, port)) + self.sockets[socket_name] = sock + self.ports[socket_name] = port + return + except zmq.ZMQError: + logger.warning(f"ZMQ bind {socket_name} failed, retrying...") + + def get_socket(self, socket_name: str) -> zmq.Socket: + return self.sockets[socket_name] + + def start_daemon_thread(self, target, name: str) -> None: + t = Thread(target=target, name=name, daemon=True) + t.start() + self.threads.append(t) + + def build_server_info(self, role: TransferQueueRole, server_id: str) -> ZMQServerInfo: + return ZMQServerInfo( + role=role, + id=server_id, + ip=self.node_ip, + ports=self.ports.copy(), + ) + # Sample pre-allocation for StreamingDataLoader compatibility. # By pre-allocating sample indices (typically global_batch_size), consumers can accurately # determine consumption status even before producers have generated the samples. @@ -987,7 +1033,6 @@ def __init__( - If True, the controller will return an empty BatchMeta when no enough data is available. The user side is responsible for handling this empty case (retrying later). """ - breakpoint() if isinstance(sampler, BaseSampler): self.sampler = sampler elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): @@ -1001,8 +1046,9 @@ def __init__( self.polling_mode = polling_mode self.tq_config = None # global config for TransferQueue system - # Initialize ZMQ sockets for communication - self._init_zmq_socket() + # Initialize ZMQ transport layer + self._transport = ZMQServerTransport(node_ip=get_node_ip_address_raw()) + self._init_zmq_sockets() # Partition management self.partitions: dict[str, DataPartitionStatus] = {} # partition_id -> DataPartitionStatus @@ -1014,9 +1060,7 @@ def __init__( self._connected_storage_managers: set[str] = set() # Start background processing threads - self._start_process_handshake() - self._start_process_update_data_status() - self._start_process_request() + self._start_daemon_threads() logger.info(f"TransferQueue Controller {self.controller_id} initialized") @@ -1674,69 +1718,29 @@ def kv_retrieve_keys( return keys - def _init_zmq_socket(self): - """Initialize ZMQ sockets for communication.""" - self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address_raw() - - while True: - try: - self._handshake_socket_port = get_free_port(ip=self._node_ip) - self._request_handle_socket_port = get_free_port(ip=self._node_ip) - self._data_status_update_socket_port = get_free_port(ip=self._node_ip) - - self.handshake_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port)) - - self.request_handle_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port)) - - self.data_status_update_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.data_status_update_socket.bind( - format_zmq_address(self._node_ip, self._data_status_update_socket_port) - ) - - break - except zmq.ZMQError: - logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...") - continue - - self.zmq_server_info = ZMQServerInfo( + def _init_zmq_sockets(self): + self._transport.create_router_socket("handshake_socket") + self._transport.create_router_socket("request_handle_socket") + self._transport.create_router_socket("data_status_update_socket") + self.zmq_server_info = self._transport.build_server_info( role=TransferQueueRole.CONTROLLER, - id=self.controller_id, - ip=self._node_ip, - ports={ - "handshake_socket": self._handshake_socket_port, - "request_handle_socket": self._request_handle_socket_port, - "data_status_update_socket": self._data_status_update_socket_port, - }, + server_id=self.controller_id, ) def _wait_connection(self): """Wait for storage instances to complete handshake with retransmission support.""" + handshake_socket = self._transport.get_socket("handshake_socket") poller = zmq.Poller() - poller.register(self.handshake_socket, zmq.POLLIN) + poller.register(handshake_socket, zmq.POLLIN) logger.debug(f"Controller {self.controller_id} started waiting for storage connections...") while True: socks = dict(poller.poll(1000)) - if self.handshake_socket in socks: + if handshake_socket in socks: try: - messages = self.handshake_socket.recv_multipart(copy=False) + messages = handshake_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -1750,7 +1754,7 @@ def _wait_connection(self): sender_id=self.controller_id, body={}, ).serialize() - self.handshake_socket.send_multipart([identity, *response_msg]) + handshake_socket.send_multipart([identity, *response_msg]) # Track new connections if storage_manager_id not in self._connected_storage_managers: @@ -1770,42 +1774,30 @@ def _wait_connection(self): except Exception as e: logger.error(f"[{self.controller_id}]: error processing handshake: {e}") - def _start_process_handshake(self): - """Start the handshake process thread.""" - self.wait_connection_thread = Thread( + def _start_daemon_threads(self): + self._transport.start_daemon_thread( target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", - daemon=True, ) - self.wait_connection_thread.start() - - def _start_process_update_data_status(self): - """Start the data status update processing thread.""" - self.process_update_data_status_thread = Thread( + self._transport.start_daemon_thread( target=self._update_data_status, name="TransferQueueControllerProcessUpdateDataStatusThread", - daemon=True, ) - self.process_update_data_status_thread.start() - - def _start_process_request(self): - """Start the request processing thread.""" - self.process_request_thread = Thread( + self._transport.start_daemon_thread( target=self._process_request, name="TransferQueueControllerProcessRequestThread", - daemon=True, ) - self.process_request_thread.start() def _process_request(self): """Main request processing loop - adapted for partition-based operations.""" logger.info(f"[{self.controller_id}]: start processing requests...") + request_handle_socket = self._transport.get_socket("request_handle_socket") perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) while True: - messages = self.request_handle_socket.recv_multipart(copy=False) + messages = request_handle_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -2039,16 +2031,17 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) - self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) + request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def _update_data_status(self): """Process data status update messages from storage units - adapted for partitions.""" logger.debug(f"[{self.controller_id}]: start receiving update_data_status requests...") + data_status_update_socket = self._transport.get_socket("data_status_update_socket") perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) while True: - messages = self.data_status_update_socket.recv_multipart(copy=False) + messages = data_status_update_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -2079,7 +2072,7 @@ def _update_data_status(self): "success": success, }, ) - self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()]) + data_status_update_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: """Get ZMQ server connection information."""