From 539667c219e73f6dd949438fe1e7304867d5ba97 Mon Sep 17 00:00:00 2001 From: yurujaja Date: Tue, 18 Mar 2025 11:55:00 +0100 Subject: [PATCH 1/6] add classification tasks, and two geobench datasets --- configs/criterion/binary_cross_entropy.yaml | 1 + configs/dataset/mbigearthnet.yaml | 39 ++++ configs/dataset/meurosat.yaml | 39 ++++ configs/decoder/cls_linear.yaml | 5 + configs/preprocessing/cls_resize.yaml | 20 +++ configs/task/classification.yaml | 30 ++++ configs/task/classification_multi_label.yaml | 34 ++++ pangaea/datasets/geobench/__init__.py | 0 pangaea/datasets/geobench/mbigearthnet.py | 177 +++++++++++++++++++ pangaea/datasets/geobench/meurosat.py | 175 ++++++++++++++++++ pangaea/datasets/utils.py | 25 ++- pangaea/decoders/linearclassifier.py | 113 ++++++++++++ pangaea/engine/data_preprocessor.py | 4 +- pangaea/engine/evaluator.py | 133 ++++++++++++++ pangaea/engine/trainer.py | 123 +++++++++++++ pangaea/utils/collate_fn.py | 1 + pangaea/utils/subset_sampler.py | 32 ++++ 17 files changed, 947 insertions(+), 4 deletions(-) create mode 100644 configs/criterion/binary_cross_entropy.yaml create mode 100644 configs/dataset/mbigearthnet.yaml create mode 100644 configs/dataset/meurosat.yaml create mode 100644 configs/decoder/cls_linear.yaml create mode 100644 configs/preprocessing/cls_resize.yaml create mode 100644 configs/task/classification.yaml create mode 100644 configs/task/classification_multi_label.yaml create mode 100644 pangaea/datasets/geobench/__init__.py create mode 100644 pangaea/datasets/geobench/mbigearthnet.py create mode 100644 pangaea/datasets/geobench/meurosat.py create mode 100644 pangaea/decoders/linearclassifier.py diff --git a/configs/criterion/binary_cross_entropy.yaml b/configs/criterion/binary_cross_entropy.yaml new file mode 100644 index 00000000..2a3a2cfb --- /dev/null +++ b/configs/criterion/binary_cross_entropy.yaml @@ -0,0 +1 @@ +_target_: torch.nn.BCEWithLogitsLoss \ No newline at end of file diff --git a/configs/dataset/mbigearthnet.yaml b/configs/dataset/mbigearthnet.yaml new file mode 100644 index 00000000..12ec8a62 --- /dev/null +++ b/configs/dataset/mbigearthnet.yaml @@ -0,0 +1,39 @@ +_target_: pangaea.datasets.geobench.mbigearthnet.mBigEarthNet +dataset_name: mBigEarthNet +root_path: ${oc.env:GEO_BENCH_DIR}/classification_v1.0/m-bigearthnet # ensure sys env var GEO_BENCH_DIR exist +download_url: "recursix/geo-bench-1.0" +auto_download: True +ignore_index: -100 +num_classes: 43 +img_size: 120 +multi_temporal: False +multi_modal: False + +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B11 + - B12 + +classes: [''] +distribution: [0,] + +# data stats +data_mean: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + +data_std: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_min: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_max: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] \ No newline at end of file diff --git a/configs/dataset/meurosat.yaml b/configs/dataset/meurosat.yaml new file mode 100644 index 00000000..b5026c55 --- /dev/null +++ b/configs/dataset/meurosat.yaml @@ -0,0 +1,39 @@ +_target_: pangaea.datasets.geobench.meurosat.mEuroSat +dataset_name: mEuroSat +root_path: ${oc.env:GEO_BENCH_DIR}/classification_v1.0/m-eurosat # ensure sys env var GEO_BENCH_DIR exist +download_url: "recursix/geo-bench-1.0" +auto_download: True +ignore_index: -100 +multi_temporal: False +multi_modal: False +img_size: 64 +num_classes: 10 + +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B10 + - B11 + - B12 + +classes: ['', '', '', '', '', '', '', '', '', ''] +distribution: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + +data_mean: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_std: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_min: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_max: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \ No newline at end of file diff --git a/configs/decoder/cls_linear.yaml b/configs/decoder/cls_linear.yaml new file mode 100644 index 00000000..b5e4c724 --- /dev/null +++ b/configs/decoder/cls_linear.yaml @@ -0,0 +1,5 @@ +_target_: pangaea.decoders.linearclassifier.LinearClassifier + +encoder: null +num_classes: ${dataset.num_classes} +finetune: ${finetune} \ No newline at end of file diff --git a/configs/preprocessing/cls_resize.yaml b/configs/preprocessing/cls_resize.yaml new file mode 100644 index 00000000..9a00e946 --- /dev/null +++ b/configs/preprocessing/cls_resize.yaml @@ -0,0 +1,20 @@ +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.BandPadding + +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.BandPadding + +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/task/classification.yaml b/configs/task/classification.yaml new file mode 100644 index 00000000..16ae0459 --- /dev/null +++ b/configs/task/classification.yaml @@ -0,0 +1,30 @@ +trainer: + _target_: pangaea.engine.trainer.ClassificationTrainer + # params overwritten in run + model: null + train_loader: null + optimizer: null + lr_scheduler: null + evaluator: null + exp_dir: null + device: null + criterion: null + + # params to adapt + n_epochs: 50 + precision: fp32 + ckpt_interval: 50 + eval_interval: 5 + log_interval: 5 + best_metric_key: accuracy + use_wandb: ${use_wandb} + +evaluator: + _target_: pangaea.engine.evaluator.ClassificationEvaluator + # params overwritten in run + val_loader: null + exp_dir: null + device: null + use_wandb: ${use_wandb} + inference_mode: null + sliding_inference_batch: null \ No newline at end of file diff --git a/configs/task/classification_multi_label.yaml b/configs/task/classification_multi_label.yaml new file mode 100644 index 00000000..972ddf82 --- /dev/null +++ b/configs/task/classification_multi_label.yaml @@ -0,0 +1,34 @@ +trainer: + _target_: pangaea.engine.trainer.ClassificationTrainer + # params overwritten in run + model: null + train_loader: null + optimizer: null + lr_scheduler: null + evaluator: null + exp_dir: null + device: null + criterion: null + multi_label: true + topk: 1 + + # params to adapt + n_epochs: 50 + precision: fp32 + ckpt_interval: 50 + eval_interval: 5 + log_interval: 5 + best_metric_key: F1 + use_wandb: ${use_wandb} + +evaluator: + _target_: pangaea.engine.evaluator.ClassificationEvaluator + # params overwritten in run + val_loader: null + exp_dir: null + device: null + use_wandb: ${use_wandb} + inference_mode: null + sliding_inference_batch: null + multi_label: true + topk: 1 \ No newline at end of file diff --git a/pangaea/datasets/geobench/__init__.py b/pangaea/datasets/geobench/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pangaea/datasets/geobench/mbigearthnet.py b/pangaea/datasets/geobench/mbigearthnet.py new file mode 100644 index 00000000..84ab4767 --- /dev/null +++ b/pangaea/datasets/geobench/mbigearthnet.py @@ -0,0 +1,177 @@ + +import os +import numpy as np +import torch +from pangaea.datasets.base import RawGeoFMDataset +from pangaea.datasets.utils import decompress_zip_with_progress +import geobench +from pathlib import Path +from tqdm import tqdm +from huggingface_hub import HfApi, hf_hub_download +from torchvision import transforms + + +class mBigEarthNet(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the mBigEarthNet dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mBigEarthNet, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + all_band_names = ( + "01", + "02", + "03", + "04", + "05", + "06", + "07", + "08", + "08A", + "09", + "11", + "12", + ) + rgb_bands = ("04", "03", "02") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 4095 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.float32), + "metadata": { + "filename": filename}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['classification_v1.0/m-bigearthnet.zip', 'classification_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/meurosat.py b/pangaea/datasets/geobench/meurosat.py new file mode 100644 index 00000000..e8def587 --- /dev/null +++ b/pangaea/datasets/geobench/meurosat.py @@ -0,0 +1,175 @@ +import os +import numpy as np +import torch +from pangaea.datasets.base import RawGeoFMDataset +from pangaea.datasets.utils import decompress_zip_with_progress +from pathlib import Path +from huggingface_hub import HfApi, hf_hub_download +import geobench + + +class mEuroSat(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the mEuroSat dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mEuroSat, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + all_band_names = ( + "01", + "02", + "03", + "04", + "05", + "06", + "07", + "08", + "08A", + "09", + "10", + "11", + "12", + ) + rgb_bands = ("04", "03", "02") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 4095 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "metadata": { + "filename": filename}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['classification_v1.0/m-eurosat.zip', 'classification_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/utils.py b/pangaea/datasets/utils.py index 9b28cd88..b88d0f72 100644 --- a/pangaea/datasets/utils.py +++ b/pangaea/datasets/utils.py @@ -4,7 +4,7 @@ import pathlib import concurrent.futures from google.cloud.storage import Client - +import zipfile # Utility progress bar handler for urlretrieve @@ -93,4 +93,25 @@ def read_tif_with_metadata(file: pathlib.Path): arr = dataset.read() # (bands X height X width) transform = dataset.transform crs = dataset.crs - return arr.transpose((1, 2, 0)), transform, crs \ No newline at end of file + return arr.transpose((1, 2, 0)), transform, crs + + +def decompress_zip_with_progress(zip_file_path, extract_to_folder=None): + """Decompress a zip file with a progress bar and remove the symlink.""" + if extract_to_folder is None: + extract_to_folder = zip_file_path.parent + + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + file_names = zip_ref.namelist() + total_files = len(file_names) + + # Initialize the progress bar with the total number of files + with tqdm.tqdm(total=total_files, unit="file", desc=f"Extracting {zip_file_path.name}") as pbar: + for file in file_names: + # Extract each file + zip_ref.extract(file, extract_to_folder) + # Update the progress bar + pbar.update(1) + + # remove zip file + zip_file_path.unlink() \ No newline at end of file diff --git a/pangaea/decoders/linearclassifier.py b/pangaea/decoders/linearclassifier.py new file mode 100644 index 00000000..db9369eb --- /dev/null +++ b/pangaea/decoders/linearclassifier.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pangaea.decoders.base import Decoder +from pangaea.decoders.ltae import LTAE2d, LTAEChannelAdaptor +from pangaea.encoders.base import Encoder + + +class LinearClassifier(Decoder): + + def __init__( + self, + encoder: Encoder, + num_classes: int, + finetune: bool, + feature_multiplier: int = 1, + in_channels: list[int] | None = None, + ): + super().__init__( + encoder=encoder, + num_classes=num_classes, + finetune=finetune, + ) + + self.model_name = "LinearClassifier" + self.encoder = encoder + self.finetune = finetune + self.feature_multiplier = feature_multiplier + + if not self.finetune: + for param in self.encoder.parameters(): + param.requires_grad = False + + self.input_layers = self.encoder.output_layers + self.input_layers_num = len(self.input_layers) + + if in_channels is None: + self.in_channels = [ + dim * feature_multiplier for dim in self.encoder.output_dim + ] + else: + self.in_channels = [dim * feature_multiplier for dim in in_channels] + + self.in_channels = sum(self.in_channels) + # self.in_channels = self.in_channels[-1] + + + self.num_classes = num_classes + self.linear_head = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(self.in_channels, self.num_classes)) + + # self.linear_head = nn.Linear(self.in_channels, num_classes) + + + + def forward( + self, img: dict[str, torch.Tensor], output_shape: torch.Size | None = None + ) -> torch.Tensor: + """Compute the segmentation output. + + Args: + img (dict[str, torch.Tensor]): input data structured as a dictionary: + img = {modality1: tensor1, modality2: tensor2, ...}, e.g. img = {"optical": tensor1, "sar": tensor2}. + with tensor1 and tensor2 of shape (B C T=1 H W) with C the number of encoders'bands for the given modality. + output_shape (torch.Size | None, optional): output's spatial dims (H, W) (equals to the target spatial dims). + Defaults to None. + + Returns: + torch.Tensor: output tensor of shape (B, num_classes, H', W') with (H' W') coressponding to the output_shape. + """ + + # img[modality] of shape [B C T=1 H W] + if self.encoder.multi_temporal: + if not self.finetune: + with torch.no_grad(): + feat = self.encoder(img) + else: + feat = self.encoder(img) + + # multi_temporal models can return either (B C' T=1 H' W') + # or (B C' H' W'), we need (B C' H' W') + if self.encoder.multi_temporal_output: + feat = [f.squeeze(-3) for f in feat] + + else: + # remove the temporal dim + # [B C T=1 H W] -> [B C H W] + if not self.finetune: + with torch.no_grad(): + feat = self.encoder({k: v[:, :, 0, :, :] for k, v in img.items()}) + else: + feat = self.encoder({k: v[:, :, 0, :, :] for k, v in img.items()}) + + shapes = torch.tensor([f.shape[2:] for f in feat]) # Extract H, W for each tensor + max_h, max_w = shapes.max(dim=0).values + max_h = max_h.item() + max_w = max_w.item() + resized_feats = [] + for f in feat: + if f.shape[2:] != (max_h, max_w): + resized_feats.append(F.interpolate(f, size=(max_h, max_w), mode='bilinear', align_corners=False)) + else: + resized_feats.append(f) + + final_feat = torch.cat(resized_feats, dim=1) + # final_feat = resized_feats[-1] + + logits = self.linear_head(final_feat) + + return logits \ No newline at end of file diff --git a/pangaea/engine/data_preprocessor.py b/pangaea/engine/data_preprocessor.py index 2784431b..c229fd9c 100644 --- a/pangaea/engine/data_preprocessor.py +++ b/pangaea/engine/data_preprocessor.py @@ -34,9 +34,9 @@ def check_dimension(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor] f"Image dimension must be 4 (C, T, H, W), Got {str(len(v.shape))}" ) - if len(data["target"].shape) != 2: + if len(data["target"].shape) not in (0, 1, 2): raise AssertionError( - f"Target dimension must be 2 (H, W), Got {str(len(data['target'].shape))}" + f"Target dimension must be 0 (for classification - single label) or 1 (for classification - multi-label) or 2 (for dense prediction), got {str(len(data['target'].shape))}" ) def check_size(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]]): diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index d3c770cb..1d8e1b46 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -3,6 +3,8 @@ import time from pathlib import Path import math +import numpy as np +import sklearn.metrics import wandb import torch @@ -128,6 +130,137 @@ def sliding_inference(model, img, input_size, output_shape=None, stride=None, ma return merged_pred +class ClassificationEvaluator(Evaluator): + def __init__( + self, + val_loader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = "whole", + sliding_inference_batch: int = None, + use_wandb: bool = False, + multi_label: bool = False, # Flag to indicate multi-label evaluation + topk: int = 1, # For multi-label: if > 1, use top-k selection + ) -> None: + super().__init__(val_loader, exp_dir, device, inference_mode, sliding_inference_batch, use_wandb) + self.multi_label = multi_label + self.topk = topk + + def evaluate( + self, + model: torch.nn.Module, + model_name: str, + model_ckpt_path: str | Path | None = None): + + t = time.time() + if model_ckpt_path is not None: + model_dict = torch.load(model_ckpt_path, map_location=self.device) + model_name = os.path.basename(model_ckpt_path).split(".")[0] + if "model" in model_dict: + model.module.load_state_dict(model_dict["model"]) + else: + model.module.load_state_dict(model_dict) + self.logger.info(f"Loaded {model_name} for evaluation") + + model.eval() + + all_preds = [] + all_targets = [] + total_correct = 0 + total_samples = 0 + + tag = f"Evaluating {model_name} on {self.split} set" + for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)): + image, target = data["image"], data["target"] + image = {k: v.to(self.device) for k, v in image.items()} + target = target.to(self.device) + + with torch.no_grad(): + logits = model(image) + + if self.multi_label: + # Multi-label evaluation: + # Option 1: If topk > 1, select top-k indices; otherwise, threshold at 0.5. + preds_prob = torch.sigmoid(logits) + if self.topk > 1: + topk_indices = preds_prob.topk(self.topk, dim=1).indices # shape: (B, topk) + preds = torch.zeros_like(preds_prob, dtype=torch.int) + preds.scatter_(1, topk_indices, 1) + else: + preds = (preds_prob > 0.5).int() + + all_preds.append(preds.cpu().numpy()) + all_targets.append(target.cpu().numpy()) + else: + preds = torch.argmax(logits, dim=1) + + total_correct += (preds == target).sum().item() + total_samples += target.numel() + all_preds.append(preds.cpu().numpy()) + all_targets.append(target.cpu().numpy()) + + all_preds = np.concatenate(all_preds, axis=0) + all_targets = np.concatenate(all_targets, axis=0) + + + if self.multi_label: + # For multi-label, accuracy is computed as the subset accuracy. + accuracy = sklearn.metrics.accuracy_score(all_targets, all_preds) + precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support( + all_targets, all_preds, average="micro", zero_division=0) + else: + # For single-class tasks, overall accuracy is computed. + accuracy = total_correct / total_samples if total_samples > 0 else 0 + + precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support( + all_targets, all_preds,labels=list(range(self.num_classes)), average="macro", zero_division=0) + + metrics = { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "F1": f1, + } + + self.log_metrics(metrics) + used_time = time.time() - t + return metrics, used_time + + def __call__(self, model, model_name, model_ckpt_path=None): + return self.evaluate(model, model_name, model_ckpt_path) + + def compute_metrics(self): + pass + + def log_metrics(self, metrics: dict): + def format_metric(name, value): + header = f"------- {name} --------\n" + value_str = ( + "-------------------\n" + + "Mean".ljust(self.max_name_len, " ") + + "\t{:>7}".format("%.3f" % value) + ) + return header + value_str + + acc_str = format_metric("Accuracy", metrics["accuracy"]) + prec_str = format_metric("Precision", metrics["precision"]) + recall_str = format_metric("Recall", metrics["recall"]) + f1_str = format_metric("F1-score", metrics["F1"]) + self.logger.info(acc_str) + self.logger.info(prec_str) + self.logger.info(recall_str) + self.logger.info(f1_str) + + if self.use_wandb and self.rank == 0: + wandb.log({ + f"{self.split}_accuracy": metrics["accuracy"], + f"{self.split}_precision": metrics["precision"], + f"{self.split}_recall": metrics["recall"], + f"{self.split}_f1": metrics["F1"], + }) + + + class SegEvaluator(Evaluator): """ SegEvaluator is a class for evaluating segmentation models. It extends the Evaluator class and provides methods diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index 2291f0ab..50a4450e 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -354,6 +354,129 @@ def reset_stats(self) -> None: v.reset() +class ClassificationTrainer(Trainer): + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + criterion: nn.Module, + optimizer: Optimizer, + lr_scheduler: LRScheduler, + evaluator: torch.nn.Module, + n_epochs: int, + exp_dir: pathlib.Path | str, + device: torch.device, + precision: str, + use_wandb: bool, + ckpt_interval: int, + eval_interval: int, + log_interval: int, + best_metric_key: str, + multi_label: bool = False, # <-- Flag for multi-label classification, e.g., BigEarthNet dataset + topk: int = 1, # Top-k predictions to use in multi-label scenario + ): + """Initialize the Trainer for Classification task. + + Args: + model (nn.Module): model to train (encoder + decoder). + train_loader (DataLoader): train data loader. + criterion (nn.Module): criterion to compute the loss. + optimizer (Optimizer): optimizer to update the model's parameters. + lr_scheduler (LRScheduler): lr scheduler to update the learning rate. + evaluator (torch.nn.Module): task evaluator to evaluate the model. + n_epochs (int): number of epochs to train the model. + exp_dir (pathlib.Path | str): path to the experiment directory. + device (torch.device): model + precision (str): precision to train the model (fp32, fp16, bfp16). + use_wandb (bool): whether to use wandb for logging. + ckpt_interval (int): interval to save the checkpoint. + eval_interval (int): interval to evaluate the model. + log_interval (int): interval to log the training information. + best_metric_key (str): metric that determines best checkpoints. + multi_label (bool): Flag to enable multi-label classification. + """ + super().__init__( + model=model, + train_loader=train_loader, + criterion=criterion, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + evaluator=evaluator, + n_epochs=n_epochs, + exp_dir=exp_dir, + device=device, + precision=precision, + use_wandb=use_wandb, + ckpt_interval=ckpt_interval, + eval_interval=eval_interval, + log_interval=log_interval, + best_metric_key=best_metric_key, + ) + + self.multi_label = multi_label + self.topk = topk + + self.training_metrics = { + name: RunningAverageMeter(length=100) for name in ["accuracy", "F1"] + } + self.best_metric = float("-inf") + self.best_metric_comp = operator.gt + + def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + + return self.criterion(logits, target) + + def compute_logging_metrics( + self, logits: torch.Tensor, targets: torch.Tensor + ) -> None: + """Compute logging metrics. + For multi-label: + - Uses sigmoid activation and top-k selection. + For single-class: + - Uses argmax and converts predictions to one-hot encoding. + + Args: + logits (torch.Tensor): logits from the decoder. + target (torch.Tensor): target tensor. + """ + if self.multi_label: + preds_prob = torch.sigmoid(logits) + topk_indices = preds_prob.topk(self.topk, dim=1).indices + preds = torch.zeros_like(preds_prob, dtype=torch.bool) + preds.scatter_(1, topk_indices, 1) + else: + preds = torch.argmax(logits, dim=1) + + one_hot_preds = torch.zeros( + size=(preds.size(0), self.num_classes), + device=preds.device, + dtype=torch.bool + ) + one_hot_preds.scatter_(1, preds.unsqueeze(1), 1) + preds = one_hot_preds + # Convert targets to one-hot. + one_hot_targets = torch.zeros_like(preds) + one_hot_targets.scatter_(1, targets.unsqueeze(1), 1) + targets = one_hot_targets + + # Micro-average: aggregate across all classes. + preds = preds.bool() + targets = targets.bool() + TP = (preds & targets).sum().float() + FP = (preds & ~targets).sum().float() + FN = (~preds & targets).sum().float() + TN = (~preds & ~targets).sum().float() + + acc = (TP + TN) / (TP + TN + FP + FN + 1e-8) + precision = TP / (TP + FP + 1e-8) + recall = TP / (TP + FN + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + + self.training_metrics["accuracy"].update(acc.item()) + self.training_metrics["F1"].update(f1.item()) + + + class SegTrainer(Trainer): def __init__( self, diff --git a/pangaea/utils/collate_fn.py b/pangaea/utils/collate_fn.py index e375794e..924ce469 100644 --- a/pangaea/utils/collate_fn.py +++ b/pangaea/utils/collate_fn.py @@ -44,6 +44,7 @@ def collate_fn( for modality in modalities }, "target": torch.stack([x["target"] for x in batch]), + "metadata": [sample["metadata"] for sample in batch] } return collate_fn diff --git a/pangaea/utils/subset_sampler.py b/pangaea/utils/subset_sampler.py index e8337362..a5a1dd5a 100644 --- a/pangaea/utils/subset_sampler.py +++ b/pangaea/utils/subset_sampler.py @@ -1,6 +1,7 @@ import random from tqdm import tqdm import numpy as np +from collections import defaultdict from pangaea.datasets.base import GeoFMDataset from pangaea.datasets.base import GeoFMSubset @@ -60,6 +61,32 @@ def bin_regression_distributions(regression_distributions, num_bins=3, logger=No ) - 1 return binned_distributions +def balance_cls_indices( + dataset:GeoFMDataset|GeoFMSubset, + strategy, + label_fraction=1.0, + logger=None): + + indices_by_class = defaultdict(list) + + n_samples = len(dataset) + for idx in range(n_samples): + label = dataset[idx]['target'] + indices_by_class[label].append(idx) + + selected_idx = [] + # For each class, sample the same fraction of indices + if strategy == "stratified": + for label, indices in indices_by_class.items(): + num_to_select = max(1, int(len(indices) * label_fraction)) + selected_idx.extend(random.sample(indices, num_to_select)) + else: + raise NotImplementedError + + other_idx = list(set(range(len(dataset))) - set(selected_idx)) + + return selected_idx, other_idx + def balance_seg_indices( dataset:GeoFMDataset|GeoFMSubset, @@ -212,6 +239,11 @@ def get_subset_indices(dataset: GeoFMDataset, ) return indices + elif task == "classification" or task == "classification_multi_label": + indices, _ = balance_cls_indices( + dataset, strategy=strategy, label_fraction=label_fraction, logger=logger + ) + elif task == "segmentation" or task == "change_detection": indices, _ = balance_seg_indices( dataset, strategy=strategy, label_fraction=label_fraction, num_bins=num_bins, logger=logger From e2e886add6389176bac5c508144cc3ee35d365c1 Mon Sep 17 00:00:00 2001 From: yurujaja Date: Tue, 18 Mar 2025 14:12:43 +0100 Subject: [PATCH 2/6] add all rest geobench classification datasets --- configs/dataset/mbrickkiln.yaml | 37 +++++ configs/dataset/mforestnet.yaml | 34 +++++ configs/dataset/mpv4ger.yaml | 30 ++++ configs/dataset/mso2sat.yaml | 38 +++++ pangaea/datasets/geobench/mbrickkiln.py | 175 ++++++++++++++++++++++++ pangaea/datasets/geobench/mforestnet.py | 170 +++++++++++++++++++++++ pangaea/datasets/geobench/mpv4ger.py | 158 +++++++++++++++++++++ pangaea/datasets/geobench/mso2sat.py | 175 ++++++++++++++++++++++++ 8 files changed, 817 insertions(+) create mode 100644 configs/dataset/mbrickkiln.yaml create mode 100644 configs/dataset/mforestnet.yaml create mode 100644 configs/dataset/mpv4ger.yaml create mode 100644 configs/dataset/mso2sat.yaml create mode 100644 pangaea/datasets/geobench/mbrickkiln.py create mode 100644 pangaea/datasets/geobench/mforestnet.py create mode 100644 pangaea/datasets/geobench/mpv4ger.py create mode 100644 pangaea/datasets/geobench/mso2sat.py diff --git a/configs/dataset/mbrickkiln.yaml b/configs/dataset/mbrickkiln.yaml new file mode 100644 index 00000000..457dfa07 --- /dev/null +++ b/configs/dataset/mbrickkiln.yaml @@ -0,0 +1,37 @@ +_target_: pangaea.datasets.geobench.mbrickkiln.mBrickKiln +dataset_name: mBrickKiln +root_path: ${oc.env:GEO_BENCH_DIR}/classification_v1.0/m-brick-kiln +download_url: "recursix/geo-bench-1.0" +auto_download: True + +num_classes: 2 +img_size: 64 +multi_temporal: False +multi_modal: False + +ignore_index: -100 +classes: ['', ''] +distribution: [0, 0] +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B10 + - B11 + - B12 +data_mean: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_std: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_min: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_max: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \ No newline at end of file diff --git a/configs/dataset/mforestnet.yaml b/configs/dataset/mforestnet.yaml new file mode 100644 index 00000000..7d40bc43 --- /dev/null +++ b/configs/dataset/mforestnet.yaml @@ -0,0 +1,34 @@ +_target_: pangaea.datasets.geobench.mforestnet.mForestnet +dataset_name: mForestnet +root_path: ${oc.env:GEO_BENCH_DIR}/classification_v1.0/m-forestnet +download_url: "recursix/geo-bench-1.0" +auto_download: True + +num_classes: 12 +img_size: 332 +multi_temporal: False +multi_modal: False + +ignore_index: -100 +classes: ['', '', '', '', '', '', '', '', '', '', '', ''] +distribution: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + +# data stats +bands: + optical: + - B4 + - B3 + - B2 + - B5 + - B6 + - B7 + +data_mean: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + +data_std: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_min: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_max: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] \ No newline at end of file diff --git a/configs/dataset/mpv4ger.yaml b/configs/dataset/mpv4ger.yaml new file mode 100644 index 00000000..b0617c9a --- /dev/null +++ b/configs/dataset/mpv4ger.yaml @@ -0,0 +1,30 @@ +_target_: pangaea.datasets.geobench.mpv4ger.mPv4ger +dataset_name: mPv4ger +root_path: ${oc.env:GEO_BENCH_DIR}/classification_v1.0/m-pv4ger +download_url: "recursix/geo-bench-1.0" +auto_download: True + +num_classes: 2 +img_size: 320 +multi_temporal: False +multi_modal: False + +ignore_index: -100 +classes: ['', ''] +distribution: [0, 0] + +# data stats +bands: + optical: + - B4 + - B3 + - B2 + +data_mean: + optical: [0.0, 0.0, 0.0] +data_std: + optical: [0.0, 0.0, 0.0] +data_min: + optical: [0.0, 0.0, 0.0] +data_max: + optical: [0.0, 0.0, 0.0] \ No newline at end of file diff --git a/configs/dataset/mso2sat.yaml b/configs/dataset/mso2sat.yaml new file mode 100644 index 00000000..7582d6c0 --- /dev/null +++ b/configs/dataset/mso2sat.yaml @@ -0,0 +1,38 @@ +_target_: pangaea.datasets.geobench.mso2sat.mSo2Sat +dataset_name: mSo2Sat +root_path: ${oc.env:GEO_BENCH_DIR}/classification_v1.0/m-so2sat +download_url: "recursix/geo-bench-1.0" +auto_download: True + + +num_classes: 17 +img_size: 32 +multi_temporal: False +multi_modal: False + +bands: + optical: + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B11 + - B12 + +ignore_index: -100 +classes: ['', '', '', '', '', '', '', '', '', '', '', '','', '', '', '', ''] +distribution: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + +data_mean: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + +data_std: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_min: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_max: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] \ No newline at end of file diff --git a/pangaea/datasets/geobench/mbrickkiln.py b/pangaea/datasets/geobench/mbrickkiln.py new file mode 100644 index 00000000..c2d44259 --- /dev/null +++ b/pangaea/datasets/geobench/mbrickkiln.py @@ -0,0 +1,175 @@ + +import numpy as np +import torch +import os +from pathlib import Path +from pangaea.datasets.base import RawGeoFMDataset +from pangaea.datasets.utils import decompress_zip_with_progress +from huggingface_hub import HfApi, hf_hub_download +import geobench + + +class mBrickKiln(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-Brick-Kiln dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mBrickKiln, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + + all_band_names = ( + "01", + "02", + "03", + "04", + "05", + "06", + "07", + "08", + "08A", + "09", + "10", + "11", + "12", + ) + rgb_bands = ("04", "03", "02") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 4095 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "metadata": { + "filename": filename}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['classification_v1.0/m-pv4ger.zip', 'classification_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/mforestnet.py b/pangaea/datasets/geobench/mforestnet.py new file mode 100644 index 00000000..69c9a96e --- /dev/null +++ b/pangaea/datasets/geobench/mforestnet.py @@ -0,0 +1,170 @@ + +import numpy as np +import torch +import os +from pathlib import Path +from pangaea.datasets.utils import decompress_zip_with_progress +from huggingface_hub import HfApi, hf_hub_download +from pangaea.datasets.base import RawGeoFMDataset +import geobench +from torchvision import transforms + +class mForestnet(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the mForestNet dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mForestnet, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + # for band in sample.bands: + # print(f" {band.band_info.name}: {band.data.shape}") + all_band_names = ( + "04", + "03", + "02", + "05", + "06", + "07", + ) + rgb_bands = ("04", "03", "02") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 255 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "metadata": { + "filename": filename}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['classification_v1.0/m-forestnet.zip', 'classification_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/mpv4ger.py b/pangaea/datasets/geobench/mpv4ger.py new file mode 100644 index 00000000..3fd238f2 --- /dev/null +++ b/pangaea/datasets/geobench/mpv4ger.py @@ -0,0 +1,158 @@ +import numpy as np +import torch +import os +from pathlib import Path +from pangaea.datasets.base import RawGeoFMDataset +from pangaea.datasets.utils import decompress_zip_with_progress +from huggingface_hub import HfApi, hf_hub_download +import geobench + + +class mPv4ger(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-PV4-Ger dataset. + Link: https://github.com/ServiceNow/geo-bench + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mPv4ger, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + + image, band_names = sample.pack_to_3d(band_names=["Red", "Green", "Blue"]) + label = sample.label + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 255 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "metadata": { + "filename": filename}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['classification_v1.0/m-pv4ger.zip', 'classification_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/mso2sat.py b/pangaea/datasets/geobench/mso2sat.py new file mode 100644 index 00000000..4fca7d32 --- /dev/null +++ b/pangaea/datasets/geobench/mso2sat.py @@ -0,0 +1,175 @@ + +import numpy as np +import torch +import os +from pathlib import Path +from pangaea.datasets.base import RawGeoFMDataset +from pangaea.datasets.utils import decompress_zip_with_progress +from huggingface_hub import HfApi, hf_hub_download +import geobench + + +class mSo2Sat(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-So2Sat dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mSo2Sat, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + # for band in sample.bands: + # print(f" {band.band_info.name}: {band.data.shape}") + all_band_names = ( + "02 - Blue", + "03 - Green", + "04 - Red", + "05 - Vegetation Red Edge", + "06 - Vegetation Red Edge", + "07 - Vegetation Red Edge", + "08 - NIR", + "08A - Vegetation Red Edge", + "11", + "12", + ) + + rgb_bands = ("04 - Red", "03 - Green", "02 - Blue") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "metadata": { + "filename": filename}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['classification_v1.0/m-so2sat.zip', 'classification_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file From e2a926f38e457b5c4a2170535a265f8cfb986849 Mon Sep 17 00:00:00 2001 From: yurujaja Date: Fri, 23 May 2025 14:26:31 +0200 Subject: [PATCH 3/6] conditionally install geobench at runtime --- pangaea/datasets/geobench/mbigearthnet.py | 9 ++++++++- pangaea/datasets/geobench/mbrickkiln.py | 9 ++++++++- pangaea/datasets/geobench/meurosat.py | 9 ++++++++- pangaea/datasets/geobench/mforestnet.py | 11 +++++++++-- pangaea/datasets/geobench/mpv4ger.py | 9 ++++++++- pangaea/datasets/geobench/mso2sat.py | 9 ++++++++- 6 files changed, 49 insertions(+), 7 deletions(-) diff --git a/pangaea/datasets/geobench/mbigearthnet.py b/pangaea/datasets/geobench/mbigearthnet.py index 84ab4767..755178e7 100644 --- a/pangaea/datasets/geobench/mbigearthnet.py +++ b/pangaea/datasets/geobench/mbigearthnet.py @@ -8,7 +8,14 @@ from pathlib import Path from tqdm import tqdm from huggingface_hub import HfApi, hf_hub_download -from torchvision import transforms +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench class mBigEarthNet(RawGeoFMDataset): diff --git a/pangaea/datasets/geobench/mbrickkiln.py b/pangaea/datasets/geobench/mbrickkiln.py index c2d44259..aabf1b53 100644 --- a/pangaea/datasets/geobench/mbrickkiln.py +++ b/pangaea/datasets/geobench/mbrickkiln.py @@ -6,7 +6,14 @@ from pangaea.datasets.base import RawGeoFMDataset from pangaea.datasets.utils import decompress_zip_with_progress from huggingface_hub import HfApi, hf_hub_download -import geobench +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench class mBrickKiln(RawGeoFMDataset): diff --git a/pangaea/datasets/geobench/meurosat.py b/pangaea/datasets/geobench/meurosat.py index e8def587..c6d1cc23 100644 --- a/pangaea/datasets/geobench/meurosat.py +++ b/pangaea/datasets/geobench/meurosat.py @@ -5,7 +5,14 @@ from pangaea.datasets.utils import decompress_zip_with_progress from pathlib import Path from huggingface_hub import HfApi, hf_hub_download -import geobench +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench class mEuroSat(RawGeoFMDataset): diff --git a/pangaea/datasets/geobench/mforestnet.py b/pangaea/datasets/geobench/mforestnet.py index 69c9a96e..777a6d49 100644 --- a/pangaea/datasets/geobench/mforestnet.py +++ b/pangaea/datasets/geobench/mforestnet.py @@ -6,8 +6,15 @@ from pangaea.datasets.utils import decompress_zip_with_progress from huggingface_hub import HfApi, hf_hub_download from pangaea.datasets.base import RawGeoFMDataset -import geobench -from torchvision import transforms +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench + class mForestnet(RawGeoFMDataset): def __init__( diff --git a/pangaea/datasets/geobench/mpv4ger.py b/pangaea/datasets/geobench/mpv4ger.py index 3fd238f2..91749e22 100644 --- a/pangaea/datasets/geobench/mpv4ger.py +++ b/pangaea/datasets/geobench/mpv4ger.py @@ -5,7 +5,14 @@ from pangaea.datasets.base import RawGeoFMDataset from pangaea.datasets.utils import decompress_zip_with_progress from huggingface_hub import HfApi, hf_hub_download -import geobench +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench class mPv4ger(RawGeoFMDataset): diff --git a/pangaea/datasets/geobench/mso2sat.py b/pangaea/datasets/geobench/mso2sat.py index 4fca7d32..f4b5d90e 100644 --- a/pangaea/datasets/geobench/mso2sat.py +++ b/pangaea/datasets/geobench/mso2sat.py @@ -6,7 +6,14 @@ from pangaea.datasets.base import RawGeoFMDataset from pangaea.datasets.utils import decompress_zip_with_progress from huggingface_hub import HfApi, hf_hub_download -import geobench +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench class mSo2Sat(RawGeoFMDataset): From a03ec0b0343232b7631271fd76b2dec0a271bb31 Mon Sep 17 00:00:00 2001 From: yurujaja Date: Wed, 4 Jun 2025 09:54:44 +0200 Subject: [PATCH 4/6] include segmentation tasks --- .github/CONTRIBUTING.md => CONTRIBUTING.md | 0 DATASET_GUIDE.md | 65 ++++++ README.md | 10 +- configs/dataset/mbigearthnet.yaml | 4 +- configs/dataset/mcashew-plantation.yaml | 54 +++++ configs/dataset/mchesapeake-landcover.yaml | 56 ++++++ configs/dataset/meurosat.yaml | 4 +- configs/dataset/mneontree.yaml | 43 ++++ configs/dataset/mnz-cattle.yaml | 43 ++++ configs/dataset/mpv4ger-seg.yaml | 43 ++++ configs/dataset/msa-crop-type.yaml | 77 ++++++++ pangaea/datasets/geobench/mbigearthnet.py | 4 +- pangaea/datasets/geobench/mbrickkiln.py | 4 +- .../datasets/geobench/mcashew-plantation.py | 185 ++++++++++++++++++ .../geobench/mchesapeake-landcover.py | 170 ++++++++++++++++ pangaea/datasets/geobench/meurosat.py | 3 - pangaea/datasets/geobench/mforestnet.py | 4 +- pangaea/datasets/geobench/mneontree.py | 167 ++++++++++++++++ pangaea/datasets/geobench/mnz-cattle.py | 169 ++++++++++++++++ pangaea/datasets/geobench/mpv4ger-seg.py | 165 ++++++++++++++++ pangaea/datasets/geobench/mpv4ger.py | 4 +- pangaea/datasets/geobench/msa-crop-type.py | 182 +++++++++++++++++ pangaea/datasets/geobench/mso2sat.py | 2 +- 23 files changed, 1438 insertions(+), 20 deletions(-) rename .github/CONTRIBUTING.md => CONTRIBUTING.md (100%) create mode 100644 configs/dataset/mcashew-plantation.yaml create mode 100644 configs/dataset/mchesapeake-landcover.yaml create mode 100644 configs/dataset/mneontree.yaml create mode 100644 configs/dataset/mnz-cattle.yaml create mode 100644 configs/dataset/mpv4ger-seg.yaml create mode 100644 configs/dataset/msa-crop-type.yaml create mode 100644 pangaea/datasets/geobench/mcashew-plantation.py create mode 100644 pangaea/datasets/geobench/mchesapeake-landcover.py create mode 100644 pangaea/datasets/geobench/mneontree.py create mode 100644 pangaea/datasets/geobench/mnz-cattle.py create mode 100644 pangaea/datasets/geobench/mpv4ger-seg.py create mode 100644 pangaea/datasets/geobench/msa-crop-type.py diff --git a/.github/CONTRIBUTING.md b/CONTRIBUTING.md similarity index 100% rename from .github/CONTRIBUTING.md rename to CONTRIBUTING.md diff --git a/DATASET_GUIDE.md b/DATASET_GUIDE.md index 67fa7748..db58eaeb 100644 --- a/DATASET_GUIDE.md +++ b/DATASET_GUIDE.md @@ -3,6 +3,28 @@ This document provides a detailed overview of the datasets used in this repository. For each dataset, you will find instructions on how to prepare the data, along with command-line examples for running models. *DISCLAIMER*: please consider that we provide the detailed overview for the datasets included in the original repo. Community-contributed datasets may not come with pre-defined command-line examples in this repository. Feel free to adapt the existing examples based on your use case. +## 📚 Table of Contents + +- [HLSBurnScars](#hlsburnscars) +- [MADOS](#mados) +- [PASTIS-R](#pastis-r) +- [Sen1Floods11](#sen1floods11) +- [xView2](#xview2) +- [FiveBillionPixels](#fivebillionpixels) +- [DynamicEarthNet](#dynamicearthnet) +- [Crop Type Mapping (South Sudan)](#crop-type-mapping-south-sudan) +- [SpaceNet 7](#spacenet-7) +- [AI4SmallFarms](#ai4smallfarms) +- [BioMassters](#biomassters) + +### 🧪 Community-Contributed Datasets +- [Potsdam](#potsdam) +- [Geo-Bench Datasets](#geo-bench-datasets) + - [Multi-label Classification (e.g., m-BigEarthNet)](#for-multi-label-classification-eg-m-bigearthnet) + - [Single-label Classification (e.g., m-EuroSat, m-Brick-Kiln)](#for-single-label-classification-ie-m-eurosat-m-brick-kiln-m-forestnet-m-pv4ger-m-so2sat) + - [Semantic Segmentation (e.g., m-NZ-Cattle, m-SA-Crop-Type)](#for-semantic-segmentation-ie-m-cashew-plantation-m-chesapeake-landcover-m-neontree-m-nz-cattle-m-pv4ger-seg-and-m-sa-crop-type) + +--- ### HLSBurnScars @@ -222,6 +244,7 @@ This document provides a detailed overview of the datasets used in this reposito ``` In this case, you can specify in the `temp` parameter which frame you want to use. +--- **Note**: The following datasets are **community-contributed** and are not part of the original benchmark repository. ### Potsdam ``` @@ -234,3 +257,45 @@ This document provides a detailed overview of the datasets used in this reposito criterion=cross_entropy \ task=segmentation ``` +### Geo-Bench Datasets +- For multi-label classification, e.g., m-BigEarthNet + ``` + export GEO_BENCH_DIR=YOUR/PATH/DIR + torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \ + --config-name=train \ + dataset=mbigearthnet \ + encoder=dofa \ + decoder=cls_linear \ + preprocessing=cls_resize \ + criterion=binary_cross_entropy \ + task=classification_multi_label \ + finetune=false + ``` + +- For single-label classification, i.e., m-EuroSat, m-Brick-Kiln, m-ForestNet, m-PV4Ger, m-So2Sat + ``` + export GEO_BENCH_DIR=YOUR/PATH/DIR + torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \ + --config-name=train \ + dataset=meurosat \ + encoder=dofa \ + decoder=cls_linear \ + preprocessing=cls_resize \ + criterion=cross_entropy \ + task=classification \ + finetune=false + ``` + +- For semantic segmentation, i.e., m-Cashew-Plantation, m-Chesapeake-Landcover, m-NeonTree, m-NZ-Cattle, m-PV4Ger-Seg and m-SA-Crop-Type + ``` + export GEO_BENCH_DIR=YOUR/PATH/DIR + torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \ + --config-name=train \ + dataset=mnz-cattle \ + encoder=dofa \ + decoder=seg_upernet \ + preprocessing=seg_default \ + criterion=cross_entropy \ + task=segmentation \ + finetune=false + ``` \ No newline at end of file diff --git a/README.md b/README.md index e48f31a8..bf2ee47b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ # PANGAEA: A Global and Inclusive Benchmark for Geospatial Foundation Models 📢 **News** - - [23/04/2025] we pushed a new version of the code, fixing different bugs (e.g. commands are working for all the datasets now, metric computation with ignore_index is fixed, etc...). In the next month, we will provide: all downloadable datasets and models, downloadable stratified subsamples for all the datasets, classification. Stay tuned! + - [04/06/2025] We integrate [Geo-Bench](https://arxiv.org/abs/2306.03831) Datasets, including six segmentation and six classification tasks. + - [22/04/2025] on EarthDay, PANGAEA was officialy adopted to benchmark TerraMind. Read the [news](https://www.linkedin.com/posts/simonetta-cheli-7669879b_earthday-earthobservation-activity-7320439907028467712-LSzl?utm_source=share&utm_medium=member_desktop&rcm=ACoAACdT8q0BDNWYKAdDYGUe_X4fQOzSHO8jgAs) and the [pre-print](https://arxiv.org/abs/2504.11171). We will release the benchmarking code in PANGAEA very soon! - [05/12/2024] the [pre-print](https://arxiv.org/abs/2412.04204) is out! @@ -52,6 +53,7 @@ And the following **datasets**: **Note**: The following datasets are **community-contributed** and are not part of the original benchmark repository. We are grateful for these contributions, which help enrich the benchmark's diversity and applicability. - **Potsdam dataset** [[Link](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-potsdam.aspx)]. Contributed by [@pierreadorni](https://github.com/pierreadorni). +- **Geo-Bench datasets** [[Link](https://github.com/ServiceNow/geo-bench)]. Contributed by [@yurujaja](https://github.com/yurujaja). The repository supports the following **tasks** using geospatial (foundation) models: - [Single Temporal Semantic Segmentation](#single-temporal-semantic-segmentation) @@ -299,11 +301,11 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \ ### Using Your Own Dataset -Refer to: [Adding a new downstream dataset](.github/CONTRIBUTING.md#adding-a-new-downstream-dataset) +Refer to: [Adding a new downstream dataset](CONTRIBUTING.md#adding-a-new-downstream-dataset) ### Using Your Own Model -Refer to: [Adding a new geospatial foundation model](.github/CONTRIBUTING.md#adding-a-new-geospatial-foundation-model) +Refer to: [Adding a new geospatial foundation model](CONTRIBUTING.md#adding-a-new-geospatial-foundation-model) ## 🏃 Evaluation @@ -316,7 +318,7 @@ torchrun pangaea/run.py --config-name=test ckpt_dir=path_to_ckpt_dir ``` ## ✏️ Contributing -We appreciate all contributions. Please refer to [Contributing Guidelines](.github/CONTRIBUTING.md). +We appreciate all contributions. Please refer to [Contributing Guidelines](CONTRIBUTING.md). ## ⚠️ TO DO diff --git a/configs/dataset/mbigearthnet.yaml b/configs/dataset/mbigearthnet.yaml index 12ec8a62..8f33bd4e 100644 --- a/configs/dataset/mbigearthnet.yaml +++ b/configs/dataset/mbigearthnet.yaml @@ -29,10 +29,10 @@ distribution: [0,] # data stats data_mean: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + optical: [378.4027, 482.2730, 706.5345, 720.9285, 1100.6688, 1909.2914, 2191.6985, 2336.8706, 2394.7449, 2368.3127, 1875.2490, 1229.3818] data_std: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + optical: [462.4629, 519.3317, 552.3597, 680.9677, 690.2879, 982.2125, 1143.4189, 1248.0188, 1223.6399, 1166.8386, 1092.4415, 862.7205] data_min: optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] data_max: diff --git a/configs/dataset/mcashew-plantation.yaml b/configs/dataset/mcashew-plantation.yaml new file mode 100644 index 00000000..3a7ca3cc --- /dev/null +++ b/configs/dataset/mcashew-plantation.yaml @@ -0,0 +1,54 @@ +_target_: pangaea.datasets.geobench.mcashew-plantation.mCashewPlant +dataset_name: mCashew-Plant +root_path: ${oc.env:GEO_BENCH_DIR}/segmentation_v1.0/m-cashew-plant +download_url: "recursix/geo-bench-1.0" +auto_download: True + +img_size: 256 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: 255 +num_classes: 7 +classes: + - 'no data' + - 'well-managed plantation' + - 'poorly-managed plantation' + - 'non-plantation' + - 'residential' + - 'background' + - 'uncertain' +distribution: + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + + +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B11 + - B12 + +data_mean: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_std: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_min: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +data_max: + optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \ No newline at end of file diff --git a/configs/dataset/mchesapeake-landcover.yaml b/configs/dataset/mchesapeake-landcover.yaml new file mode 100644 index 00000000..27ea7284 --- /dev/null +++ b/configs/dataset/mchesapeake-landcover.yaml @@ -0,0 +1,56 @@ +_target_: pangaea.datasets.geobench.mchesapeake-landcover.mChesapeake +dataset_name: mChesapeake +root_path: ${oc.env:GEO_BENCH_DIR}/segmentation_v1.0/m-chesapeake +download_url: "recursix/geo-bench-1.0" +auto_download: True + +img_size: 256 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: -1 +num_classes: 7 +classes: + - 'water' + - 'tree-canopy-forest' + - 'low-vegetation-field' + - 'barren-land' + - 'impervious-other' + - 'impervious-roads' + - 'no data' +distribution: + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + +# data stats +bands: + optical: + - B2 + - B3 + - B4 + - B8 + +data_mean: + optical: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + +data_std: + optical: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + +data_min: + optical: [0.0000, 0.0, 0.0, 0.0] +data_max: + optical: [0.0000, 0.0, 0.0, 0.0] \ No newline at end of file diff --git a/configs/dataset/meurosat.yaml b/configs/dataset/meurosat.yaml index b5026c55..133c151e 100644 --- a/configs/dataset/meurosat.yaml +++ b/configs/dataset/meurosat.yaml @@ -30,9 +30,9 @@ distribution: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] data_mean: - optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + optical: [1359.9467, 1125.5342, 1054.9949, 957.3436, 1219.6685, 2051.6189, 2433.6013, 2360.0952, 751.1636, 12.2881, 1848.8993, 1131.2699, 2665.4370] data_std: - optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + optical: [251.3307, 339.6782, 396.7326, 592.8387, 555.0112, 852.5016, 1081.5669, 1115.1144, 404.5709, 4.7965, 978.8317, 745.2798, 1223.8777] data_min: optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] data_max: diff --git a/configs/dataset/mneontree.yaml b/configs/dataset/mneontree.yaml new file mode 100644 index 00000000..8b64f516 --- /dev/null +++ b/configs/dataset/mneontree.yaml @@ -0,0 +1,43 @@ +_target_: pangaea.datasets.geobench.mneontree.mNeonTree +dataset_name: mNeonTree +root_path: ${oc.env:GEO_BENCH_DIR}/segmentation_v1.0/m-NeonTree +download_url: "recursix/geo-bench-1.0" +auto_download: True + +img_size: 400 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: -1 +num_classes: 2 +classes: + - No tree + - Tree +distribution: + - 0 + - 0 + +# data stats +bands: + optical: + - B4 + - B3 + - B2 + +data_mean: + optical: + - 122.29733333333333 + - 131.00455555555556 + - 108.44846666666666 + +data_std: + optical: + - 54.053087350753124 + - 51.442224204429245 + - 33.09221632578563 + +data_min: + optical: [0.0, 0.0, 0.0] +data_max: + optical: [255.0, 255.0, 255.0] \ No newline at end of file diff --git a/configs/dataset/mnz-cattle.yaml b/configs/dataset/mnz-cattle.yaml new file mode 100644 index 00000000..71562fe7 --- /dev/null +++ b/configs/dataset/mnz-cattle.yaml @@ -0,0 +1,43 @@ +_target_: pangaea.datasets.geobench.mnz-cattle.mNzCattle +dataset_name: mNzCattle +root_path: ${oc.env:GEO_BENCH_DIR}/segmentation_v1.0/m-nz-cattle +download_url: "recursix/geo-bench-1.0" +auto_download: True + +img_size: 500 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: -1 +num_classes: 2 +classes: + - No cattle + - Cattle +distribution: + - 0 + - 0 + +# data stats +bands: + optical: + - B4 + - B3 + - B2 + +data_mean: + optical: + - 126.31354389312978 + - 130.09102671755724 + - 106.51769083969465 + +data_std: + optical: + - 23.38495539231075 + - 19.92435830412983 + - 18.991624007810348 + +data_min: + optical: [0.0, 0.0, 0.0] +data_max: + optical: [255.0, 255.0, 255.0] \ No newline at end of file diff --git a/configs/dataset/mpv4ger-seg.yaml b/configs/dataset/mpv4ger-seg.yaml new file mode 100644 index 00000000..595458b6 --- /dev/null +++ b/configs/dataset/mpv4ger-seg.yaml @@ -0,0 +1,43 @@ +_target_: pangaea.datasets.geobench.mpv4ger-seg.mPv4GerSeg +dataset_name: mPv4GerSeg +root_path: ${oc.env:GEO_BENCH_DIR}/segmentation_v1.0/m-pv4ger-seg +download_url: "recursix/geo-bench-1.0" +auto_download: True + +img_size: 320 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: -1 +num_classes: 2 +classes: + - 'no solar pv' + - 'solar pv' +distribution: + - 0 + - 0 + +# data stats +bands: + optical: + - B4 + - B3 + - B2 + +data_mean: + optical: + - 131.102356 + - 137.354091 + - 139.761751 + +data_std: + optical: + - 54.52768048660482 + - 50.86544377633718 + - 48.29800594656056 + +data_min: + optical: [0.0, 0.0, 0.0] +data_max: + optical: [255.0, 255.0, 255.0] \ No newline at end of file diff --git a/configs/dataset/msa-crop-type.yaml b/configs/dataset/msa-crop-type.yaml new file mode 100644 index 00000000..3619f7a7 --- /dev/null +++ b/configs/dataset/msa-crop-type.yaml @@ -0,0 +1,77 @@ +_target_: pangaea.datasets.geobench.msa-crop-type.mSACropType +dataset_name: mSACropType +root_path: ${oc.env:GEO_BENCH_DIR}/segmentation_v1.0/m-SA-crop-type +download_url: "recursix/geo-bench-1.0" +auto_download: True + +img_size: 256 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: -1 +num_classes: 10 +classes: + - 'no data' + - 'Lucerne/Medics' + - 'Planted pastures (perennial)' + - 'Fallow' + - 'Wine grapes' + - 'Weeds' + - 'Small grain grazing' + - 'Wheat' + - 'Canola' + - 'Rooibos' +distribution: + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + - 0 + +# data stats +bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B11 + - B12 + +data_mean: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_std: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_min: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +data_max: + optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + +# data_mean: +# optical: +# - 0.0 +# - 0.0 +# - 0.0 + +# data_std: +# optical: +# - 0.0 +# - 0.0 +# - 0.0 + +# data_min: +# optical: [0.0000, 0.0, 0.0] +# data_max: +# optical: [0.0000, 0.0, 0.0] \ No newline at end of file diff --git a/pangaea/datasets/geobench/mbigearthnet.py b/pangaea/datasets/geobench/mbigearthnet.py index 755178e7..898a80bb 100644 --- a/pangaea/datasets/geobench/mbigearthnet.py +++ b/pangaea/datasets/geobench/mbigearthnet.py @@ -136,8 +136,8 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 4095 - image = np.clip(image, 0, 1) + # image = image / 4095 + # image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mbrickkiln.py b/pangaea/datasets/geobench/mbrickkiln.py index aabf1b53..2132f494 100644 --- a/pangaea/datasets/geobench/mbrickkiln.py +++ b/pangaea/datasets/geobench/mbrickkiln.py @@ -135,8 +135,8 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 4095 - image = np.clip(image, 0, 1) + # image = image / 4095 + # image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mcashew-plantation.py b/pangaea/datasets/geobench/mcashew-plantation.py new file mode 100644 index 00000000..039fd809 --- /dev/null +++ b/pangaea/datasets/geobench/mcashew-plantation.py @@ -0,0 +1,185 @@ + +import numpy as np +import torch +import os +from pathlib import Path +from pangaea.datasets.base import RawGeoFMDataset +from pangaea.datasets.utils import decompress_zip_with_progress +from huggingface_hub import HfApi, hf_hub_download +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench + + +class mCashewPlant(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-Cashew-Plantation dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mCashewPlant, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + # for band in sample.bands: + # print(f" {band.band_info.name}: {band.data.shape}") + all_band_names = ( + "01", + "02", + "03", + "04", + "05", + "06", + "07", + "08", + "08A", + "09", + "11", + "12", + ) + + rgb_bands = ("04", "03", "02") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label.data + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 4095 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "filename": filename, + "metadata": {}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['segmentation_v1.0/m-cashew-plant.zip', 'segmentation_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/mchesapeake-landcover.py b/pangaea/datasets/geobench/mchesapeake-landcover.py new file mode 100644 index 00000000..fdeeb488 --- /dev/null +++ b/pangaea/datasets/geobench/mchesapeake-landcover.py @@ -0,0 +1,170 @@ + +import numpy as np +import torch +import os +from pathlib import Path +from pangaea.datasets.base import RawGeoFMDataset +from pangaea.datasets.utils import decompress_zip_with_progress +from huggingface_hub import HfApi, hf_hub_download +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench + + +class mChesapeake(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-Chesapeake-Landcover dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mChesapeake, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + # for band in sample.bands: + # print(f"{band.band_info.name}: {band.data.shape}") + all_band_names = ("Blue", "Green", "Red", "NearInfrared") + rgb_bands = ("Red", "Green", "Blue") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label.data + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 255 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "filename": filename, + "metadata": {}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['segmentation_v1.0/m-chesapeake.zip', 'segmentation_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + \ No newline at end of file diff --git a/pangaea/datasets/geobench/meurosat.py b/pangaea/datasets/geobench/meurosat.py index c6d1cc23..ff7fd58e 100644 --- a/pangaea/datasets/geobench/meurosat.py +++ b/pangaea/datasets/geobench/meurosat.py @@ -134,12 +134,9 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 4095 - image = np.clip(image, 0, 1) image=image.unsqueeze(1) - return { "image": { "optical": image, diff --git a/pangaea/datasets/geobench/mforestnet.py b/pangaea/datasets/geobench/mforestnet.py index 777a6d49..11a0c0e9 100644 --- a/pangaea/datasets/geobench/mforestnet.py +++ b/pangaea/datasets/geobench/mforestnet.py @@ -129,8 +129,8 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 255 - image = np.clip(image, 0, 1) + # image = image / 255 + # image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mneontree.py b/pangaea/datasets/geobench/mneontree.py new file mode 100644 index 00000000..541a3e8b --- /dev/null +++ b/pangaea/datasets/geobench/mneontree.py @@ -0,0 +1,167 @@ + +import numpy as np +import torch +import os +from pangaea.datasets.base import RawGeoFMDataset +import subprocess +import sys +from pathlib import Path +from huggingface_hub import HfApi, hf_hub_download +from pangaea.datasets.utils import decompress_zip_with_progress +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench + + +class mNeonTree(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-Neon-Tree dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mNeonTree, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + # for band in sample.bands: + # print(f" {band.band_info.name}: {band.data.shape}") + all_band_names = ("BLUE", "CANOPY_HEIGHT_MODEL", "GREEN", "NEON", "RED") + rgb_bands = ("Red", "Green", "Blue") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["rgb"]) + label = sample.label.data + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image=image.unsqueeze(1) + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "filename": filename, + "metadata": {}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['classification_v1.0/m-NeonTree.zip', 'classification_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/mnz-cattle.py b/pangaea/datasets/geobench/mnz-cattle.py new file mode 100644 index 00000000..73db145d --- /dev/null +++ b/pangaea/datasets/geobench/mnz-cattle.py @@ -0,0 +1,169 @@ + +import numpy as np +import torch +import os +from pangaea.datasets.base import RawGeoFMDataset +import subprocess +import sys +from pathlib import Path +from huggingface_hub import HfApi, hf_hub_download +from pangaea.datasets.utils import decompress_zip_with_progress +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench + + +class mNzCattle(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-NZ-Cattle dataset. + Link: https://phys-techsciences.datastations.nl/dataset.xhtml?persistentId=doi:10.17026/dans-xy6-ngg6 + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mNzCattle, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + self.dataset_path = os.path.join(self.root_path, 'segmentation_v1.0', 'm-nz-cattle') + task = geobench.load_task_specs(self.dataset_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + + all_band_names = ("Blue", "Green", "Red") + rgb_bands = ("Red", "Green", "Blue") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["rgb"]) + label = sample.label.data + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "filename": filename, + "metadata": {}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['segmentation_v1.0/m-nz-cattle.zip', 'segmentation_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/mpv4ger-seg.py b/pangaea/datasets/geobench/mpv4ger-seg.py new file mode 100644 index 00000000..dcfe8813 --- /dev/null +++ b/pangaea/datasets/geobench/mpv4ger-seg.py @@ -0,0 +1,165 @@ + +import numpy as np +import torch +import os +from pangaea.datasets.base import RawGeoFMDataset +import subprocess +import sys +from pathlib import Path +from huggingface_hub import HfApi, hf_hub_download +from pangaea.datasets.utils import decompress_zip_with_progress +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench + + +class mPv4GerSeg(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-PV4Ger-Seg dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mPv4GerSeg, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + all_band_names = ("Blue", "Green", "Red") + rgb_bands = ("Red", "Green", "Blue") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["rgb"]) + label = sample.label.data + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + + image=image.unsqueeze(1) + + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "filename": filename, + "metadata": {}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['segmentation_v1.0/m-pv4ger-seg.zip', 'segmentation_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + diff --git a/pangaea/datasets/geobench/mpv4ger.py b/pangaea/datasets/geobench/mpv4ger.py index 91749e22..4b24550e 100644 --- a/pangaea/datasets/geobench/mpv4ger.py +++ b/pangaea/datasets/geobench/mpv4ger.py @@ -116,8 +116,8 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 255 - image = np.clip(image, 0, 1) + # image = image / 255 + # image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/msa-crop-type.py b/pangaea/datasets/geobench/msa-crop-type.py new file mode 100644 index 00000000..d368e1c5 --- /dev/null +++ b/pangaea/datasets/geobench/msa-crop-type.py @@ -0,0 +1,182 @@ + +import numpy as np +import torch +import os +from pangaea.datasets.base import RawGeoFMDataset +from pathlib import Path +from huggingface_hub import HfApi, hf_hub_download +from pangaea.datasets.utils import decompress_zip_with_progress +import subprocess +import sys +try: + import geobench +except ImportError: + print("geobench not found. Installing via pip...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "geobench"]) + import geobench + + +class mSACropType(RawGeoFMDataset): + def __init__( + self, + split: str, + dataset_name: str, + multi_modal: bool, + multi_temporal: int, + root_path: str, + classes: list, + num_classes: int, + ignore_index: int, + img_size: int, + bands: dict[str, list[str]], + distribution: list[int], + data_mean: dict[str, list[str]], + data_std: dict[str, list[str]], + data_min: dict[str, list[str]], + data_max: dict[str, list[str]], + download_url: str, + auto_download: bool, + ): + """Initialize the m-SA-Crop-Type dataset. + Link: https://github.com/ServiceNow/geo-bench + + Args: + split (str): split of the dataset (train, val, test). + dataset_name (str): dataset name. + multi_modal (bool): if the dataset is multi-modal. + multi_temporal (int): number of temporal frames. + root_path (str): root path of the dataset. + classes (list): classes of the dataset. + num_classes (int): number of classes. + ignore_index (int): index to ignore for metrics and loss. + img_size (int): size of the image. + bands (dict[str, list[str]]): bands of the dataset. + distribution (list[int]): class distribution. + data_mean (dict[str, list[str]]): mean for each band for each modality. + Dictionary with keys as the modality and values as the list of means. + e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} + data_std (dict[str, list[str]]): str for each band for each modality. + Dictionary with keys as the modality and values as the list of stds. + e.g. {"s2": [b1_std, ..., bn_std], "s1": [b1_std, ..., bn_std]} + data_min (dict[str, list[str]]): min for each band for each modality. + Dictionary with keys as the modality and values as the list of mins. + e.g. {"s2": [b1_min, ..., bn_min], "s1": [b1_min, ..., bn_min]} + data_max (dict[str, list[str]]): max for each band for each modality. + Dictionary with keys as the modality and values as the list of maxs. + e.g. {"s2": [b1_max, ..., bn_max], "s1": [b1_max, ..., bn_max]} + download_url (str): url to download the dataset. + auto_download (bool): whether to download the dataset automatically. + """ + super(mSACropType, self).__init__( + split=split, + dataset_name=dataset_name, + multi_modal=multi_modal, + multi_temporal=multi_temporal, + root_path=root_path, + classes=classes, + num_classes=num_classes, + ignore_index=ignore_index, + img_size=img_size, + bands=bands, + distribution=distribution, + data_mean=data_mean, + data_std=data_std, + data_min=data_min, + data_max=data_max, + download_url=download_url, + auto_download=auto_download, + ) + + self.data_mean = data_mean + self.data_std = data_std + self.data_min = data_min + self.data_max = data_max + self.classes = classes + self.img_size = img_size + self.distribution = distribution + self.num_classes = num_classes + self.ignore_index = ignore_index + self.download_url = download_url + self.auto_download = auto_download + + self.root_path = root_path + self.split = split + + split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} + + task = geobench.load_task_specs(self.root_path) + self.dataset = task.get_dataset(split=split_mapping[self.split]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + sample = self.dataset[index] + # for band in sample.bands: + # print(f" {band.band_info.name}: {band.data.shape}") + all_band_names = ( + "01", + "02", + "03", + "04", + "05", + "06", + "07", + "08", + "08A", + "09", + "11", + "12", + ) + rgb_bands = ("04", "03", "02") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["all"]) + label = sample.label.data + filename = sample.sample_name + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + image = image / 255 + image = np.clip(image, 0, 1) + + image=image.unsqueeze(1) + + return { + "image": { + "optical": image, + }, + "target": torch.tensor(label, dtype=torch.int64), + "filename": filename, + "metadata": {}, + } + + def download(self, silent=False): + local_directory = Path(os.getenv("GEO_BENCH_DIR")) + dataset_repo = self.download_url + + local_directory.mkdir(parents=True, exist_ok=True) + + api = HfApi() + dataset_files = api.list_repo_files(repo_id=dataset_repo, repo_type="dataset") + + for file in dataset_files: + + if file not in ['segmentation_v1.0/m-SA-crop-type.zip', 'segmentation_v1.0/normalizer.json']: + continue + + local_file_path = local_directory / file + + local_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {file}...") + hf_hub_download( + repo_id=dataset_repo, + filename=file, + cache_dir=local_directory, + local_dir=local_directory, + repo_type="dataset", + ) + if file.endswith(".zip"): + print(f"Decompressing ...") + decompress_zip_with_progress(local_directory / file) + + \ No newline at end of file diff --git a/pangaea/datasets/geobench/mso2sat.py b/pangaea/datasets/geobench/mso2sat.py index f4b5d90e..47346b77 100644 --- a/pangaea/datasets/geobench/mso2sat.py +++ b/pangaea/datasets/geobench/mso2sat.py @@ -135,7 +135,7 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = np.clip(image, 0, 1) + # image = np.clip(image, 0, 1) image=image.unsqueeze(1) From 5ea9d15d2388e0fa48993296bd9751356d9c9347 Mon Sep 17 00:00:00 2001 From: yurujaja Date: Thu, 5 Jun 2025 18:05:14 +0200 Subject: [PATCH 5/6] update dataset stats --- configs/dataset/mbigearthnet.yaml | 5 +-- configs/dataset/mbrickkiln.yaml | 8 +++-- configs/dataset/mcashew-plantation.yaml | 5 +-- configs/dataset/mchesapeake-landcover.yaml | 13 ++----- configs/dataset/meurosat.yaml | 7 ++-- configs/dataset/mforestnet.yaml | 34 +++++++++++++------ configs/dataset/mpv4ger.yaml | 8 ++--- configs/dataset/msa-crop-type.yaml | 5 +-- configs/dataset/mso2sat.yaml | 5 ++- pangaea/datasets/geobench/mbigearthnet.py | 2 -- pangaea/datasets/geobench/mbrickkiln.py | 2 -- .../datasets/geobench/mcashew-plantation.py | 2 -- .../geobench/mchesapeake-landcover.py | 2 -- pangaea/datasets/geobench/mforestnet.py | 6 ++-- pangaea/datasets/geobench/mpv4ger-seg.py | 2 -- pangaea/datasets/geobench/mpv4ger.py | 8 +++-- pangaea/datasets/geobench/msa-crop-type.py | 2 -- pangaea/datasets/geobench/mso2sat.py | 1 - 18 files changed, 58 insertions(+), 59 deletions(-) diff --git a/configs/dataset/mbigearthnet.yaml b/configs/dataset/mbigearthnet.yaml index 8f33bd4e..7eadfeec 100644 --- a/configs/dataset/mbigearthnet.yaml +++ b/configs/dataset/mbigearthnet.yaml @@ -29,10 +29,11 @@ distribution: [0,] # data stats data_mean: - optical: [378.4027, 482.2730, 706.5345, 720.9285, 1100.6688, 1909.2914, 2191.6985, 2336.8706, 2394.7449, 2368.3127, 1875.2490, 1229.3818] + optical: [378.4027, 482.2730, 706.5345, 720.9285, 1100.6688, 1909.2914, 2191.6985, 2336.8706, 2394.7449, 2368.3127, 1875.2487, 1229.3818] data_std: - optical: [462.4629, 519.3317, 552.3597, 680.9677, 690.2879, 982.2125, 1143.4189, 1248.0188, 1223.6399, 1166.8386, 1092.4415, 862.7205] + optical: [157.5666, 255.0429, 303.1750, 391.2943, 380.7916, 551.6558, 638.8196, 744.2009, 675.4041, 561.0154, 563.4095, 479.1786] + data_min: optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] data_max: diff --git a/configs/dataset/mbrickkiln.yaml b/configs/dataset/mbrickkiln.yaml index 457dfa07..7e547668 100644 --- a/configs/dataset/mbrickkiln.yaml +++ b/configs/dataset/mbrickkiln.yaml @@ -10,7 +10,7 @@ multi_temporal: False multi_modal: False ignore_index: -100 -classes: ['', ''] +classes: ['not brick kiln', 'brick kiln'] distribution: [0, 0] bands: optical: @@ -27,10 +27,12 @@ bands: - B10 - B11 - B12 + data_mean: - optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + optical: [574.7587880700896, 674.3473615470523, 886.3656479311578, 815.0945462528913, 1128.8088426870465, 1934.450471876027, 2045.7652282437202, 2012.744587807115, 1608.6255233989034, 1129.8171906000355, 83.27188605598549, 90.54924599052214, 68.98768652434848] data_std: - optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + optical: [193.60631504991184, 238.75447480113132, 276.9631260242207, 361.15060137326634, 364.5888078793488, 724.2707123576525, 819.653063972575, 794.3652427593881, 800.8538290702304, 704.0219637458916, 36.355745901131705, 28.004671947623894, 24.268892726362033] + data_min: optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] data_max: diff --git a/configs/dataset/mcashew-plantation.yaml b/configs/dataset/mcashew-plantation.yaml index 3a7ca3cc..ff8d7e59 100644 --- a/configs/dataset/mcashew-plantation.yaml +++ b/configs/dataset/mcashew-plantation.yaml @@ -45,9 +45,10 @@ bands: - B12 data_mean: - optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + optical: [520.1185302734375, 634.7583618164062, 892.461181640625, 880.7075805664062, 1380.6409912109375, 2233.432373046875, 2549.379638671875, 2643.248046875, 2643.531982421875, 2852.87451171875, 2463.933349609375, 1600.9207763671875] data_std: - optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + optical: [204.2023468017578, 227.25344848632812, 222.32545471191406, 350.47235107421875, 280.6436767578125, 373.7521057128906, 449.9236145019531, 414.6498107910156, 415.1019592285156, 413.8980407714844, 494.97430419921875, 514.4229736328125] + data_min: optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] data_max: diff --git a/configs/dataset/mchesapeake-landcover.yaml b/configs/dataset/mchesapeake-landcover.yaml index 27ea7284..b1313fea 100644 --- a/configs/dataset/mchesapeake-landcover.yaml +++ b/configs/dataset/mchesapeake-landcover.yaml @@ -37,18 +37,11 @@ bands: - B8 data_mean: - optical: - - 0.0 - - 0.0 - - 0.0 - - 0.0 + optical: [0.4807923436164856, 0.5200885534286499, 0.4570387601852417,0.569856584072113] data_std: - optical: - - 0.0 - - 0.0 - - 0.0 - - 0.0 + optical: [0.17441707849502563, 0.1976749747991562, 0.21191735565662384, 0.2831788957118988] + data_min: optical: [0.0000, 0.0, 0.0, 0.0] diff --git a/configs/dataset/meurosat.yaml b/configs/dataset/meurosat.yaml index 133c151e..434ca3f6 100644 --- a/configs/dataset/meurosat.yaml +++ b/configs/dataset/meurosat.yaml @@ -25,14 +25,15 @@ bands: - B11 - B12 -classes: ['', '', '', '', '', '', '', '', '', ''] +classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake'] distribution: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] data_mean: - optical: [1359.9467, 1125.5342, 1054.9949, 957.3436, 1219.6685, 2051.6189, 2433.6013, 2360.0952, 751.1636, 12.2881, 1848.8993, 1131.2699, 2665.4370] + optical: [1355.5426, 1113.8855, 1035.7394, 928.2619, 1188.2629, 2032.7325, 2416.5286, 2342.5396, 748.9036, 12.0419, 1810.1284, 1101.3801, 2644.5996] data_std: - optical: [251.3307, 339.6782, 396.7326, 592.8387, 555.0112, 852.5016, 1081.5669, 1115.1144, 404.5709, 4.7965, 978.8317, 745.2798, 1223.8777] + optical: [68.9288, 160.0012, 194.6687, 286.8012, 236.6991, 372.3853, 478.1329, 556.7527, 102.5583, 1.2167, 392.9388, 313.7339, 526.7788] + data_min: optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] data_max: diff --git a/configs/dataset/mforestnet.yaml b/configs/dataset/mforestnet.yaml index 7d40bc43..2685545d 100644 --- a/configs/dataset/mforestnet.yaml +++ b/configs/dataset/mforestnet.yaml @@ -10,24 +10,38 @@ multi_temporal: False multi_modal: False ignore_index: -100 -classes: ['', '', '', '', '', '', '', '', '', '', '', ''] +classes: [ + "Oil palm plantation", + "Timber plantation", + "Other large-scale plantations", + "Grassland shrubland", + "Small-scale agriculture", + "Small-scale mixed plantation", + "Small-scale oil palm plantation", + "Mining", + "Fish pond", + "Logging", + "Secondary forest", + "Other", +] + distribution: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # data stats bands: - optical: - - B4 - - B3 + optical: # 6 bands from Landsat, but band names here are corresponding to Sentinel-2 - B2 - - B5 - - B6 - - B7 + - B3 + - B4 + - B8 + - B11 + - B12 data_mean: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - + optical: [72.852258, 83.677155, 77.58181, 123.987442, 91.536942, 74.719202] data_std: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + optical: [15.837172547567825, 14.788812599596188, 16.100543441881086, 16.35234883118129, 13.7882739778638, 12.69131413539181] + data_min: optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] data_max: diff --git a/configs/dataset/mpv4ger.yaml b/configs/dataset/mpv4ger.yaml index b0617c9a..6689ca1f 100644 --- a/configs/dataset/mpv4ger.yaml +++ b/configs/dataset/mpv4ger.yaml @@ -10,7 +10,7 @@ multi_temporal: False multi_modal: False ignore_index: -100 -classes: ['', ''] +classes: ["no solar pv", "solar pv"] distribution: [0, 0] # data stats @@ -21,10 +21,10 @@ bands: - B2 data_mean: - optical: [0.0, 0.0, 0.0] + optical: [113.385309, 119.65935, 116.628328] data_std: - optical: [0.0, 0.0, 0.0] + optical: [54.19692448815262, 48.282311849967364, 44.668890717415586] data_min: optical: [0.0, 0.0, 0.0] data_max: - optical: [0.0, 0.0, 0.0] \ No newline at end of file + optical: [255.0, 255.0, 255.0] \ No newline at end of file diff --git a/configs/dataset/msa-crop-type.yaml b/configs/dataset/msa-crop-type.yaml index 3619f7a7..1c92782d 100644 --- a/configs/dataset/msa-crop-type.yaml +++ b/configs/dataset/msa-crop-type.yaml @@ -51,9 +51,10 @@ bands: - B12 data_mean: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + optical: [12.739611, 16.526744, 26.636417, 36.696639, 46.388679, 58.281453, 63.575819, 68.1836, 69.142591, 69.904566, 83.626811, 65.767679] data_std: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + optical: [7.492811526301659, 9.329547939662671, 12.674537246073758, 19.421922023931593, 19.487411106531287, 19.959174612412983, 21.53805760692545, 23.05077775347288, 22.329695761624677, 21.877766438821954, 28.14418826277069, 27.2346215312965] + data_min: optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] data_max: diff --git a/configs/dataset/mso2sat.yaml b/configs/dataset/mso2sat.yaml index 7582d6c0..6c299646 100644 --- a/configs/dataset/mso2sat.yaml +++ b/configs/dataset/mso2sat.yaml @@ -28,10 +28,9 @@ classes: ['', '', '', '', '', '', '', '', '', '', '', '','', '', '', '', ''] distribution: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] data_mean: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - + optical: [0.12951652705669403, 0.11734361201524734, 0.11374464631080627, 0.12693354487419128, 0.16917912662029266, 0.19080990552902222, 0.18381330370903015, 0.20517952740192413, 0.1762811541557312, 0.1286638230085373] data_std: - optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + optical: [0.040680479258298874, 0.05125178396701813, 0.07254913449287415, 0.06872648745775223, 0.07402216643095016, 0.08412779122591019, 0.08534552156925201, 0.09248979389667511, 0.10270608961582184, 0.09284552931785583] data_min: optical: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] data_max: diff --git a/pangaea/datasets/geobench/mbigearthnet.py b/pangaea/datasets/geobench/mbigearthnet.py index 898a80bb..32b7db26 100644 --- a/pangaea/datasets/geobench/mbigearthnet.py +++ b/pangaea/datasets/geobench/mbigearthnet.py @@ -136,8 +136,6 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - # image = image / 4095 - # image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mbrickkiln.py b/pangaea/datasets/geobench/mbrickkiln.py index 2132f494..c77e4d40 100644 --- a/pangaea/datasets/geobench/mbrickkiln.py +++ b/pangaea/datasets/geobench/mbrickkiln.py @@ -135,8 +135,6 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - # image = image / 4095 - # image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mcashew-plantation.py b/pangaea/datasets/geobench/mcashew-plantation.py index 039fd809..4e2926eb 100644 --- a/pangaea/datasets/geobench/mcashew-plantation.py +++ b/pangaea/datasets/geobench/mcashew-plantation.py @@ -137,8 +137,6 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 4095 - image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mchesapeake-landcover.py b/pangaea/datasets/geobench/mchesapeake-landcover.py index fdeeb488..bc2b3e57 100644 --- a/pangaea/datasets/geobench/mchesapeake-landcover.py +++ b/pangaea/datasets/geobench/mchesapeake-landcover.py @@ -123,8 +123,6 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 255 - image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mforestnet.py b/pangaea/datasets/geobench/mforestnet.py index 11a0c0e9..8ed5ed49 100644 --- a/pangaea/datasets/geobench/mforestnet.py +++ b/pangaea/datasets/geobench/mforestnet.py @@ -114,9 +114,9 @@ def __getitem__(self, index): # for band in sample.bands: # print(f" {band.band_info.name}: {band.data.shape}") all_band_names = ( - "04", - "03", "02", + "03", + "04", "05", "06", "07", @@ -129,8 +129,6 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - # image = image / 255 - # image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mpv4ger-seg.py b/pangaea/datasets/geobench/mpv4ger-seg.py index dcfe8813..94c7bc73 100644 --- a/pangaea/datasets/geobench/mpv4ger-seg.py +++ b/pangaea/datasets/geobench/mpv4ger-seg.py @@ -1,5 +1,3 @@ - -import numpy as np import torch import os from pangaea.datasets.base import RawGeoFMDataset diff --git a/pangaea/datasets/geobench/mpv4ger.py b/pangaea/datasets/geobench/mpv4ger.py index 4b24550e..b5e988c7 100644 --- a/pangaea/datasets/geobench/mpv4ger.py +++ b/pangaea/datasets/geobench/mpv4ger.py @@ -111,13 +111,15 @@ def __len__(self): def __getitem__(self, index): sample = self.dataset[index] - image, band_names = sample.pack_to_3d(band_names=["Red", "Green", "Blue"]) + all_band_names = ("Blue", "Green", "Red") + rgb_bands = ("Red", "Green", "Blue") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + image, band_names = sample.pack_to_3d(band_names=BAND_SETS["rgb"]) label = sample.label filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - # image = image / 255 - # image = np.clip(image, 0, 1) + image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/msa-crop-type.py b/pangaea/datasets/geobench/msa-crop-type.py index d368e1c5..ae8f309b 100644 --- a/pangaea/datasets/geobench/msa-crop-type.py +++ b/pangaea/datasets/geobench/msa-crop-type.py @@ -135,8 +135,6 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image = image / 255 - image = np.clip(image, 0, 1) image=image.unsqueeze(1) diff --git a/pangaea/datasets/geobench/mso2sat.py b/pangaea/datasets/geobench/mso2sat.py index 47346b77..b6963213 100644 --- a/pangaea/datasets/geobench/mso2sat.py +++ b/pangaea/datasets/geobench/mso2sat.py @@ -135,7 +135,6 @@ def __getitem__(self, index): filename = sample.sample_name image = torch.from_numpy(image.transpose(2, 0, 1)).float() - # image = np.clip(image, 0, 1) image=image.unsqueeze(1) From 8f91c83e30439567d93a5e3436905e87715a9a95 Mon Sep 17 00:00:00 2001 From: yurujaja Date: Fri, 6 Jun 2025 19:37:19 +0200 Subject: [PATCH 6/6] add knn evalution --- README.md | 3 + configs/criterion/none.yaml | 2 + configs/dataset/fivebillionpixels.yaml | 2 +- configs/dataset/sen1floods11.yaml | 2 +- configs/decoder/cls_knn.yaml | 10 + configs/preprocessing/cls_resize.yaml | 3 + configs/task/knn_probe.yaml | 21 ++ ...cation.yaml => linear_classification.yaml} | 4 +- pangaea/datasets/geobench/meurosat.py | 55 ++--- pangaea/decoders/knnclassifier.py | 199 ++++++++++++++++++ pangaea/decoders/linearclassifier.py | 14 +- pangaea/encoders/dofa_encoder.py | 2 +- pangaea/encoders/gfmswin_encoder.py | 2 +- pangaea/encoders/prithvi_encoder.py | 2 +- pangaea/encoders/remoteclip_encoder.py | 2 +- pangaea/encoders/satlasnet_encoder.py | 2 +- pangaea/encoders/scalemae_encoder.py | 6 +- pangaea/encoders/spectralgpt_encoder.py | 2 +- pangaea/encoders/ssl4eo_data2vec_encoder.py | 2 +- pangaea/encoders/ssl4eo_dino_encoder.py | 2 +- pangaea/encoders/ssl4eo_mae_encoder.py | 2 +- pangaea/encoders/ssl4eo_moco_encoder.py | 2 +- pangaea/engine/data_preprocessor.py | 6 +- pangaea/engine/evaluator.py | 127 +++++++++-- pangaea/engine/trainer.py | 63 +++++- pangaea/run.py | 7 +- pangaea/utils/utils.py | 31 ++- 27 files changed, 496 insertions(+), 79 deletions(-) create mode 100644 configs/criterion/none.yaml create mode 100644 configs/decoder/cls_knn.yaml create mode 100644 configs/task/knn_probe.yaml rename configs/task/{classification.yaml => linear_classification.yaml} (79%) create mode 100644 pangaea/decoders/knnclassifier.py diff --git a/README.md b/README.md index 2dd81310..0924d8ae 100644 --- a/README.md +++ b/README.md @@ -382,3 +382,6 @@ If you find this work useful, please cite: url={https://arxiv.org/abs/2412.04204}, } ``` +## Acknowledge + +The computations/data handling were enabled by resources provided by the National Academic Infrastructure for Supercomputing in Sweden (NAISS), partially funded by the Swedish Research Council through grant agreement no. 2022-06725. \ No newline at end of file diff --git a/configs/criterion/none.yaml b/configs/criterion/none.yaml new file mode 100644 index 00000000..eb8bd931 --- /dev/null +++ b/configs/criterion/none.yaml @@ -0,0 +1,2 @@ +# returns its input unchanged – produces a 0-parameter nn.Identity module +_target_: torch.nn.Identity diff --git a/configs/dataset/fivebillionpixels.yaml b/configs/dataset/fivebillionpixels.yaml index 70e9e75c..f65730b4 100644 --- a/configs/dataset/fivebillionpixels.yaml +++ b/configs/dataset/fivebillionpixels.yaml @@ -1,6 +1,6 @@ _target_: pangaea.datasets.fivebillionpixels.FiveBillionPixels dataset_name: FiveBillionPixels -root_path: /geomatics/gpuserver-1/vmarsocci/FiveBillionPixels/cropped +root_path: /mimer/NOBACKUP/groups/naiss2024-22-857/datasets/Five-Billion-Pixels/cropped/new download_url: False auto_download: False use_cmyk: False diff --git a/configs/dataset/sen1floods11.yaml b/configs/dataset/sen1floods11.yaml index 1f7a4033..3a4cb6a1 100644 --- a/configs/dataset/sen1floods11.yaml +++ b/configs/dataset/sen1floods11.yaml @@ -1,6 +1,6 @@ _target_: pangaea.datasets.sen1floods11.Sen1Floods11 dataset_name: Sen1Floods11 -root_path: ./data/sen1floods11_v1.1 +root_path: ./data/sen1floods11 download_url: gcs_bucket: sen1floods11 auto_download: True diff --git a/configs/decoder/cls_knn.yaml b/configs/decoder/cls_knn.yaml new file mode 100644 index 00000000..7be7ab56 --- /dev/null +++ b/configs/decoder/cls_knn.yaml @@ -0,0 +1,10 @@ +_target_: pangaea.decoders.knnclassifier.KNNClassifier + +encoder: null +num_classes: ${dataset.num_classes} +finetune: false +knn_k: 200 +knn_t: 0.1 +topk: [1, 2] +normalize: true +feature_dtype: float16 \ No newline at end of file diff --git a/configs/preprocessing/cls_resize.yaml b/configs/preprocessing/cls_resize.yaml index 9a00e946..625a5e70 100644 --- a/configs/preprocessing/cls_resize.yaml +++ b/configs/preprocessing/cls_resize.yaml @@ -3,6 +3,7 @@ train: preprocessor_cfg: - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - _target_: pangaea.engine.data_preprocessor.BandPadding val: @@ -10,6 +11,7 @@ val: preprocessor_cfg: - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - _target_: pangaea.engine.data_preprocessor.BandPadding test: @@ -17,4 +19,5 @@ test: preprocessor_cfg: - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/task/knn_probe.yaml b/configs/task/knn_probe.yaml new file mode 100644 index 00000000..6f9afc47 --- /dev/null +++ b/configs/task/knn_probe.yaml @@ -0,0 +1,21 @@ +trainer: + _target_: pangaea.engine.trainer.KNNTrainer + # params overwritten in run + model: null + train_loader: null + evaluator: null + exp_dir: null + device: null + n_epochs: 1 + precision: fp32 + use_wandb: ${use_wandb} + +evaluator: + _target_: pangaea.engine.evaluator.KNNClassificationEvaluator + # params overwritten in run + val_loader: null + exp_dir: null + device: null + use_wandb: ${use_wandb} + inference_mode: null + sliding_inference_batch: null \ No newline at end of file diff --git a/configs/task/classification.yaml b/configs/task/linear_classification.yaml similarity index 79% rename from configs/task/classification.yaml rename to configs/task/linear_classification.yaml index 16ae0459..6c924a8e 100644 --- a/configs/task/classification.yaml +++ b/configs/task/linear_classification.yaml @@ -1,5 +1,5 @@ trainer: - _target_: pangaea.engine.trainer.ClassificationTrainer + _target_: pangaea.engine.trainer.LinearClassificationTrainer # params overwritten in run model: null train_loader: null @@ -20,7 +20,7 @@ trainer: use_wandb: ${use_wandb} evaluator: - _target_: pangaea.engine.evaluator.ClassificationEvaluator + _target_: pangaea.engine.evaluator.LinearClassificationEvaluator # params overwritten in run val_loader: null exp_dir: null diff --git a/pangaea/datasets/geobench/meurosat.py b/pangaea/datasets/geobench/meurosat.py index ff7fd58e..532b19dd 100644 --- a/pangaea/datasets/geobench/meurosat.py +++ b/pangaea/datasets/geobench/meurosat.py @@ -15,7 +15,7 @@ import geobench -class mEuroSat(RawGeoFMDataset): +class mEuroSat(torch.utils.data.Dataset): def __init__( self, split: str, @@ -66,41 +66,30 @@ def __init__( download_url (str): url to download the dataset. auto_download (bool): whether to download the dataset automatically. """ - super(mEuroSat, self).__init__( - split=split, - dataset_name=dataset_name, - multi_modal=multi_modal, - multi_temporal=multi_temporal, - root_path=root_path, - classes=classes, - num_classes=num_classes, - ignore_index=ignore_index, - img_size=img_size, - bands=bands, - distribution=distribution, - data_mean=data_mean, - data_std=data_std, - data_min=data_min, - data_max=data_max, - download_url=download_url, - auto_download=auto_download, - ) - + super(mEuroSat, self).__init__() + self.split = split + self.dataset_name = dataset_name + self.multi_modal = multi_modal + self.multi_temporal = multi_temporal + self.root_path = root_path + self.classes = classes + self.num_classes = num_classes + self.ignore_index = ignore_index + self.img_size = img_size + self.bands = bands + self.distribution = distribution self.data_mean = data_mean self.data_std = data_std self.data_min = data_min self.data_max = data_max - self.classes = classes - self.img_size = img_size - self.distribution = distribution - self.num_classes = num_classes - self.ignore_index = ignore_index self.download_url = download_url self.auto_download = auto_download - self.root_path = root_path + if not os.path.exists(self.root_path): + self.download(self) + + - self.split = split split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} task = geobench.load_task_specs(self.root_path) @@ -135,15 +124,13 @@ def __getitem__(self, index): image = torch.from_numpy(image.transpose(2, 0, 1)).float() - image=image.unsqueeze(1) + # image=image.unsqueeze(1) return { - "image": { - "optical": image, - }, + "image": image, "target": torch.tensor(label, dtype=torch.int64), - "metadata": { - "filename": filename}, + # "metadata": { + # "filename": filename}, } def download(self, silent=False): diff --git a/pangaea/decoders/knnclassifier.py b/pangaea/decoders/knnclassifier.py new file mode 100644 index 00000000..7e9c0758 --- /dev/null +++ b/pangaea/decoders/knnclassifier.py @@ -0,0 +1,199 @@ +import torch +from typing import Tuple +from torch import Tensor +from tqdm import tqdm +import torch.nn.functional as F + +from pangaea.decoders.base import Decoder +from pangaea.encoders.base import Encoder + + +# Obtained from https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking/knn.py +def knn_predict( + features_q: Tensor, # (B, D) + features_bank: Tensor, # (N, D) + labels_bank: Tensor, # (N,) + num_classes: int, + knn_k: int = 200, + knn_t: float = 0.1, + ) -> Tensor: + """Run kNN predictions on features based on a feature bank .Non-parametric prediction (InstDisc / MoCo style). + + This method is commonly used to monitor the performance of self-supervised + learning methods. The default parameters are the ones + used in https://arxiv.org/pdf/1805.01978v1.pdf. + + Args: + feature: + Tensor of shape (B, D) for which you want predictions, where B is the + batch size and D is the feature dimension. + feature_bank: + Tensor of shape (N, D) representing a database of features used for kNN, + where N is the number of stored feature vectors. + feature_labels: + Tensor of shape (N,) containing labels for the corresponding + feature vectors in the feature_bank. + num_classes: + Number of classes (e.g., `10` for CIFAR-10). + knn_k: + Number of k nearest neighbors used for kNN. + knn_t: + Temperature parameter to reweight similarities for kNN. + + Returns: + Tensor of shape (B, num_classes) with the predicted class indices sorted + by probability in descending order for each sample. The first index + corresponds to the most probable class. To get the top-1 prediction, + you can access `pred_labels[:, 0]`. + """ + # compute cos similarity between each feature vector and feature bank ---> (B, N) + sim_matrix = torch.mm(features_q, features_bank.T) + + sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) # (B, K) + + sim_labels = torch.gather( + labels_bank.expand(features_q.size(0), -1), dim=-1, index=sim_indices + ) # (B, K) + + # we do a reweighting of the similarities + sim_weight = (sim_weight / knn_t).exp() + + # counts for each class + one_hot_label = torch.zeros( + features_q.size(0) * knn_k, num_classes, device=sim_labels.device + ) # (B*K, C) + + one_hot_label = one_hot_label.scatter( + dim=-1, index=sim_labels.view(-1, 1), value=1.0 + )# (B*K, C) + + pred_scores = torch.sum( + one_hot_label.view(features_q.size(0), -1, num_classes) + * sim_weight.unsqueeze(dim=-1), + dim=1, + ) # (B, C) + pred_labels = pred_scores.argsort(dim=-1, descending=True) # (B, C) + return pred_labels + + +class KNNClassifier(Decoder): + """Non-parametric decoder – holds only the frozen encoder + feature bank.""" + def __init__( + self, + encoder: Encoder, + num_classes: int, + knn_k: int, + knn_t: float, + topk: Tuple[int, ...] = (1, 5), + finetune: bool = False, + normalize: bool = True, + feature_dtype: torch.dtype | str = torch.float16, + ): + """ KNN classifier decoder. + Args: + encoder (Encoder): Model used for feature extraction. Must define a forward(images) method + that returns a feature tensor. + num_classes (int): number of classes in the dataset. + knn_k (int): number of neighbours used for KNN search. + knn_t (float): temperature parameter to reweight similarities. + topk (int): Tuple of integers defining the top-k accuracy metrics to compute. + finetune (bool): whether to finetune the encoder. + """ + + super().__init__( + encoder=encoder, + num_classes=num_classes, + finetune=finetune, + ) + + self.model_name = "knn_probe" + self.encoder = encoder + self.knn_k = knn_k + self.knn_t = knn_t + self.topk = topk + self.normalize = normalize + if isinstance(feature_dtype, str): + try: + feature_dtype = getattr(torch, feature_dtype) + except AttributeError: + raise ValueError( + f"Unknown dtype string '{feature_dtype}'. " + "Use a torch.dtype or e.g. 'float16', 'bfloat16', 'float32'." + ) + self.feature_dtype = feature_dtype + self._bank: Tensor | None = None # (N, D) + self._bank_labels: Tensor | None = None + + assert self.finetune == False, "KNN classifier does not support finetuning" + for param in self.encoder.parameters(): + param.requires_grad = False + + # Now model.parameters() contains at least one trainable tensor, so DDP is happy. + self.register_parameter("_dummy", torch.nn.Parameter(torch.empty(1))) # A dummy parameter to keep .parameters() non-empty + + + def _extract_features(self, img: dict[str, Tensor]) -> Tensor: + if self.encoder.multi_temporal: + if not self.finetune: + with torch.no_grad(): + feat = self.encoder(img) + else: + feat = self.encoder(img) + # multi_temporal models can return either (B C' T>1 H' W') + # or (B C' H' W'), we need (B C' H' W') + if self.encoder.multi_temporal_output: + feat = [f.squeeze(-3) for f in feat] + else: + # remove the temporal dim + # [B C T=1 H W] -> [B C H W] + if not self.finetune: + with torch.no_grad(): + feat = self.encoder({k: v[:, :, 0, :, :] for k, v in img.items()}) + else: + feat = self.encoder({k: v[:, :, 0, :, :] for k, v in img.items()}) + + # Resize multi-scale outputs → same (H, W) and concat along channels + if isinstance(feat, (list, tuple)): + target_h = max(f.shape[-2] for f in feat) + target_w = max(f.shape[-1] for f in feat) + feat = torch.cat( + [ + f + if f.shape[-2:] == (target_h, target_w) + else F.interpolate(f, size=(target_h, target_w), + mode="bilinear", align_corners=False) + for f in feat + ], + dim=1, + ) # (B, ΣC, H, W) + + # Global-avg-pool → (B, D) + feat = F.adaptive_avg_pool2d(feat, 1).flatten(1) + if self.normalize: + feat = F.normalize(feat, dim=1) + return feat.to(self.feature_dtype) # (B, D) + + @torch.no_grad() + def build_feature_bank(self, train_loader, device): + + feats, labels = [], [] + for batch in tqdm(train_loader, + desc=f"Building kNN bank)", leave=False): + image, target = batch["image"], batch["target"] + image = {k: v.to(device) for k, v in image.items()} + target = target.to(device) + feats.append(self._extract_features(image).cpu()) + labels.append(target.cpu()) + + self._bank = torch.cat(feats).to(device) # (N, D) + self._bank_labels = torch.cat(labels).to(device) # (N,) + + @torch.no_grad() + def classify(self, img: dict[str, Tensor]) -> Tensor: + if self._bank is None: + raise RuntimeError("Feature bank empty – call build_feature_bank() first") + + q = self._extract_features(img) + return knn_predict(q, self._bank, self._bank_labels, + num_classes=self.num_classes, + knn_k=self.knn_k, knn_t=self.knn_t) # (B, C) sorted diff --git a/pangaea/decoders/linearclassifier.py b/pangaea/decoders/linearclassifier.py index db9369eb..8a1cf57b 100644 --- a/pangaea/decoders/linearclassifier.py +++ b/pangaea/decoders/linearclassifier.py @@ -3,7 +3,6 @@ import torch.nn.functional as F from pangaea.decoders.base import Decoder -from pangaea.decoders.ltae import LTAE2d, LTAEChannelAdaptor from pangaea.encoders.base import Encoder @@ -17,6 +16,17 @@ def __init__( feature_multiplier: int = 1, in_channels: list[int] | None = None, ): + """ Linear decoder for classification tasks. + + Args: + encoder (Encoder): Model used for feature extraction. Must define a forward(images) method + that returns a feature tensor. + num_classes (int): number of classes in the dataset. + finetune (bool): whether to finetune the encoder. + feature_multiplier (int, optional): feature multiplier. Defaults to 1. + in_channels (list[int], optional): input channels. Defaults to None. + """ + super().__init__( encoder=encoder, num_classes=num_classes, @@ -72,7 +82,7 @@ def forward( torch.Tensor: output tensor of shape (B, num_classes, H', W') with (H' W') coressponding to the output_shape. """ - # img[modality] of shape [B C T=1 H W] + # img[modality] of shape [B C T>1 H W] if self.encoder.multi_temporal: if not self.finetune: with torch.no_grad(): diff --git a/pangaea/encoders/dofa_encoder.py b/pangaea/encoders/dofa_encoder.py index e636a32f..4f1f5470 100644 --- a/pangaea/encoders/dofa_encoder.py +++ b/pangaea/encoders/dofa_encoder.py @@ -282,7 +282,7 @@ def forward(self, image): return output def load_encoder_weights(self, logger: Logger) -> None: - pretrained_model = torch.load(self.encoder_weights, map_location="cpu") + pretrained_model = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) k = pretrained_model.keys() pretrained_encoder = {} incompatible_shape = {} diff --git a/pangaea/encoders/gfmswin_encoder.py b/pangaea/encoders/gfmswin_encoder.py index 49b640a2..c07a9094 100644 --- a/pangaea/encoders/gfmswin_encoder.py +++ b/pangaea/encoders/gfmswin_encoder.py @@ -739,7 +739,7 @@ def no_weight_decay_keywords(self): return {"relative_position_bias_table"} def load_encoder_weights(self, logger: Logger) -> None: - pretrained_model = torch.load(self.encoder_weights, map_location="cpu") + pretrained_model = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) pretrained_model = adapt_gfm_pretrained(self, pretrained_model) k = pretrained_model.keys() diff --git a/pangaea/encoders/prithvi_encoder.py b/pangaea/encoders/prithvi_encoder.py index 7a380638..4236295f 100644 --- a/pangaea/encoders/prithvi_encoder.py +++ b/pangaea/encoders/prithvi_encoder.py @@ -115,7 +115,7 @@ def __init__( self.initialize_weights() def load_encoder_weights(self, logger: Logger) -> None: - pretrained_model = torch.load(self.encoder_weights, map_location="cpu") + pretrained_model = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) k = pretrained_model.keys() pretrained_encoder = {} incompatible_shape = {} diff --git a/pangaea/encoders/remoteclip_encoder.py b/pangaea/encoders/remoteclip_encoder.py index a904d1a0..60b4e114 100644 --- a/pangaea/encoders/remoteclip_encoder.py +++ b/pangaea/encoders/remoteclip_encoder.py @@ -439,7 +439,7 @@ def freeze(self): param.requires_grad = False def load_encoder_weights(self, logger: Logger) -> None: - pretrained_model = torch.load(self.encoder_weights, map_location="cpu") + pretrained_model = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) visual_only_model = {} for k, v in pretrained_model.items(): if k.startswith("visual."): diff --git a/pangaea/encoders/satlasnet_encoder.py b/pangaea/encoders/satlasnet_encoder.py index dc09668c..308e894f 100644 --- a/pangaea/encoders/satlasnet_encoder.py +++ b/pangaea/encoders/satlasnet_encoder.py @@ -500,7 +500,7 @@ def load_encoder_weights(self, logger: Logger) -> None: else: raise Exception(f"Failed to download weights from {self.weights_url}") - pretrained_model = torch.load(weights_file, map_location=torch.device("cpu")) + pretrained_model = torch.load(weights_file, map_location=torch.device("cpu"), weights_only=False) # If using a model for multi-image, need the Aggretation to wrap underlying backbone model. prefix, prefix_allowed_count = None, None diff --git a/pangaea/encoders/scalemae_encoder.py b/pangaea/encoders/scalemae_encoder.py index 3dbe3b44..fad3d82c 100644 --- a/pangaea/encoders/scalemae_encoder.py +++ b/pangaea/encoders/scalemae_encoder.py @@ -136,7 +136,7 @@ def _init_weights(self, m): nn.init.constant_(m.weight, 1.0) def load_encoder_weights(self, logger: Logger) -> None: - pretrained_model = torch.load(self.encoder_weights, map_location="cpu")["model"] + pretrained_model = torch.load(self.encoder_weights, map_location="cpu", weights_only=False)["model"] k = pretrained_model.keys() pretrained_encoder = {} incompatible_shape = {} @@ -149,11 +149,11 @@ def load_encoder_weights(self, logger: Logger) -> None: else: pretrained_encoder[name] = pretrained_model[name] - self.load_state_dict(pretrained_encoder, strict=False) + self.load_state_dict(pretrained_encoder, strict=False, weights_only=False) self.parameters_warning(missing, incompatible_shape, logger) def forward(self, image): - x = image["optical"].squeeze(2) + x = image B, _, h, w = x.shape x = self.patch_embed(x) diff --git a/pangaea/encoders/spectralgpt_encoder.py b/pangaea/encoders/spectralgpt_encoder.py index b8f022cc..dfc8962c 100644 --- a/pangaea/encoders/spectralgpt_encoder.py +++ b/pangaea/encoders/spectralgpt_encoder.py @@ -145,7 +145,7 @@ def __init__( ) def load_encoder_weights(self, logger: Logger) -> None: - pretrained_model = torch.load(self.encoder_weights, map_location="cpu") + pretrained_model = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) pretrained_model = pretrained_model["model"] interpolate_pos_embed(self, pretrained_model) diff --git a/pangaea/encoders/ssl4eo_data2vec_encoder.py b/pangaea/encoders/ssl4eo_data2vec_encoder.py index 28ea33cd..5b28b407 100644 --- a/pangaea/encoders/ssl4eo_data2vec_encoder.py +++ b/pangaea/encoders/ssl4eo_data2vec_encoder.py @@ -522,7 +522,7 @@ def forward(self, images): return output def load_encoder_weights(self, logger: Logger) -> None: - checkpoint = torch.load(self.encoder_weights, map_location="cpu") + checkpoint = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) pretrained_model = checkpoint["model"] k = pretrained_model.keys() diff --git a/pangaea/encoders/ssl4eo_dino_encoder.py b/pangaea/encoders/ssl4eo_dino_encoder.py index 308d8bdc..2c10423e 100644 --- a/pangaea/encoders/ssl4eo_dino_encoder.py +++ b/pangaea/encoders/ssl4eo_dino_encoder.py @@ -390,7 +390,7 @@ def forward(self, images): return output def load_encoder_weights(self, logger: Logger) -> None: - checkpoint = torch.load(self.encoder_weights, map_location="cpu") + checkpoint = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) pretrained_model = checkpoint["teacher"] pretrained_model = { k.replace("backbone.", ""): v for k, v in pretrained_model.items() diff --git a/pangaea/encoders/ssl4eo_mae_encoder.py b/pangaea/encoders/ssl4eo_mae_encoder.py index 68bbc1b5..73c059a3 100644 --- a/pangaea/encoders/ssl4eo_mae_encoder.py +++ b/pangaea/encoders/ssl4eo_mae_encoder.py @@ -157,7 +157,7 @@ def forward(self, image): return output def load_encoder_weights(self, logger: Logger) -> None: - checkpoint = torch.load(self.encoder_weights, map_location="cpu") + checkpoint = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) pretrained_model = checkpoint["model"] k = pretrained_model.keys() diff --git a/pangaea/encoders/ssl4eo_moco_encoder.py b/pangaea/encoders/ssl4eo_moco_encoder.py index 7c74635e..5926c2be 100644 --- a/pangaea/encoders/ssl4eo_moco_encoder.py +++ b/pangaea/encoders/ssl4eo_moco_encoder.py @@ -119,7 +119,7 @@ def build_2d_sincos_position_embedding(self, temperature=10000.0): self.pos_embed.requires_grad = False def load_encoder_weights(self, logger: Logger) -> None: - checkpoint = torch.load(self.encoder_weights, map_location="cpu") + checkpoint = torch.load(self.encoder_weights, map_location="cpu", weights_only=False) pretrained_model = checkpoint["state_dict"] pretrained_model = { k.replace("module.base_encoder.", ""): v diff --git a/pangaea/engine/data_preprocessor.py b/pangaea/engine/data_preprocessor.py index c229fd9c..7ed45d84 100644 --- a/pangaea/engine/data_preprocessor.py +++ b/pangaea/engine/data_preprocessor.py @@ -119,9 +119,9 @@ def __call__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - self.check_dimension(data) - for process in self.preprocessor: - data = process(data) + # self.check_dimension(data) + # for process in self.preprocessor: + # data = process(data) return data diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index ebad59dc..f28a1b6b 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -9,7 +9,9 @@ import torch import torch.nn.functional as F +from torch import Tensor from torch.utils.data import DataLoader +from pangaea.decoders.knnclassifier import KNNClassifier from tqdm import tqdm @@ -135,7 +137,7 @@ def sliding_inference(model, img, input_size, output_shape=None, stride=None, ma return merged_pred -class ClassificationEvaluator(Evaluator): +class LinearClassificationEvaluator(Evaluator): def __init__( self, val_loader, @@ -159,7 +161,7 @@ def evaluate( t = time.time() if model_ckpt_path is not None: - model_dict = torch.load(model_ckpt_path, map_location=self.device) + model_dict = torch.load(model_ckpt_path, map_location=self.device, weights_only=False) model_name = os.path.basename(model_ckpt_path).split(".")[0] if "model" in model_dict: model.module.load_state_dict(model_dict["model"]) @@ -239,10 +241,10 @@ def compute_metrics(self): def log_metrics(self, metrics: dict): def format_metric(name, value): - header = f"------- {name} --------\n" + header = f"[{self.split}] ------- {name} --------\n" value_str = ( - "-------------------\n" - + "Mean".ljust(self.max_name_len, " ") + "[{self.split}] -------------------\n" + + "[{self.split}] Mean".ljust(self.max_name_len, " ") + "\t{:>7}".format("%.3f" % value) ) return header + value_str @@ -265,7 +267,108 @@ def format_metric(name, value): }) +class KNNClassificationEvaluator(Evaluator): + """Builds a feature bank from *train_loader* and evaluates on *val_loader*.""" + + def __init__( + self, + val_loader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = "whole", + sliding_inference_batch: int = None, + use_wandb: bool = False, + ) -> None: + super().__init__(val_loader, exp_dir, device, use_wandb) + + self.logger = logging.getLogger() + + def topk_acc(self, pred_rank: Tensor, target: Tensor, k: int) -> float: + return (pred_rank[:, :k] == target.unsqueeze(1)).any(1).float().mean().item() + + def evaluate( + self, + model: KNNClassifier, + train_loader: DataLoader, # used to build feature bank + model_name: str, + model_ckpt_path: str | Path | None = None): + + """Build bank on the *current* training data, then run k-NN on val/test.""" + t0 = time.time() + if model_ckpt_path is not None: + model_dict = torch.load(model_ckpt_path, map_location=self.device, weights_only=False) + model_name = os.path.basename(model_ckpt_path).split(".")[0] + if "model" in model_dict: + model.module.load_state_dict(model_dict["model"]) + else: + model.module.load_state_dict(model_dict) + self.logger.info(f"Loaded {model_name} for evaluation") + + model.eval() + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + if model._bank is None or model._bank_labels is None: + if train_loader is None: + raise ValueError("train_loader is required to build feature bank for k-NN probe") + model.build_feature_bank(train_loader, self.device) + self.topk = model.topk + total = 0 + topk_correct = {k: 0 for k in self.topk} + + for batch in tqdm(self.val_loader, + desc=f"kNN-eval", leave=False): + image, target = batch["image"], batch["target"] + image = {k: v.to(self.device) for k, v in image.items()} + target = target.to(self.device) + + pred_rank = model.classify(image) # (B, C) ints + bsz = target.size(0) + total += bsz + + for k in self.topk: + topk_correct[k] += self.topk_acc(pred_rank, target, k) * bsz + + metrics = {f"top{k}": topk_correct[k] / total for k in self.topk} + self.log_metrics(metrics) + + return metrics, time.time() - t0 + + def __call__(self, model, model_name, model_ckpt_path=None, train_loader=None): + return self.evaluate(model, train_loader, model_name, model_ckpt_path) + + + def log_metrics(self, metrics: dict[str, float]) -> None: +# ensure we have something to align to (reuse class-name length trick) + if not hasattr(self, "max_name_len"): + # 4 is the length of the word "Mean" – keeps the columns aligned + self.max_name_len = 4 + def format_metric(name: str, value: float) -> str: + header = f"[{self.split}] ------- {name} --------\n" + value_str = ( + f"[{self.split}] -------------------\n" + + f"[{self.split}] Mean".ljust(self.max_name_len, " ") + + "\t{:>7}".format("%.3f" % value) + ) + return header + value_str + + top1_str = format_metric("Top-1 Acc", metrics["top1"]) + top5_str = format_metric("Top-2 Acc", metrics["top2"]) + + self.logger.info(top1_str) + self.logger.info(top5_str) + + # optional Weights & Biases logging + if getattr(self, "use_wandb", False) and getattr(self, "rank", 0) == 0: + import wandb # local import keeps dependency optional + wandb.log( + { + f"{self.split}_top1": metrics["top1"], + f"{self.split}_top2": metrics["top2"], + } + ) + + class SegEvaluator(Evaluator): """ SegEvaluator is a class for evaluating segmentation models. It extends the Evaluator class and provides methods @@ -398,7 +501,7 @@ def compute_metrics(self, confusion_matrix): def log_metrics(self, metrics): def format_metric(name, values, mean_value): - header = f"------- {name} --------\n" + header = f"[{self.split}] ------- {name} --------\n" metric_str = ( "\n".join( c.ljust(self.max_name_len, " ") + "\t{:>7}".format("%.3f" % num) @@ -407,8 +510,8 @@ def format_metric(name, values, mean_value): + "\n" ) mean_str = ( - "-------------------\n" - + "Mean".ljust(self.max_name_len, " ") + f"[{self.split}]-------------------\n" + + f"[{self.split}] Mean".ljust(self.max_name_len, " ") + "\t{:>7}".format("%.3f" % mean_value) ) return header + metric_str + mean_str @@ -489,7 +592,7 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): t = time.time() if model_ckpt_path is not None: - model_dict = torch.load(model_ckpt_path, map_location=self.device) + model_dict = torch.load(model_ckpt_path, map_location=self.device, weights_only=False) model_name = os.path.basename(model_ckpt_path).split('.')[0] if 'model' in model_dict: model.module.load_state_dict(model_dict["model"]) @@ -535,9 +638,9 @@ def __call__(self, model, model_name='model', model_ckpt_path=None): return self.evaluate(model, model_name, model_ckpt_path) def log_metrics(self, metrics): - header = "------- MSE and RMSE --------\n" - mse = "-------------------\n" + 'MSE \t{:>7}'.format('%.3f' % metrics['MSE']) + '\n' - rmse = "-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE']) + header = f"[{self.split}] ------- MSE and RMSE --------\n" + mse = f"[{self.split}]-------------------\n" + 'MSE \t{:>7}'.format('%.3f' % metrics['MSE']) + '\n' + rmse = f"[{self.split}]-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE']) self.logger.info(header + mse + rmse) if self.use_wandb and self.rank == 0: diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index a9a8788c..8b65ad41 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -20,7 +20,7 @@ def __init__( self, model: nn.Module, train_loader: DataLoader, - criterion: nn.Module, + criterion: nn.Module | None, optimizer: Optimizer, lr_scheduler: LRScheduler, evaluator: torch.nn.Module, @@ -230,7 +230,7 @@ def load_model(self, resume_path: str | pathlib.Path) -> None: Args: resume_path (str | pathlib.Path): path to the checkpoint. """ - model_dict = torch.load(resume_path, map_location=self.device) + model_dict = torch.load(resume_path, map_location=self.device, weights_only=False) if "model" in model_dict: self.model.module.load_state_dict(model_dict["model"]) self.optimizer.load_state_dict(model_dict["optimizer"]) @@ -354,7 +354,7 @@ def reset_stats(self) -> None: v.reset() -class ClassificationTrainer(Trainer): +class LinearClassificationTrainer(Trainer): def __init__( self, model: nn.Module, @@ -476,7 +476,64 @@ def compute_logging_metrics( self.training_metrics["F1"].update(f1.item()) +class KNNTrainer(Trainer): + """A zero-learning shell so run.py can stay unchanged.""" + + + def __init__( + self, + model: nn.Module, # should be KNNClassifier + train_loader: DataLoader, + evaluator, + lr_scheduler, + optimizer, + criterion, + exp_dir: pathlib.Path | str, + device: torch.device, + n_epochs: int, + precision: str, + use_wandb: bool, + ): + dummy_opt = torch.optim.SGD([torch.empty(0, device=device, requires_grad=True)], lr=1) + dummy_sched = torch.optim.lr_scheduler.LambdaLR(dummy_opt, lambda _: 1) + + super().__init__( + model=model, + train_loader=train_loader, + criterion=nn.Identity(), # never used + optimizer=dummy_opt, + lr_scheduler=dummy_sched, + evaluator=evaluator, + n_epochs=n_epochs, + exp_dir=exp_dir, + device=device, + precision=precision, + use_wandb=use_wandb, + ckpt_interval=999, + eval_interval=1, + log_interval=999, + best_metric_key="top1", + ) + self.logger: logging.Logger = logging.getLogger() + self.train_loader = train_loader + # ------------------------------------------------------------------ # + def train(self): + self.logger.info("=========== k-NN evaluation only ===========") + self.evaluator(self.model, model_name="probe", train_loader=self.train_loader) + self.logger.info("============================================") + dummy_path1 = os.path.join(self.exp_dir, "checkpoint_dummy_best.pth") + dummy_path2 = os.path.join(self.exp_dir, "checkpoint_dummy_final.pth") + if self.rank == 0 and not os.path.exists(dummy_path1): + torch.save({"knn_probe": True}, dummy_path1) + if self.rank == 0 and not os.path.exists(dummy_path2): + torch.save({"knn_probe": True}, dummy_path2) + + # never called + def compute_loss(self, logits, target): ... + def compute_logging_metrics(self, logits, target): ... + + class SegTrainer(Trainer): def __init__( self, diff --git a/pangaea/run.py b/pangaea/run.py index 141eb4fe..7f6bcea3 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -151,7 +151,7 @@ def main(cfg: DictConfig) -> None: collate_fn = get_collate_fn(modalities) # training - if train_run: + if train_run or cfg.task.trainer.model_name == "knn_probe": # get preprocessor train_preprocessor = instantiate( cfg.preprocessing.train, @@ -262,6 +262,7 @@ def main(cfg: DictConfig) -> None: trainer.train() + # Evaluation test_preprocessor = instantiate( cfg.preprocessing.test, @@ -292,6 +293,10 @@ def main(cfg: DictConfig) -> None: model_ckpt_path = get_final_model_ckpt_path(exp_dir) else: model_ckpt_path = get_best_model_ckpt_path(exp_dir) + + if model_ckpt_path is None and not cfg.task.trainer.model_name == "knn_probe": + raise ValueError(f"No model checkpoint found in {exp_dir}") + test_evaluator.evaluate(decoder, "test_model", model_ckpt_path) if cfg.use_wandb and rank == 0: diff --git a/pangaea/utils/utils.py b/pangaea/utils/utils.py index 0960f553..abbc79b0 100644 --- a/pangaea/utils/utils.py +++ b/pangaea/utils/utils.py @@ -1,10 +1,13 @@ import os as os import random from pathlib import Path +from typing import Optional import numpy as np import torch +import logging +_log = logging.getLogger(__name__) def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 @@ -41,12 +44,26 @@ def prepare_input(input_res): return dict(img=image) -def get_best_model_ckpt_path(exp_dir: str | Path) -> str: - return os.path.join( - exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_best.pth")) +def _find_ckpt(exp_dir: str | Path, suffix: str) -> Optional[str]: + """Return the *first* file that ends with `suffix`; None if nothing found.""" + exp_dir = Path(exp_dir) + for fname in exp_dir.iterdir(): + if fname.name.endswith(suffix): + return str(fname) + # Nothing found – warn once. + _log.warning( + "No checkpoint matching '*%s' found in %s. " + "If this was a k-NN probe (no training), you can ignore this warning. Otherwise, check your experiment directory.", + suffix, exp_dir, ) + return None -def get_final_model_ckpt_path(exp_dir: str | Path) -> str: - return os.path.join( - exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_final.pth")) - ) + +def get_best_model_ckpt_path(exp_dir: str | Path) -> Optional[str]: + """Return '/…_best.pth' or None when it does not exist.""" + return _find_ckpt(exp_dir, "_best.pth") + + +def get_final_model_ckpt_path(exp_dir: str | Path) -> Optional[str]: + """Return '/…_final.pth' or None when it does not exist.""" + return _find_ckpt(exp_dir, "_final.pth") \ No newline at end of file