From 3209cabd4b336852cec1b077db5e7221944bc564 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 9 Jan 2026 18:44:02 +0000 Subject: [PATCH 1/4] add online evals --- scripts/eval/run_lm_eval.py | 385 ++++++++++++++ torchtitan/components/lm_evaluator.py | 695 ++++++++++++++++++++++++++ torchtitan/config/job_config.py | 145 ++++++ torchtitan/train.py | 32 +- 4 files changed, 1256 insertions(+), 1 deletion(-) create mode 100755 scripts/eval/run_lm_eval.py create mode 100644 torchtitan/components/lm_evaluator.py diff --git a/scripts/eval/run_lm_eval.py b/scripts/eval/run_lm_eval.py new file mode 100755 index 0000000000..3ce2666e03 --- /dev/null +++ b/scripts/eval/run_lm_eval.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python +# Copyright (c) Nous Research. +# All rights reserved. + +""" +Standalone script for running lm-evaluation-harness on torchtitan checkpoints. + +This script provides a command-line interface for running evaluations with +full reproducibility through seed control and configuration logging. + +Usage: + python scripts/eval/run_lm_eval.py \ + --checkpoint /path/to/checkpoint \ + --model_name qwen3 \ + --model_flavor 10B-A1B \ + --tasks hellaswag,arc_easy,arc_challenge,winogrande \ + --output_dir ./eval_results \ + --seed 42 + +For full reproducibility, all configuration is saved alongside results. +""" + +from __future__ import annotations + +import argparse +import json +import os +import random +import sys +from datetime import datetime +from pathlib import Path + + +def setup_paths(torchtitan_path: str | None, lm_eval_path: str | None) -> None: + """Setup Python paths for imports.""" + if torchtitan_path and torchtitan_path not in sys.path: + sys.path.insert(0, torchtitan_path) + + if lm_eval_path and lm_eval_path not in sys.path: + sys.path.insert(0, lm_eval_path) + + # Auto-detect torchtitan path if not provided + if not torchtitan_path: + possible_paths = [ + Path(__file__).parent.parent.parent, # scripts/eval/run_lm_eval.py + Path("/home/phuc/workspace/moe/online_evals/torchtitan"), + ] + for path in possible_paths: + if (path / "torchtitan").exists(): + sys.path.insert(0, str(path)) + break + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run lm-evaluation-harness on torchtitan checkpoints", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage + python run_lm_eval.py --checkpoint /path/to/checkpoint --model_name qwen3 --model_flavor 10B-A1B + + # With custom tasks and output + python run_lm_eval.py \\ + --checkpoint /path/to/checkpoint \\ + --model_name qwen3 \\ + --model_flavor 10B-A1B \\ + --tasks hellaswag,arc_easy,arc_challenge,winogrande \\ + --output_dir ./my_eval_results \\ + --seed 42 + + # Quick test with limited samples + python run_lm_eval.py \\ + --checkpoint /path/to/checkpoint \\ + --model_name llama3 \\ + --model_flavor 8B \\ + --tasks hellaswag \\ + --limit 100 + """, + ) + + # Required arguments + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to checkpoint directory (HF safetensors format)", + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + choices=["llama3", "qwen3"], + help="Model architecture name", + ) + parser.add_argument( + "--model_flavor", + type=str, + required=True, + help="Model flavor/size (e.g., 8B, 70B, 10B-A1B)", + ) + + # Optional arguments + parser.add_argument( + "--tokenizer_path", + type=str, + default=None, + help="Path to tokenizer (defaults to checkpoint path)", + ) + parser.add_argument( + "--tasks", + type=str, + default="hellaswag,arc_easy,arc_challenge,winogrande", + help="Comma-separated list of evaluation tasks", + ) + parser.add_argument( + "--num_fewshot", + type=int, + default=0, + help="Number of few-shot examples", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size for evaluation", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=2048, + help="Maximum sequence length", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Limit number of samples per task (None = full evaluation)", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./eval_results", + help="Directory to save evaluation results", + ) + parser.add_argument( + "--log_samples", + action="store_true", + default=True, + help="Log individual sample predictions", + ) + parser.add_argument( + "--no_log_samples", + action="store_false", + dest="log_samples", + help="Don't log individual sample predictions", + ) + + # Seed arguments for reproducibility + parser.add_argument( + "--seed", + type=int, + default=42, + help="Base random seed for all RNGs", + ) + parser.add_argument( + "--random_seed", + type=int, + default=None, + help="Override Python random seed", + ) + parser.add_argument( + "--numpy_seed", + type=int, + default=None, + help="Override NumPy random seed", + ) + parser.add_argument( + "--torch_seed", + type=int, + default=None, + help="Override PyTorch random seed", + ) + parser.add_argument( + "--fewshot_seed", + type=int, + default=None, + help="Override fewshot sampler seed", + ) + + # Path configuration + parser.add_argument( + "--torchtitan_path", + type=str, + default=None, + help="Path to torchtitan installation", + ) + parser.add_argument( + "--lm_eval_path", + type=str, + default=None, + help="Path to lm-evaluation-harness installation", + ) + + # Other options + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float16", "float32"], + help="Data type for model", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run evaluation on", + ) + + return parser.parse_args() + + +def get_seeds(args: argparse.Namespace) -> tuple[int, int, int, int]: + """Get seeds from args, falling back to base seed.""" + return ( + args.random_seed if args.random_seed is not None else args.seed, + args.numpy_seed if args.numpy_seed is not None else args.seed, + args.torch_seed if args.torch_seed is not None else args.seed, + args.fewshot_seed if args.fewshot_seed is not None else args.seed, + ) + + +def get_seed_string(args: argparse.Namespace) -> str: + """Get seed string for lm_eval CLI format.""" + seeds = get_seeds(args) + return f"{seeds[0]},{seeds[1]},{seeds[2]},{seeds[3]}" + + +def save_eval_config(args: argparse.Namespace, output_dir: Path) -> None: + """Save evaluation configuration for reproducibility.""" + seeds = get_seeds(args) + + config = { + "timestamp": datetime.now().isoformat(), + "checkpoint": args.checkpoint, + "tokenizer_path": args.tokenizer_path or args.checkpoint, + "model_name": args.model_name, + "model_flavor": args.model_flavor, + "tasks": args.tasks, + "num_fewshot": args.num_fewshot, + "batch_size": args.batch_size, + "max_seq_len": args.max_seq_len, + "limit": args.limit, + "dtype": args.dtype, + "device": args.device, + "seeds": { + "random_seed": seeds[0], + "numpy_seed": seeds[1], + "torch_seed": seeds[2], + "fewshot_seed": seeds[3], + }, + "command": " ".join(sys.argv), + } + + config_path = output_dir / "eval_config.json" + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"Configuration saved to: {config_path}") + + +def main() -> None: + """Main entry point.""" + args = parse_args() + + # Setup paths + setup_paths(args.torchtitan_path, args.lm_eval_path) + + import lm_eval + + # Import after path setup + import numpy as np + import torch + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get seeds + seeds = get_seeds(args) + random_seed, numpy_seed, torch_seed, fewshot_seed = seeds + + # Set seeds for reproducibility + random.seed(random_seed) + np.random.seed(numpy_seed) + torch.manual_seed(torch_seed) + + # Print configuration + print("=" * 60) + print("LM-EVALUATION-HARNESS") + print("=" * 60) + print(f"Checkpoint: {args.checkpoint}") + print(f"Model: {args.model_name} ({args.model_flavor})") + print(f"Tasks: {args.tasks}") + print(f"Batch size: {args.batch_size}") + print(f"Max seq len: {args.max_seq_len}") + print(f"Limit: {args.limit or 'None (full evaluation)'}") + print( + f"Seeds: random={random_seed}, numpy={numpy_seed}, torch={torch_seed}, fewshot={fewshot_seed}" + ) + print(f"Output: {output_dir}") + print("=" * 60) + + # Save configuration + save_eval_config(args, output_dir) + + # Build model args + tokenizer_path = args.tokenizer_path or args.checkpoint + model_args = ( + f"pretrained={args.checkpoint}," + f"tokenizer_path={tokenizer_path}," + f"model_name={args.model_name}," + f"model_flavor={args.model_flavor}," + f"dtype={args.dtype}," + f"max_seq_len={args.max_seq_len}," + f"device={args.device}" + ) + + print(f"\nModel args: {model_args}") + print("\nStarting evaluation...") + + # Run evaluation + results = lm_eval.simple_evaluate( + model="torchtitan", + model_args=model_args, + tasks=args.tasks.split(","), + num_fewshot=args.num_fewshot, + limit=args.limit, + batch_size=args.batch_size, + log_samples=args.log_samples, + ) + + # Save results + results_path = output_dir / "results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=2, default=str) + + print(f"\nResults saved to: {results_path}") + + # Print summary + print("\n" + "=" * 60) + print("RESULTS SUMMARY") + print("=" * 60) + + for task_name, task_results in results.get("results", {}).items(): + print(f"\n{task_name}:") + for metric, value in task_results.items(): + if isinstance(value, (int, float)): + print(f" {metric}: {value:.4f}") + + # Save summary to text file + summary_path = output_dir / "summary.txt" + with open(summary_path, "w") as f: + f.write(f"LM-Evaluation-Harness Results\n") + f.write(f"Generated: {datetime.now().isoformat()}\n") + f.write(f"Checkpoint: {args.checkpoint}\n") + f.write(f"Model: {args.model_name} ({args.model_flavor})\n") + f.write(f"Tasks: {args.tasks}\n") + f.write(f"Seeds: {get_seed_string(args)}\n\n") + + for task_name, task_results in results.get("results", {}).items(): + f.write(f"\n{task_name}:\n") + for metric, value in task_results.items(): + if isinstance(value, (int, float)): + f.write(f" {metric}: {value:.4f}\n") + + print(f"\nSummary saved to: {summary_path}") + print("\n" + "=" * 60) + print("Evaluation complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/components/lm_evaluator.py b/torchtitan/components/lm_evaluator.py new file mode 100644 index 0000000000..932b46adbc --- /dev/null +++ b/torchtitan/components/lm_evaluator.py @@ -0,0 +1,695 @@ +# Copyright (c) Nous Research. +# All rights reserved. + +""" +LM-Evaluation-Harness integration for torchtitan. + +This module provides automatic evaluation during training using +lm-evaluation-harness with full reproducibility through seed control +and configuration logging. +""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Any, TYPE_CHECKING + +from torchtitan.tools.logging import logger + +if TYPE_CHECKING: + from torchtitan.config.job_config import JobConfig, LMEvalConfig + + +class LMEvaluator: + """ + Evaluator component for running lm-evaluation-harness during training. + + Supports three execution modes: + - inline: Run in same process (blocks training) + - subprocess: Run in background subprocess (non-blocking) + - slurm: Submit as SLURM job (fully async) + + All modes support full reproducibility through: + - Seed control (4 independent seeds) + - Configuration logging + - Generated SLURM scripts for audit trail + """ + + def __init__(self, job_config: JobConfig, rank: int = 0) -> None: + """ + Initialize the LM Evaluator. + + Args: + job_config: The job configuration containing lm_eval settings + rank: Current distributed rank (eval only runs on rank 0) + """ + self.job_config = job_config + self.lm_eval_config = job_config.lm_eval + self.rank = rank + + # Only rank 0 runs evaluations + self.enabled = self.lm_eval_config.enable and rank == 0 + + if not self.enabled: + return + + # Setup directories + self.dump_folder = Path(job_config.job.dump_folder) + self.output_dir = self.dump_folder / self.lm_eval_config.output_dir + self.slurm_script_dir = self.dump_folder / self.lm_eval_config.slurm_script_dir + self.slurm_logs_dir = self.dump_folder / self.lm_eval_config.slurm_logs_dir + + # Create directories + self.output_dir.mkdir(parents=True, exist_ok=True) + if self.lm_eval_config.mode == "slurm": + self.slurm_script_dir.mkdir(parents=True, exist_ok=True) + self.slurm_logs_dir.mkdir(parents=True, exist_ok=True) + + # Resolve paths + self.torchtitan_path = self._resolve_torchtitan_path() + self.lm_eval_path = self.lm_eval_config.lm_eval_path + + # Track running jobs + self.running_jobs: dict[int, dict[str, Any]] = {} + + logger.info( + f"LMEvaluator initialized: mode={self.lm_eval_config.mode}, " + f"tasks={self.lm_eval_config.tasks}, interval={self.lm_eval_config.eval_interval}" + ) + + def _resolve_torchtitan_path(self) -> str: + """Resolve the torchtitan installation path.""" + if self.lm_eval_config.torchtitan_path: + return self.lm_eval_config.torchtitan_path + + # Try to find torchtitan from current module path + possible_paths = [ + Path(__file__).parent.parent.parent, # From components/lm_evaluator.py + Path("/home/phuc/workspace/moe/online_evals/torchtitan"), + ] + + for path in possible_paths: + if (path / "torchtitan").exists(): + return str(path) + + # Fallback to assuming it's in PYTHONPATH + return "" + + def should_evaluate(self, step: int) -> bool: + """ + Check if evaluation should run at this step. + + Args: + step: Current training step + + Returns: + True if evaluation should run + """ + if not self.enabled: + return False + + if step == 0: + return False + + return step % self.lm_eval_config.eval_interval == 0 + + def run_evaluation( + self, + step: int, + checkpoint_path: str, + ) -> dict[str, Any] | None: + """ + Run evaluation for the given checkpoint. + + Args: + step: Current training step + checkpoint_path: Path to the checkpoint to evaluate + + Returns: + Evaluation results dict (inline mode) or job info (subprocess/slurm mode) + """ + if not self.enabled: + return None + + logger.info(f"Starting evaluation at step {step}") + + # Create step-specific output directory + step_output_dir = self.output_dir / f"step_{step}" + step_output_dir.mkdir(parents=True, exist_ok=True) + + # Save evaluation config for reproducibility + self._save_eval_config(step, checkpoint_path, step_output_dir) + + if self.lm_eval_config.mode == "inline": + return self._run_inline(step, checkpoint_path, step_output_dir) + elif self.lm_eval_config.mode == "subprocess": + return self._run_subprocess(step, checkpoint_path, step_output_dir) + elif self.lm_eval_config.mode == "slurm": + return self._run_slurm(step, checkpoint_path, step_output_dir) + else: + raise ValueError(f"Unknown eval mode: {self.lm_eval_config.mode}") + + def _save_eval_config( + self, + step: int, + checkpoint_path: str, + output_dir: Path, + ) -> None: + """Save evaluation configuration for reproducibility.""" + eval_info = { + "step": step, + "checkpoint_path": checkpoint_path, + "timestamp": datetime.now().isoformat(), + "lm_eval_config": asdict(self.lm_eval_config), + "model_config": { + "name": self.job_config.model.name, + "flavor": self.job_config.model.flavor, + }, + "seeds": { + "random_seed": self.lm_eval_config.get_seeds()[0], + "numpy_seed": self.lm_eval_config.get_seeds()[1], + "torch_seed": self.lm_eval_config.get_seeds()[2], + "fewshot_seed": self.lm_eval_config.get_seeds()[3], + }, + } + + config_path = output_dir / "eval_config.json" + with open(config_path, "w") as f: + json.dump(eval_info, f, indent=2, default=str) + + logger.info(f"Saved eval config to {config_path}") + + def _get_eval_command( + self, + checkpoint_path: str, + output_dir: Path, + ) -> list[str]: + """Build the lm_eval command line arguments.""" + cfg = self.lm_eval_config + model_cfg = self.job_config.model + + # For DCP checkpoints, tokenizer is in hf_assets_path + # For HF checkpoints, tokenizer is in checkpoint_path + tokenizer_path = model_cfg.hf_assets_path + + # Build model_args string + model_args = ( + f"pretrained={checkpoint_path}," + f"tokenizer_path={tokenizer_path}," + f"model_name={model_cfg.name}," + f"model_flavor={model_cfg.flavor}," + f"dtype=bfloat16," + f"max_seq_len={cfg.max_seq_len}" + ) + + cmd = [ + sys.executable, + "-m", + "lm_eval", + "--model", + "torchtitan", + "--model_args", + model_args, + "--tasks", + cfg.tasks, + "--num_fewshot", + str(cfg.num_fewshot), + "--batch_size", + str(cfg.batch_size), + "--seed", + cfg.get_seed_string(), + "--output_path", + str(output_dir), + ] + + if cfg.limit is not None: + cmd.extend(["--limit", str(cfg.limit)]) + + if cfg.log_samples: + cmd.append("--log_samples") + + return cmd + + def _run_inline( + self, + step: int, + checkpoint_path: str, + output_dir: Path, + ) -> dict[str, Any]: + """Run evaluation inline (blocking).""" + logger.info(f"Running inline evaluation for step {step}") + + try: + import lm_eval + + cfg = self.lm_eval_config + model_cfg = self.job_config.model + + # Set seeds for reproducibility + seeds = cfg.get_seeds() + import random + + import numpy as np + import torch + + random.seed(seeds[0]) + np.random.seed(seeds[1]) + torch.manual_seed(seeds[2]) + + # For DCP checkpoints, tokenizer is in hf_assets_path + tokenizer_path = model_cfg.hf_assets_path + + model_args = ( + f"pretrained={checkpoint_path}," + f"tokenizer_path={tokenizer_path}," + f"model_name={model_cfg.name}," + f"model_flavor={model_cfg.flavor}," + f"dtype=bfloat16," + f"max_seq_len={cfg.max_seq_len}" + ) + + results = lm_eval.simple_evaluate( + model="torchtitan", + model_args=model_args, + tasks=cfg.tasks.split(","), + num_fewshot=cfg.num_fewshot, + limit=cfg.limit, + batch_size=cfg.batch_size, + log_samples=cfg.log_samples, + ) + + # Save results + results_path = output_dir / "results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=2, default=str) + + # Log summary + self._log_results_summary(step, results) + + return results + + except Exception as e: + logger.error(f"Inline evaluation failed: {e}") + error_info = {"error": str(e), "step": step} + error_path = output_dir / "error.json" + with open(error_path, "w") as f: + json.dump(error_info, f, indent=2) + raise + + def _run_subprocess( + self, + step: int, + checkpoint_path: str, + output_dir: Path, + ) -> dict[str, Any]: + """Run evaluation in background subprocess (non-blocking).""" + logger.info(f"Starting subprocess evaluation for step {step}") + + # Generate eval script + script_path = output_dir / "run_eval.py" + self._generate_eval_script(checkpoint_path, output_dir, script_path) + + # Build environment + env = os.environ.copy() + pythonpath = env.get("PYTHONPATH", "") + if self.torchtitan_path: + pythonpath = f"{self.torchtitan_path}:{pythonpath}" + if self.lm_eval_path: + pythonpath = f"{self.lm_eval_path}:{pythonpath}" + env["PYTHONPATH"] = pythonpath + + # Start subprocess + log_path = output_dir / "eval.log" + with open(log_path, "w") as log_file: + process = subprocess.Popen( + [sys.executable, str(script_path)], + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + ) + + job_info = { + "mode": "subprocess", + "step": step, + "pid": process.pid, + "script_path": str(script_path), + "log_path": str(log_path), + "output_dir": str(output_dir), + } + + self.running_jobs[step] = job_info + logger.info(f"Subprocess evaluation started: PID={process.pid}") + + return job_info + + def _run_slurm( + self, + step: int, + checkpoint_path: str, + output_dir: Path, + ) -> dict[str, Any]: + """Submit evaluation as SLURM job (fully async).""" + logger.info(f"Submitting SLURM evaluation job for step {step}") + + # Generate eval script + script_path = output_dir / "run_eval.py" + self._generate_eval_script(checkpoint_path, output_dir, script_path) + + # Generate SLURM script + slurm_script_path = self.slurm_script_dir / f"eval_step_{step}.sh" + self._generate_slurm_script(step, script_path, output_dir, slurm_script_path) + + # Submit job + result = subprocess.run( + ["sbatch", str(slurm_script_path)], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + logger.error(f"SLURM submission failed: {result.stderr}") + raise RuntimeError(f"SLURM submission failed: {result.stderr}") + + # Parse job ID from "Submitted batch job 12345" + job_id = result.stdout.strip().split()[-1] + + job_info = { + "mode": "slurm", + "step": step, + "job_id": job_id, + "script_path": str(script_path), + "slurm_script_path": str(slurm_script_path), + "output_dir": str(output_dir), + } + + self.running_jobs[step] = job_info + logger.info(f"SLURM job submitted: {job_id}") + + # Save job info + job_info_path = output_dir / "slurm_job_info.json" + with open(job_info_path, "w") as f: + json.dump(job_info, f, indent=2) + + return job_info + + def _generate_eval_script( + self, + checkpoint_path: str, + output_dir: Path, + script_path: Path, + ) -> None: + """Generate a standalone evaluation script for reproducibility.""" + cfg = self.lm_eval_config + model_cfg = self.job_config.model + seeds = cfg.get_seeds() + + script_content = f'''#!/usr/bin/env python +""" +Auto-generated evaluation script for reproducibility. +Generated at: {datetime.now().isoformat()} + +To rerun this evaluation: + python {script_path} +""" + +import sys +import json +import random +import numpy as np +import torch + +# Add torchtitan to path if needed +torchtitan_path = "{self.torchtitan_path}" +if torchtitan_path and torchtitan_path not in sys.path: + sys.path.insert(0, torchtitan_path) + +lm_eval_path = "{self.lm_eval_path or ''}" +if lm_eval_path and lm_eval_path not in sys.path: + sys.path.insert(0, lm_eval_path) + +import lm_eval + +# Set seeds for reproducibility +RANDOM_SEED = {seeds[0]} +NUMPY_SEED = {seeds[1]} +TORCH_SEED = {seeds[2]} +FEWSHOT_SEED = {seeds[3]} + +random.seed(RANDOM_SEED) +np.random.seed(NUMPY_SEED) +torch.manual_seed(TORCH_SEED) + +# Evaluation configuration +CHECKPOINT_PATH = "{checkpoint_path}" +OUTPUT_DIR = "{output_dir}" +MODEL_NAME = "{model_cfg.name}" +MODEL_FLAVOR = "{model_cfg.flavor}" +TASKS = "{cfg.tasks}" +NUM_FEWSHOT = {cfg.num_fewshot} +BATCH_SIZE = {cfg.batch_size} +MAX_SEQ_LEN = {cfg.max_seq_len} +LIMIT = {cfg.limit if cfg.limit is not None else "None"} +LOG_SAMPLES = {cfg.log_samples} + +print("=" * 60) +print("LM-EVALUATION-HARNESS") +print("=" * 60) +print(f"Checkpoint: {{CHECKPOINT_PATH}}") +print(f"Model: {{MODEL_NAME}} ({{MODEL_FLAVOR}})") +print(f"Tasks: {{TASKS}}") +print(f"Seeds: random={{RANDOM_SEED}}, numpy={{NUMPY_SEED}}, torch={{TORCH_SEED}}, fewshot={{FEWSHOT_SEED}}") +print("=" * 60) + +model_args = ( + f"pretrained={{CHECKPOINT_PATH}}," + f"tokenizer_path={{CHECKPOINT_PATH}}," + f"model_name={{MODEL_NAME}}," + f"model_flavor={{MODEL_FLAVOR}}," + f"dtype=bfloat16," + f"max_seq_len={{MAX_SEQ_LEN}}" +) + +results = lm_eval.simple_evaluate( + model="torchtitan", + model_args=model_args, + tasks=TASKS.split(","), + num_fewshot=NUM_FEWSHOT, + limit=LIMIT, + batch_size=BATCH_SIZE, + log_samples=LOG_SAMPLES, +) + +# Save results +results_path = f"{{OUTPUT_DIR}}/results.json" +with open(results_path, "w") as f: + json.dump(results, f, indent=2, default=str) +print(f"\\nResults saved to: {{results_path}}") + +# Print summary +print("\\n" + "=" * 60) +print("RESULTS SUMMARY") +print("=" * 60) +for task_name, task_results in results.get("results", {{}}).items(): + print(f"\\n{{task_name}}:") + for metric, value in task_results.items(): + if isinstance(value, (int, float)): + print(f" {{metric}}: {{value:.4f}}") +''' + + with open(script_path, "w") as f: + f.write(script_content) + + # Make executable + script_path.chmod(0o755) + + logger.info(f"Generated eval script: {script_path}") + + def _generate_slurm_script( + self, + step: int, + eval_script_path: Path, + output_dir: Path, + slurm_script_path: Path, + ) -> None: + """Generate SLURM submission script for evaluation.""" + slurm_cfg = self.lm_eval_config.slurm + job_name = f"lm_eval_step_{step}" + + # Build PYTHONPATH + pythonpath_parts = [] + if self.torchtitan_path: + pythonpath_parts.append(self.torchtitan_path) + if self.lm_eval_path: + pythonpath_parts.append(self.lm_eval_path) + pythonpath = ":".join(pythonpath_parts) if pythonpath_parts else "" + + # Build conda activation command + conda_cmd = "" + if slurm_cfg.conda_env: + conda_cmd = f"conda activate {slurm_cfg.conda_env}" + + # Build extra sbatch args + extra_args = slurm_cfg.extra_sbatch_args + + script_content = f"""#!/bin/bash +#SBATCH --job-name={job_name} +#SBATCH --output={self.slurm_logs_dir}/{job_name}_%j.out +#SBATCH --error={self.slurm_logs_dir}/{job_name}_%j.err +#SBATCH --partition={slurm_cfg.partition} +#SBATCH --gpus-per-node={slurm_cfg.gpus_per_node} +#SBATCH --cpus-per-task={slurm_cfg.cpus_per_task} +#SBATCH --time={slurm_cfg.time} +#SBATCH --qos={slurm_cfg.qos} +""" + + if slurm_cfg.account: + script_content += f"#SBATCH --account={slurm_cfg.account}\n" + if slurm_cfg.reservation: + script_content += f"#SBATCH --reservation={slurm_cfg.reservation}\n" + if extra_args: + script_content += f"#SBATCH {extra_args}\n" + + script_content += f""" +# Auto-generated SLURM script for lm-evaluation-harness +# Generated at: {datetime.now().isoformat()} +# Step: {step} +# To resubmit: sbatch {slurm_script_path} + +set -e + +echo "==========================================" +echo "LM-Evaluation-Harness SLURM Job" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $SLURMD_NODENAME" +echo "Step: {step}" +echo "==========================================" + +# Environment setup +export HF_HOME="{slurm_cfg.hf_cache}" +export TRANSFORMERS_CACHE="{slurm_cfg.hf_cache}" +""" + + if pythonpath: + script_content += f'export PYTHONPATH="{pythonpath}:$PYTHONPATH"\n' + + if conda_cmd: + script_content += f""" +# Activate conda environment +{conda_cmd} +""" + + script_content += f""" +# Run evaluation +echo "Starting evaluation..." +python {eval_script_path} + +echo "==========================================" +echo "Evaluation complete!" +echo "Results saved to: {output_dir}" +echo "==========================================" +""" + + with open(slurm_script_path, "w") as f: + f.write(script_content) + + # Make executable + slurm_script_path.chmod(0o755) + + logger.info(f"Generated SLURM script: {slurm_script_path}") + + def _log_results_summary(self, step: int, results: dict[str, Any]) -> None: + """Log a summary of evaluation results.""" + logger.info(f"Evaluation results at step {step}:") + + for task_name, task_results in results.get("results", {}).items(): + metrics_str = ", ".join( + f"{k}={v:.4f}" + for k, v in task_results.items() + if isinstance(v, (int, float)) + ) + logger.info(f" {task_name}: {metrics_str}") + + def check_running_jobs(self) -> dict[int, str]: + """ + Check status of running evaluation jobs. + + Returns: + Dict mapping step to status ('running', 'completed', 'failed') + """ + statuses = {} + + for step, job_info in list(self.running_jobs.items()): + if job_info["mode"] == "subprocess": + # Check if process is still running + try: + pid = job_info["pid"] + os.kill(pid, 0) # Check if process exists + statuses[step] = "running" + except ProcessLookupError: + # Process finished, check for results + results_path = Path(job_info["output_dir"]) / "results.json" + if results_path.exists(): + statuses[step] = "completed" + # Load and log results + with open(results_path) as f: + results = json.load(f) + self._log_results_summary(step, results) + else: + statuses[step] = "failed" + del self.running_jobs[step] + + elif job_info["mode"] == "slurm": + # Check SLURM job status + job_id = job_info["job_id"] + result = subprocess.run( + ["squeue", "-j", job_id, "-h", "-o", "%T"], + capture_output=True, + text=True, + ) + + if result.returncode != 0 or not result.stdout.strip(): + # Job not in queue, check if completed + results_path = Path(job_info["output_dir"]) / "results.json" + if results_path.exists(): + statuses[step] = "completed" + with open(results_path) as f: + results = json.load(f) + self._log_results_summary(step, results) + else: + statuses[step] = "failed" + del self.running_jobs[step] + else: + status = result.stdout.strip() + if status in ("PENDING", "RUNNING", "CONFIGURING"): + statuses[step] = "running" + elif status == "COMPLETED": + statuses[step] = "completed" + else: + statuses[step] = "failed" + + return statuses + + def get_latest_results(self) -> dict[int, dict[str, Any]]: + """ + Get all available evaluation results. + + Returns: + Dict mapping step to results dict + """ + results = {} + + for step_dir in self.output_dir.iterdir(): + if step_dir.is_dir() and step_dir.name.startswith("step_"): + results_path = step_dir / "results.json" + if results_path.exists(): + step = int(step_dir.name.replace("step_", "")) + with open(results_path) as f: + results[step] = json.load(f) + + return results diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 98544e4649..62fe61d487 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -1139,6 +1139,150 @@ class Debug: """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" +@dataclass +class LMEvalSlurmConfig: + """SLURM configuration for lm-evaluation-harness jobs.""" + + partition: str = "gpu" + """SLURM partition to submit eval jobs to""" + + gpus_per_node: int = 1 + """Number of GPUs per node for evaluation""" + + cpus_per_task: int = 8 + """Number of CPUs per task""" + + time: str = "02:00:00" + """Time limit for eval job""" + + qos: str = "normal" + """Quality of service""" + + account: str | None = None + """SLURM account to charge""" + + reservation: str | None = None + """SLURM reservation to use""" + + hf_cache: str = "/scratch/huggingface" + """HuggingFace cache directory on compute nodes""" + + conda_env: str | None = None + """Conda environment to activate (if any)""" + + extra_sbatch_args: str = "" + """Additional sbatch arguments (e.g., '--constraint=a100')""" + + +@dataclass +class LMEvalConfig: + """ + Configuration for automatic lm-evaluation-harness during training. + + This enables running standardized evaluations at checkpoint intervals, + with full reproducibility through seed control and config logging. + """ + + enable: bool = False + """Enable automatic evaluation during training""" + + eval_interval: int = 500 + """ + Run evaluation every N steps. Must be a multiple of checkpoint.interval. + Evaluation runs after checkpoint is saved at this step. + """ + + tasks: str = "hellaswag,arc_easy" + """Comma-separated list of lm-evaluation-harness tasks to run""" + + num_fewshot: int = 0 + """Number of few-shot examples for evaluation""" + + limit: int | None = None + """ + Limit number of samples per task. None = full evaluation. + Use smaller values (e.g., 100) during training for faster feedback. + """ + + batch_size: int = 4 + """Batch size for evaluation""" + + max_seq_len: int = 2048 + """Maximum sequence length for evaluation""" + + # Seeds for reproducibility + seed: int = 42 + """Base random seed for all random number generators""" + + random_seed: int | None = None + """Override Python random seed (defaults to seed if None)""" + + numpy_seed: int | None = None + """Override NumPy random seed (defaults to seed if None)""" + + torch_seed: int | None = None + """Override PyTorch random seed (defaults to seed if None)""" + + fewshot_seed: int | None = None + """Override fewshot sampler seed (defaults to seed if None)""" + + # Execution mode + mode: Literal["inline", "subprocess", "slurm"] = "inline" + """ + Execution mode for evaluation: + - 'inline': Run in same process (blocks training, simplest) + - 'subprocess': Run in background subprocess (non-blocking) + - 'slurm': Submit as SLURM job (fully async, separate resources) + """ + + # Output configuration + output_dir: str = "eval_results" + """Directory to save evaluation results (relative to job.dump_folder)""" + + log_samples: bool = True + """Whether to log individual sample predictions""" + + # SLURM configuration (only used when mode='slurm') + slurm: LMEvalSlurmConfig = field(default_factory=LMEvalSlurmConfig) + """SLURM configuration for eval jobs""" + + slurm_script_dir: str = "eval_slurm_scripts" + """Directory to save generated SLURM scripts (relative to job.dump_folder)""" + + slurm_logs_dir: str = "eval_slurm_logs" + """Directory for SLURM job logs (relative to job.dump_folder)""" + + # Path configuration + torchtitan_path: str | None = None + """ + Path to torchtitan installation. If None, auto-detected. + Used for PYTHONPATH in subprocess/slurm modes. + """ + + lm_eval_path: str | None = None + """ + Path to lm-evaluation-harness installation. If None, uses installed package. + """ + + def __post_init__(self): + if self.enable and self.eval_interval <= 0: + raise ValueError("eval_interval must be positive when lm_eval is enabled") + + def get_seeds(self) -> tuple[int, int, int, int]: + """Return (random_seed, numpy_seed, torch_seed, fewshot_seed).""" + return ( + self.random_seed if self.random_seed is not None else self.seed, + self.numpy_seed if self.numpy_seed is not None else self.seed, + self.torch_seed if self.torch_seed is not None else self.seed, + self.fewshot_seed if self.fewshot_seed is not None else self.seed, + ) + + def get_seed_string(self) -> str: + """Return seed string for CLI: 'random,numpy,torch,fewshot'.""" + seeds = self.get_seeds() + return f"{seeds[0]},{seeds[1]},{seeds[2]},{seeds[3]}" + + @dataclass class JobConfig: """ @@ -1166,6 +1310,7 @@ class JobConfig: validation: Validation = field(default_factory=Validation) grpo: GRPO = field(default_factory=GRPO) debug: Debug = field(default_factory=Debug) + lm_eval: LMEvalConfig = field(default_factory=LMEvalConfig) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/train.py b/torchtitan/train.py index 18895480a0..a878a1c797 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -18,6 +18,7 @@ from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training +from torchtitan.components.lm_evaluator import LMEvaluator from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, @@ -54,6 +55,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # non-swappable training components checkpointer: CheckpointManager ft_manager: FTManager + lm_evaluator: LMEvaluator | None # runtime utilities device: torch.device @@ -328,6 +330,16 @@ def __init__(self, job_config: JobConfig): ft_manager=self.ft_manager, ) + # Initialize LM Evaluator for automatic evaluation during training + self.lm_evaluator = None + if job_config.lm_eval.enable: + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ) + self.lm_evaluator = LMEvaluator(job_config, rank=rank) + loss_parallel_enabled = ( parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel @@ -524,7 +536,9 @@ def _collect_moe_expert_metrics(self) -> dict[str, Any]: if counts is None: continue - checkpoint_impl = getattr(transformer_block, "checkpoint_impl", None) + checkpoint_impl = getattr( + transformer_block, "checkpoint_impl", None + ) if ( CheckpointImpl is not None @@ -778,6 +792,22 @@ def train(self): self.step, last_step=(self.step == job_config.training.steps) ) + # Run lm-evaluation-harness if enabled and at eval interval + if self.lm_evaluator is not None and self.lm_evaluator.should_evaluate( + self.step + ): + # Get checkpoint path for evaluation + checkpoint_folder = job_config.checkpoint.folder + checkpoint_path = os.path.join( + job_config.job.dump_folder, + checkpoint_folder, + f"step-{self.step}", + ) + try: + self.lm_evaluator.run_evaluation(self.step, checkpoint_path) + except Exception as e: + logger.warning(f"LM evaluation failed at step {self.step}: {e}") + # Run validation if validator is available if ( self.job_config.validation.enable From 169a179d876f284e9a3b8ba53bc3b95565bf695e Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 9 Jan 2026 18:55:04 +0000 Subject: [PATCH 2/4] add slurm support --- torchtitan/components/lm_evaluator.py | 59 +++++++++++++++++++++------ torchtitan/config/job_config.py | 15 ++++--- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/torchtitan/components/lm_evaluator.py b/torchtitan/components/lm_evaluator.py index 932b46adbc..898fe13839 100644 --- a/torchtitan/components/lm_evaluator.py +++ b/torchtitan/components/lm_evaluator.py @@ -447,6 +447,7 @@ def _generate_eval_script( # Evaluation configuration CHECKPOINT_PATH = "{checkpoint_path}" +TOKENIZER_PATH = "{model_cfg.hf_assets_path}" OUTPUT_DIR = "{output_dir}" MODEL_NAME = "{model_cfg.name}" MODEL_FLAVOR = "{model_cfg.flavor}" @@ -468,7 +469,7 @@ def _generate_eval_script( model_args = ( f"pretrained={{CHECKPOINT_PATH}}," - f"tokenizer_path={{CHECKPOINT_PATH}}," + f"tokenizer_path={{TOKENIZER_PATH}}," f"model_name={{MODEL_NAME}}," f"model_flavor={{MODEL_FLAVOR}}," f"dtype=bfloat16," @@ -529,25 +530,23 @@ def _generate_slurm_script( pythonpath_parts.append(self.lm_eval_path) pythonpath = ":".join(pythonpath_parts) if pythonpath_parts else "" - # Build conda activation command - conda_cmd = "" - if slurm_cfg.conda_env: - conda_cmd = f"conda activate {slurm_cfg.conda_env}" - # Build extra sbatch args extra_args = slurm_cfg.extra_sbatch_args script_content = f"""#!/bin/bash #SBATCH --job-name={job_name} -#SBATCH --output={self.slurm_logs_dir}/{job_name}_%j.out -#SBATCH --error={self.slurm_logs_dir}/{job_name}_%j.err -#SBATCH --partition={slurm_cfg.partition} -#SBATCH --gpus-per-node={slurm_cfg.gpus_per_node} +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task={slurm_cfg.gpus_per_node} #SBATCH --cpus-per-task={slurm_cfg.cpus_per_task} #SBATCH --time={slurm_cfg.time} -#SBATCH --qos={slurm_cfg.qos} +#SBATCH --partition={slurm_cfg.partition} +#SBATCH --output={self.slurm_logs_dir}/{job_name}_%j.out +#SBATCH --error={self.slurm_logs_dir}/{job_name}_%j.err """ + if slurm_cfg.qos: + script_content += f"#SBATCH --qos={slurm_cfg.qos}\n" if slurm_cfg.account: script_content += f"#SBATCH --account={slurm_cfg.account}\n" if slurm_cfg.reservation: @@ -571,6 +570,12 @@ def _generate_slurm_script( echo "==========================================" # Environment setup +export LOGLEVEL=INFO +export FI_PROVIDER="efa" +export PYTHONFAULTHANDLER=1 +export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +export CUDA_LAUNCH_BLOCKING=0 export HF_HOME="{slurm_cfg.hf_cache}" export TRANSFORMERS_CACHE="{slurm_cfg.hf_cache}" """ @@ -578,21 +583,49 @@ def _generate_slurm_script( if pythonpath: script_content += f'export PYTHONPATH="{pythonpath}:$PYTHONPATH"\n' - if conda_cmd: + # Activate virtual environment if specified + if slurm_cfg.venv_path: + script_content += f""" +# Activate Python virtual environment +export PATH="{slurm_cfg.venv_path}/bin:$PATH" +export CONDA_PREFIX="{slurm_cfg.venv_path}" +echo "Activated venv: {slurm_cfg.venv_path}" +""" + elif slurm_cfg.conda_env: script_content += f""" # Activate conda environment -{conda_cmd} +conda activate {slurm_cfg.conda_env} """ script_content += f""" +# Verify python is available +which python || echo "ERROR: python not found in PATH" + +# Record start time +START_TIME=$(date +%s) + # Run evaluation echo "Starting evaluation..." python {eval_script_path} +EXIT_CODE=$? +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) + +# Save status +if [ $EXIT_CODE -eq 0 ]; then + echo "SUCCESS: Evaluation completed (Duration: ${{DURATION}}s)" +else + echo "FAILED: Evaluation failed with exit code $EXIT_CODE" +fi + echo "==========================================" echo "Evaluation complete!" echo "Results saved to: {output_dir}" +echo "Duration: ${{DURATION}} seconds" echo "==========================================" + +exit $EXIT_CODE """ with open(slurm_script_path, "w") as f: diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 62fe61d487..5d15ee5867 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -1143,20 +1143,20 @@ class Debug: class LMEvalSlurmConfig: """SLURM configuration for lm-evaluation-harness jobs.""" - partition: str = "gpu" + partition: str = "batch" """SLURM partition to submit eval jobs to""" gpus_per_node: int = 1 """Number of GPUs per node for evaluation""" - cpus_per_task: int = 8 + cpus_per_task: int = 16 """Number of CPUs per task""" time: str = "02:00:00" """Time limit for eval job""" - qos: str = "normal" - """Quality of service""" + qos: str | None = None + """Quality of service (optional, uses cluster default if not specified)""" account: str | None = None """SLURM account to charge""" @@ -1164,14 +1164,17 @@ class LMEvalSlurmConfig: reservation: str | None = None """SLURM reservation to use""" - hf_cache: str = "/scratch/huggingface" + hf_cache: str = "/home/shared/huggingface-cache" """HuggingFace cache directory on compute nodes""" + venv_path: str | None = None + """Path to Python virtual environment (e.g., /path/to/env)""" + conda_env: str | None = None """Conda environment to activate (if any)""" extra_sbatch_args: str = "" - """Additional sbatch arguments (e.g., '--constraint=a100')""" + """Additional sbatch arguments (e.g., '--exclusive')""" @dataclass From e1571cf5d98f912bba2f3e189f4deeb0751e480b Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 9 Jan 2026 20:43:34 +0000 Subject: [PATCH 3/4] add readme and example script --- docs/online_evals.md | 214 ++++++++++++++++++ torchtitan/components/lm_evaluator.py | 22 +- torchtitan/config/job_config.py | 27 +++ .../train_configs/online_eval_test.toml | 57 +++++ 4 files changed, 314 insertions(+), 6 deletions(-) create mode 100644 docs/online_evals.md create mode 100644 torchtitan/models/llama3/train_configs/online_eval_test.toml diff --git a/docs/online_evals.md b/docs/online_evals.md new file mode 100644 index 0000000000..6f39445ad9 --- /dev/null +++ b/docs/online_evals.md @@ -0,0 +1,214 @@ +# Online Evaluation with lm-evaluation-harness + +This document describes how to run automatic evaluations during training using [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). + +## Overview + +Online evaluation allows you to track model performance on standard benchmarks throughout training, providing insights into: +- Training progress and convergence +- Potential overfitting or underfitting +- Optimal checkpoint selection + +## Features + +- **Three execution modes**: inline (blocking), subprocess (background), SLURM (async) +- **Full reproducibility**: Seed control for all random number generators +- **Automatic checkpoint evaluation**: Runs after checkpoint saves at specified intervals +- **Generated scripts**: Standalone Python scripts for reproducibility and debugging + +## Configuration + +Add the `[lm_eval]` section to your training config: + +```toml +[lm_eval] +enable = true +tasks = "hellaswag,arc_easy" +eval_interval = 500 # Run eval every N steps +num_fewshot = 0 +limit = 100 # Samples per task (None = full eval) +batch_size = 4 +mode = "slurm" # "inline", "subprocess", or "slurm" + +[lm_eval.slurm] +partition = "batch" +time = "01:00:00" +gpus_per_node = 1 +``` + +## Execution Modes + +### 1. Inline Mode (`mode = "inline"`) + +Runs evaluation in the same process as training. Simple but **blocks training** during evaluation. + +**Best for**: Quick tests, small models, or when GPU resources are limited. + +```toml +[lm_eval] +enable = true +mode = "inline" +limit = 50 # Use small limit to reduce blocking time +``` + +### 2. Subprocess Mode (`mode = "subprocess"`) + +Runs evaluation in a background subprocess on the same node. **Non-blocking** but shares resources with training. + +**Best for**: Development and testing on single nodes. + +```toml +[lm_eval] +enable = true +mode = "subprocess" +``` + +### 3. SLURM Mode (`mode = "slurm"`) + +Submits evaluation as a separate SLURM job. **Fully async** with dedicated resources. + +**Best for**: Production training on clusters. + +```toml +[lm_eval] +enable = true +mode = "slurm" +job_name_prefix = "my_experiment_eval" + +[lm_eval.slurm] +partition = "batch" +time = "02:00:00" +gpus_per_node = 1 +cpus_per_task = 16 +hf_cache = "/home/shared/huggingface-cache" +``` + +## Configuration Reference + +### LMEvalConfig + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `enable` | bool | false | Enable automatic evaluation | +| `eval_interval` | int | 500 | Run eval every N steps (must align with checkpoint.interval) | +| `tasks` | str | "hellaswag,arc_easy" | Comma-separated lm-eval tasks | +| `num_fewshot` | int | 0 | Number of few-shot examples | +| `limit` | int\|None | None | Samples per task (None = full) | +| `batch_size` | int | 4 | Evaluation batch size | +| `max_seq_len` | int | 2048 | Maximum sequence length | +| `mode` | str | "inline" | Execution mode | +| `seed` | int | 42 | Base random seed | +| `output_dir` | str | "eval_results" | Results directory (relative to dump_folder) | +| `log_samples` | bool | true | Log individual predictions | +| `job_name_prefix` | str | "lm_eval" | SLURM job name prefix | + +### LMEvalSlurmConfig + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `partition` | str | "batch" | SLURM partition | +| `gpus_per_node` | int | 1 | GPUs for eval job | +| `cpus_per_task` | int | 16 | CPUs per task | +| `time` | str | "02:00:00" | Time limit | +| `qos` | str\|None | None | Quality of service | +| `account` | str\|None | None | SLURM account | +| `hf_cache` | str | "/home/shared/huggingface-cache" | HuggingFace cache path | +| `venv_path` | str\|None | None | Python venv path | +| `conda_env` | str\|None | None | Conda environment name | + +## Output Structure + +After training with online eval, your dump folder will contain: + +``` +dump_folder/ +├── checkpoint/ +│ ├── step-500/ +│ └── step-1000/ +├── eval_results/ +│ ├── step_500/ +│ │ ├── eval_config.json # Evaluation configuration +│ │ ├── results.json # lm-eval results +│ │ └── run_eval.py # Standalone eval script +│ └── step_1000/ +│ └── ... +├── eval_slurm_scripts/ # (SLURM mode only) +│ ├── eval_step_500.sh +│ └── eval_step_1000.sh +└── eval_slurm_logs/ # (SLURM mode only) + ├── lm_eval_step_500_12345.out + └── lm_eval_step_500_12345.err +``` + +## Re-running Evaluations + +Each evaluation generates a standalone script that can be re-run: + +```bash +# Re-run a specific evaluation +python /path/to/dump_folder/eval_results/step_500/run_eval.py + +# Or resubmit the SLURM job +sbatch /path/to/dump_folder/eval_slurm_scripts/eval_step_500.sh +``` + +## Quick Test + +Use the provided test config to verify online eval works: + +```bash +# Set paths (adjust to your environment) +export TORCHTITAN_PATH="/home/phuc/workspace/moe/online_evals/torchtitan" +export LM_EVAL_PATH="/home/phuc/workspace/moe/online_evals/lm-evaluation-harness" +export PYTHON_ENV="/home/phuc/workspace/moe/online_evals/env/bin" + +# Run test: 20 training steps with eval at steps 5, 10, 15, 20 +PYTHONPATH="${TORCHTITAN_PATH}:${LM_EVAL_PATH}:$PYTHONPATH" \ +${PYTHON_ENV}/torchrun --nproc_per_node=1 --standalone \ + -m torchtitan.train \ + --job.config-file ${TORCHTITAN_PATH}/torchtitan/models/llama3/train_configs/online_eval_test.toml +``` + +The test config (`torchtitan/models/llama3/train_configs/online_eval_test.toml`) runs: +- 20 training steps on the c4_test dataset with llama3 debugmodel +- Checkpoints saved every 5 steps +- Evaluation (hellaswag, 20 samples) after each checkpoint +- SLURM mode (async) for non-blocking evaluation + +**Expected output:** +``` +[titan] Training starts at step 1 +[titan] step: 5 loss: ... +[titan] Saving the checkpoint... +[titan] Starting evaluation at step 5 +[titan] Evaluation results at step 5: +[titan] hellaswag: acc,none=0.XXXX, acc_norm,none=0.XXXX +... +``` + +## Troubleshooting + +### Eval interval doesn't match checkpoint interval + +The `eval_interval` must be a multiple of `checkpoint.interval`. Evaluation runs after checkpoints are saved. + +### SLURM job fails with "python not found" + +Set `venv_path` in SLURM config to your Python environment: + +```toml +[lm_eval.slurm] +venv_path = "/path/to/your/venv" +``` + +### Out of memory during inline eval + +- Reduce `batch_size` +- Use `mode = "slurm"` to run on separate resources +- Set `limit` to reduce number of samples + +### Evaluation results not appearing + +Check the SLURM logs in `eval_slurm_logs/` for errors. Common issues: +- Missing HuggingFace cache permissions +- Incompatible model/tokenizer paths diff --git a/torchtitan/components/lm_evaluator.py b/torchtitan/components/lm_evaluator.py index 898fe13839..eecc866ab0 100644 --- a/torchtitan/components/lm_evaluator.py +++ b/torchtitan/components/lm_evaluator.py @@ -313,7 +313,7 @@ def _run_subprocess( # Generate eval script script_path = output_dir / "run_eval.py" - self._generate_eval_script(checkpoint_path, output_dir, script_path) + self._generate_eval_script(step, checkpoint_path, output_dir, script_path) # Build environment env = os.environ.copy() @@ -359,7 +359,7 @@ def _run_slurm( # Generate eval script script_path = output_dir / "run_eval.py" - self._generate_eval_script(checkpoint_path, output_dir, script_path) + self._generate_eval_script(step, checkpoint_path, output_dir, script_path) # Generate SLURM script slurm_script_path = self.slurm_script_dir / f"eval_step_{step}.sh" @@ -400,6 +400,7 @@ def _run_slurm( def _generate_eval_script( self, + step: int, checkpoint_path: str, output_dir: Path, script_path: Path, @@ -446,6 +447,7 @@ def _generate_eval_script( torch.manual_seed(TORCH_SEED) # Evaluation configuration +STEP = {step} CHECKPOINT_PATH = "{checkpoint_path}" TOKENIZER_PATH = "{model_cfg.hf_assets_path}" OUTPUT_DIR = "{output_dir}" @@ -501,6 +503,7 @@ def _generate_eval_script( for metric, value in task_results.items(): if isinstance(value, (int, float)): print(f" {{metric}}: {{value:.4f}}") + ''' with open(script_path, "w") as f: @@ -520,7 +523,8 @@ def _generate_slurm_script( ) -> None: """Generate SLURM submission script for evaluation.""" slurm_cfg = self.lm_eval_config.slurm - job_name = f"lm_eval_step_{step}" + job_name_prefix = self.lm_eval_config.job_name_prefix + job_name = f"{job_name_prefix}_step_{step}" # Build PYTHONPATH pythonpath_parts = [] @@ -583,8 +587,10 @@ def _generate_slurm_script( if pythonpath: script_content += f'export PYTHONPATH="{pythonpath}:$PYTHONPATH"\n' - # Activate virtual environment if specified + # Determine python executable path + # Use venv python if specified, otherwise use the current python executable if slurm_cfg.venv_path: + python_exec = f"{slurm_cfg.venv_path}/bin/python" script_content += f""" # Activate Python virtual environment export PATH="{slurm_cfg.venv_path}/bin:$PATH" @@ -592,21 +598,25 @@ def _generate_slurm_script( echo "Activated venv: {slurm_cfg.venv_path}" """ elif slurm_cfg.conda_env: + python_exec = "python" # Will use conda python after activation script_content += f""" # Activate conda environment conda activate {slurm_cfg.conda_env} """ + else: + # Use the current python executable (best default for matching environment) + python_exec = sys.executable script_content += f""" # Verify python is available -which python || echo "ERROR: python not found in PATH" +which {python_exec} || echo "ERROR: {python_exec} not found" # Record start time START_TIME=$(date +%s) # Run evaluation echo "Starting evaluation..." -python {eval_script_path} +{python_exec} {eval_script_path} EXIT_CODE=$? END_TIME=$(date +%s) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 5d15ee5867..7f74553552 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -1238,6 +1238,13 @@ class LMEvalConfig: - 'slurm': Submit as SLURM job (fully async, separate resources) """ + job_name_prefix: str = "lm_eval" + """ + Prefix for SLURM job names. Final job name will be {prefix}_step_{step}. + Use this to identify eval jobs for a specific training run. + Example: 'lm_eval_my_experiment' -> 'lm_eval_my_experiment_step_100' + """ + # Output configuration output_dir: str = "eval_results" """Directory to save evaluation results (relative to job.dump_folder)""" @@ -1267,6 +1274,26 @@ class LMEvalConfig: Path to lm-evaluation-harness installation. If None, uses installed package. """ + # WandB integration + log_to_wandb: bool = True + """ + Log evaluation results to WandB. Results are logged to the same run as training, + allowing you to track eval metrics alongside training metrics over time. + Requires wandb to be enabled in metrics config and initialized during training. + """ + + wandb_project: str | None = None + """ + WandB project name. If None, auto-detected from the training run. + Falls back to WANDB_PROJECT env var or 'torchtitan' if not set. + """ + + wandb_entity: str | None = None + """ + WandB entity/team name. If None, auto-detected from the training run. + Falls back to WANDB_TEAM env var if not set. + """ + def __post_init__(self): if self.enable and self.eval_interval <= 0: raise ValueError("eval_interval must be positive when lm_eval is enabled") diff --git a/torchtitan/models/llama3/train_configs/online_eval_test.toml b/torchtitan/models/llama3/train_configs/online_eval_test.toml new file mode 100644 index 0000000000..edff0194ce --- /dev/null +++ b/torchtitan/models/llama3/train_configs/online_eval_test.toml @@ -0,0 +1,57 @@ +# Online Evaluation Test Config +# Quick test: 20 training steps with evaluation every 5 steps +# +# Usage: +# PYTHONPATH="/path/to/torchtitan:/path/to/lm-evaluation-harness:$PYTHONPATH" \ +# torchrun --nproc_per_node=1 --standalone \ +# -m torchtitan.train --job.config-file configs/online_eval_test.toml + +[job] +dump_folder = "./outputs/online_eval_test" +description = "Online evaluation test run" + +[model] +name = "llama3" +flavor = "debugmodel" + +[training] +steps = 20 +local_batch_size = 1 +seq_len = 128 +dataset = "c4_test" + +[checkpoint] +enable = true +interval = 5 +folder = "checkpoint" + +[lr_scheduler] +warmup_steps = 0 + +[metrics] +log_freq = 1 + +# ============================================ +# Online Evaluation Configuration +# ============================================ +[lm_eval] +enable = true +tasks = "hellaswag" +eval_interval = 5 # Eval at steps 5, 10, 15, 20 +num_fewshot = 0 +limit = 20 # Small limit for fast testing +batch_size = 4 +max_seq_len = 2048 +mode = "slurm" # Async SLURM evaluation +seed = 42 +output_dir = "eval_results" +log_samples = false # Disable for faster evals +job_name_prefix = "eval_test" + +# SLURM config +[lm_eval.slurm] +partition = "batch" +time = "00:30:00" +gpus_per_node = 1 +cpus_per_task = 16 +hf_cache = "/home/shared/huggingface-cache" From 0741b21b21ea05d4746b4803680ac86da3ceefc9 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 9 Jan 2026 20:45:20 +0000 Subject: [PATCH 4/4] minor --- docs/online_evals.md | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/docs/online_evals.md b/docs/online_evals.md index 6f39445ad9..2425f2eaf8 100644 --- a/docs/online_evals.md +++ b/docs/online_evals.md @@ -200,15 +200,3 @@ Set `venv_path` in SLURM config to your Python environment: [lm_eval.slurm] venv_path = "/path/to/your/venv" ``` - -### Out of memory during inline eval - -- Reduce `batch_size` -- Use `mode = "slurm"` to run on separate resources -- Set `limit` to reduce number of samples - -### Evaluation results not appearing - -Check the SLURM logs in `eval_slurm_logs/` for errors. Common issues: -- Missing HuggingFace cache permissions -- Incompatible model/tokenizer paths