diff --git a/datasets/goes16_loader.py b/datasets/goes16_loader.py index 51fefd1..563a327 100644 --- a/datasets/goes16_loader.py +++ b/datasets/goes16_loader.py @@ -7,47 +7,57 @@ import torch from torch.utils.data import Dataset +from .preprocessing import handle_nans, crop_center, normalize_channels + class GOES16ProxyDataset(Dataset): """ Streams consecutive frame pairs directly from NOAA's public GOES-16 S3 bucket. + Product: 'ABI-L2-CMIPC' (Cloud & Moisture Imagery - CONUS standard projection) Bands: 'C09' (6.9um Mid-Level Water Vapor) or 'C14' (11.2um Longwave Thermal IR) """ - def __init__(self, product='ABI-L2-CMIPC', band='C09', year=2025, day_of_year=150, hour=14): + def __init__( self, product='ABI-L2-CMIPC', band='C09', year=2025, day_of_year=150, hour=14, + crop_size=512, normalization_method='minmax'): + self.bucket_name = 'noaa-goes16' self.product = product self.band = band + self.crop_size = crop_size + self.normalization_method = normalization_method # Configure anonymous public access to bypass mandatory AWS credential steps self.s3 = boto3.client('s3', region_name='us-east-1', config=Config(signature_version=UNSIGNED)) - # S3 Path syntax structure: Product/Year/Day_of_Year/Hour/ + # S3 path syntax structure: Product/Year/Day_of_Year/Hour/ self.prefix = f"{product}/{year}/{day_of_year:03d}/{hour:02d}/" self.file_list = self._get_s3_file_list() + if len(self.file_list) < 2: + raise ValueError(f"Not enough GOES-16 files found for {self.prefix}") + def _get_s3_file_list(self): response = self.s3.list_objects_v2(Bucket=self.bucket_name, Prefix=self.prefix) if 'Contents' not in response: raise FileNotFoundError(f"No files available matching path: s3://{self.bucket_name}/{self.prefix}") - # Match only files containing our desired band descriptor suffix - files = [obj['Key'] for obj in response['Contents'] if f"M6{self.band}" in obj['Key']] + # Match only files containing our desired band descriptor suffix and NetCDF files. + files = [obj['Key'] for obj in response['Contents'] + if f"M6{self.band}" in obj['Key'] and obj['Key'].endswith('.nc')] return sorted(files) def _download_and_parse(self, s3_key): - local_filename = s3_key.split('/')[-1] + local_filename = os.path.basename(s3_key) if not os.path.exists(local_filename): print(f"Streaming {local_filename} down to disk runtime...") self.s3.download_file(self.bucket_name, s3_key, local_filename) - # Parse matrix metadata with nc.Dataset(local_filename, 'r') as rootgrp: data_matrix = np.array(rootgrp.variables['CMI'][:], dtype=np.float32) - data_matrix = np.nan_to_num(data_matrix, nan=0.0) - h, w = data_matrix.shape - data_matrix = data_matrix[h//2-256 : h//2+256, w//2-256 : w//2+256] + data_matrix = handle_nans(data_matrix, replacement_val=0.0, method='mean') + data_matrix = crop_center(data_matrix, crop_h=self.crop_size, crop_w=self.crop_size) + data_matrix = normalize_channels(data_matrix, method=self.normalization_method) return torch.from_numpy(data_matrix).unsqueeze(0) # Shape: [1, H, W] @@ -58,7 +68,6 @@ def __getitem__(self, idx): frame_t = self._download_and_parse(self.file_list[idx]) frame_t_next = self._download_and_parse(self.file_list[idx + 1]) - # Terrain Context layer (DEM mapping simulation) matching image geometry _, h, w = frame_t.shape mock_dem = torch.zeros((1, h, w), dtype=torch.float32) diff --git a/datasets/insat3ds_loader.py b/datasets/insat3ds_loader.py index abf1eb8..5b4a014 100644 --- a/datasets/insat3ds_loader.py +++ b/datasets/insat3ds_loader.py @@ -1,56 +1,100 @@ +import glob import os +import h5py import numpy as np import torch from torch.utils.data import Dataset +from .preprocessing import handle_nans, crop_center, normalize_channels + class INSAT3DSProxyDataset(Dataset): """ - Placeholder loader for INSAT-3DS HDF5 (.h5) datasets. + Loader scaffold for INSAT-3DS HDF5 (.h5) datasets. + This class outlines the required API signatures, calibration parameters, and metadata mappings for IMAGER payload channels. """ - def __init__(self, data_dir=None, band='WV', year=2024, day_of_year=120, hour=14): - # TODO: scientific validation required - # Deferred from audit phase: INSAT-3DS HDF5 pipeline construction + def __init__(self, data_dir=None, band='WV', year=2024, day_of_year=120, hour=14, crop_size=512, normalization_method='minmax'): self.data_dir = data_dir self.band = band self.year = year self.day_of_year = day_of_year self.hour = hour - - # In a real implementation, scan the data_dir for matching H5 files - self.file_list = [] - + self.crop_size = crop_size + self.normalization_method = normalization_method + + self.file_list = self._get_file_list() + + def _get_file_list(self): + if not self.data_dir or not os.path.isdir(self.data_dir): + return [] + + patterns = [os.path.join(self.data_dir, '*.h5'), os.path.join(self.data_dir, '*.H5')] + files = [] + for pattern in patterns: + files.extend(glob.glob(pattern)) + + # Use simple band filtering when filenames include channel labels + filtered = [f for f in files if self.band in os.path.basename(f)] + return sorted(filtered) + def _read_hdf5_counts(self, file_path): """ - Reads raw digital counts from INSAT-3DS HDF5 datasets. + Reads raw digital counts from an INSAT-3DS HDF5 file. + + This implementation is intentionally minimal until the HDF5 schema is known. """ - # TODO: scientific validation required - # Placeholder for h5py loading and group/dataset indexing (e.g., '/IMG_WV' or '/IMG_TIR1') - raise NotImplementedError("INSAT-3DS HDF5 file reader not fully implemented yet.") + with h5py.File(file_path, 'r') as h5f: + # TODO: validate the correct dataset path for the selected band + # Example dataset names may include '/IMG_WV', '/IMG_TIR1', '/IMG_IR1', etc. + dataset_name = next(iter(h5f.keys())) + counts = np.array(h5f[dataset_name], dtype=np.float32) + return counts def _calibrate_counts(self, counts, band): """ - Converts raw counts (Digital Numbers) into Brightness Temperature (BT) - for thermal/water vapor channels or Albedo for visible channels. + Convert raw counts into a physical brightness quantity. + + Calibration equations will depend on the channel payload metadata. """ - # TODO: scientific validation required - # Calibration equations require slope/intercept values stored in HDF5 metadata attributes + # TODO: scientific validation required for actual INSAT-3DS coefficients return counts def _register_geolocation(self, counts): """ - Extracts geolocation coordinate lookup parameters to project pixel velocities - correctly onto latitude/longitude grids centered at 74E. + Placeholder for geolocation extraction and coordinate mapping. """ - # TODO: scientific validation required return counts def __len__(self): - # TODO: scientific validation required - return 0 + return max(0, len(self.file_list) - 1) def __getitem__(self, idx): - # TODO: scientific validation required - # Should return frame_t, frame_t_next, and topography DEM context - raise NotImplementedError("INSAT-3DS loader dataset indexing is not active.") + if len(self.file_list) < 2: + raise ValueError("INSAT-3DS loader requires at least two sequential files.") + + file_t = self.file_list[idx] + file_t_next = self.file_list[idx + 1] + + counts_t = self._read_hdf5_counts(file_t) + counts_t_next = self._read_hdf5_counts(file_t_next) + + calibrated_t = self._calibrate_counts(counts_t, self.band) + calibrated_t_next = self._calibrate_counts(counts_t_next, self.band) + + calibrated_t = handle_nans(calibrated_t, method='mean') + calibrated_t_next = handle_nans(calibrated_t_next, method='mean') + + frame_t = crop_center(calibrated_t, crop_h=self.crop_size, crop_w=self.crop_size) + frame_t_next = crop_center(calibrated_t_next, crop_h=self.crop_size, crop_w=self.crop_size) + + frame_t = normalize_channels(frame_t, method=self.normalization_method) + frame_t_next = normalize_channels(frame_t_next, method=self.normalization_method) + + frame_t = torch.from_numpy(frame_t).unsqueeze(0) + frame_t_next = torch.from_numpy(frame_t_next).unsqueeze(0) + + _, h, w = frame_t.shape + mock_dem = torch.zeros((1, h, w), dtype=torch.float32) + + return frame_t, frame_t_next, mock_dem diff --git a/datasets/preprocessing.py b/datasets/preprocessing.py index 53e4eeb..eb9e6fc 100644 --- a/datasets/preprocessing.py +++ b/datasets/preprocessing.py @@ -1,25 +1,98 @@ import numpy as np import torch -def handle_nans(data_matrix, replacement_val=0.0): + +def handle_nans(data_matrix, replacement_val=0.0, method="zero"): """ - Replaces NaN entries in numerical grids. + Replace invalid values in a 2D array or tensor. + + Args: + data_matrix (np.ndarray|torch.Tensor): Input data grid. + replacement_val (float): Fallback value for missing entries. + method (str): Replacement method: "zero", "mean", or "median". + + Returns: + same type as data_matrix: sanitized output. """ - return np.nan_to_num(data_matrix, nan=replacement_val) + if isinstance(data_matrix, torch.Tensor): + data = data_matrix.cpu().numpy() + is_tensor = True + else: + data = np.array(data_matrix, dtype=np.float32) + is_tensor = False + + if method == "zero": + clean = np.nan_to_num(data, nan=replacement_val) + elif method == "mean": + mean_val = np.nanmean(data) + clean = np.nan_to_num(data, nan=mean_val) + elif method == "median": + median_val = np.nanmedian(data) + clean = np.nan_to_num(data, nan=median_val) + else: + raise ValueError(f"Unsupported NaN replacement method: {method}") + + return torch.from_numpy(clean) if is_tensor else clean + def crop_center(data_matrix, crop_h=512, crop_w=512): """ - Crops a central region of specified height and width from a 2D matrix. + Crops a central region from a 2D array. + + If the requested crop is larger than the source image, the function + pads the image symmetrically with edge values. """ h, w = data_matrix.shape - start_h = h // 2 - crop_h // 2 - start_w = w // 2 - crop_w // 2 - return data_matrix[start_h : start_h + crop_h, start_w : start_w + crop_w] + if crop_h > h or crop_w > w: + pad_h = max(0, (crop_h - h + 1) // 2) + pad_w = max(0, (crop_w - w + 1) // 2) + data_matrix = np.pad( + data_matrix, + pad_width=((pad_h, pad_h), (pad_w, pad_w)), + mode="edge" + ) + h, w = data_matrix.shape -def normalize_channels(image_tensor, method="minmax"): + start_h = max(0, h // 2 - crop_h // 2) + start_w = max(0, w // 2 - crop_w // 2) + return data_matrix[start_h:start_h + crop_h, start_w:start_w + crop_w] + + +def normalize_channels(image_tensor, method="minmax", clip_percentile=(1, 99)): """ - Placeholders for data scale standardization. + Normalize a 2D image using a standard scaling method. + + Args: + image_tensor (np.ndarray|torch.Tensor): 2D image grid. + method (str): One of ["minmax", "zscore", "percentile", "none"]. + clip_percentile (tuple): Percentiles used for percentile normalization. + + Returns: + np.ndarray or torch.Tensor: normalized image in the same type. """ - # TODO: scientific validation required - # Deferred from audit phase: input normalization needs physical limits definition - return image_tensor + if isinstance(image_tensor, torch.Tensor): + data = image_tensor.cpu().numpy().astype(np.float32) + is_tensor = True + else: + data = np.array(image_tensor, dtype=np.float32) + is_tensor = False + + if method == "none": + normalized = data + elif method == "minmax": + min_val = np.nanmin(data) + max_val = np.nanmax(data) + span = max_val - min_val if max_val > min_val else 1.0 + normalized = (data - min_val) / span + elif method == "zscore": + mu = np.nanmean(data) + sigma = np.nanstd(data) + normalized = (data - mu) / (sigma if sigma > 0 else 1.0) + elif method == "percentile": + low, high = np.nanpercentile(data, clip_percentile) + clipped = np.clip(data, low, high) + normalized = (clipped - low) / max(high - low, 1e-6) + else: + raise ValueError(f"Unsupported normalization method: {method}") + + return torch.from_numpy(normalized) if is_tensor else normalized diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py new file mode 100644 index 0000000..05f2cba --- /dev/null +++ b/tests/test_preprocessing.py @@ -0,0 +1,94 @@ +import os +import sys + +import numpy as np +import torch +import pytest + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from datasets.preprocessing import handle_nans, crop_center, normalize_channels + + +def test_handle_nans_replaces_nan_with_zero_numpy(): + data = np.array([[1.0, np.nan], [3.0, 4.0]], dtype=np.float32) + result = handle_nans(data, replacement_val=0.0, method='zero') + + expected = np.array([[1.0, 0.0], [3.0, 4.0]], dtype=np.float32) + assert np.array_equal(result, expected) + + +def test_handle_nans_replaces_nan_with_mean_tensor(): + data = torch.tensor([[1.0, float('nan')], [3.0, 5.0]], dtype=torch.float32) + result = handle_nans(data, method='mean') + + assert isinstance(result, torch.Tensor) + expected_mean = torch.tensor(3.0, dtype=torch.float32) + assert torch.allclose(result, torch.tensor([[1.0, expected_mean], [3.0, 5.0]])) + + +def test_handle_nans_replaces_nan_with_median_numpy(): + data = np.array([[2.0, np.nan], [5.0, 1.0]], dtype=np.float32) + result = handle_nans(data, method='median') + + expected = np.array([[2.0, 2.0], [5.0, 1.0]], dtype=np.float32) + assert np.array_equal(result, expected) + + +def test_handle_nans_unsupported_method_raises(): + data = np.array([[1.0, np.nan]], dtype=np.float32) + with pytest.raises(ValueError, match='Unsupported NaN replacement method'): + handle_nans(data, method='invalid') + + +def test_crop_center_returns_center_section(): + data = np.arange(16, dtype=np.float32).reshape(4, 4) + result = crop_center(data, crop_h=2, crop_w=2) + + expected = np.array([[5.0, 6.0], [9.0, 10.0]], dtype=np.float32) + assert result.shape == (2, 2) + assert np.array_equal(result, expected) + + +def test_crop_center_pads_when_crop_larger_than_input(): + data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + result = crop_center(data, crop_h=4, crop_w=4) + + assert result.shape == (4, 4) + assert np.all(result[0, :] == result[1, :]) + assert np.all(result[-1, :] == result[-2, :]) + assert np.all(result[:, 0] == result[:, 1]) + assert np.all(result[:, -1] == result[:, -2]) + + +def test_normalize_channels_minmax(): + data = np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32) + result = normalize_channels(data, method='minmax') + + expected = np.array([[0.0, 0.33333334], [0.6666667, 1.0]], dtype=np.float32) + assert np.allclose(result, expected) + + +def test_normalize_channels_zscore_tensor(): + data = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + result = normalize_channels(data, method='zscore') + + assert isinstance(result, torch.Tensor) + assert torch.allclose(result.mean(), torch.tensor(0.0), atol=1e-6) + assert torch.allclose(result.std(unbiased=False), torch.tensor(1.0), atol=1e-6) + + +def test_normalize_channels_percentile_clips_and_scales(): + data = np.array([[0.0, 1.0], [100.0, 200.0]], dtype=np.float32) + result = normalize_channels(data, method='percentile', clip_percentile=(0, 50)) + + assert result.min() == 0.0 + assert result.max() == 1.0 + assert result[1, 1] == 1.0 + + +def test_normalize_channels_none_returns_same_array(): + data = np.array([[5.0, 10.0]], dtype=np.float32) + result = normalize_channels(data, method='none') + + assert np.array_equal(result, data)