From a2f70f1f7307d9ad53f1a1d345ca78746122fed3 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Feb 2026 20:30:03 +0800 Subject: [PATCH 01/14] simplify init & interface Signed-off-by: 0oshowero0 --- requirements.txt | 3 +- transfer_queue/__init__.py | 41 +- transfer_queue/client.py | 39 +- transfer_queue/config.yaml | 29 + transfer_queue/controller.py | 10 + transfer_queue/interface.py | 512 ++++++++++++++++++ .../storage/managers/mooncake_manager.py | 2 +- .../managers/simple_backend_manager.py | 2 +- .../storage/managers/yuanrong_manager.py | 2 +- transfer_queue/utils/zmq_utils.py | 35 +- 10 files changed, 612 insertions(+), 63 deletions(-) create mode 100644 transfer_queue/config.yaml create mode 100644 transfer_queue/interface.py diff --git a/requirements.txt b/requirements.txt index 1da090c8..8cfccacb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ pyzmq hydra-core numpy<2.0.0 msgspec -psutil \ No newline at end of file +psutil +omegaconf \ No newline at end of file diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index e9259024..3bb10cbe 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -15,12 +15,10 @@ import os -from .client import ( - TransferQueueClient, - process_zmq_server_info, -) +from .client import TransferQueueClient from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset +from .interface import * from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler @@ -28,23 +26,26 @@ from .sampler.sequential_sampler import SequentialSampler from .storage import SimpleStorageUnit from .utils.common import get_placement_group -from .utils.zmq_utils import ZMQServerInfo +from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info -__all__ = [ - "TransferQueueClient", - "StreamingDataset", - "StreamingDataLoader", - "BatchMeta", - "TransferQueueController", - "SimpleStorageUnit", - "ZMQServerInfo", - "process_zmq_server_info", - "get_placement_group", - "BaseSampler", - "GRPOGroupNSampler", - "SequentialSampler", - "RankAwareSampler", -] +__all__ = interface.__all__ +__all__.extend( + [ + "TransferQueueClient", + "StreamingDataset", + "StreamingDataLoader", + "BatchMeta", + "TransferQueueController", + "SimpleStorageUnit", + "ZMQServerInfo", + "process_zmq_server_info", + "get_placement_group", + "BaseSampler", + "GRPOGroupNSampler", + "SequentialSampler", + "RankAwareSampler", + ] +) version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 24b4e559..68b746f7 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -18,23 +18,19 @@ import os import threading from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional from uuid import uuid4 -import ray import torch import zmq import zmq.asyncio from tensordict import TensorDict from torch import Tensor -from transfer_queue.controller import TransferQueueController from transfer_queue.metadata import ( BatchMeta, ) from transfer_queue.storage import ( - SimpleStorageUnit, - TransferQueueStorageManager, TransferQueueStorageManagerFactory, ) from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads @@ -1304,36 +1300,3 @@ def close(self) -> None: logger.warning(f"[{self.client_id}]: Error closing event loop: {e}") super().close() - - -def process_zmq_server_info( - handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]] - | Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"], -): # noqa: UP007 - """Extract ZMQ server information from handler objects. - - Args: - handlers: Dictionary of handler objects (controllers, storage managers, or storage units), - or a single handler object - - Returns: - If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information - If handlers is a single object: ZMQ server information for that object - - Examples: - >>> # Single handler - >>> controller = TransferQueueController.remote(...) - >>> info = process_zmq_server_info(controller) - >>> - >>> # Multiple handlers - >>> handlers = {"storage_0": storage_0, "storage_1": storage_1} - >>> info_dict = process_zmq_server_info(handlers)""" - # Handle single handler object case - if not isinstance(handlers, dict): - return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] - else: - # Handle dictionary case - server_info = {} - for name, handler in handlers.items(): - server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] - return server_info diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml new file mode 100644 index 00000000..ea27d7f3 --- /dev/null +++ b/transfer_queue/config.yaml @@ -0,0 +1,29 @@ +# This is the default configuration of TransferQueue. Users may modify the default value +# and use transfer_queue.init(conf) to overwrite the config entries. + + +controller: + sampler: SequentialSampler + polling_mode: False + # ZMQ Server IP & Ports (automatically generated during init) + zmq_info: None + +backend: + # Pluggable storage/transport backend of TransferQueue. Choose from: + # SimpleStorage, Yuanrong, MooncakeStore, ... + storage_backend: SimpleStorage + + # For SimpleStorage: + SimpleStorage: + # Total number of samples + total_storage_size: 100000 + # Number of distributed storage units for SimpleStorage backend + num_data_storage_units: 2 + # ZMQ Server IP & Ports (automatically generated during init) + zmq_info: None + + # For Yuanrong: + # TODO + + # For MooncakeStore: + # TODO \ No newline at end of file diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c91eaf7c..07f1abf8 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -28,6 +28,7 @@ import ray import torch import zmq +from omegaconf import DictConfig from ray.util import get_node_ip_address from torch import Tensor @@ -879,6 +880,7 @@ def __init__( self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}" self.polling_mode = polling_mode + self.tq_config = None # global config for TransferQueue system # Initialize ZMQ sockets for communication self._init_zmq_socket() @@ -1772,3 +1774,11 @@ def _update_data_status(self): def get_zmq_server_info(self) -> ZMQServerInfo: """Get ZMQ server connection information.""" return self.zmq_server_info + + def store_config(self, conf: DictConfig) -> None: + """Storage the global config of TransferQueue.""" + self.tq_conf = conf + + def get_config(self) -> DictConfig: + """Retrieve the global config of TransferQueue.""" + return self.tq_conf diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py new file mode 100644 index 00000000..bc1482bf --- /dev/null +++ b/transfer_queue/interface.py @@ -0,0 +1,512 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from typing import Any, Optional + +import ray +from omegaconf import DictConfig, OmegaConf +from tensordict import TensorDict + +from transfer_queue.client import TransferQueueClient +from transfer_queue.controller import TransferQueueController +from transfer_queue.metadata import BatchMeta +from transfer_queue.sampler import * # noqa: F401 +from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.utils.common import get_placement_group +from transfer_queue.utils.zmq_utils import process_zmq_server_info + +_TRANSFER_QUEUE_CLIENT = None +_TRANSFER_QUEUE_STORAGE = None + +__all__ = [ + "init", + "get_meta", + "get_data", + "put", + "set_custom_meta", + "clear_samples", + "async_get_meta", + "async_get_data", + "async_put", + "async_set_custom_meta", + "async_clear_samples", +] + + +def _maybe_create_transferqueue_client( + conf: Optional[DictConfig] = None, +) -> TransferQueueClient: + global _TRANSFER_QUEUE_CLIENT + if _TRANSFER_QUEUE_CLIENT is None: + if conf is None: + raise ValueError("Missing config for initialing TranferQueueClient!") + pid = os.getpid() + _TRANSFER_QUEUE_CLIENT = TransferQueueClient( + client_id=f"TQClient_{pid}", controller_info=conf.controller.zmq_info + ) + + backend_name = conf.backend.storage_backend + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) + + return _TRANSFER_QUEUE_CLIENT + + +def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: + global _TRANSFER_QUEUE_STORAGE + + if _TRANSFER_QUEUE_STORAGE is None: + _TRANSFER_QUEUE_STORAGE = [] + if conf.backend.storage_backend == "SimpleStorage": + # initialize SimpleStorageUnit + num_data_storage_units = conf.backend.num_data_storage_units + total_storage_size = conf.backend.total_storage_size + storage_placement_group = get_placement_group(conf.backend.num_data_storage_units, num_cpus_per_actor=1) + + for storage_unit_rank in range(num_data_storage_units): + storage_node = SimpleStorageUnit.options( + placement_group=storage_placement_group, + placement_group_bundle_index=storage_unit_rank, + name=f"TQStorageUnit#{storage_unit_rank}", + lifetime="detached", + ).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units)) + _TRANSFER_QUEUE_STORAGE.append(storage_node) + print(f"TQStorageUnit#{storage_unit_rank} has been created.") + + # extract zmq info + storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_CLIENT) + conf.backend.zmq_info = storage_zmq_info + + return conf + + +def init(conf: Optional[DictConfig] = None) -> None: + try: + # Already initialize TransferQueue + controller = ray.get_actor("TransferQueueController") + conf = ray.get(controller.get_config()) + _maybe_create_transferqueue_client(conf) + + except ValueError: + # First-time initialize TransferQueue + + # create config + final_conf = OmegaConf.create({}, flags={"allow_objects": True}) + default_conf = OmegaConf.load("config.yaml") + final_conf = OmegaConf.merge(final_conf, default_conf) + final_conf = OmegaConf.merge(final_conf, conf) + + # create controller + try: + sampler = globals()[final_conf.controller.sampler] + except KeyError: + raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") + + controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( + sampler=sampler, polling_mode=final_conf.controller.polling_mode + ) + + controller_zmq_info = process_zmq_server_info(controller) + final_conf.controller.zmq_info = controller_zmq_info + + # create distributed storage backends + final_conf = _maybe_create_transferqueue_storage(final_conf) + + # storage the config into controller + controller.store_config(final_conf) + + # create client + _maybe_create_transferqueue_client(final_conf) + + +def get_meta( + data_fields: list[str], + batch_size: int, + partition_id: str, + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Synchronously fetch data metadata from the controller via ZMQ. + + Args: + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Example 1: Basic fetch metadata + >>> batch_meta = client.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = client.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... sampling_config={"n_samples_per_prompt": 4} + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = client.get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... ) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ + + tq_client = _maybe_create_transferqueue_client() + tq_client.get_meta(data_fields, batch_size, partition_id, task_name, sampling_config) + + +async def async_get_meta( + data_fields: list[str], + batch_size: int, + partition_id: str, + mode: str = "fetch", + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Asynchronously fetch data metadata from the controller via ZMQ. + + Args: + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + socket: ZMQ async socket for message transmission (injected by decorator) + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Example 1: Basic fetch metadata + >>> batch_meta = asyncio.run(client.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... )) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = asyncio.run(client.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = asyncio.run(client.async_get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... )) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ + + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + + +def get_data(metadata: BatchMeta) -> TensorDict: + """Synchronously fetch data from storage units and organize into TensorDict. + + Args: + metadata: Batch metadata containing data location information and global indexes + + Returns: + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") + + Example: + >>> batch_meta = client.get_data( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = client.get_data(batch_meta) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.get_data(metadata) + + +async def async_get_data(metadata: BatchMeta) -> TensorDict: + """Asynchronously fetch data from storage units and organize into TensorDict. + + Args: + metadata: Batch metadata containing data location information and global indexes + + Returns: + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") + + Example: + >>> batch_meta = asyncio.run(client.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(client.async_get_data(batch_meta)) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_data(metadata) + + +async def async_put( + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, +) -> BatchMeta: + """Asynchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = asyncio.run(client.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(client.async_get_data(batch_meta)) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> asyncio.run(client.async_put(data=output, metadata=batch_meta)) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_put(data, metadata, partition_id) + + +def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: + """Synchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = client.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = client.get_data(batch_meta) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> client.put(data=output, metadata=batch_meta) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.put(data, metadata, partition_id) + + +def set_custom_meta(metadata: BatchMeta) -> None: + """Synchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Create batch with custom metadata + >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> client.set_custom_meta(batch_meta) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.set_custom_meta(metadata) + + +async def async_set_custom_meta( + metadata: BatchMeta, +) -> None: + """ + Asynchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + socket: ZMQ async socket for message transmission (injected by decorator) + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Create batch with custom metadata + >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> asyncio.run(client.async_set_custom_meta(batch_meta)) + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_set_custom_meta(metadata) + + +def clear_samples(metadata: BatchMeta): + """Synchronously clear specific samples from all storage units and the controller. + + Args: + metadata: The BatchMeta of the corresponding data to be cleared + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.clear_samples(metadata) + + +async def async_clear_samples(metadata: BatchMeta): + """Asynchronously clear specific samples from all storage units and the controller. + + Args: + metadata: The BatchMeta of the corresponding data to be cleared + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_clear_samples(metadata) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index ca555668..e787ffbb 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -24,7 +24,7 @@ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) -@TransferQueueStorageManagerFactory.register("MooncakeStorageManager") +@TransferQueueStorageManagerFactory.register("MooncakeStore") class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 5c6e68f0..4fb78bc9 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -48,7 +48,7 @@ TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) -@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager") +@TransferQueueStorageManagerFactory.register("SimpleStorage") class AsyncSimpleStorageManager(TransferQueueStorageManager): """Asynchronous storage manager that handles multiple storage units. diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index bfb79e6c..6935203c 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -30,7 +30,7 @@ logger.addHandler(handler) -@TransferQueueStorageManagerFactory.register("YuanrongStorageManager") +@TransferQueueStorageManagerFactory.register("Yuanrong") class YuanrongStorageManager(KVStorageManager): """Storage manager for Yuanrong backend.""" diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 1f6ed922..b42b855e 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -19,7 +19,7 @@ import socket import time from dataclasses import dataclass -from typing import Any, Optional, TypeAlias +from typing import Any, Optional, TypeAlias, Union from uuid import uuid4 import psutil @@ -239,3 +239,36 @@ def create_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) return socket + + +def process_zmq_server_info( + handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]] + | Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"], +): # noqa: UP007 + """Extract ZMQ server information from handler objects. + + Args: + handlers: Dictionary of handler objects (controllers, storage managers, or storage units), + or a single handler object + + Returns: + If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information + If handlers is a single object: ZMQ server information for that object + + Examples: + >>> # Single handler + >>> controller = TransferQueueController.remote(...) + >>> info = process_zmq_server_info(controller) + >>> + >>> # Multiple handlers + >>> handlers = {"storage_0": storage_0, "storage_1": storage_1} + >>> info_dict = process_zmq_server_info(handlers)""" + # Handle single handler object case + if not isinstance(handlers, dict): + return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] + else: + # Handle dictionary case + server_info = {} + for name, handler in handlers.items(): + server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] + return server_info From 27983549a14c5721b1a6e0409a69fe159da2ee8e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Feb 2026 20:52:18 +0800 Subject: [PATCH 02/14] fix pre-commit Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 35 ++++++------- transfer_queue/interface.py | 83 ++++++++++++++++--------------- transfer_queue/utils/zmq_utils.py | 6 +-- 3 files changed, 63 insertions(+), 61 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 3bb10cbe..5800d7a5 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -15,10 +15,10 @@ import os +from . import interface from .client import TransferQueueClient from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset -from .interface import * from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler @@ -28,24 +28,21 @@ from .utils.common import get_placement_group from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info -__all__ = interface.__all__ -__all__.extend( - [ - "TransferQueueClient", - "StreamingDataset", - "StreamingDataLoader", - "BatchMeta", - "TransferQueueController", - "SimpleStorageUnit", - "ZMQServerInfo", - "process_zmq_server_info", - "get_placement_group", - "BaseSampler", - "GRPOGroupNSampler", - "SequentialSampler", - "RankAwareSampler", - ] -) +__all__ = interface.__all__ + [ + "TransferQueueClient", + "StreamingDataset", + "StreamingDataLoader", + "BatchMeta", + "TransferQueueController", + "SimpleStorageUnit", + "ZMQServerInfo", + "process_zmq_server_info", + "get_placement_group", + "BaseSampler", + "GRPOGroupNSampler", + "SequentialSampler", + "RankAwareSampler", +] version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index bc1482bf..68093960 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import math import os from typing import Any, Optional @@ -29,22 +30,11 @@ from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.zmq_utils import process_zmq_server_info -_TRANSFER_QUEUE_CLIENT = None -_TRANSFER_QUEUE_STORAGE = None +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) -__all__ = [ - "init", - "get_meta", - "get_data", - "put", - "set_custom_meta", - "clear_samples", - "async_get_meta", - "async_get_data", - "async_put", - "async_set_custom_meta", - "async_clear_samples", -] +_TRANSFER_QUEUE_CLIENT: Any = None +_TRANSFER_QUEUE_STORAGE: Any = None def _maybe_create_transferqueue_client( @@ -77,7 +67,7 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: storage_placement_group = get_placement_group(conf.backend.num_data_storage_units, num_cpus_per_actor=1) for storage_unit_rank in range(num_data_storage_units): - storage_node = SimpleStorageUnit.options( + storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank, name=f"TQStorageUnit#{storage_unit_rank}", @@ -113,9 +103,9 @@ def init(conf: Optional[DictConfig] = None) -> None: try: sampler = globals()[final_conf.controller.sampler] except KeyError: - raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") + raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None - controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( + controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined] sampler=sampler, polling_mode=final_conf.controller.polling_mode ) @@ -191,7 +181,7 @@ def get_meta( """ tq_client = _maybe_create_transferqueue_client() - tq_client.get_meta(data_fields, batch_size, partition_id, task_name, sampling_config) + return tq_client.get_meta(data_fields, batch_size, partition_id, task_name, sampling_config) async def async_get_meta( @@ -308,12 +298,8 @@ async def async_get_data(metadata: BatchMeta) -> TensorDict: return await tq_client.async_get_data(metadata) -async def async_put( - data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, -) -> BatchMeta: - """Asynchronously write data to storage units based on metadata. +def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: + """Synchronously write data to storage units based on metadata. If metadata is not provided, it will be created automatically using insert mode with the provided data fields and partition_id. @@ -344,16 +330,16 @@ async def async_put( >>> seq_len = 16 >>> current_partition_id = "train_0" >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = asyncio.run(client.async_get_meta( + >>> batch_meta = client.get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=batch_size, ... partition_id=current_partition_id, ... mode="fetch", ... task_name="generate_sequences", - ... )) - >>> batch = asyncio.run(client.async_get_data(batch_meta)) + ... ) + >>> batch = client.get_data(batch_meta) >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> asyncio.run(client.async_put(data=output, metadata=batch_meta)) + >>> client.put(data=output, metadata=batch_meta) >>> >>> # Example 2: Initial data insertion without pre-existing metadata >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! @@ -366,14 +352,18 @@ async def async_put( >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) >>> # This will create metadata in "insert" mode internally. - >>> metadata = asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) + >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) """ tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_put(data, metadata, partition_id) + return tq_client.put(data, metadata, partition_id) -def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: - """Synchronously write data to storage units based on metadata. +async def async_put( + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, +) -> BatchMeta: + """Asynchronously write data to storage units based on metadata. If metadata is not provided, it will be created automatically using insert mode with the provided data fields and partition_id. @@ -404,16 +394,16 @@ def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Op >>> seq_len = 16 >>> current_partition_id = "train_0" >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = client.get_meta( + >>> batch_meta = asyncio.run(client.async_get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=batch_size, ... partition_id=current_partition_id, ... mode="fetch", ... task_name="generate_sequences", - ... ) - >>> batch = client.get_data(batch_meta) + ... )) + >>> batch = asyncio.run(client.async_get_data(batch_meta)) >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> client.put(data=output, metadata=batch_meta) + >>> asyncio.run(client.async_put(data=output, metadata=batch_meta)) >>> >>> # Example 2: Initial data insertion without pre-existing metadata >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! @@ -426,10 +416,10 @@ def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Op >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) >>> # This will create metadata in "insert" mode internally. - >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) + >>> metadata = asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) """ tq_client = _maybe_create_transferqueue_client() - return tq_client.put(data, metadata, partition_id) + return await tq_client.async_put(data, metadata, partition_id) def set_custom_meta(metadata: BatchMeta) -> None: @@ -510,3 +500,18 @@ async def async_clear_samples(metadata: BatchMeta): """ tq_client = _maybe_create_transferqueue_client() return await tq_client.async_clear_samples(metadata) + + +__all__ = [ + "init", + "get_meta", + "get_data", + "put", + "set_custom_meta", + "clear_samples", + "async_get_meta", + "async_get_data", + "async_put", + "async_set_custom_meta", + "async_clear_samples", +] diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index b42b855e..bf711c8f 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -19,10 +19,11 @@ import socket import time from dataclasses import dataclass -from typing import Any, Optional, TypeAlias, Union +from typing import Any, Optional, TypeAlias from uuid import uuid4 import psutil +import ray import zmq from transfer_queue.utils.common import ( @@ -242,8 +243,7 @@ def create_zmq_socket( def process_zmq_server_info( - handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]] - | Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"], + handlers: dict[Any, Any] | Any, ): # noqa: UP007 """Extract ZMQ server information from handler objects. From 5eb6d702d283d2c0f87f980ff77ef8fb12585460 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Feb 2026 21:04:52 +0800 Subject: [PATCH 03/14] fix pre-commit Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 28 +++++++++++++++++++++-- transfer_queue/interface.py | 45 ++++++++++++++++++++++++------------- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 5800d7a5..a09fa428 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -15,10 +15,22 @@ import os -from . import interface from .client import TransferQueueClient from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset +from .interface import ( + async_clear_samples, + async_get_data, + async_get_meta, + async_put, + async_set_custom_meta, + clear_samples, + get_data, + get_meta, + init, + put, + set_custom_meta, +) from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler @@ -28,7 +40,19 @@ from .utils.common import get_placement_group from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info -__all__ = interface.__all__ + [ +__all__ = [ + "init", + "get_meta", + "get_data", + "put", + "set_custom_meta", + "clear_samples", + "async_get_meta", + "async_get_data", + "async_put", + "async_set_custom_meta", + "async_clear_samples", +] + [ "TransferQueueClient", "StreamingDataset", "StreamingDataLoader", diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 68093960..f2a066c9 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -84,6 +84,36 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: def init(conf: Optional[DictConfig] = None) -> None: + """Initialize the TransferQueue system. + + This function sets up the TransferQueue controller, distributed storage, and client. + It should be called once at the beginning of the program before any data operations. + + If a controller already exists (e.g., initialized by another process), this function + will retrieve the config from existing controller and initialize the TransferQueueClient. + In this case, the `conf` parameter will be ignored. + + Args: + conf: Optional configuration dictionary. If provided, it will be merged with + the default config from 'config.yaml'. This is only used for first-time + initializing. When connecting to an existing controller, this parameter + is ignored. + + Raises: + ValueError: If config is not valid or required configuration keys are missing. + + Example: + >>> # In process 0, node A + >>> import transfer_queue as tq + >>> tq.init() # Initialize the TransferQueue + >>> tq.put(...) # then you can use tq for data operations + >>> + >>> # In process 1, node B (with Ray connected to node A) + >>> import transfer_queue as tq + >>> tq.init() # This will only initialize a TransferQueueClient and link with existing TQ + >>> metadata = tq.get_meta(...) + >>> data = tq.get_data(metadata) + """ try: # Already initialize TransferQueue controller = ray.get_actor("TransferQueueController") @@ -500,18 +530,3 @@ async def async_clear_samples(metadata: BatchMeta): """ tq_client = _maybe_create_transferqueue_client() return await tq_client.async_clear_samples(metadata) - - -__all__ = [ - "init", - "get_meta", - "get_data", - "put", - "set_custom_meta", - "clear_samples", - "async_get_meta", - "async_get_data", - "async_put", - "async_set_custom_meta", - "async_clear_samples", -] From 512e6595452b480572491024cbb749bcdc579670 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Feb 2026 21:16:02 +0800 Subject: [PATCH 04/14] fix ut Signed-off-by: 0oshowero0 --- README.md | 6 ++--- tests/test_client.py | 4 +-- transfer_queue/storage/managers/factory.py | 29 +++++++++++++++++++--- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index b69176d8..2c7d4f0f 100644 --- a/README.md +++ b/README.md @@ -73,10 +73,10 @@ This class encapsulates the core interaction logic within the TransferQueue syst Currently, we support the following storage backends: -- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability. +- SimpleStorage: A basic CPU memory storage with minimal data format constraints and easy usability. - [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem) (beta, [#PR107](https://github.com/TransferQueue/TransferQueue/pull/107), [#PR96](https://github.com/TransferQueue/TransferQueue/pull/96)): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD. -- [Mooncake Store](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. -- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. +- [MooncakeStore](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. +- [RayRDT](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management. diff --git a/tests/test_client.py b/tests/test_client.py index 38f140b0..7a8037a1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -320,7 +320,7 @@ def client_setup(mock_controller, mock_storage): "controller_info": mock_controller.zmq_server_info, "storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info}, } - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations async def mock_put_data(data, metadata): @@ -413,7 +413,7 @@ def test_single_controller_multiple_storages(): "controller_info": controller.zmq_server_info, "storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages}, } - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations async def mock_put_data(data, metadata): diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py index e595ccd8..84738e1f 100644 --- a/transfer_queue/storage/managers/factory.py +++ b/transfer_queue/storage/managers/factory.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Any from transfer_queue.storage.managers.base import TransferQueueStorageManager @@ -42,7 +43,29 @@ def decorator(manager_cls: type[TransferQueueStorageManager]): def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager: """Create and return a TransferQueueStorageManager instance.""" if manager_type not in cls._registry: - raise ValueError( - f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" - ) + if manager_type == "AsyncSimpleStorageManager": + warnings.warn( + f"The manager_type {manager_type} is deprecated, use SimpleStorage instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "SimpleStorage" + elif manager_type == "MooncakeStorageManager": + warnings.warn( + f"The manager_type {manager_type} is deprecated, use MooncakeStore instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "MooncakeStore" + elif manager_type == "YuanrongStorageManager": + warnings.warn( + f"The manager_type {manager_type} is deprecated, use Yuanrong instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "Yuanrong" + else: + raise ValueError( + f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" + ) return cls._registry[manager_type](config) From 5d6146efd8cef1b2073ac217fae59427a9e3d978 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Feb 2026 21:46:25 +0800 Subject: [PATCH 05/14] fix Signed-off-by: 0oshowero0 --- pyproject.toml | 3 +- transfer_queue/interface.py | 63 +++++++++++--- transfer_queue/storage/managers/base.py | 3 +- .../managers/simple_backend_manager.py | 17 +++- tutorial/01_core_components.py | 85 +++---------------- 5 files changed, 77 insertions(+), 94 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f8539700..35d65242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,5 +126,6 @@ yuanrong = [ # This is the rough equivalent of package_data={'': ['version/*']} [tool.setuptools.package-data] transfer_queue = [ - "version/*", + "version/*", + "*.yaml" ] \ No newline at end of file diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index f2a066c9..c5ec87e2 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.resources as pkg_resources import logging import math import os +import time from typing import Any, Optional import ray @@ -43,13 +45,14 @@ def _maybe_create_transferqueue_client( global _TRANSFER_QUEUE_CLIENT if _TRANSFER_QUEUE_CLIENT is None: if conf is None: - raise ValueError("Missing config for initialing TranferQueueClient!") + raise ValueError("Missing config for initialing TransferQueueClient!") pid = os.getpid() _TRANSFER_QUEUE_CLIENT = TransferQueueClient( - client_id=f"TQClient_{pid}", controller_info=conf.controller.zmq_info + client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info ) backend_name = conf.backend.storage_backend + conf.backend[backend_name].controller_info = conf.controller.zmq_info _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) return _TRANSFER_QUEUE_CLIENT @@ -59,25 +62,25 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: global _TRANSFER_QUEUE_STORAGE if _TRANSFER_QUEUE_STORAGE is None: - _TRANSFER_QUEUE_STORAGE = [] + _TRANSFER_QUEUE_STORAGE = {} if conf.backend.storage_backend == "SimpleStorage": # initialize SimpleStorageUnit - num_data_storage_units = conf.backend.num_data_storage_units - total_storage_size = conf.backend.total_storage_size - storage_placement_group = get_placement_group(conf.backend.num_data_storage_units, num_cpus_per_actor=1) + num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units + total_storage_size = conf.backend.SimpleStorage.total_storage_size + storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) for storage_unit_rank in range(num_data_storage_units): storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank, - name=f"TQStorageUnit#{storage_unit_rank}", + name=f"TransferQueueStorageUnit#{storage_unit_rank}", lifetime="detached", ).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units)) - _TRANSFER_QUEUE_STORAGE.append(storage_node) - print(f"TQStorageUnit#{storage_unit_rank} has been created.") + _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + print(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") # extract zmq info - storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_CLIENT) + storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) conf.backend.zmq_info = storage_zmq_info return conf @@ -118,16 +121,22 @@ def init(conf: Optional[DictConfig] = None) -> None: # Already initialize TransferQueue controller = ray.get_actor("TransferQueueController") conf = ray.get(controller.get_config()) + while conf is None: + print("Waiting for controller to initialize... Retrying") + time.sleep(1) + conf = ray.get_actor("TransferQueueController").get_config() _maybe_create_transferqueue_client(conf) except ValueError: # First-time initialize TransferQueue - + # TODO: 是否会有竞态?文件锁好像都没法解决,需要调查Ray是否有相关机制 # create config final_conf = OmegaConf.create({}, flags={"allow_objects": True}) - default_conf = OmegaConf.load("config.yaml") + with pkg_resources.path("transfer_queue", "config.yaml") as p: + default_conf = OmegaConf.load(p) final_conf = OmegaConf.merge(final_conf, default_conf) - final_conf = OmegaConf.merge(final_conf, conf) + if conf: + final_conf = OmegaConf.merge(final_conf, conf) # create controller try: @@ -146,7 +155,7 @@ def init(conf: Optional[DictConfig] = None) -> None: final_conf = _maybe_create_transferqueue_storage(final_conf) # storage the config into controller - controller.store_config(final_conf) + ray.get(controller.store_config.remote(final_conf)) # create client _maybe_create_transferqueue_client(final_conf) @@ -530,3 +539,29 @@ async def async_clear_samples(metadata: BatchMeta): """ tq_client = _maybe_create_transferqueue_client() return await tq_client.async_clear_samples(metadata) + + +def clear_partition(partition_id: str): + """Synchronously clear the whole partition from all storage units and the controller. + + Args: + partition_id: The partition id to clear data for + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.clear_partition(partition_id) + + +async def async_clear_partition(partition_id: str): + """Asynchronously clear the whole partition from all storage units and the controller. + + Args: + partition_id: The partition id to clear data for + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_clear_partition(partition_id) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index c07ec1f9..c48b5991 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -28,6 +28,7 @@ import ray import torch import zmq +from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from torch import Tensor @@ -59,7 +60,7 @@ class TransferQueueStorageManager(ABC): """Base class for storage layer. It defines the interface for data operations and generally provides handshake & notification capabilities.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, config: DictConfig[str, Any]): self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}" self.config = config controller_info = config.get("controller_info") diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 4fb78bc9..54eebb01 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -16,6 +16,7 @@ import asyncio import logging import os +import warnings from collections.abc import Mapping from functools import wraps from operator import itemgetter @@ -24,6 +25,7 @@ import torch import zmq +from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from transfer_queue.metadata import BatchMeta @@ -56,14 +58,23 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager): instances using ZMQ communication and dynamic socket management. """ - def __init__(self, config: dict[str, Any]): + def __init__(self, config: DictConfig[str, Any]): super().__init__(config) self.config = config - server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("storage_unit_infos", None) + server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.zmq_info if server_infos is None: - raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.") + server_infos = config.storage_unit_infos + if server_infos is not None: + warnings.warn( + "The config entry `storage_unit_infos` is deprecated, use `zmq_info` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + if server_infos is None: + raise ValueError("AsyncSimpleStorageManager requires non-empty 'zmq_info' in config.") self.storage_unit_infos = self._register_servers(server_infos) self._build_storage_mapping_functions() diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 25b530c7..b250d4cf 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -37,87 +37,23 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" - -def demonstrate_basic_setup(): - """ - Demonstrate the basic setup of TransferQueue with three core components. - """ - - # Initialize Ray - if not ray.is_initialized(): - ray.init() - - # Configuration - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - print("[Step 1] Creating Storage Backend (using default SimpleStorageUnit)...") - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - print(f" ✓ Created SimpleStorageUnit #{i}") - - print("[Step 2] Creating TransferQueueController...") - controller = TransferQueueController.remote() - print(" ✓ Controller created - manages data state") - - # Get server information - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - # Create Client (User-facing API) - print("[Step 3] Creating TransferQueueClient...") - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - print(" ✓ Client created - this is what users interact with!") - - # Initialize storage manager - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - print( - " ✓ Storage manager initialized. It is a class variable inside the client, acting as an adapter to " - "suit for various storage backends." - ) - - print("[Architecture Summary]") - print( - " - TransferQueueController: Tracking the production/consumption status as metadata (can define your own " - "data consumption logics)." - ) - print(" - SimpleStorageUnit: Distributed data storage that holds actual data (easily swap out by other backends).") - print(" - TransferQueueClient: User interface that allows you to put/get/clear data or metadata)") - - return controller, storage_units, client +if not ray.is_initialized(): + ray.init() -def demonstrate_data_workflow(client): +def demonstrate_data_workflow(): """ Demonstrate basic data workflow: put → get → clear. """ @@ -148,12 +84,12 @@ def demonstrate_data_workflow(client): print(f" Created {data_batch.batch_size[0]} samples") partition_id = "tutorial_partition_0" - client.put(data=data_batch, partition_id=partition_id) + tq.put(data=data_batch, partition_id=partition_id) print(f" ✓ Data written to partition: {partition_id}") # Step 2: Get metadata print("[Step 2] Requesting data metadata...") - batch_meta = client.get_meta( + batch_meta = tq.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=data_batch.batch_size[0], partition_id=partition_id, @@ -164,7 +100,7 @@ def demonstrate_data_workflow(client): # Step 3: Get actual data print("[Step 3] Retrieving actual data...") - retrieved_data = client.get_data(batch_meta) + retrieved_data = tq.get_data(batch_meta) print(" ✓ Data retrieved successfully") print(f" Keys: {list(retrieved_data.keys())}") @@ -176,7 +112,7 @@ def demonstrate_data_workflow(client): # Step 5: Clear print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)") - client.clear_partition(partition_id=partition_id) + tq.clear_partition(partition_id=partition_id) print(" ✓ Partition cleared") @@ -234,10 +170,10 @@ def main(): try: print("Setting up TransferQueue...") - controller, storage_units, client = demonstrate_basic_setup() + tq.init() print("Demonstrating the user workflow...") - demonstrate_data_workflow(client) + demonstrate_data_workflow() demonstrate_storage_backend_options() @@ -253,7 +189,6 @@ def main(): print("3. You can swap out different storage backends easily") # Cleanup - client.close() ray.shutdown() print("\n✓ Cleanup complete") From 812e55f8911221a6703e12ef0a7926344c69a1e2 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Feb 2026 21:54:45 +0800 Subject: [PATCH 06/14] fix Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index c5ec87e2..d7715eea 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -124,7 +124,7 @@ def init(conf: Optional[DictConfig] = None) -> None: while conf is None: print("Waiting for controller to initialize... Retrying") time.sleep(1) - conf = ray.get_actor("TransferQueueController").get_config() + conf = ray.get(ray.get_actor("TransferQueueController").get_config()) _maybe_create_transferqueue_client(conf) except ValueError: @@ -140,6 +140,7 @@ def init(conf: Optional[DictConfig] = None) -> None: # create controller try: + # TODO: support sampler instance sampler = globals()[final_conf.controller.sampler] except KeyError: raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None From 3365f30d11c9c680d667f305a3cd275fa404107b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Feb 2026 22:20:47 +0800 Subject: [PATCH 07/14] fix Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 5 ++ transfer_queue/config.yaml | 4 +- transfer_queue/interface.py | 80 +++++++++++-------- .../managers/simple_backend_manager.py | 4 +- tutorial/01_core_components.py | 2 +- 5 files changed, 55 insertions(+), 40 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index a09fa428..2ed875da 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -19,11 +19,13 @@ from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset from .interface import ( + async_clear_partition, async_clear_samples, async_get_data, async_get_meta, async_put, async_set_custom_meta, + clear_partition, clear_samples, get_data, get_meta, @@ -47,11 +49,14 @@ "put", "set_custom_meta", "clear_samples", + "clear_partition", + "async_clear_samples", "async_get_meta", "async_get_data", "async_put", "async_set_custom_meta", "async_clear_samples", + "async_clear_partition", ] + [ "TransferQueueClient", "StreamingDataset", diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index ea27d7f3..59cc369e 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -6,7 +6,7 @@ controller: sampler: SequentialSampler polling_mode: False # ZMQ Server IP & Ports (automatically generated during init) - zmq_info: None + zmq_info: null backend: # Pluggable storage/transport backend of TransferQueue. Choose from: @@ -20,7 +20,7 @@ backend: # Number of distributed storage units for SimpleStorage backend num_data_storage_units: 2 # ZMQ Server IP & Ports (automatically generated during init) - zmq_info: None + zmq_info: null # For Yuanrong: # TODO diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index d7715eea..e6a36a6a 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -53,6 +53,7 @@ def _maybe_create_transferqueue_client( backend_name = conf.backend.storage_backend conf.backend[backend_name].controller_info = conf.controller.zmq_info + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) return _TRANSFER_QUEUE_CLIENT @@ -77,11 +78,12 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: lifetime="detached", ).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units)) _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node - print(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") + logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") # extract zmq info storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) - conf.backend.zmq_info = storage_zmq_info + backend_name = conf.backend.storage_backend + conf.backend[backend_name].zmq_info = storage_zmq_info return conf @@ -120,46 +122,54 @@ def init(conf: Optional[DictConfig] = None) -> None: try: # Already initialize TransferQueue controller = ray.get_actor("TransferQueueController") - conf = ray.get(controller.get_config()) + logger.info("Found existing TransferQueueController instance. Connecting...") + while conf is None: - print("Waiting for controller to initialize... Retrying") - time.sleep(1) - conf = ray.get(ray.get_actor("TransferQueueController").get_config()) - _maybe_create_transferqueue_client(conf) + remote_conf = ray.get(controller.get_config()) + if remote_conf is not None: + _maybe_create_transferqueue_client(remote_conf) + logger.info("TransferQueueClient initialized.") + return + logger.debug("Waiting for controller to initialize... Retrying in 1s") + time.sleep(1) except ValueError: - # First-time initialize TransferQueue - # TODO: 是否会有竞态?文件锁好像都没法解决,需要调查Ray是否有相关机制 - # create config - final_conf = OmegaConf.create({}, flags={"allow_objects": True}) - with pkg_resources.path("transfer_queue", "config.yaml") as p: - default_conf = OmegaConf.load(p) - final_conf = OmegaConf.merge(final_conf, default_conf) - if conf: - final_conf = OmegaConf.merge(final_conf, conf) - - # create controller - try: - # TODO: support sampler instance - sampler = globals()[final_conf.controller.sampler] - except KeyError: - raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None - - controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined] - sampler=sampler, polling_mode=final_conf.controller.polling_mode - ) + logger.info("No TransferQueueController found. Starting first-time initialization...") + + # First-time initialize TransferQueue + # TODO: fix possible race condition, especially for cross-node scenario + + # create config + final_conf = OmegaConf.create({}, flags={"allow_objects": True}) + with pkg_resources.path("transfer_queue", "config.yaml") as p: + default_conf = OmegaConf.load(p) + final_conf = OmegaConf.merge(final_conf, default_conf) + if conf: + final_conf = OmegaConf.merge(final_conf, conf) + + # create controller + try: + # TODO: support sampler instance + sampler = globals()[final_conf.controller.sampler] + except KeyError: + raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None + + controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined] + sampler=sampler, polling_mode=final_conf.controller.polling_mode + ) + logger.info("TransferQueueController has been created.") - controller_zmq_info = process_zmq_server_info(controller) - final_conf.controller.zmq_info = controller_zmq_info + controller_zmq_info = process_zmq_server_info(controller) + final_conf.controller.zmq_info = controller_zmq_info - # create distributed storage backends - final_conf = _maybe_create_transferqueue_storage(final_conf) + # create distributed storage backends + final_conf = _maybe_create_transferqueue_storage(final_conf) - # storage the config into controller - ray.get(controller.store_config.remote(final_conf)) + # storage the config into controller + ray.get(controller.store_config.remote(final_conf)) - # create client - _maybe_create_transferqueue_client(final_conf) + # create client + _maybe_create_transferqueue_client(final_conf) def get_meta( diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 54eebb01..d1211ecf 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -62,10 +62,10 @@ def __init__(self, config: DictConfig[str, Any]): super().__init__(config) self.config = config - server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.zmq_info + server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("zmq_info", None) if server_infos is None: - server_infos = config.storage_unit_infos + server_infos = config.get("storage_unit_infos", None) if server_infos is not None: warnings.warn( "The config entry `storage_unit_infos` is deprecated, use `zmq_info` instead.", diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index b250d4cf..6d6f32ba 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -50,7 +50,7 @@ os.environ["RAY_DEBUG"] = "1" if not ray.is_initialized(): - ray.init() + ray.init(namespace="TransferQueueTutorial") def demonstrate_data_workflow(): From 3fd78d248de7ce018638e4d3c9abdda90b49e8c0 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Feb 2026 09:57:03 +0800 Subject: [PATCH 08/14] fix race condition Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 2 + transfer_queue/config.yaml | 1 + transfer_queue/controller.py | 24 +++++++ transfer_queue/interface.py | 132 +++++++++++++++++++++++------------ 4 files changed, 116 insertions(+), 43 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 68b746f7..3c6f05f5 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -1037,6 +1037,7 @@ def get_meta( data_fields: list[str], batch_size: int, partition_id: str, + mode: str = "fetch", task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, ) -> BatchMeta: @@ -1094,6 +1095,7 @@ def get_meta( data_fields=data_fields, batch_size=batch_size, partition_id=partition_id, + mode=mode, task_name=task_name, sampling_config=sampling_config, ) diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index 59cc369e..3af85466 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -8,6 +8,7 @@ controller: # ZMQ Server IP & Ports (automatically generated during init) zmq_info: null + backend: # Pluggable storage/transport backend of TransferQueue. Choose from: # SimpleStorage, Yuanrong, MooncakeStore, ... diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 07f1abf8..eb1d49c5 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1782,3 +1782,27 @@ def store_config(self, conf: DictConfig) -> None: def get_config(self) -> DictConfig: """Retrieve the global config of TransferQueue.""" return self.tq_conf + + def register_sampler( + self, + sampler: BaseSampler | type[BaseSampler] = SequentialSampler, + ) -> None: + """ + Register a sampler instance or subclass after the controller is initialized. + + Args: + sampler: Sampler instance or sampler class to use for data sampling. + - If a BaseSampler instance is provided, it will be used directly + - If a BaseSampler subclass is provided, it will be instantiated + - Defaults to SequentialSampler for simple sequential sampling + - Example: sampler=GRPOGroupNSampler() (instance) + - Example: sampler=SequentialSampler (class) + """ + if isinstance(sampler, BaseSampler): + self.sampler = sampler + elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): + self.sampler = sampler() + else: + raise TypeError( + f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler" + ) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index e6a36a6a..f3bbb3ed 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -28,6 +28,7 @@ from transfer_queue.controller import TransferQueueController from transfer_queue.metadata import BatchMeta from transfer_queue.sampler import * # noqa: F401 +from transfer_queue.sampler import BaseSampler from transfer_queue.storage.simple_backend import SimpleStorageUnit from transfer_queue.utils.common import get_placement_group from transfer_queue.utils.zmq_utils import process_zmq_server_info @@ -88,6 +89,24 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: return conf +def _init_from_existing() -> None: + """Initialize the TransferQueueClient from existing controller.""" + + controller = ray.get_actor("TransferQueueController") + logger.info("Found existing TransferQueueController instance. Connecting...") + + conf = None + while conf is None: + remote_conf = ray.get(controller.get_config()) + if remote_conf is not None: + _maybe_create_transferqueue_client(remote_conf) + logger.info("TransferQueueClient initialized.") + return + + logger.debug("Waiting for controller to initialize... Retrying in 1s") + time.sleep(1) + + def init(conf: Optional[DictConfig] = None) -> None: """Initialize the TransferQueue system. @@ -120,24 +139,11 @@ def init(conf: Optional[DictConfig] = None) -> None: >>> data = tq.get_data(metadata) """ try: - # Already initialize TransferQueue - controller = ray.get_actor("TransferQueueController") - logger.info("Found existing TransferQueueController instance. Connecting...") - - while conf is None: - remote_conf = ray.get(controller.get_config()) - if remote_conf is not None: - _maybe_create_transferqueue_client(remote_conf) - logger.info("TransferQueueClient initialized.") - return - - logger.debug("Waiting for controller to initialize... Retrying in 1s") - time.sleep(1) + _init_from_existing() except ValueError: logger.info("No TransferQueueController found. Starting first-time initialization...") # First-time initialize TransferQueue - # TODO: fix possible race condition, especially for cross-node scenario # create config final_conf = OmegaConf.create({}, flags={"allow_objects": True}) @@ -149,15 +155,30 @@ def init(conf: Optional[DictConfig] = None) -> None: # create controller try: - # TODO: support sampler instance - sampler = globals()[final_conf.controller.sampler] + sampler = final_conf.conroller.sampler + if isinstance(sampler, BaseSampler): + # user pass a pre-initialized sampler instance + sampler = sampler + elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): + # user pass a sampler class + sampler = sampler() + elif isinstance(sampler, str): + # user pass a sampler name str + # try to convert as sampler class + sampler = globals()[final_conf.controller.sampler] except KeyError: raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None - controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined] - sampler=sampler, polling_mode=final_conf.controller.polling_mode - ) - logger.info("TransferQueueController has been created.") + try: + # Ray will make sure actor with same name can only be created once + controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined] + sampler=sampler, polling_mode=final_conf.controller.polling_mode + ) + logger.info("TransferQueueController has been created.") + except ValueError: + logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.") + _init_from_existing() + return controller_zmq_info = process_zmq_server_info(controller) final_conf.controller.zmq_info = controller_zmq_info @@ -176,6 +197,7 @@ def get_meta( data_fields: list[str], batch_size: int, partition_id: str, + mode: str = "fetch", task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, ) -> BatchMeta: @@ -200,8 +222,11 @@ def get_meta( RuntimeError: If communication fails or controller returns error response Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> >>> # Example 1: Basic fetch metadata - >>> batch_meta = client.get_meta( + >>> batch_meta = tq.get_meta( ... data_fields=["input_ids", "attention_mask"], ... batch_size=4, ... partition_id="train_0", @@ -211,7 +236,7 @@ def get_meta( >>> print(batch_meta.is_ready) # True if all samples ready >>> >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = client.get_meta( + >>> batch_meta = tq.get_meta( ... data_fields=["input_ids", "attention_mask"], ... batch_size=8, ... partition_id="train_0", @@ -223,7 +248,7 @@ def get_meta( >>> >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = client.get_meta( + >>> batch_meta = tq.get_meta( ... partition_id="train_0", # optional ... mode="force_fetch", ... ) @@ -231,7 +256,7 @@ def get_meta( """ tq_client = _maybe_create_transferqueue_client() - return tq_client.get_meta(data_fields, batch_size, partition_id, task_name, sampling_config) + return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) async def async_get_meta( @@ -263,8 +288,11 @@ async def async_get_meta( RuntimeError: If communication fails or controller returns error response Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> >>> # Example 1: Basic fetch metadata - >>> batch_meta = asyncio.run(client.async_get_meta( + >>> batch_meta = asyncio.run(tq.async_get_meta( ... data_fields=["input_ids", "attention_mask"], ... batch_size=4, ... partition_id="train_0", @@ -274,7 +302,7 @@ async def async_get_meta( >>> print(batch_meta.is_ready) # True if all samples ready >>> >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = asyncio.run(client.async_get_meta( + >>> batch_meta = asyncio.run(tq.async_get_meta( ... data_fields=["input_ids", "attention_mask"], ... batch_size=8, ... partition_id="train_0", @@ -285,7 +313,7 @@ async def async_get_meta( >>> >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = asyncio.run(client.async_get_meta( + >>> batch_meta = asyncio.run(tq.async_get_meta( ... partition_id="train_0", # optional ... mode="force_fetch", ... )) @@ -307,14 +335,17 @@ def get_data(metadata: BatchMeta) -> TensorDict: - Requested data fields (e.g., "prompts", "attention_mask") Example: - >>> batch_meta = client.get_data( + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_meta = tq.get_data( ... data_fields=["prompts", "attention_mask"], ... batch_size=4, ... partition_id="train_0", ... mode="fetch", ... task_name="generate_sequences", ... ) - >>> batch = client.get_data(batch_meta) + >>> batch = tq.get_data(batch_meta) >>> print(batch) >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes """ @@ -333,14 +364,17 @@ async def async_get_data(metadata: BatchMeta) -> TensorDict: - Requested data fields (e.g., "prompts", "attention_mask") Example: - >>> batch_meta = asyncio.run(client.async_get_meta( + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_meta = asyncio.run(tq.async_get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=4, ... partition_id="train_0", ... mode="fetch", ... task_name="generate_sequences", ... )) - >>> batch = asyncio.run(client.async_get_data(batch_meta)) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) >>> print(batch) >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes """ @@ -376,20 +410,23 @@ def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Op RuntimeError: If storage operation fails Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> >>> batch_size = 4 >>> seq_len = 16 >>> current_partition_id = "train_0" >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = client.get_meta( + >>> batch_meta = tq.get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=batch_size, ... partition_id=current_partition_id, ... mode="fetch", ... task_name="generate_sequences", ... ) - >>> batch = client.get_data(batch_meta) + >>> batch = tq.get_data(batch_meta) >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> client.put(data=output, metadata=batch_meta) + >>> tq.put(data=output, metadata=batch_meta) >>> >>> # Example 2: Initial data insertion without pre-existing metadata >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! @@ -402,7 +439,7 @@ def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Op >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) >>> # This will create metadata in "insert" mode internally. - >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) + >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id) """ tq_client = _maybe_create_transferqueue_client() return tq_client.put(data, metadata, partition_id) @@ -440,20 +477,23 @@ async def async_put( RuntimeError: If storage operation fails Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> >>> batch_size = 4 >>> seq_len = 16 >>> current_partition_id = "train_0" >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = asyncio.run(client.async_get_meta( + >>> batch_meta = asyncio.run(tq.async_get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=batch_size, ... partition_id=current_partition_id, ... mode="fetch", ... task_name="generate_sequences", ... )) - >>> batch = asyncio.run(client.async_get_data(batch_meta)) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> asyncio.run(client.async_put(data=output, metadata=batch_meta)) + >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta)) >>> >>> # Example 2: Initial data insertion without pre-existing metadata >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! @@ -466,7 +506,7 @@ async def async_put( >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) >>> # This will create metadata in "insert" mode internally. - >>> metadata = asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) + >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) """ tq_client = _maybe_create_transferqueue_client() return await tq_client.async_put(data, metadata, partition_id) @@ -488,10 +528,13 @@ def set_custom_meta(metadata: BatchMeta) -> None: RuntimeError: If communication fails or controller returns error response Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> >>> # Create batch with custom metadata - >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) - >>> client.set_custom_meta(batch_meta) + >>> tq.set_custom_meta(batch_meta) """ tq_client = _maybe_create_transferqueue_client() return tq_client.set_custom_meta(metadata) @@ -517,10 +560,13 @@ async def async_set_custom_meta( RuntimeError: If communication fails or controller returns error response Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> >>> # Create batch with custom metadata - >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) - >>> asyncio.run(client.async_set_custom_meta(batch_meta)) + >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) """ tq_client = _maybe_create_transferqueue_client() return await tq_client.async_set_custom_meta(metadata) From 5d51de3433346c172af8123fd7e66af0e7583abd Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Feb 2026 10:31:30 +0800 Subject: [PATCH 09/14] fix typo Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index f3bbb3ed..14108d38 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -155,7 +155,8 @@ def init(conf: Optional[DictConfig] = None) -> None: # create controller try: - sampler = final_conf.conroller.sampler + print(final_conf) + sampler = final_conf.controller.sampler if isinstance(sampler, BaseSampler): # user pass a pre-initialized sampler instance sampler = sampler From 1346318a665ae96ac0448d9a94d7026a610eb553 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Feb 2026 13:52:31 +0800 Subject: [PATCH 10/14] simplify storage manager init Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 9 +- tests/test_kv_storage_manager.py | 10 +- transfer_queue/__init__.py | 2 + transfer_queue/client.py | 4 +- transfer_queue/config.yaml | 3 +- .../dataloader/streaming_dataset.py | 46 ++++++-- transfer_queue/interface.py | 33 ++++-- transfer_queue/storage/managers/base.py | 10 +- transfer_queue/storage/managers/factory.py | 13 ++- .../storage/managers/mooncake_manager.py | 5 +- .../managers/simple_backend_manager.py | 6 +- .../storage/managers/yuanrong_manager.py | 5 +- tutorial/02_metadata_concepts.py | 51 ++------- tutorial/03_understanding_controller.py | 101 ++++++------------ tutorial/04_custom_sampler.py | 77 +++++-------- tutorial/05_streaming_dataloader.py | 80 ++++++-------- 16 files changed, 200 insertions(+), 255 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 3949a0dc..a4990fe1 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -63,7 +63,6 @@ async def mock_async_storage_manager(): config = { "storage_unit_infos": storage_unit_infos, - "controller_info": controller_info, } # Mock the handshake process entirely to avoid ZMQ complexity @@ -200,7 +199,6 @@ async def test_async_storage_manager_mapping_functions(): config = { "storage_unit_infos": storage_unit_infos, - "controller_info": controller_info, } # Mock ZMQ operations @@ -230,7 +228,7 @@ async def test_async_storage_manager_mapping_functions(): mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) # Create manager - manager = AsyncSimpleStorageManager(config) + manager = AsyncSimpleStorageManager(controller_info, config) # Test round-robin mapping for 3 storage units # global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2, @@ -266,7 +264,7 @@ async def test_async_storage_manager_error_handling(): } # Mock controller info - controller_infos = ZMQServerInfo( + controller_info = ZMQServerInfo( role=TransferQueueRole.CONTROLLER, id="controller_0", ip="127.0.0.1", @@ -275,7 +273,6 @@ async def test_async_storage_manager_error_handling(): config = { "storage_unit_infos": storage_unit_infos, - "controller_info": controller_infos, } # Mock ZMQ operations @@ -305,7 +302,7 @@ async def test_async_storage_manager_error_handling(): mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) # Create manager - manager = AsyncSimpleStorageManager(config) + manager = AsyncSimpleStorageManager(controller_info, config) # Mock operations that raise exceptions manager._put_to_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock PUT error")) diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 41296dd5..083e16d6 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -117,7 +117,7 @@ def test_merge_tensors_to_tensordict(mock_create, test_data): mock_client = MagicMock() mock_create.return_value = mock_client - manager = KVStorageManager(test_data["cfg"]) + manager = KVStorageManager(controller_info=MagicMock(), config=test_data["cfg"]) assert manager.storage_client is mock_client assert manager._multi_threads_executor is None @@ -296,9 +296,9 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo mock_storage_client.put.return_value = mock_custom_meta # Create manager with mocked dependencies - config = {"controller_info": MagicMock(), "client_name": "MockClient"} + config = {"client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"])) @@ -348,7 +348,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data): # Create manager with mocked dependencies config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"])) @@ -371,7 +371,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat # Create manager with mocked dependencies config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data and expect ValueError with pytest.raises(ValueError) as exc_info: diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 2ed875da..03f6ce9c 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -27,6 +27,7 @@ async_set_custom_meta, clear_partition, clear_samples, + close, get_data, get_meta, init, @@ -57,6 +58,7 @@ "async_set_custom_meta", "async_clear_samples", "async_clear_partition", + "close", ] + [ "TransferQueueClient", "StreamingDataset", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 3c6f05f5..7b46bffb 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -95,7 +95,9 @@ def initialize_storage_manager( - storage_unit_infos: ZMQ server information about the storage units """ - self.storage_manager = TransferQueueStorageManagerFactory.create(manager_type, config) + self.storage_manager = TransferQueueStorageManagerFactory.create( + manager_type, controller_info=self._controller, config=config + ) # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. @staticmethod diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index 3af85466..9503afa4 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -1,9 +1,10 @@ # This is the default configuration of TransferQueue. Users may modify the default value # and use transfer_queue.init(conf) to overwrite the config entries. - controller: + # User-defined sampler. User can pass sampler instance to overwrite this string config. sampler: SequentialSampler + # Whether return an empty BatchMeta to prevent request blocking when no enough data is available polling_mode: False # ZMQ Server IP & Ports (automatically generated during init) zmq_info: null diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index eb2afe49..1de90c5b 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -17,8 +17,10 @@ import os import time import uuid -from typing import Any, Callable, Iterator +import warnings +from typing import Callable, Iterator +from omegaconf import DictConfig from tensordict import TensorDict from torch.utils.data import IterableDataset @@ -68,7 +70,7 @@ class StreamingDataset(IterableDataset): def __init__( self, - config: dict[str, Any], + config: DictConfig, batch_size: int, micro_batch_size: int, data_fields: list[str], @@ -82,8 +84,8 @@ def __init__( Args: config: Configuration dictionary containing: - - controller_info: ZMQServerInfo for the TransferQueueController - - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager") + - controller.controller_info: ZMQServerInfo for the TransferQueueController + - backend.storage_backend: Storage backend type (e.g., "SimpleStorage") - Other backend-specific configuration batch_size: Batch size for data loading per iter. micro_batch_size: Number of samples per micro-batch. This is the batch size @@ -156,16 +158,44 @@ def _create_client(self): ValueError: If controller_info or storage_backend is missing or invalid. """ client_id = uuid.uuid4().hex[:8] - controller_info = self.config.get("controller_info", None) + + # TODO: DEPRECATE in future + controller_config = self.config.get("controller", None) + if controller_config: + controller_info = controller_config.get("zmq_info", None) + else: + controller_info = self.config.get("controller_info", None) + if controller_info: + warnings.warn( + "Config entry `controller_info` will be deprecated in 0.1.7, please " + "use `controller.zmq_info` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + if not controller_info or not isinstance(controller_info, ZMQServerInfo): - raise ValueError("Invalid or missing controller_info in config") + raise ValueError("Invalid or missing controller.zmq_info in config") + + backend_config = self.config.get("backend", None) + if not backend_config: + storage_backend = self.config.get("storage_backend", None) + backend_config = self.config + if storage_backend: + warnings.warn( + "Config entry `storage_backend` will be deprecated in 0.1.7, please " + "use `backend.storage_backend` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + else: + storage_backend = backend_config.get("storage_backend", None) + backend_config = self.config.backend[storage_backend] - storage_backend = self.config.get("storage_backend", None) if not storage_backend: raise ValueError("Missing storage_backend in config") self._tq_client = TransferQueueClient(client_id, controller_info) - self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=self.config) + self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config) def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: """Iterate over the dataset, yielding batches of data. diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 14108d38..6b8c8920 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -81,11 +81,6 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") - # extract zmq info - storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) - backend_name = conf.backend.storage_backend - conf.backend[backend_name].zmq_info = storage_zmq_info - return conf @@ -97,7 +92,7 @@ def _init_from_existing() -> None: conf = None while conf is None: - remote_conf = ray.get(controller.get_config()) + remote_conf = ray.get(controller.get_config.remote()) if remote_conf is not None: _maybe_create_transferqueue_client(remote_conf) logger.info("TransferQueueClient initialized.") @@ -155,7 +150,6 @@ def init(conf: Optional[DictConfig] = None) -> None: # create controller try: - print(final_conf) sampler = final_conf.controller.sampler if isinstance(sampler, BaseSampler): # user pass a pre-initialized sampler instance @@ -189,6 +183,7 @@ def init(conf: Optional[DictConfig] = None) -> None: # storage the config into controller ray.get(controller.store_config.remote(final_conf)) + logger.info(f"TransferQueue config: {final_conf}") # create client _maybe_create_transferqueue_client(final_conf) @@ -623,3 +618,27 @@ async def async_clear_partition(partition_id: str): """ tq_client = _maybe_create_transferqueue_client() return await tq_client.async_clear_partition(partition_id) + + +def close(): + """Close the TransferQueue system.""" + global _TRANSFER_QUEUE_CLIENT + global _TRANSFER_QUEUE_STORAGE + if _TRANSFER_QUEUE_CLIENT: + _TRANSFER_QUEUE_CLIENT.close() + _TRANSFER_QUEUE_CLIENT = None + + try: + if _TRANSFER_QUEUE_STORAGE: + # only the process that do first-time init can clean the distributed storage + for storage in _TRANSFER_QUEUE_STORAGE.values(): + ray.kill(storage) + _TRANSFER_QUEUE_STORAGE = None + except Exception: + pass + + try: + controller = ray.get_actor("TransferQueueController") + ray.kill(controller) + except Exception: + pass diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index c48b5991..352b3adc 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -60,12 +60,10 @@ class TransferQueueStorageManager(ABC): """Base class for storage layer. It defines the interface for data operations and generally provides handshake & notification capabilities.""" - def __init__(self, config: DictConfig[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}" self.config = config - controller_info = config.get("controller_info") - assert controller_info is not None, "controller_info is required" - self.controller_info: ZMQServerInfo = controller_info + self.controller_info = controller_info self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None @@ -352,14 +350,14 @@ class KVStorageManager(TransferQueueStorageManager): It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations. """ - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): """ Initialize the KVStorageManager with configuration. """ client_name = config.get("client_name", None) if client_name is None: raise ValueError("Missing client_name in config") - super().__init__(config) + super().__init__(controller_info, config) self.storage_client = StorageClientFactory.create(client_name, config) self._multi_threads_executor: Optional[ThreadPoolExecutor] = None # Register a cleanup function: automatically invoke shutdown when the instance is garbage collected. diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py index 84738e1f..04d2cd03 100644 --- a/transfer_queue/storage/managers/factory.py +++ b/transfer_queue/storage/managers/factory.py @@ -17,6 +17,7 @@ from typing import Any from transfer_queue.storage.managers.base import TransferQueueStorageManager +from transfer_queue.utils.zmq_utils import ZMQServerInfo class TransferQueueStorageManagerFactory: @@ -40,26 +41,28 @@ def decorator(manager_cls: type[TransferQueueStorageManager]): return decorator @classmethod - def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager: + def create( + cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any] + ) -> TransferQueueStorageManager: """Create and return a TransferQueueStorageManager instance.""" if manager_type not in cls._registry: if manager_type == "AsyncSimpleStorageManager": warnings.warn( - f"The manager_type {manager_type} is deprecated, use SimpleStorage instead.", + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.", category=DeprecationWarning, stacklevel=2, ) manager_type = "SimpleStorage" elif manager_type == "MooncakeStorageManager": warnings.warn( - f"The manager_type {manager_type} is deprecated, use MooncakeStore instead.", + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.", category=DeprecationWarning, stacklevel=2, ) manager_type = "MooncakeStore" elif manager_type == "YuanrongStorageManager": warnings.warn( - f"The manager_type {manager_type} is deprecated, use Yuanrong instead.", + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.", category=DeprecationWarning, stacklevel=2, ) @@ -68,4 +71,4 @@ def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStora raise ValueError( f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" ) - return cls._registry[manager_type](config) + return cls._registry[manager_type](controller_info, config) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index e787ffbb..9f6f93a6 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -19,6 +19,7 @@ from transfer_queue.storage.managers.base import KVStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -28,7 +29,7 @@ class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): # Required: Address of the HTTP metadata server (e.g., "localhost:8080") metadata_server = config.get("metadata_server", None) # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") @@ -45,4 +46,4 @@ def __init__(self, config: dict[str, Any]): config["client_name"] = "MooncakeStorageClient" elif client_name != "MooncakeStorageClient": raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStorageClient'") - super().__init__(config) + super().__init__(controller_info, config) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index d1211ecf..94c30f5e 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -58,8 +58,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager): instances using ZMQ communication and dynamic socket management. """ - def __init__(self, config: DictConfig[str, Any]): - super().__init__(config) + def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): + super().__init__(controller_info, config) self.config = config server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("zmq_info", None) @@ -68,7 +68,7 @@ def __init__(self, config: DictConfig[str, Any]): server_infos = config.get("storage_unit_infos", None) if server_infos is not None: warnings.warn( - "The config entry `storage_unit_infos` is deprecated, use `zmq_info` instead.", + "The config entry `storage_unit_infos` will be deprecated in 0.1.7, please use `zmq_info` instead.", category=DeprecationWarning, stacklevel=2, ) diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index 6935203c..54ac0942 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -19,6 +19,7 @@ from transfer_queue.storage.managers.base import KVStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -34,7 +35,7 @@ class YuanrongStorageManager(KVStorageManager): """Storage manager for Yuanrong backend.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): host = config.get("host", None) port = config.get("port", None) client_name = config.get("client_name", None) @@ -48,4 +49,4 @@ def __init__(self, config: dict[str, Any]): config["client_name"] = "YuanrongStorageClient" elif client_name != "YuanrongStorageClient": raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'YuanrongStorageClient'") - super().__init__(config) + super().__init__(controller_info, config) diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index 93e59ac7..ecff7ef9 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -38,19 +38,13 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @@ -308,32 +302,8 @@ def demonstrate_real_workflow(): if not ray.is_initialized(): ray.init() - # Setup TransferQueue - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote() - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + # Initialize TransferQueue + tq.init() print("[Step 1] Putting data into TransferQueue...") input_ids = torch.randint(0, 1000, (8, 512)) @@ -348,7 +318,7 @@ def demonstrate_real_workflow(): ) partition_id = "demo_partition" - batch_meta = client.put(data=data_batch, partition_id=partition_id) + batch_meta = tq.put(data=data_batch, partition_id=partition_id) print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.") print("[Step 2] [Optional] Setting sample-level custom_meta...") @@ -360,11 +330,11 @@ def demonstrate_real_workflow(): batch_meta.update_custom_meta(custom_meta) print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}") - client.set_custom_meta(batch_meta) + tq.set_custom_meta(batch_meta) print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.") print("[Step 3] Try to get metadata from TransferQueue from other places...") - batch_meta = client.get_meta( + batch_meta = tq.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=8, partition_id=partition_id, @@ -383,7 +353,7 @@ def demonstrate_real_workflow(): print("✓ Selected 'input_ids' field only:") print(f" New field names: {selected_meta.field_names}") print(f" Samples still have same global indexes: {selected_meta.global_indexes}") - retrieved_data = client.get_data(selected_meta) + retrieved_data = tq.get_data(selected_meta) print(f" Retrieved data keys: {list(retrieved_data.keys())}") print("[Step 5] Select specific samples from the retrieved BatchMeta...") @@ -391,7 +361,7 @@ def demonstrate_real_workflow(): print("✓ Selected samples at indices [0, 2, 4, 6]:") print(f" New global indexes: {partial_meta.global_indexes}") print(f" Number of samples: {len(partial_meta)}") - retrieved_data = client.get_data(partial_meta) + retrieved_data = tq.get_data(partial_meta) print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}") print("[Step 6] Demonstrate chunk operation...") @@ -399,12 +369,11 @@ def demonstrate_real_workflow(): print(f"✓ Chunked into {len(chunks)} parts:") for i, chunk in enumerate(chunks): print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") - chunk_data = client.get_data(chunk) + chunk_data = tq.get_data(chunk) print(f" Chunk {i}: Retrieved chunk data: {chunk_data}") # Cleanup - client.clear_partition(partition_id=partition_id) - client.close() + tq.clear_partition(partition_id=partition_id) ray.shutdown() print("✓ Partition cleared and resources cleaned up") diff --git a/tutorial/03_understanding_controller.py b/tutorial/03_understanding_controller.py index 8b7e6d04..4ca426d4 100644 --- a/tutorial/03_understanding_controller.py +++ b/tutorial/03_understanding_controller.py @@ -36,59 +36,19 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" -def setup_transfer_queue(): - """Setup TransferQueue components.""" - if not ray.is_initialized(): - ray.init() - - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote() - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - return controller, storage_units, client - - def demonstrate_partition_isolation(): """Feature 1: Different partitions are isolated - data doesn't interfere.""" print("=" * 80) @@ -97,7 +57,10 @@ def demonstrate_partition_isolation(): print("\nDifferent partitions are completely isolated - data doesn't interfere between partitions") - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Partition 1: Training data print("\n[Partition 1] Putting training data...") @@ -108,7 +71,7 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - client.put(data=train_data, partition_id="train") + tq.put(data=train_data, partition_id="train") print(" ✓ Training data added to 'train' partition") # Partition 2: Validation data @@ -120,25 +83,23 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - client.put(data=val_data, partition_id="val") + tq.put(data=val_data, partition_id="val") print(" ✓ Validation data added to 'val' partition") # Get from train partition print("\n[Retrieving from 'train' partition]") - train_meta = client.get_meta( + train_meta = tq.get_meta( data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task" ) - retrieved_train_data = client.get_data(train_meta) + retrieved_train_data = tq.get_data(train_meta) print(f" ✓ Got BatchMeta={train_meta} from partition 'train'") print(f" ✓ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}") # Get from val partition print("\n[Retrieving from 'val' partition]") - val_meta = client.get_meta( - data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task" - ) - retrieved_val_data = client.get_data(val_meta) + val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task") + retrieved_val_data = tq.get_data(val_meta) print(f" ✓ Got BatchMeta={val_meta} from partition 'val'") print(f" ✓ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}") @@ -146,9 +107,9 @@ def demonstrate_partition_isolation(): print(" ✓ Data isolation: 'train' and 'val' partitions are completely independent") # Cleanup - client.clear_partition(partition_id="train") - client.clear_partition(partition_id="val") - client.close() + tq.clear_partition(partition_id="train") + tq.clear_partition(partition_id="val") + tq.close() ray.shutdown() @@ -160,7 +121,10 @@ def demonstrate_dynamic_expansion(): print("\nPartitions dynamically expand to accommodate new data (rows and columns)") - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Add first batch with 2 samples, 2 fields print("\n[Step 1] Adding initial data (2 samples, 2 fields)...") @@ -171,7 +135,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta1 = client.put(data=data1, partition_id="dynamic") + meta1 = tq.put(data=data1, partition_id="dynamic") print(" ✓ Added 2 samples") print(f" ✓ Got BatchMeta: {meta1} samples") @@ -184,9 +148,9 @@ def demonstrate_dynamic_expansion(): }, batch_size=3, ) - meta2 = client.put(data=data2, partition_id="dynamic") + meta2 = tq.put(data=data2, partition_id="dynamic") - all_meta = client.get_meta( + all_meta = tq.get_meta( data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task" ) print(" ✓ Added 3 more samples (total: 5)") @@ -201,7 +165,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta3 = client.put(data=data3, metadata=meta1) + meta3 = tq.put(data=data3, metadata=meta1) print(" ✓ Added 2 samples with new field 'field3'") print(f" ✓ Got BatchMeta: {meta3} for newly put data with new field") @@ -210,8 +174,8 @@ def demonstrate_dynamic_expansion(): print(" ✓ Columns auto-expand: Can add new fields anytime") # Cleanup - client.clear_partition(partition_id="dynamic") - client.close() + tq.clear_partition(partition_id="dynamic") + tq.close() ray.shutdown() @@ -221,7 +185,10 @@ def demonstrate_default_consumption_sample_strategy(): print("Feature 3: Default Sampling Strategy for Controller - No Duplicate, Sequential Samples") print("=" * 80) - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Add 6 samples print("\n[Setup] Adding 6 samples...") @@ -231,22 +198,22 @@ def demonstrate_default_consumption_sample_strategy(): }, batch_size=6, ) - client.put(data=all_data, partition_id="sampling") + tq.put(data=all_data, partition_id="sampling") print(" ✓ 6 samples added") # First get - should get samples 0,1,2 print("\n[Task A, Get 1] Requesting 3 samples...") - meta1 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta1.global_indexes}") # Second get - should get samples 3,4,5 (no duplicates!) print("\n[Task A, Get 2] Requesting 3 more samples...") - meta2 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta2.global_indexes}") # Third get - should get samples 0,1 print("\n[Task B, Get 1] Requesting 2 samples...") - meta3 = client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") + meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -257,8 +224,8 @@ def demonstrate_default_consumption_sample_strategy(): print(" ✓ Third get (Task B): samples 0,1") # Cleanup - client.clear_partition(partition_id="sampling") - client.close() + tq.clear_partition(partition_id="sampling") + tq.close() ray.shutdown() diff --git a/tutorial/04_custom_sampler.py b/tutorial/04_custom_sampler.py index 7bf13cd1..c35b4e0a 100644 --- a/tutorial/04_custom_sampler.py +++ b/tutorial/04_custom_sampler.py @@ -49,12 +49,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 from transfer_queue.sampler import BaseSampler # noqa: E402 @@ -171,36 +166,14 @@ def sample( def setup_transfer_queue_with_sampler(sampler): """Setup TransferQueue with custom sampler.""" if not ray.is_initialized(): - ray.init() + ray.init(namespace="TransferQueueTutorial") config = OmegaConf.create( - { - "global_batch_size": 8, - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(2): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote(sampler=sampler) - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, + {"controller": {"sampler": sampler}, "backend": {"SimpleStorage": {"num_data_storage_units": 2}}}, + flags={"allow_objects": True}, ) - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - return controller, storage_units, client + tq.init(config) def demonstrate_random_sampler_with_replacement(): @@ -211,7 +184,7 @@ def demonstrate_random_sampler_with_replacement(): print("\nSetup TransferQueue with RandomSamplerWithReplacement...") sampler = RandomSamplerWithReplacement() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 5 samples print("\n[Step 1] Adding 5 samples...") @@ -221,22 +194,22 @@ def demonstrate_random_sampler_with_replacement(): }, batch_size=5, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 5 samples added") # Get batch 1 (should get 2 random samples) print("\n[Step 2] Get batch 1 (2 samples)...") - meta1 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!) print("\n[Step 3] Get batch 2 (1 sample)...") - meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -246,8 +219,8 @@ def demonstrate_random_sampler_with_replacement(): print(f" ✓ All sampled: {all_sampled}") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() @@ -259,7 +232,7 @@ def demonstrate_random_sampler_without_replacement(): print("\nSetup TransferQueue with RandomSamplerWithoutReplacement...") sampler = RandomSamplerWithoutReplacement() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 6 samples print("\n[Step 1] Adding 6 samples...") @@ -269,22 +242,22 @@ def demonstrate_random_sampler_without_replacement(): }, batch_size=6, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 6 samples added") # Get batch 1 (should get 3 random samples without replacement) print("\n[Step 2] Get batch 1 (3 samples)...") - meta1 = client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") + meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should randomly get 1 sample that are different from previous batch) print("\n[Step 3] Get batch 2 (1 samples)...") - meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should randomly get 2 samples that are different from previous batch) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -294,8 +267,8 @@ def demonstrate_random_sampler_without_replacement(): print(f" ✓ Batch 3: {meta3.global_indexes} (none left)") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() @@ -307,7 +280,7 @@ def demonstrate_priority_sampler(): print("\nSetup TransferQueue with PrioritySampler...") sampler = PrioritySampler() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 8 samples print("\n[Step 1] Adding 8 samples...") @@ -317,7 +290,7 @@ def demonstrate_priority_sampler(): }, batch_size=8, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 8 samples added") time.sleep(1) @@ -330,7 +303,7 @@ def demonstrate_priority_sampler(): print(f"Priority scores: {priority_scores}") # Get batch using priority sampling - meta1 = client.get_meta( + meta1 = tq.get_meta( data_fields=["input"], batch_size=1, partition_id="test", @@ -342,7 +315,7 @@ def demonstrate_priority_sampler(): # Get another batch print("\n[Step 3] Get another batch (2 samples)...") - meta2 = client.get_meta( + meta2 = tq.get_meta( data_fields=["input"], batch_size=2, partition_id="test", @@ -358,8 +331,8 @@ def demonstrate_priority_sampler(): print(f" ✓ Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py index 916be001..4d92aa26 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/05_streaming_dataloader.py @@ -57,39 +57,25 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import DictConfig, OmegaConf # noqa: E402 +from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path for imports parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) - +import transfer_queue as tq # noqa: E402 from transfer_queue import ( # noqa: E402 RankAwareSampler, - SimpleStorageUnit, StreamingDataLoader, StreamingDataset, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, ) def setup_transfer_queue(): """Setup TransferQueue components.""" if not ray.is_initialized(): - ray.init() - - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) + ray.init(namespace="TransferQueueTutorial") print("[Setup]: Setup TransferQueue components") print( @@ -101,26 +87,23 @@ def setup_transfer_queue(): "TransferQueueController. In polling_mode, the controller will return empty BatchMeta when " "available data cannot meet the consumption requirements. User side need to retry later." ) - controller = TransferQueueController.remote( - sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank - polling_mode=True, # Enable polling mode for streaming data retrieval - ) - - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - # Build the complete configuration - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config.storage_backend = "AsyncSimpleStorageManager" - config = OmegaConf.merge(tq_config, config) + config = OmegaConf.create( + { + "controller": { + "sampler": RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank + "polling_mode": True, # Enable polling mode for streaming data retrieval + }, + "backend": {"SimpleStorage": {"num_data_storage_units": 2}}, + }, + flags={"allow_objects": True}, + ) - return controller, storage_units, config + tq.init(config) @ray.remote(num_cpus=0.1) -def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): +def generate_worker(rank_id: int, num_samples: int = 20): """ Generate actor that produces training samples. @@ -129,7 +112,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): Args: rank_id: Unique identifier for this generator (used for sample indexing) - config: TransferQueue configuration num_samples: Number of samples to generate Note: @@ -137,13 +119,9 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): This ensures global uniqueness across all generator actors. """ # Create a client for interacting with TransferQueue - client = TransferQueueClient( - client_id=f"gen_worker_{rank_id}", - controller_info=config.controller_info, - ) - # Initialize the storage manager for this client - client.initialize_storage_manager(manager_type=config.storage_backend, config=config) + # Need to call tq.init() in each process + tq.init() # Generate and put samples into the queue for i in range(num_samples): @@ -159,7 +137,7 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue") # Put data into the specified partition - client.put(data, partition_id="train") + tq.put(data, partition_id="train") print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue") @@ -168,7 +146,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): def update_worker( rank_id: int, dp_rank: int, - config: DictConfig, max_steps: int = 5, ): """ @@ -182,7 +159,6 @@ def update_worker( rank_id: Global rank identifier for logging and display purposes dp_rank: Data parallel rank ID that this worker belongs to The same Ranks receive the same data samples - config: TransferQueue configuration max_steps: Maximum number of batches to consume Returns: @@ -200,8 +176,15 @@ def update_worker( - batch_meta: Metadata for TransferQueue coordination (contains global_indexes) """ + # Need to call tq.init() in each process + tq.init() + # Step 1: Create StreamingDataset # This dataset integrates with TransferQueue and handles batch retrieval + + controller = ray.get_actor("TransferQueueController") + config = ray.get(controller.get_config.remote()) + dataset = StreamingDataset( config=config, batch_size=2, @@ -253,7 +236,7 @@ def update_worker( } -def start_all_generate_actors(config): +def start_all_generate_actors(): """ Launch generate_actors for producing training samples. """ @@ -261,12 +244,12 @@ def start_all_generate_actors(config): handlers = [] for i in range(num_workers): - handlers.append(generate_worker.remote(rank_id=i, config=config, num_samples=20)) + handlers.append(generate_worker.remote(rank_id=i, num_samples=20)) return handlers -def start_all_update_actors(config): +def start_all_update_actors(): """ Launch update_actors for consuming training samples. """ @@ -285,7 +268,6 @@ def start_all_update_actors(config): update_worker.remote( rank_id=rank_ids[i], dp_rank=dp_rank[i], - config=config, ) ) @@ -331,15 +313,15 @@ def main(): "global_batch_size to make sure consumers can accurately determine consumption status even before " "producers have generated the samples." ) - controller, storage_units, config = setup_transfer_queue() + setup_transfer_queue() # Step 2: Launch data generation actors print("\n[Phase 2] Starting data generation...") - generate_worker_handlers = start_all_generate_actors(config) + generate_worker_handlers = start_all_generate_actors() # Step 3: Launch data consumption actors print("\n[Phase 3] Starting data consumption...") - update_worker_handlers = start_all_update_actors(config) + update_worker_handlers = start_all_update_actors() # Wait for completion print("\n[Phase 4] Waiting for actors to complete...") From fecb07d6f4f090a44f5fb629102017377fbe5c9a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Feb 2026 14:09:50 +0800 Subject: [PATCH 11/14] fix Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 6b8c8920..02a3858b 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -81,6 +81,10 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") + storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) + backend_name = conf.backend.storage_backend + conf.backend[backend_name].zmq_info = storage_zmq_info + return conf From 9217ca3f2aaf8a9f975da8a5734160ac97fdf51f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Feb 2026 15:29:37 +0800 Subject: [PATCH 12/14] fix minor comments Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 1 - transfer_queue/controller.py | 6 +++--- transfer_queue/interface.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 03f6ce9c..592ef0d2 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -51,7 +51,6 @@ "set_custom_meta", "clear_samples", "clear_partition", - "async_clear_samples", "async_get_meta", "async_get_data", "async_put", diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index eb1d49c5..763f3f49 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1776,12 +1776,12 @@ def get_zmq_server_info(self) -> ZMQServerInfo: return self.zmq_server_info def store_config(self, conf: DictConfig) -> None: - """Storage the global config of TransferQueue.""" - self.tq_conf = conf + """Store the global config of TransferQueue.""" + self.tq_config = conf def get_config(self) -> DictConfig: """Retrieve the global config of TransferQueue.""" - return self.tq_conf + return self.tq_config def register_sampler( self, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 02a3858b..aed5f2fc 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -46,7 +46,7 @@ def _maybe_create_transferqueue_client( global _TRANSFER_QUEUE_CLIENT if _TRANSFER_QUEUE_CLIENT is None: if conf is None: - raise ValueError("Missing config for initialing TransferQueueClient!") + raise ValueError("Missing config for initializing TransferQueueClient!") pid = os.getpid() _TRANSFER_QUEUE_CLIENT = TransferQueueClient( client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info @@ -185,7 +185,7 @@ def init(conf: Optional[DictConfig] = None) -> None: # create distributed storage backends final_conf = _maybe_create_transferqueue_storage(final_conf) - # storage the config into controller + # store the config into controller ray.get(controller.store_config.remote(final_conf)) logger.info(f"TransferQueue config: {final_conf}") From 23ab1aee18210cac7eeeb5a06c58ba12c818de1c Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Feb 2026 16:10:41 +0800 Subject: [PATCH 13/14] fix comments Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 6 +++--- tests/test_client.py | 4 ++-- transfer_queue/client.py | 3 +-- transfer_queue/interface.py | 3 ++- transfer_queue/storage/managers/simple_backend_manager.py | 2 +- tutorial/01_core_components.py | 1 + tutorial/02_metadata_concepts.py | 1 + 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index a4990fe1..5187e169 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -62,7 +62,7 @@ async def mock_async_storage_manager(): ) config = { - "storage_unit_infos": storage_unit_infos, + "zmq_info": storage_unit_infos, } # Mock the handshake process entirely to avoid ZMQ complexity @@ -198,7 +198,7 @@ async def test_async_storage_manager_mapping_functions(): ) config = { - "storage_unit_infos": storage_unit_infos, + "zmq_info": storage_unit_infos, } # Mock ZMQ operations @@ -272,7 +272,7 @@ async def test_async_storage_manager_error_handling(): ) config = { - "storage_unit_infos": storage_unit_infos, + "zmq_info": storage_unit_infos, } # Mock ZMQ operations diff --git a/tests/test_client.py b/tests/test_client.py index 7a8037a1..42cd63bc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -318,7 +318,7 @@ def client_setup(mock_controller, mock_storage): ): config = { "controller_info": mock_controller.zmq_server_info, - "storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info}, + "zmq_info": {mock_storage.storage_id: mock_storage.zmq_server_info}, } client.initialize_storage_manager(manager_type="SimpleStorage", config=config) @@ -411,7 +411,7 @@ def test_single_controller_multiple_storages(): ): config = { "controller_info": controller.zmq_server_info, - "storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages}, + "zmq_info": {s.storage_id: s.zmq_server_info for s in storages}, } client.initialize_storage_manager(manager_type="SimpleStorage", config=config) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 7b46bffb..cd9c6e84 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -91,8 +91,7 @@ def initialize_storage_manager( AsyncSimpleStorageManager, KVStorageManager (under development), etc. config: Configuration dictionary for the storage manager. For AsyncSimpleStorageManager, must contain the following required keys: - - controller_info: ZMQ server information about the controller - - storage_unit_infos: ZMQ server information about the storage units + - zmq_info: ZMQ server information about the storage units """ self.storage_manager = TransferQueueStorageManagerFactory.create( diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index aed5f2fc..37ebdb23 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -53,7 +53,6 @@ def _maybe_create_transferqueue_client( ) backend_name = conf.backend.storage_backend - conf.backend[backend_name].controller_info = conf.controller.zmq_info _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) @@ -141,6 +140,8 @@ def init(conf: Optional[DictConfig] = None) -> None: _init_from_existing() except ValueError: logger.info("No TransferQueueController found. Starting first-time initialization...") + else: + return # First-time initialize TransferQueue diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 94c30f5e..f1420690 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -288,7 +288,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping ) - # retrive data + # retrieve data tasks = [ self._get_from_single_storage_unit(meta_group, target_storage_unit=storage_id) for storage_id, meta_group in storage_meta_groups.items() diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 6d6f32ba..c7849e86 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -190,6 +190,7 @@ def main(): # Cleanup ray.shutdown() + tq.close() print("\n✓ Cleanup complete") except Exception as e: diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index ecff7ef9..d864db39 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -374,6 +374,7 @@ def demonstrate_real_workflow(): # Cleanup tq.clear_partition(partition_id=partition_id) + tq.close() ray.shutdown() print("✓ Partition cleared and resources cleaned up") From 4f0d3f80aa3718a6f47480aad6751420c4ac685a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Feb 2026 16:21:08 +0800 Subject: [PATCH 14/14] fix Signed-off-by: 0oshowero0 --- tutorial/01_core_components.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index c7849e86..59159a6d 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -125,16 +125,16 @@ def demonstrate_storage_backend_options(): print("=" * 80) print("TransferQueue supports multiple storage backends:") - print("1. SimpleStorageUnit (default)") + print("1. SimpleStorage (default)") print(" - In-memory storage, fast and simple") print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer") print(" - No extra dependencies, good for development and testing") - print("2. YuanrongStorage") + print("2. Yuanrong") print(" - Ascend native distributed storage solution") print(" - Hierarchical storage interfaces including HBM/DRAM/SSD") - print("3. MoonCakeStore (on the way)") + print("3. MooncakeStore (on the way)") print(" - Support multiple transmission protocols") print(" - RDMA between DRAM and HBM") @@ -189,8 +189,8 @@ def main(): print("3. You can swap out different storage backends easily") # Cleanup - ray.shutdown() tq.close() + ray.shutdown() print("\n✓ Cleanup complete") except Exception as e: