diff --git a/docs/getting_started.md b/docs/getting_started.md index 6bd7416..f9bc3aa 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -211,6 +211,24 @@ ds = xr.open_zarr(s3.get_mapper(dataset_path), consolidated=True) print(ds) ``` +6. **Accessing US EIA Data from S3** +The US EIA solar generation data is stored in the S3 bucket `s3://ocf-open-data-pvnet/data/us/eia/`. Similar to GFS data, it can be accessed directly using `xarray` and `s3fs`. + +```python +import xarray as xr +import s3fs + +# Create an S3 filesystem object (Public Access) +s3 = s3fs.S3FileSystem(anon=True) + +# Open the US EIA dataset (Latest Version) +dataset_path = 's3://ocf-open-data-pvnet/data/us/eia/latest/target_eia_data_processed.zarr' +ds = xr.open_zarr(s3.get_mapper(dataset_path), consolidated=True) + +# Display the dataset +print(ds) +``` + ### Best Practices for Using APIs - **API Keys**: Most APIs require authentication via an API key. Store keys securely using environment variables or secret management tools. diff --git a/docs/us_data_preprocessing.md b/docs/us_data_preprocessing.md new file mode 100644 index 0000000..a84eedd --- /dev/null +++ b/docs/us_data_preprocessing.md @@ -0,0 +1,199 @@ +# US Data Preprocessing for ocf-data-sampler + +This document describes how to preprocess EIA solar generation data for use with ocf-data-sampler and PVNet training. + +## Overview + +The EIA data collected by `collect_eia_data.py` needs to be preprocessed to match the format expected by ocf-data-sampler, which follows the UK GSP data structure. + +## Data Format Requirements + +### Input Format (from `collect_eia_data.py`) +- **Dimensions**: `(timestamp, ba_code)` +- **Variables**: `generation_mw`, `ba_name`, `latitude`, `longitude`, `value-units` +- **Index**: MultiIndex on `(timestamp, ba_code)` + +### Output Format (for ocf-data-sampler) +- **Dimensions**: `(ba_id, datetime_gmt)` where `ba_id` is int64 +- **Variables**: `generation_mw`, `capacity_mwp` +- **Coordinates**: `ba_code`, `ba_name`, `latitude`, `longitude` (optional) +- **Chunking**: `{"ba_id": 1, "datetime_gmt": 1000}` +- **Format**: Zarr with consolidated metadata + +## Preprocessing Steps + +The preprocessing script (`preprocess_eia_for_sampler.py`) performs the following transformations: + +1. **Rename timestamp to datetime_gmt** + - Converts to UTC timezone + - Removes timezone info (matches UK format) + +2. **Map BA codes to numeric IDs** + - Creates numeric `ba_id` (0, 1, 2, ...) for each unique `ba_code` + - Generates metadata CSV with mapping + +3. **Add capacity data** + - Estimates `capacity_mwp` from maximum historical generation + - Applies safety factor (1.15x) to account for capacity > max generation + - Ensures capacity >= generation + +4. **Restructure dataset** + - Sets index to `(ba_id, datetime_gmt)` + - Applies proper chunking for efficient access + - Saves with consolidated metadata + +## Usage + +### Step 1: Collect Raw EIA Data + +```bash +python src/open_data_pvnet/scripts/collect_eia_data.py \ + --start 2020-01-01 \ + --end 2023-12-31 \ + --output src/open_data_pvnet/data/target_eia_data.zarr +``` + +### Step 2: Preprocess for ocf-data-sampler + +```bash +python src/open_data_pvnet/scripts/preprocess_eia_for_sampler.py \ + --input src/open_data_pvnet/data/target_eia_data.zarr \ + --output src/open_data_pvnet/data/target_eia_data_processed.zarr \ + --metadata-output src/open_data_pvnet/data/us_ba_metadata.csv +``` + +### Step 3: Verify Compatibility + +```bash +python src/open_data_pvnet/scripts/test_eia_sampler_compatibility.py \ + --data-path src/open_data_pvnet/data/target_eia_data_processed.zarr \ + --config-path src/open_data_pvnet/configs/PVNet_configs/datamodule/configuration/us_configuration.yaml +``` + +## Capacity Data Options + +The preprocessing script supports three methods for obtaining capacity data: + +### 1. Estimate from Generation (Default) +```bash +--capacity-method estimate +``` +Estimates capacity as `max(generation) * 1.15`. This is a simple heuristic that works reasonably well for initial testing. + +### 2. Load from File +```bash +--capacity-method file --capacity-file path/to/capacity.csv +``` +Loads capacity data from a CSV file with columns `ba_code` and `capacity_mwp`. + +### 3. Static Value (Not Recommended) +```bash +--capacity-method static +``` +Uses a static capacity value for all BAs. Only for testing. + +## Output Files + +### Processed Zarr Dataset +- **Location**: `src/open_data_pvnet/data/target_eia_data_processed.zarr` +- **Format**: Zarr v3 with consolidated metadata +- **Structure**: Matches UK GSP format for ocf-data-sampler compatibility + +### BA Metadata CSV +- **Location**: `src/open_data_pvnet/data/us_ba_metadata.csv` +- **Columns**: `ba_id`, `ba_code`, `ba_name`, `latitude`, `longitude` +- **Purpose**: Mapping between numeric IDs and BA codes, plus spatial coordinates + +## Configuration + +Update the US configuration file to point to the processed data: + +```yaml +# src/open_data_pvnet/configs/PVNet_configs/datamodule/configuration/us_configuration.yaml +input_data: + gsp: + zarr_path: "src/open_data_pvnet/data/target_eia_data_processed.zarr" + time_resolution_minutes: 60 # EIA data is hourly + # ... other settings +``` + +## Troubleshooting + +### Issue: "ba_code not found in dataset" +**Solution**: Ensure the input file was created by `collect_eia_data.py` and has the correct structure. + +### Issue: "Missing capacity data" +**Solution**: The script will estimate capacity automatically. If you have external capacity data, use `--capacity-method file`. + +### Issue: "ocf-data-sampler compatibility test fails" +**Solution**: +1. Check that dimensions are `(ba_id, datetime_gmt)` +2. Verify `datetime_gmt` is datetime64 without timezone +3. Ensure `capacity_mwp` variable exists +4. Check that Zarr has consolidated metadata + +### Issue: "Generation exceeds capacity" +**Solution**: The script automatically ensures `capacity_mwp >= generation_mw * 1.01`. If this fails, check for data quality issues. + +## Next Steps + +After preprocessing: +1. Verify data with `test_eia_sampler_compatibility.py` +2. Update configuration files +3. Test data loading with ocf-data-sampler +4. Proceed with PVNet training setup + +## References + +- UK GSP Data Format: `src/open_data_pvnet/scripts/generate_combined_gsp.py` +- ocf-data-sampler Documentation: https://github.com/openclimatefix/ocf-data-sampler +- EIA Data Collection: `src/open_data_pvnet/scripts/collect_eia_data.py` + + + +## S3 Data Storage & Retrieval + +We support storing processed US EIA data in S3 for easier access across environments. + +### 1. Uploading to S3 + +You can upload processed data directly to S3 using the `collect_and_preprocess_eia.py` script. This requires AWS credentials with write access to the target bucket. + +```bash +python src/open_data_pvnet/scripts/collect_and_preprocess_eia.py \ + --start 2023-01-01 --end 2023-01-07 \ + --upload-to-s3 \ + --s3-bucket ocf-open-data-pvnet \ + --s3-version v1 \ + --public # make objects public-read +``` + +Or uplaod existing data using the utility script: + +```bash +python src/open_data_pvnet/scripts/upload_eia_to_s3.py \ + --input src/open_data_pvnet/data/target_eia_data_processed.zarr \ + --version v1 \ + --public +``` + +### 2. Accessing from S3 + +Update your `us_configuration.yaml` to point to the S3 path: + +```yaml +input_data: + gsp: + zarr_path: "s3://ocf-open-data-pvnet/data/us/eia/v1/target_eia_data_processed.zarr" + public: True +``` + +Or access directly in Python: + +```python +import s3fs +import xarray as xr + +s3 = s3fs.S3FileSystem(anon=True) +ds = xr.open_zarr(s3.get_mapper("s3://ocf-open-data-pvnet/data/us/eia/v1/target_eia_data_processed.zarr"), consolidated=True) +``` diff --git a/docs/us_eia_dataset_format.md b/docs/us_eia_dataset_format.md new file mode 100644 index 0000000..2b46e06 --- /dev/null +++ b/docs/us_eia_dataset_format.md @@ -0,0 +1,102 @@ +# US EIA Dataset Technical Documentation + +## Overview + +We collect hourly solar generation data for the United States from the **US Energy Information Administration (EIA) Open Data API**. This dataset serves as the primary ground truth for training solar forecasting models for the US region. + +Key characteristics: +- **Source**: [EIA Hourly Electricity Grid Monitor](https://www.eia.gov/electricity/gridmonitor/dashboard/electric_overview/US48/US48) +- **Granularity**: Hourly resolution +- **Coverage**: Major US Balancing Authorities (ISOs/RTOs) +- **License**: Public Domain (US Government Data) + +--- + +## Data Formats + +### 1. Raw Data (Intermediate) + +The data collected by `collect_eia_data.py` is stored in Zarr format with the following structure: + +- **Dimensions**: `(timestamp, ba_code)` +- **Variables**: + - `generation_mw`: Electricity generation in Megawatts (MW) + - `ba_name`: Full name of the Balancing Authority + - `latitude`: Approximate centroid latitude + - `longitude`: Approximate centroid longitude + - `value-units`: Unit string (e.g., "megawatthours") + +### 2. Processed Data (Ready for Training) + +The raw data is preprocessed by `preprocess_eia_for_sampler.py` to match the format required by `ocf-data-sampler`. This format aligns with the UK GSP dataset structure. + +- **Dimensions**: `(ba_id, datetime_gmt)` +- **Chunking**: `{"ba_id": 1, "datetime_gmt": 1000}` +- **Variables**: + +| Variable | Type | Description | +|----------|------|-------------| +| `generation_mw` | `float32` | Solar generation in MW | +| `capacity_mwp` | `float32` | Estimated installed capacity in MWp | + +- **Coordinates**: + +| Coordinate | Type | Description | +|------------|------|-------------| +| `ba_id` | `int64` | Numeric ID mapped to each BA code | +| `datetime_gmt` | `datetime64[ns]` | Timestamp in UTC | +| `ba_code` | `string` | ISO/RTO code (e.g., "CISO") | +| `ba_name` | `string` | Full name of the BA | +| `latitude` | `float32` | Centroid latitude | +| `longitude` | `float32` | Centroid longitude | + +--- + +## Metadata & Mapping + +A metadata CSV file (`us_ba_metadata.csv`) is generated alongside the processed data. It maps numeric `ba_id`s to their corresponding codes and locations. + +| ba_id | ba_code | ba_name | latitude | longitude | +|-------|---------|---------|----------|-----------| +| 0 | CISO | California ISO | 37.0 | -120.0 | +| 1 | ERCO | Electric Reliability Council of Texas | 31.0 | -99.0 | +| ... | ... | ... | ... | ... | + +--- + +## Capacity Estimation + +Unlike UK PVLive, the EIA dataset does not provide historical installed capacity. We estimate capacity using a heuristic based on maximum historical generation: + +```python +capacity = max(generation_mw) * 1.15 +min_capacity = 100.0 MW +``` + +- **Method**: `estimate` (Default) +- **Safety Factor**: 1.15 (Assumes max generation is ~85% of theoretical capacity due to efficiencies/weather) +- **Minimum**: 100 MW floor to prevent zeros for missing data intervals + +--- + +## Data Quality & Validation + +- **Missing Data**: Intervals with missing data are typically represented as NaNs. The `ocf-data-sampler` handles this by finding valid contiguous time periods. +- **Timezone**: All timestamps are converted to **UTC**. +- **Negative Generation**: Clipped to 0. + +## Usage + +### Loading Data with Xarray + +```python +import xarray as xr +import s3fs + +# Local Load +ds = xr.open_zarr("src/open_data_pvnet/data/target_eia_data_processed.zarr", consolidated=True) + +# S3 Load (Public) +s3 = s3fs.S3FileSystem(anon=True) +ds = xr.open_zarr(s3.get_mapper("s3://ocf-open-data-pvnet/data/us/eia/latest/target_eia_data_processed.zarr"), consolidated=True) +``` diff --git a/src/open_data_pvnet/configs/PVNet_configs/datamodule/configuration/us_configuration.yaml b/src/open_data_pvnet/configs/PVNet_configs/datamodule/configuration/us_configuration.yaml new file mode 100644 index 0000000..84766b9 --- /dev/null +++ b/src/open_data_pvnet/configs/PVNet_configs/datamodule/configuration/us_configuration.yaml @@ -0,0 +1,94 @@ +general: + description: Configuration for US GFS and EIA data + name: us_config + +input_data: + gsp: + # Path to US EIA data in zarr format (processed by preprocess_eia_for_sampler.py) + # Raw data from collect_eia_data.py must be preprocessed first + zarr_path: "src/open_data_pvnet/data/target_eia_data_processed.zarr" + # S3 path example (matches UK pattern): + # zarr_path: "s3://ocf-open-data-pvnet/data/us/eia/latest/target_eia_data_processed.zarr" + interval_start_minutes: -60 + interval_end_minutes: 480 + time_resolution_minutes: 60 # EIA data is hourly + dropout_timedeltas_minutes: [] + dropout_fraction: 0.0 + public: True # Required for S3 access without credentials + + nwp: + gfs: + time_resolution_minutes: 180 # Match the dataset's resolution (3 hours) + interval_start_minutes: -180 + interval_end_minutes: 540 # Cover the forecast horizon + dropout_fraction: 0.0 + dropout_timedeltas_minutes: [] + zarr_path: "s3://ocf-open-data-pvnet/data/gfs.zarr" + provider: "gfs" + image_size_pixels_height: 2 + image_size_pixels_width: 2 + public: True + channels: + - dlwrf + - dswrf + - hcc + - lcc + - mcc + - prate + - r + - t + - tcc + - u10 + - u100 + - v10 + - v100 + - vis + # Normalisation constants (generic or calculated, using defaults from example for now) + normalisation_constants: + dlwrf: + mean: 298.342 + std: 96.305916 + dswrf: + mean: 168.12321 + std: 246.18533 + hcc: + mean: 35.272 + std: 42.525383 + lcc: + mean: 43.578342 + std: 44.3732 + mcc: + mean: 33.738823 + std: 43.150745 + prate: + mean: 2.8190969e-05 + std: 0.00010159573 + r: + mean: 18.359747 + std: 25.440672 + t: + mean: 278.5223 + std: 22.825893 + tcc: + mean: 66.841606 + std: 41.030598 + u10: + mean: -0.0022310058 + std: 5.470838 + u100: + mean: 0.0823025 + std: 6.8899174 + v10: + mean: 0.06219831 + std: 4.7401133 + v100: + mean: 0.0797807 + std: 6.076132 + vis: + mean: 19628.32 + std: 8294.022 + + solar_position: + interval_start_minutes: -60 + interval_end_minutes: 480 + time_resolution_minutes: 60 diff --git a/src/open_data_pvnet/configs/PVNet_configs/experiment/example_us_run.yaml b/src/open_data_pvnet/configs/PVNet_configs/experiment/example_us_run.yaml new file mode 100644 index 0000000..e2f39f1 --- /dev/null +++ b/src/open_data_pvnet/configs/PVNet_configs/experiment/example_us_run.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +# to execute this experiment run: +# python run.py experiment=example_us_run.yaml + +defaults: + - override /trainer: default.yaml + - override /model: multimodal.yaml + - override /datamodule: streamed_batches.yaml + - override /callbacks: default.yaml + - override /logger: wandb.yaml + - override /hydra: default.yaml + +seed: 518 + +trainer: + min_epochs: 1 + max_epochs: 2 + +datamodule: + batch_size: 4 + configuration: "src/open_data_pvnet/configs/PVNet_configs/datamodule/configuration/us_configuration.yaml" + num_train_samples: 10 + num_val_samples: 5 diff --git a/src/open_data_pvnet/configs/eia_s3_config.yaml b/src/open_data_pvnet/configs/eia_s3_config.yaml new file mode 100644 index 0000000..5b63223 --- /dev/null +++ b/src/open_data_pvnet/configs/eia_s3_config.yaml @@ -0,0 +1,6 @@ +s3: + bucket: "ocf-open-data-pvnet" + prefix: "data/us/eia" + region: "us-east-1" + versioning: true + public: true diff --git a/src/open_data_pvnet/configs/gfs_us_data_config.yaml b/src/open_data_pvnet/configs/gfs_us_data_config.yaml new file mode 100644 index 0000000..5570935 --- /dev/null +++ b/src/open_data_pvnet/configs/gfs_us_data_config.yaml @@ -0,0 +1,37 @@ +general: + name: "gfs_us_config" + description: "Configuration for US GFS data sampling" + +input_data: + nwp: + gfs: + time_resolution_minutes: 180 # Match the dataset's resolution (3 hours) + interval_start_minutes: 0 + interval_end_minutes: 1080 # 6 forecast steps (6 * 3 hours) + dropout_timedeltas_minutes: null + accum_channels: [] + max_staleness_minutes: 1080 # Match interval_end_minutes for consistency + s3_bucket: "noaa-gfs-bdp-pds" + s3_prefix: "gfs" + local_output_dir: "tmp/gfs/us" + zarr_path: "s3://ocf-open-data-pvnet/data/gfs.zarr" + provider: "gfs" + image_size_pixels_height: 1 + image_size_pixels_width: 1 + channels: + [ + "dlwrf", + "dswrf", + "hcc", + "lcc", + "mcc", + "prate", + "r", + "t", + "tcc", + "u10", + "u100", + "v10", + "v100", + "vis", + ] diff --git a/src/open_data_pvnet/main.py b/src/open_data_pvnet/main.py index f8d733e..1791aba 100644 --- a/src/open_data_pvnet/main.py +++ b/src/open_data_pvnet/main.py @@ -80,6 +80,13 @@ def _add_common_arguments(parser, provider_name): default="eu", help="Specify the DWD dataset region (default: eu)", ) + elif provider_name == "gfs": + parser.add_argument( + "--region", + choices=["global", "us"], + default="global", + help="Specify the GFS dataset region (default: global)", + ) parser.add_argument( "--overwrite", diff --git a/src/open_data_pvnet/nwp/gfs.py b/src/open_data_pvnet/nwp/gfs.py index 364db23..890b859 100644 --- a/src/open_data_pvnet/nwp/gfs.py +++ b/src/open_data_pvnet/nwp/gfs.py @@ -1,8 +1,156 @@ import logging +from pathlib import Path +import xarray as xr +import boto3 +from botocore import UNSIGNED +from botocore.config import Config +import shutil +from open_data_pvnet.utils.env_loader import PROJECT_BASE +from open_data_pvnet.utils.config_loader import load_config logger = logging.getLogger(__name__) +def fetch_gfs_data(year, month, day, hour, config): + """Downloads GFS GRIB2 files from NOAA S3 bucket.""" + s3_bucket = config.get("s3_bucket", "noaa-gfs-bdp-pds") + + # Determine output directory, default to a tmp location if not specified + output_dir_rel = config.get("local_output_dir", "tmp/gfs/data") + local_output_dir = Path(PROJECT_BASE) / output_dir_rel / "raw" / f"{year}-{month:02d}-{day:02d}-{hour:02d}" + local_output_dir.mkdir(parents=True, exist_ok=True) + + interval_end = config.get("interval_end_minutes", 1080) + resolution = config.get("time_resolution_minutes", 180) + # Generate steps (e.g., 0, 3, 6 ... hours) + steps = range(0, (interval_end // 60) + 1, resolution // 60) + + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + downloaded_files = [] + + for step in steps: + # Key format: gfs.20231201/00/atmos/gfs.t00z.pgrb2.0p25.f000 + # This structure matches the NOAA GFS bucket + s3_key = f"gfs.{year}{month:02d}{day:02d}/{hour:02d}/atmos/gfs.t{hour:02d}z.pgrb2.0p25.f{step:03d}" + filename = Path(s3_key).name + local_path = local_output_dir / filename + + if not local_path.exists(): + logger.info(f"Downloading {s3_key} from {s3_bucket}") + try: + s3.download_file(s3_bucket, s3_key, str(local_path)) + downloaded_files.append(local_path) + except Exception as e: + logger.error(f"Failed to download {s3_key}: {e}") + else: + downloaded_files.append(local_path) + + return downloaded_files -def process_gfs_data(year, month): - logger.info(f"Downloading GFS data for {year}-{month}") - raise NotImplementedError("The process_gfs_data function is not implemented yet.") +def convert_grib_to_zarr(files, output_path, config): + """Converts downloaded GRIB files to a single Zarr dataset.""" + # Import cfgrib here to avoid hard dependency at module level + import cfgrib + + datasets = [] + needed_channels = set(config.get("channels", [])) + + for f in files: + try: + # GFS GRIB files often contain multiple 'hypercubes' (variable groups with different dims) + # cfgrib.open_datasets handles this by returning a list of xarray Datasets + grib_datasets = cfgrib.open_datasets(str(f)) + + if not grib_datasets: + logger.warning(f"No datasets found in {f}") + continue + + # Merge variables from all parts of the GRIB file + # compat='override' is often necessary if coordinates differ slightly due to precision + merged_ds = xr.merge(grib_datasets, compat='override') + + # Filter channels + if needed_channels: + available_vars = set(merged_ds.data_vars) + # Keep only what is in needed_channels (intersection) + vars_to_keep = available_vars.intersection(needed_channels) + + if not vars_to_keep: + logger.warning(f"No matching channels found in {f}. Available: {available_vars}. Requested: {needed_channels}") + # Decide whether to continue empty or skip. Keeping empty might break downstream. + # We'll skip this file's contribution if it lacks all desired data. + # Or we could just warn. + else: + merged_ds = merged_ds[list(vars_to_keep)] + + # GFS files are usually one 'step' per file + # We assume the list of files is ordered by step + datasets.append(merged_ds) + + except Exception as e: + logger.error(f"Error processing {f}: {e}") + continue + + if not datasets: + logger.error("No GRIB datasets could be processed.") + return None + + try: + # Concatenate along step/time + # Note: GRIB files often load with a 'step' dimension if valid_time is different but ref_time is same + full_ds = xr.concat(datasets, dim="step") + + # Save to Zarr + full_ds.to_zarr(output_path, mode="w") + logger.info(f"Successfully saved Zarr to {output_path}") + return output_path + + except Exception as e: + logger.error(f"Error during final concat/save to Zarr: {e}") + return None + +def process_gfs_data(year, month, day, hour=None, region="global", overwrite=False): + logger.info(f"Processing GFS data for {year}-{month} {day} {hour} region={region}") + + if region == "us": + config_path = PROJECT_BASE / "src/open_data_pvnet/configs/gfs_us_data_config.yaml" + elif region == "global": + config_path = PROJECT_BASE / "src/open_data_pvnet/configs/gfs_data_config.yaml" + else: + raise ValueError(f"Invalid region for GFS: {region}") + + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found at {config_path}") + + config = load_config(config_path) + # Extract GFS specific config + gfs_config = config.get("input_data", {}).get("nwp", {}).get("gfs", {}) + + if not gfs_config: + logger.error("No GFS configuration found in input_data.nwp.gfs") + return + + # Fetch data + # (Hour defaults to 00 if None, typical for daily run start) + target_hour = hour if hour is not None else 0 + files = fetch_gfs_data(year, month, day, target_hour, gfs_config) + + if not files: + logger.error("No files were downloaded.") + return + + # Determine Output Path + output_dir_rel = gfs_config.get("local_output_dir", f"tmp/gfs/{region}") + local_output_dir = Path(PROJECT_BASE) / output_dir_rel + zarr_dir = local_output_dir / "zarr" / f"{year}-{month:02d}-{day:02d}-{target_hour:02d}" + + if not zarr_dir.exists() or overwrite: + result = convert_grib_to_zarr(files, zarr_dir, gfs_config) + + if result: + # Cleanup raw files to save space + raw_dir = files[0].parent + if raw_dir.exists(): + shutil.rmtree(raw_dir) + logger.info(f"Cleaned up raw files in {raw_dir}") + else: + logger.info(f"Output Zarr already exists at {zarr_dir}. Use overwrite=True to replace.") diff --git a/src/open_data_pvnet/scripts/archive.py b/src/open_data_pvnet/scripts/archive.py index 4a2df6f..728cfc4 100644 --- a/src/open_data_pvnet/scripts/archive.py +++ b/src/open_data_pvnet/scripts/archive.py @@ -50,9 +50,9 @@ def handle_archive( elif provider == "gfs": logger.info( - f"Processing GFS data for {year}-{month:02d}-{day:02d} with overwrite={overwrite}" + f"Processing GFS {region} data for {year}-{month:02d}-{day:02d} with overwrite={overwrite}" ) - process_gfs_data(year, month, day, hour, overwrite=overwrite) + process_gfs_data(year, month, day, hour, region=region, overwrite=overwrite) elif provider == "dwd": hours = range(24) if hour is None else [hour] for hour in hours: diff --git a/src/open_data_pvnet/scripts/collect_and_preprocess_eia.py b/src/open_data_pvnet/scripts/collect_and_preprocess_eia.py new file mode 100644 index 0000000..3566bdb --- /dev/null +++ b/src/open_data_pvnet/scripts/collect_and_preprocess_eia.py @@ -0,0 +1,240 @@ +""" +Collect and Preprocess EIA Data for ocf-data-sampler + +This script combines EIA data collection and preprocessing into a single workflow. +It collects raw EIA data and immediately preprocesses it for ocf-data-sampler compatibility. + +Usage: + python src/open_data_pvnet/scripts/collect_and_preprocess_eia.py \ + --start 2020-01-01 \ + --end 2023-12-31 \ + --output-dir src/open_data_pvnet/data +""" + +import argparse +import logging +import os +import sys +from pathlib import Path + +# Add parent directory to path to import modules +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from open_data_pvnet.scripts.fetch_eia_data import EIAData +from open_data_pvnet.scripts.preprocess_eia_for_sampler import preprocess_eia_data +from open_data_pvnet.utils.env_loader import load_environment_variables +import pandas as pd +import xarray as xr +import numpy as np + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser( + description="Collect and preprocess EIA data for ocf-data-sampler" + ) + parser.add_argument( + "--start", + type=str, + default="2020-01-01", + help="Start date YYYY-MM-DD" + ) + parser.add_argument( + "--end", + type=str, + required=True, + help="End date YYYY-MM-DD" + ) + parser.add_argument( + "--bas", + nargs="+", + default=None, + help="List of BA codes (default: all major ISOs)" + ) + parser.add_argument( + "--output-dir", + type=str, + default="src/open_data_pvnet/data", + help="Output directory for data files" + ) + parser.add_argument( + "--capacity-method", + type=str, + choices=["estimate", "file", "static"], + default="estimate", + help="Method for capacity data (default: estimate)" + ) + parser.add_argument( + "--capacity-file", + type=str, + default=None, + help="Path to capacity data CSV (if --capacity-method=file)" + ) + parser.add_argument( + "--skip-collection", + action="store_true", + help="Skip data collection, only preprocess existing data" + ) + parser.add_argument( + "--raw-output", + type=str, + default=None, + help="Path for raw EIA data (default: {output_dir}/target_eia_data.zarr)" + ) + parser.add_argument( + "--processed-output", + type=str, + default=None, + help="Path for processed data (default: {output_dir}/target_eia_data_processed.zarr)" + ) + + # S3 Upload Arguments + parser.add_argument("--upload-to-s3", action="store_true", help="Upload processed data to S3") + parser.add_argument("--s3-bucket", default="ocf-open-data-pvnet", help="S3 Bucket name") + parser.add_argument("--s3-prefix", default="data/us/eia", help="S3 Prefix") + parser.add_argument("--s3-version", default="latest", help="Data version string") + parser.add_argument("--dry-run", action="store_true", help="Simulate S3 upload") + parser.add_argument("--public", action="store_true", help="Make S3 objects public-read") + + args = parser.parse_args() + + # Set up paths + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + raw_path = args.raw_output or str(output_dir / "target_eia_data.zarr") + processed_path = args.processed_output or str(output_dir / "target_eia_data_processed.zarr") + metadata_path = str(output_dir / "us_ba_metadata.csv") + + # Step 1: Collect raw EIA data + if not args.skip_collection: + try: + load_environment_variables() + except Exception as e: + logger.warning(f"Could not load environment variables: {e}") + + if args.bas is None: + DEFAULT_BAS = ['CISO', 'ERCO', 'PJM', 'MISO', 'NYIS', 'ISNE', 'SWPP'] + bas = DEFAULT_BAS + else: + bas = args.bas + + eia = EIAData() + if not eia.api_key: + logger.error("EIA_API_KEY not set. Exiting.") + return 1 + + logger.info(f"Fetching data from {args.start} to {args.end} for BAs: {bas}") + + df = eia.get_hourly_solar_data( + start_date=args.start, + end_date=args.end, + ba_codes=bas + ) + + if df.empty: + logger.error("No data fetched.") + return 1 + + logger.info(f"Fetched {len(df)} rows.") + + ba_centroids = { + 'CISO': {'latitude': 37.0, 'longitude': -120.0}, + 'ERCO': {'latitude': 31.0, 'longitude': -99.0}, + 'PJM': {'latitude': 40.0, 'longitude': -77.0}, + 'MISO': {'latitude': 40.0, 'longitude': -90.0}, + 'NYIS': {'latitude': 43.0, 'longitude': -75.0}, + 'ISNE': {'latitude': 44.0, 'longitude': -71.0}, + 'SWPP': {'latitude': 38.0, 'longitude': -98.0}, + } + + # Add coordinates + df["latitude"] = df["ba_code"].map(lambda x: ba_centroids.get(x, {}).get('latitude', np.nan)) + df["longitude"] = df["ba_code"].map(lambda x: ba_centroids.get(x, {}).get('longitude', np.nan)) + + # Ensure timestamp is datetime + df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True) + + # Ensure timestamp is timezone-naive UTC for Zarr compatibility + if "timestamp" in df.columns: + df["timestamp"] = pd.to_datetime(df["timestamp"]).dt.tz_convert(None) + + # Set index + df = df.set_index(["timestamp", "ba_code"]) + + # Convert to xarray + ds = xr.Dataset.from_dataframe(df) + + # Ensure output directory exists + os.makedirs(os.path.dirname(os.path.abspath(raw_path)), exist_ok=True) + + ds.to_zarr(raw_path, mode="w", consolidated=True) + + logger.info(f"✅ Raw EIA data collected: {raw_path}") + else: + logger.info("Skipping data collection (--skip-collection)") + if not os.path.exists(raw_path): + logger.error(f"Raw data file not found: {raw_path}") + return 1 + + # Step 2: Preprocess + try: + preprocess_eia_data( + input_path=raw_path, + output_path=processed_path, + metadata_output_path=metadata_path, + capacity_method=args.capacity_method, + capacity_file=args.capacity_file, + ) + except Exception as e: + logger.error(f"Preprocessing failed: {e}") + return 1 + + # Step 3: Optional S3 Upload + if args.upload_to_s3: + from open_data_pvnet.scripts.upload_eia_to_s3 import upload_directory_to_s3 + + # Upload processed data + full_prefix = f"{args.s3_prefix}/{args.s3_version}" + full_prefix = full_prefix.replace("//", "/") # Safety check + + logger.info(f"Uploading processed data to s3://{args.s3_bucket}/{full_prefix}") + + success = upload_directory_to_s3( + local_dir=processed_path, + bucket=args.s3_bucket, + prefix=full_prefix, + dry_run=args.dry_run, + public=args.public + ) + + if success: + # Also upload metadata + if os.path.exists(metadata_path): + meta_prefix = f"{full_prefix}/{os.path.basename(metadata_path)}" + logger.info(f"Uploading metadata to s3://{args.s3_bucket}/{meta_prefix}") + from open_data_pvnet.scripts.upload_eia_to_s3 import upload_file, get_s3_client + s3_client = get_s3_client(args.dry_run) + upload_file(s3_client, metadata_path, args.s3_bucket, meta_prefix, args.dry_run, args.public) + + logger.info("✅ S3 upload completed") + else: + logger.error("❌ S3 upload failed") + # Don't fail the whole script if upload fails, but warn user + + # Summary + logger.info("Collection and preprocessing complete!") + logger.info(f"Raw data: {raw_path}") + logger.info(f"Processed data: {processed_path}") + logger.info(f"Metadata: {metadata_path}") + if args.upload_to_s3: + logger.info(f"S3 Target: s3://{args.s3_bucket}/{args.s3_prefix}/{args.s3_version}") + + return 0 + + +if __name__ == "__main__": + exit(main()) + diff --git a/src/open_data_pvnet/scripts/collect_eia_data.py b/src/open_data_pvnet/scripts/collect_eia_data.py new file mode 100644 index 0000000..72300c1 --- /dev/null +++ b/src/open_data_pvnet/scripts/collect_eia_data.py @@ -0,0 +1,97 @@ +import pandas as pd +import logging +from datetime import datetime +from open_data_pvnet.scripts.fetch_eia_data import EIAData +from open_data_pvnet.utils.env_loader import load_environment_variables +import xarray as xr +import numpy as np +import os +import argparse + +logger = logging.getLogger(__name__) + +# Major US ISOs/RTOs +DEFAULT_BAS = [ + 'CISO', # CAISO + 'ERCO', # ERCOT + 'PJM', # PJM + 'MISO', # MISO + 'NYIS', # NYISO + 'ISNE', # ISO-NE + 'SWPP', # SPP +] + +def main(): + try: + load_environment_variables() + except Exception as e: + logger.warning(f"Could not load environment variables: {e}") + + logging.basicConfig(level=logging.INFO) + parser = argparse.ArgumentParser(description="Collect EIA Solar Data") + parser.add_argument("--start", type=str, default="2020-01-01", help="Start date YYYY-MM-DD") + parser.add_argument("--end", type=str, default=datetime.now().strftime("%Y-%m-%d"), help="End date YYYY-MM-DD") + parser.add_argument("--bas", nargs="+", default=DEFAULT_BAS, help="List of BA codes") + parser.add_argument("--output", type=str, default="src/open_data_pvnet/data/target_eia_data.nc", help="Output path") + + args = parser.parse_args() + + eia = EIAData() + if not eia.api_key: + logger.error("EIA_API_KEY not set. Exiting.") + return + + logger.info(f"Fetching data from {args.start} to {args.end} for BAs: {args.bas}") + + try: + df = eia.get_hourly_solar_data( + start_date=args.start, + end_date=args.end, + ba_codes=args.bas + ) + + if df.empty: + logger.warning("No data fetched.") + return + + logger.info(f"Fetched {len(df)} rows.") + + ba_centroids = { + 'CISO': {'latitude': 37.0, 'longitude': -120.0}, + 'ERCO': {'latitude': 31.0, 'longitude': -99.0}, + 'PJM': {'latitude': 40.0, 'longitude': -77.0}, + 'MISO': {'latitude': 40.0, 'longitude': -90.0}, + 'NYIS': {'latitude': 43.0, 'longitude': -75.0}, + 'ISNE': {'latitude': 44.0, 'longitude': -71.0}, + 'SWPP': {'latitude': 38.0, 'longitude': -98.0}, + } + + df["latitude"] = df["ba_code"].map(lambda x: ba_centroids.get(x, {}).get('latitude', np.nan)) + df["longitude"] = df["ba_code"].map(lambda x: ba_centroids.get(x, {}).get('longitude', np.nan)) + + df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True) + df["timestamp"] = df["timestamp"].dt.tz_convert(None) + + df = df.set_index(["timestamp", "ba_code"]) + + ds = xr.Dataset.from_dataframe(df) + + output_path = args.output + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + + if output_path.endswith(".zarr"): + if os.path.exists(output_path): + pass + ds.to_zarr(output_path, mode="w", consolidated=True) + else: + ds.to_netcdf(output_path) + + logger.info(f"Data successfully stored in {output_path}") + logger.info(f"Note: For ocf-data-sampler compatibility, run preprocess_eia_for_sampler.py on this file") + + except Exception as e: + logger.error(f"Failed to collect data: {e}") + raise + +if __name__ == "__main__": + main() diff --git a/src/open_data_pvnet/scripts/fetch_eia_data.py b/src/open_data_pvnet/scripts/fetch_eia_data.py new file mode 100644 index 0000000..762396e --- /dev/null +++ b/src/open_data_pvnet/scripts/fetch_eia_data.py @@ -0,0 +1,119 @@ +import os +import logging +import requests +import pandas as pd +from datetime import datetime +from typing import Optional, List, Union + +logger = logging.getLogger(__name__) + +class EIAData: + """ + Class to fetch data from the EIA Open Data API. + """ + BASE_URL = "https://api.eia.gov/v2" + + def __init__(self, api_key: Optional[str] = None): + self.api_key = api_key or os.getenv("EIA_API_KEY") + if not self.api_key: + logger.warning("EIA_API_KEY not found in environment variables.") + + def get_hourly_solar_data( + self, + start_date: Union[str, datetime], + end_date: Union[str, datetime], + ba_codes: Optional[List[str]] = None, + timeout: int = 30 + ) -> pd.DataFrame: + """ + Fetch hourly solar generation data for specific Balancing Authorities or all available. + + Args: + start_date: Start date (inclusive) in 'YYYY-MM-DD' or 'YYYY-MM-DDTHH' format. + end_date: End date (inclusive) in 'YYYY-MM-DD' or 'YYYY-MM-DDTHH' format. + ba_codes: List of Balancing Authority codes (e.g., ['CISO', 'PJM']). If None, fetches for all. + timeout: Request timeout in seconds. + + Returns: + pd.DataFrame: DataFrame containing values, timestamps, and BA codes. + """ + if not self.api_key: + raise ValueError("API Key is required to fetch data.") + + # Ensure dates are strings in ISO format if they are datetime objects + if isinstance(start_date, datetime): + start_date = start_date.strftime("%Y-%m-%dT%H") + if isinstance(end_date, datetime): + end_date = end_date.strftime("%Y-%m-%dT%H") + + # Endpoint for hourly electricity generation by fuel type + # Route: electricity/rto/fuel-type-data + url = f"{self.BASE_URL}/electricity/rto/fuel-type-data/data/" + + params = { + "api_key": self.api_key, + "frequency": "hourly", + "data[0]": "value", + "facets[fueltype][]": "SUN", # Solar + "start": start_date, + "end": end_date, + "sort[0][column]": "period", + "sort[0][direction]": "asc", + "offset": 0, + "length": 5000, # Max length per page + } + + if ba_codes: + params["facets[respondent][]"] = ba_codes + + all_data = [] + offset = 0 + + while True: + current_params = params.copy() + current_params["offset"] = offset + try: + response = requests.get(url, params=current_params, timeout=timeout) + if not response.ok: + logger.error(f"EIA API Error: {response.text}") + response.raise_for_status() + data = response.json() + + if "response" not in data or "data" not in data["response"]: + logger.error(f"Unexpected response format: {data.keys()}") + break + + batch = data["response"]["data"] + if not batch: + break + + all_data.extend(batch) + + total = int(data["response"].get("total", 0)) + if len(all_data) >= total or len(batch) < 5000: + break + + offset += 5000 + + except requests.RequestException as e: + logger.error(f"Error fetching data from EIA: {e}") + raise + + if not all_data: + return pd.DataFrame() + + df = pd.DataFrame(all_data) + + df["period"] = pd.to_datetime(df["period"]) + + df = df.rename(columns={ + "period": "timestamp", + "value": "generation_mw", + "respondent": "ba_code", + "respondent-name": "ba_name" + }) + + cols_to_keep = ["timestamp", "ba_code", "ba_name", "generation_mw", "value-units"] + cols_to_keep = [c for c in cols_to_keep if c in df.columns] + + return df[cols_to_keep] diff --git a/src/open_data_pvnet/scripts/generate_combined_eia.py b/src/open_data_pvnet/scripts/generate_combined_eia.py new file mode 100644 index 0000000..3b8c34e --- /dev/null +++ b/src/open_data_pvnet/scripts/generate_combined_eia.py @@ -0,0 +1,209 @@ +""" +Generate Combined EIA Data Script + +This script fetches EIA data for all US Balancing Authorities (BAs) and combines them into a single Zarr dataset, +matching the format required by ocf-data-sampler (equivalent to UK's generate_combined_gsp.py). + +Usage: + python src/open_data_pvnet/scripts/generate_combined_eia.py --start-year 2020 --end-year 2024 --output-folder data + +Requirements: + - EIA_API_KEY environment variable + - pandas + - xarray + - zarr + - typer + +The script will: +1. Fetch data for all default BAs from EIA API (or specified BAs) +2. Preprocess data to match ocf-data-sampler format (ba_id, datetime_gmt) +3. Add capacity estimates +4. Convert to xarray Dataset and save as Zarr format +5. Output file: combined_eia_{start_date}_{end_date}.zarr + +Note: This script combines collection and preprocessing into a single step, matching the UK pattern. +""" + +import pandas as pd +import xarray as xr +import numpy as np +from datetime import datetime +from typing import Optional, List +import pytz +import os +import typer +import logging +from pathlib import Path +import sys + +# Add parent directory to path to import modules +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from open_data_pvnet.scripts.fetch_eia_data import EIAData +from open_data_pvnet.scripts.preprocess_eia_for_sampler import ( + create_ba_mapping, + estimate_capacity_from_generation, +) +from open_data_pvnet.utils.env_loader import load_environment_variables + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Major US ISOs/RTOs +DEFAULT_BAS = [ + 'CISO', # CAISO + 'ERCO', # ERCOT + 'PJM', # PJM + 'MISO', # MISO + 'NYIS', # NYISO + 'ISNE', # ISO-NE + 'SWPP', # SPP +] + +# BA Centroids (Approximate) +BA_CENTROIDS = { + 'CISO': {'latitude': 37.0, 'longitude': -120.0}, + 'ERCO': {'latitude': 31.0, 'longitude': -99.0}, + 'PJM': {'latitude': 40.0, 'longitude': -77.0}, + 'MISO': {'latitude': 40.0, 'longitude': -90.0}, + 'NYIS': {'latitude': 43.0, 'longitude': -75.0}, + 'ISNE': {'latitude': 44.0, 'longitude': -71.0}, + 'SWPP': {'latitude': 38.0, 'longitude': -98.0}, +} + + +def main( + start_year: int = typer.Option(2020, help="Start year for data collection"), + end_year: int = typer.Option(2025, help="End year for data collection"), + output_folder: str = typer.Option("data", help="Output folder for the zarr dataset"), + bas: Optional[List[str]] = typer.Option(None, help="List of BA codes (default: all major ISOs)"), + capacity_method: str = typer.Option("estimate", help="Method for capacity data (estimate/file/static)"), +): + """ + Generate combined EIA data for all BAs and save as a zarr dataset. + + This matches the UK generate_combined_gsp.py pattern but for US EIA data. + """ + try: + load_environment_variables() + except Exception as e: + logger.warning(f"Could not load environment variables: {e}") + + range_start = datetime(start_year, 1, 1, tzinfo=pytz.UTC) + range_end = datetime(end_year, 1, 1, tzinfo=pytz.UTC) + + # Use default BAs if not specified + if bas is None: + bas = DEFAULT_BAS + + eia = EIAData() + if not eia.api_key: + logger.error("EIA_API_KEY not set. Exiting.") + raise typer.Exit(code=1) + + logger.info(f"Fetching EIA data from {range_start.date()} to {range_end.date()} for BAs: {bas}") + + # Fetch data for all BAs + df = eia.get_hourly_solar_data( + start_date=range_start.strftime("%Y-%m-%d"), + end_date=range_end.strftime("%Y-%m-%d"), + ba_codes=bas + ) + + if df.empty: + logger.error("No data retrieved for any BAs - terminating") + raise typer.Exit(code=1) + + logger.info(f"Fetched {len(df)} rows for {df['ba_code'].nunique()} BAs") + + # Add coordinates + df["latitude"] = df["ba_code"].map(lambda x: BA_CENTROIDS.get(x, {}).get('latitude', np.nan)) + df["longitude"] = df["ba_code"].map(lambda x: BA_CENTROIDS.get(x, {}).get('longitude', np.nan)) + + # Rename timestamp to datetime_gmt and ensure proper format + if "timestamp" in df.columns: + df["datetime_gmt"] = pd.to_datetime(df["timestamp"], utc=True) + df["datetime_gmt"] = df["datetime_gmt"].dt.tz_convert(None) + df = df.drop(columns=["timestamp"]) + elif "datetime_gmt" not in df.columns: + logger.error("No timestamp or datetime_gmt column found") + raise typer.Exit(code=1) + + # Ensure generation_mw is numeric + if "generation_mw" in df.columns: + df["generation_mw"] = pd.to_numeric(df["generation_mw"], errors="coerce") + else: + logger.error("No generation_mw column found") + raise typer.Exit(code=1) + + # Create BA mapping (ba_code -> ba_id) + ba_to_id, metadata = create_ba_mapping(df) + df["ba_id"] = df["ba_code"].map(ba_to_id) + + # Handle capacity data + if capacity_method == "estimate": + logger.info("Estimating capacity from maximum generation") + capacity_estimates = estimate_capacity_from_generation(df) + df["capacity_mwp"] = df["ba_code"].map(capacity_estimates) + else: + logger.warning(f"Capacity method '{capacity_method}' not fully implemented in this script") + # Fallback to estimate + capacity_estimates = estimate_capacity_from_generation(df) + df["capacity_mwp"] = df["ba_code"].map(capacity_estimates) + + # Validate capacity data + if df["capacity_mwp"].isna().any(): + logger.warning("Some BAs have missing capacity data, filling with estimates") + missing_bas = df[df["capacity_mwp"].isna()]["ba_code"].unique() + for ba in missing_bas: + ba_gen = df[df["ba_code"] == ba]["generation_mw"].max() + estimated_capacity = ba_gen * 1.15 if not pd.isna(ba_gen) else 100.0 + df.loc[df["ba_code"] == ba, "capacity_mwp"] = estimated_capacity + + # Ensure capacity >= generation (with small tolerance) + df["capacity_mwp"] = df[["capacity_mwp", "generation_mw"]].max(axis=1) * 1.01 + + # Select and reorder columns + columns_to_keep = ["ba_id", "datetime_gmt", "generation_mw", "capacity_mwp"] + if "ba_code" in df.columns: + columns_to_keep.append("ba_code") + if "ba_name" in df.columns: + columns_to_keep.append("ba_name") + if "latitude" in df.columns: + columns_to_keep.append("latitude") + if "longitude" in df.columns: + columns_to_keep.append("longitude") + + df_processed = df[columns_to_keep].copy() + + # Set index to match UK format: (ba_id, datetime_gmt) + df_processed = df_processed.set_index(["ba_id", "datetime_gmt"]) + + # Convert to xarray Dataset + ds_processed = xr.Dataset.from_dataframe(df_processed) + + # Ensure datetime_gmt is datetime64[ns] (no timezone) + if "datetime_gmt" in ds_processed.coords: + ds_processed.coords["datetime_gmt"] = ds_processed.coords["datetime_gmt"].astype(np.datetime64) + + # Apply chunking like UK implementation + ds_processed = ds_processed.chunk({"ba_id": 1, "datetime_gmt": 1000}) + + # Save to Zarr + os.makedirs(output_folder, exist_ok=True) + filename = f"combined_eia_{range_start.date()}_{range_end.date()}.zarr" + output_path = os.path.join(output_folder, filename) + ds_processed.to_zarr(output_path, mode="w", consolidated=True) + + logger.info(f"Successfully saved combined EIA dataset to {output_path}") + logger.info(f"Dataset contains {len(ba_to_id)} BAs for period {range_start.date()} to {range_end.date()}") + + # Also save metadata CSV + metadata_path = os.path.join(output_folder, f"us_ba_metadata_{range_start.date()}_{range_end.date()}.csv") + metadata.to_csv(metadata_path, index=False) + logger.info(f"BA metadata saved to {metadata_path}") + + +if __name__ == "__main__": + typer.run(main) + diff --git a/src/open_data_pvnet/scripts/preprocess_eia_for_sampler.py b/src/open_data_pvnet/scripts/preprocess_eia_for_sampler.py new file mode 100644 index 0000000..d71aff2 --- /dev/null +++ b/src/open_data_pvnet/scripts/preprocess_eia_for_sampler.py @@ -0,0 +1,318 @@ +""" +Preprocess EIA Data for ocf-data-sampler + +This script converts raw EIA data collected by collect_eia_data.py into the format +expected by ocf-data-sampler, matching the UK GSP data structure. + +UK GSP Format: +- Dimensions: (gsp_id, datetime_gmt) where gsp_id is int64 +- Variables: generation_mw, capacity_mwp, installedcapacity_mwp +- Chunking: {"gsp_id": 1, "datetime_gmt": 1000} + +US EIA Format (input): +- Dimensions: (timestamp, ba_code) where ba_code is string +- Variables: generation_mw, ba_name, latitude, longitude + +US EIA Format (output): +- Dimensions: (ba_id, datetime_gmt) where ba_id is int64 +- Variables: generation_mw, capacity_mwp +- Coordinates: ba_code, ba_name, latitude, longitude + +Usage: + python src/open_data_pvnet/scripts/preprocess_eia_for_sampler.py \ + --input src/open_data_pvnet/data/target_eia_data.zarr \ + --output src/open_data_pvnet/data/target_eia_data_processed.zarr \ + --metadata-output src/open_data_pvnet/data/us_ba_metadata.csv +""" + +import pandas as pd +import xarray as xr +import numpy as np +import logging +import os +import argparse +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def estimate_capacity_from_generation( + df: pd.DataFrame, ba_col: str = "ba_code", gen_col: str = "generation_mw" +) -> pd.Series: + """ + Estimate capacity from maximum historical generation. + + This is a simple heuristic: capacity ≈ max(generation) * safety_factor + The safety factor accounts for the fact that max generation is typically + less than installed capacity (due to weather, maintenance, etc.) + + Args: + df: DataFrame with generation data + ba_col: Column name for BA identifier + gen_col: Column name for generation values + + Returns: + Series with capacity estimates indexed by BA code + """ + # Group by BA and find max generation + max_gen = df.groupby(ba_col)[gen_col].max() + + # Apply safety factor (typically max generation is 70-90% of capacity) + # Using 0.85 as a reasonable estimate + safety_factor = 1.15 # 1/0.85 ≈ 1.15 to get capacity from max gen + capacity = max_gen * safety_factor + + # Ensure minimum capacity (at least 100 MW for major BAs) + capacity = capacity.clip(lower=100.0) + + logger.info(f"Estimated capacity for {len(capacity)} BAs") + logger.debug(f"Capacity range: {capacity.min():.2f} - {capacity.max():.2f} MW") + + return capacity + + +def create_ba_mapping(df: pd.DataFrame, ba_col: str = "ba_code") -> tuple[dict, pd.DataFrame]: + """ + Create mapping from BA codes to numeric IDs. + + Args: + df: DataFrame with BA codes + ba_col: Column name for BA codes + + Returns: + Tuple of (ba_code_to_id dict, metadata DataFrame) + """ + unique_bas = sorted(df[ba_col].unique()) + ba_to_id = {ba: idx for idx, ba in enumerate(unique_bas)} + + # Create metadata DataFrame + metadata = pd.DataFrame({ + "ba_id": list(ba_to_id.values()), + "ba_code": list(ba_to_id.keys()), + }) + + # Add coordinates if available + if "latitude" in df.columns and "longitude" in df.columns: + coords = df.groupby(ba_col)[["latitude", "longitude"]].first() + metadata = metadata.merge( + coords.reset_index(), + on="ba_code", + how="left" + ) + + # Add BA names if available + if "ba_name" in df.columns: + names = df.groupby(ba_col)["ba_name"].first() + metadata = metadata.merge( + names.reset_index(), + on="ba_code", + how="left" + ) + + logger.info(f"Created mapping for {len(ba_to_id)} BAs") + + return ba_to_id, metadata + + +def preprocess_eia_data( + input_path: str, + output_path: str, + metadata_output_path: str = None, + capacity_method: str = "estimate", + capacity_file: str = None, +) -> str: + """ + Preprocess EIA data to match ocf-data-sampler format. + + Args: + input_path: Path to input EIA Zarr/NetCDF file + output_path: Path to output processed Zarr file + metadata_output_path: Path to save BA metadata CSV + capacity_method: Method for capacity data ("estimate", "file", "static") + capacity_file: Path to capacity data file (if method is "file") + + Returns: + Path to output file + """ + logger.info(f"Loading EIA data from {input_path}") + + if input_path.endswith(".zarr"): + ds = xr.open_dataset(input_path, engine="zarr") + else: + ds = xr.open_dataset(input_path) + + df = ds.to_dataframe() + if isinstance(df.index, pd.MultiIndex): + df = df.reset_index() + else: + if df.index.name in ["timestamp", "datetime_gmt"]: + df = df.reset_index() + + logger.info(f"Loaded {len(df)} rows") + + if "ba_code" not in df.columns: + if isinstance(ds.indexes.get("ba_code"), pd.Index): + df = df.reset_index() + else: + raise ValueError("ba_code not found in dataset. Check input data format.") + + logger.info(f"Loaded {len(df)} rows for {df['ba_code'].nunique()} BAs") + + if "timestamp" in df.columns: + df["datetime_gmt"] = pd.to_datetime(df["timestamp"], utc=True) + df["datetime_gmt"] = df["datetime_gmt"].dt.tz_convert(None) + df = df.drop(columns=["timestamp"]) + elif "datetime_gmt" not in df.columns: + raise ValueError("No timestamp or datetime_gmt column found") + + if "generation_mw" in df.columns: + df["generation_mw"] = pd.to_numeric(df["generation_mw"], errors="coerce") + else: + raise ValueError("No generation_mw column found") + + ba_to_id, metadata = create_ba_mapping(df) + df["ba_id"] = df["ba_code"].map(ba_to_id) + + if capacity_method == "estimate": + logger.info("Estimating capacity from maximum generation") + capacity_estimates = estimate_capacity_from_generation(df) + df["capacity_mwp"] = df["ba_code"].map(capacity_estimates) + elif capacity_method == "file" and capacity_file: + logger.info(f"Loading capacity from {capacity_file}") + capacity_df = pd.read_csv(capacity_file) + capacity_map = dict(zip(capacity_df["ba_code"], capacity_df["capacity_mwp"])) + df["capacity_mwp"] = df["ba_code"].map(capacity_map) + elif capacity_method == "static": + logger.warning("Using static capacity values (not recommended)") + df["capacity_mwp"] = 1000.0 + else: + raise ValueError(f"Invalid capacity_method: {capacity_method}") + + if df["capacity_mwp"].isna().any(): + logger.warning("Some BAs have missing capacity data, filling with estimates") + missing_bas = df[df["capacity_mwp"].isna()]["ba_code"].unique() + for ba in missing_bas: + ba_gen = df[df["ba_code"] == ba]["generation_mw"].max() + estimated_capacity = ba_gen * 1.15 if not pd.isna(ba_gen) else 100.0 + df.loc[df["ba_code"] == ba, "capacity_mwp"] = estimated_capacity + + df["capacity_mwp"] = df[["capacity_mwp", "generation_mw"]].max(axis=1) * 1.01 + + columns_to_keep = ["ba_id", "datetime_gmt", "generation_mw", "capacity_mwp"] + if "ba_code" in df.columns: + columns_to_keep.append("ba_code") + if "ba_name" in df.columns: + columns_to_keep.append("ba_name") + if "latitude" in df.columns: + columns_to_keep.append("latitude") + if "longitude" in df.columns: + columns_to_keep.append("longitude") + + df_processed = df[columns_to_keep].copy() + + df_processed = df_processed.set_index(["ba_id", "datetime_gmt"]) + + ds_processed = xr.Dataset.from_dataframe(df_processed) + + # Ensure datetime_gmt is datetime64[ns] (no timezone) + if "datetime_gmt" in ds_processed.coords: + ds_processed.coords["datetime_gmt"] = ds_processed.coords["datetime_gmt"].astype(np.datetime64) + + # Apply chunking like UK implementation + # We'll use: {"ba_id": 1, "datetime_gmt": 1000} + try: + import dask + ds_processed = ds_processed.chunk({"ba_id": 1, "datetime_gmt": 1000}) + except ImportError: + logger.warning("Dask not installed, skipping chunking. Performance may be affected.") + + # Ensure output directory exists + output_dir = os.path.dirname(os.path.abspath(output_path)) + os.makedirs(output_dir, exist_ok=True) + + # Save to Zarr with consolidated metadata + logger.info(f"Saving processed data to {output_path}") + ds_processed.to_zarr(output_path, mode="w", consolidated=True) + + # Save metadata CSV + if metadata_output_path: + metadata_dir = os.path.dirname(os.path.abspath(metadata_output_path)) + os.makedirs(metadata_dir, exist_ok=True) + metadata.to_csv(metadata_output_path, index=False) + logger.info(f"Saved BA metadata to {metadata_output_path}") + + logger.info(f"✅ Successfully preprocessed EIA data") + logger.info(f" Output: {output_path}") + logger.info(f" Dimensions: {dict(ds_processed.dims)}") + logger.info(f" Variables: {list(ds_processed.data_vars)}") + + return output_path + + +def main(): + """CLI entry point.""" + parser = argparse.ArgumentParser( + description="Preprocess EIA data for ocf-data-sampler compatibility" + ) + parser.add_argument( + "--input", + type=str, + required=True, + help="Input EIA data file (Zarr or NetCDF)" + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output processed Zarr file path" + ) + parser.add_argument( + "--metadata-output", + type=str, + default=None, + help="Output path for BA metadata CSV (optional)" + ) + parser.add_argument( + "--capacity-method", + type=str, + choices=["estimate", "file", "static"], + default="estimate", + help="Method for obtaining capacity data (default: estimate)" + ) + parser.add_argument( + "--capacity-file", + type=str, + default=None, + help="Path to capacity data CSV file (if --capacity-method=file)" + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level" + ) + + args = parser.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level), + format='%(asctime)s - %(levelname)s - %(message)s' + ) + + try: + preprocess_eia_data( + input_path=args.input, + output_path=args.output, + metadata_output_path=args.metadata_output, + capacity_method=args.capacity_method, + capacity_file=args.capacity_file, + ) + except Exception as e: + logger.error(f"Failed to preprocess data: {e}", exc_info=True) + raise + + +if __name__ == "__main__": + main() + diff --git a/src/open_data_pvnet/scripts/upload_eia_to_s3.py b/src/open_data_pvnet/scripts/upload_eia_to_s3.py new file mode 100644 index 0000000..a7be882 --- /dev/null +++ b/src/open_data_pvnet/scripts/upload_eia_to_s3.py @@ -0,0 +1,185 @@ +import logging +import argparse +import os +import sys +from pathlib import Path +from typing import Optional, List +import boto3 +from botocore.exceptions import NoCredentialsError, ClientError +from datetime import datetime +import yaml + +# Add parent directory to path to import modules if needed +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from open_data_pvnet.utils.env_loader import load_environment_variables + +logger = logging.getLogger(__name__) + +def load_s3_config(config_path: str = None) -> dict: + """Load S3 configuration from yaml file.""" + if config_path is None: + # Default to standard location + config_path = str(Path(__file__).parent.parent / "configs" / "eia_s3_config.yaml") + + if os.path.exists(config_path): + with open(config_path, "r") as f: + return yaml.safe_load(f).get("s3", {}) + return {} + +def get_s3_client(dry_run: bool = False): + """Get authenticated S3 client.""" + if dry_run: + return None + + try: + # Check for credentials + session = boto3.Session() + credentials = session.get_credentials() + if not credentials: + logger.error("No AWS credentials found. Please configure them via 'aws configure' or env vars.") + return None + return boto3.client("s3") + except Exception as e: + logger.error(f"Failed to create S3 client: {e}") + return None + +def check_bucket_access(s3_client, bucket: str) -> bool: + """Check if bucket exists and is accessible.""" + if s3_client is None: return True # Dry run assumes access + + try: + s3_client.head_bucket(Bucket=bucket) + return True + except ClientError as e: + error_code = int(e.response['Error']['Code']) + if error_code == 404: + logger.error(f"Bucket '{bucket}' does not exist.") + elif error_code == 403: + logger.error(f"Access denied to bucket '{bucket}'. check permissions.") + else: + logger.error(f"Error accessing bucket '{bucket}': {e}") + return False + +def upload_file( + s3_client, + local_path: str, + bucket: str, + s3_key: str, + dry_run: bool = False, + public: bool = False +) -> bool: + """Upload a single file to S3.""" + if dry_run: + logger.info(f"[DRY RUN] Would upload {local_path} to s3://{bucket}/{s3_key}") + return True + + try: + extra_args = {} + if public: + extra_args['ACL'] = 'public-read' + + logger.info(f"Uploading {local_path} to s3://{bucket}/{s3_key}") + s3_client.upload_file(local_path, bucket, s3_key, ExtraArgs=extra_args) + return True + except Exception as e: + logger.error(f"Failed to upload {local_path}: {e}") + return False + +import posixpath + +def upload_directory_to_s3( + local_dir: str, + bucket: str, + prefix: str, + dry_run: bool = False, + public: bool = False +) -> bool: + """Upload a directory (e.g. Zarr store) to S3 recursively.""" + s3_client = get_s3_client(dry_run) + if not dry_run and not s3_client: + return False + + if not check_bucket_access(s3_client, bucket): + return False + + local_path = Path(local_dir) + if not local_path.exists(): + logger.error(f"Local path {local_dir} does not exist.") + return False + + prefix = prefix.strip("/") + failed_uploads = [] + + if local_path.is_file(): + s3_key = posixpath.join(prefix, local_path.name) + if not upload_file(s3_client, str(local_path), bucket, s3_key, dry_run, public): + failed_uploads.append(str(local_path)) + + elif local_path.is_dir(): + for root, _, files in os.walk(local_path): + for file in files: + full_path = Path(root) / file + relative_path = full_path.relative_to(local_path) + + relative_path_str = str(relative_path).replace(os.sep, "/") + s3_key = posixpath.join(prefix, local_path.name, relative_path_str) + + if not upload_file(s3_client, str(full_path), bucket, s3_key, dry_run, public): + failed_uploads.append(str(full_path)) + + if failed_uploads: + logger.error(f"❌ Failed to upload {len(failed_uploads)} files:") + for f in failed_uploads[:10]: + logger.error(f" - {f}") + if len(failed_uploads) > 10: + logger.error(f" ... and {len(failed_uploads) - 10} more.") + return False + + return True + +def main(): + parser = argparse.ArgumentParser(description="Upload EIA data to S3") + parser.add_argument("--input", required=True, help="Path to file or directory to upload (e.g. Zarr store)") + parser.add_argument("--bucket", default="ocf-open-data-pvnet", help="S3 Bucket name") + parser.add_argument("--prefix", default="data/us/eia", help="Base S3 prefix") + parser.add_argument("--version", default="latest", help="Version string (e.g. v1, 2024-01-01)") + parser.add_argument("--dry-run", action="store_true", help="Simulate upload without actual transfer") + parser.add_argument("--public", action="store_true", help="Set ACL to public-read") + parser.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]) + + args = parser.parse_args() + + logging.basicConfig(level=getattr(logging, args.log_level), format='%(asctime)s - %(levelname)s - %(message)s') + + # Construct full prefix + full_prefix = posixpath.join(args.prefix, args.version) + + logger.info("="*60) + logger.info(f"S3 Upload Utility") + logger.info(f"Input: {args.input}") + logger.info(f"Target: s3://{args.bucket}/{full_prefix}") + logger.info(f"Dry Run: {args.dry_run}") + logger.info(f"Public Access: {args.public}") + logger.info("="*60) + + try: + load_environment_variables() + except Exception: + pass + + if upload_directory_to_s3( + local_dir=args.input, + bucket=args.bucket, + prefix=full_prefix, + dry_run=args.dry_run, + public=args.public + ): + logger.info("✅ Upload completed successfully.") + return 0 + else: + logger.error("❌ Upload failed.") + return 1 + +if __name__ == "__main__": + exit(main()) diff --git a/tests/test_collect_eia.py b/tests/test_collect_eia.py new file mode 100644 index 0000000..2375572 --- /dev/null +++ b/tests/test_collect_eia.py @@ -0,0 +1,40 @@ +import pytest +import pandas as pd +import numpy as np +from open_data_pvnet.scripts.fetch_eia_data import EIAData +from unittest.mock import MagicMock, patch +import os +import shutil + +# Mock EIAData since we tested it separately, we just want to test collector logic +from open_data_pvnet.scripts.collect_eia_data import main as collect_main +import xarray as xr + +@pytest.fixture +def mock_args(): + return ["--start", "2023-01-01", "--end", "2023-01-02", "--bas", "CISO", "--output", "tmp_test_output.zarr"] + +def test_collect_data(mock_args): + mock_df = pd.DataFrame({ + "timestamp": ["2023-01-01T00", "2023-01-01T01"], + "ba_code": ["CISO", "CISO"], + "value": [100, 200], + "ba_name": ["CAISO", "CAISO"], + "value-units": ["MWh", "MWh"] + }) + + with patch("open_data_pvnet.scripts.collect_eia_data.EIAData") as MockEIA: + instance = MockEIA.return_value + instance.api_key = "test_key" + instance.get_hourly_solar_data.return_value = mock_df + + with patch("sys.argv", ["script_name"] + mock_args): + collect_main() + + assert os.path.exists("tmp_test_output.zarr") + ds = xr.open_zarr("tmp_test_output.zarr", consolidated=False) + assert "timestamp" in ds.coords + assert "ba_code" in ds.coords + assert "latitude" in ds.data_vars or "latitude" in ds.coords + assert ds["latitude"].values[0] == 37.0 # CISO lat + shutil.rmtree("tmp_test_output.zarr") diff --git a/tests/test_eia_fetcher.py b/tests/test_eia_fetcher.py new file mode 100644 index 0000000..5ec2521 --- /dev/null +++ b/tests/test_eia_fetcher.py @@ -0,0 +1,130 @@ +import pytest +import pandas as pd +from unittest.mock import MagicMock, patch +from open_data_pvnet.scripts.fetch_eia_data import EIAData + +@pytest.fixture +def mock_response_data(): + return { + "response": { + "total": 2, + "data": [ + { + "period": "2023-01-01T00", + "respondent": "CISO", + "respondent-name": "California Independent System Operator", + "fueltypeid": "SUN", + "type-name": "Solar", + "value": 1000, + "value-units": "megawatthours" + }, + { + "period": "2023-01-01T01", + "respondent": "CISO", + "respondent-name": "California Independent System Operator", + "fueltypeid": "SUN", + "type-name": "Solar", + "value": 1200, + "value-units": "megawatthours" + } + ] + } + } + +def test_init_no_api_key(): + with patch.dict("os.environ", {}, clear=True): + eia = EIAData() + assert eia.api_key is None + +def test_init_with_env_var(): + with patch.dict("os.environ", {"EIA_API_KEY": "test_key"}, clear=True): + eia = EIAData() + assert eia.api_key == "test_key" + +def test_get_hourly_solar_data_success(mock_response_data): + eia = EIAData(api_key="test_key") + + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + df = eia.get_hourly_solar_data( + start_date="2023-01-01T00", + end_date="2023-01-01T01", + ba_codes=["CISO"] + ) + + assert not df.empty + assert len(df) == 2 + assert "timestamp" in df.columns + assert "generation_mw" in df.columns + assert df.iloc[0]["generation_mw"] == 1000 + assert df.iloc[1]["generation_mw"] == 1200 + assert pd.api.types.is_datetime64_any_dtype(df["timestamp"]) + + # Verify call args + # Check that requests.get was called with correct params + assert mock_get.called + call_args, call_kwargs = mock_get.call_args + params = call_kwargs["params"] + + # Verify basic params + assert params["api_key"] == "test_key" + assert params["frequency"] == "hourly" + + # Verify fueltype facet - check if key exists (may have brackets) + fueltype_found = False + for key in params.keys(): + if "fueltype" in str(key) and params[key] == "SUN": + fueltype_found = True + break + assert fueltype_found, f"Expected 'facets[fueltype][]': 'SUN' in params, got {params}" + + # Verify respondent facet - check if key exists and value matches + respondent_found = False + for key in params.keys(): + if "respondent" in str(key): + # Value might be a list or single item + value = params[key] + if value == ["CISO"] or (isinstance(value, list) and "CISO" in value): + respondent_found = True + break + assert respondent_found, f"Expected 'facets[respondent][]': ['CISO'] in params, got {params}" + +def test_get_hourly_solar_data_pagination(): + eia = EIAData(api_key="test_key") + + # Create a scenario where total is 6000 and we get 5000 in first batch + first_batch = [{"period": "2023-01-01T00", "value": i, "respondent": "CISO"} for i in range(5000)] + second_batch = [{"period": "2023-01-02T00", "value": i, "respondent": "CISO"} for i in range(1000)] + + response1 = {"response": {"total": 6000, "data": first_batch}} + response2 = {"response": {"total": 6000, "data": second_batch}} + + with patch("requests.get") as mock_get: + mock_response1 = MagicMock() + mock_response1.json.return_value = response1 + + mock_response2 = MagicMock() + mock_response2.json.return_value = response2 + + # Side effect to return different responses + mock_get.side_effect = [mock_response1, mock_response2] + + df = eia.get_hourly_solar_data("2023-01-01", "2023-01-02") + + assert len(df) == 6000 + assert mock_get.call_count == 2 + + # Check offsets + call_args_list = mock_get.call_args_list + assert call_args_list[0][1]["params"]["offset"] == 0 + assert call_args_list[1][1]["params"]["offset"] == 5000 + +def test_missing_api_key_error(): + with patch.dict("os.environ", {}, clear=True): + eia = EIAData(api_key=None) + with pytest.raises(ValueError, match="API Key is required"): + eia.get_hourly_solar_data("2023-01-01", "2023-01-02") diff --git a/tests/test_eia_sampler_compatibility.py b/tests/test_eia_sampler_compatibility.py new file mode 100644 index 0000000..a415f0f --- /dev/null +++ b/tests/test_eia_sampler_compatibility.py @@ -0,0 +1,101 @@ +import pytest +import xarray as xr +import numpy as np +import pandas as pd +from pathlib import Path +import logging + +logger = logging.getLogger(__name__) + +@pytest.fixture +def data_path(): + return "src/open_data_pvnet/data/target_eia_data_processed.zarr" + +@pytest.fixture +def config_path(): + return "src/open_data_pvnet/configs/PVNet_configs/datamodule/configuration/us_configuration.yaml" + + +def test_zarr_structure(data_path): + """Test basic Zarr structure and format.""" + if not Path(data_path).exists(): + pytest.skip(f"Data file not found at {data_path}") + + logger.info(f"Testing Zarr structure: {data_path}") + + try: + ds = xr.open_dataset(data_path, engine="zarr") + + # Check required dimensions + required_dims = ["ba_id", "datetime_gmt"] + missing_dims = [d for d in required_dims if d not in ds.dims] + if missing_dims: + pytest.fail(f"Missing required dimensions: {missing_dims}") + + # Check required variables + required_vars = ["generation_mw", "capacity_mwp"] + missing_vars = [v for v in required_vars if v not in ds.data_vars] + if missing_vars: + pytest.fail(f"Missing required variables: {missing_vars}") + + # Check datetime_gmt format + if "datetime_gmt" in ds.coords: + dt_coord = ds.coords["datetime_gmt"] + if not np.issubdtype(dt_coord.dtype, np.datetime64): + logger.warning(f"datetime_gmt is not datetime64: {dt_coord.dtype}") + + # Check ba_id format + if "ba_id" in ds.coords: + ba_coord = ds.coords["ba_id"] + if not np.issubdtype(ba_coord.dtype, np.integer): + logger.warning(f"ba_id is not integer: {ba_coord.dtype}") + + # Check data ranges + gen_data = ds["generation_mw"] + cap_data = ds["capacity_mwp"] + + # Check that capacity >= generation (with tolerance) + if (gen_data > cap_data * 1.1).any(): + logger.warning("Some generation values exceed capacity") + + ds.close() + + except Exception as e: + pytest.fail(f"Failed to open/validate Zarr: {e}") + + +def test_ocf_data_sampler_compatibility(data_path, config_path): + """Test compatibility with ocf-data-sampler.""" + if not Path(data_path).exists(): + pytest.skip(f"Data file not found at {data_path}") + + logger.info(f"Testing ocf-data-sampler compatibility") + + try: + from ocf_data_sampler.config import load_yaml_configuration + + # Load configuration + config = load_yaml_configuration(config_path) + + # Load data + ds = xr.open_dataset(data_path, engine="zarr") + + # Test find_valid_time_periods + try: + # Importing here to avoiding top level failure if package is missing + from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods + valid_times = find_valid_time_periods({"gsp": ds}, config) + + logger.info(f"Found {len(valid_times)} valid time periods") + + except ImportError: + pytest.skip("ocf_data_sampler not installed or internal path changed") + except Exception as e: + pytest.fail(f"find_valid_time_periods failed: {e}") + finally: + ds.close() + + except ImportError as e: + pytest.skip(f"ocf-data-sampler not available: {e}") + except Exception as e: + pytest.fail(f"ocf-data-sampler compatibility test failed: {e}") diff --git a/tests/test_upload_s3.py b/tests/test_upload_s3.py new file mode 100644 index 0000000..17108ad --- /dev/null +++ b/tests/test_upload_s3.py @@ -0,0 +1,75 @@ +import unittest +from unittest.mock import MagicMock, patch +import os +import shutil +import tempfile +from pathlib import Path + +from open_data_pvnet.scripts.upload_eia_to_s3 import upload_file, check_bucket_access, upload_directory_to_s3 + +class TestS3Upload(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.test_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.test_dir, "test.txt") + with open(self.test_file, "w") as f: + f.write("test content") + + def tearDown(self): + # Remove the directory after the test + shutil.rmtree(self.test_dir) + + def test_upload_file_dry_run(self): + """Test upload_file with dry_run=True""" + mock_client = MagicMock() + result = upload_file( + mock_client, + self.test_file, + "my-bucket", + "key", + dry_run=True + ) + self.assertTrue(result) + mock_client.upload_file.assert_not_called() + + def test_upload_file_real(self): + """Test upload_file with dry_run=False""" + mock_client = MagicMock() + result = upload_file( + mock_client, + self.test_file, + "my-bucket", + "key", + dry_run=False + ) + self.assertTrue(result) + mock_client.upload_file.assert_called_once() + + def test_check_bucket_access_success(self): + mock_client = MagicMock() + result = check_bucket_access(mock_client, "my-bucket") + self.assertTrue(result) + mock_client.head_bucket.assert_called_with(Bucket="my-bucket") + + def test_check_bucket_access_dry_run(self): + result = check_bucket_access(None, "my-bucket") # None client implies dry_run in usage + self.assertTrue(result) + + @patch("open_data_pvnet.scripts.upload_eia_to_s3.get_s3_client") + def test_upload_directory(self, mock_get_client): + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + result = upload_directory_to_s3( + self.test_dir, + "my-bucket", + "prefix", + dry_run=False + ) + + self.assertTrue(result) + # Should be called for the one file in temp dir + mock_client.upload_file.assert_called() + +if __name__ == '__main__': + unittest.main()