Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions src/climatebenchpress/compressor/compressors/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping
from collections.abc import Mapping, Callable
from dataclasses import dataclass
from functools import partial
from types import MappingProxyType
from typing import Callable, Optional

import numpy as np
from numcodecs.abc import Codec
Expand Down Expand Up @@ -42,14 +41,14 @@ class ErrorBound:

Attributes
----------
abs_error : Optional[float]
abs_error : None | float
Absolute error bound for the variable.
rel_error : Optional[float]
rel_error : None | float
Relative error bound for the variable.
"""

abs_error: Optional[float] = None
rel_error: Optional[float] = None
abs_error: None | float = None
rel_error: None | float = None

def __post_init__(self):
if self.abs_error is not None and self.rel_error is not None:
Expand Down Expand Up @@ -85,13 +84,13 @@ class Compressor(ABC):
def abs_bound_codec(
error_bound: float,
*,
dtype: Optional[np.dtype] = None,
data_min: Optional[float] = None,
data_max: Optional[float] = None,
data_abs_min: Optional[float] = None,
data_abs_max: Optional[float] = None,
data_min_2d: Optional[np.ndarray] = None,
data_max_2d: Optional[np.ndarray] = None,
dtype: None | np.dtype = None,
data_min: None | float = None,
data_max: None | float = None,
data_abs_min: None | float = None,
data_abs_max: None | float = None,
data_min_2d: None | np.ndarray = None,
data_max_2d: None | np.ndarray = None,
) -> Codec:
"""Create a codec with an absolute error bound."""
pass
Expand All @@ -101,13 +100,13 @@ def abs_bound_codec(
def rel_bound_codec(
error_bound: float,
*,
dtype: Optional[np.dtype] = None,
data_min: Optional[float] = None,
data_max: Optional[float] = None,
data_abs_min: Optional[float] = None,
data_abs_max: Optional[float] = None,
data_min_2d: Optional[np.ndarray] = None,
data_max_2d: Optional[np.ndarray] = None,
dtype: None | np.dtype = None,
data_min: None | float = None,
data_max: None | float = None,
data_abs_min: None | float = None,
data_abs_max: None | float = None,
data_min_2d: None | np.ndarray = None,
data_max_2d: None | np.ndarray = None,
) -> Codec:
"""Create a codec with a relative error bound."""
pass
Expand Down
13 changes: 10 additions & 3 deletions src/climatebenchpress/compressor/scripts/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import json
import math
import traceback
from collections.abc import Container, Mapping
from collections.abc import Callable, Container, Mapping
from pathlib import Path
from typing import Callable
from typing import cast

import numcodecs_observers
import numpy as np
Expand Down Expand Up @@ -195,6 +195,8 @@ def compress_decompress(
measurements = dict()

for v in ds:
v = str(v)

if v in exclude_variable:
continue
if include_variable and v not in include_variable:
Expand Down Expand Up @@ -225,11 +227,16 @@ def compress_decompress(
timing,
],
) as codec_:
codec_ = cast(CodecStack, codec_) # by duck typing only

variables[v] = codec_.encode_decode_data_array(
ds[v].compute() if is_safeguarded_zero_dssim else ds[v]
).compute()

cs = [c._codec for c in codec_.__iter__()]
cs = [
c._codec # type: ignore
for c in codec_.__iter__()
]

measurements[v] = {
# bytes measurements: only look at the first and last codec in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import argparse
import json
import re
from collections.abc import Iterable
from pathlib import Path
from typing import Iterable

import pandas as pd
import xarray as xr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import argparse
import json
from pathlib import Path
from typing import Optional

import pandas as pd

Expand Down Expand Up @@ -139,7 +138,7 @@ def merge_metrics(

def get_error_bound_name(
variable2bound: dict[str, tuple[str, float]],
error_bound_list: list[dict[str, dict[str, Optional[float]]]],
error_bound_list: list[dict[str, dict[str, None | float]]],
bound_names: list[str] = ["low", "mid", "high"],
) -> str:
"""The function returns either "low", "mid", or "high" depending on which error bound
Expand All @@ -156,7 +155,7 @@ def get_error_bound_name(
A dictionary representing a single error bound, mapping variable names to
tuples of error type and error bound. The error type is either "abs_error"
or "rel_error", and the error bound is a float.
error_bound_list : list[dict[str, dict[str, Optional[float]]]]
error_bound_list : list[dict[str, dict[str, None | float]]]
A list of dictionaries, each representing an error bound (low, mid, high).
Each dictionary contains variable names as keys and a dictionary of error types
and bounds as values.
Expand Down
30 changes: 20 additions & 10 deletions src/climatebenchpress/compressor/scripts/create_error_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -162,8 +161,12 @@ def create_error_bounds(
decode_times=False,
)

low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict()
low_error_bounds: dict[str, dict[str, float | None]] = dict()
mid_error_bounds: dict[str, dict[str, float | None]] = dict()
high_error_bounds: dict[str, dict[str, float | None]] = dict()

for v in ds:
v = str(v)
if v in VAR_NAME_TO_ERA5:
low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = (
get_error_bounds(
Expand Down Expand Up @@ -205,7 +208,7 @@ def create_error_bounds(

def get_error_bounds(
error_bounds: pd.DataFrame, era5_var: str, error_bound_type: str
) -> list[dict[str, Optional[float]]]:
) -> list[dict[str, None | float]]:
var_error_bounds = error_bounds[error_bounds["var"] == era5_var].copy()
single_level = var_error_bounds["level"].unique()[0] == "single"
if single_level:
Expand All @@ -228,7 +231,7 @@ def get_error_bounds(

# Ordered from strictest to most relaxed error bounds.
percentiles = ["100%", "99%", "95%"]
var_ebs = []
var_ebs: list[dict[str, None | float]] = []
for percentile in percentiles:
eb_row = var_error_bounds[var_error_bounds["percentile"] == percentile]

Expand All @@ -246,13 +249,13 @@ def get_error_bounds(
return var_ebs


def get_no2_bounds(percentiles=[1.00, 0.99, 0.95]) -> list[dict[str, Optional[float]]]:
def get_no2_bounds(percentiles=[1.00, 0.99, 0.95]) -> list[dict[str, None | float]]:
# First we need to transform the bitwise real information into a cumulative
# distribution function.
real_information_dist = np.cumsum(NO2_REAL_INFORMATION) / np.sum(
NO2_REAL_INFORMATION
)
no2_bounds = []
no2_bounds: list[dict[str, None | float]] = []
for p in percentiles:
# Find the first position where cumulative distribution exceeds p.
# Add one for 1-based indexing.
Expand All @@ -275,7 +278,7 @@ def get_no2_bounds(percentiles=[1.00, 0.99, 0.95]) -> list[dict[str, Optional[fl

def get_agb_bound(
datasets: Path, percentiles=[1.00, 0.99, 0.95]
) -> list[dict[str, Optional[float]]]:
) -> list[dict[str, None | float]]:
# Define rough bounding box coordinates for mainland France.
# Format: [min_longitude, min_latitude, max_longitude, max_latitude].
FRANCE_BBOX = [-5.5, 42.3, 9.6, 51.1]
Expand All @@ -295,7 +298,7 @@ def get_agb_bound(
mean=agb.agb, spread=agb.agb_sd, percentile=percentiles
)

error_bounds = []
error_bounds: list[dict[str, None | float]] = []
for a, r in zip(ensemble_bounds.absolute, ensemble_bounds.relative):
if VAR_NAME_TO_ERROR_BOUND["agb"] == ABS_ERROR:
error_bounds.append(
Expand Down Expand Up @@ -330,17 +333,24 @@ def compute_ensemble_spread_bounds(

spread_nonzero = spread_values[spread_values > 0.0]

absolute: list[float]
if len(spread_nonzero) > 0:
absolute = np.nanquantile(spread_nonzero, [1 - p for p in percentile])
absolute = [
float(s)
for s in np.nanquantile(spread_nonzero, [1 - p for p in percentile])
]
else:
absolute = [0.0 for _ in percentile]

abs_mean = np.abs(mean_values)
rel = spread_values[abs_mean > 0.0] / abs_mean[abs_mean > 0.0]
rel_nonzero = rel[rel > 0.0]

relative: list[float]
if len(rel_nonzero) > 0:
relative = np.nanquantile(rel_nonzero, [1 - p for p in percentile])
relative = [
float(r) for r in np.nanquantile(rel_nonzero, [1 - p for p in percentile])
]
else:
relative = [0.0 for _ in percentile]

Expand Down
Loading