Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions src/climatebenchpress/compressor/plotting/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,27 @@
"zfp-round": "ZFP-ROUND",
"sz3": "SZ3",
"bitround-pco-conservative-rel": "BitRound + PCO",
"bitround-conservative-rel": "BitRound + Zlib",
"stochround": "StochRound + Zlib",
"bitround-conservative-rel": "BitRound + Zstd",
"stochround": "StochRound + Zstd",
"stochround-pco": "StochRound + PCO",
"tthresh": "TTHRESH",
}


def plot_metrics(
basepath: Path = Path(), bound_names: list[str] = ["low", "mid", "high"]
basepath: Path = Path(),
data_loader_base_path: None | Path = None,
bound_names: list[str] = ["low", "mid", "high"],
):
metrics_path = basepath / "metrics"
plots_path = basepath / "plots"
datasets = (data_loader_base_path or basepath) / "datasets"
compressed_datasets = basepath / "compressed-datasets"

df = pd.read_csv(metrics_path / "all_results.csv")
plot_per_variable_metrics(
basepath=basepath,
datasets=datasets,
compressed_datasets=compressed_datasets,
plots_path=plots_path,
all_results=df,
)
Expand All @@ -56,9 +61,9 @@ def plot_metrics(
for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]:
plot_aggregated_rd_curve(
normalized_df,
plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf",
compression_metric="Relative CR",
distortion_metric=metric,
outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf",
agg="median",
bound_names=bound_names,
)
Expand Down Expand Up @@ -138,7 +143,10 @@ def get_normalizer(row):


def plot_per_variable_metrics(
basepath: Path, plots_path: Path, all_results: pd.DataFrame
datasets: Path,
compressed_datasets: Path,
plots_path: Path,
all_results: pd.DataFrame,
):
"""Creates all the plots which only depend on a single variable."""
for dataset in all_results["Dataset"].unique():
Expand All @@ -155,8 +163,9 @@ def plot_per_variable_metrics(
continue
plot_variable_rd_curve(
df[df["Variable"] == var],
dataset_plots_path / f"{var}_compression_ratio_{metric_name}.pdf",
distortion_metric=dist_metric,
outfile=dataset_plots_path
/ f"{var}_compression_ratio_{metric_name}.pdf",
)

error_bounds = df[df["Variable"] == var]["Error Bound"].unique()
Expand All @@ -170,51 +179,49 @@ def plot_per_variable_metrics(
for comp in compressors:
print(f"Plotting {var} error for {comp}...")
plot_variable_error(
basepath,
datasets,
compressed_datasets,
dataset,
err_bound,
comp,
var,
err_bound_path / f"{var}_{comp}.png",
outfile=err_bound_path / f"{var}_{comp}.png",
)


def plot_variable_error(basepath, dataset_name, error_bound, compressor, var, outfile):
if outfile.exists():
def plot_variable_error(
datasets: Path,
compressed_datasets: Path,
dataset_name: str,
error_bound: str,
compressor: str,
var: str,
outfile: None | Path = None,
):
if outfile is not None and outfile.exists():
# These plots can be quite expensive to generate, so we skip if they already exist.
return

compressed = (
basepath
/ ".."
/ "compressor"
/ "compressed-datasets"
compressed_datasets
/ dataset_name
/ error_bound
/ compressor
/ "decompressed.zarr"
)
input = (
basepath
/ ".."
/ "data-loader"
/ "datasets"
/ dataset_name
/ "standardized.zarr"
)
input = datasets / dataset_name / "standardized.zarr"

ds = xr.open_dataset(input, chunks=dict(), engine="zarr").compute()
ds_new = xr.open_dataset(compressed, chunks=dict(), engine="zarr").compute()
ds, ds_new = ds[var], ds_new[var]

plotter = PLOTTERS.get(dataset_name, None)
if plotter:
plotter().plot(ds, ds_new, dataset_name, compressor, var, outfile)
plotter().plot(ds[var], ds_new[var], dataset_name, compressor, var, outfile)
else:
print(f"No plotter found for dataset {dataset_name}")


def plot_variable_rd_curve(df, outfile, distortion_metric):
def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None):
plt.figure(figsize=(8, 6))
compressors = df["Compressor"].unique()
for comp in compressors:
Expand Down Expand Up @@ -268,15 +275,17 @@ def plot_variable_rd_curve(df, outfile, distortion_metric):
)

plt.tight_layout()
plt.savefig(outfile, dpi=300)
if outfile is not None:
with outfile.open("wb") as f:
plt.savefig(f, dpi=300)
plt.close()


def plot_aggregated_rd_curve(
normalized_df,
outfile,
compression_metric,
distortion_metric,
outfile: None | Path = None,
agg="median",
bound_names=["low", "mid", "high"],
):
Expand Down Expand Up @@ -367,11 +376,13 @@ def plot_aggregated_rd_curve(
)

plt.tight_layout()
plt.savefig(outfile, dpi=300)
if outfile is not None:
with outfile.open("wb") as f:
plt.savefig(f, dpi=300)
plt.close()


def plot_bound_violations(df, bound_names, outfile):
def plot_bound_violations(df, bound_names, outfile: None | Path = None):
fig, axs = plt.subplots(1, 3, figsize=(len(bound_names) * 6, 6), sharey=True)

for i, bound_name in enumerate(bound_names):
Expand Down Expand Up @@ -401,7 +412,9 @@ def plot_bound_violations(df, bound_names, outfile):
axs[i].set_ylabel("")

fig.tight_layout()
fig.savefig(outfile, dpi=300)
if outfile is not None:
with outfile.open("wb") as f:
fig.savefig(f, dpi=300)
plt.close()


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from pathlib import Path

import cartopy.crs as ccrs
import matplotlib.colors as mcolors
Expand All @@ -16,7 +17,9 @@ def __init__(self):
def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var):
pass

def plot(self, ds, ds_new, dataset_name, compressor, var, outfile):
def plot(
self, ds, ds_new, dataset_name, compressor, var, outfile: None | Path = None
):
fig, ax = plt.subplots(
nrows=1,
ncols=3,
Expand All @@ -32,7 +35,9 @@ def plot(self, ds, ds_new, dataset_name, compressor, var, outfile):
ax[2].set_title("Error")
fig.suptitle(f"{var} Error for {dataset_name} ({compressor})")
fig.tight_layout()
fig.savefig(outfile, dpi=300)
if outfile is not None:
with outfile.open("wb") as f:
fig.savefig(f, dpi=300)
plt.close()


Expand Down
Loading