Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions tests/test_controller_data_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,172 @@ def test_get_consumption_status_parameter():
print("✓ get_consumption_status with mask works")

print("Consumption status mask parameter tests passed!\n")


def test_pre_allocated_indexes_basic():
"""Test basic pre-allocated indexes functionality in DataPartitionStatus."""
from transfer_queue.controller import DataPartitionStatus

print("Testing pre-allocated indexes basic functionality...")

partition = DataPartitionStatus(partition_id="prealloc_test")

# Initially, pre_allocated_global_indexes should be empty
assert len(partition.pre_allocated_global_indexes) == 0
assert partition.total_samples_num == 0

print("✓ Initial state correct")

# Register pre-allocated indexes
pre_allocated = [0, 1, 2, 3, 4]
partition.register_pre_allocated_indexes(pre_allocated)

assert partition.pre_allocated_global_indexes == set(pre_allocated)
# global_indexes should still be empty until retrieved
assert partition.total_samples_num == 0

print("✓ Pre-allocated indexes registered")

# activate pre-allocated indexes
retrieved = partition.activate_pre_allocated_indexes(3)

assert len(retrieved) == 3
assert set(retrieved) == {0, 1, 2}
assert partition.global_indexes == {0, 1, 2}
assert partition.pre_allocated_global_indexes == {3, 4}
assert partition.total_samples_num == 3

print("✓ Pre-allocated indexes activate & retrieved correctly")

# Activate remaining indexes
retrieved = partition.activate_pre_allocated_indexes(5)

assert len(retrieved) == 2 # Only 2 remaining
assert set(retrieved) == {3, 4}
assert partition.global_indexes == {0, 1, 2, 3, 4}
assert partition.pre_allocated_global_indexes == set()
assert partition.total_samples_num == 5

print("✓ All pre-allocated indexes retrieved")

print("Pre-allocated indexes basic tests passed!\n")


def test_pre_allocated_indexes_consumption_status():
"""Test that pre-allocated indexes are included in consumption status."""
import torch

from transfer_queue.controller import DataPartitionStatus

print("Testing pre-allocated indexes in consumption status...")

partition = DataPartitionStatus(partition_id="consumption_test")

# Register pre-allocated indexes
partition.register_pre_allocated_indexes([0, 1, 2, 3, 4])

# Get consumption status - should include pre-allocated indexes
global_index, consumption_status = partition.get_consumption_status("test_task", mask=True)

# global_index should include all pre-allocated indexes
assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long))
# All consumption statuses should be 0 (not consumed yet)
assert torch.all(consumption_status == 0)

print("✓ Consumption status includes pre-allocated indexes")

# Mark some samples as consumed
partition.mark_consumed("test_task", [0, 2, 4])

# Get consumption status again
global_index, consumption_status = partition.get_consumption_status("test_task", mask=True)

assert consumption_status[0].item() == 1 # consumed
assert consumption_status[1].item() == 0 # not consumed
assert consumption_status[2].item() == 1 # consumed
assert consumption_status[3].item() == 0 # not consumed
assert consumption_status[4].item() == 1 # consumed

print("✓ Marked consumed works with pre-allocated indexes")

print("Pre-allocated indexes consumption status tests passed!\n")


def test_pre_allocated_indexes_in_scan_data_status():
"""Test that pre-allocated indexes affect scan_data_status behavior."""
from transfer_queue.controller import DataPartitionStatus

print("Testing pre-allocated indexes in scan_data_status...")

partition = DataPartitionStatus(partition_id="scan_test")

# Register pre-allocated indexes (5 samples)
partition.register_pre_allocated_indexes([0, 1, 2, 3, 4])

# Before any production, scan should return empty (no samples produced yet)
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
assert ready == []

print("✓ Scan returns empty before production")

# Now produce some samples (0, 2, 4)
partition.update_production_status(
global_indices=[0, 2, 4],
field_names=["input_ids"],
dtypes={i: {"input_ids": "torch.int32"} for i in [0, 2, 4]},
shapes={i: {"input_ids": (32,)} for i in [0, 2, 4]},
)

# Scan should return produced and unconsumed samples
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
assert set(ready) == {0, 2, 4}

print("✓ Scan returns produced samples correctly")

# Mark sample 2 as consumed
partition.mark_consumed("test_task", [2])

# Scan should now return only 0 and 4
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
assert set(ready) == {0, 4}

print("✓ Scan respects consumption status")

print("Pre-allocated indexes scan_data_status tests passed!\n")


def test_pre_allocated_indexes_mixed_with_dynamic():
"""Test mixing pre-allocated indexes with dynamically allocated ones."""
from transfer_queue.controller import DataPartitionStatus

print("Testing mixed pre-allocated and dynamic indexes...")

partition = DataPartitionStatus(partition_id="mixed_test")

# Register 3 pre-allocated indexes
partition.register_pre_allocated_indexes([0, 1, 2])

# Simulate adding more samples (indexes 5, 6, 7)
# This would happen when producer calls update_production_status
partition.update_production_status(
global_indices=[5, 6, 7],
field_names=["input_ids"],
dtypes={i: {"input_ids": "torch.int32"} for i in [5, 6, 7]},
shapes={i: {"input_ids": (32,)} for i in [5, 6, 7]},
)

# Now global_indexes should only contain dynamically generated in (5,6,7)
assert partition.global_indexes == {5, 6, 7}
assert partition.total_samples_num == 3

# all pre-allocated
retrieved = partition.activate_pre_allocated_indexes(3)
assert set(retrieved) == {0, 1, 2}

# Now global_indexes should have both pre-allocated (0,1,2) and dynamic (5,6,7)
assert partition.global_indexes == {0, 1, 2, 5, 6, 7}
assert partition.total_samples_num == 6

print("✓ Mixed pre-allocated and dynamic indexes work correctly")

print("Mixed indexes tests passed!\n")
Loading
Loading