Skip to content

Commit 60964ac

Browse files
committed
Updated test data functions
1 parent 9bda9da commit 60964ac

File tree

3 files changed

+87
-81
lines changed

3 files changed

+87
-81
lines changed

hdp/tests/test_utils.py

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,57 @@
1-
import hdp.utils
2-
import xarray
1+
from hdp.utils import *
32
import numpy as np
4-
3+
from math import isclose
4+
from xarray import DataArray
55

66
def test_get_time_stamp():
7-
assert type(hdp.utils.get_time_stamp()) is str
7+
assert type(get_time_stamp()) is str
88

99

1010
def test_get_version():
11-
assert type(hdp.utils.get_version()) is str
12-
13-
14-
def test_test_data_functions():
15-
var = "test"
16-
center_val = 10
17-
amplitude_val = 1
18-
units = "test_units"
19-
20-
ds = hdp.utils.generate_test_dataset(name=var, units=units, center=center_val, amplitude=amplitude_val)
21-
assert type(ds) is xarray.Dataset
22-
assert len(ds.data_vars) == 1
23-
assert var in ds
24-
25-
assert "units" in ds[var].attrs
26-
assert ds[var].attrs["units"] == units
27-
assert ds[var].dims == ("lat", "lon", "time")
28-
assert ds[var].dtype == float
29-
assert ds[var].time.values[0].calendar == "noleap"
30-
assert ds[var].lat.size > 1
31-
assert ds[var].lon.size > 1
32-
assert ds[var].time.size >= 2*365
33-
34-
data = ds[var].compute()
35-
36-
assert np.isclose(data.mean(), center_val)
37-
assert np.isclose(data.max(), center_val + amplitude_val)
38-
assert np.isclose(data.values, data.values[0]).all()
39-
40-
exceed_data = hdp.utils.generate_exceedance_dataarray(ds[var], exceedance_pattern=[1, 0, 1, 0]).compute()
41-
42-
assert np.isclose(exceed_data.mean(), center_val + 0.5)
43-
assert exceed_data.attrs == ds[var].attrs
44-
assert exceed_data.dims == ds[var].dims
45-
assert exceed_data.shape == ds[var].shape
46-
assert exceed_data.dtype == ds[var].dtype
11+
assert type(get_version()) is str
12+
13+
14+
def test_control_dataarray():
15+
control_da = generate_test_control_dataarray(grid_shape=(1,1), start_date="2000", end_date="2100")
16+
assert type(control_da) is DataArray
17+
var = control_da.name
18+
assert control_da.attrs["units"] == "degC"
19+
assert control_da.dims == ("lon", "lat", "time")
20+
assert control_da.dtype == float
21+
assert control_da.time.values[0].calendar == "noleap"
22+
assert control_da.time.values.size >= 365
23+
assert np.sum(np.isnan(control_da)) == 0
24+
25+
control_slope = np.polyfit(np.arange(control_da["time"].size), control_da.mean(dim=["lat", "lon"]).values, 1)[0]
26+
assert np.abs(control_slope) < 0.01
27+
28+
29+
def test_warming_dataarray():
30+
warming_da = generate_test_warming_dataarray(grid_shape=(1,1), start_date="2000", end_date="2100")
31+
avg_da = warming_da.mean(dim=["lat", "lon"]).values
32+
warm_slope = np.polyfit(np.arange(avg_da.size), avg_da, 1)[0]
33+
34+
assert warm_slope > 0
35+
assert not isclose(warm_slope, 0)
36+
37+
38+
def test_rh_dataarray():
39+
rh_da = generate_test_rh_dataarray()
40+
41+
assert rh_da.max() <= 1
42+
assert rh_da.min() >= 0
43+
44+
45+
def test_defaults_compatibility():
46+
control_da = generate_test_control_dataarray()
47+
warming_da = generate_test_warming_dataarray()
48+
rh_da = generate_test_rh_dataarray()
49+
50+
assert control_da.shape == warming_da.shape
51+
assert warming_da.shape == rh_da.shape
52+
53+
assert control_da.dims == warming_da.dims
54+
assert warming_da.dims == rh_da.dims
55+
56+
assert control_da.name == warming_da.name
57+
assert control_da.units == warming_da.units

hdp/tests/test_workflow.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55

66
def test_full_data_workflow():
7-
baseline_temp = hdp.utils.generate_test_dataset(name="temp")["temp"]
8-
baseline_rh = hdp.utils.generate_test_dataset(name="rh", units="%", center=90, amplitude=15)["rh"]
7+
baseline_temp = hdp.utils.generate_test_control_dataarray().rename("temp")
8+
baseline_rh = hdp.utils.generate_test_rh_dataarray().rename("rh")
99
baseline_measures = hdp.measure.format_standard_measures([baseline_temp], rh=baseline_rh)
1010

1111
percentiles = np.arange(0.9, 1, 0.01)
@@ -14,9 +14,8 @@ def test_full_data_workflow():
1414

1515
exceedance_pattern = [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1]
1616

17-
test_temp = hdp.utils.generate_test_dataset(name="temp")["temp"]
18-
test_temp = hdp.utils.generate_exceedance_dataarray(test_temp, exceedance_pattern)
19-
test_rh = hdp.utils.generate_test_dataset(name="rh", units="%", center=90, amplitude=15)["rh"]
17+
test_temp = hdp.utils.generate_test_warming_dataarray().rename("temp")
18+
test_rh = baseline_rh
2019

