diff --git a/pyproject.toml b/pyproject.toml index 50ecd81..a874a51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,6 @@ exclude_lines = [ [tool.mypy] python_version = "3.10" -plugins = "numpy.typing.mypy_plugin" disallow_untyped_defs = true disallow_any_unimported = true no_implicit_optional = true diff --git a/src/xarray_regrid/methods/_shared.py b/src/xarray_regrid/methods/_shared.py index a606976..dff1240 100644 --- a/src/xarray_regrid/methods/_shared.py +++ b/src/xarray_regrid/methods/_shared.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Hashable -from typing import Any, overload +from typing import Any, cast, overload import numpy as np import pandas as pd @@ -70,7 +70,7 @@ def restore_properties( else: result = result.where(covered, fill_value) - return result.transpose(*original_data.dims) + return cast("xr.DataArray | xr.Dataset", result.transpose(*original_data.dims)) @overload diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 2ab67d9..5cdc945 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -1,7 +1,7 @@ """Conservative regridding implementation.""" from collections.abc import Hashable -from typing import overload +from typing import cast, overload import numpy as np import xarray as xr @@ -98,7 +98,7 @@ def conservative_regrid( regridded_data = regridded_data.reindex_like(target_ds, copy=False) - return regridded_data + return cast("xr.DataArray | xr.Dataset", regridded_data) def conservative_regrid_dataset( @@ -241,7 +241,7 @@ def apply_spherical_correction( latitude_res = np.median(np.diff(dot_array[latitude_coord].to_numpy(), 1)) lat_weights = lat_weight(dot_array[latitude_coord].to_numpy(), latitude_res) da.values = utils.normalize_overlap(dot_array.values * lat_weights[:, np.newaxis]) - return da + return cast("xr.DataArray", da) def lat_weight(latitude: np.ndarray, latitude_res: float) -> np.ndarray: @@ -299,4 +299,4 @@ def format_weights( elif sparse is not None: new_weights.data = sparse.COO(weights.data) - return new_weights + return cast("xr.DataArray", new_weights) diff --git a/src/xarray_regrid/methods/interp.py b/src/xarray_regrid/methods/interp.py index 3dc81a0..6304459 100644 --- a/src/xarray_regrid/methods/interp.py +++ b/src/xarray_regrid/methods/interp.py @@ -1,6 +1,6 @@ """Methods based on xr.interp.""" -from typing import Literal, overload +from typing import Literal, cast, overload import xarray as xr @@ -49,4 +49,4 @@ def interp_regrid( for coord in coord_names: interped[coord].attrs = coord_attrs[coord] - return interped + return cast("xr.DataArray | xr.Dataset", interped) diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index 6979f84..d554a34 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Hashable from dataclasses import dataclass -from typing import Any, TypedDict, overload +from typing import Any, TypedDict, cast, overload import numpy as np import pandas as pd @@ -219,9 +219,12 @@ def call_on_dataset( msg = "Trying to convert Dataset with more than one data variable to DataArray" if len(result.data_vars) > 1: raise TypeError(msg) - return next(iter(result.data_vars.values())).rename(obj.name) + return cast( + "xr.DataArray", + next(iter(result.data_vars.values())).rename(obj.name), + ) - return result + return cast("xr.DataArray | xr.Dataset", result) @overload @@ -283,7 +286,7 @@ def format_for_regrid( if len(obj[var].chunksizes.get(coord, ())) == 1: result[var] = result[var].chunk({coord: -1}) - return result + return cast("xr.DataArray | xr.Dataset", result) def format_lat( @@ -322,7 +325,7 @@ def format_lat( south_pole = obj.isel({lat_coord: 0}) if lon_coord is not None: south_pole = south_pole.mean(lon_coord, keep_attrs=True) - obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore + obj = xr.concat([south_pole, obj], dim=lat_coord) lat_vals = np.concatenate([[-polar_lat], lat_vals]) # North pole @@ -330,7 +333,7 @@ def format_lat( north_pole = obj.isel({lat_coord: -1}) if lon_coord is not None: north_pole = north_pole.mean(lon_coord, keep_attrs=True) - obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore + obj = xr.concat([obj, north_pole], dim=lat_coord) lat_vals = np.concatenate([lat_vals, [polar_lat]]) obj = update_coord(obj, lat_coord, lat_vals)