From cf83455e1f416dadc64bf126da316610083153de Mon Sep 17 00:00:00 2001 From: pierreadorni Date: Fri, 3 Oct 2025 17:12:02 +0200 Subject: [PATCH 1/2] add timm encoder and dinov3 config --- configs/encoder/dinov3_vitl_sat.yaml | 10 ++++ pangaea/encoders/timm_encoder.py | 73 ++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 configs/encoder/dinov3_vitl_sat.yaml create mode 100644 pangaea/encoders/timm_encoder.py diff --git a/configs/encoder/dinov3_vitl_sat.yaml b/configs/encoder/dinov3_vitl_sat.yaml new file mode 100644 index 00000000..5cc58114 --- /dev/null +++ b/configs/encoder/dinov3_vitl_sat.yaml @@ -0,0 +1,10 @@ +_target_: pangaea.encoders.timm_encoder.TimmEncoder +timm_model_name: vit_large_patch16_dinov3.sat493m +input_size: 512 + +input_bands: + optical: + - B2 + - B3 + - B4 + diff --git a/pangaea/encoders/timm_encoder.py b/pangaea/encoders/timm_encoder.py new file mode 100644 index 00000000..6726253f --- /dev/null +++ b/pangaea/encoders/timm_encoder.py @@ -0,0 +1,73 @@ +from logging import Logger + +import torch +import numpy as np +import timm + +from pangaea.encoders.base import Encoder + + +torch.autograd.set_detect_anomaly(True) + + +class TimmEncoder(Encoder): + """Timm Encoder class.""" + + def __init__( + self, + timm_model_name: str, + input_bands: dict[str, list[str]], + input_size: int, + ) -> None: + """Initialize the TimmEncoder. + + Args: + model_name (str): Name of the model. + input_bands (dict[str, list[str]]): Input bands. + input_size (int): Input size. + embed_dim (int): Embedding dimension. + output_layers (list[int]): Output layers. + logger (Logger, optional): Logger instance. Defaults to None. + """ + + super().__init__( + model_name=timm_model_name, + input_bands=input_bands, + input_size=input_size, + embed_dim=0, # will be set after model is created + output_layers=[], + output_dim=0, # will be set after model is created + multi_temporal=False, + multi_temporal_output=False, + pyramid_output=False, # will be set after model is created + encoder_weights="", + download_url="", + ) + + try: + self.model = timm.create_model( + timm_model_name + pretrained=True, + features_only=True, + ) + except RuntimeError as e: + raise RuntimeError( + f"Error loading Timm model {timm_model_name}. " + "Please ensure that the timm_model_name parameter is correct and that you have internet access to download the weights." + ) from e + + self.embed_dim = self.model.model.embed_dim + self.output_dim = self.model.feature_info.channels() + self.output_layers = list(range(len(self.output_dim))) + self.pyramid_output = np.unique(self.output_dim).shape[0] > 1 # pyramid if multiple unique output dims + + def forward(self, image): + x = torch.cat( + [image[modality].squeeze(2) for modality in self.input_bands], dim=1 + ) + + pred = self.model(x) + return [v for v in pred] + + def load_encoder_weights(self, logger: Logger) -> None: + pass From e6c03d6e2f14b2963575e5b0b5972a945f00623c Mon Sep 17 00:00:00 2001 From: pierreadorni Date: Wed, 8 Oct 2025 13:33:04 +0200 Subject: [PATCH 2/2] fix missing comma --- pangaea/encoders/timm_encoder.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pangaea/encoders/timm_encoder.py b/pangaea/encoders/timm_encoder.py index 6726253f..79cdfa6a 100644 --- a/pangaea/encoders/timm_encoder.py +++ b/pangaea/encoders/timm_encoder.py @@ -7,9 +7,6 @@ from pangaea.encoders.base import Encoder -torch.autograd.set_detect_anomaly(True) - - class TimmEncoder(Encoder): """Timm Encoder class.""" @@ -46,7 +43,7 @@ def __init__( try: self.model = timm.create_model( - timm_model_name + timm_model_name, pretrained=True, features_only=True, )