diff --git a/README.md b/README.md index ccc8b564..7f4f6da0 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo ### Control Plane: Panoramic Data Management -In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. Once all required data fields are ready (i.e., written to the `TransferQueueStorageManager`), the data sample can be consumed by downstream tasks. +In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. Once all required data fields are ready (i.e., written to the `StorageManager`), the data sample can be consumed by downstream tasks. We also track the consumption history for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computational tasks require the same data field, they can consume the data independently without interfering with each other. @@ -66,7 +66,7 @@ To make the data retrieval process more customizable, we provide a `Sampler` cla In the data plane, we utilize a pluggable design, enabling TransferQueue to integrate with different storage backends based on user requirements. -Specifically, we provide a `TransferQueueStorageManager` abstraction class that defines the core APIs as follows: +Specifically, we provide a `StorageManager` abstraction class that defines the core APIs as follows: - `async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None` - `async def get_data(self, metadata: BatchMeta) -> TensorDict` @@ -298,21 +298,19 @@ The data plane is organized as follows: │ │── simple_backend.py # Default distributed storage backend (SimpleStorageUnit) by TQ │ ├── managers/ # Managers are upper level interfaces that encapsulate the interaction logic with TQ system. │ │ ├── __init__.py - │ │ ├──base.py # TransferQueueStorageManager, KVStorageManager - │ │ ├──simple_backend_manager.py # AsyncSimpleStorageManager + │ │ ├──base.py # StorageManager, KVStorageManager, StorageManagerFactory + │ │ ├──simple_storage_manager.py # AsyncSimpleStorageManager │ │ ├──yuanrong_manager.py # YuanrongStorageManager - │ │ ├──mooncake_manager.py # MooncakeStorageManager - │ │ └──factory.py # TransferQueueStorageManagerFactory + │ │ └──mooncake_manager.py # MooncakeStorageManager │ └── clients/ # Clients are lower level interfaces that directly manipulate the target storage backend. │ │ ├── __init__.py - │ │ ├── base.py # TransferQueueStorageKVClient + │ │ ├── base.py # StorageKVClient, StorageClientFactory │ │ ├── yuanrong_client.py # YuanrongStorageClient │ │ ├── mooncake_client.py # MooncakeStorageClient - │ │ ├── ray_storage_client.py # RayStorageClient - │ │ └── factory.py # TransferQueueStorageClientFactory + │ │ └── ray_storage_client.py # RayStorageClient ``` -To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `TransferQueueStorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends. +To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `StorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends. Distributed storage backends often come with their own native clients serving as the interface of the storage system. In such cases, a low-level adapter for this client can be written, following the examples provided in the `storage/clients` directory. diff --git a/pyproject.toml b/pyproject.toml index 0bbb2f0c..20b148be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requires-python = ">=3.10" # Note: While the formatter will attempt to format lines such that they remain within the line-length, # it isn't a hard upper bound, and formatted lines may exceed the line-length. line-length = 120 +target-version = "py310" [tool.ruff.lint] isort = {known-first-party = ["transfer_queue"]} @@ -60,7 +61,7 @@ ignore = [ # `.log()` statement uses f-string "G004", # X | None for type annotations - "UP045", + # "UP045", # deprecated import "UP035", ] diff --git a/recipe/simple_use_case/single_controller_demo.py b/recipe/simple_use_case/single_controller_demo.py index af8689ab..3b1215b7 100644 --- a/recipe/simple_use_case/single_controller_demo.py +++ b/recipe/simple_use_case/single_controller_demo.py @@ -15,7 +15,6 @@ import argparse import asyncio -import logging import os import random import time @@ -32,9 +31,9 @@ import transfer_queue as tq from transfer_queue import KVBatchMeta +from transfer_queue.utils.logging_utils import get_logger -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) +logger = get_logger(__name__) os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index 005572c2..c67bb54c 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -30,7 +30,7 @@ from transfer_queue import TransferQueueClient from transfer_queue.controller import TransferQueueController -from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.zmq_utils import process_zmq_server_info diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 4d1419a6..606e9822 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -24,7 +24,7 @@ from transfer_queue.metadata import BatchMeta from transfer_queue.storage import AsyncSimpleStorageManager -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo @@ -35,13 +35,13 @@ async def mock_async_storage_manager(): # Mock storage unit infos storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 12345}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 12346}, @@ -50,7 +50,7 @@ async def mock_async_storage_manager(): # Mock controller info controller_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", ports={"handshake_socket": 12347, "data_status_update_socket": 12348}, @@ -61,9 +61,7 @@ async def mock_async_storage_manager(): } # Mock the handshake process entirely to avoid ZMQ complexity - with patch( - "transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller" - ) as mock_connect: + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller") as mock_connect: # Mock the manager without actually connecting manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_storage_manager" @@ -148,7 +146,7 @@ async def test_async_storage_manager_error_handling(): # Mock storage unit infos storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 12345}, @@ -157,7 +155,7 @@ async def test_async_storage_manager_error_handling(): # Mock controller info controller_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", ports={"handshake_socket": 12346, "data_status_update_socket": 12347}, @@ -242,19 +240,19 @@ async def test_get_data_routes_from_hash(): """get_data should route using global_idx % num_su (hash routing).""" storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19010}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19011}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_get" manager.storage_unit_infos = storage_unit_infos @@ -295,19 +293,19 @@ async def test_clear_data_routes_from_hash(): """clear_data should route using global_idx % num_su (hash routing).""" storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19020}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19021}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_clear" manager.storage_unit_infos = storage_unit_infos @@ -346,19 +344,19 @@ async def test_hash_routing_stable_across_batch_sizes(): """ storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19030}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19031}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_hash_batch" manager.storage_unit_infos = storage_unit_infos @@ -407,19 +405,19 @@ async def test_hash_routing_stable_reversed_order(): """ storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19040}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19041}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_hash_order" manager.storage_unit_infos = storage_unit_infos diff --git a/tests/test_client.py b/tests/test_client.py index 7ce8bcbf..2e7e37a0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -24,7 +24,7 @@ from transfer_queue import TransferQueueClient from transfer_queue.metadata import BatchMeta -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.zmq_utils import ( ZMQMessage, ZMQRequestType, @@ -59,7 +59,7 @@ def __init__(self, controller_id="controller_0"): self.request_port = self._bind_to_random_port(self.request_socket) self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id=controller_id, ip="127.0.0.1", ports={ @@ -300,7 +300,7 @@ def __init__(self, storage_id="storage_0"): self.data_port = self._bind_to_random_port(self.data_socket) self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id=storage_id, ip="127.0.0.1", ports={ @@ -409,7 +409,7 @@ def client_setup(mock_controller, mock_storage): # Mock the storage manager to avoid handshake issues but mock all data operations with patch( - "transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller" + "transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller" ): config = { "controller_info": mock_controller.zmq_server_info, @@ -502,7 +502,7 @@ def test_single_controller_multiple_storages(): # Mock the storage manager to avoid handshake issues but mock all data operations with patch( - "transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller" + "transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller" ): config = { "controller_info": controller.zmq_server_info, diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 20ace48e..b9e80dac 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -678,7 +678,7 @@ class TestStorageUnitDataStrict: def test_put_data_length_mismatch_raises(self): """put_data must raise when global_indexes and field values have different lengths.""" - from transfer_queue.storage.simple_backend import StorageUnitData + from transfer_queue.storage.simple_storage import StorageUnitData sud = StorageUnitData(storage_size=10) # 3 indexes but only 2 values — must raise, not silently drop diff --git a/tests/test_ray_p2p.py b/tests/test_ray_p2p.py index b9f0b835..353bb926 100644 --- a/tests/test_ray_p2p.py +++ b/tests/test_ray_p2p.py @@ -23,8 +23,7 @@ from transfer_queue.client import TransferQueueClient from transfer_queue.metadata import BatchMeta -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.zmq_utils import ZMQServerInfo TEST_CONFIGS: list[tuple[tuple[int, int], torch.dtype]] = [ @@ -45,18 +44,18 @@ # Step 1: Mock Controller Role try: - from transfer_queue.role import TransferQueueRole + from transfer_queue.role import Role except ImportError: from enum import Enum - class TransferQueueRole(Enum): + class Role(Enum): CONTROLLER = "controller" STORAGE = "storage" def create_mock_controller(): return ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", ports={ @@ -71,9 +70,9 @@ def create_mock_controller(): def ensure_mock_storage_manager_registered(): """Ensure MockKVStorageManager is registered in current process.""" - if "KV_MOCK" not in TransferQueueStorageManagerFactory._registry: + if "KV_MOCK" not in StorageManagerFactory._registry: - @TransferQueueStorageManagerFactory.register("KV_MOCK") + @StorageManagerFactory.register("KV_MOCK") class MockKVStorageManager(KVStorageManager): def _connect_to_controller(self): pass diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index 2519b975..319a46e7 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -21,7 +21,7 @@ import torch import zmq -from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType @@ -420,7 +420,7 @@ def test_storage_unit_data_direct(): def test_storage_unit_data_capacity_uses_active_keys(): """Capacity check must use _active_keys, not scan field_data.""" - from transfer_queue.storage.simple_backend import StorageUnitData + from transfer_queue.storage.simple_storage import StorageUnitData storage = StorageUnitData(storage_size=3) diff --git a/tests/test_storage_client_factory.py b/tests/test_storage_client_factory.py index 012c0cc5..0e1a5f90 100644 --- a/tests/test_storage_client_factory.py +++ b/tests/test_storage_client_factory.py @@ -19,7 +19,7 @@ import pytest import torch -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 5f3a4e3f..45e6303a 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -16,20 +16,15 @@ import asyncio import os import threading -from typing import Any, Callable, Optional +from typing import Any, Callable import torch import zmq import zmq.asyncio from tensordict import TensorDict -from torch import Tensor -from transfer_queue.metadata import ( - BatchMeta, -) -from transfer_queue.storage import ( - TransferQueueStorageManagerFactory, -) +from transfer_queue.metadata import BatchMeta +from transfer_queue.storage import StorageManagerFactory from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ( @@ -92,7 +87,7 @@ def initialize_storage_manager( - zmq_info: ZMQ server information about the storage units """ - self.storage_manager = TransferQueueStorageManagerFactory.create( + self.storage_manager = StorageManagerFactory.create( manager_type, controller_info=self._controller, config=config ) @@ -104,9 +99,9 @@ async def async_get_meta( batch_size: int, partition_id: str, mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, - socket: Optional[zmq.asyncio.Socket] = None, + task_name: str | None = None, + sampling_config: dict[str, Any] | None = None, + socket: zmq.asyncio.Socket | None = None, ) -> BatchMeta: """Asynchronously fetch data metadata from the controller via ZMQ. @@ -191,7 +186,7 @@ async def async_get_meta( async def async_set_custom_meta( self, metadata: BatchMeta, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> None: """ Asynchronously send custom metadata to the controller. @@ -264,9 +259,9 @@ async def async_set_custom_meta( async def async_put( self, data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + metadata: BatchMeta | None = None, + partition_id: str | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> BatchMeta: """Asynchronously write data to storage units based on metadata. @@ -575,8 +570,8 @@ async def async_get_consumption_status( self, task_name: str, partition_id: str, - socket: Optional[zmq.asyncio.Socket] = None, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + socket: zmq.asyncio.Socket | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """Get consumption status for current partition in a specific task. Args: @@ -638,8 +633,8 @@ async def async_get_production_status( self, data_fields: list[str], partition_id: str, - socket: Optional[zmq.asyncio.Socket] = None, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + socket: zmq.asyncio.Socket | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """Get production status for specific data fields and partition. Args: @@ -769,8 +764,8 @@ async def async_check_production_status( async def async_reset_consumption( self, partition_id: str, - task_name: Optional[str] = None, - socket: Optional[zmq.asyncio.Socket] = None, + task_name: str | None = None, + socket: zmq.asyncio.Socket | None = None, ) -> bool: """Asynchronously reset consumption status for a partition. @@ -830,7 +825,7 @@ async def async_reset_consumption( @with_controller_socket async def async_get_partition_list( self, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> list[str]: """Asynchronously fetch the list of partition ids from the controller. @@ -879,23 +874,22 @@ async def async_kv_retrieve_meta( keys: list[str] | str, partition_id: str, create: bool = False, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> BatchMeta: - """Asynchronously retrieve BatchMeta from the controller using user-specified keys. + """Asynchronously retrieve BatchMeta by user-defined keys. + + Retrieves metadata for given keys from a specified partition. + If keys do not exist and `create=True`, they will be automatically registered. Args: - keys: List of keys to retrieve from the controller + keys: List of keys to retrieve. partition_id: The ID of the logical partition to search for keys. - create: Whether to register new keys if not found. - socket: ZMQ socket (injected by decorator) + create: If True, automatically create entries for missing keys. + socket: ZMQ socket injected by @with_controller_socket. Returns: - metadata: BatchMeta of the corresponding keys - - Raises: - TypeError: If `keys` is not a list of string or a string + BatchMeta: Metadata for the requested keys. """ - if isinstance(keys, str): keys = [keys] elif isinstance(keys, list): @@ -919,32 +913,30 @@ async def async_kv_retrieve_meta( ) try: - assert socket is not None + assert socket is not None, "Socket must be initialized before use" await socket.send_multipart(request_msg.serialize()) response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( - f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} " + f"[{self.client_id}] Received KV_RETRIEVE_META response: {response_msg} " f"from controller {self._controller.id}" ) if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_META_RESPONSE: - metadata = response_msg.body.get("metadata", BatchMeta.empty()) - return metadata - else: - raise RuntimeError( - f"[{self.client_id}]: Failed to retrieve keys from controller {self._controller.id}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) + return response_msg.body.get("metadata", BatchMeta.empty()) + + raise RuntimeError( + f"[{self.client_id}] Failed to retrieve metadata {response_msg.body.get('message', 'Unknown error')}" + ) except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e + raise RuntimeError(f"[{self.client_id}] Failed in async_kv_retrieve_meta: {e}") from e @with_controller_socket async def async_kv_retrieve_keys( self, global_indexes: list[int] | int, partition_id: str, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> list[str]: """Asynchronously retrieve keys according to global_indexes from the controller. @@ -1005,8 +997,8 @@ async def async_kv_retrieve_keys( @with_controller_socket async def async_kv_list( self, - partition_id: Optional[str] = None, - socket: Optional[zmq.asyncio.Socket] = None, + partition_id: str | None = None, + socket: zmq.asyncio.Socket | None = None, ) -> dict[str, dict[str, Any]]: """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions. @@ -1145,8 +1137,8 @@ def get_meta( batch_size: int, partition_id: str, mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, + task_name: str | None = None, + sampling_config: dict[str, Any] | None = None, ) -> BatchMeta: """Synchronously fetch data metadata from the controller via ZMQ. @@ -1234,9 +1226,9 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: def put( self, data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + metadata: BatchMeta | None = None, + partition_id: str | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> BatchMeta: """Synchronously write data to storage units based on metadata. @@ -1356,7 +1348,7 @@ def get_consumption_status( self, task_name: str, partition_id: str, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """Synchronously get consumption status for a specific task and partition. Args: @@ -1384,7 +1376,7 @@ def get_production_status( self, data_fields: list[str], partition_id: str, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """Synchronously get production status for specific data fields and partition. Args: @@ -1454,7 +1446,7 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> """ return self._check_production_status(data_fields=data_fields, partition_id=partition_id) - def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: + def reset_consumption(self, partition_id: str, task_name: str | None = None) -> bool: """Synchronously reset consumption status for a partition. This allows the same data to be re-consumed, useful for debugging scenarios @@ -1501,20 +1493,22 @@ def kv_retrieve_meta( partition_id: str, create: bool = False, ) -> BatchMeta: - """Synchronously retrieve BatchMeta from the controller using user-specified keys. + """Synchronously retrieve BatchMeta by user-defined keys. + + Retrieves metadata for given keys from a specified partition. + If keys do not exist and `create=True`, they will be automatically registered. Args: - keys: List of keys to retrieve from the controller - partition_id: The ID of the logical partition to search for keys. - create: Whether to register new keys if not found. + keys: List of keys to retrieve from the controller. + partition_id: Logical partition to query. + create: If True, automatically create entries for non-existent keys. Returns: - metadata: BatchMeta of the corresponding keys + BatchMeta: Metadata for the requested keys. Raises: TypeError: If `keys` is not a list of string or a string """ - return self._kv_retrieve_meta(keys=keys, partition_id=partition_id, create=create) def kv_retrieve_keys( @@ -1540,7 +1534,7 @@ def kv_retrieve_keys( def kv_list( self, - partition_id: Optional[str] = None, + partition_id: str | None = None, ) -> dict[str, dict[str, Any]]: """Synchronously retrieve keys and custom_meta from the controller for one or all partitions. diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 49aa2410..1e74d3d7 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -21,7 +21,7 @@ from itertools import groupby from operator import itemgetter from threading import Lock, Thread -from typing import Any, Optional +from typing import Any from uuid import uuid4 import numpy as np @@ -35,7 +35,7 @@ BatchMeta, ) from transfer_queue.sampler import BaseSampler, SequentialSampler -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.perf_utils import IntervalPerfMonitor from transfer_queue.utils.zmq_utils import ( @@ -45,7 +45,7 @@ create_zmq_socket, format_zmq_address, get_free_port, - get_node_ip_address_raw, + get_node_ip_address, ) logger = get_logger(__name__) @@ -195,10 +195,10 @@ class FieldMeta: """ global_indexes: set[int] = field(default_factory=set) - dtype: Optional[Any] = None - shape: Optional[tuple] = None # None when is_nested=True - is_nested: Optional[bool] = None - is_non_tensor: Optional[bool] = None + dtype: Any | None = None + shape: tuple | None = None # None when is_nested=True + is_nested: bool | None = None + is_non_tensor: bool | None = None per_sample_shapes: dict[int, tuple] = field(default_factory=dict) # {global_idx: shape} @@ -384,7 +384,6 @@ def allocated_samples_num(self) -> int: return self.production_status.shape[0] # ==================== Index Pre-Allocation Methods ==================== - def register_pre_allocated_indexes(self, allocated_indexes: list[int]): """ Register pre-allocated sample indexes to this partition. @@ -442,7 +441,6 @@ def activate_pre_allocated_indexes(self, sample_num: int) -> list[int]: return global_index_to_allocate # ==================== Dynamic Expansion Methods ==================== - def ensure_samples_capacity(self, required_samples: int) -> None: """ Ensure the production status tensor has enough rows for the required samples. @@ -492,13 +490,12 @@ def ensure_fields_capacity(self, required_fields: int) -> None: logger.debug(f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields") # ==================== Production Status Interface ==================== - def update_production_status( self, global_indices: list[int], field_names: list[str], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ) -> bool: """ Update production status for specific samples and fields. @@ -563,7 +560,7 @@ def _update_field_metadata( self, global_indexes: list[int], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ): """Update field metadata from columnar field_schema.""" if not global_indexes: @@ -611,7 +608,6 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): ) # ==================== Consumption Status Interface ==================== - def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]: """ Get or create consumption status for a specific task. @@ -649,7 +645,7 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te consumption_status = self.consumption_status[task_name] return partition_global_index, consumption_status - def reset_consumption(self, task_name: Optional[str] = None): + def reset_consumption(self, task_name: str | None = None): """ Reset consumption status for a specific task or all tasks. @@ -674,7 +670,7 @@ def reset_consumption(self, task_name: Optional[str] = None): # ==================== Production Status Interface ==================== def get_production_status_for_fields( self, field_names: list[str], mask: bool = False - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + ) -> tuple[Tensor | None, Tensor | None]: """ Check if all samples for specified fields are fully produced and ready. @@ -713,7 +709,6 @@ def get_production_status_for_fields( return partition_global_index, production_status # ==================== Data Scanning and Query Methods ==================== - def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: """ Scan data status to find samples ready for consumption. @@ -761,7 +756,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: return ready_sample_indices # ==================== Metadata Methods ==================== - def get_field_schema( self, field_names: list[str], batch_global_indexes: list[int] | None = None ) -> dict[str, dict[str, Any]]: @@ -830,7 +824,6 @@ def set_custom_meta(self, custom_meta: dict[int, dict]) -> None: self.custom_meta.update(custom_meta) # ==================== Statistics and Monitoring ==================== - def get_statistics(self) -> dict[str, Any]: """Get detailed statistics for this partition.""" stats = { @@ -877,7 +870,6 @@ def get_statistics(self) -> dict[str, Any]: return stats # ==================== Serialization ==================== - def to_snapshot(self): """ Get a snapshot of partition status information. @@ -1028,7 +1020,6 @@ def __init__( logger.info(f"TransferQueue Controller {self.controller_id} initialized") # ==================== Partition Management API ==================== - def create_partition(self, partition_id: str) -> bool: """ Create a new data partition with pre-allocated sample indexes. @@ -1058,7 +1049,7 @@ def create_partition(self, partition_id: str) -> bool: logger.info(f"Created partition {partition_id} with {TQ_PRE_ALLOC_SAMPLE_NUM} pre-allocated indexes") return True - def _get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]: + def _get_partition(self, partition_id: str) -> DataPartitionStatus | None: """ Get partition status information. @@ -1070,7 +1061,7 @@ def _get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]: """ return self.partitions.get(partition_id) - def get_partition_snapshot(self, partition_id: str) -> Optional[DataPartitionStatus]: + def get_partition_snapshot(self, partition_id: str) -> DataPartitionStatus | None: """ Get a copy of partition status information, without threading.Lock(). @@ -1098,7 +1089,6 @@ def list_partitions(self) -> list[str]: return list(self.partitions.keys()) # ==================== Partition Index Management API ==================== - def get_partition_index_range(self, partition_id: str) -> list[int]: """ Get all indexes for a specific partition. @@ -1114,13 +1104,12 @@ def get_partition_index_range(self, partition_id: str) -> list[int]: return self.index_manager.get_indexes_for_partition(partition_id) # ==================== Data Production API ==================== - def update_production_status( self, partition_id: str, global_indexes: list[int], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ) -> bool: """ Update production status for specific samples and fields in a partition. @@ -1150,8 +1139,7 @@ def update_production_status( return success # ==================== Data Consumption API ==================== - - def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]: + def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Tensor | None, Tensor | None]: """ Get or create consumption status for a specific task and partition. Delegates to the partition's own method. @@ -1171,9 +1159,7 @@ def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Opt return partition.get_consumption_status(task_name, mask=True) - def get_production_status( - self, partition_id: str, data_fields: list[str] - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + def get_production_status(self, partition_id: str, data_fields: list[str]) -> tuple[Tensor | None, Tensor | None]: """ Check if all samples for specified fields are fully produced in a partition. @@ -1238,7 +1224,7 @@ def get_metadata( mode: str = "fetch", task_name: str | None = None, batch_size: int | None = None, - sampling_config: Optional[dict[str, Any]] = None, + sampling_config: dict[str, Any] | None = None, *args, **kwargs, ) -> BatchMeta: @@ -1398,7 +1384,6 @@ def scan_data_status( return ready_sample_indices # ==================== Metadata Generation API ==================== - def generate_batch_meta( self, partition_id: str, @@ -1507,7 +1492,7 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): self.partitions.pop(partition_id) self.sampler.clear_cache(partition_id) - def reset_consumption(self, partition_id: str, task_name: Optional[str] = None): + def reset_consumption(self, partition_id: str, task_name: str | None = None): """ Reset consumption status for a partition without clearing the actual data. @@ -1592,17 +1577,17 @@ def kv_retrieve_meta( metadata: BatchMeta of the requested keys """ - logger.debug(f"[{self.controller_id}]: Retrieve keys {keys} in partition {partition_id}") + logger.debug(f"[{self.controller_id}] Retrieve keys {keys} in partition {partition_id}") + # Ensure partition exists partition = self._get_partition(partition_id) - if partition is None: if not create: - logger.warning(f"Partition {partition_id} were not found in controller!") + logger.warning(f"Partition {partition_id} not found!") return BatchMeta.empty() - else: - self.create_partition(partition_id) - partition = self._get_partition(partition_id) + + self.create_partition(partition_id) + partition = self._get_partition(partition_id) assert partition is not None global_indexes = partition.kv_retrieve_indexes(keys) @@ -1646,15 +1631,13 @@ def kv_retrieve_meta( if col_idx < len(col_mask) and col_mask[col_idx]: data_fields.append(field_name) - metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch") - - return metadata + return self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch") def kv_retrieve_keys( self, global_indexes: list[int], partition_id: str, - ) -> list[Optional[str]]: + ) -> list[str | None]: """ Retrieve keys from the controller using a list of global_indexes. @@ -1689,7 +1672,7 @@ def kv_retrieve_keys( def _init_zmq_socket(self): """Initialize ZMQ sockets for communication.""" self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address_raw() + self._node_ip = get_node_ip_address() while True: try: @@ -1726,7 +1709,7 @@ def _init_zmq_socket(self): continue self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id=self.controller_id, ip=self._node_ip, ports={ diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index 3b08cd26..13277e11 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import torch from tensordict import TensorDict @@ -80,7 +79,7 @@ def __init__( pin_memory: bool = False, worker_init_fn=None, multiprocessing_context=None, - prefetch_factor: Optional[int] = None, + prefetch_factor: int | None = None, persistent_workers: bool = False, pin_memory_device: str = "", ): diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 66141b8f..8e8700de 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -18,7 +18,7 @@ import subprocess import time from importlib import resources -from typing import Any, Callable, Optional +from typing import Any, Callable from urllib.parse import urlparse import ray @@ -32,7 +32,7 @@ from transfer_queue.metadata import KVBatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler -from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.yuanrong_utils import ( @@ -43,40 +43,39 @@ logger = get_logger(__name__) -_TRANSFER_QUEUE_CLIENT: Any = None -_TRANSFER_QUEUE_STORAGE: Any = None -_TRANSFER_QUEUE_CONTROLLER: Any = None +_TQ_CLIENT: Any = None +_TQ_STORAGE: Any = None +_TQ_CONTROLLER: Any = None -def _maybe_create_transferqueue_client( - conf: Optional[DictConfig] = None, -) -> TransferQueueClient: - global _TRANSFER_QUEUE_CLIENT - if _TRANSFER_QUEUE_CLIENT is None: +def _maybe_create_tq_client(conf: DictConfig | None = None) -> TransferQueueClient: + global _TQ_CLIENT + if _TQ_CLIENT is None: if conf is None: _init_from_existing() - assert _TRANSFER_QUEUE_CLIENT is not None, ( + assert _TQ_CLIENT is not None, ( "TransferQueueController has not been initialized yet. Please call init() first." ) - return _TRANSFER_QUEUE_CLIENT + return _TQ_CLIENT pid = os.getpid() - _TRANSFER_QUEUE_CLIENT = TransferQueueClient( + _TQ_CLIENT = TransferQueueClient( client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info ) backend_name = conf.backend.storage_backend - _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) + _TQ_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) - return _TRANSFER_QUEUE_CLIENT + return _TQ_CLIENT -def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: - global _TRANSFER_QUEUE_STORAGE +# TODO(hz): Adopt registry pattern to manage storage backends for better scalability. +def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig: + global _TQ_STORAGE - if _TRANSFER_QUEUE_STORAGE is None: - _TRANSFER_QUEUE_STORAGE = {} + if _TQ_STORAGE is None: + _TQ_STORAGE = {} if conf.backend.storage_backend == "SimpleStorage": # initialize SimpleStorageUnit simple_storage_handles = {} @@ -98,7 +97,7 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: storage_zmq_info = process_zmq_server_info(simple_storage_handles) backend_name = conf.backend.storage_backend conf.backend[backend_name].zmq_info = storage_zmq_info - _TRANSFER_QUEUE_STORAGE["SimpleStorage"] = simple_storage_handles + _TQ_STORAGE["SimpleStorage"] = simple_storage_handles if conf.backend.storage_backend == "MooncakeStore": if conf.backend.MooncakeStore.auto_init: # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts @@ -186,9 +185,9 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: f"mooncake_master exited with error. Check {log_file_path} for detailed logs. " f"Output:\n{error_msg}" ) - _TRANSFER_QUEUE_STORAGE["MooncakeStore"] = process + _TQ_STORAGE["MooncakeStore"] = process if conf.backend.storage_backend == "Yuanrong" and conf.backend.Yuanrong.auto_init: - _TRANSFER_QUEUE_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf) + _TQ_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf) return conf @@ -198,10 +197,10 @@ def _init_from_existing() -> bool: Returns: True if successfully initialized from existing controller, False otherwise. """ - global _TRANSFER_QUEUE_CONTROLLER + global _TQ_CONTROLLER try: - if _TRANSFER_QUEUE_CONTROLLER is None: - _TRANSFER_QUEUE_CONTROLLER = ray.get_actor("TransferQueueController") + if _TQ_CONTROLLER is None: + _TQ_CONTROLLER = ray.get_actor("TransferQueueController") except ValueError: logger.info("Called _init_from_existing() but TransferQueueController has not been initialized yet.") @@ -211,9 +210,9 @@ def _init_from_existing() -> bool: conf = None while conf is None: - conf = ray.get(_TRANSFER_QUEUE_CONTROLLER.get_config.remote()) + conf = ray.get(_TQ_CONTROLLER.get_config.remote()) if conf is not None: - _maybe_create_transferqueue_client(conf) + _maybe_create_tq_client(conf) logger.info("TransferQueueClient initialized.") return True @@ -225,27 +224,22 @@ def _init_from_existing() -> bool: # ==================== Initialization API ==================== -def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: +def init(conf: DictConfig | None = None) -> DictConfig | None: """Initialize the TransferQueue system. This function sets up the TransferQueue controller, distributed storage, and client. It should be called once at the beginning of the program before any data operations. - If a controller already exists (e.g., initialized by another process), this function - will retrieve the config from existing controller and initialize the TransferQueueClient. - In this case, the `conf` parameter will be ignored. + If a controller already exists, reuse it and only initialize the client; + the provided `conf` will be ignored in this case. Args: - conf: Optional configuration dictionary. If provided, it will be merged with - the default config from 'config.yaml'. This is only used for first-time - initializing. When connecting to an existing controller, this parameter - is ignored. + conf: Optional custom config merged with default `config.yaml`. + Only takes effect on first-time initialization, ignored when attaching + to an existing controller. Returns: The merged configuration dictionary. - Raises: - ValueError: If config is not valid or required configuration keys are missing. - Example: >>> # In process 0, node A >>> import transfer_queue as tq @@ -261,17 +255,15 @@ def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: if _init_from_existing(): return conf - # First-time initialize TransferQueue logger.info("No TransferQueueController found. Starting first-time initialization...") - # create config final_conf = OmegaConf.create({}, flags={"allow_objects": True}) default_conf = OmegaConf.load(resources.files("transfer_queue") / "config.yaml") final_conf = OmegaConf.merge(final_conf, default_conf) if conf: final_conf = OmegaConf.merge(final_conf, conf) - # create controller + # TODO(hz): support load custom sampler class from external module. try: sampler = final_conf.controller.sampler if isinstance(sampler, BaseSampler): @@ -288,9 +280,8 @@ def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None try: - # Ray will make sure actor with same name can only be created once - global _TRANSFER_QUEUE_CONTROLLER - _TRANSFER_QUEUE_CONTROLLER = TransferQueueController.options(name="TransferQueueController").remote( # type: ignore[attr-defined] + global _TQ_CONTROLLER + _TQ_CONTROLLER = TransferQueueController.options(name="TransferQueueController").remote( # type: ignore[attr-defined] sampler=sampler, polling_mode=final_conf.controller.polling_mode ) logger.info("TransferQueueController has been created.") @@ -299,18 +290,15 @@ def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: _init_from_existing() return final_conf - controller_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_CONTROLLER) + controller_zmq_info = process_zmq_server_info(_TQ_CONTROLLER) final_conf.controller.zmq_info = controller_zmq_info - # create distributed storage backends - final_conf = _maybe_create_transferqueue_storage(final_conf) + final_conf = _maybe_create_tq_storage(final_conf) - # store the config into controller - ray.get(_TRANSFER_QUEUE_CONTROLLER.store_config.remote(final_conf)) + ray.get(_TQ_CONTROLLER.store_config.remote(final_conf)) logger.info(f"TransferQueue config: {final_conf}") - # create client - _maybe_create_transferqueue_client(final_conf) + _maybe_create_tq_client(final_conf) return final_conf @@ -325,13 +313,13 @@ def close(): Note: This function should be called when the TransferQueue system is no longer needed. """ - global _TRANSFER_QUEUE_CLIENT - global _TRANSFER_QUEUE_STORAGE - global _TRANSFER_QUEUE_CONTROLLER + global _TQ_CLIENT + global _TQ_STORAGE + global _TQ_CONTROLLER try: - if _TRANSFER_QUEUE_STORAGE: - for key, value in _TRANSFER_QUEUE_STORAGE.items(): + if _TQ_STORAGE: + for key, value in _TQ_STORAGE.items(): if key == "SimpleStorage": # only the process that do first-time init can clean the distributed storage for storage in value.values(): @@ -345,9 +333,9 @@ def close(): f"Consider manually killing the mooncake_master." ) - if _TRANSFER_QUEUE_CLIENT: + if _TQ_CLIENT: try: - ret = _TRANSFER_QUEUE_CLIENT.storage_manager.storage_client._store.remove_all() + ret = _TQ_CLIENT.storage_manager.storage_client._store.remove_all() if ret < 0: logger.error("Failed to remove existing keys in mooncake_master.") else: @@ -357,37 +345,31 @@ def close(): elif key == "Yuanrong": cleanup_yuanrong_resources(value) else: - logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") + logger.warning(f"close for _TQ_STORAGE with key {key} is not supported for now.") - _TRANSFER_QUEUE_STORAGE = None + _TQ_STORAGE = None except Exception: pass - if _TRANSFER_QUEUE_CLIENT: - _TRANSFER_QUEUE_CLIENT.close() - _TRANSFER_QUEUE_CLIENT = None + if _TQ_CLIENT: + _TQ_CLIENT.close() + _TQ_CLIENT = None - if _TRANSFER_QUEUE_CONTROLLER: + if _TQ_CONTROLLER: try: - ray.kill(_TRANSFER_QUEUE_CONTROLLER) + ray.kill(_TQ_CONTROLLER) except Exception: pass - _TRANSFER_QUEUE_CONTROLLER = None - - try: - controller = ray.get_actor("TransferQueueController") - ray.kill(controller) - except Exception: - pass + _TQ_CONTROLLER = None # ==================== High-Level KV Interface API ==================== def kv_put( key: str, partition_id: str, - fields: Optional[TensorDict | dict[str, Any]] = None, - tag: Optional[dict[str, Any]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | dict[str, Any] | None = None, + tag: dict[str, Any] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: """Put a single key-value pair to TransferQueue. @@ -440,7 +422,7 @@ def kv_put( if fields is None and tag is None: raise ValueError("Please provide at least one parameter of `fields` or `tag`.") - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() # 1. translate user-specified key to BatchMeta batch_meta = tq_client.kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True) @@ -488,41 +470,36 @@ def kv_put( def kv_batch_put( keys: list[str], partition_id: str, - fields: Optional[TensorDict] = None, - tags: Optional[list[dict[str, Any]]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | None = None, + tags: list[dict[str, Any]] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: - """Put multiple key-value pairs to TransferQueue in batch. + """Batch put multiple key-value pairs into the TransferQueue. - This method stores multiple key-value pairs in a single operation, which is more - efficient than calling kv_put multiple times. + This method stores multiple key-value entries in a single operation, + which is significantly more efficient than repeated calls to ``kv_put``. Args: - keys: List of user-specified keys for the data - partition_id: Logical partition to store the data in - fields: TensorDict containing data for all keys. Must have batch_size == len(keys). - If not provided, will only update the newly given tags to the keys. - tags: List of metadata tags, one for each key - data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. The input is a slice of the `fields` parameter passed to - kv_put / kv_batch_put, in plain dict format (not TensorDict), - mapping field_name -> batched values. For a regular tensor column - the value is a batched tensor; for nested tensors (jagged or - strided) and NonTensorStack columns the values are extracted into - a list. It must modify values in-place based on the original keys; - do not add or remove keys. The number of elements per column must - also remain unchanged. Do not change the inner order of values - within each column. Only supported by SimpleStorage. + keys: List of user-defined unique keys for the data entries. + partition_id: Logical partition where the data will be stored. + fields: TensorDict containing batched data for all keys. Must have ``batch_size == len(keys)``. + If not provided, only the associated tags will be updated. + tags: List of metadata dictionaries, one per key. Length must match the number of keys. + data_parser: Optional callable to parse raw reference data (e.g., URLs) into real content + before storage. The input is a plain dict (not TensorDict) mapping field names to + batched values. The parser **must modify data in-place** without adding/removing + keys or changing element counts/order. Only supported by ``SimpleStorage`` backend. Returns: - KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. - The `fields` attribute includes all fields stored for these samples, - including any new fields written by this put operation. + KVBatchMeta: Metadata object containing stored keys, tags, partition ID, + and field information. The ``fields`` attribute includes all + persisted fields for the written samples. Raises: - ValueError: If neither `fields` nor `tags` is provided - ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict - RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` + ValueError: When both ``fields`` and ``tags`` are empty. + ValueError: When ``fields`` batch size mismatches key count. + ValueError: When ``tags`` length mismatches key count. + RuntimeError: When retrieved metadata size mismatches input key count. Example: >>> import transfer_queue as tq @@ -535,54 +512,42 @@ def kv_batch_put( ... }, batch_size=3) >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] >>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) - >>> print(meta.fields) # ['input_ids', 'attention_mask'] + >>> print(meta.fields) """ + num_keys = len(keys) if fields is None and tags is None: raise ValueError("Please provide at least one parameter of fields or tag.") - if fields is not None and fields.batch_size[0] != len(keys): - raise ValueError( - f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " - f"batch_size {fields.batch_size[0]}" - ) - - tq_client = _maybe_create_transferqueue_client() + if fields is not None and fields.batch_size[0] != num_keys: + raise ValueError(f"Length of `keys` ({num_keys}) does not match `fields` batch size ({fields.batch_size[0]}).") - # 1. translate user-specified key to BatchMeta + tq_client = _maybe_create_tq_client() batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True) - if batch_meta.size != len(keys): - raise RuntimeError( - f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!" - ) + if batch_meta.size != num_keys: + raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match input `keys` size {num_keys}.") - # 2. register the user-specified tags to BatchMeta if tags is not None: - if len(tags) != len(keys): - raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") + if len(tags) != num_keys: + raise ValueError(f"Length of `keys` ({num_keys}) does not match length of `tags` ({len(tags)}).") batch_meta.update_custom_meta(tags) - # 3. put data if fields is not None: - # After put, batch_meta.field_names will include the new fields written by user batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser) - else: - # Directly update custom_meta (tags) to controller + else: # tags is not None tq_client.set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names - return KVBatchMeta( keys=keys, tags=batch_meta.custom_meta, partition_id=partition_id, - fields=fields_to_return, + fields=batch_meta.field_names, extra_info=batch_meta.extra_info, ) -def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | str] = None) -> TensorDict: +def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict: """Get data from TransferQueue using KVBatchMeta. This is a convenience method for retrieving data using KVBatchMeta returned @@ -622,7 +587,7 @@ def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | raise ValueError("Must provide partition_id in the input KVBatchMeta.") if select_fields is not None: if isinstance(select_fields, str): - fields_to_fetch: Optional[list[str]] = [select_fields] + fields_to_fetch: list[str] | None = [select_fields] else: fields_to_fetch = select_fields @@ -637,9 +602,7 @@ def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch) -def kv_batch_get( - keys: list[str] | str, partition_id: str, select_fields: Optional[list[str] | str] = None -) -> TensorDict: +def kv_batch_get(keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None) -> TensorDict: """Get data from TransferQueue using user-specified keys. This is a convenience method for retrieving data using keys instead of indexes. @@ -668,7 +631,7 @@ def kv_batch_get( ... select_fields="input_ids" ... ) """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) @@ -690,7 +653,7 @@ def kv_batch_get( return data -def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: +def kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]: """List all keys and their metadata in one or all partitions. Args: @@ -724,7 +687,7 @@ def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: >>> for pid, keys in all_partitions.items(): >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() partition_info = tq_client.kv_list(partition_id) @@ -753,7 +716,7 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None: if isinstance(keys, str): keys = [keys] - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size > 0: @@ -764,9 +727,9 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None: async def async_kv_put( key: str, partition_id: str, - fields: Optional[TensorDict | dict[str, Any]] = None, - tag: Optional[dict[str, Any]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | dict[str, Any] | None = None, + tag: dict[str, Any] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: """Asynchronously put a single key-value pair to TransferQueue. @@ -820,7 +783,7 @@ async def async_kv_put( if fields is None and tag is None: raise ValueError("Please provide at least one parameter of fields or tag.") - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() # 1. translate user-specified key to BatchMeta batch_meta = await tq_client.async_kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True) @@ -868,9 +831,9 @@ async def async_kv_put( async def async_kv_batch_put( keys: list[str], partition_id: str, - fields: Optional[TensorDict] = None, - tags: Optional[list[dict[str, Any]]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | None = None, + tags: list[dict[str, Any]] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: """Asynchronously put multiple key-value pairs to TransferQueue in batch. @@ -926,7 +889,7 @@ async def async_kv_batch_put( f"batch_size {fields.batch_size[0]}" ) - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() # 1. translate user-specified key to BatchMeta batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True) @@ -961,7 +924,7 @@ async def async_kv_batch_put( ) -async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | str] = None) -> TensorDict: +async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict: """Asynchronously get data from TransferQueue using KVBatchMeta. This is a convenience method for retrieving data using KVBatchMeta returned @@ -1019,7 +982,7 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[ async def async_kv_batch_get( - keys: list[str] | str, partition_id: str, select_fields: Optional[list[str] | str] = None + keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None ) -> TensorDict: """Asynchronously get data from TransferQueue using user-specified keys. @@ -1049,7 +1012,7 @@ async def async_kv_batch_get( ... select_fields="input_ids" ... ) """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) @@ -1070,7 +1033,7 @@ async def async_kv_batch_get( return data -async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: +async def async_kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]: """Asynchronously list all keys and their metadata in one or all partitions. Args: @@ -1105,7 +1068,7 @@ async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[st >>> for pid, keys in all_partitions.items(): >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() partition_info = await tq_client.async_kv_list(partition_id) @@ -1134,7 +1097,7 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: if isinstance(keys, str): keys = [keys] - tq_client = _maybe_create_transferqueue_client() + tq_client = _maybe_create_tq_client() batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size > 0: @@ -1145,6 +1108,5 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: # For low-level API support, please refer to transfer_queue/client.py for details. def get_client(): """Get a TransferQueueClient for using low-level API""" - if _TRANSFER_QUEUE_CLIENT is None: - raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!") - return _TRANSFER_QUEUE_CLIENT + assert _TQ_CLIENT is not None, "Please initialize the TransferQueue first by calling `tq.init()`!" + return _TQ_CLIENT diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 03784995..05cf3241 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -19,7 +19,7 @@ from collections import defaultdict from dataclasses import dataclass from types import MappingProxyType -from typing import Any, Optional +from typing import Any import numpy as np import torch @@ -30,11 +30,6 @@ logger = get_logger(__name__) -# --------------------------------------------------------------------------- -# Internal helpers -# --------------------------------------------------------------------------- - - def _extra_info_values_equal(a: Any, b: Any) -> bool: """Compare two extra_info values for equality. @@ -55,7 +50,7 @@ def _extra_info_values_equal(a: Any, b: Any) -> bool: class _SampleView: """Lazy read-only view of a single sample row in a columnar BatchMeta. - All returned dicts are ``MappingProxyType`` – attempts to mutate them + All returned dicts are ``MappingProxyType``, and attempts to mutate them raise ``TypeError``, making it obvious that this is a snapshot view. """ @@ -226,11 +221,11 @@ def __init__( self, global_indexes: list[int], partition_ids: list[str], - field_schema: Optional[dict[str, dict[str, Any]]] = None, - production_status: Optional[np.ndarray] = None, - extra_info: Optional[dict[str, Any]] = None, - custom_meta: Optional[list[dict[str, Any]]] = None, - _custom_backend_meta: Optional[list[dict[str, Any]]] = None, + field_schema: dict[str, dict[str, Any]] | None = None, + production_status: np.ndarray | None = None, + extra_info: dict[str, Any] | None = None, + custom_meta: list[dict[str, Any]] | None = None, + _custom_backend_meta: list[dict[str, Any]] | None = None, ) -> None: if field_schema is None: field_schema = {} @@ -379,7 +374,6 @@ def get_shapes(self, field_name: str) -> list: return [meta.get("shape")] * self.size # ==================== Extra Info Methods ==================== - def get_extra_info(self, key: str, default: Any = None) -> Any: """Get extra info by key""" return self.extra_info.get(key, default) @@ -417,7 +411,6 @@ def has_extra_info(self, key: str) -> bool: return key in self.extra_info # ==================== Custom Meta Methods (User Layer) ==================== - def get_all_custom_meta(self) -> list[dict[str, Any]]: """Get all custom_meta as a list of dictionary (one per sample, in global_indexes order). @@ -451,7 +444,6 @@ def clear_custom_meta(self) -> None: self.custom_meta = [{} for _ in range(self.size)] # ==================== Core BatchMeta Operations ==================== - def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": """Add new fields from a TensorDict to all samples in this batch. This modifies the batch in-place to include the new fields. @@ -793,7 +785,7 @@ def reorder(self, indices: list[int]): self._custom_backend_meta = [self._custom_backend_meta[i] for i in indices] @classmethod - def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": + def empty(cls, extra_info: dict[str, Any] | None = None) -> "BatchMeta": """Create an empty BatchMeta with no samples. Args: @@ -836,13 +828,13 @@ class KVBatchMeta: tags: list[dict] = dataclasses.field(default_factory=list) # [optional] partition_id of this batch - partition_id: Optional[str] = None + partition_id: str | None = None # [optional] fields of each sample - fields: Optional[list[str]] = None + fields: list[str] | None = None # [optional] external information for batch-level information - extra_info: Optional[dict[str, Any]] = dataclasses.field(default_factory=dict) + extra_info: dict[str, Any] | None = dataclasses.field(default_factory=dict) def __post_init__(self): """Validate all the variables""" diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index 766afd62..c831ecba 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -14,7 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any class BaseSampler(ABC): @@ -79,7 +79,7 @@ def has_cached_result( self, partition_id: str, task_name: str, - sampling_config: Optional[dict[str, Any]] = None, + sampling_config: dict[str, Any] | None = None, ) -> bool: """Check whether the sampler has a cached sampling result for the given context. diff --git a/transfer_queue/storage/__init__.py b/transfer_queue/storage/__init__.py index 6fbd3415..2fb1be46 100644 --- a/transfer_queue/storage/__init__.py +++ b/transfer_queue/storage/__init__.py @@ -17,17 +17,17 @@ AsyncSimpleStorageManager, MooncakeStorageManager, RayStorageManager, - TransferQueueStorageManager, - TransferQueueStorageManagerFactory, + StorageManager, + StorageManagerFactory, YuanrongStorageManager, ) -from .simple_backend import SimpleStorageUnit, StorageUnitData +from .simple_storage import SimpleStorageUnit, StorageUnitData __all__ = [ "SimpleStorageUnit", "StorageUnitData", - "TransferQueueStorageManager", - "TransferQueueStorageManagerFactory", + "StorageManager", + "StorageManagerFactory", "AsyncSimpleStorageManager", "MooncakeStorageManager", "YuanrongStorageManager", diff --git a/transfer_queue/storage/clients/__init__.py b/transfer_queue/storage/clients/__init__.py index 2b861166..d54a32ef 100644 --- a/transfer_queue/storage/clients/__init__.py +++ b/transfer_queue/storage/clients/__init__.py @@ -14,14 +14,13 @@ # limitations under the License. # This module is currently empty but reserved for future client implementations -from .base import TransferQueueStorageKVClient -from .factory import StorageClientFactory +from .base import StorageClientFactory, StorageKVClient from .mooncake_client import MooncakeStoreClient from .ray_storage_client import RayStorageClient from .yuanrong_client import YuanrongStorageClient __all__ = [ - "TransferQueueStorageKVClient", + "StorageKVClient", "StorageClientFactory", "RayStorageClient", "MooncakeStoreClient", diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index a6d63f6a..14b73e72 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -14,10 +14,10 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any -class TransferQueueStorageKVClient(ABC): +class StorageKVClient(ABC): """ Abstract base class for storage client. Subclasses must implement the core methods: put, get, and clear. @@ -32,7 +32,7 @@ def __init__(self, config: dict[str, Any]): self.config = config @abstractmethod - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + def put(self, keys: list[str], values: list[Any]) -> list[Any] | None: """ Store key-value pairs in the storage backend. Args: @@ -68,3 +68,44 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non def clear(self, keys: list[str], custom_backend_meta=None) -> None: """Clear key-value pairs in the storage backend.""" raise NotImplementedError("Subclasses must implement clear") + + +class StorageClientFactory: + """ + Factory class for creating storage client instances. + Uses a decorator-based registration mechanism to map client names to classes. + """ + + # Class variable: maps client names to their corresponding classes + _registry: dict[str, type[StorageKVClient]] = {} + + @classmethod + def register(cls, client_type: str): + """ + Decorator to register a concrete client class with the factory. + Args: + client_type (str): The name used to identify the client + Returns: + Callable: The decorator function that returns the original class + """ + + def decorator(client_class: type[StorageKVClient]) -> type[StorageKVClient]: + cls._registry[client_type] = client_class + return client_class + + return decorator + + @classmethod + def create(cls, client_type: str, config: dict) -> StorageKVClient: + """ + Create and return an instance of the storage client by name. + Args: + client_type (str): The registered name of the client + Returns: + StorageClientFactory: An instance of the requested client + Raises: + ValueError: If no client is registered with the given name + """ + if client_type not in cls._registry: + raise ValueError(f"Unknown StorageClient: {client_type}") + return cls._registry[client_type](config) diff --git a/transfer_queue/storage/clients/factory.py b/transfer_queue/storage/clients/factory.py deleted file mode 100644 index 73979bd5..00000000 --- a/transfer_queue/storage/clients/factory.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient - - -class StorageClientFactory: - """ - Factory class for creating storage client instances. - Uses a decorator-based registration mechanism to map client names to classes. - """ - - # Class variable: maps client names to their corresponding classes - _registry: dict[str, type[TransferQueueStorageKVClient]] = {} - - @classmethod - def register(cls, client_type: str): - """ - Decorator to register a concrete client class with the factory. - Args: - client_type (str): The name used to identify the client - Returns: - Callable: The decorator function that returns the original class - """ - - def decorator(client_class: type[TransferQueueStorageKVClient]) -> type[TransferQueueStorageKVClient]: - cls._registry[client_type] = client_class - return client_class - - return decorator - - @classmethod - def create(cls, client_type: str, config: dict) -> TransferQueueStorageKVClient: - """ - Create and return an instance of the storage client by name. - Args: - client_type (str): The registered name of the client - Returns: - StorageClientFactory: An instance of the requested client - Raises: - ValueError: If no client is registered with the given name - """ - if client_type not in cls._registry: - raise ValueError(f"Unknown StorageClient: {client_type}") - return cls._registry[client_type](config) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 8f311b4d..28706a39 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -15,13 +15,12 @@ import pickle from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Optional +from typing import Any import torch from torch import Tensor -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory @@ -39,7 +38,7 @@ @StorageClientFactory.register("MooncakeStoreClient") -class MooncakeStoreClient(TransferQueueStorageKVClient): +class MooncakeStoreClient(StorageKVClient): """ Storage client for MooncakeStore. """ @@ -64,9 +63,9 @@ def __init__(self, config: dict[str, Any]): self.device_name = "" if self.local_hostname is None or self.local_hostname == "": - from transfer_queue.utils.zmq_utils import get_node_ip_address_raw + from transfer_queue.utils.zmq_utils import get_node_ip_address - ip = get_node_ip_address_raw() + ip = get_node_ip_address() logger.info(f"Try to use Ray IP ({ip}) as local hostname for MooncakeStore.") self.local_hostname = ip @@ -169,9 +168,9 @@ def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any def get( self, keys: list[str], - shapes: Optional[list[Any]] = None, - dtypes: Optional[list[Any]] = None, - custom_backend_meta: Optional[list[str]] = None, + shapes: list[Any] | None = None, + dtypes: list[Any] | None = None, + custom_backend_meta: list[str] | None = None, ) -> list[Any]: """Get multiple key-value pairs from MooncakeStore. @@ -259,7 +258,7 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> return results, indexes - def clear(self, keys: list[str], custom_backend_meta: Optional[list[Any]] = None) -> None: + def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: """Deletes multiple keys from MooncakeStore. Args: diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index c290f6f2..c85fa438 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -14,13 +14,12 @@ # limitations under the License. import itertools -from typing import Any, Optional +from typing import Any import ray import torch -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient @ray.remote(max_concurrency=8) @@ -47,7 +46,7 @@ def clear_obj_ref(self, keys: list[str]): @StorageClientFactory.register("RayStorageClient") -class RayStorageClient(TransferQueueStorageKVClient): +class RayStorageClient(StorageKVClient): """ Storage client for Ray RDT. """ @@ -62,7 +61,7 @@ def __init__(self, config=None): except ValueError: self.storage_actor = RayObjectRefStorage.options(name="RayObjectRefStorage", get_if_exists=False).remote() - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + def put(self, keys: list[str], values: list[Any]) -> list[Any] | None: """ Store tensors to remote storage. Args: diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 9aeffdc1..399cb58a 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -21,8 +21,7 @@ import torch from torch import Tensor -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.serial_utils import _decoder, _encoder from transfer_queue.utils.yuanrong_utils import find_reachable_host @@ -63,7 +62,7 @@ def supports_get(self, strategy_tag: Any) -> bool: """Check if this strategy can retrieve data with given tag.""" @abstractmethod - def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + def get(self, keys: list[str], **kwargs) -> list[Any | None]: """Retrieve values by keys; kwargs may include shapes/dtypes.""" @abstractmethod @@ -145,7 +144,7 @@ def supports_get(self, strategy_tag: str) -> bool: """Matches 'DsTensorClient' Strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + def get(self, keys: list[str], **kwargs) -> list[Any | None]: """Fetch NPU tensors using pre-allocated empty buffers.""" shapes = kwargs.get("shapes", None) dtypes = kwargs.get("dtypes", None) @@ -252,7 +251,7 @@ def supports_get(self, strategy_tag: str) -> bool: """Matches 'KVClient' strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + def get(self, keys: list[str], **kwargs) -> list[Any | None]: """Retrieve and deserialize objects in batches.""" results = [] for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): @@ -365,7 +364,7 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]: @StorageClientFactory.register("YuanrongStorageClient") -class YuanrongStorageClient(TransferQueueStorageKVClient): +class YuanrongStorageClient(StorageKVClient): """ Storage client for YuanRong DataSystem. @@ -434,9 +433,9 @@ def put_task(strategy, indexes): def get( self, keys: list[str], - shapes: Optional[list[Any]] = None, - dtypes: Optional[list[Any]] = None, - custom_backend_meta: Optional[list[str]] = None, + shapes: list[Any] | None = None, + dtypes: list[Any] | None = None, + custom_backend_meta: list[str] | None = None, ) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. @@ -476,7 +475,7 @@ def get_task(strategy, indexes): results[original_index] = value return results - def clear(self, keys: list[str], custom_backend_meta: Optional[list[str]] = None) -> None: + def clear(self, keys: list[str], custom_backend_meta: list[str] | None = None) -> None: """Deletes multiple keys from remote storage. Args: diff --git a/transfer_queue/storage/managers/__init__.py b/transfer_queue/storage/managers/__init__.py index 77954bb3..d5e71abd 100644 --- a/transfer_queue/storage/managers/__init__.py +++ b/transfer_queue/storage/managers/__init__.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import TransferQueueStorageManager -from .factory import TransferQueueStorageManagerFactory +from .base import StorageManager, StorageManagerFactory from .mooncake_manager import MooncakeStorageManager from .ray_storage_manager import RayStorageManager -from .simple_backend_manager import AsyncSimpleStorageManager +from .simple_storage_manager import AsyncSimpleStorageManager from .yuanrong_manager import YuanrongStorageManager __all__ = [ - "TransferQueueStorageManager", - "TransferQueueStorageManagerFactory", + "StorageManager", + "StorageManagerFactory", "AsyncSimpleStorageManager", "YuanrongStorageManager", "MooncakeStorageManager", diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 55901e31..1a5a8275 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -20,7 +20,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Optional +from typing import Any, Callable from uuid import uuid4 import ray @@ -32,7 +32,7 @@ from torch import Tensor from transfer_queue.metadata import BatchMeta, extract_field_schema -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket @@ -49,7 +49,7 @@ LIMIT_THREADS_PER_MANAGER_IN_RAY_ACTOR = 4 -class TransferQueueStorageManager(ABC): +class StorageManager(ABC): """Base class for storage layer. It defines the interface for data operations and generally provides handshake & notification capabilities.""" @@ -59,9 +59,9 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.controller_info = controller_info # Handshake socket is sync (used only during initialization) - self.controller_handshake_socket: Optional[zmq.Socket] = None + self.controller_handshake_socket: zmq.Socket | None = None - self.zmq_context: Optional[zmq.asyncio.Context] = None + self.zmq_context: zmq.asyncio.Context | None = None self._connect_to_controller() def _connect_to_controller(self) -> None: @@ -189,7 +189,7 @@ async def notify_data_update( partition_id: str, global_indexes: list[int], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ) -> None: """ Notify controller that new data is ready. @@ -300,7 +300,7 @@ async def notify_data_update( @abstractmethod async def put_data( - self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None ) -> None: """ Put data into the storage backend. @@ -364,11 +364,36 @@ def __del__(self): logger.error(f"[{self.storage_manager_id}]: Exception during __del__: {str(e)}") -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory # noqa: E402 +class StorageManagerFactory: + """Factory that creates a StorageManager instance.""" + _registry: dict[str, type[StorageManager]] = {} -@TransferQueueStorageManagerFactory.register("KVStorageManager") -class KVStorageManager(TransferQueueStorageManager): + @classmethod + def register(cls, manager_type: str): + """Register a StorageManager class.""" + + def decorator(manager_cls: type[StorageManager]): + if not issubclass(manager_cls, StorageManager): + raise TypeError( + f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be " + f"a subclass of StorageManager" + ) + cls._registry[manager_type] = manager_cls + return manager_cls + + return decorator + + @classmethod + def create(cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any]) -> StorageManager: + """Create and return a StorageManager instance.""" + assert manager_type in cls._registry, ( + f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" + ) + return cls._registry[manager_type](controller_info, config) + + +class KVStorageManager(StorageManager): """ A storage manager that uses a key-value (KV) backend (e.g., YuanRong) to store and retrieve tensor data. It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations. @@ -383,7 +408,7 @@ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): raise ValueError("Missing client_name in config") super().__init__(controller_info, config) self.storage_client = StorageClientFactory.create(client_name, config) - self._multi_threads_executor: Optional[ThreadPoolExecutor] = None + self._multi_threads_executor: ThreadPoolExecutor | None = None self._executor_finalizer = weakref.finalize(self, self._shutdown_executor, self._multi_threads_executor) @staticmethod @@ -428,7 +453,7 @@ def _generate_values(data: TensorDict) -> list[Any]: return results @staticmethod - def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None: + def _shutdown_executor(thread_executor: ThreadPoolExecutor | None) -> None: """ A static method to ensure no strong reference to 'self' is held within the finalizer's callback, enabling proper garbage collection. @@ -569,7 +594,7 @@ def _get_shape_type_custom_backend_meta_list( return shapes, dtypes, custom_backend_meta_list async def put_data( - self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None ) -> None: """ Store tensor data in the backend storage and notify the controller. diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py deleted file mode 100644 index 04d2cd03..00000000 --- a/transfer_queue/storage/managers/factory.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import Any - -from transfer_queue.storage.managers.base import TransferQueueStorageManager -from transfer_queue.utils.zmq_utils import ZMQServerInfo - - -class TransferQueueStorageManagerFactory: - """Factory that creates a StorageManager instance.""" - - _registry: dict[str, type[TransferQueueStorageManager]] = {} - - @classmethod - def register(cls, manager_type: str): - """Register a TransferQueueStorageManager class.""" - - def decorator(manager_cls: type[TransferQueueStorageManager]): - if not issubclass(manager_cls, TransferQueueStorageManager): - raise TypeError( - f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be " - f"a subclass of TransferQueueStorageManager" - ) - cls._registry[manager_type] = manager_cls - return manager_cls - - return decorator - - @classmethod - def create( - cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any] - ) -> TransferQueueStorageManager: - """Create and return a TransferQueueStorageManager instance.""" - if manager_type not in cls._registry: - if manager_type == "AsyncSimpleStorageManager": - warnings.warn( - f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.", - category=DeprecationWarning, - stacklevel=2, - ) - manager_type = "SimpleStorage" - elif manager_type == "MooncakeStorageManager": - warnings.warn( - f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.", - category=DeprecationWarning, - stacklevel=2, - ) - manager_type = "MooncakeStore" - elif manager_type == "YuanrongStorageManager": - warnings.warn( - f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.", - category=DeprecationWarning, - stacklevel=2, - ) - manager_type = "Yuanrong" - else: - raise ValueError( - f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" - ) - return cls._registry[manager_type](controller_info, config) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 7b43ca71..c1a354c9 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -15,15 +15,14 @@ from typing import Any -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = get_logger(__name__) -@TransferQueueStorageManagerFactory.register("MooncakeStore") +@StorageManagerFactory.register("MooncakeStore") class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" diff --git a/transfer_queue/storage/managers/ray_storage_manager.py b/transfer_queue/storage/managers/ray_storage_manager.py index 78069c21..0cc2a09c 100644 --- a/transfer_queue/storage/managers/ray_storage_manager.py +++ b/transfer_queue/storage/managers/ray_storage_manager.py @@ -15,12 +15,11 @@ from typing import Any -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.zmq_utils import ZMQServerInfo -@TransferQueueStorageManagerFactory.register("RayStore") +@StorageManagerFactory.register("RayStore") class RayStorageManager(KVStorageManager): """Storage manager for Ray-RDT backend.""" diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py similarity index 98% rename from transfer_queue/storage/managers/simple_backend_manager.py rename to transfer_queue/storage/managers/simple_storage_manager.py index 0771638b..184dcfdc 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -27,8 +27,7 @@ from tensordict import NonTensorStack, TensorDict from transfer_queue.metadata import BatchMeta, extract_field_schema -from transfer_queue.storage.managers.base import TransferQueueStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import StorageManager, StorageManagerFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ( ZMQMessage, @@ -58,8 +57,8 @@ class RoutingGroup(NamedTuple): batch_positions: list[int] # corresponding positions in the original batch -@TransferQueueStorageManagerFactory.register("SimpleStorage") -class AsyncSimpleStorageManager(TransferQueueStorageManager): +@StorageManagerFactory.register("SimpleStorage") +class AsyncSimpleStorageManager(StorageManager): """Asynchronous storage manager that handles multiple storage units. This manager provides async put/get/clear operations across multiple SimpleStorageUnit diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index 4c37cc73..f76b47b2 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -15,15 +15,14 @@ from typing import Any -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = get_logger(__name__) -@TransferQueueStorageManagerFactory.register("Yuanrong") +@StorageManagerFactory.register("Yuanrong") class YuanrongStorageManager(KVStorageManager): """Storage manager for Yuanrong backend.""" diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_storage.py similarity index 96% rename from transfer_queue/storage/simple_backend.py rename to transfer_queue/storage/simple_storage.py index bf334efc..91b44935 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_storage.py @@ -17,14 +17,14 @@ import time import weakref from threading import Event, Thread -from typing import Any, Optional +from typing import Any from uuid import uuid4 import ray import zmq from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.perf_utils import IntervalPerfMonitor from transfer_queue.utils.zmq_utils import ( @@ -34,7 +34,7 @@ create_zmq_socket, format_zmq_address, get_free_port, - get_node_ip_address_raw, + get_node_ip_address, ) logger = get_logger(__name__) @@ -129,7 +129,7 @@ class SimpleStorageUnit: """A storage unit that provides distributed data storage functionality. This class represents a storage unit that can store data in a 2D structure - (samples × data fields) and provides ZMQ-based communication for put/get/clear operations. + (samples, data_fields) and provides ZMQ-based communication for put/get/clear operations. Note: We use Ray decorator (@ray.remote) only for initialization purposes. We do NOT use Ray's .remote() call capabilities - the storage unit runs @@ -160,10 +160,10 @@ def __init__(self, storage_unit_size: int): self._shutdown_event = Event() # Placeholder for zmq_context, proxy_thread and worker_threads - self.zmq_context: Optional[zmq.Context] = None - self.put_get_socket: Optional[zmq.Socket] = None - self.proxy_thread: Optional[Thread] = None - self.worker_thread: Optional[Thread] = None + self.zmq_context: zmq.Context | None = None + self.put_get_socket: zmq.Socket | None = None + self.proxy_thread: Thread | None = None + self.worker_thread: Thread | None = None self._init_zmq_socket() self._start_process_put_get() @@ -186,7 +186,7 @@ def _init_zmq_socket(self) -> None: - worker_socket (DEALER): Backend socket for worker communication. """ self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address_raw() + self._node_ip = get_node_ip_address() # Frontend: ROUTER for receiving client requests self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, self._node_ip) @@ -205,7 +205,7 @@ def _init_zmq_socket(self) -> None: self.worker_socket.bind(self._inproc_addr) self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id=str(self.storage_unit_id), ip=self._node_ip, ports={"put_get_socket": self._put_get_socket_port}, @@ -481,10 +481,10 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: @staticmethod def _shutdown_resources( shutdown_event: Event, - worker_thread: Optional[Thread], - proxy_thread: Optional[Thread], - zmq_context: Optional[zmq.Context], - put_get_socket: Optional[zmq.Socket], + worker_thread: Thread | None, + proxy_thread: Thread | None, + zmq_context: zmq.Context | None, + put_get_socket: zmq.Socket | None, ) -> None: """Clean up resources on garbage collection.""" logger.info("Shutting down SimpleStorageUnit resources...") diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py index 56f0dda8..fadcb241 100644 --- a/transfer_queue/utils/common.py +++ b/transfer_queue/utils/common.py @@ -15,7 +15,6 @@ import os from contextlib import contextmanager -from typing import Optional import psutil import ray @@ -46,7 +45,7 @@ def get_placement_group(num_ray_actors: int, num_cpus_per_actor: int = 1): @contextmanager -def limit_pytorch_auto_parallel_threads(target_num_threads: Optional[int] = None, info: str = ""): +def limit_pytorch_auto_parallel_threads(target_num_threads: int | None = None, info: str = ""): """Prevent PyTorch from overdoing the automatic parallelism during torch.stack() operation""" pytorch_current_num_threads = torch.get_num_threads() physical_cores = psutil.cpu_count(logical=False) diff --git a/transfer_queue/utils/enum_utils.py b/transfer_queue/utils/enum_utils.py index 929889da..0af0df12 100644 --- a/transfer_queue/utils/enum_utils.py +++ b/transfer_queue/utils/enum_utils.py @@ -32,7 +32,7 @@ def _missing_(cls, value): ) -class TransferQueueRole(ExplicitEnum): +class Role(ExplicitEnum): """Available Roles of TransferQueue.""" CONTROLLER = "TransferQueueController" diff --git a/transfer_queue/utils/yuanrong_utils.py b/transfer_queue/utils/yuanrong_utils.py index fd3afbd6..17d5ca82 100644 --- a/transfer_queue/utils/yuanrong_utils.py +++ b/transfer_queue/utils/yuanrong_utils.py @@ -19,7 +19,7 @@ import shutil import socket import subprocess -from typing import Any, Optional +from typing import Any import ray from omegaconf import DictConfig @@ -101,7 +101,7 @@ def check_port_connectivity(host: str, port: int, timeout: float = 2.0) -> bool: return False -def find_reachable_host(port: int, timeout: float = 1.0) -> Optional[str]: +def find_reachable_host(port: int, timeout: float = 1.0) -> str | None: """Find a reachable local host IP address for given port. Tries all local IP addresses in order and returns the first one @@ -126,7 +126,7 @@ def find_reachable_host(port: int, timeout: float = 1.0) -> Optional[str]: return None -def _parse_remote_h2d_device_ids(worker_args: str) -> Optional[str]: +def _parse_remote_h2d_device_ids(worker_args: str) -> str | None: """Parse --remote_h2d_device_ids parameter from worker_args string. Args: diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 8af6aec3..f9282099 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -17,16 +17,15 @@ import time from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Optional, TypeAlias +from typing import Any, Callable, TypeAlias from uuid import uuid4 import psutil import ray import zmq import zmq.asyncio -from ray.util import get_node_ip_address -from transfer_queue.utils.enum_utils import ExplicitEnum, TransferQueueRole +from transfer_queue.utils.enum_utils import ExplicitEnum, Role from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.serial_utils import decode, encode @@ -104,7 +103,7 @@ class ZMQServerInfo: TransferQueue server info class. """ - def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, int]): + def __init__(self, role: Role, id: str, ip: str, ports: dict[str, int]): self.role = role self.id = id self.ip = ip @@ -146,7 +145,7 @@ def create( request_type: ZMQRequestType, sender_id: str, body: dict[str, Any], - receiver_id: Optional[str] = None, + receiver_id: str | None = None, ) -> "ZMQMessage": """Create ZMQMessage.""" return cls( @@ -217,13 +216,13 @@ def format_zmq_address(ip: str, port: int) -> str: return f"tcp://{ip}:{port}" -def get_node_ip_address_raw() -> str: +def get_node_ip_address() -> str: """A wrapper around Ray's get_node_ip_address(). This function intentionally returns a raw IPv4/IPv6 address WITHOUT brackets. """ - return get_node_ip_address().strip("[]") + return ray.util.get_node_ip_address().strip("[]") def get_free_port(ip: str) -> int: @@ -253,7 +252,7 @@ def create_zmq_socket( ctx: zmq.Context, socket_type: Any, ip: str, - identity: Optional[bytestr] = None, + identity: bytestr | None = None, ) -> zmq.Socket: """Create ZMQ socket. @@ -299,9 +298,9 @@ def with_zmq_socket( socket_name: str, *, get_identity: Callable[[Any], str], - get_peer: Callable[[Any, Optional[str]], ZMQServerInfo], - resolve_target: Optional[Callable[[tuple, dict], Optional[str]]] = None, - timeout: Optional[int] = None, + get_peer: Callable[[Any, str | None], ZMQServerInfo], + resolve_target: Callable[[tuple, dict], str | None] | None = None, + timeout: int | None = None, ): """Create a reusable async decorator for request sockets. @@ -330,7 +329,7 @@ async def wrapper(self, *args, **kwargs): if owner_id is None: raise RuntimeError("get_identity returned None") - target_name: Optional[str] = None + target_name: str | None = None if resolve_target is not None: target_name = resolve_target(args, kwargs) @@ -369,13 +368,11 @@ async def wrapper(self, *args, **kwargs): return decorator -def process_zmq_server_info( - handlers: dict[Any, Any] | Any, -): # noqa: UP007 +def process_zmq_server_info(handlers: dict[Any, Any] | Any): """Extract ZMQ server information from handler objects. Args: - handlers: Dictionary of handler objects (controllers, storage managers, or storage units), + handlers: Dictionary of handler objects (controllers, storage managers or storage units), or a single handler object Returns: @@ -390,11 +387,9 @@ def process_zmq_server_info( >>> # Multiple handlers >>> handlers = {"storage_0": storage_0, "storage_1": storage_1} >>> info_dict = process_zmq_server_info(handlers)""" - # Handle single handler object case if not isinstance(handlers, dict): return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] else: - # Handle dictionary case server_info = {} for name, handler in handlers.items(): server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 9a66b8e4..bfee76d4 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -143,7 +143,7 @@ def demonstrate_storage_backend_options(): print(" - Leverage Ray's distributed object store to store data") print("5. Custom Storage Backends") - print(" - Implement your own storage manager by inheriting from `TransferQueueStorageManager` base class") + print(" - Implement your own storage manager by inheriting from `StorageManager` base class") print(" - For KV based storage, you only need to provide a storage client and integrate with `KVStorageManager`")