diff --git a/cosmos_framework/configs/toml_config/sft_config.py b/cosmos_framework/configs/toml_config/sft_config.py index a15d00a..5d17f2b 100644 --- a/cosmos_framework/configs/toml_config/sft_config.py +++ b/cosmos_framework/configs/toml_config/sft_config.py @@ -669,6 +669,15 @@ class SFTExperimentConfig(BaseModel): trainer: TrainerConfig = Field(default_factory=TrainerConfig) checkpoint: CheckpointConfig = Field(default_factory=CheckpointConfig) dataloader_train: DataloaderTrainConfig = Field(default_factory=DataloaderTrainConfig) + custom: dict[str, Any] = Field( + default_factory=dict, + description=( + "Free-form, project-owned escape hatch. Arbitrary nested content " + "passes through verbatim — the framework never validates inside it. " + "Injected onto the loaded config as ``config.custom`` after Hydra " + "resolution; specify concrete values here (no ${...} interpolation)." + ), + ) # --------------------------------------------------------------------------- @@ -693,16 +702,17 @@ def load_experiment_from_toml( ["optimizer.lr=1e-5", "trainer.max_iter=200"] ["model.config.parallelism.data_parallel_shard_degree=4"] - Calls ``cosmos_framework.utils.config.load_config`` which: + The load then: - 1. Imports the base config module and runs ``make_config()``. This - registers every config group (model, ema, tokenizer, ...) and imports - all experiment modules so their ``cs.store(group="experiment", ...)`` - side-effects fire. - 2. Runs ``override(config, overrides)`` — Hydra ``compose`` then resolves - the ``experiment=`` selector against ``ConfigStore`` and applies - the dotted-path overrides we generated from the TOML, followed by - ``extra_overrides``. + 1. Runs ``load_config`` — imports the base config module, runs + ``make_config()`` (registers config groups + experiment modules), and + lets Hydra ``compose`` resolve the ``experiment=`` selector and + apply the dotted-path overrides, followed by ``extra_overrides``. + 2. Injects the TOML's ``[custom]`` table (if any) verbatim onto + ``config.custom`` *after* loading — kept out of ``build_hydra_overrides`` + so it lands as-is, not per-leaf-remapped. Because this happens after + Hydra resolution, ``[custom]`` must hold concrete values; ``${...}`` + interpolation against ``custom`` is not supported. Returns the merged ``Config`` instance, ready for ``launch()``. """ @@ -740,4 +750,10 @@ def load_experiment_from_toml( # Import lazily so this module stays cheap to import in non-training contexts. from cosmos_framework.utils.config import load_config - return load_config(base_config_path, overrides) + config = load_config(base_config_path, overrides) + + # Inject [custom] verbatim after Hydra resolution. Kept off the base config + # schema so the framework-owned hydra configs stay untouched; lands as a + # plain dict reachable via config.custom. + config.custom = raw.get("custom", {}) + return config diff --git a/cosmos_framework/configs/toml_config/sft_config_test.py b/cosmos_framework/configs/toml_config/sft_config_test.py new file mode 100644 index 0000000..3385b41 --- /dev/null +++ b/cosmos_framework/configs/toml_config/sft_config_test.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Loader tests for the free-form ``[custom]`` escape-hatch section of the SFT TOML.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from cosmos_framework.configs.toml_config.sft_config import SFTExperimentConfig +from cosmos_framework.configs.toml_config.toml_config_helper import build_hydra_overrides + +# Representative payload: scalars, a nested sub-table, and an array-of-tables. +_CUSTOM_PAYLOAD = { + "scalar_int": 5, + "scalar_str": "hello", + "flag": True, + "ratio": 0.3, + "sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}}, + "items": [ + {"path": "/data/a", "weight": 1.0}, + {"path": "/data/b", "weight": 2.0}, + ], +} + + +# --------------------------------------------------------------------------- # +# 1. pydantic schema validation # +# --------------------------------------------------------------------------- # +class TestSchemaValidation: + def test_custom_section_validates_arbitrary_nested_content(self) -> None: + """Arbitrary nested [custom] content passes through untouched.""" + raw = { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "custom": _CUSTOM_PAYLOAD, + } + cfg = SFTExperimentConfig.model_validate(raw) + # The framework stores it verbatim — no coercion, no inner validation. + assert cfg.custom == _CUSTOM_PAYLOAD + + def test_no_custom_section_defaults_empty(self) -> None: + cfg = SFTExperimentConfig.model_validate({"job": {"task": "vfm", "experiment": "vision_sft_nano"}}) + assert cfg.custom == {} + + def test_unknown_top_level_key_raises(self) -> None: + """Any unknown top-level section that is NOT `custom` still raises.""" + with pytest.raises(ValidationError): + SFTExperimentConfig.model_validate( + { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "bogus_section": {"x": 1}, + } + ) + + def test_unknown_key_inside_optimizer_raises(self) -> None: + """A typo inside a KNOWN section is still a hard error (extra='forbid').""" + with pytest.raises(ValidationError): + SFTExperimentConfig.model_validate( + { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "optimizer": {"lr": 1.0e-4, "not_a_real_key": 1}, + } + ) + + def test_custom_does_not_loosen_sibling_validation(self) -> None: + """Presence of [custom] must not relax extra='forbid' elsewhere.""" + with pytest.raises(ValidationError): + SFTExperimentConfig.model_validate( + { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "custom": _CUSTOM_PAYLOAD, + "trainer": {"max_iter": 10, "typo_here": True}, + } + ) + + +# --------------------------------------------------------------------------- # +# 2. build_hydra_overrides must NOT emit [custom] as per-leaf overrides # +# --------------------------------------------------------------------------- # +class TestBuildHydraOverrides: + def test_custom_not_emitted_as_overrides(self) -> None: + raw = { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "optimizer": {"lr": 1.0e-5}, + "custom": _CUSTOM_PAYLOAD, + } + overrides = build_hydra_overrides(raw) + # Nothing under custom (verbatim or remapped) should appear. + assert all("custom" not in o for o in overrides), overrides + + def test_other_keys_still_emitted(self) -> None: + raw = { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "optimizer": {"lr": 1.0e-5}, + "custom": {"a": 1}, + } + overrides = build_hydra_overrides(raw) + assert "experiment=vision_sft_nano" in overrides + assert any(o.startswith("optimizer.lr=") for o in overrides), overrides + + +# --------------------------------------------------------------------------- # +# 3. end-to-end load_experiment_from_toml on the shipped vision_sft_nano recipe # +# --------------------------------------------------------------------------- # +_BASE_TOML = """\ +[job] +task = "vfm" +experiment = "vision_sft_nano" +project = "cosmos3" +group = "sft" +name = "sft_config_custom_test" +wandb_mode = "disabled" + +[model.tokenizer] +vae_path = "${oc.env:WAN_VAE_PATH}" + +[checkpoint] +load_path = "${oc.env:BASE_CHECKPOINT_PATH}" +""" + +_CUSTOM_TOML_BLOCK = """\ + +[custom] +scalar_int = 5 +scalar_str = "hello" +flag = true +ratio = 0.3 + +[custom.sampling] +bug_ratio = 0.3 + +[custom.sampling.nested] +deep = 1 + +[[custom.items]] +path = "/data/a" +weight = 1.0 + +[[custom.items]] +path = "/data/b" +weight = 2.0 +""" + + +def _load_or_skip(toml_path: Path): + """Run the real loader, skipping if the training stack can't be imported.""" + from cosmos_framework.configs.toml_config.sft_config import load_experiment_from_toml + + try: + return load_experiment_from_toml(str(toml_path)) + except ImportError as exc: # pragma: no cover — env-dependent + pytest.skip(f"training stack not importable here: {exc!r}") + + +@pytest.fixture +def _dummy_recipe_env(monkeypatch: pytest.MonkeyPatch) -> None: + # vision_sft_nano interpolates these env vars into path strings at resolve time. + monkeypatch.setenv("DATASET_PATH", "/tmp/dummy_dataset") + monkeypatch.setenv("WAN_VAE_PATH", "/tmp/dummy_vae.pth") + monkeypatch.setenv("BASE_CHECKPOINT_PATH", "/tmp/dummy_ckpt") + + +class TestEndToEndLoader: + def test_load_with_custom_section(self, tmp_path: Path, _dummy_recipe_env: None) -> None: + toml_path = tmp_path / "with_custom.toml" + toml_path.write_text(_BASE_TOML + _CUSTOM_TOML_BLOCK) + + config = _load_or_skip(toml_path) + + expected = { + "scalar_int": 5, + "scalar_str": "hello", + "flag": True, + "ratio": 0.3, + "sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}}, + "items": [ + {"path": "/data/a", "weight": 1.0}, + {"path": "/data/b", "weight": 2.0}, + ], + } + # Injected verbatim as a plain dict after Hydra resolution, so a project + # can run MyProjectConfig.model_validate(config.custom) directly. + assert config.custom == expected + + def test_load_without_custom_section_defaults_empty(self, tmp_path: Path, _dummy_recipe_env: None) -> None: + toml_path = tmp_path / "no_custom.toml" + toml_path.write_text(_BASE_TOML) + + config = _load_or_skip(toml_path) + + assert config.custom == {} diff --git a/cosmos_framework/configs/toml_config/toml_config_helper.py b/cosmos_framework/configs/toml_config/toml_config_helper.py index ac54696..82333cf 100644 --- a/cosmos_framework/configs/toml_config/toml_config_helper.py +++ b/cosmos_framework/configs/toml_config/toml_config_helper.py @@ -138,6 +138,9 @@ def build_hydra_overrides(toml_dict: dict) -> list[str]: overlay = dict(toml_dict) overlay["job"] = job + # [custom] lands verbatim on config.custom (see load_experiment_from_toml), + # so it must not be per-leaf-remapped into Hydra overrides here. + overlay.pop("custom", None) for top_key, val in overlay.items(): _emit_with_remap(overrides, [top_key], val, rules) diff --git a/cosmos_framework/utils/config.py b/cosmos_framework/utils/config.py index c59d689..5078fbc 100644 --- a/cosmos_framework/utils/config.py +++ b/cosmos_framework/utils/config.py @@ -517,7 +517,12 @@ def validate(self) -> None: assert self.job.name != "" -def load_config(config_path: str, opts: list[str], enable_one_logger: bool = False) -> Config: +def load_config( + config_path: str, + opts: list[str], + enable_one_logger: bool = False, +) -> Config: + """Load a config from a ``.yaml`` or ``.py`` path and apply ``opts``.""" from cosmos_framework.utils.serialization import from_yaml, load_callable t1 = time.monotonic_ns() @@ -549,7 +554,11 @@ def load_config(config_path: str, opts: list[str], enable_one_logger: bool = Fal return config -def _load_py_config(config_path: str, opts: list[str], validate: bool = True) -> Config: +def _load_py_config( + config_path: str, + opts: list[str], + validate: bool = True, +) -> Config: # NOTE: circular dependency from cosmos_framework.utils.config_helper import get_config_module, override diff --git a/docs/sft_config.md b/docs/sft_config.md index 6762f40..8df6e91 100644 --- a/docs/sft_config.md +++ b/docs/sft_config.md @@ -23,6 +23,7 @@ ______________________________________________________________________ - [`[trainer.callbacks.grad_clip]`](#trainercallbacksgrad_clip) - [`[checkpoint]`](#checkpoint) - [`[dataloader_train]`](#dataloader_train) +- [`[custom]` (free-form escape hatch)](#custom-free-form-escape-hatch) - [Cross-cutting behaviors](#cross-cutting-behaviors) - [`"???"` (MISSING) sentinel](#-missing-sentinel) - [Env interpolation](#env-interpolation) @@ -60,6 +61,7 @@ After validation, the TOML dict is converted to a Hydra override list by [`build [trainer.callbacks.grad_clip] # clip_norm + force_finite [checkpoint] # load_path, save_iter, key-skip blocklist [dataloader_train] # top-level scalars only +[custom] # free-form, project-owned escape hatch (opaque to the framework) ``` The full pipeline (dataloader class, dataset wiring, model_instance LazyCall, etc.) lives in the experiment SKU Python file under `cosmos_framework/configs/base/experiment/sft/.py`. The TOML only surfaces values the recipe author wants users to tune. @@ -225,6 +227,27 @@ Top-level dataloader scalars only. The dataloader's class (LazyCall) and full pi | `max_sequence_length` | `null` | Cap on tokens per packed sequence. Remapped to `max_tokens` on the VLM `DataPackerDataLoader`. `null` = no per-token cap. | | `seed` | `42` | Dataloader RNG seed. **VFM only** — skipped on VLM (DataPackerDataLoader has no `seed` ctor kwarg). | +## `[custom]` (free-form escape hatch) + +`[custom]` lets a project carry its own config (dataset paths, sampling ratios, …) in the **same** TOML as the framework knobs. The framework never looks inside it — it's the one section exempt from the `extra="forbid"` typo guard (every other section still rejects unknown keys). + +How it works: + +- **Arbitrary nested content** passes through verbatim — scalars, sub-tables (`[custom.a.b]`), arrays-of-tables (`[[custom.items]]`). +- It does **not** go through Hydra. After `load_config` finishes, the table is attached as a plain `dict` via `config.custom = raw.get("custom", {})` (or `{}` when absent — reading `config.custom` is always safe). +- So values must be **concrete**: `${custom}` interpolation is **not** supported, and `config.custom` is **not** part of `config.to_dict()` / serialized config dumps. + +```toml +[custom] +your_custom_files = "custom_value" +``` + +Read it directly to wire your own pipeline: + +```python +project_cfg = TrainingDatasetConfig.model_validate(config.custom) +``` + ## Cross-cutting behaviors ### `"???"` (MISSING) sentinel @@ -289,9 +312,10 @@ A few useful knobs aren't currently modeled by `SFTExperimentConfig` because the 1. Reads the TOML with `tomllib`. 2. Validates the parsed dict against `SFTExperimentConfig` (raises `ValidationError` on unknown keys). 3. Picks the base config from `[job].task`: `TASK_TO_BASE_CONFIG["vfm"|"vlm"]`. -4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. +4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. `[custom]` is skipped here (it is injected verbatim in step 7, not per-leaf-remapped). 5. Appends `extra_overrides` (CLI tail) so they take precedence over the TOML. -6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides)`, which imports the base config module (running `make_config()` to register every config group and import every experiment SKU's `cs.store(group="experiment", …)`), then runs `override(config, overrides)` — Hydra `compose` resolves the `experiment=` selector against `ConfigStore` and applies the dotted-path overrides. +6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides)`, which imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`), then `override(config, overrides)` has Hydra `compose` resolve the `experiment=` selector against `ConfigStore` and apply the dotted-path overrides. +7. Injects `[custom]` after loading: `config.custom = raw.get("custom", {})`. This runs **after** Hydra resolution, so it lands as a plain `dict` (no `${custom}` interpolation; not part of serialized config dumps). The returned `Config` is ready for `launch()`.