diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index df9f60d3..976b277b 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -653,3 +653,172 @@ def test_get_consumption_status_parameter(): print("✓ get_consumption_status with mask works") print("Consumption status mask parameter tests passed!\n") + + +def test_pre_allocated_indexes_basic(): + """Test basic pre-allocated indexes functionality in DataPartitionStatus.""" + from transfer_queue.controller import DataPartitionStatus + + print("Testing pre-allocated indexes basic functionality...") + + partition = DataPartitionStatus(partition_id="prealloc_test") + + # Initially, pre_allocated_global_indexes should be empty + assert len(partition.pre_allocated_global_indexes) == 0 + assert partition.total_samples_num == 0 + + print("✓ Initial state correct") + + # Register pre-allocated indexes + pre_allocated = [0, 1, 2, 3, 4] + partition.register_pre_allocated_indexes(pre_allocated) + + assert partition.pre_allocated_global_indexes == set(pre_allocated) + # global_indexes should still be empty until retrieved + assert partition.total_samples_num == 0 + + print("✓ Pre-allocated indexes registered") + + # activate pre-allocated indexes + retrieved = partition.activate_pre_allocated_indexes(3) + + assert len(retrieved) == 3 + assert set(retrieved) == {0, 1, 2} + assert partition.global_indexes == {0, 1, 2} + assert partition.pre_allocated_global_indexes == {3, 4} + assert partition.total_samples_num == 3 + + print("✓ Pre-allocated indexes activate & retrieved correctly") + + # Activate remaining indexes + retrieved = partition.activate_pre_allocated_indexes(5) + + assert len(retrieved) == 2 # Only 2 remaining + assert set(retrieved) == {3, 4} + assert partition.global_indexes == {0, 1, 2, 3, 4} + assert partition.pre_allocated_global_indexes == set() + assert partition.total_samples_num == 5 + + print("✓ All pre-allocated indexes retrieved") + + print("Pre-allocated indexes basic tests passed!\n") + + +def test_pre_allocated_indexes_consumption_status(): + """Test that pre-allocated indexes are included in consumption status.""" + import torch + + from transfer_queue.controller import DataPartitionStatus + + print("Testing pre-allocated indexes in consumption status...") + + partition = DataPartitionStatus(partition_id="consumption_test") + + # Register pre-allocated indexes + partition.register_pre_allocated_indexes([0, 1, 2, 3, 4]) + + # Get consumption status - should include pre-allocated indexes + global_index, consumption_status = partition.get_consumption_status("test_task", mask=True) + + # global_index should include all pre-allocated indexes + assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)) + # All consumption statuses should be 0 (not consumed yet) + assert torch.all(consumption_status == 0) + + print("✓ Consumption status includes pre-allocated indexes") + + # Mark some samples as consumed + partition.mark_consumed("test_task", [0, 2, 4]) + + # Get consumption status again + global_index, consumption_status = partition.get_consumption_status("test_task", mask=True) + + assert consumption_status[0].item() == 1 # consumed + assert consumption_status[1].item() == 0 # not consumed + assert consumption_status[2].item() == 1 # consumed + assert consumption_status[3].item() == 0 # not consumed + assert consumption_status[4].item() == 1 # consumed + + print("✓ Marked consumed works with pre-allocated indexes") + + print("Pre-allocated indexes consumption status tests passed!\n") + + +def test_pre_allocated_indexes_in_scan_data_status(): + """Test that pre-allocated indexes affect scan_data_status behavior.""" + from transfer_queue.controller import DataPartitionStatus + + print("Testing pre-allocated indexes in scan_data_status...") + + partition = DataPartitionStatus(partition_id="scan_test") + + # Register pre-allocated indexes (5 samples) + partition.register_pre_allocated_indexes([0, 1, 2, 3, 4]) + + # Before any production, scan should return empty (no samples produced yet) + ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task") + assert ready == [] + + print("✓ Scan returns empty before production") + + # Now produce some samples (0, 2, 4) + partition.update_production_status( + global_indices=[0, 2, 4], + field_names=["input_ids"], + dtypes={i: {"input_ids": "torch.int32"} for i in [0, 2, 4]}, + shapes={i: {"input_ids": (32,)} for i in [0, 2, 4]}, + ) + + # Scan should return produced and unconsumed samples + ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task") + assert set(ready) == {0, 2, 4} + + print("✓ Scan returns produced samples correctly") + + # Mark sample 2 as consumed + partition.mark_consumed("test_task", [2]) + + # Scan should now return only 0 and 4 + ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task") + assert set(ready) == {0, 4} + + print("✓ Scan respects consumption status") + + print("Pre-allocated indexes scan_data_status tests passed!\n") + + +def test_pre_allocated_indexes_mixed_with_dynamic(): + """Test mixing pre-allocated indexes with dynamically allocated ones.""" + from transfer_queue.controller import DataPartitionStatus + + print("Testing mixed pre-allocated and dynamic indexes...") + + partition = DataPartitionStatus(partition_id="mixed_test") + + # Register 3 pre-allocated indexes + partition.register_pre_allocated_indexes([0, 1, 2]) + + # Simulate adding more samples (indexes 5, 6, 7) + # This would happen when producer calls update_production_status + partition.update_production_status( + global_indices=[5, 6, 7], + field_names=["input_ids"], + dtypes={i: {"input_ids": "torch.int32"} for i in [5, 6, 7]}, + shapes={i: {"input_ids": (32,)} for i in [5, 6, 7]}, + ) + + # Now global_indexes should only contain dynamically generated in (5,6,7) + assert partition.global_indexes == {5, 6, 7} + assert partition.total_samples_num == 3 + + # all pre-allocated + retrieved = partition.activate_pre_allocated_indexes(3) + assert set(retrieved) == {0, 1, 2} + + # Now global_indexes should have both pre-allocated (0,1,2) and dynamic (5,6,7) + assert partition.global_indexes == {0, 1, 2, 5, 6, 7} + assert partition.total_samples_num == 6 + + print("✓ Mixed pre-allocated and dynamic indexes work correctly") + + print("Mixed indexes tests passed!\n") diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index eda92678..a59b140a 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -62,17 +62,10 @@ TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1)) TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5)) - -TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM", 1)) # Initial number of samples -TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 1)) - -# Expansion configuration - Unified approach using minimum expansion sizes -TQ_SAMPLE_MIN_EXPANSION_SIZE = int( - os.environ.get("TQ_SAMPLE_MIN_EXPANSION_SIZE", 1) -) # Minimum expansion size for samples (rows) -TQ_FIELD_MIN_EXPANSION_SIZE = int( - os.environ.get("TQ_FIELD_MIN_EXPANSION_SIZE", 1) -) # Minimum expansion size for fields (columns) +# Sample pre-allocation for StreamingDataLoader compatibility. +# By pre-allocating sample indices (typically global_batch_size), consumers can accurately +# determine consumption status even before producers have generated the samples. +TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) class PartitionIndexManager: @@ -94,7 +87,7 @@ def __init__(self): # Track all active indexes self.allocated_indexes = set() - def allocate_indexes(self, partition_id, count=1) -> list: + def allocate_indexes(self, partition_id, count=1) -> list[int]: """ Allocate global_indexes for the specified partition. Prioritizes obtaining from reusable pool, allocates new indexes when insufficient. @@ -216,7 +209,8 @@ class DataPartitionStatus: # Production status tensor - dynamically expandable # Values: 0 = not produced, 1 = ready for consumption - production_status: Tensor = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) + + production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) # Consumption status per task - task_name -> consumption_tensor # Each tensor tracks which samples have been consumed by that task @@ -227,6 +221,10 @@ class DataPartitionStatus: default_factory=set ) # set of global indexes that have been added to this partition + pre_allocated_global_indexes: set[int] = field( + default_factory=set + ) # set of global indexes that pre-allocated, but not active in this partition + # Field metadata field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} @@ -258,22 +256,79 @@ def allocated_samples_num(self) -> int: """Current number of allocated rows in the tensor.""" return self.production_status.shape[0] + # ==================== Index Pre-Allocation Methods ==================== + + def register_pre_allocated_indexes(self, allocated_indexes: list[int]): + """ + Register pre-allocated sample indexes to this partition. + + These indexes are reserved before actual data production, allowing consumers + to see the expected total sample count via get_consumption_status even when + producers haven't generated all samples yet. + + Args: + allocated_indexes: List of global indexes to pre-allocate + """ + + if len(allocated_indexes) < 1: + logger.info("Trying to pre-allocate global_indexes with empty list!") + return + + self.pre_allocated_global_indexes.update(allocated_indexes) + + # Expand the state matrices + max_sample_idx = max(allocated_indexes) + required_samples = max_sample_idx + 1 + + with self.data_status_lock: + self.ensure_samples_capacity(required_samples) + + logger.debug(f"Pre-allocated indexes in {self.partition_id}: {allocated_indexes}") + + def activate_pre_allocated_indexes(self, sample_num: int) -> list[int]: + """ + Activate and retrieve pre-allocated indexes for use in data insertion. + + This method consumes pre-allocated indexes and marks them as active in global_indexes. + If pre-allocated indexes are insufficient, returns all available ones. + + Args: + sample_num: Number of indexes needed + + Returns: + List of retrieved global indexes + """ + available_indexes = len(self.pre_allocated_global_indexes) + + if available_indexes < sample_num: + global_index_to_allocate = list(self.pre_allocated_global_indexes) + logger.debug( + f"Not enough pre-allocated indexes in partition {self.partition_id}. " + f"Returning {available_indexes} of {sample_num} requested." + ) + else: + global_index_to_allocate = list(sorted(self.pre_allocated_global_indexes))[:sample_num] + + self.global_indexes.update(global_index_to_allocate) + self.pre_allocated_global_indexes.difference_update(set(global_index_to_allocate)) + + return global_index_to_allocate + # ==================== Dynamic Expansion Methods ==================== def ensure_samples_capacity(self, required_samples: int) -> None: """ Ensure the production status tensor has enough rows for the required samples. - Dynamically expands if needed using unified minimum expansion size. Args: required_samples: Minimum number of samples needed """ + current_sample_space = self.allocated_samples_num if required_samples > current_sample_space: - # Expand rows using minimum expansion size for predictable memory usage + # Expand rows expansion_needed = required_samples - current_sample_space - min_expansion = max(TQ_SAMPLE_MIN_EXPANSION_SIZE, expansion_needed) - new_samples = current_sample_space + min_expansion + new_samples = current_sample_space + expansion_needed new_fields = self.production_status.shape[1] expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8) @@ -286,15 +341,11 @@ def ensure_samples_capacity(self, required_samples: int) -> None: expanded_consumption[:current_sample_space] = consumption_tensor self.consumption_status[task_name] = expanded_consumption - logger.debug( - f"Expanded partition {self.partition_id} from {current_sample_space} " - f"to {new_samples} samples (added {min_expansion} samples)" - ) + logger.debug(f"Expanded partition {self.partition_id} from {current_sample_space} to {new_samples} samples") def ensure_fields_capacity(self, required_fields: int) -> None: """ Ensure the production status tensor has enough columns for the required fields. - Dynamically expands if needed using unified minimum expansion size. Args: required_fields: Minimum number of fields needed @@ -302,20 +353,16 @@ def ensure_fields_capacity(self, required_fields: int) -> None: current_fields = self.production_status.shape[1] if required_fields > current_fields: - # Expand columns using minimum expansion size for predictable memory usage + # Expand columns expansion_needed = required_fields - current_fields - min_expansion = max(TQ_FIELD_MIN_EXPANSION_SIZE, expansion_needed) - new_fields = current_fields + min_expansion + new_fields = current_fields + expansion_needed new_samples = self.production_status.shape[0] expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8) expanded_tensor[:, :current_fields] = self.production_status self.production_status = expanded_tensor - logger.debug( - f"Expanded partition {self.partition_id} from {current_fields} " - f"to {new_fields} fields (added {min_expansion} fields)" - ) + logger.debug(f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields") # ==================== Production Status Interface ==================== @@ -487,7 +534,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te # Get consumption status for requested task consumption_status = self.consumption_status[task_name] - partition_global_index = torch.tensor(sorted(self.global_indexes), dtype=torch.long) + partition_global_index = torch.tensor( + sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long + ) if mask: consumption_status = consumption_status[partition_global_index] @@ -526,7 +575,9 @@ def get_production_status_for_fields( production_status = self.production_status[:, col_mask] - partition_global_index = torch.tensor(sorted(self.global_indexes), dtype=torch.long) + partition_global_index = torch.tensor( + sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long + ) if mask: production_status = production_status[partition_global_index] @@ -774,9 +825,11 @@ def __init__( def create_partition(self, partition_id: str) -> bool: """ - Create a new data partition. + Create a new data partition with pre-allocated sample indexes. - Note: Partitions now dynamically expand as needed, so initial capacity is not required. + Partitions dynamically expand as needed. Additionally, TQ_PRE_ALLOC_SAMPLE_NUM + indexes are pre-allocated to allow consumers to determine consumption status + before all samples are produced. Args: partition_id: Unique identifier for the partition (e.g., "train@global_batch_0") @@ -790,7 +843,11 @@ def create_partition(self, partition_id: str) -> bool: self.partitions[partition_id] = DataPartitionStatus(partition_id=partition_id) - logger.info(f"Created partition {partition_id}") + # Pre-allocate global indexes for consumer consumption tracking + global_indexes = self.index_manager.allocate_indexes(partition_id, count=TQ_PRE_ALLOC_SAMPLE_NUM) + self.partitions[partition_id].register_pre_allocated_indexes(global_indexes) + + logger.info(f"Created partition {partition_id} with {TQ_PRE_ALLOC_SAMPLE_NUM} pre-allocated indexes") return True def _get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]: @@ -965,10 +1022,26 @@ def get_metadata( self.create_partition(partition_id) if mode == "insert": + partition = self._get_partition(partition_id) + if data_fields: - # First put_data call, get_metadata in insert mode - batch_global_indexes = self.index_manager.allocate_indexes(partition_id, count=batch_size) + # This is called during put_data call without providing metadata. + # try to use pre-allocated global index first + + if batch_size is None: + raise ValueError("must provide batch_size for inserting new data") + + assert partition is not None + batch_global_indexes = partition.activate_pre_allocated_indexes(batch_size) + + if len(batch_global_indexes) < batch_size: + new_global_indexes = self.index_manager.allocate_indexes( + partition_id, count=(batch_size - len(batch_global_indexes)) + ) + batch_global_indexes.extend(new_global_indexes) + else: + # TODO: separate this "clear" related logic into a separated mode # clear metadata call passes empty data_fields batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 77d71863..ea1487f2 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -172,7 +172,9 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: assert self._tq_client is not None, "Failed to create TransferQueue client" - # TODO: need to consider async scenario where the samples in partition is dynamically increasing + # Note: For fully streamed production-consumption, please set the environment variable + # TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately + # determine consumption status even before producers have generated the samples. while not self._tq_client.check_consumption_status(self.task_name, self.partition_id): try: # Get metadata from controller diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py index d532c80e..9f6eaf96 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/05_streaming_dataloader.py @@ -334,7 +334,6 @@ def main(): - StreamingDataLoader: DataLoader wrapper yielding (batch, batch_meta) tuples - RankAwareSampler: Enables correct data consumption across data replica ranks - Data Replica Group: Ranks that should receive identical data samples (TP, PP, ...) - """ ) ) @@ -342,6 +341,11 @@ def main(): # Step 1: Setup TransferQueue infrastructure print("\n[Phase 1] Setting up TransferQueue infrastructure...") + print( + "\nIn real-world usage, please export the environment variable of TQ_PRE_ALLOC_SAMPLE_NUM to " + "global_batch_size to make sure consumers can accurately determine consumption status even before " + "producers have generated the samples." + ) controller, storage_units, config = setup_transfer_queue() # Step 2: Launch data generation actors