diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..92b89fc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,84 @@ +name: Bug Report +description: File a bug report to help us improve Quark +title: "[Bug]: " +labels: ["bug", "triage"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + + - type: textarea + id: description + attributes: + label: Bug Description + description: What happened? What did you expect to happen? + placeholder: I tried to... but instead... + validations: + required: true + + - type: textarea + id: reproduction + attributes: + label: Steps to Reproduce + description: How can we reproduce this issue? + placeholder: | + 1. Configure '...' + 2. Run training with '...' + 3. See error + validations: + required: true + + - type: textarea + id: code + attributes: + label: Minimal Code Example + description: Please provide a minimal code example that reproduces the issue + render: python + placeholder: | + import torch + from models.transformer import Transformer, TransformerConfig + + # Your code here... + validations: + required: false + + - type: textarea + id: traceback + attributes: + label: Error Message / Traceback + description: If applicable, paste the full error message or traceback + render: shell + placeholder: Paste error message here... + validations: + required: false + + - type: textarea + id: environment + attributes: + label: Environment + description: Python version, PyTorch version, OS, hardware (GPU/CPU) + placeholder: | + Python 3.12 + PyTorch 2.10 + Ubuntu 22.04 + NVIDIA A100 + validations: + required: true + + - type: textarea + id: additional-context + attributes: + label: Additional Context + description: Add any other context about the problem here + placeholder: Any additional information... + validations: + required: false + + - type: checkboxes + id: terms + attributes: + label: Checklist + options: + - label: I have searched existing issues to ensure this is not a duplicate + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..e6ccc3a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: GitHub Discussions + url: https://github.com/phnazari/quark/discussions + about: Ask questions or start a discussion about Quark diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..b92ec26 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,57 @@ +name: Feature Request +description: Suggest a new feature or enhancement for Quark +title: "[Feature]: " +labels: ["enhancement"] +body: + - type: markdown + attributes: + value: | + Thanks for suggesting a new feature! + + - type: textarea + id: problem + attributes: + label: Problem Statement + description: What problem would this feature solve? What are you trying to achieve? + placeholder: I'm trying to... but currently... + validations: + required: true + + - type: textarea + id: solution + attributes: + label: Proposed Solution + description: Describe how you'd like this to work + placeholder: I would like Quark to... + validations: + required: true + + - type: textarea + id: example + attributes: + label: Example Code + description: Show us how you'd like to use this feature + render: python + placeholder: | + from models.transformer import Transformer, TransformerConfig + + # Example of how the feature would be used + validations: + required: false + + - type: textarea + id: additional-context + attributes: + label: Additional Context + description: Any other context, references to papers, implementations in other frameworks, etc. + placeholder: Related papers, implementations, etc. + validations: + required: false + + - type: checkboxes + id: terms + attributes: + label: Checklist + options: + - label: I have searched existing issues to ensure this feature hasn't been requested + required: true diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..6594f86 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,19 @@ +## Description + + + +## Type of Change + +- [ ] 🐛 Bug fix +- [ ] ✨ New feature +- [ ] 💥 Breaking change +- [ ] 📚 Documentation +- [ ] Other + +## Changes Made + + + +- +- +- diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..09afde6 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,59 @@ +# ============================= +# Adapted from: +# https://github.com/fla-org/flash-linear-attention +# ============================= +name: Lint + +on: + workflow_dispatch: + pull_request: + branches: [main] + push: + branches: [main] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Check out repo + uses: actions/checkout@v4 + with: + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version-file: .python-version + + - name: Set up uv + uses: astral-sh/setup-uv@v3 + + - name: Cache uv environment + uses: actions/cache@v4 + with: + path: .venv + key: ${{ runner.os }}-uv-${{ hashFiles('pyproject.toml', 'uv.lock') }} + + - name: Cache pre-commit hooks + uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: ${{ runner.os }}-precommit-${{ hashFiles('.pre-commit-config.yaml') }} + + - name: Sync environment + run: uv sync --extra dev + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v46.0.5 + + - name: Lint only changed files + if: ${{ steps.changed-files.outputs.all_changed_files != '' }} + run: | + echo "Changed files: ${{ steps.changed-files.outputs.all_changed_files }}" + uv run pre-commit run --files ${{ steps.changed-files.outputs.all_changed_files }} + + - name: No files changed + if: ${{ steps.changed-files.outputs.all_changed_files == '' }} + run: echo "No changed files to lint." diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..7b25bb0 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,38 @@ +name: Tests + +on: + workflow_dispatch: + pull_request: + branches: [main] + push: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Check out repo + uses: actions/checkout@v4 + with: + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version-file: .python-version + + - name: Set up uv + uses: astral-sh/setup-uv@v3 + + - name: Cache uv environment + uses: actions/cache@v4 + with: + path: .venv + key: ${{ runner.os }}-uv-${{ hashFiles('pyproject.toml', 'uv.lock') }} + + - name: Sync environment + run: uv sync --extra dev + + - name: Run tests + run: uv run pytest tests/ -v diff --git a/checkpoint_utils.py b/checkpoint_utils.py index d75a9c7..4322be7 100644 --- a/checkpoint_utils.py +++ b/checkpoint_utils.py @@ -31,7 +31,7 @@ def save_checkpoint(step, model, engine, cfg, metrics=None): "scaler": engine.scaler.state_dict(), } - exp_dir = os.path.join(cfg.out_dir, cfg.exp_name) + exp_dir = os.path.join(cfg.out_dir, cfg.checkpoint.exp_name) os.makedirs(exp_dir, exist_ok=True) save_path = os.path.join(exp_dir, f"ckpt_step_{step}.pth") @@ -46,14 +46,18 @@ def save_checkpoint(step, model, engine, cfg, metrics=None): def maybe_load_checkpoint(cfg): """Load a checkpoint if resuming, else return None.""" - if not cfg.resume: + if not cfg.checkpoint.resume: return None - resume_exp_name = cfg.resume_exp_name if cfg.resume_exp_name is not None else cfg.exp_name + resume_exp_name = ( + cfg.checkpoint.resume_exp_name + if cfg.checkpoint.resume_exp_name is not None + else cfg.checkpoint.exp_name + ) ckpt_dir = os.path.join(cfg.out_dir, resume_exp_name) - if cfg.resume_step is not None: - ckpt_path = os.path.join(ckpt_dir, f"ckpt_step_{cfg.resume_step}.pth") + if cfg.checkpoint.resume_step is not None: + ckpt_path = os.path.join(ckpt_dir, f"ckpt_step_{cfg.checkpoint.resume_step}.pth") else: ckpt_path = _latest_checkpoint(ckpt_dir, prefix="ckpt_step_") diff --git a/configs/checkpoint/default.yaml b/configs/checkpoint/default.yaml new file mode 100644 index 0000000..251ac36 --- /dev/null +++ b/configs/checkpoint/default.yaml @@ -0,0 +1,8 @@ +save_last_checkpoint: true +save_intermediate_checkpoints: true +save_every_steps: 1000 +resume: false +resume_step: null +resume_exp_name: null +over_write: true +exp_name: default diff --git a/configs/config.yaml b/configs/config.yaml index 9d8437a..4b22cd5 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,59 +1,14 @@ defaults: - model: transformer - data: fineweb10B + - training: default + - system: default + - logging: default + - checkpoint: default - _self_ -training: - steps_budget: 19064 - eval_every_steps: 200 - log_every_steps: 100 - grad_accumulation_steps: 2 - # Optimizer - optim: adamw - fused_optim: false - lr: 7e-4 - weight_decay: 0.1 - beta1: 0.9 - beta2: 0.95 - grad_clip: 1.0 - eps: 1e-15 - # Scheduler - scheduler: warmup_cosine - warmup_steps: 950 - cooldown_steps: null - lr_start: 0.0 - lr_end: null - lr_end_pct: 0.1 - # Early stopping - early_stopping_patience: 0 - -system: - dtype: bfloat16 - compile_model: false - seed: 42 - ddp_backend: nccl - -logging: - wandb_log: true - wandb_project: quark - wandb_log_layer_stats: false - -checkpoint: - save_last_checkpoint: true - save_intermediate_checkpoints: true - save_every_steps: 1000 - resume: false - resume_step: null - resume_exp_name: null - over_write: true - exp_name: default - out_dir: /fast/pnazari/quark -data: - seq_len: 2048 - micro_batch_size: 16 - hydra: job: chdir: false diff --git a/configs/data/fineweb10B.yaml b/configs/data/fineweb10B.yaml index 2de72db..5b2d089 100644 --- a/configs/data/fineweb10B.yaml +++ b/configs/data/fineweb10B.yaml @@ -1,10 +1,7 @@ dataset: fineweb10B +vocab_size: 50304 trainset_path: /fast/pnazari/data/fwedu_sample_10BT/tokenized_EleutherAI_gpt-neox-20b/ctx_2048/train validset_path: /fast/pnazari/data/fwedu_sample_10BT/tokenized_EleutherAI_gpt-neox-20b/ctx_2048/valid -seq_len: 1024 -micro_batch_size: 32 -num_workers: 17 -sampler: stateful_random -sampler_seed: 42 +seq_len: 2048 eval: true valid_tokens: 10000000 diff --git a/configs/logging/default.yaml b/configs/logging/default.yaml new file mode 100644 index 0000000..a1f0986 --- /dev/null +++ b/configs/logging/default.yaml @@ -0,0 +1,3 @@ +wandb_log: true +wandb_project: quark +wandb_log_layer_stats: false diff --git a/configs/system/default.yaml b/configs/system/default.yaml new file mode 100644 index 0000000..4e43480 --- /dev/null +++ b/configs/system/default.yaml @@ -0,0 +1,4 @@ +dtype: bfloat16 +compile_model: false +seed: 42 +ddp_backend: nccl diff --git a/configs/training/default.yaml b/configs/training/default.yaml new file mode 100644 index 0000000..5585afc --- /dev/null +++ b/configs/training/default.yaml @@ -0,0 +1,26 @@ +steps_budget: 19064 +micro_batch_size: 16 +eval_every_steps: 200 +log_every_steps: 100 +grad_accumulation_steps: 2 +num_workers: 17 +sampler: stateful_random +sampler_seed: 42 +# Optimizer +optim: adamw +fused_optim: false +lr: 7e-4 +weight_decay: 0.1 +beta1: 0.9 +beta2: 0.95 +grad_clip: 1.0 +eps: 1e-15 +# Scheduler +scheduler: warmup_cosine +warmup_steps: 950 +cooldown_steps: null +lr_start: 0.0 +lr_end: null +lr_end_pct: 0.1 +# Early stopping +early_stopping_patience: 0 diff --git a/data/dataloaders.py b/data/dataloaders.py index 51d2cca..4b85136 100644 --- a/data/dataloaders.py +++ b/data/dataloaders.py @@ -15,7 +15,7 @@ def get_dataloaders(cfg): """Load trainset and validset, return DataLoaders.""" - train_set = load_from_disk(cfg.trainset_path) + train_set = load_from_disk(cfg.data.trainset_path) if not isinstance(train_set, Dataset): raise ValueError("dataset should be a datasets.Dataset") @@ -33,23 +33,23 @@ def collate_fn(batch): trainloader = DataLoader( train_set, sampler=train_sampler, - batch_size=cfg.micro_batch_size, - num_workers=cfg.num_workers, + batch_size=cfg.training.micro_batch_size, + num_workers=cfg.training.num_workers, pin_memory=True, - prefetch_factor=2 if cfg.num_workers > 0 else None, - persistent_workers=cfg.num_workers > 0, + prefetch_factor=2 if cfg.training.num_workers > 0 else None, + persistent_workers=cfg.training.num_workers > 0, collate_fn=collate_fn if has_docs_lengths else None, ) - if not cfg.validset_path: + if not cfg.data.eval or not cfg.data.validset_path: return trainloader, None - valid_set = load_from_disk(cfg.validset_path) + valid_set = load_from_disk(cfg.data.validset_path) if not isinstance(valid_set, Dataset): raise ValueError("dataset should be a datasets.Dataset") - if getattr(cfg, "valid_tokens", None): - valid_rows = cfg.valid_tokens // (cfg.seq_len + 1) + if cfg.data.valid_tokens: + valid_rows = cfg.data.valid_tokens // (cfg.data.seq_len + 1) valid_set = valid_set.take(min(len(valid_set), valid_rows)) if dist.is_initialized(): @@ -61,13 +61,13 @@ def collate_fn(batch): validloader = DataLoader( valid_set, - batch_size=cfg.micro_batch_size, + batch_size=cfg.training.micro_batch_size, drop_last=True, shuffle=False, sampler=valid_sampler, - num_workers=cfg.num_workers, + num_workers=cfg.training.num_workers, pin_memory=True, - prefetch_factor=2 if cfg.num_workers > 0 else None, + prefetch_factor=2 if cfg.training.num_workers > 0 else None, persistent_workers=False, collate_fn=collate_fn if has_docs_lengths_valid else None, ) @@ -79,52 +79,62 @@ def _get_sampler(train_set, cfg): """Initialize a sampler for the training DataLoader.""" ddp = dist.is_initialized() - if cfg.sampler == "random": + if cfg.training.sampler == "random": if ddp: sampler = DistributedSampler( - train_set, shuffle=True, seed=cfg.sampler_seed, drop_last=True + train_set, shuffle=True, seed=cfg.training.sampler_seed, drop_last=True ) else: sampler = RandomSampler( train_set, generator=( - torch.Generator().manual_seed(cfg.sampler_seed) if cfg.sampler_seed else None + torch.Generator().manual_seed(cfg.training.sampler_seed) + if cfg.training.sampler_seed + else None ), ) - elif cfg.sampler == "sequential": + elif cfg.training.sampler == "sequential": if ddp: sampler = DistributedSampler(train_set, shuffle=False, drop_last=True) else: sampler = SequentialSampler(train_set) - elif cfg.sampler == "stateful_random": - micro_step_start = cfg.resume_step * cfg.grad_accumulation_steps if cfg.resume else 0 + elif cfg.training.sampler == "stateful_random": + micro_step_start = ( + cfg.checkpoint.resume_step * cfg.training.grad_accumulation_steps + if cfg.checkpoint.resume + else 0 + ) if ddp: sampler = StatefulDistributedSampler( train_set, - batch_size=cfg.micro_batch_size, - seed=cfg.sampler_seed, + batch_size=cfg.training.micro_batch_size, + seed=cfg.training.sampler_seed, start_iter=micro_step_start, ) else: sampler = StatefulRandomSampler( train_set, - batch_size=cfg.micro_batch_size, + batch_size=cfg.training.micro_batch_size, shuffle=True, - seed=cfg.sampler_seed, + seed=cfg.training.sampler_seed, start_idx=micro_step_start, ) - elif cfg.sampler == "stateful_sequential": - micro_step_start = cfg.resume_step * cfg.grad_accumulation_steps if cfg.resume else 0 + elif cfg.training.sampler == "stateful_sequential": + micro_step_start = ( + cfg.checkpoint.resume_step * cfg.training.grad_accumulation_steps + if cfg.checkpoint.resume + else 0 + ) if ddp: raise NotImplementedError("StatefulDistributedSampler currently needs a seed.") sampler = StatefulSequentialSampler( - train_set, batch_size=cfg.micro_batch_size, start_idx=micro_step_start + train_set, batch_size=cfg.training.micro_batch_size, start_idx=micro_step_start ) else: - raise NotImplementedError(f"Sampler {cfg.sampler} is not implemented.") + raise NotImplementedError(f"Sampler {cfg.training.sampler} is not implemented.") return sampler diff --git a/engine/engine.py b/engine/engine.py index decfc53..a639158 100644 --- a/engine/engine.py +++ b/engine/engine.py @@ -22,17 +22,17 @@ def __init__(self, model, cfg, device, local_rank, ckpt): self.micro_steps = 0 self.accumulated_samples = 0 - self.seq_len = cfg.seq_len - self.accumulation_steps = cfg.grad_accumulation_steps - self.grad_clip = cfg.grad_clip - self.dtype = cfg.dtype + self.seq_len = cfg.data.seq_len + self.accumulation_steps = cfg.training.grad_accumulation_steps + self.grad_clip = cfg.training.grad_clip + self.dtype = cfg.system.dtype self.device = device # Load model state dict if resuming - if cfg.resume: + if cfg.checkpoint.resume: model.load_state_dict(ckpt["state_dict"]) - self.micro_steps = ckpt["step"] * cfg.grad_accumulation_steps + self.micro_steps = ckpt["step"] * cfg.training.grad_accumulation_steps # Move model to device and wrap in DDP self.model = model.to(device) @@ -40,7 +40,7 @@ def __init__(self, model, cfg, device, local_rank, ckpt): self.model = DDP(self.model, device_ids=[local_rank], find_unused_parameters=False) # Compile - if cfg.compile_model: + if cfg.system.compile_model: print("Compiling the model...") self.model = torch.compile(self.model) @@ -59,11 +59,11 @@ def __init__(self, model, cfg, device, local_rank, ckpt): self.scaler = torch.amp.GradScaler(enabled=(self.dtype == "float16")) # Optimizer and scheduler - param_groups = get_param_groups(model, cfg.weight_decay) + param_groups = get_param_groups(model, cfg.training.weight_decay) self.optimizer = initialize_optimizer(param_groups, cfg) self.scheduler = initialize_scheduler(self.optimizer, cfg) - if cfg.resume: + if cfg.checkpoint.resume: self.optimizer.load_state_dict(ckpt["optimizer"]) if ckpt.get("scheduler"): self.scheduler.load_state_dict(ckpt["scheduler"]) diff --git a/eval/harness.py b/eval/harness.py index 11fbc1a..36b16fe 100644 --- a/eval/harness.py +++ b/eval/harness.py @@ -7,7 +7,7 @@ from lm_eval.api.registry import register_model from lm_eval.models.huggingface import HFLM from models import DeltaNet, DeltaNetWrapperConfig, Transformer, TransformerConfig -from utils import flatten_config +from omegaconf import OmegaConf def build_model_for_eval(config_name, model_overrides=None): @@ -18,7 +18,7 @@ def build_model_for_eval(config_name, model_overrides=None): model_overrides (dict, optional): Model overrides. Returns: - tuple: (model, flat_cfg) + tuple: (model, cfg) """ with hydra.initialize(version_base=None, config_path="../configs"): # We override the model if specified, e.g., model=delta_net @@ -28,9 +28,8 @@ def build_model_for_eval(config_name, model_overrides=None): overrides.append(f"model.{k}={v}") cfg = hydra.compose(config_name="config", overrides=overrides) - flat_cfg = flatten_config(cfg) - vocab_size = flat_cfg.vocab_size if hasattr(flat_cfg, "vocab_size") else 50304 + vocab_size = cfg.data.vocab_size model_type = cfg.model.model_type if model_type == "delta_net": @@ -46,7 +45,7 @@ def build_model_for_eval(config_name, model_overrides=None): use_short_conv=cfg.model.use_short_conv, conv_size=cfg.model.conv_size, ) - return DeltaNet(config), flat_cfg + return DeltaNet(config), cfg config = TransformerConfig( vocab_size=vocab_size, @@ -54,11 +53,11 @@ def build_model_for_eval(config_name, model_overrides=None): num_layers=cfg.model.num_layers, num_heads=cfg.model.num_heads, head_dim=cfg.model.head_dim, - block_size=flat_cfg.seq_len, + block_size=cfg.data.seq_len, dropout=cfg.model.dropout, bias=cfg.model.bias, ) - return Transformer(config), flat_cfg + return Transformer(config), cfg @register_model("quark") @@ -75,6 +74,7 @@ def __init__( dtype="bfloat16", **kwargs, ): + """Initialize QuarkLM for evaluation.""" # 1. Build model and load checkpoint # Extract Quark-specific kwargs to avoid passing them to HFLM custom_kwargs = [ @@ -95,27 +95,40 @@ def __init__( if k in kwargs: model_overrides[k] = kwargs.pop(k) - model_wrapper, flat_cfg = build_model_for_eval(config, model_overrides) + model_wrapper, cfg = build_model_for_eval(config, model_overrides) - # Configure flat_cfg for maybe_load_checkpoint - flat_cfg.resume = True + # Configure cfg for maybe_load_checkpoint + cfg_override = OmegaConf.create({"checkpoint": {"resume": True}}) + cfg = OmegaConf.merge(cfg, cfg_override) # Determine out_dir and exp_name if os.path.isabs(checkpoint): - flat_cfg.out_dir = os.path.dirname(checkpoint) - flat_cfg.resume_exp_name = os.path.basename(checkpoint) + cfg_override = OmegaConf.create( + { + "out_dir": os.path.dirname(checkpoint), + "checkpoint": {"resume_exp_name": os.path.basename(checkpoint)}, + } + ) + cfg = OmegaConf.merge(cfg, cfg_override) else: - # Assume checkpoint is just the exp_name within the default flat_cfg.out_dir - flat_cfg.resume_exp_name = checkpoint + cfg_override = OmegaConf.create( + { + "checkpoint": {"resume_exp_name": checkpoint}, + } + ) + cfg = OmegaConf.merge(cfg, cfg_override) # Parse step if provided step = model_overrides.get("step") if step is not None: - flat_cfg.resume_step = int(step) - else: - flat_cfg.resume_step = None + cfg_override = OmegaConf.create( + { + "checkpoint": {"resume_step": int(step)}, + } + ) + cfg = OmegaConf.merge(cfg, cfg_override) - ckpt = maybe_load_checkpoint(flat_cfg) + ckpt = maybe_load_checkpoint(cfg) if ckpt is None: raise ValueError(f"Failed to load checkpoint for {checkpoint}") diff --git a/models/__init__.py b/models/__init__.py index 5acdfa8..8011c4c 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,6 +1,12 @@ """Model registry.""" -from .delta_net import DeltaNet, DeltaNetWrapperConfig from .transformer import Transformer, TransformerConfig -__all__ = ["Transformer", "TransformerConfig", "DeltaNet", "DeltaNetWrapperConfig"] +__all__ = ["Transformer", "TransformerConfig"] + +try: + from .delta_net import DeltaNet, DeltaNetWrapperConfig + + __all__ += ["DeltaNet", "DeltaNetWrapperConfig"] +except ImportError: + pass diff --git a/optim/init_optim.py b/optim/init_optim.py index 62eb22f..b4d4ca0 100644 --- a/optim/init_optim.py +++ b/optim/init_optim.py @@ -7,124 +7,128 @@ def initialize_optimizer(param_groups, cfg): """Initialize an optimizer from config.""" - if cfg.optim == "adamw": + if cfg.training.optim == "adamw": optimizer = torch.optim.AdamW( param_groups, - lr=cfg.lr, - betas=[cfg.beta1, cfg.beta2], - weight_decay=cfg.weight_decay, - fused=cfg.fused_optim, - eps=cfg.eps, + lr=cfg.training.lr, + betas=[cfg.training.beta1, cfg.training.beta2], + weight_decay=cfg.training.weight_decay, + fused=cfg.training.fused_optim, + eps=cfg.training.eps, ) - elif cfg.optim == "nadamw": + elif cfg.training.optim == "nadamw": kwargs = dict( - lr=cfg.lr, - betas=[cfg.beta1, cfg.beta2], - weight_decay=cfg.weight_decay, + lr=cfg.training.lr, + betas=[cfg.training.beta1, cfg.training.beta2], + weight_decay=cfg.training.weight_decay, decoupled_weight_decay=True, - eps=cfg.eps, + eps=cfg.training.eps, ) # fused only supported on CUDA - if cfg.fused_optim and torch.cuda.is_available(): + if cfg.training.fused_optim and torch.cuda.is_available(): kwargs["fused"] = True optimizer = torch.optim.NAdam(param_groups, **kwargs) - elif cfg.optim == "sgd": + elif cfg.training.optim == "sgd": optimizer = torch.optim.SGD( param_groups, - lr=cfg.lr, - momentum=cfg.beta1, - dampening=getattr(cfg, "dampening", 0.0), - weight_decay=cfg.weight_decay, + lr=cfg.training.lr, + momentum=cfg.training.beta1, + dampening=getattr(cfg.training, "dampening", 0.0), + weight_decay=cfg.training.weight_decay, ) - elif cfg.optim == "signSGD": + elif cfg.training.optim == "signSGD": from .sign_sgd import signSGD optimizer = signSGD( param_groups, - lr=cfg.lr, - momentum=cfg.beta1, - dampening=getattr(cfg, "dampening", 0.0), - weight_decay=cfg.weight_decay, + lr=cfg.training.lr, + momentum=cfg.training.beta1, + dampening=getattr(cfg.training, "dampening", 0.0), + weight_decay=cfg.training.weight_decay, ) else: - raise NotImplementedError(f"Not implemented optim: {cfg.optim}.") + raise NotImplementedError(f"Not implemented optim: {cfg.training.optim}.") return optimizer def initialize_scheduler(optimizer, cfg): """Initialize a learning rate scheduler from config.""" - if cfg.scheduler is None: + if cfg.training.scheduler is None: return None # Warmup steps: int or fraction of steps_budget warmup_steps = None - if getattr(cfg, "warmup_steps", None) is not None: + if cfg.training.warmup_steps is not None: warmup_steps = ( - cfg.warmup_steps - if isinstance(cfg.warmup_steps, int) - else int(cfg.warmup_steps * cfg.steps_budget) + cfg.training.warmup_steps + if isinstance(cfg.training.warmup_steps, int) + else int(cfg.training.warmup_steps * cfg.training.steps_budget) ) # Cooldown steps: int or fraction of steps_budget cooldown_steps = None - if getattr(cfg, "cooldown_steps", None) is not None: + if cfg.training.cooldown_steps is not None: cooldown_steps = ( - cfg.cooldown_steps - if isinstance(cfg.cooldown_steps, int) - else int(cfg.cooldown_steps * cfg.steps_budget) + cfg.training.cooldown_steps + if isinstance(cfg.training.cooldown_steps, int) + else int(cfg.training.cooldown_steps * cfg.training.steps_budget) ) # Final LR: direct or as fraction of peak lr lr_end = None - if getattr(cfg, "lr_end", None) is not None or getattr(cfg, "lr_end_pct", None) is not None: - lr_end = cfg.lr_end if (cfg.lr_end is not None) else (cfg.lr_end_pct * cfg.lr) + if cfg.training.lr_end is not None or cfg.training.lr_end_pct is not None: + lr_end = ( + cfg.training.lr_end + if cfg.training.lr_end is not None + else (cfg.training.lr_end_pct * cfg.training.lr) + ) - if cfg.scheduler == "warmup_cosine": + if cfg.training.scheduler == "warmup_cosine": scheduler = WarmupCosine( optimizer, - lr_start=cfg.lr_start, - lr_max=cfg.lr, + lr_start=cfg.training.lr_start, + lr_max=cfg.training.lr, lr_end=lr_end, warmup_steps=warmup_steps, - T=cfg.steps_budget, + T=cfg.training.steps_budget, ) - elif cfg.scheduler == "wsd": - cooldown_start_step = cfg.steps_budget - cooldown_steps + elif cfg.training.scheduler == "wsd": + cooldown_start_step = cfg.training.steps_budget - cooldown_steps scheduler = WSD( optimizer, - lr_start=cfg.lr_start, - lr_max=cfg.lr, + lr_start=cfg.training.lr_start, + lr_max=cfg.training.lr, lr_end=lr_end, warmup_steps=warmup_steps, cooldown_start_step=cooldown_start_step, cooldown_steps=cooldown_steps, ) - elif cfg.scheduler == "warmup_constant": + elif cfg.training.scheduler == "warmup_constant": scheduler = WarmupConstant( optimizer, - lr_start=cfg.lr_start, - lr_max=cfg.lr, + lr_start=cfg.training.lr_start, + lr_max=cfg.training.lr, warmup_steps=warmup_steps, ) - elif cfg.scheduler == "linear_cooldown": - cooldown_start_step = cfg.resume_step + elif cfg.training.scheduler == "linear_cooldown": + cooldown_start_step = cfg.checkpoint.resume_step scheduler = LinearCooldown( optimizer, - lr_max=cfg.lr, + lr_max=cfg.training.lr, lr_end=lr_end, cooldown_start_step=cooldown_start_step, cooldown_steps=cooldown_steps, ) else: - raise NotImplementedError(f"Not implemented scheduler: {cfg.scheduler}.") + raise NotImplementedError(f"Not implemented scheduler: {cfg.training.scheduler}.") return scheduler diff --git a/pyproject.toml b/pyproject.toml index 3181f12..17143c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "pandas>=3.0.0", "python-dotenv>=1.2.1", "torch>=2.10.0", + "tqdm>=4.67.3", "transformers>=5.1.0", "wandb>=0.25.0", ] @@ -22,6 +23,7 @@ dependencies = [ dev = [ "ruff>=0.15.1", "pre-commit>=4.5.1", + "pytest>=9.0.2", ] [build-system] @@ -35,6 +37,7 @@ packages = ["models"] target-version = "py312" line-length = 100 src = ["src"] +exclude = ["libs/"] [tool.ruff.lint] select = ["C", "E", "F", "I", "W", "D", "N"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..f70441e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Quark test suite.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..59d5381 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,19 @@ +"""Shared fixtures for tests.""" + +import pytest +from hydra import compose, initialize +from omegaconf import DictConfig + + +@pytest.fixture() +def cfg() -> DictConfig: + """Return default resolved config.""" + with initialize(version_base=None, config_path="../configs"): + return compose(config_name="config") + + +@pytest.fixture() +def cfg_delta_net() -> DictConfig: + """Return config with model=delta_net.""" + with initialize(version_base=None, config_path="../configs"): + return compose(config_name="config", overrides=["model=delta_net"]) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 0000000..21846af --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,103 @@ +"""Tests for checkpoint save/load utilities.""" + +import os +import tempfile + +import torch +from checkpoint_utils import maybe_load_checkpoint, save_checkpoint +from omegaconf import OmegaConf + + +def _make_checkpoint_cfg(tmpdir, exp_name="test_exp", resume=False, resume_step=None): + """Create a minimal config for checkpoint tests.""" + return OmegaConf.create( + { + "out_dir": tmpdir, + "checkpoint": { + "exp_name": exp_name, + "resume": resume, + "resume_step": resume_step, + "resume_exp_name": None, + "save_intermediate_checkpoints": True, + "save_last_checkpoint": True, + "over_write": True, + }, + } + ) + + +def _dummy_engine(): + """Create a minimal mock engine with optimizer, scheduler, scaler.""" + + class MockEngine: + pass + + param = torch.nn.Parameter(torch.zeros(4)) + engine = MockEngine() + engine.optimizer = torch.optim.SGD([param], lr=0.01) + engine.scheduler = None + engine.scaler = torch.amp.GradScaler(enabled=False) + return engine + + +def test_save_and_load_checkpoint(): + """Should save and reload a checkpoint correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = torch.nn.Linear(4, 4) + engine = _dummy_engine() + cfg = _make_checkpoint_cfg(tmpdir) + + save_checkpoint(step=100, model=model, engine=engine, cfg=cfg, metrics={"loss": [1.0]}) + + # Verify file was created + ckpt_path = os.path.join(tmpdir, "test_exp", "ckpt_step_100.pth") + assert os.path.exists(ckpt_path) + + # Load it back + load_cfg = _make_checkpoint_cfg(tmpdir, resume=True, resume_step=100) + ckpt = maybe_load_checkpoint(load_cfg) + + assert ckpt is not None + assert ckpt["step"] == 100 + assert "state_dict" in ckpt + assert "optimizer" in ckpt + + +def test_maybe_load_checkpoint_returns_none_when_not_resuming(): + """Should return None when resume=False.""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_checkpoint_cfg(tmpdir, resume=False) + assert maybe_load_checkpoint(cfg) is None + + +def test_save_checkpoint_creates_metrics_file(): + """Should write a metrics.json alongside the checkpoint.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = torch.nn.Linear(4, 4) + engine = _dummy_engine() + cfg = _make_checkpoint_cfg(tmpdir) + + metrics = {"train/loss": [2.5, 2.3, 2.1]} + save_checkpoint(step=50, model=model, engine=engine, cfg=cfg, metrics=metrics) + + metrics_path = os.path.join(tmpdir, "test_exp", "metrics.json") + assert os.path.exists(metrics_path) + + +def test_load_latest_checkpoint(): + """Should load the latest checkpoint when resume_step is None.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = torch.nn.Linear(4, 4) + engine = _dummy_engine() + cfg = _make_checkpoint_cfg(tmpdir) + + # Save two checkpoints + save_checkpoint(step=100, model=model, engine=engine, cfg=cfg) + save_checkpoint(step=200, model=model, engine=engine, cfg=cfg) + + # Load latest (resume_step=None) + load_cfg = _make_checkpoint_cfg(tmpdir, resume=True, resume_step=None) + ckpt = maybe_load_checkpoint(load_cfg) + + assert ckpt is not None + assert ckpt["step"] == 200 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..5904583 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,222 @@ +"""Tests for Hydra config composition.""" + +from hydra import compose, initialize +from omegaconf import OmegaConf + +# -- Structure tests --------------------------------------------------------- + + +def test_default_config_has_all_groups(cfg): + """Default config should contain all top-level groups.""" + for group in ("model", "data", "training", "system", "logging", "checkpoint"): + assert group in cfg, f"Missing config group: {group}" + + +def test_no_flat_keys_at_root(cfg): + """No training/system/logging/checkpoint keys should leak to the root.""" + root_keys = set(OmegaConf.to_container(cfg, resolve=True).keys()) + expected_root = {"model", "data", "training", "system", "logging", "checkpoint", "out_dir"} + assert root_keys == expected_root, f"Unexpected root keys: {root_keys - expected_root}" + + +# -- Training config --------------------------------------------------------- + + +def test_training_keys(cfg): + """Training group should contain all expected keys.""" + expected = { + "steps_budget", + "micro_batch_size", + "eval_every_steps", + "log_every_steps", + "grad_accumulation_steps", + "num_workers", + "sampler", + "sampler_seed", + "optim", + "fused_optim", + "lr", + "weight_decay", + "beta1", + "beta2", + "grad_clip", + "eps", + "scheduler", + "warmup_steps", + "cooldown_steps", + "lr_start", + "lr_end", + "lr_end_pct", + "early_stopping_patience", + } + actual = set(OmegaConf.to_container(cfg.training, resolve=True).keys()) + assert expected == actual, f"Missing: {expected - actual}, Extra: {actual - expected}" + + +def test_training_default_values(cfg): + """Spot-check default training values.""" + assert cfg.training.optim == "adamw" + assert cfg.training.scheduler == "warmup_cosine" + assert cfg.training.lr == 7e-4 + assert cfg.training.grad_accumulation_steps == 2 + assert cfg.training.micro_batch_size == 16 + assert cfg.training.sampler == "stateful_random" + + +# -- System config ----------------------------------------------------------- + + +def test_system_keys(cfg): + """System group should contain expected keys.""" + expected = {"dtype", "compile_model", "seed", "ddp_backend"} + actual = set(OmegaConf.to_container(cfg.system, resolve=True).keys()) + assert expected == actual + + +def test_system_default_values(cfg): + """Spot-check default system values.""" + assert cfg.system.dtype == "bfloat16" + assert cfg.system.compile_model is False + assert cfg.system.seed == 42 + + +# -- Logging config ---------------------------------------------------------- + + +def test_logging_keys(cfg): + """Logging group should contain expected keys.""" + expected = {"wandb_log", "wandb_project", "wandb_log_layer_stats"} + actual = set(OmegaConf.to_container(cfg.logging, resolve=True).keys()) + assert expected == actual + + +def test_logging_default_values(cfg): + """Spot-check default logging values.""" + assert cfg.logging.wandb_project == "quark" + assert cfg.logging.wandb_log is True + + +# -- Checkpoint config ------------------------------------------------------- + + +def test_checkpoint_keys(cfg): + """Checkpoint group should contain expected keys.""" + expected = { + "save_last_checkpoint", + "save_intermediate_checkpoints", + "save_every_steps", + "resume", + "resume_step", + "resume_exp_name", + "over_write", + "exp_name", + } + actual = set(OmegaConf.to_container(cfg.checkpoint, resolve=True).keys()) + assert expected == actual + + +def test_checkpoint_default_values(cfg): + """Spot-check default checkpoint values.""" + assert cfg.checkpoint.resume is False + assert cfg.checkpoint.exp_name == "default" + assert cfg.checkpoint.save_last_checkpoint is True + + +# -- Data config ------------------------------------------------------------- + + +def test_data_keys(cfg): + """Data group should contain expected keys.""" + expected = { + "dataset", + "vocab_size", + "trainset_path", + "validset_path", + "seq_len", + "eval", + "valid_tokens", + } + actual = set(OmegaConf.to_container(cfg.data, resolve=True).keys()) + assert expected == actual + + +def test_data_has_vocab_size(cfg): + """Data config should define vocab_size.""" + assert cfg.data.vocab_size == 50304 + + +def test_data_does_not_have_training_fields(cfg): + """Data config should not contain fields that moved to training.""" + data_keys = set(OmegaConf.to_container(cfg.data, resolve=True).keys()) + for field in ("micro_batch_size", "num_workers", "sampler", "sampler_seed"): + assert field not in data_keys, f"'{field}' should be in training, not data" + + +# -- Model configs ----------------------------------------------------------- + + +def test_default_model_is_transformer(cfg): + """Default model should be transformer.""" + assert cfg.model.model_type == "transformer" + assert cfg.model.hidden_size == 384 + assert cfg.model.num_layers == 21 + + +def test_delta_net_model_override(cfg_delta_net): + """model=delta_net should switch to delta_net config.""" + assert cfg_delta_net.model.model_type == "delta_net" + assert cfg_delta_net.model.hidden_size == 1024 + assert cfg_delta_net.model.num_layers == 23 + + +def test_model_override_preserves_other_groups(cfg_delta_net): + """Switching model should not affect other config groups.""" + assert cfg_delta_net.training.optim == "adamw" + assert cfg_delta_net.system.dtype == "bfloat16" + assert cfg_delta_net.data.vocab_size == 50304 + + +# -- CLI overrides ----------------------------------------------------------- + + +def test_cli_override_training(): + """CLI overrides should propagate to the correct group.""" + with initialize(version_base=None, config_path="../configs"): + cfg = compose(config_name="config", overrides=["training.lr=1e-4"]) + assert cfg.training.lr == 1e-4 + + +def test_cli_override_system(): + """CLI override for system group should work.""" + with initialize(version_base=None, config_path="../configs"): + cfg = compose(config_name="config", overrides=["system.seed=123"]) + assert cfg.system.seed == 123 + + +def test_cli_override_checkpoint(): + """CLI override for checkpoint group should work.""" + with initialize(version_base=None, config_path="../configs"): + cfg = compose( + config_name="config", + overrides=["checkpoint.exp_name=my_experiment", "checkpoint.resume=true"], + ) + assert cfg.checkpoint.exp_name == "my_experiment" + assert cfg.checkpoint.resume is True + + +def test_cli_override_multiple(): + """Multiple CLI overrides should all take effect.""" + with initialize(version_base=None, config_path="../configs"): + cfg = compose( + config_name="config", + overrides=[ + "model=delta_net", + "training.lr=1e-3", + "training.steps_budget=5000", + "system.compile_model=true", + ], + ) + assert cfg.model.model_type == "delta_net" + assert cfg.training.lr == 1e-3 + assert cfg.training.steps_budget == 5000 + assert cfg.system.compile_model is True diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..7eab96e --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,87 @@ +"""Tests for TorchEngine initialization with structured config.""" + +import torch +from engine import TorchEngine +from models.transformer import Transformer, TransformerConfig +from omegaconf import OmegaConf + + +def _small_model(): + """Create a tiny transformer for testing.""" + config = TransformerConfig( + vocab_size=256, hidden_size=32, num_layers=2, num_heads=2, head_dim=16, block_size=64 + ) + return Transformer(config) + + +def test_engine_init(cfg): + """TorchEngine should initialize from structured config without errors.""" + # Override to avoid needing GPU / real data paths + test_cfg = OmegaConf.merge( + cfg, + OmegaConf.create( + { + "system": {"compile_model": False, "dtype": "float32"}, + "checkpoint": {"resume": False}, + "data": {"seq_len": 64}, + } + ), + ) + model = _small_model() + engine = TorchEngine(model, test_cfg, device="cpu", local_rank=None, ckpt=None) + + assert engine.seq_len == 64 + assert engine.accumulation_steps == test_cfg.training.grad_accumulation_steps + assert engine.grad_clip == test_cfg.training.grad_clip + assert engine.dtype == "float32" + assert engine.optimizer is not None + assert engine.scheduler is not None + + +def test_engine_step(cfg): + """TorchEngine.step should return loss and grad_norm.""" + test_cfg = OmegaConf.merge( + cfg, + OmegaConf.create( + { + "system": {"compile_model": False, "dtype": "float32"}, + "checkpoint": {"resume": False}, + "data": {"seq_len": 16}, + "training": {"grad_accumulation_steps": 1}, + } + ), + ) + model = _small_model() + engine = TorchEngine(model, test_cfg, device="cpu", local_rank=None, ckpt=None) + + # Create a fake batch + batch = {"input_ids": torch.randint(0, 256, (2, 17))} # seq_len + 1 + loss, grad_norm = engine.step(batch) + + assert loss is not None + assert loss.ndim == 0 + assert not torch.isnan(loss) + assert grad_norm is not None + + +def test_engine_eval(cfg): + """TorchEngine.eval should return average loss.""" + test_cfg = OmegaConf.merge( + cfg, + OmegaConf.create( + { + "system": {"compile_model": False, "dtype": "float32"}, + "checkpoint": {"resume": False}, + "data": {"seq_len": 16}, + } + ), + ) + model = _small_model() + engine = TorchEngine(model, test_cfg, device="cpu", local_rank=None, ckpt=None) + + # Create a tiny DataLoader-like iterable + batches = [{"input_ids": torch.randint(0, 256, (2, 17))} for _ in range(3)] + avg_loss = engine.eval(batches) + + assert isinstance(avg_loss, float) + assert avg_loss > 0 diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..fcd9f87 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,95 @@ +"""Tests for model building and instantiation.""" + +import pytest +import torch +from models.transformer import Transformer, TransformerConfig + +try: + from models.delta_net import DeltaNet, DeltaNetWrapperConfig + + HAS_FLA = True +except ImportError: + HAS_FLA = False + + +# -- build_model (transformer only, no triton needed) ------------------------- + + +def test_build_model_transformer(cfg): + """build_model should create a Transformer with correct dimensions from config.""" + vocab_size = cfg.data.vocab_size + config = TransformerConfig( + vocab_size=vocab_size, + hidden_size=cfg.model.hidden_size, + num_layers=cfg.model.num_layers, + num_heads=cfg.model.num_heads, + head_dim=cfg.model.head_dim, + block_size=cfg.data.seq_len, + dropout=cfg.model.dropout, + bias=cfg.model.bias, + ) + model = Transformer(config) + + assert isinstance(model, Transformer) + assert config.hidden_size == 384 + assert config.num_layers == 21 + assert config.num_heads == 8 + assert config.vocab_size == 50304 + + +@pytest.mark.skipif(not HAS_FLA, reason="flash-linear-attention / triton not available") +def test_build_model_delta_net(cfg_delta_net): + """build_model should create a DeltaNet with correct dimensions.""" + cfg = cfg_delta_net + config = DeltaNetWrapperConfig( + vocab_size=cfg.data.vocab_size, + hidden_size=cfg.model.hidden_size, + num_layers=cfg.model.num_layers, + num_heads=cfg.model.num_heads, + ) + model = DeltaNet(config) + + assert isinstance(model, DeltaNet) + assert config.hidden_size == 1024 + assert config.num_layers == 23 + + +# -- Transformer forward ------------------------------------------------------ + + +def test_transformer_forward(): + """Transformer forward should return (logits, loss) with correct shapes.""" + config = TransformerConfig( + vocab_size=256, hidden_size=32, num_layers=2, num_heads=2, head_dim=16, block_size=64 + ) + model = Transformer(config) + + idx = torch.randint(0, 256, (2, 16)) + targets = torch.randint(0, 256, (2, 16)) + + logits, loss = model(idx, targets) + assert logits.shape == (2, 16, 256) + assert loss is not None + assert loss.ndim == 0 # scalar + + +def test_transformer_forward_no_targets(): + """Transformer forward without targets should return loss=None.""" + config = TransformerConfig( + vocab_size=256, hidden_size=32, num_layers=2, num_heads=2, head_dim=16, block_size=64 + ) + model = Transformer(config) + + idx = torch.randint(0, 256, (2, 16)) + logits, loss = model(idx) + assert logits.shape == (2, 16, 256) + assert loss is None + + +def test_transformer_weight_tying(): + """Embedding and lm_head weights should be tied.""" + config = TransformerConfig( + vocab_size=256, hidden_size=32, num_layers=2, num_heads=2, head_dim=16, block_size=64 + ) + model = Transformer(config) + assert model.wte.weight is model.lm_head.weight diff --git a/tests/test_optim.py b/tests/test_optim.py new file mode 100644 index 0000000..ba435fc --- /dev/null +++ b/tests/test_optim.py @@ -0,0 +1,107 @@ +"""Tests for optimizer and scheduler initialization.""" + +import torch +from omegaconf import OmegaConf +from optim import initialize_optimizer, initialize_scheduler + + +def _dummy_param_groups(lr=1e-3): + """Create minimal param groups for testing.""" + param = torch.nn.Parameter(torch.zeros(4)) + return [{"params": [param], "weight_decay": 0.1}] + + +# -- Optimizer ---------------------------------------------------------------- + + +def test_initialize_adamw(cfg): + """Should create an AdamW optimizer from structured config.""" + groups = _dummy_param_groups() + optimizer = initialize_optimizer(groups, cfg) + assert isinstance(optimizer, torch.optim.AdamW) + assert optimizer.defaults["lr"] == cfg.training.lr + + +def test_initialize_nadamw(cfg): + """Should create an NAdam optimizer when optim=nadamw.""" + test_cfg = OmegaConf.merge(cfg, OmegaConf.create({"training": {"optim": "nadamw"}})) + groups = _dummy_param_groups() + optimizer = initialize_optimizer(groups, test_cfg) + assert isinstance(optimizer, torch.optim.NAdam) + + +def test_initialize_sgd(cfg): + """Should create an SGD optimizer when optim=sgd.""" + test_cfg = OmegaConf.merge(cfg, OmegaConf.create({"training": {"optim": "sgd"}})) + groups = _dummy_param_groups() + optimizer = initialize_optimizer(groups, test_cfg) + assert isinstance(optimizer, torch.optim.SGD) + + +# -- Scheduler ---------------------------------------------------------------- + + +def test_initialize_warmup_cosine(cfg): + """Should create a WarmupCosine scheduler.""" + groups = _dummy_param_groups() + optimizer = torch.optim.SGD(groups, lr=1e-3) + scheduler = initialize_scheduler(optimizer, cfg) + assert scheduler is not None + assert hasattr(scheduler, "step") + assert hasattr(scheduler, "warmup_steps") + + +def test_initialize_wsd(cfg): + """Should create a WSD scheduler.""" + test_cfg = OmegaConf.merge( + cfg, + OmegaConf.create( + { + "training": { + "scheduler": "wsd", + "cooldown_steps": 100, + "lr_end": 1e-5, + }, + } + ), + ) + groups = _dummy_param_groups() + optimizer = torch.optim.SGD(groups, lr=1e-3) + scheduler = initialize_scheduler(optimizer, test_cfg) + assert scheduler is not None + + +def test_initialize_warmup_constant(cfg): + """Should create a WarmupConstant scheduler.""" + test_cfg = OmegaConf.merge( + cfg, OmegaConf.create({"training": {"scheduler": "warmup_constant"}}) + ) + groups = _dummy_param_groups() + optimizer = torch.optim.SGD(groups, lr=1e-3) + scheduler = initialize_scheduler(optimizer, test_cfg) + assert scheduler is not None + + +def test_scheduler_none(cfg): + """Should return None when scheduler is null.""" + test_cfg = OmegaConf.merge(cfg, OmegaConf.create({"training": {"scheduler": None}})) + groups = _dummy_param_groups() + optimizer = torch.optim.SGD(groups, lr=1e-3) + scheduler = initialize_scheduler(optimizer, test_cfg) + assert scheduler is None + + +def test_scheduler_step_changes_lr(cfg): + """Scheduler.step() should modify the optimizer's learning rate.""" + groups = _dummy_param_groups() + optimizer = torch.optim.SGD(groups, lr=1e-3) + scheduler = initialize_scheduler(optimizer, cfg) + + initial_lr = optimizer.param_groups[0]["lr"] + # Step through warmup + for _ in range(cfg.training.warmup_steps + 10): + scheduler.step() + post_warmup_lr = optimizer.param_groups[0]["lr"] + + # LR should have changed from initial (lr_start=0.0) + assert initial_lr != post_warmup_lr diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..7c2f02d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,219 @@ +"""Tests for utility functions.""" + +import os +import tempfile + +import torch +from omegaconf import OmegaConf +from utils import get_param_groups, get_variant_name, log, maybe_make_dir + +# -- get_variant_name -------------------------------------------------------- + + +def test_variant_name_transformer(cfg): + """Transformer config should return 'Transformer'.""" + assert get_variant_name(cfg) == "Transformer" + + +def test_variant_name_delta_net(cfg_delta_net): + """DeltaNet config should return 'DeltaNet'.""" + assert get_variant_name(cfg_delta_net) == "DeltaNet" + + +# -- get_param_groups -------------------------------------------------------- + + +def test_param_groups_separates_decay(): + """Linear weights should decay, biases and norms should not.""" + model = torch.nn.Sequential( + torch.nn.Linear(4, 8, bias=True), + torch.nn.LayerNorm(8), + torch.nn.Linear(8, 4, bias=False), + ) + groups = get_param_groups(model, weight_decay=0.1) + assert len(groups) == 2 + + decay_group = groups[0] + no_decay_group = groups[1] + + assert decay_group["weight_decay"] == 0.1 + assert no_decay_group["weight_decay"] == 0.0 + + # Linear weights should be in decay group + decay_numel = sum(p.numel() for p in decay_group["params"]) + no_decay_numel = sum(p.numel() for p in no_decay_group["params"]) + + # 4*8 + 8*4 = 64 in decay (two Linear weights) + assert decay_numel == 64 + # 8 (bias) + 8 + 8 (LayerNorm weight + bias) = 24 in no_decay + assert no_decay_numel == 24 + + +def test_param_groups_all_params_accounted(): + """Every parameter should be in exactly one group.""" + model = torch.nn.Sequential( + torch.nn.Embedding(100, 16), + torch.nn.Linear(16, 32), + torch.nn.Linear(32, 16, bias=False), + ) + groups = get_param_groups(model, weight_decay=0.05) + group_params = set() + for g in groups: + for p in g["params"]: + group_params.add(id(p)) + + all_params = {id(p) for p in model.parameters()} + assert group_params == all_params + + +# -- maybe_make_dir ---------------------------------------------------------- + + +def test_maybe_make_dir_creates_directory(): + """Should create the experiment directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = OmegaConf.create( + { + "out_dir": tmpdir, + "checkpoint": { + "save_intermediate_checkpoints": True, + "save_last_checkpoint": True, + "resume": False, + "resume_exp_name": None, + "over_write": True, + "exp_name": "test_exp", + }, + } + ) + maybe_make_dir(cfg) + assert os.path.isdir(os.path.join(tmpdir, "test_exp")) + + +def test_maybe_make_dir_skips_when_no_checkpoints(): + """Should not create a directory when checkpointing is disabled.""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = OmegaConf.create( + { + "out_dir": tmpdir, + "checkpoint": { + "save_intermediate_checkpoints": False, + "save_last_checkpoint": False, + "resume": False, + "resume_exp_name": None, + "over_write": True, + "exp_name": "should_not_exist", + }, + } + ) + maybe_make_dir(cfg) + assert not os.path.exists(os.path.join(tmpdir, "should_not_exist")) + + +def test_maybe_make_dir_overwrites_existing(): + """Should remove and recreate when over_write is True.""" + with tempfile.TemporaryDirectory() as tmpdir: + exp_dir = os.path.join(tmpdir, "overwrite_exp") + os.makedirs(exp_dir) + marker = os.path.join(exp_dir, "old_file.txt") + with open(marker, "w") as f: + f.write("old") + + cfg = OmegaConf.create( + { + "out_dir": tmpdir, + "checkpoint": { + "save_intermediate_checkpoints": True, + "save_last_checkpoint": True, + "resume": False, + "resume_exp_name": None, + "over_write": True, + "exp_name": "overwrite_exp", + }, + } + ) + maybe_make_dir(cfg) + assert os.path.isdir(exp_dir) + assert not os.path.exists(marker) + + +# -- log --------------------------------------------------------------------- + + +def test_log_populates_metrics(cfg): + """Log should append metrics to the metrics dict.""" + metrics = {} + # We need a mock optimizer with param_groups + dummy_param = torch.nn.Parameter(torch.zeros(1)) + optimizer = torch.optim.SGD([dummy_param], lr=0.001) + + # Disable wandb for this test + test_cfg = OmegaConf.merge(cfg, OmegaConf.create({"logging": {"wandb_log": False}})) + + train_loss = torch.tensor(2.5) + + log( + test_cfg, + metrics, + micro_step=100, + train_loss=train_loss, + train_loss_array=[train_loss], + valid_loss=None, + optimizer=optimizer, + world_size=1, + grad_norm=torch.tensor(0.5), + ) + + assert "step" in metrics + assert "train/loss" in metrics + assert "train/grad_norm" in metrics + assert len(metrics["step"]) == 1 + assert metrics["step"][0] == 100 // cfg.training.grad_accumulation_steps + + +def test_log_includes_throughput(cfg): + """Log should include throughput and step_time when provided.""" + metrics = {} + dummy_param = torch.nn.Parameter(torch.zeros(1)) + optimizer = torch.optim.SGD([dummy_param], lr=0.001) + test_cfg = OmegaConf.merge(cfg, OmegaConf.create({"logging": {"wandb_log": False}})) + train_loss = torch.tensor(2.5) + + log( + test_cfg, + metrics, + micro_step=100, + train_loss=train_loss, + train_loss_array=[train_loss], + valid_loss=None, + optimizer=optimizer, + world_size=1, + throughput=50000.0, + step_time=0.5, + ) + + assert "train/throughput" in metrics + assert "train/step_time" in metrics + assert metrics["train/throughput"][0] == 50000.0 + + +def test_log_includes_valid_loss(cfg): + """Log should include valid/loss when provided.""" + metrics = {} + dummy_param = torch.nn.Parameter(torch.zeros(1)) + optimizer = torch.optim.SGD([dummy_param], lr=0.001) + test_cfg = OmegaConf.merge(cfg, OmegaConf.create({"logging": {"wandb_log": False}})) + train_loss = torch.tensor(2.5) + + log( + test_cfg, + metrics, + micro_step=100, + train_loss=train_loss, + train_loss_array=[train_loss], + valid_loss=1.8, + optimizer=optimizer, + world_size=1, + ) + + assert "valid/loss" in metrics + assert "valid/ppl" in metrics diff --git a/torch_utils.py b/torch_utils.py index ee58fc2..cbac4d5 100644 --- a/torch_utils.py +++ b/torch_utils.py @@ -32,9 +32,9 @@ def pytorch_setup(cfg): elif torch.backends.mps.is_available(): device = "mps" - random.seed(cfg.seed + seed_offset) - np.random.seed(cfg.seed + seed_offset) - torch.manual_seed(cfg.seed + seed_offset) + random.seed(cfg.system.seed + seed_offset) + np.random.seed(cfg.system.seed + seed_offset) + torch.manual_seed(cfg.system.seed + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True diff --git a/train.py b/train.py index c829c8a..505b347 100644 --- a/train.py +++ b/train.py @@ -2,12 +2,14 @@ Usage: python train.py + python train.py model=delta_net python train.py training.steps_budget=10000 python train.py training.optim=nadamw training.scheduler=wsd training.cooldown_steps=100 torchrun --standalone --nproc_per_node=4 train.py """ import os +import time from collections import defaultdict import hydra @@ -15,14 +17,15 @@ from data import get_dataloaders from engine import TorchEngine from models import DeltaNet, DeltaNetWrapperConfig, Transformer, TransformerConfig -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from torch_utils import destroy_ddp, pytorch_setup -from utils import flatten_config, log, maybe_make_dir, print_master +from tqdm import tqdm +from utils import get_variant_name, log, maybe_make_dir, print_master -def build_model(cfg: DictConfig, flat_cfg): +def build_model(cfg: DictConfig): """Build a model from config, dispatching on model_type.""" - vocab_size = flat_cfg.vocab_size if hasattr(flat_cfg, "vocab_size") else 50304 + vocab_size = cfg.data.vocab_size model_type = cfg.model.model_type if model_type == "delta_net": @@ -46,7 +49,7 @@ def build_model(cfg: DictConfig, flat_cfg): num_layers=cfg.model.num_layers, num_heads=cfg.model.num_heads, head_dim=cfg.model.head_dim, - block_size=flat_cfg.seq_len, + block_size=cfg.data.seq_len, dropout=cfg.model.dropout, bias=cfg.model.bias, ) @@ -56,55 +59,71 @@ def build_model(cfg: DictConfig, flat_cfg): @hydra.main(version_base=None, config_path="configs", config_name="config") def main(cfg: DictConfig): # noqa: C901 """Main training function.""" - flat_cfg = flatten_config(cfg) - - local_rank, world_size, device, master_process = pytorch_setup(flat_cfg) + # Setup + local_rank, world_size, device, master_process = pytorch_setup(cfg) if master_process: - maybe_make_dir(flat_cfg) + maybe_make_dir(cfg) - ckpt = maybe_load_checkpoint(flat_cfg) + # Checkpoint + ckpt = maybe_load_checkpoint(cfg) - trainloader, validloader = get_dataloaders(flat_cfg) + # Data + trainloader, validloader = get_dataloaders(cfg) - model, config = build_model(cfg, flat_cfg) + # Model + model, config = build_model(cfg) if master_process: + variant_name = get_variant_name(cfg) num_params = sum(p.numel() for p in model.parameters()) + total_tokens = ( + cfg.training.steps_budget + * cfg.training.grad_accumulation_steps + * cfg.training.micro_batch_size + * cfg.data.seq_len + * world_size + ) print("=" * 80) - print(f"Training {cfg.model.model_type}") + print(f"Training {variant_name}") print(f" Device: {device}, dtype: {cfg.system.dtype}") print(f" Hidden: {config.hidden_size}, Layers: {config.num_layers}") if hasattr(config, "head_dim"): print(f" Heads: {config.num_heads}, Head dim: {config.head_dim}") else: print(f" Heads: {config.num_heads}") - total_tokens = ( - flat_cfg.steps_budget - * flat_cfg.grad_accumulation_steps - * flat_cfg.micro_batch_size - * flat_cfg.seq_len - * world_size - ) print(f" Parameters: {num_params:,}") - print(f" Steps budget: {flat_cfg.steps_budget}") - print(f" Total tokens: {total_tokens:,}") + print(f" Steps budget: {cfg.training.steps_budget}") + print(f" Total training tokens: {total_tokens:,}") print("=" * 80) - engine = TorchEngine(model, flat_cfg, device, local_rank, ckpt) + # Engine + engine = TorchEngine(model, cfg, device, local_rank, ckpt) + # W&B if cfg.logging.wandb_log and master_process: import uuid import wandb - from omegaconf import OmegaConf + variant_name = get_variant_name(cfg) num_params = sum(p.numel() for p in model.parameters()) + total_tokens = ( + cfg.training.steps_budget + * cfg.training.grad_accumulation_steps + * cfg.training.micro_batch_size + * cfg.data.seq_len + * world_size + ) run_id = uuid.uuid4().hex[:6] wandb.init( project=cfg.logging.wandb_project, - name=f"{cfg.model.model_type}-L{cfg.model.num_layers}-D{cfg.model.hidden_size}-lr{cfg.training.lr}-{run_id}", - config={**OmegaConf.to_container(cfg, resolve=True), "num_params": num_params}, + name=f"{variant_name}-L{cfg.model.num_layers}-D{cfg.model.hidden_size}-lr{cfg.training.lr}-{run_id}", + config={ + **OmegaConf.to_container(cfg, resolve=True), + "num_params": num_params, + "total_tokens": total_tokens, + }, ) wandb.run.log_code( root=".", @@ -118,10 +137,11 @@ def main(cfg: DictConfig): # noqa: C901 wandb.define_metric("step") wandb.define_metric("*", step_metric="step") - steps_budget = flat_cfg.steps_budget - micro_step_budget = steps_budget * flat_cfg.grad_accumulation_steps - step_start = flat_cfg.resume_step if flat_cfg.resume else 0 - micro_step_start = step_start * flat_cfg.grad_accumulation_steps + # Training loop + steps_budget = cfg.training.steps_budget + micro_step_budget = steps_budget * cfg.training.grad_accumulation_steps + step_start = cfg.checkpoint.resume_step if cfg.checkpoint.resume else 0 + micro_step_start = step_start * cfg.training.grad_accumulation_steps print_master( f"=== Start Training from step {step_start}/{steps_budget}, " @@ -130,24 +150,55 @@ def main(cfg: DictConfig): # noqa: C901 metrics = defaultdict(list) train_loss_array = [] + t_log_start = time.time() + micro_step_prev = micro_step_start + pbar = None + if master_process: + pbar = tqdm( + total=steps_budget, + initial=step_start, + desc="Training steps", + dynamic_ncols=True, + ) for micro_step, micro_batch in enumerate(trainloader, micro_step_start + 1): - step = micro_step // flat_cfg.grad_accumulation_steps - is_step = micro_step % flat_cfg.grad_accumulation_steps == 0 + step = micro_step // cfg.training.grad_accumulation_steps + is_step = micro_step % cfg.training.grad_accumulation_steps == 0 if step > steps_budget and is_step: break + # Train train_loss, grad_norm = engine.step(micro_batch) train_loss_array.append(train_loss) + # Progress bar (one tick per optimizer step) + if master_process and is_step and pbar is not None: + pbar.update(1) + pbar.set_postfix(loss=float(train_loss)) + + # Eval valid_loss = None - if flat_cfg.eval and validloader and step % flat_cfg.eval_every_steps == 0 and is_step: + if cfg.data.eval and validloader and step % cfg.training.eval_every_steps == 0 and is_step: print_master("Evaluating on validation set") valid_loss = engine.eval(validloader) - if master_process and step % flat_cfg.log_every_steps == 0 and is_step: + # Log + if master_process and step % cfg.training.log_every_steps == 0 and is_step: + # Throughput + t_log_end = time.time() + dt = t_log_end - t_log_start + micro_steps_done = micro_step - micro_step_prev + throughput = ( + micro_steps_done + * cfg.training.micro_batch_size + * cfg.data.seq_len + * world_size + / dt + ) + step_time = dt / (micro_steps_done / cfg.training.grad_accumulation_steps) + log( - flat_cfg, + cfg, metrics, micro_step, train_loss, @@ -156,21 +207,30 @@ def main(cfg: DictConfig): # noqa: C901 engine.optimizer, world_size, grad_norm, + throughput=throughput, + step_time=step_time, ) train_loss_array = [] + t_log_start = time.time() + micro_step_prev = micro_step + # Checkpoint if ( master_process - and flat_cfg.save_intermediate_checkpoints - and flat_cfg.save_every_steps - and step % flat_cfg.save_every_steps == 0 + and cfg.checkpoint.save_intermediate_checkpoints + and cfg.checkpoint.save_every_steps + and step % cfg.checkpoint.save_every_steps == 0 and is_step ): - save_checkpoint(step, model, engine, flat_cfg, metrics) + save_checkpoint(step, model, engine, cfg, metrics) + + # End of training + if master_process and pbar is not None: + pbar.close() print_master("=== Training Completed! ===") - if master_process and flat_cfg.save_last_checkpoint: - save_checkpoint(step, model, engine, flat_cfg, metrics) + if master_process and cfg.checkpoint.save_last_checkpoint: + save_checkpoint(step, model, engine, cfg, metrics) if cfg.logging.wandb_log and master_process: import wandb diff --git a/utils.py b/utils.py index 918f9a4..425ef07 100644 --- a/utils.py +++ b/utils.py @@ -3,10 +3,9 @@ import math import os import shutil -from types import SimpleNamespace import torch -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig def print_master(msg): @@ -18,14 +17,11 @@ def print_master(msg): print(msg) -def flatten_config(cfg: DictConfig) -> SimpleNamespace: - """Flatten nested Hydra config into a flat namespace.""" - flat = {} - for group in ["model", "data", "training", "system", "logging", "checkpoint"]: - if group in cfg: - flat.update(OmegaConf.to_container(cfg[group], resolve=True)) - flat["out_dir"] = cfg.out_dir - return SimpleNamespace(**flat) +def get_variant_name(cfg: DictConfig) -> str: + """Derive variant name from model config.""" + model_type = cfg.model.model_type + names = {"transformer": "Transformer", "delta_net": "DeltaNet"} + return names.get(model_type, model_type) def get_param_groups(model, weight_decay): @@ -62,15 +58,15 @@ def get_param_groups(model, weight_decay): def maybe_make_dir(cfg): """Create experiment directory if checkpointing is enabled.""" - if not cfg.save_intermediate_checkpoints and not cfg.save_last_checkpoint: + if not cfg.checkpoint.save_intermediate_checkpoints and not cfg.checkpoint.save_last_checkpoint: return - if cfg.resume and cfg.resume_exp_name is None: + if cfg.checkpoint.resume and cfg.checkpoint.resume_exp_name is None: return - exp_dir = os.path.join(cfg.out_dir, cfg.exp_name) + exp_dir = os.path.join(cfg.out_dir, cfg.checkpoint.exp_name) if os.path.exists(exp_dir): - if not cfg.over_write: + if not cfg.checkpoint.over_write: raise ValueError(f"Found existing exp_dir at {exp_dir}.") print(f"Removing experiment dir: {exp_dir}") shutil.rmtree(exp_dir) @@ -89,6 +85,8 @@ def log( optimizer, world_size, grad_norm=None, + throughput=None, + step_time=None, ): """Update metrics, print to console, and log to W&B.""" if isinstance(train_loss_array, list): @@ -98,8 +96,8 @@ def log( new_metrics = { "micro_step": micro_step, - "step": micro_step // cfg.grad_accumulation_steps, - "tokens": micro_step * cfg.micro_batch_size * cfg.seq_len * world_size, + "step": micro_step // cfg.training.grad_accumulation_steps, + "tokens": micro_step * cfg.training.micro_batch_size * cfg.data.seq_len * world_size, "lr": optimizer.param_groups[0].get("lr", float("NaN")), "train/loss": train_loss.item(), "train/loss_avg": train_loss_avg, @@ -116,8 +114,13 @@ def log( grad_norm.item() if hasattr(grad_norm, "item") else grad_norm ) + if throughput is not None: + new_metrics["train/throughput"] = throughput + if step_time is not None: + new_metrics["train/step_time"] = step_time + for k, v in new_metrics.items(): - metrics[k].append(v) + metrics.setdefault(k, []).append(v) msg = " | ".join( f"{key}: {value:.3e}" if isinstance(value, float) else f"{key}: {value}" @@ -125,7 +128,7 @@ def log( ) print(msg) - if cfg.wandb_log: + if cfg.logging.wandb_log: import wandb wandb.log(dict(new_metrics)) diff --git a/uv.lock b/uv.lock index ec2d4fe..ff428e4 100644 --- a/uv.lock +++ b/uv.lock @@ -2198,6 +2198,7 @@ dependencies = [ { name = "pandas" }, { name = "python-dotenv" }, { name = "torch" }, + { name = "tqdm" }, { name = "transformers" }, { name = "wandb" }, ] @@ -2205,6 +2206,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "pre-commit" }, + { name = "pytest" }, { name = "ruff" }, ] @@ -2218,9 +2220,11 @@ requires-dist = [ { name = "matplotlib", specifier = ">=3.10.8" }, { name = "pandas", specifier = ">=3.0.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.5.1" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" }, { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.1" }, { name = "torch", specifier = ">=2.10.0" }, + { name = "tqdm", specifier = ">=4.67.3" }, { name = "transformers", specifier = ">=5.1.0" }, { name = "wandb", specifier = ">=0.25.0" }, ]