Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/aero_cfd/scripts/train_ahmedml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2026 Emmi AI GmbH. All rights reserved.

from aero_cfd.presets import AhmedMLPreset

from noether.core.distributed.utils import accelerator_to_device
from noether.training.runners import HydraRunner

Expand Down
1 change: 1 addition & 0 deletions recipes/aero_cfd/scripts/train_drivaerml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2026 Emmi AI GmbH. All rights reserved.

from aero_cfd.presets import DrivAerMLPreset

from noether.core.distributed.utils import accelerator_to_device
from noether.training.runners import HydraRunner

Expand Down
1 change: 1 addition & 0 deletions recipes/aero_cfd/scripts/train_drivaernet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2026 Emmi AI GmbH. All rights reserved.

from aero_cfd.presets import DrivAerNetPreset

from noether.core.distributed.utils import accelerator_to_device
from noether.training.runners import HydraRunner

Expand Down
1 change: 1 addition & 0 deletions recipes/aero_cfd/scripts/train_emmi_wing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

from aero_cfd.presets import EmmiWingPreset

from noether.core.distributed.utils import accelerator_to_device
from noether.data.datasets.cfd.emmi_wing.dataset_hf import EmmiWingHFDataset
from noether.training.runners import HydraRunner
Expand Down
1 change: 1 addition & 0 deletions recipes/aero_cfd/scripts/train_shapenet_car.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2026 Emmi AI GmbH. All rights reserved.

from aero_cfd.presets import ShapeNetCarPreset

from noether.core.distributed.utils import accelerator_to_device
from noether.training.runners import HydraRunner

Expand Down
1 change: 0 additions & 1 deletion recipes/aero_cfd/showcase/utils/forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pathlib import Path

import torch

from aero_cfd.utils.drag_lift import FlowConditions, compute_force_coefficients


Expand Down
2 changes: 2 additions & 0 deletions src/noether/core/schemas/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .transformer import TransformerConfig
from .transolver import TransolverConfig, TransolverPlusPlusConfig
from .upt import UPTConfig
from .vit import ViTConfig

__all__ = [
"ModelBaseConfig",
Expand All @@ -13,4 +14,5 @@
"TransolverPlusPlusConfig",
"TransformerConfig",
"UPTConfig",
"ViTConfig",
]
60 changes: 60 additions & 0 deletions src/noether/core/schemas/models/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright © 2026 Emmi AI GmbH. All rights reserved.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this one too src/noether/modeling/models/vit.py directly?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the other models its still in this place. i think if we do this we should move all of the model schemas together in a separate PR maybe?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think it should be done at once

from pydantic import ConfigDict, Field, computed_field

from noether.core.schemas.modules.blocks import TransformerBlockConfig

from .base import ModelBaseConfig


class ViTConfig(ModelBaseConfig):
"""Configuration for ViT model"""

model_config = ConfigDict(extra="forbid")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the transformer block config not used?

coord_dim: int = Field(..., ge=1)
"""Coordinate dimensionality of the input grid (2 for 2D, 3 for 3D)."""

out_channels: int = Field(..., ge=1)
"""Number of output channels emitted per spatial cell."""

patch_size: int = Field(..., ge=2)
"""Patch side length in cells. The grid resolution must be divisible by this value."""

hidden_dim: int = Field(192, ge=1)
"""Token hidden dimension throughout the transformer stack."""

num_heads: int = Field(6, ge=1)
"""Number of attention heads in each transformer block."""

depth: int = Field(10, ge=1)
"""Number of stacked transformer blocks."""

mlp_ratio: int = Field(4, ge=1)
"""FFN expansion factor inside each transformer block."""

use_conditioning: bool = True
"""If True, enable AdaLN-Zero conditioning (forward requires ``cond``); if False, plain ViT (``cond`` must be ``None``)."""

token_dropout: float = Field(0.0, ge=0.0, le=1.0)
"""Per-patch token dropout probability used during training."""

attn_drop: float = Field(0.0, ge=0.0, le=1.0)
"""Dropout probability inside attention."""

use_conv_output_head: bool = True
"""If True, decode via a cascaded PixelShuffle conv head; if False, decode via a linear unpatchify."""

@computed_field # type: ignore[prop-decorator]
@property
def transformer_block_config(self) -> TransformerBlockConfig:
return TransformerBlockConfig(
hidden_dim=self.hidden_dim,
num_heads=self.num_heads,
mlp_expansion_factor=self.mlp_ratio,
attention_constructor="dot_product",
condition_dim=self.hidden_dim if self.use_conditioning else None,
use_rope=True,
dropout=self.attn_drop,
init_weights="xavier",
)
4 changes: 3 additions & 1 deletion src/noether/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import Literal

InitWeightsMode = Literal["truncnormal002", "torch", "truncnormal", "truncnormal002-identity", "torchs", "zeros"]
InitWeightsMode = Literal[
"truncnormal002", "torch", "truncnormal", "truncnormal002-identity", "torchs", "zeros", "xavier"
]

