Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions src/xarray_regrid/regrid.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections.abc import Hashable
from typing import Any, overload
from typing import Any, Literal, overload

import numpy as np
import xarray as xr

from xarray_regrid.methods import conservative, flox_reduce, interp
from xarray_regrid.utils import format_for_regrid

InterpMethod = Literal["linear", "nearest", "cubic"]


@xr.register_dataarray_accessor("regrid")
@xr.register_dataset_accessor("regrid")
Expand Down Expand Up @@ -40,9 +42,7 @@ def linear(
Returns:
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "linear")
return self._interp(ds_target_grid, "linear", time_dim)

def nearest(
self,
Expand All @@ -59,9 +59,7 @@ def nearest(
Returns:
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest")
return self._interp(ds_target_grid, "nearest", time_dim)

def cubic(
self,
Expand All @@ -78,9 +76,17 @@ def cubic(
Returns:
Data regridded to the target dataset coordinates.
"""
return self._interp(ds_target_grid, "cubic", time_dim)

def _interp(
self,
ds_target_grid: xr.Dataset,
method: InterpMethod,
time_dim: str | None,
) -> xr.DataArray | xr.Dataset:
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic")
return interp.interp_regrid(ds_formatted, ds_target_grid, method)

def conservative(
self,
Expand Down Expand Up @@ -160,27 +166,7 @@ def most_common(
Returns:
Regridded data.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)

if isinstance(self._obj, xr.Dataset):
msg = (
"The 'most common value' regridder is not implemented for\n",
"xarray.Dataset, as it requires specifying the expected labels.\n"
"Please select only a single variable (as DataArray),\n"
" and regrid it separately.",
)
raise ValueError(msg)

ds_formatted = format_for_regrid(self._obj, ds_target_grid, stats=True)

return flox_reduce.compute_mode(
ds_formatted,
ds_target_grid,
values,
time_dim,
fill_value,
anti_mode=False,
)
return self._mode(ds_target_grid, values, time_dim, fill_value, anti_mode=False)

def least_common(
self,
Expand Down Expand Up @@ -212,14 +198,25 @@ def least_common(
Returns:
Regridded data.
"""
return self._mode(ds_target_grid, values, time_dim, fill_value, anti_mode=True)

def _mode(
self,
ds_target_grid: xr.Dataset,
values: np.ndarray,
time_dim: str | None,
fill_value: None | Any,
anti_mode: bool,
) -> xr.DataArray:
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)

if isinstance(self._obj, xr.Dataset):
label = "least common value" if anti_mode else "most common value"
msg = (
"The 'least common value' regridder is not implemented for\n",
f"The '{label}' regridder is not implemented for\n"
"xarray.Dataset, as it requires specifying the expected labels.\n"
"Please select only a single variable (as DataArray),\n"
" and regrid it separately.",
" and regrid it separately."
)
raise ValueError(msg)

Expand All @@ -231,7 +228,7 @@ def least_common(
values,
time_dim,
fill_value,
anti_mode=True,
anti_mode=anti_mode,
)

def stat(
Expand Down