Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ exclude_lines = [

[tool.mypy]
python_version = "3.10"
plugins = "numpy.typing.mypy_plugin"
disallow_untyped_defs = true
disallow_any_unimported = true
no_implicit_optional = true
Expand Down
4 changes: 2 additions & 2 deletions src/xarray_regrid/methods/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from collections.abc import Hashable
from typing import Any, overload
from typing import Any, cast, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -70,7 +70,7 @@ def restore_properties(
else:
result = result.where(covered, fill_value)

return result.transpose(*original_data.dims)
return cast("xr.DataArray | xr.Dataset", result.transpose(*original_data.dims))


@overload
Expand Down
8 changes: 4 additions & 4 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Conservative regridding implementation."""

from collections.abc import Hashable
from typing import overload
from typing import cast, overload

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -98,7 +98,7 @@ def conservative_regrid(

regridded_data = regridded_data.reindex_like(target_ds, copy=False)

return regridded_data
return cast("xr.DataArray | xr.Dataset", regridded_data)


def conservative_regrid_dataset(
Expand Down Expand Up @@ -241,7 +241,7 @@ def apply_spherical_correction(
latitude_res = np.median(np.diff(dot_array[latitude_coord].to_numpy(), 1))
lat_weights = lat_weight(dot_array[latitude_coord].to_numpy(), latitude_res)
da.values = utils.normalize_overlap(dot_array.values * lat_weights[:, np.newaxis])
return da
return cast("xr.DataArray", da)


def lat_weight(latitude: np.ndarray, latitude_res: float) -> np.ndarray:
Expand Down Expand Up @@ -299,4 +299,4 @@ def format_weights(
elif sparse is not None:
new_weights.data = sparse.COO(weights.data)

return new_weights
return cast("xr.DataArray", new_weights)
4 changes: 2 additions & 2 deletions src/xarray_regrid/methods/interp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Methods based on xr.interp."""

from typing import Literal, overload
from typing import Literal, cast, overload

import xarray as xr

Expand Down Expand Up @@ -49,4 +49,4 @@ def interp_regrid(
for coord in coord_names:
interped[coord].attrs = coord_attrs[coord]

return interped
return cast("xr.DataArray | xr.Dataset", interped)
15 changes: 9 additions & 6 deletions src/xarray_regrid/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import Any, TypedDict, overload
from typing import Any, TypedDict, cast, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -219,9 +219,12 @@ def call_on_dataset(
msg = "Trying to convert Dataset with more than one data variable to DataArray"
if len(result.data_vars) > 1:
raise TypeError(msg)
return next(iter(result.data_vars.values())).rename(obj.name)
return cast(
"xr.DataArray",
next(iter(result.data_vars.values())).rename(obj.name),
)

return result
return cast("xr.DataArray | xr.Dataset", result)


@overload
Expand Down Expand Up @@ -283,7 +286,7 @@ def format_for_regrid(
if len(obj[var].chunksizes.get(coord, ())) == 1:
result[var] = result[var].chunk({coord: -1})

return result
return cast("xr.DataArray | xr.Dataset", result)


def format_lat(
Expand Down Expand Up @@ -322,15 +325,15 @@ def format_lat(
south_pole = obj.isel({lat_coord: 0})
if lon_coord is not None:
south_pole = south_pole.mean(lon_coord, keep_attrs=True)
obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore
obj = xr.concat([south_pole, obj], dim=lat_coord)
lat_vals = np.concatenate([[-polar_lat], lat_vals])

# North pole
if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat:
north_pole = obj.isel({lat_coord: -1})
if lon_coord is not None:
north_pole = north_pole.mean(lon_coord, keep_attrs=True)
obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore
obj = xr.concat([obj, north_pole], dim=lat_coord)
lat_vals = np.concatenate([lat_vals, [polar_lat]])

obj = update_coord(obj, lat_coord, lat_vals)
Expand Down