From 0a53fbf6b7bb9ff5d672eacced36fd25b5f5c2be Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 18 Sep 2024 09:13:51 +0100 Subject: [PATCH 01/29] support seqpos slicing --- sae_lens/config.py | 3 +++ sae_lens/training/activations_store.py | 27 +++++++++++++------------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 3fc4ac75e..7673c9201 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig: store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations. train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop. normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output). + seqpos_slice (tuple): Determines slicing of (batch, seq, d_in) activations when constructing batches, during training. Example: for Othello we sometimes use (5, -5). device (str): The device to use. Usually cuda. act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram. seed (int): The seed to use. @@ -153,6 +154,7 @@ class LanguageModelSAERunnerConfig: normalize_activations: str = ( "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) ) + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" @@ -461,6 +463,7 @@ class CacheActivationsRunnerConfig: store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 normalize_activations: str = "none" # should always be none for activation caching + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 9f53c0eb5..ab94e74fd 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -88,6 +88,7 @@ def from_config( model_kwargs=cfg.model_kwargs, autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, + seqpos_slice=cfg.seqpos_slice, ) @classmethod @@ -147,6 +148,7 @@ def __init__( model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, dataset_trust_remote_code: bool | None = None, + seqpos_slice: tuple[int | None, ...] = (None,) ): self.model = model if model_kwargs is None: @@ -188,6 +190,7 @@ def __init__( self.dtype = DTYPE_MAP[dtype] self.cached_activations_path = cached_activations_path self.autocast_lm = autocast_lm + self.seqpos_slice = seqpos_slice self.n_dataset_processed = 0 @@ -441,7 +444,7 @@ def get_activations(self, batch_tokens: torch.Tensor): autocast_if_enabled = contextlib.nullcontext() with autocast_if_enabled: - layerwise_activations = self.model.run_with_cache( + layerwise_activations_cache = self.model.run_with_cache( batch_tokens, names_filter=[self.hook_name], stop_at_layer=self.hook_layer + 1, @@ -449,29 +452,26 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - n_batches, n_context = batch_tokens.shape + layerwise_activations = layerwise_activations_cache[self.hook_name][:, slice(*self.seqpos_slice)] + n_batches, n_context = layerwise_activations.shape[:2] stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) if self.hook_head_index is not None: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][ + stacked_activations[:, :, 0] = layerwise_activations[ :, :, self.hook_head_index ] elif ( - layerwise_activations[self.hook_name].ndim > 3 + layerwise_activations.ndim > 3 ): # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].view(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.view(n_batches, n_context, -1) except RuntimeError as e: print(f"Error during view operation: {e}") print("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].reshape(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.reshape(n_batches, n_context, -1) else: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name] + stacked_activations[:, :, 0] = layerwise_activations return stacked_activations @@ -487,6 +487,7 @@ def get_buffer( If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react. """ context_size = self.context_size + training_context_size = len(range(context_size)[slice(*self.seqpos_slice)]) batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer @@ -494,7 +495,7 @@ def get_buffer( if self.cached_activations_path is not None: # Load the activations from disk - buffer_size = total_size * context_size + buffer_size = total_size * training_context_size # Initialize an empty tensor with an additional dimension for layers new_buffer = torch.zeros( (buffer_size, num_layers, d_in), @@ -548,7 +549,7 @@ def get_buffer( refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size) # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers new_buffer = torch.zeros( - (total_size, context_size, num_layers, d_in), + (total_size, training_context_size, num_layers, d_in), dtype=self.dtype, # type: ignore device=self.device, ) From 3ba222b5a48994290e7ecab5f3bc7809248c3777 Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 10:47:03 +0100 Subject: [PATCH 02/29] add basic tests, ensure it's in the SAE config --- sae_lens/config.py | 1 + sae_lens/sae.py | 6 +++++ sae_lens/training/activations_store.py | 3 ++- tests/unit/training/test_activations_store.py | 23 +++++++++++++++++++ tests/unit/training/test_sae_basic.py | 17 ++++++++++++++ 5 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 7673c9201..4e1e14d08 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -388,6 +388,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "activation_fn_kwargs": self.activation_fn_kwargs, "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs, + "seqpos_slice": self.seqpos_slice, } def get_training_sae_cfg_dict(self) -> dict[str, Any]: diff --git a/sae_lens/sae.py b/sae_lens/sae.py index f26b89118..83d503eb8 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -62,6 +62,7 @@ class SAEConfig: activation_fn_kwargs: dict[str, Any] = field(default_factory=dict) neuronpedia_id: Optional[str] = None model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict) + seqpos_slice: tuple[int | None, ...] = (None,) @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": @@ -81,6 +82,10 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": for k, v in config_dict.items() if k in cls.__dataclass_fields__ # pylint: disable=no-member } + + if "seqpos_slice" in config_dict: + config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"]) + return cls(**config_dict) # def __post_init__(self): @@ -108,6 +113,7 @@ def to_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "neuronpedia_id": self.neuronpedia_id, "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs, + "seqpos_slice": self.seqpos_slice, } diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index ab94e74fd..da5b14f5a 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -124,6 +124,7 @@ def from_sae( dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code, dtype=sae.cfg.dtype, device=torch.device(device), + seqpos_slice=sae.cfg.seqpos_slice, ) def __init__( @@ -148,7 +149,7 @@ def __init__( model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, dataset_trust_remote_code: bool | None = None, - seqpos_slice: tuple[int | None, ...] = (None,) + seqpos_slice: tuple[int | None, ...] = (None,), ): self.model = model if model_kwargs is None: diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index 7f14974dd..c0a76b198 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -478,3 +478,26 @@ def test_validate_pretokenized_dataset_tokenizer_does_nothing_if_the_dataset_pat model_tokenizer = ts_model.tokenizer assert model_tokenizer is not None validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer) + + +def test_activations_store_respects_seqpos_slice(ts_model: HookedTransformer): + cfg = build_sae_cfg( + context_size=10, + seqpos_slice=(2, 8), # Only consider positions 2 to 7 (inclusive) + ) + dataset = Dataset.from_list( + [ + {"text": "This is a test sentence for slicing."}, + ] + * 100 + ) + + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + + batch = activation_store.get_batch_tokens(1) + activations = activation_store.get_activations(batch) + + assert batch.shape == (1, 10) # Full context size + assert activations.shape == (1, 6, 1, cfg.d_in) # Only 6 positions (2 to 7) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 55428bb94..530fca2cd 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -225,6 +225,23 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: assert torch.allclose(sae_out_1, sae_out_2) +def test_sae_seqpos(tmp_path: Path) -> None: + cfg = build_sae_cfg( + seqpos_slice=(1, 3), + device="cpu", + ) + model_path = str(tmp_path) + sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + + assert sae.cfg.seqpos_slice == (1, 3) + + sae.save_model(model_path) + + sae_loaded = SAE.load_from_pretrained(model_path, device="cpu") + + assert sae_loaded.cfg.seqpos_slice == (1, 3) + + # TODO: Handle scaling factor in saeBase # def test_sae_save_and_load_from_pretrained_lacks_scaling_factor( # tmp_path: Path, From b54d188c3245482ecffa5b600760006c6c2f8abb Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 10:47:15 +0100 Subject: [PATCH 03/29] format --- sae_lens/training/activations_store.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index da5b14f5a..e4a7618e0 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -453,7 +453,9 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - layerwise_activations = layerwise_activations_cache[self.hook_name][:, slice(*self.seqpos_slice)] + layerwise_activations = layerwise_activations_cache[self.hook_name][ + :, slice(*self.seqpos_slice) + ] n_batches, n_context = layerwise_activations.shape[:2] stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) @@ -462,15 +464,17 @@ def get_activations(self, batch_tokens: torch.Tensor): stacked_activations[:, :, 0] = layerwise_activations[ :, :, self.hook_head_index ] - elif ( - layerwise_activations.ndim > 3 - ): # if we have a head dimension + elif layerwise_activations.ndim > 3: # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations.view(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.view( + n_batches, n_context, -1 + ) except RuntimeError as e: print(f"Error during view operation: {e}") print("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations.reshape(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.reshape( + n_batches, n_context, -1 + ) else: stacked_activations[:, :, 0] = layerwise_activations From 264a570b852fdeb478184094333c3b574267b422 Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 10:57:46 +0100 Subject: [PATCH 04/29] fix tests --- sae_lens/training/training_sae.py | 1 + tests/unit/training/test_config.py | 1 + 2 files changed, 2 insertions(+) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 217e7252a..66637716e 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -75,6 +75,7 @@ def from_sae_runner_config( context_size=cfg.context_size, dataset_path=cfg.dataset_path, prepend_bos=cfg.prepend_bos, + seqpos_slice=cfg.seqpos_slice, # Training cfg l1_coefficient=cfg.l1_coefficient, lp_norm=cfg.lp_norm, diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index ca0f154a2..6643d8182 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -67,6 +67,7 @@ def test_sae_training_runner_config_get_sae_base_parameters(): "model_from_pretrained_kwargs": { "center_writing_weights": False, }, + "seqpos_slice": (None,), } assert expected_config == cfg.get_base_sae_cfg_dict() From 48b92c535988cb61cbe00e62e054f4bca426f0a1 Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 11:19:14 +0100 Subject: [PATCH 05/29] fix tests 2 --- sae_lens/config.py | 9 +++++++++ sae_lens/training/training_sae.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sae_lens/config.py b/sae_lens/config.py index 4e1e14d08..45aaead6a 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -430,6 +430,15 @@ def to_json(self, path: str) -> None: def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig": with open(path + "cfg.json", "r") as f: cfg = json.load(f) + + # ensure that seqpos slices is a tuple + # Ensure seqpos_slice is a tuple + if "seqpos_slice" in cfg: + if isinstance(cfg["seqpos_slice"], list): + cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"]) + elif not isinstance(cfg["seqpos_slice"], tuple): + cfg["seqpos_slice"] = (cfg["seqpos_slice"],) + return cls(**cfg) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 66637716e..b7925d4ef 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -100,6 +100,18 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig": valid_config_dict = { key: val for key, val in config_dict.items() if key in valid_field_names } + + # ensure seqpos slice is tuple + # ensure that seqpos slices is a tuple + # Ensure seqpos_slice is a tuple + if "seqpos_slice" in valid_config_dict: + if isinstance(valid_config_dict["seqpos_slice"], list): + valid_config_dict["seqpos_slice"] = tuple( + valid_config_dict["seqpos_slice"] + ) + elif not isinstance(valid_config_dict["seqpos_slice"], tuple): + valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],) + return TrainingSAEConfig(**valid_config_dict) def to_dict(self) -> dict[str, Any]: From 54d1105158529a461fbfefc527560f2bab2027da Mon Sep 17 00:00:00 2001 From: liuman Date: Mon, 30 Sep 2024 17:24:47 +0100 Subject: [PATCH 06/29] fix: Changing the activations store to handle context sizes smaller than dataset lengths for tokenized datasets. --- sae_lens/training/activations_store.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index e4a7618e0..dc1d2a50a 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -221,17 +221,8 @@ def __init__( ds_context_size = len(dataset_sample[self.tokens_column]) if ds_context_size < self.context_size: raise ValueError( - f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. - The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}.""" - ) - if self.context_size < 0: - raise ValueError( - f"The provided context_size is {self.context_size} is negative. Expecting positive context_size" - ) - if self.context_size != ds_context_size: - warnings.warn( - f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""", - RuntimeWarning, + f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. + The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}.""" ) # TODO: investigate if this can work for iterable datasets, or if this is even worthwhile as a perf improvement if hasattr(self.dataset, "set_format"): From eb04a0190049ed6993da65f8c68d22a8fbf87bff Mon Sep 17 00:00:00 2001 From: liuman Date: Mon, 30 Sep 2024 17:56:37 +0100 Subject: [PATCH 07/29] fix: Found bug which allowed for negative context lengths. Removed the bug --- sae_lens/training/activations_store.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index dc1d2a50a..4f126e57d 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -224,6 +224,10 @@ def __init__( f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}.""" ) + if self.context_size<0: + raise ValueError( + f"The provided context_size is {self.context_size} is negative. Expecting positive context_size" + ) # TODO: investigate if this can work for iterable datasets, or if this is even worthwhile as a perf improvement if hasattr(self.dataset, "set_format"): self.dataset.set_format(type="torch", columns=[self.tokens_column]) # type: ignore From cc43814e4512f98c497d237bc9b885d2335ef0e3 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Mon, 30 Sep 2024 19:35:51 +0100 Subject: [PATCH 08/29] Update pytest to test new logic for context size of tokenized dataset --- tests/unit/training/test_activations_store.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index c0a76b198..a6d3ee10f 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -1,6 +1,7 @@ from collections.abc import Iterable from math import ceil from typing import Optional +from typing import Optional import numpy as np import pytest @@ -349,14 +350,13 @@ def test_activations_store___iterate_tokenized_sequences__yields_sequences_of_co assert toks.shape == (5,) -# We expect the code to work for context_size being less than or equal to the +# We expect the code to work for context_size being less than or equal to the # length of the dataset -@pytest.mark.parametrize( - "context_size, expected_error", - [(-1, ValueError), (5, RuntimeWarning), (10, None), (15, ValueError)], -) +@pytest.mark.parametrize("context_size, expected_error", [(-1, ValueError), (5, None), (10, None), (15, ValueError)]) def test_activations_store__errors_on_context_size_mismatch( - ts_model: HookedTransformer, context_size: int, expected_error: Optional[ValueError] + ts_model: HookedTransformer, + context_size: int, + expected_error: Optional[ValueError] ): tokenizer = ts_model.tokenizer assert tokenizer is not None From 0284000efc5c894141a3a99fec0fd7a53b2c117a Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Mon, 30 Sep 2024 19:45:22 +0100 Subject: [PATCH 09/29] Reformat code to pass CI tests --- sae_lens/training/activations_store.py | 6 +++--- tests/unit/training/test_activations_store.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 4f126e57d..8e913b670 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -221,10 +221,10 @@ def __init__( ds_context_size = len(dataset_sample[self.tokens_column]) if ds_context_size < self.context_size: raise ValueError( - f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. - The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}.""" + f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. + The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}.""" ) - if self.context_size<0: + if self.context_size < 0: raise ValueError( f"The provided context_size is {self.context_size} is negative. Expecting positive context_size" ) diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index a6d3ee10f..0b4d139c4 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -350,13 +350,14 @@ def test_activations_store___iterate_tokenized_sequences__yields_sequences_of_co assert toks.shape == (5,) -# We expect the code to work for context_size being less than or equal to the +# We expect the code to work for context_size being less than or equal to the # length of the dataset -@pytest.mark.parametrize("context_size, expected_error", [(-1, ValueError), (5, None), (10, None), (15, ValueError)]) +@pytest.mark.parametrize( + "context_size, expected_error", + [(-1, ValueError), (5, None), (10, None), (15, ValueError)], +) def test_activations_store__errors_on_context_size_mismatch( - ts_model: HookedTransformer, - context_size: int, - expected_error: Optional[ValueError] + ts_model: HookedTransformer, context_size: int, expected_error: Optional[ValueError] ): tokenizer = ts_model.tokenizer assert tokenizer is not None From c12550f20c936618c0c326fc58296e5e3a7c7f2d Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Tue, 1 Oct 2024 16:47:14 +0100 Subject: [PATCH 10/29] Add warning for when context_size is smaller than the dataset context_size --- sae_lens/training/activations_store.py | 5 +++++ tests/unit/training/test_activations_store.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 8e913b670..e4a7618e0 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -228,6 +228,11 @@ def __init__( raise ValueError( f"The provided context_size is {self.context_size} is negative. Expecting positive context_size" ) + if self.context_size != ds_context_size: + warnings.warn( + f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""", + RuntimeWarning, + ) # TODO: investigate if this can work for iterable datasets, or if this is even worthwhile as a perf improvement if hasattr(self.dataset, "set_format"): self.dataset.set_format(type="torch", columns=[self.tokens_column]) # type: ignore diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index 0b4d139c4..39a447193 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -354,7 +354,7 @@ def test_activations_store___iterate_tokenized_sequences__yields_sequences_of_co # length of the dataset @pytest.mark.parametrize( "context_size, expected_error", - [(-1, ValueError), (5, None), (10, None), (15, ValueError)], + [(-1, ValueError), (5, RuntimeWarning), (10, None), (15, ValueError)], ) def test_activations_store__errors_on_context_size_mismatch( ts_model: HookedTransformer, context_size: int, expected_error: Optional[ValueError] From 59439bf2fb98cc911568d7c71cbc7ec7dd04c52c Mon Sep 17 00:00:00 2001 From: liuman Date: Tue, 1 Oct 2024 18:20:00 +0100 Subject: [PATCH 11/29] feat: adding support for start and end position offsets for token sequences --- sae_lens/config.py | 10 ++++++++++ sae_lens/training/activations_store.py | 18 ++++++++++++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 45aaead6a..53810c2c5 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -122,6 +122,8 @@ class LanguageModelSAERunnerConfig: streaming: bool = True is_dataset_tokenized: bool = True context_size: int = 128 + start_pos_offset: int = 0 # set to n if you want to exclude first n seq positions from sae training + end_pos_offset: int = 0 # set to n if you want to exclude last n seq positions from sae training use_cached_activations: bool = False cached_activations_path: Optional[str] = ( None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}" @@ -356,6 +358,14 @@ def __post_init__(self): if self.use_ghost_grads: print("Using Ghost Grads.") + + if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size): + raise ValueError(f"Start position offset {self.start_pos_offset} should be in range [0,{self.context_size}]") + if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size): + raise ValueError(f"End position offset {self.end_pos_offset} should be in range [0,{self.context_size-1}]") + if self.start_pos_offset + self.end_pos_offset > self.context_size: + raise ValueError(f"""Choice of start and end position overlap. Obtained + {self.start_pos_offset, self.end_pos_offset} with context size {self.context_size}""") @property def total_training_tokens(self) -> int: diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index e4a7618e0..f6d7a62c3 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -75,6 +75,8 @@ def from_config( hook_layer=cfg.hook_layer, hook_head_index=cfg.hook_head_index, context_size=cfg.context_size, + start_pos_offset=cfg.start_pos_offset, + end_pos_offset=cfg.end_pos_offset, d_in=cfg.d_in, n_batches_in_buffer=cfg.n_batches_in_buffer, total_training_tokens=cfg.training_tokens, @@ -97,6 +99,8 @@ def from_sae( model: HookedRootModule, sae: SAE, context_size: int | None = None, + start_pos_offset: int = 0, + end_pos_offset: int = 0, dataset: HfDataset | str | None = None, streaming: bool = True, store_batch_size_prompts: int = 8, @@ -114,6 +118,8 @@ def from_sae( hook_layer=sae.cfg.hook_layer, hook_head_index=sae.cfg.hook_head_index, context_size=sae.cfg.context_size if context_size is None else context_size, + start_pos_offset=start_pos_offset, + end_pos_offset=end_pos_offset, prepend_bos=sae.cfg.prepend_bos, streaming=streaming, store_batch_size_prompts=store_batch_size_prompts, @@ -136,6 +142,8 @@ def __init__( hook_layer: int, hook_head_index: int | None, context_size: int, + start_pos_offset: int, + end_pos_offset: int, d_in: int, n_batches_in_buffer: int, total_training_tokens: int, @@ -179,6 +187,8 @@ def __init__( self.hook_layer = hook_layer self.hook_head_index = hook_head_index self.context_size = context_size + self.start_pos_offset = start_pos_offset + self.end_pos_offset = end_pos_offset self.d_in = d_in self.n_batches_in_buffer = n_batches_in_buffer self.half_buffer_size = n_batches_in_buffer // 2 @@ -453,10 +463,7 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - layerwise_activations = layerwise_activations_cache[self.hook_name][ - :, slice(*self.seqpos_slice) - ] - n_batches, n_context = layerwise_activations.shape[:2] + n_batches, n_context = batch_tokens.shape stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) @@ -497,6 +504,9 @@ def get_buffer( d_in = self.d_in total_size = batch_size * n_batches_in_buffer num_layers = 1 + # Calculate the effective context size + context_window = list(range(self.start_pos_offset, context_size-self.end_pos_offset)) + effective_context_size = len(context_window) if self.cached_activations_path is not None: # Load the activations from disk From ac7ed3b6b91d6e2f95480feefd3a50b88df5b888 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Wed, 2 Oct 2024 12:00:26 +0100 Subject: [PATCH 12/29] Add start_pos_offset and end_pos_offset to the SAERunnerConfig --- sae_lens/config.py | 24 +++++++++++++++++------- sae_lens/training/activations_store.py | 4 +++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 53810c2c5..db434a7f6 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -122,8 +122,12 @@ class LanguageModelSAERunnerConfig: streaming: bool = True is_dataset_tokenized: bool = True context_size: int = 128 - start_pos_offset: int = 0 # set to n if you want to exclude first n seq positions from sae training - end_pos_offset: int = 0 # set to n if you want to exclude last n seq positions from sae training + start_pos_offset: int = ( + 0 # set to n if you want to exclude first n seq positions from sae training + ) + end_pos_offset: int = ( + 0 # set to n if you want to exclude last n seq positions from sae training + ) use_cached_activations: bool = False cached_activations_path: Optional[str] = ( None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}" @@ -358,14 +362,20 @@ def __post_init__(self): if self.use_ghost_grads: print("Using Ghost Grads.") - + if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size): - raise ValueError(f"Start position offset {self.start_pos_offset} should be in range [0,{self.context_size}]") + raise ValueError( + f"Start position offset {self.start_pos_offset} should be in range [0,{self.context_size}]" + ) if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size): - raise ValueError(f"End position offset {self.end_pos_offset} should be in range [0,{self.context_size-1}]") + raise ValueError( + f"End position offset {self.end_pos_offset} should be in range [0,{self.context_size-1}]" + ) if self.start_pos_offset + self.end_pos_offset > self.context_size: - raise ValueError(f"""Choice of start and end position overlap. Obtained - {self.start_pos_offset, self.end_pos_offset} with context size {self.context_size}""") + raise ValueError( + f"""Choice of start and end position overlap. Obtained + {self.start_pos_offset, self.end_pos_offset} with context size {self.context_size}""" + ) @property def total_training_tokens(self) -> int: diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index f6d7a62c3..ec0d99497 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -505,7 +505,9 @@ def get_buffer( total_size = batch_size * n_batches_in_buffer num_layers = 1 # Calculate the effective context size - context_window = list(range(self.start_pos_offset, context_size-self.end_pos_offset)) + context_window = list( + range(self.start_pos_offset, context_size - self.end_pos_offset) + ) effective_context_size = len(context_window) if self.cached_activations_path is not None: From 560ae8a7d6700e0588cbf6bd15d461f9f0ba86e8 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Wed, 2 Oct 2024 12:01:14 +0100 Subject: [PATCH 13/29] Add tests for start_pos_offset and end_pos_offset in the LanguageModelSAERunnerConfig --- tests/unit/training/test_config.py | 34 ++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 6643d8182..1446807b3 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -92,3 +92,37 @@ def test_sae_training_runner_config_expansion_factor(): cfg = LanguageModelSAERunnerConfig() assert cfg.expansion_factor == 4 + + +@pytest.mark.parametrize( + "start_pos_offset, end_pos_offset, expected_error", + [ + (-1, 0, ValueError), + (0, 0, None), + (10, 0, None), + (11, 0, ValueError), + (0, -1, ValueError), + (0, 10, ValueError), + (0, 11, ValueError), + (5, 5, None), + (6, 5, ValueError), + (3, 4, None), + ], +) +def test_sae_training_runner_config_start_end_pos_offset( + start_pos_offset: int, end_pos_offset: int, expected_error: Optional[ValueError] +): + context_size = 10 + if expected_error is ValueError: + with pytest.raises(expected_error): + LanguageModelSAERunnerConfig( + start_pos_offset=start_pos_offset, + end_pos_offset=end_pos_offset, + context_size=context_size, + ) + else: + LanguageModelSAERunnerConfig( + start_pos_offset=start_pos_offset, + end_pos_offset=end_pos_offset, + context_size=context_size, + ) From 93ebea68a8c064274075d512316a7073465b0c5e Mon Sep 17 00:00:00 2001 From: liuman Date: Wed, 2 Oct 2024 16:46:37 +0100 Subject: [PATCH 14/29] feat: start and end position offset support for SAELens. --- sae_lens/config.py | 37 +++++++++++++++++-- sae_lens/training/activations_store.py | 4 -- tests/unit/training/test_activations_store.py | 8 +++- tests/unit/training/test_config.py | 2 + 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index db434a7f6..1f647d954 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -40,6 +40,8 @@ class LanguageModelSAERunnerConfig: streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical. is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized. context_size (int): The context size to use when generating activations on which to train the SAE. + start_pos_offset (int): A positive offset to cut off the start of the sequences used to train the SAE. + end_pos_offset (int): A positive offset to cut off the end of the sequences used to train the SAE. use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations. cached_activations_path (str, optional): The path to the cached activations. d_in (int): The input dimension of the SAE. @@ -363,18 +365,22 @@ def __post_init__(self): if self.use_ghost_grads: print("Using Ghost Grads.") + if self.context_size < 0: + raise ValueError( + f"The provided context_size is {self.context_size} is negative. Expecting positive context_size." + ) + if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size): raise ValueError( - f"Start position offset {self.start_pos_offset} should be in range [0,{self.context_size}]" + f"Start position offset {self.start_pos_offset} should be in range [0, {self.context_size}]" ) if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size): raise ValueError( - f"End position offset {self.end_pos_offset} should be in range [0,{self.context_size-1}]" + f"End position offset {self.end_pos_offset} should be in range [0, {self.context_size-1}]" ) if self.start_pos_offset + self.end_pos_offset > self.context_size: raise ValueError( - f"""Choice of start and end position overlap. Obtained - {self.start_pos_offset, self.end_pos_offset} with context size {self.context_size}""" + f"Choice of {self.start_pos_offset=} and {self.end_pos_offset=} is incompatible with {self.context_size=}. We expect start_pos_offset + end_pos_offset < context_size." ) @property @@ -479,6 +485,12 @@ class CacheActivationsRunnerConfig: streaming: bool = True is_dataset_tokenized: bool = True context_size: int = 128 + start_pos_offset: int = ( + 0 # set to n if you want to exclude first n seq positions from sae training + ) + end_pos_offset: int = ( + 0 # set to n if you want to exclude last n seq positions from sae training + ) new_cached_activations_path: Optional[str] = ( None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}" ) @@ -524,6 +536,23 @@ def __post_init__(self): if self.act_store_device == "with_model": self.act_store_device = self.device + if self.context_size < 0: + raise ValueError( + f"The provided context_size is {self.context_size} is negative. Expecting positive context_size." + ) + if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size): + raise ValueError( + f"Start position offset {self.start_pos_offset} should be in range [0, {self.context_size}]" + ) + if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size): + raise ValueError( + f"End position offset {self.end_pos_offset} should be in range [0, {self.context_size-1}]" + ) + if self.start_pos_offset + self.end_pos_offset > self.context_size: + raise ValueError( + f"Choice of {self.start_pos_offset=} and {self.end_pos_offset=} is incompatible with {self.context_size=}. We expect start_pos_offset + end_pos_offset < context_size." + ) + @dataclass class ToyModelSAERunnerConfig: diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index ec0d99497..222bc7db0 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -234,10 +234,6 @@ def __init__( f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}.""" ) - if self.context_size < 0: - raise ValueError( - f"The provided context_size is {self.context_size} is negative. Expecting positive context_size" - ) if self.context_size != ds_context_size: warnings.warn( f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""", diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index 39a447193..eef03a5d6 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -354,7 +354,7 @@ def test_activations_store___iterate_tokenized_sequences__yields_sequences_of_co # length of the dataset @pytest.mark.parametrize( "context_size, expected_error", - [(-1, ValueError), (5, RuntimeWarning), (10, None), (15, ValueError)], + [(5, RuntimeWarning), (10, None), (15, ValueError)], ) def test_activations_store__errors_on_context_size_mismatch( ts_model: HookedTransformer, context_size: int, expected_error: Optional[ValueError] @@ -390,6 +390,12 @@ def test_activations_store__errors_on_context_size_mismatch( ActivationsStore.from_config(ts_model, cfg, override_dataset=tokenized_dataset) +def test_activations_store__errors_on_negative_context_size(): + with pytest.raises(ValueError): + # We should raise an error when the context_size is negative + build_sae_cfg(prepend_bos=True, context_size=-1) + + def test_activations_store___iterate_tokenized_sequences__yields_identical_results_with_and_without_pretokenizing( ts_model: HookedTransformer, ): diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 1446807b3..68688f8e9 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from sae_lens import __version__ From 340500fc8c7e8e93e2c878c4f8344863b320a127 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Wed, 2 Oct 2024 17:02:22 +0100 Subject: [PATCH 15/29] Add test for CacheActivationsRunnerConfig with start and end pos offset --- tests/unit/training/test_config.py | 36 +++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 68688f8e9..98cad1851 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -3,7 +3,7 @@ import pytest from sae_lens import __version__ -from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.config import CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig TINYSTORIES_MODEL = "tiny-stories-1M" TINYSTORIES_DATASET = "roneneldan/TinyStories" @@ -128,3 +128,37 @@ def test_sae_training_runner_config_start_end_pos_offset( end_pos_offset=end_pos_offset, context_size=context_size, ) + + +@pytest.mark.parametrize( + "start_pos_offset, end_pos_offset, expected_error", + [ + (-1, 0, ValueError), + (0, 0, None), + (10, 0, None), + (11, 0, ValueError), + (0, -1, ValueError), + (0, 10, ValueError), + (0, 11, ValueError), + (5, 5, None), + (6, 5, ValueError), + (3, 4, None), + ], +) +def test_cache_activations_runner_config_start_end_pos_offset( + start_pos_offset: int, end_pos_offset: int, expected_error: Optional[ValueError] +): + context_size = 10 + if expected_error is ValueError: + with pytest.raises(expected_error): + CacheActivationsRunnerConfig( + start_pos_offset=start_pos_offset, + end_pos_offset=end_pos_offset, + context_size=context_size, + ) + else: + CacheActivationsRunnerConfig( + start_pos_offset=start_pos_offset, + end_pos_offset=end_pos_offset, + context_size=context_size, + ) From c436a4f6a7e01bc89aa9a55be1f20c40d51f96f8 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Wed, 2 Oct 2024 17:09:42 +0100 Subject: [PATCH 16/29] Test cache activation runner wtih valid start and end pos offset --- .../training/test_cache_activations_runner.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/unit/training/test_cache_activations_runner.py b/tests/unit/training/test_cache_activations_runner.py index 7358fa5cc..ee81525f5 100644 --- a/tests/unit/training/test_cache_activations_runner.py +++ b/tests/unit/training/test_cache_activations_runner.py @@ -216,3 +216,89 @@ def W_E(self) -> torch.Tensor: # no errors are ever raised if we do not ask for raise_at_epoch_end for _ in range(32): _ = activations_store.get_batch_tokens(batch_size, raise_at_epoch_end=False) + + +# The way to run this with this command: +# poetry run py.test tests/unit/test_cache_activations_runner.py --profile-svg -s +def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # total_training_steps = 20_000 + context_size = 1024 + start_pos_offset = 12 + end_pos_offset = 12 + effective_context_size = context_size - start_pos_offset - end_pos_offset + print(f"n tokens per context: {context_size}") + n_batches_in_buffer = 32 + print(f"n batches in buffer: {n_batches_in_buffer}") + store_batch_size = 1 + print(f"store_batch_size: {store_batch_size}") + n_buffers = 3 + print(f"n_buffers: {n_buffers}") + + tokens_in_buffer = n_batches_in_buffer * store_batch_size * effective_context_size + total_training_tokens = n_buffers * tokens_in_buffer + print(f"Total Training Tokens: {total_training_tokens}") + + # better if we can look at the files (change tmp_path to a real path to look at the files) + # tmp_path = os.path.join(os.path.dirname(__file__), "tmp") + # tmp_path = Path("/Volumes/T7 Shield/activations/gelu_1l") + # if os.path.exists(tmp_path): + # shutil.rmtree(tmp_path) + + cfg = CacheActivationsRunnerConfig( + new_cached_activations_path=str(tmp_path), + # Pick a tiny model to make this easier. + model_name="gelu-1l", + ## MLP Layer 0 ## + hook_name="blocks.0.hook_mlp_out", + hook_layer=0, + d_in=512, + dataset_path="NeelNanda/c4-tokenized-2b", + context_size=context_size, # Speed things up. + is_dataset_tokenized=True, + prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. + training_tokens=total_training_tokens, # For initial testing I think this is a good number. + train_batch_size_tokens=4096, + # Test the start and end pos offset + start_pos_offset=start_pos_offset, + end_pos_offset=end_pos_offset, + # Loss Function + ## Reconstruction Coefficient. + # Buffer details won't matter in we cache / shuffle our activations ahead of time. + n_batches_in_buffer=n_batches_in_buffer, + store_batch_size_prompts=store_batch_size, + normalize_activations="none", + # + shuffle_every_n_buffers=2, + n_shuffles_with_last_section=1, + n_shuffles_in_entire_dir=1, + n_shuffles_final=1, + # Misc + device=device, + seed=42, + dtype="float16", + ) + + # look at the next cell to see some instruction for what to do while this is running. + CacheActivationsRunner(cfg).run() + + assert os.path.exists(tmp_path) + + # assert that there are n_buffer files in the directory. + assert len(os.listdir(tmp_path)) == n_buffers + + for _, buffer_file in enumerate(os.listdir(tmp_path)): + path_to_file = Path(tmp_path) / buffer_file + with safe_open(path_to_file, framework="pt", device=str(device)) as f: # type: ignore + buffer = f.get_tensor("activations") + assert buffer.shape == ( + tokens_in_buffer, + 1, + cfg.d_in, + ) From bdbb585c72301fbead581207398996f938175e81 Mon Sep 17 00:00:00 2001 From: liuman Date: Wed, 2 Oct 2024 17:59:01 +0100 Subject: [PATCH 17/29] feat: Enabling loading of start and end pos offset from saes. Adding tests for this --- sae_lens/training/activations_store.py | 6 ++---- tests/unit/helpers.py | 4 ++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 222bc7db0..7111f41a0 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -99,8 +99,6 @@ def from_sae( model: HookedRootModule, sae: SAE, context_size: int | None = None, - start_pos_offset: int = 0, - end_pos_offset: int = 0, dataset: HfDataset | str | None = None, streaming: bool = True, store_batch_size_prompts: int = 8, @@ -118,8 +116,8 @@ def from_sae( hook_layer=sae.cfg.hook_layer, hook_head_index=sae.cfg.hook_head_index, context_size=sae.cfg.context_size if context_size is None else context_size, - start_pos_offset=start_pos_offset, - end_pos_offset=end_pos_offset, + start_pos_offset=sae.cfg.start_pos_offset, + end_pos_offset=sae.cfg.end_pos_offset, prepend_bos=sae.cfg.prepend_bos, streaming=streaming, store_batch_size_prompts=store_batch_size_prompts, diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 78d040548..08264f49d 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -39,6 +39,8 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): checkpoint_path: str dtype: str prepend_bos: bool + start_pos_offset: int + end_pos_offset: int def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: @@ -75,6 +77,8 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: "checkpoint_path": "test/checkpoints", "dtype": "float32", "prepend_bos": True, + "start_pos_offset": 0, + "end_pos_offset": 0, } for key, value in kwargs.items(): From 7f3b76ac065a375c77086f5fa4ff514038e6d2b4 Mon Sep 17 00:00:00 2001 From: liuman Date: Thu, 3 Oct 2024 11:44:57 +0100 Subject: [PATCH 18/29] fix: Renaming variables and a test --- sae_lens/training/activations_store.py | 4 ++-- tests/unit/training/test_activations_store.py | 2 +- tests/unit/training/test_cache_activations_runner.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 7111f41a0..111586d52 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -499,10 +499,10 @@ def get_buffer( total_size = batch_size * n_batches_in_buffer num_layers = 1 # Calculate the effective context size - context_window = list( + training_context_slice = list( range(self.start_pos_offset, context_size - self.end_pos_offset) ) - effective_context_size = len(context_window) + training_context_size = len(training_context_slice) if self.cached_activations_path is not None: # Load the activations from disk diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index eef03a5d6..207e31c37 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -487,7 +487,7 @@ def test_validate_pretokenized_dataset_tokenizer_does_nothing_if_the_dataset_pat validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer) -def test_activations_store_respects_seqpos_slice(ts_model: HookedTransformer): +def test_activations_store_respects_position_offsets(ts_model: HookedTransformer): cfg = build_sae_cfg( context_size=10, seqpos_slice=(2, 8), # Only consider positions 2 to 7 (inclusive) diff --git a/tests/unit/training/test_cache_activations_runner.py b/tests/unit/training/test_cache_activations_runner.py index ee81525f5..b9676bff6 100644 --- a/tests/unit/training/test_cache_activations_runner.py +++ b/tests/unit/training/test_cache_activations_runner.py @@ -232,7 +232,7 @@ def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path context_size = 1024 start_pos_offset = 12 end_pos_offset = 12 - effective_context_size = context_size - start_pos_offset - end_pos_offset + training_context_size = context_size - start_pos_offset - end_pos_offset print(f"n tokens per context: {context_size}") n_batches_in_buffer = 32 print(f"n batches in buffer: {n_batches_in_buffer}") @@ -241,7 +241,7 @@ def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path n_buffers = 3 print(f"n_buffers: {n_buffers}") - tokens_in_buffer = n_batches_in_buffer * store_batch_size * effective_context_size + tokens_in_buffer = n_batches_in_buffer * store_batch_size * training_context_size total_training_tokens = n_buffers * tokens_in_buffer print(f"Total Training Tokens: {total_training_tokens}") From 755ba75bee5160a5964b35bb5ee6ed555c29b19f Mon Sep 17 00:00:00 2001 From: liuman Date: Thu, 3 Oct 2024 11:53:43 +0100 Subject: [PATCH 19/29] adds test for position offests for saes --- tests/unit/training/test_sae_basic.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 530fca2cd..f72754616 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -301,3 +301,23 @@ def test_sae_change_dtype() -> None: sae.to(dtype=torch.float16) assert sae.dtype == torch.float16 assert sae.cfg.dtype == "torch.float16" + + +def test_sae_position_offsets(tmp_path: Path) -> None: + cfg = build_sae_cfg(device="cpu", + context_size = 10, + start_pos_offset = 2, + end_pos_offset = 2, + dtype="float64") + model_path = str(tmp_path) + sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + + assert sae.cfg.start_pos_offset == 2 + assert sae.cfg.end_pos_offset == 2 + + sae.save_model(model_path) + + sae_loaded = sae.load_from_pretrained(model_path, device="cpu") + + assert sae_loaded.cfg.start_pos_offset == 2 + assert sae_loaded.cfg.end_pos_offset == 2 \ No newline at end of file From d680041ea0e7cccf889f3b241e575e0cee4cb108 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Thu, 3 Oct 2024 11:55:37 +0100 Subject: [PATCH 20/29] reformats files with black --- tests/unit/training/test_sae_basic.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index f72754616..a157f8b15 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -304,11 +304,13 @@ def test_sae_change_dtype() -> None: def test_sae_position_offsets(tmp_path: Path) -> None: - cfg = build_sae_cfg(device="cpu", - context_size = 10, - start_pos_offset = 2, - end_pos_offset = 2, - dtype="float64") + cfg = build_sae_cfg( + device="cpu", + context_size=10, + start_pos_offset=2, + end_pos_offset=2, + dtype="float64", + ) model_path = str(tmp_path) sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) @@ -318,6 +320,6 @@ def test_sae_position_offsets(tmp_path: Path) -> None: sae.save_model(model_path) sae_loaded = sae.load_from_pretrained(model_path, device="cpu") - + assert sae_loaded.cfg.start_pos_offset == 2 - assert sae_loaded.cfg.end_pos_offset == 2 \ No newline at end of file + assert sae_loaded.cfg.end_pos_offset == 2 From 776fdd76fc064f70370335cb9d37a176e8280183 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Thu, 3 Oct 2024 12:10:31 +0100 Subject: [PATCH 21/29] Add start and end pos offset to the base sae dict --- sae_lens/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sae_lens/config.py b/sae_lens/config.py index 1f647d954..61fca500e 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -406,6 +406,8 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "activation_fn_str": self.activation_fn, "apply_b_dec_to_input": self.apply_b_dec_to_input, "context_size": self.context_size, + "start_pos_offset": self.start_pos_offset, + "end_pos_offset": self.end_pos_offset, "prepend_bos": self.prepend_bos, "dataset_path": self.dataset_path, "dataset_trust_remote_code": self.dataset_trust_remote_code, From 06254473012b71ecbeaa297cb0c7bd9c4647b899 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Thu, 3 Oct 2024 12:30:19 +0100 Subject: [PATCH 22/29] fix test for sae training runner config with position offsets --- tests/unit/training/test_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 98cad1851..f5c64b458 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -60,6 +60,8 @@ def test_sae_training_runner_config_get_sae_base_parameters(): "hook_head_index": None, "device": "cpu", "context_size": 128, + "start_pos_offset": 0, + "end_pos_offset": 0, "prepend_bos": True, "finetuning_scaling_factor": False, "dataset_path": "", From f7d6a38bddd3de062b246034d87d34faf6f42c35 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Thu, 3 Oct 2024 13:40:23 +0100 Subject: [PATCH 23/29] add a benchmark test to train an SAE on OthelloGPT --- .../test_language_model_sae_runner.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index 29d99c9c5..6d6f6b046 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -292,3 +292,85 @@ def test_language_model_sae_runner_top_k(): assert sae is not None # know whether or not this works by looking at the dashboard! + + +def test_language_model_sae_runner_othellogpt(): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # total_training_steps = 20_000 + total_training_steps = 500 + batch_size = 4096 + total_training_tokens = total_training_steps * batch_size + print(f"Total Training Tokens: {total_training_tokens}") + + lr_warm_up_steps = 0 + lr_decay_steps = 40_000 + print(f"lr_decay_steps: {lr_decay_steps}") + l1_warmup_steps = 10_000 + print(f"l1_warmup_steps: {l1_warmup_steps}") + + cfg = LanguageModelSAERunnerConfig( + # Data Generating Function (Model + Training Distibuion) + model_name="othello-gpt", # othello-gpt model + hook_name="blocks.6.hook_resid_pre", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points) + hook_layer=6, # Only one layer in the model. + d_in=512, # the width of the mlp output. + dataset_path="taufeeque/othellogpt", # this is a tokenized language dataset on Huggingface for OthelloGPT games. + is_dataset_tokenized=True, + streaming=True, # we could pre-download the token dataset if it was small. + # SAE Parameters + mse_loss_normalization=None, # We won't normalize the mse loss, + expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training. + b_dec_init_method="geometric_median", # The geometric median can be used to initialize the decoder weights. + apply_b_dec_to_input=False, # We won't apply the decoder weights to the input. + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + normalize_activations="expected_average_only_in", + # Training Parameters + lr=0.00003, # lower the better, we'll go fairly high to speed up the tutorial. + adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.) + adam_beta2=0.999, + lr_scheduler_name="constant", # constant learning rate with warmup. Could be better schedules out there. + lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially. + lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting. + l1_coefficient=0.001, # will control how sparse the feature activations are + l1_warm_up_steps=l1_warmup_steps, # this can help avoid too many dead features initially. + lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1) + train_batch_size_tokens=batch_size, + context_size=59, # will control the length of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one. + start_pos_offset=5, + end_pos_offset=5, + # Activation Store Parameters + n_batches_in_buffer=32, # controls how many activations we store / shuffle. + training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back. + store_batch_size_prompts=32, + # Resampling protocol + use_ghost_grads=False, # we don't use ghost grads anymore. + feature_sampling_window=500, # this controls our reporting of feature sparsity stats + dead_feature_window=1e6, # would effect resampling or ghost grads if we were using it. + dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it. + # WANDB + log_to_wandb=False, # always use wandb unless you are just testing code. + wandb_project="benchmark", + wandb_log_frequency=100, + eval_every_n_wandb_logs=20, + # Misc + device=device, + seed=42, + n_checkpoints=0, + checkpoint_path="checkpoints", + dtype="torch.float32", + ) + + # look at the next cell to see some instruction for what to do while this is running. + sae = SAETrainingRunner(cfg).run() + + assert sae is not None + # know whether or not this works by looking at the dashboard! From 9f16ff21177b8dcc318f45000c86acdc85abb45a Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Thu, 3 Oct 2024 13:40:50 +0100 Subject: [PATCH 24/29] Remove double import from typing --- tests/unit/training/test_activations_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index 207e31c37..6c8456426 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -1,7 +1,6 @@ from collections.abc import Iterable from math import ceil from typing import Optional -from typing import Optional import numpy as np import pytest From 99ace75aa09d14ac4cc1e09bc19df0dfb22f5625 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Thu, 3 Oct 2024 13:53:59 +0100 Subject: [PATCH 25/29] change dead_feature_window to int --- tests/benchmark/test_language_model_sae_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index 6d6f6b046..166d0c796 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -354,7 +354,7 @@ def test_language_model_sae_runner_othellogpt(): # Resampling protocol use_ghost_grads=False, # we don't use ghost grads anymore. feature_sampling_window=500, # this controls our reporting of feature sparsity stats - dead_feature_window=1e6, # would effect resampling or ghost grads if we were using it. + dead_feature_window=1000000, # would effect resampling or ghost grads if we were using it. dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it. # WANDB log_to_wandb=False, # always use wandb unless you are just testing code. From c0dc5bf27ba8391722b74c373191064094b08fc3 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Fri, 4 Oct 2024 10:37:16 +0100 Subject: [PATCH 26/29] remove print statements from test file --- tests/unit/training/test_cache_activations_runner.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/unit/training/test_cache_activations_runner.py b/tests/unit/training/test_cache_activations_runner.py index b9676bff6..806d91bea 100644 --- a/tests/unit/training/test_cache_activations_runner.py +++ b/tests/unit/training/test_cache_activations_runner.py @@ -104,7 +104,6 @@ def test_load_cached_activations(): tokens_in_buffer = n_batches_in_buffer * store_batch_size * context_size total_training_tokens = n_buffers * tokens_in_buffer - print(f"Total Training Tokens: {total_training_tokens}") # better if we can look at the files cached_activations_fixture_path = os.path.join( @@ -233,17 +232,12 @@ def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path start_pos_offset = 12 end_pos_offset = 12 training_context_size = context_size - start_pos_offset - end_pos_offset - print(f"n tokens per context: {context_size}") n_batches_in_buffer = 32 - print(f"n batches in buffer: {n_batches_in_buffer}") store_batch_size = 1 - print(f"store_batch_size: {store_batch_size}") n_buffers = 3 - print(f"n_buffers: {n_buffers}") tokens_in_buffer = n_batches_in_buffer * store_batch_size * training_context_size total_training_tokens = n_buffers * tokens_in_buffer - print(f"Total Training Tokens: {total_training_tokens}") # better if we can look at the files (change tmp_path to a real path to look at the files) # tmp_path = os.path.join(os.path.dirname(__file__), "tmp") From 9130ff97f17edee032730851d6a1656433cda35f Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Fri, 4 Oct 2024 19:38:43 +0100 Subject: [PATCH 27/29] Rebase on seqpos tuple implementation and remove start/end pos offset --- sae_lens/config.py | 56 +++++----------- sae_lens/training/activations_store.py | 19 ++---- .../test_language_model_sae_runner.py | 6 +- tests/unit/helpers.py | 4 -- .../training/test_cache_activations_runner.py | 12 ++-- tests/unit/training/test_config.py | 66 ++++++------------- tests/unit/training/test_sae_basic.py | 22 ------- 7 files changed, 46 insertions(+), 139 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 61fca500e..61b777586 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -40,8 +40,6 @@ class LanguageModelSAERunnerConfig: streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical. is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized. context_size (int): The context size to use when generating activations on which to train the SAE. - start_pos_offset (int): A positive offset to cut off the start of the sequences used to train the SAE. - end_pos_offset (int): A positive offset to cut off the end of the sequences used to train the SAE. use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations. cached_activations_path (str, optional): The path to the cached activations. d_in (int): The input dimension of the SAE. @@ -62,7 +60,7 @@ class LanguageModelSAERunnerConfig: store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations. train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop. normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output). - seqpos_slice (tuple): Determines slicing of (batch, seq, d_in) activations when constructing batches, during training. Example: for Othello we sometimes use (5, -5). + seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, step_size), e.g. for Othello we sometimes use (5, -5). device (str): The device to use. Usually cuda. act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram. seed (int): The seed to use. @@ -124,12 +122,6 @@ class LanguageModelSAERunnerConfig: streaming: bool = True is_dataset_tokenized: bool = True context_size: int = 128 - start_pos_offset: int = ( - 0 # set to n if you want to exclude first n seq positions from sae training - ) - end_pos_offset: int = ( - 0 # set to n if you want to exclude last n seq positions from sae training - ) use_cached_activations: bool = False cached_activations_path: Optional[str] = ( None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}" @@ -370,18 +362,7 @@ def __post_init__(self): f"The provided context_size is {self.context_size} is negative. Expecting positive context_size." ) - if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size): - raise ValueError( - f"Start position offset {self.start_pos_offset} should be in range [0, {self.context_size}]" - ) - if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size): - raise ValueError( - f"End position offset {self.end_pos_offset} should be in range [0, {self.context_size-1}]" - ) - if self.start_pos_offset + self.end_pos_offset > self.context_size: - raise ValueError( - f"Choice of {self.start_pos_offset=} and {self.end_pos_offset=} is incompatible with {self.context_size=}. We expect start_pos_offset + end_pos_offset < context_size." - ) + _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size) @property def total_training_tokens(self) -> int: @@ -406,8 +387,6 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "activation_fn_str": self.activation_fn, "apply_b_dec_to_input": self.apply_b_dec_to_input, "context_size": self.context_size, - "start_pos_offset": self.start_pos_offset, - "end_pos_offset": self.end_pos_offset, "prepend_bos": self.prepend_bos, "dataset_path": self.dataset_path, "dataset_trust_remote_code": self.dataset_trust_remote_code, @@ -487,12 +466,6 @@ class CacheActivationsRunnerConfig: streaming: bool = True is_dataset_tokenized: bool = True context_size: int = 128 - start_pos_offset: int = ( - 0 # set to n if you want to exclude first n seq positions from sae training - ) - end_pos_offset: int = ( - 0 # set to n if you want to exclude last n seq positions from sae training - ) new_cached_activations_path: Optional[str] = ( None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}" ) @@ -542,18 +515,8 @@ def __post_init__(self): raise ValueError( f"The provided context_size is {self.context_size} is negative. Expecting positive context_size." ) - if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size): - raise ValueError( - f"Start position offset {self.start_pos_offset} should be in range [0, {self.context_size}]" - ) - if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size): - raise ValueError( - f"End position offset {self.end_pos_offset} should be in range [0, {self.context_size-1}]" - ) - if self.start_pos_offset + self.end_pos_offset > self.context_size: - raise ValueError( - f"Choice of {self.start_pos_offset=} and {self.end_pos_offset=} is incompatible with {self.context_size=}. We expect start_pos_offset + end_pos_offset < context_size." - ) + + _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size) @dataclass @@ -640,6 +603,17 @@ def _default_cached_activations_path( return path +def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None: + # Ensure that the step-size is larger or equal to 1 + if len(seqpos) == 3: + step_size = seqpos[2] or 1 + assert ( + step_size > 1 + ), f"Ensure the step_size {seqpos[2]=} for sequence slicing is positive." + # Ensure that the choice of seqpos doesn't end up with an empty list + assert len(list(range(context_size))[slice(*seqpos)]) > 0 + + @dataclass class PretokenizeRunnerConfig: tokenizer_name: str = "gpt2" diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 111586d52..d4a2320de 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -75,8 +75,6 @@ def from_config( hook_layer=cfg.hook_layer, hook_head_index=cfg.hook_head_index, context_size=cfg.context_size, - start_pos_offset=cfg.start_pos_offset, - end_pos_offset=cfg.end_pos_offset, d_in=cfg.d_in, n_batches_in_buffer=cfg.n_batches_in_buffer, total_training_tokens=cfg.training_tokens, @@ -116,8 +114,6 @@ def from_sae( hook_layer=sae.cfg.hook_layer, hook_head_index=sae.cfg.hook_head_index, context_size=sae.cfg.context_size if context_size is None else context_size, - start_pos_offset=sae.cfg.start_pos_offset, - end_pos_offset=sae.cfg.end_pos_offset, prepend_bos=sae.cfg.prepend_bos, streaming=streaming, store_batch_size_prompts=store_batch_size_prompts, @@ -140,8 +136,6 @@ def __init__( hook_layer: int, hook_head_index: int | None, context_size: int, - start_pos_offset: int, - end_pos_offset: int, d_in: int, n_batches_in_buffer: int, total_training_tokens: int, @@ -185,8 +179,6 @@ def __init__( self.hook_layer = hook_layer self.hook_head_index = hook_head_index self.context_size = context_size - self.start_pos_offset = start_pos_offset - self.end_pos_offset = end_pos_offset self.d_in = d_in self.n_batches_in_buffer = n_batches_in_buffer self.half_buffer_size = n_batches_in_buffer // 2 @@ -457,7 +449,11 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - n_batches, n_context = batch_tokens.shape + layerwise_activations = layerwise_activations_cache[self.hook_name][ + :, slice(*self.seqpos_slice) + ] + + n_batches, n_context = layerwise_activations.shape[:2] stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) @@ -498,11 +494,6 @@ def get_buffer( d_in = self.d_in total_size = batch_size * n_batches_in_buffer num_layers = 1 - # Calculate the effective context size - training_context_slice = list( - range(self.start_pos_offset, context_size - self.end_pos_offset) - ) - training_context_size = len(training_context_slice) if self.cached_activations_path is not None: # Load the activations from disk diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index 166d0c796..995af247c 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -306,13 +306,10 @@ def test_language_model_sae_runner_othellogpt(): total_training_steps = 500 batch_size = 4096 total_training_tokens = total_training_steps * batch_size - print(f"Total Training Tokens: {total_training_tokens}") lr_warm_up_steps = 0 lr_decay_steps = 40_000 - print(f"lr_decay_steps: {lr_decay_steps}") l1_warmup_steps = 10_000 - print(f"l1_warmup_steps: {l1_warmup_steps}") cfg = LanguageModelSAERunnerConfig( # Data Generating Function (Model + Training Distibuion) @@ -345,8 +342,7 @@ def test_language_model_sae_runner_othellogpt(): lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1) train_batch_size_tokens=batch_size, context_size=59, # will control the length of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one. - start_pos_offset=5, - end_pos_offset=5, + seqpos_slice=(5, -5), # Activation Store Parameters n_batches_in_buffer=32, # controls how many activations we store / shuffle. training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back. diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 08264f49d..78d040548 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -39,8 +39,6 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): checkpoint_path: str dtype: str prepend_bos: bool - start_pos_offset: int - end_pos_offset: int def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: @@ -77,8 +75,6 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: "checkpoint_path": "test/checkpoints", "dtype": "float32", "prepend_bos": True, - "start_pos_offset": 0, - "end_pos_offset": 0, } for key, value in kwargs.items(): diff --git a/tests/unit/training/test_cache_activations_runner.py b/tests/unit/training/test_cache_activations_runner.py index 806d91bea..f10ee1c9a 100644 --- a/tests/unit/training/test_cache_activations_runner.py +++ b/tests/unit/training/test_cache_activations_runner.py @@ -219,7 +219,7 @@ def W_E(self) -> torch.Tensor: # The way to run this with this command: # poetry run py.test tests/unit/test_cache_activations_runner.py --profile-svg -s -def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path): +def test_cache_activations_runner_with_valid_seqpos(tmp_path: Path): if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): @@ -229,9 +229,8 @@ def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path # total_training_steps = 20_000 context_size = 1024 - start_pos_offset = 12 - end_pos_offset = 12 - training_context_size = context_size - start_pos_offset - end_pos_offset + seqpos_slice = (12, -12) + training_context_size = len(range(context_size)[slice(*seqpos_slice)]) n_batches_in_buffer = 32 store_batch_size = 1 n_buffers = 3 @@ -259,9 +258,8 @@ def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. training_tokens=total_training_tokens, # For initial testing I think this is a good number. train_batch_size_tokens=4096, - # Test the start and end pos offset - start_pos_offset=start_pos_offset, - end_pos_offset=end_pos_offset, + # Test the sequence slicing + seqpos_slice=seqpos_slice, # Loss Function ## Reconstruction Coefficient. # Buffer details won't matter in we cache / shuffle our activations ahead of time. diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index f5c64b458..6b1a93d95 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -60,8 +60,6 @@ def test_sae_training_runner_config_get_sae_base_parameters(): "hook_head_index": None, "device": "cpu", "context_size": 128, - "start_pos_offset": 0, - "end_pos_offset": 0, "prepend_bos": True, "finetuning_scaling_factor": False, "dataset_path": "", @@ -98,69 +96,45 @@ def test_sae_training_runner_config_expansion_factor(): assert cfg.expansion_factor == 4 -@pytest.mark.parametrize( - "start_pos_offset, end_pos_offset, expected_error", - [ - (-1, 0, ValueError), - (0, 0, None), - (10, 0, None), - (11, 0, ValueError), - (0, -1, ValueError), - (0, 10, ValueError), - (0, 11, ValueError), - (5, 5, None), - (6, 5, ValueError), - (3, 4, None), - ], -) -def test_sae_training_runner_config_start_end_pos_offset( - start_pos_offset: int, end_pos_offset: int, expected_error: Optional[ValueError] +test_cases_for_seqpos = [ + ((None, 10, -1), AssertionError), + ((None, 10, 0), AssertionError), + ((5, 5, None), AssertionError), + ((6, 3, None), AssertionError), +] + + +@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos) +def test_sae_training_runner_config_seqpos( + seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError] ): context_size = 10 - if expected_error is ValueError: + if expected_error is AssertionError: with pytest.raises(expected_error): LanguageModelSAERunnerConfig( - start_pos_offset=start_pos_offset, - end_pos_offset=end_pos_offset, + seqpos_slice=seqpos_slice, context_size=context_size, ) else: LanguageModelSAERunnerConfig( - start_pos_offset=start_pos_offset, - end_pos_offset=end_pos_offset, + seqpos_slice=seqpos_slice, context_size=context_size, ) -@pytest.mark.parametrize( - "start_pos_offset, end_pos_offset, expected_error", - [ - (-1, 0, ValueError), - (0, 0, None), - (10, 0, None), - (11, 0, ValueError), - (0, -1, ValueError), - (0, 10, ValueError), - (0, 11, ValueError), - (5, 5, None), - (6, 5, ValueError), - (3, 4, None), - ], -) -def test_cache_activations_runner_config_start_end_pos_offset( - start_pos_offset: int, end_pos_offset: int, expected_error: Optional[ValueError] +@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos) +def test_cache_activations_runner_config_seqpos( + seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError] ): context_size = 10 - if expected_error is ValueError: + if expected_error is AssertionError: with pytest.raises(expected_error): CacheActivationsRunnerConfig( - start_pos_offset=start_pos_offset, - end_pos_offset=end_pos_offset, + seqpos_slice=seqpos_slice, context_size=context_size, ) else: CacheActivationsRunnerConfig( - start_pos_offset=start_pos_offset, - end_pos_offset=end_pos_offset, + seqpos_slice=seqpos_slice, context_size=context_size, ) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index a157f8b15..530fca2cd 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -301,25 +301,3 @@ def test_sae_change_dtype() -> None: sae.to(dtype=torch.float16) assert sae.dtype == torch.float16 assert sae.cfg.dtype == "torch.float16" - - -def test_sae_position_offsets(tmp_path: Path) -> None: - cfg = build_sae_cfg( - device="cpu", - context_size=10, - start_pos_offset=2, - end_pos_offset=2, - dtype="float64", - ) - model_path = str(tmp_path) - sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) - - assert sae.cfg.start_pos_offset == 2 - assert sae.cfg.end_pos_offset == 2 - - sae.save_model(model_path) - - sae_loaded = sae.load_from_pretrained(model_path, device="cpu") - - assert sae_loaded.cfg.start_pos_offset == 2 - assert sae_loaded.cfg.end_pos_offset == 2 From 125b2751ea786ff0bdd6941268de90dd4c75c33d Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Wed, 9 Oct 2024 14:46:44 +0200 Subject: [PATCH 28/29] Reword docstring for seqpos to be clearer. --- sae_lens/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 61b777586..5bc2a1a7d 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -60,7 +60,7 @@ class LanguageModelSAERunnerConfig: store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations. train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop. normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output). - seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, step_size), e.g. for Othello we sometimes use (5, -5). + seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0. device (str): The device to use. Usually cuda. act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram. seed (int): The seed to use. From 552eea6a3cd620a870a1628febc4b8c6a21d6152 Mon Sep 17 00:00:00 2001 From: Oliver De Candido Date: Wed, 9 Oct 2024 15:22:07 +0200 Subject: [PATCH 29/29] Added script to train an SAE on othelloGPT --- ...raining_a_sparse_autoencoder_othelloGPT.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 scripts/training_a_sparse_autoencoder_othelloGPT.py diff --git a/scripts/training_a_sparse_autoencoder_othelloGPT.py b/scripts/training_a_sparse_autoencoder_othelloGPT.py new file mode 100644 index 000000000..ee48914d2 --- /dev/null +++ b/scripts/training_a_sparse_autoencoder_othelloGPT.py @@ -0,0 +1,107 @@ +import os + +import torch + +from sae_lens import ( + SAE, + HookedSAETransformer, + LanguageModelSAERunnerConfig, + SAETrainingRunner, + upload_saes_to_huggingface, +) + +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print("Using device:", device) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +model_name = "othello-gpt" +model = HookedSAETransformer.from_pretrained(model_name) + +dataset_path = "taufeeque/othellogpt" +context_size = 59 + +layer = 5 +training_tokens = int(1e3) +train_batch_size_tokens = 2048 +n_steps = int(training_tokens / train_batch_size_tokens) + +print(LanguageModelSAERunnerConfig()) +runner_cfg = LanguageModelSAERunnerConfig( + # + # Data generation + model_name=model_name, + hook_name=f"blocks.{layer}.mlp.hook_post", + hook_layer=layer, + d_in=model.cfg.d_mlp, + dataset_path=dataset_path, + is_dataset_tokenized=True, + prepend_bos=False, + streaming=True, + train_batch_size_tokens=train_batch_size_tokens, + context_size=context_size, + seqpos_slice=(5, -5), + # + # SAE achitecture + architecture="gated", + expansion_factor=8, + b_dec_init_method="zeros", + apply_b_dec_to_input=True, + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + # + # Activations store + n_batches_in_buffer=32, + store_batch_size_prompts=16, + training_tokens=training_tokens, + # + # Training hyperparameters (standard) + lr=2e-4, + adam_beta1=0.9, + adam_beta2=0.999, + lr_scheduler_name="constant", + lr_warm_up_steps=int(0.2 * n_steps), + lr_decay_steps=int(0.2 * n_steps), + # + # Training hyperparameters (SAE-specific) + l1_coefficient=5, + l1_warm_up_steps=int(0.2 * n_steps), + use_ghost_grads=False, + feature_sampling_window=1000, + dead_feature_window=500, + dead_feature_threshold=1e-5, + # + # Logging / evals + log_to_wandb=True, + wandb_project=f"othello_gpt_sae_{layer=}", + wandb_log_frequency=30, + eval_every_n_wandb_logs=10, + checkpoint_path="checkpoints", + # + # Misc. + device=str(device), + seed=42, + n_checkpoints=5, + dtype="float32", +) + +# t.set_grad_enabled(True) +runner = SAETrainingRunner(runner_cfg) +sae = runner.run() + +hf_repo_id = "callummcdougall/arena-demos-othellogpt" +sae_id = "blocks.5.mlp.hook_post-v1" + +upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id) + +othellogpt_sae = SAE.from_pretrained( + release=hf_repo_id, sae_id=sae_id, device=str(device) +)[0]