diff --git a/tests/util/test_fof_utils.py b/tests/util/test_fof_utils.py index bc57bec..d7345fe 100644 --- a/tests/util/test_fof_utils.py +++ b/tests/util/test_fof_utils.py @@ -2,6 +2,8 @@ This module contains unit tests for the `util/fof_utils.py` module. """ +import os + import numpy as np import pytest @@ -9,11 +11,11 @@ clean_value, compare_arrays, compare_var_and_attr_ds, - fill_nans_for_float32, get_observation_variables, get_report_variables, primary_check, print_entire_line, + replace_nan_with_sentinel, split_feedback_dataset, write_lines, ) @@ -196,7 +198,7 @@ def test_fill_nans_for_float32_nan(arr_nan): Test that if an array containign nan is given, these values are replaced by -9.99999e05. """ - array = fill_nans_for_float32(arr_nan) + array = replace_nan_with_sentinel(arr_nan) expected = np.array([1.0, -9.99999e05, 3.0, 4.0, -9.99999e05], dtype=np.float32) assert np.array_equal(array, expected) @@ -206,7 +208,7 @@ def test_fill_nans_for_float32(arr1): Test that if an array without nan is given, the output of the function is the same as the input. """ - array = fill_nans_for_float32(arr1) + array = replace_nan_with_sentinel(arr1) assert np.array_equal(array, arr1) @@ -322,11 +324,18 @@ def test_compare_var_and_attr_ds(ds1, ds2, tmp_path): total1, equal1 = compare_var_and_attr_ds( ds1, ds2, nl=0, output=True, location=file_path ) - total2, equal2 = compare_var_and_attr_ds(ds1, ds2, nl=4, output=True, location=None) + total2, equal2 = compare_var_and_attr_ds(ds1, ds2, nl=4, output=True, location=None) assert (total1, equal1) == (104, 103) assert (total2, equal2) == (104, 103) + script_dir = os.path.dirname(os.path.abspath(__file__)) + grandparent_dir = os.path.dirname(os.path.dirname(script_dir)) + + path_name = os.path.join(grandparent_dir, "differences.csv") + if os.path.exists(path_name): + os.remove(path_name) + @pytest.fixture(name="ds3") def fixture_sample_dataset_3(sample_dataset_fof): diff --git a/util/fof_utils.py b/util/fof_utils.py index 9400b00..9d35087 100644 --- a/util/fof_utils.py +++ b/util/fof_utils.py @@ -102,12 +102,18 @@ def compare_arrays(arr1, arr2, var_name): return total, equal, diff -def fill_nans_for_float32(arr): +def replace_nan_with_sentinel(arr): """ To make sure nan values are recognised. """ - if arr.dtype == np.float32 and np.isnan(arr).any(): + if not np.issubdtype(arr.dtype, np.floating): + return arr + + arr = arr.astype(np.float64, copy=False) + + if np.isnan(arr).any(): return np.where(np.isnan(arr), -999999, arr) + return arr @@ -217,9 +223,11 @@ def compare_var_and_attr_ds(ds1, ds2, nl, output, location): if output: if location: path_name = location + else: script_dir = os.path.dirname(os.path.abspath(__file__)) - path_name = os.path.join(script_dir, "differences.csv") + parent_dir = os.path.dirname(script_dir) + path_name = os.path.join(parent_dir, "differences.csv") with open(path_name, "w", encoding="utf-8") as f: f.write("Differences\n") @@ -227,8 +235,8 @@ def compare_var_and_attr_ds(ds1, ds2, nl, output, location): for var in set(ds1.data_vars).union(ds2.data_vars): if var in ds1.data_vars and var in ds2.data_vars and var not in list_to_skip: - arr1 = fill_nans_for_float32(ds1[var].values) - arr2 = fill_nans_for_float32(ds2[var].values) + arr1 = replace_nan_with_sentinel(ds1[var].values) + arr2 = replace_nan_with_sentinel(ds2[var].values) if arr1.size == arr2.size: t, e, diff = compare_arrays(arr1, arr2, var)