2120
hw_definitions = [[3,0,0], [3,1,1], [4,2,0], [4,1,3], [5,0,1], [5,1,4]]
2221

@@ -37,11 +36,6 @@ def test_full_data_workflow():
3736
assert metrics.definition.values[5] == "5-1-4"
3837
assert (metrics.percentile.values == percentiles).all()
3938

40-
assert (metrics["temp.temp_threshold.HWF"] == metrics["temp_hi.temp_hi_threshold.HWF"]).all()
41-
assert (metrics["temp.temp_threshold.HWD"] == metrics["temp_hi.temp_hi_threshold.HWD"]).all()
42-
assert (metrics["temp.temp_threshold.HWA"] == metrics["temp_hi.temp_hi_threshold.HWA"]).all()
43-
assert (metrics["temp.temp_threshold.HWN"] == metrics["temp_hi.temp_hi_threshold.HWN"]).all()
44-
4539
metric_means = metrics.mean()
4640

4741
assert metric_means["temp.temp_threshold.HWF"] >= metric_means["temp.temp_threshold.HWD"]

hdp/utils.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,45 +36,46 @@ def get_func_description(func):
3636
return desc
3737

3838

39-
def generate_test_dataset(center=25, amplitude=10, name="temperature", units="degC"):
39+
def generate_test_warming_dataarray(start_date="2000-01-01", end_date="2009-12-31", grid_shape=(10, 10), warming_period=10):
40+
base_data = generate_test_control_dataarray(start_date=start_date, end_date=end_date, grid_shape=grid_shape)
41+
base_data += xarray.DataArray(np.arange(base_data["time"].size) / (365*warming_period), dims=["time"], coords={"time": base_data["time"]})
42+
return base_data
43+
44+
45+
def generate_test_rh_dataarray(start_date="2000-01-01", end_date="2009-12-31", grid_shape=(10, 10)):
46+
base_data = generate_test_control_dataarray(start_date=start_date, end_date=end_date, grid_shape=grid_shape)
47+
base_data = abs(base_data / base_data.max() - 0.3)
48+
base_data = base_data.rename("test_rh_data")
49+
base_data.attrs["units"] = 'g/g'
50+
return base_data
51+
52+
53+
def generate_test_control_dataarray(start_date="2000-01-01", end_date="2009-12-31", grid_shape=(10, 10)):
4054
time_values = xarray.date_range(
41-
start=cftime.DatetimeNoLeap(2000, 1, 1),
42-
end=cftime.DatetimeNoLeap(2009, 12, 31),
55+
start=start_date,
56+
end=end_date,
4357
freq="D",
4458
calendar="noleap",
4559
use_cftime=True
4660
)
61+
temperature_seasonal_ts = 20 + 15*np.sin(np.pi*np.arange(time_values.size, dtype=float) / 365)
62+
temperature_seasonal_vals = np.broadcast_to(temperature_seasonal_ts, (grid_shape[0], grid_shape[1], temperature_seasonal_ts.size))
4763

48-
temp_timeseries = center + amplitude*np.sin(np.pi*np.arange(time_values.size, dtype=float) / 365)
49-
temp_values = np.broadcast_to(temp_timeseries, (3, 3, temp_timeseries.size))
50-
51-
temp_da = xarray.DataArray(
52-
data=da.from_array(temp_values),
53-
dims=["lat", "lon", "time"],
64+
lat_vals = np.linspace(-90, 90, grid_shape[1], dtype=float)
65+
66+
lat_grad = np.broadcast_to(np.abs(lat_vals) / 90 * 15, grid_shape).T
67+
temperature_seasonal_vals = temperature_seasonal_vals - lat_grad[:, :, None]
68+
69+
return xarray.DataArray(
70+
data=temperature_seasonal_vals,
71+
dims=["lon", "lat", "time"],
5472
coords={
55-
"lat": np.array([-90, 0, 90], dtype=float),
56-
"lon": np.array([-180, 0, 180], dtype=float),
73+
"lon": np.linspace(-180, 180, grid_shape[0], dtype=float),
74+
"lat": lat_vals,
5775
"time": time_values
5876
},
59-
name=name,
77+
name="test_temperature_data",
6078
attrs={
61-
"units": units
79+
"units": "degC"
6280
}
63-
).chunk(dict(lat=1, lon=1))
64-
return xarray.Dataset({name: temp_da})
65-
66-
67-
def generate_exceedance_dataarray(measure, exceedance_pattern, multiplier=1.0):
68-
tiles = np.ceil(measure.time.size / len(exceedance_pattern)).astype(int)
69-
pattern_data = np.broadcast_to(np.tile(exceedance_pattern, tiles)[:measure.time.size], measure.shape)*multiplier
70-
71-
ret_da = None
72-
with xarray.set_options(keep_attrs=True):
73-
ret_da = measure + xarray.DataArray(
74-
data=da.from_array(pattern_data),
75-
dims=measure.dims,
76-
coords=measure.coords,
77-
name=measure.name,
78-
attrs=measure.attrs
79-
).chunk(measure.chunksizes)
80-
return ret_da
81+
).chunk('auto')

0 commit comments

Comments
 (0)