From d8574525f0ffb81390b35d5105c6ae9f944ed05c Mon Sep 17 00:00:00 2001 From: Anthony Duong Date: Thu, 17 Oct 2024 21:27:11 -0700 Subject: [PATCH] extracts repeated code to load_pretrained_saes_yaml() --- docs/generate_sae_table.py | 7 +- sae_lens/toolkit/pretrained_saes_directory.py | 73 +++++++++---------- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/docs/generate_sae_table.py b/docs/generate_sae_table.py index 1bb27ad8e..676690803 100644 --- a/docs/generate_sae_table.py +++ b/docs/generate_sae_table.py @@ -2,7 +2,6 @@ from pathlib import Path import pandas as pd -import yaml from tqdm import tqdm from sae_lens import SAEConfig @@ -11,6 +10,7 @@ get_sae_config, handle_config_defaulting, ) +from sae_lens.toolkit.pretrained_saes_directory import load_pretrained_saes_yaml INCLUDED_CFG = [ "id", @@ -32,10 +32,7 @@ def on_pre_build(config): def generate_sae_table(): - # Read the YAML file - yaml_path = Path("sae_lens/pretrained_saes.yaml") - with open(yaml_path, "r") as file: - data = yaml.safe_load(file) + data = load_pretrained_saes_yaml() # Start the Markdown content markdown_content = "# Pretrained SAEs\n\n" diff --git a/sae_lens/toolkit/pretrained_saes_directory.py b/sae_lens/toolkit/pretrained_saes_directory.py index 342866c31..80ccca643 100644 --- a/sae_lens/toolkit/pretrained_saes_directory.py +++ b/sae_lens/toolkit/pretrained_saes_directory.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from functools import cache from importlib import resources -from typing import Optional +from typing import Any, Optional import yaml @@ -19,39 +19,40 @@ class PretrainedSAELookup: config_overrides: dict[str, str] | dict[str, dict[str, str | bool | int]] | None +@cache +def load_pretrained_saes_yaml() -> dict[str, Any]: + with resources.open_text("sae_lens", "pretrained_saes.yaml") as file: + return yaml.safe_load(file) + + @cache def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]: - package = "sae_lens" - # Access the file within the package using importlib.resources directory: dict[str, PretrainedSAELookup] = {} - with resources.open_text(package, "pretrained_saes.yaml") as file: - # Load the YAML file content - data = yaml.safe_load(file) - for release, value in data.items(): - saes_map: dict[str, str] = {} - var_explained_map: dict[str, float] = {} - l0_map: dict[str, float] = {} - neuronpedia_id_map: dict[str, str] = {} - - assert "saes" in value, f"Missing 'saes' key in {release}" - for hook_info in value["saes"]: - saes_map[hook_info["id"]] = hook_info["path"] - var_explained_map[hook_info["id"]] = hook_info.get( - "variance_explained", 1.00 - ) - l0_map[hook_info["id"]] = hook_info.get("l0", 0.00) - neuronpedia_id_map[hook_info["id"]] = hook_info.get("neuronpedia") - directory[release] = PretrainedSAELookup( - release=release, - repo_id=value["repo_id"], - model=value["model"], - conversion_func=value.get("conversion_func"), - saes_map=saes_map, - expected_var_explained=var_explained_map, - expected_l0=l0_map, - neuronpedia_id=neuronpedia_id_map, - config_overrides=value.get("config_overrides"), + data = load_pretrained_saes_yaml() + for release, value in data.items(): + saes_map: dict[str, str] = {} + var_explained_map: dict[str, float] = {} + l0_map: dict[str, float] = {} + neuronpedia_id_map: dict[str, str] = {} + assert "saes" in value, f"Missing 'saes' key in {release}" + for hook_info in value["saes"]: + saes_map[hook_info["id"]] = hook_info["path"] + var_explained_map[hook_info["id"]] = hook_info.get( + "variance_explained", 1.00 ) + l0_map[hook_info["id"]] = hook_info.get("l0", 0.00) + neuronpedia_id_map[hook_info["id"]] = hook_info.get("neuronpedia") + directory[release] = PretrainedSAELookup( + release=release, + repo_id=value["repo_id"], + model=value["model"], + conversion_func=value.get("conversion_func"), + saes_map=saes_map, + expected_var_explained=var_explained_map, + expected_l0=l0_map, + neuronpedia_id=neuronpedia_id_map, + config_overrides=value.get("config_overrides"), + ) return directory @@ -66,13 +67,11 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> Optional[float]: Returns: Optional[float]: The norm_scaling_factor if it exists, None otherwise. """ - package = "sae_lens" - with resources.open_text(package, "pretrained_saes.yaml") as file: - data = yaml.safe_load(file) - if release in data: - for sae_info in data[release]["saes"]: - if sae_info["id"] == sae_id: - return sae_info.get("norm_scaling_factor") + data = load_pretrained_saes_yaml() + if release in data: + for sae_info in data[release]["saes"]: + if sae_info["id"] == sae_id: + return sae_info.get("norm_scaling_factor") return None