diff --git a/configs/encoder/galileo_base.yaml b/configs/encoder/galileo_base.yaml new file mode 100644 index 00000000..010f2716 --- /dev/null +++ b/configs/encoder/galileo_base.yaml @@ -0,0 +1,97 @@ +_target_: pangaea.encoders.galileo_encoder.galileo_base +models_folder: ./pretrained_models/galileo/models/ +encoder_weights: ./pretrained_models/galileo/models/base/encoder.pt +download_url: https://huggingface.co/nasaharvest/galileo/resolve/main/models/ + +# For sizes other than 'base': see config in pretrained_modesl/galileo/[size]/config.json +input_size: 80 +size: 'base' + +# These parameters are hard-coded for each model size in pangaea.encoders.galileo.gelileo_{base/nano/tiny} +embed_dim: 768 # same as "embedding_size" from original implementation of Galileo Encoder +# depth: 12 +# num_heads: 12 +# mlp_ratio: 4 +# max_sequence_length: 24 +freeze_projections: ${finetune} +# drop_path: 0.1 +# max_patch_size: 8 +do_pool: false + +data_bands: ${dataset.bands} +data_gsd: 10. # ${dataset.gsd} + +input_bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B10 + - B11 + - B12 + sar: + - VV + - VH + # Other modalities: not implemented yet + # era5: + # - temperature_2m + # - total_precipitation_sum + # tc: + # - def + # - soil + # - aet + # viirs: + # - avg_rad + # srtm: + # - elevation + # - slope + # dw: + # - DW_water + # - DW_trees + # - DW_grass + # - DW_flooded_vegetation + # - DW_crops + # - DW_shrub_and_scrub + # - DW_built + # "DW_bare", + # - DW_snow_and_ice + # wc: + # - WC_temporarycrops + # - WC_maize + # - WC_wintercereals + # - WC_springcereals + # - WC_irrigation + +token_exit_cfg: + S1: 12 + S2_RGB: 12 + S2_Red_Edge: 12 + S2_NIR_10m: 12 + S2_NIR_20m: 12 + S2_SWIR: 12 + NDVI: 6 + ERA5: 6 + TC: 6 + VIIRS: 12 + SRTM: 6 + DW: 0 + WC: 0 + LS: 0 + location: 12 + DW_static: 0 + WC_static: 0 + +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 768 diff --git a/configs/preprocessing/seg_default_moddrop.yaml b/configs/preprocessing/seg_default_moddrop.yaml new file mode 100644 index 00000000..703d1c85 --- /dev/null +++ b/configs/preprocessing/seg_default_moddrop.yaml @@ -0,0 +1,24 @@ +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.ModalityDrop + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding + +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ModalityDrop + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding + +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ModalityDrop + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/preprocessing/seg_focus_crop_moddrop.yaml b/configs/preprocessing/seg_focus_crop_moddrop.yaml new file mode 100644 index 00000000..3d9c2460 --- /dev/null +++ b/configs/preprocessing/seg_focus_crop_moddrop.yaml @@ -0,0 +1,25 @@ +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.FocusRandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.ModalityDrop + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding + +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ModalityDrop + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding + +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ModalityDrop + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding + diff --git a/pangaea/encoders/galileo_encoder.py b/pangaea/encoders/galileo_encoder.py new file mode 100644 index 00000000..f68d06ef --- /dev/null +++ b/pangaea/encoders/galileo_encoder.py @@ -0,0 +1,1682 @@ +import os +import collections.abc +import itertools +import json +import math +from logging import Logger +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import OrderedDict as OrderedDictType + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor, vmap +from torch.jit import Final + +import urllib +from pangaea.encoders.base import Encoder, DownloadProgressBar + + +# constants +CONFIG_FILENAME = "config.json" +ENCODER_FILENAME = "encoder.pt" +BASE_GSD = 10 + +# band information +S1_BANDS = ["VV", "VH"] +S2_BANDS = [ + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8", + "B8A", + "B11", + "B12", +] +ERA5_BANDS = ["temperature_2m", "total_precipitation_sum"] +TC_BANDS = ["def", "soil", "aet"] +VIIRS_BANDS = ["avg_rad"] +SRTM_BANDS = ["elevation", "slope"] +DW_BANDS = [ + "DW_water", + "DW_trees", + "DW_grass", + "DW_flooded_vegetation", + "DW_crops", + "DW_shrub_and_scrub", + "DW_built", + "DW_bare", + "DW_snow_and_ice", +] +WC_BANDS = [ + "WC_temporarycrops", + "WC_maize", + "WC_wintercereals", + "WC_springcereals", + "WC_irrigation", +] +STATIC_DW_BANDS = [f"{x}_static" for x in DW_BANDS] +STATIC_WC_BANDS = [f"{x}_static" for x in WC_BANDS] + +LANDSCAN_BANDS = ["b1"] +LOCATION_BANDS = ["x", "y", "z"] + +SPACE_TIME_BANDS = S1_BANDS + S2_BANDS + ["NDVI"] +TIME_BANDS = ERA5_BANDS + TC_BANDS + VIIRS_BANDS +SPACE_BANDS = SRTM_BANDS + DW_BANDS + WC_BANDS +STATIC_BANDS = LANDSCAN_BANDS + LOCATION_BANDS + STATIC_DW_BANDS + STATIC_WC_BANDS + + +SPACE_TIME_BANDS_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( + { + "S1": [SPACE_TIME_BANDS.index(b) for b in S1_BANDS], + "S2_RGB": [SPACE_TIME_BANDS.index(b) for b in ["B2", "B3", "B4"]], + "S2_Red_Edge": [SPACE_TIME_BANDS.index(b) for b in ["B5", "B6", "B7"]], + "S2_NIR_10m": [SPACE_TIME_BANDS.index(b) for b in ["B8"]], + "S2_NIR_20m": [SPACE_TIME_BANDS.index(b) for b in ["B8A"]], + "S2_SWIR": [SPACE_TIME_BANDS.index(b) for b in ["B11", "B12"]], + "NDVI": [SPACE_TIME_BANDS.index("NDVI")], + } +) + +TIME_BAND_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( + { + "ERA5": [TIME_BANDS.index(b) for b in ERA5_BANDS], + "TC": [TIME_BANDS.index(b) for b in TC_BANDS], + "VIIRS": [TIME_BANDS.index(b) for b in VIIRS_BANDS], + } +) + +SPACE_BAND_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( + { + "SRTM": [SPACE_BANDS.index(b) for b in SRTM_BANDS], + "DW": [SPACE_BANDS.index(b) for b in DW_BANDS], + "WC": [SPACE_BANDS.index(b) for b in WC_BANDS], + } +) + +STATIC_BAND_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( + { + "LS": [STATIC_BANDS.index(b) for b in LANDSCAN_BANDS], + "location": [STATIC_BANDS.index(b) for b in LOCATION_BANDS], + "DW_static": [STATIC_BANDS.index(b) for b in STATIC_DW_BANDS], + "WC_static": [STATIC_BANDS.index(b) for b in STATIC_WC_BANDS], + } +) + + +def get_2d_sincos_pos_embed_with_resolution( + embed_dim, grid_size, res, cls_token=False, device="cpu" +): + """ + grid_size: int of the grid height and width + res: array of size n, representing the resolution of a pixel (say, in meters), + return: + pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + res = res.to(device) + grid_h = torch.arange(grid_size, device=device) + grid_w = torch.arange(grid_size, device=device) + grid = torch.meshgrid( + grid_w, grid_h, indexing="xy" + ) # here h goes first,direction reversed for numpy + grid = torch.stack(grid, dim=0) # 2 x h x w + + # grid = grid.reshape([2, 1, grid_size, grid_size]) + grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w + _, n, h, w = grid.shape + pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # # (nxH*W, D/2) + pos_embed = pos_embed.reshape(n, h * w, embed_dim) + if cls_token: + pos_embed = torch.cat( + [ + torch.zeros([n, 1, embed_dim], device=pos_embed.device), + pos_embed, + ], + dim=1, + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, device=pos.device) / embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + +def get_month_encoding_table(embed_dim): + """Sinusoid month encoding table, for 12 months indexed from 0-11""" + assert embed_dim % 2 == 0 + angles = torch.arange(0, 13) / (12 / (2 * np.pi)) + + sin_table = torch.sin(torch.stack([angles for _ in range(embed_dim // 2)], axis=-1)) + cos_table = torch.cos(torch.stack([angles for _ in range(embed_dim // 2)], axis=-1)) + month_table = torch.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1) + + return month_table # (M, D) + + +def adjust_learning_rate( + optimizer, + epoch, + warmup_epochs, + total_epochs, + max_lr, + min_lr, +): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < warmup_epochs: + lr = max_lr * epoch / warmup_epochs + else: + lr = min_lr + (max_lr - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)) + ) + for group in optimizer.param_groups: + group["lr"] = lr + return lr + + +# thanks to https://github.com/bwconrad/flexivit/ for this nice implementation +# of the FlexiPatchEmbed module +def to_2tuple(x: Any) -> Tuple: + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(itertools.repeat(x, 2)) + + +class FlexiPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int]], + in_chans: int = 3, + embed_dim: int = 128, + norm_layer: Optional[nn.Module] = None, + bias: bool = True, + patch_size_seq: Sequence[int] = (1, 2, 3, 4, 5, 6), + interpolation: str = "bicubic", + antialias: bool = True, + ) -> None: + """2D image to patch embedding w/ flexible patch sizes + Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24 + by https://github.com/bwconrad/flexivit/ + + Args: + patch_size: Base patch size. i.e the size of the parameter buffer + in_chans: Number of input image channels + embed_dim: Network embedding dimension size + norm_layer: Optional normalization layer + bias: Whether to use bias in convolution + patch_size_seq: List of patch sizes to randomly sample from + interpolation: Resize interpolation type + antialias: Whether to apply antialiasing resizing + """ + super().__init__() + + self.patch_size = to_2tuple(patch_size) + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=bias, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + # Flexi specific attributes + self.interpolation = interpolation + self.antialias = antialias + + self.patch_size_seq = patch_size_seq + + # Pre-calculate pinvs + self.pinvs = self._cache_pinvs() + + def _cache_pinvs(self) -> dict: + """Pre-calculate all pinv matrices""" + pinvs = {} + for ps in self.patch_size_seq: + tuple_ps = to_2tuple(ps) + pinvs[tuple_ps] = self._calculate_pinv(self.patch_size, tuple_ps) + return pinvs + + def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: + x_resized = F.interpolate( + x[None, None, ...], + shape, + mode=self.interpolation, + antialias=self.antialias, + ) + return x_resized[0, 0, ...] + + def _calculate_pinv(self, old_shape: Tuple[int, int], new_shape: Tuple[int, int]) -> Tensor: + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = torch.zeros(old_shape) + basis_vec[np.unravel_index(i, old_shape)] = 1.0 + mat.append(self._resize(basis_vec, new_shape).reshape(-1)) + resize_matrix = torch.stack(mat) + return torch.linalg.pinv(resize_matrix) + + def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]): + """Resize patch_embed to target resolution via pseudo-inverse resizing""" + # Return original kernel if no resize is necessary + if self.patch_size == new_patch_size: + return patch_embed + + # Calculate pseudo-inverse of resize matrix + if new_patch_size not in self.pinvs: + self.pinvs[new_patch_size] = self._calculate_pinv(self.patch_size, new_patch_size) + pinv = self.pinvs[new_patch_size] + pinv = pinv.to(patch_embed.device) + + def resample_patch_embed(patch_embed: Tensor): + h, w = new_patch_size + resampled_kernel = pinv @ patch_embed.reshape(-1) + return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w) + + v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) + + return v_resample_patch_embed(patch_embed) + + def forward( + self, + x: Tensor, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + ) -> Union[Tensor, Tuple[Tensor, Tuple[int, int]]]: + # x has input shape [b, h, w, (t), c] + batch_size = x.shape[0] + has_time_dimension = False + num_timesteps = 0 # ignored if has_time_dimension is False + if len(x.shape) == 5: + has_time_dimension = True + num_timesteps = x.shape[3] + x = rearrange(x, "b h w t c -> (b t) c h w") + else: + x = rearrange(x, "b h w c -> b c h w") + + if not patch_size: + # During evaluation use base patch size if not specified + patch_size = self.patch_size + + patch_size = to_2tuple(patch_size) + + # Resize conv weights + if patch_size == self.patch_size: + weight = self.proj.weight + else: + weight = self.resize_patch_embed(self.proj.weight, patch_size) + # Apply conv with resized weights + x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size) + + if has_time_dimension: + x = rearrange(x, "(b t) c h w -> b h w t c", b=batch_size, t=num_timesteps) + else: + x = rearrange(x, "b c h w -> b h w c") + x = self.norm(x) + + return x + + +class Attention(nn.Module): + # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py + fast_attn: Final[bool] + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + norm_layer=nn.LayerNorm, + cross_attn: bool = False, + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fast_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention") # FIXME + + self.cross_attn = cross_attn + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y=None, attn_mask=None): + B, N, C = x.shape + + q = self.q(x) + + if y is None: + assert not self.cross_attn + k = self.k(x) + v = self.v(x) + else: + assert self.cross_attn + k = self.k(y) + v = self.v(y) + + q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) + + q, k = self.q_norm(q), self.k_norm(k) + if self.fast_attn: + if attn_mask is not None: + attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, N, 1)) + x = F.scaled_dot_product_attention( + q, + k, + v, + # a value of True indicates that the element should take part in attention + attn_mask=attn_mask, + dropout_p=self.attn_drop.p, + ) + else: + if attn_mask is not None: + raise NotImplementedError + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_norm=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + cross_attn: bool = False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=drop, + norm_layer=norm_layer, + cross_attn=cross_attn, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + + def forward(self, x, y, attn_mask): + x = x + self.drop_path(self.ls1(self.attn(self.norm1(x), y, attn_mask))) + x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ModuleListWithInit(nn.ModuleList): + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class GalileoBase(nn.Module): + cross_attn: bool + + def __init__( + self, + embedding_size: int = 128, + depth=2, + mlp_ratio=2, + num_heads=8, + max_sequence_length=24, + base_patch_size: int = 4, + use_channel_embs: bool = True, + drop_path: float = 0.0, + ): + super().__init__() + + self.space_time_groups = SPACE_TIME_BANDS_GROUPS_IDX + self.space_groups = SPACE_BAND_GROUPS_IDX + self.time_groups = TIME_BAND_GROUPS_IDX + self.static_groups = STATIC_BAND_GROUPS_IDX + self.embedding_size = embedding_size + self.base_patch_size = base_patch_size + + self.blocks = ModuleListWithInit( + [ + Block( + embedding_size, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=nn.LayerNorm, + cross_attn=self.cross_attn, + drop_path=drop_path, + ) + for _ in range(depth) + ] + ) + + self.max_sequence_length = max_sequence_length + # we have 4 embeddings (pos_in_time, pos_in_space, month, channel) so each get + # 0.25 of the dimension. This will change soon anyway + self.pos_embed = nn.Parameter( + get_1d_sincos_pos_embed_from_grid_torch( + int(embedding_size * 0.25), torch.arange(max_sequence_length) + ), + requires_grad=False, + ) + month_tab = get_month_encoding_table(int(embedding_size * 0.25)) + self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True) + if use_channel_embs: + args = {"requires_grad": True} + else: + args = {"requires_grad": False} + self.s_t_channel_embed = nn.Parameter( + torch.zeros(len(SPACE_TIME_BANDS_GROUPS_IDX), int(embedding_size * 0.25)), **args + ) + self.sp_channel_embed = nn.Parameter( + torch.zeros(len(SPACE_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args + ) + self.t_channel_embed = nn.Parameter( + torch.zeros(len(TIME_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args + ) + self.st_channel_embed = nn.Parameter( + torch.zeros(len(STATIC_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @classmethod + def collapse_and_combine_hwtc( + cls, + s_t_x: torch.Tensor, + sp_x: torch.Tensor, + t_x: torch.Tensor, + st_x: torch.Tensor, + s_t_m: torch.Tensor, + sp_m: torch.Tensor, + t_m: torch.Tensor, + st_m: torch.Tensor, + ): + s_t_x = rearrange(s_t_x, "b h w t c_g d -> b (h w t c_g) d") + sp_x = rearrange(sp_x, "b h w c_g d -> b (h w c_g) d") + t_x = rearrange(t_x, "b t c_g d -> b (t c_g) d") + + s_t_m = rearrange(s_t_m, "b h w t c_g-> b (h w t c_g)") + sp_m = rearrange(sp_m, "b h w c_g-> b (h w c_g)") + t_m = rearrange(t_m, "b t c_g -> b (t c_g)") + + x = torch.cat( + [ + s_t_x, + sp_x, + t_x, + st_x, + ], + dim=1, + ) + m = torch.cat([s_t_m, sp_m, t_m, st_m], dim=1) + return x, m + + @classmethod + def split_and_expand_hwtc( + cls, + x: torch.Tensor, + h: int, + w: int, + t: int, + s_t_c_g: int, + sp_c_g: int, + t_c_g: int, + st_c_g: int, + ): + n_s_t_t = h * w * t * s_t_c_g + n_t_t = t * t_c_g + + s_t_x = rearrange(x[:, :n_s_t_t], "b (h w t c) d -> b h w t c d", h=h, w=w, t=t, c=s_t_c_g) + sp_x = rearrange( + x[:, n_s_t_t : -(n_t_t + st_c_g)], "b (h w c) d -> b h w c d", h=h, w=w, c=sp_c_g + ) + t_x = rearrange(x[:, -(n_t_t + st_c_g) : -st_c_g], "b (t c) d -> b t c d", t=t, c=t_c_g) + st_x = x[:, -st_c_g:] + + return s_t_x, sp_x, t_x, st_x + + def apply_encodings(self, s_t_x, sp_x, t_x, st_x, months, patch_size, input_res): + b, h, w, t, s_t_c_g, _ = s_t_x.shape + sp_c_g, t_c_g = sp_x.shape[-2], t_x.shape[-2] + st_c_g = st_x.shape[-2] + + s_t_channel = repeat(self.s_t_channel_embed, "c_g d -> b h w t c_g d", b=b, h=h, w=w, t=t) + t_channel = repeat(self.t_channel_embed, "c_g d -> b t c_g d", b=b, t=t) + st_channel = repeat(self.st_channel_embed, "c_g d -> b c_g d", b=b) + sp_channel = repeat(self.sp_channel_embed, "c_g d -> b h w c_g d", b=b, h=h, w=w) + + pos_embed_s_t = repeat( + self.pos_embed[:t], "t d -> b h w t c_g d", b=b, h=h, w=w, c_g=s_t_c_g + ) + m_embed_s_t = repeat( + self.month_embed(months), "b t d -> b h w t c_g d", h=h, w=w, c_g=s_t_c_g + ) + + pos_embed_t = repeat(self.pos_embed[:t], "t d -> b t c_g d", b=b, c_g=t_c_g) + m_embed_t = repeat(self.month_embed(months), "b t d -> b t c_g d", c_g=t_c_g) + t_zeros = torch.zeros(b, t, t_c_g, int(self.embedding_size * 0.25), device=t_x.device) + + sp_zeros = torch.zeros( + b, + h, + w, + sp_c_g, + sp_channel.shape[-1] * 2, + device=sp_channel.device, + ) + + st_zeros = torch.zeros(b, st_c_g, st_channel.shape[-1] * 3, device=st_channel.device) + + # find the resolution that each token represents, which will be + # the number of pixels in a patch * the resolution of each pixel + if patch_size is None: + patch_size = self.base_patch_size + token_res = input_res * patch_size + gsd_ratio = token_res / BASE_GSD + + assert h == w, "get_2d_sincos_pos_embed_with_resolution currently requires that h==w" + spatial_embed = get_2d_sincos_pos_embed_with_resolution( + int(self.embedding_size * 0.25), + h, + torch.ones(b).to(s_t_x.device) * gsd_ratio, + device=s_t_x.device, + ) + spatial_embed = rearrange(spatial_embed, "b (h w) d -> b h w d", h=h, w=w) + spatial_embed_s_t = repeat( + spatial_embed, "b h w d -> b h w t c_g d", h=h, w=w, t=t, c_g=s_t_c_g + ) + spatial_embed_s = repeat(spatial_embed, "b h w d -> b h w c_g d", h=h, w=w, c_g=sp_c_g) + + s_t_embed = torch.cat([s_t_channel, pos_embed_s_t, m_embed_s_t, spatial_embed_s_t], dim=-1) + sp_embed = torch.cat([sp_channel, sp_zeros, spatial_embed_s], dim=-1) + t_embed = torch.cat([t_channel, pos_embed_t, m_embed_t, t_zeros], dim=-1) + st_embed = torch.cat([st_channel, st_zeros], dim=-1) + return s_t_x + s_t_embed, sp_x + sp_embed, t_x + t_embed, st_x + st_embed + + +class Galileo_Encoder(GalileoBase): + cross_attn = False + + def __init__( + self, + max_patch_size: int = 8, + embedding_size: int = 128, + depth=2, + mlp_ratio=2, + num_heads=8, + max_sequence_length=24, + freeze_projections: bool = False, + drop_path: float = 0.0, + **kwargs + ): + super().__init__( + embedding_size, + depth, + mlp_ratio, + num_heads, + max_sequence_length, + max_patch_size, + use_channel_embs=True, + drop_path=drop_path, + ) + + self.space_time_embed = nn.ModuleDict( + { + group_name: FlexiPatchEmbed( + in_chans=len(group), embed_dim=embedding_size, patch_size=max_patch_size + ) + for group_name, group in self.space_time_groups.items() + } + ) + self.space_embed = nn.ModuleDict( + { + group_name: FlexiPatchEmbed( + in_chans=len(group), embed_dim=embedding_size, patch_size=max_patch_size + ) + for group_name, group in self.space_groups.items() + } + ) + self.time_embed = nn.ModuleDict( + { + group_name: nn.Linear(in_features=len(group), out_features=embedding_size) + for group_name, group in self.time_groups.items() + } + ) + self.static_embed = nn.ModuleDict( + { + group_name: nn.Linear(in_features=len(group), out_features=embedding_size) + for group_name, group in self.static_groups.items() + } + ) + if freeze_projections: + self.space_time_embed.requires_grad_(False) + self.space_embed.requires_grad_(False) + self.time_embed.requires_grad_(False) + self.static_embed.requires_grad_(False) + self.norm = nn.LayerNorm(embedding_size) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def apply_linear_projection( + self, + s_t_x: torch.Tensor, + sp_x: torch.Tensor, + t_x: torch.Tensor, + st_x: torch.Tensor, + s_t_m: torch.Tensor, + sp_m: torch.Tensor, + t_m: torch.Tensor, + st_m: torch.Tensor, + patch_size: int, + ): + """ + Given a [B, H, W, (T), C] inputs, returns a [B, H, W, (T), C_G, D] output. + We assume that the spatial masks are consistent for the given patch size, + so that if patch_size == 2 then one possible mask would be + [0, 0, 1, 1] + [0, 0, 1, 1] + [1, 1, 0, 0] + [1, 1, 0, 0] + for the H, W dimensions + """ + b, h, w, t, _ = s_t_x.shape + new_h, new_w = h // patch_size, w // patch_size + + s_t_l, sp_l, t_l, st_l, s_t_m_l, sp_m_l, t_m_l, st_m_l = [], [], [], [], [], [], [], [] + for idx, (channel_group, channel_idxs) in enumerate(self.space_time_groups.items()): + s_t_m_l.append(s_t_m[:, 0::patch_size, 0::patch_size, :, idx]) + if s_t_m_l[-1].min() == 0: + s_t_l.append( + self.space_time_embed[channel_group]( + s_t_x[:, :, :, :, channel_idxs], patch_size=patch_size + ) + ) + else: + s_t_l.append( + torch.empty( + b, + new_h, + new_w, + t, + self.embedding_size, + dtype=s_t_x.dtype, + device=s_t_x.device, + ) + ) + for idx, (channel_group, channel_idxs) in enumerate(self.space_groups.items()): + sp_m_l.append(sp_m[:, 0::patch_size, 0::patch_size, idx]) + if sp_m_l[-1].min() == 0: + sp_l.append( + self.space_embed[channel_group]( + sp_x[:, :, :, channel_idxs], patch_size=patch_size + ) + ) + else: + sp_l.append( + torch.empty( + b, + new_h, + new_w, + self.embedding_size, + dtype=sp_x.dtype, + device=sp_x.device, + ) + ) + + for idx, (channel_group, channel_idxs) in enumerate(self.time_groups.items()): + t_m_l.append(t_m[:, :, idx]) + if t_m_l[-1].min() == 0: + t_l.append(self.time_embed[channel_group](t_x[:, :, channel_idxs])) + else: + t_l.append( + torch.empty(b, t, self.embedding_size, dtype=t_x.dtype, device=t_x.device) + ) + + for idx, (channel_group, channel_idxs) in enumerate(self.static_groups.items()): + st_m_l.append(st_m[:, idx]) + if st_m_l[-1].min() == 0: + st_l.append(self.static_embed[channel_group](st_x[:, channel_idxs])) + else: + st_l.append( + torch.empty(b, self.embedding_size, dtype=st_x.dtype, device=st_x.device) + ) + + return ( + torch.stack(s_t_l, dim=-2), + torch.stack(sp_l, dim=-2), + torch.stack(t_l, dim=-2), + torch.stack(st_l, dim=-2), + torch.stack(s_t_m_l, dim=-1), + torch.stack(sp_m_l, dim=-1), + torch.stack(t_m_l, dim=-1), + torch.stack(st_m_l, dim=-1), + ) + + @staticmethod + def remove_masked_tokens(x, mask): + org_mask_dtype = mask.dtype + mask = mask.bool() + # https://stackoverflow.com/a/68621610/2332296 + # move all non-masked values to the front of their rows + sorted_mask, indices = torch.sort((~mask).int(), dim=1, descending=True, stable=True) + x = x.gather(1, indices[:, :, None].expand_as(x)) + # set masked values to 0 (not really necessary since we'll ignore them anyway) + x = x * sorted_mask.unsqueeze(-1) + + # cut off to the length of the longest sequence + max_length = sorted_mask.sum(-1).max() + x = x[:, :max_length] + updated_mask = 1 - sorted_mask[:, :max_length] + + return x, indices, updated_mask.to(dtype=org_mask_dtype) + + @staticmethod + def add_removed_tokens(x, indices, mask): + masked_tokens = repeat( + torch.zeros_like(x[0, 0, :]), "d -> b t d", b=x.shape[0], t=indices.shape[1] + ) + full_mask = torch.cat( + ( + mask, + torch.ones( + (x.shape[0], indices.shape[1] - x.shape[1]), device=x.device, dtype=mask.dtype + ), + ), + dim=-1, + ) + # can't set value on leaf variable + out = masked_tokens.clone() + # put tokens in full masked tensor (at the first N positions in every row) + out[~full_mask.bool()] = x[~mask.bool()] + # then move them to their original positions + out = out.scatter(1, indices[:, :, None].expand_as(out), out) + full_mask = full_mask.scatter(1, indices.expand_as(full_mask), full_mask) + return out, full_mask + + def apply_attn( + self, + s_t_x, + sp_x, + t_x, + st_x, + s_t_m, + sp_m, + t_m, + st_m, + months, + patch_size, + input_res, + exit_after, + token_exit_cfg, + ): + if token_exit_cfg: + exit_s_t, exit_sp, exit_t, exit_st = self.create_token_exit_ids( + s_t_x, sp_x, t_x, st_x, token_exit_cfg + ) + exit_ids_seq, _ = self.collapse_and_combine_hwtc( + exit_s_t, exit_sp, exit_t, exit_st, s_t_m, sp_m, t_m, st_m + ) + # exited_tokens starts as linear projections! + exited_tokens, _ = self.collapse_and_combine_hwtc( + s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m + ) + else: + exit_ids_seq = None + exited_tokens = None + + _, h, w, t, s_t_c_g, _ = s_t_x.shape + sp_c_g, t_c_g, st_c_g = sp_x.shape[3], t_x.shape[-2], st_x.shape[-2] + s_t_x, sp_x, t_x, st_x = self.apply_encodings( + s_t_x, sp_x, t_x, st_x, months, patch_size, input_res + ) + x, m = self.collapse_and_combine_hwtc(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m) + + # we only care about the values >= 1 for this mask, since 2 just tells the decoder + # to decode those tokens. From the perspective of the encoder, 1 and 2 are equivalent + # since they both represent masked values + new_m = m >= 1 + x, indices, new_m = self.remove_masked_tokens(x, new_m) # new_m is shape (bsz, seq_len) + + if exit_ids_seq is not None: + exit_ids_seq, _, _ = self.remove_masked_tokens(exit_ids_seq, m >= 1) + # still linear projections + exited_tokens, _, _ = self.remove_masked_tokens(exited_tokens, m >= 1) + + for i_blk, blk in enumerate(self.blocks): + if (exit_after is not None) and ((i_blk + 1) > exit_after): + # if exit_after is N, then we exit after the Nth layer + # if exit_after is 0, then all layers are skipped + break + + # skip the 0th block since this is just the linear + # projection + if (exit_ids_seq is not None) and (i_blk > 0): + assert exited_tokens is not None + # half depth + exited_tokens = torch.where( + condition=(exit_ids_seq == i_blk), + input=x.detach(), + other=exited_tokens.detach(), + ) + + # we take the inverse of the mask because a value + # of True indicates the value *should* take part in + # attention + x = blk(x=x, y=None, attn_mask=~new_m.bool()) + + if exit_ids_seq is not None: + assert exited_tokens is not None + # full depth + # IMPORTANT: write this to x + x = torch.where( + condition=(exit_ids_seq == (i_blk + 1)), # 2 for full depth + input=x.detach(), + other=exited_tokens.detach(), + ) + + # we don't care about the mask returned by add_removed_tokens, since we will + # just use the original, unclipped mask here + x, _ = self.add_removed_tokens(x, indices, new_m) + return ( + *self.split_and_expand_hwtc(x, h, w, t, s_t_c_g, sp_c_g, t_c_g, st_c_g), + s_t_m, + sp_m, + t_m, + st_m, + ) + + @classmethod + def average_tokens(cls, s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m): + x, m = cls.collapse_and_combine_hwtc(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m) + x, _, m = cls.remove_masked_tokens(x, m) + x_for_mean = x * (1 - m.unsqueeze(-1)) + return x_for_mean.sum(dim=1) / torch.sum(1 - m, -1, keepdim=True) + + @classmethod + def apply_mask_and_average_tokens_per_patch( + cls, + s_t_x: torch.Tensor, + sp_x: torch.Tensor, + t_x: torch.Tensor, + st_x: torch.Tensor, + s_t_m: torch.Tensor, + sp_m: torch.Tensor, + t_m: torch.Tensor, + st_m: torch.Tensor, + ): + s_t_x = rearrange(s_t_x, "b t_h t_w t c_g d -> b (t_h t_w) (t c_g) d") + sp_x = rearrange(sp_x, "b t_h t_w c_g d -> b (t_h t_w) c_g d") + # repeat time tokens over space + t_x = repeat( + rearrange(t_x, "b t c_g d -> b (t c_g) d"), "b n d -> b s n d", s=sp_x.shape[1] + ) + st_x = repeat(st_x, "b c_g d -> b s c_g d", s=sp_x.shape[1]) + s_t_m = rearrange(s_t_m, "b t_h t_w t c_g-> b (t_h t_w) (t c_g)") + sp_m = rearrange(sp_m, "b t_h t_w c_g-> b (t_h t_w) c_g") + t_m = repeat(rearrange(t_m, "b t c_g -> b (t c_g)"), "b n -> b s n", s=sp_x.shape[1]) + st_m = repeat(st_m, "b c_g -> b s c_g", s=sp_x.shape[1]) + + x = torch.cat([s_t_x, sp_x, t_x, st_x], dim=2) # B, S, N, D + m = torch.cat([s_t_m, sp_m, t_m, st_m], dim=2) # B, S, N + + x_for_mean = x * (1 - m.unsqueeze(-1)) + + return x_for_mean.sum(dim=2) / torch.sum(1 - m, -1, keepdim=True) + + def create_token_exit_ids(self, s_t_x, sp_x, t_x, st_x, token_exit_cfg): + exit_s_t = torch.zeros_like(s_t_x) + exit_sp = torch.zeros_like(sp_x) + exit_t = torch.zeros_like(t_x) + exit_st = torch.zeros_like(st_x) + + for idx, (key, _) in enumerate(self.space_time_groups.items()): + exit_s_t[:, :, :, :, idx, :] = token_exit_cfg[key] + + for idx, (key, _) in enumerate(self.space_groups.items()): + exit_sp[:, :, :, idx, :] = token_exit_cfg[key] + + for idx, (key, _) in enumerate(self.time_groups.items()): + exit_t[:, :, idx, :] = token_exit_cfg[key] + + for idx, (key, _) in enumerate(self.static_groups.items()): + exit_st[:, idx, :] = token_exit_cfg[key] + return exit_s_t, exit_sp, exit_t, exit_st + + def forward( + self, + s_t_x: torch.Tensor, + sp_x: torch.Tensor, + t_x: torch.Tensor, + st_x: torch.Tensor, + s_t_m: torch.Tensor, + sp_m: torch.Tensor, + t_m: torch.Tensor, + st_m: torch.Tensor, + months: torch.Tensor, + patch_size: int, + input_resolution_m: Optional[int] = BASE_GSD, + exit_after: Optional[int] = None, + token_exit_cfg: Optional[Dict] = None, + add_layernorm_on_exit: bool = True, + ): + s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m = self.apply_linear_projection( + s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, patch_size) + + if (exit_after is None) or (exit_after > 0): + s_t_x, sp_x, t_x, st_x, s_t_m, st_m, t_m, st_m = self.apply_attn( + s_t_x, + sp_x, + t_x, + st_x, + s_t_m, + sp_m, + t_m, + st_m, + months, + patch_size, + input_resolution_m, + exit_after=exit_after, + token_exit_cfg=token_exit_cfg, + ) + + if add_layernorm_on_exit: + s_t_x = self.norm(s_t_x) + sp_x = self.norm(sp_x) + t_x = self.norm(t_x) + st_x = self.norm(st_x) + return ( + s_t_x, + sp_x, + t_x, + st_x, + s_t_m, + sp_m, + t_m, + st_m, + months, + ) + + @classmethod + def load_from_folder(cls, folder: Path, device: torch.device): + if not (folder / CONFIG_FILENAME).exists(): + all_files_in_folder = [f.name for f in folder.glob("*")] + raise ValueError( + f"Expected {CONFIG_FILENAME} in {folder}, found {all_files_in_folder}" + ) + if not (folder / ENCODER_FILENAME).exists(): + all_files_in_folder = [f.name for f in folder.glob("*")] + raise ValueError( + f"Expected {ENCODER_FILENAME} in {folder}, found {all_files_in_folder}" + ) + + with (folder / CONFIG_FILENAME).open("r") as f: + config = json.load(f) + model_config = config["model"] + encoder_config = model_config["encoder"] + encoder = cls(**encoder_config) + + state_dict = torch.load(folder / ENCODER_FILENAME, map_location=device) + for key in list(state_dict.keys()): + # this cleans the state dict, which occasionally had an extra + # ".backbone" included in the key names + state_dict[key.replace(".backbone", "")] = state_dict.pop(key) + encoder.load_state_dict(state_dict) + return encoder + + +class GalileoWrapper(Encoder): + """ + Adapted from class Galileo Wrapper https://github.com/nasaharvest/galileo/blob/main/src/galileo.py#L1245 + """ + + # we assume any data passed to this wrapper + # will contain S2 data with the following channels + S2_BAND_ORDERING = [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8", + "B8A", + "B9", + "B10", + "B11", + "B12", + ] + S1_BAND_ORDERING = [ + "VV", + "VH", + ] + + def __init__( + self, + # new - from Pangaea Base Encoder + encoder_weights: str | Path, + model_name: str, + size: str, + input_bands: dict[str, list[str]], + data_bands: dict[str, list[str]], + input_size: int, + embed_dim, + output_layers: int | list[int], + output_dim: int | list[int], + download_url: str, + models_folder: str | Path, + token_exit_cfg, + + # old - from original GalileoWrapper implementation + # pretrained_path: Path, + # patch_size: int = 8, + month: int = 6, + do_pool: bool = True, + add_layernorm_on_exit: bool = True, + + **kwargs # now contains parameters for Galileo_Encoder + ): + models_path = os.path.join(models_folder, size) + if encoder_weights is None: + encoder_weights = os.path.join(models_path, ENCODER_FILENAME) + + input_bands = self.drop_encoder_modalities(data_bands, input_bands) + + self.galileo_models_path = models_path + self.models_folder = models_folder + self.size = size + super().__init__( + model_name=model_name, + encoder_weights=encoder_weights, + input_bands=input_bands, + input_size=input_size, + embed_dim=embed_dim, + output_dim=output_dim, + output_layers=output_layers, + multi_temporal=True, + multi_temporal_output=False, + pyramid_output=False, + download_url=download_url, + ) + self.encoder = Galileo_Encoder(**kwargs) # Encoder.load_from_folder(pretrained_path) + self.dim = self.encoder.embedding_size + self.patch_size = self.encoder.base_patch_size + self.data_gsd = kwargs["data_gsd"] if "data_gsd" in kwargs.keys() else BASE_GSD + self.grid_size: Optional[int] = None + self.do_pool = do_pool + self.month = month + self.kept_s2_band_idx = [i for i, v in enumerate(self.S2_BAND_ORDERING) if v in S2_BANDS] + self.kept_s1_band_idx = [i for i, v in enumerate(self.S1_BAND_ORDERING) if v in S1_BANDS] + kept_s2_band_names = [val for val in self.S2_BAND_ORDERING if val in S2_BANDS] + kept_s1_band_names = [val for val in self.S1_BAND_ORDERING if val in S1_BANDS] + self.to_galileo_s2_map = [SPACE_TIME_BANDS.index(val) for val in kept_s2_band_names] + self.to_galileo_s1_map = [SPACE_TIME_BANDS.index(val) for val in kept_s1_band_names] + self.s_t_channels_s2 = [ + idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key + ] + self.s_t_channels_s1 = [ + idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key + ] + self.add_layernorm_on_exit = add_layernorm_on_exit + self.token_exit_cfg = token_exit_cfg + + def drop_encoder_modalities(self, data_bands, encoder_bands): + encoder_modalities = list(encoder_bands.keys()) + for k in encoder_modalities: + if k not in data_bands.keys(): + del encoder_bands[k] + return encoder_bands + + def load_encoder_weights(self, logger: Logger) -> None: + self.encoder = Galileo_Encoder.load_from_folder(Path(self.galileo_models_path), + device=torch.device("cpu") + ) + + def preprocess( + self, + s2: Optional[torch.Tensor] = None, + s1: Optional[torch.Tensor] = None, + months: Optional[torch.Tensor] = None, + ): + """ + Converts Pangaea batches to Galileo batches. + + + batch contains only the image data from dataset + + Extended to (s1) or (s2) or (s1 and s2) inputs. + """ + + # images should have shape (b h w c) or (b h w t c) + if s1 is None: # s2 only + assert s2 is not None, "no s1 or s2?" + + s_t_channels = self.s_t_channels_s2 + data_dtype = s2.dtype + data_device = s2.device + if len(s2.shape) == 4: + b, c_s2, h, w = s2.shape # /!\ Pangaea shape + t = 1 + s2 = rearrange(s2, "b c h w -> b h w c") + else: + assert len(s2.shape) == 5 + b, c_s2, t, h, w = s2.shape + s2 = rearrange(s2, "b c t h w -> b h w t c") + assert c_s2 == len(self.S2_BAND_ORDERING) + + # add a single timestep + s_t_x = torch.zeros( + (b, h, w, t, len(SPACE_TIME_BANDS)), dtype=s2.dtype, device=s2.device + ) + if len(s2.shape) == 4: + s_t_x[:, :, :, 0, self.to_galileo_s2_map] = s2[:, :, :, self.kept_s2_band_idx] + else: + s_t_x[:, :, :, :, self.to_galileo_s2_map] = s2[:, :, :, :, self.kept_s2_band_idx] + + elif s2 is None: # s1 only + assert s1 is not None, "no s1 or s2?" + + s_t_channels = self.s_t_channels_s1 + data_dtype = s1.dtype + data_device = s1.device + if len(s1.shape) == 4: + b, c_s1, h, w = s1.shape + t = 1 + s1 = rearrange(s1, "b c h w -> b h w c") + else: + assert len(s1.shape) == 5 + b, c_s1, t, h, w = s1.shape + s1 = rearrange(s1, "b c t h w -> b h w t c") + assert c_s1 == len(self.S1_BAND_ORDERING) + + # add a single timestep + s_t_x = torch.zeros( + (b, h, w, t, len(SPACE_TIME_BANDS)), dtype=s1.dtype, device=s1.device + ) + if len(s1.shape) == 4: + s_t_x[:, :, :, 0, self.to_galileo_s1_map] = s1[:, :, :, self.kept_s1_band_idx] + else: + s_t_x[:, :, :, :, self.to_galileo_s1_map] = s1[:, :, :, :, self.kept_s1_band_idx] + + else: # s1 and s2 + s_t_channels = self.s_t_channels_s1 + self.s_t_channels_s2 + + assert s2.dtype == s1.dtype, f"Got different data types for s1 ({s1.dtype}) and s2 ({s2.dtype})" + + data_dtype = s2.dtype + data_device = s2.device + + if len(s2.shape) == 4: + b, c_s2, h, w = s2.shape # /!\ Pangaea shape + t = 1 + s2 = rearrange(s2, "b c h w -> b h w c") + else: + assert len(s2.shape) == 5 + b, c_s2, t, h, w = s2.shape + s2 = rearrange(s2, "b c t h w -> b h w t c") + assert c_s2 == len(self.S2_BAND_ORDERING) + + if len(s1.shape) == 4: + b, c_s1, h, w = s1.shape + t = 1 + s1 = rearrange(s1, "b c h w -> b h w c") + else: + assert len(s1.shape) == 5 + b, c_s1, t, h, w = s1.shape + s1 = rearrange(s1, "b c t h w -> b h w t c") + assert c_s1 == len(self.S1_BAND_ORDERING) + + # add a single timestep + s_t_x = torch.zeros( + (b, h, w, t, len(SPACE_TIME_BANDS)), dtype=s2.dtype, device=s2.device + ) + if len(s2.shape) == 4: + s_t_x[:, :, :, 0, self.to_galileo_s2_map] = s2[:, :, :, self.kept_s2_band_idx] + else: + s_t_x[:, :, :, :, self.to_galileo_s2_map] = s2[:, :, :, :, self.kept_s2_band_idx] + if len(s1.shape) == 4: + s_t_x[:, :, :, 0, self.to_galileo_s1_map] = s1[:, :, :, self.kept_s1_band_idx] + else: + s_t_x[:, :, :, :, self.to_galileo_s1_map] = s1[:, :, :, :, self.kept_s1_band_idx] + + + s_t_m = torch.ones( + (b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), + dtype=data_dtype, + device=data_device, + ) + s_t_m[:, :, :, :, s_t_channels] = 0 + + if months is None: + months = torch.ones((b, t), dtype=data_dtype, device=data_device) * self.month + else: + assert months.shape[-1] == t + + self.grid_size = int(s_t_x.shape[1] / self.patch_size) + + return ( + s_t_x, + torch.empty((b, h, w, len(SPACE_BANDS)), dtype=data_dtype, device=data_device), + torch.empty((b, t, len(TIME_BANDS)), dtype=data_dtype, device=data_device), + torch.empty((b, len(STATIC_BANDS)), dtype=data_dtype, device=data_device), + s_t_m, + torch.ones( + (b, h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=data_dtype, device=data_device + ), + torch.ones((b, t, len(TIME_BAND_GROUPS_IDX)), dtype=data_dtype, device=data_device), + torch.ones((b, len(STATIC_BAND_GROUPS_IDX)), dtype=data_dtype, device=data_device), + months.long(), + ) + + def forward(self, x: dict[str, torch.Tensor] | torch.Tensor | None = None, months=None, **kwargs) -> list[torch.Tensor]: + """Foward pass of the encoder. + + Args: + x (dict[str, torch.Tensor]): encoder's input structured as a dictionary: + x = {modality1: tensor1, modality2: tensor2, ...}, e.g. x = {"optical": tensor1, "sar": tensor2}. + If the encoder is multi-temporal (self.multi_temporal==True), input tensor shape is (B C T H W) with C the + number of bands required by the encoder for the given modality and T the number of time steps. If the + encoder is not multi-temporal, input tensor shape is (B C H W) with C the number of bands required by the + encoder for the given modality. + Note: Galileo Encoder is multi_temporal. + + Returns: + list[torch.Tensor]: list of the embeddings for each modality. For single-temporal encoders, the list's + elements are of shape (B, embed_dim, H', W'). For multi-temporal encoders, the list's elements are of shape + (B, C', T, H', W') with T the number of time steps if the encoder does not have any time-merging strategy, + else (B, C', H', W') if the encoder has a time-merging strategy (where C'==self.output_dim). + """ + s1, s2 = None, None + if "optical" in x.keys(): + if x['optical'].any(): + s2 = x['optical'] + s_t_channels = self.s_t_channels_s2 + + if "sar" in x.keys(): + if x['sar'].any(): + s1 = x['sar'] + s_t_channels = self.s_t_channels_s1 + + months = months + + s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, months = self.preprocess( + s2=s2, s1=s1, months=months + ) + + # Modified Encoder.forward for extracting intermediate features from the attention blocks + s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m = self.encoder.apply_linear_projection( + s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, self.patch_size) + + exit_after = None + if (exit_after is None) or (exit_after > 0): + s_t_x, sp_x, t_x, st_x, s_t_m, st_m, t_m, st_m, outputs = self.apply_attn( + s_t_x, + sp_x, + t_x, + st_x, + s_t_m, + sp_m, + t_m, + st_m, + months, + self.patch_size, + input_res = self.data_gsd, + exit_after=exit_after, + token_exit_cfg=self.token_exit_cfg, + add_layernorm_on_exit=self.add_layernorm_on_exit + ) + + if self.do_pool: + return self.encoder.average_tokens(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m) + else: + ### New code + # we will be assuming we only want s_t_x, and (for now) that we want all s_2 bands + # s_t_x has shape [b, h, w, t, c_g, d] + # and we want [b, h * w, d] + + # Select s_t_x features only (s_t_channels is s1 or s2 (or both) for now) + features = [rearrange(s_t_x_feat[0][:, :, :, :, s_t_channels, :].mean(dim=3), + "b h w c_g d -> b d h w c_g").mean(dim=-1) + for s_t_x_feat in outputs] + return features + + def apply_attn( + self, + s_t_x, + sp_x, + t_x, + st_x, + s_t_m, + sp_m, + t_m, + st_m, + months, + patch_size, + input_res, + exit_after, + token_exit_cfg, + add_layernorm_on_exit + ): + if token_exit_cfg: + exit_s_t, exit_sp, exit_t, exit_st = self.encoder.create_token_exit_ids( + s_t_x, sp_x, t_x, st_x, token_exit_cfg + ) + exit_ids_seq, _ = self.encoder.collapse_and_combine_hwtc( + exit_s_t, exit_sp, exit_t, exit_st, s_t_m, sp_m, t_m, st_m + ) + # exited_tokens starts as linear projections! + exited_tokens, _ = self.encoder.collapse_and_combine_hwtc( + s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m + ) + else: + exit_ids_seq = None + exited_tokens = None + + _, h, w, t, s_t_c_g, _ = s_t_x.shape + sp_c_g, t_c_g, st_c_g = sp_x.shape[3], t_x.shape[-2], st_x.shape[-2] + s_t_x, sp_x, t_x, st_x = self.encoder.apply_encodings( + s_t_x, sp_x, t_x, st_x, months, patch_size, input_res + ) + x, m = self.encoder.collapse_and_combine_hwtc(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m) + + # we only care about the values >= 1 for this mask, since 2 just tells the decoder + # to decode those tokens. From the perspective of the encoder, 1 and 2 are equivalent + # since they both represent masked values + new_m = m >= 1 + x, indices, new_m = self.encoder.remove_masked_tokens(x, new_m) # new_m is shape (bsz, seq_len) + + if exit_ids_seq is not None: + exit_ids_seq, _, _ = self.encoder.remove_masked_tokens(exit_ids_seq, m >= 1) + # still linear projections + exited_tokens, _, _ = self.encoder.remove_masked_tokens(exited_tokens, m >= 1) + + outputs = [] + for i_blk, blk in enumerate(self.encoder.blocks): + if (exit_after is not None) and ((i_blk + 1) > exit_after): + # if exit_after is N, then we exit after the Nth layer + # if exit_after is 0, then all layers are skipped + break + + # skip the 0th block since this is just the linear + # projection + if (exit_ids_seq is not None) and (i_blk > 0): + assert exited_tokens is not None + # half depth + exited_tokens = torch.where( + condition=(exit_ids_seq == i_blk), + input=x.detach(), + other=exited_tokens.detach(), + ) + + # we take the inverse of the mask because a value + # of True indicates the value *should* take part in + # attention + x = blk(x=x, y=None, attn_mask=~new_m.bool()) + if i_blk in self.output_layers: + outputs.append(x) + + if exit_ids_seq is not None: + assert exited_tokens is not None + # full depth + # IMPORTANT: write this to x + x = torch.where( + condition=(exit_ids_seq == (i_blk + 1)), # 2 for full depth + input=x.detach(), + other=exited_tokens.detach(), + ) + # Replace last output for pyramid decoder? + outputs[-1] = x + + # we don't care about the mask returned by add_removed_tokens, since we will + # just use the original, unclipped mask here + x, _ = self.encoder.add_removed_tokens(x, indices, new_m) + outputs = [self.encoder.add_removed_tokens(out, indices, new_m)[0] for out in outputs] + outputs = [list(self.encoder.split_and_expand_hwtc(out, h, w, t, s_t_c_g, sp_c_g, t_c_g, st_c_g)) + for out in outputs] + + s_t_x, sp_x, t_x, st_x = self.encoder.split_and_expand_hwtc(x, h, w, t, s_t_c_g, sp_c_g, t_c_g, st_c_g) + + if add_layernorm_on_exit: # on last elem of output only + s_t_x = self.encoder.norm(s_t_x) + sp_x = self.encoder.norm(sp_x) + t_x = self.encoder.norm(t_x) + st_x = self.encoder.norm(st_x) + + for i, out in enumerate(outputs[-1]): + outputs[-1][i] = self.encoder.norm(out) + + return ( + s_t_x, + sp_x, + t_x, + st_x, + s_t_m, + sp_m, + t_m, + st_m, + outputs, + ) + + def download_model(self) -> None: + """Download the model and config if the weights are not already downloaded.""" + if self.download_url and not os.path.isfile(self.encoder_weights): + save_path = os.path.join(self.models_folder, self.size) + os.makedirs(save_path, exist_ok=True) + + files = ["/config.json", "/encoder.pt", "/decoder.pt", "/second_decoder.pt", "/target_encoder.pt"] + # Example url: https://huggingface.co/nasaharvest/galileo/resolve/main/models/base/config.json?download=true + + pbar = DownloadProgressBar(f"Downloading {save_path + f}") + + for f in files: + url = self.download_url + self.size + f + "?download=true" + try: + urllib.request.urlretrieve( + url, + save_path + f, + pbar, + ) + except urllib.error.HTTPError as e: + print( + "Error while downloading model: The server couldn't fulfill the request." + ) + print("Error code: ", e.code) + except urllib.error.URLError as e: + print("Error while downloading model: Failed to reach a server.") + print("Reason: ", e.reason) + +def galileo_base(**kwargs): + model = GalileoWrapper( + model_name="galileo_base", + max_patch_size=8, + embedding_size=768, + depth=12, + mlp_ratio=4, + num_heads=12, + max_sequence_length=24, + #freeze_projections=False, # in kwargs, from config.yaml + drop_path=0.1, + **kwargs # config from yaml + ) + return model + +def galileo_nano(**kwargs): + model = GalileoWrapper( + model_name="galileo_nano", + max_patch_size=8, + embedding_size=128, + depth=4, + mlp_ratio=4, + num_heads=8, + max_sequence_length=24, + #freeze_projections=False, # in kwargs, from config.yaml + drop_path=0.1, + **kwargs # config from yaml + ) + return model + + +def galileo_tiny(**kwargs): + model = GalileoWrapper( + model_name="galileo_tiny", + max_patch_size=8, + embedding_size=192, + depth=12, + mlp_ratio=4, + num_heads=3, + max_sequence_length=24, + #freeze_projections=False, # in kwargs, from config.yaml + drop_path=0.1, + **kwargs # config from yaml + ) + return model diff --git a/pangaea/engine/data_preprocessor.py b/pangaea/engine/data_preprocessor.py index c229fd9c..906d30de 100644 --- a/pangaea/engine/data_preprocessor.py +++ b/pangaea/engine/data_preprocessor.py @@ -941,6 +941,44 @@ def __init__( size, scale, ratio, interpolation, antialias, resize_target, **meta ) +class ModalityDrop(BasePreprocessor): + def __init__(self, **meta) -> None: + """Intialize the ModalityDrop. + + Drop **encoder** modality if not in dataset (same as BandFilter, but for the encoder). + + Args: + meta: statistics/info of the input data and target encoder + data_bands: bands of incoming data + encoder_bands: expected bands by encoder + """ + super().__init__() + + self.avail_modalities = [] + for k in meta["encoder_bands"].keys(): + if k in meta["data_bands"].keys(): + self.avail_modalities.append(k) + + if not self.avail_modalities: + raise ValueError("No common input modalities after ModalityDrop!") + + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """Drop **encoder** modality if not in dataset. + Args: + data (dict): input + """ + return data + + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta["encoder_bands"] = {k: val + for k, val in meta["encoder_bands"].items() + if k in self.avail_modalities} + return meta + def _setup_size(size, error_msg): if isinstance(size, numbers.Number):