ActivationTypes = Literal["GELU", "TANH", "SIGMOID", "RELU", "LEAKY_RELU", "SOFTPLUS", "ELU", "SILU"]

Expand Down
14 changes: 14 additions & 0 deletions src/noether/modeling/functional/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def init_trunc_normal_zero_bias(layer_module: nn.Module, std: float = 0.02) -> N
nn.init.constant_(layer_module.bias, 0.0)


def init_xavier_uniform_zero_bias(layer_module: nn.Module) -> None:
"""Initialize the weight tensor of a nn.Module instance using Xavier uniform with a zero bias vector.

Args:
layer_module: An nn.Module instance, either a Linear or Conv layer.
"""
if isinstance(layer_module, ALL_LAYERS):
nn.init.xavier_uniform_(layer_module.weight)
if layer_module.bias is not None:
nn.init.constant_(layer_module.bias, 0.0)


def apply_init_method(
module: torch.nn.Module,
proj_weight: torch.Tensor,
Expand All @@ -61,5 +73,7 @@ def apply_init_method(
elif init_method == "truncnormal002-identity":
module.apply(init_trunc_normal_zero_bias)
torch.nn.init.zeros_(proj_weight)
elif init_method == "xavier":
module.apply(init_xavier_uniform_zero_bias)
else:
raise NotImplementedError(f"Weight initialization method {init_method} not implemented for DotProductAttention")
2 changes: 2 additions & 0 deletions src/noether/modeling/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from .transformer import Transformer
from .transolver import Transolver
from .upt import UPT
from .vit import ViT

__all__ = [
"AnchoredBranchedUPT",
"Transformer",
"Transolver",
"UPT",
"ViT",
"AeroABUPT",
"AeroTransformer",
"AeroTransformerConfig",
Expand Down
5 changes: 4 additions & 1 deletion src/noether/modeling/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,20 @@ def forward(
self,
x: torch.Tensor,
attn_kwargs: dict[str, torch.Tensor],
condition: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass of the Transformer model.

Args:
x: Input tensor of shape (batch_size, seq_len, hidden_dim).
attn_kwargs: Additional arguments for the attention mechanism.
condition: Optional conditioning vector of shape (batch_size, condition_dim) consumed
by each block's AdaLN-Zero modulation. ``None`` (default) for unconditioned models.
Returns:
torch.Tensor: Output tensor after processing through the Transformer model.
"""

for block in self.blocks:
x, _ = block(x, attn_kwargs=attn_kwargs)
x, _ = block(x, condition=condition, attn_kwargs=attn_kwargs)

return x
188 changes: 188 additions & 0 deletions src/noether/modeling/models/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright © 2026 Emmi AI GmbH. All rights reserved.

from __future__ import annotations

import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, nn

from noether.core.schemas.models import TransformerConfig, ViTConfig
from noether.core.schemas.modules import (
ContinuousSincosEmbeddingConfig,
RopeFrequencyConfig,
)
from noether.modeling.models.transformer import Transformer
from noether.modeling.modules.layers import (
AvgPool2DPatchify,
ContinuousSincosEmbed,
ConvOutputHead,
FinalLayer,
MaskPatchify,
RopeFrequency,
)


class ViT(nn.Module):
"""Vision Transformer for spatial regression on continuous-coordinate grids.

Based on the ViT paper (https://arxiv.org/pdf/2010.11929) with several modifications, such as:

- Continuous coordinate inputs with sincos positional embedding and RoPE (vs. learned 1D position embeddings).
- Optional AdaLN-Zero conditioning, à la DiT (https://arxiv.org/abs/2212.09748).
- RMSNorm and QK-norm in attention (vs. LayerNorm only).
"""

def __init__(self, config: ViTConfig) -> None:
"""
Comment thread
kinggongzilla marked this conversation as resolved.
Args:
config: Configuration for the ViT model. See
:class:`~noether.core.schemas.models.ViTConfig` for available options.
"""
super().__init__()

self.coord_dim = config.coord_dim
self.out_channels = config.out_channels
self.patch_size = config.patch_size
self.hidden_dim = config.hidden_dim
self.num_heads = config.num_heads
self.token_dropout = config.token_dropout
self.use_conditioning = config.use_conditioning

# patchify
self.pool_patch = AvgPool2DPatchify(patch_size=config.patch_size)
self.mask_patchify = MaskPatchify(patch_size=config.patch_size)

# positional encoding
self.pos_embedding = ContinuousSincosEmbed(
config=ContinuousSincosEmbeddingConfig(hidden_dim=config.hidden_dim, input_dim=config.coord_dim), # type: ignore[call-arg]
)
self.rope = RopeFrequency(
config=RopeFrequencyConfig( # type: ignore[call-arg]
input_dim=config.coord_dim,
hidden_dim=config.hidden_dim // config.num_heads,
),
)

self.backbone = Transformer(
config=TransformerConfig(
name="vit_backbone",
hidden_dim=config.hidden_dim,
depth=config.depth,
transformer_block_config=config.transformer_block_config,
)
)

# output heads
self.use_conv_output_head = config.use_conv_output_head
if config.use_conv_output_head:
self.final_layer = FinalLayer(
config.hidden_dim, 1, config.hidden_dim, use_modulation=config.use_conditioning
)
self.conv_output_head: ConvOutputHead | None = ConvOutputHead(
config.hidden_dim, config.out_channels, config.patch_size
)
else:
self.final_layer = FinalLayer(
config.hidden_dim, config.patch_size, config.out_channels, use_modulation=config.use_conditioning
)
self.conv_output_head = None

self.initialize_weights()

def initialize_weights(self) -> None:
"""Initialize backbone weights"""
if self.final_layer.adaLN_modulation is not None:
nn.init.constant_(self.final_layer.adaLN_modulation.weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation.bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)

# Zero the last conv of the ConvOutputHead so decoding starts near identity.
if self.conv_output_head is not None:
last_stage = self.conv_output_head.stages[-1]
if not isinstance(last_stage, nn.Sequential):
raise ValueError("Expected last stage of ConvOutputHead to be nn.Sequential.")
for module in reversed(list(last_stage)):
if isinstance(module, nn.Conv2d):
nn.init.constant_(module.weight, 0)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
break

def unpatchify(self, x: Tensor, grid_h: int, grid_w: int) -> Tensor:
"""Linear unpatchify: ``(B, L, p²·C_out) → (B, H, W, C_out)``."""
p = self.patch_size
c = self.out_channels
b, seq_len, patch_dim = x.shape
if seq_len != grid_h * grid_w:
raise ValueError(f"Sequence length {seq_len} doesn't match grid {grid_h}*{grid_w}.")
if patch_dim != p * p * c:
raise ValueError(f"Patch dim mismatch: expected {p * p * c}, got {patch_dim}.")
x = x.view(b, grid_h, grid_w, p, p, c)
return rearrange(x, "b h w p q c -> b (h p) (w q) c")

def forward(
self,
x: Tensor | None,
coords: Tensor,
mask: Tensor | None = None,
cond: Tensor | None = None,
return_tokens: bool = False,
) -> Tensor | tuple[Tensor, tuple[int, int]]:
"""Run the standard ViT.

Args:
x: Optional pre-computed patch embeddings of shape ``(B, L, hidden_dim)``. When
``None``, tokens come purely from positional encoding.
coords: Per-cell coordinates of shape ``(B, H, W, coord_dim)``.
mask: Optional per-cell fluid mask of shape ``(B, H, W)``.
cond: AdaLN conditioning vector of shape ``(B, hidden_dim)``. Required when the ViT
was built with ``use_conditioning=True`` (the default); must be ``None`` otherwise.
return_tokens: If True, return raw post-FinalLayer tokens plus ``(grid_h, grid_w)``
instead of the decoded spatial output.

Returns:
Either ``(B, H, W, out_channels)`` or ``(tokens, (grid_h, grid_w))`` if
``return_tokens``.
"""
if self.use_conditioning and cond is None:
raise ValueError("ViT was built with use_conditioning=True; `cond` is required.")
if not self.use_conditioning and cond is not None:
raise ValueError("ViT was built with use_conditioning=False; `cond` must be None.")

# Patchify coords (used for both sincos pos embed and RoPE).
coords_patched = self.pool_patch(coords) # (B, gh, gw, coord_dim)
_, grid_h, grid_w, _ = coords_patched.shape
coords_flat = coords_patched.flatten(1, 2) # (B, L, coord_dim)

rope_freqs = self.rope(coords_flat)
pos_encoded = self.pos_embedding(coords_flat)
tokens = pos_encoded

if x is not None:
tokens = tokens + x

patch_mask: Tensor | None = None
if mask is not None:
patch_mask = self.mask_patchify(mask)
tokens = tokens * patch_mask.unsqueeze(-1).float()

condition = F.silu(cond) if cond is not None else None

attn_kwargs = {"freqs": rope_freqs}
if patch_mask is not None:
# Patch mask is also attention mask; (B, L) -> (B, 1, 1, L) so SDPA broadcasts across heads and queries
attn_kwargs["attn_mask"] = patch_mask[:, None, None, :]

tokens = self.backbone(tokens, attn_kwargs=attn_kwargs, condition=condition)

tokens = self.final_layer(tokens, condition)

if return_tokens:
return tokens, (grid_h, grid_w)

if self.conv_output_head is not None:
decoded: Tensor = self.conv_output_head(tokens, grid_h, grid_w)
return decoded

return self.unpatchify(tokens, grid_h, grid_w)
6 changes: 6 additions & 0 deletions src/noether/modeling/modules/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@
from .rope_frequency import RopeFrequency
from .scalar_conditioner import ScalarsConditioner
from .transformer_batchnorm import TransformerBatchNorm
from .vit_layers import (
AvgPool2DPatchify,
ConvOutputHead,
FinalLayer,
MaskPatchify,
)
Loading
Loading