From a632bdce7234153ffc1fc096f8be2343bead0c9a Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 2 Jun 2025 15:30:12 +0100 Subject: [PATCH 01/14] Use Juniper's ensemble bounds --- .../compressor/scripts/create_error_bounds.py | 246 +++++++++++++++++- 1 file changed, 242 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py index f33c84a..69be31a 100644 --- a/src/climatebenchpress/compressor/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -1,10 +1,49 @@ __all__ = ["create_error_bounds"] import json +from dataclasses import dataclass from pathlib import Path +from typing import Optional +import numpy as np +import pandas as pd import xarray as xr +# Table has header: +# var,level,percentile,min,range,max,lpbits,lpabsolute,lprelative,brlqabsolute,brabsolute,brrelative,esabsolute,esrelative,esquadratic,unabsolute,cabsolute,crelative,cquadratic,pick,crlinquant,crbitround,crlinquantquadstep +# +# esabsolute and esrelative are respectively the absolute and relative error bounds +# derived from the ERA5 ensembles. +ERROR_BOUNDS = "https://raw.githubusercontent.com/juntyr/era5-ensemble/refs/heads/main/table-raw.csv?token=GHSAT0AAAAAACTGGFLKSCEPFNNUEGWSWPEA2CJOYSQ" + + +VAR_NAME_TO_ERA5 = { + # NextGEMS Icon Outgoing Longwave Radiation (OLR). + # Closest ERA5 equivalent Top net long-wave (thermal) radiation + # (https://www.ecmwf.int/sites/default/files/elibrary/2015/18490-radiation-quantities-ecmwf-model-and-mars.pdf). + # which is the negative of OLR. + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/179 + "rlut": "ttr", + # NextGEMS Icon Precipitation + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/228 + "pr": "tp", + # Air temperature. + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/130 + # The CMIP6 data contains temperature data for multiple pressure levels, + # we use the 2m ERA5 temperature data to derive the error bound for all + # pressure levels. + "ta": "t2m", + # Sea surface temperature. + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/34 + "tos": "sst", + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/165 + "10m_u_component_of_wind": "u10", + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/166 + "10m_v_component_of_wind": "v10", + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/151 + "mean_sea_level_pressure": "msl", +} + def create_error_bounds( basepath: Path = Path(), @@ -13,6 +52,8 @@ def create_error_bounds( datasets = (data_loader_base_path or basepath) / "datasets" datasets_error_bounds = basepath / "datasets-error-bounds" + era5_error_bounds = pd.read_csv(ERROR_BOUNDS) + for dataset in datasets.iterdir(): if dataset.name == ".gitignore": continue @@ -33,10 +74,28 @@ def create_error_bounds( # principled method to selct the error bounds. low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict() for v in ds: - data_range: float = (ds[v].max() - ds[v].min()).values.item() # type: ignore - low_error_bounds[v] = {"abs_error": 0.0001 * data_range, "rel_error": None} - mid_error_bounds[v] = {"abs_error": 0.001 * data_range, "rel_error": None} - high_error_bounds[v] = {"abs_error": 0.01 * data_range, "rel_error": None} + if v in VAR_NAME_TO_ERA5: + low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = ( + get_error_bounds(era5_error_bounds, VAR_NAME_TO_ERA5[v]) + ) + elif v == "agb": + low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = ( + get_agb_bound(datasets, percentiles=[1.00, 0.99, 0.95]) + ) + else: + data_range: float = (ds[v].max() - ds[v].min()).values.item() # type: ignore + low_error_bounds[v] = { + "abs_error": 0.0001 * data_range, + "rel_error": None, + } + mid_error_bounds[v] = { + "abs_error": 0.001 * data_range, + "rel_error": None, + } + high_error_bounds[v] = { + "abs_error": 0.01 * data_range, + "rel_error": None, + } error_bounds = [low_error_bounds, mid_error_bounds, high_error_bounds] @@ -46,6 +105,185 @@ def create_error_bounds( json.dump(error_bounds, f) +def get_error_bounds( + error_bounds: pd.DataFrame, era5_var: str +) -> list[dict[str, Optional[float]]]: + var_error_bounds = error_bounds[error_bounds["var"] == era5_var] + assert len(var_error_bounds) == 3, "Expected three error bounds for each variable." + + # Ordered from strictest to most relaxed error bounds. + percentiles = ["100%", "99%", "95%"] + var_ebs = [] + for percentile in percentiles: + eb_row = var_error_bounds[var_error_bounds["percentile"] == percentile] + eb_type = eb_row["pick"].values.item() + + if eb_type == "quadratic": + # Right now no compressor supports quadratic error bounds. We therefore + # fall back to absolute error bounds for them. + eb_type = "absolute" + + if eb_type == "relative": + # Relative error bounds are given as a percentage with an "%" at the end, + # so we need to convert them to a fraction. + rel_error = float(eb_row[f"es{eb_type}"].values.item()[:-1]) / 100.0 + var_ebs.append( + { + "abs_error": None, + "rel_error": rel_error, + } + ) + elif eb_type == "absolute": + var_ebs.append( + { + "abs_error": eb_row[f"es{eb_type}"].values.item(), + "rel_error": None, + } + ) + else: + raise ValueError(f"Unknown error bound type: {eb_type}") + + return var_ebs + + +def get_agb_bound( + datasets: Path, percentiles=[1.00, 0.99, 0.95] +) -> list[dict[str, Optional[float]]]: + # Define rough bounding box coordinates for mainland France. + # Format: [min_longitude, min_latitude, max_longitude, max_latitude]. + FRANCE_BBOX = [-5.5, 42.3, 9.6, 51.1] + + agb = xr.open_dataset( + datasets + / "esa-biomass-cci" + / "download" + / "ESACCI-BIOMASS-L4-AGB-MERGED-100m-2020-fv5.01.nc" + ) + agb = agb.sel( + lon=slice(FRANCE_BBOX[0], FRANCE_BBOX[2]), + lat=slice(FRANCE_BBOX[3], FRANCE_BBOX[1]), + ) + + ensemble_bounds = compute_ensemble_spread_bounds( + mean=agb.agb, spread=agb.agb_sd, percentile=percentiles + ) + + minfo = compute_minimum_bound( + mean=agb.agb, + spread=agb.agb_sd, + percentile=percentiles, + ) + + error_bounds = [] + for a, r, b in zip( + ensemble_bounds.absolute, ensemble_bounds.relative, minfo.mean_bin_spread_bounds + ): + cabs = f"{np.nansum(np.abs(np.array(b) - a) * minfo.mean_bin_counts) / np.sum(minfo.mean_bin_counts):.1e}" + crel = f"{np.nansum(np.abs(np.array(b) - np.abs((np.array(minfo.mean_bin_edges[:-1]) + np.array(minfo.mean_bin_edges[1:])) / 2) * r) * minfo.mean_bin_counts) / np.sum(minfo.mean_bin_counts):.1e}" + + bounds = [("absolute", cabs), ("relative", crel)] + bound_pick = sorted(bounds, key=lambda x: float(x[1]))[0][0] + + if bound_pick == "absolute": + error_bounds.append( + { + "abs_error": float(a), + "rel_error": None, + } + ) + elif bound_pick == "relative": + error_bounds.append( + { + "abs_error": None, + "rel_error": float(r), + } + ) + + return error_bounds + + +@dataclass +class EnsembleSpreadBounds: + percentile: list[float] + absolute: list[float] + relative: list[float] + + +def compute_ensemble_spread_bounds( + mean: xr.DataArray, spread: xr.DataArray, percentile: list[float] +) -> EnsembleSpreadBounds: + mean = mean.values.flatten() + spread = spread.values.flatten() + + spread_nonzero = spread[spread > 0.0] + + if len(spread_nonzero) > 0: + absolute = np.nanquantile(spread_nonzero, [1 - p for p in percentile]) + else: + absolute = [0.0 for _ in percentile] + + abs_mean = np.abs(mean) + rel = spread[abs_mean > 0.0] / abs_mean[abs_mean > 0.0] + rel_nonzero = rel[rel > 0.0] + + if len(rel_nonzero) > 0: + relative = np.nanquantile(rel_nonzero, [1 - p for p in percentile]) + else: + relative = [0.0 for _ in percentile] + + return EnsembleSpreadBounds( + percentile=percentile, + absolute=absolute, + relative=relative, + ) + + +@dataclass +class MinimumBounds: + percentile: list[float] + mean_bin_edges: list[float] + mean_bin_counts: list[int] + mean_bin_spread_bounds: list[list[float]] + + +def compute_minimum_bound( + mean: xr.DataArray, + spread: xr.DataArray, + percentile: list[float], + nbins: int = 100, +) -> MinimumBounds: + mean, spread = mean.copy(deep=True), spread.copy(deep=True) + + mean = mean.values.flatten() + spread = spread.values.flatten() + + mean_bin_edges = np.nanquantile(mean, np.linspace(0.0, 1.0, nbins + 1)) + + ibin = np.minimum(np.searchsorted(mean_bin_edges, mean), nbins - 1) + + mean_bin_spread_bounds = [np.zeros(nbins) for _ in percentile] + + for i in range(nbins): + ispread = spread[ibin == i] + + if len(ispread) > 0: + bs = np.nanquantile(ispread, [1 - p for p in percentile]) + else: + bs = [np.nan for _ in percentile] + + for bd, b in zip(mean_bin_spread_bounds, bs): + bd[i] = b + + mean_bin_counts, _ = np.histogram(mean, bins=mean_bin_edges) + + return MinimumBounds( + percentile=percentile, + mean_bin_edges=list(mean_bin_edges), + mean_bin_counts=list(mean_bin_counts), + mean_bin_spread_bounds=[list(bs) for bs in mean_bin_spread_bounds], + ) + + if __name__ == "__main__": create_error_bounds( basepath=Path(), From fa5f474dd9781f5077a1d439b8afbce4788b02f8 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Thu, 12 Jun 2025 17:15:45 +0100 Subject: [PATCH 02/14] Appease mypy --- .../compressor/scripts/create_error_bounds.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/climatebenchpress/compressor/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py index 69be31a..4010f6f 100644 --- a/src/climatebenchpress/compressor/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -76,7 +76,7 @@ def create_error_bounds( for v in ds: if v in VAR_NAME_TO_ERA5: low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = ( - get_error_bounds(era5_error_bounds, VAR_NAME_TO_ERA5[v]) + get_error_bounds(era5_error_bounds, VAR_NAME_TO_ERA5[str(v)]) ) elif v == "agb": low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = ( @@ -116,7 +116,7 @@ def get_error_bounds( var_ebs = [] for percentile in percentiles: eb_row = var_error_bounds[var_error_bounds["percentile"] == percentile] - eb_type = eb_row["pick"].values.item() + eb_type = eb_row["pick"].item() if eb_type == "quadratic": # Right now no compressor supports quadratic error bounds. We therefore @@ -126,7 +126,7 @@ def get_error_bounds( if eb_type == "relative": # Relative error bounds are given as a percentage with an "%" at the end, # so we need to convert them to a fraction. - rel_error = float(eb_row[f"es{eb_type}"].values.item()[:-1]) / 100.0 + rel_error = float(eb_row[f"es{eb_type}"].item()[:-1]) / 100.0 var_ebs.append( { "abs_error": None, @@ -134,9 +134,10 @@ def get_error_bounds( } ) elif eb_type == "absolute": + abs_error = float(eb_row[f"es{eb_type}"].item()) var_ebs.append( { - "abs_error": eb_row[f"es{eb_type}"].values.item(), + "abs_error": abs_error, "rel_error": None, } ) @@ -212,18 +213,18 @@ class EnsembleSpreadBounds: def compute_ensemble_spread_bounds( mean: xr.DataArray, spread: xr.DataArray, percentile: list[float] ) -> EnsembleSpreadBounds: - mean = mean.values.flatten() - spread = spread.values.flatten() + mean_values = mean.values.flatten() + spread_values = spread.values.flatten() - spread_nonzero = spread[spread > 0.0] + spread_nonzero = spread_values[spread_values > 0.0] if len(spread_nonzero) > 0: absolute = np.nanquantile(spread_nonzero, [1 - p for p in percentile]) else: absolute = [0.0 for _ in percentile] - abs_mean = np.abs(mean) - rel = spread[abs_mean > 0.0] / abs_mean[abs_mean > 0.0] + abs_mean = np.abs(mean_values) + rel = spread_values[abs_mean > 0.0] / abs_mean[abs_mean > 0.0] rel_nonzero = rel[rel > 0.0] if len(rel_nonzero) > 0: @@ -252,19 +253,17 @@ def compute_minimum_bound( percentile: list[float], nbins: int = 100, ) -> MinimumBounds: - mean, spread = mean.copy(deep=True), spread.copy(deep=True) + mean_values = mean.copy(deep=True).values.flatten() + spread_values = spread.copy(deep=True).values.flatten() - mean = mean.values.flatten() - spread = spread.values.flatten() + mean_bin_edges = np.nanquantile(mean_values, np.linspace(0.0, 1.0, nbins + 1)) - mean_bin_edges = np.nanquantile(mean, np.linspace(0.0, 1.0, nbins + 1)) - - ibin = np.minimum(np.searchsorted(mean_bin_edges, mean), nbins - 1) + ibin = np.minimum(np.searchsorted(mean_bin_edges, mean_values), nbins - 1) mean_bin_spread_bounds = [np.zeros(nbins) for _ in percentile] for i in range(nbins): - ispread = spread[ibin == i] + ispread = spread_values[ibin == i] if len(ispread) > 0: bs = np.nanquantile(ispread, [1 - p for p in percentile]) @@ -274,7 +273,7 @@ def compute_minimum_bound( for bd, b in zip(mean_bin_spread_bounds, bs): bd[i] = b - mean_bin_counts, _ = np.histogram(mean, bins=mean_bin_edges) + mean_bin_counts, _ = np.histogram(mean_values, bins=mean_bin_edges) return MinimumBounds( percentile=percentile, From 19cd25908252cc6a72f88ea4c05493b798517c1f Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 16 Jun 2025 10:15:31 +0100 Subject: [PATCH 03/14] Avoid pick column; add netcdf dependency; update error bounds link --- pyproject.toml | 1 + .../compressor/scripts/create_error_bounds.py | 135 +++++------------- 2 files changed, 40 insertions(+), 96 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b015aac..1791186 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "cftime~=1.6.0", "dask>=2024.12.0,<2025.4", "matplotlib~=3.8", + "netcdf4~=1.7.2", "numcodecs>=0.13.0,<0.17", "numcodecs-combinators[xarray]~=0.2.4", "numcodecs-observers~=0.1.1", diff --git a/src/climatebenchpress/compressor/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py index 4010f6f..f8fcda5 100644 --- a/src/climatebenchpress/compressor/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -10,11 +10,11 @@ import xarray as xr # Table has header: -# var,level,percentile,min,range,max,lpbits,lpabsolute,lprelative,brlqabsolute,brabsolute,brrelative,esabsolute,esrelative,esquadratic,unabsolute,cabsolute,crelative,cquadratic,pick,crlinquant,crbitround,crlinquantquadstep +# var,level,percentile,min,range,max,lpbits,lpabsolute,lprelative,brlqabsolute,brabsolute,brrelative,esabsolute,esrelative,esquadratic,unabsolute,cabsolute,crelative,cquadratic,pick,crlinquant,crbitround,crlinquantquadstep,exabsmean,exabsmax,exrelmean,exrelmax # # esabsolute and esrelative are respectively the absolute and relative error bounds # derived from the ERA5 ensembles. -ERROR_BOUNDS = "https://raw.githubusercontent.com/juntyr/era5-ensemble/refs/heads/main/table-raw.csv?token=GHSAT0AAAAAACTGGFLKSCEPFNNUEGWSWPEA2CJOYSQ" +ERROR_BOUNDS = "https://gist.githubusercontent.com/juntyr/bbe2780256e5f91d8f2cb2f606b7935f/raw/table-raw.csv" VAR_NAME_TO_ERA5 = { @@ -45,6 +45,21 @@ } +ABS_ERROR = "abs_error" +REL_ERROR = "rel_error" +VAR_NAME_TO_ERROR_BOUND = { + "rlut": REL_ERROR, + "agb": REL_ERROR, + "pr": ABS_ERROR, + "ta": ABS_ERROR, + "tos": ABS_ERROR, + "10m_u_component_of_wind": ABS_ERROR, + "10m_v_component_of_wind": ABS_ERROR, + "mean_sea_level_pressure": ABS_ERROR, + "no2": ABS_ERROR, +} + + def create_error_bounds( basepath: Path = Path(), data_loader_base_path: None | Path = None, @@ -70,13 +85,15 @@ def create_error_bounds( decode_times=False, ) - # TODO: This is a temporary solution that should be replaced by a more - # principled method to selct the error bounds. low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict() for v in ds: if v in VAR_NAME_TO_ERA5: low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = ( - get_error_bounds(era5_error_bounds, VAR_NAME_TO_ERA5[str(v)]) + get_error_bounds( + era5_error_bounds, + VAR_NAME_TO_ERA5[str(v)], + VAR_NAME_TO_ERROR_BOUND[str(v)], + ) ) elif v == "agb": low_error_bounds[v], mid_error_bounds[v], high_error_bounds[v] = ( @@ -85,16 +102,16 @@ def create_error_bounds( else: data_range: float = (ds[v].max() - ds[v].min()).values.item() # type: ignore low_error_bounds[v] = { - "abs_error": 0.0001 * data_range, - "rel_error": None, + ABS_ERROR: 0.0001 * data_range, + REL_ERROR: None, } mid_error_bounds[v] = { - "abs_error": 0.001 * data_range, - "rel_error": None, + ABS_ERROR: 0.001 * data_range, + REL_ERROR: None, } high_error_bounds[v] = { - "abs_error": 0.01 * data_range, - "rel_error": None, + ABS_ERROR: 0.01 * data_range, + REL_ERROR: None, } error_bounds = [low_error_bounds, mid_error_bounds, high_error_bounds] @@ -106,7 +123,7 @@ def create_error_bounds( def get_error_bounds( - error_bounds: pd.DataFrame, era5_var: str + error_bounds: pd.DataFrame, era5_var: str, error_bound_type: str ) -> list[dict[str, Optional[float]]]: var_error_bounds = error_bounds[error_bounds["var"] == era5_var] assert len(var_error_bounds) == 3, "Expected three error bounds for each variable." @@ -116,33 +133,17 @@ def get_error_bounds( var_ebs = [] for percentile in percentiles: eb_row = var_error_bounds[var_error_bounds["percentile"] == percentile] - eb_type = eb_row["pick"].item() - - if eb_type == "quadratic": - # Right now no compressor supports quadratic error bounds. We therefore - # fall back to absolute error bounds for them. - eb_type = "absolute" - if eb_type == "relative": + if error_bound_type == REL_ERROR: # Relative error bounds are given as a percentage with an "%" at the end, # so we need to convert them to a fraction. - rel_error = float(eb_row[f"es{eb_type}"].item()[:-1]) / 100.0 - var_ebs.append( - { - "abs_error": None, - "rel_error": rel_error, - } - ) - elif eb_type == "absolute": - abs_error = float(eb_row[f"es{eb_type}"].item()) - var_ebs.append( - { - "abs_error": abs_error, - "rel_error": None, - } - ) + rel_error = float(eb_row["esrelative"].item()[:-1]) / 100.0 + var_ebs.append({ABS_ERROR: None, REL_ERROR: rel_error}) + elif error_bound_type == ABS_ERROR: + abs_error = float(eb_row["esabsolute"].item()) + var_ebs.append({ABS_ERROR: abs_error, REL_ERROR: None}) else: - raise ValueError(f"Unknown error bound type: {eb_type}") + raise ValueError(f"Unknown error bound type: {error_bound_type}") return var_ebs @@ -169,30 +170,16 @@ def get_agb_bound( mean=agb.agb, spread=agb.agb_sd, percentile=percentiles ) - minfo = compute_minimum_bound( - mean=agb.agb, - spread=agb.agb_sd, - percentile=percentiles, - ) - error_bounds = [] - for a, r, b in zip( - ensemble_bounds.absolute, ensemble_bounds.relative, minfo.mean_bin_spread_bounds - ): - cabs = f"{np.nansum(np.abs(np.array(b) - a) * minfo.mean_bin_counts) / np.sum(minfo.mean_bin_counts):.1e}" - crel = f"{np.nansum(np.abs(np.array(b) - np.abs((np.array(minfo.mean_bin_edges[:-1]) + np.array(minfo.mean_bin_edges[1:])) / 2) * r) * minfo.mean_bin_counts) / np.sum(minfo.mean_bin_counts):.1e}" - - bounds = [("absolute", cabs), ("relative", crel)] - bound_pick = sorted(bounds, key=lambda x: float(x[1]))[0][0] - - if bound_pick == "absolute": + for a, r in zip(ensemble_bounds.absolute, ensemble_bounds.relative): + if VAR_NAME_TO_ERROR_BOUND["agb"] == ABS_ERROR: error_bounds.append( { "abs_error": float(a), "rel_error": None, } ) - elif bound_pick == "relative": + elif VAR_NAME_TO_ERROR_BOUND["agb"] == REL_ERROR: error_bounds.append( { "abs_error": None, @@ -239,50 +226,6 @@ def compute_ensemble_spread_bounds( ) -@dataclass -class MinimumBounds: - percentile: list[float] - mean_bin_edges: list[float] - mean_bin_counts: list[int] - mean_bin_spread_bounds: list[list[float]] - - -def compute_minimum_bound( - mean: xr.DataArray, - spread: xr.DataArray, - percentile: list[float], - nbins: int = 100, -) -> MinimumBounds: - mean_values = mean.copy(deep=True).values.flatten() - spread_values = spread.copy(deep=True).values.flatten() - - mean_bin_edges = np.nanquantile(mean_values, np.linspace(0.0, 1.0, nbins + 1)) - - ibin = np.minimum(np.searchsorted(mean_bin_edges, mean_values), nbins - 1) - - mean_bin_spread_bounds = [np.zeros(nbins) for _ in percentile] - - for i in range(nbins): - ispread = spread_values[ibin == i] - - if len(ispread) > 0: - bs = np.nanquantile(ispread, [1 - p for p in percentile]) - else: - bs = [np.nan for _ in percentile] - - for bd, b in zip(mean_bin_spread_bounds, bs): - bd[i] = b - - mean_bin_counts, _ = np.histogram(mean_values, bins=mean_bin_edges) - - return MinimumBounds( - percentile=percentile, - mean_bin_edges=list(mean_bin_edges), - mean_bin_counts=list(mean_bin_counts), - mean_bin_spread_bounds=[list(bs) for bs in mean_bin_spread_bounds], - ) - - if __name__ == "__main__": create_error_bounds( basepath=Path(), From 5135dd0307b3a86f43e22b09644245f213651e6a Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 16 Jun 2025 10:40:48 +0100 Subject: [PATCH 04/14] Add units to ERA5 variables --- .../compressor/scripts/create_error_bounds.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/climatebenchpress/compressor/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py index f8fcda5..4b9b6d2 100644 --- a/src/climatebenchpress/compressor/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -19,28 +19,42 @@ VAR_NAME_TO_ERA5 = { # NextGEMS Icon Outgoing Longwave Radiation (OLR). - # Closest ERA5 equivalent Top net long-wave (thermal) radiation + # Closest ERA5 equivalent Mean flux top net long-wave radiation # (https://www.ecmwf.int/sites/default/files/elibrary/2015/18490-radiation-quantities-ecmwf-model-and-mars.pdf). # which is the negative of OLR. - # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/179 - "rlut": "ttr", + # NOTE: Be careful in using the flux instead of the time-accumulated variables. + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/235040 + # ERA5 unit: W m-2 + # NextGEMS unit: W m-2 + "rlut": "avg_tnlwrf", # NextGEMS Icon Precipitation - # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/228 - "pr": "tp", + # NOTE: Be careful in using the flux instead of the time-accumulated variables. + # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/235055 + # ERA5 unit: kg m-2 s-1 + # NextGEMS unit: kg m-2 s-1 + "pr": "avg_tprate", # Air temperature. # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/130 # The CMIP6 data contains temperature data for multiple pressure levels, # we use the 2m ERA5 temperature data to derive the error bound for all # pressure levels. + # ERA5 unit: K + # CMIP6 unit: K "ta": "t2m", # Sea surface temperature. + # NOTE: Difference in units means we should use absolute error bounds. # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/34 + # ERA5 unit: K + # CMIP6 unit: degC "tos": "sst", # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/165 + # Units will match because data source is ERA5. "10m_u_component_of_wind": "u10", # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/166 + # Units will match because data source is ERA5. "10m_v_component_of_wind": "v10", # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/151 + # Units will match because data source is ERA5. "mean_sea_level_pressure": "msl", } From 9ca0a796bd8800bbf2ecc018f3bf6d43ad2baac8 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 17 Jun 2025 12:10:24 +0100 Subject: [PATCH 05/14] Save name of error bound when computing metrics --- .../compressor/scripts/collect_metrics.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/climatebenchpress/compressor/scripts/collect_metrics.py b/src/climatebenchpress/compressor/scripts/collect_metrics.py index 904c995..2d1415d 100644 --- a/src/climatebenchpress/compressor/scripts/collect_metrics.py +++ b/src/climatebenchpress/compressor/scripts/collect_metrics.py @@ -3,6 +3,7 @@ import json import re from pathlib import Path +from typing import Optional import pandas as pd import xarray as xr @@ -30,6 +31,7 @@ def collect_metrics( ): datasets = (data_loader_base_path or basepath) / "datasets" compressed_datasets = basepath / "compressed-datasets" + error_bounds_dir = basepath / "datasets-error-bounds" metrics_dir = basepath / "metrics" all_results = [] @@ -37,8 +39,14 @@ def collect_metrics( if dataset.name == ".gitignore": continue + with (error_bounds_dir / dataset.name / "error_bounds.json").open() as f: + error_bound_list = json.load(f) + for error_bound in dataset.iterdir(): variable2error_bound = parse_error_bounds(error_bound.name) + error_bound_name = get_error_bound_name( + variable2error_bound, error_bound_list + ) for compressor in error_bound.iterdir(): print(f"Evaluating {compressor.stem} on {dataset.name}...") @@ -75,12 +83,62 @@ def collect_metrics( df = merge_metrics(measurements, metrics, tests) df["Dataset"] = dataset.name df["Error Bound"] = error_bound.name + df["Error Bound Name"] = error_bound_name all_results.append(df) all_results_df = pd.concat(all_results) all_results_df.to_csv(metrics_dir / "all_results.csv", index=False) +def get_error_bound_name( + variable2bound: dict[str, tuple[str, float]], + error_bound_list: list[dict[str, dict[str, Optional[float]]]], + bound_names: list[str] = ["low", "mid", "high"], +) -> str: + """The function returns either "low", "mid", or "high" depending on which error bound + from the variable2bound dictionary matches the exact error bound in the error_bound_list. + + error_bound_list contains one dictionary for each error bound (low, mid, high). + Each of these dictionaries contains the error bounds for + each variable. The variable names in the dictionaries should exactly match the variable names + in the variable2bound dictionary. + + Parameters + ---------- + variable2bound : dict[str, tuple[str, float]] + A dictionary representing a single error bound, mapping variable names to + tuples of error type and error bound. The error type is either "abs_error" + or "rel_error", and the error bound is a float. + error_bound_list : list[dict[str, dict[str, Optional[float]]]] + A list of dictionaries, each representing an error bound (low, mid, high). + Each dictionary contains variable names as keys and a dictionary of error types + and bounds as values. + bound_names : list[str], optional + A list of names for the error bounds, by default ["low", "mid", "high"]. + """ + + # Convert the variable2bound dictionary to match the format of error_bound_list. + for k in variable2bound.keys(): + variable2bound[k] = { + "abs_error": ( + variable2bound[k][1] if variable2bound[k][0] == "abs_error" else None + ), + "rel_error": ( + variable2bound[k][1] if variable2bound[k][0] == "rel_error" else None + ), + } + + # Return the name of the error bound that matches variable2bound. + for bound_name, error_bound in zip(bound_names, error_bound_list): + if variable2bound == error_bound: + return bound_name + + raise ValueError( + f"Error bounds {variable2bound} do not match any of the error bounds " + f"{error_bound_list}." + ) + + def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]: """ The error bound string is of the form From d2204bad15f01040acc41bc59d72ddee619e55b1 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 17 Jun 2025 12:14:30 +0100 Subject: [PATCH 06/14] Adjust minimum computation to avoid division by zero --- src/climatebenchpress/compressor/scripts/compress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/climatebenchpress/compressor/scripts/compress.py b/src/climatebenchpress/compressor/scripts/compress.py index 819ef56..08c3b9f 100644 --- a/src/climatebenchpress/compressor/scripts/compress.py +++ b/src/climatebenchpress/compressor/scripts/compress.py @@ -55,7 +55,8 @@ def compress( ) for v in ds: abs_vals = xr.ufuncs.abs(ds[v]) - ds_abs_mins[v] = abs_vals.min().values.item() + # Take minimum of non-zero absolute values to avoid division by zero. + ds_abs_mins[v] = abs_vals.where(abs_vals > 0).min().values.item() ds_abs_maxs[v] = abs_vals.max().values.item() ds_mins[v] = ds[v].min().values.item() ds_maxs[v] = ds[v].max().values.item() From 427b5686e4a65a32165c8393d981a4dde938661d Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 17 Jun 2025 14:13:58 +0100 Subject: [PATCH 07/14] Adjust plotting to handle mix of relative and absolute error bounds --- .../compressor/plotting/plot_metrics.py | 133 ++++++++++-------- 1 file changed, 71 insertions(+), 62 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 1ebd5a9..be49fa3 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -12,30 +12,48 @@ from .variable_plotters import PLOTTERS COMPRESSOR2LINEINFO = { - "jpeg2000": ("#EE7733", "-"), - "sperr": ("#117733", ":"), - "zfp": ("#EE3377", "--"), - "zfp-round": ("#DDAA33", "--"), - "sz3": ("#CC3311", "-."), - "bitround-pco-conservative-rel": ("#0077BB", ":"), - "bitround-conservative-rel": ("#33BBEE", "-"), - "stochround": ("#009988", "--"), - "stochround-pco": ("#BBBBBB", "--"), - "tthresh": ("#882255", "-."), + ("jpeg2000", ("#EE7733", "-")), + ("sperr", ("#117733", ":")), + ("zfp", ("#EE3377", "--")), + ("zfp-round", ("#DDAA33", "--")), + ("sz3", ("#CC3311", "-.")), + ("bitround-pco", ("#0077BB", ":")), + ("bitround", ("#33BBEE", "-")), + ("stochround-pco", ("#BBBBBB", "--")), + ("stochround", ("#009988", "--")), + ("tthresh", ("#882255", "-.")), } -COMPRESSOR2LEGEND_NAME = { - "jpeg2000": "JPEG2000", - "sperr": "SPERR", - "zfp": "ZFP", - "zfp-round": "ZFP-ROUND", - "sz3": "SZ3", - "bitround-pco-conservative-rel": "BitRound + PCO", - "bitround-conservative-rel": "BitRound + Zstd", - "stochround": "StochRound + Zstd", - "stochround-pco": "StochRound + PCO", - "tthresh": "TTHRESH", -} + +def get_lineinfo(compressor: str) -> tuple[str, str]: + """Get the line color and style for a given compressor.""" + for comp, (color, linestyle) in COMPRESSOR2LINEINFO: + if compressor.startswith(comp): + return color, linestyle + raise ValueError(f"Unknown compressor: {compressor}") + + +COMPRESSOR2LEGEND_NAME = [ + ("jpeg2000", "JPEG2000"), + ("sperr", "SPERR"), + ("zfp-round", "ZFP-ROUND"), + ("zfp", "ZFP"), + ("sz3", "SZ3"), + ("bitround-pco", "BitRound + PCO"), + ("bitround", "BitRound + Zstd"), + ("stochround-pco", "StochRound + PCO"), + ("stochround", "StochRound + Zstd"), + ("tthresh", "TTHRESH"), +] + + +def get_legend_name(compressor: str) -> str: + """Get the legend name for a given compressor.""" + for comp, name in COMPRESSOR2LEGEND_NAME: + if compressor.startswith(comp): + return name + + return compressor # Fallback to the compressor name if not found in the mapping. def plot_metrics( @@ -68,14 +86,15 @@ def plot_metrics( all_results=df, ) - df = rename_error_bounds(df, bound_names) - plot_throughput(df, plots_path / "throughput.pdf") - plot_instruction_count(df, plots_path / "instruction_count.pdf") + df = rename_compressors(df) + normalized_df = normalize(df, bound_normalize="mid", normalizer=normalizer) normalized_df = normalize(df, bound_normalize="mid", normalizer=normalizer) plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) + plot_throughput(df, plots_path / "throughput.pdf") + plot_instruction_count(df, plots_path / "instruction_count.pdf") for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]: with plt.rc_context(rc={"text.usetex": True}): @@ -90,27 +109,17 @@ def plot_metrics( ) -def rename_error_bounds(df, bound_names): - """Give error bound consistent names between variables. By default the error bounds - have the pattern {variable_name}-{bound_type}={bound_value}.""" - # Get unique variables - variables = df["Variable"].unique() - - # Process each variable - for variable in variables: - var_selector = df["Variable"] == variable - var_data = df[var_selector] - - error_bounds = sort_error_bounds(var_data["Error Bound"].unique()) - - assert len(error_bounds) == len(bound_names), ( - f"Number of error bounds {len(error_bounds)} does not match number of bound names {len(bound_names)} for {variable}." - ) - - for i in range(len(error_bounds)): - bound_selector = var_data["Error Bound"] == error_bounds[i] - df.loc[bound_selector & var_selector, "Error Bound"] = bound_names[i] - +def rename_compressors(df): + """Give compressors consistent names. They sometimes have suffixes if they are + applied on a converted error bound. The three patterns are: + - {compressor_name} + - {compressor_name}-conservative-abs + - {compressor_name}-conservative-rel + """ + df = df.copy() + df["Compressor"] = df["Compressor"].str.replace( + r"-(conservative-(abs|rel))$", "", regex=True + ) return df @@ -141,7 +150,7 @@ def normalize(data, bound_normalize="mid", normalizer=None): if normalizer is None: # Group by Variable and rank compressors within each variable ranked = data.copy() - ranked = ranked[ranked["Error Bound"] == bound_normalize] + ranked = ranked[ranked["Error Bound Name"] == bound_normalize] ranked["CompRatio_Rank"] = ranked.groupby("Variable")[ "Compression Ratio [raw B / enc B]" ].rank(ascending=False) @@ -167,7 +176,7 @@ def get_normalizer(row): return normalized[ (data["Compressor"] == normalizer) & (data["Variable"] == row["Variable"]) - & (data["Error Bound"] == bound_normalize) + & (data["Error Bound Name"] == bound_normalize) ][col].item() for col, new_col in normalize_vars: @@ -297,11 +306,11 @@ def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None): for i in sorting_ixs ] distortion = [compressor_data[distortion_metric].iloc[i] for i in sorting_ixs] - color, linestyle = COMPRESSOR2LINEINFO[comp] + color, linestyle = get_lineinfo(comp) plt.plot( compr_ratio, distortion, - label=COMPRESSOR2LEGEND_NAME[comp], + label=get_legend_name(comp), marker="s", color=color, linestyle=linestyle, @@ -356,7 +365,7 @@ def plot_aggregated_rd_curve( ): plt.figure(figsize=(8, 6)) compressors = normalized_df["Compressor"].unique() - agg_distortion = normalized_df.groupby(["Error Bound", "Compressor"])[ + agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ [compression_metric, distortion_metric] ].agg(agg) for comp in compressors: @@ -368,11 +377,11 @@ def plot_aggregated_rd_curve( agg_distortion.loc[(bound, comp), distortion_metric] for bound in bound_names ] - color, linestyle = COMPRESSOR2LINEINFO[comp] + color, linestyle = get_lineinfo(comp) plt.plot( compr_ratio, distortion, - label=COMPRESSOR2LEGEND_NAME[comp], + label=get_legend_name(comp), marker="s", color=color, linestyle=linestyle, @@ -418,7 +427,7 @@ def plot_aggregated_rd_curve( fontsize=12, title_fontsize=14, ) - normalizer_label = COMPRESSOR2LEGEND_NAME.get(normalizer, normalizer) + normalizer_label = get_legend_name(normalizer) plt.xlabel( rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", fontsize=16, @@ -463,7 +472,7 @@ def plot_throughput(df, outfile: None | Path = None): # with instruction count measurements. encode_col = "Encode Throughput [raw B / s]" decode_col = "Decode Throughput [raw B / s]" - new_df = df[["Compressor", "Error Bound", encode_col, decode_col]].copy() + new_df = df[["Compressor", "Error Bound Name", encode_col, decode_col]].copy() transformed_encode_col = "Encode Throughput [s / MB]" transformed_decode_col = "Decode Throughput [s / MB]" new_df[transformed_encode_col] = 1e6 / new_df[encode_col] @@ -492,7 +501,7 @@ def plot_instruction_count(df, outfile: None | Path = None): def get_median_and_quantiles(df, encode_column, decode_column): - return df.groupby(["Compressor", "Error Bound"])[ + return df.groupby(["Compressor", "Error Bound Name"])[ [encode_column, decode_column] ].agg( encode_median=pd.NamedAgg( @@ -522,14 +531,14 @@ def plot_grouped_df(grouped_df, title, ylabel, outfile: None | Path = None): # Bar width bar_width = 0.35 compressors = grouped_df.index.levels[0].tolist() - x_labels = [COMPRESSOR2LEGEND_NAME[c] for c in compressors] + x_labels = [get_legend_name(c) for c in compressors] x_positions = range(len(x_labels)) error_bounds = ["low", "mid", "high"] for i, error_bound in enumerate(error_bounds): ax = axes[i] - bound_data = grouped_df.xs(error_bound, level="Error Bound") + bound_data = grouped_df.xs(error_bound, level="Error Bound Name") # Plot encode throughput ax.bar( @@ -541,7 +550,7 @@ def plot_grouped_df(grouped_df, title, ylabel, outfile: None | Path = None): bound_data["encode_upper_quantile"], ], label="Encoding", - color=[COMPRESSOR2LINEINFO[comp][0] for comp in compressors], + color=[get_lineinfo(comp)[0] for comp in compressors], ) # Plot decode throughput @@ -554,7 +563,7 @@ def plot_grouped_df(grouped_df, title, ylabel, outfile: None | Path = None): bound_data["decode_upper_quantile"], ], label="Decoding", - edgecolor=[COMPRESSOR2LINEINFO[comp][0] for comp in compressors], + edgecolor=[get_lineinfo(comp)[0] for comp in compressors], fill=False, linewidth=4, ) @@ -591,8 +600,8 @@ def plot_bound_violations(df, bound_names, outfile: None | Path = None): fig, axs = plt.subplots(1, 3, figsize=(len(bound_names) * 6, 6), sharey=True) for i, bound_name in enumerate(bound_names): - df_bound = df[df["Error Bound"] == bound_name].copy() - df_bound["Compressor"] = df_bound["Compressor"].map(COMPRESSOR2LEGEND_NAME) + df_bound = df[df["Error Bound Name"] == bound_name].copy() + df_bound["Compressor"] = df_bound["Compressor"].map(get_legend_name) pass_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Passed)" ) From 35fb48992c6700e0f2c3f8b5140510aee403318a Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 17 Jun 2025 14:19:50 +0100 Subject: [PATCH 08/14] Initialize new dict to pass mypy --- .../compressor/scripts/collect_metrics.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/scripts/collect_metrics.py b/src/climatebenchpress/compressor/scripts/collect_metrics.py index 2d1415d..20b7147 100644 --- a/src/climatebenchpress/compressor/scripts/collect_metrics.py +++ b/src/climatebenchpress/compressor/scripts/collect_metrics.py @@ -118,8 +118,9 @@ def get_error_bound_name( """ # Convert the variable2bound dictionary to match the format of error_bound_list. + new_bound_format = dict() for k in variable2bound.keys(): - variable2bound[k] = { + new_bound_format[k] = { "abs_error": ( variable2bound[k][1] if variable2bound[k][0] == "abs_error" else None ), @@ -128,13 +129,13 @@ def get_error_bound_name( ), } - # Return the name of the error bound that matches variable2bound. + # Return the name of the error bound that matches new_bound_format. for bound_name, error_bound in zip(bound_names, error_bound_list): - if variable2bound == error_bound: + if new_bound_format == error_bound: return bound_name raise ValueError( - f"Error bounds {variable2bound} do not match any of the error bounds " + f"Error bounds {new_bound_format} do not match any of the error bounds " f"{error_bound_list}." ) From 70408a7bd749418cc2a78047217b4b02cbba799a Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 17 Jun 2025 15:57:14 +0100 Subject: [PATCH 09/14] Turn set into list --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index be49fa3..b7ed96f 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -11,7 +11,7 @@ from .error_dist_plotter import ErrorDistPlotter from .variable_plotters import PLOTTERS -COMPRESSOR2LINEINFO = { +COMPRESSOR2LINEINFO = [ ("jpeg2000", ("#EE7733", "-")), ("sperr", ("#117733", ":")), ("zfp", ("#EE3377", "--")), @@ -22,7 +22,7 @@ ("stochround-pco", ("#BBBBBB", "--")), ("stochround", ("#009988", "--")), ("tthresh", ("#882255", "-.")), -} +] def get_lineinfo(compressor: str) -> tuple[str, str]: From e642a9d00889240e2f6aa45a260aa0b473f42981 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 17 Jun 2025 15:59:36 +0100 Subject: [PATCH 10/14] Reorder ZFP variants for proper plotting --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index b7ed96f..a9d2a45 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -14,8 +14,8 @@ COMPRESSOR2LINEINFO = [ ("jpeg2000", ("#EE7733", "-")), ("sperr", ("#117733", ":")), - ("zfp", ("#EE3377", "--")), ("zfp-round", ("#DDAA33", "--")), + ("zfp", ("#EE3377", "--")), ("sz3", ("#CC3311", "-.")), ("bitround-pco", ("#0077BB", ":")), ("bitround", ("#33BBEE", "-")), From cc0a71225f0b9fb769a1aa5a4039b46a49a68dec Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Tue, 17 Jun 2025 16:24:11 +0100 Subject: [PATCH 11/14] Avoid double normalization call --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index a9d2a45..a72c7b2 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -87,8 +87,6 @@ def plot_metrics( ) df = rename_compressors(df) - normalized_df = normalize(df, bound_normalize="mid", normalizer=normalizer) - normalized_df = normalize(df, bound_normalize="mid", normalizer=normalizer) plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" From 22fae2845edd4a89dafa24774851eee57f785bef Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Wed, 18 Jun 2025 15:29:01 +0100 Subject: [PATCH 12/14] Average over multiple levels for temperature; use relative error for wind --- .../compressor/scripts/create_error_bounds.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/climatebenchpress/compressor/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py index 4b9b6d2..8026cfe 100644 --- a/src/climatebenchpress/compressor/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -35,12 +35,9 @@ "pr": "avg_tprate", # Air temperature. # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/130 - # The CMIP6 data contains temperature data for multiple pressure levels, - # we use the 2m ERA5 temperature data to derive the error bound for all - # pressure levels. # ERA5 unit: K # CMIP6 unit: K - "ta": "t2m", + "ta": "t", # Sea surface temperature. # NOTE: Difference in units means we should use absolute error bounds. # ERA5 documentation: https://codes.ecmwf.int/grib/param-db/34 @@ -67,8 +64,8 @@ "pr": ABS_ERROR, "ta": ABS_ERROR, "tos": ABS_ERROR, - "10m_u_component_of_wind": ABS_ERROR, - "10m_v_component_of_wind": ABS_ERROR, + "10m_u_component_of_wind": REL_ERROR, + "10m_v_component_of_wind": REL_ERROR, "mean_sea_level_pressure": ABS_ERROR, "no2": ABS_ERROR, } @@ -139,8 +136,25 @@ def create_error_bounds( def get_error_bounds( error_bounds: pd.DataFrame, era5_var: str, error_bound_type: str ) -> list[dict[str, Optional[float]]]: - var_error_bounds = error_bounds[error_bounds["var"] == era5_var] - assert len(var_error_bounds) == 3, "Expected three error bounds for each variable." + var_error_bounds = error_bounds[error_bounds["var"] == era5_var].copy() + single_level = var_error_bounds["level"].unique()[0] == "single" + if single_level: + assert len(var_error_bounds) == 3, ( + "Expected three error bounds for each variable." + ) + else: + # For variables with multiple levels (only air temperature at this point) + # take the average error bound across all levels. + var_error_bounds["esrelative_float"] = ( + var_error_bounds["esrelative"].str.rstrip("%").astype(float) + ) + grouped = ( + var_error_bounds.groupby(["percentile"]) + .agg({"esabsolute": "mean", "esrelative_float": "mean"}) + .reset_index() + ) + grouped["esrelative"] = grouped["esrelative_float"].astype(str) + "%" + var_error_bounds = grouped[["percentile", "esabsolute", "esrelative"]] # Ordered from strictest to most relaxed error bounds. percentiles = ["100%", "99%", "95%"] From c58ef9a9c271234a3e913d1a47709fc60ee8b547 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Wed, 18 Jun 2025 15:46:03 +0100 Subject: [PATCH 13/14] Plotting adjustments --- .../compressor/plotting/error_dist_plotter.py | 29 ++++++----- .../compressor/plotting/plot_metrics.py | 52 +++++++++++++++++-- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py index b528c9a..ab79802 100644 --- a/src/climatebenchpress/compressor/plotting/error_dist_plotter.py +++ b/src/climatebenchpress/compressor/plotting/error_dist_plotter.py @@ -30,7 +30,7 @@ def compute_errors(self, compressor, ds, ds_new, var, err_bound_type): error / np.abs(ds[var]), ) .compute() - .value + .values ) else: raise ValueError(f"Unknown error bound type: {err_bound_type}") @@ -42,34 +42,35 @@ def plot_error_bound_histograms( variables, compressors, error_bound_vals, - compressor2legendname, - compressor2lineinfo, + get_legend_name, + get_line_info, ): """ Plot error histograms for a single error bound across all variables in that dataset. """ # We only plot bitround and stochround once because the lossless compressor - # does not change the error plot distribution. - legend_names = compressor2legendname.copy() - legend_names.update( - { - "bitround-conservative-rel": "BitRound", - "stochround": "StochRound", - } - ) + # does not change the error plot distribution. Hence, we ignore the PCO + # compressors here. compressors = [comp for comp in compressors if "-pco" not in comp] for j, var in enumerate(variables): for comp in compressors: + color, linestyle = get_line_info(comp) + label = get_legend_name(comp) + # Don't state the lossless compressor in the legend. + if label.startswith("BitRound"): + label = "BitRound" + elif label.startswith("StochRound"): + label = "StochRound" self.axes[j, col_index].hist( self.errors[var][comp], bins=100, density=True, histtype="step", - label=legend_names.get(comp, comp), - color=compressor2lineinfo.get(comp, ("#000000", "-"))[0], - linestyle=compressor2lineinfo.get(comp, ("#000000", "-"))[1], + label=label, + color=color, + linestyle=linestyle, linewidth=2, alpha=0.8, ) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index a72c7b2..648d189 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -64,6 +64,7 @@ def plot_metrics( exclude_dataset: list[str] = [], exclude_compressor: list[str] = [], tiny_datasets: bool = False, + use_latex: bool = True, ): metrics_path = basepath / "metrics" plots_path = basepath / "plots" @@ -95,7 +96,7 @@ def plot_metrics( plot_instruction_count(df, plots_path / "instruction_count.pdf") for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]: - with plt.rc_context(rc={"text.usetex": True}): + with plt.rc_context(rc={"text.usetex": use_latex}): plot_aggregated_rd_curve( normalized_df, normalizer=normalizer, @@ -263,8 +264,8 @@ def plot_per_variable_metrics( variables, compressors, error_bound_vals, - COMPRESSOR2LEGEND_NAME, - COMPRESSOR2LINEINFO, + get_legend_name, + get_lineinfo, ) fig, _ = error_dist_plotter.get_final_figure() @@ -362,6 +363,11 @@ def plot_aggregated_rd_curve( bound_names=["low", "mid", "high"], ): plt.figure(figsize=(8, 6)) + if distortion_metric == "DSSIM": + # For fields with large number of NaNs, the DSSIM values are unreliable + # which is why we exclude them here. + normalized_df = normalized_df[~normalized_df["Variable"].isin(["ta", "tos"])] + compressors = normalized_df["Compressor"].unique() agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ [compression_metric, distortion_metric] @@ -417,6 +423,7 @@ def plot_aggregated_rd_curve( right=True, ) + normalizer_label = get_legend_name(normalizer) if "MAE" in distortion_metric: plt.legend( title="Compressor", @@ -425,7 +432,6 @@ def plot_aggregated_rd_curve( fontsize=12, title_fontsize=14, ) - normalizer_label = get_legend_name(normalizer) plt.xlabel( rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", fontsize=16, @@ -458,6 +464,42 @@ def plot_aggregated_rd_curve( color=arrow_color, ha="center", ) + elif "DSSIM" in distortion_metric: + plt.xlabel( + rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", + fontsize=16, + ) + plt.ylabel( + rf"Median DSSIM to {normalizer_label} ($\downarrow$)", + fontsize=16, + ) + arrow_color = "black" + # Add an arrow pointing into the top right corner + plt.annotate( + "", + xy=(0.95, 0.95), + xycoords="axes fraction", + xytext=(-60, -50), + textcoords="offset points", + arrowprops=dict( + arrowstyle="-|>, head_length=0.5, head_width=0.5", + color=arrow_color, + lw=5, + ), + ) + # Attach the text to the lower left of the arrow + plt.text( + 0.83, + 0.92, + "Better", + transform=plt.gca().transAxes, + fontsize=16, + fontweight="bold", + color=arrow_color, + ha="center", + va="center", + ) + plt.legend().remove() plt.tight_layout() if outfile is not None: @@ -653,6 +695,7 @@ def savefig(outfile: Path): parser.add_argument("--exclude-dataset", type=str, nargs="+", default=[]) parser.add_argument("--exclude-compressor", type=str, nargs="+", default=[]) parser.add_argument("--tiny-datasets", action="store_true", default=False) + parser.add_argument("--avoid-latex", action="store_true", default=False) args = parser.parse_args() plot_metrics( @@ -661,4 +704,5 @@ def savefig(outfile: Path): exclude_compressor=args.exclude_compressor, exclude_dataset=args.exclude_dataset, tiny_datasets=args.tiny_datasets, + use_latex=(not args.avoid_latex), ) From 11250d7ec51c510a9de7900254aca89d0f0d3726 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Thu, 19 Jun 2025 14:02:41 +0100 Subject: [PATCH 14/14] Use relative error bound for precip --- src/climatebenchpress/compressor/scripts/create_error_bounds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/climatebenchpress/compressor/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py index 8026cfe..e64be6a 100644 --- a/src/climatebenchpress/compressor/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -61,7 +61,7 @@ VAR_NAME_TO_ERROR_BOUND = { "rlut": REL_ERROR, "agb": REL_ERROR, - "pr": ABS_ERROR, + "pr": REL_ERROR, "ta": ABS_ERROR, "tos": ABS_ERROR, "10m_u_component_of_wind": REL_ERROR,