Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions tests/util/test_fof_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
This module contains unit tests for the `util/fof_utils.py` module.
"""

import os

import numpy as np
import pytest

from util.fof_utils import (
clean_value,
compare_arrays,
compare_var_and_attr_ds,
fill_nans_for_float32,
get_observation_variables,
get_report_variables,
primary_check,
print_entire_line,
replace_nan_with_sentinel,
split_feedback_dataset,
write_lines,
)
Expand Down Expand Up @@ -196,7 +198,7 @@ def test_fill_nans_for_float32_nan(arr_nan):
Test that if an array containign nan is given, these values are replaced
by -9.99999e05.
"""
array = fill_nans_for_float32(arr_nan)
array = replace_nan_with_sentinel(arr_nan)
expected = np.array([1.0, -9.99999e05, 3.0, 4.0, -9.99999e05], dtype=np.float32)
assert np.array_equal(array, expected)

Expand All @@ -206,7 +208,7 @@ def test_fill_nans_for_float32(arr1):
Test that if an array without nan is given, the output of the function
is the same as the input.
"""
array = fill_nans_for_float32(arr1)
array = replace_nan_with_sentinel(arr1)
assert np.array_equal(array, arr1)


Expand Down Expand Up @@ -322,11 +324,18 @@ def test_compare_var_and_attr_ds(ds1, ds2, tmp_path):
total1, equal1 = compare_var_and_attr_ds(
ds1, ds2, nl=0, output=True, location=file_path
)
total2, equal2 = compare_var_and_attr_ds(ds1, ds2, nl=4, output=True, location=None)

total2, equal2 = compare_var_and_attr_ds(ds1, ds2, nl=4, output=True, location=None)
assert (total1, equal1) == (104, 103)
assert (total2, equal2) == (104, 103)

script_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(script_dir))

path_name = os.path.join(grandparent_dir, "differences.csv")
if os.path.exists(path_name):
os.remove(path_name)


@pytest.fixture(name="ds3")
def fixture_sample_dataset_3(sample_dataset_fof):
Expand Down
18 changes: 13 additions & 5 deletions util/fof_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,18 @@ def compare_arrays(arr1, arr2, var_name):
return total, equal, diff


def fill_nans_for_float32(arr):
def replace_nan_with_sentinel(arr):
"""
To make sure nan values are recognised.
"""
if arr.dtype == np.float32 and np.isnan(arr).any():
if not np.issubdtype(arr.dtype, np.floating):
return arr

arr = arr.astype(np.float64, copy=False)

if np.isnan(arr).any():
return np.where(np.isnan(arr), -999999, arr)

return arr


Expand Down Expand Up @@ -217,18 +223,20 @@ def compare_var_and_attr_ds(ds1, ds2, nl, output, location):
if output:
if location:
path_name = location

else:
script_dir = os.path.dirname(os.path.abspath(__file__))
path_name = os.path.join(script_dir, "differences.csv")
parent_dir = os.path.dirname(script_dir)
path_name = os.path.join(parent_dir, "differences.csv")

with open(path_name, "w", encoding="utf-8") as f:
f.write("Differences\n")

for var in set(ds1.data_vars).union(ds2.data_vars):
if var in ds1.data_vars and var in ds2.data_vars and var not in list_to_skip:

arr1 = fill_nans_for_float32(ds1[var].values)
arr2 = fill_nans_for_float32(ds2[var].values)
arr1 = replace_nan_with_sentinel(ds1[var].values)
arr2 = replace_nan_with_sentinel(ds2[var].values)

if arr1.size == arr2.size:
t, e, diff = compare_arrays(arr1, arr2, var)
Expand Down