From 628190a98b67d8a5a7606173dc553f9e628dd161 Mon Sep 17 00:00:00 2001 From: akai <2780896955@qq.com> Date: Wed, 24 Dec 2025 22:14:42 +0800 Subject: [PATCH 1/4] Add support for logging with SwanLab. --- sae_lens/config.py | 8 ++-- sae_lens/llm_sae_training_runner.py | 2 +- sae_lens/training/sae_trainer.py | 16 ++++--- sae_lens/wandb_compat.py | 39 +++++++++++++++++ tests/test_swanlab_log.py | 65 +++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 11 deletions(-) create mode 100644 sae_lens/wandb_compat.py create mode 100644 tests/test_swanlab_log.py diff --git a/sae_lens/config.py b/sae_lens/config.py index a2a91833d..65700e2bc 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -7,7 +7,7 @@ import simple_parsing import torch -import wandb +from sae_lens.wandb_compat import BACKEND, generate_id, wandb from datasets import ( Dataset, DatasetDict, @@ -92,6 +92,8 @@ def log( sparsity_path: Path | str | None, wandb_aliases: list[str] | None = None, ) -> None: + if BACKEND == "swanlab": + return # Avoid wandb saving errors such as: # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc sae_name = trainer.sae.get_name().replace("/", "__") @@ -310,9 +312,7 @@ def __post_init__(self): unique_id = self.logger.wandb_id if unique_id is None: - unique_id = cast( - Any, wandb - ).util.generate_id() # not sure why this type is erroring + unique_id = generate_id() self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}" if self.verbose: diff --git a/sae_lens/llm_sae_training_runner.py b/sae_lens/llm_sae_training_runner.py index deb970d84..719d2ae73 100644 --- a/sae_lens/llm_sae_training_runner.py +++ b/sae_lens/llm_sae_training_runner.py @@ -7,7 +7,7 @@ from typing import Any, Generic import torch -import wandb +from sae_lens.wandb_compat import wandb from safetensors.torch import save_file from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 6cc926fbc..6060a6bb8 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Generic, Protocol import torch -import wandb +from sae_lens.wandb_compat import BACKEND, wandb from safetensors.torch import save_file from torch.optim import Adam from tqdm.auto import tqdm @@ -398,8 +398,9 @@ def _run_and_log_evals(self): if self.evaluator is not None else {} ) - for key, value in self.sae.log_histograms().items(): - eval_metrics[key] = wandb.Histogram(value) # type: ignore + if BACKEND != "swanlab": + for key, value in self.sae.log_histograms().items(): + eval_metrics[key] = wandb.Histogram(value) # type: ignore wandb.log( eval_metrics, @@ -410,13 +411,16 @@ def _run_and_log_evals(self): @torch.no_grad() def _build_sparsity_log_dict(self) -> dict[str, Any]: log_feature_sparsity = _log_feature_sparsity(self.feature_sparsity) - wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) # type: ignore - return { + log_dict = { "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(), - "plots/feature_density_line_chart": wandb_histogram, "sparsity/below_1e-5": (self.feature_sparsity < 1e-5).sum().item(), "sparsity/below_1e-6": (self.feature_sparsity < 1e-6).sum().item(), } + if BACKEND != "swanlab": + log_dict["plots/feature_density_line_chart"] = wandb.Histogram( + log_feature_sparsity.numpy() + ) + return log_dict @torch.no_grad() def _reset_running_sparsity_stats(self) -> None: diff --git a/sae_lens/wandb_compat.py b/sae_lens/wandb_compat.py new file mode 100644 index 000000000..8fd32d5b7 --- /dev/null +++ b/sae_lens/wandb_compat.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import os +import uuid +from typing import Any + + +def _load_backend() -> tuple[Any, str]: + backend = os.getenv("SAE_LENS_LOGGING_BACKEND", "auto").lower() + if backend not in {"auto", "wandb", "swanlab"}: + backend = "auto" + + if backend in {"auto", "swanlab"}: + try: + import swanlab as backend_module # type: ignore + + return backend_module, "swanlab" + except Exception: + if backend == "swanlab": + raise + + import wandb as backend_module # type: ignore + + return backend_module, "wandb" + + +wandb, BACKEND = _load_backend() + + +def generate_id() -> str: + util = getattr(wandb, "util", None) + if util is not None: + generator = getattr(util, "generate_id", None) + if callable(generator): + try: + return generator() + except Exception: + pass + return uuid.uuid4().hex diff --git a/tests/test_swanlab_log.py b/tests/test_swanlab_log.py new file mode 100644 index 000000000..56a3c3bb8 --- /dev/null +++ b/tests/test_swanlab_log.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from sae_lens.config import LoggingConfig +from sae_lens.llm_sae_training_runner import LanguageModelSAETrainingRunner +from sae_lens.wandb_compat import BACKEND +from tests.helpers import ( + NEEL_NANDA_C4_10K_DATASET, + TINYSTORIES_MODEL, + build_runner_cfg_for_arch, +) + + +def run_minimal_standard_sae_training(output_dir: Path) -> None: + cfg = build_runner_cfg_for_arch( + architecture="standard", + d_in=64, + d_sae=128, + training_tokens=64, + store_batch_size_prompts=2, + train_batch_size_tokens=4, + context_size=10, + n_batches_in_buffer=2, + n_eval_batches=1, + dataset_path=NEEL_NANDA_C4_10K_DATASET, + hook_name="blocks.0.hook_resid_post", + model_name=TINYSTORIES_MODEL, + n_checkpoints=0, + save_final_checkpoint=False, + output_path=str(output_dir), + logger=LoggingConfig( + log_to_wandb=True, + wandb_project=os.getenv("WANDB_PROJECT", "swanlab-minimal"), + wandb_entity=os.getenv("WANDB_ENTITY"), + wandb_log_frequency=1, + eval_every_n_wandb_logs=1, + ), + ) + + LanguageModelSAETrainingRunner(cfg).run() + + +def test_swanlab_logging_runs(tmp_path: Path) -> None: + if BACKEND != "swanlab": + pytest.skip("swanlab backend not active; set SAE_LENS_LOGGING_BACKEND=swanlab") + os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab") + os.environ.setdefault("SWANLAB_MODE", "offline") + os.environ.setdefault("WANDB_MODE", "offline") + run_minimal_standard_sae_training(tmp_path / "swanlab_output") + + +def main() -> None: + os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab") + os.environ.setdefault("SWANLAB_MODE", "offline") + os.environ.setdefault("WANDB_MODE", "offline") + output_dir = Path(os.getenv("SAE_OUTPUT_DIR", "swanlab_minimal_output")) + run_minimal_standard_sae_training(output_dir) + + +if __name__ == "__main__": + main() From e1f005f540673a0bb8bd3007945ba144d6b17b46 Mon Sep 17 00:00:00 2001 From: akai <2780896955@qq.com> Date: Wed, 24 Dec 2025 22:33:16 +0800 Subject: [PATCH 2/4] ignore test outputdir --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c423ecb38..482a106b1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +swanlab_minimal_output # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From 9da0930b0bda6bb844d9f9c23b4b59aceef0b42f Mon Sep 17 00:00:00 2001 From: akai <2780896955@qq.com> Date: Thu, 25 Dec 2025 03:12:23 +0000 Subject: [PATCH 3/4] make format Signed-off-by: akai <2780896955@qq.com> --- sae_lens/config.py | 4 ++-- sae_lens/llm_sae_training_runner.py | 2 +- sae_lens/training/sae_trainer.py | 2 +- tests/_comparison/sae_lens_old/config.py | 2 +- tests/_comparison/sae_lens_old/sae_training_runner.py | 2 +- tests/_comparison/sae_lens_old/training/sae_trainer.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 65700e2bc..552135fcd 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -3,11 +3,10 @@ import warnings from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar import simple_parsing import torch -from sae_lens.wandb_compat import BACKEND, generate_id, wandb from datasets import ( Dataset, DatasetDict, @@ -20,6 +19,7 @@ from sae_lens.registry import get_sae_training_class from sae_lens.saes.sae import TrainingSAEConfig from sae_lens.util import str_to_dtype +from sae_lens.wandb_compat import BACKEND, generate_id, wandb if TYPE_CHECKING: pass diff --git a/sae_lens/llm_sae_training_runner.py b/sae_lens/llm_sae_training_runner.py index 719d2ae73..75fe8aca0 100644 --- a/sae_lens/llm_sae_training_runner.py +++ b/sae_lens/llm_sae_training_runner.py @@ -7,7 +7,6 @@ from typing import Any, Generic import torch -from sae_lens.wandb_compat import wandb from safetensors.torch import save_file from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule @@ -32,6 +31,7 @@ from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.sae_trainer import SAETrainer from sae_lens.training.types import DataProvider +from sae_lens.wandb_compat import wandb class InterruptedException(Exception): diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 6060a6bb8..2fb2ef5a6 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Generic, Protocol import torch -from sae_lens.wandb_compat import BACKEND, wandb from safetensors.torch import save_file from torch.optim import Adam from tqdm.auto import tqdm @@ -28,6 +27,7 @@ from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler from sae_lens.training.types import DataProvider from sae_lens.util import path_or_tmp_dir +from sae_lens.wandb_compat import BACKEND, wandb def _log_feature_sparsity( diff --git a/tests/_comparison/sae_lens_old/config.py b/tests/_comparison/sae_lens_old/config.py index 75a963bcf..b57c4ee72 100644 --- a/tests/_comparison/sae_lens_old/config.py +++ b/tests/_comparison/sae_lens_old/config.py @@ -6,7 +6,6 @@ import simple_parsing import torch -import wandb from datasets import ( Dataset, DatasetDict, @@ -15,6 +14,7 @@ load_dataset, ) +import wandb from tests._comparison.sae_lens_old import __version__, logger DTYPE_MAP = { diff --git a/tests/_comparison/sae_lens_old/sae_training_runner.py b/tests/_comparison/sae_lens_old/sae_training_runner.py index 18bf1dc43..5fb1d7c3e 100644 --- a/tests/_comparison/sae_lens_old/sae_training_runner.py +++ b/tests/_comparison/sae_lens_old/sae_training_runner.py @@ -6,10 +6,10 @@ from typing import Any, cast import torch -import wandb from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule +import wandb from tests._comparison.sae_lens_old import logger from tests._comparison.sae_lens_old.config import ( HfDataset, diff --git a/tests/_comparison/sae_lens_old/training/sae_trainer.py b/tests/_comparison/sae_lens_old/training/sae_trainer.py index 69e00b138..098cb516b 100644 --- a/tests/_comparison/sae_lens_old/training/sae_trainer.py +++ b/tests/_comparison/sae_lens_old/training/sae_trainer.py @@ -2,11 +2,11 @@ from typing import Any, Protocol, cast import torch -import wandb from torch.optim import Adam from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule +import wandb from tests._comparison.sae_lens_old import __version__ from tests._comparison.sae_lens_old.config import LanguageModelSAERunnerConfig from tests._comparison.sae_lens_old.evals import EvalConfig, run_evals From e46b04c82d6d5f11567b7cb1c6a8d4ae5f4125e6 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 25 Dec 2025 23:44:43 -0500 Subject: [PATCH 4/4] fixing linting --- tests/_comparison/sae_lens_old/config.py | 2 +- tests/_comparison/sae_lens_old/sae_training_runner.py | 2 +- tests/_comparison/sae_lens_old/training/sae_trainer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/_comparison/sae_lens_old/config.py b/tests/_comparison/sae_lens_old/config.py index b57c4ee72..75a963bcf 100644 --- a/tests/_comparison/sae_lens_old/config.py +++ b/tests/_comparison/sae_lens_old/config.py @@ -6,6 +6,7 @@ import simple_parsing import torch +import wandb from datasets import ( Dataset, DatasetDict, @@ -14,7 +15,6 @@ load_dataset, ) -import wandb from tests._comparison.sae_lens_old import __version__, logger DTYPE_MAP = { diff --git a/tests/_comparison/sae_lens_old/sae_training_runner.py b/tests/_comparison/sae_lens_old/sae_training_runner.py index 5fb1d7c3e..18bf1dc43 100644 --- a/tests/_comparison/sae_lens_old/sae_training_runner.py +++ b/tests/_comparison/sae_lens_old/sae_training_runner.py @@ -6,10 +6,10 @@ from typing import Any, cast import torch +import wandb from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule -import wandb from tests._comparison.sae_lens_old import logger from tests._comparison.sae_lens_old.config import ( HfDataset, diff --git a/tests/_comparison/sae_lens_old/training/sae_trainer.py b/tests/_comparison/sae_lens_old/training/sae_trainer.py index 098cb516b..69e00b138 100644 --- a/tests/_comparison/sae_lens_old/training/sae_trainer.py +++ b/tests/_comparison/sae_lens_old/training/sae_trainer.py @@ -2,11 +2,11 @@ from typing import Any, Protocol, cast import torch +import wandb from torch.optim import Adam from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule -import wandb from tests._comparison.sae_lens_old import __version__ from tests._comparison.sae_lens_old.config import LanguageModelSAERunnerConfig from tests._comparison.sae_lens_old.evals import EvalConfig, run_evals