Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ccb7c4f
Add pressure_dimension_str arg to geopotential_thickness (#297)
aaTman Jan 11, 2026
d5ca07c
`DurationMeanError` memory fix and add time resolution option (#296)
aaTman Jan 13, 2026
f2d8cd2
Move parallel config check outside of function (#301)
aaTman Jan 13, 2026
d09f456
feat: Forecast wrapper for custom xarray datasets (#302)
darothen Jan 13, 2026
1035230
Simplify IBTrACS polars subset (#303)
aaTman Jan 14, 2026
4027de8
Update `geopotential_thickness` var names and docstring (#306)
aaTman Jan 14, 2026
1afff35
Clarify default preprocess function names; geopotential division fix …
aaTman Jan 14, 2026
5d239b8
Remove "cases" key requirement in yamls and dicts (#308)
aaTman Jan 15, 2026
4793532
remove out-of-date notebook from docs
aaTman Jan 15, 2026
ed3b9e6
CIRA Icechunk store (#310)
aaTman Jan 15, 2026
342668a
update pyproject and uv lock
aaTman Jan 16, 2026
fcd329e
Merge pull request #311 from brightbandtech:chore/upgrade-scores
aaTman Jan 16, 2026
89a9bdb
add TODO
aaTman Jan 17, 2026
94eeb98
update PR template
aaTman Jan 19, 2026
0fabca5
Remove `IndividualCaseCollection` (#317)
aaTman Jan 22, 2026
63f8b13
Cleanup docstrings in repo (#318)
aaTman Jan 24, 2026
5350563
add explanation for dim reqs (#320)
aaTman Jan 24, 2026
71f3a0b
Update `defaults` and `inputs` to include new CIRA icechunk store (#319)
aaTman Jan 24, 2026
79f4b74
Bump version from 0.2.0 to 0.3.0 (#324)
aaTman Jan 26, 2026
b1fc1b6
Updated API (#321)
aaTman Jan 26, 2026
34a3965
Golden tests (#323)
aaTman Jan 26, 2026
8979bed
PyPI Preparation (#315)
aaTman Jan 26, 2026
54f5c35
update pyproject version for release
aaTman Jan 26, 2026
c367043
Merge branch 'main' into develop
aaTman Jan 26, 2026
43fafbf
Remove duplicate function and fixtures (#326)
aaTman Jan 27, 2026
8445eba
Merge branch 'main' into develop
aaTman Mar 2, 2026
42c176e
update install docs for pypi and fix typo in link
aaTman Mar 2, 2026
9794aaf
Update ar calcs with improved parallelization (#329)
aaTman Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 29 additions & 50 deletions data_prep/ar_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from dask.distributed import Client
from matplotlib.patches import Rectangle

import extremeweatherbench.cases as cases
import extremeweatherbench.derived as derived
import extremeweatherbench.inputs as inputs
import extremeweatherbench.regions as regions
import extremeweatherbench.utils as utils
from extremeweatherbench import calc, cases, derived, inputs, regions, utils
from extremeweatherbench.events import atmospheric_river as ar

logging.basicConfig()
Expand Down Expand Up @@ -413,31 +409,6 @@ def find_timestamp_peak_field(
return peak_time_idx, peak_ivt_value


def create_composite_ar_mask(
ar_mask: xr.DataArray,
land_intersection: Optional[xr.DataArray] = None,
) -> Tuple[xr.DataArray, Optional[xr.DataArray]]:
"""Create composite AR masks by taking maximum over time.

Args:
ar_mask: Binary AR mask with time dimension.
land_intersection: Optional land intersection mask with time dim.

Returns:
Tuple of (composite_ar_mask, composite_land_intersection).
"""
time_dim = "valid_time" if "valid_time" in ar_mask.dims else "time"

# Create composite by taking max over time dimension
composite_mask = ar_mask.max(dim=time_dim)

composite_land = None
if land_intersection is not None:
composite_land = land_intersection.max(dim=time_dim)

return composite_mask, composite_land


def expand_bounds_to_contiguous_ar(
ar_mask: xr.DataArray,
land_intersection: xr.DataArray,
Expand Down Expand Up @@ -835,7 +806,23 @@ def process_ar_event(
era5_data = era5_data.sel(
valid_time=era5_data.valid_time.dt.hour.isin([0, 6, 12, 18])
)
era5_data = inputs.maybe_subset_variables(era5_data, variables=era5_ar.variables)

# Rough check, if you're adding variables beyond AtmosphericRiverVariables or
# specific_humidity, u_component..., v_component..., this will fail
if isinstance(era5_ar.variables, list):
if len(era5_ar.variables) > 1 and isinstance(
era5_ar.variables[0], derived.DerivedVariable
):
raise ValueError(
"Only accepted variables are derived.AtmosphericRiverVariables"
" or [specific_humidity, u_component_of_wind, v_component_of_wind]."
)
elif isinstance(era5_ar.variables[0], derived.DerivedVariable):
era5_ar_vars = era5_ar.variables[0].variables
else:
era5_ar_vars = era5_ar.variables

era5_data = inputs.maybe_subset_variables(era5_data, variables=era5_ar_vars)
era5_subset = era5_ar.subset_data_to_case(era5_data, case)
era5_subset = era5_subset.chunk()
# Generate IVT
Expand All @@ -844,19 +831,17 @@ def process_ar_event(
specific_humidity=era5_subset["specific_humidity"],
eastward_wind=era5_subset["eastward_wind"],
northward_wind=era5_subset["northward_wind"],
levels=era5_subset["adjusted_level"],
)
).persist()
ivt_da.name = "integrated_vapor_transport"
# Compute IVT Laplacian
ivt_laplacian = ar.integrated_vapor_transport_laplacian(ivt_da)
ivt_laplacian = ar.integrated_vapor_transport_laplacian(ivt_da).compute()
ivt_laplacian.name = "integrated_vapor_transport_laplacian"

# Compute AR mask
ar_mask = ar.atmospheric_river_mask(
ivt=ivt_da,
ivt_laplacian=ivt_laplacian,
min_size_gridpoints=AR_OBJECT_CONFIG["min_area_gridpoints"],
)
).compute()

# Generate land mask for peak time finding
logger.info(" Generating land mask...")
Expand Down Expand Up @@ -902,20 +887,14 @@ def process_ar_event(

# Create composite AR mask over entire time range (max)
logger.info(" Creating composite AR mask over time...")
composite_ar_mask, composite_land_intersection = create_composite_ar_mask(
ar_mask, land_intersection=land_mask
composite_ar_mask = ar_mask.max(
dim=["valid_time" if "valid_time" in ar_mask.dims else "time"]
)

logger.info(
" Composite AR mask has %s grid points", composite_ar_mask.sum().values
)
logger.info(
" Composite land intersection has %s grid points",
composite_land_intersection.sum().values
if composite_land_intersection is not None
else 0,
)

composite_land_intersection = calc.find_land_intersection(composite_ar_mask)
# Find bounds using composite mask & expand to contiguous AR
left_lon, right_lon, bottom_lat, top_lat, largest_obj_metadata = (
find_ar_bounds_from_largest_object(
Expand Down Expand Up @@ -1052,11 +1031,11 @@ def main():
# shapes
}
if parallel:
with joblib.parallel_backend("dask"):
ar_bounds_results_enhanced = joblib.Parallel(n_jobs=len(ar_events))(
joblib.delayed(process_ar_event)(single_case, era5_ar, AR_OBJECT_CONFIG)
for single_case in ar_events
)
# Fixing the n_jobs arbitrarily; change as desired
ar_bounds_results_enhanced = joblib.Parallel(n_jobs=8)(
joblib.delayed(process_ar_event)(single_case, era5_ar, AR_OBJECT_CONFIG)
for single_case in ar_events
)
else:
# Run in serial using a list comprehension
ar_bounds_results_enhanced = [
Expand Down
10 changes: 4 additions & 6 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@
You can install ExtremeWeatherBench using:

```
uv pip install git+https://github.com/brightband/ExtremeWeatherBench.git
uv sync --all-extras
uv add extremeweatherbench
```

## Using pip
Alternatively, you can use pip in the same fashion (using a virtual environment is recommended):
```
pip install git+https://github.com/brightband/ExtremeWeatherBench.git
pip install extremeweatherbench
```

Alternatively, for developers, you can install the package in development mode:

For development, an easy approach is clone the repository and use `uv` to sync all required and optional dependencies:

```
git clone https://github.com/brightband/ExtremeWeatherBench.git
git clone https://github.com/brightbandtech/ExtremeWeatherBench.git
uv sync --all-extras
```
129 changes: 87 additions & 42 deletions src/extremeweatherbench/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import scores.categorical as categorical
import shapely
import xarray as xr
from numba import float64, guvectorize
from scipy import ndimage
from skimage import filters

from extremeweatherbench import utils

Expand Down Expand Up @@ -276,50 +279,60 @@ def geopotential_thickness(
return geopotential_thickness


def nantrapezoid(
y: np.ndarray,
x: np.ndarray | None = None,
dx: float = 1.0,
axis: int = -1,
):
"""Trapezoid rule for arrays with nans.
@guvectorize(
[(float64[:], float64[:], float64[:])],
"(n),(n)->()",
nopython=True,
target="parallel",
)
def _nantrapezoid_kernel(y, x, out):
"""1D nan-aware trapezoid integration kernel."""
total = 0.0
for i in range(len(y) - 1):
y0 = y[i]
y1 = y[i + 1]
dx = x[i + 1] - x[i]
# skip intervals where either endpoint is nan
if not (np.isnan(y0) or np.isnan(y1)):
total += dx * (y0 + y1) / 2.0
out[()] = total


def nantrapezoid_4d(y, x):
"""
Wrapper that moves the integration axis to last position,
calls the guvectorize kernel, then returns result.

Identical to np.trapezoid but with nans handled correctly in the summation.
y: (time, level, lat, lon)
x: (level,) or same shape as y
"""
y = np.asanyarray(y)
if x is None:
# Create an array of the step size
d = np.full(y.shape[axis] - 1, dx) if y.shape[axis] > 1 else np.array([dx])
# reshape to correct shape
shape = [1] * y.ndim
shape[axis] = d.shape[0]
d = d.reshape(shape)
else:
x = np.asanyarray(x)
if x.ndim == 1:
d = np.diff(x)
# reshape to correct shape
shape = [1] * y.ndim
shape[axis] = d.shape[0]
d = d.reshape(shape)
else:
d = np.diff(x, axis=axis)
if y.ndim != d.ndim:
d = np.expand_dims(d, axis=1)
nd = y.ndim
slice1 = [slice(None)] * nd
slice2 = [slice(None)] * nd
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
try:
# This is the only location different from np.trapezoid
ret = np.nansum(d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0, axis=axis)
except ValueError:
# Operations didn't work, cast to ndarray
d = np.asarray(d)
y = np.asarray(y)
ret = np.add.reduce(d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0, axis)
return ret

return _nantrapezoid_kernel(y, x)


def nantrapezoid_pressure_levels(da: xr.DataArray):
"""Calculates the integral using the trapezoid rule for arrays with nans.

Args:
da: a DataArray with dimensions (time, latitude, longitude, level). Level units
are in hPa.

Returns a DataArray of the computed quantity integrated over the entire column.
"""

# Convert levels to Pascals
levels_pa = da["level"] * 100

output = xr.apply_ufunc(
nantrapezoid_4d,
da,
levels_pa,
input_core_dims=[["level"], ["level"]],
output_core_dims=[[]],
dask="parallelized",
output_dtypes=[float],
)
return output


def specific_humidity_from_relative_humidity(
Expand Down Expand Up @@ -871,3 +884,35 @@ def _is_true_landfall(
# - AttributeError: invalid/None geometry
# - ValueError/TypeError: invalid coordinate values
return False


def _binary_dilation_ufunc(data: xr.DataArray, dilation_radius: int) -> xr.DataArray:
"""Apply binary dilation to a single 2D (lat, lon) slice.

Args:
data: 2D boolean array of shape (lat, lon)
dilation_radius: radius for the dilation in gridpoints

Returns:
Dilated boolean array of shape (lat, lon)
"""
size = dilation_radius * 2 + 1
struct = np.ones((size, size))
return np.expand_dims(
ndimage.binary_dilation(data.squeeze(), structure=struct).astype(np.int8),
axis=0,
)


def _compute_blurred_laplacian_ufunc(data: xr.DataArray, sigma: float) -> xr.DataArray:
"""Compute blurred Laplacian using scipy filters.

Args:
data: IVT data to compute the blurred Laplacian of; data must be 2D
sigma: the standard deviation for the Gaussian filter

Returns:
The blurred Laplacian of IVT
"""
laplace_data = filters.laplace(data)
return ndimage.gaussian_filter(laplace_data, sigma=sigma)
Loading
Loading