Skip to content

Commit c10a5db

Browse files
Merge pull request #304 from NCAS-CMS/fix_mask_equal_2
Fix masking for fill_value and missing_value for numpy 2.4.0
2 parents 166d0b4 + 22443ce commit c10a5db

File tree

3 files changed

+74
-3
lines changed

3 files changed

+74
-3
lines changed

activestorage/active.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def hfix(x):
143143
missing_value = ds.attrs.get('missing_value')
144144
# see https://github.com/NCAS-CMS/PyActiveStorage/pull/303
145145
if isinstance(missing_value, np.ndarray):
146-
missing_value = missing_value[0]
146+
if missing_value.size == 1:
147+
missing_value = missing_value[0]
147148
valid_min = hfix(ds.attrs.get('valid_min'))
148149
valid_max = hfix(ds.attrs.get('valid_max'))
149150
valid_range = hfix(ds.attrs.get('valid_range'))

activestorage/storage.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,19 @@ def mask_missing(data, missing):
130130
fill_value, missing_value, valid_min, valid_max = missing
131131

132132
if fill_value is not None:
133-
data = np.ma.masked_equal(data, fill_value)
133+
if isinstance(fill_value, np.ndarray) or isinstance(fill_value, list):
134+
data = np.ma.masked_where(data == fill_value, data)
135+
else:
136+
data = np.ma.masked_equal(data, fill_value)
134137

135138
if missing_value is not None:
136-
data = np.ma.masked_equal(data, missing_value)
139+
if isinstance(missing_value, np.ndarray) or isinstance(missing_value, list):
140+
try:
141+
data = np.ma.masked_where(data == missing_value, data)
142+
except ValueError: # not broadcastable
143+
raise ValueError("Data and missing_value arrays are not brodcastable!")
144+
else:
145+
data = np.ma.masked_equal(data, missing_value)
137146

138147
if valid_max is not None:
139148
data = np.ma.masked_greater(data, valid_max)

tests/unit/test_storage.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,67 @@
66
import activestorage.storage as st
77

88

9+
def test_mask_missing():
10+
"""Test mask missing."""
11+
missing_1 = ([-900.], np.array([-900.]), None, None)
12+
missing_2 = ([-900., 33.], np.array([-900., 33.]), None, None)
13+
data_1 = np.ma.array(
14+
[[[-900., 33.], [33., -900], [33., 44.]]],
15+
mask=False,
16+
fill_value=-900.0,
17+
dtype=float
18+
)
19+
data_2 = np.ma.array(
20+
[[[-900., 33.], [33., -900], [33., 44.]]],
21+
mask=False,
22+
fill_value=[-900.0, 33.],
23+
dtype=float
24+
)
25+
res_1 = st.mask_missing(data_1, missing_1)
26+
expected_1 = np.ma.array(
27+
data_1,
28+
mask=[[[True, False], [False, True], [False, False]]]
29+
)
30+
np.testing.assert_array_equal(res_1, expected_1)
31+
res_2 = st.mask_missing(data_2, missing_2)
32+
expected_2 = np.ma.array(
33+
data_2,
34+
mask=[[[True, True], [False, False], [False, False]]]
35+
)
36+
np.testing.assert_array_equal(res_2, expected_2)
37+
38+
39+
def test_mask_missing_missing_broadcastable():
40+
"""Test mask missing when fill_value cant be broadcast to data."""
41+
data = np.ma.array(
42+
[[[-900., 33.], [33., -900], [33., 44.]]],
43+
mask=False,
44+
fill_value=np.array([-900.0]),
45+
dtype=float
46+
)
47+
missing = (-900, np.array([-900., 33.]), None, None)
48+
res = st.mask_missing(data, missing)
49+
expected = np.ma.array(
50+
data,
51+
mask=[[[True, True], [False, False], [False, False]]]
52+
)
53+
np.testing.assert_array_equal(res, expected)
54+
55+
56+
def test_mask_missing_missing_not_broadcastable():
57+
"""Test mask missing when fill_value cant be broadcast to data."""
58+
data = np.ma.array(
59+
[[[-900., 33.], [33., -900], [33., 44.]]],
60+
mask=False,
61+
fill_value=np.array([-900.0]),
62+
dtype=float
63+
)
64+
missing = (-900, np.array([-900., -900., 33.]), None, None)
65+
msg = "Data and missing_value arrays are not brodcastable!"
66+
with pytest.raises(ValueError, match=msg):
67+
st.mask_missing(data, missing)
68+
69+
970
def test_reduce_chunk():
1071
"""Test reduce chunk entirely."""
1172
rfile = "tests/test_data/cesm2_native.nc"

0 commit comments

Comments
 (0)