diff --git a/.gitignore b/.gitignore index c423ecb38..482a106b1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +swanlab_minimal_output # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/sae_lens/config.py b/sae_lens/config.py index 08ad37c7f..7edaf88db 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -3,11 +3,10 @@ import warnings from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar import simple_parsing import torch -import wandb from datasets import ( Dataset, DatasetDict, @@ -25,6 +24,7 @@ from sae_lens.registry import get_sae_training_class from sae_lens.saes.sae import TrainingSAEConfig from sae_lens.util import str_to_dtype +from sae_lens.wandb_compat import BACKEND, generate_id, wandb if TYPE_CHECKING: pass @@ -97,6 +97,8 @@ def log( sparsity_path: Path | str | None, wandb_aliases: list[str] | None = None, ) -> None: + if BACKEND == "swanlab": + return # Avoid wandb saving errors such as: # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc sae_name = trainer.sae.get_name().replace("/", "__") @@ -315,9 +317,7 @@ def __post_init__(self): unique_id = self.logger.wandb_id if unique_id is None: - unique_id = cast( - Any, wandb - ).util.generate_id() # not sure why this type is erroring + unique_id = generate_id() self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}" if self.verbose: diff --git a/sae_lens/llm_sae_training_runner.py b/sae_lens/llm_sae_training_runner.py index deb970d84..75fe8aca0 100644 --- a/sae_lens/llm_sae_training_runner.py +++ b/sae_lens/llm_sae_training_runner.py @@ -7,7 +7,6 @@ from typing import Any, Generic import torch -import wandb from safetensors.torch import save_file from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule @@ -32,6 +31,7 @@ from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.sae_trainer import SAETrainer from sae_lens.training.types import DataProvider +from sae_lens.wandb_compat import wandb class InterruptedException(Exception): diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 6cc926fbc..2fb2ef5a6 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Generic, Protocol import torch -import wandb from safetensors.torch import save_file from torch.optim import Adam from tqdm.auto import tqdm @@ -28,6 +27,7 @@ from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler from sae_lens.training.types import DataProvider from sae_lens.util import path_or_tmp_dir +from sae_lens.wandb_compat import BACKEND, wandb def _log_feature_sparsity( @@ -398,8 +398,9 @@ def _run_and_log_evals(self): if self.evaluator is not None else {} ) - for key, value in self.sae.log_histograms().items(): - eval_metrics[key] = wandb.Histogram(value) # type: ignore + if BACKEND != "swanlab": + for key, value in self.sae.log_histograms().items(): + eval_metrics[key] = wandb.Histogram(value) # type: ignore wandb.log( eval_metrics, @@ -410,13 +411,16 @@ def _run_and_log_evals(self): @torch.no_grad() def _build_sparsity_log_dict(self) -> dict[str, Any]: log_feature_sparsity = _log_feature_sparsity(self.feature_sparsity) - wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) # type: ignore - return { + log_dict = { "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(), - "plots/feature_density_line_chart": wandb_histogram, "sparsity/below_1e-5": (self.feature_sparsity < 1e-5).sum().item(), "sparsity/below_1e-6": (self.feature_sparsity < 1e-6).sum().item(), } + if BACKEND != "swanlab": + log_dict["plots/feature_density_line_chart"] = wandb.Histogram( + log_feature_sparsity.numpy() + ) + return log_dict @torch.no_grad() def _reset_running_sparsity_stats(self) -> None: diff --git a/sae_lens/wandb_compat.py b/sae_lens/wandb_compat.py new file mode 100644 index 000000000..8fd32d5b7 --- /dev/null +++ b/sae_lens/wandb_compat.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import os +import uuid +from typing import Any + + +def _load_backend() -> tuple[Any, str]: + backend = os.getenv("SAE_LENS_LOGGING_BACKEND", "auto").lower() + if backend not in {"auto", "wandb", "swanlab"}: + backend = "auto" + + if backend in {"auto", "swanlab"}: + try: + import swanlab as backend_module # type: ignore + + return backend_module, "swanlab" + except Exception: + if backend == "swanlab": + raise + + import wandb as backend_module # type: ignore + + return backend_module, "wandb" + + +wandb, BACKEND = _load_backend() + + +def generate_id() -> str: + util = getattr(wandb, "util", None) + if util is not None: + generator = getattr(util, "generate_id", None) + if callable(generator): + try: + return generator() + except Exception: + pass + return uuid.uuid4().hex diff --git a/tests/test_swanlab_log.py b/tests/test_swanlab_log.py new file mode 100644 index 000000000..56a3c3bb8 --- /dev/null +++ b/tests/test_swanlab_log.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from sae_lens.config import LoggingConfig +from sae_lens.llm_sae_training_runner import LanguageModelSAETrainingRunner +from sae_lens.wandb_compat import BACKEND +from tests.helpers import ( + NEEL_NANDA_C4_10K_DATASET, + TINYSTORIES_MODEL, + build_runner_cfg_for_arch, +) + + +def run_minimal_standard_sae_training(output_dir: Path) -> None: + cfg = build_runner_cfg_for_arch( + architecture="standard", + d_in=64, + d_sae=128, + training_tokens=64, + store_batch_size_prompts=2, + train_batch_size_tokens=4, + context_size=10, + n_batches_in_buffer=2, + n_eval_batches=1, + dataset_path=NEEL_NANDA_C4_10K_DATASET, + hook_name="blocks.0.hook_resid_post", + model_name=TINYSTORIES_MODEL, + n_checkpoints=0, + save_final_checkpoint=False, + output_path=str(output_dir), + logger=LoggingConfig( + log_to_wandb=True, + wandb_project=os.getenv("WANDB_PROJECT", "swanlab-minimal"), + wandb_entity=os.getenv("WANDB_ENTITY"), + wandb_log_frequency=1, + eval_every_n_wandb_logs=1, + ), + ) + + LanguageModelSAETrainingRunner(cfg).run() + + +def test_swanlab_logging_runs(tmp_path: Path) -> None: + if BACKEND != "swanlab": + pytest.skip("swanlab backend not active; set SAE_LENS_LOGGING_BACKEND=swanlab") + os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab") + os.environ.setdefault("SWANLAB_MODE", "offline") + os.environ.setdefault("WANDB_MODE", "offline") + run_minimal_standard_sae_training(tmp_path / "swanlab_output") + + +def main() -> None: + os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab") + os.environ.setdefault("SWANLAB_MODE", "offline") + os.environ.setdefault("WANDB_MODE", "offline") + output_dir = Path(os.getenv("SAE_OUTPUT_DIR", "swanlab_minimal_output")) + run_minimal_standard_sae_training(output_dir) + + +if __name__ == "__main__": + main()