-
Notifications
You must be signed in to change notification settings - Fork 238
feat: Add Swanlab logging backend #605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| swanlab_minimal_output | ||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
| *.py[cod] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import uuid | ||
| from typing import Any | ||
|
|
||
|
|
||
| def _load_backend() -> tuple[Any, str]: | ||
| backend = os.getenv("SAE_LENS_LOGGING_BACKEND", "auto").lower() | ||
| if backend not in {"auto", "wandb", "swanlab"}: | ||
| backend = "auto" | ||
|
|
||
| if backend in {"auto", "swanlab"}: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather the default be wandb, not swanlab. If a user just happens to have SAELens and swanlab deps in the same project their logging will seem to break and this will be very hard to debug. IMO if users want to use swanlab they should need to opt-in specifically with a ENV var. |
||
| try: | ||
| import swanlab as backend_module # type: ignore | ||
|
|
||
| return backend_module, "swanlab" | ||
| except Exception: | ||
| if backend == "swanlab": | ||
| raise | ||
|
|
||
| import wandb as backend_module # type: ignore | ||
|
|
||
| return backend_module, "wandb" | ||
|
|
||
|
|
||
| wandb, BACKEND = _load_backend() | ||
|
|
||
|
|
||
| def generate_id() -> str: | ||
| util = getattr(wandb, "util", None) | ||
| if util is not None: | ||
| generator = getattr(util, "generate_id", None) | ||
| if callable(generator): | ||
| try: | ||
| return generator() | ||
| except Exception: | ||
| pass | ||
| return uuid.uuid4().hex | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
| from sae_lens.config import LoggingConfig | ||
| from sae_lens.llm_sae_training_runner import LanguageModelSAETrainingRunner | ||
| from sae_lens.wandb_compat import BACKEND | ||
| from tests.helpers import ( | ||
| NEEL_NANDA_C4_10K_DATASET, | ||
| TINYSTORIES_MODEL, | ||
| build_runner_cfg_for_arch, | ||
| ) | ||
|
|
||
|
|
||
| def run_minimal_standard_sae_training(output_dir: Path) -> None: | ||
| cfg = build_runner_cfg_for_arch( | ||
| architecture="standard", | ||
| d_in=64, | ||
| d_sae=128, | ||
| training_tokens=64, | ||
| store_batch_size_prompts=2, | ||
| train_batch_size_tokens=4, | ||
| context_size=10, | ||
| n_batches_in_buffer=2, | ||
| n_eval_batches=1, | ||
| dataset_path=NEEL_NANDA_C4_10K_DATASET, | ||
| hook_name="blocks.0.hook_resid_post", | ||
| model_name=TINYSTORIES_MODEL, | ||
| n_checkpoints=0, | ||
| save_final_checkpoint=False, | ||
| output_path=str(output_dir), | ||
| logger=LoggingConfig( | ||
| log_to_wandb=True, | ||
| wandb_project=os.getenv("WANDB_PROJECT", "swanlab-minimal"), | ||
| wandb_entity=os.getenv("WANDB_ENTITY"), | ||
| wandb_log_frequency=1, | ||
| eval_every_n_wandb_logs=1, | ||
| ), | ||
| ) | ||
|
|
||
| LanguageModelSAETrainingRunner(cfg).run() | ||
|
|
||
|
|
||
| def test_swanlab_logging_runs(tmp_path: Path) -> None: | ||
| if BACKEND != "swanlab": | ||
| pytest.skip("swanlab backend not active; set SAE_LENS_LOGGING_BACKEND=swanlab") | ||
| os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab") | ||
| os.environ.setdefault("SWANLAB_MODE", "offline") | ||
| os.environ.setdefault("WANDB_MODE", "offline") | ||
| run_minimal_standard_sae_training(tmp_path / "swanlab_output") | ||
|
|
||
|
|
||
| def main() -> None: | ||
| os.environ.setdefault("SAE_LENS_LOGGING_BACKEND", "swanlab") | ||
| os.environ.setdefault("SWANLAB_MODE", "offline") | ||
| os.environ.setdefault("WANDB_MODE", "offline") | ||
| output_dir = Path(os.getenv("SAE_OUTPUT_DIR", "swanlab_minimal_output")) | ||
| run_minimal_standard_sae_training(output_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't use main functions in tests. Just run the test using pytest. |
||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not need to have if-statements like this in code. Can you provide an interface so swanlab works identical to wandb, and we don't need any of these checks? Or maybe make a generic
Loggerinterface and have both Swanlab and WandB versions of that same interface? It shouldn't be the responsibility of the SAETrainer to keep track of what logger is being used.