From 72691c81eabc78ee414c0dc545fc71de8ec7e49c Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Wed, 16 Apr 2025 09:52:31 +0100 Subject: [PATCH 1/5] Add plotting functionality --- compressed-datasets/.gitignore | 4 +- metrics/.gitignore | 4 +- plots/.gitignore | 3 + pyproject.toml | 1 + .../compressor/plotting/__init__.py | 1 + .../compressor/plotting/plot_metrics.py | 401 ++++++++++++++++++ .../compressor/plotting/variable_plotters.py | 200 +++++++++ 7 files changed, 610 insertions(+), 4 deletions(-) create mode 100644 plots/.gitignore create mode 100644 src/climatebenchpress/compressor/plotting/__init__.py create mode 100644 src/climatebenchpress/compressor/plotting/plot_metrics.py create mode 100644 src/climatebenchpress/compressor/plotting/variable_plotters.py diff --git a/compressed-datasets/.gitignore b/compressed-datasets/.gitignore index 0336a5a..0f3db69 100644 --- a/compressed-datasets/.gitignore +++ b/compressed-datasets/.gitignore @@ -1,2 +1,2 @@ -/*/*/decompressed.zarr -/*/*/measurements.json +/*/*/*/decompressed.zarr +/*/*/*/measurements.json diff --git a/metrics/.gitignore b/metrics/.gitignore index 581f497..dacf85c 100644 --- a/metrics/.gitignore +++ b/metrics/.gitignore @@ -1,3 +1,3 @@ -/*/*/metrics.csv -/*/*/tests.csv +/*/*/*/metrics.csv +/*/*/*/tests.csv /all_results.csv diff --git a/plots/.gitignore b/plots/.gitignore new file mode 100644 index 0000000..844f945 --- /dev/null +++ b/plots/.gitignore @@ -0,0 +1,3 @@ +/*.png +/*/*.png +/*/*/*.png diff --git a/pyproject.toml b/pyproject.toml index 8a92df2..1c6a9bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "numcodecs-wasm-zlib~=0.3.0", "pandas~=2.2", "scipy~=1.14", + "seaborn>=0.13.2", "tabulate~=0.9", "typed-classproperties~=1.1.0", "xarray>=2024.11.0,<2025.4", diff --git a/src/climatebenchpress/compressor/plotting/__init__.py b/src/climatebenchpress/compressor/plotting/__init__.py new file mode 100644 index 0000000..c9c2ef6 --- /dev/null +++ b/src/climatebenchpress/compressor/plotting/__init__.py @@ -0,0 +1 @@ +__all__: list[str] = [] diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py new file mode 100644 index 0000000..0d9d85f --- /dev/null +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -0,0 +1,401 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import xarray as xr + +from .variable_plotters import PLOTTERS + +COMPRESSOR2COLOR = { + "jpeg2000": "#EE7733", + "zfp": "#EE3377", + "sz3": "#CC3311", + "bitround-pco-conservative-rel": "#0077BB", + "bitround-conservative-rel": "#33BBEE", + "stochround": "#009988", + "tthresh": "#BBBBBB", +} + +COMPRESSOR2LEGEND_NAME = { + "jpeg2000": "JPEG2000", + "zfp": "ZFP", + "sz3": "SZ3", + "bitround-pco-conservative-rel": "BitRound + PCO", + "bitround-conservative-rel": "BitRound + Zlib", + "stochround": "StochRound", + "tthresh": "TTHRESH", +} + + +def plot_metrics( + basepath: Path = Path(), bound_names: list[str] = ["low", "mid", "high"] +): + metrics_path = basepath / "metrics" + plots_path = basepath / "plots" + + df = pd.read_csv(metrics_path / "all_results.csv") + plot_per_variable_metrics( + basepath=basepath, + plots_path=plots_path, + all_results=df, + ) + + df = rename_error_bounds(df, bound_names) + normalized_df = normalize(df, bound_normalize="mid") + + plot_bound_violations( + normalized_df, bound_names, plots_path / "bound_violations.png" + ) + + for metric in ["Normalized_MAE", "Normalized_DSSIM", "Normalized_MaxAbsError"]: + plot_aggregated_rd_curve( + normalized_df, + plots_path / f"rd_curve_{metric.lower()}.png", + compression_metric="Normalized_CR", + distortion_metric=metric, + agg="median", + bound_names=bound_names, + ) + + +def rename_error_bounds(df, bound_names): + """Give error bound consistent names between variables. By default the error bounds + have the pattern {variable_name}-{bound_type}={bound_value}.""" + # Get unique variables + variables = df["Variable"].unique() + + # Process each variable + for variable in variables: + var_selector = df["Variable"] == variable + var_data = df[var_selector] + + error_bounds = sorted( + var_data["Error Bound"].unique(), + key=lambda x: float(x.split("=")[1].split("_")[0]), + ) + + assert len(error_bounds) == len(bound_names) + for i in range(len(error_bounds)): + bound_selector = var_data["Error Bound"] == error_bounds[i] + df.loc[bound_selector & var_selector, "Error Bound"] = bound_names[i] + + return df + + +def normalize(data, bound_normalize="mid"): + """Generate normalized metrics for each compressor and variable. The normalization + first computes the 'best compressor' with the highest average rank over all variables (ranked by + compression ratio). + + For each metric, the normalization is done by dividing the metric by the value of the + 'best compressor' for the same variable and error bound, i.e.: + normalized_metric = metric[compressor, variable] / metric[best_compressor, variable]. + """ + # Group by Variable and rank compressors within each variable + ranked = data.copy() + ranked = ranked[ranked["Error Bound"] == bound_normalize] + ranked["CompRatio_Rank"] = ranked.groupby("Variable")[ + "Compression Ratio [raw B / enc B]" + ].rank(ascending=False) + + # Calculate average rank for each compressor across all variables + avg_ranks = ranked.groupby("Compressor")["CompRatio_Rank"].mean().reset_index() + avg_ranks.columns = ["Compressor", "Average_Rank"] + avg_ranks = avg_ranks.sort_values("Average_Rank") + + best_compressor = avg_ranks.iloc[0]["Compressor"] + + normalized = data.copy() + normalize_vars = [ + ("Compression Ratio [raw B / enc B]", "Normalized_CR"), + ("MAE", "Normalized_MAE"), + ("DSSIM", "Normalized_DSSIM"), + ("Max Absolute Error", "Normalized_MaxAbsError"), + ] + # Avoid negative values. By default, DSSIM is in the range [-1, 1]. + normalized["DSSIM"] = normalized["DSSIM"] + 1.0 + + def get_normalizer(row): + return normalized[ + (data["Compressor"] == best_compressor) + & (data["Variable"] == row["Variable"]) + & (data["Error Bound"] == bound_normalize) + ][col].item() + + for col, new_col in normalize_vars: + normalized[new_col] = normalized.apply( + lambda x: x[col] / get_normalizer(x), + axis=1, + ) + + return normalized + + +def plot_per_variable_metrics( + basepath: Path, plots_path: Path, all_results: pd.DataFrame +): + """Creates all the plots which only depend on a single variable.""" + for dataset in all_results["Dataset"].unique(): + df = all_results[all_results["Dataset"] == dataset] + dataset_plots_path = plots_path / dataset + dataset_plots_path.mkdir(parents=True, exist_ok=True) + + # For each variable and compressor, plot the input, output, and error fields. + variables = df["Variable"].unique() + for var in variables: + for dist_metric in ["Max Absolute Error", "MAE"]: + metric_name = dist_metric.lower().replace(" ", "_") + if df[df["Variable"] == var][dist_metric].isnull().all(): + continue + plot_variable_rd_curve( + df[df["Variable"] == var], + dataset_plots_path / f"{var}_compression_ratio_{metric_name}.png", + distortion_metric=dist_metric, + ) + + error_bounds = df[df["Variable"] == var]["Error Bound"].unique() + for err_bound in error_bounds: + compressors = df[ + (df["Variable"] == var) & (df["Error Bound"] == err_bound) + ]["Compressor"].unique() + + err_bound_path = dataset_plots_path / err_bound + err_bound_path.mkdir(parents=True, exist_ok=True) + for comp in compressors: + print(f"Plotting {var} error for {comp}...") + plot_variable_error( + basepath, + dataset, + err_bound, + comp, + var, + err_bound_path / f"{var}_{comp}.png", + ) + + +def plot_variable_error(basepath, dataset_name, error_bound, compressor, var, outfile): + if outfile.exists(): + # These plots can be quite expensive to generate, so we skip if they already exist. + return + + compressed = ( + basepath + / ".." + / "compressor" + / "compressed-datasets" + / dataset_name + / error_bound + / compressor + / "decompressed.zarr" + ) + input = ( + basepath + / ".." + / "data-loader" + / "datasets" + / dataset_name + / "standardized.zarr" + ) + + ds = xr.open_dataset(input, chunks=dict(), engine="zarr").compute() + ds_new = xr.open_dataset(compressed, chunks=dict(), engine="zarr").compute() + ds, ds_new = ds[var], ds_new[var] + + plotter = PLOTTERS.get(dataset_name, None) + if plotter: + plotter().plot(ds, ds_new, dataset_name, compressor, var, outfile) + else: + print(f"No plotter found for dataset {dataset_name}") + + +def plot_variable_rd_curve(df, outfile, distortion_metric): + plt.figure(figsize=(8, 6)) + compressors = df["Compressor"].unique() + for comp in compressors: + compressor_data = df[df["Compressor"] == comp] + sorting_ixs = np.argsort(compressor_data["Compression Ratio [raw B / enc B]"]) + compr_ratio = [ + compressor_data["Compression Ratio [raw B / enc B]"].iloc[i] + for i in sorting_ixs + ] + distortion = [compressor_data[distortion_metric].iloc[i] for i in sorting_ixs] + plt.plot( + compr_ratio, + distortion, + label=COMPRESSOR2LEGEND_NAME[comp], + marker="s", + color=COMPRESSOR2COLOR[comp], + linewidth=4, + markersize=8, + ) + + plt.xlabel("Compression Ratio [raw B / enc B]", fontsize=14) + plt.xscale("log") + if distortion_metric != "PSNR": + # PSNR is already on log scale. + plt.yscale("log") + plt.ylabel(distortion_metric, fontsize=14) + + plt.legend( + title="Compressor", + fontsize=10, + title_fontsize=12, + ) + plt.tick_params( + axis="both", + which="major", + labelsize=14, + length=12, + direction="in", + top=True, + right=True, + ) + plt.tick_params( + axis="both", + which="minor", + length=6, + direction="in", + top=True, + right=True, + ) + + plt.tight_layout() + plt.savefig(outfile, dpi=300) + plt.close() + + +def plot_aggregated_rd_curve( + normalized_df, + outfile, + compression_metric, + distortion_metric, + agg="median", + bound_names=["low", "mid", "high"], +): + plt.figure(figsize=(8, 6)) + compressors = normalized_df["Compressor"].unique() + agg_distortion = normalized_df.groupby(["Error Bound", "Compressor"])[ + [compression_metric, distortion_metric] + ].agg(agg) + for comp in compressors: + compr_ratio = [ + agg_distortion.loc[(bound, comp), compression_metric] + for bound in bound_names + ] + distortion = [ + agg_distortion.loc[(bound, comp), distortion_metric] + for bound in bound_names + ] + plt.plot( + compr_ratio, + distortion, + label=COMPRESSOR2LEGEND_NAME[comp], + marker="s", + color=COMPRESSOR2COLOR[comp], + linewidth=4, + markersize=8, + ) + + plt.xlabel(compression_metric, fontsize=14) + plt.xscale("log") + if "PSNR" not in distortion_metric: + # PSNR is already on log scale. + plt.yscale("log") + plt.ylabel(distortion_metric, fontsize=14) + + plt.legend( + title="Compressor", + fontsize=10, + title_fontsize=12, + ) + plt.tick_params( + axis="both", + which="major", + labelsize=14, + length=12, + direction="in", + top=True, + right=True, + ) + plt.tick_params( + axis="both", + which="minor", + length=6, + direction="in", + top=True, + right=True, + ) + + if "MAE" in distortion_metric: + plt.legend( + title="Compressor", + loc="upper right", + bbox_to_anchor=(0.95, 0.6), + fontsize=10, + title_fontsize=12, + ) + plt.xlabel("Normalized Compression Ratio", fontsize=14) + plt.ylabel("Normalized Mean Absolute Error", fontsize=14) + # Add an arrow pointing into the lower right corner + plt.annotate( + "", + xy=(0.97, 0.05), + xycoords="axes fraction", + xytext=(-60, 50), + textcoords="offset points", + arrowprops=dict(arrowstyle="-|>", color="grey", lw=2), + ) + plt.text( + 0.85, + 0.08, + "Better", + transform=plt.gca().transAxes, + fontsize=14, + fontweight="bold", + color="grey", + ha="center", + ) + + plt.tight_layout() + plt.savefig(outfile, dpi=300) + plt.close() + + +def plot_bound_violations(df, bound_names, outfile): + fig, axs = plt.subplots(1, 3, figsize=(len(bound_names) * 6, 6), sharey=True) + + for i, bound_name in enumerate(bound_names): + df_bound = df[df["Error Bound"] == bound_name] + pass_fail = df_bound.pivot( + index="Compressor", columns="Variable", values="Satisfies Bound (Passed)" + ) + pass_fail = pass_fail.astype(np.float32) + fraction_fail = df_bound.pivot( + index="Compressor", columns="Variable", values="Satisfies Bound (Value)" + ) + annotations = fraction_fail.map( + lambda x: "{:.2f}".format(x * 100) if x * 100 >= 0.01 else "<0.01" + ) + annotations[fraction_fail == 0.0] = "" + sns.heatmap( + pass_fail, + cbar=False, + cmap="vlag_r", + annot=annotations, + fmt="s", + linewidths=0.5, + ax=axs[i], + ) + axs[i].set_title(f"Bound: {bound_name}") + if i != 0: + axs[i].set_ylabel("") + + fig.tight_layout() + fig.savefig(outfile, dpi=300) + plt.close() + + +if __name__ == "__main__": + plot_metrics(basepath=Path()) diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py new file mode 100644 index 0000000..623203b --- /dev/null +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -0,0 +1,200 @@ +from abc import ABC, abstractmethod + +import cartopy.crs as ccrs +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np + + +class Plotter(ABC): + def __init__(self): + self.projection = ccrs.Robinson() + + @abstractmethod + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + pass + + def plot(self, ds, ds_new, dataset_name, compressor, var, outfile): + fig, ax = plt.subplots( + nrows=1, + ncols=3, + figsize=(20, 7), + subplot_kw={"projection": self.projection}, + ) + self.plot_fields(fig, ax, ds, ds_new, dataset_name, var) + ax[0].coastlines() + ax[1].coastlines() + ax[2].coastlines() + ax[0].set_title("Original Dataset") + ax[1].set_title("Compressed Dataset") + ax[2].set_title("Error") + fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") + plt.tight_layout() + plt.savefig(outfile) + plt.close() + + +class CmipAtmosPlotter(Plotter): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0, plev=3) + ds.isel(**selector).plot(ax=ax[0], transform=ccrs.PlateCarree()) + ds_new.isel(**selector).plot( + ax=ax[1], transform=ccrs.PlateCarree(), robust=True + ) + error = ds.isel(**selector) - ds_new.isel(**selector) + error.plot(ax=ax[2], transform=ccrs.PlateCarree()) + + +class CmipOceanPlotter(Plotter): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + pcm0 = ax[0].pcolormesh( + ds.longitude.values, + ds.latitude.values, + ds.isel(time=0).values.squeeze(), + transform=ccrs.PlateCarree(), + shading="auto", + cmap="coolwarm", + ) + fig.colorbar( + pcm0, ax=ax[0], orientation="vertical", fraction=0.046, pad=0.04 + ).set_label("degC") + + pcm1 = ax[1].pcolormesh( + ds_new.longitude.values, + ds_new.latitude.values, + ds_new.isel(time=0).values.squeeze(), + transform=ccrs.PlateCarree(), + shading="auto", + cmap="coolwarm", + ) + fig.colorbar( + pcm1, ax=ax[1], orientation="vertical", fraction=0.046, pad=0.04 + ).set_label("degC") + + error = ds.isel(time=0) - ds_new.isel(time=0) + pcm2 = ax[2].pcolormesh( + ds.longitude.values, + ds.latitude.values, + error.values.squeeze(), + transform=ccrs.PlateCarree(), + shading="auto", + cmap="coolwarm", + ) + fig.colorbar( + pcm2, ax=ax[2], orientation="vertical", fraction=0.046, pad=0.04 + ).set_label("degC") + + +class Era5Plotter(Plotter): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0) + error = ds.isel(**selector) - ds_new.isel(**selector) + + # Instead of using the inbuilt xarray plot method, we are manually doing + # the projection and calling pcolormesh. By doing so we can avoid having + # to do the projection three times and only have to do it once and re-use + # it between plots. + lons = ds.isel(**selector).longitude.values + lats = ds.isel(**selector).latitude.values + lon_grid, lat_grid = np.meshgrid(lons, lats) + xys = self.projection.transform_points(ccrs.PlateCarree(), lon_grid, lat_grid) + x, y = xys[..., 0], xys[..., 1] + # Wind variable plots coolwarm because they lie around 0 and change in sign + # signifies change in wind direction. + cmap = "coolwarm" if var.startswith("10m") else "viridis" + c1 = ax[0].pcolormesh(x, y, ds.isel(**selector).values.squeeze(), cmap=cmap) + c2 = ax[1].pcolormesh( + x, + y, + ds_new.isel(**selector).values.squeeze(), + cmap=cmap, + ) + c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") + for i, c in enumerate([c1, c2, c3]): + fig.colorbar(c, ax=ax[i], shrink=0.6) + + +class NextGEMSPlotter(Plotter): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0) + error = ds.isel(**selector) - ds_new.isel(**selector) + + lons = ds.isel(**selector).lon.values + lats = ds.isel(**selector).lat.values + lon_grid, lat_grid = np.meshgrid(lons, lats) + xys = self.projection.transform_points(ccrs.PlateCarree(), lon_grid, lat_grid) + x, y = xys[..., 0], xys[..., 1] + + cmap = "Blues" + max_val = max( + ds.isel(**selector).max().values.item(), + ds_new.isel(**selector).max().values.item(), + ) + color_norm = mcolors.LogNorm(vmin=1e-12, vmax=max_val) if var == "pr" else None + # Avoid zero values for log transformation for precipitation + offset = 1e-12 if var == "pr" else 0 + c1 = ax[0].pcolormesh( + x, + y, + ds.isel(**selector).values.squeeze() + offset, + norm=color_norm, + cmap=cmap, + ) + c2 = ax[1].pcolormesh( + x, + y, + ds_new.isel(**selector).values.squeeze() + offset, + norm=color_norm, + cmap=cmap, + ) + c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") + for i, c in enumerate([c1, c2, c3]): + fig.colorbar(c, ax=ax[i], shrink=0.6) + + +class CamsPlotter(Plotter): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(valid_time=0, pressure_level=3) + in_min = ds.isel(**selector).min().values.item() + in_max = ds.isel(**selector).max().values.item() + out_min = ds_new.isel(**selector).min().values.item() + out_max = ds_new.isel(**selector).max().values.item() + vmin, vmax = min(in_min, out_min), max(in_max, out_max) + vmin = max(vmin, 1e-14) # Avoid zero values for log transformation + color_norm = mcolors.LogNorm(vmin=vmin, vmax=vmax) + ds.isel(**selector).plot( + ax=ax[0], + transform=ccrs.PlateCarree(), + norm=color_norm, + cmap="gist_earth", + ) + ds_new.isel(**selector).plot( + ax=ax[1], + transform=ccrs.PlateCarree(), + norm=color_norm, + cmap="gist_earth", + ) + error = ds.isel(**selector) - ds_new.isel(**selector) + error.plot(ax=ax[2], transform=ccrs.PlateCarree()) + + +class EsaBiomassPlotter(Plotter): + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0) + ds.isel(**selector).plot(ax=ax[0]) + ds_new.isel(**selector).plot(ax=ax[1]) + error = ds.isel(**selector) - ds_new.isel(**selector) + error.plot(ax=ax[2]) + ax[0].set_title("Original Dataset") + ax[1].set_title("Compressed Dataset") + ax[2].set_title("Error") + + +PLOTTERS = { + "cams-nitrogen-dioxide-tiny": CamsPlotter, + "cmip6-access-ta-tiny": CmipAtmosPlotter, + "cmip6-access-tos-tiny": CmipOceanPlotter, + "era5-tiny": Era5Plotter, + "esa-biomass-cci-tiny": EsaBiomassPlotter, + "nextgems-icon-tiny": NextGEMSPlotter, +} From d272a7f76be1f67a8162d8af5dba79dae4c1abc6 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Wed, 16 Apr 2025 10:07:58 +0100 Subject: [PATCH 2/5] Update dependencies --- pyproject.toml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c6a9bb..6c9aab1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,10 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "astropy~=7.0.1", + "cartopy~=0.24.1", "cf-xarray~=0.10", "dask>=2024.12.0,<2025.4", + "matplotlib~=3.10.1", "numcodecs>=0.13.0,<0.17", "numcodecs-combinators[xarray]~=0.2.4", "numcodecs-observers~=0.1.1", @@ -24,7 +26,7 @@ dependencies = [ "numcodecs-wasm-zlib~=0.3.0", "pandas~=2.2", "scipy~=1.14", - "seaborn>=0.13.2", + "seaborn~=0.13.2", "tabulate~=0.9", "typed-classproperties~=1.1.0", "xarray>=2024.11.0,<2025.4", @@ -39,6 +41,7 @@ dev = [ "pre-commit~=4.0", "ruff~=0.9", "scipy-stubs~=1.15", + "types-seaborn~=0.13.2.20250111", ] [tool.setuptools.packages.find] @@ -48,5 +51,5 @@ where = ["src"] "climatebenchpress.compressor" = ["py.typed"] [[tool.mypy.overrides]] -module = ["numcodecs.*", "astropy.convolution.*"] +module = ["numcodecs.*", "astropy.convolution.*", "cartopy.*"] follow_untyped_imports = true From 4e29acbf9861baa96356781cd1e6ae1edf0208ec Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Thu, 17 Apr 2025 15:56:58 +0100 Subject: [PATCH 3/5] Matplotlib version change; pdf printing --- pyproject.toml | 2 +- .../compressor/plotting/plot_metrics.py | 34 ++++++++++-------- .../compressor/plotting/variable_plotters.py | 35 ++++++++++++++----- 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6c9aab1..80954e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "cartopy~=0.24.1", "cf-xarray~=0.10", "dask>=2024.12.0,<2025.4", - "matplotlib~=3.10.1", + "matplotlib~=3.8", "numcodecs>=0.13.0,<0.17", "numcodecs-combinators[xarray]~=0.2.4", "numcodecs-observers~=0.1.1", diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 0d9d85f..9289872 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -8,14 +8,14 @@ from .variable_plotters import PLOTTERS -COMPRESSOR2COLOR = { - "jpeg2000": "#EE7733", - "zfp": "#EE3377", - "sz3": "#CC3311", - "bitround-pco-conservative-rel": "#0077BB", - "bitround-conservative-rel": "#33BBEE", - "stochround": "#009988", - "tthresh": "#BBBBBB", +COMPRESSOR2LINEINFO = { + "jpeg2000": ("#EE7733", "-"), + "zfp": ("#EE3377", "--"), + "sz3": ("#CC3311", "-."), + "bitround-pco-conservative-rel": ("#0077BB", ":"), + "bitround-conservative-rel": ("#33BBEE", "-"), + "stochround": ("#009988", "--"), + "tthresh": ("#BBBBBB", "-."), } COMPRESSOR2LEGEND_NAME = { @@ -46,13 +46,13 @@ def plot_metrics( normalized_df = normalize(df, bound_normalize="mid") plot_bound_violations( - normalized_df, bound_names, plots_path / "bound_violations.png" + normalized_df, bound_names, plots_path / "bound_violations.pdf" ) for metric in ["Normalized_MAE", "Normalized_DSSIM", "Normalized_MaxAbsError"]: plot_aggregated_rd_curve( normalized_df, - plots_path / f"rd_curve_{metric.lower()}.png", + plots_path / f"rd_curve_{metric.lower()}.pdf", compression_metric="Normalized_CR", distortion_metric=metric, agg="median", @@ -151,7 +151,7 @@ def plot_per_variable_metrics( continue plot_variable_rd_curve( df[df["Variable"] == var], - dataset_plots_path / f"{var}_compression_ratio_{metric_name}.png", + dataset_plots_path / f"{var}_compression_ratio_{metric_name}.pdf", distortion_metric=dist_metric, ) @@ -221,12 +221,14 @@ def plot_variable_rd_curve(df, outfile, distortion_metric): for i in sorting_ixs ] distortion = [compressor_data[distortion_metric].iloc[i] for i in sorting_ixs] + color, linestyle = COMPRESSOR2LINEINFO[comp] plt.plot( compr_ratio, distortion, label=COMPRESSOR2LEGEND_NAME[comp], marker="s", - color=COMPRESSOR2COLOR[comp], + color=color, + linestyle=linestyle, linewidth=4, markersize=8, ) @@ -288,12 +290,14 @@ def plot_aggregated_rd_curve( agg_distortion.loc[(bound, comp), distortion_metric] for bound in bound_names ] + color, linestyle = COMPRESSOR2LINEINFO[comp] plt.plot( compr_ratio, distortion, label=COMPRESSOR2LEGEND_NAME[comp], marker="s", - color=COMPRESSOR2COLOR[comp], + color=color, + linestyle=linestyle, linewidth=4, markersize=8, ) @@ -336,8 +340,8 @@ def plot_aggregated_rd_curve( fontsize=10, title_fontsize=12, ) - plt.xlabel("Normalized Compression Ratio", fontsize=14) - plt.ylabel("Normalized Mean Absolute Error", fontsize=14) + plt.xlabel("Median Compression Ratio Relative to SZ3", fontsize=14) + plt.ylabel("Median Mean Absolute Error Relative to SZ3", fontsize=14) # Add an arrow pointing into the lower right corner plt.annotate( "", diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index 623203b..5fed2b8 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -7,6 +7,8 @@ class Plotter(ABC): + datasets: list[str] + def __init__(self): self.projection = ccrs.Robinson() @@ -35,6 +37,8 @@ def plot(self, ds, ds_new, dataset_name, compressor, var, outfile): class CmipAtmosPlotter(Plotter): + datasets = ["cmip6-access-ta-tiny", "cmip6-access-ta"] + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): selector = dict(time=0, plev=3) ds.isel(**selector).plot(ax=ax[0], transform=ccrs.PlateCarree()) @@ -46,6 +50,8 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): class CmipOceanPlotter(Plotter): + datasets = ["cmip6-access-tos-tiny", "cmip6-access-tos"] + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): pcm0 = ax[0].pcolormesh( ds.longitude.values, @@ -86,6 +92,8 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): class Era5Plotter(Plotter): + datasets = ["era5-tiny", "era5"] + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): selector = dict(time=0) error = ds.isel(**selector) - ds_new.isel(**selector) @@ -115,6 +123,8 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): class NextGEMSPlotter(Plotter): + datasets = ["nextgems-icon-tiny", "nextgems-icon"] + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): selector = dict(time=0) error = ds.isel(**selector) - ds_new.isel(**selector) @@ -153,6 +163,8 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): class CamsPlotter(Plotter): + datasets = ["cams-nitrogen-dioxide-tiny", "cams-nitrogen-dioxide"] + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): selector = dict(valid_time=0, pressure_level=3) in_min = ds.isel(**selector).min().values.item() @@ -179,6 +191,8 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): class EsaBiomassPlotter(Plotter): + datasets = ["esa-biomass-cci-tiny", "esa-biomass-cci"] + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): selector = dict(time=0) ds.isel(**selector).plot(ax=ax[0]) @@ -190,11 +204,16 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): ax[2].set_title("Error") -PLOTTERS = { - "cams-nitrogen-dioxide-tiny": CamsPlotter, - "cmip6-access-ta-tiny": CmipAtmosPlotter, - "cmip6-access-tos-tiny": CmipOceanPlotter, - "era5-tiny": Era5Plotter, - "esa-biomass-cci-tiny": EsaBiomassPlotter, - "nextgems-icon-tiny": NextGEMSPlotter, -} +plotter_clss: list[type[Plotter]] = [ + CamsPlotter, + CmipAtmosPlotter, + CmipOceanPlotter, + Era5Plotter, + EsaBiomassPlotter, + NextGEMSPlotter, +] +PLOTTERS: dict[str, type[Plotter]] = dict() +for plotter_cls in plotter_clss: + for dataset in plotter_cls.datasets: + assert dataset not in PLOTTERS, f"Duplicate dataset found: {dataset}" + PLOTTERS[dataset] = plotter_cls From c3658e65858fbe62c78121f212a31f95594099e4 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Fri, 18 Apr 2025 11:04:10 +0100 Subject: [PATCH 4/5] Rename normalized metrics to relative metrics --- .../compressor/plotting/plot_metrics.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 9289872..62eeabb 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -49,11 +49,11 @@ def plot_metrics( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) - for metric in ["Normalized_MAE", "Normalized_DSSIM", "Normalized_MaxAbsError"]: + for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]: plot_aggregated_rd_curve( normalized_df, - plots_path / f"rd_curve_{metric.lower()}.pdf", - compression_metric="Normalized_CR", + plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", + compression_metric="Relative CR", distortion_metric=metric, agg="median", bound_names=bound_names, @@ -109,10 +109,10 @@ def normalize(data, bound_normalize="mid"): normalized = data.copy() normalize_vars = [ - ("Compression Ratio [raw B / enc B]", "Normalized_CR"), - ("MAE", "Normalized_MAE"), - ("DSSIM", "Normalized_DSSIM"), - ("Max Absolute Error", "Normalized_MaxAbsError"), + ("Compression Ratio [raw B / enc B]", "Relative CR"), + ("MAE", "Relative MAE"), + ("DSSIM", "Relative DSSIM"), + ("Max Absolute Error", "Relative MaxAbsError"), ] # Avoid negative values. By default, DSSIM is in the range [-1, 1]. normalized["DSSIM"] = normalized["DSSIM"] + 1.0 @@ -302,12 +302,12 @@ def plot_aggregated_rd_curve( markersize=8, ) - plt.xlabel(compression_metric, fontsize=14) + plt.xlabel(f"{agg.title()} {compression_metric}", fontsize=14) plt.xscale("log") if "PSNR" not in distortion_metric: # PSNR is already on log scale. plt.yscale("log") - plt.ylabel(distortion_metric, fontsize=14) + plt.ylabel(f"{agg.title()} {distortion_metric}", fontsize=14) plt.legend( title="Compressor", From ee6b0d016a8a2eb64f39cb928c79fff1d1af3a14 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Fri, 18 Apr 2025 11:04:44 +0100 Subject: [PATCH 5/5] Add rasterized and dpi=300 to plotting --- .../compressor/plotting/variable_plotters.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index 5fed2b8..ee202ff 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -31,8 +31,8 @@ def plot(self, ds, ds_new, dataset_name, compressor, var, outfile): ax[1].set_title("Compressed Dataset") ax[2].set_title("Error") fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") - plt.tight_layout() - plt.savefig(outfile) + fig.tight_layout() + fig.savefig(outfile, dpi=300) plt.close() @@ -46,7 +46,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): ax=ax[1], transform=ccrs.PlateCarree(), robust=True ) error = ds.isel(**selector) - ds_new.isel(**selector) - error.plot(ax=ax[2], transform=ccrs.PlateCarree()) + error.plot(ax=ax[2], transform=ccrs.PlateCarree(), rasterized=True) class CmipOceanPlotter(Plotter): @@ -60,6 +60,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): transform=ccrs.PlateCarree(), shading="auto", cmap="coolwarm", + rasterized=True, ) fig.colorbar( pcm0, ax=ax[0], orientation="vertical", fraction=0.046, pad=0.04 @@ -72,6 +73,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): transform=ccrs.PlateCarree(), shading="auto", cmap="coolwarm", + rasterized=True, ) fig.colorbar( pcm1, ax=ax[1], orientation="vertical", fraction=0.046, pad=0.04 @@ -85,6 +87,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): transform=ccrs.PlateCarree(), shading="auto", cmap="coolwarm", + rasterized=True, ) fig.colorbar( pcm2, ax=ax[2], orientation="vertical", fraction=0.046, pad=0.04 @@ -116,6 +119,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): y, ds_new.isel(**selector).values.squeeze(), cmap=cmap, + rasterized=True, ) c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") for i, c in enumerate([c1, c2, c3]): @@ -149,6 +153,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): ds.isel(**selector).values.squeeze() + offset, norm=color_norm, cmap=cmap, + rasterized=True, ) c2 = ax[1].pcolormesh( x, @@ -156,6 +161,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): ds_new.isel(**selector).values.squeeze() + offset, norm=color_norm, cmap=cmap, + rasterized=True, ) c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") for i, c in enumerate([c1, c2, c3]): @@ -179,12 +185,14 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): transform=ccrs.PlateCarree(), norm=color_norm, cmap="gist_earth", + rasterized=True, ) ds_new.isel(**selector).plot( ax=ax[1], transform=ccrs.PlateCarree(), norm=color_norm, cmap="gist_earth", + rasterized=True, ) error = ds.isel(**selector) - ds_new.isel(**selector) error.plot(ax=ax[2], transform=ccrs.PlateCarree()) @@ -198,7 +206,7 @@ def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): ds.isel(**selector).plot(ax=ax[0]) ds_new.isel(**selector).plot(ax=ax[1]) error = ds.isel(**selector) - ds_new.isel(**selector) - error.plot(ax=ax[2]) + error.plot(ax=ax[2], rasterized=True) ax[0].set_title("Original Dataset") ax[1].set_title("Compressed Dataset") ax[2].set_title("Error")