diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index c1861b9..8f8bdf1 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -26,22 +26,27 @@ "zfp-round": "ZFP-ROUND", "sz3": "SZ3", "bitround-pco-conservative-rel": "BitRound + PCO", - "bitround-conservative-rel": "BitRound + Zlib", - "stochround": "StochRound + Zlib", + "bitround-conservative-rel": "BitRound + Zstd", + "stochround": "StochRound + Zstd", "stochround-pco": "StochRound + PCO", "tthresh": "TTHRESH", } def plot_metrics( - basepath: Path = Path(), bound_names: list[str] = ["low", "mid", "high"] + basepath: Path = Path(), + data_loader_base_path: None | Path = None, + bound_names: list[str] = ["low", "mid", "high"], ): metrics_path = basepath / "metrics" plots_path = basepath / "plots" + datasets = (data_loader_base_path or basepath) / "datasets" + compressed_datasets = basepath / "compressed-datasets" df = pd.read_csv(metrics_path / "all_results.csv") plot_per_variable_metrics( - basepath=basepath, + datasets=datasets, + compressed_datasets=compressed_datasets, plots_path=plots_path, all_results=df, ) @@ -56,9 +61,9 @@ def plot_metrics( for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]: plot_aggregated_rd_curve( normalized_df, - plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", compression_metric="Relative CR", distortion_metric=metric, + outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="median", bound_names=bound_names, ) @@ -138,7 +143,10 @@ def get_normalizer(row): def plot_per_variable_metrics( - basepath: Path, plots_path: Path, all_results: pd.DataFrame + datasets: Path, + compressed_datasets: Path, + plots_path: Path, + all_results: pd.DataFrame, ): """Creates all the plots which only depend on a single variable.""" for dataset in all_results["Dataset"].unique(): @@ -155,8 +163,9 @@ def plot_per_variable_metrics( continue plot_variable_rd_curve( df[df["Variable"] == var], - dataset_plots_path / f"{var}_compression_ratio_{metric_name}.pdf", distortion_metric=dist_metric, + outfile=dataset_plots_path + / f"{var}_compression_ratio_{metric_name}.pdf", ) error_bounds = df[df["Variable"] == var]["Error Bound"].unique() @@ -170,51 +179,49 @@ def plot_per_variable_metrics( for comp in compressors: print(f"Plotting {var} error for {comp}...") plot_variable_error( - basepath, + datasets, + compressed_datasets, dataset, err_bound, comp, var, - err_bound_path / f"{var}_{comp}.png", + outfile=err_bound_path / f"{var}_{comp}.png", ) -def plot_variable_error(basepath, dataset_name, error_bound, compressor, var, outfile): - if outfile.exists(): +def plot_variable_error( + datasets: Path, + compressed_datasets: Path, + dataset_name: str, + error_bound: str, + compressor: str, + var: str, + outfile: None | Path = None, +): + if outfile is not None and outfile.exists(): # These plots can be quite expensive to generate, so we skip if they already exist. return compressed = ( - basepath - / ".." - / "compressor" - / "compressed-datasets" + compressed_datasets / dataset_name / error_bound / compressor / "decompressed.zarr" ) - input = ( - basepath - / ".." - / "data-loader" - / "datasets" - / dataset_name - / "standardized.zarr" - ) + input = datasets / dataset_name / "standardized.zarr" ds = xr.open_dataset(input, chunks=dict(), engine="zarr").compute() ds_new = xr.open_dataset(compressed, chunks=dict(), engine="zarr").compute() - ds, ds_new = ds[var], ds_new[var] plotter = PLOTTERS.get(dataset_name, None) if plotter: - plotter().plot(ds, ds_new, dataset_name, compressor, var, outfile) + plotter().plot(ds[var], ds_new[var], dataset_name, compressor, var, outfile) else: print(f"No plotter found for dataset {dataset_name}") -def plot_variable_rd_curve(df, outfile, distortion_metric): +def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None): plt.figure(figsize=(8, 6)) compressors = df["Compressor"].unique() for comp in compressors: @@ -268,15 +275,17 @@ def plot_variable_rd_curve(df, outfile, distortion_metric): ) plt.tight_layout() - plt.savefig(outfile, dpi=300) + if outfile is not None: + with outfile.open("wb") as f: + plt.savefig(f, dpi=300) plt.close() def plot_aggregated_rd_curve( normalized_df, - outfile, compression_metric, distortion_metric, + outfile: None | Path = None, agg="median", bound_names=["low", "mid", "high"], ): @@ -367,11 +376,13 @@ def plot_aggregated_rd_curve( ) plt.tight_layout() - plt.savefig(outfile, dpi=300) + if outfile is not None: + with outfile.open("wb") as f: + plt.savefig(f, dpi=300) plt.close() -def plot_bound_violations(df, bound_names, outfile): +def plot_bound_violations(df, bound_names, outfile: None | Path = None): fig, axs = plt.subplots(1, 3, figsize=(len(bound_names) * 6, 6), sharey=True) for i, bound_name in enumerate(bound_names): @@ -401,7 +412,9 @@ def plot_bound_violations(df, bound_names, outfile): axs[i].set_ylabel("") fig.tight_layout() - fig.savefig(outfile, dpi=300) + if outfile is not None: + with outfile.open("wb") as f: + fig.savefig(f, dpi=300) plt.close() diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py index ee202ff..8c76858 100644 --- a/src/climatebenchpress/compressor/plotting/variable_plotters.py +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from pathlib import Path import cartopy.crs as ccrs import matplotlib.colors as mcolors @@ -16,7 +17,9 @@ def __init__(self): def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): pass - def plot(self, ds, ds_new, dataset_name, compressor, var, outfile): + def plot( + self, ds, ds_new, dataset_name, compressor, var, outfile: None | Path = None + ): fig, ax = plt.subplots( nrows=1, ncols=3, @@ -32,7 +35,9 @@ def plot(self, ds, ds_new, dataset_name, compressor, var, outfile): ax[2].set_title("Error") fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") fig.tight_layout() - fig.savefig(outfile, dpi=300) + if outfile is not None: + with outfile.open("wb") as f: + fig.savefig(f, dpi=300) plt.close() diff --git a/src/climatebenchpress/py.typed b/src/climatebenchpress/compressor/py.typed similarity index 100% rename from src/climatebenchpress/py.typed rename to src/climatebenchpress/compressor/py.typed