diff --git a/docs/data_stages/README.md b/docs/data_stages/README.md new file mode 100644 index 0000000000..de56e6e18f --- /dev/null +++ b/docs/data_stages/README.md @@ -0,0 +1,332 @@ +# Multi-Stage Data Training + +Multi-stage training allows switching between different data mixtures at specified training steps, similar to approaches used in Qwen3, DeepSeek-V3, and Llama 3. + +## Quick Start + +Data stages are **optional**. If no `[[training.data_stages]]` are defined, a single stage is auto-created from `[training]` data fields (backward compatible). When stages ARE defined, they override `[training]` data fields completely. + +### Multi-Stage Example + +Define `[[training.data_stages]]` sections for multi-stage training: + +```toml +[training] +steps = 150000 + +[[training.data_stages]] +name = "general" +start_step = 0 +end_step = 100000 +dataset_type = "nanoset" +dataset_folders = ["/data/general", "/data/math", "/data/code"] +dataset_weights = [0.8, 0.1, 0.1] +seq_len = 4096 + +[[training.data_stages]] +name = "reasoning" +start_step = 100000 +end_step = 130000 +dataset_type = "nanoset" +dataset_folders = ["/data/general", "/data/math", "/data/code"] +dataset_weights = [0.3, 0.35, 0.35] +seq_len = 4096 + +[[training.data_stages]] +name = "long_context" +start_step = 130000 +dataset_type = "nanoset" +dataset_folders = ["/data/general", "/data/math", "/data/code"] +dataset_weights = [0.3, 0.35, 0.35] +seq_len = 32768 +``` + +## Configuration Fields + +Each `[[training.data_stages]]` section must define all data-related fields explicitly: + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `name` | string | Yes | Stage identifier for logging | +| `start_step` | int | Yes | Step when stage begins (inclusive) | +| `end_step` | int | No | Step when stage ends (exclusive). Omit for final stage | +| `dataset` | string | Yes* | Dataset name (for huggingface type) | +| `dataset_path` | string | No | Path to dataset | +| `dataset_type` | string | Yes | `"huggingface"`, `"nanoset"`, `"preprocessed"`, `"packed_memmap"` | +| `dataset_folders` | list | Yes* | Folders for nanoset datasets | +| `dataset_weights` | list | No | Weights for blending datasets (must sum to 1.0) | +| `dataset_random_seed` | int | No | Random seed for this stage (defaults to `training.dataset_random_seed`) | +| `seq_len` | int | Yes | Sequence length | + +*Required based on `dataset_type`: `dataset` for huggingface, `dataset_folders` for nanoset. + +## Single-Stage Training (Backward Compatible) + +For single-stage training, you can simply use `[training]` data fields - no `[[training.data_stages]]` needed: + +```toml +[training] +steps = 100000 +dataset_type = "huggingface" +dataset = "c4_test" +seq_len = 4096 +``` + +A single stage named "default" is auto-created internally. This maintains full backward compatibility with existing configs. + +Alternatively, you can explicitly define a single stage: + +```toml +[training] +steps = 100000 + +[[training.data_stages]] +name = "pretrain" +start_step = 0 +dataset_type = "nanoset" +dataset_folders = ["/data/web", "/data/books"] +dataset_weights = [0.7, 0.3] +seq_len = 4096 +``` + +## Validation + +The following validations are performed at startup: + +- **Stage coverage**: First stage must start at step 0, no gaps or overlaps between stages +- **Required fields**: `name`, `start_step`, `dataset_type`, `seq_len` must be defined +- **Dataset source**: `dataset` required for huggingface, `dataset_folders` required for nanoset +- **Weights**: Must be non-negative, each <= 1.0, sum to 1.0, count must match folders +- **Value ranges**: `seq_len > 0`, `dataset_random_seed >= 0`, `start_step < training.steps` + +## Common Patterns + +### Pattern 1: Change Data Mixture + +```toml +[[training.data_stages]] +name = "pretrain" +start_step = 0 +end_step = 100000 +dataset_type = "nanoset" +dataset_folders = ["/data/web", "/data/books", "/data/code"] +dataset_weights = [0.7, 0.2, 0.1] # 70% web, 20% books, 10% code +seq_len = 4096 + +[[training.data_stages]] +name = "annealing" +start_step = 100000 +dataset_type = "nanoset" +dataset_folders = ["/data/web", "/data/books", "/data/code"] +dataset_weights = [0.4, 0.3, 0.3] # More balanced for final phase +seq_len = 4096 +``` + +### Pattern 2: Context Extension + +```toml +[[training.data_stages]] +name = "base" +start_step = 0 +end_step = 90000 +dataset_type = "nanoset" +dataset_folders = ["/data/web", "/data/books", "/data/code"] +dataset_weights = [0.5, 0.3, 0.2] +seq_len = 4096 + +[[training.data_stages]] +name = "long_context" +start_step = 90000 +dataset_type = "nanoset" +dataset_folders = ["/data/web", "/data/books", "/data/code"] +dataset_weights = [0.5, 0.3, 0.2] +seq_len = 32768 +``` + +### Pattern 3: Different Random Seeds (Multi-Epoch) + +```toml +[[training.data_stages]] +name = "epoch1" +start_step = 0 +end_step = 50000 +dataset_type = "nanoset" +dataset_folders = ["/data/web", "/data/books"] +dataset_weights = [0.7, 0.3] +dataset_random_seed = 1234 +seq_len = 4096 + +[[training.data_stages]] +name = "epoch2" +start_step = 50000 +dataset_type = "nanoset" +dataset_folders = ["/data/web", "/data/books"] +dataset_weights = [0.7, 0.3] +dataset_random_seed = 5678 +seq_len = 4096 +``` + +### Pattern 4: Mid-Training Ablation + +For ablation studies where you want to test different data mixtures from a checkpoint, you can add stages that start mid-training. The system will auto-create a "default" stage from `[training]` fields for the gap. + +**Ablation config** (start new mixture at step 5): +```toml +[training] +steps = 10 +# These fields cover steps 0-5 (auto-created as "default") +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 + +# Ablation stage starts at step 5 with different random seed +[[training.data_stages]] +name = "ablation_stage" +start_step = 5 +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 +dataset_random_seed = 9999 # Different seed for ablation +``` + +The system auto-creates "default" for steps 0-5 from `[training]`, then transitions to "ablation_stage" at step 5. + +## Logging + +At training start, a stage plan is logged: + +``` +============================================================ +DATA STAGE TRAINING PLAN +============================================================ +Total stages: 3 + +Stage 1: general + Steps: 0 -> 100,000 (100,000 steps) + Estimated tokens: 409.60B tokens + Dataset type: nanoset + Dataset folders: 3 folders + Weights: [0.800, 0.100, 0.100] + Sequence length: 4096 + +Stage 2: reasoning + Steps: 100,000 -> 130,000 (30,000 steps) + ... +============================================================ +``` + +At each transition: + +``` +============================================================ +DATA STAGE TRANSITION +============================================================ +Step 100000: 'general' -> 'reasoning' +Changes: dataset_weights +New weights: [0.300, 0.350, 0.350] +============================================================ +``` + +## Checkpoint & Resume + +Stage state is automatically saved in checkpoints: +- `stage_idx`: Current stage index +- `stage_name`: Current stage name +- `dataloader_state`: Position within the dataset + +On resume, the exact stage and dataloader position are restored. No manual intervention needed. + +## Testing + +### Test Configs + +Test configs are located in `docs/data_stages/configs/`: + +| Config | Description | +|--------|-------------| +| `data_stages_test.toml` | 3 stages with transitions at step 5 and 10 | +| `data_stages_backcompat_test.toml` | No data_stages (backward compatibility) | +| `data_stages_ablation_test.toml` | Stages start at step 5 (ablation use case) | + +### Automated Test Suite + +Run the test script to verify all functionality: + +```bash +./scripts/test_data_stages.sh +``` + +The test suite runs 5 tests: + +``` +============================================================ +DATA STAGES TEST SUITE +============================================================ + +[Test 1] Backward Compatibility: No [[training.data_stages]] + ✓ Auto-created 'default' stage from [training] + ✓ Stage named 'default' + ✓ Training completed successfully + +[Test 2] Multi-Stage Training: Full run with 3 stages + ✓ Transition at step 5: stage_1_general -> stage_2_reasoning + ✓ Transition at step 10: stage_2_reasoning -> stage_3_final + +[Test 3] Checkpoint Resume: from step 7 + ✓ Stage correctly restored to stage_2_reasoning + ✓ Dataloader position restored + ✓ Training resumed at correct step (8) + +[Test 4] Reproducibility: Comparing losses between full and resumed runs +Step | Full Run | Resume | Match +------|----------|----------|------ +8 | 4.7073 | 4.7073 | ✓ +9 | 4.0312 | 4.0312 | ✓ +10 | 4.0548 | 4.0548 | ✓ +11 | 3.8143 | 3.8143 | ✓ +12 | 3.8702 | 3.8702 | ✓ +13 | 4.2306 | 4.2306 | ✓ +14 | 3.6354 | 3.6354 | ✓ +15 | 3.7099 | 3.7099 | ✓ + +[Test 5] Ablation: Stages start at step 5 + ✓ Auto-created 'default' stage for gap (steps 0-5) + ✓ First stage is 'default' + ✓ Second stage is 'ablation_stage' + ✓ Transition occurred: default -> ablation_stage + ✓ Training completed successfully + +============================================================ +ALL TESTS PASSED! +============================================================ +``` + +### Manual Testing + +```bash +# Backward compatibility (no data_stages) +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file docs/data_stages/configs/data_stages_backcompat_test.toml + +# Multi-stage training +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file docs/data_stages/configs/data_stages_test.toml + +# Resume from step 7 +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file docs/data_stages/configs/data_stages_test.toml \ + --checkpoint.load_step 7 + +# Ablation (stages start mid-training) +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file docs/data_stages/configs/data_stages_ablation_test.toml +``` + +### What the Tests Verify + +1. **Backward compatibility**: Existing configs without `[[training.data_stages]]` still work +2. **Stage transitions**: Dataloader rebuilds correctly at stage boundaries +3. **Checkpoint saves**: Stage index + exact dataloader position (sample count) +4. **Resume restores**: Exact state - losses match between full run and resumed run +5. **Ablation mode**: When `[[training.data_stages]]` starts after step 0 (e.g., step 5), the system auto-creates a "default" stage from `[training]` fields to cover the gap (steps 0-5). This lets you train initially with `[training]` only, then later add stages mid-training to test different data mixtures from a checkpoint. +6. **No data skip/repeat**: Same batches processed in same order diff --git a/docs/data_stages/configs/data_stages_ablation_test.toml b/docs/data_stages/configs/data_stages_ablation_test.toml new file mode 100644 index 0000000000..063a7677e2 --- /dev/null +++ b/docs/data_stages/configs/data_stages_ablation_test.toml @@ -0,0 +1,79 @@ +# Copyright (c) Nous Research. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Test config for dmayhem's ablation use case +# [training] data fields cover steps 0-5 (auto-created as "default") +# [[training.data_stages]] starts at step 5 with different seed (ablation) +# +# This tests the scenario where a user: +# 1. Trained with [training] only to step 5 +# 2. Now wants to resume with different data mixture from step 5 + +[job] +dump_folder = "./outputs/data_stages_ablation_test" +description = "Ablation test - stages start mid-training" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 +enable_tensorboard = false +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 2 +max_norm = 1.0 +steps = 10 + +# These fields cover steps 0-5 (auto-created as "default") +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 + +# Ablation stage starts at step 5 with different random seed +[[training.data_stages]] +name = "ablation_stage" +start_step = 5 +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 +dataset_random_seed = 9999 + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = true +folder = "checkpoint" +interval = 1 +last_save_model_only = false +async_mode = "disabled" + +[activation_checkpoint] +mode = "none" + +[compile] +enable = false + +[debug] +seed = 42 diff --git a/docs/data_stages/configs/data_stages_backcompat_test.toml b/docs/data_stages/configs/data_stages_backcompat_test.toml new file mode 100644 index 0000000000..33bd307ed7 --- /dev/null +++ b/docs/data_stages/configs/data_stages_backcompat_test.toml @@ -0,0 +1,63 @@ +# Copyright (c) Nous Research. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Backward compatibility test config +# No [[training.data_stages]] defined - should auto-create single "default" stage +# +# This tests that existing configs without data_stages still work. + +[job] +dump_folder = "./outputs/data_stages_backcompat_test" +description = "Backward compatibility test - no data_stages" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 +enable_tensorboard = false +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 2 +max_norm = 1.0 +steps = 5 + +# Old-style data config (no data_stages) +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false + +[activation_checkpoint] +mode = "none" + +[compile] +enable = false + +[debug] +seed = 42 diff --git a/docs/data_stages/configs/data_stages_test.toml b/docs/data_stages/configs/data_stages_test.toml new file mode 100644 index 0000000000..5e8a0a9669 --- /dev/null +++ b/docs/data_stages/configs/data_stages_test.toml @@ -0,0 +1,102 @@ +# Copyright (c) Nous Research. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Test config for multi-stage data training +# Tests: stage transitions, checkpoint save/resume, reproducibility +# +# Test scenarios: +# 1. Full run: steps 0->15 with 3 stage transitions +# 2. Resume test: run 0->7, resume 7->15, compare with full run +# +# Usage: +# # Full run +# CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ +# -m torchtitan.train --job.config_file torchtitan/models/llama3/train_configs/data_stages_test.toml +# +# # Resume from step 7 +# CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ +# -m torchtitan.train --job.config_file torchtitan/models/llama3/train_configs/data_stages_test.toml \ +# --checkpoint.load_step 7 + +[job] +dump_folder = "./outputs/data_stages_test" +description = "Data stages test - verify transitions and checkpoint resume" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 +enable_tensorboard = false +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 2 +max_norm = 1.0 +steps = 15 + +# 3 stages: transitions at step 5 and step 10 +# Each stage is fully self-contained with all data configuration +[[training.data_stages]] +name = "stage_1_general" +start_step = 0 +end_step = 5 +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 +dataset_random_seed = 1111 + +[[training.data_stages]] +name = "stage_2_reasoning" +start_step = 5 +end_step = 10 +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 +dataset_random_seed = 2222 + +[[training.data_stages]] +name = "stage_3_final" +start_step = 10 +dataset = "c4_test" +dataset_type = "huggingface" +seq_len = 512 +dataset_random_seed = 3333 + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = true +folder = "checkpoint" +interval = 1 # Save every step for testing resume at any point +last_save_model_only = false +async_mode = "disabled" + +[activation_checkpoint] +mode = "none" + +[compile] +enable = false + +[debug] +seed = 42 # Fixed seed for reproducibility diff --git a/scripts/test_data_stages.sh b/scripts/test_data_stages.sh new file mode 100755 index 0000000000..1e6ade41ac --- /dev/null +++ b/scripts/test_data_stages.sh @@ -0,0 +1,277 @@ +#!/bin/bash +# Copyright (c) Nous Research. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Test script for multi-stage data training +# Verifies: backward compatibility, stage transitions, checkpoint resume, ablation +# +# Usage: ./scripts/test_data_stages.sh + +set -e + +STAGES_CONFIG="docs/data_stages/configs/data_stages_test.toml" +BACKCOMPAT_CONFIG="docs/data_stages/configs/data_stages_backcompat_test.toml" +ABLATION_CONFIG="docs/data_stages/configs/data_stages_ablation_test.toml" + +STAGES_OUTPUT="./outputs/data_stages_test" +BACKCOMPAT_OUTPUT="./outputs/data_stages_backcompat_test" +ABLATION_OUTPUT="./outputs/data_stages_ablation_test" + +FULL_LOG="/tmp/data_stages_full_run.log" +RESUME_LOG="/tmp/data_stages_resume_run.log" +BACKCOMPAT_LOG="/tmp/data_stages_backcompat.log" +ABLATION_LOG="/tmp/data_stages_ablation.log" + +echo "============================================================" +echo "DATA STAGES TEST SUITE" +echo "============================================================" +echo "" + +# Clean previous outputs +rm -rf "$STAGES_OUTPUT" "$BACKCOMPAT_OUTPUT" "$ABLATION_OUTPUT" +rm -f "$FULL_LOG" "$RESUME_LOG" "$BACKCOMPAT_LOG" "$ABLATION_LOG" + +############################################################################## +# Test 1: Backward Compatibility (no data_stages defined) +############################################################################## +echo "[Test 1] Backward Compatibility: No [[training.data_stages]]" +echo "------------------------------------------------------------" +echo "Config uses only [training] data fields, no data_stages." +echo "" + +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file "$BACKCOMPAT_CONFIG" 2>&1 | tee "$BACKCOMPAT_LOG" + +echo "" +echo "[Test 1] Verifying backward compatibility..." + +if grep -q "No \[\[training.data_stages\]\] defined. Auto-created single stage from \[training\] config" "$BACKCOMPAT_LOG"; then + echo " ✓ Auto-created 'default' stage from [training]" +else + echo " ✗ Failed to auto-create stage from [training]" + exit 1 +fi + +if grep -q "Stage 1: default" "$BACKCOMPAT_LOG"; then + echo " ✓ Stage named 'default'" +else + echo " ✗ Stage name incorrect" + exit 1 +fi + +if grep -q "Training completed" "$BACKCOMPAT_LOG"; then + echo " ✓ Training completed successfully" +else + echo " ✗ Training did not complete" + exit 1 +fi + +echo "" +echo "[Test 1] PASSED: Backward compatibility works" +echo "" + +############################################################################## +# Test 2: Multi-stage with transitions +############################################################################## +echo "[Test 2] Multi-Stage Training: Full run with 3 stages" +echo "------------------------------------------------------------" +echo "Config has 3 stages with transitions at step 5 and 10." +echo "" + +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file "$STAGES_CONFIG" 2>&1 | tee "$FULL_LOG" + +echo "" +echo "[Test 2] Verifying stage transitions occurred..." + +if grep -q "stage_1_general.*stage_2_reasoning" "$FULL_LOG"; then + echo " ✓ Transition at step 5: stage_1_general -> stage_2_reasoning" +else + echo " ✗ Missing transition at step 5" + exit 1 +fi + +if grep -q "stage_2_reasoning.*stage_3_final" "$FULL_LOG"; then + echo " ✓ Transition at step 10: stage_2_reasoning -> stage_3_final" +else + echo " ✗ Missing transition at step 10" + exit 1 +fi + +echo "" +echo "[Test 2] PASSED: Stage transitions work correctly" +echo "" + +############################################################################## +# Test 3: Checkpoint resume +############################################################################## +echo "[Test 3] Checkpoint Resume: from step 7" +echo "------------------------------------------------------------" +echo "Resume from checkpoint at step 7, verify state restoration." +echo "" + +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file "$STAGES_CONFIG" \ + --checkpoint.load_step 7 2>&1 | tee "$RESUME_LOG" + +echo "" +echo "[Test 3] Verifying checkpoint restore..." + +if grep -q "Checkpoint was at stage 'stage_2_reasoning'" "$RESUME_LOG"; then + echo " ✓ Stage correctly restored to stage_2_reasoning" +else + echo " ✗ Stage not correctly restored" + exit 1 +fi + +if grep -q "Restored dataloader position from checkpoint" "$RESUME_LOG"; then + echo " ✓ Dataloader position restored" +else + echo " ✗ Dataloader position not restored" + exit 1 +fi + +if grep -q "Training starts at step 8" "$RESUME_LOG"; then + echo " ✓ Training resumed at correct step (8)" +else + echo " ✗ Training did not resume at correct step" + exit 1 +fi + +echo "" +echo "[Test 3] PASSED: Checkpoint resume works correctly" +echo "" + +############################################################################## +# Test 4: Reproducibility (compare losses) +############################################################################## +echo "[Test 4] Reproducibility: Comparing losses between full and resumed runs" +echo "------------------------------------------------------------" + +# Extract losses from both runs (steps 8-15) +extract_losses() { + grep -oP "step:\s*\K\d+.*?loss:\s*[\d.]+" "$1" | \ + sed 's/\x1b\[[0-9;]*m//g' | \ + awk '{print $1, $3}' | \ + while read step loss; do + if [ "$step" -ge 8 ] && [ "$step" -le 15 ]; then + echo "$step $loss" + fi + done +} + +FULL_LOSSES=$(extract_losses "$FULL_LOG") +RESUME_LOSSES=$(extract_losses "$RESUME_LOG") + +echo "Step | Full Run | Resume | Match" +echo "------|----------|----------|------" + +MISMATCH=0 +for step in 8 9 10 11 12 13 14 15; do + full=$(echo "$FULL_LOSSES" | grep "^$step " | awk '{print $2}') + resume=$(echo "$RESUME_LOSSES" | grep "^$step " | awk '{print $2}') + + if [ -z "$full" ] || [ -z "$resume" ]; then + echo "$step | N/A | N/A | ?" + continue + fi + + # Compare with tolerance (4 decimal places) + diff=$(echo "$full $resume" | awk '{printf "%.4f", ($1-$2)^2}') + if [ "$diff" = "0.0000" ]; then + match="✓" + else + match="✗" + MISMATCH=1 + fi + + printf "%-5s | %-8s | %-8s | %s\n" "$step" "$full" "$resume" "$match" +done + +echo "" +if [ "$MISMATCH" -eq 0 ]; then + echo "[Test 4] PASSED: Losses match exactly" +else + echo "[Test 4] FAILED: Losses do not match" + exit 1 +fi +echo "" + +############################################################################## +# Test 5: Ablation (stages start mid-training) +############################################################################## +echo "[Test 5] Ablation: Stages start at step 5 (dmayhem use case)" +echo "------------------------------------------------------------" +echo "Config has [training] data for steps 0-5, then [[training.data_stages]]" +echo "at step 5 with different random seed. Tests mid-training ablation." +echo "" + +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train --job.config_file "$ABLATION_CONFIG" 2>&1 | tee "$ABLATION_LOG" + +echo "" +echo "[Test 5] Verifying ablation setup..." + +if grep -q "Auto-created 'default' stage from \[training\] for steps 0-5" "$ABLATION_LOG"; then + echo " ✓ Auto-created 'default' for gap (steps 0-5)" +else + echo " ✗ Failed to auto-create default" + exit 1 +fi + +if grep -q "Stage 1: default" "$ABLATION_LOG"; then + echo " ✓ First stage is 'default'" +else + echo " ✗ First stage not correctly named" + exit 1 +fi + +if grep -q "Stage 2: ablation_stage" "$ABLATION_LOG"; then + echo " ✓ Second stage is 'ablation_stage'" +else + echo " ✗ Second stage not correctly named" + exit 1 +fi + +if grep -q "default.*ablation_stage" "$ABLATION_LOG"; then + echo " ✓ Transition occurred: default -> ablation_stage" +else + echo " ✗ Transition did not occur" + exit 1 +fi + +if grep -q "Training completed" "$ABLATION_LOG"; then + echo " ✓ Training completed successfully" +else + echo " ✗ Training did not complete" + exit 1 +fi + +echo "" +echo "[Test 5] PASSED: Ablation mode works correctly" +echo "" + +############################################################################## +# Final Summary +############################################################################## +echo "============================================================" +echo "ALL TESTS PASSED!" +echo "============================================================" +echo "" +echo "Verified:" +echo " [Test 1] Backward compatibility - no data_stages" +echo " [Test 2] Multi-stage transitions" +echo " [Test 3] Checkpoint save/resume" +echo " [Test 4] Exact reproducibility on resume" +echo " [Test 5] Ablation mode (stages start mid-training)" +echo "" +echo "============================================================" + +# Cleanup +rm -rf "$STAGES_OUTPUT" "$BACKCOMPAT_OUTPUT" "$ABLATION_OUTPUT" +rm -f "$FULL_LOG" "$RESUME_LOG" "$BACKCOMPAT_LOG" "$ABLATION_LOG" + +exit 0 diff --git a/torchtitan/components/data_stages.py b/torchtitan/components/data_stages.py new file mode 100644 index 0000000000..2d93769c0c --- /dev/null +++ b/torchtitan/components/data_stages.py @@ -0,0 +1,747 @@ +# Copyright (c) Nous Research. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multi-stage data training support. + +This module provides DataStageManager and StageAwareDataloader for training +with different data mixtures at different stages of training, similar to +approaches used in Qwen3, DeepSeek-V3, and Llama 3. + +Data stages are OPTIONAL. If no [[training.data_stages]] are defined, a single +stage is auto-created from [training] data fields for backward compatibility. +When stages ARE defined, they override [training] data fields completely. + +Example usage (multi-stage): + [[training.data_stages]] + name = "general" + start_step = 0 + end_step = 100000 + dataset_type = "nanoset" + dataset_folders = ["/data/web", "/data/books", "/data/code"] + dataset_weights = [0.7, 0.2, 0.1] + seq_len = 4096 + + [[training.data_stages]] + name = "reasoning" + start_step = 100000 + dataset_type = "nanoset" + dataset_folders = ["/data/web", "/data/books", "/data/code"] + dataset_weights = [0.3, 0.35, 0.35] + seq_len = 4096 + +Example usage (single-stage, backward compatible): + [training] + dataset_type = "huggingface" + dataset = "c4_test" + seq_len = 4096 + # No [[training.data_stages]] needed - auto-created internally +""" + +import math +from dataclasses import dataclass +from typing import Any, Callable, Iterator + +from torchtitan.components.dataloader import BaseDataLoader +from torchtitan.config.job_config import DataStage, JobConfig +from torchtitan.tools.logging import logger + + +@dataclass +class EffectiveStageConfig: + """Resolved stage config. Each stage must define all required fields.""" + + dataset: str | None + dataset_path: str | None + dataset_type: str + dataset_folders: list[str] + dataset_weights: list[float] | None + dataset_random_seed: int + seq_len: int + + +class DataStageManager: + """Manages data stage transitions during training. + + Tracks current stage based on training step, handles stage transitions, + and builds dataloaders with stage-specific configurations. + """ + + def __init__( + self, + job_config: JobConfig, + build_dataloader_fn: Callable, + dp_world_size: int, + dp_rank: int, + tokenizer: Any, + ): + self.job_config = job_config + self.build_dataloader_fn = build_dataloader_fn + self.dp_world_size = dp_world_size + self.dp_rank = dp_rank + self.tokenizer = tokenizer + + # Convert dicts to DataStage objects if needed (TOML parser returns dicts) + raw_stages = job_config.training.data_stages + self.stages: list[DataStage] = [] + for stage in raw_stages: + if isinstance(stage, dict): + self.stages.append(DataStage(**stage)) + else: + self.stages.append(stage) + + training = job_config.training + + # Case 1: No stages defined - auto-create from [training] (backward compat) + if not self.stages: + auto_stage = DataStage( + name="default", + start_step=0, + end_step=None, + dataset=training.dataset, + dataset_path=training.dataset_path, + dataset_type=training.dataset_type, + dataset_folders=training.dataset_folders, + dataset_weights=training.dataset_weights, + dataset_random_seed=training.dataset_random_seed, + seq_len=training.seq_len, + ) + self.stages.append(auto_stage) + logger.info( + "No [[training.data_stages]] defined. " + "Auto-created single stage from [training] config." + ) + else: + # Sort stages by start_step first to find the earliest + self.stages.sort(key=lambda s: s.start_step) + first_stage_start = self.stages[0].start_step + + # Check if [training] has non-default data fields + training_has_data = self._has_non_default_training_data(training) + + if first_stage_start == 0: + # Case 2: Stages cover from step 0 - warn if [training] has non-default data + if training_has_data: + raise ValueError( + "Cannot define data fields in both [training] and [[training.data_stages]].\n" + "Your [[training.data_stages]] starts at step 0, so it fully covers training.\n" + "Please either:\n" + " 1. Remove data fields from [training] and use [[training.data_stages]] only, OR\n" + " 2. Remove [[training.data_stages]] and use [training] only\n\n" + "Detected non-default [training] data fields:\n" + f"{self._format_training_data_fields(training)}" + ) + else: + # Case 3: Stages start after step 0 - auto-create stage from [training] + # for the gap (ablation/resume use case) + # Always use [training] values for the gap, even if they are defaults + auto_stage = DataStage( + name="default", + start_step=0, + end_step=first_stage_start, + dataset=training.dataset, + dataset_path=training.dataset_path, + dataset_type=training.dataset_type, + dataset_folders=training.dataset_folders, + dataset_weights=training.dataset_weights, + dataset_random_seed=training.dataset_random_seed, + seq_len=training.seq_len, + ) + self.stages.insert(0, auto_stage) + logger.info( + f"[[training.data_stages]] starts at step {first_stage_start}. " + f"Auto-created 'default' stage from [training] for steps 0-{first_stage_start}." + ) + + self._current_stage_idx = 0 + + # Sort stages by start_step for consistent ordering + self.stages.sort(key=lambda s: s.start_step) + self._validate_stages() + self._log_stage_plan() + + @property + def current_stage(self) -> DataStage: + """Get current stage config.""" + return self.stages[self._current_stage_idx] + + @property + def current_stage_idx(self) -> int: + """Get current stage index.""" + return self._current_stage_idx + + def _validate_stages(self) -> None: + """Validate stage configurations comprehensively.""" + training = self.job_config.training + total_steps = training.steps + + # Check for duplicate stage names + stage_names = [s.name for s in self.stages] + if len(stage_names) != len(set(stage_names)): + duplicates = [n for n in stage_names if stage_names.count(n) > 1] + raise ValueError(f"Duplicate stage names found: {set(duplicates)}") + + # Validate each stage + for i, stage in enumerate(self.stages): + # Validate stage name + if not stage.name or not stage.name.strip(): + raise ValueError(f"Stage at index {i} has empty or whitespace name") + + # Validate step ranges + if stage.start_step < 0: + raise ValueError( + f"Stage '{stage.name}' has invalid start_step: {stage.start_step}" + ) + if stage.end_step is not None and stage.end_step <= stage.start_step: + raise ValueError( + f"Stage '{stage.name}' has end_step ({stage.end_step}) <= " + f"start_step ({stage.start_step})" + ) + + # Validate required fields are set + if stage.dataset_type is None: + raise ValueError( + f"Stage '{stage.name}' must define 'dataset_type' " + "(e.g., 'huggingface', 'nanoset')" + ) + if stage.seq_len is None: + raise ValueError(f"Stage '{stage.name}' must define 'seq_len'") + if stage.seq_len <= 0: + raise ValueError( + f"Stage '{stage.name}' has invalid seq_len: {stage.seq_len}. " + "seq_len must be positive." + ) + + # Validate dataset_random_seed if provided + if stage.dataset_random_seed is not None and stage.dataset_random_seed < 0: + raise ValueError( + f"Stage '{stage.name}' has negative dataset_random_seed: " + f"{stage.dataset_random_seed}. Seeds must be non-negative." + ) + + # Validate start_step doesn't exceed total training steps + if stage.start_step >= total_steps: + raise ValueError( + f"Stage '{stage.name}' starts at step {stage.start_step} but " + f"training.steps is only {total_steps}. This stage would never run." + ) + + # Validate dataset source based on type + if stage.dataset_type == "nanoset": + if not stage.dataset_folders: + raise ValueError( + f"Stage '{stage.name}' with dataset_type='nanoset' " + "must define 'dataset_folders'" + ) + # Validate no empty or whitespace-only folder paths + for j, folder in enumerate(stage.dataset_folders): + if not folder or not folder.strip(): + raise ValueError( + f"Stage '{stage.name}' has empty or whitespace-only path " + f"in dataset_folders at index {j}" + ) + elif stage.dataset_type == "huggingface": + if stage.dataset is None: + raise ValueError( + f"Stage '{stage.name}' with dataset_type='huggingface' " + "must define 'dataset'" + ) + # Validate dataset is not empty string + if not stage.dataset.strip(): + raise ValueError( + f"Stage '{stage.name}' has empty or whitespace-only 'dataset' value" + ) + + # Validate dataset_weights if provided + if stage.dataset_weights is not None: + if not stage.dataset_weights: + raise ValueError( + f"Stage '{stage.name}' has empty dataset_weights list" + ) + # Check for NaN or inf values + for j, w in enumerate(stage.dataset_weights): + if math.isnan(w): + raise ValueError( + f"Stage '{stage.name}' has NaN in dataset_weights " + f"at index {j}: {stage.dataset_weights}" + ) + if math.isinf(w): + raise ValueError( + f"Stage '{stage.name}' has infinity in dataset_weights " + f"at index {j}: {stage.dataset_weights}" + ) + if any(w < 0 for w in stage.dataset_weights): + raise ValueError( + f"Stage '{stage.name}' has negative dataset_weights: " + f"{stage.dataset_weights}" + ) + if any(w > 1 for w in stage.dataset_weights): + raise ValueError( + f"Stage '{stage.name}' has dataset_weight > 1: " + f"{stage.dataset_weights}" + ) + # Check weights sum to 1 (with tolerance for floating point) + weight_sum = sum(stage.dataset_weights) + if abs(weight_sum - 1.0) > 0.001: + raise ValueError( + f"Stage '{stage.name}' dataset_weights must sum to 1.0, " + f"but sum is {weight_sum:.6f}: {stage.dataset_weights}" + ) + # Check weights match folders count for nanoset + if stage.dataset_type == "nanoset" and stage.dataset_folders: + if len(stage.dataset_weights) != len(stage.dataset_folders): + raise ValueError( + f"Stage '{stage.name}' has {len(stage.dataset_weights)} " + f"weights but {len(stage.dataset_folders)} folders" + ) + + # Check first stage starts at step 0 + if self.stages[0].start_step != 0: + raise ValueError( + f"First stage '{self.stages[0].name}' must start at step 0, " + f"but starts at {self.stages[0].start_step}" + ) + + # Check for gaps or overlaps between stages + for i in range(len(self.stages) - 1): + current = self.stages[i] + next_stage = self.stages[i + 1] + + # Determine current stage's end + if current.end_step is not None: + current_end = current.end_step + else: + # If no end_step, it should extend to next stage's start + current_end = next_stage.start_step + + if current_end < next_stage.start_step: + raise ValueError( + f"Gap in data stages: '{current.name}' ends at {current_end} " + f"but '{next_stage.name}' starts at {next_stage.start_step}. " + f"Steps {current_end} to {next_stage.start_step - 1} are not covered." + ) + elif current_end > next_stage.start_step: + raise ValueError( + f"Overlap in data stages: '{current.name}' ends at {current_end} " + f"but '{next_stage.name}' starts at {next_stage.start_step}. " + f"Steps {next_stage.start_step} to {current_end - 1} are covered " + f"by both stages." + ) + + # Check last stage covers until training end + last_stage = self.stages[-1] + if last_stage.end_step is not None: + if last_stage.end_step < total_steps: + raise ValueError( + f"Last stage '{last_stage.name}' ends at step {last_stage.end_step} " + f"but training.steps is {total_steps}. " + f"Steps {last_stage.end_step} to {total_steps - 1} are not covered. " + f"Remove 'end_step' from the last stage to cover until training end." + ) + elif last_stage.end_step > total_steps: + logger.warning( + f"Last stage '{last_stage.name}' end_step ({last_stage.end_step}) " + f"exceeds training.steps ({total_steps}). " + f"Training will end at step {total_steps}." + ) + + def _has_non_default_training_data(self, training) -> bool: + """Check if [training] has non-default data fields set.""" + # Default values from Training dataclass + defaults = { + "dataset": "c4_test", + "dataset_path": None, + "dataset_type": "huggingface", + "dataset_folders": [], + "dataset_weights": None, + # Note: dataset_random_seed and seq_len are not checked since they + # have valid defaults that users commonly keep + } + return ( + training.dataset != defaults["dataset"] + or training.dataset_path != defaults["dataset_path"] + or training.dataset_type != defaults["dataset_type"] + or training.dataset_folders != defaults["dataset_folders"] + or training.dataset_weights != defaults["dataset_weights"] + ) + + def _format_training_data_fields(self, training) -> str: + """Format [training] data fields for error message.""" + lines = [] + if training.dataset != "c4_test": + lines.append(f" dataset = {training.dataset!r}") + if training.dataset_path is not None: + lines.append(f" dataset_path = {training.dataset_path!r}") + if training.dataset_type != "huggingface": + lines.append(f" dataset_type = {training.dataset_type!r}") + if training.dataset_folders: + lines.append(f" dataset_folders = {training.dataset_folders!r}") + if training.dataset_weights is not None: + lines.append(f" dataset_weights = {training.dataset_weights!r}") + return "\n".join(lines) if lines else " (none detected)" + + def _log_stage_plan(self) -> None: + """Log the data stage training plan.""" + logger.info("=" * 60) + logger.info("DATA STAGE TRAINING PLAN") + logger.info("=" * 60) + logger.info(f"Total stages: {len(self.stages)}") + logger.info("") + + training = self.job_config.training + total_steps = training.steps + + for i, stage in enumerate(self.stages): + effective = self.get_effective_config(stage) + end_step = stage.end_step if stage.end_step is not None else total_steps + stage_steps = end_step - stage.start_step + + # Calculate tokens for this stage + # tokens = steps * global_batch_size * seq_len + # Note: global_batch_size may be -1 (auto), so we show what we can + global_bs = training.global_batch_size + if global_bs > 0: + tokens = stage_steps * global_bs * effective.seq_len + token_str = f"{tokens / 1e9:.2f}B tokens" + else: + token_str = ( + f"{stage_steps} steps × batch_size × {effective.seq_len} seq_len" + ) + + logger.info(f"Stage {i + 1}: {stage.name}") + logger.info( + f" Steps: {stage.start_step:,} -> {end_step:,} ({stage_steps:,} steps)" + ) + logger.info(f" Estimated tokens: {token_str}") + logger.info(f" Dataset type: {effective.dataset_type}") + + if effective.dataset_folders: + logger.info( + f" Dataset folders: {len(effective.dataset_folders)} folders" + ) + for folder in effective.dataset_folders[:3]: # Show first 3 + logger.info(f" - {folder}") + if len(effective.dataset_folders) > 3: + logger.info( + f" ... and {len(effective.dataset_folders) - 3} more" + ) + else: + logger.info(f" Dataset: {effective.dataset}") + + if effective.dataset_weights: + weights_str = ", ".join( + f"{w:.3f}" for w in effective.dataset_weights[:5] + ) + if len(effective.dataset_weights) > 5: + weights_str += f", ... ({len(effective.dataset_weights)} total)" + logger.info(f" Weights: [{weights_str}]") + + logger.info(f" Sequence length: {effective.seq_len}") + logger.info("") + + logger.info("=" * 60) + + def get_effective_config(self, stage: DataStage) -> EffectiveStageConfig: + """Get effective config for a stage. Each stage is self-contained.""" + # Use training.dataset_random_seed as fallback since it's optional + training = self.job_config.training + return EffectiveStageConfig( + dataset=stage.dataset, + dataset_path=stage.dataset_path, + dataset_type=stage.dataset_type, + dataset_folders=stage.dataset_folders, + dataset_weights=stage.dataset_weights, + dataset_random_seed=( + stage.dataset_random_seed + if stage.dataset_random_seed is not None + else training.dataset_random_seed + ), + seq_len=stage.seq_len, + ) + + def find_stage_for_step(self, step: int) -> int: + """Find the stage index for the given training step.""" + if step < 0: + raise ValueError(f"Step cannot be negative, got {step}") + for i, stage in enumerate(self.stages): + in_range = step >= stage.start_step + if stage.end_step is not None: + in_range = in_range and step < stage.end_step + elif i + 1 < len(self.stages): + # If no end_step, use next stage's start_step + in_range = in_range and step < self.stages[i + 1].start_step + if in_range: + return i + # Default to last stage if step exceeds all ranges + return len(self.stages) - 1 + + def set_stage_for_step(self, step: int) -> bool: + """Set current stage based on step. Returns True if stage changed.""" + new_idx = self.find_stage_for_step(step) + if new_idx != self._current_stage_idx: + old_stage = self.stages[self._current_stage_idx] + self._current_stage_idx = new_idx + new_stage = self.stages[new_idx] + return True + return False + + def maybe_transition_stage(self, step: int) -> bool: + """Check if stage transition needed at this step. Returns True if transitioned.""" + new_idx = self.find_stage_for_step(step) + if new_idx != self._current_stage_idx: + old_stage = self.stages[self._current_stage_idx] + new_stage = self.stages[new_idx] + + logger.info("=" * 60) + logger.info("DATA STAGE TRANSITION") + logger.info("=" * 60) + logger.info(f"Step {step}: '{old_stage.name}' -> '{new_stage.name}'") + + old_effective = self.get_effective_config(old_stage) + new_effective = self.get_effective_config(new_stage) + + # Log what changed + changes = [] + if old_effective.dataset_weights != new_effective.dataset_weights: + changes.append("dataset_weights") + if old_effective.dataset_folders != new_effective.dataset_folders: + changes.append("dataset_folders") + if old_effective.seq_len != new_effective.seq_len: + changes.append( + f"seq_len: {old_effective.seq_len} -> {new_effective.seq_len}" + ) + + if changes: + logger.info(f"Changes: {', '.join(changes)}") + else: + logger.info("No config changes (stage name only)") + + if new_effective.dataset_weights: + weights_str = ", ".join( + f"{w:.3f}" for w in new_effective.dataset_weights[:5] + ) + if len(new_effective.dataset_weights) > 5: + weights_str += f", ... ({len(new_effective.dataset_weights)} total)" + logger.info(f"New weights: [{weights_str}]") + + logger.info("=" * 60) + + self._current_stage_idx = new_idx + return True + return False + + def build_dataloader_for_stage( + self, stage_idx: int | None = None + ) -> BaseDataLoader: + """Build dataloader for the specified stage (or current stage if None).""" + if stage_idx is None: + stage_idx = self._current_stage_idx + + stage = self.stages[stage_idx] + effective = self.get_effective_config(stage) + + logger.info(f"Building dataloader for stage '{stage.name}' (idx={stage_idx})") + + # Temporarily override training config with stage-specific values + training = self.job_config.training + original_values = {} + override_fields = [ + ("dataset", effective.dataset), + ("dataset_path", effective.dataset_path), + ("dataset_type", effective.dataset_type), + ("dataset_folders", effective.dataset_folders), + ("dataset_weights", effective.dataset_weights), + ("dataset_random_seed", effective.dataset_random_seed), + ("seq_len", effective.seq_len), + ] + + for field_name, new_value in override_fields: + original_values[field_name] = getattr(training, field_name) + setattr(training, field_name, new_value) + + try: + dataloader = self.build_dataloader_fn( + dp_world_size=self.dp_world_size, + dp_rank=self.dp_rank, + tokenizer=self.tokenizer, + job_config=self.job_config, + ) + finally: + # Restore original config values + for field_name, original_value in original_values.items(): + setattr(training, field_name, original_value) + + return dataloader + + def state_dict(self) -> dict[str, Any]: + """Return state for checkpointing.""" + return {"current_stage_idx": self._current_stage_idx} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Restore state from checkpoint.""" + if "current_stage_idx" in state_dict: + old_idx = self._current_stage_idx + self._current_stage_idx = state_dict["current_stage_idx"] + if old_idx != self._current_stage_idx: + logger.info( + f"Restored data stage index: {old_idx} -> {self._current_stage_idx} " + f"(stage: '{self.current_stage.name}')" + ) + + +class StageAwareDataloader(BaseDataLoader): + """Dataloader wrapper that handles multi-stage training with proper checkpoint support. + + This wrapper: + 1. Manages the underlying dataloader for the current stage + 2. Saves/restores both stage index AND dataloader state for exact checkpoint resume + 3. Rebuilds dataloader on stage transitions + + The key insight for checkpoint correctness: + - When saving: we save {stage_idx, dataloader_state_for_current_stage} + - When loading: we restore stage_idx, rebuild dataloader for that stage, + then restore the dataloader's internal state + + This ensures exact resume: same stage, same position within the dataset. + """ + + def __init__( + self, + stage_manager: DataStageManager, + initial_dataloader: BaseDataLoader, + ): + self._stage_manager = stage_manager + self._dataloader = initial_dataloader + self._dp_rank = stage_manager.dp_rank + self._dp_world_size = stage_manager.dp_world_size + + @property + def dataloader(self) -> BaseDataLoader: + """Get the underlying dataloader.""" + return self._dataloader + + def rebuild_for_current_stage(self) -> None: + """Rebuild the underlying dataloader for the current stage.""" + self._dataloader = self._stage_manager.build_dataloader_for_stage() + + def maybe_transition(self, step: int) -> bool: + """Check for stage transition and rebuild if needed. Returns True if transitioned.""" + if self._stage_manager.maybe_transition_stage(step): + self.rebuild_for_current_stage() + return True + return False + + def __iter__(self) -> Iterator: + """Iterate over the underlying dataloader.""" + return iter(self._dataloader) + + def __len__(self) -> int: + """Return length of underlying dataloader if available.""" + return len(self._dataloader) + + def state_dict(self) -> dict[str, Any]: + """Save state for checkpointing. + + Saves: + - stage_idx: Which stage we're in + - dataloader_state: Position within current stage's dataset + - dp_world_size: For validation on resume + """ + return { + "stage_idx": self._stage_manager.current_stage_idx, + "stage_name": self._stage_manager.current_stage.name, + "dataloader_state": self._dataloader.state_dict(), + "world_size": self._dp_world_size, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Restore state from checkpoint. + + This is the critical method for exact checkpoint resume: + 1. Restore stage_idx to stage manager + 2. Rebuild dataloader for the correct stage + 3. Restore dataloader's internal state (position in dataset) + """ + if not state_dict: + return + + # Validate world size consistency + if "world_size" in state_dict: + saved_world_size = state_dict["world_size"] + if saved_world_size != self._dp_world_size: + raise ValueError( + f"Data parallel world size changed from {saved_world_size} to " + f"{self._dp_world_size}. Dataloader state is incompatible." + ) + + # Restore stage index + if "stage_idx" in state_dict: + saved_stage_idx = state_dict["stage_idx"] + saved_stage_name = state_dict.get("stage_name", "unknown") + num_stages = len(self._stage_manager.stages) + + # Validate stage_idx is within bounds + if saved_stage_idx < 0 or saved_stage_idx >= num_stages: + raise ValueError( + f"Checkpoint stage_idx ({saved_stage_idx}) is out of bounds. " + f"Current config has {num_stages} stages (indices 0-{num_stages - 1}). " + f"Checkpoint was at stage '{saved_stage_name}'. " + "The stage configuration may have changed since the checkpoint was saved." + ) + + current_stage_idx = self._stage_manager.current_stage_idx + + if saved_stage_idx != current_stage_idx: + logger.info( + f"Checkpoint was at stage '{saved_stage_name}' (idx={saved_stage_idx}), " + f"rebuilding dataloader..." + ) + # Update stage manager's index + self._stage_manager._current_stage_idx = saved_stage_idx + # Rebuild dataloader for the restored stage + self.rebuild_for_current_stage() + + # Restore dataloader state (position in dataset) + if "dataloader_state" in state_dict: + try: + self._dataloader.load_state_dict(state_dict["dataloader_state"]) + logger.info("Restored dataloader position from checkpoint") + except Exception as e: + logger.warning( + f"Failed to restore dataloader state: {e}. " + "Training will resume from beginning of current stage's dataset." + ) + + +def build_stage_aware_dataloader( + job_config: JobConfig, + build_dataloader_fn: Callable, + dp_world_size: int, + dp_rank: int, + tokenizer: Any, +) -> tuple[StageAwareDataloader, DataStageManager]: + """Build a stage-aware dataloader. + + Data stages are required. At least one stage must be defined in + [[training.data_stages]]. + + Returns: + tuple of (StageAwareDataloader, DataStageManager) + """ + stage_manager = DataStageManager( + job_config=job_config, + build_dataloader_fn=build_dataloader_fn, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + ) + + # Build initial dataloader for stage 0 (or whichever stage step 0 falls into) + initial_dataloader = stage_manager.build_dataloader_for_stage() + dataloader = StageAwareDataloader(stage_manager, initial_dataloader) + logger.info(f"Created StageAwareDataloader with {len(stage_manager.stages)} stages") + + return dataloader, stage_manager diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 8a19466d63..386a189cac 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -290,6 +290,55 @@ class LRScheduler: """ +@dataclass +class DataStage: + """Configuration for a single training data stage. + + Each stage can override data-related settings from the Training config. + Fields set to None or empty will inherit from the base Training config. + + Example TOML: + [[training.data_stages]] + name = "general" + start_step = 0 + end_step = 100000 + dataset_weights = [0.7, 0.2, 0.1] + """ + + name: str = "stage_0" + """Name of the stage for logging and identification""" + + start_step: int = 0 + """Step at which this stage starts (inclusive)""" + + end_step: int | None = None + """Step at which this stage ends (exclusive). None means until training end or next stage""" + + # Data config fields - mirror Training config, None means inherit + dataset: str | None = None + """Dataset to use. None = inherit from training config""" + + dataset_path: str | None = None + """Path to the dataset. None = inherit from training config""" + + dataset_type: Literal[ + "huggingface", "nanoset", "preprocessed", "packed_memmap" + ] | None = None + """Type of dataset. None = inherit from training config""" + + dataset_folders: list[str] = field(default_factory=list) + """List of folders for Nanoset. Empty = inherit from training config""" + + dataset_weights: list[float] | None = None + """Weights for blending datasets. None = inherit from training config""" + + dataset_random_seed: int | None = None + """Random seed for this stage. None = inherit from training config""" + + seq_len: int | None = None + """Sequence length for this stage. None = inherit from training config""" + + @dataclass class Training: dataset: str = "c4_test" @@ -372,6 +421,28 @@ class Training: many temporary files. """ + data_stages: list[DataStage] = field(default_factory=list) + """ + List of data stages for multi-stage training. Each stage can override + dataset configs (dataset, weights, seq_len, etc.) for a range of training steps. + If empty, single-stage training is used with the base Training config. + + Stages are defined by start_step and end_step. The stage active at any step + is determined by which stage's range contains that step. + + Example TOML: + [[training.data_stages]] + name = "general" + start_step = 0 + end_step = 100000 + dataset_weights = [0.7, 0.2, 0.1] + + [[training.data_stages]] + name = "reasoning" + start_step = 100000 + dataset_weights = [0.3, 0.35, 0.35] + """ + @dataclass class Parallelism: diff --git a/torchtitan/train.py b/torchtitan/train.py index cde676431c..90a13636ca 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -16,6 +16,10 @@ import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.data_stages import ( + build_stage_aware_dataloader, + DataStageManager, +) from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training from torchtitan.components.loss import rescale_accumulated_loss @@ -54,6 +58,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # non-swappable training components checkpointer: CheckpointManager ft_manager: FTManager + data_stage_manager: DataStageManager # runtime utilities device: torch.device @@ -121,11 +126,13 @@ def __init__(self, job_config: JobConfig): else None ) - self.dataloader = self.train_spec.build_dataloader_fn( + # Build dataloader with data stage support + self.dataloader, self.data_stage_manager = build_stage_aware_dataloader( + job_config=job_config, + build_dataloader_fn=self.train_spec.build_dataloader_fn, dp_world_size=dp_degree, dp_rank=dp_rank, tokenizer=self.tokenizer, - job_config=job_config, ) # build model (using meta init) @@ -784,13 +791,21 @@ def train(self): ), ), ): - first_step_save = self.step == 0 and job_config.checkpoint.enable_first_step_checkpoint + first_step_save = ( + self.step == 0 and job_config.checkpoint.enable_first_step_checkpoint + ) if first_step_save: self.checkpointer.save(1, False) data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1 + + # Check for data stage transition + if self.dataloader.maybe_transition(self.step): + # Rebuild iterator after stage transition + data_iterator = self.batch_generator(self.dataloader) + self.gc_handler.run(self.step) try: self.train_step(data_iterator)