diff --git a/README.md b/README.md index b69176d8..2c7d4f0f 100644 --- a/README.md +++ b/README.md @@ -73,10 +73,10 @@ This class encapsulates the core interaction logic within the TransferQueue syst Currently, we support the following storage backends: -- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability. +- SimpleStorage: A basic CPU memory storage with minimal data format constraints and easy usability. - [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem) (beta, [#PR107](https://github.com/TransferQueue/TransferQueue/pull/107), [#PR96](https://github.com/TransferQueue/TransferQueue/pull/96)): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD. -- [Mooncake Store](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. -- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. +- [MooncakeStore](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. +- [RayRDT](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management. diff --git a/pyproject.toml b/pyproject.toml index f8539700..35d65242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,5 +126,6 @@ yuanrong = [ # This is the rough equivalent of package_data={'': ['version/*']} [tool.setuptools.package-data] transfer_queue = [ - "version/*", + "version/*", + "*.yaml" ] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1da090c8..8cfccacb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ pyzmq hydra-core numpy<2.0.0 msgspec -psutil \ No newline at end of file +psutil +omegaconf \ No newline at end of file diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 3949a0dc..5187e169 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -62,8 +62,7 @@ async def mock_async_storage_manager(): ) config = { - "storage_unit_infos": storage_unit_infos, - "controller_info": controller_info, + "zmq_info": storage_unit_infos, } # Mock the handshake process entirely to avoid ZMQ complexity @@ -199,8 +198,7 @@ async def test_async_storage_manager_mapping_functions(): ) config = { - "storage_unit_infos": storage_unit_infos, - "controller_info": controller_info, + "zmq_info": storage_unit_infos, } # Mock ZMQ operations @@ -230,7 +228,7 @@ async def test_async_storage_manager_mapping_functions(): mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) # Create manager - manager = AsyncSimpleStorageManager(config) + manager = AsyncSimpleStorageManager(controller_info, config) # Test round-robin mapping for 3 storage units # global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2, @@ -266,7 +264,7 @@ async def test_async_storage_manager_error_handling(): } # Mock controller info - controller_infos = ZMQServerInfo( + controller_info = ZMQServerInfo( role=TransferQueueRole.CONTROLLER, id="controller_0", ip="127.0.0.1", @@ -274,8 +272,7 @@ async def test_async_storage_manager_error_handling(): ) config = { - "storage_unit_infos": storage_unit_infos, - "controller_info": controller_infos, + "zmq_info": storage_unit_infos, } # Mock ZMQ operations @@ -305,7 +302,7 @@ async def test_async_storage_manager_error_handling(): mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) # Create manager - manager = AsyncSimpleStorageManager(config) + manager = AsyncSimpleStorageManager(controller_info, config) # Mock operations that raise exceptions manager._put_to_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock PUT error")) diff --git a/tests/test_client.py b/tests/test_client.py index 38f140b0..42cd63bc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -318,9 +318,9 @@ def client_setup(mock_controller, mock_storage): ): config = { "controller_info": mock_controller.zmq_server_info, - "storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info}, + "zmq_info": {mock_storage.storage_id: mock_storage.zmq_server_info}, } - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations async def mock_put_data(data, metadata): @@ -411,9 +411,9 @@ def test_single_controller_multiple_storages(): ): config = { "controller_info": controller.zmq_server_info, - "storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages}, + "zmq_info": {s.storage_id: s.zmq_server_info for s in storages}, } - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations async def mock_put_data(data, metadata): diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 41296dd5..083e16d6 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -117,7 +117,7 @@ def test_merge_tensors_to_tensordict(mock_create, test_data): mock_client = MagicMock() mock_create.return_value = mock_client - manager = KVStorageManager(test_data["cfg"]) + manager = KVStorageManager(controller_info=MagicMock(), config=test_data["cfg"]) assert manager.storage_client is mock_client assert manager._multi_threads_executor is None @@ -296,9 +296,9 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo mock_storage_client.put.return_value = mock_custom_meta # Create manager with mocked dependencies - config = {"controller_info": MagicMock(), "client_name": "MockClient"} + config = {"client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"])) @@ -348,7 +348,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data): # Create manager with mocked dependencies config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"])) @@ -371,7 +371,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat # Create manager with mocked dependencies config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data and expect ValueError with pytest.raises(ValueError) as exc_info: diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index e9259024..592ef0d2 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -15,12 +15,25 @@ import os -from .client import ( - TransferQueueClient, - process_zmq_server_info, -) +from .client import TransferQueueClient from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset +from .interface import ( + async_clear_partition, + async_clear_samples, + async_get_data, + async_get_meta, + async_put, + async_set_custom_meta, + clear_partition, + clear_samples, + close, + get_data, + get_meta, + init, + put, + set_custom_meta, +) from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler @@ -28,9 +41,24 @@ from .sampler.sequential_sampler import SequentialSampler from .storage import SimpleStorageUnit from .utils.common import get_placement_group -from .utils.zmq_utils import ZMQServerInfo +from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info __all__ = [ + "init", + "get_meta", + "get_data", + "put", + "set_custom_meta", + "clear_samples", + "clear_partition", + "async_get_meta", + "async_get_data", + "async_put", + "async_set_custom_meta", + "async_clear_samples", + "async_clear_partition", + "close", +] + [ "TransferQueueClient", "StreamingDataset", "StreamingDataLoader", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 24b4e559..cd9c6e84 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -18,23 +18,19 @@ import os import threading from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional from uuid import uuid4 -import ray import torch import zmq import zmq.asyncio from tensordict import TensorDict from torch import Tensor -from transfer_queue.controller import TransferQueueController from transfer_queue.metadata import ( BatchMeta, ) from transfer_queue.storage import ( - SimpleStorageUnit, - TransferQueueStorageManager, TransferQueueStorageManagerFactory, ) from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads @@ -95,11 +91,12 @@ def initialize_storage_manager( AsyncSimpleStorageManager, KVStorageManager (under development), etc. config: Configuration dictionary for the storage manager. For AsyncSimpleStorageManager, must contain the following required keys: - - controller_info: ZMQ server information about the controller - - storage_unit_infos: ZMQ server information about the storage units + - zmq_info: ZMQ server information about the storage units """ - self.storage_manager = TransferQueueStorageManagerFactory.create(manager_type, config) + self.storage_manager = TransferQueueStorageManagerFactory.create( + manager_type, controller_info=self._controller, config=config + ) # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. @staticmethod @@ -1041,6 +1038,7 @@ def get_meta( data_fields: list[str], batch_size: int, partition_id: str, + mode: str = "fetch", task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, ) -> BatchMeta: @@ -1098,6 +1096,7 @@ def get_meta( data_fields=data_fields, batch_size=batch_size, partition_id=partition_id, + mode=mode, task_name=task_name, sampling_config=sampling_config, ) @@ -1304,36 +1303,3 @@ def close(self) -> None: logger.warning(f"[{self.client_id}]: Error closing event loop: {e}") super().close() - - -def process_zmq_server_info( - handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]] - | Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"], -): # noqa: UP007 - """Extract ZMQ server information from handler objects. - - Args: - handlers: Dictionary of handler objects (controllers, storage managers, or storage units), - or a single handler object - - Returns: - If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information - If handlers is a single object: ZMQ server information for that object - - Examples: - >>> # Single handler - >>> controller = TransferQueueController.remote(...) - >>> info = process_zmq_server_info(controller) - >>> - >>> # Multiple handlers - >>> handlers = {"storage_0": storage_0, "storage_1": storage_1} - >>> info_dict = process_zmq_server_info(handlers)""" - # Handle single handler object case - if not isinstance(handlers, dict): - return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] - else: - # Handle dictionary case - server_info = {} - for name, handler in handlers.items(): - server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] - return server_info diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml new file mode 100644 index 00000000..9503afa4 --- /dev/null +++ b/transfer_queue/config.yaml @@ -0,0 +1,31 @@ +# This is the default configuration of TransferQueue. Users may modify the default value +# and use transfer_queue.init(conf) to overwrite the config entries. + +controller: + # User-defined sampler. User can pass sampler instance to overwrite this string config. + sampler: SequentialSampler + # Whether return an empty BatchMeta to prevent request blocking when no enough data is available + polling_mode: False + # ZMQ Server IP & Ports (automatically generated during init) + zmq_info: null + + +backend: + # Pluggable storage/transport backend of TransferQueue. Choose from: + # SimpleStorage, Yuanrong, MooncakeStore, ... + storage_backend: SimpleStorage + + # For SimpleStorage: + SimpleStorage: + # Total number of samples + total_storage_size: 100000 + # Number of distributed storage units for SimpleStorage backend + num_data_storage_units: 2 + # ZMQ Server IP & Ports (automatically generated during init) + zmq_info: null + + # For Yuanrong: + # TODO + + # For MooncakeStore: + # TODO \ No newline at end of file diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c91eaf7c..763f3f49 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -28,6 +28,7 @@ import ray import torch import zmq +from omegaconf import DictConfig from ray.util import get_node_ip_address from torch import Tensor @@ -879,6 +880,7 @@ def __init__( self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}" self.polling_mode = polling_mode + self.tq_config = None # global config for TransferQueue system # Initialize ZMQ sockets for communication self._init_zmq_socket() @@ -1772,3 +1774,35 @@ def _update_data_status(self): def get_zmq_server_info(self) -> ZMQServerInfo: """Get ZMQ server connection information.""" return self.zmq_server_info + + def store_config(self, conf: DictConfig) -> None: + """Store the global config of TransferQueue.""" + self.tq_config = conf + + def get_config(self) -> DictConfig: + """Retrieve the global config of TransferQueue.""" + return self.tq_config + + def register_sampler( + self, + sampler: BaseSampler | type[BaseSampler] = SequentialSampler, + ) -> None: + """ + Register a sampler instance or subclass after the controller is initialized. + + Args: + sampler: Sampler instance or sampler class to use for data sampling. + - If a BaseSampler instance is provided, it will be used directly + - If a BaseSampler subclass is provided, it will be instantiated + - Defaults to SequentialSampler for simple sequential sampling + - Example: sampler=GRPOGroupNSampler() (instance) + - Example: sampler=SequentialSampler (class) + """ + if isinstance(sampler, BaseSampler): + self.sampler = sampler + elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): + self.sampler = sampler() + else: + raise TypeError( + f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler" + ) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index eb2afe49..1de90c5b 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -17,8 +17,10 @@ import os import time import uuid -from typing import Any, Callable, Iterator +import warnings +from typing import Callable, Iterator +from omegaconf import DictConfig from tensordict import TensorDict from torch.utils.data import IterableDataset @@ -68,7 +70,7 @@ class StreamingDataset(IterableDataset): def __init__( self, - config: dict[str, Any], + config: DictConfig, batch_size: int, micro_batch_size: int, data_fields: list[str], @@ -82,8 +84,8 @@ def __init__( Args: config: Configuration dictionary containing: - - controller_info: ZMQServerInfo for the TransferQueueController - - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager") + - controller.controller_info: ZMQServerInfo for the TransferQueueController + - backend.storage_backend: Storage backend type (e.g., "SimpleStorage") - Other backend-specific configuration batch_size: Batch size for data loading per iter. micro_batch_size: Number of samples per micro-batch. This is the batch size @@ -156,16 +158,44 @@ def _create_client(self): ValueError: If controller_info or storage_backend is missing or invalid. """ client_id = uuid.uuid4().hex[:8] - controller_info = self.config.get("controller_info", None) + + # TODO: DEPRECATE in future + controller_config = self.config.get("controller", None) + if controller_config: + controller_info = controller_config.get("zmq_info", None) + else: + controller_info = self.config.get("controller_info", None) + if controller_info: + warnings.warn( + "Config entry `controller_info` will be deprecated in 0.1.7, please " + "use `controller.zmq_info` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + if not controller_info or not isinstance(controller_info, ZMQServerInfo): - raise ValueError("Invalid or missing controller_info in config") + raise ValueError("Invalid or missing controller.zmq_info in config") + + backend_config = self.config.get("backend", None) + if not backend_config: + storage_backend = self.config.get("storage_backend", None) + backend_config = self.config + if storage_backend: + warnings.warn( + "Config entry `storage_backend` will be deprecated in 0.1.7, please " + "use `backend.storage_backend` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + else: + storage_backend = backend_config.get("storage_backend", None) + backend_config = self.config.backend[storage_backend] - storage_backend = self.config.get("storage_backend", None) if not storage_backend: raise ValueError("Missing storage_backend in config") self._tq_client = TransferQueueClient(client_id, controller_info) - self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=self.config) + self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config) def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: """Iterate over the dataset, yielding batches of data. diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py new file mode 100644 index 00000000..37ebdb23 --- /dev/null +++ b/transfer_queue/interface.py @@ -0,0 +1,649 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources as pkg_resources +import logging +import math +import os +import time +from typing import Any, Optional + +import ray +from omegaconf import DictConfig, OmegaConf +from tensordict import TensorDict + +from transfer_queue.client import TransferQueueClient +from transfer_queue.controller import TransferQueueController +from transfer_queue.metadata import BatchMeta +from transfer_queue.sampler import * # noqa: F401 +from transfer_queue.sampler import BaseSampler +from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.utils.common import get_placement_group +from transfer_queue.utils.zmq_utils import process_zmq_server_info + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +_TRANSFER_QUEUE_CLIENT: Any = None +_TRANSFER_QUEUE_STORAGE: Any = None + + +def _maybe_create_transferqueue_client( + conf: Optional[DictConfig] = None, +) -> TransferQueueClient: + global _TRANSFER_QUEUE_CLIENT + if _TRANSFER_QUEUE_CLIENT is None: + if conf is None: + raise ValueError("Missing config for initializing TransferQueueClient!") + pid = os.getpid() + _TRANSFER_QUEUE_CLIENT = TransferQueueClient( + client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info + ) + + backend_name = conf.backend.storage_backend + + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) + + return _TRANSFER_QUEUE_CLIENT + + +def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: + global _TRANSFER_QUEUE_STORAGE + + if _TRANSFER_QUEUE_STORAGE is None: + _TRANSFER_QUEUE_STORAGE = {} + if conf.backend.storage_backend == "SimpleStorage": + # initialize SimpleStorageUnit + num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units + total_storage_size = conf.backend.SimpleStorage.total_storage_size + storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) + + for storage_unit_rank in range(num_data_storage_units): + storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] + placement_group=storage_placement_group, + placement_group_bundle_index=storage_unit_rank, + name=f"TransferQueueStorageUnit#{storage_unit_rank}", + lifetime="detached", + ).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units)) + _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") + + storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) + backend_name = conf.backend.storage_backend + conf.backend[backend_name].zmq_info = storage_zmq_info + + return conf + + +def _init_from_existing() -> None: + """Initialize the TransferQueueClient from existing controller.""" + + controller = ray.get_actor("TransferQueueController") + logger.info("Found existing TransferQueueController instance. Connecting...") + + conf = None + while conf is None: + remote_conf = ray.get(controller.get_config.remote()) + if remote_conf is not None: + _maybe_create_transferqueue_client(remote_conf) + logger.info("TransferQueueClient initialized.") + return + + logger.debug("Waiting for controller to initialize... Retrying in 1s") + time.sleep(1) + + +def init(conf: Optional[DictConfig] = None) -> None: + """Initialize the TransferQueue system. + + This function sets up the TransferQueue controller, distributed storage, and client. + It should be called once at the beginning of the program before any data operations. + + If a controller already exists (e.g., initialized by another process), this function + will retrieve the config from existing controller and initialize the TransferQueueClient. + In this case, the `conf` parameter will be ignored. + + Args: + conf: Optional configuration dictionary. If provided, it will be merged with + the default config from 'config.yaml'. This is only used for first-time + initializing. When connecting to an existing controller, this parameter + is ignored. + + Raises: + ValueError: If config is not valid or required configuration keys are missing. + + Example: + >>> # In process 0, node A + >>> import transfer_queue as tq + >>> tq.init() # Initialize the TransferQueue + >>> tq.put(...) # then you can use tq for data operations + >>> + >>> # In process 1, node B (with Ray connected to node A) + >>> import transfer_queue as tq + >>> tq.init() # This will only initialize a TransferQueueClient and link with existing TQ + >>> metadata = tq.get_meta(...) + >>> data = tq.get_data(metadata) + """ + try: + _init_from_existing() + except ValueError: + logger.info("No TransferQueueController found. Starting first-time initialization...") + else: + return + + # First-time initialize TransferQueue + + # create config + final_conf = OmegaConf.create({}, flags={"allow_objects": True}) + with pkg_resources.path("transfer_queue", "config.yaml") as p: + default_conf = OmegaConf.load(p) + final_conf = OmegaConf.merge(final_conf, default_conf) + if conf: + final_conf = OmegaConf.merge(final_conf, conf) + + # create controller + try: + sampler = final_conf.controller.sampler + if isinstance(sampler, BaseSampler): + # user pass a pre-initialized sampler instance + sampler = sampler + elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): + # user pass a sampler class + sampler = sampler() + elif isinstance(sampler, str): + # user pass a sampler name str + # try to convert as sampler class + sampler = globals()[final_conf.controller.sampler] + except KeyError: + raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None + + try: + # Ray will make sure actor with same name can only be created once + controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined] + sampler=sampler, polling_mode=final_conf.controller.polling_mode + ) + logger.info("TransferQueueController has been created.") + except ValueError: + logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.") + _init_from_existing() + return + + controller_zmq_info = process_zmq_server_info(controller) + final_conf.controller.zmq_info = controller_zmq_info + + # create distributed storage backends + final_conf = _maybe_create_transferqueue_storage(final_conf) + + # store the config into controller + ray.get(controller.store_config.remote(final_conf)) + logger.info(f"TransferQueue config: {final_conf}") + + # create client + _maybe_create_transferqueue_client(final_conf) + + +def get_meta( + data_fields: list[str], + batch_size: int, + partition_id: str, + mode: str = "fetch", + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Synchronously fetch data metadata from the controller via ZMQ. + + Args: + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Example 1: Basic fetch metadata + >>> batch_meta = tq.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = tq.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... sampling_config={"n_samples_per_prompt": 4} + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = tq.get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... ) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ + + tq_client = _maybe_create_transferqueue_client() + return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + + +async def async_get_meta( + data_fields: list[str], + batch_size: int, + partition_id: str, + mode: str = "fetch", + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Asynchronously fetch data metadata from the controller via ZMQ. + + Args: + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + socket: ZMQ async socket for message transmission (injected by decorator) + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Example 1: Basic fetch metadata + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... )) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... )) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ + + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + + +def get_data(metadata: BatchMeta) -> TensorDict: + """Synchronously fetch data from storage units and organize into TensorDict. + + Args: + metadata: Batch metadata containing data location information and global indexes + + Returns: + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_meta = tq.get_data( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = tq.get_data(batch_meta) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.get_data(metadata) + + +async def async_get_data(metadata: BatchMeta) -> TensorDict: + """Asynchronously fetch data from storage units and organize into TensorDict. + + Args: + metadata: Batch metadata containing data location information and global indexes + + Returns: + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_data(metadata) + + +def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: + """Synchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = tq.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = tq.get_data(batch_meta) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> tq.put(data=output, metadata=batch_meta) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.put(data, metadata, partition_id) + + +async def async_put( + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, +) -> BatchMeta: + """Asynchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta)) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_put(data, metadata, partition_id) + + +def set_custom_meta(metadata: BatchMeta) -> None: + """Synchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Create batch with custom metadata + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> tq.set_custom_meta(batch_meta) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.set_custom_meta(metadata) + + +async def async_set_custom_meta( + metadata: BatchMeta, +) -> None: + """ + Asynchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + socket: ZMQ async socket for message transmission (injected by decorator) + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Create batch with custom metadata + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_set_custom_meta(metadata) + + +def clear_samples(metadata: BatchMeta): + """Synchronously clear specific samples from all storage units and the controller. + + Args: + metadata: The BatchMeta of the corresponding data to be cleared + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.clear_samples(metadata) + + +async def async_clear_samples(metadata: BatchMeta): + """Asynchronously clear specific samples from all storage units and the controller. + + Args: + metadata: The BatchMeta of the corresponding data to be cleared + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_clear_samples(metadata) + + +def clear_partition(partition_id: str): + """Synchronously clear the whole partition from all storage units and the controller. + + Args: + partition_id: The partition id to clear data for + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.clear_partition(partition_id) + + +async def async_clear_partition(partition_id: str): + """Asynchronously clear the whole partition from all storage units and the controller. + + Args: + partition_id: The partition id to clear data for + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_clear_partition(partition_id) + + +def close(): + """Close the TransferQueue system.""" + global _TRANSFER_QUEUE_CLIENT + global _TRANSFER_QUEUE_STORAGE + if _TRANSFER_QUEUE_CLIENT: + _TRANSFER_QUEUE_CLIENT.close() + _TRANSFER_QUEUE_CLIENT = None + + try: + if _TRANSFER_QUEUE_STORAGE: + # only the process that do first-time init can clean the distributed storage + for storage in _TRANSFER_QUEUE_STORAGE.values(): + ray.kill(storage) + _TRANSFER_QUEUE_STORAGE = None + except Exception: + pass + + try: + controller = ray.get_actor("TransferQueueController") + ray.kill(controller) + except Exception: + pass diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index c07ec1f9..352b3adc 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -28,6 +28,7 @@ import ray import torch import zmq +from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from torch import Tensor @@ -59,12 +60,10 @@ class TransferQueueStorageManager(ABC): """Base class for storage layer. It defines the interface for data operations and generally provides handshake & notification capabilities.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}" self.config = config - controller_info = config.get("controller_info") - assert controller_info is not None, "controller_info is required" - self.controller_info: ZMQServerInfo = controller_info + self.controller_info = controller_info self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None @@ -351,14 +350,14 @@ class KVStorageManager(TransferQueueStorageManager): It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations. """ - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): """ Initialize the KVStorageManager with configuration. """ client_name = config.get("client_name", None) if client_name is None: raise ValueError("Missing client_name in config") - super().__init__(config) + super().__init__(controller_info, config) self.storage_client = StorageClientFactory.create(client_name, config) self._multi_threads_executor: Optional[ThreadPoolExecutor] = None # Register a cleanup function: automatically invoke shutdown when the instance is garbage collected. diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py index e595ccd8..04d2cd03 100644 --- a/transfer_queue/storage/managers/factory.py +++ b/transfer_queue/storage/managers/factory.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Any from transfer_queue.storage.managers.base import TransferQueueStorageManager +from transfer_queue.utils.zmq_utils import ZMQServerInfo class TransferQueueStorageManagerFactory: @@ -39,10 +41,34 @@ def decorator(manager_cls: type[TransferQueueStorageManager]): return decorator @classmethod - def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager: + def create( + cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any] + ) -> TransferQueueStorageManager: """Create and return a TransferQueueStorageManager instance.""" if manager_type not in cls._registry: - raise ValueError( - f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" - ) - return cls._registry[manager_type](config) + if manager_type == "AsyncSimpleStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "SimpleStorage" + elif manager_type == "MooncakeStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "MooncakeStore" + elif manager_type == "YuanrongStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "Yuanrong" + else: + raise ValueError( + f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" + ) + return cls._registry[manager_type](controller_info, config) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index ca555668..9f6f93a6 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -19,16 +19,17 @@ from transfer_queue.storage.managers.base import KVStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) -@TransferQueueStorageManagerFactory.register("MooncakeStorageManager") +@TransferQueueStorageManagerFactory.register("MooncakeStore") class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): # Required: Address of the HTTP metadata server (e.g., "localhost:8080") metadata_server = config.get("metadata_server", None) # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") @@ -45,4 +46,4 @@ def __init__(self, config: dict[str, Any]): config["client_name"] = "MooncakeStorageClient" elif client_name != "MooncakeStorageClient": raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStorageClient'") - super().__init__(config) + super().__init__(controller_info, config) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 5c6e68f0..f1420690 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -16,6 +16,7 @@ import asyncio import logging import os +import warnings from collections.abc import Mapping from functools import wraps from operator import itemgetter @@ -24,6 +25,7 @@ import torch import zmq +from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from transfer_queue.metadata import BatchMeta @@ -48,7 +50,7 @@ TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) -@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager") +@TransferQueueStorageManagerFactory.register("SimpleStorage") class AsyncSimpleStorageManager(TransferQueueStorageManager): """Asynchronous storage manager that handles multiple storage units. @@ -56,14 +58,23 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager): instances using ZMQ communication and dynamic socket management. """ - def __init__(self, config: dict[str, Any]): - super().__init__(config) + def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): + super().__init__(controller_info, config) self.config = config - server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("storage_unit_infos", None) + server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("zmq_info", None) if server_infos is None: - raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.") + server_infos = config.get("storage_unit_infos", None) + if server_infos is not None: + warnings.warn( + "The config entry `storage_unit_infos` will be deprecated in 0.1.7, please use `zmq_info` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + if server_infos is None: + raise ValueError("AsyncSimpleStorageManager requires non-empty 'zmq_info' in config.") self.storage_unit_infos = self._register_servers(server_infos) self._build_storage_mapping_functions() @@ -277,7 +288,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping ) - # retrive data + # retrieve data tasks = [ self._get_from_single_storage_unit(meta_group, target_storage_unit=storage_id) for storage_id, meta_group in storage_meta_groups.items() diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index bfb79e6c..54ac0942 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -19,6 +19,7 @@ from transfer_queue.storage.managers.base import KVStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -30,11 +31,11 @@ logger.addHandler(handler) -@TransferQueueStorageManagerFactory.register("YuanrongStorageManager") +@TransferQueueStorageManagerFactory.register("Yuanrong") class YuanrongStorageManager(KVStorageManager): """Storage manager for Yuanrong backend.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): host = config.get("host", None) port = config.get("port", None) client_name = config.get("client_name", None) @@ -48,4 +49,4 @@ def __init__(self, config: dict[str, Any]): config["client_name"] = "YuanrongStorageClient" elif client_name != "YuanrongStorageClient": raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'YuanrongStorageClient'") - super().__init__(config) + super().__init__(controller_info, config) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 1f6ed922..bf711c8f 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -23,6 +23,7 @@ from uuid import uuid4 import psutil +import ray import zmq from transfer_queue.utils.common import ( @@ -239,3 +240,35 @@ def create_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) return socket + + +def process_zmq_server_info( + handlers: dict[Any, Any] | Any, +): # noqa: UP007 + """Extract ZMQ server information from handler objects. + + Args: + handlers: Dictionary of handler objects (controllers, storage managers, or storage units), + or a single handler object + + Returns: + If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information + If handlers is a single object: ZMQ server information for that object + + Examples: + >>> # Single handler + >>> controller = TransferQueueController.remote(...) + >>> info = process_zmq_server_info(controller) + >>> + >>> # Multiple handlers + >>> handlers = {"storage_0": storage_0, "storage_1": storage_1} + >>> info_dict = process_zmq_server_info(handlers)""" + # Handle single handler object case + if not isinstance(handlers, dict): + return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] + else: + # Handle dictionary case + server_info = {} + for name, handler in handlers.items(): + server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] + return server_info diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 25b530c7..59159a6d 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -37,87 +37,23 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" - -def demonstrate_basic_setup(): - """ - Demonstrate the basic setup of TransferQueue with three core components. - """ - - # Initialize Ray - if not ray.is_initialized(): - ray.init() - - # Configuration - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - print("[Step 1] Creating Storage Backend (using default SimpleStorageUnit)...") - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - print(f" ✓ Created SimpleStorageUnit #{i}") - - print("[Step 2] Creating TransferQueueController...") - controller = TransferQueueController.remote() - print(" ✓ Controller created - manages data state") - - # Get server information - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - # Create Client (User-facing API) - print("[Step 3] Creating TransferQueueClient...") - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - print(" ✓ Client created - this is what users interact with!") - - # Initialize storage manager - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - print( - " ✓ Storage manager initialized. It is a class variable inside the client, acting as an adapter to " - "suit for various storage backends." - ) - - print("[Architecture Summary]") - print( - " - TransferQueueController: Tracking the production/consumption status as metadata (can define your own " - "data consumption logics)." - ) - print(" - SimpleStorageUnit: Distributed data storage that holds actual data (easily swap out by other backends).") - print(" - TransferQueueClient: User interface that allows you to put/get/clear data or metadata)") - - return controller, storage_units, client +if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") -def demonstrate_data_workflow(client): +def demonstrate_data_workflow(): """ Demonstrate basic data workflow: put → get → clear. """ @@ -148,12 +84,12 @@ def demonstrate_data_workflow(client): print(f" Created {data_batch.batch_size[0]} samples") partition_id = "tutorial_partition_0" - client.put(data=data_batch, partition_id=partition_id) + tq.put(data=data_batch, partition_id=partition_id) print(f" ✓ Data written to partition: {partition_id}") # Step 2: Get metadata print("[Step 2] Requesting data metadata...") - batch_meta = client.get_meta( + batch_meta = tq.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=data_batch.batch_size[0], partition_id=partition_id, @@ -164,7 +100,7 @@ def demonstrate_data_workflow(client): # Step 3: Get actual data print("[Step 3] Retrieving actual data...") - retrieved_data = client.get_data(batch_meta) + retrieved_data = tq.get_data(batch_meta) print(" ✓ Data retrieved successfully") print(f" Keys: {list(retrieved_data.keys())}") @@ -176,7 +112,7 @@ def demonstrate_data_workflow(client): # Step 5: Clear print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)") - client.clear_partition(partition_id=partition_id) + tq.clear_partition(partition_id=partition_id) print(" ✓ Partition cleared") @@ -189,16 +125,16 @@ def demonstrate_storage_backend_options(): print("=" * 80) print("TransferQueue supports multiple storage backends:") - print("1. SimpleStorageUnit (default)") + print("1. SimpleStorage (default)") print(" - In-memory storage, fast and simple") print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer") print(" - No extra dependencies, good for development and testing") - print("2. YuanrongStorage") + print("2. Yuanrong") print(" - Ascend native distributed storage solution") print(" - Hierarchical storage interfaces including HBM/DRAM/SSD") - print("3. MoonCakeStore (on the way)") + print("3. MooncakeStore (on the way)") print(" - Support multiple transmission protocols") print(" - RDMA between DRAM and HBM") @@ -234,10 +170,10 @@ def main(): try: print("Setting up TransferQueue...") - controller, storage_units, client = demonstrate_basic_setup() + tq.init() print("Demonstrating the user workflow...") - demonstrate_data_workflow(client) + demonstrate_data_workflow() demonstrate_storage_backend_options() @@ -253,7 +189,7 @@ def main(): print("3. You can swap out different storage backends easily") # Cleanup - client.close() + tq.close() ray.shutdown() print("\n✓ Cleanup complete") diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index 93e59ac7..d864db39 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -38,19 +38,13 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @@ -308,32 +302,8 @@ def demonstrate_real_workflow(): if not ray.is_initialized(): ray.init() - # Setup TransferQueue - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote() - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + # Initialize TransferQueue + tq.init() print("[Step 1] Putting data into TransferQueue...") input_ids = torch.randint(0, 1000, (8, 512)) @@ -348,7 +318,7 @@ def demonstrate_real_workflow(): ) partition_id = "demo_partition" - batch_meta = client.put(data=data_batch, partition_id=partition_id) + batch_meta = tq.put(data=data_batch, partition_id=partition_id) print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.") print("[Step 2] [Optional] Setting sample-level custom_meta...") @@ -360,11 +330,11 @@ def demonstrate_real_workflow(): batch_meta.update_custom_meta(custom_meta) print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}") - client.set_custom_meta(batch_meta) + tq.set_custom_meta(batch_meta) print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.") print("[Step 3] Try to get metadata from TransferQueue from other places...") - batch_meta = client.get_meta( + batch_meta = tq.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=8, partition_id=partition_id, @@ -383,7 +353,7 @@ def demonstrate_real_workflow(): print("✓ Selected 'input_ids' field only:") print(f" New field names: {selected_meta.field_names}") print(f" Samples still have same global indexes: {selected_meta.global_indexes}") - retrieved_data = client.get_data(selected_meta) + retrieved_data = tq.get_data(selected_meta) print(f" Retrieved data keys: {list(retrieved_data.keys())}") print("[Step 5] Select specific samples from the retrieved BatchMeta...") @@ -391,7 +361,7 @@ def demonstrate_real_workflow(): print("✓ Selected samples at indices [0, 2, 4, 6]:") print(f" New global indexes: {partial_meta.global_indexes}") print(f" Number of samples: {len(partial_meta)}") - retrieved_data = client.get_data(partial_meta) + retrieved_data = tq.get_data(partial_meta) print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}") print("[Step 6] Demonstrate chunk operation...") @@ -399,12 +369,12 @@ def demonstrate_real_workflow(): print(f"✓ Chunked into {len(chunks)} parts:") for i, chunk in enumerate(chunks): print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") - chunk_data = client.get_data(chunk) + chunk_data = tq.get_data(chunk) print(f" Chunk {i}: Retrieved chunk data: {chunk_data}") # Cleanup - client.clear_partition(partition_id=partition_id) - client.close() + tq.clear_partition(partition_id=partition_id) + tq.close() ray.shutdown() print("✓ Partition cleared and resources cleaned up") diff --git a/tutorial/03_understanding_controller.py b/tutorial/03_understanding_controller.py index 8b7e6d04..4ca426d4 100644 --- a/tutorial/03_understanding_controller.py +++ b/tutorial/03_understanding_controller.py @@ -36,59 +36,19 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" -def setup_transfer_queue(): - """Setup TransferQueue components.""" - if not ray.is_initialized(): - ray.init() - - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote() - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - return controller, storage_units, client - - def demonstrate_partition_isolation(): """Feature 1: Different partitions are isolated - data doesn't interfere.""" print("=" * 80) @@ -97,7 +57,10 @@ def demonstrate_partition_isolation(): print("\nDifferent partitions are completely isolated - data doesn't interfere between partitions") - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Partition 1: Training data print("\n[Partition 1] Putting training data...") @@ -108,7 +71,7 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - client.put(data=train_data, partition_id="train") + tq.put(data=train_data, partition_id="train") print(" ✓ Training data added to 'train' partition") # Partition 2: Validation data @@ -120,25 +83,23 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - client.put(data=val_data, partition_id="val") + tq.put(data=val_data, partition_id="val") print(" ✓ Validation data added to 'val' partition") # Get from train partition print("\n[Retrieving from 'train' partition]") - train_meta = client.get_meta( + train_meta = tq.get_meta( data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task" ) - retrieved_train_data = client.get_data(train_meta) + retrieved_train_data = tq.get_data(train_meta) print(f" ✓ Got BatchMeta={train_meta} from partition 'train'") print(f" ✓ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}") # Get from val partition print("\n[Retrieving from 'val' partition]") - val_meta = client.get_meta( - data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task" - ) - retrieved_val_data = client.get_data(val_meta) + val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task") + retrieved_val_data = tq.get_data(val_meta) print(f" ✓ Got BatchMeta={val_meta} from partition 'val'") print(f" ✓ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}") @@ -146,9 +107,9 @@ def demonstrate_partition_isolation(): print(" ✓ Data isolation: 'train' and 'val' partitions are completely independent") # Cleanup - client.clear_partition(partition_id="train") - client.clear_partition(partition_id="val") - client.close() + tq.clear_partition(partition_id="train") + tq.clear_partition(partition_id="val") + tq.close() ray.shutdown() @@ -160,7 +121,10 @@ def demonstrate_dynamic_expansion(): print("\nPartitions dynamically expand to accommodate new data (rows and columns)") - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Add first batch with 2 samples, 2 fields print("\n[Step 1] Adding initial data (2 samples, 2 fields)...") @@ -171,7 +135,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta1 = client.put(data=data1, partition_id="dynamic") + meta1 = tq.put(data=data1, partition_id="dynamic") print(" ✓ Added 2 samples") print(f" ✓ Got BatchMeta: {meta1} samples") @@ -184,9 +148,9 @@ def demonstrate_dynamic_expansion(): }, batch_size=3, ) - meta2 = client.put(data=data2, partition_id="dynamic") + meta2 = tq.put(data=data2, partition_id="dynamic") - all_meta = client.get_meta( + all_meta = tq.get_meta( data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task" ) print(" ✓ Added 3 more samples (total: 5)") @@ -201,7 +165,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta3 = client.put(data=data3, metadata=meta1) + meta3 = tq.put(data=data3, metadata=meta1) print(" ✓ Added 2 samples with new field 'field3'") print(f" ✓ Got BatchMeta: {meta3} for newly put data with new field") @@ -210,8 +174,8 @@ def demonstrate_dynamic_expansion(): print(" ✓ Columns auto-expand: Can add new fields anytime") # Cleanup - client.clear_partition(partition_id="dynamic") - client.close() + tq.clear_partition(partition_id="dynamic") + tq.close() ray.shutdown() @@ -221,7 +185,10 @@ def demonstrate_default_consumption_sample_strategy(): print("Feature 3: Default Sampling Strategy for Controller - No Duplicate, Sequential Samples") print("=" * 80) - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Add 6 samples print("\n[Setup] Adding 6 samples...") @@ -231,22 +198,22 @@ def demonstrate_default_consumption_sample_strategy(): }, batch_size=6, ) - client.put(data=all_data, partition_id="sampling") + tq.put(data=all_data, partition_id="sampling") print(" ✓ 6 samples added") # First get - should get samples 0,1,2 print("\n[Task A, Get 1] Requesting 3 samples...") - meta1 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta1.global_indexes}") # Second get - should get samples 3,4,5 (no duplicates!) print("\n[Task A, Get 2] Requesting 3 more samples...") - meta2 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta2.global_indexes}") # Third get - should get samples 0,1 print("\n[Task B, Get 1] Requesting 2 samples...") - meta3 = client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") + meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -257,8 +224,8 @@ def demonstrate_default_consumption_sample_strategy(): print(" ✓ Third get (Task B): samples 0,1") # Cleanup - client.clear_partition(partition_id="sampling") - client.close() + tq.clear_partition(partition_id="sampling") + tq.close() ray.shutdown() diff --git a/tutorial/04_custom_sampler.py b/tutorial/04_custom_sampler.py index 7bf13cd1..c35b4e0a 100644 --- a/tutorial/04_custom_sampler.py +++ b/tutorial/04_custom_sampler.py @@ -49,12 +49,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 from transfer_queue.sampler import BaseSampler # noqa: E402 @@ -171,36 +166,14 @@ def sample( def setup_transfer_queue_with_sampler(sampler): """Setup TransferQueue with custom sampler.""" if not ray.is_initialized(): - ray.init() + ray.init(namespace="TransferQueueTutorial") config = OmegaConf.create( - { - "global_batch_size": 8, - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(2): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote(sampler=sampler) - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, + {"controller": {"sampler": sampler}, "backend": {"SimpleStorage": {"num_data_storage_units": 2}}}, + flags={"allow_objects": True}, ) - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - return controller, storage_units, client + tq.init(config) def demonstrate_random_sampler_with_replacement(): @@ -211,7 +184,7 @@ def demonstrate_random_sampler_with_replacement(): print("\nSetup TransferQueue with RandomSamplerWithReplacement...") sampler = RandomSamplerWithReplacement() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 5 samples print("\n[Step 1] Adding 5 samples...") @@ -221,22 +194,22 @@ def demonstrate_random_sampler_with_replacement(): }, batch_size=5, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 5 samples added") # Get batch 1 (should get 2 random samples) print("\n[Step 2] Get batch 1 (2 samples)...") - meta1 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!) print("\n[Step 3] Get batch 2 (1 sample)...") - meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -246,8 +219,8 @@ def demonstrate_random_sampler_with_replacement(): print(f" ✓ All sampled: {all_sampled}") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() @@ -259,7 +232,7 @@ def demonstrate_random_sampler_without_replacement(): print("\nSetup TransferQueue with RandomSamplerWithoutReplacement...") sampler = RandomSamplerWithoutReplacement() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 6 samples print("\n[Step 1] Adding 6 samples...") @@ -269,22 +242,22 @@ def demonstrate_random_sampler_without_replacement(): }, batch_size=6, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 6 samples added") # Get batch 1 (should get 3 random samples without replacement) print("\n[Step 2] Get batch 1 (3 samples)...") - meta1 = client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") + meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should randomly get 1 sample that are different from previous batch) print("\n[Step 3] Get batch 2 (1 samples)...") - meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should randomly get 2 samples that are different from previous batch) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -294,8 +267,8 @@ def demonstrate_random_sampler_without_replacement(): print(f" ✓ Batch 3: {meta3.global_indexes} (none left)") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() @@ -307,7 +280,7 @@ def demonstrate_priority_sampler(): print("\nSetup TransferQueue with PrioritySampler...") sampler = PrioritySampler() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 8 samples print("\n[Step 1] Adding 8 samples...") @@ -317,7 +290,7 @@ def demonstrate_priority_sampler(): }, batch_size=8, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 8 samples added") time.sleep(1) @@ -330,7 +303,7 @@ def demonstrate_priority_sampler(): print(f"Priority scores: {priority_scores}") # Get batch using priority sampling - meta1 = client.get_meta( + meta1 = tq.get_meta( data_fields=["input"], batch_size=1, partition_id="test", @@ -342,7 +315,7 @@ def demonstrate_priority_sampler(): # Get another batch print("\n[Step 3] Get another batch (2 samples)...") - meta2 = client.get_meta( + meta2 = tq.get_meta( data_fields=["input"], batch_size=2, partition_id="test", @@ -358,8 +331,8 @@ def demonstrate_priority_sampler(): print(f" ✓ Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py index 916be001..4d92aa26 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/05_streaming_dataloader.py @@ -57,39 +57,25 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import DictConfig, OmegaConf # noqa: E402 +from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path for imports parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) - +import transfer_queue as tq # noqa: E402 from transfer_queue import ( # noqa: E402 RankAwareSampler, - SimpleStorageUnit, StreamingDataLoader, StreamingDataset, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, ) def setup_transfer_queue(): """Setup TransferQueue components.""" if not ray.is_initialized(): - ray.init() - - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) + ray.init(namespace="TransferQueueTutorial") print("[Setup]: Setup TransferQueue components") print( @@ -101,26 +87,23 @@ def setup_transfer_queue(): "TransferQueueController. In polling_mode, the controller will return empty BatchMeta when " "available data cannot meet the consumption requirements. User side need to retry later." ) - controller = TransferQueueController.remote( - sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank - polling_mode=True, # Enable polling mode for streaming data retrieval - ) - - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - # Build the complete configuration - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config.storage_backend = "AsyncSimpleStorageManager" - config = OmegaConf.merge(tq_config, config) + config = OmegaConf.create( + { + "controller": { + "sampler": RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank + "polling_mode": True, # Enable polling mode for streaming data retrieval + }, + "backend": {"SimpleStorage": {"num_data_storage_units": 2}}, + }, + flags={"allow_objects": True}, + ) - return controller, storage_units, config + tq.init(config) @ray.remote(num_cpus=0.1) -def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): +def generate_worker(rank_id: int, num_samples: int = 20): """ Generate actor that produces training samples. @@ -129,7 +112,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): Args: rank_id: Unique identifier for this generator (used for sample indexing) - config: TransferQueue configuration num_samples: Number of samples to generate Note: @@ -137,13 +119,9 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): This ensures global uniqueness across all generator actors. """ # Create a client for interacting with TransferQueue - client = TransferQueueClient( - client_id=f"gen_worker_{rank_id}", - controller_info=config.controller_info, - ) - # Initialize the storage manager for this client - client.initialize_storage_manager(manager_type=config.storage_backend, config=config) + # Need to call tq.init() in each process + tq.init() # Generate and put samples into the queue for i in range(num_samples): @@ -159,7 +137,7 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue") # Put data into the specified partition - client.put(data, partition_id="train") + tq.put(data, partition_id="train") print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue") @@ -168,7 +146,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): def update_worker( rank_id: int, dp_rank: int, - config: DictConfig, max_steps: int = 5, ): """ @@ -182,7 +159,6 @@ def update_worker( rank_id: Global rank identifier for logging and display purposes dp_rank: Data parallel rank ID that this worker belongs to The same Ranks receive the same data samples - config: TransferQueue configuration max_steps: Maximum number of batches to consume Returns: @@ -200,8 +176,15 @@ def update_worker( - batch_meta: Metadata for TransferQueue coordination (contains global_indexes) """ + # Need to call tq.init() in each process + tq.init() + # Step 1: Create StreamingDataset # This dataset integrates with TransferQueue and handles batch retrieval + + controller = ray.get_actor("TransferQueueController") + config = ray.get(controller.get_config.remote()) + dataset = StreamingDataset( config=config, batch_size=2, @@ -253,7 +236,7 @@ def update_worker( } -def start_all_generate_actors(config): +def start_all_generate_actors(): """ Launch generate_actors for producing training samples. """ @@ -261,12 +244,12 @@ def start_all_generate_actors(config): handlers = [] for i in range(num_workers): - handlers.append(generate_worker.remote(rank_id=i, config=config, num_samples=20)) + handlers.append(generate_worker.remote(rank_id=i, num_samples=20)) return handlers -def start_all_update_actors(config): +def start_all_update_actors(): """ Launch update_actors for consuming training samples. """ @@ -285,7 +268,6 @@ def start_all_update_actors(config): update_worker.remote( rank_id=rank_ids[i], dp_rank=dp_rank[i], - config=config, ) ) @@ -331,15 +313,15 @@ def main(): "global_batch_size to make sure consumers can accurately determine consumption status even before " "producers have generated the samples." ) - controller, storage_units, config = setup_transfer_queue() + setup_transfer_queue() # Step 2: Launch data generation actors print("\n[Phase 2] Starting data generation...") - generate_worker_handlers = start_all_generate_actors(config) + generate_worker_handlers = start_all_generate_actors() # Step 3: Launch data consumption actors print("\n[Phase 3] Starting data consumption...") - update_worker_handlers = start_all_update_actors(config) + update_worker_handlers = start_all_update_actors() # Wait for completion print("\n[Phase 4] Waiting for actors to complete...")