From 0f5649a1d04b4d8ca5c54c808555fb86d2848062 Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Fri, 1 May 2026 18:13:10 +0800 Subject: [PATCH 1/2] refactor: simplify StorageManager naming Signed-off-by: ji-huazhong --- README.md | 18 ++--- scripts/put_benchmark.py | 2 +- tests/test_async_simple_storage_manager.py | 40 +++++----- tests/test_client.py | 10 +-- tests/test_metadata.py | 2 +- tests/test_ray_p2p.py | 13 ++-- tests/test_simple_storage_unit.py | 4 +- tests/test_storage_client_factory.py | 2 +- transfer_queue/client.py | 4 +- transfer_queue/controller.py | 17 +---- transfer_queue/interface.py | 2 +- transfer_queue/metadata.py | 10 +-- transfer_queue/storage/__init__.py | 10 +-- transfer_queue/storage/clients/__init__.py | 5 +- transfer_queue/storage/clients/base.py | 43 ++++++++++- transfer_queue/storage/clients/factory.py | 57 -------------- .../storage/clients/mooncake_client.py | 5 +- .../storage/clients/ray_storage_client.py | 5 +- .../storage/clients/yuanrong_client.py | 5 +- transfer_queue/storage/managers/__init__.py | 9 +-- transfer_queue/storage/managers/base.py | 59 +++++++++++++-- transfer_queue/storage/managers/factory.py | 74 ------------------- .../storage/managers/mooncake_manager.py | 5 +- .../storage/managers/ray_storage_manager.py | 5 +- ...d_manager.py => simple_storage_manager.py} | 7 +- .../storage/managers/yuanrong_manager.py | 5 +- .../{simple_backend.py => simple_storage.py} | 4 +- transfer_queue/utils/enum_utils.py | 2 +- transfer_queue/utils/zmq_utils.py | 6 +- tutorial/01_core_components.py | 2 +- 30 files changed, 178 insertions(+), 254 deletions(-) delete mode 100644 transfer_queue/storage/clients/factory.py delete mode 100644 transfer_queue/storage/managers/factory.py rename transfer_queue/storage/managers/{simple_backend_manager.py => simple_storage_manager.py} (98%) rename transfer_queue/storage/{simple_backend.py => simple_storage.py} (99%) diff --git a/README.md b/README.md index ccc8b564..7f4f6da0 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo ### Control Plane: Panoramic Data Management -In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. Once all required data fields are ready (i.e., written to the `TransferQueueStorageManager`), the data sample can be consumed by downstream tasks. +In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. Once all required data fields are ready (i.e., written to the `StorageManager`), the data sample can be consumed by downstream tasks. We also track the consumption history for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computational tasks require the same data field, they can consume the data independently without interfering with each other. @@ -66,7 +66,7 @@ To make the data retrieval process more customizable, we provide a `Sampler` cla In the data plane, we utilize a pluggable design, enabling TransferQueue to integrate with different storage backends based on user requirements. -Specifically, we provide a `TransferQueueStorageManager` abstraction class that defines the core APIs as follows: +Specifically, we provide a `StorageManager` abstraction class that defines the core APIs as follows: - `async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None` - `async def get_data(self, metadata: BatchMeta) -> TensorDict` @@ -298,21 +298,19 @@ The data plane is organized as follows: │ │── simple_backend.py # Default distributed storage backend (SimpleStorageUnit) by TQ │ ├── managers/ # Managers are upper level interfaces that encapsulate the interaction logic with TQ system. │ │ ├── __init__.py - │ │ ├──base.py # TransferQueueStorageManager, KVStorageManager - │ │ ├──simple_backend_manager.py # AsyncSimpleStorageManager + │ │ ├──base.py # StorageManager, KVStorageManager, StorageManagerFactory + │ │ ├──simple_storage_manager.py # AsyncSimpleStorageManager │ │ ├──yuanrong_manager.py # YuanrongStorageManager - │ │ ├──mooncake_manager.py # MooncakeStorageManager - │ │ └──factory.py # TransferQueueStorageManagerFactory + │ │ └──mooncake_manager.py # MooncakeStorageManager │ └── clients/ # Clients are lower level interfaces that directly manipulate the target storage backend. │ │ ├── __init__.py - │ │ ├── base.py # TransferQueueStorageKVClient + │ │ ├── base.py # StorageKVClient, StorageClientFactory │ │ ├── yuanrong_client.py # YuanrongStorageClient │ │ ├── mooncake_client.py # MooncakeStorageClient - │ │ ├── ray_storage_client.py # RayStorageClient - │ │ └── factory.py # TransferQueueStorageClientFactory + │ │ └── ray_storage_client.py # RayStorageClient ``` -To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `TransferQueueStorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends. +To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `StorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends. Distributed storage backends often come with their own native clients serving as the interface of the storage system. In such cases, a low-level adapter for this client can be written, following the examples provided in the `storage/clients` directory. diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index 005572c2..c67bb54c 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -30,7 +30,7 @@ from transfer_queue import TransferQueueClient from transfer_queue.controller import TransferQueueController -from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.zmq_utils import process_zmq_server_info diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 4d1419a6..606e9822 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -24,7 +24,7 @@ from transfer_queue.metadata import BatchMeta from transfer_queue.storage import AsyncSimpleStorageManager -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo @@ -35,13 +35,13 @@ async def mock_async_storage_manager(): # Mock storage unit infos storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 12345}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 12346}, @@ -50,7 +50,7 @@ async def mock_async_storage_manager(): # Mock controller info controller_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", ports={"handshake_socket": 12347, "data_status_update_socket": 12348}, @@ -61,9 +61,7 @@ async def mock_async_storage_manager(): } # Mock the handshake process entirely to avoid ZMQ complexity - with patch( - "transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller" - ) as mock_connect: + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller") as mock_connect: # Mock the manager without actually connecting manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_storage_manager" @@ -148,7 +146,7 @@ async def test_async_storage_manager_error_handling(): # Mock storage unit infos storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 12345}, @@ -157,7 +155,7 @@ async def test_async_storage_manager_error_handling(): # Mock controller info controller_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", ports={"handshake_socket": 12346, "data_status_update_socket": 12347}, @@ -242,19 +240,19 @@ async def test_get_data_routes_from_hash(): """get_data should route using global_idx % num_su (hash routing).""" storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19010}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19011}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_get" manager.storage_unit_infos = storage_unit_infos @@ -295,19 +293,19 @@ async def test_clear_data_routes_from_hash(): """clear_data should route using global_idx % num_su (hash routing).""" storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19020}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19021}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_clear" manager.storage_unit_infos = storage_unit_infos @@ -346,19 +344,19 @@ async def test_hash_routing_stable_across_batch_sizes(): """ storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19030}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19031}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_hash_batch" manager.storage_unit_infos = storage_unit_infos @@ -407,19 +405,19 @@ async def test_hash_routing_stable_reversed_order(): """ storage_unit_infos = { "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_0", ip="127.0.0.1", ports={"put_get_socket": 19040}, ), "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id="storage_1", ip="127.0.0.1", ports={"put_get_socket": 19041}, ), } - with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"): manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) manager.storage_manager_id = "test_hash_order" manager.storage_unit_infos = storage_unit_infos diff --git a/tests/test_client.py b/tests/test_client.py index 7ce8bcbf..2e7e37a0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -24,7 +24,7 @@ from transfer_queue import TransferQueueClient from transfer_queue.metadata import BatchMeta -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.zmq_utils import ( ZMQMessage, ZMQRequestType, @@ -59,7 +59,7 @@ def __init__(self, controller_id="controller_0"): self.request_port = self._bind_to_random_port(self.request_socket) self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id=controller_id, ip="127.0.0.1", ports={ @@ -300,7 +300,7 @@ def __init__(self, storage_id="storage_0"): self.data_port = self._bind_to_random_port(self.data_socket) self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id=storage_id, ip="127.0.0.1", ports={ @@ -409,7 +409,7 @@ def client_setup(mock_controller, mock_storage): # Mock the storage manager to avoid handshake issues but mock all data operations with patch( - "transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller" + "transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller" ): config = { "controller_info": mock_controller.zmq_server_info, @@ -502,7 +502,7 @@ def test_single_controller_multiple_storages(): # Mock the storage manager to avoid handshake issues but mock all data operations with patch( - "transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller" + "transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller" ): config = { "controller_info": controller.zmq_server_info, diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 20ace48e..b9e80dac 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -678,7 +678,7 @@ class TestStorageUnitDataStrict: def test_put_data_length_mismatch_raises(self): """put_data must raise when global_indexes and field values have different lengths.""" - from transfer_queue.storage.simple_backend import StorageUnitData + from transfer_queue.storage.simple_storage import StorageUnitData sud = StorageUnitData(storage_size=10) # 3 indexes but only 2 values — must raise, not silently drop diff --git a/tests/test_ray_p2p.py b/tests/test_ray_p2p.py index b9f0b835..353bb926 100644 --- a/tests/test_ray_p2p.py +++ b/tests/test_ray_p2p.py @@ -23,8 +23,7 @@ from transfer_queue.client import TransferQueueClient from transfer_queue.metadata import BatchMeta -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.zmq_utils import ZMQServerInfo TEST_CONFIGS: list[tuple[tuple[int, int], torch.dtype]] = [ @@ -45,18 +44,18 @@ # Step 1: Mock Controller Role try: - from transfer_queue.role import TransferQueueRole + from transfer_queue.role import Role except ImportError: from enum import Enum - class TransferQueueRole(Enum): + class Role(Enum): CONTROLLER = "controller" STORAGE = "storage" def create_mock_controller(): return ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", ports={ @@ -71,9 +70,9 @@ def create_mock_controller(): def ensure_mock_storage_manager_registered(): """Ensure MockKVStorageManager is registered in current process.""" - if "KV_MOCK" not in TransferQueueStorageManagerFactory._registry: + if "KV_MOCK" not in StorageManagerFactory._registry: - @TransferQueueStorageManagerFactory.register("KV_MOCK") + @StorageManagerFactory.register("KV_MOCK") class MockKVStorageManager(KVStorageManager): def _connect_to_controller(self): pass diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index 2519b975..319a46e7 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -21,7 +21,7 @@ import torch import zmq -from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType @@ -420,7 +420,7 @@ def test_storage_unit_data_direct(): def test_storage_unit_data_capacity_uses_active_keys(): """Capacity check must use _active_keys, not scan field_data.""" - from transfer_queue.storage.simple_backend import StorageUnitData + from transfer_queue.storage.simple_storage import StorageUnitData storage = StorageUnitData(storage_size=3) diff --git a/tests/test_storage_client_factory.py b/tests/test_storage_client_factory.py index 012c0cc5..0e1a5f90 100644 --- a/tests/test_storage_client_factory.py +++ b/tests/test_storage_client_factory.py @@ -19,7 +19,7 @@ import pytest import torch -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 5f3a4e3f..bd8b937e 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -28,7 +28,7 @@ BatchMeta, ) from transfer_queue.storage import ( - TransferQueueStorageManagerFactory, + StorageManagerFactory, ) from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads from transfer_queue.utils.logging_utils import get_logger @@ -92,7 +92,7 @@ def initialize_storage_manager( - zmq_info: ZMQ server information about the storage units """ - self.storage_manager = TransferQueueStorageManagerFactory.create( + self.storage_manager = StorageManagerFactory.create( manager_type, controller_info=self._controller, config=config ) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 49aa2410..96f971c4 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -35,7 +35,7 @@ BatchMeta, ) from transfer_queue.sampler import BaseSampler, SequentialSampler -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.perf_utils import IntervalPerfMonitor from transfer_queue.utils.zmq_utils import ( @@ -384,7 +384,6 @@ def allocated_samples_num(self) -> int: return self.production_status.shape[0] # ==================== Index Pre-Allocation Methods ==================== - def register_pre_allocated_indexes(self, allocated_indexes: list[int]): """ Register pre-allocated sample indexes to this partition. @@ -442,7 +441,6 @@ def activate_pre_allocated_indexes(self, sample_num: int) -> list[int]: return global_index_to_allocate # ==================== Dynamic Expansion Methods ==================== - def ensure_samples_capacity(self, required_samples: int) -> None: """ Ensure the production status tensor has enough rows for the required samples. @@ -492,7 +490,6 @@ def ensure_fields_capacity(self, required_fields: int) -> None: logger.debug(f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields") # ==================== Production Status Interface ==================== - def update_production_status( self, global_indices: list[int], @@ -611,7 +608,6 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): ) # ==================== Consumption Status Interface ==================== - def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]: """ Get or create consumption status for a specific task. @@ -713,7 +709,6 @@ def get_production_status_for_fields( return partition_global_index, production_status # ==================== Data Scanning and Query Methods ==================== - def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: """ Scan data status to find samples ready for consumption. @@ -761,7 +756,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: return ready_sample_indices # ==================== Metadata Methods ==================== - def get_field_schema( self, field_names: list[str], batch_global_indexes: list[int] | None = None ) -> dict[str, dict[str, Any]]: @@ -830,7 +824,6 @@ def set_custom_meta(self, custom_meta: dict[int, dict]) -> None: self.custom_meta.update(custom_meta) # ==================== Statistics and Monitoring ==================== - def get_statistics(self) -> dict[str, Any]: """Get detailed statistics for this partition.""" stats = { @@ -877,7 +870,6 @@ def get_statistics(self) -> dict[str, Any]: return stats # ==================== Serialization ==================== - def to_snapshot(self): """ Get a snapshot of partition status information. @@ -1028,7 +1020,6 @@ def __init__( logger.info(f"TransferQueue Controller {self.controller_id} initialized") # ==================== Partition Management API ==================== - def create_partition(self, partition_id: str) -> bool: """ Create a new data partition with pre-allocated sample indexes. @@ -1098,7 +1089,6 @@ def list_partitions(self) -> list[str]: return list(self.partitions.keys()) # ==================== Partition Index Management API ==================== - def get_partition_index_range(self, partition_id: str) -> list[int]: """ Get all indexes for a specific partition. @@ -1114,7 +1104,6 @@ def get_partition_index_range(self, partition_id: str) -> list[int]: return self.index_manager.get_indexes_for_partition(partition_id) # ==================== Data Production API ==================== - def update_production_status( self, partition_id: str, @@ -1150,7 +1139,6 @@ def update_production_status( return success # ==================== Data Consumption API ==================== - def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Get or create consumption status for a specific task and partition. @@ -1398,7 +1386,6 @@ def scan_data_status( return ready_sample_indices # ==================== Metadata Generation API ==================== - def generate_batch_meta( self, partition_id: str, @@ -1726,7 +1713,7 @@ def _init_zmq_socket(self): continue self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, + role=Role.CONTROLLER, id=self.controller_id, ip=self._node_ip, ports={ diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 66141b8f..4b154ab5 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -32,7 +32,7 @@ from transfer_queue.metadata import KVBatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler -from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.storage.simple_storage import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.yuanrong_utils import ( diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 03784995..64fc18e2 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -30,11 +30,6 @@ logger = get_logger(__name__) -# --------------------------------------------------------------------------- -# Internal helpers -# --------------------------------------------------------------------------- - - def _extra_info_values_equal(a: Any, b: Any) -> bool: """Compare two extra_info values for equality. @@ -55,7 +50,7 @@ def _extra_info_values_equal(a: Any, b: Any) -> bool: class _SampleView: """Lazy read-only view of a single sample row in a columnar BatchMeta. - All returned dicts are ``MappingProxyType`` – attempts to mutate them + All returned dicts are ``MappingProxyType``, and attempts to mutate them raise ``TypeError``, making it obvious that this is a snapshot view. """ @@ -379,7 +374,6 @@ def get_shapes(self, field_name: str) -> list: return [meta.get("shape")] * self.size # ==================== Extra Info Methods ==================== - def get_extra_info(self, key: str, default: Any = None) -> Any: """Get extra info by key""" return self.extra_info.get(key, default) @@ -417,7 +411,6 @@ def has_extra_info(self, key: str) -> bool: return key in self.extra_info # ==================== Custom Meta Methods (User Layer) ==================== - def get_all_custom_meta(self) -> list[dict[str, Any]]: """Get all custom_meta as a list of dictionary (one per sample, in global_indexes order). @@ -451,7 +444,6 @@ def clear_custom_meta(self) -> None: self.custom_meta = [{} for _ in range(self.size)] # ==================== Core BatchMeta Operations ==================== - def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": """Add new fields from a TensorDict to all samples in this batch. This modifies the batch in-place to include the new fields. diff --git a/transfer_queue/storage/__init__.py b/transfer_queue/storage/__init__.py index 6fbd3415..2fb1be46 100644 --- a/transfer_queue/storage/__init__.py +++ b/transfer_queue/storage/__init__.py @@ -17,17 +17,17 @@ AsyncSimpleStorageManager, MooncakeStorageManager, RayStorageManager, - TransferQueueStorageManager, - TransferQueueStorageManagerFactory, + StorageManager, + StorageManagerFactory, YuanrongStorageManager, ) -from .simple_backend import SimpleStorageUnit, StorageUnitData +from .simple_storage import SimpleStorageUnit, StorageUnitData __all__ = [ "SimpleStorageUnit", "StorageUnitData", - "TransferQueueStorageManager", - "TransferQueueStorageManagerFactory", + "StorageManager", + "StorageManagerFactory", "AsyncSimpleStorageManager", "MooncakeStorageManager", "YuanrongStorageManager", diff --git a/transfer_queue/storage/clients/__init__.py b/transfer_queue/storage/clients/__init__.py index 2b861166..d54a32ef 100644 --- a/transfer_queue/storage/clients/__init__.py +++ b/transfer_queue/storage/clients/__init__.py @@ -14,14 +14,13 @@ # limitations under the License. # This module is currently empty but reserved for future client implementations -from .base import TransferQueueStorageKVClient -from .factory import StorageClientFactory +from .base import StorageClientFactory, StorageKVClient from .mooncake_client import MooncakeStoreClient from .ray_storage_client import RayStorageClient from .yuanrong_client import YuanrongStorageClient __all__ = [ - "TransferQueueStorageKVClient", + "StorageKVClient", "StorageClientFactory", "RayStorageClient", "MooncakeStoreClient", diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index a6d63f6a..9c475493 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -17,7 +17,7 @@ from typing import Any, Optional -class TransferQueueStorageKVClient(ABC): +class StorageKVClient(ABC): """ Abstract base class for storage client. Subclasses must implement the core methods: put, get, and clear. @@ -68,3 +68,44 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non def clear(self, keys: list[str], custom_backend_meta=None) -> None: """Clear key-value pairs in the storage backend.""" raise NotImplementedError("Subclasses must implement clear") + + +class StorageClientFactory: + """ + Factory class for creating storage client instances. + Uses a decorator-based registration mechanism to map client names to classes. + """ + + # Class variable: maps client names to their corresponding classes + _registry: dict[str, type[StorageKVClient]] = {} + + @classmethod + def register(cls, client_type: str): + """ + Decorator to register a concrete client class with the factory. + Args: + client_type (str): The name used to identify the client + Returns: + Callable: The decorator function that returns the original class + """ + + def decorator(client_class: type[StorageKVClient]) -> type[StorageKVClient]: + cls._registry[client_type] = client_class + return client_class + + return decorator + + @classmethod + def create(cls, client_type: str, config: dict) -> StorageKVClient: + """ + Create and return an instance of the storage client by name. + Args: + client_type (str): The registered name of the client + Returns: + StorageClientFactory: An instance of the requested client + Raises: + ValueError: If no client is registered with the given name + """ + if client_type not in cls._registry: + raise ValueError(f"Unknown StorageClient: {client_type}") + return cls._registry[client_type](config) diff --git a/transfer_queue/storage/clients/factory.py b/transfer_queue/storage/clients/factory.py deleted file mode 100644 index 73979bd5..00000000 --- a/transfer_queue/storage/clients/factory.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient - - -class StorageClientFactory: - """ - Factory class for creating storage client instances. - Uses a decorator-based registration mechanism to map client names to classes. - """ - - # Class variable: maps client names to their corresponding classes - _registry: dict[str, type[TransferQueueStorageKVClient]] = {} - - @classmethod - def register(cls, client_type: str): - """ - Decorator to register a concrete client class with the factory. - Args: - client_type (str): The name used to identify the client - Returns: - Callable: The decorator function that returns the original class - """ - - def decorator(client_class: type[TransferQueueStorageKVClient]) -> type[TransferQueueStorageKVClient]: - cls._registry[client_type] = client_class - return client_class - - return decorator - - @classmethod - def create(cls, client_type: str, config: dict) -> TransferQueueStorageKVClient: - """ - Create and return an instance of the storage client by name. - Args: - client_type (str): The registered name of the client - Returns: - StorageClientFactory: An instance of the requested client - Raises: - ValueError: If no client is registered with the given name - """ - if client_type not in cls._registry: - raise ValueError(f"Unknown StorageClient: {client_type}") - return cls._registry[client_type](config) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 8f311b4d..7a47111c 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -20,8 +20,7 @@ import torch from torch import Tensor -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory @@ -39,7 +38,7 @@ @StorageClientFactory.register("MooncakeStoreClient") -class MooncakeStoreClient(TransferQueueStorageKVClient): +class MooncakeStoreClient(StorageKVClient): """ Storage client for MooncakeStore. """ diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index c290f6f2..db901533 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -19,8 +19,7 @@ import ray import torch -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient @ray.remote(max_concurrency=8) @@ -47,7 +46,7 @@ def clear_obj_ref(self, keys: list[str]): @StorageClientFactory.register("RayStorageClient") -class RayStorageClient(TransferQueueStorageKVClient): +class RayStorageClient(StorageKVClient): """ Storage client for Ray RDT. """ diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 9aeffdc1..aa644065 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -21,8 +21,7 @@ import torch from torch import Tensor -from transfer_queue.storage.clients.base import TransferQueueStorageKVClient -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.serial_utils import _decoder, _encoder from transfer_queue.utils.yuanrong_utils import find_reachable_host @@ -365,7 +364,7 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]: @StorageClientFactory.register("YuanrongStorageClient") -class YuanrongStorageClient(TransferQueueStorageKVClient): +class YuanrongStorageClient(StorageKVClient): """ Storage client for YuanRong DataSystem. diff --git a/transfer_queue/storage/managers/__init__.py b/transfer_queue/storage/managers/__init__.py index 77954bb3..d5e71abd 100644 --- a/transfer_queue/storage/managers/__init__.py +++ b/transfer_queue/storage/managers/__init__.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import TransferQueueStorageManager -from .factory import TransferQueueStorageManagerFactory +from .base import StorageManager, StorageManagerFactory from .mooncake_manager import MooncakeStorageManager from .ray_storage_manager import RayStorageManager -from .simple_backend_manager import AsyncSimpleStorageManager +from .simple_storage_manager import AsyncSimpleStorageManager from .yuanrong_manager import YuanrongStorageManager __all__ = [ - "TransferQueueStorageManager", - "TransferQueueStorageManagerFactory", + "StorageManager", + "StorageManagerFactory", "AsyncSimpleStorageManager", "YuanrongStorageManager", "MooncakeStorageManager", diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 55901e31..d578873c 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -17,6 +17,7 @@ import itertools import os import time +import warnings import weakref from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -32,7 +33,7 @@ from torch import Tensor from transfer_queue.metadata import BatchMeta, extract_field_schema -from transfer_queue.storage.clients.factory import StorageClientFactory +from transfer_queue.storage.clients.base import StorageClientFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket @@ -49,7 +50,7 @@ LIMIT_THREADS_PER_MANAGER_IN_RAY_ACTOR = 4 -class TransferQueueStorageManager(ABC): +class StorageManager(ABC): """Base class for storage layer. It defines the interface for data operations and generally provides handshake & notification capabilities.""" @@ -364,11 +365,59 @@ def __del__(self): logger.error(f"[{self.storage_manager_id}]: Exception during __del__: {str(e)}") -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory # noqa: E402 +class StorageManagerFactory: + """Factory that creates a StorageManager instance.""" + _registry: dict[str, type[StorageManager]] = {} -@TransferQueueStorageManagerFactory.register("KVStorageManager") -class KVStorageManager(TransferQueueStorageManager): + @classmethod + def register(cls, manager_type: str): + """Register a StorageManager class.""" + + def decorator(manager_cls: type[StorageManager]): + if not issubclass(manager_cls, StorageManager): + raise TypeError( + f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be " + f"a subclass of StorageManager" + ) + cls._registry[manager_type] = manager_cls + return manager_cls + + return decorator + + @classmethod + def create(cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any]) -> StorageManager: + """Create and return a StorageManager instance.""" + if manager_type not in cls._registry: + if manager_type == "AsyncSimpleStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "SimpleStorage" + elif manager_type == "MooncakeStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "MooncakeStore" + elif manager_type == "YuanrongStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "Yuanrong" + else: + raise ValueError( + f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" + ) + return cls._registry[manager_type](controller_info, config) + + +class KVStorageManager(StorageManager): """ A storage manager that uses a key-value (KV) backend (e.g., YuanRong) to store and retrieve tensor data. It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations. diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py deleted file mode 100644 index 04d2cd03..00000000 --- a/transfer_queue/storage/managers/factory.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import Any - -from transfer_queue.storage.managers.base import TransferQueueStorageManager -from transfer_queue.utils.zmq_utils import ZMQServerInfo - - -class TransferQueueStorageManagerFactory: - """Factory that creates a StorageManager instance.""" - - _registry: dict[str, type[TransferQueueStorageManager]] = {} - - @classmethod - def register(cls, manager_type: str): - """Register a TransferQueueStorageManager class.""" - - def decorator(manager_cls: type[TransferQueueStorageManager]): - if not issubclass(manager_cls, TransferQueueStorageManager): - raise TypeError( - f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be " - f"a subclass of TransferQueueStorageManager" - ) - cls._registry[manager_type] = manager_cls - return manager_cls - - return decorator - - @classmethod - def create( - cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any] - ) -> TransferQueueStorageManager: - """Create and return a TransferQueueStorageManager instance.""" - if manager_type not in cls._registry: - if manager_type == "AsyncSimpleStorageManager": - warnings.warn( - f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.", - category=DeprecationWarning, - stacklevel=2, - ) - manager_type = "SimpleStorage" - elif manager_type == "MooncakeStorageManager": - warnings.warn( - f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.", - category=DeprecationWarning, - stacklevel=2, - ) - manager_type = "MooncakeStore" - elif manager_type == "YuanrongStorageManager": - warnings.warn( - f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.", - category=DeprecationWarning, - stacklevel=2, - ) - manager_type = "Yuanrong" - else: - raise ValueError( - f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" - ) - return cls._registry[manager_type](controller_info, config) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 7b43ca71..c1a354c9 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -15,15 +15,14 @@ from typing import Any -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = get_logger(__name__) -@TransferQueueStorageManagerFactory.register("MooncakeStore") +@StorageManagerFactory.register("MooncakeStore") class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" diff --git a/transfer_queue/storage/managers/ray_storage_manager.py b/transfer_queue/storage/managers/ray_storage_manager.py index 78069c21..0cc2a09c 100644 --- a/transfer_queue/storage/managers/ray_storage_manager.py +++ b/transfer_queue/storage/managers/ray_storage_manager.py @@ -15,12 +15,11 @@ from typing import Any -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.zmq_utils import ZMQServerInfo -@TransferQueueStorageManagerFactory.register("RayStore") +@StorageManagerFactory.register("RayStore") class RayStorageManager(KVStorageManager): """Storage manager for Ray-RDT backend.""" diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py similarity index 98% rename from transfer_queue/storage/managers/simple_backend_manager.py rename to transfer_queue/storage/managers/simple_storage_manager.py index 0771638b..184dcfdc 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -27,8 +27,7 @@ from tensordict import NonTensorStack, TensorDict from transfer_queue.metadata import BatchMeta, extract_field_schema -from transfer_queue.storage.managers.base import TransferQueueStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import StorageManager, StorageManagerFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ( ZMQMessage, @@ -58,8 +57,8 @@ class RoutingGroup(NamedTuple): batch_positions: list[int] # corresponding positions in the original batch -@TransferQueueStorageManagerFactory.register("SimpleStorage") -class AsyncSimpleStorageManager(TransferQueueStorageManager): +@StorageManagerFactory.register("SimpleStorage") +class AsyncSimpleStorageManager(StorageManager): """Asynchronous storage manager that handles multiple storage units. This manager provides async put/get/clear operations across multiple SimpleStorageUnit diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index 4c37cc73..f76b47b2 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -15,15 +15,14 @@ from typing import Any -from transfer_queue.storage.managers.base import KVStorageManager -from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = get_logger(__name__) -@TransferQueueStorageManagerFactory.register("Yuanrong") +@StorageManagerFactory.register("Yuanrong") class YuanrongStorageManager(KVStorageManager): """Storage manager for Yuanrong backend.""" diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_storage.py similarity index 99% rename from transfer_queue/storage/simple_backend.py rename to transfer_queue/storage/simple_storage.py index bf334efc..492037d1 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_storage.py @@ -24,7 +24,7 @@ import zmq from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads -from transfer_queue.utils.enum_utils import TransferQueueRole +from transfer_queue.utils.enum_utils import Role from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.perf_utils import IntervalPerfMonitor from transfer_queue.utils.zmq_utils import ( @@ -205,7 +205,7 @@ def _init_zmq_socket(self) -> None: self.worker_socket.bind(self._inproc_addr) self.zmq_server_info = ZMQServerInfo( - role=TransferQueueRole.STORAGE, + role=Role.STORAGE, id=str(self.storage_unit_id), ip=self._node_ip, ports={"put_get_socket": self._put_get_socket_port}, diff --git a/transfer_queue/utils/enum_utils.py b/transfer_queue/utils/enum_utils.py index 929889da..0af0df12 100644 --- a/transfer_queue/utils/enum_utils.py +++ b/transfer_queue/utils/enum_utils.py @@ -32,7 +32,7 @@ def _missing_(cls, value): ) -class TransferQueueRole(ExplicitEnum): +class Role(ExplicitEnum): """Available Roles of TransferQueue.""" CONTROLLER = "TransferQueueController" diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 8af6aec3..6216e7ec 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -26,7 +26,7 @@ import zmq.asyncio from ray.util import get_node_ip_address -from transfer_queue.utils.enum_utils import ExplicitEnum, TransferQueueRole +from transfer_queue.utils.enum_utils import ExplicitEnum, Role from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.serial_utils import decode, encode @@ -104,7 +104,7 @@ class ZMQServerInfo: TransferQueue server info class. """ - def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, int]): + def __init__(self, role: Role, id: str, ip: str, ports: dict[str, int]): self.role = role self.id = id self.ip = ip @@ -371,7 +371,7 @@ async def wrapper(self, *args, **kwargs): def process_zmq_server_info( handlers: dict[Any, Any] | Any, -): # noqa: UP007 +): """Extract ZMQ server information from handler objects. Args: diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 9a66b8e4..bfee76d4 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -143,7 +143,7 @@ def demonstrate_storage_backend_options(): print(" - Leverage Ray's distributed object store to store data") print("5. Custom Storage Backends") - print(" - Implement your own storage manager by inheriting from `TransferQueueStorageManager` base class") + print(" - Implement your own storage manager by inheriting from `StorageManager` base class") print(" - For KV based storage, you only need to provide a storage client and integrate with `KVStorageManager`") From 79d1e671a2906899e747c7c434f3b8548f24ccc5 Mon Sep 17 00:00:00 2001 From: ji-huazhong Date: Sat, 2 May 2026 11:57:13 +0800 Subject: [PATCH 2/2] =?UTF-8?q?set=20target-verion=20py310=20and=20enable?= =?UTF-8?q?=20modern=20type=20annotation=20rules=20(Optional/Union=20?= =?UTF-8?q?=E2=86=92=20X=20|=20None)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ji-huazhong --- pyproject.toml | 3 +- transfer_queue/client.py | 56 +++++++++---------- transfer_queue/controller.py | 36 ++++++------ .../dataloader/streaming_dataloader.py | 3 +- transfer_queue/interface.py | 46 ++++++++------- transfer_queue/metadata.py | 20 +++---- transfer_queue/sampler/base.py | 4 +- transfer_queue/storage/clients/base.py | 4 +- .../storage/clients/mooncake_client.py | 10 ++-- .../storage/clients/ray_storage_client.py | 4 +- .../storage/clients/yuanrong_client.py | 14 ++--- transfer_queue/storage/managers/base.py | 16 +++--- transfer_queue/storage/simple_storage.py | 18 +++--- transfer_queue/utils/common.py | 3 +- transfer_queue/utils/yuanrong_utils.py | 6 +- transfer_queue/utils/zmq_utils.py | 14 ++--- 16 files changed, 126 insertions(+), 131 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0bbb2f0c..20b148be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requires-python = ">=3.10" # Note: While the formatter will attempt to format lines such that they remain within the line-length, # it isn't a hard upper bound, and formatted lines may exceed the line-length. line-length = 120 +target-version = "py310" [tool.ruff.lint] isort = {known-first-party = ["transfer_queue"]} @@ -60,7 +61,7 @@ ignore = [ # `.log()` statement uses f-string "G004", # X | None for type annotations - "UP045", + # "UP045", # deprecated import "UP035", ] diff --git a/transfer_queue/client.py b/transfer_queue/client.py index bd8b937e..c997d750 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -16,7 +16,7 @@ import asyncio import os import threading -from typing import Any, Callable, Optional +from typing import Any, Callable import torch import zmq @@ -104,9 +104,9 @@ async def async_get_meta( batch_size: int, partition_id: str, mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, - socket: Optional[zmq.asyncio.Socket] = None, + task_name: str | None = None, + sampling_config: dict[str, Any] | None = None, + socket: zmq.asyncio.Socket | None = None, ) -> BatchMeta: """Asynchronously fetch data metadata from the controller via ZMQ. @@ -191,7 +191,7 @@ async def async_get_meta( async def async_set_custom_meta( self, metadata: BatchMeta, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> None: """ Asynchronously send custom metadata to the controller. @@ -264,9 +264,9 @@ async def async_set_custom_meta( async def async_put( self, data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + metadata: BatchMeta | None = None, + partition_id: str | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> BatchMeta: """Asynchronously write data to storage units based on metadata. @@ -575,8 +575,8 @@ async def async_get_consumption_status( self, task_name: str, partition_id: str, - socket: Optional[zmq.asyncio.Socket] = None, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + socket: zmq.asyncio.Socket | None = None, + ) -> tuple[Tensor | None, Tensor | None]: """Get consumption status for current partition in a specific task. Args: @@ -638,8 +638,8 @@ async def async_get_production_status( self, data_fields: list[str], partition_id: str, - socket: Optional[zmq.asyncio.Socket] = None, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + socket: zmq.asyncio.Socket | None = None, + ) -> tuple[Tensor | None, Tensor | None]: """Get production status for specific data fields and partition. Args: @@ -769,8 +769,8 @@ async def async_check_production_status( async def async_reset_consumption( self, partition_id: str, - task_name: Optional[str] = None, - socket: Optional[zmq.asyncio.Socket] = None, + task_name: str | None = None, + socket: zmq.asyncio.Socket | None = None, ) -> bool: """Asynchronously reset consumption status for a partition. @@ -830,7 +830,7 @@ async def async_reset_consumption( @with_controller_socket async def async_get_partition_list( self, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> list[str]: """Asynchronously fetch the list of partition ids from the controller. @@ -879,7 +879,7 @@ async def async_kv_retrieve_meta( keys: list[str] | str, partition_id: str, create: bool = False, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> BatchMeta: """Asynchronously retrieve BatchMeta from the controller using user-specified keys. @@ -944,7 +944,7 @@ async def async_kv_retrieve_keys( self, global_indexes: list[int] | int, partition_id: str, - socket: Optional[zmq.asyncio.Socket] = None, + socket: zmq.asyncio.Socket | None = None, ) -> list[str]: """Asynchronously retrieve keys according to global_indexes from the controller. @@ -1005,8 +1005,8 @@ async def async_kv_retrieve_keys( @with_controller_socket async def async_kv_list( self, - partition_id: Optional[str] = None, - socket: Optional[zmq.asyncio.Socket] = None, + partition_id: str | None = None, + socket: zmq.asyncio.Socket | None = None, ) -> dict[str, dict[str, Any]]: """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions. @@ -1145,8 +1145,8 @@ def get_meta( batch_size: int, partition_id: str, mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, + task_name: str | None = None, + sampling_config: dict[str, Any] | None = None, ) -> BatchMeta: """Synchronously fetch data metadata from the controller via ZMQ. @@ -1234,9 +1234,9 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: def put( self, data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + metadata: BatchMeta | None = None, + partition_id: str | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> BatchMeta: """Synchronously write data to storage units based on metadata. @@ -1356,7 +1356,7 @@ def get_consumption_status( self, task_name: str, partition_id: str, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + ) -> tuple[Tensor | None, Tensor | None]: """Synchronously get consumption status for a specific task and partition. Args: @@ -1384,7 +1384,7 @@ def get_production_status( self, data_fields: list[str], partition_id: str, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + ) -> tuple[Tensor | None, Tensor | None]: """Synchronously get production status for specific data fields and partition. Args: @@ -1454,7 +1454,7 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> """ return self._check_production_status(data_fields=data_fields, partition_id=partition_id) - def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: + def reset_consumption(self, partition_id: str, task_name: str | None = None) -> bool: """Synchronously reset consumption status for a partition. This allows the same data to be re-consumed, useful for debugging scenarios @@ -1540,7 +1540,7 @@ def kv_retrieve_keys( def kv_list( self, - partition_id: Optional[str] = None, + partition_id: str | None = None, ) -> dict[str, dict[str, Any]]: """Synchronously retrieve keys and custom_meta from the controller for one or all partitions. diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 96f971c4..f1e1c1c9 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -21,7 +21,7 @@ from itertools import groupby from operator import itemgetter from threading import Lock, Thread -from typing import Any, Optional +from typing import Any from uuid import uuid4 import numpy as np @@ -195,10 +195,10 @@ class FieldMeta: """ global_indexes: set[int] = field(default_factory=set) - dtype: Optional[Any] = None - shape: Optional[tuple] = None # None when is_nested=True - is_nested: Optional[bool] = None - is_non_tensor: Optional[bool] = None + dtype: Any | None = None + shape: tuple | None = None # None when is_nested=True + is_nested: bool | None = None + is_non_tensor: bool | None = None per_sample_shapes: dict[int, tuple] = field(default_factory=dict) # {global_idx: shape} @@ -495,7 +495,7 @@ def update_production_status( global_indices: list[int], field_names: list[str], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ) -> bool: """ Update production status for specific samples and fields. @@ -560,7 +560,7 @@ def _update_field_metadata( self, global_indexes: list[int], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ): """Update field metadata from columnar field_schema.""" if not global_indexes: @@ -645,7 +645,7 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te consumption_status = self.consumption_status[task_name] return partition_global_index, consumption_status - def reset_consumption(self, task_name: Optional[str] = None): + def reset_consumption(self, task_name: str | None = None): """ Reset consumption status for a specific task or all tasks. @@ -670,7 +670,7 @@ def reset_consumption(self, task_name: Optional[str] = None): # ==================== Production Status Interface ==================== def get_production_status_for_fields( self, field_names: list[str], mask: bool = False - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + ) -> tuple[Tensor | None, Tensor | None]: """ Check if all samples for specified fields are fully produced and ready. @@ -1049,7 +1049,7 @@ def create_partition(self, partition_id: str) -> bool: logger.info(f"Created partition {partition_id} with {TQ_PRE_ALLOC_SAMPLE_NUM} pre-allocated indexes") return True - def _get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]: + def _get_partition(self, partition_id: str) -> DataPartitionStatus | None: """ Get partition status information. @@ -1061,7 +1061,7 @@ def _get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]: """ return self.partitions.get(partition_id) - def get_partition_snapshot(self, partition_id: str) -> Optional[DataPartitionStatus]: + def get_partition_snapshot(self, partition_id: str) -> DataPartitionStatus | None: """ Get a copy of partition status information, without threading.Lock(). @@ -1109,7 +1109,7 @@ def update_production_status( partition_id: str, global_indexes: list[int], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ) -> bool: """ Update production status for specific samples and fields in a partition. @@ -1139,7 +1139,7 @@ def update_production_status( return success # ==================== Data Consumption API ==================== - def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]: + def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Tensor | None, Tensor | None]: """ Get or create consumption status for a specific task and partition. Delegates to the partition's own method. @@ -1159,9 +1159,7 @@ def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Opt return partition.get_consumption_status(task_name, mask=True) - def get_production_status( - self, partition_id: str, data_fields: list[str] - ) -> tuple[Optional[Tensor], Optional[Tensor]]: + def get_production_status(self, partition_id: str, data_fields: list[str]) -> tuple[Tensor | None, Tensor | None]: """ Check if all samples for specified fields are fully produced in a partition. @@ -1226,7 +1224,7 @@ def get_metadata( mode: str = "fetch", task_name: str | None = None, batch_size: int | None = None, - sampling_config: Optional[dict[str, Any]] = None, + sampling_config: dict[str, Any] | None = None, *args, **kwargs, ) -> BatchMeta: @@ -1494,7 +1492,7 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): self.partitions.pop(partition_id) self.sampler.clear_cache(partition_id) - def reset_consumption(self, partition_id: str, task_name: Optional[str] = None): + def reset_consumption(self, partition_id: str, task_name: str | None = None): """ Reset consumption status for a partition without clearing the actual data. @@ -1641,7 +1639,7 @@ def kv_retrieve_keys( self, global_indexes: list[int], partition_id: str, - ) -> list[Optional[str]]: + ) -> list[str | None]: """ Retrieve keys from the controller using a list of global_indexes. diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index 3b08cd26..13277e11 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import torch from tensordict import TensorDict @@ -80,7 +79,7 @@ def __init__( pin_memory: bool = False, worker_init_fn=None, multiprocessing_context=None, - prefetch_factor: Optional[int] = None, + prefetch_factor: int | None = None, persistent_workers: bool = False, pin_memory_device: str = "", ): diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 4b154ab5..ea82c37f 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -18,7 +18,7 @@ import subprocess import time from importlib import resources -from typing import Any, Callable, Optional +from typing import Any, Callable from urllib.parse import urlparse import ray @@ -49,7 +49,7 @@ def _maybe_create_transferqueue_client( - conf: Optional[DictConfig] = None, + conf: DictConfig | None = None, ) -> TransferQueueClient: global _TRANSFER_QUEUE_CLIENT if _TRANSFER_QUEUE_CLIENT is None: @@ -225,7 +225,7 @@ def _init_from_existing() -> bool: # ==================== Initialization API ==================== -def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]: +def init(conf: DictConfig | None = None) -> DictConfig | None: """Initialize the TransferQueue system. This function sets up the TransferQueue controller, distributed storage, and client. @@ -385,9 +385,9 @@ def close(): def kv_put( key: str, partition_id: str, - fields: Optional[TensorDict | dict[str, Any]] = None, - tag: Optional[dict[str, Any]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | dict[str, Any] | None = None, + tag: dict[str, Any] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: """Put a single key-value pair to TransferQueue. @@ -488,9 +488,9 @@ def kv_put( def kv_batch_put( keys: list[str], partition_id: str, - fields: Optional[TensorDict] = None, - tags: Optional[list[dict[str, Any]]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | None = None, + tags: list[dict[str, Any]] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: """Put multiple key-value pairs to TransferQueue in batch. @@ -582,7 +582,7 @@ def kv_batch_put( ) -def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | str] = None) -> TensorDict: +def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict: """Get data from TransferQueue using KVBatchMeta. This is a convenience method for retrieving data using KVBatchMeta returned @@ -622,7 +622,7 @@ def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | raise ValueError("Must provide partition_id in the input KVBatchMeta.") if select_fields is not None: if isinstance(select_fields, str): - fields_to_fetch: Optional[list[str]] = [select_fields] + fields_to_fetch: list[str] | None = [select_fields] else: fields_to_fetch = select_fields @@ -637,9 +637,7 @@ def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch) -def kv_batch_get( - keys: list[str] | str, partition_id: str, select_fields: Optional[list[str] | str] = None -) -> TensorDict: +def kv_batch_get(keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None) -> TensorDict: """Get data from TransferQueue using user-specified keys. This is a convenience method for retrieving data using keys instead of indexes. @@ -690,7 +688,7 @@ def kv_batch_get( return data -def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: +def kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]: """List all keys and their metadata in one or all partitions. Args: @@ -764,9 +762,9 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None: async def async_kv_put( key: str, partition_id: str, - fields: Optional[TensorDict | dict[str, Any]] = None, - tag: Optional[dict[str, Any]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | dict[str, Any] | None = None, + tag: dict[str, Any] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: """Asynchronously put a single key-value pair to TransferQueue. @@ -868,9 +866,9 @@ async def async_kv_put( async def async_kv_batch_put( keys: list[str], partition_id: str, - fields: Optional[TensorDict] = None, - tags: Optional[list[dict[str, Any]]] = None, - data_parser: Optional[Callable[[Any], Any]] = None, + fields: TensorDict | None = None, + tags: list[dict[str, Any]] | None = None, + data_parser: Callable[[Any], Any] | None = None, ) -> KVBatchMeta: """Asynchronously put multiple key-value pairs to TransferQueue in batch. @@ -961,7 +959,7 @@ async def async_kv_batch_put( ) -async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | str] = None) -> TensorDict: +async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: list[str] | str | None = None) -> TensorDict: """Asynchronously get data from TransferQueue using KVBatchMeta. This is a convenience method for retrieving data using KVBatchMeta returned @@ -1019,7 +1017,7 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[ async def async_kv_batch_get( - keys: list[str] | str, partition_id: str, select_fields: Optional[list[str] | str] = None + keys: list[str] | str, partition_id: str, select_fields: list[str] | str | None = None ) -> TensorDict: """Asynchronously get data from TransferQueue using user-specified keys. @@ -1070,7 +1068,7 @@ async def async_kv_batch_get( return data -async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: +async def async_kv_list(partition_id: str | None = None) -> dict[str, dict[str, Any]]: """Asynchronously list all keys and their metadata in one or all partitions. Args: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 64fc18e2..05cf3241 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -19,7 +19,7 @@ from collections import defaultdict from dataclasses import dataclass from types import MappingProxyType -from typing import Any, Optional +from typing import Any import numpy as np import torch @@ -221,11 +221,11 @@ def __init__( self, global_indexes: list[int], partition_ids: list[str], - field_schema: Optional[dict[str, dict[str, Any]]] = None, - production_status: Optional[np.ndarray] = None, - extra_info: Optional[dict[str, Any]] = None, - custom_meta: Optional[list[dict[str, Any]]] = None, - _custom_backend_meta: Optional[list[dict[str, Any]]] = None, + field_schema: dict[str, dict[str, Any]] | None = None, + production_status: np.ndarray | None = None, + extra_info: dict[str, Any] | None = None, + custom_meta: list[dict[str, Any]] | None = None, + _custom_backend_meta: list[dict[str, Any]] | None = None, ) -> None: if field_schema is None: field_schema = {} @@ -785,7 +785,7 @@ def reorder(self, indices: list[int]): self._custom_backend_meta = [self._custom_backend_meta[i] for i in indices] @classmethod - def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": + def empty(cls, extra_info: dict[str, Any] | None = None) -> "BatchMeta": """Create an empty BatchMeta with no samples. Args: @@ -828,13 +828,13 @@ class KVBatchMeta: tags: list[dict] = dataclasses.field(default_factory=list) # [optional] partition_id of this batch - partition_id: Optional[str] = None + partition_id: str | None = None # [optional] fields of each sample - fields: Optional[list[str]] = None + fields: list[str] | None = None # [optional] external information for batch-level information - extra_info: Optional[dict[str, Any]] = dataclasses.field(default_factory=dict) + extra_info: dict[str, Any] | None = dataclasses.field(default_factory=dict) def __post_init__(self): """Validate all the variables""" diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index 766afd62..c831ecba 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -14,7 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any class BaseSampler(ABC): @@ -79,7 +79,7 @@ def has_cached_result( self, partition_id: str, task_name: str, - sampling_config: Optional[dict[str, Any]] = None, + sampling_config: dict[str, Any] | None = None, ) -> bool: """Check whether the sampler has a cached sampling result for the given context. diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index 9c475493..14b73e72 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -14,7 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any class StorageKVClient(ABC): @@ -32,7 +32,7 @@ def __init__(self, config: dict[str, Any]): self.config = config @abstractmethod - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + def put(self, keys: list[str], values: list[Any]) -> list[Any] | None: """ Store key-value pairs in the storage backend. Args: diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 7a47111c..bcfe3848 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -15,7 +15,7 @@ import pickle from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Optional +from typing import Any import torch from torch import Tensor @@ -168,9 +168,9 @@ def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any def get( self, keys: list[str], - shapes: Optional[list[Any]] = None, - dtypes: Optional[list[Any]] = None, - custom_backend_meta: Optional[list[str]] = None, + shapes: list[Any] | None = None, + dtypes: list[Any] | None = None, + custom_backend_meta: list[str] | None = None, ) -> list[Any]: """Get multiple key-value pairs from MooncakeStore. @@ -258,7 +258,7 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> return results, indexes - def clear(self, keys: list[str], custom_backend_meta: Optional[list[Any]] = None) -> None: + def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: """Deletes multiple keys from MooncakeStore. Args: diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index db901533..c85fa438 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -14,7 +14,7 @@ # limitations under the License. import itertools -from typing import Any, Optional +from typing import Any import ray import torch @@ -61,7 +61,7 @@ def __init__(self, config=None): except ValueError: self.storage_actor = RayObjectRefStorage.options(name="RayObjectRefStorage", get_if_exists=False).remote() - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + def put(self, keys: list[str], values: list[Any]) -> list[Any] | None: """ Store tensors to remote storage. Args: diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index aa644065..399cb58a 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -62,7 +62,7 @@ def supports_get(self, strategy_tag: Any) -> bool: """Check if this strategy can retrieve data with given tag.""" @abstractmethod - def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + def get(self, keys: list[str], **kwargs) -> list[Any | None]: """Retrieve values by keys; kwargs may include shapes/dtypes.""" @abstractmethod @@ -144,7 +144,7 @@ def supports_get(self, strategy_tag: str) -> bool: """Matches 'DsTensorClient' Strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + def get(self, keys: list[str], **kwargs) -> list[Any | None]: """Fetch NPU tensors using pre-allocated empty buffers.""" shapes = kwargs.get("shapes", None) dtypes = kwargs.get("dtypes", None) @@ -251,7 +251,7 @@ def supports_get(self, strategy_tag: str) -> bool: """Matches 'KVClient' strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + def get(self, keys: list[str], **kwargs) -> list[Any | None]: """Retrieve and deserialize objects in batches.""" results = [] for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): @@ -433,9 +433,9 @@ def put_task(strategy, indexes): def get( self, keys: list[str], - shapes: Optional[list[Any]] = None, - dtypes: Optional[list[Any]] = None, - custom_backend_meta: Optional[list[str]] = None, + shapes: list[Any] | None = None, + dtypes: list[Any] | None = None, + custom_backend_meta: list[str] | None = None, ) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. @@ -475,7 +475,7 @@ def get_task(strategy, indexes): results[original_index] = value return results - def clear(self, keys: list[str], custom_backend_meta: Optional[list[str]] = None) -> None: + def clear(self, keys: list[str], custom_backend_meta: list[str] | None = None) -> None: """Deletes multiple keys from remote storage. Args: diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index d578873c..fe3d87bf 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -21,7 +21,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Optional +from typing import Any, Callable from uuid import uuid4 import ray @@ -60,9 +60,9 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.controller_info = controller_info # Handshake socket is sync (used only during initialization) - self.controller_handshake_socket: Optional[zmq.Socket] = None + self.controller_handshake_socket: zmq.Socket | None = None - self.zmq_context: Optional[zmq.asyncio.Context] = None + self.zmq_context: zmq.asyncio.Context | None = None self._connect_to_controller() def _connect_to_controller(self) -> None: @@ -190,7 +190,7 @@ async def notify_data_update( partition_id: str, global_indexes: list[int], field_schema: dict[str, dict[str, Any]], - custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: dict[int, dict[str, Any]] | None = None, ) -> None: """ Notify controller that new data is ready. @@ -301,7 +301,7 @@ async def notify_data_update( @abstractmethod async def put_data( - self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None ) -> None: """ Put data into the storage backend. @@ -432,7 +432,7 @@ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): raise ValueError("Missing client_name in config") super().__init__(controller_info, config) self.storage_client = StorageClientFactory.create(client_name, config) - self._multi_threads_executor: Optional[ThreadPoolExecutor] = None + self._multi_threads_executor: ThreadPoolExecutor | None = None self._executor_finalizer = weakref.finalize(self, self._shutdown_executor, self._multi_threads_executor) @staticmethod @@ -477,7 +477,7 @@ def _generate_values(data: TensorDict) -> list[Any]: return results @staticmethod - def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None: + def _shutdown_executor(thread_executor: ThreadPoolExecutor | None) -> None: """ A static method to ensure no strong reference to 'self' is held within the finalizer's callback, enabling proper garbage collection. @@ -618,7 +618,7 @@ def _get_shape_type_custom_backend_meta_list( return shapes, dtypes, custom_backend_meta_list async def put_data( - self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None ) -> None: """ Store tensor data in the backend storage and notify the controller. diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index 492037d1..771f22dc 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -17,7 +17,7 @@ import time import weakref from threading import Event, Thread -from typing import Any, Optional +from typing import Any from uuid import uuid4 import ray @@ -160,10 +160,10 @@ def __init__(self, storage_unit_size: int): self._shutdown_event = Event() # Placeholder for zmq_context, proxy_thread and worker_threads - self.zmq_context: Optional[zmq.Context] = None - self.put_get_socket: Optional[zmq.Socket] = None - self.proxy_thread: Optional[Thread] = None - self.worker_thread: Optional[Thread] = None + self.zmq_context: zmq.Context | None = None + self.put_get_socket: zmq.Socket | None = None + self.proxy_thread: Thread | None = None + self.worker_thread: Thread | None = None self._init_zmq_socket() self._start_process_put_get() @@ -481,10 +481,10 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: @staticmethod def _shutdown_resources( shutdown_event: Event, - worker_thread: Optional[Thread], - proxy_thread: Optional[Thread], - zmq_context: Optional[zmq.Context], - put_get_socket: Optional[zmq.Socket], + worker_thread: Thread | None, + proxy_thread: Thread | None, + zmq_context: zmq.Context | None, + put_get_socket: zmq.Socket | None, ) -> None: """Clean up resources on garbage collection.""" logger.info("Shutting down SimpleStorageUnit resources...") diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py index 56f0dda8..fadcb241 100644 --- a/transfer_queue/utils/common.py +++ b/transfer_queue/utils/common.py @@ -15,7 +15,6 @@ import os from contextlib import contextmanager -from typing import Optional import psutil import ray @@ -46,7 +45,7 @@ def get_placement_group(num_ray_actors: int, num_cpus_per_actor: int = 1): @contextmanager -def limit_pytorch_auto_parallel_threads(target_num_threads: Optional[int] = None, info: str = ""): +def limit_pytorch_auto_parallel_threads(target_num_threads: int | None = None, info: str = ""): """Prevent PyTorch from overdoing the automatic parallelism during torch.stack() operation""" pytorch_current_num_threads = torch.get_num_threads() physical_cores = psutil.cpu_count(logical=False) diff --git a/transfer_queue/utils/yuanrong_utils.py b/transfer_queue/utils/yuanrong_utils.py index fd3afbd6..17d5ca82 100644 --- a/transfer_queue/utils/yuanrong_utils.py +++ b/transfer_queue/utils/yuanrong_utils.py @@ -19,7 +19,7 @@ import shutil import socket import subprocess -from typing import Any, Optional +from typing import Any import ray from omegaconf import DictConfig @@ -101,7 +101,7 @@ def check_port_connectivity(host: str, port: int, timeout: float = 2.0) -> bool: return False -def find_reachable_host(port: int, timeout: float = 1.0) -> Optional[str]: +def find_reachable_host(port: int, timeout: float = 1.0) -> str | None: """Find a reachable local host IP address for given port. Tries all local IP addresses in order and returns the first one @@ -126,7 +126,7 @@ def find_reachable_host(port: int, timeout: float = 1.0) -> Optional[str]: return None -def _parse_remote_h2d_device_ids(worker_args: str) -> Optional[str]: +def _parse_remote_h2d_device_ids(worker_args: str) -> str | None: """Parse --remote_h2d_device_ids parameter from worker_args string. Args: diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 6216e7ec..606a4c32 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -17,7 +17,7 @@ import time from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Optional, TypeAlias +from typing import Any, Callable, TypeAlias from uuid import uuid4 import psutil @@ -146,7 +146,7 @@ def create( request_type: ZMQRequestType, sender_id: str, body: dict[str, Any], - receiver_id: Optional[str] = None, + receiver_id: str | None = None, ) -> "ZMQMessage": """Create ZMQMessage.""" return cls( @@ -253,7 +253,7 @@ def create_zmq_socket( ctx: zmq.Context, socket_type: Any, ip: str, - identity: Optional[bytestr] = None, + identity: bytestr | None = None, ) -> zmq.Socket: """Create ZMQ socket. @@ -299,9 +299,9 @@ def with_zmq_socket( socket_name: str, *, get_identity: Callable[[Any], str], - get_peer: Callable[[Any, Optional[str]], ZMQServerInfo], - resolve_target: Optional[Callable[[tuple, dict], Optional[str]]] = None, - timeout: Optional[int] = None, + get_peer: Callable[[Any, str | None], ZMQServerInfo], + resolve_target: Callable[[tuple, dict], str | None] | None = None, + timeout: int | None = None, ): """Create a reusable async decorator for request sockets. @@ -330,7 +330,7 @@ async def wrapper(self, *args, **kwargs): if owner_id is None: raise RuntimeError("get_identity returned None") - target_name: Optional[str] = None + target_name: str | None = None if resolve_target is not None: target_name = resolve_target(args, kwargs)