From d2cd67e2cb2bb30488a424b44b3b5613af5f4ff6 Mon Sep 17 00:00:00 2001 From: mahendra-918 Date: Tue, 17 Feb 2026 20:51:01 +0530 Subject: [PATCH] feat: add EIA data preprocessing pipeline for US solar generation --- docs/training_model_new_country.md | 44 +++ .../scripts/preprocess_eia_data.py | 267 +++++++++++++++++ tests/test_eia_preprocessing.py | 268 ++++++++++++++++++ 3 files changed, 579 insertions(+) create mode 100644 src/open_data_pvnet/scripts/preprocess_eia_data.py create mode 100644 tests/test_eia_preprocessing.py diff --git a/docs/training_model_new_country.md b/docs/training_model_new_country.md index 453b6d1..f6c10fd 100644 --- a/docs/training_model_new_country.md +++ b/docs/training_model_new_country.md @@ -77,9 +77,53 @@ Generation data represents the actual solar PV power output for your target coun - **United States**: - EIA (Energy Information Administration) - National level data - Regional ISOs: CAISO (California), ERCOT (Texas), PJM (Eastern US) + - **Use the built-in EIA preprocessing script** (see below) - **United Kingdom**: PVlive API (already implemented in this project) - Check national grid operators and government energy data portals +#### United States - EIA Data Preprocessing + +For the United States, use the built-in EIA preprocessing script to fetch and transform solar generation data: + +```bash +# Set your EIA API key (get one free at https://www.eia.gov/opendata/) +export EIA_API_KEY="your_api_key_here" + +# Fetch and preprocess EIA solar data for specific regions +python -m open_data_pvnet.scripts.preprocess_eia_data \ + --start-date 2023-01-01 \ + --end-date 2023-12-31 \ + --regions CAISO ERCOT PJM \ + --output ./data/us/generation/2023.zarr \ + --frequency hourly + +# Or fetch US48 aggregate data (default) +python -m open_data_pvnet.scripts.preprocess_eia_data \ + --start-date 2023-01-01 \ + --end-date 2023-12-31 \ + --output ./data/us/generation/2023.zarr +``` + +**Available Regions:** +- `US48` - Continental United States (aggregate) +- `CAISO` - California ISO +- `ERCOT` - Texas +- `PJM` - Mid-Atlantic +- `MISO` - Midwest ISO +- `NYISO` - New York ISO +- `ISONE` - New England ISO +- `SPP` - Southwest Power Pool + +**What the script does:** +1. Fetches hourly solar generation data from EIA API +2. Transforms to ocf-data-sampler schema (dimensions: `time_utc`, `location_id`) +3. Estimates capacity from historical 99th percentile values +4. Adds location metadata (latitude, longitude) +5. Saves to Zarr format + +> [!TIP] +> The preprocessing script automatically estimates `capacity_mwp` using the 99th percentile of historical generation values. This avoids outliers while capturing peak generation capacity. + 3. **Manual Data Collection** - Download from national energy/grid operator websites - Use data from research institutions or universities diff --git a/src/open_data_pvnet/scripts/preprocess_eia_data.py b/src/open_data_pvnet/scripts/preprocess_eia_data.py new file mode 100644 index 0000000..fcdd1e3 --- /dev/null +++ b/src/open_data_pvnet/scripts/preprocess_eia_data.py @@ -0,0 +1,267 @@ +import os +import logging +import argparse +import pandas as pd +import xarray as xr +import numpy as np +from typing import Optional, List, Dict, Any +from pathlib import Path + +from open_data_pvnet.scripts.fetch_eia_data import EIAData + +logger = logging.getLogger(__name__) + +# US RTO/Region location metadata +US_RTO_LOCATIONS = { + "CAISO": {"location_id": 1, "latitude": 36.7783, "longitude": -119.4179, "name": "California ISO"}, + "ERCOT": {"location_id": 2, "latitude": 31.9686, "longitude": -99.9018, "name": "Texas"}, + "PJM": {"location_id": 3, "latitude": 40.0583, "longitude": -76.3055, "name": "Mid-Atlantic"}, + "MISO": {"location_id": 4, "latitude": 41.8781, "longitude": -87.6298, "name": "Midwest ISO"}, + "NYISO": {"location_id": 5, "latitude": 42.6526, "longitude": -73.7562, "name": "New York ISO"}, + "ISONE": {"location_id": 6, "latitude": 42.3601, "longitude": -71.0589, "name": "New England ISO"}, + "SPP": {"location_id": 7, "latitude": 38.5767, "longitude": -92.1735, "name": "Southwest Power Pool"}, + "US48": {"location_id": 0, "latitude": 39.8283, "longitude": -98.5795, "name": "Continental US"}, +} + + +class EIAPreprocessor: + """Preprocessor to convert EIA solar data to ocf-data-sampler format.""" + def __init__(self, api_key: Optional[str] = None): + """Initialize preprocessor with optional API key.""" + self.eia_data = EIAData(api_key=api_key) + self.location_metadata = US_RTO_LOCATIONS + + def fetch_and_preprocess( + self, + start_date: str, + end_date: str, + regions: Optional[List[str]] = None, + output_path: Optional[str] = None, + route: str = "electricity/rto/daily-fuel-type-data", + frequency: str = "hourly", + ) -> xr.Dataset: + """Fetch EIA data and preprocess it. Defaults to US48 if no regions specified.""" + logger.info(f"Starting preprocessing for {start_date} to {end_date}") + + if regions is None: + regions = ["US48"] + + all_datasets = [] + + for region in regions: + logger.info(f"Processing region: {region}") + + facets = {"fueltype": "SUN"} + if region != "US48": + facets["respondent"] = [region] + + df = self.eia_data.get_data( + route=route, + start_date=start_date, + end_date=end_date, + frequency=frequency, + data_cols=["value"], + facets=facets, + region=region if region == "US48" else None, + ) + + if df is None or df.empty: + logger.warning(f"No data retrieved for region {region}") + continue + + ds = self.transform_to_schema(df, region) + all_datasets.append(ds) + + if not all_datasets: + raise ValueError("No data retrieved for any region") + + combined_ds = xr.concat(all_datasets, dim="location_id") + combined_ds = self.estimate_capacity(combined_ds) + + if not self.validate_data(combined_ds): + raise ValueError("Data validation failed") + + logger.info("Preprocessing completed successfully") + + if output_path: + self.save_to_zarr(combined_ds, output_path) + + return combined_ds + + def transform_to_schema(self, df: pd.DataFrame, region: str) -> xr.Dataset: + """Transform raw EIA DataFrame to the required schema.""" + if region not in self.location_metadata: + raise ValueError(f"Unknown region: {region}. Available: {list(self.location_metadata.keys())}") + + location_info = self.location_metadata[region] + location_id = location_info["location_id"] + + if "period" in df.columns: + df["time_utc"] = pd.to_datetime(df["period"], utc=True) + elif "datetime_gmt" in df.columns: + df["time_utc"] = pd.to_datetime(df["datetime_gmt"], utc=True) + else: + raise ValueError("No time column found in data") + + if "value" in df.columns: + df["generation_mw"] = df["value"].astype(np.float32) + else: + raise ValueError("No 'value' column found in data") + + df_clean = df[["time_utc", "generation_mw"]].copy() + df_clean = df_clean.drop_duplicates(subset=["time_utc"]) + df_clean = df_clean.set_index("time_utc") + df_clean = df_clean.sort_index() + + ds = xr.Dataset.from_dataframe(df_clean) + ds = ds.expand_dims({"location_id": [location_id]}) + ds = ds.assign_coords({ + "longitude": ("location_id", [location_info["longitude"]]), + "latitude": ("location_id", [location_info["latitude"]]), + }) + + return ds + + def estimate_capacity(self, ds: xr.Dataset, percentile: float = 99.0) -> xr.Dataset: + """Estimate capacity using percentile of generation values (default 99th to avoid outliers).""" + logger.info(f"Estimating capacity using {percentile}th percentile of generation") + + capacity_values = [] + for loc_id in ds.location_id.values: + gen_data = ds["generation_mw"].sel(location_id=loc_id) + capacity = np.percentile(gen_data.values[~np.isnan(gen_data.values)], percentile) + capacity_values.append(capacity) + + capacity_da = xr.DataArray( + np.array(capacity_values, dtype=np.float32), + dims=["location_id"], + coords={"location_id": ds.location_id}, + name="capacity_mwp" + ) + + capacity_broadcast = capacity_da.broadcast_like(ds["generation_mw"]) + ds["capacity_mwp"] = capacity_broadcast + + return ds + + def add_location_metadata(self, ds: xr.Dataset) -> xr.Dataset: + """Add location metadata. Currently handled in transform_to_schema.""" + return ds + + def validate_data(self, ds: xr.Dataset) -> bool: + """Check if dataset has all required dims, vars, and coords.""" + logger.info("Validating dataset schema") + + required_dims = {"time_utc", "location_id"} + if not required_dims.issubset(set(ds.dims)): + logger.error(f"Missing required dimensions. Expected {required_dims}, got {set(ds.dims)}") + return False + + required_vars = {"generation_mw", "capacity_mwp"} + if not required_vars.issubset(set(ds.data_vars)): + logger.error(f"Missing required variables. Expected {required_vars}, got {set(ds.data_vars)}") + return False + + required_coords = {"time_utc", "location_id", "longitude", "latitude"} + if not required_coords.issubset(set(ds.coords)): + logger.error(f"Missing required coordinates. Expected {required_coords}, got {set(ds.coords)}") + return False + + gen_missing = ds["generation_mw"].isnull().sum().values + if gen_missing > 0: + logger.warning(f"Found {gen_missing} missing values in generation_mw") + + if ds["generation_mw"].dtype != np.float32: + logger.warning(f"generation_mw dtype is {ds['generation_mw'].dtype}, expected float32") + + logger.info("✅ Dataset validation passed") + return True + + def save_to_zarr(self, ds: xr.Dataset, path: str, mode: str = "w"): + """Save to Zarr, handling timezone conversion for compatibility.""" + logger.info(f"Saving dataset to {path}") + + Path(path).parent.mkdir(parents=True, exist_ok=True) + + # Zarr doesn't handle datetime64[ns, UTC] well, convert to timezone-naive datetime64[ns] + ds_to_save = ds.copy() + if "time_utc" in ds_to_save.coords: + time_values = pd.DatetimeIndex(ds_to_save["time_utc"].values).tz_localize(None) + ds_to_save["time_utc"] = time_values + + ds_to_save.to_zarr(path, mode=mode) + logger.info(f"✅ Dataset saved to {path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess EIA solar generation data for ocf-data-sampler" + ) + parser.add_argument( + "--start-date", + required=True, + help="Start date in YYYY-MM-DD format" + ) + parser.add_argument( + "--end-date", + required=True, + help="End date in YYYY-MM-DD format" + ) + parser.add_argument( + "--regions", + nargs="+", + default=None, + help="List of region codes (e.g., CAISO ERCOT). Default: US48" + ) + parser.add_argument( + "--output", + required=True, + help="Output path for Zarr file" + ) + parser.add_argument( + "--api-key", + default=None, + help="EIA API key (optional, uses EIA_API_KEY env var if not provided)" + ) + parser.add_argument( + "--frequency", + default="hourly", + choices=["hourly", "daily"], + help="Data frequency (default: hourly)" + ) + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level (default: INFO)" + ) + + args = parser.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + preprocessor = EIAPreprocessor(api_key=args.api_key) + + try: + ds = preprocessor.fetch_and_preprocess( + start_date=args.start_date, + end_date=args.end_date, + regions=args.regions, + output_path=args.output, + frequency=args.frequency, + ) + + logger.info("Preprocessing completed successfully!") + logger.info(f"Dataset shape: {ds.dims}") + logger.info(f"Variables: {list(ds.data_vars)}") + + except Exception as e: + logger.error(f"Preprocessing failed: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/tests/test_eia_preprocessing.py b/tests/test_eia_preprocessing.py new file mode 100644 index 0000000..6a9339e --- /dev/null +++ b/tests/test_eia_preprocessing.py @@ -0,0 +1,268 @@ +import pytest +import pandas as pd +import numpy as np +import xarray as xr +from unittest.mock import Mock, patch +from open_data_pvnet.scripts.preprocess_eia_data import EIAPreprocessor, US_RTO_LOCATIONS + + +@pytest.fixture +def mock_eia_response(): + """Mock EIA API response data.""" + return pd.DataFrame({ + "period": [ + "2023-01-01T00", "2023-01-01T01", "2023-01-01T02", + "2023-01-01T03", "2023-01-01T04", "2023-01-01T05" + ], + "value": [0, 50, 150, 300, 250, 100], + "fueltype": ["SUN"] * 6, + "respondent": ["CAISO"] * 6, + }) + + +@pytest.fixture +def preprocessor(): + """Create EIAPreprocessor instance.""" + return EIAPreprocessor(api_key="test_key") + + +def test_preprocessor_init(): + """Test initialization.""" + preprocessor = EIAPreprocessor(api_key="test_key") + assert preprocessor.eia_data.api_key == "test_key" + assert preprocessor.location_metadata == US_RTO_LOCATIONS + + +def test_transform_to_schema(preprocessor, mock_eia_response): + """Test schema transformation.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + + assert "time_utc" in ds.dims + assert "location_id" in ds.dims + assert "generation_mw" in ds.data_vars + assert "longitude" in ds.coords + assert "latitude" in ds.coords + assert "location_id" in ds.coords + + assert ds.coords["location_id"].values[0] == US_RTO_LOCATIONS["CAISO"]["location_id"] + assert ds.coords["longitude"].values[0] == US_RTO_LOCATIONS["CAISO"]["longitude"] + assert ds.coords["latitude"].values[0] == US_RTO_LOCATIONS["CAISO"]["latitude"] + + assert len(ds.time_utc) == 6 + assert ds["generation_mw"].dtype == np.float32 + + +def test_transform_to_schema_with_datetime_gmt(preprocessor): + """Test with datetime_gmt column instead of period.""" + df = pd.DataFrame({ + "datetime_gmt": pd.to_datetime(["2023-01-01T00", "2023-01-01T01"], utc=True), + "value": [100, 150], + }) + + ds = preprocessor.transform_to_schema(df, "ERCOT") + + assert "time_utc" in ds.dims + assert len(ds.time_utc) == 2 + + +def test_transform_to_schema_unknown_region(preprocessor, mock_eia_response): + """Test unknown region raises error.""" + with pytest.raises(ValueError, match="Unknown region"): + preprocessor.transform_to_schema(mock_eia_response, "UNKNOWN_REGION") + + +def test_transform_to_schema_missing_time_column(preprocessor): + """Test missing time column raises error.""" + df = pd.DataFrame({"value": [100, 150]}) + + with pytest.raises(ValueError, match="No time column found"): + preprocessor.transform_to_schema(df, "CAISO") + + +def test_transform_to_schema_missing_value_column(preprocessor): + """Test missing value column raises error.""" + df = pd.DataFrame({"period": ["2023-01-01T00", "2023-01-01T01"]}) + + with pytest.raises(ValueError, match="No 'value' column found"): + preprocessor.transform_to_schema(df, "CAISO") + + +def test_estimate_capacity(preprocessor, mock_eia_response): + """Test capacity estimation.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + ds_with_capacity = preprocessor.estimate_capacity(ds, percentile=99.0) + + assert "capacity_mwp" in ds_with_capacity.data_vars + + expected_capacity = np.percentile(mock_eia_response["value"].values, 99.0) + actual_capacity = ds_with_capacity["capacity_mwp"].isel(location_id=0, time_utc=0).values + assert np.isclose(actual_capacity, expected_capacity, rtol=0.01) + assert ds_with_capacity["capacity_mwp"].dtype == np.float32 + + +def test_estimate_capacity_with_max_percentile(preprocessor, mock_eia_response): + """Test capacity with 100th percentile.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + ds_with_capacity = preprocessor.estimate_capacity(ds, percentile=100.0) + + expected_capacity = mock_eia_response["value"].max() + actual_capacity = ds_with_capacity["capacity_mwp"].isel(location_id=0, time_utc=0).values + assert np.isclose(actual_capacity, expected_capacity, rtol=0.01) + + +def test_validate_data_success(preprocessor, mock_eia_response): + """Test validation passes.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + ds = preprocessor.estimate_capacity(ds) + + assert preprocessor.validate_data(ds) is True + + +def test_validate_data_missing_dimension(preprocessor, mock_eia_response): + """Test validation fails with missing dimension.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + ds = preprocessor.estimate_capacity(ds) + + ds_invalid = ds.isel(location_id=0, drop=True) + assert preprocessor.validate_data(ds_invalid) is False + + +def test_validate_data_missing_variable(preprocessor, mock_eia_response): + """Test validation fails with missing variable.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + assert preprocessor.validate_data(ds) is False + + +def test_validate_data_missing_coordinate(preprocessor, mock_eia_response): + """Test validation fails with missing coordinate.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + ds = preprocessor.estimate_capacity(ds) + + ds_invalid = ds.drop_vars("longitude") + assert preprocessor.validate_data(ds_invalid) is False + + +def test_save_to_zarr(preprocessor, mock_eia_response, tmp_path): + """Test saving to Zarr.""" + ds = preprocessor.transform_to_schema(mock_eia_response, "CAISO") + ds = preprocessor.estimate_capacity(ds) + + output_path = tmp_path / "test_output.zarr" + preprocessor.save_to_zarr(ds, str(output_path)) + + assert output_path.exists() + + ds_loaded = xr.open_zarr(output_path) + assert "generation_mw" in ds_loaded.data_vars + assert "capacity_mwp" in ds_loaded.data_vars + assert len(ds_loaded.time_utc) == 6 + + +def test_fetch_and_preprocess_single_region(preprocessor, mock_eia_response, tmp_path): + """Test full pipeline for single region.""" + with patch.object(preprocessor.eia_data, 'get_data', return_value=mock_eia_response): + output_path = tmp_path / "output.zarr" + + ds = preprocessor.fetch_and_preprocess( + start_date="2023-01-01", + end_date="2023-01-02", + regions=["CAISO"], + output_path=str(output_path), + ) + + assert "generation_mw" in ds.data_vars + assert "capacity_mwp" in ds.data_vars + assert "time_utc" in ds.dims + assert "location_id" in ds.dims + assert output_path.exists() + + +def test_fetch_and_preprocess_multiple_regions(preprocessor, mock_eia_response, tmp_path): + """Test pipeline for multiple regions.""" + mock_caiso = mock_eia_response.copy() + mock_ercot = mock_eia_response.copy() + mock_ercot["value"] = mock_ercot["value"] * 1.5 + + with patch.object(preprocessor.eia_data, 'get_data', side_effect=[mock_caiso, mock_ercot]): + ds = preprocessor.fetch_and_preprocess( + start_date="2023-01-01", + end_date="2023-01-02", + regions=["CAISO", "ERCOT"], + ) + + assert len(ds.location_id) == 2 + assert US_RTO_LOCATIONS["CAISO"]["location_id"] in ds.location_id.values + assert US_RTO_LOCATIONS["ERCOT"]["location_id"] in ds.location_id.values + + +def test_fetch_and_preprocess_us48_default(preprocessor, mock_eia_response): + """Test US48 is default region.""" + with patch.object(preprocessor.eia_data, 'get_data', return_value=mock_eia_response) as mock_get: + ds = preprocessor.fetch_and_preprocess( + start_date="2023-01-01", + end_date="2023-01-02", + regions=None, + ) + + call_args = mock_get.call_args + assert call_args[1]["region"] == "US48" + assert ds.location_id.values[0] == US_RTO_LOCATIONS["US48"]["location_id"] + + +def test_fetch_and_preprocess_no_data(preprocessor): + """Test handling when no data retrieved.""" + with patch.object(preprocessor.eia_data, 'get_data', return_value=None): + with pytest.raises(ValueError, match="No data retrieved"): + preprocessor.fetch_and_preprocess( + start_date="2023-01-01", + end_date="2023-01-02", + regions=["CAISO"], + ) + + +def test_fetch_and_preprocess_empty_dataframe(preprocessor): + """Test handling empty DataFrame.""" + empty_df = pd.DataFrame() + + with patch.object(preprocessor.eia_data, 'get_data', return_value=empty_df): + with pytest.raises(ValueError, match="No data retrieved"): + preprocessor.fetch_and_preprocess( + start_date="2023-01-01", + end_date="2023-01-02", + regions=["CAISO"], + ) + + +def test_fetch_and_preprocess_validation_failure(preprocessor, mock_eia_response): + """Test validation failure raises error.""" + with patch.object(preprocessor.eia_data, 'get_data', return_value=mock_eia_response): + with patch.object(preprocessor, 'validate_data', return_value=False): + with pytest.raises(ValueError, match="Data validation failed"): + preprocessor.fetch_and_preprocess( + start_date="2023-01-01", + end_date="2023-01-02", + regions=["CAISO"], + ) + + +def test_us_rto_locations_structure(): + """Test location metadata structure.""" + for region, info in US_RTO_LOCATIONS.items(): + assert "location_id" in info + assert "latitude" in info + assert "longitude" in info + assert "name" in info + + assert isinstance(info["location_id"], int) + assert isinstance(info["latitude"], (int, float)) + assert isinstance(info["longitude"], (int, float)) + assert isinstance(info["name"], str) + + assert -90 <= info["latitude"] <= 90 + assert -180 <= info["longitude"] <= 180 + + +def test_us_rto_locations_unique_ids(): + """Test location IDs are unique.""" + location_ids = [info["location_id"] for info in US_RTO_LOCATIONS.values()] + assert len(location_ids) == len(set(location_ids)), "Location IDs must be unique"