Skip to content

Commit 79a545b

Browse files
aangelo9gpsaggese
andauthored
TutorTask531_Create_data_downloader_for_EIA (#542)
Co-authored-by: GP Saggese <saggese@gmail.com>
1 parent f4cae24 commit 79a545b

File tree

2 files changed

+338
-20
lines changed

2 files changed

+338
-20
lines changed

causal_automl/TutorTask401_EIA_metadata_downloader_pipeline/eia_utils.py

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
"""
66

77
import logging
8-
from typing import Any, Dict, List, Tuple
8+
import re
9+
from typing import Any, Dict, List, Optional, Tuple, cast
910

11+
import helpers.hdbg as hdbg
1012
import matplotlib.pyplot as plt
1113
import pandas as pd
1214
import requests
@@ -118,12 +120,18 @@ def _get_api_request(self, route: str) -> Dict[str, Any]:
118120
# Build the full API request URL.
119121
url = f"{self._base_url}/{route}?api_key={self._api_key}"
120122
# Send HTTP GET request to the EIA API.
123+
# TODO(alvino): Add error handling for the HTTP request to handle
124+
# potential exceptions such as connection errors or timeouts.
121125
response = requests.get(url, timeout=20)
122126
# Parse JSON content.
127+
# TODO(alvino): Check if the response is successful (e.g.,
128+
# `response.status_code == 200`) before attempting to parse the JSON
129+
# content.
123130
json_data = response.json()
124131
# Get response from parsed payload.
125132
data: Dict[str, Any] = {}
126-
data = json_data.get("response", {})
133+
# TODO(alvino): Add error handling for JSON parsing to manage potential parsing errors.
134+
data = json_data["response"]
127135
return data
128136

129137
def _get_leaf_route_data(self) -> Dict[str, Dict[str, Any]]:
@@ -242,19 +250,19 @@ def _extract_metadata(
242250
"url": url,
243251
"id": f"{route_clean}.{frequency_id}.{metric_id_clean}",
244252
"dataset_id": dataset_id_clean,
245-
"name": data.get("name"),
246-
"description": data.get("description"),
247-
"frequency_id": frequency.get("id"),
253+
"name": data["name"],
254+
"description": data["description"],
255+
"frequency_id": frequency["id"],
248256
"frequency_alias": frequency.get("alias"),
249-
"frequency_description": frequency.get("description"),
250-
"frequency_query": frequency.get("query"),
251-
"frequency_format": frequency.get("format"),
252-
"facets": data.get("facets"),
257+
"frequency_description": frequency["description"],
258+
"frequency_query": frequency["query"],
259+
"frequency_format": frequency["format"],
260+
"facets": data["facets"],
253261
"data": metric_id,
254262
"data_alias": metric_info.get("alias"),
255263
"data_units": metric_info.get("units"),
256-
"start_period": data.get("startPeriod"),
257-
"end_period": data.get("endPeriod"),
264+
"start_period": data["startPeriod"],
265+
"end_period": data["endPeriod"],
258266
"parameter_values_file": param_file_path,
259267
}
260268
flattened_metadata.append(metadata)
@@ -270,6 +278,11 @@ def _get_facet_values(
270278
:param route: dataset route under the EIA v2 API
271279
:return: data containing all facet values
272280
"""
281+
hdbg.dassert_in(
282+
"facets",
283+
metadata,
284+
msg="Column 'facets' not found in metadata index."
285+
)
273286
facets = metadata["facets"]
274287
rows = []
275288
for facet in facets:
@@ -295,31 +308,84 @@ def _get_facet_values(
295308
def build_full_url(
296309
base_url: str,
297310
api_key: str,
298-
facet_input: Dict[str, str],
311+
*,
312+
facet_input: Optional[Dict[str, str]] = None,
313+
start_timestamp: Optional[pd.Timestamp] = None,
314+
end_timestamp: Optional[pd.Timestamp] = None,
299315
) -> str:
300316
"""
301-
Build a full EIA v2 API URL by appending one facet value per facet type.
317+
Build an EIA v2 API URL to data endpoint.
302318
303-
This modifies the base metadata URL to point to the actual time series
304-
data endpoint.
319+
This function modifies the base metadata URL by:
320+
- Replacing the metadata endpoint with the actual data endpoint
321+
- Injecting the provided API key
322+
- Appending optional facet filters
323+
- Appending start and end timestamps formatted to match the series frequency
305324
306325
:param base_url: base API URL with frequency and metric, excluding
307326
facet values,
308327
e.g., "https://api.eia.gov/v2/electricity/retail-sales?api_key={API_KEY}&frequency=monthly&data[0]=revenue"
309328
:param api_key: EIA API key, e.g., "abcd1234xyz"
310329
:param facet_input: specified facet values, e.g., {"stateid": "KS", "sectorid": "COM"}
311-
:return: full EIA API URL with all required facet parameters,
330+
:param start_timestamp: first observation date
331+
:param end_timestamp: last observation date
332+
:return: full EIA API URL to data endpoint,
312333
e.g, "https://api.eia.gov/v2/electricity/retail-sales/data?api_key=abcd1234xyz&frequency=monthly&data[0]=price&facets[stateid][]=KS&facets[sectorid][]=OTH"
313334
"""
335+
match = cast(re.Match[str], re.search(r"frequency=([a-zA-Z\-]+)", base_url))
336+
frequency = match.group(1)
314337
base_url = base_url.replace("?", "/data?")
315338
url = base_url.replace("{API_KEY}", api_key)
316339
query_parts = []
317-
for facet_id, value in facet_input.items():
318-
query_parts.append(f"&facets[{facet_id}][]={value}")
340+
if start_timestamp:
341+
formatted_start = _format_timestamp(start_timestamp, frequency)
342+
query_parts.append(f"&start={formatted_start}")
343+
if end_timestamp:
344+
formatted_end = _format_timestamp(end_timestamp, frequency)
345+
query_parts.append(f"&end={formatted_end}")
346+
if facet_input:
347+
# Add facet values when specified.
348+
for facet_id, value in facet_input.items():
349+
query_parts.append(f"&facets[{facet_id}][]={value}")
319350
full_url = url + "".join(query_parts)
320351
return full_url
321352

322353

354+
def _format_timestamp(timestamp: pd.Timestamp, frequency: str) -> pd.Timestamp:
355+
"""
356+
Format a timestamp based on the EIA time series frequency.
357+
358+
Supported formats:
359+
- "annual": "YYYY"
360+
- "quarterly": "YYYY-QN"
361+
- "monthly": "YYYY-MM"
362+
- "daily": "YYYY-MM-DD"
363+
- "hourly": "YYYY-MM-DDTHH"
364+
- "local-hourly": "YYYY-MM-DDTHH-ZZ" (fixed timezone offset, e.g., "-00")
365+
366+
:param timestamp: the timestamp to format
367+
:param frequency: the frequency type (e.g., "monthly", "local-hourly")
368+
:return: formatted timestamp
369+
"""
370+
result = ""
371+
if frequency == "annual":
372+
result = timestamp.strftime("%Y")
373+
elif frequency == "monthly":
374+
result = timestamp.strftime("%Y-%m")
375+
elif frequency == "quarterly":
376+
q = (timestamp.month - 1) // 3 + 1
377+
result = f"{timestamp.year}-Q{q}"
378+
elif frequency == "daily":
379+
result = timestamp.strftime("%Y-%m-%d")
380+
elif frequency == "hourly":
381+
result = timestamp.strftime("%Y-%m-%dT%H")
382+
elif frequency == "local-hourly":
383+
result = timestamp.strftime("%Y-%m-%dT%H") + "-00"
384+
else:
385+
raise ValueError(f"Unsupported frequency: {frequency}")
386+
return result
387+
388+
323389
def plot_distribution(df_metadata: pd.DataFrame, column: str, title: str) -> None:
324390
"""
325391
Plot a distribution count for a specified metadata column.
@@ -329,8 +395,11 @@ def plot_distribution(df_metadata: pd.DataFrame, column: str, title: str) -> Non
329395
'frequency_id', 'data_units')
330396
:param title: title for the plot
331397
"""
332-
if column not in df_metadata.columns:
333-
raise ValueError(f"Column '{column}' not found in metadata index.")
398+
hdbg.dassert_in(
399+
column,
400+
df_metadata.columns,
401+
msg=f"Column '{column}' not found in metadata index."
402+
)
334403
counts = df_metadata[column].value_counts()
335404
ax = counts.plot(kind="bar", figsize=(8, 4), title=title)
336405
ax.set_xlabel(column.replace("_", " ").title())

0 commit comments

Comments
 (0)