Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 91 additions & 30 deletions elt-common/src/elt_common/iceberg/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
from typing import Sequence
from typing import Collection

import pyarrow as pa
from pyiceberg.schema import Schema
Expand All @@ -9,18 +8,20 @@
DateType,
DecimalType,
DoubleType,
IcebergType,
IntegerType,
ListType,
LongType,
NestedField,
PrimitiveType,
StringType,
TimeType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
)


def arrow_type_to_iceberg(arrow_type: pa.DataType) -> PrimitiveType:
def arrow_type_to_iceberg(arrow_type: pa.DataType, field_id: int = 1) -> IcebergType:
"""Returns the Iceberg type for the given pyarrow data type.

:raises TypeError: If the type is unknown or is not supported
Expand Down Expand Up @@ -58,6 +59,25 @@ def arrow_type_to_iceberg(arrow_type: pa.DataType) -> PrimitiveType:
or pa.types.is_fixed_size_binary(arrow_type)
):
return BinaryType()

elif pa.types.is_list(arrow_type):
# The list itself uses field_id, the lists element type uses the subsequent id
element_type = arrow_type_to_iceberg(arrow_type.value_type, field_id + 1)

# HACK: element_required is set to false because of difficulties getting object
# list elements from JSON to be optional.
# Not sure if this is a limitation of pyarrow or I just didn't find the right incantation
Comment thread
martyngigg marked this conversation as resolved.
return ListType(element_id=field_id, element_type=element_type, element_required=False)

elif pa.types.is_struct(arrow_type):
iceberg_fields = []
next_field_id = field_id
for subfield in arrow_type.fields:
iceberg_field = arrow_field_to_iceberg(arrow_field=subfield, column_id=next_field_id)
next_field_id = get_max_field_id(iceberg_field) + 1
iceberg_fields.append(iceberg_field)

return StructType(*iceberg_fields)
else:
raise TypeError(f"Pyarrow type '{arrow_type}' unknown to type mapper.")

Expand All @@ -66,50 +86,91 @@ def arrow_field_to_iceberg(column_id: int, arrow_field: pa.Field) -> NestedField
return NestedField(
field_id=column_id,
name=arrow_field.name,
field_type=arrow_type_to_iceberg(arrow_field.type),
field_type=arrow_type_to_iceberg(arrow_field.type, column_id + 1),
required=not arrow_field.nullable,
)


def create_schema(arrow_schema: pa.Schema, identifier_fields: Sequence[str] = ()) -> Schema:
def create_schema(arrow_schema: pa.Schema, identifier_fields: Collection[str] = ()) -> Schema:
"""Convert a pyarrow schema into an iceberg schema

:param arrow_schema: A pyarrow schema.
:param identifier_fields: An optional list of fields to mark as identifiers
"""
iceberg_fields, identifier_field_ids = [], []
for index, arrow_field in enumerate(arrow_schema):
col_id = index + 1
iceberg_fields.append(arrow_field_to_iceberg(col_id, arrow_field))
col_id = 1
for arrow_field in arrow_schema:
field = arrow_field_to_iceberg(col_id, arrow_field)

iceberg_fields.append(field)
if arrow_field.name in identifier_fields:
identifier_field_ids.append(col_id)

col_id = get_max_field_id(field) + 1
Comment thread
martyngigg marked this conversation as resolved.

return Schema(*iceberg_fields, identifier_field_ids=identifier_field_ids)


def evolve_schema(iceberg_schema: Schema, new_arrow_schema: pa.Schema) -> Schema | None:
"""Attempt to evolve the schema to match the data.

Returns the new schema if updates were applied, else None
Only new fields are considered backwards compatible. This is less permissive
than should be allowed - renaming fields, reordering files, and some type/property
changes could also be allowed - but iceberg rejects the changes when trying
to actually write to the table.

:returns: None if the schema didn't change, or the new schema if it did (in a backward compatible way).
:raises ValueError: If the schema has incompatible changes.
"""
existing_columns = set(iceberg_schema.column_names)
new_columns = set(new_arrow_schema.names) - existing_columns
if new_columns:
num_existing_fields = len(iceberg_schema.fields)

return Schema(
*(
itertools.chain(
iceberg_schema.fields,
[
arrow_field_to_iceberg(
num_existing_fields + index + 1, new_arrow_schema.field(name)
)
for index, name in enumerate(new_arrow_schema.names)
if name in new_columns
],
new_iceberg_schema = create_schema(new_arrow_schema, iceberg_schema.identifier_field_names())

if new_iceberg_schema == iceberg_schema:
return None
else:
# If there are incompatible changes, throw an error
incompatibilities = []
for f in iceberg_schema.fields:
try:
new_field = new_iceberg_schema.find_field(f.field_id)
except ValueError:
incompatibilities.append(f"Field id {f.field_id} removed")
continue

if f.name != new_field.name:
incompatibilities.append(
f"Field {f.field_id} changed name from '{f.name}' to '{new_field.name}'"
)
)
)
elif f.field_type != new_field.field_type:
incompatibilities.append(
f"Field '{f.name}' (id: {f.field_id}) changed type from '{f.field_type}' to '{new_field.field_type}'"
)
elif new_field.required != f.required:
incompatibilities.append(
f"Field '{f.name}' (id: {f.field_id}) 'required' changed to {new_field.required}"
)

if incompatibilities:
raise ValueError(f"Incompatible changes to schema: {incompatibilities}")

# The new schema is different, but backwards compatible
return new_iceberg_schema


def get_max_field_id(f: NestedField) -> int:
"""Return the largest field_id from an Iceberg field.

