diff --git a/recipes/aero_cfd/scripts/train_ahmedml.py b/recipes/aero_cfd/scripts/train_ahmedml.py index 0fb236a2..b6f1759c 100644 --- a/recipes/aero_cfd/scripts/train_ahmedml.py +++ b/recipes/aero_cfd/scripts/train_ahmedml.py @@ -1,6 +1,7 @@ # Copyright © 2026 Emmi AI GmbH. All rights reserved. from aero_cfd.presets import AhmedMLPreset + from noether.core.distributed.utils import accelerator_to_device from noether.training.runners import HydraRunner diff --git a/recipes/aero_cfd/scripts/train_drivaerml.py b/recipes/aero_cfd/scripts/train_drivaerml.py index bf69999d..c3a051e5 100644 --- a/recipes/aero_cfd/scripts/train_drivaerml.py +++ b/recipes/aero_cfd/scripts/train_drivaerml.py @@ -1,6 +1,7 @@ # Copyright © 2026 Emmi AI GmbH. All rights reserved. from aero_cfd.presets import DrivAerMLPreset + from noether.core.distributed.utils import accelerator_to_device from noether.training.runners import HydraRunner diff --git a/recipes/aero_cfd/scripts/train_drivaernet.py b/recipes/aero_cfd/scripts/train_drivaernet.py index ab7f1c73..5bc458ed 100644 --- a/recipes/aero_cfd/scripts/train_drivaernet.py +++ b/recipes/aero_cfd/scripts/train_drivaernet.py @@ -1,6 +1,7 @@ # Copyright © 2026 Emmi AI GmbH. All rights reserved. from aero_cfd.presets import DrivAerNetPreset + from noether.core.distributed.utils import accelerator_to_device from noether.training.runners import HydraRunner diff --git a/recipes/aero_cfd/scripts/train_emmi_wing.py b/recipes/aero_cfd/scripts/train_emmi_wing.py index 921c1cec..0d03ec0f 100644 --- a/recipes/aero_cfd/scripts/train_emmi_wing.py +++ b/recipes/aero_cfd/scripts/train_emmi_wing.py @@ -5,6 +5,7 @@ from pathlib import Path from aero_cfd.presets import EmmiWingPreset + from noether.core.distributed.utils import accelerator_to_device from noether.data.datasets.cfd.emmi_wing.dataset_hf import EmmiWingHFDataset from noether.training.runners import HydraRunner diff --git a/recipes/aero_cfd/scripts/train_shapenet_car.py b/recipes/aero_cfd/scripts/train_shapenet_car.py index 980cfdf3..f7d58b27 100644 --- a/recipes/aero_cfd/scripts/train_shapenet_car.py +++ b/recipes/aero_cfd/scripts/train_shapenet_car.py @@ -1,6 +1,7 @@ # Copyright © 2026 Emmi AI GmbH. All rights reserved. from aero_cfd.presets import ShapeNetCarPreset + from noether.core.distributed.utils import accelerator_to_device from noether.training.runners import HydraRunner diff --git a/recipes/aero_cfd/showcase/utils/forces.py b/recipes/aero_cfd/showcase/utils/forces.py index 979c754a..47b4442f 100644 --- a/recipes/aero_cfd/showcase/utils/forces.py +++ b/recipes/aero_cfd/showcase/utils/forces.py @@ -8,7 +8,6 @@ from pathlib import Path import torch - from aero_cfd.utils.drag_lift import FlowConditions, compute_force_coefficients diff --git a/src/noether/core/schemas/models/__init__.py b/src/noether/core/schemas/models/__init__.py index d34eeb73..1d42a476 100644 --- a/src/noether/core/schemas/models/__init__.py +++ b/src/noether/core/schemas/models/__init__.py @@ -5,6 +5,7 @@ from .transformer import TransformerConfig from .transolver import TransolverConfig, TransolverPlusPlusConfig from .upt import UPTConfig +from .vit import ViTConfig __all__ = [ "ModelBaseConfig", @@ -13,4 +14,5 @@ "TransolverPlusPlusConfig", "TransformerConfig", "UPTConfig", + "ViTConfig", ] diff --git a/src/noether/core/schemas/models/vit.py b/src/noether/core/schemas/models/vit.py new file mode 100644 index 00000000..7b0f32b3 --- /dev/null +++ b/src/noether/core/schemas/models/vit.py @@ -0,0 +1,60 @@ +# Copyright © 2026 Emmi AI GmbH. All rights reserved. + +from pydantic import ConfigDict, Field, computed_field + +from noether.core.schemas.modules.blocks import TransformerBlockConfig + +from .base import ModelBaseConfig + + +class ViTConfig(ModelBaseConfig): + """Configuration for ViT model""" + + model_config = ConfigDict(extra="forbid") + + coord_dim: int = Field(..., ge=1) + """Coordinate dimensionality of the input grid (2 for 2D, 3 for 3D).""" + + out_channels: int = Field(..., ge=1) + """Number of output channels emitted per spatial cell.""" + + patch_size: int = Field(..., ge=2) + """Patch side length in cells. The grid resolution must be divisible by this value.""" + + hidden_dim: int = Field(192, ge=1) + """Token hidden dimension throughout the transformer stack.""" + + num_heads: int = Field(6, ge=1) + """Number of attention heads in each transformer block.""" + + depth: int = Field(10, ge=1) + """Number of stacked transformer blocks.""" + + mlp_ratio: int = Field(4, ge=1) + """FFN expansion factor inside each transformer block.""" + + use_conditioning: bool = True + """If True, enable AdaLN-Zero conditioning (forward requires ``cond``); if False, plain ViT (``cond`` must be ``None``).""" + + token_dropout: float = Field(0.0, ge=0.0, le=1.0) + """Per-patch token dropout probability used during training.""" + + attn_drop: float = Field(0.0, ge=0.0, le=1.0) + """Dropout probability inside attention.""" + + use_conv_output_head: bool = True + """If True, decode via a cascaded PixelShuffle conv head; if False, decode via a linear unpatchify.""" + + @computed_field # type: ignore[prop-decorator] + @property + def transformer_block_config(self) -> TransformerBlockConfig: + return TransformerBlockConfig( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + mlp_expansion_factor=self.mlp_ratio, + attention_constructor="dot_product", + condition_dim=self.hidden_dim if self.use_conditioning else None, + use_rope=True, + dropout=self.attn_drop, + init_weights="xavier", + ) diff --git a/src/noether/core/types.py b/src/noether/core/types.py index 7dcafbc5..e396ca04 100644 --- a/src/noether/core/types.py +++ b/src/noether/core/types.py @@ -2,7 +2,9 @@ from typing import Literal -InitWeightsMode = Literal["truncnormal002", "torch", "truncnormal", "truncnormal002-identity", "torchs", "zeros"] +InitWeightsMode = Literal[ + "truncnormal002", "torch", "truncnormal", "truncnormal002-identity", "torchs", "zeros", "xavier" +] ActivationTypes = Literal["GELU", "TANH", "SIGMOID", "RELU", "LEAKY_RELU", "SOFTPLUS", "ELU", "SILU"] diff --git a/src/noether/modeling/functional/init.py b/src/noether/modeling/functional/init.py index 265b8e68..40f1a7ec 100644 --- a/src/noether/modeling/functional/init.py +++ b/src/noether/modeling/functional/init.py @@ -43,6 +43,18 @@ def init_trunc_normal_zero_bias(layer_module: nn.Module, std: float = 0.02) -> N nn.init.constant_(layer_module.bias, 0.0) +def init_xavier_uniform_zero_bias(layer_module: nn.Module) -> None: + """Initialize the weight tensor of a nn.Module instance using Xavier uniform with a zero bias vector. + + Args: + layer_module: An nn.Module instance, either a Linear or Conv layer. + """ + if isinstance(layer_module, ALL_LAYERS): + nn.init.xavier_uniform_(layer_module.weight) + if layer_module.bias is not None: + nn.init.constant_(layer_module.bias, 0.0) + + def apply_init_method( module: torch.nn.Module, proj_weight: torch.Tensor, @@ -61,5 +73,7 @@ def apply_init_method( elif init_method == "truncnormal002-identity": module.apply(init_trunc_normal_zero_bias) torch.nn.init.zeros_(proj_weight) + elif init_method == "xavier": + module.apply(init_xavier_uniform_zero_bias) else: raise NotImplementedError(f"Weight initialization method {init_method} not implemented for DotProductAttention") diff --git a/src/noether/modeling/models/__init__.py b/src/noether/modeling/models/__init__.py index 0536525f..8bdbaded 100644 --- a/src/noether/modeling/models/__init__.py +++ b/src/noether/modeling/models/__init__.py @@ -12,12 +12,14 @@ from .transformer import Transformer from .transolver import Transolver from .upt import UPT +from .vit import ViT __all__ = [ "AnchoredBranchedUPT", "Transformer", "Transolver", "UPT", + "ViT", "AeroABUPT", "AeroTransformer", "AeroTransformerConfig", diff --git a/src/noether/modeling/models/transformer.py b/src/noether/modeling/models/transformer.py index 0b5ccf07..a3c0de5a 100644 --- a/src/noether/modeling/models/transformer.py +++ b/src/noether/modeling/models/transformer.py @@ -34,17 +34,20 @@ def forward( self, x: torch.Tensor, attn_kwargs: dict[str, torch.Tensor], + condition: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass of the Transformer model. Args: x: Input tensor of shape (batch_size, seq_len, hidden_dim). attn_kwargs: Additional arguments for the attention mechanism. + condition: Optional conditioning vector of shape (batch_size, condition_dim) consumed + by each block's AdaLN-Zero modulation. ``None`` (default) for unconditioned models. Returns: torch.Tensor: Output tensor after processing through the Transformer model. """ for block in self.blocks: - x, _ = block(x, attn_kwargs=attn_kwargs) + x, _ = block(x, condition=condition, attn_kwargs=attn_kwargs) return x diff --git a/src/noether/modeling/models/vit.py b/src/noether/modeling/models/vit.py new file mode 100644 index 00000000..5c966e4c --- /dev/null +++ b/src/noether/modeling/models/vit.py @@ -0,0 +1,188 @@ +# Copyright © 2026 Emmi AI GmbH. All rights reserved. + +from __future__ import annotations + +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn + +from noether.core.schemas.models import TransformerConfig, ViTConfig +from noether.core.schemas.modules import ( + ContinuousSincosEmbeddingConfig, + RopeFrequencyConfig, +) +from noether.modeling.models.transformer import Transformer +from noether.modeling.modules.layers import ( + AvgPool2DPatchify, + ContinuousSincosEmbed, + ConvOutputHead, + FinalLayer, + MaskPatchify, + RopeFrequency, +) + + +class ViT(nn.Module): + """Vision Transformer for spatial regression on continuous-coordinate grids. + + Based on the ViT paper (https://arxiv.org/pdf/2010.11929) with several modifications, such as: + + - Continuous coordinate inputs with sincos positional embedding and RoPE (vs. learned 1D position embeddings). + - Optional AdaLN-Zero conditioning, à la DiT (https://arxiv.org/abs/2212.09748). + - RMSNorm and QK-norm in attention (vs. LayerNorm only). + """ + + def __init__(self, config: ViTConfig) -> None: + """ + Args: + config: Configuration for the ViT model. See + :class:`~noether.core.schemas.models.ViTConfig` for available options. + """ + super().__init__() + + self.coord_dim = config.coord_dim + self.out_channels = config.out_channels + self.patch_size = config.patch_size + self.hidden_dim = config.hidden_dim + self.num_heads = config.num_heads + self.token_dropout = config.token_dropout + self.use_conditioning = config.use_conditioning + + # patchify + self.pool_patch = AvgPool2DPatchify(patch_size=config.patch_size) + self.mask_patchify = MaskPatchify(patch_size=config.patch_size) + + # positional encoding + self.pos_embedding = ContinuousSincosEmbed( + config=ContinuousSincosEmbeddingConfig(hidden_dim=config.hidden_dim, input_dim=config.coord_dim), # type: ignore[call-arg] + ) + self.rope = RopeFrequency( + config=RopeFrequencyConfig( # type: ignore[call-arg] + input_dim=config.coord_dim, + hidden_dim=config.hidden_dim // config.num_heads, + ), + ) + + self.backbone = Transformer( + config=TransformerConfig( + name="vit_backbone", + hidden_dim=config.hidden_dim, + depth=config.depth, + transformer_block_config=config.transformer_block_config, + ) + ) + + # output heads + self.use_conv_output_head = config.use_conv_output_head + if config.use_conv_output_head: + self.final_layer = FinalLayer( + config.hidden_dim, 1, config.hidden_dim, use_modulation=config.use_conditioning + ) + self.conv_output_head: ConvOutputHead | None = ConvOutputHead( + config.hidden_dim, config.out_channels, config.patch_size + ) + else: + self.final_layer = FinalLayer( + config.hidden_dim, config.patch_size, config.out_channels, use_modulation=config.use_conditioning + ) + self.conv_output_head = None + + self.initialize_weights() + + def initialize_weights(self) -> None: + """Initialize backbone weights""" + if self.final_layer.adaLN_modulation is not None: + nn.init.constant_(self.final_layer.adaLN_modulation.weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation.bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + # Zero the last conv of the ConvOutputHead so decoding starts near identity. + if self.conv_output_head is not None: + last_stage = self.conv_output_head.stages[-1] + if not isinstance(last_stage, nn.Sequential): + raise ValueError("Expected last stage of ConvOutputHead to be nn.Sequential.") + for module in reversed(list(last_stage)): + if isinstance(module, nn.Conv2d): + nn.init.constant_(module.weight, 0) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + break + + def unpatchify(self, x: Tensor, grid_h: int, grid_w: int) -> Tensor: + """Linear unpatchify: ``(B, L, p²·C_out) → (B, H, W, C_out)``.""" + p = self.patch_size + c = self.out_channels + b, seq_len, patch_dim = x.shape + if seq_len != grid_h * grid_w: + raise ValueError(f"Sequence length {seq_len} doesn't match grid {grid_h}*{grid_w}.") + if patch_dim != p * p * c: + raise ValueError(f"Patch dim mismatch: expected {p * p * c}, got {patch_dim}.") + x = x.view(b, grid_h, grid_w, p, p, c) + return rearrange(x, "b h w p q c -> b (h p) (w q) c") + + def forward( + self, + x: Tensor | None, + coords: Tensor, + mask: Tensor | None = None, + cond: Tensor | None = None, + return_tokens: bool = False, + ) -> Tensor | tuple[Tensor, tuple[int, int]]: + """Run the standard ViT. + + Args: + x: Optional pre-computed patch embeddings of shape ``(B, L, hidden_dim)``. When + ``None``, tokens come purely from positional encoding. + coords: Per-cell coordinates of shape ``(B, H, W, coord_dim)``. + mask: Optional per-cell fluid mask of shape ``(B, H, W)``. + cond: AdaLN conditioning vector of shape ``(B, hidden_dim)``. Required when the ViT + was built with ``use_conditioning=True`` (the default); must be ``None`` otherwise. + return_tokens: If True, return raw post-FinalLayer tokens plus ``(grid_h, grid_w)`` + instead of the decoded spatial output. + + Returns: + Either ``(B, H, W, out_channels)`` or ``(tokens, (grid_h, grid_w))`` if + ``return_tokens``. + """ + if self.use_conditioning and cond is None: + raise ValueError("ViT was built with use_conditioning=True; `cond` is required.") + if not self.use_conditioning and cond is not None: + raise ValueError("ViT was built with use_conditioning=False; `cond` must be None.") + + # Patchify coords (used for both sincos pos embed and RoPE). + coords_patched = self.pool_patch(coords) # (B, gh, gw, coord_dim) + _, grid_h, grid_w, _ = coords_patched.shape + coords_flat = coords_patched.flatten(1, 2) # (B, L, coord_dim) + + rope_freqs = self.rope(coords_flat) + pos_encoded = self.pos_embedding(coords_flat) + tokens = pos_encoded + + if x is not None: + tokens = tokens + x + + patch_mask: Tensor | None = None + if mask is not None: + patch_mask = self.mask_patchify(mask) + tokens = tokens * patch_mask.unsqueeze(-1).float() + + condition = F.silu(cond) if cond is not None else None + + attn_kwargs = {"freqs": rope_freqs} + if patch_mask is not None: + # Patch mask is also attention mask; (B, L) -> (B, 1, 1, L) so SDPA broadcasts across heads and queries + attn_kwargs["attn_mask"] = patch_mask[:, None, None, :] + + tokens = self.backbone(tokens, attn_kwargs=attn_kwargs, condition=condition) + + tokens = self.final_layer(tokens, condition) + + if return_tokens: + return tokens, (grid_h, grid_w) + + if self.conv_output_head is not None: + decoded: Tensor = self.conv_output_head(tokens, grid_h, grid_w) + return decoded + + return self.unpatchify(tokens, grid_h, grid_w) diff --git a/src/noether/modeling/modules/layers/__init__.py b/src/noether/modeling/modules/layers/__init__.py index 0a20dcc4..41461ac9 100644 --- a/src/noether/modeling/modules/layers/__init__.py +++ b/src/noether/modeling/modules/layers/__init__.py @@ -7,3 +7,9 @@ from .rope_frequency import RopeFrequency from .scalar_conditioner import ScalarsConditioner from .transformer_batchnorm import TransformerBatchNorm +from .vit_layers import ( + AvgPool2DPatchify, + ConvOutputHead, + FinalLayer, + MaskPatchify, +) diff --git a/src/noether/modeling/modules/layers/linear_projection.py b/src/noether/modeling/modules/layers/linear_projection.py index b0c28332..32095943 100644 --- a/src/noether/modeling/modules/layers/linear_projection.py +++ b/src/noether/modeling/modules/layers/linear_projection.py @@ -4,7 +4,7 @@ from torch import nn from noether.core.schemas.modules.layers import LinearProjectionConfig -from noether.modeling.functional.init import init_trunc_normal_zero_bias +from noether.modeling.functional.init import init_trunc_normal_zero_bias, init_xavier_uniform_zero_bias class LinearProjection(nn.Module): @@ -55,6 +55,8 @@ def reset_parameters(self) -> None: pass elif self.init_weights in ["truncnormal", "truncnormal002"]: init_trunc_normal_zero_bias(self.project) + elif self.init_weights == "xavier": + init_xavier_uniform_zero_bias(self.project) elif self.init_weights == "zeros": assert isinstance(self.project, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)) nn.init.zeros_(self.project.weight) diff --git a/src/noether/modeling/modules/layers/vit_layers.py b/src/noether/modeling/modules/layers/vit_layers.py new file mode 100644 index 00000000..141c7fe9 --- /dev/null +++ b/src/noether/modeling/modules/layers/vit_layers.py @@ -0,0 +1,161 @@ +# Copyright © 2025 Emmi AI GmbH. All rights reserved. + +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn + +from noether.modeling.functional.modulation import modulate_scale_shift + + +class AvgPool2DPatchify(nn.Module): + """Tokenize a 2D grid by average-pooling each ``patch_size``×``patch_size`` patch.""" + + def __init__(self, patch_size: int = 16) -> None: + super().__init__() + self.patch_size = patch_size + self.patch = nn.AvgPool2d(kernel_size=patch_size, stride=patch_size) + + def forward(self, x: Tensor) -> Tensor: + """Pool spatial features into patches. + + Args: + x: Input grid with shape ``(B, H, W, C)``. + + Returns: + Pooled patch grid of shape ``(B, H // patch_size, W // patch_size, C)``. + """ + x = rearrange(x, "b h w c -> b c h w") + x = self.patch(x) + return rearrange(x, "b c h w -> b h w c") + + +class MaskPatchify(nn.Module): + """Downsample a boolean mask to patch resolution via max-pooling (``True`` = at least one valid cell).""" + + def __init__(self, patch_size: int) -> None: + super().__init__() + self.patch_size = patch_size + + def forward(self, mask: Tensor) -> Tensor: + """Downsample boolean mask to patch resolution. + + Args: + mask: Boolean mask of shape ``(B, H, W)``. + + Returns: + Flat boolean mask of shape ``(B, (H // patch_size) * (W // patch_size))``. + """ + pooled = F.max_pool2d(mask.float(), kernel_size=self.patch_size, stride=self.patch_size) + return pooled.flatten(1).bool() + + +class FinalLayer(nn.Module): + """Final unpatchify projection with optional AdaLN modulation conditioned on a global vector ``c``.""" + + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + use_modulation: bool = True, + ) -> None: + super().__init__() + self.norm_final = nn.RMSNorm(hidden_size, eps=1e-6, elementwise_affine=True) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation: nn.Linear | None = ( + nn.Linear(hidden_size, 2 * hidden_size, bias=True) if use_modulation else None + ) + + def forward(self, x: Tensor, c: Tensor | None = None) -> Tensor: + """Apply (optionally AdaLN-modulated) norm then linear projection. + + Args: + x: Tokens of shape ``(B, L, hidden_size)``. + c: Conditioning vector of shape ``(B, hidden_size)`` when ``use_modulation=True``; + must be ``None`` when ``use_modulation=False``. The caller is responsible for any + upstream activation (e.g. SiLU) — this layer applies the AdaLN linear directly. + + Returns: + Tensor of shape ``(B, L, patch_size**2 * out_channels)``. + """ + if self.adaLN_modulation is None: + if c is not None: + raise ValueError("FinalLayer was built with use_modulation=False; do not pass `c`.") + return self.linear(self.norm_final(x)) # type: ignore[no-any-return] + if c is None: + raise ValueError("FinalLayer was built with use_modulation=True; a conditioning vector `c` is required.") + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate_scale_shift(self.norm_final(x), scale=scale, shift=shift) + return self.linear(x) # type: ignore[no-any-return] + + +class ConvOutputHead(nn.Module): + """Conv output head decodes tokens to spatial output""" + + def __init__( + self, + hidden_dim: int, + out_channels: int, + patch_size: int, + mid_channels: int = 64, + ) -> None: + super().__init__() + if patch_size < 2 or (patch_size & (patch_size - 1)) != 0: + raise ValueError(f"ConvOutputHead requires patch_size to be a power of 2 >= 2, got {patch_size}") + self.patch_size = patch_size + self.out_channels = out_channels + + factors = self._factorize(patch_size) + self.stages = nn.ModuleList() + + for i, factor in enumerate(factors): + is_first = i == 0 + is_last = i == len(factors) - 1 + ch_in = hidden_dim if is_first else mid_channels + ch_out = out_channels if is_last else mid_channels + + layers: list[nn.Module] = [ + nn.Conv2d(ch_in, ch_in, 3, padding=1), + nn.SiLU(), + ] + if is_first and len(factors) > 1: + layers += [nn.Conv2d(ch_in, ch_in, 3, padding=1), nn.SiLU()] + layers += [ + nn.Conv2d(ch_in, ch_out * factor**2, 1), + nn.PixelShuffle(factor), + ] + self.stages.append(nn.Sequential(*layers)) + + @staticmethod + def _factorize(patch_size: int) -> list[int]: + factors: list[int] = [] + remaining = patch_size + while remaining % 4 == 0: + factors.append(4) + remaining //= 4 + if remaining == 2: + factors.append(2) + return factors + + def forward( + self, + x: Tensor, + grid_h: int, + grid_w: int, + ) -> Tensor: + """Decode tokens to spatial output via cascaded PixelShuffle stages. + + Args: + x: Flattened tokens of shape ``(B, grid_h * grid_w, hidden_dim)``. + grid_h: Patch grid height (``H // patch_size``). + grid_w: Patch grid width (``W // patch_size``). + + Returns: + Spatial tensor of shape ``(B, H, W, out_channels)`` after upsampling. + """ + b = x.shape[0] + x = x.view(b, grid_h, grid_w, -1) + x = rearrange(x, "b h w c -> b c h w") + for stage in self.stages: + x = stage(x) + return rearrange(x, "b c h w -> b h w c") diff --git a/src/noether/modeling/modules/mlp/upactdown_mlp.py b/src/noether/modeling/modules/mlp/upactdown_mlp.py index 88fd6d44..af45ebb9 100644 --- a/src/noether/modeling/modules/mlp/upactdown_mlp.py +++ b/src/noether/modeling/modules/mlp/upactdown_mlp.py @@ -4,7 +4,7 @@ from torch import nn from noether.core.schemas.modules.mlp import UpActDownMLPConfig -from noether.modeling.functional.init import init_trunc_normal_zero_bias +from noether.modeling.functional.init import init_trunc_normal_zero_bias, init_xavier_uniform_zero_bias from noether.modeling.modules.activations import Activation @@ -48,6 +48,8 @@ def reset_parameters(self) -> None: elif self.init_weights == "truncnormal002-identity": self.apply(init_trunc_normal_zero_bias) nn.init.zeros_(self.fc2.weight) + elif self.init_weights == "xavier": + self.apply(init_xavier_uniform_zero_bias) else: raise NotImplementedError( f"Initialization method {self.init_weights} not implemented. " diff --git a/tests/unit/noether/modeling/models/test_vit.py b/tests/unit/noether/modeling/models/test_vit.py new file mode 100644 index 00000000..2b13bb76 --- /dev/null +++ b/tests/unit/noether/modeling/models/test_vit.py @@ -0,0 +1,128 @@ +# Copyright © 2025 Emmi AI GmbH. All rights reserved. + +import pytest +import torch + +from noether.core.schemas.models import ViTConfig +from noether.modeling.models import ViT + +_HIDDEN_DIM = 32 + + +def _make_inputs(B: int = 2, H: int = 16, W: int = 16, coord_dim: int = 2): + return dict( + x=None, + coords=torch.randn(B, H, W, coord_dim), + mask=torch.ones(B, H, W, dtype=torch.bool), + cond=torch.randn(B, _HIDDEN_DIM), + ) + + +def _make_model(**overrides): + cfg = dict( + name="vit_test", + coord_dim=2, + out_channels=4, + patch_size=8, + hidden_dim=_HIDDEN_DIM, + num_heads=2, + depth=2, + ) + cfg.update(overrides) + return ViT(config=ViTConfig(**cfg)) + + +def test_forward_shape_conv_head(): + model = _make_model() + out = model(**_make_inputs()) + assert out.shape == (2, 16, 16, 4) + + +def test_forward_shape_linear_unpatchify(): + model = _make_model(use_conv_output_head=False) + out = model(**_make_inputs()) + assert out.shape == (2, 16, 16, 4) + + +def test_forward_without_mask(): + model = _make_model() + inputs = _make_inputs() + inputs["mask"] = None + out = model(**inputs) + assert out.shape == (2, 16, 16, 4) + + +def test_fully_solid_mask_does_not_nan(): + model = _make_model().eval() + inputs = _make_inputs() + inputs["mask"] = torch.zeros_like(inputs["mask"]) + out = model(**inputs) + assert not torch.isnan(out).any() + + +def test_return_tokens(): + """With use_conv_output_head=True, FinalLayer is configured with patch_size=1 and + out_channels=hidden_dim, so its output is (B, num_patches, hidden_dim).""" + model = _make_model() + tokens, (gh, gw) = model(**_make_inputs(), return_tokens=True) + assert (gh, gw) == (2, 2) # H=16, patch=8 => 2x2 patches + assert tokens.shape == (2, gh * gw, 32) + + +def test_cond_required(): + model = _make_model() + inputs = _make_inputs() + inputs["cond"] = None + with pytest.raises(ValueError, match="cond"): + model(**inputs) + + +def test_arbitrary_grid_shape(): + """The ViT works on rectangular grids as long as patch_size divides H and W.""" + model = _make_model(patch_size=4) + coords = torch.randn(1, 32, 16, 2) + out = model(x=None, coords=coords, cond=torch.randn(1, _HIDDEN_DIM)) + assert out.shape == (1, 32, 16, 4) + + +def test_unconditioned_vit(): + """With ``use_conditioning=False`` the ViT runs without ``cond`` and has no AdaLN machinery.""" + model = _make_model(use_conditioning=False) + # No per-block AdaLN modulation submodule. + for block in model.backbone.blocks: + assert block.modulation is None + # FinalLayer has no AdaLN modulation submodule either. + assert model.final_layer.adaLN_modulation is None + + inputs = _make_inputs() + inputs["cond"] = None + out = model(**inputs) + assert out.shape == (2, 16, 16, 4) + + # Passing cond when conditioning is disabled is an error. + inputs["cond"] = torch.randn(2, _HIDDEN_DIM) + with pytest.raises(ValueError, match="must be None"): + model(**inputs) + + +def test_unconditioned_vit_backward_pass(): + model = _make_model(use_conditioning=False) + inputs = _make_inputs() + inputs["cond"] = None + out = model(**inputs) + out.sum().backward() + grad = model.final_layer.linear.weight.grad + # final_layer.linear is zero-initialized; gradient still flows through it. + assert grad is not None + assert not torch.isnan(grad).any() + + +def test_backward_pass_grads_first_block_modulation(): + """Drops the earlier cond_embedder-grad check: cond_embedder is gone. Verify a representative + learnable parameter (the first block's AdaLN modulation Linear) gets a gradient instead.""" + model = _make_model() + out = model(**_make_inputs()) + out.sum().backward() + grad = model.backbone.blocks[0].modulation.project.weight.grad # type: ignore[union-attr] + assert grad is not None + assert not torch.isnan(grad).any()