diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d51ffeeb..14339bb8 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -37,11 +37,6 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=False) + - name: Test with pytest run: | - pytest - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=True) - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True pytest \ No newline at end of file diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index f1c00dd2..96bad6ee 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -7,6 +7,7 @@ import sys import time from pathlib import Path + import numpy as np import ray import torch @@ -14,17 +15,16 @@ from tensordict import TensorDict from tensordict.utils import LinkedList - parent_dir = Path(__file__).resolve().parent.parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( +from transfer_queue import ( # noqa: E402 AsyncTransferQueueClient, SimpleStorageUnit, TransferQueueController, process_zmq_server_info, ) -from transfer_queue.utils.utils import get_placement_group +from transfer_queue.utils.utils import get_placement_group # noqa: E402 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -33,48 +33,13 @@ # Configuration Map # ========================================================= CONFIG_MAP = { - "debug": { - "global_batch_size": 32, - "seq_length": 128, - "field_num": 2, - "desc": "Debug (~32KB)" - }, - "tiny": { - "global_batch_size": 64, - "seq_length": 1024, - "field_num": 4, - "desc": "Tiny (~1MB)" - }, - "small": { - "global_batch_size": 512, - "seq_length": 12800, - "field_num": 4, - "desc": "Small (~100MB)" - }, - "medium": { - "global_batch_size": 1024, - "seq_length": 65536, - "field_num": 4, - "desc": "Medium (~1GB)" - }, - "large": { - "global_batch_size": 2048, - "seq_length": 128000, - "field_num": 5, - "desc": "Large (~5GB)" - }, - "xlarge": { - "global_batch_size": 4096, - "seq_length": 128000, - "field_num": 5, - "desc": "X-Large (~10GB)" - }, - "huge": { - "global_batch_size": 4096, - "seq_length": 128000, - "field_num": 10, - "desc": "Huge (~20GB)" - } + "debug": {"global_batch_size": 32, "seq_length": 128, "field_num": 2, "desc": "Debug (~32KB)"}, + "tiny": {"global_batch_size": 64, "seq_length": 1024, "field_num": 4, "desc": "Tiny (~1MB)"}, + "small": {"global_batch_size": 512, "seq_length": 12800, "field_num": 4, "desc": "Small (~100MB)"}, + "medium": {"global_batch_size": 1024, "seq_length": 65536, "field_num": 4, "desc": "Medium (~1GB)"}, + "large": {"global_batch_size": 2048, "seq_length": 128000, "field_num": 5, "desc": "Large (~5GB)"}, + "xlarge": {"global_batch_size": 4096, "seq_length": 128000, "field_num": 5, "desc": "X-Large (~10GB)"}, + "huge": {"global_batch_size": 4096, "seq_length": 128000, "field_num": 10, "desc": "Huge (~20GB)"}, } @@ -89,7 +54,7 @@ def calculate_stats(data: list) -> dict: "mean": float(np.mean(data)), "max": float(np.max(data)), "min": float(np.min(data)), - "p99": float(np.percentile(data, 99)) + "p99": float(np.percentile(data, 99)), } @@ -119,7 +84,7 @@ def _generate_nested_tensor(batch_size, total_elements, dtype): # Use Dirichlet distribution to generate random proportions summing to 1 proportions = np.random.dirichlet(np.ones(batch_size)) lengths = (proportions * total_elements).astype(int) - + # Fix rounding errors to ensure exact total element count diff = total_elements - lengths.sum() if diff != 0: @@ -127,10 +92,10 @@ def _generate_nested_tensor(batch_size, total_elements, dtype): indices = np.argsort(lengths)[::-1] for i in range(abs(diff)): lengths[indices[i % batch_size]] += 1 if diff > 0 else -1 - + # Ensure each length is at least 1 lengths = np.maximum(lengths, 1) - + # Generate tensors with different lengths tensors = [] for length in lengths: @@ -138,7 +103,7 @@ def _generate_nested_tensor(batch_size, total_elements, dtype): tensors.append(torch.randint(0, 10000, (int(length),), dtype=dtype)) else: tensors.append(torch.randn(int(length), dtype=dtype)) - + return torch.nested.nested_tensor(tensors, dtype=dtype) @@ -168,7 +133,7 @@ def create_complex_test_case(batch_size, seq_length, field_num): fields[field_name] = tensor_data total_size_bytes += total_elements_per_field * bytes_per_elem - total_size_gb = total_size_bytes / (1024 ** 3) + total_size_gb = total_size_bytes / (1024**3) prompt_batch = TensorDict( fields, @@ -190,18 +155,18 @@ def _compare_nested_tensors(original, retrieved, path): # Unbind to list for element-wise comparison orig_tensors = original.unbind() retr_tensors = retrieved.unbind() - + if len(orig_tensors) != len(retr_tensors): return False, f"[{path}] NestedTensor batch size mismatch: {len(orig_tensors)} vs {len(retr_tensors)}" - - for idx, (orig, retr) in enumerate(zip(orig_tensors, retr_tensors)): + + for idx, (orig, retr) in enumerate(zip(orig_tensors, retr_tensors, strict=False)): if orig.shape != retr.shape: return False, f"[{path}][{idx}] Shape mismatch: {orig.shape} vs {retr.shape}" if orig.dtype != retr.dtype: return False, f"[{path}][{idx}] Dtype mismatch: {orig.dtype} vs {retr.dtype}" if not torch.equal(orig.cpu(), retr.cpu()): return False, f"[{path}][{idx}] Values mismatch" - + return True, "Passed" @@ -214,8 +179,8 @@ def check_data_consistency(original, retrieved, path="root"): retrieved = list(retrieved) # NestedTensor check (must be before regular Tensor since NestedTensor is also a Tensor) - if original.is_nested if hasattr(original, 'is_nested') else False: - if not (retrieved.is_nested if hasattr(retrieved, 'is_nested') else False): + if original.is_nested if hasattr(original, "is_nested") else False: + if not (retrieved.is_nested if hasattr(retrieved, "is_nested") else False): return False, f"[{path}] Type mismatch: NestedTensor vs non-NestedTensor" return _compare_nested_tensors(original, retrieved, path) @@ -255,10 +220,11 @@ def check_data_consistency(original, retrieved, path="root"): # Core Tester Class # ========================================================= + def sync_stage(flag_to_create, flag_to_wait): """Profile sync helper function for synchronizing with external profiler process""" - with open(flag_to_create, 'w') as f: - f.write('1') + with open(flag_to_create, "w") as f: + f.write("1") while not os.path.exists(flag_to_wait): time.sleep(0.05) try: @@ -283,12 +249,14 @@ def __init__(self, target_ip=None, storage_units=8, enable_profile=False): def initialize_system(self, config_dict): """Initialize TransferQueue system based on current configuration""" # Basic config conversion - self.tq_config = OmegaConf.create({ - "global_batch_size": config_dict["global_batch_size"], - "num_global_batch": 1, - "num_data_storage_units": self.num_storage_units, - "num_data_controllers": 1 - }) + self.tq_config = OmegaConf.create( + { + "global_batch_size": config_dict["global_batch_size"], + "num_global_batch": 1, + "num_data_storage_units": self.num_storage_units, + "num_data_controllers": 1, + } + ) total_storage_size = self.tq_config.global_batch_size * 2 @@ -301,9 +269,7 @@ def initialize_system(self, config_dict): num_cpus=1, resources={f"node:{self.target_ip}": 0.001}, runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}}, - ).remote( - storage_unit_size=math.ceil(total_storage_size / self.num_storage_units) - ) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.num_storage_units)) else: # Local Mode: Use placement group self.storage_placement_group = get_placement_group(self.num_storage_units, num_cpus_per_actor=2) @@ -312,9 +278,7 @@ def initialize_system(self, config_dict): placement_group=self.storage_placement_group, placement_group_bundle_index=rank, runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}}, - ).remote( - storage_unit_size=math.ceil(total_storage_size / self.num_storage_units) - ) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.num_storage_units)) # Controller Init self.data_system_controller = TransferQueueController.remote() @@ -331,11 +295,11 @@ def initialize_system(self, config_dict): # Client Init self.data_system_client = AsyncTransferQueueClient( - client_id='Trainer', - controller_info=self.data_system_controller_info + client_id="Trainer", controller_info=self.data_system_controller_info + ) + self.data_system_client.initialize_storage_manager( + manager_type="AsyncSimpleStorageManager", config=self.tq_config ) - self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", - config=self.tq_config) return self.data_system_client @@ -361,6 +325,7 @@ def cleanup(self): # 4. Force garbage collection import gc + gc.collect() # 5. Wait for Ray scheduler to update state (prevent race condition) @@ -370,9 +335,7 @@ def run_benchmark_rounds(self, config_name, config, rounds): """Run multiple rounds of PUT/GET bandwidth tests""" logger.info(f"Generating test data [{config_name}]...") big_input_ids, total_gb = create_complex_test_case( - batch_size=config["global_batch_size"], - seq_length=config["seq_length"], - field_num=config["field_num"] + batch_size=config["global_batch_size"], seq_length=config["seq_length"], field_num=config["field_num"] ) logger.info(f"Data Size: {total_gb:.4f} GB") @@ -387,27 +350,29 @@ def run_benchmark_rounds(self, config_name, config, rounds): # PUT operation start_put = time.time() if i == 0 and self.enable_profile: - sync_stage('init_ready.flag', 'put_start.flag') + sync_stage("init_ready.flag", "put_start.flag") asyncio.run(self.data_system_client.async_put(data=big_input_ids, partition_id=partition_key)) put_time = time.time() - start_put if i == 0 and self.enable_profile: - sync_stage('put_done.flag', 'get_prepare.flag') + sync_stage("put_done.flag", "get_prepare.flag") put_gbps = (total_gb * 8) / put_time put_speeds.append(put_gbps) time.sleep(2) # Get metadata (required step for TQ flow) - prompt_meta = asyncio.run(self.data_system_client.async_get_meta( - data_fields=list(big_input_ids.keys()), - batch_size=big_input_ids.size(0), - partition_id=partition_key, - task_name='generate_sequences', - )) + prompt_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=list(big_input_ids.keys()), + batch_size=big_input_ids.size(0), + partition_id=partition_key, + task_name="generate_sequences", + ) + ) # GET operation start_get = time.time() if i == 0 and self.enable_profile: - sync_stage('get_ready.flag', 'get_start.flag') + sync_stage("get_ready.flag", "get_start.flag") retrieved_data = asyncio.run(self.data_system_client.async_get_data(prompt_meta)) get_time = time.time() - start_get @@ -422,7 +387,7 @@ def run_benchmark_rounds(self, config_name, config, rounds): if not is_consistent: print(f" ❌ FAIL: {msg}") else: - print(f" ✅ PASS", end="") + print(" ✅ PASS", end="") asyncio.run(self.data_system_client.async_clear_partition(partition_id=partition_key)) print("\n") @@ -434,7 +399,7 @@ def make_result(op, speeds): "data_volume": f"{total_gb * 1024:.2f} MB" if total_gb * 1024 < 10 else f"{total_gb:.4f} GB", "operation": op, "payload_gb": total_gb, - "stats_gbps": calculate_stats(speeds) + "stats_gbps": calculate_stats(speeds), } return [make_result("PUT", put_speeds), make_result("GET", get_speeds)] @@ -446,8 +411,9 @@ def make_result(op, speeds): def main(): parser = argparse.ArgumentParser(description="TransferQueue Bandwidth Benchmark") parser.add_argument("--ip", type=str, default=None, help="Worker node IP, local test if not set") - parser.add_argument("--config", type=str, default=None, choices=list(CONFIG_MAP.keys()), - help="Specific config to run") + parser.add_argument( + "--config", type=str, default=None, choices=list(CONFIG_MAP.keys()), help="Specific config to run" + ) parser.add_argument("--output", type=str, default="tq_benchmark_result.json", help="Output JSON file") parser.add_argument("--rounds", type=int, default=20, help="Test rounds per config (default: 20)") parser.add_argument("--shards", type=int, default=8, help="Number of storage units (default: 8)") @@ -458,12 +424,9 @@ def main(): # Initialize Ray current_working_dir = os.getcwd() if not ray.is_initialized(): - ray.init( - address="auto" if args.ip else None, - runtime_env={"working_dir": current_working_dir} - ) + ray.init(address="auto" if args.ip else None, runtime_env={"working_dir": current_working_dir}) - target_address = args.ip if args.ip else '127.0.0.1' + target_address = args.ip if args.ip else "127.0.0.1" logger.info(f"Ray initialized. Target: {target_address}") # Create tester @@ -488,16 +451,12 @@ def main(): logger.info(f"💾 Results saved to {args.output}") except Exception as e: - logger.error(f"❌ Critical error: {e}", exc_info=True) + logger.exception(f"❌ Critical error: {e}") finally: if ray.is_initialized(): ray.shutdown() if __name__ == "__main__": - try: - from transfer_queue.utils import serial_utils - print(f'[Startup Check] TQ_ZERO_COPY_SERIALIZATION = {serial_utils.TQ_ZERO_COPY_SERIALIZATION}') - except ImportError: - print('[Startup Check] Could not import serial_utils to check flag') + print("[Startup Check]") main() diff --git a/tests/test_serial_utils_on_cpu.py b/tests/test_serial_utils_on_cpu.py index ababaf27..c7e83c57 100644 --- a/tests/test_serial_utils_on_cpu.py +++ b/tests/test_serial_utils_on_cpu.py @@ -574,7 +574,7 @@ def test_single_nested_tensor_serialization(): def test_large_string_serialization(): """Test serialization of large strings (>10KB). - + Note: msgpack natively handles str type, so enc_hook is not called for strings. This test verifies large strings are correctly serialized/deserialized. """ @@ -583,9 +583,9 @@ def test_large_string_serialization(): # Create a string larger than 10KB large_string = "x" * 11000 # ~11KB - + serialized = encoder.encode({"text": large_string}) - + # Verify content is correctly restored decoded = decoder.decode(serialized) assert decoded["text"] == large_string diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 6569adab..0f03a122 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -21,6 +21,7 @@ process_zmq_server_info, ) from .controller import TransferQueueController +from .dataloader import StreamingDataLoader, StreamingDataset from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler @@ -32,8 +33,10 @@ __all__ = [ "AsyncTransferQueueClient", - "BatchMeta", "TransferQueueClient", + "StreamingDataset", + "StreamingDataLoader", + "BatchMeta", "TransferQueueController", "SimpleStorageUnit", "ZMQServerInfo", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index df0a4f78..3bcfcc5e 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -509,7 +509,11 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta if response_msg.request_type != ZMQRequestType.GET_PARTITION_META_RESPONSE: raise RuntimeError("Failed to get metadata for clear operation.") - return BatchMeta.from_dict(response_msg.body["metadata"]) if isinstance(response_msg.body["metadata"], dict) else response_msg.body["metadata"] + return ( + BatchMeta.from_dict(response_msg.body["metadata"]) + if isinstance(response_msg.body["metadata"], dict) + else response_msg.body["metadata"] + ) @dynamic_socket(socket_name="request_handle_socket") async def _clear_partition_in_controller(self, partition_id, socket=None): diff --git a/transfer_queue/dataloader/__init__.py b/transfer_queue/dataloader/__init__.py new file mode 100644 index 00000000..87d3ed8a --- /dev/null +++ b/transfer_queue/dataloader/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .streaming_dataloader import StreamingDataLoader +from .streaming_dataset import StreamingDataset + +__all__ = ["StreamingDataset", "StreamingDataLoader"] diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py new file mode 100644 index 00000000..374d5155 --- /dev/null +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -0,0 +1,139 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Optional + +import torch +from tensordict import TensorDict + +from transfer_queue.dataloader.streaming_dataset import StreamingDataset +from transfer_queue.metadata import BatchMeta + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +# Ensure logger has a handler +if not logger.hasHandlers(): + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) + logger.addHandler(handler) + + +def _identity_collate_fn(data: tuple[TensorDict, BatchMeta]) -> tuple[TensorDict, BatchMeta]: + """Identity collate function for TransferQueue. + + This function acts as a pass-through, preserving the `(TensorDict, BatchMeta)` + structure yielded by `StreamingDataset`. It prevents PyTorch from attempting + to stack or modify the already-batched data. + """ + return data + + +class StreamingDataLoader(torch.utils.data.DataLoader): + """StreamingDataLoader interface for TransferQueue. + + This DataLoader wraps StreamingDataset and provides a PyTorch DataLoader + interface for distributed training with streaming data access. + + Key Features: + - Compatible with PyTorch training loops (for loop iteration) + - Works with StreamingDataset for streaming data access + - Supports distributed training via RankAwareSampler coordination + + + Note: + This DataLoader is typically used with StreamingDataset which manages + batch size internally. The standard PyTorch DataLoader batch_size + parameter is set to None because batching is handled by the dataset + in coordination with TransferQueue's sampling logic. + + Example: + >>> dataset = StreamingDataset( + ... config=config, + ... micro_batch_size=4, + ... required_fields=["input_ids", "attention_mask"], + ... partition_id="train", + ... task_name="update_actor", + ... data_replica_group=0, + ... data_replica_rank=0, + ... data_replica_world_size=1, + ... ) + >>> dataloader = StreamingDataLoader(dataset, num_workers=0) + >>> for batch, batch_meta in dataloader: + ... # batch: TensorDict with requested fields + ... # batch_meta: Metadata for TransferQueue coordination + ... loss = model(batch) + ... loss.backward() + """ + + def __init__( + self, + dataset: StreamingDataset, + num_workers: int = 0, + collate_fn=None, + pin_memory: bool = False, + worker_init_fn=None, + multiprocessing_context=None, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + ): + """Initialize the StreamingDataLoader. + + Args: + dataset: StreamingDataset instance. + num_workers: Number of subprocesses for data loading. + collate_fn: Function to collate samples into batches. + pin_memory: If True, pin memory for GPU transfer. + worker_init_fn: Worker initialization function. + multiprocessing_context: Multiprocessing context. + prefetch_factor: Number of batches to prefetch per worker. + persistent_workers: Keep workers alive between epochs. + pin_memory_device: Device for pin_memory. + + Note: + This DataLoader is designed to work with StreamingDataset which handles + batch size internally via the micro_batch_size parameter. The batch_size + parameter in PyTorch DataLoader is set to None because batching is managed + by the StreamingDataset in coordination with RankAwareSampler. + """ + + if collate_fn is None: + # use identical collate function to directly return the self-defined + # [TensorDict, BatchMeta] output of StreamingDataset + final_collate_fn = _identity_collate_fn + else: + final_collate_fn = collate_fn + + super().__init__( + dataset=dataset, + batch_size=None, # Batch size is handled by the dataset + shuffle=None, + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=final_collate_fn, + pin_memory=pin_memory, + drop_last=False, + timeout=0, + worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, + generator=None, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + pin_memory_device=pin_memory_device, + ) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py new file mode 100644 index 00000000..510a497a --- /dev/null +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -0,0 +1,200 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time +import uuid +from typing import Any, Iterator + +from tensordict import TensorDict +from torch.utils.data import IterableDataset + +from transfer_queue import TransferQueueClient +from transfer_queue.metadata import BatchMeta +from transfer_queue.utils.zmq_utils import ZMQServerInfo + +TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float( + os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1) +) # in seconds + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +# Ensure logger has a handler +if not logger.hasHandlers(): + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) + logger.addHandler(handler) + + +class StreamingDataset(IterableDataset): + """Streaming dataset for distributed training with TransferQueue. + + This dataset is designed to work with RankAwareSampler for distributed training + scenarios where each rank independently retrieves data through TransferQueue. + + Usage Example: + >>> dataset = StreamingDataset( + ... config=config, + ... micro_batch_size=4, + ... required_fields=["input_ids", "attention_mask"], + ... partition_id="train", + ... task_name="update_actor", + ... data_replica_group=data_replica_group_id, # Same for all ranks in data replica group + ... data_replica_rank=local_rank, # local rank in data replica group + ... data_replica_world_size=world_size/dp_world_size, # size of data replica group + ... ) + >>> dataloader = StreamingDataLoader( + ... dataset, + ... num_workers=2, # num_workers for data retrieval, each has a TQ client for async data retrieval + ... prefetch_factor=2, # number of batches loaded in advance by each worker + ... ) + >>> for batch, batch_meta in dataloader: + ... # batch is a TensorDict with the requested fields + ... # batch_meta contains metadata for TransferQueue coordination + ... pass + """ + + def __init__( + self, + config: dict[str, Any], + micro_batch_size: int, + required_fields: list[str], + partition_id: str, + task_name: str, + data_replica_group: int, + data_replica_rank: int, + data_replica_world_size: int, + ): + """Initialize the StreamingDataset. + + Args: + config: Configuration dictionary containing: + - controller_info: ZMQServerInfo for the TransferQueueController + - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager") + - Other backend-specific configuration + micro_batch_size: Number of samples per micro-batch. This is the batch size + that will be requested from TransferQueue for each iteration. + required_fields: List of field names to retrieve from storage. Only these + fields will be included in the returned batch. + partition_id: Partition ID for data versioning. Different partitions can + be used for different data versions or splits (e.g., "train", "val"). + task_name: Unique identifier for the training task. This is used to track + which samples have been consumed by which task. + data_replica_group: The group ID of the current data replica group. All + ranks with the same data_replica_group will receive identical samples. + data_replica_rank: Local rank index within the data_replica_group. Range: + [0, data_replica_world_size - 1] + data_replica_world_size: Total number of ranks in this data_replica_group. + Must be >= 1. + + Raises: + ValueError: If input parameters are invalid. + """ + + if micro_batch_size < 1: + raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}") + + if len(required_fields) < 1: + raise ValueError(f"required_fields must be a list with at least one field name, got {required_fields}") + + if data_replica_world_size < 1: + raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") + + if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: + raise ValueError( + f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than " + f"data_replica_world_size {data_replica_world_size}" + ) + + self.config = config + self.micro_batch_size = micro_batch_size + self.required_fields = required_fields + self.partition_id = partition_id + self.task_name = task_name + self.data_replica_group = data_replica_group + self.data_replica_rank = data_replica_rank + self.data_replica_world_size = data_replica_world_size + + # Build sampling config for controller + self.sampling_config = { + "data_replica_group": self.data_replica_group, + "data_replica_rank": self.data_replica_rank, + "data_replica_world_size": self.data_replica_world_size, + "task_name": self.task_name, + "partition_id": self.partition_id, + } + + self._tq_client = None + + super().__init__() + + def _create_client(self): + client_id = uuid.uuid4().hex[:8] + controller_info = self.config.get("controller_info", None) + if not controller_info or not isinstance(controller_info, ZMQServerInfo): + raise ValueError("Invalid or missing controller_info in config") + + storage_backend = self.config.get("storage_backend", None) + if not storage_backend: + raise ValueError("Missing storage_backend in config") + + self._tq_client = TransferQueueClient(client_id, controller_info) + self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=self.config) + + def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: + """Iterate over the dataset, yielding batches of data. + + Yields: + Tuple[TensorDict, BatchMeta]: A tuple containing: + - TensorDict: Batch of data with the requested fields. + - BatchMeta: Corresponding metadata to interact with TransferQueue. + Note: + This iterator runs indefinitely until the data source is exhausted. + The caller should handle StopIteration when appropriate (e.g., when + all data has been consumed and no more data will be produced). + """ + if self._tq_client is None: + self._create_client() + + # TODO: need to consider async scenario where the samples in partition is dynamically increasing + while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): + try: + # Get metadata from controller + batch_meta = self._tq_client.get_meta( + data_fields=self.required_fields, + batch_size=self.micro_batch_size, + partition_id=self.partition_id, + task_name=self.task_name, + sampling_config=self.sampling_config, + ) + + # Check if we got valid data + if batch_meta.size == 0: + logger.debug( + f"[StreamingDataset]: Received empty batch, waiting for more data... " + f"Required batch_size={self.micro_batch_size}, data_fields={self.required_fields}," + f"partition_id={self.partition_id}, task_name={self.task_name}." + ) + + time.sleep(TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL) + else: + batch = self._tq_client.get_data(batch_meta) + yield (batch, batch_meta) + + except Exception as e: + logger.error(f"[StreamingDataset]: Error in data iteration: {e}") + raise diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 68a93505..af3f3218 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -68,7 +68,7 @@ def from_dict(cls, data: dict) -> "FieldMeta": dtype=data["dtype"], shape=data["shape"], production_status=ProductionStatus(str(data["production_status"])) - if isinstance(data["production_status"], (int, str)) + if isinstance(data["production_status"], int | str) else data["production_status"], ) diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index 0707a1d0..71399752 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -257,9 +257,7 @@ def _process_put_get(self) -> None: }, ) - self.put_get_socket.send_multipart( - [identity, *response_msg.serialize()], copy=False - ) + self.put_get_socket.send_multipart([identity, *response_msg.serialize()], copy=False) def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: """ diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index a12f48b5..cd308a9d 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -21,7 +21,7 @@ import warnings from collections.abc import Sequence from types import FunctionType -from typing import Any, Optional, TypeAlias +from typing import Any, TypeAlias import cloudpickle import numpy as np @@ -30,7 +30,6 @@ from msgspec import msgpack from tensordict import TensorDictBase - CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference @@ -58,7 +57,7 @@ def __init__(self): # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. - self.aux_buffers: Optional[list[bytestr]] = None + self.aux_buffers: list[bytestr] = [] def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -70,7 +69,7 @@ def encode(self, obj: Any) -> Sequence[bytestr]: # new buffer. return bufs finally: - self.aux_buffers = None + self.aux_buffers = [] def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: try: @@ -79,7 +78,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: self.encoder.encode_into(obj, buf) return bufs finally: - self.aux_buffers = None + self.aux_buffers = [] def enc_hook(self, obj: Any) -> Any: """Custom encoding hook for types msgspec doesn't natively support. @@ -95,7 +94,6 @@ def enc_hook(self, obj: Any) -> Any: # Handle TensorDict explicitly for recursive zero-copy if isinstance(obj, TensorDictBase): return self._encode_tensordict(obj) - # Handle numpy arrays by converting to tensor if isinstance(obj, np.ndarray): @@ -136,7 +134,7 @@ def _encode_tensor(self, obj: torch.Tensor) -> msgpack.Ext: Returns Ext type so decoding goes through ext_hook (which has buffer access). """ - assert self.aux_buffers is not None + assert len(self.aux_buffers) > 0 # Handle nested tensors (strided or jagged) via unbind if obj.is_nested: @@ -167,11 +165,12 @@ def _encode_nested_tensor(self, obj: torch.Tensor) -> msgpack.Ext: def _encode_regular_tensor_meta(self, obj: torch.Tensor) -> tuple: """Encode a regular tensor and return its metadata tuple.""" # Handle non-contiguous tensors + if not obj.is_contiguous(): obj = obj.contiguous() # Handle GPU tensors - if obj.device.type != 'cpu': + if obj.device.type != "cpu": obj = obj.cpu() # Zero-copy buffer extraction via uint8 view @@ -186,11 +185,12 @@ def _encode_regular_tensor_meta(self, obj: torch.Tensor) -> tuple: def _encode_regular_tensor(self, obj: torch.Tensor) -> msgpack.Ext: """Encode a regular (non-nested) tensor with zero-copy.""" # Handle non-contiguous tensors + if not obj.is_contiguous(): obj = obj.contiguous() # Handle GPU tensors - if obj.device.type != 'cpu': + if obj.device.type != "cpu": obj = obj.cpu() if obj.is_sparse: @@ -251,6 +251,7 @@ def _reconstruct_tensordict(self, obj: dict) -> Any: """Reconstruct TensorDict from marked dict structure.""" try: from tensordict import TensorDict + batch_size = obj["batch_size"] data = obj["data"] # Recursively process nested data @@ -305,4 +306,3 @@ def ext_hook(self, code: int, data: memoryview) -> Any: _encoder = MsgpackEncoder() _decoder = MsgpackDecoder() - diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 543e34b9..3ece8edc 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -24,7 +24,7 @@ import psutil import zmq -from transfer_queue.utils.serial_utils import _encoder, _decoder +from transfer_queue.utils.serial_utils import _decoder, _encoder from transfer_queue.utils.utils import ( ExplicitEnum, TransferQueueRole, diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py new file mode 100644 index 00000000..d532c80e --- /dev/null +++ b/tutorial/05_streaming_dataloader.py @@ -0,0 +1,388 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 5: Streaming DataLoader for Distributed Training + +This script demonstrates how to use StreamingDataset and StreamingDataLoader +for efficient streaming data loading in distributed training scenarios. + +Key Components: +- StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue +- StreamingDataLoader: DataLoader wrapper that yields (batch, batch_meta) tuples +- RankAwareSampler: Enables data replica group coordination for consistent + sampling across multiple ranks + +Use Cases: +- Distributed training with multiple data replica groups +- Fine-grained micro-batch-level data retrieval +""" + +import os +import sys +import textwrap +import time +import warnings +from pathlib import Path + +os.environ["RAY_DEDUP_LOGS"] = "0" + +warnings.filterwarnings( + action="ignore", + message=r"The PyTorch API of nested tensors is in prototype stage*", + category=UserWarning, + module=r"torch\.nested", +) + +warnings.filterwarnings( + action="ignore", + message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible " + r"devices env var if num_gpus=0 or num_gpus=None.*", + category=FutureWarning, + module=r"ray\._private\.worker", +) + + +import ray # noqa: E402 +import torch # noqa: E402 +from omegaconf import DictConfig, OmegaConf # noqa: E402 +from tensordict import TensorDict # noqa: E402 + +# Add the parent directory to the path for imports +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + + +from transfer_queue import ( # noqa: E402 + RankAwareSampler, + SimpleStorageUnit, + StreamingDataLoader, + StreamingDataset, + TransferQueueClient, + TransferQueueController, + process_zmq_server_info, +) + + +def setup_transfer_queue(): + """Setup TransferQueue components.""" + if not ray.is_initialized(): + ray.init() + + config = OmegaConf.create( + { + "num_data_storage_units": 2, + } + ) + + storage_units = {} + for i in range(config["num_data_storage_units"]): + storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) + + print("[Setup]: Setup TransferQueue components") + print( + "Note: Using RankAwareSampler when each rank retrieves data independently. It guarantees that " + "all ranks within the same data replica group receive the same sample indices." + ) + print( + "Note: When using streaming data retrieval, please set polling_mode=True when initializing " + "TransferQueueController. In polling_mode, the controller will return empty BatchMeta when " + "available data cannot meet the consumption requirements. User side need to retry later." + ) + controller = TransferQueueController.remote( + sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling across ranks in same replica group + polling_mode=True, # Enable polling mode for streaming data retrieval + ) + + controller_info = process_zmq_server_info(controller) + storage_unit_infos = process_zmq_server_info(storage_units) + + # Build the complete configuration + tq_config = OmegaConf.create({}, flags={"allow_objects": True}) + tq_config.controller_info = controller_info + tq_config.storage_unit_infos = storage_unit_infos + config.storage_backend = "AsyncSimpleStorageManager" + config = OmegaConf.merge(tq_config, config) + + return controller, storage_units, config + + +@ray.remote(num_cpus=0.1) +def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): + """ + Generate actor that produces training samples. + + This actor simulates a data producer that generates training samples + and puts them into the TransferQueue for consumption by training actors. + + Args: + rank_id: Unique identifier for this generator (used for sample indexing) + config: TransferQueue configuration + num_samples: Number of samples to generate + + Note: + Each sample has a unique sequence ID calculated as: seq_id = i + (rank_id * 10000) + This ensures global uniqueness across all generator actors. + """ + # Create a client for interacting with TransferQueue + client = TransferQueueClient( + client_id=f"gen_worker_{rank_id}", + controller_info=config.controller_info, + ) + + # Initialize the storage manager for this client + client.initialize_storage_manager(manager_type=config.storage_backend, config=config) + + # Generate and put samples into the queue + for i in range(num_samples): + # Create unique sequence ID for this sample + seq_id = i + (rank_id * 10000) + + # Create sample data as TensorDict + data = TensorDict( + {"input_ids": torch.full((1, 16), seq_id, dtype=torch.long), "meta_idx": torch.tensor([seq_id])}, + batch_size=1, + ) + + print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue") + + # Put data into the specified partition + client.put(data, partition_id="train") + + print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue") + + +@ray.remote(num_cpus=0.1) +def update_worker( + rank_id: int, + data_replica_group: int, + data_replica_rank: int, + data_replica_world_size: int, + config: DictConfig, + max_steps: int = 5, +): + """ + Update actor that retrieves and processes training batches. + + This actor simulates a training worker that consumes data from the + TransferQueue using StreamingDataLoader. It demonstrates how to use + the streaming data loading infrastructure in a distributed setting. + + Args: + rank_id: Global rank identifier for logging and display purposes + data_replica_group: ID of the data parallel group this rank belongs to + Ranks in the same group receive the same data samples + data_replica_rank: Local rank index within the data replica group + Range: [0, data_replica_world_size - 1] + data_replica_world_size: Total number of ranks in this data replica group + config: TransferQueue configuration + max_steps: Maximum number of batches to consume + + Returns: + dict: Contains data_replica_rank, data_replica_group, and consumed_ids + + Example: + For a setup with 2 data replica groups (0 and 1), each with 2 ranks: + - Group 0: ranks [0, 1] receive identical samples + - Group 1: ranks [2, 3] receive identical samples + All ranks within the same group get the same global indexes. + + Note: + The StreamingDataLoader yields tuples of (batch, batch_meta) where: + - batch: TensorDict containing the requested data fields + - batch_meta: Metadata for TransferQueue coordination (contains global_indexes) + """ + + # Step 1: Create StreamingDataset + # This dataset integrates with TransferQueue and handles batch retrieval + dataset = StreamingDataset( + config=config, + micro_batch_size=2, # Number of samples per batch + required_fields=["meta_idx"], # Fields to retrieve from storage. We can retrieve partial fields! + partition_id="train", # Data partition to consume from + task_name="update_task", # Unique task identifier + data_replica_group=data_replica_group, + data_replica_rank=data_replica_rank, + data_replica_world_size=data_replica_world_size, + ) + print(f"[Update Worker@{rank_id}] StreamingDataset created successfully") + + # Step 2: Create StreamingDataLoader + # Wraps the dataset and provides PyTorch DataLoader-compatible interface + dataloader = StreamingDataLoader( + dataset=dataset, + num_workers=2, # We can enable parallel data retrieval and data pre-fetching! + prefetch_factor=2, + ) + print( + f"[Update Worker@{rank_id}] StreamingDataLoader ready, enabling data pre-fetching through num_workers " + f"and prefetch_factor." + ) + + # Step 3: Consume data batches + print(f"[Update Worker@{rank_id}] Starting data consumption...") + consumed_ids = [] + step = 0 + + for batch, batch_meta in dataloader: + # Extract sample IDs from the batch + ids = batch["meta_idx"].view(-1).tolist() + + print( + f"[Update Worker@{rank_id}]: data_replica_rank {data_replica_rank} in " + f"data_replica_group {data_replica_group} retrieved samples: {ids}" + ) + consumed_ids.extend(ids) + + # Simulate processing time (remove in real training) + time.sleep(5) + + step += 1 + if step >= max_steps: + print(f"[Update Worker@{rank_id}] Reached max steps ({max_steps}), stopping...") + break + + print(f"[Update Worker@{rank_id}] Completed {step} steps, consumed {len(consumed_ids)} samples") + + return { + "data_replica_rank": data_replica_rank, + "data_replica_group": data_replica_group, + "consumed_ids": consumed_ids, + } + + +def start_all_generate_actors(config): + """ + Launch generate_actors for producing training samples. + """ + num_workers = 2 + handlers = [] + + for i in range(num_workers): + handlers.append(generate_worker.remote(rank_id=i, config=config)) + + return handlers + + +def start_all_update_actors(config): + """ + Launch update_actors for consuming training samples. + """ + + # Define the distributed training topology + rank_ids = [0, 1, 2, 3] + data_replica_group = [0, 0, 1, 1] # First two ranks in group 0, last two in group 1 + data_replica_world_size = 2 # Each group has 2 ranks + + print("Training topology configuration:") + print(f" - Total ranks: {len(rank_ids)}") + print(f" - Data replica groups: {len(set(data_replica_group))}") + print(f" - World size per group: {data_replica_world_size}") + print(f" - Group assignments: {dict(zip(rank_ids, data_replica_group, strict=False))}") + + handlers = [] + for i in range(len(rank_ids)): + handlers.append( + update_worker.remote( + rank_id=rank_ids[i], + data_replica_group=data_replica_group[i], + data_replica_rank=rank_ids[i] % data_replica_world_size, + data_replica_world_size=data_replica_world_size, + config=config, + ) + ) + + return handlers + + +def main(): + """ + Main function demonstrating end-to-end streaming data loading. + + This tutorial showcases: + 1. Setting up TransferQueue with streaming capabilities + 2. Launching data generation actors + 3. Launching data consumption actors with distributed training topology + 4. Verifying that ranks in the same group receive identical samples + """ + + print("=" * 80) + print( + textwrap.dedent( + """ + TransferQueue Tutorial 5: StreamingDataLoader for Distributed Training + + This tutorial demonstrates the StreamingDataLoader interface for distributed + training scenarios. It showcases how to use StreamingDataset and StreamingDataLoader + to efficiently consume micro-batch of samples from TransferQueue with proper coordination + across multiple training ranks. + + Key Concepts: + - StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue + - StreamingDataLoader: DataLoader wrapper yielding (batch, batch_meta) tuples + - RankAwareSampler: Enables correct data consumption across data replica ranks + - Data Replica Group: Ranks that should receive identical data samples (TP, PP, ...) + + """ + ) + ) + print("=" * 80) + + # Step 1: Setup TransferQueue infrastructure + print("\n[Phase 1] Setting up TransferQueue infrastructure...") + controller, storage_units, config = setup_transfer_queue() + + # Step 2: Launch data generation actors + print("\n[Phase 2] Starting data generation...") + generate_worker_handlers = start_all_generate_actors(config) + + # Step 3: Launch data consumption actors + print("\n[Phase 3] Starting data consumption...") + update_worker_handlers = start_all_update_actors(config) + + # Wait for completion + print("\n[Phase 4] Waiting for actors to complete...") + print("=" * 80) + + # Wait for generation to complete + ray.get(generate_worker_handlers) + print("✓ All generation actors completed") + + # Wait for consumption to complete + update_results = ray.get(update_worker_handlers) + print("✓ All update actors completed") + + # Display results summary + print("\n" + "=" * 80) + print("Results Summary") + print("=" * 80) + for result in update_results: + print( + f" Rank {result['data_replica_rank']} (Group {result['data_replica_group']}): " + f"consumed {len(result['consumed_ids'])} samples" + ) + + print("\n" + "=" * 80) + print("Tutorial Complete!") + print("=" * 80) + print("Key Takeaways:") + print("1. StreamingDataset provides PyTorch IterableDataset interface for TransferQueue") + print("2. StreamingDataLoader wraps the dataset and yields (batch, batch_meta) tuples") + print("3. Ranks in the same data_replica_group receive identical samples") + print("4. The system enables efficient streaming capabilities") + + +if __name__ == "__main__": + main()