Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/climatebenchpress/compressor/plotting/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import seaborn as sns
import xarray as xr

from ..scripts.collect_metrics import parse_error_bounds
from ..scripts.compute_metrics import parse_error_bounds
from .error_dist_plotter import ErrorDistPlotter
from .variable_plotters import PLOTTERS

Expand Down Expand Up @@ -58,7 +58,7 @@ def get_legend_name(compressor: str) -> str:

def plot_metrics(
basepath: Path = Path(),
data_loader_base_path: None | Path = None,
data_loader_basepath: None | Path = None,
bound_names: list[str] = ["low", "mid", "high"],
normalizer: str = "sz3",
exclude_dataset: list[str] = [],
Expand All @@ -68,7 +68,7 @@ def plot_metrics(
):
metrics_path = basepath / "metrics"
plots_path = basepath / "plots"
datasets = (data_loader_base_path or basepath) / "datasets"
datasets = (data_loader_basepath or basepath) / "datasets"
compressed_datasets = basepath / "compressed-datasets"

df = pd.read_csv(metrics_path / "all_results.csv")
Expand Down Expand Up @@ -696,11 +696,17 @@ def savefig(outfile: Path):
parser.add_argument("--exclude-compressor", type=str, nargs="+", default=[])
parser.add_argument("--tiny-datasets", action="store_true", default=False)
parser.add_argument("--avoid-latex", action="store_true", default=False)
parser.add_argument("--basepath", type=Path, default=Path())
parser.add_argument(
"--data-loader-basepath",
type=Path,
default=Path() / ".." / "data-loader",
)
args = parser.parse_args()

plot_metrics(
basepath=Path(),
data_loader_base_path=Path() / ".." / "data-loader",
basepath=args.basepath,
data_loader_basepath=args.data_loader_basepath,
exclude_compressor=args.exclude_compressor,
exclude_dataset=args.exclude_dataset,
tiny_datasets=args.tiny_datasets,
Expand Down
10 changes: 7 additions & 3 deletions src/climatebenchpress/compressor/scripts/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def compress(
include_dataset: None | Container[str] = None,
exclude_compressor: Container[str] = tuple(),
include_compressor: None | Container[str] = None,
data_loader_base_path: None | Path = None,
data_loader_basepath: None | Path = None,
progress: bool = True,
):
datasets = (data_loader_base_path or basepath) / "datasets"
datasets = (data_loader_basepath or basepath) / "datasets"
compressed_datasets = basepath / "compressed-datasets"
datasets_error_bounds = basepath / "datasets-error-bounds"

Expand Down Expand Up @@ -187,6 +187,10 @@ def get_error_bounds(
parser.add_argument("--include-dataset", type=str, nargs="+", default=None)
parser.add_argument("--exclude-compressor", type=str, nargs="+", default=[])
parser.add_argument("--include-compressor", type=str, nargs="+", default=None)
parser.add_argument("--basepath", type=Path, default=Path())
parser.add_argument(
"--data-loader-basepath", type=Path, default=Path() / ".." / "data-loader"
)
args = parser.parse_args()

compress(
Expand All @@ -195,6 +199,6 @@ def get_error_bounds(
include_dataset=args.include_dataset,
exclude_compressor=args.exclude_compressor,
include_compressor=args.include_compressor,
data_loader_base_path=Path() / ".." / "data-loader",
data_loader_basepath=args.data_loader_basepath,
progress=True,
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
__all__ = ["collect_metrics"]
__all__ = ["compute_metrics"]

import argparse
import json
import re
from pathlib import Path
from typing import Optional
from typing import Iterable

import pandas as pd
import xarray as xr
Expand All @@ -25,30 +26,35 @@
}


def collect_metrics(
def compute_metrics(
basepath: Path = Path(),
data_loader_base_path: None | Path = None,
data_loader_basepath: None | Path = None,
exclude_dataset: Iterable[str] = tuple(),
include_dataset: None | Iterable[str] = None,
exclude_compressor: Iterable[str] = tuple(),
include_compressor: None | Iterable[str] = None,
):
datasets = (data_loader_base_path or basepath) / "datasets"
exclude_compressor = add_compressor_suffixes(exclude_compressor)
include_compressor = add_compressor_suffixes(include_compressor)

datasets = (data_loader_basepath or basepath) / "datasets"
compressed_datasets = basepath / "compressed-datasets"
error_bounds_dir = basepath / "datasets-error-bounds"
metrics_dir = basepath / "metrics"

all_results = []
for dataset in compressed_datasets.iterdir():
if dataset.name == ".gitignore":
if dataset.name == ".gitignore" or dataset.name in exclude_dataset:
continue
if include_dataset and dataset.name not in include_dataset:
continue

with (error_bounds_dir / dataset.name / "error_bounds.json").open() as f:
error_bound_list = json.load(f)

for error_bound in dataset.iterdir():
variable2error_bound = parse_error_bounds(error_bound.name)
error_bound_name = get_error_bound_name(
variable2error_bound, error_bound_list
)

for compressor in error_bound.iterdir():
if compressor.stem in exclude_compressor:
continue
if include_compressor and compressor.stem not in include_compressor:
continue
print(f"Evaluating {compressor.stem} on {dataset.name}...")

compressed_dataset = (
Expand All @@ -74,70 +80,21 @@ def collect_metrics(
)
compressor_metrics.mkdir(parents=True, exist_ok=True)

metrics = compute_metrics(compressor_metrics, ds, ds_new)
tests = compute_tests(
compressor_metrics, variable2error_bound, ds, ds_new
)
measurements = load_measurements(compressed_dataset, compressor)

df = merge_metrics(measurements, metrics, tests)
df["Dataset"] = dataset.name
df["Error Bound"] = error_bound.name
df["Error Bound Name"] = error_bound_name
all_results.append(df)
compute_compressor_metrics(compressor_metrics, ds, ds_new)
compute_tests(compressor_metrics, variable2error_bound, ds, ds_new)

all_results_df = pd.concat(all_results)
all_results_df.to_csv(metrics_dir / "all_results.csv", index=False)

def add_compressor_suffixes(compressors: None | Iterable[str]) -> list[str]:
if compressors is None:
return []

def get_error_bound_name(
variable2bound: dict[str, tuple[str, float]],
error_bound_list: list[dict[str, dict[str, Optional[float]]]],
bound_names: list[str] = ["low", "mid", "high"],
) -> str:
"""The function returns either "low", "mid", or "high" depending on which error bound
from the variable2bound dictionary matches the exact error bound in the error_bound_list.

error_bound_list contains one dictionary for each error bound (low, mid, high).
Each of these dictionaries contains the error bounds for
each variable. The variable names in the dictionaries should exactly match the variable names
in the variable2bound dictionary.

Parameters
----------
variable2bound : dict[str, tuple[str, float]]
A dictionary representing a single error bound, mapping variable names to
tuples of error type and error bound. The error type is either "abs_error"
or "rel_error", and the error bound is a float.
error_bound_list : list[dict[str, dict[str, Optional[float]]]]
A list of dictionaries, each representing an error bound (low, mid, high).
Each dictionary contains variable names as keys and a dictionary of error types
and bounds as values.
bound_names : list[str], optional
A list of names for the error bounds, by default ["low", "mid", "high"].
"""
extended_compressors = []
for compressor in compressors:
extended_compressors.append(compressor)
extended_compressors.append(compressor + "-conservative-rel")
extended_compressors.append(compressor + "-conservative-abs")

# Convert the variable2bound dictionary to match the format of error_bound_list.
new_bound_format = dict()
for k in variable2bound.keys():
new_bound_format[k] = {
"abs_error": (
variable2bound[k][1] if variable2bound[k][0] == "abs_error" else None
),
"rel_error": (
variable2bound[k][1] if variable2bound[k][0] == "rel_error" else None
),
}

# Return the name of the error bound that matches new_bound_format.
for bound_name, error_bound in zip(bound_names, error_bound_list):
if new_bound_format == error_bound:
return bound_name

raise ValueError(
f"Error bounds {new_bound_format} do not match any of the error bounds "
f"{error_bound_list}."
)
return extended_compressors


def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]:
Expand Down Expand Up @@ -187,7 +144,7 @@ def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]:
return result


def compute_metrics(
def compute_compressor_metrics(
compressor_metrics: Path, ds: xr.Dataset, ds_new: xr.Dataset
) -> pd.DataFrame:
metrics_path = compressor_metrics / "metrics.csv"
Expand All @@ -197,7 +154,15 @@ def compute_metrics(
metric_list = []
for name, metric in EVALUATION_METRICS.items():
for v in ds_new:
error = metric(ds[v], ds_new[v])
try:
error = metric(ds[v], ds_new[v])
except Exception as e:
print(
f"Error computing metric {name} for variable {v} on "
f"{compressor_metrics.parent.name}: {e}"
)
error = float("nan")

metric_list.append(
{
"Metric": name,
Expand All @@ -223,7 +188,16 @@ def compute_tests(
test_list = []
for name, test in PASSFAIL_TESTS.items():
for v in ds_new:
test_result, test_value = test(ds[v], ds_new[v])
try:
test_result, test_value = test(ds[v], ds_new[v])
except Exception as e:
print(
f"Error computing test {name} for variable {v} on "
f"{compressor_metrics.parent.name}: {e}"
)
test_result = False
test_value = float("nan")

test_list.append(
{
"Test": name,
Expand Down Expand Up @@ -327,7 +301,21 @@ def merge_metrics(


if __name__ == "__main__":
collect_metrics(
basepath=Path(),
data_loader_base_path=Path() / ".." / "data-loader",
parser = argparse.ArgumentParser()
parser.add_argument("--basepath", type=Path, default=Path())
parser.add_argument(
"--data-loader-basepath", type=Path, default=Path() / ".." / "data-loader"
)
parser.add_argument("--exclude-dataset", type=str, nargs="+", default=[])
parser.add_argument("--include-dataset", type=str, nargs="+", default=None)
parser.add_argument("--exclude-compressor", type=str, nargs="+", default=[])
parser.add_argument("--include-compressor", type=str, nargs="+", default=None)
args = parser.parse_args()
compute_metrics(
basepath=args.basepath,
data_loader_basepath=args.data_loader_basepath,
exclude_dataset=args.exclude_dataset,
include_dataset=args.include_dataset,
exclude_compressor=args.exclude_compressor,
include_compressor=args.include_compressor,
)
Loading