From 7892ac0a501c19e62fb256a5617e80dd55bf1af4 Mon Sep 17 00:00:00 2001 From: fy2462 Date: Sun, 17 May 2026 12:04:11 +0800 Subject: [PATCH 1/5] [refactor] Register storage backend for greater scalability Signed-off-by: fy2462 --- scripts/put_benchmark.py | 2 +- tests/test_metadata.py | 2 +- tests/test_simple_storage_unit.py | 4 +- transfer_queue/interface.py | 134 +--- transfer_queue/storage/__init__.py | 2 +- transfer_queue/storage/backends/__init__.py | 24 + transfer_queue/storage/backends/base.py | 40 ++ .../storage/backends/mooncake_storage.py | 127 ++++ .../storage/backends/simple_storage.py | 663 ++++++++++++++++++ .../storage/backends/yuanrong_storage.py | 427 +++++++++++ transfer_queue/utils/yuanrong_utils.py | 402 +---------- 11 files changed, 1300 insertions(+), 527 deletions(-) create mode 100644 transfer_queue/storage/backends/__init__.py create mode 100644 transfer_queue/storage/backends/base.py create mode 100644 transfer_queue/storage/backends/mooncake_storage.py create mode 100644 transfer_queue/storage/backends/simple_storage.py create mode 100644 transfer_queue/storage/backends/yuanrong_storage.py diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index c67bb54c..55289f6f 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -30,7 +30,7 @@ from transfer_queue import TransferQueueClient from transfer_queue.controller import TransferQueueController -from transfer_queue.storage.simple_storage import SimpleStorageUnit +from transfer_queue.storage.backends.simple_storage import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.zmq_utils import process_zmq_server_info diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 6fc49d2e..e9218de8 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -807,7 +807,7 @@ class TestStorageUnitDataStrict: def test_put_data_length_mismatch_raises(self): """put_data must raise when global_indexes and field values have different lengths.""" - from transfer_queue.storage.simple_storage import StorageUnitData + from transfer_queue.storage.backends.simple_storage import StorageUnitData sud = StorageUnitData(storage_size=10) # 3 indexes but only 2 values — must raise, not silently drop diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index 319a46e7..c0d084b2 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -21,7 +21,7 @@ import torch import zmq -from transfer_queue.storage.simple_storage import SimpleStorageUnit +from transfer_queue.storage.backends.simple_storage import SimpleStorageUnit from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType @@ -420,7 +420,7 @@ def test_storage_unit_data_direct(): def test_storage_unit_data_capacity_uses_active_keys(): """Capacity check must use _active_keys, not scan field_data.""" - from transfer_queue.storage.simple_storage import StorageUnitData + from transfer_queue.storage.backends.simple_storage import StorageUnitData storage = StorageUnitData(storage_size=3) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index ba74a35c..0ee31508 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import os import subprocess import time from importlib import resources from typing import Any, Callable -from urllib.parse import urlparse import ray import torch @@ -32,13 +30,9 @@ from transfer_queue.metadata import KVBatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler -from transfer_queue.storage.simple_storage import SimpleStorageUnit -from transfer_queue.utils.common import get_placement_group +from transfer_queue.storage.backends.base import StorageBackendFactory from transfer_queue.utils.logging_utils import get_logger -from transfer_queue.utils.yuanrong_utils import ( - cleanup_yuanrong_resources, - initialize_yuanrong_backend, -) +from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources from transfer_queue.utils.zmq_utils import process_zmq_server_info logger = get_logger(__name__) @@ -70,125 +64,21 @@ def _maybe_create_tq_client(conf: DictConfig | None = None) -> TransferQueueClie return _TQ_CLIENT -# TODO(hz): Adopt registry pattern to manage storage backends for better scalability. def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig: global _TQ_STORAGE if _TQ_STORAGE is None: _TQ_STORAGE = {} - if conf.backend.storage_backend == "SimpleStorage": - # initialize SimpleStorageUnit - simple_storage_handles = {} - num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units - total_storage_size = conf.backend.SimpleStorage.total_storage_size - storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) - - for storage_unit_rank in range(num_data_storage_units): - storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] - placement_group=storage_placement_group, - placement_group_bundle_index=storage_unit_rank, - name=f"TransferQueueStorageUnit#{storage_unit_rank}", - ).remote( - storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), - ) - simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node - logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") - - storage_zmq_info = process_zmq_server_info(simple_storage_handles) - backend_name = conf.backend.storage_backend - conf.backend[backend_name].zmq_info = storage_zmq_info - _TQ_STORAGE["SimpleStorage"] = simple_storage_handles - if conf.backend.storage_backend == "MooncakeStore": - if conf.backend.MooncakeStore.auto_init: - # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts - check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) - if check.returncode == 0: - pids = check.stdout.strip().replace("\n", ", ") - logger.info(f"Find existing mooncake_master (PID: {pids}), try to kill first...") - - result = os.system('pkill -f "[m]ooncake_master"') - if result == 0: - logger.info("Successfully killed existing mooncake_master processes.") - else: - raise RuntimeError(f"Failed to kill existing mooncake_master processes (exit code: {result}).") - - # process metadata_server - metadata_server_raw_address = conf.backend.MooncakeStore.metadata_server - if "://" not in metadata_server_raw_address: - metadata_server_raw_address = "//" + metadata_server_raw_address - - metadata_server_parsed = urlparse(metadata_server_raw_address) - - if not metadata_server_parsed.hostname or metadata_server_parsed.port is None: - raise ValueError( - f"Invalid metadata_server '{conf.backend.MooncakeStore.metadata_server}'. " - f"Host and port are required (e.g., host:port)." - ) - - metadata_server_host = metadata_server_parsed.hostname - metadata_server_port = str(metadata_server_parsed.port) - - # process master_server - master_server_raw_address = conf.backend.MooncakeStore.master_server_address - if "://" not in master_server_raw_address: - master_server_raw_address = "//" + master_server_raw_address - - master_server_parsed = urlparse(master_server_raw_address) - - if not master_server_parsed.hostname or master_server_parsed.port is None: - raise ValueError( - f"Invalid master_server_address '{conf.backend.MooncakeStore.master_server_address}'. " - f"Host and port are required (e.g., host:port)." - ) - - master_server_port = str(master_server_parsed.port) - - cmd = [ - "mooncake_master", - "-client_ttl=30", - "-default_kv_lease_ttl=999999", - "-default_kv_soft_pin_ttl=999999", - "--eviction_high_watermark_ratio=1.0", - "--eviction_ratio=0.0", - "--enable_http_metadata_server=true", - "--allow_evict_soft_pinned_objects=false", - f"--http_metadata_server_host={metadata_server_host}", - f"--http_metadata_server_port={metadata_server_port}", - f"--rpc_port={master_server_port}", - ] - - log_file_path = "/tmp/mooncake_master.log" - with open(log_file_path, "w") as log_file: - process = subprocess.Popen( - cmd, - stdout=log_file, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True, - start_new_session=True, - ) - time.sleep(3) - - if process.poll() is None: - logger.info( - f"mooncake_master started, PID: {process.pid}. Logs are at: {os.path.abspath(log_file_path)}" - ) - else: - error_msg = "" - try: - with open(log_file_path) as f: - error_msg = f.read() - except Exception as e: - error_msg = f"Failed to read log file: {e}" - - raise RuntimeError( - f"mooncake_master exited with error. Check {log_file_path} for detailed logs. " - f"Output:\n{error_msg}" - ) - _TQ_STORAGE["MooncakeStore"] = process - if conf.backend.storage_backend == "Yuanrong" and conf.backend.Yuanrong.auto_init: - _TQ_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf) + backend_name = conf.backend.storage_backend + registered_backend_fn = StorageBackendFactory.get_backend(backend_name) + if registered_backend_fn: + backend_instance = registered_backend_fn(conf) + if backend_instance: + _TQ_STORAGE[backend_name] = backend_instance + else: + logger.error(f"Not found available {backend_name} storage backend instance, please check the config.") + else: + logger.error(f"Storage backend {backend_name} not registered. Please add it to the StorageBackendFactory.") return conf diff --git a/transfer_queue/storage/__init__.py b/transfer_queue/storage/__init__.py index 2fb1be46..809eeecc 100644 --- a/transfer_queue/storage/__init__.py +++ b/transfer_queue/storage/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .backends import SimpleStorageUnit, StorageUnitData from .managers import ( AsyncSimpleStorageManager, MooncakeStorageManager, @@ -21,7 +22,6 @@ StorageManagerFactory, YuanrongStorageManager, ) -from .simple_storage import SimpleStorageUnit, StorageUnitData __all__ = [ "SimpleStorageUnit", diff --git a/transfer_queue/storage/backends/__init__.py b/transfer_queue/storage/backends/__init__.py new file mode 100644 index 00000000..2056c9c7 --- /dev/null +++ b/transfer_queue/storage/backends/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2026 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import mooncake_storage, simple_storage, yuanrong_storage # noqa: F401, I001 +from .base import StorageBackendFactory +from .simple_storage import SimpleStorageUnit, StorageUnitData + +__all__ = [ + "StorageBackendFactory", + "SimpleStorageUnit", + "StorageUnitData", +] diff --git a/transfer_queue/storage/backends/base.py b/transfer_queue/storage/backends/base.py new file mode 100644 index 00000000..4efa00e7 --- /dev/null +++ b/transfer_queue/storage/backends/base.py @@ -0,0 +1,40 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2026 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps +from typing import Callable + + +class StorageBackendFactory: + _backends: dict[str, Callable] = {} + + @classmethod + def register_backend(cls, name: str): + """Decorator to register storage backend & returns function.""" + + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + cls._backends[name.lower()] = wrapper + return wrapper + + return decorator + + @classmethod + def get_backend(cls, name: str) -> Callable | None: + """Get storage backend function by name.""" + return cls._backends.get(name.lower(), None) diff --git a/transfer_queue/storage/backends/mooncake_storage.py b/transfer_queue/storage/backends/mooncake_storage.py new file mode 100644 index 00000000..70116e8d --- /dev/null +++ b/transfer_queue/storage/backends/mooncake_storage.py @@ -0,0 +1,127 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2026 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import time +from urllib.parse import urlparse + +from omegaconf import DictConfig + +from transfer_queue.storage.backends.base import StorageBackendFactory +from transfer_queue.utils.logging_utils import get_logger + +logger = get_logger(__name__) + + +@StorageBackendFactory.register_backend("MooncakeStore") +def initialize_mooncake_backend(conf: DictConfig) -> DictConfig: + """ + Initialize MooncakeStore backend. + Args: + conf (DictConfig): Configuration dictionary for the MooncakeStore backend. + Returns: + DictConfig: Initialized configuration dictionary for the MooncakeStore backend. + Raises: + ValueError: If the backend is not initialized successfully. + """ + if not conf.backend.MooncakeStore.auto_init: + return None + + # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts + check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) + if check.returncode == 0: + pids = check.stdout.strip().replace("\n", ", ") + logger.info(f"Find existing mooncake_master (PID: {pids}), try to kill first...") + + result = os.system('pkill -f "[m]ooncake_master"') + if result == 0: + logger.info("Successfully killed existing mooncake_master processes.") + else: + raise RuntimeError(f"Failed to kill existing mooncake_master processes (exit code: {result}).") + + # process metadata_server + metadata_server_raw_address = conf.backend.MooncakeStore.metadata_server + if "://" not in metadata_server_raw_address: + metadata_server_raw_address = "//" + metadata_server_raw_address + + metadata_server_parsed = urlparse(metadata_server_raw_address) + + if not metadata_server_parsed.hostname or metadata_server_parsed.port is None: + raise ValueError( + f"Invalid metadata_server '{conf.backend.MooncakeStore.metadata_server}'. " + f"Host and port are required (e.g., host:port)." + ) + + metadata_server_host = metadata_server_parsed.hostname + metadata_server_port = str(metadata_server_parsed.port) + + # process master_server + master_server_raw_address = conf.backend.MooncakeStore.master_server_address + if "://" not in master_server_raw_address: + master_server_raw_address = "//" + master_server_raw_address + + master_server_parsed = urlparse(master_server_raw_address) + + if not master_server_parsed.hostname or master_server_parsed.port is None: + raise ValueError( + f"Invalid master_server_address '{conf.backend.MooncakeStore.master_server_address}'. " + f"Host and port are required (e.g., host:port)." + ) + + master_server_port = str(master_server_parsed.port) + + cmd = [ + "mooncake_master", + "-client_ttl=30", + "-default_kv_lease_ttl=999999", + "-default_kv_soft_pin_ttl=999999", + "--eviction_high_watermark_ratio=1.0", + "--eviction_ratio=0.0", + "--enable_http_metadata_server=true", + "--allow_evict_soft_pinned_objects=false", + f"--http_metadata_server_host={metadata_server_host}", + f"--http_metadata_server_port={metadata_server_port}", + f"--rpc_port={master_server_port}", + ] + + log_file_path = "/tmp/mooncake_master.log" + with open(log_file_path, "w") as log_file: + process = subprocess.Popen( + cmd, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True, + start_new_session=True, + ) + time.sleep(3) + + if process.poll() is None: + logger.info(f"mooncake_master started, PID: {process.pid}. Logs are at: {os.path.abspath(log_file_path)}") + else: + error_msg = "" + try: + with open(log_file_path) as f: + error_msg = f.read() + except Exception as e: + error_msg = f"Failed to read log file: {e}" + + raise RuntimeError( + f"mooncake_master exited with error. Check {log_file_path} for detailed logs. Output:\n{error_msg}" + ) + + return process diff --git a/transfer_queue/storage/backends/simple_storage.py b/transfer_queue/storage/backends/simple_storage.py new file mode 100644 index 00000000..bb4fea4a --- /dev/null +++ b/transfer_queue/storage/backends/simple_storage.py @@ -0,0 +1,663 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import time +import weakref +from threading import Event, Thread +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +import psutil +import ray +import zmq +from omegaconf import DictConfig + +from transfer_queue.storage.backends.base import StorageBackendFactory +from transfer_queue.utils.common import get_placement_group, limit_pytorch_auto_parallel_threads +from transfer_queue.utils.enum_utils import Role +from transfer_queue.utils.logging_utils import get_logger +from transfer_queue.utils.perf_utils import IntervalPerfMonitor +from transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, + format_zmq_address, + get_free_port, + get_node_ip_address, + process_zmq_server_info, +) + +if TYPE_CHECKING: + from transfer_queue.metrics import TQMetricsExporter + +logger = get_logger(__name__) + +TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5)) # in seconds +TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8)) + + +class StorageUnitData: + """Storage unit for managing 2D data structure (samples × fields). + + Uses dict-based storage keyed by global_index instead of pre-allocated list. + This allows O(1) insert/delete without index translation and avoids capacity bloat. + + Data Structure Example: + field_data = { + "field_name1": {global_index_0: item1, global_index_3: item2, ...}, + "field_name2": {global_index_0: item3, global_index_3: item4, ...}, + } + """ + + def __init__(self, storage_size: int): + # field_name -> {global_index: data} nested dict + self.field_data: dict[str, dict] = {} + # Capacity upper bound (not pre-allocated list length) + self.storage_size = storage_size + # Track active global_index keys for O(1) capacity checks + self._active_keys: set = set() + + @property + def active_key_count(self) -> int: + """Number of active keys currently stored.""" + return len(self._active_keys) + + def get_data(self, fields: list[str], global_indexes: list) -> dict[str, list]: + """Get data by global index keys. + + Args: + fields: Field names used for getting data. + global_indexes: Global indexes used as dict keys. + + Returns: + dict with field names as keys, corresponding data list as values. + """ + result: dict[str, list] = {} + for field in fields: + if field not in self.field_data: + raise ValueError( + f"StorageUnitData get_data: field '{field}' not found. Available: {list(self.field_data.keys())}" + ) + try: + result[field] = [self.field_data[field][k] for k in global_indexes] + except KeyError as e: + raise KeyError(f"StorageUnitData get_data: key {e} not found in field '{field}'") from e + return result + + def put_data(self, field_data: dict[str, Any], global_indexes: list) -> None: + """Put data into storage. + + Args: + field_data: Dict with field names as keys, data list as values. + global_indexes: Global indexes to use as dict keys. + """ + # Capacity is enforced per unique sample key, not counted per-field + new_global_keys = [k for k in global_indexes if k not in self._active_keys] + if len(self._active_keys) + len(new_global_keys) > self.storage_size: + raise ValueError( + f"Storage capacity exceeded: {len(self._active_keys)} existing + " + f"{len(new_global_keys)} new > {self.storage_size}" + ) + for f, values in field_data.items(): + if len(values) != len(global_indexes): + raise ValueError( + f"StorageUnitData put_data: field '{f}' values length {len(values)} " + f"!= global_indexes length {len(global_indexes)}, length mismatch" + ) + if f not in self.field_data: + self.field_data[f] = {} + field_dict = self.field_data[f] + for key, val in zip(global_indexes, values, strict=True): + field_dict[key] = val + self._active_keys.update(global_indexes) + + def clear(self, keys: list[int]) -> None: + """Remove data at given global index keys, immediately freeing memory. + + Args: + keys: Global indexes to remove. + """ + for f in self.field_data: + for key in keys: + self.field_data[f].pop(key, None) + self._active_keys -= set(keys) + + +@ray.remote(num_cpus=1) +class SimpleStorageUnit: + """A storage unit that provides distributed data storage functionality. + + This class represents a storage unit that can store data in a 2D structure + (samples, data_fields) and provides ZMQ-based communication for put/get/clear operations. + + Note: We use Ray decorator (@ray.remote) only for initialization purposes. + We do NOT use Ray's .remote() call capabilities - the storage unit runs + as a standalone process with its own ZMQ server socket. + + Attributes: + storage_unit_id: Unique identifier for this storage unit. + storage_unit_size: Maximum number of elements that can be stored. + storage_data: Internal StorageUnitData instance for data management. + zmq_server_info: ZMQ connection information for clients. + """ + + def __init__(self, storage_unit_size: int): + """Initialize a SimpleStorageUnit with the specified size. + + Args: + storage_unit_size: Maximum number of elements that can be stored in this storage unit. + """ + self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4().hex[:8]}" + self.storage_unit_size = storage_unit_size + + self.storage_data = StorageUnitData(self.storage_unit_size) + + # Internal communication address for proxy and workers + self._inproc_addr = f"inproc://simple_storage_workers_{self.storage_unit_id}" + + # Shutdown event for graceful termination + self._shutdown_event = Event() + + # Placeholder for zmq_context, proxy_thread and worker_threads + self.zmq_context: zmq.Context | None = None + self.put_get_socket: zmq.Socket | None = None + self.proxy_thread: Thread | None = None + self.worker_thread: Thread | None = None + + self._metrics: TQMetricsExporter | None = None + + self._init_zmq_socket() + self._start_process_put_get() + + # Register finalizer for graceful cleanup when garbage collected + self._finalizer = weakref.finalize( + self, + self._shutdown_resources, + self._shutdown_event, + self.worker_thread, + self.proxy_thread, + self.zmq_context, + self.put_get_socket, + ) + + def _init_zmq_socket(self) -> None: + """ + Initialize ZMQ socket connections between storage unit and controller/clients: + - put_get_socket (ROUTER): Handle put/get requests from clients. + - worker_socket (DEALER): Backend socket for worker communication. + """ + self.zmq_context = zmq.Context() + self._node_ip = get_node_ip_address() + + # Frontend: ROUTER for receiving client requests + self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, self._node_ip) + + while True: + try: + self._put_get_socket_port = get_free_port(ip=self._node_ip) + self.put_get_socket.bind(format_zmq_address(self._node_ip, self._put_get_socket_port)) + break + except zmq.ZMQError: + logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...") + continue + + # Backend: DEALER for worker communication (connected via zmq.proxy) + self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) + self.worker_socket.bind(self._inproc_addr) + + self.zmq_server_info = ZMQServerInfo( + role=Role.STORAGE, + id=str(self.storage_unit_id), + ip=self._node_ip, + ports={"put_get_socket": self._put_get_socket_port}, + ) + + def _start_process_put_get(self) -> None: + """Start worker threads and ZMQ proxy for handling requests.""" + + # Start worker thread + self.worker_thread = Thread( + target=self._worker_routine, + name=f"StorageUnitWorkerThread-{self.storage_unit_id}", + daemon=True, + ) + self.worker_thread.start() + + time.sleep(0.5) # make sure worker thread is ready before zmq.proxy forwarding messages + + # Start proxy thread (ROUTER <-> DEALER) + self.proxy_thread = Thread( + target=self._proxy_routine, + name=f"StorageUnitProxyThread-{self.storage_unit_id}", + daemon=True, + ) + self.proxy_thread.start() + + def _proxy_routine(self) -> None: + """ZMQ proxy for message forwarding between frontend ROUTER and backend DEALER.""" + logger.info(f"[{self.storage_unit_id}]: start ZMQ proxy...") + try: + zmq.proxy(self.put_get_socket, self.worker_socket) + except zmq.ContextTerminated: + logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy stopped gracefully (Context Terminated)") + except Exception as e: + if self._shutdown_event.is_set(): + logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy shutting down...") + else: + logger.error(f"[{self.storage_unit_id}]: ZMQ Proxy unexpected error: {e}") + + def _worker_routine(self) -> None: + """Worker thread for processing requests.""" + + worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) + worker_socket.connect(self._inproc_addr) + + poller = zmq.Poller() + poller.register(worker_socket, zmq.POLLIN) + + logger.info(f"[{self.storage_unit_id}]: worker thread started...") + perf_monitor = IntervalPerfMonitor(caller_name=f"{self.storage_unit_id}") + + while not self._shutdown_event.is_set(): + monitor = self._metrics if self._metrics is not None else perf_monitor + try: + socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) + except zmq.error.ContextTerminated: + # ZMQ context was terminated, exit gracefully + logger.info(f"[{self.storage_unit_id}]: worker stopped gracefully (Context Terminated)") + break + except Exception as e: + logger.warning(f"[{self.storage_unit_id}]: worker poll error: {e}") + continue + + if self._shutdown_event.is_set(): + break + + if worker_socket in socks: + # Messages received from proxy: [identity, serialized_msg_frame1, ...] + messages = worker_socket.recv_multipart(copy=False) + identity = messages[0] + serialized_msg = messages[1:] + + request_msg = ZMQMessage.deserialize(serialized_msg) + operation = request_msg.request_type + + try: + logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}") + + # Process request + if operation == ZMQRequestType.PUT_DATA: # type: ignore[arg-type] + with monitor.measure(op_type="PUT_DATA"): + response_msg = self._handle_put(request_msg) + elif operation == ZMQRequestType.GET_DATA: # type: ignore[arg-type] + with monitor.measure(op_type="GET_DATA"): + response_msg = self._handle_get(request_msg) + elif operation == ZMQRequestType.CLEAR_DATA: # type: ignore[arg-type] + with monitor.measure(op_type="CLEAR_DATA"): + response_msg = self._handle_clear(request_msg) + elif operation == ZMQRequestType.GET_METRICS: # type: ignore[arg-type] + response_msg = self._handle_get_metrics() + else: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "message": f"Storage unit id #{self.storage_unit_id} " + f"receive invalid operation: {operation}." + }, + ) + except Exception as e: + logger.error( + f"[{self.storage_unit_id}]: worker error during {operation} " + f"from sender={request_msg.sender_id}: {type(e).__name__}: {e}" + ) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_GET_ERROR, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "message": f"{self.storage_unit_id}, worker encountered error " + f"during operation {operation}: {str(e)}." + }, + ) + + # Send response back with identity for routing + worker_socket.send_multipart([identity] + response_msg.serialize(), copy=False) + + logger.info(f"[{self.storage_unit_id}]: worker stopped.") + poller.unregister(worker_socket) + worker_socket.close(linger=0) + + def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: + """ + Handle put request, add or update data into storage unit. + + Args: + data_parts: ZMQMessage from client. + + Returns: + Put data success response ZMQMessage. + """ + try: + global_indexes = data_parts.body["global_indexes"] + field_data = data_parts.body["data"] # field_data should be a dict. + data_parser = data_parts.body.get("data_parser", None) + + with limit_pytorch_auto_parallel_threads( + target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_put" + ): + if data_parser is not None: + if not callable(data_parser): + raise TypeError(f"data_parser must be callable, got {type(data_parser).__name__}") + + original_keys = set(field_data.keys()) + original_lengths = {} + for k, v in field_data.items(): + if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: + original_lengths[k] = v.shape[0] + else: + try: + original_lengths[k] = len(v) + except Exception: + original_lengths[k] = None + + field_data = data_parser(field_data) + + if not isinstance(field_data, dict): + raise TypeError(f"data_parser must return a dict, got {type(field_data).__name__}") + + new_keys = set(field_data.keys()) + if new_keys != original_keys: + raise ValueError( + f"data_parser must not change dict keys. " + f"Original keys: {sorted(original_keys)}, got: {sorted(new_keys)}" + ) + + for k, v in field_data.items(): + if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: + new_len = v.shape[0] + else: + try: + new_len = len(v) + except Exception: + new_len = None + + orig_len = original_lengths[k] + if orig_len is not None and new_len is not None and orig_len != new_len: + raise ValueError( + f"data_parser changed the number of elements for key '{k}': " + f"expected {orig_len}, got {new_len}" + ) + self.storage_data.put_data(field_data, global_indexes) + + # After put operation finish, send a message to the client + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_DATA_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={}, + ) + + return response_msg + except Exception as e: + return ZMQMessage.create( + request_type=ZMQRequestType.PUT_ERROR, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "message": f"Failed to put data into storage unit id " + f"#{self.storage_unit_id}, detail error message: {str(e)}" + }, + ) + + def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: + """ + Handle get request, return data from storage unit. + + Args: + data_parts: ZMQMessage from client. + + Returns: + Get data success response ZMQMessage, containing target data. + """ + try: + fields = data_parts.body["fields"] + global_indexes = data_parts.body["global_indexes"] + + with limit_pytorch_auto_parallel_threads( + target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_get" + ): + result_data = self.storage_data.get_data(fields, global_indexes) + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_DATA_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "data": result_data, + }, + ) + except Exception as e: + logger.error( + f"[{self.storage_unit_id}]: _handle_get error, " + f"fields={fields}, global_indexes={global_indexes}: {type(e).__name__}: {e}" + ) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_ERROR, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "message": f"Failed to get data from storage unit id #{self.storage_unit_id}, " + f"detail error message: {str(e)}" + }, + ) + return response_msg + + def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: + """ + Handle clear request, clear data in storage unit according to given global_indexes. + + Args: + data_parts: ZMQMessage from client, including target global_indexes. + + Returns: + Clear data success response ZMQMessage. + """ + try: + global_indexes = data_parts.body["global_indexes"] + + with limit_pytorch_auto_parallel_threads( + target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_clear" + ): + self.storage_data.clear(global_indexes) + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."}, + ) + except Exception as e: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CLEAR_DATA_ERROR, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, " + f"detail error message: {str(e)}" + }, + ) + return response_msg + + def _handle_get_metrics(self) -> ZMQMessage: + """Handle GET_METRICS request by returning storage unit statistics. + + Returns: + ZMQMessage containing storage unit ID, capacity, active keys, + process RSS memory, and per-operation request stats. + """ + try: + process_rss = psutil.Process().memory_info().rss + except Exception: + process_rss = 0 + + metrics = { + "storage_unit_id": self.storage_unit_id, + "capacity": self.storage_unit_size, + "active_keys": self.storage_data.active_key_count, + "process_rss_bytes": process_rss, + } + + # Include per-operation stats if Prometheus metrics are enabled + if self._metrics is not None: + op_stats = {} + for op_type in ("PUT_DATA", "GET_DATA", "CLEAR_DATA"): + try: + hist = self._metrics.request_duration.labels(op_type=op_type) + counter = self._metrics.request_total.labels(op_type=op_type) + duration_sum = hist._sum.get() + # Build cumulative counts once, reuse for total and quantiles + cumulative_counts = self._cumulative_bucket_counts(hist) + duration_count = cumulative_counts[-1] if cumulative_counts else 0 + op_stats[op_type] = { + "request_count": counter._value.get(), + "latency_avg": duration_sum / duration_count if duration_count > 0 else 0, + "latency_p50": self._quantile_from_cumulative(hist, cumulative_counts, 0.50), + "latency_p99": self._quantile_from_cumulative(hist, cumulative_counts, 0.99), + } + except (AttributeError, TypeError, ZeroDivisionError) as e: + logger.debug(f"[{self.storage_unit_id}]: Failed to extract metrics for {op_type}: {e}") + if op_stats: + metrics["op_stats"] = op_stats + + return ZMQMessage.create( + request_type=ZMQRequestType.METRICS_RESPONSE, + sender_id=self.storage_unit_id, + body=metrics, + ) + + @staticmethod + def _cumulative_bucket_counts(hist) -> list[float]: + """Build cumulative counts from a prometheus_client Histogram's non-cumulative buckets.""" + cumulative = 0.0 + counts = [] + for bucket in hist._buckets: + cumulative += bucket.get() + counts.append(cumulative) + return counts + + @staticmethod + def _quantile_from_cumulative(hist, cumulative_counts: list[float], q: float) -> float: + """Estimate a quantile using pre-computed cumulative bucket counts. + + Uses linear interpolation matching Prometheus histogram_quantile() logic. + """ + total = cumulative_counts[-1] if cumulative_counts else 0 + if total == 0: + return 0.0 + target = q * total + prev_bound = 0.0 + prev_cumulative = 0.0 + for bound, cum_count in zip(hist._upper_bounds, cumulative_counts, strict=False): + if cum_count >= target: + fraction = ( + (target - prev_cumulative) / (cum_count - prev_cumulative) if cum_count > prev_cumulative else 0 + ) + return prev_bound + (bound - prev_bound) * fraction + prev_bound = bound + prev_cumulative = cum_count + return prev_bound + + @staticmethod + def _shutdown_resources( + shutdown_event: Event, + worker_thread: Thread | None, + proxy_thread: Thread | None, + zmq_context: zmq.Context | None, + put_get_socket: zmq.Socket | None, + ) -> None: + """Clean up resources on garbage collection.""" + logger.info("Shutting down SimpleStorageUnit resources...") + + # Signal all threads to stop + shutdown_event.set() + + # Terminate put_get_socket + if put_get_socket: + put_get_socket.close(linger=0) + + # Terminate ZMQ context to unblock proxy and workers + if zmq_context: + zmq_context.term() + + # Wait for threads to finish (with timeout) + if worker_thread and worker_thread.is_alive(): + worker_thread.join(timeout=5) + if proxy_thread and proxy_thread.is_alive(): + proxy_thread.join(timeout=5) + + logger.info("SimpleStorageUnit resources shutdown complete.") + + def start_metrics(self, port: int = 0) -> str: + """Initialize and start the Prometheus metrics exporter for this storage unit. + + When enabled, replaces ``IntervalPerfMonitor`` for request latency/throughput + tracking with Prometheus counters and histograms. + + Args: + port: HTTP port for the /metrics endpoint (0 = auto-assign). + + Returns: + The metrics endpoint address in ``host:port`` format. + """ + if self._metrics is not None: + return self._metrics.endpoint + from transfer_queue.metrics import TQMetricsExporter + + self._metrics = TQMetricsExporter(role="storage") + endpoint = self._metrics.start(node_ip=self._node_ip, port=port) + logger.info(f"[{self.storage_unit_id}]: Prometheus metrics exporter started on {endpoint}") + return endpoint + + def get_zmq_server_info(self) -> ZMQServerInfo: + """Get the ZMQ server information for this storage unit. + + Returns: + ZMQServerInfo containing connection details for this storage unit. + """ + return self.zmq_server_info + + +@StorageBackendFactory.register_backend("SimpleStorage") +def initialize_simple_backend(conf: DictConfig) -> dict[str, Any]: + """Initialize Simple backend with metastore mode.""" + + simple_storage_handles = {} + num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units + total_storage_size = conf.backend.SimpleStorage.total_storage_size + storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) + + for storage_unit_rank in range(num_data_storage_units): + storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] + placement_group=storage_placement_group, + placement_group_bundle_index=storage_unit_rank, + name=f"TransferQueueStorageUnit#{storage_unit_rank}", + ).remote( + storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), + ) + simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") + + storage_zmq_info = process_zmq_server_info(simple_storage_handles) + backend_name = conf.backend.storage_backend + conf.backend[backend_name].zmq_info = storage_zmq_info + + return simple_storage_handles diff --git a/transfer_queue/storage/backends/yuanrong_storage.py b/transfer_queue/storage/backends/yuanrong_storage.py new file mode 100644 index 00000000..57239768 --- /dev/null +++ b/transfer_queue/storage/backends/yuanrong_storage.py @@ -0,0 +1,427 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import subprocess +from typing import Any + +import ray +from omegaconf import DictConfig + +from transfer_queue.storage.backends.base import StorageBackendFactory +from transfer_queue.utils.yuanrong_utils import get_local_ip_addresses, kill_actors_and_placement_group + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + + +def _parse_remote_h2d_device_ids(worker_args: str) -> str | None: + """Parse --remote_h2d_device_ids parameter from worker_args string. + + Args: + worker_args: Worker arguments string, e.g., "--arg1 value1 --remote_h2d_device_ids 0,1,2,3" + + Returns: + The device IDs string if found and valid, None otherwise. + + Raises: + RuntimeError: If --remote_h2d_device_ids flag is found but has invalid format. + """ + if not worker_args: + return None + + args_list = worker_args.split() + + # Find the index of --remote_h2d_device_ids + try: + idx = args_list.index("--remote_h2d_device_ids") + except ValueError: + return None + + # Check if there's a value after the flag + if idx + 1 >= len(args_list): + raise RuntimeError("--remote_h2d_device_ids flag found but no value provided") + + device_ids = args_list[idx + 1] + + # Validate the format: comma-separated digits + if not device_ids: + raise RuntimeError("Empty device IDs value after --remote_h2d_device_ids") + + # Validate each segment is a digit + parts = device_ids.split(",") + for part in parts: + if not part.isdigit(): + raise RuntimeError( + f"Invalid device ID format: '{device_ids}'. Expected comma-separated digits (e.g., '0,1,2,3')." + ) + + return device_ids + + +def start_datasystem_worker( + worker_address: str, + metastore_address: str, + is_head: bool, + worker_args: str = "", +) -> None: + """Start Yuanrong datasystem worker in metastore mode. + + Args: + worker_address: Worker address in format host:port + metastore_address: Metastore address in format host:port + is_head: Whether this node should start metastore service + worker_args: Additional arguments to append to dscli start command + + Raises: + RuntimeError: If dscli command fails + """ + if not shutil.which("dscli"): + raise RuntimeError("dscli executable not found in PATH. Please run `pip install openyuanrong-datasystem`.") + + cmd = ["dscli", "start", "-w", "--worker_address", worker_address] + cmd.extend(["--metastore_address", metastore_address]) + if is_head: + cmd.extend(["--start_metastore_service", "true"]) + + # Built-in default options + cmd.extend(["--arena_per_tenant", "1", "--enable_worker_worker_batch_get", "true"]) + + # Append worker_args if provided + if worker_args: + cmd.extend(worker_args.split()) + + node_type = "head node" if is_head else "worker node" + logger.info(f"Starting Yuanrong datasystem ({node_type}) at {worker_address}, worker_args={worker_args}") + + # Build environment with ASCEND_RT_VISIBLE_DEVICES if specified + env = None + device_ids = _parse_remote_h2d_device_ids(worker_args) + if device_ids: + env = os.environ.copy() + env["ASCEND_RT_VISIBLE_DEVICES"] = device_ids + logger.info( + f"Setting ASCEND_RT_VISIBLE_DEVICES={device_ids} for dscli subprocess ({node_type} at {worker_address})" + ) + + try: + ds_result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=90, + env=env, + ) + except subprocess.TimeoutExpired as err: + raise RuntimeError(f"dscli start timed out: {err}") from err + + if ds_result.returncode == 0 and "[ OK ]" in ds_result.stdout: + logger.info( + f"dscli started Yuanrong datasystem ({node_type}, metastore mode) at {worker_address} successfully." + ) + else: + raise RuntimeError( + f"Failed to start datasystem ({node_type}, metastore mode) at {worker_address}. " + f"Return code: {ds_result.returncode}, Output: {ds_result.stdout}" + ) + + +def stop_datasystem_worker(worker_address: str) -> None: + """Stop Yuanrong datasystem worker. + + Args: + worker_address: Worker address in format host:port + """ + if worker_address: + try: + result = subprocess.run( + ["dscli", "stop", "--worker_address", worker_address], + timeout=90, + capture_output=True, + ) + if result.returncode == 0: + logger.info(f"Stopped datasystem worker at {worker_address} via dscli stop") + else: + error_msg = (result.stderr or result.stdout or b"").decode() + logger.warning( + f"Failed to stop datasystem worker at {worker_address}. " + f"Return code: {result.returncode}, Error: {error_msg}" + ) + except subprocess.TimeoutExpired as err: + logger.warning(f"dscli stop timed out for {worker_address}: {err}") + except Exception as e: + logger.warning(f"Failed to stop datasystem worker via dscli: {e}") + + +@ray.remote(num_cpus=0.1) +class YuanrongWorkerActor: + """Ray actor to manage Yuanrong datasystem worker on a node. + + This actor runs on each node in the Ray cluster and is responsible for + starting and stopping the Yuanrong datasystem worker process on that node. + + The actor determines its own rank and role (head or worker) by finding the + intersection of local IP addresses with the provided node IPs. + """ + + def __init__(self, node_ips: list[str], worker_port: int, metastore_port: int, worker_args: str = ""): + """Initialize the Yuanrong worker actor. + + Args: + node_ips: List of all node IPs in the Ray cluster + worker_port: Port for the datasystem worker + metastore_port: Port for the metastore service (on head node) + worker_args: Additional arguments to append to dscli start command + + Raises: + RuntimeError: If cannot determine this node's IP from node_ips + """ + local_ips = get_local_ip_addresses() + self.my_ip = None + + # Find the intersection between local IPs and node_ips + for ip in node_ips: + if ip in local_ips: + self.my_ip = ip + break + + if self.my_ip is None: + raise RuntimeError(f"Cannot determine local node IP. Local IPs: {local_ips}, Cluster node IPs: {node_ips}") + + self.node_ips = node_ips + self.worker_port = worker_port + self.metastore_port = metastore_port + self.worker_address = f"{self.my_ip}:{worker_port}" + self.worker_args = worker_args + + # First node in the list is assumed to be the head node. + # This assumption is based on how interface.py constructs node_ips from ray.nodes(). + self.head_node_ip = node_ips[0] + self.metastore_address = f"{self.head_node_ip}:{metastore_port}" + self.is_head = self.my_ip == self.head_node_ip + + logger.info( + f"YuanrongWorkerActor initialized on node {self.my_ip}: " + f"worker_address={self.worker_address}, " + f"metastore_address={self.metastore_address}, is_head={self.is_head}, worker_args={self.worker_args}" + ) + + def start(self) -> str: + """Start the datasystem worker on this node. + + Returns: + The worker address. + + Raises: + RuntimeError: If dscli command fails + """ + logger.info(f"Starting datasystem worker at {self.worker_address}...") + start_datasystem_worker( + self.worker_address, + metastore_address=self.metastore_address, + is_head=self.is_head, + worker_args=self.worker_args, + ) + logger.info(f"Datasystem worker started successfully at {self.worker_address}") + return self.worker_address + + def get_metastore_address(self) -> str: + """Get the metastore address. + + Returns: + The metastore address in format host:port + """ + return self.metastore_address + + def get_node_ip(self) -> str: + """Return the IP address of the node this actor is running on.""" + assert self.my_ip is not None + return self.my_ip + + def stop(self) -> None: + """Stop the datasystem worker on this node.""" + logger.info(f"Stopping datasystem worker at {self.worker_address}...") + stop_datasystem_worker(self.worker_address) + logger.info(f"Datasystem worker stopped successfully at {self.worker_address}") + + +@StorageBackendFactory.register_backend("Yuanrong") +def initialize_yuanrong_backend(conf: DictConfig) -> dict[str, Any] | None: + """Initialize Yuanrong backend with metastore mode. + + This function sets up the Yuanrong datasystem workers across all Ray nodes + using placement groups and actors. + + Args: + conf: Configuration containing Yuanrong settings + + Returns: + Dict containing worker_actors, metastore_address, and placement_group + + Raises: + RuntimeError: If Ray nodes not found or initialization fails + """ + if not conf.backend.Yuanrong.auto_init: + return None + + # Get Ray cluster information + nodes = ray.nodes() + if not nodes: + raise RuntimeError("No Ray nodes found. Is Ray initialized?") + + # Filter to only alive nodes and get their IPs + alive_nodes = [node for node in nodes if node.get("Alive", False)] + if not alive_nodes: + raise RuntimeError("No alive Ray nodes found") + + # Get driver node IP to use as head node + driver_ip = ray.util.get_node_ip_address() + head_node = None + other_nodes = [] + + # Separate head node (driver) from other nodes + for node in alive_nodes: + node_ip = node["NodeManagerAddress"] + if node_ip == driver_ip: + head_node = node + else: + other_nodes.append(node) + + if head_node is None: + raise RuntimeError(f"Driver node {driver_ip} not found in alive nodes") + + # Reorder nodes: head node first, then others + ordered_nodes = [head_node] + other_nodes + + # Extract node IPs in deterministic order + node_ips = [node["NodeManagerAddress"] for node in ordered_nodes] + worker_port = conf.backend.Yuanrong.worker_port + metastore_port = conf.backend.Yuanrong.metastore_port + worker_args = conf.backend.Yuanrong.get("worker_args", "") + + logger.info(f"Found {len(ordered_nodes)} alive Ray nodes: {node_ips}") + + # Create placement group using STRICT_SPREAD to ensure each bundle is on a distinct node + bundles = [{"CPU": 0.1} for _ in ordered_nodes] + + pg = ray.util.placement_group(bundles, strategy="STRICT_SPREAD") + try: + ray.get(pg.ready(), timeout=60) + except ray.exceptions.GetTimeoutError as e: + try: + ray.util.remove_placement_group(pg) + except Exception as cleanup_error: + logger.warning(f"Failed to remove placement group after readiness timeout: {cleanup_error}") + raise RuntimeError( + "Timed out waiting for Yuanrong placement group to become ready. " + f"Requested strategy=STRICT_SPREAD, bundles={bundles}. " + "This may be due to insufficient cluster capacity." + ) from e + except Exception as e: + try: + ray.util.remove_placement_group(pg) + except Exception as cleanup_error: + logger.warning(f"Failed to remove placement group after scheduling failure: {cleanup_error}") + raise RuntimeError( + f"Failed to create Yuanrong placement group. Requested strategy=STRICT_SPREAD, bundles={bundles}." + ) from e + + logger.info(f"Created placement group with {len(bundles)} bundles using STRICT_SPREAD") + + try: + # Create all worker actors using placement group + # Without node resources, actor scheduling order is not guaranteed to match node order + # We'll identify head node actor by checking which node it runs on + worker_actors = [] + for rank in range(len(ordered_nodes)): + actor = YuanrongWorkerActor.options( # type: ignore[attr-defined] + placement_group=pg, + placement_group_bundle_index=rank, + ).remote(node_ips, worker_port, metastore_port, worker_args) + worker_actors.append(actor) + + logger.info(f"Created {len(worker_actors)} YuanrongWorkerActor instances") + + # Find which actor is running on the head node (driver IP) + # The head node actor needs to start first to initialize metastore service + head_actor_index = None + for idx, actor in enumerate(worker_actors): + try: + node_ip = ray.get(actor.get_node_ip.remote()) + if node_ip == driver_ip: + head_actor_index = idx + break + except Exception: + pass + + if head_actor_index is None: + logger.warning("Could not identify head node actor, using actor 0 as default") + head_actor_index = 0 + + logger.info(f"Head node actor identified: actor {head_actor_index}") + + # Start head worker first to initialize metastore service + logger.info("Starting head worker to initialize metastore...") + ray.get(worker_actors[head_actor_index].start.remote()) + metastore_address = ray.get(worker_actors[head_actor_index].get_metastore_address.remote()) + logger.info(f"Head worker started, metastore address: {metastore_address}") + + # Start remaining worker actors in parallel + other_actors = [worker_actors[i] for i in range(len(worker_actors)) if i != head_actor_index] + if other_actors: + logger.info(f"Starting {len(other_actors)} worker actors in parallel...") + ray.get([actor.start.remote() for actor in other_actors]) + + logger.info( + f"Yuanrong backend started successfully: metastore at {metastore_address}, workers on {len(node_ips)} nodes" + ) + + return { + "worker_actors": worker_actors, + "metastore_address": metastore_address, + "placement_group": pg, + } + except Exception as e: + # Cleanup on initialization failure: attempt graceful stop of started workers first + logger.error(f"Failed to start Yuanrong workers: {e}, cleaning up...") + + # Try to gracefully stop workers that may have already started + if worker_actors: + stop_exceptions = [] + # Stop worker nodes (all except head node 0) first + if len(worker_actors) > 1: + stop_refs = [actor.stop.remote() for actor in worker_actors[1:]] + for idx, stop_ref in enumerate(stop_refs, start=1): + try: + ray.get(stop_ref, timeout=30) + except Exception as stop_e: + stop_exceptions.append(stop_e) + logger.warning(f"Failed to stop worker node actor {idx}: {stop_e}") + # Stop head node (actor 0) + try: + ray.get(worker_actors[0].stop.remote(), timeout=30) + except Exception as stop_e: + stop_exceptions.append(stop_e) + logger.warning(f"Failed to stop head node actor: {stop_e}") + + if stop_exceptions: + logger.warning(f"Encountered {len(stop_exceptions)} errors during graceful worker stop") + + # Then kill actors and remove placement group + kill_actors_and_placement_group(worker_actors, pg) + raise diff --git a/transfer_queue/utils/yuanrong_utils.py b/transfer_queue/utils/yuanrong_utils.py index 17d5ca82..5f8ddf6e 100644 --- a/transfer_queue/utils/yuanrong_utils.py +++ b/transfer_queue/utils/yuanrong_utils.py @@ -13,16 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging import os -import shutil import socket -import subprocess from typing import Any import ray -from omegaconf import DictConfig logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -126,238 +122,7 @@ def find_reachable_host(port: int, timeout: float = 1.0) -> str | None: return None -def _parse_remote_h2d_device_ids(worker_args: str) -> str | None: - """Parse --remote_h2d_device_ids parameter from worker_args string. - - Args: - worker_args: Worker arguments string, e.g., "--arg1 value1 --remote_h2d_device_ids 0,1,2,3" - - Returns: - The device IDs string if found and valid, None otherwise. - - Raises: - RuntimeError: If --remote_h2d_device_ids flag is found but has invalid format. - """ - if not worker_args: - return None - - args_list = worker_args.split() - - # Find the index of --remote_h2d_device_ids - try: - idx = args_list.index("--remote_h2d_device_ids") - except ValueError: - return None - - # Check if there's a value after the flag - if idx + 1 >= len(args_list): - raise RuntimeError("--remote_h2d_device_ids flag found but no value provided") - - device_ids = args_list[idx + 1] - - # Validate the format: comma-separated digits - if not device_ids: - raise RuntimeError("Empty device IDs value after --remote_h2d_device_ids") - - # Validate each segment is a digit - parts = device_ids.split(",") - for part in parts: - if not part.isdigit(): - raise RuntimeError( - f"Invalid device ID format: '{device_ids}'. Expected comma-separated digits (e.g., '0,1,2,3')." - ) - - return device_ids - - -def start_datasystem_worker( - worker_address: str, - metastore_address: str, - is_head: bool, - worker_args: str = "", -) -> None: - """Start Yuanrong datasystem worker in metastore mode. - - Args: - worker_address: Worker address in format host:port - metastore_address: Metastore address in format host:port - is_head: Whether this node should start metastore service - worker_args: Additional arguments to append to dscli start command - - Raises: - RuntimeError: If dscli command fails - """ - if not shutil.which("dscli"): - raise RuntimeError("dscli executable not found in PATH. Please run `pip install openyuanrong-datasystem`.") - - cmd = ["dscli", "start", "-w", "--worker_address", worker_address] - cmd.extend(["--metastore_address", metastore_address]) - if is_head: - cmd.extend(["--start_metastore_service", "true"]) - - # Built-in default options - cmd.extend(["--arena_per_tenant", "1", "--enable_worker_worker_batch_get", "true"]) - - # Append worker_args if provided - if worker_args: - cmd.extend(worker_args.split()) - - node_type = "head node" if is_head else "worker node" - logger.info(f"Starting Yuanrong datasystem ({node_type}) at {worker_address}, worker_args={worker_args}") - - # Build environment with ASCEND_RT_VISIBLE_DEVICES if specified - env = None - device_ids = _parse_remote_h2d_device_ids(worker_args) - if device_ids: - env = os.environ.copy() - env["ASCEND_RT_VISIBLE_DEVICES"] = device_ids - logger.info( - f"Setting ASCEND_RT_VISIBLE_DEVICES={device_ids} for dscli subprocess ({node_type} at {worker_address})" - ) - - try: - ds_result = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - timeout=90, - env=env, - ) - except subprocess.TimeoutExpired as err: - raise RuntimeError(f"dscli start timed out: {err}") from err - - if ds_result.returncode == 0 and "[ OK ]" in ds_result.stdout: - logger.info( - f"dscli started Yuanrong datasystem ({node_type}, metastore mode) at {worker_address} successfully." - ) - else: - raise RuntimeError( - f"Failed to start datasystem ({node_type}, metastore mode) at {worker_address}. " - f"Return code: {ds_result.returncode}, Output: {ds_result.stdout}" - ) - - -def stop_datasystem_worker(worker_address: str) -> None: - """Stop Yuanrong datasystem worker. - - Args: - worker_address: Worker address in format host:port - """ - if worker_address: - try: - result = subprocess.run( - ["dscli", "stop", "--worker_address", worker_address], - timeout=90, - capture_output=True, - ) - if result.returncode == 0: - logger.info(f"Stopped datasystem worker at {worker_address} via dscli stop") - else: - error_msg = (result.stderr or result.stdout or b"").decode() - logger.warning( - f"Failed to stop datasystem worker at {worker_address}. " - f"Return code: {result.returncode}, Error: {error_msg}" - ) - except subprocess.TimeoutExpired as err: - logger.warning(f"dscli stop timed out for {worker_address}: {err}") - except Exception as e: - logger.warning(f"Failed to stop datasystem worker via dscli: {e}") - - -@ray.remote(num_cpus=0.1) -class YuanrongWorkerActor: - """Ray actor to manage Yuanrong datasystem worker on a node. - - This actor runs on each node in the Ray cluster and is responsible for - starting and stopping the Yuanrong datasystem worker process on that node. - - The actor determines its own rank and role (head or worker) by finding the - intersection of local IP addresses with the provided node IPs. - """ - - def __init__(self, node_ips: list[str], worker_port: int, metastore_port: int, worker_args: str = ""): - """Initialize the Yuanrong worker actor. - - Args: - node_ips: List of all node IPs in the Ray cluster - worker_port: Port for the datasystem worker - metastore_port: Port for the metastore service (on head node) - worker_args: Additional arguments to append to dscli start command - - Raises: - RuntimeError: If cannot determine this node's IP from node_ips - """ - local_ips = get_local_ip_addresses() - self.my_ip = None - - # Find the intersection between local IPs and node_ips - for ip in node_ips: - if ip in local_ips: - self.my_ip = ip - break - - if self.my_ip is None: - raise RuntimeError(f"Cannot determine local node IP. Local IPs: {local_ips}, Cluster node IPs: {node_ips}") - - self.node_ips = node_ips - self.worker_port = worker_port - self.metastore_port = metastore_port - self.worker_address = f"{self.my_ip}:{worker_port}" - self.worker_args = worker_args - - # First node in the list is assumed to be the head node. - # This assumption is based on how interface.py constructs node_ips from ray.nodes(). - self.head_node_ip = node_ips[0] - self.metastore_address = f"{self.head_node_ip}:{metastore_port}" - self.is_head = self.my_ip == self.head_node_ip - - logger.info( - f"YuanrongWorkerActor initialized on node {self.my_ip}: " - f"worker_address={self.worker_address}, " - f"metastore_address={self.metastore_address}, is_head={self.is_head}, worker_args={self.worker_args}" - ) - - def start(self) -> str: - """Start the datasystem worker on this node. - - Returns: - The worker address. - - Raises: - RuntimeError: If dscli command fails - """ - logger.info(f"Starting datasystem worker at {self.worker_address}...") - start_datasystem_worker( - self.worker_address, - metastore_address=self.metastore_address, - is_head=self.is_head, - worker_args=self.worker_args, - ) - logger.info(f"Datasystem worker started successfully at {self.worker_address}") - return self.worker_address - - def get_metastore_address(self) -> str: - """Get the metastore address. - - Returns: - The metastore address in format host:port - """ - return self.metastore_address - - def get_node_ip(self) -> str: - """Return the IP address of the node this actor is running on.""" - assert self.my_ip is not None - return self.my_ip - - def stop(self) -> None: - """Stop the datasystem worker on this node.""" - logger.info(f"Stopping datasystem worker at {self.worker_address}...") - stop_datasystem_worker(self.worker_address) - logger.info(f"Datasystem worker stopped successfully at {self.worker_address}") - - -def _kill_actors_and_placement_group(worker_actors: list, placement_group: Any) -> None: +def kill_actors_and_placement_group(worker_actors: list, placement_group: Any) -> None: """Kill actors and remove placement group without stopping workers. Args: @@ -420,169 +185,6 @@ def cleanup_yuanrong_resources(storage_value: Any) -> None: logger.warning(f"Encountered {len(stop_exceptions)} errors while stopping workers") finally: # Kill actors and remove placement group even if graceful stop fails. - _kill_actors_and_placement_group(worker_actors, placement_group) + kill_actors_and_placement_group(worker_actors, placement_group) if placement_group: logger.info("Removed Yuanrong placement group") - - -def initialize_yuanrong_backend(conf: DictConfig) -> dict[str, Any]: - """Initialize Yuanrong backend with metastore mode. - - This function sets up the Yuanrong datasystem workers across all Ray nodes - using placement groups and actors. - - Args: - conf: Configuration containing Yuanrong settings - - Returns: - Dict containing worker_actors, metastore_address, and placement_group - - Raises: - RuntimeError: If Ray nodes not found or initialization fails - """ - # Get Ray cluster information - nodes = ray.nodes() - if not nodes: - raise RuntimeError("No Ray nodes found. Is Ray initialized?") - - # Filter to only alive nodes and get their IPs - alive_nodes = [node for node in nodes if node.get("Alive", False)] - if not alive_nodes: - raise RuntimeError("No alive Ray nodes found") - - # Get driver node IP to use as head node - driver_ip = ray.util.get_node_ip_address() - head_node = None - other_nodes = [] - - # Separate head node (driver) from other nodes - for node in alive_nodes: - node_ip = node["NodeManagerAddress"] - if node_ip == driver_ip: - head_node = node - else: - other_nodes.append(node) - - if head_node is None: - raise RuntimeError(f"Driver node {driver_ip} not found in alive nodes") - - # Reorder nodes: head node first, then others - ordered_nodes = [head_node] + other_nodes - - # Extract node IPs in deterministic order - node_ips = [node["NodeManagerAddress"] for node in ordered_nodes] - worker_port = conf.backend.Yuanrong.worker_port - metastore_port = conf.backend.Yuanrong.metastore_port - worker_args = conf.backend.Yuanrong.get("worker_args", "") - - logger.info(f"Found {len(ordered_nodes)} alive Ray nodes: {node_ips}") - - # Create placement group using STRICT_SPREAD to ensure each bundle is on a distinct node - bundles = [{"CPU": 0.1} for _ in ordered_nodes] - - pg = ray.util.placement_group(bundles, strategy="STRICT_SPREAD") - try: - ray.get(pg.ready(), timeout=60) - except ray.exceptions.GetTimeoutError as e: - try: - ray.util.remove_placement_group(pg) - except Exception as cleanup_error: - logger.warning(f"Failed to remove placement group after readiness timeout: {cleanup_error}") - raise RuntimeError( - "Timed out waiting for Yuanrong placement group to become ready. " - f"Requested strategy=STRICT_SPREAD, bundles={bundles}. " - "This may be due to insufficient cluster capacity." - ) from e - except Exception as e: - try: - ray.util.remove_placement_group(pg) - except Exception as cleanup_error: - logger.warning(f"Failed to remove placement group after scheduling failure: {cleanup_error}") - raise RuntimeError( - f"Failed to create Yuanrong placement group. Requested strategy=STRICT_SPREAD, bundles={bundles}." - ) from e - - logger.info(f"Created placement group with {len(bundles)} bundles using STRICT_SPREAD") - - try: - # Create all worker actors using placement group - # Without node resources, actor scheduling order is not guaranteed to match node order - # We'll identify head node actor by checking which node it runs on - worker_actors = [] - for rank in range(len(ordered_nodes)): - actor = YuanrongWorkerActor.options( # type: ignore[attr-defined] - placement_group=pg, - placement_group_bundle_index=rank, - ).remote(node_ips, worker_port, metastore_port, worker_args) - worker_actors.append(actor) - - logger.info(f"Created {len(worker_actors)} YuanrongWorkerActor instances") - - # Find which actor is running on the head node (driver IP) - # The head node actor needs to start first to initialize metastore service - head_actor_index = None - for idx, actor in enumerate(worker_actors): - try: - node_ip = ray.get(actor.get_node_ip.remote()) - if node_ip == driver_ip: - head_actor_index = idx - break - except Exception: - pass - - if head_actor_index is None: - logger.warning("Could not identify head node actor, using actor 0 as default") - head_actor_index = 0 - - logger.info(f"Head node actor identified: actor {head_actor_index}") - - # Start head worker first to initialize metastore service - logger.info("Starting head worker to initialize metastore...") - ray.get(worker_actors[head_actor_index].start.remote()) - metastore_address = ray.get(worker_actors[head_actor_index].get_metastore_address.remote()) - logger.info(f"Head worker started, metastore address: {metastore_address}") - - # Start remaining worker actors in parallel - other_actors = [worker_actors[i] for i in range(len(worker_actors)) if i != head_actor_index] - if other_actors: - logger.info(f"Starting {len(other_actors)} worker actors in parallel...") - ray.get([actor.start.remote() for actor in other_actors]) - - logger.info( - f"Yuanrong backend started successfully: metastore at {metastore_address}, workers on {len(node_ips)} nodes" - ) - - return { - "worker_actors": worker_actors, - "metastore_address": metastore_address, - "placement_group": pg, - } - except Exception as e: - # Cleanup on initialization failure: attempt graceful stop of started workers first - logger.error(f"Failed to start Yuanrong workers: {e}, cleaning up...") - - # Try to gracefully stop workers that may have already started - if worker_actors: - stop_exceptions = [] - # Stop worker nodes (all except head node 0) first - if len(worker_actors) > 1: - stop_refs = [actor.stop.remote() for actor in worker_actors[1:]] - for idx, stop_ref in enumerate(stop_refs, start=1): - try: - ray.get(stop_ref, timeout=30) - except Exception as stop_e: - stop_exceptions.append(stop_e) - logger.warning(f"Failed to stop worker node actor {idx}: {stop_e}") - # Stop head node (actor 0) - try: - ray.get(worker_actors[0].stop.remote(), timeout=30) - except Exception as stop_e: - stop_exceptions.append(stop_e) - logger.warning(f"Failed to stop head node actor: {stop_e}") - - if stop_exceptions: - logger.warning(f"Encountered {len(stop_exceptions)} errors during graceful worker stop") - - # Then kill actors and remove placement group - _kill_actors_and_placement_group(worker_actors, pg) - raise From cbb1a562d3ae573c53412a66faeea0d0333d106b Mon Sep 17 00:00:00 2001 From: fy2462 Date: Sun, 17 May 2026 20:13:09 +0800 Subject: [PATCH 2/5] [refact] Renamed files and code variables for PR comments. Signed-off-by: fy2462 --- scripts/put_benchmark.py | 2 +- tests/test_metadata.py | 2 +- tests/test_simple_storage_unit.py | 4 +- transfer_queue/interface.py | 18 +- transfer_queue/storage/__init__.py | 2 +- .../storage/backends/simple_storage.py | 663 ------------------ .../{backends => bootstrap}/__init__.py | 9 +- .../mooncake_bootstrap.py} | 8 +- .../base.py => bootstrap/provider.py} | 16 +- .../storage/bootstrap/simple_bootstrap.py | 54 ++ .../yuanrong_bootstrap.py} | 4 +- 11 files changed, 86 insertions(+), 696 deletions(-) delete mode 100644 transfer_queue/storage/backends/simple_storage.py rename transfer_queue/storage/{backends => bootstrap}/__init__.py (71%) rename transfer_queue/storage/{backends/mooncake_storage.py => bootstrap/mooncake_bootstrap.py} (93%) rename transfer_queue/storage/{backends/base.py => bootstrap/provider.py} (69%) create mode 100644 transfer_queue/storage/bootstrap/simple_bootstrap.py rename transfer_queue/storage/{backends/yuanrong_storage.py => bootstrap/yuanrong_bootstrap.py} (99%) diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index 55289f6f..c67bb54c 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -30,7 +30,7 @@ from transfer_queue import TransferQueueClient from transfer_queue.controller import TransferQueueController -from transfer_queue.storage.backends.simple_storage import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.zmq_utils import process_zmq_server_info diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e9218de8..6fc49d2e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -807,7 +807,7 @@ class TestStorageUnitDataStrict: def test_put_data_length_mismatch_raises(self): """put_data must raise when global_indexes and field values have different lengths.""" - from transfer_queue.storage.backends.simple_storage import StorageUnitData + from transfer_queue.storage.simple_storage import StorageUnitData sud = StorageUnitData(storage_size=10) # 3 indexes but only 2 values — must raise, not silently drop diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index c0d084b2..319a46e7 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -21,7 +21,7 @@ import torch import zmq -from transfer_queue.storage.backends.simple_storage import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType @@ -420,7 +420,7 @@ def test_storage_unit_data_direct(): def test_storage_unit_data_capacity_uses_active_keys(): """Capacity check must use _active_keys, not scan field_data.""" - from transfer_queue.storage.backends.simple_storage import StorageUnitData + from transfer_queue.storage.simple_storage import StorageUnitData storage = StorageUnitData(storage_size=3) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 0ee31508..933b42c3 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -30,7 +30,7 @@ from transfer_queue.metadata import KVBatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler -from transfer_queue.storage.backends.base import StorageBackendFactory +from transfer_queue.storage.bootstrap import StorageBootstrapProvider from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources from transfer_queue.utils.zmq_utils import process_zmq_server_info @@ -70,15 +70,17 @@ def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig: if _TQ_STORAGE is None: _TQ_STORAGE = {} backend_name = conf.backend.storage_backend - registered_backend_fn = StorageBackendFactory.get_backend(backend_name) - if registered_backend_fn: - backend_instance = registered_backend_fn(conf) - if backend_instance: - _TQ_STORAGE[backend_name] = backend_instance + provider_fn = StorageBootstrapProvider.get_provider(backend_name) + if provider_fn is not None: + backend_resources = provider_fn(conf) + if backend_resources is not None: + _TQ_STORAGE[backend_name] = backend_resources else: - logger.error(f"Not found available {backend_name} storage backend instance, please check the config.") + logger.error(f"Not found available {backend_name} storage resources, please check the config.") else: - logger.error(f"Storage backend {backend_name} not registered. Please add it to the StorageBackendFactory.") + logger.error( + f"Storage backend {backend_name} not registered. Please add it to the StorageBootstrapProvider." + ) return conf diff --git a/transfer_queue/storage/__init__.py b/transfer_queue/storage/__init__.py index 809eeecc..2fb1be46 100644 --- a/transfer_queue/storage/__init__.py +++ b/transfer_queue/storage/__init__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .backends import SimpleStorageUnit, StorageUnitData from .managers import ( AsyncSimpleStorageManager, MooncakeStorageManager, @@ -22,6 +21,7 @@ StorageManagerFactory, YuanrongStorageManager, ) +from .simple_storage import SimpleStorageUnit, StorageUnitData __all__ = [ "SimpleStorageUnit", diff --git a/transfer_queue/storage/backends/simple_storage.py b/transfer_queue/storage/backends/simple_storage.py deleted file mode 100644 index bb4fea4a..00000000 --- a/transfer_queue/storage/backends/simple_storage.py +++ /dev/null @@ -1,663 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import os -import time -import weakref -from threading import Event, Thread -from typing import TYPE_CHECKING, Any -from uuid import uuid4 - -import psutil -import ray -import zmq -from omegaconf import DictConfig - -from transfer_queue.storage.backends.base import StorageBackendFactory -from transfer_queue.utils.common import get_placement_group, limit_pytorch_auto_parallel_threads -from transfer_queue.utils.enum_utils import Role -from transfer_queue.utils.logging_utils import get_logger -from transfer_queue.utils.perf_utils import IntervalPerfMonitor -from transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, - format_zmq_address, - get_free_port, - get_node_ip_address, - process_zmq_server_info, -) - -if TYPE_CHECKING: - from transfer_queue.metrics import TQMetricsExporter - -logger = get_logger(__name__) - -TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5)) # in seconds -TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8)) - - -class StorageUnitData: - """Storage unit for managing 2D data structure (samples × fields). - - Uses dict-based storage keyed by global_index instead of pre-allocated list. - This allows O(1) insert/delete without index translation and avoids capacity bloat. - - Data Structure Example: - field_data = { - "field_name1": {global_index_0: item1, global_index_3: item2, ...}, - "field_name2": {global_index_0: item3, global_index_3: item4, ...}, - } - """ - - def __init__(self, storage_size: int): - # field_name -> {global_index: data} nested dict - self.field_data: dict[str, dict] = {} - # Capacity upper bound (not pre-allocated list length) - self.storage_size = storage_size - # Track active global_index keys for O(1) capacity checks - self._active_keys: set = set() - - @property - def active_key_count(self) -> int: - """Number of active keys currently stored.""" - return len(self._active_keys) - - def get_data(self, fields: list[str], global_indexes: list) -> dict[str, list]: - """Get data by global index keys. - - Args: - fields: Field names used for getting data. - global_indexes: Global indexes used as dict keys. - - Returns: - dict with field names as keys, corresponding data list as values. - """ - result: dict[str, list] = {} - for field in fields: - if field not in self.field_data: - raise ValueError( - f"StorageUnitData get_data: field '{field}' not found. Available: {list(self.field_data.keys())}" - ) - try: - result[field] = [self.field_data[field][k] for k in global_indexes] - except KeyError as e: - raise KeyError(f"StorageUnitData get_data: key {e} not found in field '{field}'") from e - return result - - def put_data(self, field_data: dict[str, Any], global_indexes: list) -> None: - """Put data into storage. - - Args: - field_data: Dict with field names as keys, data list as values. - global_indexes: Global indexes to use as dict keys. - """ - # Capacity is enforced per unique sample key, not counted per-field - new_global_keys = [k for k in global_indexes if k not in self._active_keys] - if len(self._active_keys) + len(new_global_keys) > self.storage_size: - raise ValueError( - f"Storage capacity exceeded: {len(self._active_keys)} existing + " - f"{len(new_global_keys)} new > {self.storage_size}" - ) - for f, values in field_data.items(): - if len(values) != len(global_indexes): - raise ValueError( - f"StorageUnitData put_data: field '{f}' values length {len(values)} " - f"!= global_indexes length {len(global_indexes)}, length mismatch" - ) - if f not in self.field_data: - self.field_data[f] = {} - field_dict = self.field_data[f] - for key, val in zip(global_indexes, values, strict=True): - field_dict[key] = val - self._active_keys.update(global_indexes) - - def clear(self, keys: list[int]) -> None: - """Remove data at given global index keys, immediately freeing memory. - - Args: - keys: Global indexes to remove. - """ - for f in self.field_data: - for key in keys: - self.field_data[f].pop(key, None) - self._active_keys -= set(keys) - - -@ray.remote(num_cpus=1) -class SimpleStorageUnit: - """A storage unit that provides distributed data storage functionality. - - This class represents a storage unit that can store data in a 2D structure - (samples, data_fields) and provides ZMQ-based communication for put/get/clear operations. - - Note: We use Ray decorator (@ray.remote) only for initialization purposes. - We do NOT use Ray's .remote() call capabilities - the storage unit runs - as a standalone process with its own ZMQ server socket. - - Attributes: - storage_unit_id: Unique identifier for this storage unit. - storage_unit_size: Maximum number of elements that can be stored. - storage_data: Internal StorageUnitData instance for data management. - zmq_server_info: ZMQ connection information for clients. - """ - - def __init__(self, storage_unit_size: int): - """Initialize a SimpleStorageUnit with the specified size. - - Args: - storage_unit_size: Maximum number of elements that can be stored in this storage unit. - """ - self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4().hex[:8]}" - self.storage_unit_size = storage_unit_size - - self.storage_data = StorageUnitData(self.storage_unit_size) - - # Internal communication address for proxy and workers - self._inproc_addr = f"inproc://simple_storage_workers_{self.storage_unit_id}" - - # Shutdown event for graceful termination - self._shutdown_event = Event() - - # Placeholder for zmq_context, proxy_thread and worker_threads - self.zmq_context: zmq.Context | None = None - self.put_get_socket: zmq.Socket | None = None - self.proxy_thread: Thread | None = None - self.worker_thread: Thread | None = None - - self._metrics: TQMetricsExporter | None = None - - self._init_zmq_socket() - self._start_process_put_get() - - # Register finalizer for graceful cleanup when garbage collected - self._finalizer = weakref.finalize( - self, - self._shutdown_resources, - self._shutdown_event, - self.worker_thread, - self.proxy_thread, - self.zmq_context, - self.put_get_socket, - ) - - def _init_zmq_socket(self) -> None: - """ - Initialize ZMQ socket connections between storage unit and controller/clients: - - put_get_socket (ROUTER): Handle put/get requests from clients. - - worker_socket (DEALER): Backend socket for worker communication. - """ - self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address() - - # Frontend: ROUTER for receiving client requests - self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, self._node_ip) - - while True: - try: - self._put_get_socket_port = get_free_port(ip=self._node_ip) - self.put_get_socket.bind(format_zmq_address(self._node_ip, self._put_get_socket_port)) - break - except zmq.ZMQError: - logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...") - continue - - # Backend: DEALER for worker communication (connected via zmq.proxy) - self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) - self.worker_socket.bind(self._inproc_addr) - - self.zmq_server_info = ZMQServerInfo( - role=Role.STORAGE, - id=str(self.storage_unit_id), - ip=self._node_ip, - ports={"put_get_socket": self._put_get_socket_port}, - ) - - def _start_process_put_get(self) -> None: - """Start worker threads and ZMQ proxy for handling requests.""" - - # Start worker thread - self.worker_thread = Thread( - target=self._worker_routine, - name=f"StorageUnitWorkerThread-{self.storage_unit_id}", - daemon=True, - ) - self.worker_thread.start() - - time.sleep(0.5) # make sure worker thread is ready before zmq.proxy forwarding messages - - # Start proxy thread (ROUTER <-> DEALER) - self.proxy_thread = Thread( - target=self._proxy_routine, - name=f"StorageUnitProxyThread-{self.storage_unit_id}", - daemon=True, - ) - self.proxy_thread.start() - - def _proxy_routine(self) -> None: - """ZMQ proxy for message forwarding between frontend ROUTER and backend DEALER.""" - logger.info(f"[{self.storage_unit_id}]: start ZMQ proxy...") - try: - zmq.proxy(self.put_get_socket, self.worker_socket) - except zmq.ContextTerminated: - logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy stopped gracefully (Context Terminated)") - except Exception as e: - if self._shutdown_event.is_set(): - logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy shutting down...") - else: - logger.error(f"[{self.storage_unit_id}]: ZMQ Proxy unexpected error: {e}") - - def _worker_routine(self) -> None: - """Worker thread for processing requests.""" - - worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) - worker_socket.connect(self._inproc_addr) - - poller = zmq.Poller() - poller.register(worker_socket, zmq.POLLIN) - - logger.info(f"[{self.storage_unit_id}]: worker thread started...") - perf_monitor = IntervalPerfMonitor(caller_name=f"{self.storage_unit_id}") - - while not self._shutdown_event.is_set(): - monitor = self._metrics if self._metrics is not None else perf_monitor - try: - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) - except zmq.error.ContextTerminated: - # ZMQ context was terminated, exit gracefully - logger.info(f"[{self.storage_unit_id}]: worker stopped gracefully (Context Terminated)") - break - except Exception as e: - logger.warning(f"[{self.storage_unit_id}]: worker poll error: {e}") - continue - - if self._shutdown_event.is_set(): - break - - if worker_socket in socks: - # Messages received from proxy: [identity, serialized_msg_frame1, ...] - messages = worker_socket.recv_multipart(copy=False) - identity = messages[0] - serialized_msg = messages[1:] - - request_msg = ZMQMessage.deserialize(serialized_msg) - operation = request_msg.request_type - - try: - logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}") - - # Process request - if operation == ZMQRequestType.PUT_DATA: # type: ignore[arg-type] - with monitor.measure(op_type="PUT_DATA"): - response_msg = self._handle_put(request_msg) - elif operation == ZMQRequestType.GET_DATA: # type: ignore[arg-type] - with monitor.measure(op_type="GET_DATA"): - response_msg = self._handle_get(request_msg) - elif operation == ZMQRequestType.CLEAR_DATA: # type: ignore[arg-type] - with monitor.measure(op_type="CLEAR_DATA"): - response_msg = self._handle_clear(request_msg) - elif operation == ZMQRequestType.GET_METRICS: # type: ignore[arg-type] - response_msg = self._handle_get_metrics() - else: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={ - "message": f"Storage unit id #{self.storage_unit_id} " - f"receive invalid operation: {operation}." - }, - ) - except Exception as e: - logger.error( - f"[{self.storage_unit_id}]: worker error during {operation} " - f"from sender={request_msg.sender_id}: {type(e).__name__}: {e}" - ) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_ERROR, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={ - "message": f"{self.storage_unit_id}, worker encountered error " - f"during operation {operation}: {str(e)}." - }, - ) - - # Send response back with identity for routing - worker_socket.send_multipart([identity] + response_msg.serialize(), copy=False) - - logger.info(f"[{self.storage_unit_id}]: worker stopped.") - poller.unregister(worker_socket) - worker_socket.close(linger=0) - - def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle put request, add or update data into storage unit. - - Args: - data_parts: ZMQMessage from client. - - Returns: - Put data success response ZMQMessage. - """ - try: - global_indexes = data_parts.body["global_indexes"] - field_data = data_parts.body["data"] # field_data should be a dict. - data_parser = data_parts.body.get("data_parser", None) - - with limit_pytorch_auto_parallel_threads( - target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_put" - ): - if data_parser is not None: - if not callable(data_parser): - raise TypeError(f"data_parser must be callable, got {type(data_parser).__name__}") - - original_keys = set(field_data.keys()) - original_lengths = {} - for k, v in field_data.items(): - if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: - original_lengths[k] = v.shape[0] - else: - try: - original_lengths[k] = len(v) - except Exception: - original_lengths[k] = None - - field_data = data_parser(field_data) - - if not isinstance(field_data, dict): - raise TypeError(f"data_parser must return a dict, got {type(field_data).__name__}") - - new_keys = set(field_data.keys()) - if new_keys != original_keys: - raise ValueError( - f"data_parser must not change dict keys. " - f"Original keys: {sorted(original_keys)}, got: {sorted(new_keys)}" - ) - - for k, v in field_data.items(): - if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: - new_len = v.shape[0] - else: - try: - new_len = len(v) - except Exception: - new_len = None - - orig_len = original_lengths[k] - if orig_len is not None and new_len is not None and orig_len != new_len: - raise ValueError( - f"data_parser changed the number of elements for key '{k}': " - f"expected {orig_len}, got {new_len}" - ) - self.storage_data.put_data(field_data, global_indexes) - - # After put operation finish, send a message to the client - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA_RESPONSE, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={}, - ) - - return response_msg - except Exception as e: - return ZMQMessage.create( - request_type=ZMQRequestType.PUT_ERROR, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={ - "message": f"Failed to put data into storage unit id " - f"#{self.storage_unit_id}, detail error message: {str(e)}" - }, - ) - - def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle get request, return data from storage unit. - - Args: - data_parts: ZMQMessage from client. - - Returns: - Get data success response ZMQMessage, containing target data. - """ - try: - fields = data_parts.body["fields"] - global_indexes = data_parts.body["global_indexes"] - - with limit_pytorch_auto_parallel_threads( - target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_get" - ): - result_data = self.storage_data.get_data(fields, global_indexes) - - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA_RESPONSE, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={ - "data": result_data, - }, - ) - except Exception as e: - logger.error( - f"[{self.storage_unit_id}]: _handle_get error, " - f"fields={fields}, global_indexes={global_indexes}: {type(e).__name__}: {e}" - ) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_ERROR, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={ - "message": f"Failed to get data from storage unit id #{self.storage_unit_id}, " - f"detail error message: {str(e)}" - }, - ) - return response_msg - - def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle clear request, clear data in storage unit according to given global_indexes. - - Args: - data_parts: ZMQMessage from client, including target global_indexes. - - Returns: - Clear data success response ZMQMessage. - """ - try: - global_indexes = data_parts.body["global_indexes"] - - with limit_pytorch_auto_parallel_threads( - target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_clear" - ): - self.storage_data.clear(global_indexes) - - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."}, - ) - except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_ERROR, # type: ignore[arg-type] - sender_id=self.storage_unit_id, - body={ - "message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, " - f"detail error message: {str(e)}" - }, - ) - return response_msg - - def _handle_get_metrics(self) -> ZMQMessage: - """Handle GET_METRICS request by returning storage unit statistics. - - Returns: - ZMQMessage containing storage unit ID, capacity, active keys, - process RSS memory, and per-operation request stats. - """ - try: - process_rss = psutil.Process().memory_info().rss - except Exception: - process_rss = 0 - - metrics = { - "storage_unit_id": self.storage_unit_id, - "capacity": self.storage_unit_size, - "active_keys": self.storage_data.active_key_count, - "process_rss_bytes": process_rss, - } - - # Include per-operation stats if Prometheus metrics are enabled - if self._metrics is not None: - op_stats = {} - for op_type in ("PUT_DATA", "GET_DATA", "CLEAR_DATA"): - try: - hist = self._metrics.request_duration.labels(op_type=op_type) - counter = self._metrics.request_total.labels(op_type=op_type) - duration_sum = hist._sum.get() - # Build cumulative counts once, reuse for total and quantiles - cumulative_counts = self._cumulative_bucket_counts(hist) - duration_count = cumulative_counts[-1] if cumulative_counts else 0 - op_stats[op_type] = { - "request_count": counter._value.get(), - "latency_avg": duration_sum / duration_count if duration_count > 0 else 0, - "latency_p50": self._quantile_from_cumulative(hist, cumulative_counts, 0.50), - "latency_p99": self._quantile_from_cumulative(hist, cumulative_counts, 0.99), - } - except (AttributeError, TypeError, ZeroDivisionError) as e: - logger.debug(f"[{self.storage_unit_id}]: Failed to extract metrics for {op_type}: {e}") - if op_stats: - metrics["op_stats"] = op_stats - - return ZMQMessage.create( - request_type=ZMQRequestType.METRICS_RESPONSE, - sender_id=self.storage_unit_id, - body=metrics, - ) - - @staticmethod - def _cumulative_bucket_counts(hist) -> list[float]: - """Build cumulative counts from a prometheus_client Histogram's non-cumulative buckets.""" - cumulative = 0.0 - counts = [] - for bucket in hist._buckets: - cumulative += bucket.get() - counts.append(cumulative) - return counts - - @staticmethod - def _quantile_from_cumulative(hist, cumulative_counts: list[float], q: float) -> float: - """Estimate a quantile using pre-computed cumulative bucket counts. - - Uses linear interpolation matching Prometheus histogram_quantile() logic. - """ - total = cumulative_counts[-1] if cumulative_counts else 0 - if total == 0: - return 0.0 - target = q * total - prev_bound = 0.0 - prev_cumulative = 0.0 - for bound, cum_count in zip(hist._upper_bounds, cumulative_counts, strict=False): - if cum_count >= target: - fraction = ( - (target - prev_cumulative) / (cum_count - prev_cumulative) if cum_count > prev_cumulative else 0 - ) - return prev_bound + (bound - prev_bound) * fraction - prev_bound = bound - prev_cumulative = cum_count - return prev_bound - - @staticmethod - def _shutdown_resources( - shutdown_event: Event, - worker_thread: Thread | None, - proxy_thread: Thread | None, - zmq_context: zmq.Context | None, - put_get_socket: zmq.Socket | None, - ) -> None: - """Clean up resources on garbage collection.""" - logger.info("Shutting down SimpleStorageUnit resources...") - - # Signal all threads to stop - shutdown_event.set() - - # Terminate put_get_socket - if put_get_socket: - put_get_socket.close(linger=0) - - # Terminate ZMQ context to unblock proxy and workers - if zmq_context: - zmq_context.term() - - # Wait for threads to finish (with timeout) - if worker_thread and worker_thread.is_alive(): - worker_thread.join(timeout=5) - if proxy_thread and proxy_thread.is_alive(): - proxy_thread.join(timeout=5) - - logger.info("SimpleStorageUnit resources shutdown complete.") - - def start_metrics(self, port: int = 0) -> str: - """Initialize and start the Prometheus metrics exporter for this storage unit. - - When enabled, replaces ``IntervalPerfMonitor`` for request latency/throughput - tracking with Prometheus counters and histograms. - - Args: - port: HTTP port for the /metrics endpoint (0 = auto-assign). - - Returns: - The metrics endpoint address in ``host:port`` format. - """ - if self._metrics is not None: - return self._metrics.endpoint - from transfer_queue.metrics import TQMetricsExporter - - self._metrics = TQMetricsExporter(role="storage") - endpoint = self._metrics.start(node_ip=self._node_ip, port=port) - logger.info(f"[{self.storage_unit_id}]: Prometheus metrics exporter started on {endpoint}") - return endpoint - - def get_zmq_server_info(self) -> ZMQServerInfo: - """Get the ZMQ server information for this storage unit. - - Returns: - ZMQServerInfo containing connection details for this storage unit. - """ - return self.zmq_server_info - - -@StorageBackendFactory.register_backend("SimpleStorage") -def initialize_simple_backend(conf: DictConfig) -> dict[str, Any]: - """Initialize Simple backend with metastore mode.""" - - simple_storage_handles = {} - num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units - total_storage_size = conf.backend.SimpleStorage.total_storage_size - storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) - - for storage_unit_rank in range(num_data_storage_units): - storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] - placement_group=storage_placement_group, - placement_group_bundle_index=storage_unit_rank, - name=f"TransferQueueStorageUnit#{storage_unit_rank}", - ).remote( - storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), - ) - simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node - logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") - - storage_zmq_info = process_zmq_server_info(simple_storage_handles) - backend_name = conf.backend.storage_backend - conf.backend[backend_name].zmq_info = storage_zmq_info - - return simple_storage_handles diff --git a/transfer_queue/storage/backends/__init__.py b/transfer_queue/storage/bootstrap/__init__.py similarity index 71% rename from transfer_queue/storage/backends/__init__.py rename to transfer_queue/storage/bootstrap/__init__.py index 2056c9c7..874dd174 100644 --- a/transfer_queue/storage/backends/__init__.py +++ b/transfer_queue/storage/bootstrap/__init__.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import mooncake_storage, simple_storage, yuanrong_storage # noqa: F401, I001 -from .base import StorageBackendFactory -from .simple_storage import SimpleStorageUnit, StorageUnitData +from . import mooncake_bootstrap, simple_bootstrap, yuanrong_bootstrap # noqa: F401, I001 +from .provider import StorageBootstrapProvider __all__ = [ - "StorageBackendFactory", - "SimpleStorageUnit", - "StorageUnitData", + "StorageBootstrapProvider", ] diff --git a/transfer_queue/storage/backends/mooncake_storage.py b/transfer_queue/storage/bootstrap/mooncake_bootstrap.py similarity index 93% rename from transfer_queue/storage/backends/mooncake_storage.py rename to transfer_queue/storage/bootstrap/mooncake_bootstrap.py index 70116e8d..31abe638 100644 --- a/transfer_queue/storage/backends/mooncake_storage.py +++ b/transfer_queue/storage/bootstrap/mooncake_bootstrap.py @@ -20,20 +20,20 @@ from omegaconf import DictConfig -from transfer_queue.storage.backends.base import StorageBackendFactory +from transfer_queue.storage.bootstrap.provider import StorageBootstrapProvider from transfer_queue.utils.logging_utils import get_logger logger = get_logger(__name__) -@StorageBackendFactory.register_backend("MooncakeStore") -def initialize_mooncake_backend(conf: DictConfig) -> DictConfig: +@StorageBootstrapProvider.register_provider("MooncakeStore") +def initialize_mooncake_backend(conf: DictConfig) -> subprocess.Popen | None: """ Initialize MooncakeStore backend. Args: conf (DictConfig): Configuration dictionary for the MooncakeStore backend. Returns: - DictConfig: Initialized configuration dictionary for the MooncakeStore backend. + subprocess.Popen | None: Process object for the MooncakeStore backend process. Raises: ValueError: If the backend is not initialized successfully. """ diff --git a/transfer_queue/storage/backends/base.py b/transfer_queue/storage/bootstrap/provider.py similarity index 69% rename from transfer_queue/storage/backends/base.py rename to transfer_queue/storage/bootstrap/provider.py index 4efa00e7..0c3399da 100644 --- a/transfer_queue/storage/backends/base.py +++ b/transfer_queue/storage/bootstrap/provider.py @@ -17,24 +17,24 @@ from typing import Callable -class StorageBackendFactory: - _backends: dict[str, Callable] = {} +class StorageBootstrapProvider: + _providers: dict[str, Callable] = {} @classmethod - def register_backend(cls, name: str): - """Decorator to register storage backend & returns function.""" + def register_provider(cls, name: str): + """Decorator to register storage provider & returns function.""" def decorator(fn): @wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) - cls._backends[name.lower()] = wrapper + cls._providers[name.lower()] = wrapper return wrapper return decorator @classmethod - def get_backend(cls, name: str) -> Callable | None: - """Get storage backend function by name.""" - return cls._backends.get(name.lower(), None) + def get_provider(cls, name: str) -> Callable | None: + """Get storage provider function by name.""" + return cls._providers.get(name.lower(), None) diff --git a/transfer_queue/storage/bootstrap/simple_bootstrap.py b/transfer_queue/storage/bootstrap/simple_bootstrap.py new file mode 100644 index 00000000..b2350fa8 --- /dev/null +++ b/transfer_queue/storage/bootstrap/simple_bootstrap.py @@ -0,0 +1,54 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any + +from omegaconf import DictConfig + +from transfer_queue.storage.bootstrap.provider import StorageBootstrapProvider +from transfer_queue.storage.simple_storage import SimpleStorageUnit +from transfer_queue.utils.common import get_placement_group +from transfer_queue.utils.logging_utils import get_logger +from transfer_queue.utils.zmq_utils import process_zmq_server_info + +logger = get_logger(__name__) + + +@StorageBootstrapProvider.register_provider("SimpleStorage") +def initialize_simple_backend(conf: DictConfig) -> dict[str, Any]: + """Initialize Simple backend with metastore mode.""" + + simple_storage_handles = {} + num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units + total_storage_size = conf.backend.SimpleStorage.total_storage_size + storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) + + for storage_unit_rank in range(num_data_storage_units): + storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] + placement_group=storage_placement_group, + placement_group_bundle_index=storage_unit_rank, + name=f"TransferQueueStorageUnit#{storage_unit_rank}", + ).remote( + storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), + ) + simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") + + storage_zmq_info = process_zmq_server_info(simple_storage_handles) + backend_name = conf.backend.storage_backend + conf.backend[backend_name].zmq_info = storage_zmq_info + + return simple_storage_handles diff --git a/transfer_queue/storage/backends/yuanrong_storage.py b/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py similarity index 99% rename from transfer_queue/storage/backends/yuanrong_storage.py rename to transfer_queue/storage/bootstrap/yuanrong_bootstrap.py index 57239768..30b28ba6 100644 --- a/transfer_queue/storage/backends/yuanrong_storage.py +++ b/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py @@ -22,7 +22,7 @@ import ray from omegaconf import DictConfig -from transfer_queue.storage.backends.base import StorageBackendFactory +from transfer_queue.storage.bootstrap.provider import StorageBootstrapProvider from transfer_queue.utils.yuanrong_utils import get_local_ip_addresses, kill_actors_and_placement_group logger = logging.getLogger(__name__) @@ -260,7 +260,7 @@ def stop(self) -> None: logger.info(f"Datasystem worker stopped successfully at {self.worker_address}") -@StorageBackendFactory.register_backend("Yuanrong") +@StorageBootstrapProvider.register_provider("Yuanrong") def initialize_yuanrong_backend(conf: DictConfig) -> dict[str, Any] | None: """Initialize Yuanrong backend with metastore mode. From a76b9bbcce94a03d4807dfe0ca0ceb547b8b6ce9 Mon Sep 17 00:00:00 2001 From: fy2462 Date: Mon, 18 May 2026 15:59:01 +0800 Subject: [PATCH 3/5] [refact] Fix license data error from CI. Signed-off-by: fy2462 --- transfer_queue/storage/bootstrap/__init__.py | 6 +++--- ...{mooncake_bootstrap.py => mooncake_storage_bootstrap.py} | 4 ++-- transfer_queue/storage/bootstrap/provider.py | 4 ++-- .../{simple_bootstrap.py => simple_storage_bootstrap.py} | 0 ...{yuanrong_bootstrap.py => yuanrong_storage_bootstrap.py} | 0 5 files changed, 7 insertions(+), 7 deletions(-) rename transfer_queue/storage/bootstrap/{mooncake_bootstrap.py => mooncake_storage_bootstrap.py} (97%) rename transfer_queue/storage/bootstrap/{simple_bootstrap.py => simple_storage_bootstrap.py} (100%) rename transfer_queue/storage/bootstrap/{yuanrong_bootstrap.py => yuanrong_storage_bootstrap.py} (100%) diff --git a/transfer_queue/storage/bootstrap/__init__.py b/transfer_queue/storage/bootstrap/__init__.py index 874dd174..c76840dc 100644 --- a/transfer_queue/storage/bootstrap/__init__.py +++ b/transfer_queue/storage/bootstrap/__init__.py @@ -1,5 +1,5 @@ -# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2026 The TransferQueue Team +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import mooncake_bootstrap, simple_bootstrap, yuanrong_bootstrap # noqa: F401, I001 +from . import mooncake_storage_bootstrap, simple_storage_bootstrap, yuanrong_storage_bootstrap # noqa: F401, I001 from .provider import StorageBootstrapProvider __all__ = [ diff --git a/transfer_queue/storage/bootstrap/mooncake_bootstrap.py b/transfer_queue/storage/bootstrap/mooncake_storage_bootstrap.py similarity index 97% rename from transfer_queue/storage/bootstrap/mooncake_bootstrap.py rename to transfer_queue/storage/bootstrap/mooncake_storage_bootstrap.py index 31abe638..51dc27e1 100644 --- a/transfer_queue/storage/bootstrap/mooncake_bootstrap.py +++ b/transfer_queue/storage/bootstrap/mooncake_storage_bootstrap.py @@ -1,5 +1,5 @@ -# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2026 The TransferQueue Team +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/transfer_queue/storage/bootstrap/provider.py b/transfer_queue/storage/bootstrap/provider.py index 0c3399da..74505ebb 100644 --- a/transfer_queue/storage/bootstrap/provider.py +++ b/transfer_queue/storage/bootstrap/provider.py @@ -1,5 +1,5 @@ -# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2026 The TransferQueue Team +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/transfer_queue/storage/bootstrap/simple_bootstrap.py b/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py similarity index 100% rename from transfer_queue/storage/bootstrap/simple_bootstrap.py rename to transfer_queue/storage/bootstrap/simple_storage_bootstrap.py diff --git a/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py b/transfer_queue/storage/bootstrap/yuanrong_storage_bootstrap.py similarity index 100% rename from transfer_queue/storage/bootstrap/yuanrong_bootstrap.py rename to transfer_queue/storage/bootstrap/yuanrong_storage_bootstrap.py From 8328083c387e31287c9610758a5462c42aa0de0b Mon Sep 17 00:00:00 2001 From: fy2462 Date: Mon, 18 May 2026 16:05:50 +0800 Subject: [PATCH 4/5] [refact] Fix license data error from CI. Signed-off-by: fy2462 --- transfer_queue/storage/bootstrap/__init__.py | 2 +- .../{mooncake_storage_bootstrap.py => mooncake_bootstrap.py} | 0 .../{yuanrong_storage_bootstrap.py => yuanrong_bootstrap.py} | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename transfer_queue/storage/bootstrap/{mooncake_storage_bootstrap.py => mooncake_bootstrap.py} (100%) rename transfer_queue/storage/bootstrap/{yuanrong_storage_bootstrap.py => yuanrong_bootstrap.py} (100%) diff --git a/transfer_queue/storage/bootstrap/__init__.py b/transfer_queue/storage/bootstrap/__init__.py index c76840dc..e8ce25f1 100644 --- a/transfer_queue/storage/bootstrap/__init__.py +++ b/transfer_queue/storage/bootstrap/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import mooncake_storage_bootstrap, simple_storage_bootstrap, yuanrong_storage_bootstrap # noqa: F401, I001 +from . import mooncake_bootstrap, simple_storage_bootstrap, yuanrong_bootstrap # noqa: F401, I001 from .provider import StorageBootstrapProvider __all__ = [ diff --git a/transfer_queue/storage/bootstrap/mooncake_storage_bootstrap.py b/transfer_queue/storage/bootstrap/mooncake_bootstrap.py similarity index 100% rename from transfer_queue/storage/bootstrap/mooncake_storage_bootstrap.py rename to transfer_queue/storage/bootstrap/mooncake_bootstrap.py diff --git a/transfer_queue/storage/bootstrap/yuanrong_storage_bootstrap.py b/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py similarity index 100% rename from transfer_queue/storage/bootstrap/yuanrong_storage_bootstrap.py rename to transfer_queue/storage/bootstrap/yuanrong_bootstrap.py From 6429907ec224932c102e3d3f2a940beca857a729 Mon Sep 17 00:00:00 2001 From: fy2462 Date: Mon, 18 May 2026 16:17:07 +0800 Subject: [PATCH 5/5] [refact] Fix some CI error. Signed-off-by: fy2462 --- transfer_queue/storage/bootstrap/mooncake_bootstrap.py | 10 +++++----- transfer_queue/storage/bootstrap/provider.py | 2 ++ .../storage/bootstrap/simple_storage_bootstrap.py | 4 ++-- transfer_queue/storage/bootstrap/yuanrong_bootstrap.py | 8 ++++---- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/transfer_queue/storage/bootstrap/mooncake_bootstrap.py b/transfer_queue/storage/bootstrap/mooncake_bootstrap.py index 51dc27e1..536599ca 100644 --- a/transfer_queue/storage/bootstrap/mooncake_bootstrap.py +++ b/transfer_queue/storage/bootstrap/mooncake_bootstrap.py @@ -27,15 +27,15 @@ @StorageBootstrapProvider.register_provider("MooncakeStore") -def initialize_mooncake_backend(conf: DictConfig) -> subprocess.Popen | None: +def initialize_mooncake_storage(conf: DictConfig) -> subprocess.Popen | None: """ - Initialize MooncakeStore backend. + Initialize Mooncake store backend. Args: - conf (DictConfig): Configuration dictionary for the MooncakeStore backend. + conf (DictConfig): Configuration dictionary for the Mooncake store backend. Returns: - subprocess.Popen | None: Process object for the MooncakeStore backend process. + subprocess.Popen | None: Process object for the Mooncake store backend process. Raises: - ValueError: If the backend is not initialized successfully. + ValueError: If the Mooncake store is not initialized successfully. """ if not conf.backend.MooncakeStore.auto_init: return None diff --git a/transfer_queue/storage/bootstrap/provider.py b/transfer_queue/storage/bootstrap/provider.py index 74505ebb..504900f8 100644 --- a/transfer_queue/storage/bootstrap/provider.py +++ b/transfer_queue/storage/bootstrap/provider.py @@ -18,6 +18,8 @@ class StorageBootstrapProvider: + """Registry for storage backend bootstrap functions.""" + _providers: dict[str, Callable] = {} @classmethod diff --git a/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py b/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py index b2350fa8..1ab2f6b6 100644 --- a/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py +++ b/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py @@ -28,8 +28,8 @@ @StorageBootstrapProvider.register_provider("SimpleStorage") -def initialize_simple_backend(conf: DictConfig) -> dict[str, Any]: - """Initialize Simple backend with metastore mode.""" +def initialize_simple_storage(conf: DictConfig) -> dict[str, Any]: + """Initialize Simple storage with metastore mode.""" simple_storage_handles = {} num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units diff --git a/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py b/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py index 30b28ba6..e114f939 100644 --- a/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py +++ b/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py @@ -261,14 +261,14 @@ def stop(self) -> None: @StorageBootstrapProvider.register_provider("Yuanrong") -def initialize_yuanrong_backend(conf: DictConfig) -> dict[str, Any] | None: - """Initialize Yuanrong backend with metastore mode. +def initialize_yuanrong_storage(conf: DictConfig) -> dict[str, Any] | None: + """Initialize Yuanrong storage with metastore mode. - This function sets up the Yuanrong datasystem workers across all Ray nodes + This function sets up the Yuanrong storage datasystem workers across all Ray nodes using placement groups and actors. Args: - conf: Configuration containing Yuanrong settings + conf: Configuration containing Yuanrong storage settings Returns: Dict containing worker_actors, metastore_address, and placement_group