From 1fa714bd00e727a8de46f819b3244f10c48d4b32 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Thu, 19 Feb 2026 18:24:48 +0100 Subject: [PATCH 01/14] integration of agbd-lite dataset --- configs/dataset/agbdlite-seg.yaml | 86 +++++++++ configs/dataset/agbdlite.yaml | 59 +++++++ configs/preprocessing/reg_resize.yaml | 23 +++ environment.yaml | 1 + pangaea/datasets/agbdlite-seg.py | 243 ++++++++++++++++++++++++++ pangaea/datasets/agbdlite.py | 232 ++++++++++++++++++++++++ pangaea/engine/evaluator.py | 6 + pangaea/engine/trainer.py | 13 ++ requirements.txt | 2 +- 9 files changed, 664 insertions(+), 1 deletion(-) create mode 100644 configs/dataset/agbdlite-seg.yaml create mode 100644 configs/dataset/agbdlite.yaml create mode 100644 configs/preprocessing/reg_resize.yaml create mode 100644 pangaea/datasets/agbdlite-seg.py create mode 100644 pangaea/datasets/agbdlite.py diff --git a/configs/dataset/agbdlite-seg.yaml b/configs/dataset/agbdlite-seg.yaml new file mode 100644 index 00000000..afb75f30 --- /dev/null +++ b/configs/dataset/agbdlite-seg.yaml @@ -0,0 +1,86 @@ +_target_: pangaea.datasets.agbdlite-seg.AGBDLite +dataset_name: AGBDLite +root_path: ./data/AGBDLite +download_url: "https://zenodo.org/api/records/18485030" +auto_download: True + +img_size: 25 +multi_temporal: False +multi_modal: True + +eval_big: False # whether to evaluate on AGBD-test instead of AGBD-Lite-test +lite_chunk_size: 1 # should be 32, once we can load in batches + +# for classification (biome) +target: "biome" +ignore_index: -1 +num_classes: 14 +classes: + - Shrubs + - Herbaceous vegetation + - Cultivated + - Herbaceous wetland + - Closed-ENL + - Closed-EBL + - Closed-DBL + - Closed-mixed + - Closed-other + - Open-ENL + - Open-EBL + - Open-DBL + - Open-mixed + - Open-other +distribution: + - 0.07687005878775593 + - 0.16642475842962362 + - 0.12770322319075614 + - 0.010562875869991216 + - 0.15101560916278126 + - 0.0937414690181769 + - 0.06032569768227583 + - 0.025453071153456314 + - 0.07738090411514291 + - 0.01700385161159538 + - 0.003516453814446922 + - 0.014763159672950875 + - 0.009043854314480708 + - 0.166195013176566 + + +# features +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B09 + - B11 + - B12 + sar: + - HH + - HV + + +# statistics + +data_mean: + optical: [0.1235409, 0.13345453, 0.15880638, 0.15243885, 0.20159355, 0.32084775, 0.35937154, 0.37473023, 0.38475886, 0.3847485, 0.28990296, 0.21002403] + sar: [-10.24998, -16.50024] + +data_std: + optical: [0.023037845, 0.02647069, 0.030844089, 0.03733916, 0.041663107, 0.06969897, 0.08277564, 0.090343215, 0.086930655, 0.08083575, 0.075588495, 0.058553487] + sar: [8.61129, 8.732289] + +data_min: + optical: [1e-04, 1e-04, 0.0283, 0.0218, 0.0795, 0.0695, 0.0794, 0.0524, 0.0846, 0.0037, 0.0974, 0.0977] + sar: [-83.0, -83.0] + +data_max: + optical: [1.4194, 2.1088, 1.9552, 1.856, 1.6232, 1.6055, 1.586, 1.7232, 1.593, 1.7081, 1.61, 1.6255] + sar: [13.297462, 11.688309] \ No newline at end of file diff --git a/configs/dataset/agbdlite.yaml b/configs/dataset/agbdlite.yaml new file mode 100644 index 00000000..973c5ee8 --- /dev/null +++ b/configs/dataset/agbdlite.yaml @@ -0,0 +1,59 @@ +_target_: pangaea.datasets.agbdlite.AGBDLite +dataset_name: AGBDLite +root_path: ./data/AGBDLite +download_url: "https://zenodo.org/api/records/18485030" +auto_download: True + +img_size: 25 +multi_temporal: False +multi_modal: True + +eval_big: False # whether to evaluate on AGBD-test instead of AGBD-Lite-test +lite_chunk_size: 1 # should be 32, once we can load in batches + +# for regression (agbd or rh98) +target: "agbd" # agbd or rh98 +ignore_index: -1 +num_classes: 1 +classes: + - regression +distribution: + - 1. + +# features +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B09 + - B11 + - B12 + sar: + - HH + - HV + + +# statistics + +data_mean: + optical: [0.1235409, 0.13345453, 0.15880638, 0.15243885, 0.20159355, 0.32084775, 0.35937154, 0.37473023, 0.38475886, 0.3847485, 0.28990296, 0.21002403] + sar: [-10.24998, -16.50024] + +data_std: + optical: [0.023037845, 0.02647069, 0.030844089, 0.03733916, 0.041663107, 0.06969897, 0.08277564, 0.090343215, 0.086930655, 0.08083575, 0.075588495, 0.058553487] + sar: [8.61129, 8.732289] + +data_min: + optical: [1e-04, 1e-04, 0.0283, 0.0218, 0.0795, 0.0695, 0.0794, 0.0524, 0.0846, 0.0037, 0.0974, 0.0977] + sar: [-83.0, -83.0] + +data_max: + optical: [1.4194, 2.1088, 1.9552, 1.856, 1.6232, 1.6055, 1.586, 1.7232, 1.593, 1.7081, 1.61, 1.6255] + sar: [13.297462, 11.688309] \ No newline at end of file diff --git a/configs/preprocessing/reg_resize.yaml b/configs/preprocessing/reg_resize.yaml new file mode 100644 index 00000000..f759b434 --- /dev/null +++ b/configs/preprocessing/reg_resize.yaml @@ -0,0 +1,23 @@ +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMinMax + - _target_: pangaea.engine.data_preprocessor.BandPadding + +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMinMax + - _target_: pangaea.engine.data_preprocessor.BandPadding + +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMinMax + - _target_: pangaea.engine.data_preprocessor.BandPadding diff --git a/environment.yaml b/environment.yaml index 62524cbb..25c3811e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -31,4 +31,5 @@ dependencies: - yacs - wandb - hydra-core + - h5py diff --git a/pangaea/datasets/agbdlite-seg.py b/pangaea/datasets/agbdlite-seg.py new file mode 100644 index 00000000..1f510055 --- /dev/null +++ b/pangaea/datasets/agbdlite-seg.py @@ -0,0 +1,243 @@ +import torch +from pangaea.datasets.base import RawGeoFMDataset + +import numpy as np +import h5py +from os.path import join, isfile +from tqdm import tqdm +import requests +import pathlib + +REF_BIOMES = { + 20: 'Shrubs', 30: 'Herbaceous vegetation', 40: 'Cultivated', 90: 'Herbaceous wetland', + 111: 'Closed-ENL', 112: 'Closed-EBL', 114: 'Closed-DBL', 115: 'Closed-mixed', + 116: 'Closed-other', 121: 'Open-ENL', 122: 'Open-EBL', 124: 'Open-DBL', + 125: 'Open-mixed', 126: 'Open-other' +} + +def download_from_zenodo(api_url, target_filenames, output_dir): + """ + This function downloads specific files from a Zenodo record. It first retrieves + the metadata for the record to find the download URLs for the specified files, + and then downloads only those files. + + Args: + - record_id (str): The Zenodo record ID (e.g., "18485030"). + - target_filenames (list of str): List of filenames to download from the record. + - output_dir (str): Directory where the downloaded files should be saved. + + Returns: + - None: The function saves the files to the specified output directory. + """ + + # Fetch metadata to find the specific download URLs + response = requests.get(api_url) + response.raise_for_status() + files_in_record = response.json().get('files', []) + to_download = {f['key']: f['links']['self'] for f in files_in_record if f['key'] in target_filenames} + if not to_download: + print("None of the specified files were found in this record.") + return + + # Download the files + for filename, download_url in to_download.items(): + file_path = join(output_dir, filename) + with requests.get(download_url, stream=True) as r: + r.raise_for_status() + total_size = int(r.headers.get('content-length', 0)) + with open(file_path, 'wb') as f, tqdm( + total=total_size, unit='B', unit_scale=True, desc=filename + ) as pbar: + for chunk in r.iter_content(chunk_size=1024 * 1024): # 1MB chunks + _ = f.write(chunk) + _ = pbar.update(len(chunk)) + print(f"\nDone. Files saved to '{output_dir}'") + + +class AGBDLite(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + target: str, + eval_big: bool, + lite_chunk_size: int + ): + super(AGBDLite, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + assert split in ['train', 'val', 'test'], "split must be one of 'train', 'val', or 'test'" + self.mode = split + self.eval_big = eval_big + self.target = target + assert self.target == 'biome', "Only biome classification is currently supported for AGBD-Lite segmentation. Please set target to 'biome'." + self.lite_chunk_size = lite_chunk_size + self.patch_size = img_size + self.zenodo_record = "18485030" + if auto_download: self.download(self) + + self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] + if self.eval_big and self.mode == 'test' : self.fname = 'AGBD-test.h5' + else: self.fname = f'AGBD-Lite-{self.mode}.h5' + self.f_handle = h5py.File(join(self.root_path, self.fname), 'r') + + with h5py.File(join(self.root_path, self.fname), 'r') as f: + gedi_length = len(f['GEDI']['agbd']) + total_length = (gedi_length // self.lite_chunk_size) + (1 if (gedi_length % self.lite_chunk_size != 0) else 0) + self.gedi_length, self.length = gedi_length, total_length + + # Prepare to map the biome classes to 0-13 for Cross Entropy loss + keys = torch.tensor(sorted(REF_BIOMES.keys())) + self.biome_lookup = torch.full((max(keys.max(), 255) + 1,), -1) + self.biome_lookup[keys] = torch.arange(len(keys)) + + def __len__(self): + # Return the total number of samples + return int(self.length) + + def __getitem__(self, n): + """Returns the i-th item of the dataset. + + Args: + i (int): index of the item + + Raises: + NotImplementedError: raise if the method is not implemented + + Returns: + dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format + {"image": + { + "optical": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + "sar": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + }, + "target": torch.Tensor of shape (H W) of type torch.int64 for segmentation, torch.float for + regression datasets., + "metadata": dict}. + """ + + # Find the file, tile, and row index corresponding to this chunk + idx_start = n * self.lite_chunk_size + idx_end = min(idx_start + self.lite_chunk_size, self.gedi_length) + f = self.f_handle + + # Sentinel-2 bands ------------------------------------------------------------------------ + + # Set the order and indices for the Sentinel-2 bands + if not hasattr(self, 's2_order') : self.s2_order = list(f['S2_bands'].attrs['order']) + if not hasattr(self, 's2_indices') : self.s2_indices = [self.s2_order.index(band) for band in self.s2_bands] + + # Get the bands + s2_bands = f['S2_bands'][idx_start : idx_end, :, :, self.s2_indices].astype(np.float32) + + # Get the BOA offset, if it exists + if 'S2_boa_offset' in f['Sentinel_metadata'].keys() : s2_boa_offset = f['Sentinel_metadata']['S2_boa_offset'][idx_start : idx_end].astype(np.float32) + else: s2_boa_offset = np.full((s2_bands.shape[0],), 0, dtype = np.float32) + s2_boa_offset = s2_boa_offset[:, np.newaxis, np.newaxis, np.newaxis] + + # Get the surface reflectance values + sr_bands = (s2_bands - s2_boa_offset * 1000) / 10000 + sr_bands[s2_bands == 0] = 0 + sr_bands[sr_bands < 0] = 0 + + # SAR bands (from ALOS-PALSAR-2) ---------------------------------------------------------- + + # Set the order for the ALOS bands + if not hasattr(self, 'alos_order') : self.alos_order = f['ALOS_bands'].attrs['order'] + + # Get the bands as gamma naught values + alos_bands = f['ALOS_bands'][idx_start : idx_end, :, :, :].astype(np.float32) + alos_bands = np.where(alos_bands == 0, -9999.0, 10 * np.log10(np.power(alos_bands, 2)) - 83.0) + + # Target data ----------------------------------------------------------------------------- + lc = torch.from_numpy(np.array(f['LC'][idx_start : idx_end, :, :, 0])).long() + target, _ = torch.mode(lc.flatten(start_dim=1), dim = 1) + target = self.biome_lookup[target] + + # TEMPORARY + placeholder_target = torch.full_like(lc, fill_value = self.ignore_index, dtype = torch.float if self.target in ['agbd', 'rh98'] else torch.long) + placeholder_target[:, self.patch_size // 2, self.patch_size // 2] = target + target = placeholder_target + # TEMPORARY + + # Metadata (if needed) -------------------------------------------------------------------- + region = torch.from_numpy(np.array(f['GEDI']['region_cla'][idx_start : idx_end])) + biome = lc[:, self.patch_size // 2, self.patch_size // 2] + + # Convert to tensors and return ----------------------------------------------------------- + sr_bands = torch.from_numpy(sr_bands).float() + sr_bands = sr_bands.permute(0, 3, 1, 2).unsqueeze(2) # Change to (B, C, 1, H, W) + alos_bands = torch.from_numpy(alos_bands).float() + alos_bands = alos_bands.permute(0, 3, 1, 2).unsqueeze(2) # Change to (B, C, 1, H, W) + + # TEMPORARY + sr_bands = sr_bands.squeeze(0) + alos_bands = alos_bands.squeeze(0) + target = target.squeeze(0) + region = region.squeeze(0) + biome = biome.squeeze(0) + + return { + 'image': { + 'optical': sr_bands, + 'sar': alos_bands + }, + 'target': target, + 'metadata': { + 'region': region, + 'biome': biome + } + } + + @staticmethod + def download(self, silent=False): + + root_path = pathlib.Path(self.root_path) + + # Create the root directory if it does not exist + if not root_path.exists(): root_path.mkdir(parents=True, exist_ok=True) + if root_path.exists() : + if self.eval_big and not isfile(join(root_path, 'AGBD-test.h5')) : + print(f"AGBD-Lite test file does not exist at {root_path}. Downloading test file.") + download_from_zenodo(self.download_url, ["AGBD-test.h5"], output_dir=root_path) + if isfile(join(root_path, 'AGBD-Lite-train.h5')) and isfile(join(root_path, 'AGBD-Lite-val.h5')) and isfile(join(root_path, 'AGBD-Lite-test.h5')) : + if not silent: + print(f"AGBD-Lite files exist at {root_path}. Skipping download.") + return + + # Download the files from https://zenodo.org/records/18485030 + fnames = ["AGBD-Lite-train.h5", "AGBD-Lite-val.h5", "AGBD-Lite-test.h5"] + fnames += ["AGBD-test.h5"] if self.eval_big else [] + download_from_zenodo(self.download_url, fnames, output_dir=root_path) \ No newline at end of file diff --git a/pangaea/datasets/agbdlite.py b/pangaea/datasets/agbdlite.py new file mode 100644 index 00000000..4ae26659 --- /dev/null +++ b/pangaea/datasets/agbdlite.py @@ -0,0 +1,232 @@ +import torch +from pangaea.datasets.base import RawGeoFMDataset + +import numpy as np +import h5py +from os.path import join, isfile +from tqdm import tqdm +import requests +import pathlib + +REF_BIOMES = { + 20: 'Shrubs', 30: 'Herbaceous vegetation', 40: 'Cultivated', 90: 'Herbaceous wetland', + 111: 'Closed-ENL', 112: 'Closed-EBL', 114: 'Closed-DBL', 115: 'Closed-mixed', + 116: 'Closed-other', 121: 'Open-ENL', 122: 'Open-EBL', 124: 'Open-DBL', + 125: 'Open-mixed', 126: 'Open-other' +} + +def download_from_zenodo(api_url, target_filenames, output_dir): + """ + This function downloads specific files from a Zenodo record. It first retrieves + the metadata for the record to find the download URLs for the specified files, + and then downloads only those files. + + Args: + - record_id (str): The Zenodo record ID (e.g., "18485030"). + - target_filenames (list of str): List of filenames to download from the record. + - output_dir (str): Directory where the downloaded files should be saved. + + Returns: + - None: The function saves the files to the specified output directory. + """ + + # Fetch metadata to find the specific download URLs + response = requests.get(api_url) + response.raise_for_status() + files_in_record = response.json().get('files', []) + to_download = {f['key']: f['links']['self'] for f in files_in_record if f['key'] in target_filenames} + if not to_download: + print("None of the specified files were found in this record.") + return + + # Download the files + for filename, download_url in to_download.items(): + file_path = join(output_dir, filename) + with requests.get(download_url, stream=True) as r: + r.raise_for_status() + total_size = int(r.headers.get('content-length', 0)) + with open(file_path, 'wb') as f, tqdm( + total=total_size, unit='B', unit_scale=True, desc=filename + ) as pbar: + for chunk in r.iter_content(chunk_size=1024 * 1024): # 1MB chunks + _ = f.write(chunk) + _ = pbar.update(len(chunk)) + print(f"\nDone. Files saved to '{output_dir}'") + + +class AGBDLite(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + target: str, + eval_big: bool, + lite_chunk_size: int + ): + super(AGBDLite, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + assert split in ['train', 'val', 'test'], "split must be one of 'train', 'val', or 'test'" + self.mode = split + self.eval_big = eval_big + self.target = target + self.lite_chunk_size = lite_chunk_size + self.patch_size = img_size + self.zenodo_record = "18485030" + if auto_download: self.download(self) + + self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] + if self.eval_big and self.mode == 'test' : self.fname = 'AGBD-test.h5' + else: self.fname = f'AGBD-Lite-{self.mode}.h5' + self.f_handle = h5py.File(join(self.root_path, self.fname), 'r') + + with h5py.File(join(self.root_path, self.fname), 'r') as f: + gedi_length = len(f['GEDI']['agbd']) + total_length = (gedi_length // self.lite_chunk_size) + (1 if (gedi_length % self.lite_chunk_size != 0) else 0) + self.gedi_length, self.length = gedi_length, total_length + + def __len__(self): + # Return the total number of samples + return int(self.length) + + def __getitem__(self, n): + """Returns the i-th item of the dataset. + + Args: + i (int): index of the item + + Raises: + NotImplementedError: raise if the method is not implemented + + Returns: + dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format + {"image": + { + "optical": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + "sar": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + }, + "target": torch.Tensor of shape (H W) of type torch.int64 for segmentation, torch.float for + regression datasets., + "metadata": dict}. + """ + + # Find the file, tile, and row index corresponding to this chunk + idx_start = n * self.lite_chunk_size + idx_end = min(idx_start + self.lite_chunk_size, self.gedi_length) + f = self.f_handle + + # Sentinel-2 bands ------------------------------------------------------------------------ + + # Set the order and indices for the Sentinel-2 bands + if not hasattr(self, 's2_order') : self.s2_order = list(f['S2_bands'].attrs['order']) + if not hasattr(self, 's2_indices') : self.s2_indices = [self.s2_order.index(band) for band in self.s2_bands] + + # Get the bands + s2_bands = f['S2_bands'][idx_start : idx_end, :, :, self.s2_indices].astype(np.float32) + + # Get the BOA offset, if it exists + if 'S2_boa_offset' in f['Sentinel_metadata'].keys() : s2_boa_offset = f['Sentinel_metadata']['S2_boa_offset'][idx_start : idx_end].astype(np.float32) + else: s2_boa_offset = np.full((s2_bands.shape[0],), 0, dtype = np.float32) + s2_boa_offset = s2_boa_offset[:, np.newaxis, np.newaxis, np.newaxis] + + # Get the surface reflectance values + sr_bands = (s2_bands - s2_boa_offset * 1000) / 10000 + sr_bands[s2_bands == 0] = 0 + sr_bands[sr_bands < 0] = 0 + + # SAR bands (from ALOS-PALSAR-2) ---------------------------------------------------------- + + # Set the order for the ALOS bands + if not hasattr(self, 'alos_order') : self.alos_order = f['ALOS_bands'].attrs['order'] + + # Get the bands as gamma naught values + alos_bands = f['ALOS_bands'][idx_start : idx_end, :, :, :].astype(np.float32) + alos_bands = np.where(alos_bands == 0, -9999.0, 10 * np.log10(np.power(alos_bands, 2)) - 83.0) + + # Target data ----------------------------------------------------------------------------- + target_value = torch.from_numpy(np.array(f['GEDI'][self.target][idx_start : idx_end], dtype = np.float32)).to(torch.float) + lc = torch.from_numpy(np.array(f['LC'][idx_start : idx_end, :, :, 0])).long() + target = torch.full_like(lc, fill_value = self.ignore_index, dtype = torch.float if self.target in ['agbd', 'rh98'] else torch.long) + target[:, self.patch_size // 2, self.patch_size // 2] = target_value + + # Metadata (if needed) -------------------------------------------------------------------- + region = torch.from_numpy(np.array(f['GEDI']['region_cla'][idx_start : idx_end])).long() + biome = lc[:, self.patch_size // 2, self.patch_size // 2] + + # Convert to tensors and return ----------------------------------------------------------- + sr_bands = torch.from_numpy(sr_bands).float() + sr_bands = sr_bands.permute(0, 3, 1, 2).unsqueeze(2) # Change to (B, C, 1, H, W) + alos_bands = torch.from_numpy(alos_bands).float() + alos_bands = alos_bands.permute(0, 3, 1, 2).unsqueeze(2) # Change to (B, C, 1, H, W) + + # TEMPORARY, until we can return chunks + sr_bands = sr_bands.squeeze(0) + alos_bands = alos_bands.squeeze(0) + target = target.squeeze(0) + region = region.squeeze(0) + biome = biome.squeeze(0) + + return { + 'image': { + 'optical': sr_bands, + 'sar': alos_bands + }, + 'target': target, + 'metadata': { + 'region': region, + 'biome': biome + } + } + + @staticmethod + def download(self, silent=False): + + root_path = pathlib.Path(self.root_path) + + # Create the root directory if it does not exist + if not root_path.exists(): root_path.mkdir(parents=True, exist_ok=True) + if root_path.exists() : + if self.eval_big and not isfile(join(root_path, 'AGBD-test.h5')) : + print(f"AGBD-Lite test file does not exist at {root_path}. Downloading test file.") + download_from_zenodo(self.download_url, ["AGBD-test.h5"], output_dir=root_path) + if isfile(join(root_path, 'AGBD-Lite-train.h5')) and isfile(join(root_path, 'AGBD-Lite-val.h5')) and isfile(join(root_path, 'AGBD-Lite-test.h5')) : + if not silent: + print(f"AGBD-Lite files exist at {root_path}. Skipping download.") + return + + # Download the files from https://zenodo.org/records/18485030 + fnames = ["AGBD-Lite-train.h5", "AGBD-Lite-val.h5", "AGBD-Lite-test.h5"] + fnames += ["AGBD-test.h5"] if self.eval_big else [] + download_from_zenodo(self.download_url, fnames, output_dir=root_path) \ No newline at end of file diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index 345e7b33..8dd0f7f3 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -631,6 +631,12 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): else: raise NotImplementedError((f"Inference mode {self.inference_mode} is not implemented.")) + # sparse backprop + if self.val_loader.dataset.dataset_name == "AGBDLite" : + valid_mask = (target != self.val_loader.dataset.ignore_index) + logits = logits[valid_mask] + target = target[valid_mask] + mse += F.mse_loss(logits, target) torch.distributed.all_reduce(mse, op=torch.distributed.ReduceOp.SUM) diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index bee930b6..a222cd91 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -729,6 +729,13 @@ def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tens Returns: torch.Tensor: loss value. """ + + # sparse backprop + if self.train_loader.dataset.dataset_name == "AGBDLite" : + valid_mask = (target != self.train_loader.dataset.ignore_index) + logits = logits[valid_mask.unsqueeze(1)].unsqueeze(1) + target = target[valid_mask] + return self.criterion(logits.squeeze(dim=1), target) @torch.no_grad() @@ -742,6 +749,12 @@ def compute_logging_metrics( target (torch.Tensor): target tensor. """ + # sparse backprop + if self.train_loader.dataset.dataset_name == "AGBDLite" : + valid_mask = (target != self.train_loader.dataset.ignore_index) + logits = logits[valid_mask.unsqueeze(1)].unsqueeze(1) + target = target[valid_mask] + mse = F.mse_loss(logits.squeeze(dim=1), target) self.training_metrics["MSE"].update(mse.item()) diff --git a/requirements.txt b/requirements.txt index 015715f5..472037e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,4 @@ pytest yacs wandb hydra-core - +h5py From b6376a190f9b905a1b4888105cf6a00e7623f59c Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Mon, 23 Feb 2026 10:39:26 +0100 Subject: [PATCH 02/14] working encoders --- configs/dataset/agbdlite.yaml | 18 +++++++++--------- configs/encoder/terramind_large.yaml | 4 ++-- environment.yaml | 1 + pangaea/encoders/terramind_encoder.py | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/configs/dataset/agbdlite.yaml b/configs/dataset/agbdlite.yaml index 973c5ee8..699d2cda 100644 --- a/configs/dataset/agbdlite.yaml +++ b/configs/dataset/agbdlite.yaml @@ -6,7 +6,7 @@ auto_download: True img_size: 25 multi_temporal: False -multi_modal: True +multi_modal: False eval_big: False # whether to evaluate on AGBD-test instead of AGBD-Lite-test lite_chunk_size: 1 # should be 32, once we can load in batches @@ -32,28 +32,28 @@ bands: - B7 - B8 - B8A - - B09 + - B9 - B11 - B12 - sar: - - HH - - HV + # alos: + # - HH + # - HV # statistics data_mean: optical: [0.1235409, 0.13345453, 0.15880638, 0.15243885, 0.20159355, 0.32084775, 0.35937154, 0.37473023, 0.38475886, 0.3847485, 0.28990296, 0.21002403] - sar: [-10.24998, -16.50024] + # alos: [-10.24998, -16.50024] data_std: optical: [0.023037845, 0.02647069, 0.030844089, 0.03733916, 0.041663107, 0.06969897, 0.08277564, 0.090343215, 0.086930655, 0.08083575, 0.075588495, 0.058553487] - sar: [8.61129, 8.732289] + # alos: [8.61129, 8.732289] data_min: optical: [1e-04, 1e-04, 0.0283, 0.0218, 0.0795, 0.0695, 0.0794, 0.0524, 0.0846, 0.0037, 0.0974, 0.0977] - sar: [-83.0, -83.0] + # alos: [-83.0, -83.0] data_max: optical: [1.4194, 2.1088, 1.9552, 1.856, 1.6232, 1.6055, 1.586, 1.7232, 1.593, 1.7081, 1.61, 1.6255] - sar: [13.297462, 11.688309] \ No newline at end of file + # alos: [13.297462, 11.688309] \ No newline at end of file diff --git a/configs/encoder/terramind_large.yaml b/configs/encoder/terramind_large.yaml index 7e7233db..97782c15 100644 --- a/configs/encoder/terramind_large.yaml +++ b/configs/encoder/terramind_large.yaml @@ -1,6 +1,6 @@ _target_: pangaea.encoders.terramind_encoder.terramind_v1_large -encoder_weights: /home/vmarsocci/pangaea-bench/pretrained_models/TerraMind_v1_large.pt -download_url: #https://drive.google.com/uc?id=1CseO5vvMReGlAulm5o4ZgbjUgj8VlAH7&export=download&confirm=yes +encoder_weights: ./pretrained_models/TerraMind_v1_large.pt +download_url: null #https://huggingface.co/ibm-esa-geospatial/TerraMind-1.0-large/blob/main/TerraMind_v1_large.pt # ckpt_path: /home/vmarsocci/pangaea-bench/pretrained_models/TerraMind_v1_large.pt # dim: 768 diff --git a/environment.yaml b/environment.yaml index 25c3811e..59d33986 100644 --- a/environment.yaml +++ b/environment.yaml @@ -32,4 +32,5 @@ dependencies: - wandb - hydra-core - h5py + - albumentations diff --git a/pangaea/encoders/terramind_encoder.py b/pangaea/encoders/terramind_encoder.py index 273df2e0..eb5636a2 100644 --- a/pangaea/encoders/terramind_encoder.py +++ b/pangaea/encoders/terramind_encoder.py @@ -2731,7 +2731,7 @@ def build_terrammind_vit( if encoder_weights is not None: # Load model from checkpoint - state_dict = torch.load(encoder_weights, map_location="cpu", weights_only=True) + state_dict = torch.load(encoder_weights, map_location="cpu", weights_only=False) loaded_keys = model.load_state_dict(state_dict, strict=False) if loaded_keys.missing_keys: logger.warning(f"Missing keys in encoder_weights {encoder_weights}: {loaded_keys.missing_keys}") From f981690b6f4862e3da0e07ac75e07a73d20b272f Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Tue, 24 Feb 2026 10:32:24 +0100 Subject: [PATCH 03/14] ok --- pangaea/run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pangaea/run.py b/pangaea/run.py index e3409caa..be741193 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -29,6 +29,7 @@ seed_worker, ) +os.environ["WANDB__SERVICE_WAIT"] = "300" def get_exp_info(hydra_config: HydraConf) -> dict[str, str]: """Create a unique experiment name based on the choices made in the config. From ae9e75bae6fd58cfceaff1a9af71cc3625c3869a Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Wed, 25 Feb 2026 17:47:19 +0100 Subject: [PATCH 04/14] modif --- pangaea/datasets/agbdlite.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pangaea/datasets/agbdlite.py b/pangaea/datasets/agbdlite.py index 4ae26659..e117aff1 100644 --- a/pangaea/datasets/agbdlite.py +++ b/pangaea/datasets/agbdlite.py @@ -172,8 +172,10 @@ def __getitem__(self, n): if not hasattr(self, 'alos_order') : self.alos_order = f['ALOS_bands'].attrs['order'] # Get the bands as gamma naught values - alos_bands = f['ALOS_bands'][idx_start : idx_end, :, :, :].astype(np.float32) - alos_bands = np.where(alos_bands == 0, -9999.0, 10 * np.log10(np.power(alos_bands, 2)) - 83.0) + _alos_bands = f['ALOS_bands'][idx_start : idx_end, :, :, :].astype(np.float32) + mask = (_alos_bands != 0) + alos_bands = np.full(_alos_bands.shape, -9999.0, dtype=np.float32) + alos_bands[mask] = 10 * np.log10(np.power(_alos_bands[mask], 2)) - 83.0 # Target data ----------------------------------------------------------------------------- target_value = torch.from_numpy(np.array(f['GEDI'][self.target][idx_start : idx_end], dtype = np.float32)).to(torch.float) From d875406e3e3f0ce9ee2bf1db6b1415c78f1727e8 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Sat, 7 Mar 2026 10:47:21 +0100 Subject: [PATCH 05/14] update --- configs/dataset/agbd.yaml | 52 +++++++ pangaea/datasets/agbd.py | 302 ++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 355 insertions(+) create mode 100644 configs/dataset/agbd.yaml create mode 100644 pangaea/datasets/agbd.py diff --git a/configs/dataset/agbd.yaml b/configs/dataset/agbd.yaml new file mode 100644 index 00000000..f56256a5 --- /dev/null +++ b/configs/dataset/agbd.yaml @@ -0,0 +1,52 @@ +_target_: pangaea.datasets.agbd.AGBD +dataset_name: AGBD +root_path: /cluster/scratch/gsialelli #/scratch3/gsialelli/patches #/cluster/scratch/gsialelli +download_url: null +auto_download: False + +img_size: 25 +multi_temporal: False +multi_modal: False + +hold_out_region: null +keep_region: False +drop_overlaps: True + +# for regression (agbd or rh98) +target: "agbd" # agbd or rh98 +ignore_index: -1 +num_classes: 1 +classes: + - regression +distribution: + - 1. + +# features +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B11 + - B12 + +# statistics + +data_mean: + optical: [0.1244767, 0.13442503, 0.15990143, 0.15322255, 0.2027034, 0.32371777, 0.3625338, 0.3777227, 0.38758147, 0.38764006, 0.29084504, 0.21135607] + +data_std: + optical: [0.024913007, 0.02851103, 0.032457292, 0.038697336, 0.042601082, 0.07086351, 0.08404868, 0.09149762, 0.08816632, 0.082269475, 0.07562889, 0.059278008] + +data_min: + optical: [1e-04, 1e-04, 1e-04, 1e-04, 0.0633, 0.0502, 0.0616, 0.0278, 0.055, 0.0012, 0.0954, 0.0975] + +data_max: + optical: [1.7, 2.1088, 2.12, 2.0032, 1.7413, 1.7228, 1.7084, 1.7312, 1.688, 1.7915, 1.648, 1.6775] diff --git a/pangaea/datasets/agbd.py b/pangaea/datasets/agbd.py new file mode 100644 index 00000000..d6436911 --- /dev/null +++ b/pangaea/datasets/agbd.py @@ -0,0 +1,302 @@ +import torch +from pangaea.datasets.base import RawGeoFMDataset + +import numpy as np +import h5py +from os.path import join +import pickle + +continent_to_region = {'North America': ['California', 'Cuba'], 'South America': ['Paraguay', 'FrenchGuiana'], + 'Africa': ['UnitedRepublicofTanzania', 'Ghana'], 'Europe': ['Austria', 'Greece'], + 'South Asia': ['Nepal', 'ShaanxiProvince'], 'Australasia': ['NewZealand']} + +def initialize_index(fnames, mode, chunk_size, path_mapping, path_h5, hold_out_region = None, keep_region = False, drop_overlaps = False) : + """ + This function creates the index for the dataset. The index is a dictionary which maps the file + names (`fnames`) to the tiles that are in the `mode` (train, val, test); and the tiles to the + number of chunks that make it up. + + Args: + - fnames (list): list of file names + - mode (str): the mode of the dataset (train, val, test) + - chunk_size (int): the size of the chunks + - path_mapping (str): the path to the file mapping each mode to its tiles + - path_h5 (str): the path to the h5 files + - hold_out_region (str): the region to hold out + - keep_region (bool): whether to keep the specified region + - drop_overlaps (bool): whether to drop overlapping patches + + Returns: + - idx (dict): dictionary mapping the file names to the tiles and the tiles to the chunks + - total_length (int): the total number of chunks in the dataset + """ + + # Load the mapping from mode to tile name + with open(join(path_mapping, 'biomes_splits_to_name.pkl'), 'rb') as f: + tile_mapping = pickle.load(f) + + # Skip the tiles in the region to hold out, if specified (only for train and val) + if hold_out_region and mode in ['train', 'val'] : + # Mapping from e.g. New Zealand to the S2 tiles it contains + with open(join(path_h5, 'tiles_per_region.pkl'), 'rb') as f: tiles_per_region = pickle.load(f) + # Mapping from world region (e.g. North America) to the regions in it (e.g. California, Cuba) + subregions = continent_to_region.get(hold_out_region) + hold_out_tiles = [] + for region in subregions : hold_out_tiles.extend(tiles_per_region[region]) + else : hold_out_tiles = [] + + # If need to drop the test patches that overlap with the AEF train set + if drop_overlaps: + with open(join(path_h5, 'AEF_overlaps.pkl'), 'rb') as f: + overlap = pickle.load(f) + + # Iterate over all files + idx = {} + for fname in fnames : + idx[fname] = {} + + with h5py.File(join(path_h5, fname), 'r') as f: + + # Get the tiles in this file which belong to the mode + all_tiles = list(f.keys()) + tiles = np.intersect1d(all_tiles, tile_mapping[mode]) + + # Iterate over the tiles + for tile in tiles : + + if (keep_region and len(hold_out_tiles) > 0): + if tile not in hold_out_tiles : + continue + else: + if tile in hold_out_tiles : continue + + # Get the number of patches in the tile + if drop_overlaps : + if fname in overlap and tile in overlap[fname] : indices_to_skip = overlap[fname][tile] + else: indices_to_skip = [] + n_total = len(f[tile]['GEDI']['agbd']) + n_patches = n_total - len(indices_to_skip) + idx[fname][tile] = {'n_patches' : n_patches, 'n_total' : n_total, 'indices_to_skip' : indices_to_skip} + else: + n_patches = len(f[tile]['GEDI']['agbd']) + idx[fname][tile] = n_patches // chunk_size + + if drop_overlaps : total_length = sum(sum(d['n_patches'] for d in idx[fname].values()) for fname in idx.keys()) + else: total_length = sum(sum(v for v in d.values()) for d in idx.values()) + + return idx, total_length + + +def init_ranges_for_chunk(index, total_length, drop_overlaps = False): + """ + This function creates a list of tuples (start_idx, end_idx, fname, tname) for each tile in the index, where + start_idx and end_idx are the indices of the first and last chunk of the tile in the dataset. This will allow us to + quickly find the file, tile, and row index corresponding to a given chunk index. + + Args: + - index (dict): the index of the dataset, mapping file names to tile names and tile names to number of chunks + - total_length (int): the total number of chunks in the dataset + - oversampling (bool): whether to use oversampling or not + - drop_overlaps (bool): whether to drop overlapping patches or not + + Returns: + - ranges (list): list of tuples (start_idx, end_idx, fname, tname) for each tile in the index + """ + + ranges = [] + start_idx = 0 + for fname, file_data in index.items() : + for tname, tile_data in file_data.items() : + num_patches = tile_data['n_patches'] if drop_overlaps else tile_data + end_idx = start_idx + num_patches + assert end_idx <= total_length, f"Index out of bounds: {end_idx} > {total_length}" + ranges.append((start_idx, end_idx, fname, tname)) + start_idx = end_idx + + return ranges + + +def find_index_for_chunk(index, ranges, n, total_length, drop_overlaps = False) : + """ + For a given `index`, `ranges`, and `n`-th chunk, find the file, tile, and row index corresponding to this chunk. + + Args: + - index (dict): dictionary mapping the files to the tiles and the tiles to the chunks + - ranges (list): list of tuples (start_idx, end_idx, fname, tname) for each tile in the index + - n (int): the n-th chunk + - total_length (int): the total number of chunks in the dataset + - chunk_size (int): the size of the chunks + - oversampling (bool): whether to use oversampling or not + - lite (bool): whether to use the lite version of the dataset + - drop_overlaps (bool): whether to drop overlapping patches or not + + Returns: + - file_name (str): the name of the file + - tile_name (str): the name of the tile + - chunk_within_tile (int): the chunk index within the tile + """ + + # Check that the chunk index is within bounds + assert n < total_length, "The chunk index is out of bounds" + + for start, end, fname, tname in ranges : + if start <= n < end : + chunk_within_tile = n - start + + if drop_overlaps : + tile_data = index[fname][tname] + indices_to_skip = tile_data['indices_to_skip'] + n_total = tile_data['n_total'] + indices_to_keep = np.setdiff1d(np.arange(n_total), indices_to_skip) + chunk_within_tile = indices_to_keep[chunk_within_tile] + + return fname, tname, chunk_within_tile + + +class AGBD(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + target: str, + hold_out_region: str | None = None, + keep_region: bool = False, + drop_overlaps: bool = False + ): + super(AGBD, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download + ) + + assert split in ['train', 'val', 'test'], "split must be one of 'train', 'val', or 'test'" + self.mode = split + self.target = target + self.patch_size = img_size + self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] + self.h5_path, self.mapping = root_path, root_path + self.fnames = [f'data_subset-{year}-v4_{i}-20.h5' for i in range(20) for year in [2019,2020]] + + self.hold_out_region = hold_out_region + self.keep_region = keep_region + self.drop_overlaps = drop_overlaps + + self.index, self.length = initialize_index(self.fnames, self.mode, 1, self.mapping, self.h5_path, self.hold_out_region, self.keep_region, self.drop_overlaps) + self.ranges = init_ranges_for_chunk(self.index, self.length, drop_overlaps = self.drop_overlaps) + + self.handles = {fname: h5py.File(join(self.h5_path, fname), 'r') for fname in self.index.keys()} + + + def __len__(self): + # Return the total number of samples + return int(self.length) + + def __getitem__(self, n): + """Returns the i-th item of the dataset. + + Args: + i (int): index of the item + + Raises: + NotImplementedError: raise if the method is not implemented + + Returns: + dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format + {"image": + { + "optical": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + "sar": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + }, + "target": torch.Tensor of shape (H W) of type torch.int64 for segmentation, torch.float for + regression datasets., + "metadata": dict}. + """ + + # Find the file, tile, and row index corresponding to this chunk + file_name, tile_name, idx = find_index_for_chunk(self.index, self.ranges, n, self.length, self.drop_overlaps) + f = self.handles[file_name][tile_name] + idx_start, idx_end = idx, idx + 1 + + + # Sentinel-2 bands ------------------------------------------------------------------------ + + # Set the order and indices for the Sentinel-2 bands + if not hasattr(self, 's2_order') : self.s2_order = list(f['S2_bands'].attrs['order']) + if not hasattr(self, 's2_indices') : self.s2_indices = [self.s2_order.index(band) for band in self.s2_bands] + + # Get the bands + s2_bands = f['S2_bands'][idx_start : idx_end, :, :, self.s2_indices].astype(np.float32) + + # Get the BOA offset, if it exists + if 'S2_boa_offset' in f['Sentinel_metadata'].keys() : s2_boa_offset = f['Sentinel_metadata']['S2_boa_offset'][idx_start : idx_end].astype(np.float32) + else: s2_boa_offset = np.full((s2_bands.shape[0],), 0, dtype = np.float32) + s2_boa_offset = s2_boa_offset[:, np.newaxis, np.newaxis, np.newaxis] + + # Get the surface reflectance values + sr_bands = (s2_bands - s2_boa_offset * 1000) / 10000 + sr_bands[s2_bands == 0] = 0 + sr_bands[sr_bands < 0] = 0 + + # Target data ----------------------------------------------------------------------------- + target_value = torch.from_numpy(np.array(f['GEDI'][self.target][idx_start : idx_end], dtype = np.float32)).to(torch.float) + lc = torch.from_numpy(np.array(f['LC'][idx_start : idx_end, :, :, 0])).long() + target = torch.full_like(lc, fill_value = self.ignore_index, dtype = torch.float if self.target in ['agbd', 'rh98'] else torch.long) + target[:, self.patch_size // 2, self.patch_size // 2] = target_value + + # Metadata (if needed) -------------------------------------------------------------------- + region = torch.from_numpy(np.array(f['GEDI']['region_cla'][idx_start : idx_end])).long() + biome = lc[:, self.patch_size // 2, self.patch_size // 2] + + # Convert to tensors and return ----------------------------------------------------------- + sr_bands = torch.from_numpy(sr_bands).float() + sr_bands = sr_bands.permute(0, 3, 1, 2).unsqueeze(2) # Change to (B, C, 1, H, W) + + # TEMPORARY, until we can return chunks + sr_bands = sr_bands.squeeze(0) + target = target.squeeze(0) + region = region.squeeze(0) + biome = biome.squeeze(0) + + return { + 'image': { + 'optical': sr_bands + }, + 'target': target, + 'metadata': { + 'region': region, + 'biome': biome + } + } + + @staticmethod + def download(self, silent=False): + pass \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 472037e3..0d44da32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ yacs wandb hydra-core h5py +albumentations From 31abe2cd1072d20f1395366527df11d934bdef82 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Mon, 9 Mar 2026 10:24:10 +0100 Subject: [PATCH 06/14] ok --- configs/dataset/agbd.yaml | 3 ++- configs/dataset/agbdlite.yaml | 1 + pangaea/datasets/agbd.py | 4 ++++ pangaea/datasets/agbdlite.py | 4 ++++ 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/configs/dataset/agbd.yaml b/configs/dataset/agbd.yaml index f56256a5..2c9ac96f 100644 --- a/configs/dataset/agbd.yaml +++ b/configs/dataset/agbd.yaml @@ -1,6 +1,7 @@ _target_: pangaea.datasets.agbd.AGBD dataset_name: AGBD -root_path: /cluster/scratch/gsialelli #/scratch3/gsialelli/patches #/cluster/scratch/gsialelli +root_path: /scratch3/gsialelli/patches +root_path_cluster: /cluster/scratch/gsialelli download_url: null auto_download: False diff --git a/configs/dataset/agbdlite.yaml b/configs/dataset/agbdlite.yaml index 699d2cda..bfaa7938 100644 --- a/configs/dataset/agbdlite.yaml +++ b/configs/dataset/agbdlite.yaml @@ -1,6 +1,7 @@ _target_: pangaea.datasets.agbdlite.AGBDLite dataset_name: AGBDLite root_path: ./data/AGBDLite +root_path_cluster: /cluster/scratch/gsialelli download_url: "https://zenodo.org/api/records/18485030" auto_download: True diff --git a/pangaea/datasets/agbd.py b/pangaea/datasets/agbd.py index d6436911..8137890a 100644 --- a/pangaea/datasets/agbd.py +++ b/pangaea/datasets/agbd.py @@ -5,6 +5,7 @@ import h5py from os.path import join import pickle +from os import getcwd continent_to_region = {'North America': ['California', 'Cuba'], 'South America': ['Paraguay', 'FrenchGuiana'], 'Africa': ['UnitedRepublicofTanzania', 'Ghana'], 'Europe': ['Austria', 'Greece'], @@ -161,6 +162,7 @@ def __init__( multi_modal: bool, multi_temporal: int, root_path: str, + root_path_cluster: str, classes: list, num_classes: int, ignore_index: int, @@ -184,6 +186,7 @@ def __init__( multi_modal=multi_modal, multi_temporal=multi_temporal, root_path=root_path, + root_path_cluster=root_path_cluster, classes=classes, num_classes=num_classes, ignore_index=ignore_index, @@ -203,6 +206,7 @@ def __init__( self.target = target self.patch_size = img_size self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] + if getcwd().startswith('/cluster') : self.root_path = self.root_path_cluster self.h5_path, self.mapping = root_path, root_path self.fnames = [f'data_subset-{year}-v4_{i}-20.h5' for i in range(20) for year in [2019,2020]] diff --git a/pangaea/datasets/agbdlite.py b/pangaea/datasets/agbdlite.py index e117aff1..89cfadd0 100644 --- a/pangaea/datasets/agbdlite.py +++ b/pangaea/datasets/agbdlite.py @@ -4,6 +4,7 @@ import numpy as np import h5py from os.path import join, isfile +from os import getcwd from tqdm import tqdm import requests import pathlib @@ -62,6 +63,7 @@ def __init__( multi_modal: bool, multi_temporal: int, root_path: str, + root_path_cluster: str, classes: list, num_classes: int, ignore_index: int, @@ -84,6 +86,7 @@ def __init__( multi_modal=multi_modal, multi_temporal=multi_temporal, root_path=root_path, + root_path_cluster=root_path_cluster, classes=classes, num_classes=num_classes, ignore_index=ignore_index, @@ -106,6 +109,7 @@ def __init__( self.patch_size = img_size self.zenodo_record = "18485030" if auto_download: self.download(self) + if getcwd().startswith('/cluster') : self.root_path = self.root_path_cluster self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] if self.eval_big and self.mode == 'test' : self.fname = 'AGBD-test.h5' From e8bb98c380d3a13b9a50fe89e903b8b293f77523 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Mon, 9 Mar 2026 12:09:50 +0100 Subject: [PATCH 07/14] ok --- configs/dataset/agbd.yaml | 2 +- configs/dataset/agbdlite.yaml | 2 +- pangaea/datasets/agbd.py | 11 +++++------ pangaea/datasets/agbdlite.py | 5 ++--- pangaea/engine/evaluator.py | 8 ++++++-- pangaea/engine/trainer.py | 4 ++-- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/configs/dataset/agbd.yaml b/configs/dataset/agbd.yaml index 2c9ac96f..96386e6a 100644 --- a/configs/dataset/agbd.yaml +++ b/configs/dataset/agbd.yaml @@ -1,7 +1,6 @@ _target_: pangaea.datasets.agbd.AGBD dataset_name: AGBD root_path: /scratch3/gsialelli/patches -root_path_cluster: /cluster/scratch/gsialelli download_url: null auto_download: False @@ -9,6 +8,7 @@ img_size: 25 multi_temporal: False multi_modal: False +root_path_cluster: /cluster/scratch/gsialelli hold_out_region: null keep_region: False drop_overlaps: True diff --git a/configs/dataset/agbdlite.yaml b/configs/dataset/agbdlite.yaml index bfaa7938..eab83ab0 100644 --- a/configs/dataset/agbdlite.yaml +++ b/configs/dataset/agbdlite.yaml @@ -1,7 +1,6 @@ _target_: pangaea.datasets.agbdlite.AGBDLite dataset_name: AGBDLite root_path: ./data/AGBDLite -root_path_cluster: /cluster/scratch/gsialelli download_url: "https://zenodo.org/api/records/18485030" auto_download: True @@ -9,6 +8,7 @@ img_size: 25 multi_temporal: False multi_modal: False +root_path_cluster: /cluster/scratch/gsialelli eval_big: False # whether to evaluate on AGBD-test instead of AGBD-Lite-test lite_chunk_size: 1 # should be 32, once we can load in batches diff --git a/pangaea/datasets/agbd.py b/pangaea/datasets/agbd.py index 8137890a..3de905d3 100644 --- a/pangaea/datasets/agbd.py +++ b/pangaea/datasets/agbd.py @@ -34,7 +34,7 @@ def initialize_index(fnames, mode, chunk_size, path_mapping, path_h5, hold_out_r # Load the mapping from mode to tile name with open(join(path_mapping, 'biomes_splits_to_name.pkl'), 'rb') as f: - tile_mapping = pickle.load(f) + tile_mapping = pickle.load(f)[mode] # Skip the tiles in the region to hold out, if specified (only for train and val) if hold_out_region and mode in ['train', 'val'] : @@ -60,7 +60,7 @@ def initialize_index(fnames, mode, chunk_size, path_mapping, path_h5, hold_out_r # Get the tiles in this file which belong to the mode all_tiles = list(f.keys()) - tiles = np.intersect1d(all_tiles, tile_mapping[mode]) + tiles = np.intersect1d(all_tiles, tile_mapping) # Iterate over the tiles for tile in tiles : @@ -162,7 +162,6 @@ def __init__( multi_modal: bool, multi_temporal: int, root_path: str, - root_path_cluster: str, classes: list, num_classes: int, ignore_index: int, @@ -175,6 +174,7 @@ def __init__( data_max: dict[str, list[str]], download_url: str, auto_download: bool, + root_path_cluster: str, target: str, hold_out_region: str | None = None, keep_region: bool = False, @@ -186,7 +186,6 @@ def __init__( multi_modal=multi_modal, multi_temporal=multi_temporal, root_path=root_path, - root_path_cluster=root_path_cluster, classes=classes, num_classes=num_classes, ignore_index=ignore_index, @@ -206,13 +205,13 @@ def __init__( self.target = target self.patch_size = img_size self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] - if getcwd().startswith('/cluster') : self.root_path = self.root_path_cluster + if getcwd().startswith('/cluster') : self.root_path = root_path_cluster self.h5_path, self.mapping = root_path, root_path self.fnames = [f'data_subset-{year}-v4_{i}-20.h5' for i in range(20) for year in [2019,2020]] self.hold_out_region = hold_out_region self.keep_region = keep_region - self.drop_overlaps = drop_overlaps + self.drop_overlaps = drop_overlaps and split == 'test' self.index, self.length = initialize_index(self.fnames, self.mode, 1, self.mapping, self.h5_path, self.hold_out_region, self.keep_region, self.drop_overlaps) self.ranges = init_ranges_for_chunk(self.index, self.length, drop_overlaps = self.drop_overlaps) diff --git a/pangaea/datasets/agbdlite.py b/pangaea/datasets/agbdlite.py index 89cfadd0..1b02dcb5 100644 --- a/pangaea/datasets/agbdlite.py +++ b/pangaea/datasets/agbdlite.py @@ -63,7 +63,6 @@ def __init__( multi_modal: bool, multi_temporal: int, root_path: str, - root_path_cluster: str, classes: list, num_classes: int, ignore_index: int, @@ -76,6 +75,7 @@ def __init__( data_max: dict[str, list[str]], download_url: str, auto_download: bool, + root_path_cluster: str, target: str, eval_big: bool, lite_chunk_size: int @@ -86,7 +86,6 @@ def __init__( multi_modal=multi_modal, multi_temporal=multi_temporal, root_path=root_path, - root_path_cluster=root_path_cluster, classes=classes, num_classes=num_classes, ignore_index=ignore_index, @@ -109,7 +108,7 @@ def __init__( self.patch_size = img_size self.zenodo_record = "18485030" if auto_download: self.download(self) - if getcwd().startswith('/cluster') : self.root_path = self.root_path_cluster + if getcwd().startswith('/cluster') : self.root_path = root_path_cluster self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] if self.eval_big and self.mode == 'test' : self.fname = 'AGBD-test.h5' diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index 8dd0f7f3..6a4e70d9 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -618,6 +618,9 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): mse = torch.zeros(1, device=self.device) for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)): + + if batch_idx == 500 : break # TODO remove, for debugging purposes + image, target = data['image'], data['target'] image = {k: v.to(self.device) for k, v in image.items()} target = target.to(self.device) @@ -632,7 +635,7 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): raise NotImplementedError((f"Inference mode {self.inference_mode} is not implemented.")) # sparse backprop - if self.val_loader.dataset.dataset_name == "AGBDLite" : + if self.val_loader.dataset.dataset_name in ["AGBDLite", "AGBD"]: valid_mask = (target != self.val_loader.dataset.ignore_index) logits = logits[valid_mask] target = target[valid_mask] @@ -640,7 +643,8 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): mse += F.mse_loss(logits, target) torch.distributed.all_reduce(mse, op=torch.distributed.ReduceOp.SUM) - mse = mse / len(self.val_loader) + # mse = mse / (len(self.val_loader) * torch.distributed.get_world_size()) + mse = mse / (500 * torch.distributed.get_world_size()) metrics = {"MSE": mse.item(), "RMSE": torch.sqrt(mse).item()} self.log_metrics(metrics) diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index a222cd91..b5a0ea6f 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -731,7 +731,7 @@ def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tens """ # sparse backprop - if self.train_loader.dataset.dataset_name == "AGBDLite" : + if self.train_loader.dataset.dataset_name in ["AGBDLite", "AGBD"]: valid_mask = (target != self.train_loader.dataset.ignore_index) logits = logits[valid_mask.unsqueeze(1)].unsqueeze(1) target = target[valid_mask] @@ -750,7 +750,7 @@ def compute_logging_metrics( """ # sparse backprop - if self.train_loader.dataset.dataset_name == "AGBDLite" : + if self.train_loader.dataset.dataset_name in ["AGBDLite", "AGBD"]: valid_mask = (target != self.train_loader.dataset.ignore_index) logits = logits[valid_mask.unsqueeze(1)].unsqueeze(1) target = target[valid_mask] From 3d279c06bc95e15bda932732164f02bac7236905 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Mon, 9 Mar 2026 13:14:23 +0100 Subject: [PATCH 08/14] other approach --- pangaea/datasets/agbd.py | 6 ++++-- pangaea/datasets/agbdlite.py | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pangaea/datasets/agbd.py b/pangaea/datasets/agbd.py index 3de905d3..32697d9f 100644 --- a/pangaea/datasets/agbd.py +++ b/pangaea/datasets/agbd.py @@ -5,7 +5,7 @@ import h5py from os.path import join import pickle -from os import getcwd +import os continent_to_region = {'North America': ['California', 'Cuba'], 'South America': ['Paraguay', 'FrenchGuiana'], 'Africa': ['UnitedRepublicofTanzania', 'Ghana'], 'Europe': ['Austria', 'Greece'], @@ -205,7 +205,9 @@ def __init__( self.target = target self.patch_size = img_size self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] - if getcwd().startswith('/cluster') : self.root_path = root_path_cluster + if os.environ.get('SLURM_SUBMIT_DIR') is not None: + print('Running on cluster, using cluster root path.') + self.root_path = root_path_cluster self.h5_path, self.mapping = root_path, root_path self.fnames = [f'data_subset-{year}-v4_{i}-20.h5' for i in range(20) for year in [2019,2020]] diff --git a/pangaea/datasets/agbdlite.py b/pangaea/datasets/agbdlite.py index 1b02dcb5..65b170bb 100644 --- a/pangaea/datasets/agbdlite.py +++ b/pangaea/datasets/agbdlite.py @@ -4,7 +4,7 @@ import numpy as np import h5py from os.path import join, isfile -from os import getcwd +import os from tqdm import tqdm import requests import pathlib @@ -108,8 +108,9 @@ def __init__( self.patch_size = img_size self.zenodo_record = "18485030" if auto_download: self.download(self) - if getcwd().startswith('/cluster') : self.root_path = root_path_cluster - + if os.environ.get('SLURM_SUBMIT_DIR') is not None: + print('Running on cluster, using cluster root path.') + self.root_path = root_path_cluster self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] if self.eval_big and self.mode == 'test' : self.fname = 'AGBD-test.h5' else: self.fname = f'AGBD-Lite-{self.mode}.h5' From 9df746c814fa6d7df0f76730a44f7aa117d6ffda Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Mon, 9 Mar 2026 14:22:20 +0100 Subject: [PATCH 09/14] fuck this --- pangaea/engine/evaluator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index 6a4e70d9..172aee7f 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -619,8 +619,6 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)): - if batch_idx == 500 : break # TODO remove, for debugging purposes - image, target = data['image'], data['target'] image = {k: v.to(self.device) for k, v in image.items()} target = target.to(self.device) From 3743816c5f54007f4542737cbb7bb6b990330c46 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Mon, 9 Mar 2026 15:38:23 +0100 Subject: [PATCH 10/14] ugh --- pangaea/datasets/agbd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pangaea/datasets/agbd.py b/pangaea/datasets/agbd.py index 32697d9f..fce4e296 100644 --- a/pangaea/datasets/agbd.py +++ b/pangaea/datasets/agbd.py @@ -208,7 +208,8 @@ def __init__( if os.environ.get('SLURM_SUBMIT_DIR') is not None: print('Running on cluster, using cluster root path.') self.root_path = root_path_cluster - self.h5_path, self.mapping = root_path, root_path + else: self.root_path = root_path + self.h5_path, self.mapping = self.root_path, self.root_path self.fnames = [f'data_subset-{year}-v4_{i}-20.h5' for i in range(20) for year in [2019,2020]] self.hold_out_region = hold_out_region From e47162124185e5ed505895b826aaddfca183a7d5 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Mon, 16 Mar 2026 15:34:08 +0100 Subject: [PATCH 11/14] eval more regularly on AGBDLite --- configs/dataset/agbd.yaml | 1 + pangaea/datasets/agbd.py | 47 ++++++++++++++++++++-------- pangaea/engine/evaluator.py | 3 +- pangaea/engine/trainer.py | 61 ++++++++++++++++++++++++++++--------- 4 files changed, 83 insertions(+), 29 deletions(-) diff --git a/configs/dataset/agbd.yaml b/configs/dataset/agbd.yaml index 96386e6a..85cb2080 100644 --- a/configs/dataset/agbd.yaml +++ b/configs/dataset/agbd.yaml @@ -12,6 +12,7 @@ root_path_cluster: /cluster/scratch/gsialelli hold_out_region: null keep_region: False drop_overlaps: True +eval_lite: True # for regression (agbd or rh98) target: "agbd" # agbd or rh98 diff --git a/pangaea/datasets/agbd.py b/pangaea/datasets/agbd.py index fce4e296..4e79ed86 100644 --- a/pangaea/datasets/agbd.py +++ b/pangaea/datasets/agbd.py @@ -154,6 +154,21 @@ def find_index_for_chunk(index, ranges, n, total_length, drop_overlaps = False) return fname, tname, chunk_within_tile +def get_file_and_idx(eval_lite, mode, index, ranges, n, length, drop_overlaps, handles) : + + if eval_lite and mode == 'val' : + idx_start, idx_end = n, n + 1 + return handles, idx_start, idx_end + + else: + # Find the file, tile, and row index corresponding to this chunk + file_name, tile_name, idx = find_index_for_chunk(index, ranges, n, length, drop_overlaps) + f = handles[file_name][tile_name] + idx_start, idx_end = idx, idx + 1 + return f, idx_start, idx_end + + + class AGBD(RawGeoFMDataset): def __init__( self, @@ -178,7 +193,8 @@ def __init__( target: str, hold_out_region: str | None = None, keep_region: bool = False, - drop_overlaps: bool = False + drop_overlaps: bool = False, + eval_lite: bool = False ): super(AGBD, self).__init__( split=split, @@ -202,6 +218,7 @@ def __init__( assert split in ['train', 'val', 'test'], "split must be one of 'train', 'val', or 'test'" self.mode = split + self.eval_lite = eval_lite self.target = target self.patch_size = img_size self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] @@ -210,16 +227,25 @@ def __init__( self.root_path = root_path_cluster else: self.root_path = root_path self.h5_path, self.mapping = self.root_path, self.root_path - self.fnames = [f'data_subset-{year}-v4_{i}-20.h5' for i in range(20) for year in [2019,2020]] - self.hold_out_region = hold_out_region - self.keep_region = keep_region - self.drop_overlaps = drop_overlaps and split == 'test' + # Evaluate on the AGBD-Lite validation set + if self.eval_lite and split == 'val' : + self.handles = h5py.File(join(self.root_path, 'AGBD-Lite-val.h5'), 'r') + self.index, self.ranges, self.drop_overlaps = None, None, None + self.length = len(self.handles['GEDI']['agbd']) + + # Regular training and evaluation on the full dataset + else: + self.fnames = [f'data_subset-{year}-v4_{i}-20.h5' for i in range(20) for year in [2019,2020]] + + self.hold_out_region = hold_out_region + self.keep_region = keep_region + self.drop_overlaps = drop_overlaps and split == 'test' - self.index, self.length = initialize_index(self.fnames, self.mode, 1, self.mapping, self.h5_path, self.hold_out_region, self.keep_region, self.drop_overlaps) - self.ranges = init_ranges_for_chunk(self.index, self.length, drop_overlaps = self.drop_overlaps) + self.index, self.length = initialize_index(self.fnames, self.mode, 1, self.mapping, self.h5_path, self.hold_out_region, self.keep_region, self.drop_overlaps) + self.ranges = init_ranges_for_chunk(self.index, self.length, drop_overlaps = self.drop_overlaps) - self.handles = {fname: h5py.File(join(self.h5_path, fname), 'r') for fname in self.index.keys()} + self.handles = {fname: h5py.File(join(self.h5_path, fname), 'r') for fname in self.index.keys()} def __len__(self): @@ -248,10 +274,7 @@ def __getitem__(self, n): """ # Find the file, tile, and row index corresponding to this chunk - file_name, tile_name, idx = find_index_for_chunk(self.index, self.ranges, n, self.length, self.drop_overlaps) - f = self.handles[file_name][tile_name] - idx_start, idx_end = idx, idx + 1 - + f, idx_start, idx_end = get_file_and_idx(self.eval_lite, self.mode, self.index, self.ranges, n, self.length, self.drop_overlaps, self.handles) # Sentinel-2 bands ------------------------------------------------------------------------ diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index 172aee7f..f479eee4 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -641,8 +641,7 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): mse += F.mse_loss(logits, target) torch.distributed.all_reduce(mse, op=torch.distributed.ReduceOp.SUM) - # mse = mse / (len(self.val_loader) * torch.distributed.get_world_size()) - mse = mse / (500 * torch.distributed.get_world_size()) + mse = mse / (len(self.val_loader) * torch.distributed.get_world_size()) metrics = {"MSE": mse.item(), "RMSE": torch.sqrt(mse).item()} self.log_metrics(metrics) diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index b5a0ea6f..d22c3e97 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -5,6 +5,7 @@ import pathlib import time import numpy as np +import math import torch import torch.nn as nn @@ -30,7 +31,7 @@ def __init__( precision: str, use_wandb: bool, ckpt_interval: int, - eval_interval: int, + eval_interval: float, log_interval: int, best_metric_key: str, ): @@ -49,7 +50,7 @@ def __init__( precision (str): precision to train the model (fp32, fp16, bfp16). use_wandb (bool): whether to use wandb for logging. ckpt_interval (int): interval to save the checkpoint. - eval_interval (int): interval to evaluate the model. + eval_interval (float): interval to evaluate the model in epoch units. log_interval (int): interval to log the training information. best_metric_key (str): metric that determines best checkpoints. """ @@ -67,9 +68,11 @@ def __init__( self.device = device self.use_wandb = use_wandb self.ckpt_interval = ckpt_interval - self.eval_interval = eval_interval + self.eval_interval = float(eval_interval) + if self.eval_interval <= 0: raise ValueError("eval_interval must be > 0.") self.log_interval = log_interval self.best_metric_key = best_metric_key + self.intra_epoch_eval_steps = self._compute_intra_epoch_eval_steps() self.training_stats = { name: RunningAverageMeter(length=self.batch_per_epoch) @@ -101,10 +104,8 @@ def train(self) -> None: # end_time = time.time() for epoch in range(self.start_epoch, self.n_epochs): # train the network for one epoch - if epoch % self.eval_interval == 0: - metrics, used_time = self.evaluator(self.model, f"epoch {epoch}") - self.training_stats["eval_time"].update(used_time) - self.save_best_checkpoint(metrics, epoch) + if self.eval_interval >= 1 and epoch % int(self.eval_interval) == 0: + self._run_validation(epoch, f"epoch {epoch}") self.logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler @@ -159,6 +160,13 @@ def train_one_epoch(self, epoch: int) -> None: if (batch_idx + 1) % self.log_interval == 0: self.log(batch_idx + 1, epoch) + if self.eval_interval < 1 and (batch_idx + 1) in self.intra_epoch_eval_steps: + self._run_validation( + epoch, + f"epoch {epoch} | step {batch_idx + 1}/{self.batch_per_epoch}", + ) + self.model.train() + self.lr_scheduler.step() if self.use_wandb and self.rank == 0: @@ -178,6 +186,26 @@ def train_one_epoch(self, epoch: int) -> None: self.training_stats["batch_time"].update(time.time() - end_time) end_time = time.time() + def _compute_intra_epoch_eval_steps(self) -> set[int]: + """Compute batch indices where validation should run within an epoch.""" + + if self.eval_interval >= 1: return set() + + n_eval = int(math.floor(1.0 / self.eval_interval)) + if n_eval == 1 : return set() + + eval_steps = { + max(1, min(self.batch_per_epoch, int(round(self.batch_per_epoch * i * self.eval_interval)))) + for i in range(1, n_eval + 1) + } + return eval_steps + + def _run_validation(self, epoch: int, model_name: str) -> None: + """Run validation, track timing, and update best checkpoint.""" + metrics, used_time = self.evaluator(self.model, model_name) + self.training_stats["eval_time"].update(used_time) + self.save_best_checkpoint(metrics, epoch) + def get_checkpoint(self, epoch: int) -> dict[str, dict | int]: """Create a checkpoint dictionary, containing references to the pytorch tensors. @@ -309,8 +337,11 @@ def log(self, batch_idx: int, epoch) -> None: left_batch_all = ( self.batch_per_epoch * (self.n_epochs - epoch - 1) + left_batch_this_epoch ) - left_eval_times = ((self.n_epochs - 0.5) // self.eval_interval + 2 - - self.training_stats["eval_time"].count) + if self.eval_interval >= 1: + total_eval_times = int((self.n_epochs - 0.5) // self.eval_interval + 2) + else: + total_eval_times = self.n_epochs * len(self.intra_epoch_eval_steps) + 1 + left_eval_times = max(0, total_eval_times - self.training_stats["eval_time"].count) left_time_this_epoch = sec_to_hm( left_batch_this_epoch * self.training_stats["batch_time"].avg ) @@ -369,7 +400,7 @@ def __init__( precision: str, use_wandb: bool, ckpt_interval: int, - eval_interval: int, + eval_interval: float, log_interval: int, best_metric_key: str, multi_label: bool = False, # <-- Flag for multi-label classification, e.g., BigEarthNet dataset @@ -390,7 +421,7 @@ def __init__( precision (str): precision to train the model (fp32, fp16, bfp16). use_wandb (bool): whether to use wandb for logging. ckpt_interval (int): interval to save the checkpoint. - eval_interval (int): interval to evaluate the model. + eval_interval (float): interval to evaluate the model. log_interval (int): interval to log the training information. best_metric_key (str): metric that determines best checkpoints. multi_label (bool): Flag to enable multi-label classification. @@ -547,7 +578,7 @@ def __init__( precision: str, use_wandb: bool, ckpt_interval: int, - eval_interval: int, + eval_interval: float, log_interval: int, best_metric_key: str, ): @@ -565,7 +596,7 @@ def __init__( precision (str): precision to train the model (fp32, fp16, bfp16). use_wandb (bool): whether to use wandb for logging. ckpt_interval (int): interval to save the checkpoint. - eval_interval (int): interval to evaluate the model. + eval_interval (float): interval to evaluate the model. log_interval (int): interval to log the training information. best_metric_key (str): metric that determines best checkpoints. """ @@ -673,7 +704,7 @@ def __init__( precision: str, use_wandb: bool, ckpt_interval: int, - eval_interval: int, + eval_interval: float, log_interval: int, best_metric_key: str, ): @@ -691,7 +722,7 @@ def __init__( precision (str): precision to train the model (fp32, fp16, bfp16). use_wandb (bool): whether to use wandb for logging. ckpt_interval (int): interval to save the checkpoint. - eval_interval (int): interval to evaluate the model. + eval_interval (float): interval to evaluate the model. log_interval (int): interval to log the training information. best_metric_key (str): metric that determines best checkpoints. """ From 8778e3e413252460e238fec928bd9c91d73b652d Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Fri, 20 Mar 2026 15:30:16 +0100 Subject: [PATCH 12/14] updates --- configs/encoder/prithvi2_100m.yaml | 39 + configs/encoder/terramind_base.yaml | 42 + configs/encoder/terramind_optical_tiny.yaml | 39 + configs/encoder/terramind_tiny.yaml | 42 + configs/encoder/thor.yaml | 26 + pangaea/datasets/agbdlite.py | 16 +- pangaea/encoders/prithvi2_encoder.py | 552 +++++++ pangaea/encoders/spectralgpt_encoder.py | 17 +- pangaea/encoders/terramind_encoder.py | 296 ++-- pangaea/encoders/thor_encoder.py | 1528 +++++++++++++++++++ pangaea/run.py | 1 + 11 files changed, 2458 insertions(+), 140 deletions(-) create mode 100644 configs/encoder/prithvi2_100m.yaml create mode 100644 configs/encoder/terramind_base.yaml create mode 100644 configs/encoder/terramind_optical_tiny.yaml create mode 100644 configs/encoder/terramind_tiny.yaml create mode 100644 configs/encoder/thor.yaml create mode 100644 pangaea/encoders/prithvi2_encoder.py create mode 100644 pangaea/encoders/thor_encoder.py diff --git a/configs/encoder/prithvi2_100m.yaml b/configs/encoder/prithvi2_100m.yaml new file mode 100644 index 00000000..6d4d3916 --- /dev/null +++ b/configs/encoder/prithvi2_100m.yaml @@ -0,0 +1,39 @@ +_target_: pangaea.encoders.prithvi2_encoder.Prithvi2_Encoder +encoder_weights: ./pretrained_models/Prithvi/Prithvi-EO-2.0-100M-TL/Prithvi_EO_V2_100M_TL.pt +download_url: False + +# Model architecture (100M parameters) +embed_dim: 768 +input_size: 224 +in_chans: 6 +patch_size: [1, 16, 16] +num_heads: 12 +depth: 12 +mlp_ratio: 4.0 +drop_path: 0.0 + +# Multi-temporal support +num_frames: ${dataset.multi_temporal} + +# Coordinate encoding options +coords_encoding: null # Can be ["time", "location"] if needed +coords_scale_learn: false + +# Input bands configuration +input_bands: + optical: + - B2 + - B3 + - B4 + - B8A + - B11 + - B12 + +# Output configuration +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/terramind_base.yaml b/configs/encoder/terramind_base.yaml new file mode 100644 index 00000000..82e7c56e --- /dev/null +++ b/configs/encoder/terramind_base.yaml @@ -0,0 +1,42 @@ +_target_: pangaea.encoders.terramind_encoder.terramind_v1_base +encoder_weights: /home/egm/Data/Projects/FMs/pangaea-bench/pretrained_models/TerraMind_v1_base.pt +download_url: #https://drive.google.com/uc?id=1CseO5vvMReGlAulm5o4ZgbjUgj8VlAH7&export=download&confirm=yes +# ckpt_path: /home/vmarsocci/pangaea-bench/pretrained_models/TerraMind_v1_large.pt + +# dim: 768 +input_size: 224 +patch_size: 16 +merge_method: "mean" +# in_chans: 13 +# num_heads: 6 +# depth: 12 +# mlp_ratio: 4 +# multi_temporal: False +modalities: ["S2L2A", "S1GRD"] + +input_bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + # - B9 + - B10 + - B11 + - B12 + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/terramind_optical_tiny.yaml b/configs/encoder/terramind_optical_tiny.yaml new file mode 100644 index 00000000..80ef0e89 --- /dev/null +++ b/configs/encoder/terramind_optical_tiny.yaml @@ -0,0 +1,39 @@ +_target_: pangaea.encoders.terramind_encoder.terramind_v1_tiny +encoder_weights: ./pretrained_models/TerraMind_v1_tiny.pt +download_url: #https://drive.google.com/uc?id=1CseO5vvMReGlAulm5o4ZgbjUgj8VlAH7&export=download&confirm=yes +# ckpt_path: /home/vmarsocci/pangaea-bench/pretrained_models/TerraMind_v1_large.pt + +# dim: 768 +input_size: 224 +patch_size: 16 +merge_method: "mean" +# in_chans: 13 +# num_heads: 6 +# depth: 12 +# mlp_ratio: 4 +# multi_temporal: False +modalities: ["S2L2A"] + +input_bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + # - B9 + - B10 + - B11 + - B12 + +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 192 \ No newline at end of file diff --git a/configs/encoder/terramind_tiny.yaml b/configs/encoder/terramind_tiny.yaml new file mode 100644 index 00000000..68e269ee --- /dev/null +++ b/configs/encoder/terramind_tiny.yaml @@ -0,0 +1,42 @@ +_target_: pangaea.encoders.terramind_encoder.terramind_v1_tiny +encoder_weights: /home/vmarsocci/pangaea-bench/pretrained_models/TerraMind_v1_tiny.pt +download_url: #https://drive.google.com/uc?id=1CseO5vvMReGlAulm5o4ZgbjUgj8VlAH7&export=download&confirm=yes +# ckpt_path: /home/vmarsocci/pangaea-bench/pretrained_models/TerraMind_v1_large.pt + +# dim: 768 +input_size: 224 +patch_size: 16 +merge_method: "mean" +# in_chans: 13 +# num_heads: 6 +# depth: 12 +# mlp_ratio: 4 +# multi_temporal: False +modalities: ["S2L2A", "S1GRD"] + +input_bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + # - B9 + - B10 + - B11 + - B12 + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 192 \ No newline at end of file diff --git a/configs/encoder/thor.yaml b/configs/encoder/thor.yaml new file mode 100644 index 00000000..37e0cc9b --- /dev/null +++ b/configs/encoder/thor.yaml @@ -0,0 +1,26 @@ +_target_: pangaea.encoders.thor_encoder.thor_base_encoder +encoder_weights: ./pretrained_models/THOR/thor_base.ckpt +# Other available THOR encoders: thor_tiny_encoder, thor_small_encoder, thor_large_encoder +input_bands: + optical: + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B11 + - B12 + # sar: + # - VV + # - VH +input_size: 96 +patch_size: 16 +output_aggr: mean +output_layers: + - 2 + - 5 + - 8 + - 11 \ No newline at end of file diff --git a/pangaea/datasets/agbdlite.py b/pangaea/datasets/agbdlite.py index 65b170bb..c8e81660 100644 --- a/pangaea/datasets/agbdlite.py +++ b/pangaea/datasets/agbdlite.py @@ -80,6 +80,15 @@ def __init__( eval_big: bool, lite_chunk_size: int ): + + assert split in ['train', 'val', 'test'], "split must be one of 'train', 'val', or 'test'" + self.mode = split + self.eval_big = eval_big + self.target = target + self.lite_chunk_size = lite_chunk_size + self.patch_size = img_size + self.zenodo_record = "18485030" + super(AGBDLite, self).__init__( split=split, dataset_name=dataset_name, @@ -100,13 +109,6 @@ def __init__( auto_download=auto_download, ) - assert split in ['train', 'val', 'test'], "split must be one of 'train', 'val', or 'test'" - self.mode = split - self.eval_big = eval_big - self.target = target - self.lite_chunk_size = lite_chunk_size - self.patch_size = img_size - self.zenodo_record = "18485030" if auto_download: self.download(self) if os.environ.get('SLURM_SUBMIT_DIR') is not None: print('Running on cluster, using cluster root path.') diff --git a/pangaea/encoders/prithvi2_encoder.py b/pangaea/encoders/prithvi2_encoder.py new file mode 100644 index 00000000..8c3f0757 --- /dev/null +++ b/pangaea/encoders/prithvi2_encoder.py @@ -0,0 +1,552 @@ +# Copyright (c) IBM Corp. 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# transformers: https://github.com/huggingface/transformers +# -------------------------------------------------------- + +from logging import Logger +from pathlib import Path + +import warnings +import logging +from anyio import Path +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from timm.layers import to_2tuple +from timm.models.vision_transformer import Block + +from pangaea.encoders.base import Encoder + +logger = logging.getLogger(__name__) + + +def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): + """ + Create 3D sin/cos positional embeddings. + + Args: + embed_dim (int): + Embedding dimension. + grid_size (tuple[int, int, int] | list[int]): + The grid depth, height and width. + add_cls_token (bool, *optional*, defaults to False): + Whether or not to add a classification (CLS) token. + + Returns: + (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or + (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token) + """ + + assert embed_dim % 16 == 0 + + t_size, h_size, w_size = grid_size + + w_embed_dim = embed_dim // 16 * 6 + h_embed_dim = embed_dim // 16 * 6 + t_embed_dim = embed_dim // 16 * 4 + + w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) + h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) + t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) + + w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) + h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) + t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) + + pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) + + if add_cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor): + """ Modified torch version of *get_1d_sincos_pos_embed_from_grid()*. + + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) - must be float dtype! + out: (M, D) + """ + assert embed_dim % 2 == 0 + assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16] + + omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return emb + + +def _init_weights(module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +def _interpolate_pos_encoding( + pos_embed: torch.Tensor, + grid_size: tuple[int, int, int] | list[int], + patch_size: tuple[int, int, int] | list[int], + shape: tuple[int, int, int], + embed_dim: int, +): + """ + Adapted from: + - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 + """ + t, h, w = shape + t_patches = t // patch_size[0] + h_patches = h // patch_size[1] + w_patches = w // patch_size[2] + + if [t_patches, h_patches, w_patches] == grid_size: + # No interpolation needed + return pos_embed + if t_patches != grid_size[0]: + # Re-compute pos embedding to handle changed num_frames + new_grid_size = (t_patches, *grid_size[1:]) + new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True) + new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0) + else: + new_grid_size = grid_size + new_pos_embed = pos_embed + + class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:] + + patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(h_patches, w_patches), + mode='bicubic', + align_corners=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + +class PatchEmbed(nn.Module): + """3D version of timm.models.vision_transformer.PatchEmbed""" + def __init__( + self, + input_size: tuple[int, int, int] = (1, 224, 224), + patch_size: tuple[int, int, int] = (1, 16, 16), + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: nn.Module | None = None, + flatten: bool = True, + bias: bool = True, + ): + super().__init__() + self.input_size = input_size + self.patch_size = patch_size + self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] + assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size." + self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] + self.flatten = flatten + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, T, H, W = x.shape + + if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: + warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." + f"The border will be ignored, add backbone_padding for pixel-wise tasks.") + + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C + x = self.norm(x) + return x + + +class TemporalEncoder(nn.Module): + def __init__(self, embed_dim: int, trainable_scale: bool = False): + super().__init__() + self.embed_dim = embed_dim + self.year_embed_dim = embed_dim // 2 + self.julian_day_embed_dim = embed_dim - self.year_embed_dim + + # If trainable, initialize scale with small number + if trainable_scale: + self.scale = nn.Parameter(torch.full((1,), 0.1)) + else: + self.register_buffer('scale', torch.ones(1)) + + def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None): + """ + temporal_coords: year and day-of-year info with shape (B, T, 2). + tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be + repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim). + """ + shape = temporal_coords.shape[:2] + (-1,) # B, T, -1 + + year = _get_1d_sincos_embed_from_grid_torch( + self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape) + julian_day = _get_1d_sincos_embed_from_grid_torch( + self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape) + + embedding = self.scale * torch.cat([year, julian_day], dim=-1) + + if tokens_per_frame is not None: + embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1) + + return embedding # B, T*tokens_per_frame, embed_dim + + +class LocationEncoder(nn.Module): + def __init__(self, embed_dim: int, trainable_scale: bool = False): + super().__init__() + self.embed_dim = embed_dim + self.lat_embed_dim = embed_dim // 2 + self.lon_embed_dim = embed_dim - self.lat_embed_dim + + # If trainable, initialize scale with small number + if trainable_scale: + self.scale = nn.Parameter(torch.full((1,), 0.1)) + else: + self.register_buffer('scale', torch.ones(1)) + + def forward(self, location_coords: torch.Tensor): + """ + location_coords: lat and lon info with shape (B, 2). + """ + shape = location_coords.shape[:1] + (1, -1) # B, 1, -1 + + lat = _get_1d_sincos_embed_from_grid_torch( + self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape) + lon = _get_1d_sincos_embed_from_grid_torch( + self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape) + + embedding = self.scale * torch.cat([lat, lon], dim=-1) + + return embedding # B, 1, embed_dim + + +class Prithvi2_Encoder(Encoder): + """ Prithvi ViT Encoder""" + def __init__(self, + encoder_weights: str | Path, + download_url: str, + input_bands: dict[str, list[str]], + input_size: int | tuple[int, int] = 224, + output_dim: int | list[int] = 1024, + output_layers: int | list[int] = [5, 11, 17, 23], + patch_size: int | tuple[int, int, int] = (1, 16, 16), + num_frames: int = 1, + in_chans: int = 3, + embed_dim: int = 1024, + depth: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4., + norm_layer: nn.Module = nn.LayerNorm, + coords_encoding: list[str] | None = None, + coords_scale_learn: bool = False, + drop_path: float = 0., + ** kwargs, + ): + super().__init__( + model_name="Prithvi2", + encoder_weights=encoder_weights, + input_bands=input_bands, + input_size=input_size, + embed_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, + multi_temporal=True, + multi_temporal_output=True, + pyramid_output=False, + download_url=download_url, + ) + + self.in_chans = in_chans + self.num_frames = num_frames if isinstance(num_frames, int) and num_frames > 0 else 1 + self.embed_dim = embed_dim + self.img_size = [input_size, input_size] if isinstance(input_size, int) else input_size + self.img_size = to_2tuple(self.img_size) + if isinstance(patch_size, int): + patch_size = (1, patch_size, patch_size) + + self.np = self.img_size[-1] // patch_size[-1] + + + # 3D patch embedding + self.patch_embed = PatchEmbed( + input_size=(self.num_frames,) + self.img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth + + # Optional temporal and location embedding + coords_encoding = coords_encoding or [] + self.temporal_encoding = 'time' in coords_encoding + self.location_encoding = 'location' in coords_encoding + if self.temporal_encoding: + assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}" + self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn) + if self.location_encoding: + self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) + + # Transformer layers + self.blocks = [] + for i in range(depth): + self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, + drop_path=drop_path,)) + self.blocks = nn.ModuleList(self.blocks) + + self.norm = norm_layer(embed_dim) + + self.initialize_weights() + + + def load_encoder_weights(self, logger: Logger) -> None: + pretrained_model = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) + pretrained_model = {key.replace("encoder.", ""): value for key, value in pretrained_model.items()} + k = pretrained_model.keys() + pretrained_encoder = {} + incompatible_shape = {} + missing = {} + for name, param in self.named_parameters(): + if name not in k: + missing[name] = param.shape + elif pretrained_model[name].shape != param.shape: + incompatible_shape[name] = (param.shape, pretrained_model[name].shape) + else: + pretrained_encoder[name] = pretrained_model[name] + # print(name) + # print(k) + + self.load_state_dict(pretrained_encoder, strict=False) + self.parameters_warning(missing, incompatible_shape, logger) + + def initialize_weights(self): + # initialize (and freeze) position embeddings by sin-cos embedding + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=0.02) + self.apply(_init_weights) + + def random_masking(self, sequence, mask_ratio, noise=None): + """ + Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random + noise. + + Args: + sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`) + mask_ratio (float): mask ratio to use. + noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is + mainly used for testing purposes to control randomness and maintain the reproducibility + """ + batch_size, seq_length, dim = sequence.shape + len_keep = int(seq_length * (1 - mask_ratio)) + + if noise is None: + noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([batch_size, seq_length], device=sequence.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return sequence_unmasked, mask, ids_restore + + def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]): + + pos_embed = _interpolate_pos_encoding( + pos_embed=self.pos_embed, + grid_size=self.patch_embed.grid_size, + patch_size=self.patch_embed.patch_size, + shape=sample_shape, + embed_dim=self.embed_dim, + ) + return pos_embed + + # def forward( + # self, image, + # temporal_coords: None | torch.Tensor = None, + # location_coords: None | torch.Tensor = None, + # mask_ratio=0.75 + # ): + # x = image["optical"] + # if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: + # # add time dim + # x = x.unsqueeze(2) + # sample_shape = x.shape[-3:] + + # # embed patches + # x = self.patch_embed(x) + + # pos_embed = self.interpolate_pos_encoding(sample_shape) + # # add pos embed w/o cls token + # x = x + pos_embed[:, 1:, :] + + # if self.temporal_encoding and temporal_coords is not None: + # num_tokens_per_frame = x.shape[1] // self.num_frames + # temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) + # x = x + temporal_encoding + # if self.location_encoding and location_coords is not None: + # location_encoding = self.location_embed_enc(location_coords) + # x = x + location_encoding + + # # masking: length -> length * mask_ratio + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # # append cls token + # cls_token = self.cls_token + pos_embed[:, :1, :] + # cls_tokens = cls_token.expand(x.shape[0], -1, -1) + # x = torch.cat((cls_tokens, x), dim=1) + + # # apply Transformer blocks + # for block in self.blocks: + # x = block(x) + # x = self.norm(x) + + # return x, mask, ids_restore + + def forward( + self, + image, + temporal_coords: None | torch.Tensor = None, + location_coords: None | torch.Tensor = None, + ) -> list[torch.Tensor]: + + x = image["optical"] + if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: + # add time dim + x = x.unsqueeze(2) + sample_shape = x.shape[-3:] + + # embed patches + x = self.patch_embed(x) + + pos_embed = self.interpolate_pos_encoding(sample_shape) + # add pos embed w/o cls token + x = x + pos_embed[:, 1:, :] + + if self.temporal_encoding and temporal_coords is not None: + num_tokens_per_frame = x.shape[1] // self.num_frames + temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) + x = x + temporal_encoding + if self.location_encoding and location_coords is not None: + location_encoding = self.location_embed_enc(location_coords) + x = x + location_encoding + + # append cls token + cls_token = self.cls_token + pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in self.output_layers: + out = ( + x[:, 1:, :] + .permute(0, 2, 1) + .view( + x.shape[0], + -1, + self.num_frames, + self.np, + self.np, + ) + .squeeze(2) + .contiguous() + ) + output.append(out) + + return output + + def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: + out = [] + effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] + for x in features: + x_no_token = x[:, 1:, :] + number_of_tokens = x_no_token.shape[1] + tokens_per_timestep = number_of_tokens // effective_time_dim + h = int(np.sqrt(tokens_per_timestep)) + encoded = rearrange( + x_no_token, + "batch (t h w) e -> batch (t e) h w", + e=self.embed_dim, + t=effective_time_dim, + h=h, + ) + out.append(encoded) + return out \ No newline at end of file diff --git a/pangaea/encoders/spectralgpt_encoder.py b/pangaea/encoders/spectralgpt_encoder.py index dfc8962c..6d2be05b 100644 --- a/pangaea/encoders/spectralgpt_encoder.py +++ b/pangaea/encoders/spectralgpt_encoder.py @@ -165,10 +165,21 @@ def load_encoder_weights(self, logger: Logger) -> None: self.parameters_warning(missing, incompatible_shape, logger) def forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]: - # input image of shape B C H W - x = image["optical"].unsqueeze(-3) # B C H W -> B C 1 H W + x = image["optical"] + if x.dim() == 4: + x = x.unsqueeze(-3) # B C H W -> B C 1 H W + elif x.dim() != 5: + raise ValueError( + f"SpectralGPT expects optical input with 4 or 5 dimensions, got shape {tuple(x.shape)}" + ) + + if x.shape[2] != 1: + raise ValueError( + "SpectralGPT is configured as a single-temporal encoder and expects input shaped (B, C, 1, H, W). " + f"Got {tuple(x.shape)}" + ) - x = x.permute(0, 2, 1, 3, 4) # for this model: B, T, C, H, W + x = x.permute(0, 2, 1, 3, 4) # B C 1 H W -> B 1 C H W x = self.patch_embed(x) N, T, L, C = x.shape # T: number of bands; L: spatial diff --git a/pangaea/encoders/terramind_encoder.py b/pangaea/encoders/terramind_encoder.py index eb5636a2..b79e1687 100644 --- a/pangaea/encoders/terramind_encoder.py +++ b/pangaea/encoders/terramind_encoder.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from huggingface_hub import hf_hub_download from PIL import Image -import albumentations as A +# import albumentations as A import numpy as np import torch import torch.nn.functional as F @@ -890,135 +890,135 @@ def postprocess(self, sample): return sample -class DetectionTransform(AbstractTransform): - - def __init__(self, det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, - min_visibility=0.0, return_raw=False): - self.det_threshold = det_threshold - self.det_max_instances = det_max_instances - self.coord_bins = coord_bins - self.min_visibility = min_visibility - self.return_raw = return_raw - - if bbox_order == 'area': - self.bbox_order = self.order_bboxes_by_area - elif bbox_order == 'score': - self.bbox_order = self.order_bboxes_by_score - elif bbox_order == 'random': - self.bbox_order = self.shuffle_bboxes - else: - self.bbox_order = self.order_bboxes_by_dist_to_orig - - @staticmethod - def order_bboxes_by_area(bboxes): - return sorted(bboxes, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]), reverse=True) - - @staticmethod - def order_bboxes_by_dist_to_orig(bboxes): - return sorted(bboxes, key=lambda x: x[0] ** 2 + x[1] ** 2) - - @staticmethod - def order_bboxes_by_score(bboxes): - return sorted(bboxes, key=lambda x: x[5], reverse=True) - - @staticmethod - def shuffle_bboxes(bboxes): - return sorted(bboxes, key=lambda x: random.random()) - - def convert_detection_instance(self, instances): - """Convert instances dict to list of lists where each list takes the form: - [xmin, ymin, xmax, ymax, class_name, score] - """ - - instances = [inst['boxes'] + [inst['class_name'], inst['score']] for inst in instances if - inst['score'] >= self.det_threshold] - return instances - - def bboxes_hflip(self, bboxes: List[Tuple], image_size: Tuple, flip: bool): - image_height, image_width = image_size - if flip: - bboxes = [tuple(A.bbox_hflip(bbox[:4], rows=image_height, cols=image_width)) + tuple(bbox[4:]) - for bbox in bboxes] - - return bboxes - - def bboxes_crop_and_resize(self, bboxes: List[Tuple], crop_coords: Tuple, orig_size: Tuple): - """Crop and resize bounding boxes - - Args: - bboxes: Bounding boxes to crop and resize - crop_coords: Coordinates of the crop (top, left, h, w) - orig_size: Size of the original image - - Returns: - Cropped and resized bounding boxes - """ - orig_height, orig_width = orig_size - top, left, h, w = crop_coords - xmin, ymin, xmax, ymax = left, top, left + w, top + h - bboxes = [tuple(A.bbox_crop(bbox[:4], x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, - cols=orig_width)) + tuple(bbox[4:]) - for bbox in bboxes] - bboxes = A.core.bbox_utils.filter_bboxes(bboxes, rows=h, cols=w, min_visibility=self.min_visibility) - # No need to resize, bounding boxes in albumentations format are scale invariant - - return bboxes - - def order_and_filter_bboxes(self, bboxes): - if self.det_max_instances is not None and len(bboxes) > self.det_max_instances: - bboxes = self.order_bboxes_by_score(bboxes)[:self.det_max_instances] - - return self.bbox_order(bboxes) - - def convert_bboxes_to_string(self, bboxes: List[Tuple]): - """Convert bounding boxes to a string - - Args: - bboxes: Bounding boxes - - Returns: - String representation of the bounding boxes - """ - # Remove score, quantize coordinates - bins = self.coord_bins - - bboxes = [ - [ - f"xmin={round(xmin * (bins - 1))}", - f"ymin={round(ymin * (bins - 1))}", - f"xmax={round(xmax * (bins - 1))}", - f"ymax={round(ymax * (bins - 1))}", - cls, - ] - for (xmin, ymin, xmax, ymax, cls, score) in bboxes - ] - # Convert each bounding box to a string - bboxes = [' '.join(b) for b in bboxes] - # Convert the list to a str - return ' '.join(bboxes) - - def load(self, path): - with open(path, 'r') as f: - sample = json.load(f) - - return sample - - def preprocess(self, sample): - instances = sample['instances'] - return self.convert_detection_instance(instances) - - def image_augment(self, bboxes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx=None, resample_mode: str = None): - bboxes = self.bboxes_crop_and_resize(bboxes, crop_coords, orig_size) - bboxes = self.bboxes_hflip(bboxes, target_size, flip) - bboxes = self.order_and_filter_bboxes(bboxes) - return bboxes - - def postprocess(self, bboxes): - if self.return_raw: - return bboxes - bboxes = self.convert_bboxes_to_string(bboxes) - return bboxes +# class DetectionTransform(AbstractTransform): + +# def __init__(self, det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, +# min_visibility=0.0, return_raw=False): +# self.det_threshold = det_threshold +# self.det_max_instances = det_max_instances +# self.coord_bins = coord_bins +# self.min_visibility = min_visibility +# self.return_raw = return_raw + +# if bbox_order == 'area': +# self.bbox_order = self.order_bboxes_by_area +# elif bbox_order == 'score': +# self.bbox_order = self.order_bboxes_by_score +# elif bbox_order == 'random': +# self.bbox_order = self.shuffle_bboxes +# else: +# self.bbox_order = self.order_bboxes_by_dist_to_orig + +# @staticmethod +# def order_bboxes_by_area(bboxes): +# return sorted(bboxes, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]), reverse=True) + +# @staticmethod +# def order_bboxes_by_dist_to_orig(bboxes): +# return sorted(bboxes, key=lambda x: x[0] ** 2 + x[1] ** 2) + +# @staticmethod +# def order_bboxes_by_score(bboxes): +# return sorted(bboxes, key=lambda x: x[5], reverse=True) + +# @staticmethod +# def shuffle_bboxes(bboxes): +# return sorted(bboxes, key=lambda x: random.random()) + +# def convert_detection_instance(self, instances): +# """Convert instances dict to list of lists where each list takes the form: +# [xmin, ymin, xmax, ymax, class_name, score] +# """ + +# instances = [inst['boxes'] + [inst['class_name'], inst['score']] for inst in instances if +# inst['score'] >= self.det_threshold] +# return instances + +# def bboxes_hflip(self, bboxes: List[Tuple], image_size: Tuple, flip: bool): +# image_height, image_width = image_size +# if flip: +# bboxes = [tuple(A.bbox_hflip(bbox[:4], rows=image_height, cols=image_width)) + tuple(bbox[4:]) +# for bbox in bboxes] + +# return bboxes + +# def bboxes_crop_and_resize(self, bboxes: List[Tuple], crop_coords: Tuple, orig_size: Tuple): +# """Crop and resize bounding boxes + +# Args: +# bboxes: Bounding boxes to crop and resize +# crop_coords: Coordinates of the crop (top, left, h, w) +# orig_size: Size of the original image + +# Returns: +# Cropped and resized bounding boxes +# """ +# orig_height, orig_width = orig_size +# top, left, h, w = crop_coords +# xmin, ymin, xmax, ymax = left, top, left + w, top + h +# bboxes = [tuple(A.bbox_crop(bbox[:4], x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, +# cols=orig_width)) + tuple(bbox[4:]) +# for bbox in bboxes] +# bboxes = A.core.bbox_utils.filter_bboxes(bboxes, rows=h, cols=w, min_visibility=self.min_visibility) +# # No need to resize, bounding boxes in albumentations format are scale invariant + +# return bboxes + +# def order_and_filter_bboxes(self, bboxes): +# if self.det_max_instances is not None and len(bboxes) > self.det_max_instances: +# bboxes = self.order_bboxes_by_score(bboxes)[:self.det_max_instances] + +# return self.bbox_order(bboxes) + +# def convert_bboxes_to_string(self, bboxes: List[Tuple]): +# """Convert bounding boxes to a string + +# Args: +# bboxes: Bounding boxes + +# Returns: +# String representation of the bounding boxes +# """ +# # Remove score, quantize coordinates +# bins = self.coord_bins + +# bboxes = [ +# [ +# f"xmin={round(xmin * (bins - 1))}", +# f"ymin={round(ymin * (bins - 1))}", +# f"xmax={round(xmax * (bins - 1))}", +# f"ymax={round(ymax * (bins - 1))}", +# cls, +# ] +# for (xmin, ymin, xmax, ymax, cls, score) in bboxes +# ] +# # Convert each bounding box to a string +# bboxes = [' '.join(b) for b in bboxes] +# # Convert the list to a str +# return ' '.join(bboxes) + +# def load(self, path): +# with open(path, 'r') as f: +# sample = json.load(f) + +# return sample + +# def preprocess(self, sample): +# instances = sample['instances'] +# return self.convert_detection_instance(instances) + +# def image_augment(self, bboxes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, +# rand_aug_idx=None, resample_mode: str = None): +# bboxes = self.bboxes_crop_and_resize(bboxes, crop_coords, orig_size) +# bboxes = self.bboxes_hflip(bboxes, target_size, flip) +# bboxes = self.order_and_filter_bboxes(bboxes) +# return bboxes + +# def postprocess(self, bboxes): +# if self.return_raw: +# return bboxes +# bboxes = self.convert_bboxes_to_string(bboxes) +# return bboxes class CaptionTransform(AbstractTransform): @@ -2751,6 +2751,42 @@ def build_terrammind_vit( return model +def terramind_v1_tiny(**kwargs): + model = build_terrammind_vit( + variant="terramind_v1_tiny", + encoder_depth=12, + # decoder_depth=4, + dim=192, + num_heads=3, + mlp_ratio=4, + qkv_bias=True, + proj_bias=True, + mlp_bias=True, + norm_layer=partial(LayerNorm, eps=1e-6, bias=False), + act_layer=nn.GELU, + gated_mlp=False, + pretrained_bands=PRETRAINED_BANDS, + **kwargs + ) + return model + +def terramind_v1_small(**kwargs): + model = build_terrammind_vit( + variant="terramind_v1_small", + encoder_depth=12, + dim=384, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + proj_bias=True, + mlp_bias=True, + norm_layer=partial(LayerNorm, eps=1e-6, bias=False), + act_layer=nn.GELU, + gated_mlp=False, + pretrained_bands=PRETRAINED_BANDS, + **kwargs + ) + return model # @TERRATORCH_BACKBONE_REGISTRY.register def terramind_v1_base(**kwargs): diff --git a/pangaea/encoders/thor_encoder.py b/pangaea/encoders/thor_encoder.py new file mode 100644 index 00000000..ee56d5b7 --- /dev/null +++ b/pangaea/encoders/thor_encoder.py @@ -0,0 +1,1528 @@ +import math +import warnings +import logging +from collections.abc import Iterable, Sequence +from functools import partial +from itertools import repeat +from logging import Logger +from pathlib import Path +from typing import Any, Final, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange +from timm.layers import ( + DropPath, + Mlp, + use_fused_attn, +) +from timm.models.vision_transformer import LayerScale +from torch import Tensor, nn, vmap + +from pangaea.encoders.base import Encoder + +logging.basicConfig( + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +pangaea_to_thor_band_map = { + "B1": "S2:CoastAerosal", + "B2": "S2:Blue", + "B3": "S2:Green", + "B4": "S2:Red", + "B5": "S2:RE1", + "B6": "S2:RE2", + "B7": "S2:RE3", + "B8": "S2:NIR", + "B8A": "S2:RE4", + "B9": "S2:WaterVapor", + # "B10 # SWIR - Cirrus (60m) # We don't have this band + "B11": "S2:SWIR1", + "B12": "S2:SWIR2", + "VV": "S1:IW-VV", + "VH": "S1:IW-VH", + "ASC_VV": "S1:IW-VV", + "ASC_VH": "S1:IW-VH", + "DSC_VV": "S1:IW-VV", + "DSC_VH": "S1:IW-VH", +} + + +DEFAULT_GROUPS = [ + ["S2:Red", "S2:Green", "S2:Blue", "S2:NIR"], + ["S2:RE1", "S2:RE2", "S2:RE3", "S2:RE4", "S2:SWIR1", "S2:SWIR2"], + ["S2:CoastAerosal", "S2:WaterVapor"], + ["S1:IW-VH", "S1:IW-VV", "S1:EW-VH", "S1:EW-VV"], + ["S1:IW-HV", "S1:IW-HH", "S1:EW-HV", "S1:EW-HH"], + [ + "S3:Oa01_reflectance", + "S3:Oa02_reflectance", + "S3:Oa03_reflectance", + "S3:Oa04_reflectance", + "S3:Oa05_reflectance", + "S3:Oa06_reflectance", + "S3:Oa07_reflectance", + ], + [ + "S3:Oa08_reflectance", + "S3:Oa09_reflectance", + "S3:Oa10_reflectance", + "S3:Oa11_reflectance", + "S3:Oa12_reflectance", + "S3:Oa13_reflectance", + "S3:Oa14_reflectance", + ], + [ + "S3:Oa15_reflectance", + "S3:Oa16_reflectance", + "S3:Oa17_reflectance", + "S3:Oa18_reflectance", + "S3:Oa19_reflectance", + "S3:Oa20_reflectance", + "S3:Oa21_reflectance", + ], + [ + "S3:S1_reflectance_an", + "S3:S2_reflectance_an", + "S3:S3_reflectance_an", + "S3:S4_reflectance_an", + "S3:S5_reflectance_an", + "S3:S6_reflectance_an", + ], + ["S3:S7_BT_in", "S3:S8_BT_in", "S3:S9_BT_in"], +] + +DEFAULT_CHANNELS = { + "S2:Red": { + "GSD": 10, + "patch_size": 16, # px + }, + "S2:Green": { + "GSD": 10, + "patch_size": 16, # px + }, + "S2:Blue": { + "GSD": 10, + "patch_size": 16, # px + }, + "S2:NIR": { + "GSD": 10, + "patch_size": 16, # px + }, + "S2:RE1": { + "GSD": 20, + "patch_size": 16, # px + }, + "S2:RE2": { + "GSD": 20, + "patch_size": 16, # px + }, + "S2:RE3": { + "GSD": 20, + "patch_size": 16, # px + }, + "S2:RE4": { + "GSD": 20, + "patch_size": 16, # px + }, + "S2:SWIR1": { + "GSD": 20, + "patch_size": 16, # px + }, + "S2:SWIR2": { + "GSD": 20, + "patch_size": 16, # px + }, + "S2:CoastAerosal": { + "GSD": 60, + "patch_size": 16, # px + }, + "S2:WaterVapor": { + "GSD": 60, + "patch_size": 16, # px + }, + "S1:IW-VV": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:VV", + }, + "S1:IW-VH": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:VH", + }, + "S1:IW-HV": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:HV", + }, + "S1:IW-HH": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:HH", + }, + "S1:EW-VV": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:VV", + }, + "S1:EW-VH": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:VH", + }, + "S1:EW-HV": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:HV", + }, + "S1:EW-HH": { + "GSD": 10, + "patch_size": 16, # px + "patch_embed_name": "S1:HH", + }, + "S3:Oa01_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa02_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa03_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa04_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa05_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa06_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa07_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa08_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa09_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa10_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa11_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa12_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa13_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa14_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa15_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa16_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa17_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa18_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa19_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa20_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:Oa21_reflectance": { + "GSD": 240, # GSD 240 interp + "patch_size": 16, # px + }, + "S3:S1_reflectance_an": { + "GSD": 480, # GSD 480 interp + "patch_size": 16, # px + }, + "S3:S2_reflectance_an": { + "GSD": 480, # GSD 480 interp + "patch_size": 16, # px + }, + "S3:S3_reflectance_an": { + "GSD": 480, # GSD 480 interp + "patch_size": 16, # px + }, + "S3:S4_reflectance_an": { + "GSD": 480, # GSD 480 interp + "patch_size": 16, # px + }, + "S3:S5_reflectance_an": { + "GSD": 480, # GSD 480 interp + "patch_size": 16, # px + }, + "S3:S6_reflectance_an": { + "GSD": 480, # GSD 480 interp + "patch_size": 16, # px + }, + "S3:S7_BT_in": { + "GSD": 960, # GSD 960 interp + "patch_size": 16, # px + }, + "S3:S8_BT_in": { + "GSD": 960, # GSD 960 interp + "patch_size": 16, # px + }, + "S3:S9_BT_in": { + "GSD": 960, # GSD 960 interp + "patch_size": 16, # px + }, +} + + +def to_2tuple(x: Any) -> tuple[int, int]: + if isinstance(x, Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, 2)) + + +class IndFlexiPatchEmbed(nn.Module): + def __init__( + self, + ground_covers: list[int], + channels: dict[str, dict[str, int]], + channel_rename_map: dict[str, str] | None = None, + embed_dim: int = 768, + norm_layer: nn.Module | None = None, + flatten: bool = True, + bias: bool = True, + patch_size_seqs: dict[str, Sequence[int]] | Sequence[int] = ( + 4, + 6, + 8, + 10, + 12, + 16, + ), + interpolation: str = "bicubic", + antialias: bool = True, + ) -> None: + """2D image to patch embedding w/ flexible patch sizes, for multiple product bands + Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24 + + Args: + ground_cover: Ground cover size in meters + channels: Dictionary of product bands and their parameters (GSD, num_patch) + channel_rename_map: Dictionary of channel names to rename, to i.e., use same parameters for different bands + embed_dim: Network embedding dimension size + norm_layer: Optional normalization layer + flatten: Whether to flatten the spatial dimensions of the output + bias: Whether to use bias in convolution + patch_size_seqs: Dict of List of patch sizes for each band or list of patch sizes for all bands to + randomly sample from, unvalidated patch sizes are dropped + patch_size_probs: Optional Dict of list of probabilities for each band or list of probabilities for all + bands to sample corresponding patch_size_seqs elements. If None, then uniform distribution is used + interpolation: Resize interpolation type + antialias: Whether to apply antialiasing resizing + """ + super().__init__() + self.interpolation = interpolation + self.antialias = antialias + self.flatten = flatten + self.channels = channels + self.channel_rename_map = channel_rename_map + self.embed_dim = embed_dim + self.patch_sizes = {} + proj_dict = {} + for product_band, params in channels.items(): + kernel_size = params["patch_size"] + self.patch_sizes[product_band] = to_2tuple(kernel_size) + if channel_rename_map and product_band in channel_rename_map: + product_band = channel_rename_map[product_band] + if product_band in proj_dict: + logger.info(f"Product band {product_band} already added, skipping") + continue + proj_dict[product_band] = nn.Conv2d( + 1, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias + ) + self.patch_embed = nn.ModuleDict(proj_dict) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if not isinstance(patch_size_seqs, dict): + patch_size_seqs = { + product_band: patch_size_seqs for product_band in channels + } + + # filter valid patch size seqs + for product_band in patch_size_seqs.keys(): + _patch_size_seq = [] + for patch_size in patch_size_seqs[product_band]: + product_gsd = channels[product_band]["GSD"] + for ground_cover in ground_covers: + product_num_patch = ground_cover // patch_size // product_gsd + if patch_size * product_num_patch * product_gsd != ground_cover: + # Skip patch sizes that don't add up for the given ground cover + continue + if patch_size not in _patch_size_seq: + _patch_size_seq.append(patch_size) + + if len(_patch_size_seq) == 0: + msg = ( + f"No valid patch sizes for {product_band} for ground cover {ground_covers} and GSD {product_gsd}" + f" with patch size seq {patch_size_seqs[product_band]}" + ) + logger.warning(msg) + continue + + logger.info(f"product_band: {product_band}, patch_size_seq: {_patch_size_seq}") + patch_size_seqs[product_band] = sorted(_patch_size_seq) + + self.patch_size_seqs = patch_size_seqs + self.pinvs = {} + + def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor: + x_resized = F.interpolate( + x[None, None, ...], + shape, + mode=self.interpolation, + antialias=self.antialias, + ) + return x_resized[0, 0, ...] + + def _calculate_pinv( + self, old_shape: tuple[int, int], new_shape: tuple[int, int] + ) -> Tensor: + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = torch.zeros(old_shape) + basis_vec[np.unravel_index(i, old_shape)] = 1.0 + mat.append(self._resize(basis_vec, new_shape).reshape(-1)) + resize_matrix = torch.stack(mat) + return torch.linalg.pinv(resize_matrix) + + def resize_patch_embed( + self, + patch_embed: Tensor, + patch_size: tuple[int, int], + new_patch_size: tuple[int, int], + ): + """Resize patch_embed to target resolution via pseudo-inverse resizing""" + # Return original kernel if no resize is necessary + if patch_size == new_patch_size: + return patch_embed + + # Calculate pseudo-inverse of resize matrix + if patch_size not in self.pinvs or new_patch_size not in self.pinvs[patch_size]: + if patch_size not in self.pinvs: + self.pinvs[patch_size] = {} + self.pinvs[patch_size][new_patch_size] = self._calculate_pinv( + patch_size, new_patch_size + ) + pinv = self.pinvs[patch_size][new_patch_size] + pinv = pinv.to(patch_embed.device) + + def resample_patch_embed(patch_embed: Tensor): + h, w = new_patch_size + resampled_kernel = pinv @ patch_embed.reshape(-1) + return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w) + + v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) + + return v_resample_patch_embed(patch_embed) + + def forward( + self, + x: dict[str, Tensor], + patch_sizes: dict[str, tuple[int, int]] | None = None, + ) -> Tensor | tuple[Tensor, tuple[int, int]]: + if patch_sizes is None: + # During evaluation use base patch sizes if not specified + patch_sizes = self.patch_sizes + + patch_embed_dict = {} + for product_band, data in x.items(): + if product_band not in patch_sizes: + logger.info(f"Skipping product band: {product_band}") + continue + + patch_size = patch_sizes[product_band] + patch_size = to_2tuple(patch_size) + + patch_embed_name = ( + self.channel_rename_map[product_band] + if self.channel_rename_map + else product_band + ) + + # Resize conv weights + if patch_size == self.patch_sizes[product_band]: + weight = self.patch_embed[patch_embed_name].weight + else: + weight = self.resize_patch_embed( + self.patch_embed[patch_embed_name].weight, + self.patch_sizes[product_band], + patch_size, + ) + + # Apply conv with resized weights + data = F.conv2d( + input=data, + weight=weight, + bias=self.patch_embed[patch_embed_name].bias, + stride=patch_size, + ) + + if self.flatten: + data = data.flatten(2).transpose(1, 2) # BCHW -> BNC + + data = self.norm(data) + patch_embed_dict[product_band] = data + + return patch_embed_dict + + +def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): + """ + Generate 1D sincos positional embedding. + Args: + embed_dim: Output dimension for each position + pos: A list of positions to be encoded: size (M,) + Returns: + Positional embedding: (M, D) + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.double() + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# Copied from timm.models.vision_transformer +# and adapted to support alibi +# furthermore, we use the flexi attention implementation if available +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, x: torch.Tensor, alibi: torch.Tensor | None = None + ) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=alibi, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if alibi is not None: + attn = attn + alibi + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward( + self, x: torch.Tensor, alibi: torch.Tensor | None = None + ) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), alibi))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def get_slopes(n): + """ + Get slopes for attention bias calculation in alibi attention. + Args: + n: Number of attention heads + Returns: + List of slopes for each attention head + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +@torch.jit.script +def get_alibi_thor( + metadata: dict[str, dict[str, int]], + available_groups: dict[str, list[str]], + slopes: torch.Tensor, + offset: float | int = 0.0, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + 2D Alibi implementation using Euclidean distance between patches + Args: + metadata: Metadata of the input data with GSD, patch_size, and num_patch information + available_groups: Available groups of the input data + slopes: Slopes for attention heads, used to scale the distance, should be a tensor of shape (num_heads,) + offset: Offset to add to the distance, useful for cls token + device: Device for computation + dtype: Data type for computation + Returns: + distances: Alibi tensor (batch_size, num_heads, num_patches, num_patches), + where num_patches is the total number of patches for all groups + """ + num_patches = 0 + all_points = [] + max_patch_gsd_size = 0 + + for _group_name, group_members in available_groups.items(): + first_member = group_members[0] + product_gsd = metadata[first_member]["GSD"] + product_num_patch = metadata[first_member]["num_patch"] + product_patch_size = metadata[first_member]["patch_size"] + num_patches += int(product_num_patch**2) + max_patch_gsd_size = max(max_patch_gsd_size, product_patch_size * product_gsd) + + line_of_points = torch.arange(0, product_num_patch, dtype=dtype, device=device) + line_of_points *= product_patch_size + line_of_points += product_patch_size / 2 + line_of_points *= product_gsd + + points = torch.cartesian_prod(line_of_points, line_of_points) + all_points.append(points) + + points = torch.cat(all_points, dim=0) + + # Normalize points by largest GSD + points = points / max_patch_gsd_size + + attention_heads = slopes.shape[0] + slopes = slopes.unsqueeze(1).unsqueeze(2) + distances = torch.cdist(points, points) + distances += float(offset) + distances = distances.unsqueeze(0) + distances = distances * slopes * -1 + distances = distances.view(-1, attention_heads, num_patches, num_patches) + + return distances + + +class THOR_Encoder(Encoder): + def __init__( + self, + encoder_weights: str | Path, + input_size: int, + patch_size: int, + input_bands: dict[str, list[str]], + output_layers: list[int], + download_url: str = "", + ground_cover: int | None = None, + output_aggr: int | str = "concat", + #### Default parameters vit 'base_encoder_alibi_patch_embed' ### + ref_patch_size: int = 4, + patch_size_seqs: dict[str, Sequence[int]] | Sequence[int] | None = None, + channels: dict[str, dict[str, int]] = DEFAULT_CHANNELS, + groups: list[list] = DEFAULT_GROUPS, + aggr_type: str = "subsetmean", + ### Base ViT parameters ########################################## + embed_dim=768, + depth=12, + num_heads=12, + embed_band=True, + band_embed_dim=128, + embed_prod=False, + prod_embed_dim=0, + embed_patch_size=True, + pad_prod_embed_null=False, + pad_band_embed_null=False, + mlp_ratio=4.0, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ) -> None: + if ground_cover is None: + ground_cover = input_size * 10 + logger.info( + f"Assuming ground_cover is {ground_cover} m based on input_size {input_size} px" + ) + self.ground_cover = ground_cover + + bands = [ + pangaea_to_thor_band_map[band] + for modality_bands in input_bands.values() + for band in modality_bands + ] + self.bands = bands + logger.info(f"Bands: {self.bands}") + + self.aggr_type = aggr_type + if patch_size < ref_patch_size: + msg = ( + f"The model was trained with a ref_patch_size of {ref_patch_size}, but the input patch size is {patch_size}." + f" Overriding the ref_patch_size to {patch_size}. This may lead to suboptimal performance." + ) + warnings.warn(msg) + ref_patch_size = patch_size + self.ref_patch_size = ref_patch_size + if patch_size_seqs is None: + patch_size_seqs = [patch_size] + + # Backwards compat channels + self.channels = {} + self.channel_rename_map = {} + for channel, params in channels.items(): + if "num_patch" in params and "patch_size" not in params: + params["patch_size"] = ( + self.ground_cover // params["num_patch"] // params["GSD"] + ) + if "patch_size" in params and "num_patch" not in params: + if isinstance(patch_size_seqs, dict): + min_patch_size_seq = min(patch_size_seqs[channel]) + elif isinstance(patch_size_seqs, Sequence): + min_patch_size_seq = min(patch_size_seqs) + patch_size = min(params["patch_size"], min_patch_size_seq) + + params["num_patch"] = self.ground_cover // patch_size // params["GSD"] + + rename_name = params.pop("patch_embed_name", channel) + if rename_name in self.channel_rename_map: + raise ValueError(f"Duplicate patch embed name {rename_name} found.") + self.channel_rename_map[channel] = rename_name + self.channels[channel] = params + + # Default groups + self.groups = self.validate_group( + groups + ) # {'group0':[product_band, ...], 'group1': ..., ...} + + available_groups: dict[str, list[str]] = self.get_available_groups({ + band: None for band in bands + }) + + self.available_groups = available_groups + + self.output_aggr = output_aggr + if self.output_aggr == "concat": + output_dim = len(available_groups) * embed_dim + elif self.output_aggr in ["mean", "sum", "max", "min"]: + output_dim = embed_dim + elif isinstance(self.output_aggr, int): + output_dim = embed_dim + else: + raise ValueError(f"Unknown output_aggr {self.output_aggr}") + + logger.info(f"THOR Encoder output_dim: {output_dim}") + + size_map = {192: "tiny", 384: "small", 768: "base", 1024: "large"} + + super().__init__( + model_name=f"thor_{size_map[embed_dim]}_encoder", + encoder_weights=encoder_weights, + input_bands=input_bands, + input_size=input_size, + embed_dim=embed_dim, # my_model_embed_dim, fixed parameters + output_dim=output_dim, + output_layers=output_layers, + pyramid_output=False, + download_url=download_url, + multi_temporal=False, # wether support multi-temporal, fixed parametersfixed parameters + multi_temporal_output=False, # wether the output of the model has a temporal dimension + ) + logger.info(f"Available groups: {available_groups}") + self.min_gsd = min([params["GSD"] for params in self.channels.values()]) + + self.ind_patch_embed = IndFlexiPatchEmbed( + ground_covers=[self.ground_cover], + channels=self.channels, + channel_rename_map=self.channel_rename_map, + embed_dim=embed_dim, + patch_size_seqs=patch_size_seqs, + ) + + # Initialize embedding + self.num_heads = num_heads + self.embed_dim = embed_dim + self.embed_prod = embed_prod + self.prod_embed_dim = prod_embed_dim if embed_prod else 0 + self.pad_prod_embed_null = pad_prod_embed_null + self.embed_band = embed_band + self.band_embed_dim = band_embed_dim if embed_band else 0 + self.pad_band_embed_null = pad_band_embed_null + self.pos_embed_dim = embed_dim - prod_embed_dim - band_embed_dim + self.embed_patch_size = embed_patch_size + self.patch_size_embed_dim = self.pos_embed_dim if embed_patch_size else 0 + + self.register_buffer("encoder_slopes", torch.tensor(get_slopes(num_heads))) + + self.init_embeds() + + # Initialize transformer blocks + self.blocks = nn.ModuleList([ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + ) + for i in range(depth) + ]) + self.norm = norm_layer(embed_dim) + + self.initialize_weights() + + def init_embeds(self): + self.band_embed, self.prod_embed = self.initialize_embedding( + self.band_embed_dim, + self.prod_embed_dim, + ) + + def validate_group(self, groups): + valid_groups = {} + if groups is None: + for group_idx, prodcut_band in enumerate(self.channels.keys()): + valid_groups[f"group{group_idx}"] = [prodcut_band] + return valid_groups + found_bands = [] + for group_idx, group in enumerate(groups): + group_product = None + group_gsd = None + group_patch_size = None + valid_groups[f"group{group_idx}"] = [] + for product_band in group: + product, _ = product_band.split(":") + group_product = product if group_product is None else group_product + group_gsd = ( + self.channels[product_band]["GSD"] + if group_gsd is None + else group_gsd + ) + group_patch_size = ( + self.channels[product_band]["patch_size"] + if group_patch_size is None + else group_patch_size + ) + if self.channels[product_band]["GSD"] != group_gsd: + msg = f"GSD {self.channels[product_band]['GSD']} in group {group} does not match with GSD {group_gsd} in the same group." + raise ValueError(msg) + if self.channels[product_band]["patch_size"] != group_patch_size: + msg = f"Patch size {self.channels[product_band]['patch_size']} in group {group} does not match with patch size {group_patch_size} in the same group." + raise ValueError(msg) + valid_groups[f"group{group_idx}"].append(product_band) + found_bands.append(product_band) + + # Remove bands from channels that are not present in the groups + keys = list(self.channels.keys()) + remove_bands = set(keys) - set(found_bands) + for band in remove_bands: + del self.channels[band] + if remove_bands: + warnings.warn( + f"Removed bands {remove_bands} from channels that are not present in the groups." + ) + + return valid_groups + + def initialize_embedding( + self, + band_embed_dim: int, + prod_embed_dim: int, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + band_embed = {} + if self.embed_band: + embed = get_1d_sincos_pos_embed_from_grid( + band_embed_dim, torch.arange(len(self.groups)).cpu().numpy() + ) + for i, spectral_group in enumerate(self.groups.keys()): + band_embed[spectral_group] = torch.from_numpy(embed[i]).float() + + prod_embed = {} + if self.embed_prod: + unique_prod = { + prduct_band.split(":")[0]: None for prduct_band in self.channels.keys() + } + embed = get_1d_sincos_pos_embed_from_grid( + prod_embed_dim, torch.arange(len(unique_prod)).cpu().numpy() + ) + for i, prod in enumerate(unique_prod.keys()): + if self.pad_prod_embed_null: + prod_embed[prod] = torch.zeros_like( + torch.from_numpy(embed[i]) + ).float() + else: + prod_embed[prod] = torch.from_numpy(embed[i]).float() + + band_embed = nn.ParameterDict(band_embed).requires_grad_(False) + prod_embed = nn.ParameterDict(prod_embed).requires_grad_(False) + + return band_embed, prod_embed + + def initialize_weights(self) -> None: + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_available_groups( + self, original_input: dict[str, torch.Tensor | None] + ) -> dict[str, list[str]]: + available_groups: dict[str, list[str]] = {} + for group_name, group_member in self.groups.items(): + for member in group_member: + if member in original_input: + if group_name not in available_groups: + available_groups[group_name] = [member] + else: + available_groups[group_name].append(member) + return available_groups + + def get_channel_params( + self, + patch_embed: dict[str, torch.Tensor], + metadata: dict[str, dict[str, int]] | None = None, + ground_cover: int | None = None, + ) -> dict[str, dict[str, int]]: + """Get GSD, num_patch, patch_size for each channel in the input data. + Args: + patch_embed: patch embeddings of the input data {'product_band': (B, N, C)} + metadata: metadata of the input data + """ + channel_params = { + product_band: {"GSD": self.channels[product_band]["GSD"]} + for product_band in patch_embed.keys() + } + + # Override with default if available + if ground_cover is None: + ground_cover = self.ground_cover + + assert isinstance(ground_cover, int), ( + f"Ground cover {ground_cover} is not defined as int, please provide a valid ground cover." + ) + + for product_band in patch_embed: + # Override with metadata if available + if ( + metadata is not None + and product_band in metadata + and "GSD" in metadata[product_band] + ): + channel_params[product_band]["GSD"] = metadata[product_band]["GSD"] + num_patch = int(patch_embed[product_band].shape[1] ** 0.5) + patch_size = ( + ground_cover // num_patch // channel_params[product_band]["GSD"] + ) + channel_params[product_band]["num_patch"] = num_patch + channel_params[product_band]["patch_size"] = patch_size + assert ( + patch_size * channel_params[product_band]["GSD"] * num_patch + == ground_cover + ), ( + f"Patch size {patch_size} * GSD {channel_params[product_band]['GSD']} * num_patch {num_patch} does not match ground cover {ground_cover}" + ) + + return channel_params + + def get_encoder_auxilliary_embed( + self, + original_input: dict[str, torch.Tensor], + available_groups: dict[str, list[str]], + channel_params: dict[str, dict[str, int]], + ) -> dict[str, torch.Tensor]: + auxilliary_embed = {} + for group_name, group_member in available_groups.items(): + product_band = group_member[0] + product, band = product_band.split(":") + + data = original_input[product_band] + + if self.embed_patch_size: + patch_sizes = ( + torch.ones( + (channel_params[product_band]["num_patch"] ** 2), + device=data.device, + dtype=data.dtype, + ) + * channel_params[product_band]["patch_size"] + * channel_params[product_band]["GSD"] + / (self.ref_patch_size * self.min_gsd) + ) + + group_pos_embed = get_1d_sincos_pos_embed_from_grid_torch( + self.patch_size_embed_dim, pos=patch_sizes + ).to(data.dtype) + else: + group_pos_embed = torch.zeros( + ( + channel_params[product_band]["num_patch"] ** 2, + self.pos_embed_dim, + ), + device=data.device, + dtype=data.dtype, + ) + + auxilliary_embed[group_name] = group_pos_embed + + if self.embed_band: + auxilliary_embed[group_name] = torch.cat( + ( + auxilliary_embed[group_name], + self.band_embed[group_name].expand( + auxilliary_embed[group_name].shape[0], -1 + ), + ), + dim=-1, + ) + if self.embed_prod: + auxilliary_embed[group_name] = torch.cat( + ( + auxilliary_embed[group_name], + self.prod_embed[product].expand( + auxilliary_embed[group_name].shape[0], -1 + ), + ), + dim=-1, + ) + + band_ground_cover = int( + data.shape[-1] * channel_params[product_band]["GSD"] + ) + if band_ground_cover != self.ground_cover: + raise ValueError( + f"Input ground cover for {product_band} is {band_ground_cover}x{band_ground_cover}, (image shape:{data.shape[-2:]}) " + f"which does match grid of {channel_params[product_band]['num_patch']}x{channel_params[product_band]['num_patch']}" + f"with patch size {channel_params[product_band]['patch_size']} and GSD {channel_params[product_band]['GSD']}." + f" patches for defined ground cover {self.ground_cover}." + ) + + return auxilliary_embed + + def aggregate_by_group( + self, + patch_embed: dict[str, torch.Tensor], + available_groups: dict[str, list[str]], + ) -> dict[str, torch.Tensor]: + group_embed = {} + for group_name, group_member in available_groups.items(): + if self.aggr_type == "subsetmean": + to_stack = [ + patch_embed[product_band] + for product_band in group_member + if product_band in patch_embed + ] + if len(to_stack) > 0: + group_embed[group_name] = torch.stack(to_stack, -1).mean(-1) + elif self.aggr_type == "subsetsum": + to_stack = [ + patch_embed[product_band] + for product_band in group_member + if product_band in patch_embed + ] + if len(to_stack) > 0: + group_embed[group_name] = torch.stack(to_stack, -1).sum(-1) + + return group_embed + + def forward_intermediates( + self, + inp: dict[str, torch.Tensor], + metadata: dict[str, dict[str, int]] | None = None, + ground_cover: int | None = None, + ) -> tuple[list[torch.Tensor], dict[str, dict[str, int]]]: + """Forward pass with intermediate outputs.""" + + # TODO: add mode option for using i.e., smallest patch size or normal patch size + patch_embeds = self.ind_patch_embed( + inp, + patch_sizes={ + p: min(p_sizes) + for p, p_sizes in self.ind_patch_embed.patch_size_seqs.items() + }, + ) + # {'product:band': (B, N_n, D), ...} + + # Get available groups + available_groups = self.get_available_groups(patch_embeds) + # {'group0': [product_band, ...], ...} + + # Get channel parameters + channel_params = self.get_channel_params(patch_embeds, metadata, ground_cover) + + # Aggregate embeddings by group + group_embeds = self.aggregate_by_group(patch_embeds, available_groups) + # {'group0': (B, N_n, D), ...} + + # Get auxiliary embeddings (positional, band) + aux_embeds = self.get_encoder_auxilliary_embed( + inp, available_groups, channel_params + ) + # {'group0': (N_n, D), ...} + + # Combine group embeddings with auxiliary embeddings + x = torch.cat( + [ + group_embeds[group_name] + aux_embeds[group_name] + for group_name in group_embeds.keys() + ], + dim=1, + ) + # (B, T, D) + + # Compute alibi attention bias + alibi = get_alibi_thor( + channel_params, + available_groups, + slopes=self.encoder_slopes, + offset=0.0, + device=x.device, + dtype=x.dtype, + ) + alibi = alibi.expand(x.shape[0], -1, -1, -1) + + # apply Transformer blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x, alibi) + # if i == len(self.blocks) - 1: + # x = self.norm(x) # TODO: drop? + if i in self.output_layers: + output.append(x) + + return output, channel_params + + def _post_process( + self, features: list[torch.Tensor], channel_params: dict[str, dict[str, int]] + ) -> list[torch.Tensor]: + """Stack embeddings for each group, requires interpolation to the highest num_patch.""" + highest_num_patch = max( + channel_params[product_band]["num_patch"] + for product_band in channel_params.keys() + ) + + out_features = [] + for feature in features: + start_idx = 0 + out = [] + # Important that we iterate through this in the same order we encoded + for group_member in self.available_groups.values(): + if len(group_member) == 0: + continue + product_band = group_member[0] + + num_patch = channel_params[product_band]["num_patch"] + + x_ = feature[:, start_idx : start_idx + num_patch**2, :].reshape( + -1, num_patch, num_patch, self.embed_dim + ) + x_ = x_.permute(0, 3, 1, 2) # B, C, H, W + + # Interpolate if needed + if num_patch != highest_num_patch and self.output_aggr in [ + "concat", + "mean", + "sum", + "max", + "min", + ]: + x_ = F.interpolate( + x_, + size=(highest_num_patch, highest_num_patch), + mode="bilinear", + ) + + out.append(x_) + start_idx += num_patch**2 + + if start_idx != feature.shape[1]: + raise ValueError( + f"Number of patches {start_idx} does not match number of patches in input {feature.shape[1]}" + ) + + if self.output_aggr == "concat": + # Concatenate all group features + out = torch.cat(out, dim=1) + elif self.output_aggr == "mean": + # Mean all group features + out = torch.mean(torch.stack(out, dim=0), dim=0) + elif self.output_aggr == "sum": + # Sum all group features + out = torch.sum(torch.stack(out, dim=0), dim=0) + elif self.output_aggr == "max": + # Max all group features + out = torch.max(torch.stack(out, dim=0), dim=0)[0] + elif self.output_aggr == "min": + # Min all group features + out = torch.min(torch.stack(out, dim=0), dim=0)[0] + elif isinstance(self.output_aggr, int): + # Select the group feature at the specified index + out = out[self.output_aggr] + out_features.append(out) + + return out_features + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def load_encoder_weights(self, logger: Logger) -> None: + pretrained_model = torch.load( + self.encoder_weights, map_location="cpu", weights_only=True + )["state_dict"] + prefix = "mae." + k = pretrained_model.keys() + renamed_keys = { + key.replace(prefix, ""): key for key in k if key.startswith(prefix) + } + pretrained_encoder = {} + incompatible_shape = {} + missing = {} + for name, param in self.named_parameters(): + if name not in renamed_keys: + missing[name] = param.shape + elif pretrained_model[renamed_keys[name]].shape != param.shape: + incompatible_shape[name] = ( + param.shape, + pretrained_model[renamed_keys[name]].shape, + ) + else: + pretrained_encoder[name] = pretrained_model[renamed_keys[name]] + unused_keys = set(renamed_keys.keys()) - set(pretrained_encoder.keys()) + for key in unused_keys: + if key in renamed_keys: + # These are keys that were in the pretrained model but not used, + # typically decoder weights. + logger.debug(f"Unused key {key} in pretrained model") + else: + logger.warning(f"Key {key} not found in pretrained model") + + if missing: + logger.warning( + f"Some keys from the pretrained model were not found in the current model: {missing}" + ) + + if incompatible_shape: + logger.warning( + f"Some parameters have incompatible shapes: {incompatible_shape}" + ) + raise ValueError("Incompatible parameter shapes found, have you loaded the correct size model?") + + self.load_state_dict(pretrained_encoder, strict=False) + self.parameters_warning(missing, incompatible_shape, logger) + logger.info("Loaded encoder weights successfully.") + + # def _preprocess_input(self, x): + # """Preprocess input data for the model.""" + # # Concatenate all modalities + # x = torch.concat([x[modality] for modality in x.keys()], dim=1) + + # x = { + # channel: F.interpolate( + # x[:, [i], :, :], + # ( + # int(self.ground_cover / self.channels[channel]["GSD"]), + # int(self.ground_cover / self.channels[channel]["GSD"]), + # ), + # mode="bilinear", + # ) + # for i, channel in enumerate(self.bands) + # } + + # return x + + def _preprocess_input(self, x): + """Preprocess input data for the model.""" + # Concatenate all modalities + # print(f"x keys: {x.keys()}") + # print(f"x shapes before concat: {[x[modality].shape for modality in x.keys()]}") + x = torch.concat([x[modality] for modality in x.keys()], dim=1) + + # Ensure x has batch dimension + # print(f"x.dim() before unsqueeze: {x.dim()}") + if x.dim() != 3: + x = x.squeeze(2) + + x = { + channel: F.interpolate( + x[:, [i], :, :], + ( + int(self.ground_cover / self.channels[channel]["GSD"]), + int(self.ground_cover / self.channels[channel]["GSD"]), + ), + mode="bilinear", + ) + for i, channel in enumerate(self.bands) + } + + return x + + def forward(self, x: dict[str, torch.Tensor]) -> list[torch.Tensor]: + """Foward pass of the encoder. + + Args: + x (dict[str, torch.Tensor]): encoder's input structured as a dictionary: + x = {modality1: tensor1, modality2: tensor2, ...}, e.g. x = {"optical": tensor1, "sar": tensor2}. + If the encoder is multi-temporal (self.multi_temporal==True), input tensor shape is (B C T H W) with C the + number of bands required by the encoder for the given modality and T the number of time steps. If the + encoder is not multi-temporal, input tensor shape is (B C H W) with C the number of bands required by the + encoder for the given modality. + + Returns: + list[torch.Tensor]: list of the embeddings for each modality. For single-temporal encoders, the list's + elements are of shape (B, embed_dim, H', W'). For multi-temporal encoders, the list's elements are of shape + (B, C', T, H', W') with T the number of time steps if the encoder does not have any time-merging strategy, + else (B, C', H', W') if the encoder has a time-merging strategy (where C'==self.output_dim). + """ + x = self._preprocess_input(x) + outputs, channel_params = self.forward_intermediates(x) + outputs = self._post_process(outputs, channel_params=channel_params) + return outputs + +_large_cfg = { + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + "band_embed_dim": 256, +} + +_base_cfg = { + "embed_dim": 768, + "depth": 12, + "num_heads": 12, + "band_embed_dim": 128, +} + +_small_cfg = { + "embed_dim": 384, + "depth": 12, + "num_heads": 6, + "band_embed_dim": 64, +} + +_tiny_cfg = { + "embed_dim": 192, + "depth": 12, + "num_heads": 3, + "band_embed_dim": 32, +} + +def thor_tiny_encoder(**kwargs) -> THOR_Encoder: + """THOR Tiny Encoder with Alibi Attention and Flexivit Patch Embedding V1 + + Args: + **kwargs: keyword arguments for the THOR_Encoder class + + Returns: + THOR_Encoder: THOR Tiny Encoder with Alibi Attention and Patch Embedding V1 + """ + model = THOR_Encoder( + **_tiny_cfg, + **kwargs, + ) + return model + +def thor_small_encoder(**kwargs) -> THOR_Encoder: + """THOR Small Encoder with Alibi Attention and Flexivit Patch Embedding V1 + + Args: + **kwargs: keyword arguments for the THOR_Encoder class + + Returns: + THOR_Encoder: THOR Small Encoder with Alibi Attention and Flexivit Patch Embedding V1 + """ + model = THOR_Encoder( + **_small_cfg, + **kwargs, + ) + return model + + +def thor_base_encoder(**kwargs) -> THOR_Encoder: + """THOR Base Encoder with Alibi Attention and Flexivit Patch Embedding V1 + + Args: + **kwargs: keyword arguments for the THOR_Encoder class + + Returns: + THOR_Encoder: THOR Base Encoder with Alibi Attention and Flexivit Patch Embedding V1 + """ + model = THOR_Encoder( + **_base_cfg, + **kwargs, + ) + return model + +def thor_large_encoder(**kwargs) -> THOR_Encoder: + """THOR Large Encoder with Alibi Attention and Flexivit Patch Embedding V1 + + Args: + **kwargs: keyword arguments for the THOR_Encoder class + + Returns: + THOR_Encoder: THOR Large Encoder with Alibi Attention and Flexivit Patch Embedding V1 + """ + model = THOR_Encoder( + **_large_cfg, + **kwargs, + ) + return model + + +__all__ = [ + "thor_tiny_encoder", + "thor_small_encoder", + "thor_base_encoder", + "thor_large_encoder", +] \ No newline at end of file diff --git a/pangaea/run.py b/pangaea/run.py index be741193..d0075a39 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -282,6 +282,7 @@ def main(cfg: DictConfig) -> None: batch_size=cfg.test_batch_size, num_workers=cfg.test_num_workers, pin_memory=True, + prefetch_factor=6, persistent_workers=False, drop_last=False, collate_fn=collate_fn, From 4aabf837df0be55a779352eab14b4ba11e66a8b7 Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Fri, 20 Mar 2026 15:47:35 +0100 Subject: [PATCH 13/14] and prithvi --- configs/encoder/prithvi2_100m.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/encoder/prithvi2_100m.yaml b/configs/encoder/prithvi2_100m.yaml index 6d4d3916..047eb25e 100644 --- a/configs/encoder/prithvi2_100m.yaml +++ b/configs/encoder/prithvi2_100m.yaml @@ -1,6 +1,6 @@ _target_: pangaea.encoders.prithvi2_encoder.Prithvi2_Encoder -encoder_weights: ./pretrained_models/Prithvi/Prithvi-EO-2.0-100M-TL/Prithvi_EO_V2_100M_TL.pt -download_url: False +encoder_weights: ./pretrained_models/Prithvi_EO_V2_100M_TL.pt +download_url: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL/resolve/main/Prithvi_EO_V2_100M_TL.pt?download=true # Model architecture (100M parameters) embed_dim: 768 From 4fe37fe6665b485d5ac81792a9e9696abee3469c Mon Sep 17 00:00:00 2001 From: ghjuliasialelli Date: Fri, 27 Mar 2026 17:47:42 +0100 Subject: [PATCH 14/14] ok --- pangaea/datasets/agbdlite.py | 6 +- pangaea/run.py | 2 +- throughput.py | 165 +++++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 3 deletions(-) create mode 100644 throughput.py diff --git a/pangaea/datasets/agbdlite.py b/pangaea/datasets/agbdlite.py index c8e81660..c0909572 100644 --- a/pangaea/datasets/agbdlite.py +++ b/pangaea/datasets/agbdlite.py @@ -109,12 +109,14 @@ def __init__( auto_download=auto_download, ) - if auto_download: self.download(self) if os.environ.get('SLURM_SUBMIT_DIR') is not None: print('Running on cluster, using cluster root path.') self.root_path = root_path_cluster + if auto_download: self.download(self) self.s2_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'] - if self.eval_big and self.mode == 'test' : self.fname = 'AGBD-test.h5' + if self.eval_big and self.mode == 'test' : + print("Using AGBD-test.h5 file for evaluation.") + self.fname = 'AGBD-test.h5' else: self.fname = f'AGBD-Lite-{self.mode}.h5' self.f_handle = h5py.File(join(self.root_path, self.fname), 'r') diff --git a/pangaea/run.py b/pangaea/run.py index d0075a39..4be10c8b 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -282,7 +282,7 @@ def main(cfg: DictConfig) -> None: batch_size=cfg.test_batch_size, num_workers=cfg.test_num_workers, pin_memory=True, - prefetch_factor=6, + prefetch_factor=4, persistent_workers=False, drop_last=False, collate_fn=collate_fn, diff --git a/throughput.py b/throughput.py new file mode 100644 index 00000000..8e53accb --- /dev/null +++ b/throughput.py @@ -0,0 +1,165 @@ +"""Measure inference throughput (samples/sec) of a model on a given dataset. + +Usage: + python throughput.py \ + task=regression dataset=agbdlite encoder=gfmswin \ + decoder=reg_upernet preprocessing=reg_default criterion=mse \ + batch_size=64 \ + --warmup 20 --iterations 100 + +All Hydra overrides (task, dataset, encoder, etc.) work as in run.py. +Extra CLI flags (--warmup, --iterations) control the benchmark. +""" + +import argparse +import sys +import time + +import hydra +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from pangaea.datasets.base import GeoFMDataset, RawGeoFMDataset +from pangaea.decoders.base import Decoder +from pangaea.encoders.base import Encoder +from pangaea.utils.collate_fn import get_collate_fn + + +def parse_extra_args(): + """Parse --warmup and --iterations before Hydra consumes the rest.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--warmup", type=int, default=20, + help="Number of warm-up forward passes (discarded).") + parser.add_argument("--iterations", type=int, default=100, + help="Number of timed forward passes.") + args, remaining = parser.parse_known_args() + # Put remaining args back so Hydra can parse them + sys.argv = [sys.argv[0]] + remaining + return args + + +extra_args = parse_extra_args() + + +@hydra.main(version_base=None, config_path="configs", config_name="train") +def main(cfg: DictConfig) -> None: + warmup = extra_args.warmup + iterations = extra_args.iterations + batch_size = cfg.batch_size + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if device.type != "cuda": + print("WARNING: No GPU detected. Throughput numbers on CPU are not meaningful.") + + # ── Build model ────────────────────────────────────────────────────── + encoder: Encoder = instantiate(cfg.encoder) + + #if encoder.model_name != "Prithvi": + # encoder.load_encoder_weights(None) # None logger → silent + + decoder: Decoder = instantiate(cfg.decoder, encoder=encoder) + decoder.to(device) + decoder.eval() + + n_params = sum(p.numel() for p in decoder.parameters()) + n_trainable = sum(p.numel() for p in decoder.parameters() if p.requires_grad) + print(f"Model : {decoder.model_name} (encoder: {encoder.model_name})") + print(f"Parameters : {n_params:,} total, {n_trainable:,} trainable") + + # ── Build dataset & loader ─────────────────────────────────────────── + preprocessor = instantiate( + cfg.preprocessing.test, + dataset_cfg=cfg.dataset, + encoder_cfg=cfg.encoder, + _recursive_=False, + ) + raw_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="test") + dataset = GeoFMDataset(raw_dataset, preprocessor) + + modalities = list(encoder.input_bands.keys()) + collate_fn = get_collate_fn(modalities) + + loader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=True, + collate_fn=collate_fn, + shuffle=True, + ) + + print(f"Dataset : {cfg.dataset.dataset_name} (split=test, {len(dataset)} samples)") + print(f"Batch size : {batch_size}") + print(f"Warm-up : {warmup} forward passes") + print(f"Iterations : {iterations} forward passes") + print() + + # ── Helper: get one batch (cycles the loader) ──────────────────────── + loader_iter = iter(loader) + + def next_batch(): + nonlocal loader_iter + try: + return next(loader_iter) + except StopIteration: + loader_iter = iter(loader) + return next(loader_iter) + + # ── Warm-up ────────────────────────────────────────────────────────── + print("Warming up …") + with torch.no_grad(): + for _ in range(warmup): + data = next_batch() + image = {k: v.to(device, non_blocking=True) for k, v in data["image"].items()} + target = data["target"].to(device, non_blocking=True) + _ = decoder(image, output_shape=target.shape[-2:]) + + if device.type == "cuda": + torch.cuda.synchronize() + + # ── Timed run ──────────────────────────────────────────────────────── + print("Benchmarking …") + if device.type == "cuda": + torch.cuda.synchronize() + + t_start = time.perf_counter() + + with torch.no_grad(): + for _ in range(iterations): + data = next_batch() + image = {k: v.to(device, non_blocking=True) for k, v in data["image"].items()} + target = data["target"].to(device, non_blocking=True) + _ = decoder(image, output_shape=target.shape[-2:]) + + if device.type == "cuda": + torch.cuda.synchronize() + + elapsed = time.perf_counter() - t_start + + # ── Report ─────────────────────────────────────────────────────────── + total_samples = iterations * batch_size + throughput = total_samples / elapsed + ms_per_sample = (elapsed / total_samples) * 1000 + ms_per_batch = (elapsed / iterations) * 1000 + + print() + print("═" * 50) + print(f" Total time : {elapsed:.2f} s") + print(f" Batches : {iterations}") + print(f" Batch size : {batch_size}") + print(f" Throughput : {throughput:.1f} samples/s") + print(f" Latency/sample : {ms_per_sample:.2f} ms") + print(f" Latency/batch : {ms_per_batch:.2f} ms") + print("═" * 50) + + if device.type == "cuda": + peak_mem = torch.cuda.max_memory_allocated(device) / (1024 ** 3) + print(f" Peak GPU mem : {peak_mem:.2f} GB") + print("═" * 50) + + +if __name__ == "__main__": + main()