diff --git a/src/climatebenchpress/compressor/compressors/abc.py b/src/climatebenchpress/compressor/compressors/abc.py index e20429f..effb436 100644 --- a/src/climatebenchpress/compressor/compressors/abc.py +++ b/src/climatebenchpress/compressor/compressors/abc.py @@ -2,11 +2,10 @@ from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Mapping +from collections.abc import Mapping, Callable from dataclasses import dataclass from functools import partial from types import MappingProxyType -from typing import Callable, Optional import numpy as np from numcodecs.abc import Codec @@ -42,14 +41,14 @@ class ErrorBound: Attributes ---------- - abs_error : Optional[float] + abs_error : None | float Absolute error bound for the variable. - rel_error : Optional[float] + rel_error : None | float Relative error bound for the variable. """ - abs_error: Optional[float] = None - rel_error: Optional[float] = None + abs_error: None | float = None + rel_error: None | float = None def __post_init__(self): if self.abs_error is not None and self.rel_error is not None: @@ -85,13 +84,13 @@ class Compressor(ABC): def abs_bound_codec( error_bound: float, *, - dtype: Optional[np.dtype] = None, - data_min: Optional[float] = None, - data_max: Optional[float] = None, - data_abs_min: Optional[float] = None, - data_abs_max: Optional[float] = None, - data_min_2d: Optional[np.ndarray] = None, - data_max_2d: Optional[np.ndarray] = None, + dtype: None | np.dtype = None, + data_min: None | float = None, + data_max: None | float = None, + data_abs_min: None | float = None, + data_abs_max: None | float = None, + data_min_2d: None | np.ndarray = None, + data_max_2d: None | np.ndarray = None, ) -> Codec: """Create a codec with an absolute error bound.""" pass @@ -101,13 +100,13 @@ def abs_bound_codec( def rel_bound_codec( error_bound: float, *, - dtype: Optional[np.dtype] = None, - data_min: Optional[float] = None, - data_max: Optional[float] = None, - data_abs_min: Optional[float] = None, - data_abs_max: Optional[float] = None, - data_min_2d: Optional[np.ndarray] = None, - data_max_2d: Optional[np.ndarray] = None, + dtype: None | np.dtype = None, + data_min: None | float = None, + data_max: None | float = None, + data_abs_min: None | float = None, + data_abs_max: None | float = None, + data_min_2d: None | np.ndarray = None, + data_max_2d: None | np.ndarray = None, ) -> Codec: """Create a codec with a relative error bound.""" pass diff --git a/src/climatebenchpress/compressor/scripts/compress.py b/src/climatebenchpress/compressor/scripts/compress.py index df7d40f..8130874 100644 --- a/src/climatebenchpress/compressor/scripts/compress.py +++ b/src/climatebenchpress/compressor/scripts/compress.py @@ -4,9 +4,9 @@ import json import math import traceback -from collections.abc import Container, Mapping +from collections.abc import Callable, Container, Mapping from pathlib import Path -from typing import Callable +from typing import cast import numcodecs_observers import numpy as np @@ -195,6 +195,8 @@ def compress_decompress( measurements = dict() for v in ds: + v = str(v) + if v in exclude_variable: continue if include_variable and v not in include_variable: @@ -225,11 +227,16 @@ def compress_decompress( timing, ], ) as codec_: + codec_ = cast(CodecStack, codec_) # by duck typing only + variables[v] = codec_.encode_decode_data_array( ds[v].compute() if is_safeguarded_zero_dssim else ds[v] ).compute() - cs = [c._codec for c in codec_.__iter__()] + cs = [ + c._codec # type: ignore + for c in codec_.__iter__() + ] measurements[v] = { # bytes measurements: only look at the first and last codec in diff --git a/src/climatebenchpress/compressor/scripts/compute_metrics.py b/src/climatebenchpress/compressor/scripts/compute_metrics.py index 4cb2208..208255c 100644 --- a/src/climatebenchpress/compressor/scripts/compute_metrics.py +++ b/src/climatebenchpress/compressor/scripts/compute_metrics.py @@ -3,8 +3,8 @@ import argparse import json import re +from collections.abc import Iterable from pathlib import Path -from typing import Iterable import pandas as pd import xarray as xr diff --git a/src/climatebenchpress/compressor/scripts/concatenate_metrics.py b/src/climatebenchpress/compressor/scripts/concatenate_metrics.py index 56836ac..2c8d272 100644 --- a/src/climatebenchpress/compressor/scripts/concatenate_metrics.py +++ b/src/climatebenchpress/compressor/scripts/concatenate_metrics.py @@ -3,7 +3,6 @@ import argparse import json from pathlib import Path -from typing import Optional import pandas as pd @@ -139,7 +138,7 @@ def merge_metrics( def get_error_bound_name( variable2bound: dict[str, tuple[str, float]], - error_bound_list: list[dict[str, dict[str, Optional[float]]]], + error_bound_list: list[dict[str, dict[str, None | float]]], bound_names: list[str] = ["low", "mid", "high"], ) -> str: """The function returns either "low", "mid", or "high" depending on which error bound @@ -156,7 +155,7 @@ def get_error_bound_name( A dictionary representing a single error bound, mapping variable names to tuples of error type and error bound. The error type is either "abs_error" or "rel_error", and the error bound is a float. - error_bound_list : list[dict[str, dict[str, Optional[float]]]] + error_bound_list : list[dict[str, dict[str, None | float]]] A list of dictionaries, each representing an error bound (low, mid, high). Each dictionary contains variable names as keys and a dictionary of error types and bounds as values. diff --git a/src/climatebenchpress/compressor/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py index 26344f4..4f7665e 100644 --- a/src/climatebenchpress/compressor/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -4,7 +4,6 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Optional import numpy as np import pandas as pd @@ -162,8 +161,12 @@ def create_error_bounds( decode_times=False, ) - low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict() + low_error_bounds: dict[str, dict[str, float | None]] = dict() + mid_error_bounds: dict[str, dict[str, float | None]] = dict() + high_error_bounds: dict[str, dict[str, float | None]] = dict() + for v in ds: + v = str(v) if v in VAR_NAME_TO_ERA5: low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = ( get_error_bounds( @@ -205,7 +208,7 @@ def create_error_bounds( def get_error_bounds( error_bounds: pd.DataFrame, era5_var: str, error_bound_type: str -) -> list[dict[str, Optional[float]]]: +) -> list[dict[str, None | float]]: var_error_bounds = error_bounds[error_bounds["var"] == era5_var].copy() single_level = var_error_bounds["level"].unique()[0] == "single" if single_level: @@ -228,7 +231,7 @@ def get_error_bounds( # Ordered from strictest to most relaxed error bounds. percentiles = ["100%", "99%", "95%"] - var_ebs = [] + var_ebs: list[dict[str, None | float]] = [] for percentile in percentiles: eb_row = var_error_bounds[var_error_bounds["percentile"] == percentile] @@ -246,13 +249,13 @@ def get_error_bounds( return var_ebs -def get_no2_bounds(percentiles=[1.00, 0.99, 0.95]) -> list[dict[str, Optional[float]]]: +def get_no2_bounds(percentiles=[1.00, 0.99, 0.95]) -> list[dict[str, None | float]]: # First we need to transform the bitwise real information into a cumulative # distribution function. real_information_dist = np.cumsum(NO2_REAL_INFORMATION) / np.sum( NO2_REAL_INFORMATION ) - no2_bounds = [] + no2_bounds: list[dict[str, None | float]] = [] for p in percentiles: # Find the first position where cumulative distribution exceeds p. # Add one for 1-based indexing. @@ -275,7 +278,7 @@ def get_no2_bounds(percentiles=[1.00, 0.99, 0.95]) -> list[dict[str, Optional[fl def get_agb_bound( datasets: Path, percentiles=[1.00, 0.99, 0.95] -) -> list[dict[str, Optional[float]]]: +) -> list[dict[str, None | float]]: # Define rough bounding box coordinates for mainland France. # Format: [min_longitude, min_latitude, max_longitude, max_latitude]. FRANCE_BBOX = [-5.5, 42.3, 9.6, 51.1] @@ -295,7 +298,7 @@ def get_agb_bound( mean=agb.agb, spread=agb.agb_sd, percentile=percentiles ) - error_bounds = [] + error_bounds: list[dict[str, None | float]] = [] for a, r in zip(ensemble_bounds.absolute, ensemble_bounds.relative): if VAR_NAME_TO_ERROR_BOUND["agb"] == ABS_ERROR: error_bounds.append( @@ -330,8 +333,12 @@ def compute_ensemble_spread_bounds( spread_nonzero = spread_values[spread_values > 0.0] + absolute: list[float] if len(spread_nonzero) > 0: - absolute = np.nanquantile(spread_nonzero, [1 - p for p in percentile]) + absolute = [ + float(s) + for s in np.nanquantile(spread_nonzero, [1 - p for p in percentile]) + ] else: absolute = [0.0 for _ in percentile] @@ -339,8 +346,11 @@ def compute_ensemble_spread_bounds( rel = spread_values[abs_mean > 0.0] / abs_mean[abs_mean > 0.0] rel_nonzero = rel[rel > 0.0] + relative: list[float] if len(rel_nonzero) > 0: - relative = np.nanquantile(rel_nonzero, [1 - p for p in percentile]) + relative = [ + float(r) for r in np.nanquantile(rel_nonzero, [1 - p for p in percentile]) + ] else: relative = [0.0 for _ in percentile]