Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ What this is good for
Despite the binary name, treat ``noether-eval`` as the generic
inference/evaluation entry point — it has no eval-only logic baked in.

For **interactive** work (notebooks, prototyping, debugging) where you don't
need callbacks, logging, or reproducibility, see :ref:`loading-a-run-in-python`
below — that path skips Hydra and the trainer entirely.

Quick start
-----------

Expand Down Expand Up @@ -263,6 +267,73 @@ itself, or with ``PYTHONPATH`` set:
Each run directory also contains a ``code.tar.gz`` snapshot of the codebase at
training time, useful when the source tree has drifted.

.. _loading-a-run-in-python:

Loading a run in Python (notebooks, prototyping)
------------------------------------------------

When you want to poke at a trained model in a notebook — inspect predictions
on a single sample, prototype a new visualization, debug a head-scratcher —
``noether-eval`` is overkill: it stands up Hydra, the trainer, the tracker,
and the callback loop just to give you a model and a dataset.

The :mod:`noether.inference` package exposes a single :class:`Run` class
for that case. Construction reads the run's ``hp_resolved.yaml`` and
validates it; the model and dataset are built on demand — no Hydra, no
trainer:

.. code-block:: python

from noether.inference import Run

run = Run("/path/to/outputs/2026-01-10_abc12")

# Optional: patch the config before building artifacts —
# typically to point dataset paths at this machine's data.
for ds_cfg in run.config.datasets.values():
ds_cfg.root = "/local/path/to/data"

dataset = run.dataset("test")
model = run.model(checkpoint="latest", device="cuda")
# checkpoint examples: "latest", "best_model.<metric>", "E10", "latest_ema=0.9999"

``Run`` exposes three lazy methods, all independent — you don't have to
call them in order, and you don't have to call them all. Pick whichever
fit your use case:

- ``run.model(...)`` — the trained model with checkpoint weights loaded.
Only needs the run dir; works on **any** tensor dict you can construct.
- ``run.normalizers(split)`` — the field normalizers (e.g. for converting
model predictions back to physical units). Built without instantiating
the dataset; the data files do not need to be present.
- ``run.dataset(split)`` — the dataset, with the same collator the trainer
wired. **This** is the one that needs the original data files on disk.

That separation matters in particular for the **bring-your-own-data** flow
— applying a trained model to a CAD mesh, a custom point cloud, or any
data that isn't packaged as a noether ``Dataset``:

.. code-block:: python

run = Run("/path/to/outputs/2026-01-10_abc12")
model = run.model(device="cuda")
norms = run.normalizers()

# You build the input dict yourself, matching the model's forward signature.
with torch.inference_mode():
pred = model(**my_inputs)

# Same normalizers the training data used — denormalize the prediction.
pressure_phys = norms["surface_pressure"].inverse(pred["surface_pressure"])

This is **not** a substitute for ``noether-eval``: there are no metrics,
no callbacks, no run output directory, and no reproducibility guarantees.
Use it for interactive work; use ``noether-eval`` for everything else.

A worked example — load a trained AB-UPT / DrivAerML run, do both the
"standard" and the "bring your own data" flow, and plot predictions vs.
ground truth — lives at ``notebooks/ab_upt_drivaerml_inference.ipynb``.

A note on ``--help`` and the binary name
----------------------------------------

Expand Down
13 changes: 13 additions & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Notebooks

Interactive notebooks for exploring noether models and datasets.

These are **not** part of CI — they require trained run directories and recipe-specific datasets on disk.
They are intended for prototyping, debugging, and visualization, not reproducible eval. For reproducible eval runs,
use `noether-eval`.

## Index

