diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 8f8bdf1..af8d622 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -10,6 +10,7 @@ COMPRESSOR2LINEINFO = { "jpeg2000": ("#EE7733", "-"), + "sperr": ("#000000", ":"), "zfp": ("#EE3377", "--"), "zfp-round": ("#DDAA33", "--"), "sz3": ("#CC3311", "-."), @@ -17,11 +18,12 @@ "bitround-conservative-rel": ("#33BBEE", "-"), "stochround": ("#009988", "--"), "stochround-pco": ("#BBBBBB", "--"), - "tthresh": ("#000000", "-."), + "tthresh": ("#882255", "-."), } COMPRESSOR2LEGEND_NAME = { "jpeg2000": "JPEG2000", + "sperr": "SPERR", "zfp": "ZFP", "zfp-round": "ZFP-ROUND", "sz3": "SZ3", @@ -37,6 +39,7 @@ def plot_metrics( basepath: Path = Path(), data_loader_base_path: None | Path = None, bound_names: list[str] = ["low", "mid", "high"], + normalizer: str = "sz3", ): metrics_path = basepath / "metrics" plots_path = basepath / "plots" @@ -44,6 +47,11 @@ def plot_metrics( compressed_datasets = basepath / "compressed-datasets" df = pd.read_csv(metrics_path / "all_results.csv") + df = df[ + np.logical_and( + df["Compressor"] != "tthresh", df["Dataset"].str.endswith("tiny") + ) + ] plot_per_variable_metrics( datasets=datasets, compressed_datasets=compressed_datasets, @@ -52,21 +60,23 @@ def plot_metrics( ) df = rename_error_bounds(df, bound_names) - normalized_df = normalize(df, bound_normalize="mid") + normalized_df = normalize(df, bound_normalize="mid", normalizer=normalizer) plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]: - plot_aggregated_rd_curve( - normalized_df, - compression_metric="Relative CR", - distortion_metric=metric, - outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", - agg="median", - bound_names=bound_names, - ) + with plt.rc_context(rc={"text.usetex": True}): + plot_aggregated_rd_curve( + normalized_df, + normalizer=normalizer, + compression_metric="Relative CR", + distortion_metric=metric, + outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", + agg="median", + bound_names=bound_names, + ) def rename_error_bounds(df, bound_names): @@ -93,28 +103,30 @@ def rename_error_bounds(df, bound_names): return df -def normalize(data, bound_normalize="mid"): +def normalize(data, bound_normalize="mid", normalizer=None): """Generate normalized metrics for each compressor and variable. The normalization - first computes the 'best compressor' with the highest average rank over all variables (ranked by + is done either with respect to either a user provided compressor or the + compressor with the highest average rank over all variables (ranked by compression ratio). For each metric, the normalization is done by dividing the metric by the value of the - 'best compressor' for the same variable and error bound, i.e.: - normalized_metric = metric[compressor, variable] / metric[best_compressor, variable]. + normalizer for the same variable and error bound, i.e.: + normalized_metric = metric[compressor, variable] / metric[normalizer, variable]. """ - # Group by Variable and rank compressors within each variable - ranked = data.copy() - ranked = ranked[ranked["Error Bound"] == bound_normalize] - ranked["CompRatio_Rank"] = ranked.groupby("Variable")[ - "Compression Ratio [raw B / enc B]" - ].rank(ascending=False) + if normalizer is None: + # Group by Variable and rank compressors within each variable + ranked = data.copy() + ranked = ranked[ranked["Error Bound"] == bound_normalize] + ranked["CompRatio_Rank"] = ranked.groupby("Variable")[ + "Compression Ratio [raw B / enc B]" + ].rank(ascending=False) - # Calculate average rank for each compressor across all variables - avg_ranks = ranked.groupby("Compressor")["CompRatio_Rank"].mean().reset_index() - avg_ranks.columns = ["Compressor", "Average_Rank"] - avg_ranks = avg_ranks.sort_values("Average_Rank") + # Calculate average rank for each compressor across all variables + avg_ranks = ranked.groupby("Compressor")["CompRatio_Rank"].mean().reset_index() + avg_ranks.columns = ["Compressor", "Average_Rank"] + avg_ranks = avg_ranks.sort_values("Average_Rank") - best_compressor = avg_ranks.iloc[0]["Compressor"] + normalizer = avg_ranks.iloc[0]["Compressor"] normalized = data.copy() normalize_vars = [ @@ -128,7 +140,7 @@ def normalize(data, bound_normalize="mid"): def get_normalizer(row): return normalized[ - (data["Compressor"] == best_compressor) + (data["Compressor"] == normalizer) & (data["Variable"] == row["Variable"]) & (data["Error Bound"] == bound_normalize) ][col].item() @@ -276,13 +288,13 @@ def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None): plt.tight_layout() if outfile is not None: - with outfile.open("wb") as f: - plt.savefig(f, dpi=300) + savefig(outfile) plt.close() def plot_aggregated_rd_curve( normalized_df, + normalizer, compression_metric, distortion_metric, outfile: None | Path = None, @@ -349,36 +361,47 @@ def plot_aggregated_rd_curve( plt.legend( title="Compressor", loc="upper right", - bbox_to_anchor=(0.95, 0.6), - fontsize=10, - title_fontsize=12, + bbox_to_anchor=(0.95, 0.7), + fontsize=12, + title_fontsize=14, + ) + normalizer_label = COMPRESSOR2LEGEND_NAME.get(normalizer, normalizer) + plt.xlabel( + rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)", + fontsize=16, ) - plt.xlabel("Median Compression Ratio Relative to SZ3", fontsize=14) - plt.ylabel("Median Mean Absolute Error Relative to SZ3", fontsize=14) + plt.ylabel( + rf"Median Mean Absolute Error Relative to {normalizer_label} ($\downarrow$)", + fontsize=16, + ) + arrow_color = "black" # Add an arrow pointing into the lower right corner plt.annotate( "", - xy=(0.97, 0.05), + xy=(0.95, 0.05), xycoords="axes fraction", xytext=(-60, 50), textcoords="offset points", - arrowprops=dict(arrowstyle="-|>", color="grey", lw=2), + arrowprops=dict( + arrowstyle="-|>, head_length=0.5, head_width=0.5", + color=arrow_color, + lw=5, + ), ) plt.text( - 0.85, + 0.83, 0.08, "Better", transform=plt.gca().transAxes, - fontsize=14, + fontsize=16, fontweight="bold", - color="grey", + color=arrow_color, ha="center", ) plt.tight_layout() if outfile is not None: - with outfile.open("wb") as f: - plt.savefig(f, dpi=300) + savefig(outfile) plt.close() @@ -413,10 +436,21 @@ def plot_bound_violations(df, bound_names, outfile: None | Path = None): fig.tight_layout() if outfile is not None: - with outfile.open("wb") as f: - fig.savefig(f, dpi=300) + savefig(outfile) plt.close() +def savefig(outfile: Path): + ispdf = outfile.suffix == ".pdf" + if ispdf: + # Saving a PDF with the alternative code below leads to a corrupted file. + # Hence, we use the default savefig method. + # NOTE: This means passing a virtual UPath is only supported for non-PDF files. + plt.savefig(outfile, dpi=300) + else: + with outfile.open("wb") as f: + plt.savefig(f, dpi=300) + + if __name__ == "__main__": - plot_metrics(basepath=Path()) + plot_metrics(basepath=Path(), data_loader_base_path=Path() / ".." / "data-loader")