Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions integrations/python/dataloader/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[tool.hatch.metadata]
allow-direct-references = true

[project]
name = "openhouse.dataloader"
dynamic = ["version"]
Expand All @@ -13,7 +10,7 @@ readme = "README.md"
requires-python = ">=3.12"
license = {text = "BSD-2-Clause"}
keywords = ["openhouse", "data-loader", "lakehouse", "iceberg", "datafusion"]
dependencies = ["datafusion==51.0.0", "pyiceberg @ git+https://github.com/sumedhsakdeo/iceberg-python@75ba28bfc6d8bbeac398357c6db80327632a2dc8", "requests>=2.31.0", "sqlglot>=29.0.0", "tenacity>=8.0.0"]
dependencies = ["datafusion==51.0.0", "pyiceberg~=0.11.0", "requests>=2.31.0", "sqlglot>=29.0.0", "tenacity>=8.0.0"]

[project.optional-dependencies]
dev = ["responses>=0.25.0", "ruff>=0.9.0", "pytest>=8.0.0", "twine>=6.0.0", "mypy>=1.14.0", "types-requests>=2.31.0"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
filters: Filter | None = None,
context: DataLoaderContext | None = None,
max_attempts: int = 3,
batch_size: int | None = None,
):
"""
Args:
Expand All @@ -92,10 +91,6 @@ def __init__(
filters: Row filter expression, defaults to always_true() (all rows)
context: Data loader context
max_attempts: Total number of attempts including the initial try (default 3)
batch_size: Maximum number of rows per RecordBatch yielded by each split.
Passed to PyArrow's Scanner which produces batches of at most this many
rows. Smaller values reduce peak memory but increase per-batch overhead.
None uses the PyArrow default (~131K rows).
"""
if branch is not None and branch.strip() == "":
raise ValueError("branch must not be empty or whitespace")
Expand All @@ -108,7 +103,6 @@ def __init__(
self._filters = filters if filters is not None else always_true()
self._context = context or DataLoaderContext()
self._max_attempts = max_attempts
self._batch_size = batch_size

@cached_property
def _iceberg_table(self) -> Table:
Expand Down Expand Up @@ -203,5 +197,4 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
scan_context=scan_context,
transform_sql=transform_sql,
udf_registry=self._context.udf_registry,
batch_size=self._batch_size,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datafusion.context import SessionContext
from pyarrow import RecordBatch
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.table import ArrivalOrder, FileScanTask
from pyiceberg.table import FileScanTask

from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader.table_identifier import TableIdentifier
Expand Down Expand Up @@ -56,13 +56,11 @@ def __init__(
scan_context: TableScanContext,
transform_sql: str | None = None,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
):
self._file_scan_task = file_scan_task
self._scan_context = scan_context
self._transform_sql = transform_sql
self._udf_registry = udf_registry or NoOpRegistry()
self._batch_size = batch_size

@property
def id(self) -> str:
Expand All @@ -81,8 +79,7 @@ def __iter__(self) -> Iterator[RecordBatch]:
"""Reads the file scan task and yields Arrow RecordBatches.

Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution,
delete files, and partition spec lookups. The number of batches loaded
into memory at once is bounded to prevent using too much memory at once.
delete files, and partition spec lookups.
"""
ctx = self._scan_context
arrow_scan = ArrowScan(
Expand All @@ -92,10 +89,7 @@ def __iter__(self) -> Iterator[RecordBatch]:
row_filter=ctx.row_filter,
)

batches = arrow_scan.to_record_batches(
[self._file_scan_task],
order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size),
)
batches = arrow_scan.to_record_batches([self._file_scan_task])

if self._transform_sql is None:
yield from batches
Expand Down
11 changes: 3 additions & 8 deletions integrations/python/dataloader/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,13 @@ def read_token() -> str:
snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id
assert snap1 is not None

# 4. Read all data with batch_size and verify batch count
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, batch_size=2)
batches = [batch for split in loader for batch in split]
assert len(batches) == 2, f"Expected 2 batches (3 rows, batch_size=2), got {len(batches)}"
for batch in batches:
assert batch.num_rows <= 2
result = pa.concat_tables([pa.Table.from_batches([b]) for b in batches]).sort_by(COL_ID)
# 4. Read all data
result = _read_all(OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID))
assert result.num_rows == 3
assert result.column(COL_ID).to_pylist() == [1, 2, 3]
assert result.column(COL_NAME).to_pylist() == ["alice", "bob", "charlie"]
assert result.column(COL_SCORE).to_pylist() == [1.1, 2.2, 3.3]
print(f"PASS: read all {result.num_rows} rows in {len(batches)} batches (batch_size=2)")
print(f"PASS: read all {result.num_rows} rows")

# 5a. Row filter
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, filters=col(COL_ID) > 1)
Expand Down
137 changes: 0 additions & 137 deletions integrations/python/dataloader/tests/test_arrival_order.py

This file was deleted.

27 changes: 0 additions & 27 deletions integrations/python/dataloader/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,30 +541,3 @@ def fake_scan(**kwargs):
branch_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch"))
assert len(branch_splits) == 1
assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet"


# --- batch_size tests ---


def test_batch_size_forwarded_to_splits(tmp_path):
"""batch_size is correctly passed through to each DataLoaderSplit."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=32768)
splits = list(loader)

assert len(splits) >= 1
for split in splits:
assert split._batch_size == 32768


def test_batch_size_default_is_none(tmp_path):
"""Omitting batch_size defaults to None in each split."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl")
splits = list(loader)

assert len(splits) >= 1
for split in splits:
assert split._batch_size is None
46 changes: 0 additions & 46 deletions integrations/python/dataloader/tests/test_data_loader_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def _create_test_split(
transform_sql: str | None = None,
table_id: TableIdentifier = _DEFAULT_TABLE_ID,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
) -> DataLoaderSplit:
"""Create a DataLoaderSplit for testing by writing data to disk.

Expand Down Expand Up @@ -104,7 +103,6 @@ def _create_test_split(
scan_context=scan_context,
transform_sql=transform_sql,
udf_registry=udf_registry,
batch_size=batch_size,
)


Expand Down Expand Up @@ -398,47 +396,3 @@ def test_transform_with_quoted_identifier(tmp_path):

assert result.num_rows == 1
assert result.column("name").to_pylist() == ["MASKED"]


# --- batch_size tests ---

_BATCH_SCHEMA = Schema(
NestedField(field_id=1, name="id", field_type=LongType(), required=False),
)


def _make_table(num_rows: int) -> pa.Table:
return pa.table({"id": pa.array(list(range(num_rows)), type=pa.int64())})


def test_split_batch_size_limits_rows_per_batch(tmp_path):
"""When batch_size is set, each RecordBatch has at most that many rows."""
table = _make_table(100)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=10)

batches = list(split)

assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows"
for batch in batches:
assert batch.num_rows <= 10
assert sum(b.num_rows for b in batches) == 100


def test_split_batch_size_none_returns_all_rows(tmp_path):
"""Default batch_size (None) returns all data correctly."""
table = _make_table(50)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 50
assert sorted(result.column("id").to_pylist()) == list(range(50))


def test_split_batch_size_preserves_data(tmp_path):
"""batch_size controls chunking but all data is preserved."""
table = _make_table(25)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=7)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 25
assert sorted(result.column("id").to_pylist()) == list(range(25))
Loading