Skip to content

Fix incorrect buffer size calculations in ActivationsStore#692

Open
danra wants to merge 6 commits into
decoderesearch:mainfrom
danra:fix_buffer_size2
Open

Fix incorrect buffer size calculations in ActivationsStore#692
danra wants to merge 6 commits into
decoderesearch:mainfrom
danra:fix_buffer_size2

Conversation

@danra

@danra danra commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

Description

ActivationsStore.get_data_loader and .get_multi_hook_data_loader both calculated the buffer size incorrectly, which I stumbled upon via an assertion error running test_language_model_sae_runner_othellogpt under benchmark/test_language_model_sae_runner.py - which doesn't run in CI which is probably why it went undetected.

A couple of other tangential issues also came up in the way to fix the main problem, and are also fixed here.

Review Guidance

Best reviewed commit-by-commit.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected before) - buffer size would now be larger in existing code.

Checklist:

  • I have not made corresponding changes to the documentation - is this significant enough to be documented as a breaking change? Performance might differ (see below) as well as mixing behavior (which would now work as it was supposed to before)
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and tests

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Performance Check.

Tested benchmark/test_language_model_sae_runner.py:test_language_model_sae_runner_top_k, losses seems unchanged, but on my machine the test now takes more time (~20secs vs previous ~15secs) - possibly due to the larger memory footprint.

danra and others added 3 commits June 6, 2026 15:35
… config fields

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…erConfig, test pases

Add four missing and drop normalize_activations, which lives on the SAE config
rather than the runner's.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 6, 2026 22:58

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR adjusts activations buffering to reflect the full token capacity of the activations buffer and adds tests/typing checks to keep configuration helpers aligned with config fields.

Changes:

  • Fix buffer_size passed into activation mixing/concat iterators to account for prompts × context (token count).
  • Add a regression test exercising the “train on entire activations buffer” case.
  • Add a test to ensure LanguageModelSAERunnerConfigDict stays in sync with LanguageModelSAERunnerConfig dataclass fields and update the TypedDict accordingly.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
tests/training/test_activations_store.py Adds a new regression test and pins activations_mixing_fraction in existing tests.
tests/test_util.py Adds a sync test to ensure the TypedDict matches dataclass config fields.
tests/helpers.py Updates LanguageModelSAERunnerConfigDict keys to match runner config fields.
sae_lens/training/activations_store.py Fixes buffer sizing math and clarifies naming in activation shaping.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +189 to +205
def test_activations_store__can_train_on_entire_activations_buffer(
ts_model: HookedTransformer,
):
n_batches_in_buffer = 4
store_batch_size_prompts = 8
context_size = 5
train_batch_size_tokens = (
n_batches_in_buffer * store_batch_size_prompts * context_size
)
cfg = build_runner_cfg(
n_batches_in_buffer=n_batches_in_buffer,
store_batch_size_prompts=store_batch_size_prompts,
context_size=context_size,
train_batch_size_tokens=train_batch_size_tokens,
)
activation_store = ActivationsStore.from_config(ts_model, cfg)
activation_store.next_batch()
Comment on lines +667 to +669
batch_size, n_context = layerwise_activations.shape[:2]

stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
stacked_activations = torch.zeros((batch_size, n_context, self.d_in))
danra added 3 commits June 6, 2026 17:29
Trying to consume the expected number of tokens from the activations buffer reveals
that it is smaller than expected (mixing_buffer raises ValueError on `buffer_size < batch_size`)

test_language_model_sae_runner_othellogpt in benchmark/test_language_model_sae_runner.py
also fails with the same error. It probably went unnoticed because it's not run in CI.
… test unexpectedly fails!)

Fixes failing test in last commit
Also fixes test_language_model_sae_runner_othellogpt

However, test_activations_next_batch_excludes_special_tokens unexpectedly starts failing!
…xes buffer, test passes again

Mixing caused the test to incorrectly conclude bos is missing when it wasn't
@danra danra force-pushed the fix_buffer_size2 branch from e79b9bb to 145cd7b Compare June 7, 2026 00:30

@chanind chanind left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hesitant to merge this PR, since this is a breaking change that would cause existing configs to OOM. I agree the config parameter name is confusing, it's really n_sequences_in_buffer rather than n_batches_in_buffer. Could we instead just change the config param name to n_sequences_in_buffer instead, deprecate the old n_batches_in_buffer (while still allowing it to work backwards compatibly, we can use an @deprecated flag)? That's probably more clear anyway since batches is already a confusing term given there's the LLM batches vs SAE batches, and it's confusing which is which.

@danra

danra commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Makes sense to me.

One argument against: perhaps some research was done given an incorrect assumption on how many batches were in the buffer and mixed together. But I suppose making the deprecation and documentation clear on this is good enough; better than unexpectedly OOM-ing, and probably also better than just making a completely breaking change (raising an exception on using old parameter name, not just a deprecation warning).

That's probably more clear anyway since batches is already a confusing term given there's the LLM batches vs SAE batches, and it's confusing which is which.

Agreed, it confused me for sure.

@chanind

chanind commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

One argument against: perhaps some research was done given an incorrect assumption on how many batches were in the buffer and mixed together. But I suppose making the deprecation and documentation clear on this is good enough; better than unexpectedly OOM-ing, and probably also better than just making a completely breaking change (raising an exception on using old parameter name, not just a deprecation warning).

If there was research done with an incorrect asumption on the batches in the mixing buffer there's not much we can do about that at this point IMO. I think with a deprecation warning we can output a helpful warning explaining everything to the user and what the issue is as well the new parameter name that's probably the gentlest way we can fix things. If you feel strongly that we should error and crash the run for the user rather than just outputting deprecation warning with a helpful message I'm open to that too, as long as it's extremely clear to the user what to do to fix things. But I would argue for the warning rather than error though, since I imagine a lot users just want to train an SAE and don't care about parameter names and likely don't know what the mixing buffer is to begin with.

@danra

danra commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

The soft deprecation warning seems like the best solution to me, too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants