Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nested_pandas/series/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def columns(self) -> list[str]:
@property
def flat_index(self) -> pd.Index:
"""Index of the flattened arrays"""
flat_index = np.repeat(self._series.index, np.diff(self._series.array.list_offsets))
flat_index = np.repeat(self._series.index, self._series.array.list_lengths)
# pd.Index supports np.repeat, so flat_index is the same type as self._series.index
flat_index = cast(pd.Index, flat_index)
return flat_index
Expand Down
199 changes: 170 additions & 29 deletions src/nested_pandas/series/ext_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
from nested_pandas.series.dtype import NestedDtype
from nested_pandas.series.nestedseries import NestedSeries # noqa
from nested_pandas.series.utils import (
_MAX_FLAT_SIZE,
chunk_lengths,
chunk_sizes_are_fragmented,
compute_chunk_boundaries,
is_pa_type_a_list,
normalize_list_array,
normalize_struct_list_type,
Expand All @@ -74,6 +77,15 @@

__all__ = ["NestedExtensionArray"]

DEFAULT_CHUNK_SIZE = 1_048_576
"""Target number of outer rows per chunk for :meth:`NestedExtensionArray.rechunk`.

Matches PyArrow's default Parquet row group size.
"""

DEFAULT_MIN_CHUNK_SIZE = 4_096
"""Minimum outer rows per chunk before :meth:`NestedExtensionArray.is_fragmented` flags excess chunks."""


BOXED_NESTED_EXTENSION_ARRAY_FORMAT_TRICK = True
"""Use a trick to by-pass pandas limitations on extension array formatting
Expand Down Expand Up @@ -306,13 +318,13 @@ def __getitem__(self, item: ScalarIndexer) -> Self | pd.DataFrame: # type: igno
return type(self)(pa.chunked_array([], type=self.dtype.pyarrow_dtype), validate=False)
pa_item = pa.array(item)
if item.dtype.kind in "iu":
return type(self)(self.struct_array.take(pa_item), validate=False)
if item.dtype.kind == "b":
return type(self)(self.struct_array.filter(pa_item), validate=False)
# It should be covered by check_array_indexer above
raise IndexError(
"Only integers, slices and integer or boolean arrays are valid indices."
) # pragma: no cover
pa_item = self.struct_array.take(pa_item)
elif item.dtype.kind == "b":
pa_item = self.struct_array.filter(pa_item)
else: # pragma: no cover
raise IndexError("Only integers, slices and integer or boolean arrays are valid indices.")
result = type(self)(pa_item, validate=False)
return result._rechunk_if_fragmented()

if isinstance(item, tuple):
item = unpack_tuple_and_ellipses(item)
Expand All @@ -325,7 +337,8 @@ def __getitem__(self, item: ScalarIndexer) -> Self | pd.DataFrame: # type: igno
return self._convert_struct_scalar_to_df(scalar_or_array, copy=False)
# Logically, it must be a pa.ChunkedArray if it is not a scalar
pa_array = cast(pa.ChunkedArray, scalar_or_array)
return type(self)(pa_array, validate=False)
result = type(self)(pa_array, validate=False)
return result._rechunk_if_fragmented()

def __setitem__(self, key, value) -> None:
# TODO: optimize for many chunk_lens
Expand Down Expand Up @@ -532,22 +545,28 @@ def take(
fill_mask = indices_array < 0
if not fill_mask.any():
# Nothing to fill, using list-array should be faster
return type(self)(self.list_array.take(indices))
result = type(self)(self.list_array.take(indices))
return result._rechunk_if_fragmented()
validate_indices(indices_array, len(self))
indices_array = pa.array(indices_array, mask=fill_mask)

result = self.struct_array.take(indices_array)
taken = self.struct_array.take(indices_array)
if not pa.compute.is_null(fill_value).as_py():
result = pa.compute.if_else(fill_mask, fill_value, result)
taken = pa.compute.if_else(fill_mask, fill_value, taken)
# Validate for fill_value
return type(self)(result, validate=True)
result = type(self)(taken, validate=True)
return result._rechunk_if_fragmented()

if (indices_array < 0).any():
# Don't modify in-place
indices_array = np.copy(indices_array)
indices_array[indices_array < 0] += len(self)
# list_array should be faster
return type(self)(self.list_array.take(indices_array))
result = type(self)(self.list_array.take(indices_array))
return result._rechunk_if_fragmented()

def _rechunk_if_fragmented(self) -> Self: # type: ignore[name-defined] # noqa: F821
return self.rechunk() if self.is_fragmented() else self

def copy(self) -> Self: # type: ignore[name-defined] # noqa: F821
"""Return a copy of the extension array.
Expand Down Expand Up @@ -591,7 +610,10 @@ def format_row(row):
def _concat_same_type(cls, to_concat: Sequence[Self]) -> Self: # type: ignore[name-defined] # noqa: F821
chunks = [chunk for ext_array in to_concat for chunk in ext_array.list_array.iterchunks()]
pa_array = pa.chunked_array(chunks)
return cls(pa_array)
result = cls(pa_array)
if result.is_fragmented():
return result.rechunk()
return result

def equals(self, other) -> bool:
"""
Expand Down Expand Up @@ -663,12 +685,6 @@ def __array__(self, dtype=None, copy=True):

return array

# Adopted from ArrowExtensionArray
def __getstate__(self):
state = self.__dict__.copy()
state["_storage"] = ListStructStorage(self.list_array.combine_chunks())
return state

# End of Additional magic methods #

@classmethod
Expand Down Expand Up @@ -964,17 +980,22 @@ def list_offsets(self) -> pa.Array:

Returns
-------
pa.ChunkedArray
The list offsets of the field arrays.
pa.Array
Cumulative offsets of length ``len(self) + 1``. For a single-chunk
array the dtype is ``int32`` (matching the underlying
``pa.ListArray`` buffer). For a multi-chunk array the dtype is
``int64`` to avoid silent overflow when the total number of flat
rows exceeds ``2**31``.
"""
# Cheap path for a single chunk
if self._storage.num_chunks == 1:
return self.list_array.chunk(0).offsets

# Use int64 to avoid overflow when total flat rows exceed 2^31.
zero_and_lengths = pa.chunked_array(
[
pa.array([0], type=pa.int32()),
pa.array(self.list_lengths, type=pa.int32()),
pa.array([0], type=pa.int64()),
pa.array(self.list_lengths, type=pa.int64()),
]
)
offsets = pa.compute.cumulative_sum(zero_and_lengths)
Expand Down Expand Up @@ -1005,6 +1026,90 @@ def num_chunks(self) -> int:
"""Number of chunk_lens in underlying pyarrow.ChunkedArray"""
return self._storage.num_chunks

def is_fragmented(
self,
min_chunk_size: int | None = None,
) -> bool:
"""Check whether :meth:`rechunk` would improve memory layout.

Returns ``True`` if any of the following hold:

- Any "body" chunk (all chunks except a trailing run of small ones)
has fewer than ``min_chunk_size`` outer rows.
- The trailing run of small chunks ("tail") has accumulated enough
rows in total to form a proper chunk (i.e. total >= ``min_chunk_size``).

The tail allowance amortises detection cost for incremental-append
workloads: a rechunk is triggered roughly once per ``min_chunk_size``
single-row appends rather than after every second append.

Parameters
----------
min_chunk_size : int or None
Minimum acceptable average outer rows per chunk.
Defaults to ``DEFAULT_MIN_CHUNK_SIZE`` (1_024).

Returns
-------
bool
"""
if min_chunk_size is None:
min_chunk_size = DEFAULT_MIN_CHUNK_SIZE
sizes = [len(c) for c in self.list_array.iterchunks()]
return chunk_sizes_are_fragmented(sizes, min_chunk_size)

def _chunk_boundaries(self, chunk_size: int) -> list[int]:
"""Compute chunk boundaries for this array.

Uses the total flat size from chunk metadata (O(n_chunks)) to avoid
computing list_lengths when there is no int32 overflow risk.
"""
n = len(self)
total_flat = sum(len(chunk.values) for chunk in self.list_array.chunks)
if total_flat <= _MAX_FLAT_SIZE:
# No overflow risk: boundaries are pure arithmetic, no need to compute list_lengths.
return list(range(0, n, chunk_size)) + [n]
return compute_chunk_boundaries(self.list_lengths, chunk_size)

def rechunk(self, chunk_size: int | None = None) -> NestedExtensionArray: # type: ignore[name-defined] # noqa: F821
"""Rechunk the array to approximately ``chunk_size`` outer rows per chunk.

Parameters
----------
chunk_size : int or None
Target number of outer rows per chunk. If ``None``, uses
``DEFAULT_CHUNK_SIZE`` (PyArrow's default Parquet row group size,
1_048_576). Chunks may be smaller when the flat (inner) row count
would overflow int32 offsets (see ``_MAX_FLAT_SIZE``).

Returns
-------
NestedExtensionArray
A new array with rebalanced chunks, or ``self`` if already clean.
"""
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
if chunk_size < 1:
raise ValueError("chunk_size must be >= 1")

boundaries = self._chunk_boundaries(chunk_size)

# Fast-path: already in the exact layout rechunk would produce.
# Build expected sizes from boundaries and compare to current chunk sizes.
expected_sizes = [boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1)]
actual_sizes = [len(c) for c in self.list_array.iterchunks()]
if actual_sizes == expected_sizes:
return self

combined = self.list_array
chunks = [
# combine_chunks() slices combined[start:end] into a fresh ListArray
# with zero-based offsets, guaranteeing no int32 overflow.
combined[start:end].combine_chunks()
for start, end in zip(boundaries[:-1], boundaries[1:], strict=True)
]
return type(self)(pa.chunked_array(chunks, type=combined.type))

def get_list_index(self) -> np.ndarray:
"""Keys mapping values to lists"""
if len(self) == 0:
Expand Down Expand Up @@ -1102,11 +1207,47 @@ def set_flat_field(self, field: str, value: ArrayLike, *, keep_dtype: bool = Fal
if len(pa_array) != self.flat_length:
raise ValueError("The input must be a struct_scalar or have the same length as the flat arrays")

if isinstance(pa_array, pa.ChunkedArray):
pa_array = pa_array.combine_chunks()
field_list_array = pa.ListArray.from_arrays(values=pa_array, offsets=self.list_offsets)

return self.set_list_field(field, field_list_array, keep_dtype=keep_dtype)
# Convert flat input to ChunkedArray for uniform handling
flat_chunked = pa.chunked_array([pa_array]) if isinstance(pa_array, pa.Array) else pa_array

# Build one ListArray per outer chunk, slicing the flat input accordingly.
# This avoids a global combine_chunks() of the flat array and a global
# list_offsets computation.
list_chunks = []
flat_offset = 0
for outer_chunk in self.list_array.iterchunks():
outer_chunk = cast(pa.ListArray, outer_chunk)
offsets = outer_chunk.offsets
flat_start = offsets[0].as_py()
flat_end = offsets[-1].as_py()
chunk_flat_size = flat_end - flat_start

flat_slice_chunked = flat_chunked[flat_offset : flat_offset + chunk_flat_size]
# Avoid an unnecessary copy when the slice already lands on a single
# contiguous chunk (the common path when both self and pa_array are
# single-chunk arrays).
if flat_slice_chunked.num_chunks == 1:
flat_slice = flat_slice_chunked.chunk(0)
else:
flat_slice = flat_slice_chunked.combine_chunks()
flat_offset += chunk_flat_size

# Normalize offsets to start at 0 (required when outer_chunk is a view
# with non-zero offsets[0] pointing into a larger flat buffer)
chunk_offsets = pa.compute.subtract(offsets, offsets[0]) if flat_start != 0 else offsets

list_chunks.append(pa.ListArray.from_arrays(offsets=chunk_offsets, values=flat_slice))

new_list_array = pa.chunked_array(list_chunks, type=pa.list_(pa_array.type))

# Update the table directly — set_list_field would call pa.array(value, ...)
# which does not handle ChunkedArray inputs reliably.
if field in self.field_names:
field_idx = self.field_names.index(field)
pa_table = self.pa_table.drop(field).add_column(field_idx, field, new_list_array)
else:
pa_table = self.pa_table.append_column(field, new_list_array)
self.pa_table = pa_table

def set_list_field(self, field: str, value: ArrayLike, *, keep_dtype: bool = False) -> None:
"""Set the field from list-array
Expand Down
Loading
Loading