Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def get_batch_tokens(
@torch.no_grad()
def get_activations(self, batch_tokens: torch.Tensor):
"""
Returns activations of shape (batches, context, num_layers, d_in)
Returns activations of shape (batch_size, context, num_layers, d_in)

d_in may result from a concatenated head dimension.
"""
Expand All @@ -664,9 +664,9 @@ def get_activations(self, batch_tokens: torch.Tensor):
:, slice(*self.seqpos_slice)
]

n_batches, n_context = layerwise_activations.shape[:2]
batch_size, n_context = layerwise_activations.shape[:2]

stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
stacked_activations = torch.zeros((batch_size, n_context, self.d_in))
Comment on lines +667 to +669

if self.hook_head_index is not None:
stacked_activations[:, :] = layerwise_activations[
Expand All @@ -675,13 +675,13 @@ def get_activations(self, batch_tokens: torch.Tensor):
elif layerwise_activations.ndim > 3: # if we have a head dimension
try:
stacked_activations[:, :] = layerwise_activations.view(
n_batches, n_context, -1
batch_size, n_context, -1
)
except RuntimeError as e:
logger.error(f"Error during view operation: {e}")
logger.info("Attempting to use reshape instead...")
stacked_activations[:, :] = layerwise_activations.reshape(
n_batches, n_context, -1
batch_size, n_context, -1
)
else:
stacked_activations[:, :] = layerwise_activations
Expand Down Expand Up @@ -821,7 +821,9 @@ def get_data_loader(
Return an auto-refilling stream of filtered and mixed activations.
"""
return mixing_buffer(
buffer_size=self.n_batches_in_buffer * self.training_context_size,
buffer_size=self.n_batches_in_buffer
* self.store_batch_size_prompts
* self.training_context_size,
batch_size=self.train_batch_size_tokens,
activations_loader=self._iterate_filtered_activations(),
mix_fraction=self.activations_mixing_fraction,
Expand Down Expand Up @@ -943,7 +945,9 @@ def get_multi_hook_data_loader(
"via from_config_multi_hook"
)
return multi_hook_concat_split_iter(
buffer_size=self.n_batches_in_buffer * self.training_context_size,
buffer_size=self.n_batches_in_buffer
* self.store_batch_size_prompts
* self.training_context_size,
batch_size=self.train_batch_size_tokens,
activations_loader=self._iterate_filtered_multi_hook_activations(),
hook_names=list(self._hook_names),
Expand Down
5 changes: 4 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False):
model_name: str
model_class_name: str
hook_name: str
hook_eval: str
hook_head_index: int | None
dataset_path: str
dataset_trust_remote_code: bool
Expand All @@ -76,10 +77,10 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False):
n_batches_in_buffer: int
training_tokens: int
store_batch_size_prompts: int
normalize_activations: str
seqpos_slice: tuple[int | None, ...] | Sequence[int | None]
disable_concat_sequences: bool
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None
activations_mixing_fraction: float
device: str
llm_device: str | None
act_store_device: str | None
Expand Down Expand Up @@ -112,12 +113,14 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False):
checkpoint_path: str | None
save_final_checkpoint: bool
output_path: str | None
resume_from_checkpoint: str | None
verbose: bool
model_kwargs: dict[str, Any]
model_from_pretrained_kwargs: dict[str, Any] | None
sae_lens_version: str
sae_lens_training_version: str
exclude_special_tokens: bool | list[int]
n_batches_for_norm_estimate: int


# Base TrainingSAEConfig fields + all architecture specific fields
Expand Down
10 changes: 9 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass
from dataclasses import dataclass, fields
from pathlib import Path

import pytest
import torch
from transformer_lens import HookedTransformer

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.util import (
cosine_similarities,
dtype_to_str,
Expand All @@ -15,6 +16,7 @@
str_to_dtype,
temporary_seed,
)
from tests.helpers import LanguageModelSAERunnerConfigDict


@pytest.mark.parametrize(
Expand Down Expand Up @@ -362,3 +364,9 @@ def test_temporary_seed_none_is_noop():
assert not torch.equal(before, after)
# And we should still get a valid tensor
assert sample.shape == (1,)


def test_language_model_sae_runner_config_dict_matches_config_fields():
config_fields = {f.name for f in fields(LanguageModelSAERunnerConfig)}
dict_fields = set(LanguageModelSAERunnerConfigDict.__annotations__.keys())
assert config_fields == dict_fields
22 changes: 22 additions & 0 deletions tests/training/test_activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,26 @@ def test_activations_store__shapes_look_correct_with_real_models_and_datasets(
assert tok_batch.device == store.device


def test_activations_store__can_train_on_entire_activations_buffer(
ts_model: HookedTransformer,
):
n_batches_in_buffer = 4
store_batch_size_prompts = 8
context_size = 5
train_batch_size_tokens = (
n_batches_in_buffer * store_batch_size_prompts * context_size
)
cfg = build_runner_cfg(
n_batches_in_buffer=n_batches_in_buffer,
store_batch_size_prompts=store_batch_size_prompts,
context_size=context_size,
train_batch_size_tokens=train_batch_size_tokens,
)
activation_store = ActivationsStore.from_config(ts_model, cfg)
batch = activation_store.next_batch()
assert batch.shape[0] == train_batch_size_tokens


def test_activations_store__get_activations_head_hook(ts_model: HookedTransformer):
cfg = build_runner_cfg(
hook_name="blocks.0.attn.hook_q",
Expand Down Expand Up @@ -718,13 +738,15 @@ def test_activations_next_batch_excludes_special_tokens(
store_batch_size_prompts=2,
hook_name=hook_name,
train_batch_size_tokens=5,
activations_mixing_fraction=0.0,
)
cfg = build_runner_cfg(
exclude_special_tokens=True,
context_size=5,
store_batch_size_prompts=2,
hook_name=hook_name,
train_batch_size_tokens=5,
activations_mixing_fraction=0.0,
)
dataset = Dataset.from_list([{"text": "hello world"}] * 100)
_, cache = ts_model.run_with_cache(dataset[0]["text"])
Expand Down
Loading