- [`ab_upt_drivaerml_inference.ipynb`](ab_upt_drivaerml_inference.ipynb) —
Load a trained AB-UPT / DrivAerML model and its dataset interactively, run a forward pass, and visualize predictions
in physical units.
584 changes: 584 additions & 0 deletions notebooks/ab_upt_drivaerml_inference.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/noether/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ class CheckpointKeys:
""" The state dicts of the callbacks. """
GRAD_SCALER = "grad_scaler"
""" The state dict of the grad scaler (if used). """
NORMALIZER_CONFIGS = "normalizer_configs"
""" Per-field preprocessor configs for normalization (serialized dict of the test split's ``dataset_normalizers``). """
NORMALIZER_STATISTICS = "normalizer_statistics"
""" Resolved statistics dict (means/stds/bounds) loaded from the dataset class's ``STATS_FILE`` at write time. """
49 changes: 49 additions & 0 deletions src/noether/core/writers/checkpoint_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,44 @@
from noether.core.schemas.models import ModelBaseConfig
from noether.core.types import CheckpointKeys
from noether.core.utils.training import UpdateCounter
from noether.data.container import DataContainer

if TYPE_CHECKING:
from noether.core.models import ModelBase
from noether.training.trainers import BaseTrainer


def _build_normalizer_payload(data_container: DataContainer | None) -> tuple[dict | None, dict | None]:
"""Pull per-field normalizer configs and resolved statistics from the trainer's
:class:`~noether.data.container.DataContainer`.

Sources the test split's ``dataset_normalizers`` (per-field preprocessor specs) off the dataset's own config
(``Dataset.config``) and the resolved statistics from :meth:`Dataset.fetch_statistics`, then embeds both
in the checkpoint so that :func:`~noether.inference.load_normalizers_from_checkpoint` can reconstruct field
normalizers without ``hp_resolved.yaml`` or the recipe's stats file on the loading machine.

Falls back to the first available split if ``test`` isn't configured.
Returns ``(None, None)`` if no DataContainer, no datasets, or no ``dataset_normalizers`` entry.
"""
if not isinstance(data_container, DataContainer) or not data_container.datasets:
return None, None
split = "test" if "test" in data_container.datasets else next(iter(data_container.datasets))
dataset = data_container.datasets[split]
normalizer_configs = getattr(getattr(dataset, "config", None), "dataset_normalizers", None)
if not normalizer_configs:
return None, None

configs_dump: dict[str, Any] = {}
for key, val in normalizer_configs.items():
if isinstance(val, list):
configs_dump[key] = [c.model_dump() for c in val]
else:
configs_dump[key] = val.model_dump()

statistics = dataset.fetch_statistics() if hasattr(dataset, "fetch_statistics") else None
return configs_dump, statistics


