Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 77 additions & 43 deletions src/climatebenchpress/compressor/plotting/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@

COMPRESSOR2LINEINFO = {
"jpeg2000": ("#EE7733", "-"),
"sperr": ("#000000", ":"),
"zfp": ("#EE3377", "--"),
"zfp-round": ("#DDAA33", "--"),
"sz3": ("#CC3311", "-."),
"bitround-pco-conservative-rel": ("#0077BB", ":"),
"bitround-conservative-rel": ("#33BBEE", "-"),
"stochround": ("#009988", "--"),
"stochround-pco": ("#BBBBBB", "--"),
"tthresh": ("#000000", "-."),
"tthresh": ("#882255", "-."),
}

COMPRESSOR2LEGEND_NAME = {
"jpeg2000": "JPEG2000",
"sperr": "SPERR",
"zfp": "ZFP",
"zfp-round": "ZFP-ROUND",
"sz3": "SZ3",
Expand All @@ -37,13 +39,19 @@ def plot_metrics(
basepath: Path = Path(),
data_loader_base_path: None | Path = None,
bound_names: list[str] = ["low", "mid", "high"],
normalizer: str = "sz3",
):
metrics_path = basepath / "metrics"
plots_path = basepath / "plots"
datasets = (data_loader_base_path or basepath) / "datasets"
compressed_datasets = basepath / "compressed-datasets"

df = pd.read_csv(metrics_path / "all_results.csv")
df = df[
np.logical_and(
df["Compressor"] != "tthresh", df["Dataset"].str.endswith("tiny")
)
]
plot_per_variable_metrics(
datasets=datasets,
compressed_datasets=compressed_datasets,
Expand All @@ -52,21 +60,23 @@ def plot_metrics(
)

df = rename_error_bounds(df, bound_names)
normalized_df = normalize(df, bound_normalize="mid")
normalized_df = normalize(df, bound_normalize="mid", normalizer=normalizer)

plot_bound_violations(
normalized_df, bound_names, plots_path / "bound_violations.pdf"
)

for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]:
plot_aggregated_rd_curve(
normalized_df,
compression_metric="Relative CR",
distortion_metric=metric,
outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf",
agg="median",
bound_names=bound_names,
)
with plt.rc_context(rc={"text.usetex": True}):
plot_aggregated_rd_curve(
normalized_df,
normalizer=normalizer,
compression_metric="Relative CR",
distortion_metric=metric,
outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf",
agg="median",
bound_names=bound_names,
)


def rename_error_bounds(df, bound_names):
Expand All @@ -93,28 +103,30 @@ def rename_error_bounds(df, bound_names):
return df


def normalize(data, bound_normalize="mid"):
def normalize(data, bound_normalize="mid", normalizer=None):
"""Generate normalized metrics for each compressor and variable. The normalization
first computes the 'best compressor' with the highest average rank over all variables (ranked by
is done either with respect to either a user provided compressor or the
compressor with the highest average rank over all variables (ranked by
compression ratio).

For each metric, the normalization is done by dividing the metric by the value of the
'best compressor' for the same variable and error bound, i.e.:
normalized_metric = metric[compressor, variable] / metric[best_compressor, variable].
normalizer for the same variable and error bound, i.e.:
normalized_metric = metric[compressor, variable] / metric[normalizer, variable].
"""
# Group by Variable and rank compressors within each variable
ranked = data.copy()
ranked = ranked[ranked["Error Bound"] == bound_normalize]
ranked["CompRatio_Rank"] = ranked.groupby("Variable")[
"Compression Ratio [raw B / enc B]"
].rank(ascending=False)
if normalizer is None:
# Group by Variable and rank compressors within each variable
ranked = data.copy()
ranked = ranked[ranked["Error Bound"] == bound_normalize]
ranked["CompRatio_Rank"] = ranked.groupby("Variable")[
"Compression Ratio [raw B / enc B]"
].rank(ascending=False)

# Calculate average rank for each compressor across all variables
avg_ranks = ranked.groupby("Compressor")["CompRatio_Rank"].mean().reset_index()
avg_ranks.columns = ["Compressor", "Average_Rank"]
avg_ranks = avg_ranks.sort_values("Average_Rank")
# Calculate average rank for each compressor across all variables
avg_ranks = ranked.groupby("Compressor")["CompRatio_Rank"].mean().reset_index()
avg_ranks.columns = ["Compressor", "Average_Rank"]
avg_ranks = avg_ranks.sort_values("Average_Rank")

best_compressor = avg_ranks.iloc[0]["Compressor"]
normalizer = avg_ranks.iloc[0]["Compressor"]

normalized = data.copy()
normalize_vars = [
Expand All @@ -128,7 +140,7 @@ def normalize(data, bound_normalize="mid"):

def get_normalizer(row):
return normalized[
(data["Compressor"] == best_compressor)
(data["Compressor"] == normalizer)
& (data["Variable"] == row["Variable"])
& (data["Error Bound"] == bound_normalize)
][col].item()
Expand Down Expand Up @@ -276,13 +288,13 @@ def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None):

plt.tight_layout()
if outfile is not None:
with outfile.open("wb") as f:
plt.savefig(f, dpi=300)
savefig(outfile)
plt.close()


def plot_aggregated_rd_curve(
normalized_df,
normalizer,
compression_metric,
distortion_metric,
outfile: None | Path = None,
Expand Down Expand Up @@ -349,36 +361,47 @@ def plot_aggregated_rd_curve(
plt.legend(
title="Compressor",
loc="upper right",
bbox_to_anchor=(0.95, 0.6),
fontsize=10,
title_fontsize=12,
bbox_to_anchor=(0.95, 0.7),
fontsize=12,
title_fontsize=14,
)
normalizer_label = COMPRESSOR2LEGEND_NAME.get(normalizer, normalizer)
plt.xlabel(
rf"Median Compression Ratio Relative to {normalizer_label} ($\uparrow$)",
fontsize=16,
)
plt.xlabel("Median Compression Ratio Relative to SZ3", fontsize=14)
plt.ylabel("Median Mean Absolute Error Relative to SZ3", fontsize=14)
plt.ylabel(
rf"Median Mean Absolute Error Relative to {normalizer_label} ($\downarrow$)",
fontsize=16,
)
arrow_color = "black"
# Add an arrow pointing into the lower right corner
plt.annotate(
"",
xy=(0.97, 0.05),
xy=(0.95, 0.05),
xycoords="axes fraction",
xytext=(-60, 50),
textcoords="offset points",
arrowprops=dict(arrowstyle="-|>", color="grey", lw=2),
arrowprops=dict(
arrowstyle="-|>, head_length=0.5, head_width=0.5",
color=arrow_color,
lw=5,
),
)
plt.text(
0.85,
0.83,
0.08,
"Better",
transform=plt.gca().transAxes,
fontsize=14,
fontsize=16,
fontweight="bold",
color="grey",
color=arrow_color,
ha="center",
)

plt.tight_layout()
if outfile is not None:
with outfile.open("wb") as f:
plt.savefig(f, dpi=300)
savefig(outfile)
plt.close()


Expand Down Expand Up @@ -413,10 +436,21 @@ def plot_bound_violations(df, bound_names, outfile: None | Path = None):

fig.tight_layout()
if outfile is not None:
with outfile.open("wb") as f:
fig.savefig(f, dpi=300)
savefig(outfile)
plt.close()


def savefig(outfile: Path):
ispdf = outfile.suffix == ".pdf"
if ispdf:
# Saving a PDF with the alternative code below leads to a corrupted file.
# Hence, we use the default savefig method.
# NOTE: This means passing a virtual UPath is only supported for non-PDF files.
plt.savefig(outfile, dpi=300)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least we should add a warning comment here saying that virtual UPath's are only support with the second method

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment!

else:
with outfile.open("wb") as f:
plt.savefig(f, dpi=300)


if __name__ == "__main__":
plot_metrics(basepath=Path())
plot_metrics(basepath=Path(), data_loader_base_path=Path() / ".." / "data-loader")