Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions datasets/goes16_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,57 @@
import torch
from torch.utils.data import Dataset

from .preprocessing import handle_nans, crop_center, normalize_channels

class GOES16ProxyDataset(Dataset):
"""
Streams consecutive frame pairs directly from NOAA's public GOES-16 S3 bucket.

Product: 'ABI-L2-CMIPC' (Cloud & Moisture Imagery - CONUS standard projection)
Bands: 'C09' (6.9um Mid-Level Water Vapor) or 'C14' (11.2um Longwave Thermal IR)
"""
def __init__(self, product='ABI-L2-CMIPC', band='C09', year=2025, day_of_year=150, hour=14):
def __init__( self, product='ABI-L2-CMIPC', band='C09', year=2025, day_of_year=150, hour=14,
crop_size=512, normalization_method='minmax'):

self.bucket_name = 'noaa-goes16'
self.product = product
self.band = band
self.crop_size = crop_size
self.normalization_method = normalization_method

# Configure anonymous public access to bypass mandatory AWS credential steps
self.s3 = boto3.client('s3', region_name='us-east-1',
config=Config(signature_version=UNSIGNED))

# S3 Path syntax structure: Product/Year/Day_of_Year/Hour/
# S3 path syntax structure: Product/Year/Day_of_Year/Hour/
self.prefix = f"{product}/{year}/{day_of_year:03d}/{hour:02d}/"
self.file_list = self._get_s3_file_list()

if len(self.file_list) < 2:
raise ValueError(f"Not enough GOES-16 files found for {self.prefix}")

def _get_s3_file_list(self):
response = self.s3.list_objects_v2(Bucket=self.bucket_name, Prefix=self.prefix)
if 'Contents' not in response:
raise FileNotFoundError(f"No files available matching path: s3://{self.bucket_name}/{self.prefix}")

# Match only files containing our desired band descriptor suffix
files = [obj['Key'] for obj in response['Contents'] if f"M6{self.band}" in obj['Key']]
# Match only files containing our desired band descriptor suffix and NetCDF files.
files = [obj['Key'] for obj in response['Contents']
if f"M6{self.band}" in obj['Key'] and obj['Key'].endswith('.nc')]
return sorted(files)

def _download_and_parse(self, s3_key):
local_filename = s3_key.split('/')[-1]
local_filename = os.path.basename(s3_key)
if not os.path.exists(local_filename):
print(f"Streaming {local_filename} down to disk runtime...")
self.s3.download_file(self.bucket_name, s3_key, local_filename)

# Parse matrix metadata
with nc.Dataset(local_filename, 'r') as rootgrp:
data_matrix = np.array(rootgrp.variables['CMI'][:], dtype=np.float32)
data_matrix = np.nan_to_num(data_matrix, nan=0.0)