class CheckpointWriter:
"""Class to easily write checkpoints in a structured way to the disk.

Expand Down Expand Up @@ -65,6 +97,8 @@ def save_model_checkpoint(
state_dict: dict[str, Any],
model_config: ModelBaseConfig | None = None,
model_info: str | None = None,
normalizer_configs: dict[str, Any] | None = None,
normalizer_statistics: dict[str, Any] | None = None,
**extra,
) -> None:
"""Save a checkpoint to disk.
Expand Down Expand Up @@ -100,6 +134,11 @@ def save_model_checkpoint(
raise RuntimeError(f"An unexpected error occurred during model_dump: {e}") from e
output_dict[CheckpointKeys.CONFIG_KIND] = model_config.config_kind

if normalizer_configs is not None:
output_dict[CheckpointKeys.NORMALIZER_CONFIGS] = normalizer_configs
if normalizer_statistics is not None:
output_dict[CheckpointKeys.NORMALIZER_STATISTICS] = normalizer_statistics

# Construct model URI with optional model_info; follows structure: {model_name}_{model_info}_cp={checkpoint}_model.th
model_info = f"_{model_info}" if model_info else ""
model_uri = self.path_provider.checkpoint_path / f"{model_name}{model_info}_cp={checkpoint_tag}_model.th"
Expand Down Expand Up @@ -139,6 +178,8 @@ def save(
# NOTE: this has to be called from all ranks because random states are gathered to rank0
trainer_sd = trainer.state_dict() if trainer is not None else None

normalizer_configs, normalizer_statistics = _build_normalizer_payload(getattr(trainer, "data_container", None))

if is_rank0():
self._save_separate_models(
model=model,
Expand All @@ -150,6 +191,8 @@ def save(
model_names_to_save=model_names_to_save,
save_frozen_weights=save_frozen_weights,
model_info=model_info,
normalizer_configs=normalizer_configs,
normalizer_statistics=normalizer_statistics,
)

if trainer_sd is not None:
Expand All @@ -176,6 +219,8 @@ def _save_separate_models(
save_frozen_weights: bool,
model_info: str | None = None,
model_name: str | None = None,
normalizer_configs: dict[str, Any] | None = None,
normalizer_statistics: dict[str, Any] | None = None,
):
if isinstance(model, DistributedDataParallel):
raise RuntimeError("DistributedDataParallel models should be unwrapped before saving.")
Expand Down Expand Up @@ -203,6 +248,8 @@ def _save_separate_models(
model_info=model_info,
state_dict=model.state_dict(),
model_config=getattr(model, "model_config", None),
normalizer_configs=normalizer_configs,
normalizer_statistics=normalizer_statistics,
)

# --- Save Optimizer ---
Expand Down Expand Up @@ -234,6 +281,8 @@ def _save_separate_models(
save_latest_optim=save_latest_optim,
model_names_to_save=model_names_to_save,
save_frozen_weights=save_frozen_weights,
normalizer_configs=normalizer_configs,
normalizer_statistics=normalizer_statistics,
)
else:
raise NotImplementedError
4 changes: 4 additions & 0 deletions src/noether/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
# Copyright © 2025 Emmi AI GmbH. All rights reserved.

from noether.inference.run import Run, load_model_from_checkpoint, load_normalizers_from_checkpoint

__all__ = ["Run", "load_model_from_checkpoint", "load_normalizers_from_checkpoint"]
33 changes: 2 additions & 31 deletions src/noether/inference/cli/main_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import logging
import os
import sys
import tempfile
from pathlib import Path

import hydra
import yaml
from omegaconf import DictConfig, OmegaConf

from noether.inference.run import sanitize_hp_resolved
from noether.inference.runners.inference_runner import InferenceRunner
from noether.training.cli import setup_hydra

Expand All @@ -19,35 +19,6 @@
_LEGACY_NAV_KEYS = ("input_dir", "run_id", "stage_name")


def _to_plain_python(obj):
"""Recursively convert tuples/sets to lists so the result round-trips through ``yaml.safe_dump`` (and therefore
Hydra's safe loader) without ``!!python/...`` tags."""
if isinstance(obj, dict):
return {k: _to_plain_python(v) for k, v in obj.items()}
if isinstance(obj, (tuple, set, frozenset)):
return [_to_plain_python(v) for v in obj]
if isinstance(obj, list):
return [_to_plain_python(v) for v in obj]
return obj


def _sanitize_hp_resolved_for_hydra(hp_resolved_path: Path) -> Path:
"""Rewrite ``hp_resolved.yaml`` to a temp file with no Python-specific tags.

Resolved configs are dumped via ``yaml.dump``, which emits ``!!python/tuple`` for tuple values (notably
``dataset_statistics``). Hydra loads configs with a safe loader, so we re-serialize using ``yaml.safe_dump``
over a list-converted dict before handing the path to Hydra.
"""
with open(hp_resolved_path) as f:
config = yaml.full_load(f)

tmp_dir = Path(tempfile.mkdtemp(prefix="noether_eval_"))
safe_path = tmp_dir / "hp_resolved.yaml"
with open(safe_path, "w") as f:
yaml.safe_dump(_to_plain_python(config), f, sort_keys=False)
return safe_path


def _pop_eval_path_args(argv: list[str]) -> tuple[dict[str, str], list[str]]:
"""Extract path-navigation args from ``argv``.

Expand Down Expand Up @@ -131,7 +102,7 @@ def _inject_hp_resolved_into_argv() -> None:
"Make sure run_dir points at a training run output directory "
"(typically output_path/run_id[/stage_name])."
)
safe_hp = _sanitize_hp_resolved_for_hydra(hp_resolved)
safe_hp = sanitize_hp_resolved(hp_resolved)

# `hp_resolved.yaml` is dumped with `exclude_unset=True`, so values that were generated at training-time
# (e.g. `run_id`) are absent. Infer them from the run_dir path and inject as forced overrides so the eval run
Expand Down
Loading
Loading