Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions batched/aio/inference/model_batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
cache: AsyncCache[dict[str, Feature], Feature] | None = None,
max_batch_length: int | None = None,
pad_tokens: Optional[dict[str, int]] = None,
padding_side: str = "right",
priority_strategy: PriorityStrategy = PriorityStrategy.NONE,
batch_item_cls: type[AsyncBatchItem[dict[str, Feature], Feature]] = AsyncBatchItem[dict[str, Feature], Feature],
spread_kwargs: bool = False,
Expand All @@ -45,6 +46,7 @@ def __init__(
cache (AsyncCache | None): An optional cache for storing results. Defaults to None.
max_batch_length (int | None): The maximum length of a batch. Defaults to None.
pad_tokens (dict[str, int] | None): Dictionary of padding tokens for each feature. Defaults to None.
padding_side (str): Side to add padding tokens. Either "left" or "right". Defaults to "right".
priority_strategy (PriorityStrategy): The strategy to use for prioritizing items.
batch_item_cls (type[AsyncBatchItem]): The class to use for batch items. Defaults to AsyncBatchItem.
spread_kwargs (bool): Whether to spread the kwargs over passing dict as args. Defaults to False.
Expand All @@ -61,6 +63,7 @@ def __init__(
)

self.pad_tokens = pad_tokens or {}
self.padding_side = padding_side
self.spread_kwargs = spread_kwargs

async def _process_batches(self):
Expand All @@ -77,6 +80,7 @@ async def _process_batches(self):
batch_inputs = stack_features(
[item.content for item in batch],
pad_tokens=self.pad_tokens,
padding_side=self.padding_side,
)

batch_outputs = (
Expand Down Expand Up @@ -129,6 +133,7 @@ def dynamically(
small_batch_threshold: int = 8,
max_batch_length: int | None = None,
pad_tokens: Optional[dict[str, int]] = None,
padding_side: str = "right",
priority_strategy: PriorityStrategy = PriorityStrategy.NONE,
cache: AsyncCache[dict[str, Feature], Feature] | None = None,
batch_item_cls: type[AsyncBatchItem[dict[str, Feature], Feature]] = AsyncBatchItem[dict[str, Feature], Feature],
Expand All @@ -144,6 +149,7 @@ def dynamically(
small_batch_threshold (int): The threshold to give priority to small batches. Defaults to 8.
max_batch_length (int | None): The maximum length of a batch. Defaults to None.
pad_tokens (dict[str, int] | None): Padding token values for each input feature. Defaults to None.
padding_side (str): Side to add padding tokens. Either "left" or "right". Defaults to "right".
priority_strategy (PriorityStrategy): The strategy to use for prioritizing items.
cache (AsyncCache | None): An optional cache for storing results.
batch_item_cls (type[AsyncBatchItem]): The class to use for batch items. Defaults to AsyncBatchItem.
Expand All @@ -170,6 +176,7 @@ def make_processor(_func: BatchInfer) -> AsyncModelBatchProcessor:
small_batch_threshold=small_batch_threshold,
max_batch_length=max_batch_length,
pad_tokens=pad_tokens,
padding_side=padding_side,
priority_strategy=priority_strategy,
cache=cache,
batch_item_cls=batch_item_cls,
Expand Down
15 changes: 13 additions & 2 deletions batched/inference/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ def torch_or_np(item: Any):
raise ValueError(msg)


def stack_features(inputs: list[dict[str, Feature]], pad_tokens: dict[str, int]) -> dict[str, Feature]:
def stack_features(
inputs: list[dict[str, Feature]],
pad_tokens: dict[str, int],
padding_side: str = "right"
) -> dict[str, Feature]:
"""
Stack a list of model features into a single batch.

Args:
inputs (list[ModelFeatures]): List of input features to stack.
pad_tokens (dict[str, int]): Dictionary of padding tokens for each feature.
padding_side (str): Side to add padding tokens. Either "left" or "right". Defaults to "right".

Returns:
ModelFeatures: Stacked features as a single batch.
Expand All @@ -76,7 +81,13 @@ def stack_features(inputs: list[dict[str, Feature]], pad_tokens: dict[str, int])
for i, item in enumerate(inputs):
for key, tensor in padded_tensors.items():
tensor_length = item[key].shape[0]
tensor[i, :tensor_length] = item[key]
if padding_side == "left":
# Left padding: fill from the right side (end of sequence)
start_idx = max_length - tensor_length
tensor[i, start_idx:] = item[key]
else:
# Right padding: fill from the left side (start of sequence)
tensor[i, :tensor_length] = item[key]

return padded_tensors

Expand Down
7 changes: 7 additions & 0 deletions batched/inference/model_batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
timeout_ms: float = 5.0,
small_batch_threshold: int = 8,
pad_tokens: Optional[dict[str, int]] = None,
padding_side: str = "right",
spread_kwargs: bool = False,
):
"""
Expand All @@ -38,6 +39,7 @@ def __init__(
timeout_ms (float): The timeout in milliseconds between batch generation attempts. Defaults to 5.0.
small_batch_threshold (int): The threshold for considering a batch as small. Defaults to 8.
pad_tokens (dict[str, int] | None): Dictionary of padding tokens for each feature. Defaults to None.
padding_side (str): Side to add padding tokens. Either "left" or "right". Defaults to "right".
spread_kwargs (bool): Whether to spread the kwargs over passing dict as args. Defaults to False.
"""
super().__init__(
Expand All @@ -48,6 +50,7 @@ def __init__(
)

self.pad_tokens = pad_tokens or {}
self.padding_side = padding_side
self.spread_kwargs = spread_kwargs

def _process_batches(self):
Expand All @@ -67,6 +70,7 @@ def _process_batches(self):
batch_inputs = stack_features(
[item.content for item in batch],
pad_tokens=self.pad_tokens,
padding_side=self.padding_side,
)

batch_outputs = self.batch_func(**batch_inputs) if self.spread_kwargs else self.batch_func(batch_inputs)
Expand Down Expand Up @@ -144,6 +148,7 @@ def dynamically(
timeout_ms: float = 5.0,
small_batch_threshold: int = 8,
pad_tokens: Optional[dict[str, int]] = None,
padding_side: str = "right",
spread_kwargs: bool = False,
) -> Callable:
"""
Expand All @@ -159,6 +164,7 @@ def dynamically(
timeout_ms (float): The timeout in milliseconds between batch generation attempts. Defaults to 5.0.
small_batch_threshold (int): The threshold to give priority to small batches. Defaults to 8.
pad_tokens (dict[str, int] | None): Dictionary of padding tokens for each feature. Defaults to None.
padding_side (str): Side to add padding tokens. Either "left" or "right". Defaults to "right".
spread_kwargs (bool): Whether to spread the kwargs over passing dict as args. Defaults to False.

Returns:
Expand Down Expand Up @@ -193,6 +199,7 @@ def make_processor(_func: BatchInfer) -> ModelBatchProcessor:
timeout_ms=timeout_ms,
small_batch_threshold=small_batch_threshold,
pad_tokens=pad_tokens,
padding_side=padding_side,
spread_kwargs=spread_kwargs,
)

Expand Down
235 changes: 235 additions & 0 deletions tests/aio/inference/aio_inference_batch_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,238 @@ async def batch_func(features: dict[str, Feature]) -> Feature:
np.testing.assert_array_equal(result, np.array([[1, 2], [3, 4], [5, 6], [7, 8]]))
assert processor.stats.total_batches == 2
assert processor.stats.total_processed == 4


@pytest.mark.asyncio
async def test_async_model_batch_processor_with_padding_side_left():
"""Test AsyncModelBatchProcessor with left padding."""
async def batch_func(features: dict[str, Feature]) -> Feature:
return features["input_ids"]

processor = AsyncModelBatchProcessor(
batch_func,
batch_size=2,
pad_tokens={"input_ids": 0},
padding_side="left"
)

# Use asyncio.gather to ensure concurrent execution
results = await asyncio.gather(
processor(input_ids=np.array([[1, 2, 3]])),
processor(input_ids=np.array([[4, 5]]))
)

# Both inputs should be batched together and padded to length 3
# The second input should be left-padded: [4, 5] -> [0, 4, 5]
expected_results = [
np.array([[1, 2, 3]]),
np.array([[0, 4, 5]]) # Left-padded
]

# Sort results by first element to ensure consistent ordering
results = sorted(results, key=lambda x: x[0, 0])
expected_results = sorted(expected_results, key=lambda x: x[0, 0])

for result, expected in zip(results, expected_results):
np.testing.assert_array_equal(result, expected)


@pytest.mark.asyncio
async def test_async_model_batch_processor_with_padding_side_right():
"""Test AsyncModelBatchProcessor with explicit right padding."""
async def batch_func(features: dict[str, Feature]) -> Feature:
return features["input_ids"]

processor = AsyncModelBatchProcessor(
batch_func,
batch_size=3,
pad_tokens={"input_ids": 0},
padding_side="right"
)

# Process different length sequences concurrently
results = await asyncio.gather(
processor(input_ids=np.array([[1, 2, 3]])),
processor(input_ids=np.array([[4, 5]]))
)

# The results should be padded correctly
expected_results = [
np.array([[1, 2, 3]]),
np.array([[4, 5, 0]])
]

# Check that we got the expected results (in any order)
assert len(results) == 2
found_first = any(np.array_equal(r, expected_results[0]) for r in results)
found_second = any(np.array_equal(r, expected_results[1]) for r in results)
assert found_first and found_second, f"Expected results not found in {results}"


@pytest.mark.asyncio
async def test_async_model_batch_processor_multiple_keys_left_padding():
"""Test AsyncModelBatchProcessor with multiple keys and left padding."""
async def batch_func(features: dict[str, Feature]) -> Feature:
return {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"]
}

processor = AsyncModelBatchProcessor(
batch_func,
batch_size=2,
pad_tokens={"input_ids": 0, "attention_mask": 0},
padding_side="left"
)

# Use asyncio.gather to ensure concurrent execution
results = await asyncio.gather(
processor(
input_ids=np.array([[1, 2]]),
attention_mask=np.array([[1, 1]])
),
processor(
input_ids=np.array([[3, 4, 5]]),
attention_mask=np.array([[1, 1, 1]])
)
)

# Both inputs should be batched together and padded to length 3
# The first input should be left-padded: [1, 2] -> [0, 1, 2]
expected_results = [
{
"input_ids": np.array([[0, 1, 2]]),
"attention_mask": np.array([[0, 1, 1]])
},
{
"input_ids": np.array([[3, 4, 5]]),
"attention_mask": np.array([[1, 1, 1]])
}
]

# Sort results by first element to ensure consistent ordering
results = sorted(results, key=lambda x: x["input_ids"][0, 0])
expected_results = sorted(expected_results, key=lambda x: x["input_ids"][0, 0])

for result, expected in zip(results, expected_results):
np.testing.assert_array_equal(result["input_ids"], expected["input_ids"])
np.testing.assert_array_equal(result["attention_mask"], expected["attention_mask"])


@pytest.mark.asyncio
async def test_async_dynamically_decorator_with_padding_side_left():
"""Test async dynamically decorator with left padding."""
@dynamically(batch_size=2, pad_tokens={"input_ids": 0}, padding_side="left")
async def batch_func(features: dict[str, Feature]) -> Feature:
return features["input_ids"]

# Use asyncio.gather to ensure concurrent execution
results = await asyncio.gather(
batch_func(input_ids=np.array([[1, 2, 3]])),
batch_func(input_ids=np.array([[4, 5]]))
)

# Both inputs should be batched together and padded to length 3
# The second input should be left-padded: [4, 5] -> [0, 4, 5]
expected_results = [
np.array([[1, 2, 3]]),
np.array([[0, 4, 5]]) # Left-padded
]

# Sort results by first element to ensure consistent ordering
results = sorted(results, key=lambda x: x[0, 0])
expected_results = sorted(expected_results, key=lambda x: x[0, 0])

for result, expected in zip(results, expected_results):
np.testing.assert_array_equal(result, expected)


@pytest.mark.asyncio
async def test_async_dynamically_decorator_with_padding_side_right():
"""Test async dynamically decorator with right padding."""
@dynamically(batch_size=2, pad_tokens={"input_ids": 0}, padding_side="right")
async def batch_func(features: dict[str, Feature]) -> Feature:
return features["input_ids"]

# Use asyncio.gather to ensure concurrent execution
results = await asyncio.gather(
batch_func(input_ids=np.array([[1, 2, 3]])),
batch_func(input_ids=np.array([[4, 5]]))
)

# Both inputs should be batched together and padded to length 3
# The second input should be right-padded: [4, 5] -> [4, 5, 0]
expected_results = [
np.array([[1, 2, 3]]),
np.array([[4, 5, 0]]) # Right-padded
]

# Sort results by first element to ensure consistent ordering
results = sorted(results, key=lambda x: x[0, 0])
expected_results = sorted(expected_results, key=lambda x: x[0, 0])

for result, expected in zip(results, expected_results):
np.testing.assert_array_equal(result, expected)


@pytest.mark.asyncio
async def test_async_model_batch_processor_padding_side_initialization():
"""Test that padding_side is properly stored during initialization."""
async def dummy_batch_func(features: dict[str, Feature]) -> Feature:
return features["input"]

processor = AsyncModelBatchProcessor(
dummy_batch_func,
batch_size=32,
timeout_ms=5.0,
small_batch_threshold=8,
padding_side="left"
)

assert processor.padding_side == "left"

# Test default
processor_default = AsyncModelBatchProcessor(dummy_batch_func)
assert processor_default.padding_side == "right"


@pytest.mark.asyncio
async def test_async_model_batch_processor_concurrent_calls_with_padding():
"""Test AsyncModelBatchProcessor with concurrent calls and left padding."""
async def batch_func(features: dict[str, Feature]) -> Feature:
await asyncio.sleep(0.01) # Simulate async processing
return features["input_ids"]

processor = AsyncModelBatchProcessor(
batch_func,
batch_size=5,
timeout_ms=50.0,
pad_tokens={"input_ids": 0},
padding_side="left"
)

# Create inputs of different lengths
inputs = [
np.array([[1, 2, 3]]),
np.array([[4, 5]]),
np.array([[6, 7, 8, 9]]),
np.array([[10, 11]]),
np.array([[12, 13, 14]])
]

# Execute concurrent calls
results = await asyncio.gather(*[
processor(input_ids=inp) for inp in inputs
])

# Verify results - all should be left-padded to max length (4)
expected = [
np.array([[0, 1, 2, 3]]),
np.array([[0, 0, 4, 5]]),
np.array([[6, 7, 8, 9]]),
np.array([[0, 0, 10, 11]]),
np.array([[0, 12, 13, 14]])
]

for result, exp in zip(results, expected):
np.testing.assert_array_equal(result, exp)
Loading
Loading