Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
swanlab_minimal_output
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
10 changes: 5 additions & 5 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import warnings
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

import simple_parsing
import torch
import wandb
from datasets import (
Dataset,
DatasetDict,
Expand All @@ -25,6 +24,7 @@
from sae_lens.registry import get_sae_training_class
from sae_lens.saes.sae import TrainingSAEConfig
from sae_lens.util import str_to_dtype
from sae_lens.wandb_compat import BACKEND, generate_id, wandb

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -97,6 +97,8 @@ def log(
sparsity_path: Path | str | None,
wandb_aliases: list[str] | None = None,
) -> None:
if BACKEND == "swanlab":
return
# Avoid wandb saving errors such as:
# ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
sae_name = trainer.sae.get_name().replace("/", "__")
Expand Down Expand Up @@ -315,9 +317,7 @@ def __post_init__(self):

unique_id = self.logger.wandb_id
if unique_id is None:
unique_id = cast(
Any, wandb
).util.generate_id() # not sure why this type is erroring
unique_id = generate_id()
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

if self.verbose:
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/llm_sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Any, Generic

import torch
import wandb
from safetensors.torch import save_file
from simple_parsing import ArgumentParser
from transformer_lens.hook_points import HookedRootModule
Expand All @@ -32,6 +31,7 @@
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.sae_trainer import SAETrainer
from sae_lens.training.types import DataProvider
from sae_lens.wandb_compat import wandb


class InterruptedException(Exception):
Expand Down
16 changes: 10 additions & 6 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Generic, Protocol

import torch
import wandb
from safetensors.torch import save_file
from torch.optim import Adam
from tqdm.auto import tqdm
Expand All @@ -28,6 +27,7 @@
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
from sae_lens.training.types import DataProvider
from sae_lens.util import path_or_tmp_dir
from sae_lens.wandb_compat import BACKEND, wandb


def _log_feature_sparsity(
Expand Down Expand Up @@ -398,8 +398,9 @@ def _run_and_log_evals(self):
if self.evaluator is not None
else {}
)
for key, value in self.sae.log_histograms().items():
eval_metrics[key] = wandb.Histogram(value) # type: ignore
if BACKEND != "swanlab":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not need to have if-statements like this in code. Can you provide an interface so swanlab works identical to wandb, and we don't need any of these checks? Or maybe make a generic Logger interface and have both Swanlab and WandB versions of that same interface? It shouldn't be the responsibility of the SAETrainer to keep track of what logger is being used.

for key, value in self.sae.log_histograms().items():
eval_metrics[key] = wandb.Histogram(value) # type: ignore

wandb.log(
eval_metrics,
Expand All @@ -410,13 +411,16 @@ def _run_and_log_evals(self):
@torch.no_grad()
def _build_sparsity_log_dict(self) -> dict[str, Any]:
log_feature_sparsity = _log_feature_sparsity(self.feature_sparsity)
wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) # type: ignore
return {
log_dict = {
"metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
"plots/feature_density_line_chart": wandb_histogram,
"sparsity/below_1e-5": (self.feature_sparsity < 1e-5).sum().item(),
"sparsity/below_1e-6": (self.feature_sparsity < 1e-6).sum().item(),
}
if BACKEND != "swanlab":
log_dict["plots/feature_density_line_chart"] = wandb.Histogram(
log_feature_sparsity.numpy()
)
return log_dict

@torch.no_grad()
def _reset_running_sparsity_stats(self) -> None:
Expand Down
39 changes: 39 additions & 0 deletions sae_lens/wandb_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import os
import uuid
from typing import Any


def _load_backend() -> tuple[Any, str]:
backend = os.getenv("SAE_LENS_LOGGING_BACKEND", "auto").lower()
if backend not in {"auto", "wandb", "swanlab"}:
backend = "auto"

if backend in {"auto", "swanlab"}:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather the default be wandb, not swanlab. If a user just happens to have SAELens and swanlab deps in the same project their logging will seem to break and this will be very hard to debug. IMO if users want to use swanlab they should need to opt-in specifically with a ENV var.

try:
import swanlab as backend_module # type: ignore

return backend_module, "swanlab"
except Exception:
if backend == "swanlab":
raise

import wandb as backend_module # type: ignore

return backend_module, "wandb"


wandb, BACKEND = _load_backend()


def generate_id() -> str:
util = getattr(wandb, "util", None)
if util is not None:
generator = getattr(util, "generate_id", None)
if callable(generator):
try:
return generator()
except Exception:
pass
return uuid.uuid4().hex
65 changes: 65 additions & 0 deletions tests/test_swanlab_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import os
from pathlib import Path

import pytest

from sae_lens.config import LoggingConfig
from sae_lens.llm_sae_training_runner import LanguageModelSAETrainingRunner
from sae_lens.wandb_compat import BACKEND
from tests.helpers import (
NEEL_NANDA_C4_10K_DATASET,
TINYSTORIES_MODEL,
build_runner_cfg_for_arch,
)


def run_minimal_standard_sae_training(output_dir: Path) -> None:
cfg = build_runner_cfg_for_arch(
architecture="standard",
d_in=64,
d_sae=128,
training_tokens=64,
store_batch_size_prompts=2,
train_batch_size_tokens=4,
context_size=10,
n_batches_in_buffer=2,
n_eval_batches=1,
dataset_path=NEEL_NANDA_C4_10K_DATASET,
hook_name="blocks.0.hook_resid_post",
model_name=TINYSTORIES_MODEL,
n_checkpoints=0,
save_final_checkpoint=False,
output_path=str(output_dir),
logger=LoggingConfig(
log_to_wandb=True,
wandb_project=os.getenv("WANDB_PROJECT", "swanlab-minimal"),
wandb_entity=os.getenv("WANDB_ENTITY"),
wandb_log_frequency=1,
eval_every_n_wandb_logs=1,
),
)

LanguageModelSAETrainingRunner(cfg).run()


def test_swanlab_logging_runs(tmp_path: Path) -> None:
if BACKEND != "swanlab":
pytest.skip("swanlab backend not active; set SAE_LENS_LOGGING_BACKEND=swanlab")
os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab")
os.environ.setdefault("SWANLAB_MODE", "offline")
os.environ.setdefault("WANDB_MODE", "offline")
run_minimal_standard_sae_training(tmp_path / "swanlab_output")


def main() -> None:
os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab")
os.environ.setdefault("SWANLAB_MODE", "offline")
os.environ.setdefault("WANDB_MODE", "offline")
output_dir = Path(os.getenv("SAE_OUTPUT_DIR", "swanlab_minimal_output"))
run_minimal_standard_sae_training(output_dir)


if __name__ == "__main__":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't use main functions in tests. Just run the test using pytest.

main()
Loading