- For primitive fields this is just the field_id
- For list fields this is the larget field_id from the list's element type
- For struct fields this is the largest id across all of its subfields (potentially recursively)
"""
if f.field_type.is_primitive:
return f.field_id
elif isinstance(f.field_type, StructType):
struct_fields = f.field_type.fields
if not struct_fields:
return f.field_id
return max(get_max_field_id(sub) for sub in struct_fields)
elif isinstance(f.field_type, ListType):
return get_max_field_id(f.field_type.element_field)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
return None
raise ValueError("Can only get fields ids for primitive, list, and struct fields")
129 changes: 107 additions & 22 deletions elt-common/tests/unit_tests/iceberg/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,46 @@
"""Tests for elt_common.iceberg.schema"""

import pyarrow as pa
from pyiceberg.schema import Schema, NestedField
import pytest
from pyiceberg.schema import NestedField, Schema
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
IntegerType,
ListType,
LongType,
StringType,
TimeType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
)
import pytest

from elt_common.iceberg.schema import arrow_type_to_iceberg, create_schema, evolve_schema


arrow_fields = [
pa.field("row_id", pa.int64(), nullable=False),
pa.field("entry_name", pa.string(), nullable=False),
pa.field("entry_timestamp", pa.timestamp(unit="us")),
pa.field("entry_weight", pa.float64()),
]


@pytest.fixture()
def arrow_schema() -> pa.Schema:
return pa.schema(
[
pa.field("row_id", pa.int64(), nullable=False),
pa.field("entry_name", pa.string(), nullable=False),
pa.field("entry_timestamp", pa.timestamp(unit="us")),
pa.field("entry_weight", pa.float64()),
]
)
return pa.schema(arrow_fields)


iceberg_fields = [
NestedField(field_id=1, name="row_id", field_type=LongType(), required=True),
NestedField(field_id=2, name="entry_name", field_type=StringType(), required=True),
NestedField(field_id=3, name="entry_timestamp", field_type=TimestampType()),
NestedField(field_id=4, name="entry_weight", field_type=DoubleType()),
]


def test_unsupported_arrow_type_raises():
Expand All @@ -54,13 +65,61 @@ def test_unsupported_arrow_type_raises():
(pa.binary(), BinaryType),
(pa.large_binary(), BinaryType),
(pa.binary(8), BinaryType),
(pa.struct([("test", pa.int32())]), StructType),
(pa.struct([("nested", pa.struct([("test", pa.int32())]))]), StructType),
(pa.list_(pa.int32()), ListType),
(
pa.list_(
pa.struct(
[
("list_of_structs", pa.list_(pa.struct([("a", pa.int32())]))),
("something", pa.binary()),
]
)
),
ListType,
),
],
)
def test_returns_expected_iceberg_type(arrow_type, expected_type):
result = arrow_type_to_iceberg(arrow_type)
assert isinstance(result, expected_type)


def test_arrow_type_to_iceberg_nested_fields():
arrow_type = pa.struct(
[
(
"a",
pa.list_(
pa.struct(
[
("b", pa.int32()),
("c", pa.string()),
]
)
),
),
("d", pa.struct([("e", pa.timestamp("ms"))])),
]
)
result = arrow_type_to_iceberg(arrow_type)
assert isinstance(result, StructType)
assert len(result.fields) == 2

list_field = result.fields[0]
assert isinstance(list_field.field_type, ListType)
list_struct = list_field.field_type.element_type
assert isinstance(list_struct, StructType)
assert len(list_struct.fields) == 2
assert isinstance(list_struct.fields[0].field_type, IntegerType)
assert isinstance(list_struct.fields[1].field_type, StringType)

struct_field = result.fields[1]
assert isinstance(struct_field.field_type, StructType)
assert isinstance(struct_field.field_type.fields[0].field_type, TimestampType)


def test_maps_decimal_precision_and_scale():
result = arrow_type_to_iceberg(pa.decimal128(12, 3))

Expand Down Expand Up @@ -101,29 +160,55 @@ def test_create_iceberg_schema(arrow_schema: pa.Schema, identifier_fields):


@pytest.mark.parametrize(
["iceberg_field_names", "expected_new_field_names"],
["iceberg_field_idxs", "expected_new_field_names"],
[
([], {"row_id", "entry_name", "entry_timestamp", "entry_weight"}),
([], ["row_id", "entry_name", "entry_timestamp", "entry_weight"]),
(
["row_id", "entry_name", "entry_timestamp"],
{"row_id", "entry_name", "entry_timestamp", "entry_weight"},
[0, 1, 2],
["row_id", "entry_name", "entry_timestamp", "entry_weight"],
),
(["row_id", "entry_name", "entry_timestamp", "entry_weight"], {}),
([0, 1, 2, 3], []),
],
)
def test_evolve_schema(
arrow_schema: pa.Schema, iceberg_field_names: list[str], expected_new_field_names
arrow_schema: pa.Schema, iceberg_field_idxs: list[int], expected_new_field_names
):
existing_fields = [
NestedField(field_id=i + 1, name=name, field_type=StringType(), required=False)
for i, name in enumerate(iceberg_field_names)
]
existing_fields = [iceberg_fields[i] for i in iceberg_field_idxs]
existing_schema = Schema(*existing_fields)

schema_with_new_fields = evolve_schema(existing_schema, arrow_schema)

if expected_new_field_names:
assert schema_with_new_fields is not None
assert {f.name for f in schema_with_new_fields.fields} == expected_new_field_names
assert [f.name for f in schema_with_new_fields.fields] == expected_new_field_names
else:
assert schema_with_new_fields is None


@pytest.mark.parametrize(
["iceberg_field_idxs", "new_fields"],
[
# Fields removed
([0], []),
([0], arrow_fields[1:]),
([0, 1], arrow_fields[:1]),
([0, 1, 2], arrow_fields[:2]),
([0], [arrow_fields[1]]),
([1, 2], arrow_fields[2:4]),
# Fields reordered
([0, 1], [arrow_fields[1], arrow_fields[0]]),
([3, 2, 1], [arrow_fields[1], arrow_fields[3], arrow_fields[2]]),
# Field property changed
([0], [pa.field("row_id_renamed", pa.int64(), nullable=False)]),
([0], [pa.field("row_id", pa.int32(), nullable=False)]),
([0], [pa.field("row_id", pa.int64(), nullable=True)]),
],
)
def test_evolve_schema_incompatible(iceberg_field_idxs, new_fields):
existing_fields = [iceberg_fields[i] for i in iceberg_field_idxs]
existing_schema = Schema(*existing_fields)

new_schema = pa.schema(new_fields)

with pytest.raises(ValueError):
evolve_schema(existing_schema, new_schema)
Comment thread
martyngigg marked this conversation as resolved.
Loading