Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions cosmos_framework/configs/toml_config/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,15 @@ class SFTExperimentConfig(BaseModel):
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
checkpoint: CheckpointConfig = Field(default_factory=CheckpointConfig)
dataloader_train: DataloaderTrainConfig = Field(default_factory=DataloaderTrainConfig)
custom: dict[str, Any] = Field(
default_factory=dict,
description=(
"Free-form, project-owned escape hatch. Arbitrary nested content "
"passes through verbatim — the framework never validates inside it. "
"Injected onto the loaded config as ``config.custom`` after Hydra "
"resolution; specify concrete values here (no ${...} interpolation)."
),
)


# ---------------------------------------------------------------------------
Expand All @@ -693,16 +702,17 @@ def load_experiment_from_toml(
["optimizer.lr=1e-5", "trainer.max_iter=200"]
["model.config.parallelism.data_parallel_shard_degree=4"]

Calls ``cosmos_framework.utils.config.load_config`` which:
The load then:

1. Imports the base config module and runs ``make_config()``. This
registers every config group (model, ema, tokenizer, ...) and imports
all experiment modules so their ``cs.store(group="experiment", ...)``
side-effects fire.
2. Runs ``override(config, overrides)`` — Hydra ``compose`` then resolves
the ``experiment=<name>`` selector against ``ConfigStore`` and applies
the dotted-path overrides we generated from the TOML, followed by
``extra_overrides``.
1. Runs ``load_config`` — imports the base config module, runs
``make_config()`` (registers config groups + experiment modules), and
lets Hydra ``compose`` resolve the ``experiment=<name>`` selector and
apply the dotted-path overrides, followed by ``extra_overrides``.
2. Injects the TOML's ``[custom]`` table (if any) verbatim onto
``config.custom`` *after* loading — kept out of ``build_hydra_overrides``
so it lands as-is, not per-leaf-remapped. Because this happens after
Hydra resolution, ``[custom]`` must hold concrete values; ``${...}``
interpolation against ``custom`` is not supported.

Returns the merged ``Config`` instance, ready for ``launch()``.
"""
Expand Down Expand Up @@ -740,4 +750,10 @@ def load_experiment_from_toml(
# Import lazily so this module stays cheap to import in non-training contexts.
from cosmos_framework.utils.config import load_config

return load_config(base_config_path, overrides)
config = load_config(base_config_path, overrides)

# Inject [custom] verbatim after Hydra resolution. Kept off the base config
# schema so the framework-owned hydra configs stay untouched; lands as a
# plain dict reachable via config.custom.
config.custom = raw.get("custom", {})
return config
194 changes: 194 additions & 0 deletions cosmos_framework/configs/toml_config/sft_config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""Loader tests for the free-form ``[custom]`` escape-hatch section of the SFT TOML."""

from __future__ import annotations

from pathlib import Path

import pytest
from pydantic import ValidationError

from cosmos_framework.configs.toml_config.sft_config import SFTExperimentConfig
from cosmos_framework.configs.toml_config.toml_config_helper import build_hydra_overrides

# Representative payload: scalars, a nested sub-table, and an array-of-tables.
_CUSTOM_PAYLOAD = {
"scalar_int": 5,
"scalar_str": "hello",
"flag": True,
"ratio": 0.3,
"sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}},
"items": [
{"path": "/data/a", "weight": 1.0},
{"path": "/data/b", "weight": 2.0},
],
}


# --------------------------------------------------------------------------- #
# 1. pydantic schema validation #
# --------------------------------------------------------------------------- #
class TestSchemaValidation:
def test_custom_section_validates_arbitrary_nested_content(self) -> None:
"""Arbitrary nested [custom] content passes through untouched."""
raw = {
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"custom": _CUSTOM_PAYLOAD,
}
cfg = SFTExperimentConfig.model_validate(raw)
# The framework stores it verbatim — no coercion, no inner validation.
assert cfg.custom == _CUSTOM_PAYLOAD

def test_no_custom_section_defaults_empty(self) -> None:
cfg = SFTExperimentConfig.model_validate({"job": {"task": "vfm", "experiment": "vision_sft_nano"}})
assert cfg.custom == {}

def test_unknown_top_level_key_raises(self) -> None:
"""Any unknown top-level section that is NOT `custom` still raises."""
with pytest.raises(ValidationError):
SFTExperimentConfig.model_validate(
{
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"bogus_section": {"x": 1},
}
)

def test_unknown_key_inside_optimizer_raises(self) -> None:
"""A typo inside a KNOWN section is still a hard error (extra='forbid')."""
with pytest.raises(ValidationError):
SFTExperimentConfig.model_validate(
{
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"optimizer": {"lr": 1.0e-4, "not_a_real_key": 1},
}
)

def test_custom_does_not_loosen_sibling_validation(self) -> None:
"""Presence of [custom] must not relax extra='forbid' elsewhere."""
with pytest.raises(ValidationError):
SFTExperimentConfig.model_validate(
{
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"custom": _CUSTOM_PAYLOAD,
"trainer": {"max_iter": 10, "typo_here": True},
}
)


# --------------------------------------------------------------------------- #
# 2. build_hydra_overrides must NOT emit [custom] as per-leaf overrides #
# --------------------------------------------------------------------------- #
class TestBuildHydraOverrides:
def test_custom_not_emitted_as_overrides(self) -> None:
raw = {
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"optimizer": {"lr": 1.0e-5},
"custom": _CUSTOM_PAYLOAD,
}
overrides = build_hydra_overrides(raw)
# Nothing under custom (verbatim or remapped) should appear.
assert all("custom" not in o for o in overrides), overrides

def test_other_keys_still_emitted(self) -> None:
raw = {
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"optimizer": {"lr": 1.0e-5},
"custom": {"a": 1},
}
overrides = build_hydra_overrides(raw)
assert "experiment=vision_sft_nano" in overrides
assert any(o.startswith("optimizer.lr=") for o in overrides), overrides


# --------------------------------------------------------------------------- #
# 3. end-to-end load_experiment_from_toml on the shipped vision_sft_nano recipe #
# --------------------------------------------------------------------------- #
_BASE_TOML = """\
[job]
task = "vfm"
experiment = "vision_sft_nano"
project = "cosmos3"
group = "sft"
name = "sft_config_custom_test"
wandb_mode = "disabled"

[model.tokenizer]
vae_path = "${oc.env:WAN_VAE_PATH}"

[checkpoint]
load_path = "${oc.env:BASE_CHECKPOINT_PATH}"
"""

_CUSTOM_TOML_BLOCK = """\

[custom]
scalar_int = 5
scalar_str = "hello"
flag = true
ratio = 0.3

[custom.sampling]
bug_ratio = 0.3

[custom.sampling.nested]
deep = 1

[[custom.items]]
path = "/data/a"
weight = 1.0

[[custom.items]]
path = "/data/b"
weight = 2.0
"""


def _load_or_skip(toml_path: Path):
"""Run the real loader, skipping if the training stack can't be imported."""
from cosmos_framework.configs.toml_config.sft_config import load_experiment_from_toml

try:
return load_experiment_from_toml(str(toml_path))
except ImportError as exc: # pragma: no cover — env-dependent
pytest.skip(f"training stack not importable here: {exc!r}")


@pytest.fixture
def _dummy_recipe_env(monkeypatch: pytest.MonkeyPatch) -> None:
# vision_sft_nano interpolates these env vars into path strings at resolve time.
monkeypatch.setenv("DATASET_PATH", "/tmp/dummy_dataset")
monkeypatch.setenv("WAN_VAE_PATH", "/tmp/dummy_vae.pth")
monkeypatch.setenv("BASE_CHECKPOINT_PATH", "/tmp/dummy_ckpt")


class TestEndToEndLoader:
def test_load_with_custom_section(self, tmp_path: Path, _dummy_recipe_env: None) -> None:
toml_path = tmp_path / "with_custom.toml"
toml_path.write_text(_BASE_TOML + _CUSTOM_TOML_BLOCK)

config = _load_or_skip(toml_path)

expected = {
"scalar_int": 5,
"scalar_str": "hello",
"flag": True,
"ratio": 0.3,
"sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}},
"items": [
{"path": "/data/a", "weight": 1.0},
{"path": "/data/b", "weight": 2.0},
],
}
# Injected verbatim as a plain dict after Hydra resolution, so a project
# can run MyProjectConfig.model_validate(config.custom) directly.
assert config.custom == expected

def test_load_without_custom_section_defaults_empty(self, tmp_path: Path, _dummy_recipe_env: None) -> None:
toml_path = tmp_path / "no_custom.toml"
toml_path.write_text(_BASE_TOML)

config = _load_or_skip(toml_path)

assert config.custom == {}
3 changes: 3 additions & 0 deletions cosmos_framework/configs/toml_config/toml_config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def build_hydra_overrides(toml_dict: dict) -> list[str]:

overlay = dict(toml_dict)
overlay["job"] = job
# [custom] lands verbatim on config.custom (see load_experiment_from_toml),
# so it must not be per-leaf-remapped into Hydra overrides here.
overlay.pop("custom", None)

for top_key, val in overlay.items():
_emit_with_remap(overrides, [top_key], val, rules)
Expand Down
13 changes: 11 additions & 2 deletions cosmos_framework/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,12 @@ def validate(self) -> None:
assert self.job.name != ""


def load_config(config_path: str, opts: list[str], enable_one_logger: bool = False) -> Config:
def load_config(
config_path: str,
opts: list[str],
enable_one_logger: bool = False,
) -> Config:
"""Load a config from a ``.yaml`` or ``.py`` path and apply ``opts``."""
from cosmos_framework.utils.serialization import from_yaml, load_callable

t1 = time.monotonic_ns()
Expand Down Expand Up @@ -549,7 +554,11 @@ def load_config(config_path: str, opts: list[str], enable_one_logger: bool = Fal
return config


def _load_py_config(config_path: str, opts: list[str], validate: bool = True) -> Config:
def _load_py_config(
config_path: str,
opts: list[str],
validate: bool = True,
) -> Config:
# NOTE: circular dependency
from cosmos_framework.utils.config_helper import get_config_module, override

Expand Down
28 changes: 26 additions & 2 deletions docs/sft_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ______________________________________________________________________
- [`[trainer.callbacks.grad_clip]`](#trainercallbacksgrad_clip)
- [`[checkpoint]`](#checkpoint)
- [`[dataloader_train]`](#dataloader_train)
- [`[custom]` (free-form escape hatch)](#custom-free-form-escape-hatch)
- [Cross-cutting behaviors](#cross-cutting-behaviors)
- [`"???"` (MISSING) sentinel](#-missing-sentinel)
- [Env interpolation](#env-interpolation)
Expand Down Expand Up @@ -60,6 +61,7 @@ After validation, the TOML dict is converted to a Hydra override list by [`build
[trainer.callbacks.grad_clip] # clip_norm + force_finite
[checkpoint] # load_path, save_iter, key-skip blocklist
[dataloader_train] # top-level scalars only
[custom] # free-form, project-owned escape hatch (opaque to the framework)
```

The full pipeline (dataloader class, dataset wiring, model_instance LazyCall, etc.) lives in the experiment SKU Python file under `cosmos_framework/configs/base/experiment/sft/<recipe>.py`. The TOML only surfaces values the recipe author wants users to tune.
Expand Down Expand Up @@ -225,6 +227,27 @@ Top-level dataloader scalars only. The dataloader's class (LazyCall) and full pi
| `max_sequence_length` | `null` | Cap on tokens per packed sequence. Remapped to `max_tokens` on the VLM `DataPackerDataLoader`. `null` = no per-token cap. |
| `seed` | `42` | Dataloader RNG seed. **VFM only** — skipped on VLM (DataPackerDataLoader has no `seed` ctor kwarg). |

## `[custom]` (free-form escape hatch)

`[custom]` lets a project carry its own config (dataset paths, sampling ratios, …) in the **same** TOML as the framework knobs. The framework never looks inside it — it's the one section exempt from the `extra="forbid"` typo guard (every other section still rejects unknown keys).

How it works:

- **Arbitrary nested content** passes through verbatim — scalars, sub-tables (`[custom.a.b]`), arrays-of-tables (`[[custom.items]]`).
- It does **not** go through Hydra. After `load_config` finishes, the table is attached as a plain `dict` via `config.custom = raw.get("custom", {})` (or `{}` when absent — reading `config.custom` is always safe).
- So values must be **concrete**: `${custom}` interpolation is **not** supported, and `config.custom` is **not** part of `config.to_dict()` / serialized config dumps.

```toml
[custom]
your_custom_files = "custom_value"
```

Read it directly to wire your own pipeline:

```python
project_cfg = TrainingDatasetConfig.model_validate(config.custom)
```

## Cross-cutting behaviors

### `"???"` (MISSING) sentinel
Expand Down Expand Up @@ -289,9 +312,10 @@ A few useful knobs aren't currently modeled by `SFTExperimentConfig` because the
1. Reads the TOML with `tomllib`.
2. Validates the parsed dict against `SFTExperimentConfig` (raises `ValidationError` on unknown keys).
3. Picks the base config from `[job].task`: `TASK_TO_BASE_CONFIG["vfm"|"vlm"]`.
4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=<name>", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered.
4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=<name>", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. `[custom]` is skipped here (it is injected verbatim in step 7, not per-leaf-remapped).
5. Appends `extra_overrides` (CLI tail) so they take precedence over the TOML.
6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides)`, which imports the base config module (running `make_config()` to register every config group and import every experiment SKU's `cs.store(group="experiment", …)`), then runs `override(config, overrides)` — Hydra `compose` resolves the `experiment=<name>` selector against `ConfigStore` and applies the dotted-path overrides.
6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides)`, which imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`), then `override(config, overrides)` has Hydra `compose` resolve the `experiment=<name>` selector against `ConfigStore` and apply the dotted-path overrides.
7. Injects `[custom]` after loading: `config.custom = raw.get("custom", {})`. This runs **after** Hydra resolution, so it lands as a plain `dict` (no `${custom}` interpolation; not part of serialized config dumps).

The returned `Config` is ready for `launch()`.

Expand Down