diff --git a/src/climatebenchpress/compressor/scripts/compute_metrics.py b/src/climatebenchpress/compressor/scripts/compute_metrics.py index 208255c..771b158 100644 --- a/src/climatebenchpress/compressor/scripts/compute_metrics.py +++ b/src/climatebenchpress/compressor/scripts/compute_metrics.py @@ -33,6 +33,7 @@ def compute_metrics( include_dataset: None | Iterable[str] = None, exclude_compressor: Iterable[str] = tuple(), include_compressor: None | Iterable[str] = None, + overwrite: bool = False, ): """Compute evaluation metrics for compressors. @@ -54,6 +55,9 @@ def compute_metrics( include_compressor : None | Iterable[str] Compressors to include in evaluation. If `None`, all compressors are included. If specified, only compressors in `include_compressor` will be evaluated. + overwrite : bool + If `True`, overwrite existing `metrics.csv` and `tests.csv` files. Otherwise, + existing files are read and returned without recomputation. """ exclude_compressor = add_compressor_suffixes(exclude_compressor) include_compressor = add_compressor_suffixes(include_compressor) @@ -106,8 +110,16 @@ def compute_metrics( ) compressor_metrics.mkdir(parents=True, exist_ok=True) - compute_compressor_metrics(compressor_metrics, ds, ds_new) - compute_tests(compressor_metrics, variable2error_bound, ds, ds_new) + compute_compressor_metrics( + compressor_metrics, ds, ds_new, overwrite=overwrite + ) + compute_tests( + compressor_metrics, + variable2error_bound, + ds, + ds_new, + overwrite=overwrite, + ) def add_compressor_suffixes(compressors: None | Iterable[str]) -> list[str]: @@ -171,10 +183,13 @@ def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]: def compute_compressor_metrics( - compressor_metrics: Path, ds: xr.Dataset, ds_new: xr.Dataset + compressor_metrics: Path, + ds: xr.Dataset, + ds_new: xr.Dataset, + overwrite: bool = False, ) -> pd.DataFrame: metrics_path = compressor_metrics / "metrics.csv" - if metrics_path.exists(): + if metrics_path.exists() and not overwrite: return pd.read_csv(metrics_path) metric_list = [] @@ -206,9 +221,10 @@ def compute_tests( variable2bound: dict[str, tuple[str, float]], ds: xr.Dataset, ds_new: xr.Dataset, + overwrite: bool = False, ) -> pd.DataFrame: tests_path = compressor_metrics / "tests.csv" - if tests_path.exists(): + if tests_path.exists() and not overwrite: return pd.read_csv(tests_path) test_list = [] @@ -336,6 +352,7 @@ def merge_metrics( parser.add_argument("--include-dataset", type=str, nargs="+", default=None) parser.add_argument("--exclude-compressor", type=str, nargs="+", default=[]) parser.add_argument("--include-compressor", type=str, nargs="+", default=None) + parser.add_argument("--overwrite", action="store_true") args = parser.parse_args() compute_metrics( basepath=args.basepath, @@ -344,4 +361,5 @@ def merge_metrics( include_dataset=args.include_dataset, exclude_compressor=args.exclude_compressor, include_compressor=args.include_compressor, + overwrite=args.overwrite, )