From 47fa08d3558aea8c5b7232ce08bb1b70280bb9fb Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Fri, 17 Apr 2026 21:05:30 -0500 Subject: [PATCH] Collapse duplicated accessor methods The three interp methods (linear/nearest/cubic) had identical 3-line bodies differing only in the method string; extract a private `_interp`. most_common/least_common were near-identical with an `anti_mode` flag already threaded through to `compute_mode`; extract a private `_mode` that also builds the error label from the same flag. No behavior change; public API unchanged. --- src/xarray_regrid/regrid.py | 61 ++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index b2ed389..872253a 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -1,5 +1,5 @@ from collections.abc import Hashable -from typing import Any, overload +from typing import Any, Literal, overload import numpy as np import xarray as xr @@ -7,6 +7,8 @@ from xarray_regrid.methods import conservative, flox_reduce, interp from xarray_regrid.utils import format_for_regrid +InterpMethod = Literal["linear", "nearest", "cubic"] + @xr.register_dataarray_accessor("regrid") @xr.register_dataset_accessor("regrid") @@ -40,9 +42,7 @@ def linear( Returns: Data regridded to the target dataset coordinates. """ - ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - ds_formatted = format_for_regrid(self._obj, ds_target_grid) - return interp.interp_regrid(ds_formatted, ds_target_grid, "linear") + return self._interp(ds_target_grid, "linear", time_dim) def nearest( self, @@ -59,9 +59,7 @@ def nearest( Returns: Data regridded to the target dataset coordinates. """ - ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - ds_formatted = format_for_regrid(self._obj, ds_target_grid) - return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest") + return self._interp(ds_target_grid, "nearest", time_dim) def cubic( self, @@ -78,9 +76,17 @@ def cubic( Returns: Data regridded to the target dataset coordinates. """ + return self._interp(ds_target_grid, "cubic", time_dim) + + def _interp( + self, + ds_target_grid: xr.Dataset, + method: InterpMethod, + time_dim: str | None, + ) -> xr.DataArray | xr.Dataset: ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid) - return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic") + return interp.interp_regrid(ds_formatted, ds_target_grid, method) def conservative( self, @@ -160,27 +166,7 @@ def most_common( Returns: Regridded data. """ - ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - - if isinstance(self._obj, xr.Dataset): - msg = ( - "The 'most common value' regridder is not implemented for\n", - "xarray.Dataset, as it requires specifying the expected labels.\n" - "Please select only a single variable (as DataArray),\n" - " and regrid it separately.", - ) - raise ValueError(msg) - - ds_formatted = format_for_regrid(self._obj, ds_target_grid, stats=True) - - return flox_reduce.compute_mode( - ds_formatted, - ds_target_grid, - values, - time_dim, - fill_value, - anti_mode=False, - ) + return self._mode(ds_target_grid, values, time_dim, fill_value, anti_mode=False) def least_common( self, @@ -212,14 +198,25 @@ def least_common( Returns: Regridded data. """ + return self._mode(ds_target_grid, values, time_dim, fill_value, anti_mode=True) + + def _mode( + self, + ds_target_grid: xr.Dataset, + values: np.ndarray, + time_dim: str | None, + fill_value: None | Any, + anti_mode: bool, + ) -> xr.DataArray: ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) if isinstance(self._obj, xr.Dataset): + label = "least common value" if anti_mode else "most common value" msg = ( - "The 'least common value' regridder is not implemented for\n", + f"The '{label}' regridder is not implemented for\n" "xarray.Dataset, as it requires specifying the expected labels.\n" "Please select only a single variable (as DataArray),\n" - " and regrid it separately.", + " and regrid it separately." ) raise ValueError(msg) @@ -231,7 +228,7 @@ def least_common( values, time_dim, fill_value, - anti_mode=True, + anti_mode=anti_mode, ) def stat(