h, w = data_matrix.shape
data_matrix = data_matrix[h//2-256 : h//2+256, w//2-256 : w//2+256]
data_matrix = handle_nans(data_matrix, replacement_val=0.0, method='mean')
data_matrix = crop_center(data_matrix, crop_h=self.crop_size, crop_w=self.crop_size)
data_matrix = normalize_channels(data_matrix, method=self.normalization_method)

return torch.from_numpy(data_matrix).unsqueeze(0) # Shape: [1, H, W]

Expand All @@ -58,7 +68,6 @@ def __getitem__(self, idx):
frame_t = self._download_and_parse(self.file_list[idx])
frame_t_next = self._download_and_parse(self.file_list[idx + 1])

# Terrain Context layer (DEM mapping simulation) matching image geometry
_, h, w = frame_t.shape
mock_dem = torch.zeros((1, h, w), dtype=torch.float32)

Expand Down
92 changes: 68 additions & 24 deletions datasets/insat3ds_loader.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,100 @@
import glob
import os
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset

from .preprocessing import handle_nans, crop_center, normalize_channels

class INSAT3DSProxyDataset(Dataset):
"""
Placeholder loader for INSAT-3DS HDF5 (.h5) datasets.
Loader scaffold for INSAT-3DS HDF5 (.h5) datasets.

This class outlines the required API signatures, calibration parameters, and
metadata mappings for IMAGER payload channels.
"""
def __init__(self, data_dir=None, band='WV', year=2024, day_of_year=120, hour=14):
# TODO: scientific validation required
# Deferred from audit phase: INSAT-3DS HDF5 pipeline construction
def __init__(self, data_dir=None, band='WV', year=2024, day_of_year=120, hour=14, crop_size=512, normalization_method='minmax'):
self.data_dir = data_dir
self.band = band
self.year = year
self.day_of_year = day_of_year
self.hour = hour

# In a real implementation, scan the data_dir for matching H5 files
self.file_list = []

self.crop_size = crop_size
self.normalization_method = normalization_method

self.file_list = self._get_file_list()

def _get_file_list(self):
if not self.data_dir or not os.path.isdir(self.data_dir):
return []

patterns = [os.path.join(self.data_dir, '*.h5'), os.path.join(self.data_dir, '*.H5')]
files = []
for pattern in patterns:
files.extend(glob.glob(pattern))

# Use simple band filtering when filenames include channel labels
filtered = [f for f in files if self.band in os.path.basename(f)]
return sorted(filtered)

def _read_hdf5_counts(self, file_path):
"""
Reads raw digital counts from INSAT-3DS HDF5 datasets.
Reads raw digital counts from an INSAT-3DS HDF5 file.

This implementation is intentionally minimal until the HDF5 schema is known.
"""
# TODO: scientific validation required
# Placeholder for h5py loading and group/dataset indexing (e.g., '/IMG_WV' or '/IMG_TIR1')
raise NotImplementedError("INSAT-3DS HDF5 file reader not fully implemented yet.")
with h5py.File(file_path, 'r') as h5f:
# TODO: validate the correct dataset path for the selected band
# Example dataset names may include '/IMG_WV', '/IMG_TIR1', '/IMG_IR1', etc.
dataset_name = next(iter(h5f.keys()))
counts = np.array(h5f[dataset_name], dtype=np.float32)
return counts

def _calibrate_counts(self, counts, band):
"""
Converts raw counts (Digital Numbers) into Brightness Temperature (BT)
for thermal/water vapor channels or Albedo for visible channels.
Convert raw counts into a physical brightness quantity.

Calibration equations will depend on the channel payload metadata.
"""
# TODO: scientific validation required
# Calibration equations require slope/intercept values stored in HDF5 metadata attributes
# TODO: scientific validation required for actual INSAT-3DS coefficients
return counts

def _register_geolocation(self, counts):
"""
Extracts geolocation coordinate lookup parameters to project pixel velocities
correctly onto latitude/longitude grids centered at 74E.
Placeholder for geolocation extraction and coordinate mapping.
"""
# TODO: scientific validation required
return counts

def __len__(self):
# TODO: scientific validation required
return 0
return max(0, len(self.file_list) - 1)

def __getitem__(self, idx):
# TODO: scientific validation required
# Should return frame_t, frame_t_next, and topography DEM context
raise NotImplementedError("INSAT-3DS loader dataset indexing is not active.")
if len(self.file_list) < 2:
raise ValueError("INSAT-3DS loader requires at least two sequential files.")

file_t = self.file_list[idx]
file_t_next = self.file_list[idx + 1]

counts_t = self._read_hdf5_counts(file_t)
counts_t_next = self._read_hdf5_counts(file_t_next)

calibrated_t = self._calibrate_counts(counts_t, self.band)
calibrated_t_next = self._calibrate_counts(counts_t_next, self.band)

calibrated_t = handle_nans(calibrated_t, method='mean')
calibrated_t_next = handle_nans(calibrated_t_next, method='mean')

frame_t = crop_center(calibrated_t, crop_h=self.crop_size, crop_w=self.crop_size)
frame_t_next = crop_center(calibrated_t_next, crop_h=self.crop_size, crop_w=self.crop_size)

frame_t = normalize_channels(frame_t, method=self.normalization_method)
frame_t_next = normalize_channels(frame_t_next, method=self.normalization_method)

frame_t = torch.from_numpy(frame_t).unsqueeze(0)
frame_t_next = torch.from_numpy(frame_t_next).unsqueeze(0)

_, h, w = frame_t.shape
mock_dem = torch.zeros((1, h, w), dtype=torch.float32)

return frame_t, frame_t_next, mock_dem
97 changes: 85 additions & 12 deletions datasets/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,98 @@
import numpy as np
import torch

def handle_nans(data_matrix, replacement_val=0.0):

def handle_nans(data_matrix, replacement_val=0.0, method="zero"):
"""
Replaces NaN entries in numerical grids.
Replace invalid values in a 2D array or tensor.

Args:
data_matrix (np.ndarray|torch.Tensor): Input data grid.
replacement_val (float): Fallback value for missing entries.
method (str): Replacement method: "zero", "mean", or "median".

Returns:
same type as data_matrix: sanitized output.
"""
return np.nan_to_num(data_matrix, nan=replacement_val)
if isinstance(data_matrix, torch.Tensor):
data = data_matrix.cpu().numpy()
is_tensor = True
else:
data = np.array(data_matrix, dtype=np.float32)
is_tensor = False

if method == "zero":
clean = np.nan_to_num(data, nan=replacement_val)
elif method == "mean":
mean_val = np.nanmean(data)
clean = np.nan_to_num(data, nan=mean_val)
elif method == "median":
median_val = np.nanmedian(data)
clean = np.nan_to_num(data, nan=median_val)
else:
raise ValueError(f"Unsupported NaN replacement method: {method}")

return torch.from_numpy(clean) if is_tensor else clean


def crop_center(data_matrix, crop_h=512, crop_w=512):
"""
Crops a central region of specified height and width from a 2D matrix.
Crops a central region from a 2D array.

If the requested crop is larger than the source image, the function
pads the image symmetrically with edge values.
"""
h, w = data_matrix.shape
start_h = h // 2 - crop_h // 2
start_w = w // 2 - crop_w // 2
return data_matrix[start_h : start_h + crop_h, start_w : start_w + crop_w]
if crop_h > h or crop_w > w:
pad_h = max(0, (crop_h - h + 1) // 2)
pad_w = max(0, (crop_w - w + 1) // 2)
data_matrix = np.pad(
data_matrix,
pad_width=((pad_h, pad_h), (pad_w, pad_w)),
mode="edge"
)
h, w = data_matrix.shape

def normalize_channels(image_tensor, method="minmax"):
start_h = max(0, h // 2 - crop_h // 2)
start_w = max(0, w // 2 - crop_w // 2)
return data_matrix[start_h:start_h + crop_h, start_w:start_w + crop_w]


def normalize_channels(image_tensor, method="minmax", clip_percentile=(1, 99)):
"""
Placeholders for data scale standardization.
Normalize a 2D image using a standard scaling method.

Args:
image_tensor (np.ndarray|torch.Tensor): 2D image grid.
method (str): One of ["minmax", "zscore", "percentile", "none"].
clip_percentile (tuple): Percentiles used for percentile normalization.

Returns:
np.ndarray or torch.Tensor: normalized image in the same type.
"""
# TODO: scientific validation required
# Deferred from audit phase: input normalization needs physical limits definition
return image_tensor
if isinstance(image_tensor, torch.Tensor):
data = image_tensor.cpu().numpy().astype(np.float32)
is_tensor = True
else:
data = np.array(image_tensor, dtype=np.float32)
is_tensor = False

if method == "none":
normalized = data
elif method == "minmax":
min_val = np.nanmin(data)
max_val = np.nanmax(data)
span = max_val - min_val if max_val > min_val else 1.0
normalized = (data - min_val) / span
elif method == "zscore":
mu = np.nanmean(data)
sigma = np.nanstd(data)
normalized = (data - mu) / (sigma if sigma > 0 else 1.0)
elif method == "percentile":
low, high = np.nanpercentile(data, clip_percentile)
clipped = np.clip(data, low, high)
normalized = (clipped - low) / max(high - low, 1e-6)
else:
raise ValueError(f"Unsupported normalization method: {method}")

return torch.from_numpy(normalized) if is_tensor else normalized
Loading