From 04fd717d0b2b4cb468431022f8fa9dfbdec4234f Mon Sep 17 00:00:00 2001 From: Haichuan Hu Date: Tue, 7 Apr 2026 10:05:51 +0800 Subject: [PATCH] [feat] Support auto init YR backend based on metastore Signed-off-by: Haichuan Hu --- .github/workflows/run-tests.yml | 6 +- .../openyuanrong_datasystem.md | 136 ++-- scripts/performance_test/README_PERFTEST.md | 5 +- scripts/performance_test/perftest_config.yaml | 14 +- tests/e2e/test_e2e_lifecycle_consistency.py | 7 +- tests/e2e/test_kv_interface_e2e.py | 15 +- tests/test_kv_storage_manager.py | 3 +- tests/test_storage_client_factory.py | 2 +- tests/test_yuanrong_client_zero_copy.py | 2 +- tests/test_yuanrong_storage_client_e2e.py | 2 +- transfer_queue/config.yaml | 19 +- transfer_queue/interface.py | 179 +----- .../storage/clients/yuanrong_client.py | 114 +--- .../storage/managers/yuanrong_manager.py | 6 +- transfer_queue/utils/yuanrong_utils.py | 588 ++++++++++++++++++ 15 files changed, 723 insertions(+), 375 deletions(-) create mode 100644 transfer_queue/utils/yuanrong_utils.py diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 01f3a9eb..2c4429ae 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -42,4 +42,8 @@ jobs: TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py pkill -f "[m]ooncake_master" || true TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py - pkill -f "[m]ooncake_master" || true \ No newline at end of file + pkill -f "[m]ooncake_master" || true + - name: Run Yuanrong Backend Specific E2E Tests + run: | + TQ_TEST_BACKEND=Yuanrong pytest tests/e2e/test_e2e_lifecycle_consistency.py + TQ_TEST_BACKEND=Yuanrong pytest tests/e2e/test_kv_interface_e2e.py \ No newline at end of file diff --git a/docs/storage_backends/openyuanrong_datasystem.md b/docs/storage_backends/openyuanrong_datasystem.md index e25c2f97..26244cbd 100644 --- a/docs/storage_backends/openyuanrong_datasystem.md +++ b/docs/storage_backends/openyuanrong_datasystem.md @@ -1,7 +1,7 @@ # OpenYuanrong-Datasystem Integration for TransferQueue -> Last updated: 01/26/2026 +> Last updated: 04/17/2026 ## 🎉 Overview @@ -58,27 +58,7 @@ pip install openyuanrong-datasystem dscli -h ``` -#### 3. Install etcd - -OpenYuanrong-datasystem relies on etcd for cluster coordination. -Download and install etcd from the official releases: [ETCD GitHub Releases](https://github.com/etcd-io/etcd/releases) - -```bash -# Example for Linux ARM64 (adjust for your architecture) -# Unpack and install etcd -ETCD_VERSION = "v3.6.5" # Replace with the desired version -tar -xvf etcd-${ETCD_VERSION}-linux-arm64.tar.gz -cd etcd-${ETCD_VERSION}-linux-arm64 - -# Copy the executable file to the system path -sudo cp etcd etcdctl /usr/local/bin/ - -# Verify installation -etcd --version -etcdctl version -``` - -#### 4. (Optional) Install CANN and torch-npu +#### 3. (Optional) Install CANN and torch-npu If you have NPU devices and want to accelerate the transmission of NPU tensor, you can install **Ascend-cann-toolkit** and **torch-npu**. @@ -106,19 +86,36 @@ pip install torch-npu==2.8.0 Next, we will provide deployment and code examples for single-node scenarios. For multi-node scenarios, please refer to [Appendix B](#B-deploy-multi-node-datasystem-for-multi-node-training-and-inference-scenarios). -Unlike using TransferQueue with its default backend, integrating OpenYuanrong-Datasystem requires **pre-launching** the datasystem services before running your Python application. - #### Deployment -```bash -# Deploy etcd (for example, port 2379) -etcd --listen-client-urls http://0.0.0.0:2379 \ - --advertise-client-urls http://localhost:2379 & -# Deploy datasystem -dscli start -w --worker_address "127.0.0.1:31501" --etcd_address "127.0.0.1:2379" +TransferQueue will automatically initialize the Yuanrong backend when `auto_init: True` is set. TransferQueue will: +- Create placement groups to ensure workers are spread across Ray nodes +- Launch YuanrongWorkerActor on each node to start datasystem workers +- Set up metastore service on the head node + +Configuration example: +```yaml +backend: + storage_backend: Yuanrong + Yuanrong: + auto_init: True + worker_port: 31501 + metastore_port: 2379 + enable_yr_npu_transport: true + worker_args: "--shared_memory_size_mb 8192 --remote_h2d_device_ids 0 --enable_huge_tlb true" ``` -Once the datasystem is up, you can run your TransferQueue + Datasystem application. +Configuration options: +- `auto_init`: Whether to automatically initialize Yuanrong backend. Currently only `True` is supported. +- `worker_port`: Port for Yuanrong datasystem worker on each node. +- `metastore_port`: Port for metastore service on the head node. +- `enable_yr_npu_transport`: Enable NPU transport for high-performance device-to-device data transfer. +- `worker_args`: Additional arguments passed to `dscli start` command: + - `--shared_memory_size_mb`: Shared memory size in MB for datasystem worker. + - `--remote_h2d_device_ids`: Enable RH2D (Remote Host-to-Device) for efficient cross-node data transfer. Specify NPU device IDs as comma-separated values (e.g., `0,1,2,3`). + - `--enable_huge_tlb`: Enable huge page memory, required for >21GB shared memory on Ascend 910B. + +Once the configuration is set, you can run your TransferQueue + Datasystem application directly. #### Demo You can associate `TransferQueueClient` with `YuanrongStorageManager` through the configuration dictionary when initializing the TransferQueue. @@ -137,7 +134,7 @@ from transfer_queue import ( config_str = """ manager_type: YuanrongStorageManager client_name: YuanrongStorageClient - port: 31501 + worker_port: 31501 """ dict_conf = OmegaConf.create(config_str, flags={"allow_objects": True}) ``` @@ -198,13 +195,7 @@ print("output: ", output) #### Shut down datasystem: -```bash -# shutdown datasystem on the node -dscli stop --worker_address "127.0.0.1:31501" - -# shutdown etcd -pkill -f etcd || true -``` +TransferQueue automatically handles cleanup when calling `tq.close()`, which stops all Yuanrong datasystem workers gracefully. ### Datasystem Logs @@ -251,57 +242,46 @@ If you need to uninstall, execute: ``` ### B: Deploy multi-node datasystem for multi-node training and inference scenarios -We can use etcd to connect to a datasystem backend across multiple nodes. -Let's take two nodes (for instance, 10.170.27.24 and 10.170.27.33) as an example. -#### Start etcd on head node - -```bash -# For example, using the port 2379 of head node -etcd \ - --name etcd-single \ - --data-dir /tmp/etcd-data \ - --listen-client-urls http://10.170.27.24:2379 \ - --advertise-client-urls http://10.170.27.24:2379 \ - --listen-peer-urls http://10.170.27.24:2380 \ - --initial-advertise-peer-urls http://10.170.27.24:2380 \ - --initial-cluster etcd-single=http://10.170.27.24:2380 & -``` +TransferQueue automatically initializes Yuanrong datasystem workers across all Ray cluster nodes. Just set `auto_init: True` in the configuration and TransferQueue will handle the multi-node deployment. +Let's take two nodes (for instance, 192.168.0.1 and 192.168.0.2) as an example. -#### Deploy multi-nodes datasystem -On each node, you need to connect to the etcd service on the head node using your local node's IP address. +#### Deploy ray ```bash -#on head node -dscli start -w --worker_address "10.170.27.24:31501" --etcd_address "10.170.27.24:2379" +# on head node +ray start --head --resources='{"node:192.168.0.1": 1}' + +# on worker node (assume ray port of head_node is 6379) +ray start --address="192.168.0.1:6379" --resources='{"node:192.168.0.2": 1}' ``` -```bash -#on work node -dscli start -w --worker_address "10.170.27.33:31501" --etcd_address "10.170.27.24:2379" +#### Configuration + +TransferQueue will detect all Ray nodes and deploy datasystem workers automatically: +```yaml +backend: + storage_backend: Yuanrong + Yuanrong: + auto_init: True + worker_port: 31501 + metastore_port: 2379 + enable_yr_npu_transport: true + worker_args: "--shared_memory_size_mb 65536 --remote_h2d_device_ids 0 --enable_huge_tlb true" ``` -Now you can use datasystem on head-node and work-node. > For more detailed deployment instructions, please refer to [yuanrong documents](https://gitcode.com/openeuler/yuanrong-datasystem/blob/master/README.md#%E9%83%A8%E7%BD%B2-openyuanrong-datasystem). > The configuration parameters for deploying the data system can refer [dscli config](https://gitcode.com/openeuler/yuanrong-datasystem/blob/master/docs/source_zh_cn/deployment/dscli.md#%E9%85%8D%E7%BD%AE%E9%A1%B9%E8%AF%B4%E6%98%8E). There is a demo with multi-node scenarios as fellow. -#### Deploy ray -```bash -# on head node -ray start --head --resources='{"node:10.170.27.24": 1}' - -# on worker node (assume ray port of head_node is 6379) -ray start --address="10.170.27.24:6379" --resources='{"node:10.170.27.33": 1}' -``` - #### Run demo -In the demo below, we use ray actors to implement distributed deployment of processes. +In the demo below, we use ray actors to implement distributed deployment of processes. The actor writer writes data to the head node, and the actor reader reads data from the worker nodes. ```python from omegaconf import OmegaConf from tensordict import TensorDict +import transfer_queue as tq from transfer_queue import ( TransferQueueClient, TransferQueueController, @@ -312,10 +292,10 @@ import ray ######################################################################## # Please set up Ray cluster before running this script -# e.g. ray start --head --resources='{"node:127.0.0.1": 1}' +# e.g. ray start --head --resources='{"node:192.168.0.1": 1}' ######################################################################## -HEAD_NODE_IP = "10.170.27.24" # Replace with your head node IP -WORKER_NODE_IP = "10.170.27.33" # Replace with your worker node IP +HEAD_NODE_IP = "192.168.0.1" # Replace with your head node IP +WORKER_NODE_IP = "192.168.0.2" # Replace with your worker node IP def initialize_controller(): @@ -357,10 +337,12 @@ class TransferQueueClientActor: def main(): + tq.init() + config_str = """ manager_type: YuanrongStorageManager client_name: YuanrongStorageClient - port: 31501 + worker_port: 31501 """ dict_conf = OmegaConf.create(config_str, flags={"allow_objects": True}) # It is important to pay attention to the controller's lifecycle. @@ -387,6 +369,8 @@ def main(): ) output = ray.get(output) + tq.close() + if __name__ == "__main__": main() diff --git a/scripts/performance_test/README_PERFTEST.md b/scripts/performance_test/README_PERFTEST.md index 8e150986..16b4b6bf 100644 --- a/scripts/performance_test/README_PERFTEST.md +++ b/scripts/performance_test/README_PERFTEST.md @@ -64,8 +64,11 @@ backend: backend: storage_backend: Yuanrong Yuanrong: - port: 31501 + auto_init: True + worker_port: 31501 + metastore_port: 2379 enable_yr_npu_transport: true + worker_args: "--shared_memory_size_mb 65536 --remote_h2d_device_ids 0 --enable_huge_tlb true" ``` For Yuanrong backend, writer runs on the head node and reader runs on the worker node. `--worker_node_ip` is required. diff --git a/scripts/performance_test/perftest_config.yaml b/scripts/performance_test/perftest_config.yaml index 39538bc1..1533d7dc 100644 --- a/scripts/performance_test/perftest_config.yaml +++ b/scripts/performance_test/perftest_config.yaml @@ -48,7 +48,17 @@ backend: # For Yuanrong: Yuanrong: - # Port of local yuanrong datasystem worker - port: 31501 + # Whether to let TQ automatically init yuanrong + auto_init: True + # Datasystem worker port + worker_port: 31501 + # Metastore service port + metastore_port: 2379 # If enable npu transport enable_yr_npu_transport: true + # Additional config for yuanrong worker. + # Recommended options for NPU environments: + # --remote_h2d_device_ids Enable RH2D for efficient cross-node data transfer. Specify NPU device IDs (comma-separated). + # --enable_huge_tlb Enable huge page memory to improve performance. Required for >21GB shared memory on 910B. + # Example: "--shared_memory_size_mb 16384 --remote_h2d_device_ids 0,1,2,3 --enable_huge_tlb true" + worker_args: "--shared_memory_size_mb 65536 --remote_h2d_device_ids 0 --enable_huge_tlb true" diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 218ecba8..957c1e5c 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -80,8 +80,8 @@ "backend": { "storage_backend": "Yuanrong", "Yuanrong": { - "host": "127.0.0.1", - "port": 31501, + "worker_port": 31501, + "metastore_port": 2379, }, }, }, @@ -102,11 +102,12 @@ def backend_name(): """Get the backend name from environment variable. Environment variables: - TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore) + TQ_TEST_BACKEND: Backend name (SimpleStorage, MooncakeStore, or Yuanrong) To run tests for a specific backend: TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_e2e_lifecycle_consistency.py TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py + TQ_TEST_BACKEND=Yuanrong pytest tests/e2e/test_e2e_lifecycle_consistency.py """ return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index fb12edbb..8b171717 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -94,6 +94,18 @@ def tq_api(request): }, }, }, + "Yuanrong": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "Yuanrong", + "Yuanrong": { + "worker_port": 31501, + "metastore_port": 2379, + }, + }, + }, } @@ -112,11 +124,12 @@ def backend_name(): """Get the backend name from environment variable. Environment variables: - TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore) + TQ_TEST_BACKEND: Backend name (SimpleStorage, MooncakeStore, or Yuanrong) To run tests for a specific backend: TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_kv_interface_e2e.py TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py + TQ_TEST_BACKEND=Yuanrong pytest tests/e2e/test_kv_interface_e2e.py """ return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index bfe9a3bf..e3cc5915 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -57,8 +57,7 @@ def test_data(): cfg = { "controller_info": MagicMock(), "client_name": "YuanrongStorageClient", - "host": "127.0.0.1", - "port": 31501, + "worker_port": 31501, "device_id": 0, } global_indexes = [8, 9, 10] diff --git a/tests/test_storage_client_factory.py b/tests/test_storage_client_factory.py index 537cd1f1..012c0cc5 100644 --- a/tests/test_storage_client_factory.py +++ b/tests/test_storage_client_factory.py @@ -25,7 +25,7 @@ class Test(unittest.TestCase): def setUp(self): - self.cfg = {"host": "127.0.0.1", "port": 31501, "device_id": 0} + self.cfg = {"worker_port": 31501, "device_id": 0} @pytest.mark.skipif(find_spec("datasystem") is None, reason="datasystem is not available") def test_create_client(self): diff --git a/tests/test_yuanrong_client_zero_copy.py b/tests/test_yuanrong_client_zero_copy.py index f3da2716..c100d39e 100644 --- a/tests/test_yuanrong_client_zero_copy.py +++ b/tests/test_yuanrong_client_zero_copy.py @@ -46,7 +46,7 @@ def mock_kv_client(self, mocker): @pytest.fixture def storage_client(self, mock_kv_client): - return GeneralKVClientAdapter({"host": "127.0.0.1", "port": 31501}) + return GeneralKVClientAdapter({"worker_port": 31501}) def test_mset_mget_p2p(self, storage_client, mocker): # Mock serialization/deserialization diff --git a/tests/test_yuanrong_storage_client_e2e.py b/tests/test_yuanrong_storage_client_e2e.py index 3cb1f997..17b2f885 100644 --- a/tests/test_yuanrong_storage_client_e2e.py +++ b/tests/test_yuanrong_storage_client_e2e.py @@ -125,7 +125,7 @@ def mock_find_reachable_host(port, timeout=1.0): @pytest.fixture def config(): - return {"host": "127.0.0.1", "port": 12345, "enable_yr_npu_optimization": True} + return {"worker_port": 12345, "enable_yr_npu_optimization": True} def assert_tensors_equal(a: torch.Tensor, b: torch.Tensor): diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index d5a3bac6..d711dba8 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -48,10 +48,17 @@ backend: # For Yuanrong: Yuanrong: - # Whether to let TQ automatically start etcd and datasystem services + # Whether to let TQ automatically init yuanrong auto_init: True - # etcd service address (used to start etcd when auto_init=true) - etcd_address: "127.0.0.1:2379" - # datasystem worker host and port (used to start dscli when auto_init=true) - host: "127.0.0.1" - port: 31501 + # Datasystem worker port + worker_port: 31501 + # Metastore service port + metastore_port: 2379 + # If enable npu transport + enable_yr_npu_transport: false + # Additional config for yuanrong worker. + # Recommended options for NPU environments: + # --remote_h2d_device_ids Enable RH2D for efficient cross-node data transfer. Specify NPU device IDs (comma-separated). + # --enable_huge_tlb Enable huge page memory to improve performance. Required for >21GB shared memory on 910B. + # Example: "--shared_memory_size_mb 16384 --remote_h2d_device_ids 0,1,2,3 --enable_huge_tlb true" + worker_args: "--shared_memory_size_mb 8192" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 8b37ca33..21f1e8ae 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -16,10 +16,7 @@ import logging import math import os -import shutil -import socket import subprocess -import tempfile import time from importlib import resources from typing import Any, Optional @@ -38,6 +35,10 @@ from transfer_queue.sampler import BaseSampler from transfer_queue.storage.simple_backend import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group +from transfer_queue.utils.yuanrong_utils import ( + cleanup_yuanrong_resources, + initialize_yuanrong_backend, +) from transfer_queue.utils.zmq_utils import process_zmq_server_info logger = logging.getLogger(__name__) @@ -187,129 +188,8 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: f"Output:\n{error_msg}" ) _TRANSFER_QUEUE_STORAGE["MooncakeStore"] = process - if conf.backend.storage_backend == "Yuanrong": - if conf.backend.Yuanrong.auto_init: - etcd_process = None - etcd_data_dir = None - worker_address = None - if not shutil.which("etcd"): - raise RuntimeError( - "etcd executable not found in PATH. Please install etcd and make sure it's in the PATH." - ) - if not shutil.which("dscli"): - raise RuntimeError( - "dscli executable not found in PATH. Please run `pip install openyuanrong-datasystem`." - ) - try: - # ========== Start etcd ========== - etcd_address = "127.0.0.1:2379" - try: - etcd_address = conf.backend.Yuanrong.etcd_address - except Exception: - pass - - # Assume host:port format - parts = etcd_address.split(":") - if len(parts) != 2: - raise ValueError(f"Invalid etcd_address format: {etcd_address}. Expected host:port") - host = parts[0] - port = int(parts[1]) - - # Create temporary data directory - etcd_data_dir = tempfile.mkdtemp(prefix="tq_etcd_") - logger.info(f"Starting etcd with data directory: {etcd_data_dir}") - - cmd = [ - "etcd", - f"--data-dir={etcd_data_dir}", - f"--listen-client-urls=http://{host}:{port}", - f"--advertise-client-urls=http://{host}:{port}", - ] - - etcd_process = subprocess.Popen( - cmd, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - text=True, - bufsize=1, - universal_newlines=True, - start_new_session=True, - ) - time.sleep(3) # Wait for etcd to start - - if etcd_process.poll() is None: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - result = sock.connect_ex((host, port)) - sock.close() - if result != 0: - raise RuntimeError(f"etcd process started but not listening on {host}:{port}") - else: - raise RuntimeError(f"etcd exited immediately with return code {etcd_process.returncode}") - - logger.info(f"etcd started, PID: {etcd_process.pid}") - time.sleep(2) - - # ========== Start datasystem worker ========== - # Assume host:port format - worker_host = conf.backend.Yuanrong.host - worker_port = conf.backend.Yuanrong.port - worker_address = worker_host + ":" + str(worker_port) - - cmd = [ - "dscli", - "start", - "-w", - "--worker_address", - worker_address, - "--etcd_address", - etcd_address, - ] - - try: - ds_result = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - timeout=90, - ) - except subprocess.TimeoutExpired as err: - raise RuntimeError(f"dscli start timed out: {err}") from err - # Wait for dscli to start and exit (it starts worker and exits) - if ds_result.returncode == 0 and "[ OK ]" in ds_result.stdout: - logger.info(f"dscli started Yuanrong datasystem worker at {worker_address} successfully.") - - else: - raise RuntimeError( - f"Failed to start datasystem worker at {worker_address}. " - f"Return code: {ds_result.returncode}, Output: {ds_result.stdout}" - ) - - # Store processes and data directory - _TRANSFER_QUEUE_STORAGE["Yuanrong"] = { - "etcd": etcd_process, - "etcd_data_dir": etcd_data_dir, - "worker_address": worker_address, - "etcd_address": etcd_address, - } - logger.info("Yuanrong backend (etcd + datasystem) started successfully.") - - except Exception as e: - # Clean up on failure - if etcd_process is not None and etcd_process.poll() is None: - etcd_process.terminate() - try: - etcd_process.wait(timeout=5) - except subprocess.TimeoutExpired: - etcd_process.kill() - etcd_process.wait() - if etcd_data_dir is not None: - try: - shutil.rmtree(etcd_data_dir, ignore_errors=True) - except Exception: - pass - raise RuntimeError(f"Failed to start Yuanrong backend: {e}") from e + if conf.backend.storage_backend == "Yuanrong" and conf.backend.Yuanrong.auto_init: + _TRANSFER_QUEUE_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf) return conf @@ -335,6 +215,7 @@ def _init_from_existing() -> bool: conf = ray.get(_TRANSFER_QUEUE_CONTROLLER.get_config.remote()) if conf is not None: _maybe_create_transferqueue_client(conf) + logger.info("TransferQueueClient initialized.") return True @@ -475,51 +356,7 @@ def close(): except Exception: pass elif key == "Yuanrong": - # Stop etcd process and clean up data directory, stop datasystem worker via dscli - if isinstance(value, dict): - etcd_process = value.get("etcd") - etcd_data_dir = value.get("etcd_data_dir") - worker_address = value.get("worker_address") - - # Stop etcd if running - if etcd_process is not None and etcd_process.poll() is None: - etcd_process.terminate() - try: - etcd_process.wait(timeout=5) - except subprocess.TimeoutExpired: - etcd_process.kill() - etcd_process.wait() - - # Clean up etcd data directory - if etcd_data_dir is not None and os.path.exists(etcd_data_dir): - try: - shutil.rmtree(etcd_data_dir, ignore_errors=True) - logger.info(f"Cleaned up etcd data directory: {etcd_data_dir}") - except Exception as e: - logger.warning(f"Failed to clean up etcd data directory {etcd_data_dir}: {e}") - - # Stop datasystem worker via dscli command - if worker_address: - try: - result = subprocess.run( - ["dscli", "stop", "--worker_address", worker_address], - timeout=90, - capture_output=True, - ) - if result.returncode == 0: - logger.info(f"Stopped datasystem worker at {worker_address} via dscli stop") - else: - error_msg = (result.stderr or result.stdout or b"").decode() - logger.warning( - f"Failed to stop datasystem worker at {worker_address}. " - f"Return code: {result.returncode}, Error: {error_msg}" - ) - except subprocess.TimeoutExpired as err: - logger.warning(f"dscli stop timed out for {worker_address}: {err}") - except Exception as e: - logger.warning(f"Failed to stop datasystem worker via dscli: {e}") - else: - logger.warning(f"Unexpected Yuanrong storage value: {value}") + cleanup_yuanrong_resources(value) else: logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index bdf48d62..fccd9a9b 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -15,7 +15,6 @@ import logging import os -import socket import struct from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -27,109 +26,12 @@ from transfer_queue.storage.clients.base import TransferQueueStorageKVClient from transfer_queue.storage.clients.factory import StorageClientFactory from transfer_queue.utils.serial_utils import _decoder, _encoder +from transfer_queue.utils.yuanrong_utils import find_reachable_host logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) -def get_local_ip_addresses() -> list[str]: - """Get all local IP addresses including 127.0.0.1. - - Returns: - List of local IP addresses, with 127.0.0.1 first. - """ - ips = ["127.0.0.1"] - - try: - hostname = socket.gethostname() - # Add hostname resolution - try: - host_ip = socket.gethostbyname(hostname) - if host_ip not in ips: - ips.append(host_ip) - except socket.gaierror: - pass - - # Get all network interfaces - import netifaces - - for interface in netifaces.interfaces(): - try: - addrs = netifaces.ifaddresses(interface) - if netifaces.AF_INET in addrs: - for addr_info in addrs[netifaces.AF_INET]: - ip = addr_info.get("addr") - if ip and ip not in ips: - ips.append(ip) - except (ValueError, KeyError): - continue - except ImportError: - # Fallback if netifaces is not available - try: - # Try to get IP by connecting to an external address - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - # Doesn't need to be reachable - s.connect(("8.8.8.8", 80)) - ip = s.getsockname()[0] - if ip not in ips: - ips.append(ip) - except Exception: - pass - finally: - s.close() - except Exception: - pass - - return ips - - -def check_port_connectivity(host: str, port: int, timeout: float = 2.0) -> bool: - """Check if a TCP port is reachable on the given host. - - Args: - host: Host IP address to check - port: Port number to check - timeout: Connection timeout in seconds - - Returns: - True if the port is reachable, False otherwise - """ - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(timeout) - result = sock.connect_ex((host, port)) - sock.close() - return result == 0 - except Exception: - return False - - -def find_reachable_host(port: int, timeout: float = 1.0) -> Optional[str]: - """Find a reachable local host IP address for the given port. - - Tries all local IP addresses in order and returns the first one - that has the given port open. - - Args: - port: Port number to check - timeout: Connection timeout in seconds per check - - Returns: - The first reachable host IP address, or None if none found. - """ - local_ips = get_local_ip_addresses() - logger.info(f"Checking port {port} on local IPs: {local_ips}") - - for ip in local_ips: - if check_port_connectivity(ip, port, timeout): - logger.info(f"Found reachable host: {ip}:{port}") - return ip - - logger.warning(f"No reachable host found for port {port}") - return None - - YUANRONG_DATASYSTEM_IMPORTED: bool = True try: @@ -183,10 +85,10 @@ class NPUTensorKVClientAdapter(StorageStrategy): KEYS_LIMIT: int = 10_000 def __init__(self, config: dict): - port = config.get("port") + port = config.get("worker_port") if port is None or not isinstance(port, int): - raise ValueError("Missing or invalid 'port' in config") + raise ValueError("Missing or invalid 'worker_port' in config") logger.info(f"Auto-detecting reachable host for Yuanrong port {port}...") host = find_reachable_host(port) @@ -274,7 +176,7 @@ def clear(self, keys: list[str]): # Todo(dpj): Test call clear when no (key,value) put in ds self._ds_client.delete(batch) - def _create_empty_npu_tensorlist(self, shapes, dtypes): + def _create_empty_npu_tensorlist(self, shapes: list, dtypes: list): """ Create a list of empty NPU tensors with given shapes and dtypes. @@ -310,10 +212,10 @@ class GeneralKVClientAdapter(StorageStrategy): DS_MAX_WORKERS: int = 16 def __init__(self, config: dict): - port = config.get("port") + port = config.get("worker_port") if port is None or not isinstance(port, int): - raise ValueError("Missing or invalid 'port' in config") + raise ValueError("Missing or invalid 'worker_port' in config") logger.info(f"Auto-detecting reachable host for Yuanrong port {port}...") host = find_reachable_host(port) @@ -479,10 +381,10 @@ def __init__(self, config: dict[str, Any]): if not YUANRONG_DATASYSTEM_IMPORTED: raise ImportError("YuanRong DataSystem not installed.") - port = config.get("port") + port = config.get("worker_port") if port is None or not isinstance(port, int): - raise ValueError("Missing or invalid 'port' in config") + raise ValueError("Missing or invalid 'worker_port' in config") super().__init__(config) diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index d5270401..69050acf 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -36,11 +36,11 @@ class YuanrongStorageManager(KVStorageManager): """Storage manager for Yuanrong backend.""" def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): - port = config.get("port", None) + worker_port = config.get("worker_port", None) client_name = config.get("client_name", None) - if port is None or not isinstance(port, int): - raise ValueError("Missing or invalid 'port' in config") + if worker_port is None or not isinstance(worker_port, int): + raise ValueError("Missing or invalid 'worker_port' in config") if client_name is None: logger.info("Missing 'client_name' in config, using default value('YuanrongStorageClient')") diff --git a/transfer_queue/utils/yuanrong_utils.py b/transfer_queue/utils/yuanrong_utils.py new file mode 100644 index 00000000..fd3afbd6 --- /dev/null +++ b/transfer_queue/utils/yuanrong_utils.py @@ -0,0 +1,588 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import shutil +import socket +import subprocess +from typing import Any, Optional + +import ray +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + + +def get_local_ip_addresses() -> list[str]: + """Get all local IP addresses including 127.0.0.1. + + Returns: + List of local IP addresses, with 127.0.0.1 first. + """ + ips = ["127.0.0.1"] + + try: + hostname = socket.gethostname() + # Add hostname resolution + try: + host_ip = socket.gethostbyname(hostname) + if host_ip not in ips: + ips.append(host_ip) + except socket.gaierror: + pass + + # Get all network interfaces + import netifaces + + for interface in netifaces.interfaces(): + try: + addrs = netifaces.ifaddresses(interface) + if netifaces.AF_INET in addrs: + for addr_info in addrs[netifaces.AF_INET]: + ip = addr_info.get("addr") + if ip and ip not in ips: + ips.append(ip) + except (ValueError, KeyError): + continue + except ImportError: + # Fallback if netifaces is not available + try: + # Try to get IP by connecting to an external address + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # Doesn't need to be reachable + s.connect(("8.8.8.8", 80)) + ip = s.getsockname()[0] + if ip not in ips: + ips.append(ip) + except Exception: + pass + finally: + s.close() + except Exception: + pass + + return ips + + +def check_port_connectivity(host: str, port: int, timeout: float = 2.0) -> bool: + """Check if a TCP port is reachable on given host. + + Args: + host: Host IP address to check + port: Port number to check + timeout: Connection timeout in seconds + + Returns: + True if port is reachable, False otherwise + """ + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + result = sock.connect_ex((host, port)) + sock.close() + return result == 0 + except Exception: + return False + + +def find_reachable_host(port: int, timeout: float = 1.0) -> Optional[str]: + """Find a reachable local host IP address for given port. + + Tries all local IP addresses in order and returns the first one + that has the given port open. + + Args: + port: Port number to check + timeout: Connection timeout in seconds per check + + Returns: + The first reachable host IP address, or None if none found. + """ + local_ips = get_local_ip_addresses() + logger.info(f"Checking port {port} on local IPs: {local_ips}") + + for ip in local_ips: + if check_port_connectivity(ip, port, timeout): + logger.info(f"Found reachable host: {ip}:{port}") + return ip + + logger.warning(f"No reachable host found for port {port}") + return None + + +def _parse_remote_h2d_device_ids(worker_args: str) -> Optional[str]: + """Parse --remote_h2d_device_ids parameter from worker_args string. + + Args: + worker_args: Worker arguments string, e.g., "--arg1 value1 --remote_h2d_device_ids 0,1,2,3" + + Returns: + The device IDs string if found and valid, None otherwise. + + Raises: + RuntimeError: If --remote_h2d_device_ids flag is found but has invalid format. + """ + if not worker_args: + return None + + args_list = worker_args.split() + + # Find the index of --remote_h2d_device_ids + try: + idx = args_list.index("--remote_h2d_device_ids") + except ValueError: + return None + + # Check if there's a value after the flag + if idx + 1 >= len(args_list): + raise RuntimeError("--remote_h2d_device_ids flag found but no value provided") + + device_ids = args_list[idx + 1] + + # Validate the format: comma-separated digits + if not device_ids: + raise RuntimeError("Empty device IDs value after --remote_h2d_device_ids") + + # Validate each segment is a digit + parts = device_ids.split(",") + for part in parts: + if not part.isdigit(): + raise RuntimeError( + f"Invalid device ID format: '{device_ids}'. Expected comma-separated digits (e.g., '0,1,2,3')." + ) + + return device_ids + + +def start_datasystem_worker( + worker_address: str, + metastore_address: str, + is_head: bool, + worker_args: str = "", +) -> None: + """Start Yuanrong datasystem worker in metastore mode. + + Args: + worker_address: Worker address in format host:port + metastore_address: Metastore address in format host:port + is_head: Whether this node should start metastore service + worker_args: Additional arguments to append to dscli start command + + Raises: + RuntimeError: If dscli command fails + """ + if not shutil.which("dscli"): + raise RuntimeError("dscli executable not found in PATH. Please run `pip install openyuanrong-datasystem`.") + + cmd = ["dscli", "start", "-w", "--worker_address", worker_address] + cmd.extend(["--metastore_address", metastore_address]) + if is_head: + cmd.extend(["--start_metastore_service", "true"]) + + # Built-in default options + cmd.extend(["--arena_per_tenant", "1", "--enable_worker_worker_batch_get", "true"]) + + # Append worker_args if provided + if worker_args: + cmd.extend(worker_args.split()) + + node_type = "head node" if is_head else "worker node" + logger.info(f"Starting Yuanrong datasystem ({node_type}) at {worker_address}, worker_args={worker_args}") + + # Build environment with ASCEND_RT_VISIBLE_DEVICES if specified + env = None + device_ids = _parse_remote_h2d_device_ids(worker_args) + if device_ids: + env = os.environ.copy() + env["ASCEND_RT_VISIBLE_DEVICES"] = device_ids + logger.info( + f"Setting ASCEND_RT_VISIBLE_DEVICES={device_ids} for dscli subprocess ({node_type} at {worker_address})" + ) + + try: + ds_result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=90, + env=env, + ) + except subprocess.TimeoutExpired as err: + raise RuntimeError(f"dscli start timed out: {err}") from err + + if ds_result.returncode == 0 and "[ OK ]" in ds_result.stdout: + logger.info( + f"dscli started Yuanrong datasystem ({node_type}, metastore mode) at {worker_address} successfully." + ) + else: + raise RuntimeError( + f"Failed to start datasystem ({node_type}, metastore mode) at {worker_address}. " + f"Return code: {ds_result.returncode}, Output: {ds_result.stdout}" + ) + + +def stop_datasystem_worker(worker_address: str) -> None: + """Stop Yuanrong datasystem worker. + + Args: + worker_address: Worker address in format host:port + """ + if worker_address: + try: + result = subprocess.run( + ["dscli", "stop", "--worker_address", worker_address], + timeout=90, + capture_output=True, + ) + if result.returncode == 0: + logger.info(f"Stopped datasystem worker at {worker_address} via dscli stop") + else: + error_msg = (result.stderr or result.stdout or b"").decode() + logger.warning( + f"Failed to stop datasystem worker at {worker_address}. " + f"Return code: {result.returncode}, Error: {error_msg}" + ) + except subprocess.TimeoutExpired as err: + logger.warning(f"dscli stop timed out for {worker_address}: {err}") + except Exception as e: + logger.warning(f"Failed to stop datasystem worker via dscli: {e}") + + +@ray.remote(num_cpus=0.1) +class YuanrongWorkerActor: + """Ray actor to manage Yuanrong datasystem worker on a node. + + This actor runs on each node in the Ray cluster and is responsible for + starting and stopping the Yuanrong datasystem worker process on that node. + + The actor determines its own rank and role (head or worker) by finding the + intersection of local IP addresses with the provided node IPs. + """ + + def __init__(self, node_ips: list[str], worker_port: int, metastore_port: int, worker_args: str = ""): + """Initialize the Yuanrong worker actor. + + Args: + node_ips: List of all node IPs in the Ray cluster + worker_port: Port for the datasystem worker + metastore_port: Port for the metastore service (on head node) + worker_args: Additional arguments to append to dscli start command + + Raises: + RuntimeError: If cannot determine this node's IP from node_ips + """ + local_ips = get_local_ip_addresses() + self.my_ip = None + + # Find the intersection between local IPs and node_ips + for ip in node_ips: + if ip in local_ips: + self.my_ip = ip + break + + if self.my_ip is None: + raise RuntimeError(f"Cannot determine local node IP. Local IPs: {local_ips}, Cluster node IPs: {node_ips}") + + self.node_ips = node_ips + self.worker_port = worker_port + self.metastore_port = metastore_port + self.worker_address = f"{self.my_ip}:{worker_port}" + self.worker_args = worker_args + + # First node in the list is assumed to be the head node. + # This assumption is based on how interface.py constructs node_ips from ray.nodes(). + self.head_node_ip = node_ips[0] + self.metastore_address = f"{self.head_node_ip}:{metastore_port}" + self.is_head = self.my_ip == self.head_node_ip + + logger.info( + f"YuanrongWorkerActor initialized on node {self.my_ip}: " + f"worker_address={self.worker_address}, " + f"metastore_address={self.metastore_address}, is_head={self.is_head}, worker_args={self.worker_args}" + ) + + def start(self) -> str: + """Start the datasystem worker on this node. + + Returns: + The worker address. + + Raises: + RuntimeError: If dscli command fails + """ + logger.info(f"Starting datasystem worker at {self.worker_address}...") + start_datasystem_worker( + self.worker_address, + metastore_address=self.metastore_address, + is_head=self.is_head, + worker_args=self.worker_args, + ) + logger.info(f"Datasystem worker started successfully at {self.worker_address}") + return self.worker_address + + def get_metastore_address(self) -> str: + """Get the metastore address. + + Returns: + The metastore address in format host:port + """ + return self.metastore_address + + def get_node_ip(self) -> str: + """Return the IP address of the node this actor is running on.""" + assert self.my_ip is not None + return self.my_ip + + def stop(self) -> None: + """Stop the datasystem worker on this node.""" + logger.info(f"Stopping datasystem worker at {self.worker_address}...") + stop_datasystem_worker(self.worker_address) + logger.info(f"Datasystem worker stopped successfully at {self.worker_address}") + + +def _kill_actors_and_placement_group(worker_actors: list, placement_group: Any) -> None: + """Kill actors and remove placement group without stopping workers. + + Args: + worker_actors: List of Yuanrong worker actors to kill + placement_group: Placement group to remove + """ + for actor in worker_actors: + try: + ray.kill(actor) + except Exception: + pass + if placement_group: + try: + ray.util.remove_placement_group(placement_group) + except Exception: + pass + + +def cleanup_yuanrong_resources(storage_value: Any) -> None: + """Stop Yuanrong workers and cleanup resources. + + Args: + storage_value: Yuanrong storage dict containing worker_actors and placement_group + """ + if not isinstance(storage_value, dict): + logger.warning(f"Unexpected Yuanrong storage value: {storage_value}") + return + + worker_actors = storage_value.get("worker_actors", []) + placement_group = storage_value.get("placement_group") + + try: + if worker_actors: + logger.info(f"Cleaning up Yuanrong backend (stopping {len(worker_actors)} workers)...") + + # Stop worker nodes (all except head node 0) in parallel first + stop_exceptions = [] + if len(worker_actors) > 1: + logger.info(f"Stopping {len(worker_actors) - 1} worker nodes (excluding head) in parallel...") + stop_refs = [actor.stop.remote() for actor in worker_actors[1:]] + for idx, stop_ref in enumerate(stop_refs, start=1): + try: + ray.get(stop_ref) + except Exception as e: + stop_exceptions.append(e) + logger.warning(f"Failed to stop worker node actor {idx}: {e}") + if len(stop_exceptions) < len(stop_refs): + logger.info("Completed stop requests for non-head worker nodes") + + # Then stop head node (actor 0) which runs metastore service + logger.info("Stopping head node with metastore service...") + try: + ray.get(worker_actors[0].stop.remote()) + logger.info("Head node stopped successfully") + except Exception as e: + stop_exceptions.append(e) + logger.warning(f"Failed to stop head node actor: {e}") + + if stop_exceptions: + logger.warning(f"Encountered {len(stop_exceptions)} errors while stopping workers") + finally: + # Kill actors and remove placement group even if graceful stop fails. + _kill_actors_and_placement_group(worker_actors, placement_group) + if placement_group: + logger.info("Removed Yuanrong placement group") + + +def initialize_yuanrong_backend(conf: DictConfig) -> dict[str, Any]: + """Initialize Yuanrong backend with metastore mode. + + This function sets up the Yuanrong datasystem workers across all Ray nodes + using placement groups and actors. + + Args: + conf: Configuration containing Yuanrong settings + + Returns: + Dict containing worker_actors, metastore_address, and placement_group + + Raises: + RuntimeError: If Ray nodes not found or initialization fails + """ + # Get Ray cluster information + nodes = ray.nodes() + if not nodes: + raise RuntimeError("No Ray nodes found. Is Ray initialized?") + + # Filter to only alive nodes and get their IPs + alive_nodes = [node for node in nodes if node.get("Alive", False)] + if not alive_nodes: + raise RuntimeError("No alive Ray nodes found") + + # Get driver node IP to use as head node + driver_ip = ray.util.get_node_ip_address() + head_node = None + other_nodes = [] + + # Separate head node (driver) from other nodes + for node in alive_nodes: + node_ip = node["NodeManagerAddress"] + if node_ip == driver_ip: + head_node = node + else: + other_nodes.append(node) + + if head_node is None: + raise RuntimeError(f"Driver node {driver_ip} not found in alive nodes") + + # Reorder nodes: head node first, then others + ordered_nodes = [head_node] + other_nodes + + # Extract node IPs in deterministic order + node_ips = [node["NodeManagerAddress"] for node in ordered_nodes] + worker_port = conf.backend.Yuanrong.worker_port + metastore_port = conf.backend.Yuanrong.metastore_port + worker_args = conf.backend.Yuanrong.get("worker_args", "") + + logger.info(f"Found {len(ordered_nodes)} alive Ray nodes: {node_ips}") + + # Create placement group using STRICT_SPREAD to ensure each bundle is on a distinct node + bundles = [{"CPU": 0.1} for _ in ordered_nodes] + + pg = ray.util.placement_group(bundles, strategy="STRICT_SPREAD") + try: + ray.get(pg.ready(), timeout=60) + except ray.exceptions.GetTimeoutError as e: + try: + ray.util.remove_placement_group(pg) + except Exception as cleanup_error: + logger.warning(f"Failed to remove placement group after readiness timeout: {cleanup_error}") + raise RuntimeError( + "Timed out waiting for Yuanrong placement group to become ready. " + f"Requested strategy=STRICT_SPREAD, bundles={bundles}. " + "This may be due to insufficient cluster capacity." + ) from e + except Exception as e: + try: + ray.util.remove_placement_group(pg) + except Exception as cleanup_error: + logger.warning(f"Failed to remove placement group after scheduling failure: {cleanup_error}") + raise RuntimeError( + f"Failed to create Yuanrong placement group. Requested strategy=STRICT_SPREAD, bundles={bundles}." + ) from e + + logger.info(f"Created placement group with {len(bundles)} bundles using STRICT_SPREAD") + + try: + # Create all worker actors using placement group + # Without node resources, actor scheduling order is not guaranteed to match node order + # We'll identify head node actor by checking which node it runs on + worker_actors = [] + for rank in range(len(ordered_nodes)): + actor = YuanrongWorkerActor.options( # type: ignore[attr-defined] + placement_group=pg, + placement_group_bundle_index=rank, + ).remote(node_ips, worker_port, metastore_port, worker_args) + worker_actors.append(actor) + + logger.info(f"Created {len(worker_actors)} YuanrongWorkerActor instances") + + # Find which actor is running on the head node (driver IP) + # The head node actor needs to start first to initialize metastore service + head_actor_index = None + for idx, actor in enumerate(worker_actors): + try: + node_ip = ray.get(actor.get_node_ip.remote()) + if node_ip == driver_ip: + head_actor_index = idx + break + except Exception: + pass + + if head_actor_index is None: + logger.warning("Could not identify head node actor, using actor 0 as default") + head_actor_index = 0 + + logger.info(f"Head node actor identified: actor {head_actor_index}") + + # Start head worker first to initialize metastore service + logger.info("Starting head worker to initialize metastore...") + ray.get(worker_actors[head_actor_index].start.remote()) + metastore_address = ray.get(worker_actors[head_actor_index].get_metastore_address.remote()) + logger.info(f"Head worker started, metastore address: {metastore_address}") + + # Start remaining worker actors in parallel + other_actors = [worker_actors[i] for i in range(len(worker_actors)) if i != head_actor_index] + if other_actors: + logger.info(f"Starting {len(other_actors)} worker actors in parallel...") + ray.get([actor.start.remote() for actor in other_actors]) + + logger.info( + f"Yuanrong backend started successfully: metastore at {metastore_address}, workers on {len(node_ips)} nodes" + ) + + return { + "worker_actors": worker_actors, + "metastore_address": metastore_address, + "placement_group": pg, + } + except Exception as e: + # Cleanup on initialization failure: attempt graceful stop of started workers first + logger.error(f"Failed to start Yuanrong workers: {e}, cleaning up...") + + # Try to gracefully stop workers that may have already started + if worker_actors: + stop_exceptions = [] + # Stop worker nodes (all except head node 0) first + if len(worker_actors) > 1: + stop_refs = [actor.stop.remote() for actor in worker_actors[1:]] + for idx, stop_ref in enumerate(stop_refs, start=1): + try: + ray.get(stop_ref, timeout=30) + except Exception as stop_e: + stop_exceptions.append(stop_e) + logger.warning(f"Failed to stop worker node actor {idx}: {stop_e}") + # Stop head node (actor 0) + try: + ray.get(worker_actors[0].stop.remote(), timeout=30) + except Exception as stop_e: + stop_exceptions.append(stop_e) + logger.warning(f"Failed to stop head node actor: {stop_e}") + + if stop_exceptions: + logger.warning(f"Encountered {len(stop_exceptions)} errors during graceful worker stop") + + # Then kill actors and remove placement group + _kill_actors_and_placement_group(worker_actors, pg) + raise