Skip to content
22 changes: 14 additions & 8 deletions earthkit/data/sources/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,32 @@ def retrieve(target, args):
@normalize("date", "date-list(%Y-%m-%d)")
@normalize("area", "bounding-box(list)")
def _normalize_request(**kwargs):
return kwargs
request = {}
for k, v in sorted(kwargs.items()):
v = ensure_iterable(v)
if k not in ("area", "grid"):
v = sorted(v)
request[k] = v[0] if len(v) == 1 else v
return request

@cached_property
def requests(self):
requests = []
for arg in self._args:
request = self._normalize_request(**arg)
split_on = request.pop("split_on", None)
if split_on is None:
for request in self._args:
split_on = request.pop("split_on", {})
if not isinstance(split_on, dict):
split_on = {k: 1 for k in ensure_iterable(split_on) if k is not None}
if not split_on:
requests.append(request)
continue

if not isinstance(split_on, dict):
split_on = {k: 1 for k in ensure_iterable(split_on)}
request = self._normalize_request(**request)
for values in itertools.product(
*[batched(ensure_iterable(request[k]), v) for k, v in split_on.items()]
):
subrequest = dict(zip(split_on, values))
requests.append(request | subrequest)
return requests
return [self._normalize_request(**request) for request in requests]


source = CdsRetriever
30 changes: 30 additions & 0 deletions tests/sources/test_cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,36 @@ def test_cds_grib_multi_var_date(date, expected_date):
assert s.metadata("date") == expected_date


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
@pytest.mark.parametrize(
"variable1,variable2,expected_vars",
(
("2t", ["2t"], {"t2m"}),
(["2t", "msl"], ["msl", "2t"], {"t2m", "msl"}),
),
)
def test_cds_normalized_request(variable1, variable2, expected_vars):
base_request = dict(
product_type="reanalysis",
area=[50, -50, 20, 50],
grid=[2, 1],
date="2012-12-12",
time="12:00",
)
s1 = from_source(
"cds", "reanalysis-era5-single-levels", variable=variable1, **base_request
)
s2 = from_source(
"cds", "reanalysis-era5-single-levels", variable=variable2, **base_request
)
assert s1.path == s2.path
assert set(s1.to_xarray().data_vars) == expected_vars
assert s1.to_xarray()["longitude"].values.tolist() == list(range(-50, 51, 2))
assert s1.to_xarray()["latitude"].values.tolist() == list(range(50, 19, -1))


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
Expand Down