From 1a75b60f85d0cb941997fbfe8c033a118f566970 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:01:02 +0800 Subject: [PATCH 01/17] fix for async retrival and data preload Signed-off-by: 0oshowero0 --- transfer_queue/sampler/rank_aware_sampler.py | 85 +++++++++++--------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index b369ca55..a44c4cb4 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -48,15 +48,18 @@ def __init__(self): within the same DP group. This state tracks which samples have been sampled and how many times they have been fetched. """ + super().__init__() def sample( self, ready_indexes: list[int], batch_size: int, - dp_group: int, - dp_world_size: int, - world_size: int, + data_replica_group: int, + data_replica_rank: int, + data_replica_world_size: int, + task_name: str, + partition_id: str, *args: Any, **kwargs: Any, ) -> tuple[list[int], list[int]]: @@ -67,66 +70,74 @@ def sample( from ``ready_indexes`` and caches the result. Subsequent ranks in the same DP group receive the cached indices directly. + Internal state structure (self._states): + + .. code-block:: python + + self._states = { + "partition_id": { + "task_name": { + data_replica_group: { + data_replica_rank: [sampled_indexes] # Cached sampled indices + } + } + } + } + + State lifecycle: + 1. First rank samples from ``ready_indexes``, caches results for other ranks + 2. Other ranks pop and retrieve the cached indices + Args: ready_indexes: List of global indices for which all required fields of the corresponding samples have been produced, and the samples are not labeled as consumed in the corresponding task. batch_size: Number of samples to select. If larger than available ready samples, all available samples will be returned. - dp_group: The group id of current data parallel group. Used to - identify which DP group this rank belongs to. - dp_world_size: Number of ranks in the data parallelism group. Used to - determine when all ranks have fetched their samples. - world_size: Total number of ranks across all parallelism dimensions. - Used to determine when all ranks have fetched their samples. + data_replica_group: The group id of current data replica group. Used to + identify which data replica group this rank belongs to. + data_replica_rank: Local rank inside this data_replica_group. + data_replica_world_size: Total number of ranks in this data_replica_group. + task_name: Identifier for the task. + partition_id: Partition ID for data management. *args: Additional positional arguments (ignored). **kwargs: Additional keyword arguments (ignored). Returns: - List of sampled global indices. Typically, has length `batch_size`, - or returns an empty list if samples are insufficient. + Tuple of two lists: + - List of sampled global indices. Typically, has length ``batch_size``, + or empty if samples are insufficient. + - List of global indices to mark as consumed (excluded from future + retrieval by other data_replica_groups). - List of global indices that should be labeled as consumed - (will never be retrieved by other dp_groups in the future). - - Raises: - RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``. """ - # Check if this DP group already has sampled data cached - data_for_dp_group = self._states.get(dp_group, None) + if partition_id not in self._states: + self._states[partition_id] = {} - # Calculate how many times this batch should be fetched across all ranks - if dp_world_size <= 0 or world_size % dp_world_size != 0: - raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})") + if task_name not in self._states[partition_id]: + self._states[partition_id][task_name] = {} - fetches_per_batch = world_size // dp_world_size + if data_replica_group not in self._states[partition_id][task_name]: + self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} - if data_for_dp_group is None: + if len(self._states[partition_id][task_name][data_replica_group][data_replica_rank]) == 0: # Select first batch_size indices from ready_indexes sampled_indexes = ready_indexes[:batch_size] if len(sampled_indexes) < batch_size: return [], [] - # Initialize state for this DP group - self._states[dp_group] = {} consumed_indexes = sampled_indexes - # Cache the sampled indices for other ranks in this DP group - self._states[dp_group]["index"] = sampled_indexes - self._states[dp_group]["fetch_count"] = 1 + # Cache the sampled indices for other ranks in this data replica group + for i in range(data_replica_world_size): + if i != data_replica_rank: + self._states[partition_id][task_name][data_replica_group][i].append(sampled_indexes) else: # Return the cached indices (identical to what first rank received) - sampled_indexes = self._states[dp_group]["index"] - consumed_indexes = self._states[dp_group]["index"] - - # Increment fetch count to track progress - self._states[dp_group]["fetch_count"] += 1 - - # Check if this was the last rank in the DP group to fetch - if self._states[dp_group]["fetch_count"] >= fetches_per_batch: - del self._states[dp_group] + sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop() + consumed_indexes = sampled_indexes return sampled_indexes, consumed_indexes From 2f81d3f1cad4c3631a56f5036cbadaaadff9611f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:32:14 +0800 Subject: [PATCH 02/17] update CI Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 249 ++++++++++++++++--- transfer_queue/sampler/rank_aware_sampler.py | 2 +- 2 files changed, 211 insertions(+), 40 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 0d665b62..4cf9864b 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -445,66 +445,93 @@ def test_rank_aware_sampler_initialization(self): assert sampler._states == {} def test_rank_aware_sampler_first_rank_sampling(self): - """Test that first rank in DP group performs actual sampling.""" + """Test that first rank in data replica group performs actual sampling.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - # When world_size == dp_world_size, fetches_per_batch = 1 - # First rank samples and immediately marks consumed (no other ranks to wait for) - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + # Rank 0 (first in group) samples and caches for all ranks + # Since rank 1 will call next, state is kept until rank 1 fetches + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [0, 1, 2] - # consumed is returned assert consumed == [0, 1, 2] assert len(sampled) == batch_size - # State should be cleaned up - assert sampler._states == {} + # State is kept for other ranks to fetch def test_rank_aware_sampler_second_rank_gets_cached(self): - """Test that second rank in DP group gets cached indices.""" + """Test that second rank in data replica group gets cached indices.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - dp_world_size = 2 - world_size = 4 # Use world_size=4 so fetches_per_batch=2 - # Rank 0 (dp_group=0) samples first + # Rank 0 (first in group) samples first sampled1, consumed1 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", ) - # Rank 1 (dp_group=0) should get same cached indices + # Rank 1 (second in group) should get same cached indices sampled2, consumed2 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="test", + partition_id="test", ) assert sampled1 == sampled2 == [0, 1, 2] - # First rank already returns consumed indexes assert consumed1 == [0, 1, 2] - # Second rank also sees the same consumed indexes; state is then cleaned up assert consumed2 == [0, 1, 2] - # State should be cleaned up - assert sampler._states == {} + + # cache should be empty after all ranks fetch + assert len(sampler._states["test"]["test"][0][0]) == 0 + assert len(sampler._states["test"]["test"][0][1]) == 0 def test_rank_aware_sampler_multiple_dp_groups(self): - """Test that multiple DP groups work independently.""" + """Test that multiple data replica groups work independently.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 2 - dp_world_size = 4 - world_size = 8 + data_replica_world_size = 2 # Each group has 2 ranks - # DP group 0: rank 0 samples first + # data replica group 0: rank 0 samples first sampled0_g0, consumed0_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) # mimic the consumption status update managed in TransferQueueController ready_indexes = [i for i in ready_indexes if i not in consumed0_g0] - # DP group 1: rank 0 samples first + # data replica group 1: rank 0 samples first sampled0_g1, consumed0_g1 = sampler.sample( - ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=1, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] @@ -514,39 +541,66 @@ def test_rank_aware_sampler_multiple_dp_groups(self): assert consumed0_g0 == [0, 1] assert consumed0_g1 == [2, 3] - # DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed + # data replica group 0: rank 1 fetches cached sampled1_g0, consumed1_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] assert sampled1_g0 == [0, 1] assert consumed1_g0 == [0, 1] - # DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed + # data replica group 1: rank 1 fetches cached sampled1_g1, consumed1_g1 = sampler.sample( - ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=1, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] assert sampled1_g1 == [2, 3] assert consumed1_g1 == [2, 3] - # DP group 0: rank 0 fetches again, this should return new data + # data replica group 0: rank 0 fetches again, this should return new data sampled2_g0, consumed2_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] assert sampled2_g0 == [4, 5] assert consumed2_g0 == [4, 5] - # DP group 0: rank 1 fetches cached + # data replica group 0: rank 1 fetches cached sampled3_g0, consumed3_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) assert sampled3_g0 == [4, 5] assert consumed3_g0 == [4, 5] - # Both groups should be cleaned up - assert sampler._states == {} + # examine the internal state to ensure proper caching and clearing + assert len(sampler._states["test"]["test"][0][0]) == 0 + assert len(sampler._states["test"]["test"][0][1]) == 0 + assert len(sampler._states["test"]["test"][1][0]) == 0 + assert len(sampler._states["test"]["test"][1][1]) == 0 def test_rank_aware_sampler_empty_ready_indexes(self): """Test behavior with empty ready indexes.""" @@ -554,7 +608,15 @@ def test_rank_aware_sampler_empty_ready_indexes(self): ready_indexes = [] batch_size = 3 - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [] assert consumed == [] @@ -565,8 +627,15 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self): ready_indexes = [0, 1] batch_size = 5 - # When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [] assert consumed == [] @@ -577,11 +646,113 @@ def test_rank_aware_sampler_zero_batch_size(self): ready_indexes = [0, 1, 2, 3] batch_size = 0 - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [] assert consumed == [] + def test_rank_aware_sampler_data_prefetch(self): + """Test behavior with data prefetch.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + + sampled_rank0_time0, consumed_rank0_time0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) + + assert sampled_rank0_time0 == [0, 1] + assert consumed_rank0_time0 == [0, 1] + assert sampler._states["test"]["test"][0][0] == [] + assert sampler._states["test"]["test"][0][1] == [[0, 1]] + + ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0] + + sampled_rank0_time1, consumed_rank0_time1 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) + + assert sampled_rank0_time1 == [2, 3] + assert consumed_rank0_time1 == [2, 3] + assert sampler._states["test"]["test"][0][0] == [] + assert sampler._states["test"]["test"][0][1] == [[0, 1], [2, 3]] + + ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1] + + sampled_rank1_time0, consumed_rank1_time0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) + assert sampled_rank1_time0 == [0, 1] + assert consumed_rank1_time0 == [0, 1] + ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] + + assert sampler._states["test"]["test"][0][0] == [] + assert sampler._states["test"]["test"][0][1] == [[2, 3]] + + def test_rank_aware_sampler_multiple_tasks(self): + """Test behavior with multiple tasks.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + + sampled_rank0_task0, consumed_rank0_task0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task0", + partition_id="test", + ) + + assert sampled_rank0_task0 == [0, 1] + assert consumed_rank0_task0 == [0, 1] + assert sampler._states["test"]["task0"][0][0] == [] + assert sampler._states["test"]["task0"][0][1] == [[0, 1]] + + sampled_rank0_task1, consumed_rank0_task1 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="task1", + partition_id="test", + ) + + assert sampled_rank0_task1 == [0, 1] + assert consumed_rank0_task1 == [0, 1] + assert sampler._states["test"]["task0"][0][0] == [] + assert sampler._states["test"]["task0"][0][1] == [[0, 1]] + assert sampler._states["test"]["task1"][0][0] == [[0, 1]] + assert sampler._states["test"]["task1"][0][1] == [] + class TestSamplerIntegration: """Integration tests for samplers.""" diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index a44c4cb4..1984362c 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -137,7 +137,7 @@ def sample( else: # Return the cached indices (identical to what first rank received) - sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop() + sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop(0) consumed_indexes = sampled_indexes return sampled_indexes, consumed_indexes From 0d84f2582db8242ac6943ccdf1da1cb5d38ab9bd Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:33:34 +0800 Subject: [PATCH 03/17] fix Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 54 +++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 4cf9864b..399797c0 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -458,7 +458,7 @@ def test_rank_aware_sampler_first_rank_sampling(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -480,7 +480,7 @@ def test_rank_aware_sampler_second_rank_gets_cached(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -491,7 +491,7 @@ def test_rank_aware_sampler_second_rank_gets_cached(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -500,8 +500,8 @@ def test_rank_aware_sampler_second_rank_gets_cached(self): assert consumed2 == [0, 1, 2] # cache should be empty after all ranks fetch - assert len(sampler._states["test"]["test"][0][0]) == 0 - assert len(sampler._states["test"]["test"][0][1]) == 0 + assert len(sampler._states["test"]["task"][0][0]) == 0 + assert len(sampler._states["test"]["task"][0][1]) == 0 def test_rank_aware_sampler_multiple_dp_groups(self): """Test that multiple data replica groups work independently.""" @@ -517,7 +517,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) # mimic the consumption status update managed in TransferQueueController @@ -530,7 +530,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=1, data_replica_rank=0, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] @@ -548,7 +548,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] @@ -562,7 +562,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=1, data_replica_rank=1, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] @@ -576,7 +576,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] @@ -590,17 +590,17 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) assert sampled3_g0 == [4, 5] assert consumed3_g0 == [4, 5] # examine the internal state to ensure proper caching and clearing - assert len(sampler._states["test"]["test"][0][0]) == 0 - assert len(sampler._states["test"]["test"][0][1]) == 0 - assert len(sampler._states["test"]["test"][1][0]) == 0 - assert len(sampler._states["test"]["test"][1][1]) == 0 + assert len(sampler._states["test"]["task"][0][0]) == 0 + assert len(sampler._states["test"]["task"][0][1]) == 0 + assert len(sampler._states["test"]["task"][1][0]) == 0 + assert len(sampler._states["test"]["task"][1][1]) == 0 def test_rank_aware_sampler_empty_ready_indexes(self): """Test behavior with empty ready indexes.""" @@ -614,7 +614,7 @@ def test_rank_aware_sampler_empty_ready_indexes(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -633,7 +633,7 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -652,7 +652,7 @@ def test_rank_aware_sampler_zero_batch_size(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -671,14 +671,14 @@ def test_rank_aware_sampler_data_prefetch(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) assert sampled_rank0_time0 == [0, 1] assert consumed_rank0_time0 == [0, 1] - assert sampler._states["test"]["test"][0][0] == [] - assert sampler._states["test"]["test"][0][1] == [[0, 1]] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[0, 1]] ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0] @@ -688,14 +688,14 @@ def test_rank_aware_sampler_data_prefetch(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) assert sampled_rank0_time1 == [2, 3] assert consumed_rank0_time1 == [2, 3] - assert sampler._states["test"]["test"][0][0] == [] - assert sampler._states["test"]["test"][0][1] == [[0, 1], [2, 3]] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[0, 1], [2, 3]] ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1] @@ -705,15 +705,15 @@ def test_rank_aware_sampler_data_prefetch(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) assert sampled_rank1_time0 == [0, 1] assert consumed_rank1_time0 == [0, 1] ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] - assert sampler._states["test"]["test"][0][0] == [] - assert sampler._states["test"]["test"][0][1] == [[2, 3]] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[2, 3]] def test_rank_aware_sampler_multiple_tasks(self): """Test behavior with multiple tasks.""" From f808f78625d5647a1dc73778e6f1585e550f986b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:36:01 +0800 Subject: [PATCH 04/17] add check logics Signed-off-by: 0oshowero0 --- transfer_queue/sampler/rank_aware_sampler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index 1984362c..2b18dfa7 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -110,8 +110,17 @@ def sample( - List of global indices to mark as consumed (excluded from future retrieval by other data_replica_groups). + Raises: + ValueError: If ``data_replica_rank`` is invalid. + """ + if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: + raise ValueError( + f"data_replica_rank {data_replica_rank} must bigger than 0 and less than " + f"data_replica_world_size {data_replica_world_size}" + ) + if partition_id not in self._states: self._states[partition_id] = {} From 1e8d0bb5cc83e14b07afa388013090ff5f0d79f6 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 19:09:42 +0800 Subject: [PATCH 05/17] fix Signed-off-by: 0oshowero0 --- transfer_queue/sampler/rank_aware_sampler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index 2b18dfa7..fa2a115e 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -24,16 +24,15 @@ class RankAwareSampler(BaseSampler): This sampler is designed for distributed data parallel training scenarios where each rank retrieves data independently. - This sampler guarantees that all ranks within the same DP group receive + This sampler guarantees that all ranks within the same data replica group receive the same sample indices. - The sampler maintains per-DP-group state to coordinate sampling across ranks: + The sampler maintains inner state to coordinate sampling across ranks: - - First rank in a DP group to call :meth:`sample` performs actual sampling from - ``ready_indexes`` and caches the result - - Subsequent ranks in the same DP group retrieve the cached indices - - Once all ranks in the DP group have fetched their samples, the cached state is - cleaned up. + - First rank in a data replica group to call :meth:`sample` performs actual sampling from + ``ready_indexes`` and caches the result for other ranks in the same group + - Subsequent ranks in the same group retrieve the cached indices. + - If no cached indices are available, sampling is performed again and cached for others. Please refer to our roadmap for more details: From 6a9dfb34d8bf9cde467ee1c135ad4b3f3b911b9a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 19:16:58 +0800 Subject: [PATCH 06/17] fix Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 7 +++---- transfer_queue/sampler/rank_aware_sampler.py | 9 ++++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 399797c0..d184bba8 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -710,7 +710,6 @@ def test_rank_aware_sampler_data_prefetch(self): ) assert sampled_rank1_time0 == [0, 1] assert consumed_rank1_time0 == [0, 1] - ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] assert sampler._states["test"]["task"][0][0] == [] assert sampler._states["test"]["task"][0][1] == [[2, 3]] @@ -740,7 +739,7 @@ def test_rank_aware_sampler_multiple_tasks(self): ready_indexes, batch_size, data_replica_group=0, - data_replica_rank=1, + data_replica_rank=0, data_replica_world_size=2, task_name="task1", partition_id="test", @@ -750,8 +749,8 @@ def test_rank_aware_sampler_multiple_tasks(self): assert consumed_rank0_task1 == [0, 1] assert sampler._states["test"]["task0"][0][0] == [] assert sampler._states["test"]["task0"][0][1] == [[0, 1]] - assert sampler._states["test"]["task1"][0][0] == [[0, 1]] - assert sampler._states["test"]["task1"][0][1] == [] + assert sampler._states["test"]["task1"][0][0] == [] + assert sampler._states["test"]["task1"][0][1] == [[0, 1]] class TestSamplerIntegration: diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index fa2a115e..c879556d 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -92,7 +92,7 @@ def sample( corresponding samples have been produced, and the samples are not labeled as consumed in the corresponding task. batch_size: Number of samples to select. If larger than available - ready samples, all available samples will be returned. + ready samples, no samples are returned and both lists are empty. data_replica_group: The group id of current data replica group. Used to identify which data replica group this rank belongs to. data_replica_rank: Local rank inside this data_replica_group. @@ -110,13 +110,16 @@ def sample( retrieval by other data_replica_groups). Raises: - ValueError: If ``data_replica_rank`` is invalid. + ValueError: If ``data_replica_rank`` or ``data_replica_world_size`` is invalid. """ + if data_replica_world_size < 1: + raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") + if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: raise ValueError( - f"data_replica_rank {data_replica_rank} must bigger than 0 and less than " + f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than " f"data_replica_world_size {data_replica_world_size}" ) From 717b2ccb285c54fea8426522da53ac556cc04777 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 19:52:10 +0800 Subject: [PATCH 07/17] implement streaming dataloader Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 5 +- transfer_queue/dataloader/__init__.py | 19 ++ .../dataloader/streaming_dataloader.py | 91 +++++++++ .../dataloader/streaming_dataset.py | 188 ++++++++++++++++++ 4 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 transfer_queue/dataloader/__init__.py create mode 100644 transfer_queue/dataloader/streaming_dataloader.py create mode 100644 transfer_queue/dataloader/streaming_dataset.py diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 6569adab..0f03a122 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -21,6 +21,7 @@ process_zmq_server_info, ) from .controller import TransferQueueController +from .dataloader import StreamingDataLoader, StreamingDataset from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler @@ -32,8 +33,10 @@ __all__ = [ "AsyncTransferQueueClient", - "BatchMeta", "TransferQueueClient", + "StreamingDataset", + "StreamingDataLoader", + "BatchMeta", "TransferQueueController", "SimpleStorageUnit", "ZMQServerInfo", diff --git a/transfer_queue/dataloader/__init__.py b/transfer_queue/dataloader/__init__.py new file mode 100644 index 00000000..87d3ed8a --- /dev/null +++ b/transfer_queue/dataloader/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .streaming_dataloader import StreamingDataLoader +from .streaming_dataset import StreamingDataset + +__all__ = ["StreamingDataset", "StreamingDataLoader"] diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py new file mode 100644 index 00000000..efadf2e7 --- /dev/null +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -0,0 +1,91 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Optional + +import torch + +from transfer_queue.dataloader.streaming_dataset import StreamingDataset + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +# Ensure logger has a handler +if not logger.hasHandlers(): + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) + logger.addHandler(handler) + + +class StreamingDataLoader(torch.utils.data.DataLoader): + """StreamingDataLoader interface for TransferQueue. + + This DataLoader wraps StreamingDataset and provides a familiar PyTorch + DataLoader interface for distributed training. + + """ + + def __init__( + self, + dataset: StreamingDataset, + num_workers: int = 0, + collate_fn=None, + pin_memory: bool = False, + timeout: float = 0, + worker_init_fn=None, + multiprocessing_context=None, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + ): + """Initialize the StreamDataLoader. + + Args: + dataset: StreamingDataset instance. + num_workers: Number of subprocesses for data loading. + collate_fn: Function to collate samples into batches. + pin_memory: If True, pin memory for GPU transfer. + drop_last: If True, drop last incomplete batch. + timeout: Timeout for data loading. + worker_init_fn: Worker initialization function. + multiprocessing_context: Multiprocessing context. + prefetch_factor: Number of batches to prefetch per worker. + persistent_workers: Keep workers alive between epochs. + pin_memory_device: Device for pin_memory. + """ + + # Store reference to dataset for data retrieval + self.dataset = dataset + + super().__init__( + dataset=dataset, + batch_size=None, # Batch size is handled by the dataset + shuffle=None, + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=False, + timeout=timeout, + worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, + generator=None, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + pin_memory_device=pin_memory_device, + ) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py new file mode 100644 index 00000000..53ee5e32 --- /dev/null +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -0,0 +1,188 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time +import uuid +from typing import Any, Iterator + +from torch.utils.data import IterableDataset + +from transfer_queue import ( + TransferQueueClient, + ZMQServerInfo, +) + +TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float( + os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1) +) # in seconds + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +# Ensure logger has a handler +if not logger.hasHandlers(): + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) + logger.addHandler(handler) + + +class StreamingDataset(IterableDataset): + """Streaming dataset for distributed training with TransferQueue. + + This dataset is designed to work with RankAwareSampler for distributed training + scenarios where each rank independently retrieves data through TransferQueue. + + Each rank calls the sampler independently, passing its own rank information. + The RankAwareSampler guarantees that all ranks within the same DP group receive + the same sample indices, ensuring no data is duplicated or missed. + + Usage Example: + >>> # On each rank: + >>> dataset = StreamingDataset( + ... config=config, + ... micro_batch_size=4, + ... required_fields=["input_ids", "attention_mask"], + ... partition_id="train", + ... task_name="update_actor", + ... dp_group=rank // dp_world_size, # DP group ID for this rank + ... dp_world_size=dp_world_size, + ... world_size=total_world_size, + ... rank=rank, + ... ) + >>> dataloader = StreamingDataLoader( + ... dataset, + ... batch_size=micro_batch_size, + ... num_workers=0, + ... ) + >>> for batch in dataloader: + ... # batch is a TensorDict with the requested fields + ... pass + """ + + def __init__( + self, + config: dict[str, Any], + micro_batch_size: int, + required_fields: list[str], + partition_id: str, + task_name: str, + data_replica_group: int, + data_replica_rank: int, + data_replica_world_size: int, + ): + """Initialize the StreamingDataset. + + Args: + config: Configuration dictionary containing: + - controller_info: ZMQServerInfo for the TransferQueueController + - storage_backend: Storage backend type (e.g., "SimpleAsyncStorageManager") + micro_batch_size: Number of samples per micro-batch. + required_fields: List of field names to retrieve from storage. + partition_id: Partition ID for data versioning. + task_name: Unique identifier for the training task. + data_replica_group: The group id of current data replica group. Used to + identify which data replica group this rank belongs to. + data_replica_rank: Local rank inside this data_replica_group. + data_replica_world_size: Total number of ranks in this data_replica_group. + """ + + self.config = config + self.micro_batch_size = micro_batch_size + self.required_fields = required_fields + self.partition_id = partition_id + self.task_name = task_name + self.data_replica_group = data_replica_group + self.data_replica_rank = data_replica_rank + self.data_replica_world_size = data_replica_world_size + + if data_replica_world_size < 1: + raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") + + if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: + raise ValueError( + f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than " + f"data_replica_world_size {data_replica_world_size}" + ) + + # Build sampling config for controller + self.sampling_config = { + "data_replica_group": self.data_replica_group, + "data_replica_rank": self.data_replica_rank, + "data_replica_world_size": self.data_replica_world_size, + "task_name": self.task_name, + "partition_id": self.partition_id, + } + + self._tq_client = None + + super().__init__() + + def _create_client(self): + client_id = uuid.uuid4().hex[:8] + controller_info = self.config.get("controller_info", None) + if not controller_info or not isinstance(controller_info, ZMQServerInfo): + raise ValueError("Invalid or missing controller_info in config") + + storage_backend = self.config.get("storage_backend", None) + if not storage_backend: + raise ValueError("Missing storage_backend in config") + + self._tq_client = TransferQueueClient(client_id, self.config[controller_info]) + self._tq_client.initialize_storage_manager(manager_type=self.config.storage_backend, config=self.config) + + def __iter__(self) -> Iterator[Any]: + """Iterate over the dataset, yielding batches of data. + + Yields: + Tuple[TensorDict, BatchMeta]: A tuple containing: + - TensorDict: Batch of data with the requested fields. + - BatchMeta: Corresponding metadata to interact with TransferQueue. + Note: + This iterator runs indefinitely until the data source is exhausted. + The caller should handle StopIteration when appropriate (e.g., when + all data has been consumed and no more data will be produced). + """ + if self._tq_client is None: + self._create_client() + + while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): + try: + # Get metadata from controller + batch_meta = self._tq_client.get_meta( + data_fields=self.required_fields, + batch_size=self.micro_batch_size, + partition_id=self.partition_id, + task_name=self.task_name, + sampling_config=self.sampling_config, + ) + + # Check if we got valid data + if batch_meta.size == 0: + logger.debug( + f"[StreamingDataset]: Received empty batch, waiting for more data... " + f"Required batch_size={self.micro_batch_size}, data_fields={self.required_fields}," + f"partition_id={self.partition_id}, task_name={self.task_name}." + ) + + time.sleep(TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL) + else: + batch = self._tq_client.get_data(batch_meta) + yield (batch, batch_meta) + + except Exception as e: + logger.error(f"[StreamingDataset]: Error in data iteration: {e}") + raise From 85c1f992295568a3683d3a93e0ec78d0e00a0f16 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 20:24:42 +0800 Subject: [PATCH 08/17] fix Signed-off-by: 0oshowero0 --- .../dataloader/streaming_dataloader.py | 40 ++++++++++++-- .../dataloader/streaming_dataset.py | 54 ++++++++++--------- 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index efadf2e7..553bd2bd 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -34,9 +34,38 @@ class StreamingDataLoader(torch.utils.data.DataLoader): """StreamingDataLoader interface for TransferQueue. - This DataLoader wraps StreamingDataset and provides a familiar PyTorch - DataLoader interface for distributed training. + This DataLoader wraps StreamingDataset and provides a PyTorch DataLoader + interface for distributed training with streaming data access. + Key Features: + - Compatible with PyTorch training loops (for loop iteration) + - Works with StreamingDataset for streaming data access + - Supports distributed training via RankAwareSampler coordination + + + Note: + This DataLoader is typically used with StreamingDataset which manages + batch size internally. The standard PyTorch DataLoader batch_size + parameter is set to None because batching is handled by the dataset + in coordination with TransferQueue's sampling logic. + + Example: + >>> dataset = StreamingDataset( + ... config=config, + ... micro_batch_size=4, + ... required_fields=["input_ids", "attention_mask"], + ... partition_id="train", + ... task_name="update_actor", + ... data_replica_group=0, + ... data_replica_rank=0, + ... data_replica_world_size=1, + ... ) + >>> dataloader = StreamingDataLoader(dataset, num_workers=0) + >>> for batch, batch_meta in dataloader: + ... # batch: TensorDict with requested fields + ... # batch_meta: Metadata for TransferQueue coordination + ... loss = model(batch) + ... loss.backward() """ def __init__( @@ -59,13 +88,18 @@ def __init__( num_workers: Number of subprocesses for data loading. collate_fn: Function to collate samples into batches. pin_memory: If True, pin memory for GPU transfer. - drop_last: If True, drop last incomplete batch. timeout: Timeout for data loading. worker_init_fn: Worker initialization function. multiprocessing_context: Multiprocessing context. prefetch_factor: Number of batches to prefetch per worker. persistent_workers: Keep workers alive between epochs. pin_memory_device: Device for pin_memory. + + Note: + This DataLoader is designed to work with StreamingDataset which handles + batch size internally via the micro_batch_size parameter. The batch_size + parameter in PyTorch DataLoader is set to None because batching is managed + by the StreamingDataset in coordination with RankAwareSampler. """ # Store reference to dataset for data retrieval diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 53ee5e32..f51febf0 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -19,12 +19,12 @@ import uuid from typing import Any, Iterator +from tensordict import TensorDict from torch.utils.data import IterableDataset -from transfer_queue import ( - TransferQueueClient, - ZMQServerInfo, -) +from transfer_queue import TransferQueueClient +from transfer_queue.metadata import BatchMeta +from transfer_queue.utils.zmq_utils import ZMQServerInfo TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float( os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1) @@ -46,30 +46,25 @@ class StreamingDataset(IterableDataset): This dataset is designed to work with RankAwareSampler for distributed training scenarios where each rank independently retrieves data through TransferQueue. - Each rank calls the sampler independently, passing its own rank information. - The RankAwareSampler guarantees that all ranks within the same DP group receive - the same sample indices, ensuring no data is duplicated or missed. - Usage Example: - >>> # On each rank: + >>> # On each rank (same data_replica_group, different data_replica_rank): >>> dataset = StreamingDataset( ... config=config, ... micro_batch_size=4, ... required_fields=["input_ids", "attention_mask"], ... partition_id="train", ... task_name="update_actor", - ... dp_group=rank // dp_world_size, # DP group ID for this rank - ... dp_world_size=dp_world_size, - ... world_size=total_world_size, - ... rank=rank, + ... data_replica_group=dp_group_id, # Same for all ranks in DP group + ... data_replica_rank=local_rank, # Different for each rank + ... data_replica_world_size=dp_world_size, ... ) >>> dataloader = StreamingDataLoader( ... dataset, - ... batch_size=micro_batch_size, ... num_workers=0, ... ) - >>> for batch in dataloader: + >>> for batch, batch_meta in dataloader: ... # batch is a TensorDict with the requested fields + ... # batch_meta contains metadata for TransferQueue coordination ... pass """ @@ -90,14 +85,25 @@ def __init__( config: Configuration dictionary containing: - controller_info: ZMQServerInfo for the TransferQueueController - storage_backend: Storage backend type (e.g., "SimpleAsyncStorageManager") - micro_batch_size: Number of samples per micro-batch. - required_fields: List of field names to retrieve from storage. - partition_id: Partition ID for data versioning. - task_name: Unique identifier for the training task. - data_replica_group: The group id of current data replica group. Used to - identify which data replica group this rank belongs to. - data_replica_rank: Local rank inside this data_replica_group. + - Other backend-specific configuration + micro_batch_size: Number of samples per micro-batch. This is the batch size + that will be requested from TransferQueue for each iteration. + required_fields: List of field names to retrieve from storage. Only these + fields will be included in the returned batch. + partition_id: Partition ID for data versioning. Different partitions can + be used for different data versions or splits (e.g., "train", "val"). + task_name: Unique identifier for the training task. This is used to track + which samples have been consumed by which task. + data_replica_group: The group ID of the current data replica group. All + ranks with the same data_replica_group will receive identical samples. + data_replica_rank: Local rank index within the data_replica_group. Range: + [0, data_replica_world_size - 1] data_replica_world_size: Total number of ranks in this data_replica_group. + Must be >= 1. + + Raises: + ValueError: If data_replica_world_size < 1 or data_replica_rank is out of + valid range [0, data_replica_world_size - 1]. """ self.config = config @@ -141,10 +147,10 @@ def _create_client(self): if not storage_backend: raise ValueError("Missing storage_backend in config") - self._tq_client = TransferQueueClient(client_id, self.config[controller_info]) + self._tq_client = TransferQueueClient(client_id, controller_info) self._tq_client.initialize_storage_manager(manager_type=self.config.storage_backend, config=self.config) - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: """Iterate over the dataset, yielding batches of data. Yields: From 13177e14f9ea89a4ba1b7d0189d9f0d17d27ac72 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 15:41:11 +0800 Subject: [PATCH 09/17] fix Signed-off-by: 0oshowero0 --- transfer_queue/dataloader/streaming_dataloader.py | 3 +-- transfer_queue/dataloader/streaming_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index 553bd2bd..c46652ea 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -74,7 +74,6 @@ def __init__( num_workers: int = 0, collate_fn=None, pin_memory: bool = False, - timeout: float = 0, worker_init_fn=None, multiprocessing_context=None, prefetch_factor: Optional[int] = None, @@ -115,7 +114,7 @@ def __init__( collate_fn=collate_fn, pin_memory=pin_memory, drop_last=False, - timeout=timeout, + timeout=0, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=None, diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index f51febf0..300db516 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -84,7 +84,7 @@ def __init__( Args: config: Configuration dictionary containing: - controller_info: ZMQServerInfo for the TransferQueueController - - storage_backend: Storage backend type (e.g., "SimpleAsyncStorageManager") + - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager") - Other backend-specific configuration micro_batch_size: Number of samples per micro-batch. This is the batch size that will be requested from TransferQueue for each iteration. @@ -148,7 +148,7 @@ def _create_client(self): raise ValueError("Missing storage_backend in config") self._tq_client = TransferQueueClient(client_id, controller_info) - self._tq_client.initialize_storage_manager(manager_type=self.config.storage_backend, config=self.config) + self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=self.config) def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: """Iterate over the dataset, yielding batches of data. From 3947fc3b419b20983a33d2a99bf4a1de9605765d Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 12:20:25 +0800 Subject: [PATCH 10/17] fix Signed-off-by: 0oshowero0 --- .../dataloader/streaming_dataloader.py | 15 +- .../dataloader/streaming_dataset.py | 2 +- tutorial/05_streaming_dataloader.py | 381 ++++++++++++++++++ 3 files changed, 396 insertions(+), 2 deletions(-) create mode 100644 tutorial/05_streaming_dataloader.py diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index c46652ea..79b66cef 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -15,11 +15,13 @@ import logging import os -from typing import Optional +from typing import Iterator, Optional import torch +from tensordict import TensorDict from transfer_queue.dataloader.streaming_dataset import StreamingDataset +from transfer_queue.metadata import BatchMeta logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -122,3 +124,14 @@ def __init__( persistent_workers=persistent_workers, pin_memory_device=pin_memory_device, ) + + def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: + """Iterate over the dataset, yielding batches with metadata. + + Yields: + Tuple[TensorDict, BatchMeta]: A tuple containing: + - TensorDict: Batch of data with the requested fields. + - BatchMeta: Corresponding metadata to interact with TransferQueue. + """ + # Directly iterate over the dataset, which yields (batch, batch_meta) tuples + return iter(self.dataset) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 300db516..b2427fcd 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -105,7 +105,6 @@ def __init__( ValueError: If data_replica_world_size < 1 or data_replica_rank is out of valid range [0, data_replica_world_size - 1]. """ - self.config = config self.micro_batch_size = micro_batch_size self.required_fields = required_fields @@ -165,6 +164,7 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: if self._tq_client is None: self._create_client() + # TODO: need to consider async scenario where the samples in partition is dynamically increasing while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): try: # Get metadata from controller diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py new file mode 100644 index 00000000..3c5bbf5f --- /dev/null +++ b/tutorial/05_streaming_dataloader.py @@ -0,0 +1,381 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 5: Streaming DataLoader for Distributed Training + +This script demonstrates how to use StreamingDataset and StreamingDataLoader +for efficient streaming data loading in distributed training scenarios. + +Key Components: +- StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue +- StreamingDataLoader: DataLoader wrapper that yields (batch, batch_meta) tuples +- RankAwareSampler: Enables data replica group coordination for consistent + sampling across multiple ranks + +Use Cases: +- Distributed training with multiple data replica groups +- Fine-grained micro-batch-level data retrieval +""" + +import sys +import textwrap +import time +import warnings +from pathlib import Path + +warnings.filterwarnings( + action="ignore", + message=r"The PyTorch API of nested tensors is in prototype stage*", + category=UserWarning, + module=r"torch\.nested", +) + +warnings.filterwarnings( + action="ignore", + message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible " + r"devices env var if num_gpus=0 or num_gpus=None.*", + category=FutureWarning, + module=r"ray\._private\.worker", +) + + +import ray # noqa: E402 +import torch # noqa: E402 +from omegaconf import DictConfig, OmegaConf # noqa: E402 +from tensordict import TensorDict # noqa: E402 + +# Add the parent directory to the path for imports +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + + +from transfer_queue import ( # noqa: E402 + RankAwareSampler, + SimpleStorageUnit, + StreamingDataLoader, + StreamingDataset, + TransferQueueClient, + TransferQueueController, + process_zmq_server_info, +) + + +def setup_transfer_queue(): + """Setup TransferQueue components.""" + if not ray.is_initialized(): + ray.init() + + config = OmegaConf.create( + { + "num_data_storage_units": 2, + } + ) + + storage_units = {} + for i in range(config["num_data_storage_units"]): + storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) + + print("[Setup]: Setup TransferQueue components") + print( + "Note: Using RankAwareSampler when each rank retrieves data independently. It guarantees that " + "all ranks within the same data replica group receive the same sample indices." + ) + print( + "Note: When using streaming data retrieval, please set polling_mode=True when initializing " + "TransferQueueController. In polling_mode, the controller will return empty BatchMeta when " + "available data cannot meet the consumption requirements. User side need to retry later." + ) + controller = TransferQueueController.remote( + sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling across ranks in same replica group + polling_mode=True, # Enable polling mode for streaming data retrieval + ) + + controller_info = process_zmq_server_info(controller) + storage_unit_infos = process_zmq_server_info(storage_units) + + # Build the complete configuration + tq_config = OmegaConf.create({}, flags={"allow_objects": True}) + tq_config.controller_info = controller_info + tq_config.storage_unit_infos = storage_unit_infos + config.storage_backend = "AsyncSimpleStorageManager" + config = OmegaConf.merge(tq_config, config) + + return controller, storage_units, config + + +@ray.remote(num_cpus=0.1) +def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): + """ + Generate actor that produces training samples. + + This actor simulates a data producer that generates training samples + and puts them into the TransferQueue for consumption by training actors. + + Args: + rank_id: Unique identifier for this generator (used for sample indexing) + config: TransferQueue configuration + num_samples: Number of samples to generate (default: 10) + + Note: + Each sample has a unique sequence ID calculated as: seq_id = i + (rank_id * 10000) + This ensures global uniqueness across all generator actors. + """ + # Create a client for interacting with TransferQueue + client = TransferQueueClient( + client_id=f"gen_worker_{rank_id}", + controller_info=config.controller_info, + ) + + # Initialize the storage manager for this client + client.initialize_storage_manager(manager_type=config.storage_backend, config=config) + + # Generate and put samples into the queue + for i in range(num_samples): + # Create unique sequence ID for this sample + seq_id = i + (rank_id * 10000) + + # Create sample data as TensorDict + data = TensorDict( + {"input_ids": torch.full((1, 16), seq_id, dtype=torch.long), "meta_idx": torch.tensor([seq_id])}, + batch_size=1, + ) + + print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue") + + # Put data into the specified partition + client.put(data, partition_id="train") + + print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue") + + +@ray.remote(num_cpus=0.1) +def update_worker( + rank_id: int, + data_replica_group: int, + data_replica_rank: int, + data_replica_world_size: int, + config: DictConfig, + max_steps: int = 5, +): + """ + Update actor that retrieves and processes training batches. + + This actor simulates a training worker that consumes data from the + TransferQueue using StreamingDataLoader. It demonstrates how to use + the streaming data loading infrastructure in a distributed setting. + + Args: + rank_id: Global rank identifier for logging and display purposes + data_replica_group: ID of the data parallel group this rank belongs to + Ranks in the same group receive the same data samples + data_replica_rank: Local rank index within the data replica group + Range: [0, data_replica_world_size - 1] + data_replica_world_size: Total number of ranks in this data replica group + config: TransferQueue configuration + max_steps: Maximum number of batches to consume + + Returns: + dict: Contains data_replica_rank, data_replica_group, and consumed_ids + + Example: + For a setup with 2 data replica groups (0 and 1), each with 2 ranks: + - Group 0: ranks [0, 1] receive identical samples + - Group 1: ranks [2, 3] receive identical samples + All ranks within the same group get the same global indexes. + + Note: + The StreamingDataLoader yields tuples of (batch, batch_meta) where: + - batch: TensorDict containing the requested data fields + - batch_meta: Metadata for TransferQueue coordination (contains global_indexes) + """ + + # Step 1: Create StreamingDataset + # This dataset integrates with TransferQueue and handles batch retrieval + dataset = StreamingDataset( + config=config, + micro_batch_size=2, # Number of samples per batch + required_fields=["meta_idx"], # Fields to retrieve from storage. We can retrieve partial fields! + partition_id="train", # Data partition to consume from + task_name="update_task", # Unique task identifier + data_replica_group=data_replica_group, + data_replica_rank=data_replica_rank, + data_replica_world_size=data_replica_world_size, + ) + print(f"[Update Worker@{rank_id}] StreamingDataset created successfully") + + # Step 2: Create StreamingDataLoader + # Wraps the dataset and provides PyTorch DataLoader-compatible interface + dataloader = StreamingDataLoader( + dataset=dataset, + num_workers=0, # We can enable parallel data retrievel and data pre-fetch! + ) + print(f"[Update Worker@{rank_id}] StreamingDataLoader ready") + + # Step 3: Consume data batches + print(f"[Update Worker@{rank_id}] Starting data consumption...") + consumed_ids = [] + step = 0 + + for batch, batch_meta in dataloader: + # Extract sample IDs from the batch + ids = batch["meta_idx"].view(-1).tolist() + + print( + f"[Update Worker@{rank_id}]: data_replica_rank {data_replica_rank} in " + f"data_replica_group {data_replica_group} retrieved samples: {ids}" + ) + consumed_ids.extend(ids) + + # Simulate processing time (remove in real training) + time.sleep(5) + + step += 1 + if step >= max_steps: + print(f"[Update Worker@{rank_id}] Reached max steps ({max_steps}), stopping...") + break + + print(f"[Update Worker@{rank_id}] Completed {step} steps, consumed {len(consumed_ids)} samples") + + return { + "data_replica_rank": data_replica_rank, + "data_replica_group": data_replica_group, + "consumed_ids": consumed_ids, + } + + +def start_all_generate_actors(config): + """ + Launch generate_actors for producing training samples. + """ + num_workers = 2 + handlers = [] + + for i in range(num_workers): + handlers.append(generate_worker.remote(rank_id=i, config=config)) + + return handlers + + +def start_all_update_actors(config): + """ + Launch update_actors for consuming training samples. + """ + + # Define the distributed training topology + rank_ids = [0, 1, 2, 3] + data_replica_group = [0, 0, 1, 1] # First two ranks in group 0, last two in group 1 + data_replica_world_size = 2 # Each group has 2 ranks + + print("Training topology configuration:") + print(f" - Total ranks: {len(rank_ids)}") + print(f" - Data replica groups: {len(set(data_replica_group))}") + print(f" - World size per group: {data_replica_world_size}") + print(f" - Group assignments: {dict(zip(rank_ids, data_replica_group, strict=False))}") + + handlers = [] + for i in range(len(rank_ids)): + handlers.append( + update_worker.remote( + rank_id=rank_ids[i], + data_replica_group=data_replica_group[i], + data_replica_rank=rank_ids[i] % data_replica_world_size, + data_replica_world_size=data_replica_world_size, + config=config, + ) + ) + + return handlers + + +def main(): + """ + Main function demonstrating end-to-end streaming data loading. + + This tutorial showcases: + 1. Setting up TransferQueue with streaming capabilities + 2. Launching data generation actors + 3. Launching data consumption actors with distributed training topology + 4. Verifying that ranks in the same group receive identical samples + """ + + print("=" * 80) + print( + textwrap.dedent( + """ + TransferQueue Tutorial 5: StreamingDataLoader for Distributed Training + + This tutorial demonstrates the StreamingDataLoader interface for distributed + training scenarios. It showcases how to use StreamingDataset and StreamingDataLoader + to efficiently consume micro-batch of samples from TransferQueue with proper coordination + across multiple training ranks. + + Key Concepts: + - StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue + - StreamingDataLoader: DataLoader wrapper yielding (batch, batch_meta) tuples + - RankAwareSampler: Enables correct data consumption across data replica ranks + - Data Replica Group: Ranks that should receive identical data samples (TP, PP, ...) + + """ + ) + ) + print("=" * 80) + + # Step 1: Setup TransferQueue infrastructure + print("\n[Phase 1] Setting up TransferQueue infrastructure...") + controller, storage_units, config = setup_transfer_queue() + + # Step 2: Launch data generation actors + print("\n[Phase 2] Starting data generation...") + generate_worker_handlers = start_all_generate_actors(config) + + # Step 3: Launch data consumption actors + print("\n[Phase 3] Starting data consumption...") + update_worker_handlers = start_all_update_actors(config) + + # Wait for completion + print("\n[Phase 4] Waiting for actors to complete...") + print("=" * 80) + + # Wait for generation to complete + ray.get(generate_worker_handlers) + print("✓ All generation actors completed") + + # Wait for consumption to complete + update_results = ray.get(update_worker_handlers) + print("✓ All update actors completed") + + # Display results summary + print("\n" + "=" * 80) + print("Results Summary") + print("=" * 80) + for result in update_results: + print( + f" Rank {result['data_replica_rank']} (Group {result['data_replica_group']}): " + f"consumed {len(result['consumed_ids'])} samples" + ) + + print("\n" + "=" * 80) + print("Tutorial Complete!") + print("=" * 80) + print("Key Takeaways:") + print("1. StreamingDataset provides PyTorch IterableDataset interface for TransferQueue") + print("2. StreamingDataLoader wraps the dataset and yields (batch, batch_meta) tuples") + print("3. Ranks in the same data_replica_group receive identical samples") + print("4. The system enables efficient streaming capabilities") + + +if __name__ == "__main__": + main() From 61acc862dcee57c6c04e705135c44cf8e4f03869 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 12:28:51 +0800 Subject: [PATCH 11/17] fix comments Signed-off-by: 0oshowero0 --- .../dataloader/streaming_dataloader.py | 6 +----- transfer_queue/dataloader/streaming_dataset.py | 6 ++++++ transfer_queue/sampler/rank_aware_sampler.py | 18 ------------------ 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index 79b66cef..f8ed1c39 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -82,14 +82,13 @@ def __init__( persistent_workers: bool = False, pin_memory_device: str = "", ): - """Initialize the StreamDataLoader. + """Initialize the StreamingDataLoader. Args: dataset: StreamingDataset instance. num_workers: Number of subprocesses for data loading. collate_fn: Function to collate samples into batches. pin_memory: If True, pin memory for GPU transfer. - timeout: Timeout for data loading. worker_init_fn: Worker initialization function. multiprocessing_context: Multiprocessing context. prefetch_factor: Number of batches to prefetch per worker. @@ -103,9 +102,6 @@ def __init__( by the StreamingDataset in coordination with RankAwareSampler. """ - # Store reference to dataset for data retrieval - self.dataset = dataset - super().__init__( dataset=dataset, batch_size=None, # Batch size is handled by the dataset diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index b2427fcd..c0f3d497 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -114,6 +114,12 @@ def __init__( self.data_replica_rank = data_replica_rank self.data_replica_world_size = data_replica_world_size + if micro_batch_size < 1: + raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}") + + if len(self.required_fields) < 1: + raise ValueError(f"required_fields must be a list of more than one field name, got {self.required_fields}") + if data_replica_world_size < 1: raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index 2ce8b2a1..64531fd5 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -87,24 +87,6 @@ def sample( 1. First rank samples from ``ready_indexes``, caches results for other ranks 2. Other ranks pop and retrieve the cached indices - Internal state structure (self._states): - - .. code-block:: python - - self._states = { - "partition_id": { - "task_name": { - data_replica_group: { - data_replica_rank: [sampled_indexes] # Cached sampled indices - } - } - } - } - - State lifecycle: - 1. First rank samples from ``ready_indexes``, caches results for other ranks - 2. Other ranks pop and retrieve the cached indices - Args: ready_indexes: List of global indices for which all required fields of the corresponding samples have been produced, and the samples are not labeled From 2a918bdb01e40325133af35c03049ec2e6a5bc49 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 12:46:03 +0800 Subject: [PATCH 12/17] fix comments Signed-off-by: 0oshowero0 --- .../dataloader/streaming_dataset.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index c0f3d497..0fe9ab6a 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -47,20 +47,20 @@ class StreamingDataset(IterableDataset): scenarios where each rank independently retrieves data through TransferQueue. Usage Example: - >>> # On each rank (same data_replica_group, different data_replica_rank): >>> dataset = StreamingDataset( ... config=config, ... micro_batch_size=4, ... required_fields=["input_ids", "attention_mask"], ... partition_id="train", ... task_name="update_actor", - ... data_replica_group=dp_group_id, # Same for all ranks in DP group - ... data_replica_rank=local_rank, # Different for each rank - ... data_replica_world_size=dp_world_size, + ... data_replica_group=data_replica_group_id, # Same for all ranks in data replica group + ... data_replica_rank=local_rank, # local rank in data replica group + ... data_replica_world_size=world_size/dp_world_size, # size of data replica group ... ) >>> dataloader = StreamingDataLoader( ... dataset, - ... num_workers=0, + ... num_workers=2, # num_workers for data retrieval, each has a TQ client for async data retrieval + ... prefetch_factor=2, # number of batches loaded in advance by each worker ... ) >>> for batch, batch_meta in dataloader: ... # batch is a TensorDict with the requested fields @@ -102,17 +102,8 @@ def __init__( Must be >= 1. Raises: - ValueError: If data_replica_world_size < 1 or data_replica_rank is out of - valid range [0, data_replica_world_size - 1]. + ValueError: If input parameters are invalid. """ - self.config = config - self.micro_batch_size = micro_batch_size - self.required_fields = required_fields - self.partition_id = partition_id - self.task_name = task_name - self.data_replica_group = data_replica_group - self.data_replica_rank = data_replica_rank - self.data_replica_world_size = data_replica_world_size if micro_batch_size < 1: raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}") @@ -129,6 +120,15 @@ def __init__( f"data_replica_world_size {data_replica_world_size}" ) + self.config = config + self.micro_batch_size = micro_batch_size + self.required_fields = required_fields + self.partition_id = partition_id + self.task_name = task_name + self.data_replica_group = data_replica_group + self.data_replica_rank = data_replica_rank + self.data_replica_world_size = data_replica_world_size + # Build sampling config for controller self.sampling_config = { "data_replica_group": self.data_replica_group, From 1425e6576b0b193ebad08c3697bb14cba6df1961 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 12:58:17 +0800 Subject: [PATCH 13/17] update Signed-off-by: 0oshowero0 --- transfer_queue/dataloader/streaming_dataset.py | 2 +- tutorial/05_streaming_dataloader.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 0fe9ab6a..1a99950c 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -108,7 +108,7 @@ def __init__( if micro_batch_size < 1: raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}") - if len(self.required_fields) < 1: + if len(required_fields) < 1: raise ValueError(f"required_fields must be a list of more than one field name, got {self.required_fields}") if data_replica_world_size < 1: diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py index 3c5bbf5f..cd3ef4c4 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/05_streaming_dataloader.py @@ -30,12 +30,15 @@ - Fine-grained micro-batch-level data retrieval """ +import os import sys import textwrap import time import warnings from pathlib import Path +os.environ["RAY_DEDUP_LOGS"] = "0" + warnings.filterwarnings( action="ignore", message=r"The PyTorch API of nested tensors is in prototype stage*", @@ -220,9 +223,13 @@ def update_worker( # Wraps the dataset and provides PyTorch DataLoader-compatible interface dataloader = StreamingDataLoader( dataset=dataset, - num_workers=0, # We can enable parallel data retrievel and data pre-fetch! + num_workers=2, # We can enable parallel data retrieval and data pre-fetching! + prefetch_factor=2, + ) + print( + f"[Update Worker@{rank_id}] StreamingDataLoader ready, enabling data pre-fetching through num_workers " + f"and prefetch_factor." ) - print(f"[Update Worker@{rank_id}] StreamingDataLoader ready") # Step 3: Consume data batches print(f"[Update Worker@{rank_id}] Starting data consumption...") From 8e308ef1c7f7725d07ec4d9a97bc3314fc4fde05 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 13:15:39 +0800 Subject: [PATCH 14/17] fix pre-commit introduced by https://gitcode.com/Ascend/TransferQueue/pull/7 Signed-off-by: 0oshowero0 --- scripts/put_benchmark.py | 162 +++++++++-------------- tests/test_serial_utils_on_cpu.py | 6 +- transfer_queue/client.py | 6 +- transfer_queue/metadata.py | 2 +- transfer_queue/storage/simple_backend.py | 4 +- transfer_queue/utils/serial_utils.py | 24 ++-- transfer_queue/utils/zmq_utils.py | 2 +- 7 files changed, 88 insertions(+), 118 deletions(-) diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index f1c00dd2..1de38893 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -7,6 +7,7 @@ import sys import time from pathlib import Path + import numpy as np import ray import torch @@ -14,17 +15,16 @@ from tensordict import TensorDict from tensordict.utils import LinkedList - parent_dir = Path(__file__).resolve().parent.parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( +from transfer_queue import ( # noqa: E402 AsyncTransferQueueClient, SimpleStorageUnit, TransferQueueController, process_zmq_server_info, ) -from transfer_queue.utils.utils import get_placement_group +from transfer_queue.utils.utils import get_placement_group # noqa: E402 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -33,48 +33,13 @@ # Configuration Map # ========================================================= CONFIG_MAP = { - "debug": { - "global_batch_size": 32, - "seq_length": 128, - "field_num": 2, - "desc": "Debug (~32KB)" - }, - "tiny": { - "global_batch_size": 64, - "seq_length": 1024, - "field_num": 4, - "desc": "Tiny (~1MB)" - }, - "small": { - "global_batch_size": 512, - "seq_length": 12800, - "field_num": 4, - "desc": "Small (~100MB)" - }, - "medium": { - "global_batch_size": 1024, - "seq_length": 65536, - "field_num": 4, - "desc": "Medium (~1GB)" - }, - "large": { - "global_batch_size": 2048, - "seq_length": 128000, - "field_num": 5, - "desc": "Large (~5GB)" - }, - "xlarge": { - "global_batch_size": 4096, - "seq_length": 128000, - "field_num": 5, - "desc": "X-Large (~10GB)" - }, - "huge": { - "global_batch_size": 4096, - "seq_length": 128000, - "field_num": 10, - "desc": "Huge (~20GB)" - } + "debug": {"global_batch_size": 32, "seq_length": 128, "field_num": 2, "desc": "Debug (~32KB)"}, + "tiny": {"global_batch_size": 64, "seq_length": 1024, "field_num": 4, "desc": "Tiny (~1MB)"}, + "small": {"global_batch_size": 512, "seq_length": 12800, "field_num": 4, "desc": "Small (~100MB)"}, + "medium": {"global_batch_size": 1024, "seq_length": 65536, "field_num": 4, "desc": "Medium (~1GB)"}, + "large": {"global_batch_size": 2048, "seq_length": 128000, "field_num": 5, "desc": "Large (~5GB)"}, + "xlarge": {"global_batch_size": 4096, "seq_length": 128000, "field_num": 5, "desc": "X-Large (~10GB)"}, + "huge": {"global_batch_size": 4096, "seq_length": 128000, "field_num": 10, "desc": "Huge (~20GB)"}, } @@ -89,7 +54,7 @@ def calculate_stats(data: list) -> dict: "mean": float(np.mean(data)), "max": float(np.max(data)), "min": float(np.min(data)), - "p99": float(np.percentile(data, 99)) + "p99": float(np.percentile(data, 99)), } @@ -119,7 +84,7 @@ def _generate_nested_tensor(batch_size, total_elements, dtype): # Use Dirichlet distribution to generate random proportions summing to 1 proportions = np.random.dirichlet(np.ones(batch_size)) lengths = (proportions * total_elements).astype(int) - + # Fix rounding errors to ensure exact total element count diff = total_elements - lengths.sum() if diff != 0: @@ -127,10 +92,10 @@ def _generate_nested_tensor(batch_size, total_elements, dtype): indices = np.argsort(lengths)[::-1] for i in range(abs(diff)): lengths[indices[i % batch_size]] += 1 if diff > 0 else -1 - + # Ensure each length is at least 1 lengths = np.maximum(lengths, 1) - + # Generate tensors with different lengths tensors = [] for length in lengths: @@ -138,7 +103,7 @@ def _generate_nested_tensor(batch_size, total_elements, dtype): tensors.append(torch.randint(0, 10000, (int(length),), dtype=dtype)) else: tensors.append(torch.randn(int(length), dtype=dtype)) - + return torch.nested.nested_tensor(tensors, dtype=dtype) @@ -168,7 +133,7 @@ def create_complex_test_case(batch_size, seq_length, field_num): fields[field_name] = tensor_data total_size_bytes += total_elements_per_field * bytes_per_elem - total_size_gb = total_size_bytes / (1024 ** 3) + total_size_gb = total_size_bytes / (1024**3) prompt_batch = TensorDict( fields, @@ -190,18 +155,18 @@ def _compare_nested_tensors(original, retrieved, path): # Unbind to list for element-wise comparison orig_tensors = original.unbind() retr_tensors = retrieved.unbind() - + if len(orig_tensors) != len(retr_tensors): return False, f"[{path}] NestedTensor batch size mismatch: {len(orig_tensors)} vs {len(retr_tensors)}" - - for idx, (orig, retr) in enumerate(zip(orig_tensors, retr_tensors)): + + for idx, (orig, retr) in enumerate(zip(orig_tensors, retr_tensors, strict=False)): if orig.shape != retr.shape: return False, f"[{path}][{idx}] Shape mismatch: {orig.shape} vs {retr.shape}" if orig.dtype != retr.dtype: return False, f"[{path}][{idx}] Dtype mismatch: {orig.dtype} vs {retr.dtype}" if not torch.equal(orig.cpu(), retr.cpu()): return False, f"[{path}][{idx}] Values mismatch" - + return True, "Passed" @@ -214,8 +179,8 @@ def check_data_consistency(original, retrieved, path="root"): retrieved = list(retrieved) # NestedTensor check (must be before regular Tensor since NestedTensor is also a Tensor) - if original.is_nested if hasattr(original, 'is_nested') else False: - if not (retrieved.is_nested if hasattr(retrieved, 'is_nested') else False): + if original.is_nested if hasattr(original, "is_nested") else False: + if not (retrieved.is_nested if hasattr(retrieved, "is_nested") else False): return False, f"[{path}] Type mismatch: NestedTensor vs non-NestedTensor" return _compare_nested_tensors(original, retrieved, path) @@ -255,10 +220,11 @@ def check_data_consistency(original, retrieved, path="root"): # Core Tester Class # ========================================================= + def sync_stage(flag_to_create, flag_to_wait): """Profile sync helper function for synchronizing with external profiler process""" - with open(flag_to_create, 'w') as f: - f.write('1') + with open(flag_to_create, "w") as f: + f.write("1") while not os.path.exists(flag_to_wait): time.sleep(0.05) try: @@ -283,12 +249,14 @@ def __init__(self, target_ip=None, storage_units=8, enable_profile=False): def initialize_system(self, config_dict): """Initialize TransferQueue system based on current configuration""" # Basic config conversion - self.tq_config = OmegaConf.create({ - "global_batch_size": config_dict["global_batch_size"], - "num_global_batch": 1, - "num_data_storage_units": self.num_storage_units, - "num_data_controllers": 1 - }) + self.tq_config = OmegaConf.create( + { + "global_batch_size": config_dict["global_batch_size"], + "num_global_batch": 1, + "num_data_storage_units": self.num_storage_units, + "num_data_controllers": 1, + } + ) total_storage_size = self.tq_config.global_batch_size * 2 @@ -301,9 +269,7 @@ def initialize_system(self, config_dict): num_cpus=1, resources={f"node:{self.target_ip}": 0.001}, runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}}, - ).remote( - storage_unit_size=math.ceil(total_storage_size / self.num_storage_units) - ) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.num_storage_units)) else: # Local Mode: Use placement group self.storage_placement_group = get_placement_group(self.num_storage_units, num_cpus_per_actor=2) @@ -312,9 +278,7 @@ def initialize_system(self, config_dict): placement_group=self.storage_placement_group, placement_group_bundle_index=rank, runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}}, - ).remote( - storage_unit_size=math.ceil(total_storage_size / self.num_storage_units) - ) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.num_storage_units)) # Controller Init self.data_system_controller = TransferQueueController.remote() @@ -331,11 +295,11 @@ def initialize_system(self, config_dict): # Client Init self.data_system_client = AsyncTransferQueueClient( - client_id='Trainer', - controller_info=self.data_system_controller_info + client_id="Trainer", controller_info=self.data_system_controller_info + ) + self.data_system_client.initialize_storage_manager( + manager_type="AsyncSimpleStorageManager", config=self.tq_config ) - self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", - config=self.tq_config) return self.data_system_client @@ -361,6 +325,7 @@ def cleanup(self): # 4. Force garbage collection import gc + gc.collect() # 5. Wait for Ray scheduler to update state (prevent race condition) @@ -370,9 +335,7 @@ def run_benchmark_rounds(self, config_name, config, rounds): """Run multiple rounds of PUT/GET bandwidth tests""" logger.info(f"Generating test data [{config_name}]...") big_input_ids, total_gb = create_complex_test_case( - batch_size=config["global_batch_size"], - seq_length=config["seq_length"], - field_num=config["field_num"] + batch_size=config["global_batch_size"], seq_length=config["seq_length"], field_num=config["field_num"] ) logger.info(f"Data Size: {total_gb:.4f} GB") @@ -387,27 +350,29 @@ def run_benchmark_rounds(self, config_name, config, rounds): # PUT operation start_put = time.time() if i == 0 and self.enable_profile: - sync_stage('init_ready.flag', 'put_start.flag') + sync_stage("init_ready.flag", "put_start.flag") asyncio.run(self.data_system_client.async_put(data=big_input_ids, partition_id=partition_key)) put_time = time.time() - start_put if i == 0 and self.enable_profile: - sync_stage('put_done.flag', 'get_prepare.flag') + sync_stage("put_done.flag", "get_prepare.flag") put_gbps = (total_gb * 8) / put_time put_speeds.append(put_gbps) time.sleep(2) # Get metadata (required step for TQ flow) - prompt_meta = asyncio.run(self.data_system_client.async_get_meta( - data_fields=list(big_input_ids.keys()), - batch_size=big_input_ids.size(0), - partition_id=partition_key, - task_name='generate_sequences', - )) + prompt_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=list(big_input_ids.keys()), + batch_size=big_input_ids.size(0), + partition_id=partition_key, + task_name="generate_sequences", + ) + ) # GET operation start_get = time.time() if i == 0 and self.enable_profile: - sync_stage('get_ready.flag', 'get_start.flag') + sync_stage("get_ready.flag", "get_start.flag") retrieved_data = asyncio.run(self.data_system_client.async_get_data(prompt_meta)) get_time = time.time() - start_get @@ -422,7 +387,7 @@ def run_benchmark_rounds(self, config_name, config, rounds): if not is_consistent: print(f" ❌ FAIL: {msg}") else: - print(f" ✅ PASS", end="") + print(" ✅ PASS", end="") asyncio.run(self.data_system_client.async_clear_partition(partition_id=partition_key)) print("\n") @@ -434,7 +399,7 @@ def make_result(op, speeds): "data_volume": f"{total_gb * 1024:.2f} MB" if total_gb * 1024 < 10 else f"{total_gb:.4f} GB", "operation": op, "payload_gb": total_gb, - "stats_gbps": calculate_stats(speeds) + "stats_gbps": calculate_stats(speeds), } return [make_result("PUT", put_speeds), make_result("GET", get_speeds)] @@ -446,8 +411,9 @@ def make_result(op, speeds): def main(): parser = argparse.ArgumentParser(description="TransferQueue Bandwidth Benchmark") parser.add_argument("--ip", type=str, default=None, help="Worker node IP, local test if not set") - parser.add_argument("--config", type=str, default=None, choices=list(CONFIG_MAP.keys()), - help="Specific config to run") + parser.add_argument( + "--config", type=str, default=None, choices=list(CONFIG_MAP.keys()), help="Specific config to run" + ) parser.add_argument("--output", type=str, default="tq_benchmark_result.json", help="Output JSON file") parser.add_argument("--rounds", type=int, default=20, help="Test rounds per config (default: 20)") parser.add_argument("--shards", type=int, default=8, help="Number of storage units (default: 8)") @@ -458,12 +424,9 @@ def main(): # Initialize Ray current_working_dir = os.getcwd() if not ray.is_initialized(): - ray.init( - address="auto" if args.ip else None, - runtime_env={"working_dir": current_working_dir} - ) + ray.init(address="auto" if args.ip else None, runtime_env={"working_dir": current_working_dir}) - target_address = args.ip if args.ip else '127.0.0.1' + target_address = args.ip if args.ip else "127.0.0.1" logger.info(f"Ray initialized. Target: {target_address}") # Create tester @@ -488,7 +451,7 @@ def main(): logger.info(f"💾 Results saved to {args.output}") except Exception as e: - logger.error(f"❌ Critical error: {e}", exc_info=True) + logger.exception(f"❌ Critical error: {e}") finally: if ray.is_initialized(): ray.shutdown() @@ -497,7 +460,8 @@ def main(): if __name__ == "__main__": try: from transfer_queue.utils import serial_utils - print(f'[Startup Check] TQ_ZERO_COPY_SERIALIZATION = {serial_utils.TQ_ZERO_COPY_SERIALIZATION}') + + print(f"[Startup Check] TQ_ZERO_COPY_SERIALIZATION = {serial_utils.TQ_ZERO_COPY_SERIALIZATION}") except ImportError: - print('[Startup Check] Could not import serial_utils to check flag') + print("[Startup Check] Could not import serial_utils to check flag") main() diff --git a/tests/test_serial_utils_on_cpu.py b/tests/test_serial_utils_on_cpu.py index ababaf27..c7e83c57 100644 --- a/tests/test_serial_utils_on_cpu.py +++ b/tests/test_serial_utils_on_cpu.py @@ -574,7 +574,7 @@ def test_single_nested_tensor_serialization(): def test_large_string_serialization(): """Test serialization of large strings (>10KB). - + Note: msgpack natively handles str type, so enc_hook is not called for strings. This test verifies large strings are correctly serialized/deserialized. """ @@ -583,9 +583,9 @@ def test_large_string_serialization(): # Create a string larger than 10KB large_string = "x" * 11000 # ~11KB - + serialized = encoder.encode({"text": large_string}) - + # Verify content is correctly restored decoded = decoder.decode(serialized) assert decoded["text"] == large_string diff --git a/transfer_queue/client.py b/transfer_queue/client.py index df0a4f78..3bcfcc5e 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -509,7 +509,11 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta if response_msg.request_type != ZMQRequestType.GET_PARTITION_META_RESPONSE: raise RuntimeError("Failed to get metadata for clear operation.") - return BatchMeta.from_dict(response_msg.body["metadata"]) if isinstance(response_msg.body["metadata"], dict) else response_msg.body["metadata"] + return ( + BatchMeta.from_dict(response_msg.body["metadata"]) + if isinstance(response_msg.body["metadata"], dict) + else response_msg.body["metadata"] + ) @dynamic_socket(socket_name="request_handle_socket") async def _clear_partition_in_controller(self, partition_id, socket=None): diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 68a93505..af3f3218 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -68,7 +68,7 @@ def from_dict(cls, data: dict) -> "FieldMeta": dtype=data["dtype"], shape=data["shape"], production_status=ProductionStatus(str(data["production_status"])) - if isinstance(data["production_status"], (int, str)) + if isinstance(data["production_status"], int | str) else data["production_status"], ) diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index 0707a1d0..71399752 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -257,9 +257,7 @@ def _process_put_get(self) -> None: }, ) - self.put_get_socket.send_multipart( - [identity, *response_msg.serialize()], copy=False - ) + self.put_get_socket.send_multipart([identity, *response_msg.serialize()], copy=False) def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: """ diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index a12f48b5..5cd7275a 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -21,7 +21,7 @@ import warnings from collections.abc import Sequence from types import FunctionType -from typing import Any, Optional, TypeAlias +from typing import Any, TypeAlias import cloudpickle import numpy as np @@ -30,7 +30,6 @@ from msgspec import msgpack from tensordict import TensorDictBase - CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference @@ -58,7 +57,7 @@ def __init__(self): # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. - self.aux_buffers: Optional[list[bytestr]] = None + self.aux_buffers: list[bytestr] = [] def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -70,7 +69,7 @@ def encode(self, obj: Any) -> Sequence[bytestr]: # new buffer. return bufs finally: - self.aux_buffers = None + self.aux_buffers = [] def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: try: @@ -79,7 +78,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: self.encoder.encode_into(obj, buf) return bufs finally: - self.aux_buffers = None + self.aux_buffers = [] def enc_hook(self, obj: Any) -> Any: """Custom encoding hook for types msgspec doesn't natively support. @@ -95,7 +94,6 @@ def enc_hook(self, obj: Any) -> Any: # Handle TensorDict explicitly for recursive zero-copy if isinstance(obj, TensorDictBase): return self._encode_tensordict(obj) - # Handle numpy arrays by converting to tensor if isinstance(obj, np.ndarray): @@ -136,7 +134,7 @@ def _encode_tensor(self, obj: torch.Tensor) -> msgpack.Ext: Returns Ext type so decoding goes through ext_hook (which has buffer access). """ - assert self.aux_buffers is not None + assert len(self.aux_buffers) > 0 # Handle nested tensors (strided or jagged) via unbind if obj.is_nested: @@ -167,11 +165,14 @@ def _encode_nested_tensor(self, obj: torch.Tensor) -> msgpack.Ext: def _encode_regular_tensor_meta(self, obj: torch.Tensor) -> tuple: """Encode a regular tensor and return its metadata tuple.""" # Handle non-contiguous tensors + + assert self.aux_buffers is not None + if not obj.is_contiguous(): obj = obj.contiguous() # Handle GPU tensors - if obj.device.type != 'cpu': + if obj.device.type != "cpu": obj = obj.cpu() # Zero-copy buffer extraction via uint8 view @@ -186,11 +187,14 @@ def _encode_regular_tensor_meta(self, obj: torch.Tensor) -> tuple: def _encode_regular_tensor(self, obj: torch.Tensor) -> msgpack.Ext: """Encode a regular (non-nested) tensor with zero-copy.""" # Handle non-contiguous tensors + + assert self.aux_buffers is not None + if not obj.is_contiguous(): obj = obj.contiguous() # Handle GPU tensors - if obj.device.type != 'cpu': + if obj.device.type != "cpu": obj = obj.cpu() if obj.is_sparse: @@ -251,6 +255,7 @@ def _reconstruct_tensordict(self, obj: dict) -> Any: """Reconstruct TensorDict from marked dict structure.""" try: from tensordict import TensorDict + batch_size = obj["batch_size"] data = obj["data"] # Recursively process nested data @@ -305,4 +310,3 @@ def ext_hook(self, code: int, data: memoryview) -> Any: _encoder = MsgpackEncoder() _decoder = MsgpackDecoder() - diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 543e34b9..3ece8edc 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -24,7 +24,7 @@ import psutil import zmq -from transfer_queue.utils.serial_utils import _encoder, _decoder +from transfer_queue.utils.serial_utils import _decoder, _encoder from transfer_queue.utils.utils import ( ExplicitEnum, TransferQueueRole, From ce984ba749a1623b2c61a0d99505fca1d66b4220 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 16:23:01 +0800 Subject: [PATCH 15/17] fix dataloader pre-fetch Signed-off-by: 0oshowero0 --- .../dataloader/streaming_dataloader.py | 32 +++++++++++-------- .../dataloader/streaming_dataset.py | 2 +- tutorial/05_streaming_dataloader.py | 2 +- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index f8ed1c39..374d5155 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -15,7 +15,7 @@ import logging import os -from typing import Iterator, Optional +from typing import Optional import torch from tensordict import TensorDict @@ -33,6 +33,16 @@ logger.addHandler(handler) +def _identity_collate_fn(data: tuple[TensorDict, BatchMeta]) -> tuple[TensorDict, BatchMeta]: + """Identity collate function for TransferQueue. + + This function acts as a pass-through, preserving the `(TensorDict, BatchMeta)` + structure yielded by `StreamingDataset`. It prevents PyTorch from attempting + to stack or modify the already-batched data. + """ + return data + + class StreamingDataLoader(torch.utils.data.DataLoader): """StreamingDataLoader interface for TransferQueue. @@ -102,6 +112,13 @@ def __init__( by the StreamingDataset in coordination with RankAwareSampler. """ + if collate_fn is None: + # use identical collate function to directly return the self-defined + # [TensorDict, BatchMeta] output of StreamingDataset + final_collate_fn = _identity_collate_fn + else: + final_collate_fn = collate_fn + super().__init__( dataset=dataset, batch_size=None, # Batch size is handled by the dataset @@ -109,7 +126,7 @@ def __init__( sampler=None, batch_sampler=None, num_workers=num_workers, - collate_fn=collate_fn, + collate_fn=final_collate_fn, pin_memory=pin_memory, drop_last=False, timeout=0, @@ -120,14 +137,3 @@ def __init__( persistent_workers=persistent_workers, pin_memory_device=pin_memory_device, ) - - def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: - """Iterate over the dataset, yielding batches with metadata. - - Yields: - Tuple[TensorDict, BatchMeta]: A tuple containing: - - TensorDict: Batch of data with the requested fields. - - BatchMeta: Corresponding metadata to interact with TransferQueue. - """ - # Directly iterate over the dataset, which yields (batch, batch_meta) tuples - return iter(self.dataset) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 1a99950c..510a497a 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -109,7 +109,7 @@ def __init__( raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}") if len(required_fields) < 1: - raise ValueError(f"required_fields must be a list of more than one field name, got {self.required_fields}") + raise ValueError(f"required_fields must be a list with at least one field name, got {required_fields}") if data_replica_world_size < 1: raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py index cd3ef4c4..d532c80e 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/05_streaming_dataloader.py @@ -130,7 +130,7 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): Args: rank_id: Unique identifier for this generator (used for sample indexing) config: TransferQueue configuration - num_samples: Number of samples to generate (default: 10) + num_samples: Number of samples to generate Note: Each sample has a unique sequence ID calculated as: seq_id = i + (rank_id * 10000) From dc448cba469e19213163b17f79cf2320d666a640 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 16:25:55 +0800 Subject: [PATCH 16/17] delete TQ_ZERO_COPY_SERIALIZATION config Signed-off-by: 0oshowero0 --- .github/workflows/python-package.yml | 7 +------ scripts/put_benchmark.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d51ffeeb..14339bb8 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -37,11 +37,6 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=False) + - name: Test with pytest run: | - pytest - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=True) - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True pytest \ No newline at end of file diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index 1de38893..96bad6ee 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -458,10 +458,5 @@ def main(): if __name__ == "__main__": - try: - from transfer_queue.utils import serial_utils - - print(f"[Startup Check] TQ_ZERO_COPY_SERIALIZATION = {serial_utils.TQ_ZERO_COPY_SERIALIZATION}") - except ImportError: - print("[Startup Check] Could not import serial_utils to check flag") + print("[Startup Check]") main() From c005fb574667502682036c019217e900dbbde215 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 24 Jan 2026 16:35:37 +0800 Subject: [PATCH 17/17] fix Signed-off-by: 0oshowero0 --- transfer_queue/utils/serial_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 5cd7275a..cd308a9d 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -166,8 +166,6 @@ def _encode_regular_tensor_meta(self, obj: torch.Tensor) -> tuple: """Encode a regular tensor and return its metadata tuple.""" # Handle non-contiguous tensors - assert self.aux_buffers is not None - if not obj.is_contiguous(): obj = obj.contiguous() @@ -188,8 +186,6 @@ def _encode_regular_tensor(self, obj: torch.Tensor) -> msgpack.Ext: """Encode a regular (non-nested) tensor with zero-copy.""" # Handle non-contiguous tensors - assert self.aux_buffers is not None - if not obj.is_contiguous(): obj = obj.contiguous()