From 145fea3a4deee26f8228df8ed8bb2ddf13be09e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=81=E6=9C=AC=E5=93=B2?= Date: Mon, 2 Feb 2026 13:29:20 +0000 Subject: [PATCH] [StreamingDataLoader, 5/N] feat: Refactor the StreamDataLoader implementation to support fully asynchronous mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 宁本哲 --- README.md | 1 - recipe/simple_use_case/async_demo.py | 1 - recipe/simple_use_case/sync_demo.py | 1 - tests/test_samplers.py | 415 ++++++++---------- transfer_queue/client.py | 5 +- transfer_queue/controller.py | 10 +- .../dataloader/streaming_dataloader.py | 30 ++ .../dataloader/streaming_dataset.py | 230 +++++++--- transfer_queue/sampler/base.py | 13 + .../sampler/grpo_group_n_sampler.py | 74 +++- transfer_queue/sampler/rank_aware_sampler.py | 40 +- tutorial/05_streaming_dataloader.py | 49 +-- 12 files changed, 485 insertions(+), 384 deletions(-) diff --git a/README.md b/README.md index 7f46346b..b69176d8 100644 --- a/README.md +++ b/README.md @@ -247,7 +247,6 @@ batch_meta = client.get_meta( batch_size=8, partition_id="train_0", task_name="generate_sequences", - sampling_config={"n_samples_per_prompt": 4} # Put the required sampling parameters here ) ``` diff --git a/recipe/simple_use_case/async_demo.py b/recipe/simple_use_case/async_demo.py index a1431f1e..2bad38a2 100644 --- a/recipe/simple_use_case/async_demo.py +++ b/recipe/simple_use_case/async_demo.py @@ -226,7 +226,6 @@ def _initialize_data_system(self): # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) # Then use sampling_config in get_meta calls: - # sampling_config={"n_samples_per_prompt": 4} self.data_system_controller = TransferQueueController.remote() logger.info("TransferQueueController has been created.") diff --git a/recipe/simple_use_case/sync_demo.py b/recipe/simple_use_case/sync_demo.py index a4bf04d6..e6513274 100644 --- a/recipe/simple_use_case/sync_demo.py +++ b/recipe/simple_use_case/sync_demo.py @@ -68,7 +68,6 @@ def initialize_data_system(config): # data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) # Then use sampling_config in get_meta calls: - # sampling_config={"n_samples_per_prompt": 4} data_system_controller = TransferQueueController.remote() logger.info("TransferQueueController has been created.") diff --git a/tests/test_samplers.py b/tests/test_samplers.py index d184bba8..da7aae1a 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -192,12 +192,11 @@ def test_grpo_sampler_initialization(self): def test_grpo_sampler_basic_functionality(self): """Test basic grouped sampling functionality.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] # 8 indexes batch_size = 8 - n_samples_per_prompt = 4 # 2 groups of 4 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [0, 1, 2, 3, 4, 5, 6, 7] assert consumed == [0, 1, 2, 3, 4, 5, 6, 7] @@ -206,77 +205,58 @@ def test_grpo_sampler_basic_functionality(self): def test_grpo_sampler_partial_batch(self): """Test partial batch sampling.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] # 12 indexes batch_size = 8 # Want 8 samples total - n_samples_per_prompt = 4 # 2 groups of 4 + # 2 groups of 4 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [0, 1, 2, 3, 4, 5, 6, 7] assert consumed == [0, 1, 2, 3, 4, 5, 6, 7] assert len(sampled) == batch_size assert len(consumed) == batch_size - def test_grpo_sampler_different_group_sizes(self): - """Test different n_samples_per_prompt values.""" - sampler = GRPOGroupNSampler() - ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] - - # Test with 2 samples per prompt (8 groups) - sampled, consumed = sampler.sample(ready_indexes, 8, n_samples_per_prompt=2) - assert sampled == [0, 1, 2, 3, 4, 5, 6, 7] - assert consumed == [0, 1, 2, 3, 4, 5, 6, 7] - - # Test with 8 samples per prompt (2 groups) - sampled, consumed = sampler.sample(ready_indexes, 8, n_samples_per_prompt=8) - assert sampled == [0, 1, 2, 3, 4, 5, 6, 7] - assert consumed == [0, 1, 2, 3, 4, 5, 6, 7] - def test_grpo_sampler_batch_size_divisibility(self): """Test that batch_size must be divisible by n_samples_per_prompt.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] # 8 indexes, sufficient for batch_size=7 batch_size = 7 - n_samples_per_prompt = 4 with pytest.raises(ValueError) as exc_info: - sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampler.sample(ready_indexes, batch_size) assert "must be a multiple of n_samples_per_prompt" in str(exc_info.value) def test_grpo_sampler_insufficient_ready_indexes(self): """Test behavior when not enough ready indexes are available.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [0, 1, 2, 3] # Only 4 indexes, but need 8 for 2 groups of 4 batch_size = 8 - n_samples_per_prompt = 4 # Should return empty lists when insufficient complete groups - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [] assert consumed == [] def test_grpo_sampler_exact_multiple_available(self): """Test when ready_indexes length is exactly a multiple of n_samples_per_prompt.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] # 8 indexes batch_size = 8 - n_samples_per_prompt = 4 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [0, 1, 2, 3, 4, 5, 6, 7] assert consumed == [0, 1, 2, 3, 4, 5, 6, 7] def test_grpo_sampler_zero_batch_size(self): """Test behavior with zero batch size.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=2) ready_indexes = [0, 1, 2, 3] batch_size = 0 - n_samples_per_prompt = 2 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [] assert consumed == [] @@ -286,73 +266,53 @@ def test_grpo_sampler_single_sample_per_prompt(self): sampler = GRPOGroupNSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - n_samples_per_prompt = 1 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [0, 1, 2] assert consumed == [0, 1, 2] def test_grpo_sampler_large_group_size(self): """Test with large n_samples_per_prompt.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=10) ready_indexes = list(range(20)) # 20 indexes batch_size = 20 - n_samples_per_prompt = 10 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == list(range(20)) assert consumed == list(range(20)) def test_grpo_sampler_call_method(self): """Test that __call__ method works correctly.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=2) ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 4 - n_samples_per_prompt = 2 - sampled, consumed = sampler(ready_indexes, batch_size, n_samples_per_prompt=n_samples_per_prompt) + sampled, consumed = sampler(ready_indexes, batch_size) assert sampled == [0, 1, 2, 3] assert consumed == [0, 1, 2, 3] - def test_grpo_sampler_parameter_order_independence(self): - """Test that parameter order doesn't matter when using kwargs.""" - sampler = GRPOGroupNSampler() - ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] - - # Try different parameter orders - sampled1, consumed1 = sampler.sample(n_samples_per_prompt=4, batch_size=8, ready_indexes=ready_indexes) - - sampled2, consumed2 = sampler.sample(batch_size=8, ready_indexes=ready_indexes, n_samples_per_prompt=4) - - assert sampled1 == sampled2 - assert consumed1 == consumed2 - def test_grpo_sampler_with_extra_kwargs(self): """Test that GRPOGroupNSampler accepts extra kwargs but ignores them.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 8 - n_samples_per_prompt = 4 # GRPOGroupNSampler should accept extra kwargs but ignore them - sampled, consumed = sampler.sample( - ready_indexes, batch_size, n_samples_per_prompt, extra_param="ignored", another_param=42 - ) + sampled, consumed = sampler.sample(ready_indexes, batch_size, extra_param="ignored", another_param=42) assert sampled == [0, 1, 2, 3, 4, 5, 6, 7] assert consumed == [0, 1, 2, 3, 4, 5, 6, 7] def test_grpo_sampler_non_sequential_indexes(self): """Test with non-sequential ready indexes that get sorted.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [3, 4, 5, 6, 9, 10, 11, 12] # Non-sequential order but has consecutive groups after sorting batch_size = 8 - n_samples_per_prompt = 4 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) # Should find consecutive groups after sorting: [3,4,5,6] and [9,10,11,12] expected = [3, 4, 5, 6, 9, 10, 11, 12] @@ -361,52 +321,44 @@ def test_grpo_sampler_non_sequential_indexes(self): def test_grpo_sampler_invalid_n_samples_per_prompt(self): """Test behavior with invalid n_samples_per_prompt values.""" - sampler = GRPOGroupNSampler() - ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] - batch_size = 8 - # Test zero n_samples_per_prompt with pytest.raises(ValueError) as exc_info: - sampler.sample(ready_indexes, batch_size, n_samples_per_prompt=0) + GRPOGroupNSampler(n_samples_per_prompt=0) assert "must be positive" in str(exc_info.value) - # Test negative n_samples_per_prompt with pytest.raises(ValueError) as exc_info: - sampler.sample(ready_indexes, batch_size, n_samples_per_prompt=-2) + GRPOGroupNSampler(n_samples_per_prompt=-2) assert "must be positive" in str(exc_info.value) def test_grpo_sampler_no_complete_groups(self): """Test behavior when no complete groups are available.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=3) ready_indexes = [0, 1, 3, 4, 6, 7] # No consecutive groups of size 3 batch_size = 6 - n_samples_per_prompt = 3 # Should return empty lists when no complete groups found - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [] assert consumed == [] def test_grpo_sampler_mixed_groups(self): """Test behavior with mixed complete and incomplete groups.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=3) ready_indexes = [0, 1, 3, 4, 5, 6, 7, 9, 10, 11] # Mixed groups batch_size = 6 - n_samples_per_prompt = 3 # Should find the complete groups [3,4,5] and [9,10,11] - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert sampled == [3, 4, 5, 9, 10, 11] assert consumed == [3, 4, 5, 9, 10, 11] def test_grpo_sampler_sorting_functionality(self): """Test that ready_indexes are properly sorted before group detection.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [10, 11, 12, 5, 6, 7, 8, 9] # Out of order but contains consecutive groups batch_size = 8 - n_samples_per_prompt = 4 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) # After sorting: [5,6,7,8,9,10,11,12], should find [5,6,7,8] and [9,10,11,12] expected = [5, 6, 7, 8, 9, 10, 11, 12] @@ -415,19 +367,18 @@ def test_grpo_sampler_sorting_functionality(self): def test_grpo_sampler_insufficient_groups(self): """Test behavior when requesting more groups than available.""" - sampler = GRPOGroupNSampler() + sampler = GRPOGroupNSampler(n_samples_per_prompt=4) ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] # 4 groups of 4 batch_size = 12 # Requesting 3 groups of 4 - this should work - n_samples_per_prompt = 4 # This should actually work fine since we have 4 groups and request 3 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) assert len(sampled) == 12 assert len(consumed) == 12 # Now test requesting more than available batch_size = 20 # Requesting 5 groups of 4, but only have 4 - sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt) + sampled, consumed = sampler.sample(ready_indexes, batch_size) # Should return empty lists when requesting more complete groups than available assert sampled == [] @@ -444,20 +395,17 @@ def test_rank_aware_sampler_initialization(self): assert hasattr(sampler, "_states") assert sampler._states == {} - def test_rank_aware_sampler_first_rank_sampling(self): - """Test that first rank in data replica group performs actual sampling.""" + def test_rank_aware_sampler_basic_sampling(self): + """Test basic sampling functionality.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - # Rank 0 (first in group) samples and caches for all ranks - # Since rank 1 will call next, state is kept until rank 1 fetches sampled, consumed = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) @@ -465,142 +413,97 @@ def test_rank_aware_sampler_first_rank_sampling(self): assert sampled == [0, 1, 2] assert consumed == [0, 1, 2] assert len(sampled) == batch_size - # State is kept for other ranks to fetch - def test_rank_aware_sampler_second_rank_gets_cached(self): - """Test that second rank in data replica group gets cached indices.""" + def test_rank_aware_sampler_caching_on_same_batch_index(self): + """Test that same batch_index returns cached results.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - # Rank 0 (first in group) samples first + # First call with batch_index=0 sampled1, consumed1 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) - # Rank 1 (second in group) should get same cached indices + # Second call with same batch_index=0 should return cached result sampled2, consumed2 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=1, - data_replica_world_size=2, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) assert sampled1 == sampled2 == [0, 1, 2] - assert consumed1 == [0, 1, 2] - assert consumed2 == [0, 1, 2] - - # cache should be empty after all ranks fetch - assert len(sampler._states["test"]["task"][0][0]) == 0 - assert len(sampler._states["test"]["task"][0][1]) == 0 + assert consumed1 == consumed2 == [0, 1, 2] - def test_rank_aware_sampler_multiple_dp_groups(self): - """Test that multiple data replica groups work independently.""" + def test_rank_aware_sampler_different_batch_indexes(self): + """Test that different batch_index values sample different data.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 2 - data_replica_world_size = 2 # Each group has 2 ranks - # data replica group 0: rank 0 samples first - sampled0_g0, consumed0_g0 = sampler.sample( + # First batch + sampled1, consumed1 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=data_replica_world_size, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) - # mimic the consumption status update managed in TransferQueueController - ready_indexes = [i for i in ready_indexes if i not in consumed0_g0] - # data replica group 1: rank 0 samples first - sampled0_g1, consumed0_g1 = sampler.sample( + # Second batch + ready_indexes = [2, 3, 4, 5, 6, 7] + sampled2, consumed2 = sampler.sample( ready_indexes, batch_size, - data_replica_group=1, - data_replica_rank=0, - data_replica_world_size=data_replica_world_size, + dp_rank=0, + batch_index=1, task_name="task", partition_id="test", ) - ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] - - # Both should have sampled their first batch - assert sampled0_g0 == [0, 1] - assert sampled0_g1 == [2, 3] - assert consumed0_g0 == [0, 1] - assert consumed0_g1 == [2, 3] - # data replica group 0: rank 1 fetches cached - sampled1_g0, consumed1_g0 = sampler.sample( - ready_indexes, - batch_size, - data_replica_group=0, - data_replica_rank=1, - data_replica_world_size=data_replica_world_size, - task_name="task", - partition_id="test", - ) - ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] - assert sampled1_g0 == [0, 1] - assert consumed1_g0 == [0, 1] + assert sampled1 == [0, 1] + assert sampled2 == [2, 3] + assert consumed1 == [0, 1] + assert consumed2 == [2, 3] - # data replica group 1: rank 1 fetches cached - sampled1_g1, consumed1_g1 = sampler.sample( - ready_indexes, - batch_size, - data_replica_group=1, - data_replica_rank=1, - data_replica_world_size=data_replica_world_size, - task_name="task", - partition_id="test", - ) - ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] - assert sampled1_g1 == [2, 3] - assert consumed1_g1 == [2, 3] + def test_rank_aware_sampler_multiple_dp_ranks(self): + """Test that same dp_ranks reuse state cache.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 - # data replica group 0: rank 0 fetches again, this should return new data - sampled2_g0, consumed2_g0 = sampler.sample( + # DP rank 0 samples batch 0 + sampled_dp0_b0, consumed_dp0_b0 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=data_replica_world_size, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) - ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] - assert sampled2_g0 == [4, 5] - assert consumed2_g0 == [4, 5] - - # data replica group 0: rank 1 fetches cached - sampled3_g0, consumed3_g0 = sampler.sample( + ready_indexes = [2, 3, 4, 5, 6, 7] + # DP rank 0 samples batch 0 (should get same result as dp_rank=0) + sampled_dp1_b0, consumed_dp1_b0 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=1, - data_replica_world_size=data_replica_world_size, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) - assert sampled3_g0 == [4, 5] - assert consumed3_g0 == [4, 5] - # examine the internal state to ensure proper caching and clearing - assert len(sampler._states["test"]["task"][0][0]) == 0 - assert len(sampler._states["test"]["task"][0][1]) == 0 - assert len(sampler._states["test"]["task"][1][0]) == 0 - assert len(sampler._states["test"]["task"][1][1]) == 0 + # Both should sample from the same ready_indexes + assert sampled_dp0_b0 == [0, 1] + assert sampled_dp1_b0 == [0, 1] def test_rank_aware_sampler_empty_ready_indexes(self): """Test behavior with empty ready indexes.""" @@ -611,9 +514,8 @@ def test_rank_aware_sampler_empty_ready_indexes(self): sampled, consumed = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) @@ -630,9 +532,8 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self): sampled, consumed = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) @@ -649,9 +550,8 @@ def test_rank_aware_sampler_zero_batch_size(self): sampled, consumed = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, + dp_rank=0, + batch_index=0, task_name="task", partition_id="test", ) @@ -659,98 +559,129 @@ def test_rank_aware_sampler_zero_batch_size(self): assert sampled == [] assert consumed == [] - def test_rank_aware_sampler_data_prefetch(self): - """Test behavior with data prefetch.""" + def test_rank_aware_sampler_multiple_tasks(self): + """Test behavior with multiple tasks.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 2 - sampled_rank0_time0, consumed_rank0_time0 = sampler.sample( + sampled_task0, consumed_task0 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, - task_name="task", + dp_rank=0, + batch_index=0, + task_name="task0", partition_id="test", ) - assert sampled_rank0_time0 == [0, 1] - assert consumed_rank0_time0 == [0, 1] - assert sampler._states["test"]["task"][0][0] == [] - assert sampler._states["test"]["task"][0][1] == [[0, 1]] - - ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0] - - sampled_rank0_time1, consumed_rank0_time1 = sampler.sample( + sampled_task1, consumed_task1 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, - task_name="task", + dp_rank=0, + batch_index=0, + task_name="task1", partition_id="test", ) - assert sampled_rank0_time1 == [2, 3] - assert consumed_rank0_time1 == [2, 3] - assert sampler._states["test"]["task"][0][0] == [] - assert sampler._states["test"]["task"][0][1] == [[0, 1], [2, 3]] + assert sampled_task0 == [0, 1] + assert consumed_task0 == [0, 1] + assert sampled_task1 == [0, 1] + assert consumed_task1 == [0, 1] - ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1] + # Check that state is separate per task + assert sampler._states["test"]["task0"][0][0] == [0, 1] + assert sampler._states["test"]["task1"][0][0] == [0, 1] - sampled_rank1_time0, consumed_rank1_time0 = sampler.sample( + def test_rank_aware_sampler_multiple_partitions(self): + """Test behavior with multiple partitions.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5] + batch_size = 2 + + sampled_part0, consumed_part0 = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=1, - data_replica_world_size=2, + dp_rank=0, + batch_index=0, task_name="task", - partition_id="test", + partition_id="partition0", ) - assert sampled_rank1_time0 == [0, 1] - assert consumed_rank1_time0 == [0, 1] - assert sampler._states["test"]["task"][0][0] == [] - assert sampler._states["test"]["task"][0][1] == [[2, 3]] + sampled_part1, consumed_part1 = sampler.sample( + ready_indexes, + batch_size, + dp_rank=0, + batch_index=0, + task_name="task", + partition_id="partition1", + ) - def test_rank_aware_sampler_multiple_tasks(self): - """Test behavior with multiple tasks.""" + assert sampled_part0 == [0, 1] + assert consumed_part0 == [0, 1] + assert sampled_part1 == [0, 1] + assert consumed_part1 == [0, 1] + + # Check that state is separate per partition + assert sampler._states["partition0"]["task"][0][0] == [0, 1] + assert sampler._states["partition1"]["task"][0][0] == [0, 1] + + def test_rank_aware_sampler_invalid_dp_rank(self): + """Test behavior with invalid dp_rank.""" sampler = RankAwareSampler() - ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + ready_indexes = [0, 1, 2, 3] batch_size = 2 - sampled_rank0_task0, consumed_rank0_task0 = sampler.sample( + with pytest.raises(ValueError) as exc_info: + sampler.sample( + ready_indexes, + batch_size, + dp_rank=-1, + batch_index=0, + task_name="task", + partition_id="test", + ) + + assert "dp_rank" in str(exc_info.value) + assert "greater than or equal to 0" in str(exc_info.value) + + def test_rank_aware_sampler_with_extra_kwargs(self): + """Test that RankAwareSampler accepts extra kwargs but ignores them.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5] + batch_size = 2 + + # Should accept extra kwargs gracefully + sampled, consumed = sampler.sample( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, - task_name="task0", + dp_rank=0, + batch_index=0, + task_name="task", partition_id="test", + extra_param="ignored", + another_param=42, ) - assert sampled_rank0_task0 == [0, 1] - assert consumed_rank0_task0 == [0, 1] - assert sampler._states["test"]["task0"][0][0] == [] - assert sampler._states["test"]["task0"][0][1] == [[0, 1]] + assert sampled == [0, 1] + assert consumed == [0, 1] - sampled_rank0_task1, consumed_rank0_task1 = sampler.sample( + def test_rank_aware_sampler_call_method(self): + """Test that __call__ method works correctly.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3] + batch_size = 2 + + sampled, consumed = sampler( ready_indexes, batch_size, - data_replica_group=0, - data_replica_rank=0, - data_replica_world_size=2, - task_name="task1", + dp_rank=0, + batch_index=0, + task_name="task", partition_id="test", ) - assert sampled_rank0_task1 == [0, 1] - assert consumed_rank0_task1 == [0, 1] - assert sampler._states["test"]["task0"][0][0] == [] - assert sampler._states["test"]["task0"][0][1] == [[0, 1]] - assert sampler._states["test"]["task1"][0][0] == [] - assert sampler._states["test"]["task1"][0][1] == [[0, 1]] + assert sampled == [0, 1] + assert consumed == [0, 1] class TestSamplerIntegration: @@ -772,7 +703,7 @@ def test_samplers_implement_base_interface(self): def test_samplers_return_consistent_types(self): """Test that all samplers return consistent tuple types.""" - samplers = [(SequentialSampler(), {}), (GRPOGroupNSampler(), {"n_samples_per_prompt": 2})] + samplers = [(SequentialSampler(), {}), (GRPOGroupNSampler(n_samples_per_prompt=2), {})] ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 4 @@ -792,7 +723,7 @@ def test_samplers_return_consistent_types(self): def test_samplers_handle_edge_cases_consistently(self): """Test that samplers handle edge cases consistently.""" - samplers = [(SequentialSampler(), {}), (GRPOGroupNSampler(), {"n_samples_per_prompt": 2})] + samplers = [(SequentialSampler(), {}), (GRPOGroupNSampler(n_samples_per_prompt=2), {})] # Test empty ready indexes for sampler, kwargs in samplers: diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 1044d89e..cdce3392 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -179,7 +179,6 @@ async def async_get_meta( - 'insert': Internal usage - should not be used by users task_name: Optional task name associated with the request sampling_config: Optional sampling configuration for custom samplers. - For GRPOGroupNSampler, should include "n_samples_per_prompt": int socket: ZMQ async socket for message transmission (injected by decorator) Returns: @@ -206,7 +205,6 @@ async def async_get_meta( ... partition_id="train_0", ... mode="fetch", ... task_name="generate_sequences", - ... sampling_config={"n_samples_per_prompt": 4} ... )) >>> print(batch_meta.is_ready) # True if all samples ready >>> @@ -698,7 +696,7 @@ async def async_check_consumption_status( partition_id=partition_id, ) - if consumption_status is None: + if consumption_status is None or consumption_status.numel() == 0: return False return torch.all(consumption_status == 1).item() @@ -883,7 +881,6 @@ def get_meta( partition_id: Target data partition id task_name: Optional task name associated with the request sampling_config: Optional sampling configuration for custom samplers. - For GRPOGroupNSampler, should include "n_samples_per_prompt": int Returns: BatchMeta: Batch metadata containing data location information diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 83d5b298..39c890c8 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -781,7 +781,7 @@ def __init__( - If a BaseSampler subclass is provided, it will be instantiated - Defaults to SequentialSampler for simple sequential sampling - Example: sampler=GRPOGroupNSampler() (instance) - - Example: sampler=GRPOGroupNSampler (class) + - Example: sampler=SequentialSampler (class) polling_mode: Whether to use polling mode for TransferQueue controller. - If False, the controller will raise an error when no enough data is available. - If True, the controller will return an empty BatchMeta when no enough data is available. @@ -1015,12 +1015,12 @@ def get_metadata( Raises: TimeoutError: If waiting for sufficient data times out in fetch mode """ - if partition_id not in self.partitions: - self.create_partition(partition_id) if mode == "insert": - partition = self._get_partition(partition_id) + if partition_id not in self.partitions: + self.create_partition(partition_id) + partition = self._get_partition(partition_id) if data_fields: # This is called during put_data call without providing metadata. # try to use pre-allocated global index first @@ -1083,6 +1083,7 @@ def get_metadata( ready_for_consume_indexes, batch_size, **(sampling_config or {}), + **kwargs, ) # Check if we got valid results from the sampler @@ -1240,6 +1241,7 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): partition.clear_data(global_indexes_range, clear_consumption) self.index_manager.release_partition(partition_id) self.partitions.pop(partition_id) + self.sampler.clear_cache(partition_id) def clear_meta( self, diff --git a/transfer_queue/dataloader/streaming_dataloader.py b/transfer_queue/dataloader/streaming_dataloader.py index 374d5155..20a3b1d3 100644 --- a/transfer_queue/dataloader/streaming_dataloader.py +++ b/transfer_queue/dataloader/streaming_dataloader.py @@ -111,6 +111,7 @@ def __init__( parameter in PyTorch DataLoader is set to None because batching is managed by the StreamingDataset in coordination with RankAwareSampler. """ + self.dataset: StreamingDataset = dataset if collate_fn is None: # use identical collate function to directly return the self-defined @@ -137,3 +138,32 @@ def __init__( persistent_workers=persistent_workers, pin_memory_device=pin_memory_device, ) + + def reset(self): + """Reset the dataset iterator to the beginning. + + Clears the buffer and resets the batch index for a fresh iteration. + """ + self.dataset.reset() + + def step(self, partition_id): + """Switch to a new partition and reset the dataset state. + + This method clears the buffer, resets the batch index, and updates the partition_id + to fetch data from a different partition (e.g., switching from "train" to "val"). + + Args: + partition_id: The new partition ID to switch to. + """ + self.dataset.step(partition_id) + + def get_buffer(self): + """Get the current buffer from the underlying dataset. + + Returns the batch buffer maintained by StreamingDataset, which stores + pre-fetched batches for efficient data access. + + Returns: + list: Buffer containing pre-fetched (TensorDict, BatchMeta) tuples. + """ + return self.dataset.buffer diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index ea1487f2..4b54fb3a 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -17,7 +17,7 @@ import os import time import uuid -from typing import Any, Iterator +from typing import Any, Callable, Iterator from tensordict import TensorDict from torch.utils.data import IterableDataset @@ -53,9 +53,7 @@ class StreamingDataset(IterableDataset): ... required_fields=["input_ids", "attention_mask"], ... partition_id="train", ... task_name="update_actor", - ... data_replica_group=data_replica_group_id, # Same for all ranks in data replica group - ... data_replica_rank=local_rank, # local rank in data replica group - ... data_replica_world_size=world_size/dp_world_size, # size of data replica group + ... dp_rank=dp_rank, # Same for all ranks in data replica group ... ) >>> dataloader = StreamingDataLoader( ... dataset, @@ -71,13 +69,14 @@ class StreamingDataset(IterableDataset): def __init__( self, config: dict[str, Any], + batch_size: int, micro_batch_size: int, - required_fields: list[str], + data_fields: list[str], partition_id: str, task_name: str, - data_replica_group: int, - data_replica_rank: int, - data_replica_world_size: int, + dp_rank: int, + fetch_batch_fn: Callable | None = None, + process_batch_fn: Callable | None = None, ): """Initialize the StreamingDataset. @@ -86,20 +85,21 @@ def __init__( - controller_info: ZMQServerInfo for the TransferQueueController - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager") - Other backend-specific configuration + batch_size: Batch size for data loading per iter. micro_batch_size: Number of samples per micro-batch. This is the batch size that will be requested from TransferQueue for each iteration. - required_fields: List of field names to retrieve from storage. Only these + data_fields: List of field names to retrieve from storage. Only these fields will be included in the returned batch. partition_id: Partition ID for data versioning. Different partitions can be used for different data versions or splits (e.g., "train", "val"). task_name: Unique identifier for the training task. This is used to track which samples have been consumed by which task. - data_replica_group: The group ID of the current data replica group. All - ranks with the same data_replica_group will receive identical samples. - data_replica_rank: Local rank index within the data_replica_group. Range: - [0, data_replica_world_size - 1] - data_replica_world_size: Total number of ranks in this data_replica_group. - Must be >= 1. + dp_rank: The group ID of the current data group. All + ranks with the same dp_rank will receive identical samples. + fetch_batch_fn: Optional custom function to retrieve batch data. + If None, uses default_fetch_batch_fn function. + process_batch_fn: Optional custom function to post-process + and split data into micro-batches. If None, uses chunk_batch_fn. Raises: ValueError: If input parameters are invalid. @@ -108,41 +108,53 @@ def __init__( if micro_batch_size < 1: raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}") - if len(required_fields) < 1: - raise ValueError(f"required_fields must be a list with at least one field name, got {required_fields}") + if len(data_fields) < 1: + raise ValueError(f"data_fields must be a list with at least one field name, got {data_fields}") - if data_replica_world_size < 1: - raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") - - if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: - raise ValueError( - f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than " - f"data_replica_world_size {data_replica_world_size}" - ) + if dp_rank < 0: + raise ValueError(f"dp_rank {dp_rank} must be greater than or equal to 0") self.config = config + self.batch_size = batch_size self.micro_batch_size = micro_batch_size - self.required_fields = required_fields + self.data_fields = data_fields self.partition_id = partition_id self.task_name = task_name - self.data_replica_group = data_replica_group - self.data_replica_rank = data_replica_rank - self.data_replica_world_size = data_replica_world_size + self.dp_rank = dp_rank + self.get_batch_func = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn + self.post_process_for_micro_func = process_batch_fn if process_batch_fn else chunk_batch_fn # Build sampling config for controller self.sampling_config = { - "data_replica_group": self.data_replica_group, - "data_replica_rank": self.data_replica_rank, - "data_replica_world_size": self.data_replica_world_size, + "dp_rank": self.dp_rank, "task_name": self.task_name, - "partition_id": self.partition_id, } self._tq_client = None + # Buffer for caching fetched batches (list of tuples (TensorDict, BatchMeta)). + # Purpose: + # 1) Cache full training batches retrieved from TransferQueue / storage to + # make logging, debugging and replaying batches easier. + # 2) Support multi-pass training on the same samples in some scenarios — + # using `batch_index` to iterate over cached batches multiple times + # avoids re-fetching them from remote storage and reduces network/storage + # overhead. + # 3) Work together with `reset()` / `step()` to manage iteration state cleanly + # and avoid dropping batches that haven't been consumed yet. + self.buffer: list[tuple[TensorDict, BatchMeta]] = [] + self.batch_index = 0 super().__init__() def _create_client(self): + """Create and initialize a TransferQueue client. + + This method initializes the TransferQueueClient with the provided configuration + and storage backend, and sets up the storage manager for data retrieval. + + Raises: + ValueError: If controller_info or storage_backend is missing or invalid. + """ client_id = uuid.uuid4().hex[:8] controller_info = self.config.get("controller_info", None) if not controller_info or not isinstance(controller_info, ZMQServerInfo): @@ -175,30 +187,142 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: # Note: For fully streamed production-consumption, please set the environment variable # TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately # determine consumption status even before producers have generated the samples. - while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): + while ( + not self._tq_client.check_consumption_status(self.task_name, self.partition_id) + or self.batch_index <= len(self.buffer) - 1 + ): try: - # Get metadata from controller - batch_meta = self._tq_client.get_meta( - data_fields=self.required_fields, - batch_size=self.micro_batch_size, - partition_id=self.partition_id, - task_name=self.task_name, - sampling_config=self.sampling_config, - ) - - # Check if we got valid data - if batch_meta.size == 0: - logger.debug( - f"[StreamingDataset]: Received empty batch, waiting for more data... " - f"Required batch_size={self.micro_batch_size}, data_fields={self.required_fields}," - f"partition_id={self.partition_id}, task_name={self.task_name}." - ) + if self.batch_index <= len(self.buffer) - 1: + current_data = self.buffer[self.batch_index] + self.batch_index += 1 + yield from self.post_process_for_micro_func(*current_data, micro_batch_size=self.micro_batch_size) - time.sleep(TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL) else: - batch = self._tq_client.get_data(batch_meta) - yield (batch, batch_meta) + batch_data, batch_meta = self.get_batch_func( + self._tq_client, + self.data_fields, + self.batch_size, + self.partition_id, + self.task_name, + self.sampling_config, + self.batch_index, + ) + if batch_data is not None: + self.buffer.append((batch_data, batch_meta)) + else: + time.sleep(1) except Exception as e: logger.error(f"[StreamingDataset]: Error in data iteration: {e}") raise + + def reset(self): + """Reset the dataset iterator to the beginning. + + Clears the buffer and resets the batch index for a fresh iteration. + """ + self.batch_index = 0 + + def step(self, partition_id): + """Switch to a new partition and reset the dataset state. + + This method clears the buffer, resets the batch index, and updates the partition_id + to fetch data from a different partition (e.g., switching from "train" to "val"). + + Args: + partition_id: The new partition ID to switch to. + """ + self.buffer = [] + self.batch_index = 0 + self.partition_id = partition_id + + +def default_fetch_batch_fn(tq_client, data_fields, batch_size, partition_id, task_name, sampling_config, batch_index): + """Retrieve a batch of data from TransferQueue. + + This function queries the TransferQueue controller for batch metadata and retrieves + the actual data if available. It handles empty batches gracefully. + + Args: + tq_client: The TransferQueueClient instance for data retrieval. + data_fields: List of field names to retrieve from the batch. + batch_size: The requested batch size. + partition_id: The partition ID for data versioning. + task_name: Unique identifier for the training task. + sampling_config: Configuration dictionary for sampling strategy. + batch_index: Current batch index for tracking consumption progress. + + Returns: + tuple: A tuple containing: + - batch: TensorDict with the retrieved data, or None if batch is empty. + - batch_meta: BatchMeta object containing batch metadata. + """ + # Get metadata from controller + config = {**sampling_config, "batch_index": batch_index, "partition_id": partition_id} + batch_meta = tq_client.get_meta( + data_fields=data_fields, + batch_size=batch_size, + partition_id=partition_id, + task_name=task_name, + sampling_config=config, + ) + + # Check if we got valid data + if batch_meta.size == 0: + logger.debug( + f"[StreamingDataset]: Received empty batch, waiting for more data... " + f"Required batch_size={batch_size}, data_fields={data_fields}," + f"partition_id={partition_id}, task_name={task_name}." + ) + return None, batch_meta + else: + batch = tq_client.get_data(batch_meta) + return batch, batch_meta + + +def chunk_batch_fn(td, batch_meta, micro_batch_size=1): + """Split TensorDict into micro-batches along the batch dimension. + + This function chunks a TensorDict into smaller micro-batches with the specified size, + along with corresponding metadata chunks. Handles cases where batch size is not + evenly divisible by micro_batch_size. + + Args: + td: Input TensorDict with non-empty batch_size. + batch_meta: BatchMeta object to be chunked along with the TensorDict. + micro_batch_size: Target size for each micro-batch (positive integer, default: 1). + + Returns: + list: List of tuples (micro_batch_td, micro_batch_meta) where each tuple + contains a TensorDict chunk and corresponding metadata chunk. + + Raises: + TypeError: If td is not a TensorDict. + ValueError: If micro_batch_size is not a positive integer, batch_size is empty, + or micro_batch_size exceeds total batch size. + """ + if not isinstance(td, TensorDict): + raise TypeError(f"Expected TensorDict, got {type(td).__name__}") + + if not isinstance(micro_batch_size, int) or micro_batch_size <= 0: + raise ValueError(f"micro_batch_size must be a positive integer, got {micro_batch_size}") + + if len(td.batch_size) == 0: + raise ValueError("Input TensorDict must have non-empty batch_size") + + total_size = td.batch_size[0] + if micro_batch_size > total_size: + raise ValueError(f"micro_batch_size ({micro_batch_size}) exceeds total batch size ({total_size})") + + # Calculate number of splits (handles uneven division) + num_splits = (total_size + micro_batch_size - 1) // micro_batch_size + splits = [] + batch_meta_list = batch_meta.chunk(num_splits) + + # Chunk the TensorDict and pair with corresponding metadata chunks + for i in range(num_splits): + start = i * micro_batch_size + end = min(start + micro_batch_size, total_size) + splits.append((td[start:end], batch_meta_list[i])) + + return splits diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index 2cdd1b5e..93f0edd9 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -74,3 +74,16 @@ def sample( def __call__(self, *args: Any, **kwargs: Any) -> tuple[list[int], list[int]]: return self.sample(*args, **kwargs) + + def clear_cache(self, partition_id: str): + """Clear cached states. + + This method removes any cached sampling results that include the specified + global indexes, ensuring that future sampling operations do not reference + stale data. + + Args: + partition_id: The partition ID associated with the task. + """ + if partition_id in self._states.keys(): + self._states.pop(partition_id) diff --git a/transfer_queue/sampler/grpo_group_n_sampler.py b/transfer_queue/sampler/grpo_group_n_sampler.py index d61c71ad..0578a09f 100644 --- a/transfer_queue/sampler/grpo_group_n_sampler.py +++ b/transfer_queue/sampler/grpo_group_n_sampler.py @@ -38,7 +38,7 @@ class GRPOGroupNSampler(BaseSampler): # Initialize controller with GRPO sampler from transfer_queue import TransferQueueController, GRPOGroupNSampler, AsyncTransferQueueClient - controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) + controller = TransferQueueController.remote(sampler=GRPOGroupNSampler(n_samples_per_prompt=4)) controller_info = process_zmq_server_info(controller) client = AsyncTransferQueueClient( @@ -52,7 +52,6 @@ class GRPOGroupNSampler(BaseSampler): batch_size=16, # Total samples requested partition_id="train_0", task_name="rl_training", - sampling_config={"n_samples_per_prompt": 4} # 4 samples per prompt ) # This will return 16 samples organized as 4 groups of 4 samples each ``` @@ -69,19 +68,29 @@ class GRPOGroupNSampler(BaseSampler): def __init__( self, + n_samples_per_prompt: int = 1, ): """Initialize the GRPOGroupNSampler. The sampler maintains minimal internal state and relies on runtime configuration through the sampling_config parameter. + Args: + n_samples_per_prompt: Number of samples per prompt group. Must be > 0. + """ super().__init__() + # Basic validation + if n_samples_per_prompt <= 0: + raise ValueError(f"n_samples_per_prompt must be positive, got {n_samples_per_prompt}") + self.n_samples_per_prompt = n_samples_per_prompt + def sample( self, ready_indexes: list[int], batch_size: int, - n_samples_per_prompt: int, + task_name: str = "", + partition_id: str = "", *args: Any, **kwargs: Any, ) -> tuple[list[int], list[int]]: @@ -95,50 +104,63 @@ def sample( produced and samples are not labeled as consumed. These should be organized such that consecutive indices belong to the same prompt group. batch_size: Total number of samples to select. Must be divisible by n_samples_per_prompt. - n_samples_per_prompt: Number of samples per prompt group. Must be > 0. - *args: Additional positional arguments (ignored in current implementation) - **kwargs: Additional keyword arguments (ignored in current implementation) + task_name: Unique identifier for the training task. Used for state caching and + tracking consumed samples. + partition_id: Partition ID for data versioning. Used for state organization. + *args: Additional positional arguments (ignored in current implementation). + **kwargs: Additional keyword arguments, key ones are: + - dp_rank: Data parallel rank for multi-GPU training. Used for state cache organization. + - batch_index: Current batch index for tracking consumption progress. Returns: Tuple of (sampled_indexes, consumed_indexes): - - sampled_indexes: List of selected global indices, length = batch_size or empty + - sampled_indexes: List of selected global indices, length = batch_size or empty if + insufficient complete groups are available. - consumed_indexes: List of indices to mark as consumed, identical to sampled_indexes - (without replacement semantics) + (without replacement semantics). + + Raises: + ValueError: batch_size is not divisible by n_samples_per_prompt. Examples: - >>> sampler = GRPOGroupNSampler() + >>> sampler = GRPOGroupNSampler(n_samples_per_prompt=3) >>> ready_indexes = [0, 1, 3, 4, 6, 7] # No complete groups after sorting - >>> sampled, consumed = sampler.sample(ready_indexes, 6, n_samples_per_prompt=3) + >>> sampled, consumed = sampler.sample(ready_indexes, 6) >>> sampled [] >>> consumed [] >>> ready_indexes = [0, 1, 3, 4, 5, 6, 7, 9, 10, 11] # Has complete groups after sorting - >>> sampled, consumed = sampler.sample(ready_indexes, 6, n_samples_per_prompt=3) + >>> sampled, consumed = sampler.sample(ready_indexes, 6) >>> sampled [3, 4, 5, 9, 10, 11] >>> consumed [3, 4, 5, 9, 10, 11] """ - # Basic validation - if n_samples_per_prompt <= 0: - raise ValueError(f"n_samples_per_prompt must be positive, got {n_samples_per_prompt}") + states = self._states.get(partition_id, {}).get(task_name, {}) + dp_rank = kwargs.get("dp_rank", None) + batch_index = kwargs.get("batch_index", None) + + # Return cached result if available + if dp_rank in states.keys() and batch_index in states[dp_rank].keys(): + return states[dp_rank][batch_index] - if batch_size % n_samples_per_prompt != 0: + if batch_size % self.n_samples_per_prompt != 0: raise ValueError( - f"batch_size ({batch_size}) must be a multiple of n_samples_per_prompt ({n_samples_per_prompt})" + f"batch_size ({batch_size}) must be a multiple of n_samples_per_prompt ({self.n_samples_per_prompt})" ) - required_groups = batch_size // n_samples_per_prompt + required_groups = batch_size // self.n_samples_per_prompt sorted_ready_indexes = sorted(ready_indexes) complete_group_indices = [] found_groups = 0 + # Scan for consecutive groups i = 0 - while i <= len(sorted_ready_indexes) - n_samples_per_prompt and found_groups < required_groups: - potential_group = sorted_ready_indexes[i : i + n_samples_per_prompt] + while i <= len(sorted_ready_indexes) - self.n_samples_per_prompt and found_groups < required_groups: + potential_group = sorted_ready_indexes[i : i + self.n_samples_per_prompt] # Check if this forms a complete group (consecutive indices) is_consecutive = all( potential_group[j + 1] - potential_group[j] == 1 for j in range(len(potential_group) - 1) @@ -146,13 +168,25 @@ def sample( if is_consecutive: complete_group_indices.extend(potential_group) found_groups += 1 - i += n_samples_per_prompt + i += self.n_samples_per_prompt else: i += 1 + # Return empty if insufficient complete groups if found_groups < required_groups: return [], [] + sampled_indexes = complete_group_indices consumed_indexes = sampled_indexes.copy() + # Cache the sampling result for deterministic future calls + if dp_rank is not None: + if dp_rank not in states: + states[dp_rank] = {} + states[dp_rank][batch_index] = (sampled_indexes, consumed_indexes) + elif batch_index not in states[dp_rank]: + states[dp_rank][batch_index] = (sampled_indexes, consumed_indexes) + if partition_id not in self._states: + self._states[partition_id] = {} + self._states[partition_id][task_name] = states return sampled_indexes, consumed_indexes diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index 64531fd5..34b8b232 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -54,9 +54,8 @@ def sample( self, ready_indexes: list[int], batch_size: int, - data_replica_group: int, - data_replica_rank: int, - data_replica_world_size: int, + dp_rank: int, + batch_index: int, task_name: str, partition_id: str, *args: Any, @@ -76,8 +75,8 @@ def sample( self._states = { "partition_id": { "task_name": { - data_replica_group: { - data_replica_rank: [[sampled_indexes], ...] # Buffer of cached sampled indices + dp_rank: { + "batch_index": [sampled_indexes] } } } @@ -93,10 +92,9 @@ def sample( as consumed in the corresponding task. batch_size: Number of samples to select. If larger than available ready samples, no samples are returned and both lists are empty. - data_replica_group: The group id of current data replica group. Used to - identify which data replica group this rank belongs to. - data_replica_rank: Local rank inside this data_replica_group. - data_replica_world_size: Total number of ranks in this data_replica_group. + dp_rank: Data parallel rank ID that this worker belongs to + The same Ranks receive the same data samples. + batch_index: Current batch index for tracking consumption progress. task_name: Identifier for the task. partition_id: Partition ID for data management. *args: Additional positional arguments (ignored). @@ -114,14 +112,8 @@ def sample( """ - if data_replica_world_size < 1: - raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") - - if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: - raise ValueError( - f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than " - f"data_replica_world_size {data_replica_world_size}" - ) + if dp_rank < 0: + raise ValueError(f"dp_rank {dp_rank} must be greater than or equal to 0") if partition_id not in self._states: self._states[partition_id] = {} @@ -129,10 +121,10 @@ def sample( if task_name not in self._states[partition_id]: self._states[partition_id][task_name] = {} - if data_replica_group not in self._states[partition_id][task_name]: - self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} + if dp_rank not in self._states[partition_id][task_name]: + self._states[partition_id][task_name][dp_rank] = {} - if len(self._states[partition_id][task_name][data_replica_group][data_replica_rank]) == 0: + if batch_index not in self._states[partition_id][task_name][dp_rank]: # Select first batch_size indices from ready_indexes sampled_indexes = ready_indexes[:batch_size] @@ -141,14 +133,10 @@ def sample( consumed_indexes = sampled_indexes - # Cache the sampled indices for other ranks in this data replica group - for i in range(data_replica_world_size): - if i != data_replica_rank: - self._states[partition_id][task_name][data_replica_group][i].append(sampled_indexes) - + self._states[partition_id][task_name][dp_rank][batch_index] = sampled_indexes else: # Return the cached indices (identical to what first rank received) - sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop(0) + sampled_indexes = self._states[partition_id][task_name][dp_rank][batch_index] consumed_indexes = sampled_indexes return sampled_indexes, consumed_indexes diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py index 9f6eaf96..96ac33f8 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/05_streaming_dataloader.py @@ -167,9 +167,7 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): @ray.remote(num_cpus=0.1) def update_worker( rank_id: int, - data_replica_group: int, - data_replica_rank: int, - data_replica_world_size: int, + dp_rank: int, config: DictConfig, max_steps: int = 5, ): @@ -182,11 +180,8 @@ def update_worker( Args: rank_id: Global rank identifier for logging and display purposes - data_replica_group: ID of the data parallel group this rank belongs to - Ranks in the same group receive the same data samples - data_replica_rank: Local rank index within the data replica group - Range: [0, data_replica_world_size - 1] - data_replica_world_size: Total number of ranks in this data replica group + dp_rank: Data parallel rank ID that this worker belongs to + The same Ranks receive the same data samples config: TransferQueue configuration max_steps: Maximum number of batches to consume @@ -194,10 +189,10 @@ def update_worker( dict: Contains data_replica_rank, data_replica_group, and consumed_ids Example: - For a setup with 2 data replica groups (0 and 1), each with 2 ranks: - - Group 0: ranks [0, 1] receive identical samples - - Group 1: ranks [2, 3] receive identical samples - All ranks within the same group get the same global indexes. + For a setup with 2 data rank (0 and 1): + - Rank 0: receive identical samples + - Rank 1: receive identical samples + All ranks within the same rank index get the same global indexes. Note: The StreamingDataLoader yields tuples of (batch, batch_meta) where: @@ -209,13 +204,12 @@ def update_worker( # This dataset integrates with TransferQueue and handles batch retrieval dataset = StreamingDataset( config=config, - micro_batch_size=2, # Number of samples per batch - required_fields=["meta_idx"], # Fields to retrieve from storage. We can retrieve partial fields! + batch_size=2, + micro_batch_size=2, # Number of samples per micro-batch. + data_fields=["meta_idx"], # Fields to retrieve from storage. We can retrieve partial fields! partition_id="train", # Data partition to consume from task_name="update_task", # Unique task identifier - data_replica_group=data_replica_group, - data_replica_rank=data_replica_rank, - data_replica_world_size=data_replica_world_size, + dp_rank=dp_rank, ) print(f"[Update Worker@{rank_id}] StreamingDataset created successfully") @@ -240,10 +234,7 @@ def update_worker( # Extract sample IDs from the batch ids = batch["meta_idx"].view(-1).tolist() - print( - f"[Update Worker@{rank_id}]: data_replica_rank {data_replica_rank} in " - f"data_replica_group {data_replica_group} retrieved samples: {ids}" - ) + print(f"[Update Worker@{rank_id}]: dp_rank {dp_rank} retrieved samples: {ids}") consumed_ids.extend(ids) # Simulate processing time (remove in real training) @@ -257,8 +248,7 @@ def update_worker( print(f"[Update Worker@{rank_id}] Completed {step} steps, consumed {len(consumed_ids)} samples") return { - "data_replica_rank": data_replica_rank, - "data_replica_group": data_replica_group, + "dp_rank": dp_rank, "consumed_ids": consumed_ids, } @@ -271,7 +261,7 @@ def start_all_generate_actors(config): handlers = [] for i in range(num_workers): - handlers.append(generate_worker.remote(rank_id=i, config=config)) + handlers.append(generate_worker.remote(rank_id=i, config=config, num_samples=20)) return handlers @@ -283,23 +273,18 @@ def start_all_update_actors(config): # Define the distributed training topology rank_ids = [0, 1, 2, 3] - data_replica_group = [0, 0, 1, 1] # First two ranks in group 0, last two in group 1 - data_replica_world_size = 2 # Each group has 2 ranks + dp_rank = [0, 0, 1, 1] # First two ranks in group 0, last two in group 1 print("Training topology configuration:") print(f" - Total ranks: {len(rank_ids)}") - print(f" - Data replica groups: {len(set(data_replica_group))}") - print(f" - World size per group: {data_replica_world_size}") - print(f" - Group assignments: {dict(zip(rank_ids, data_replica_group, strict=False))}") + print(f" - Data parallel rank: {len(set(dp_rank))}") handlers = [] for i in range(len(rank_ids)): handlers.append( update_worker.remote( rank_id=rank_ids[i], - data_replica_group=data_replica_group[i], - data_replica_rank=rank_ids[i] % data_replica_world_size, - data_replica_world_size=data_replica_world_size, + dp_rank=dp_rank[i], config=config, ) )