Fix incorrect buffer size calculations in ActivationsStore#692
Conversation
… config fields Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…erConfig, test pases Add four missing and drop normalize_activations, which lives on the SAE config rather than the runner's. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR adjusts activations buffering to reflect the full token capacity of the activations buffer and adds tests/typing checks to keep configuration helpers aligned with config fields.
Changes:
- Fix
buffer_sizepassed into activation mixing/concat iterators to account for prompts × context (token count). - Add a regression test exercising the “train on entire activations buffer” case.
- Add a test to ensure
LanguageModelSAERunnerConfigDictstays in sync withLanguageModelSAERunnerConfigdataclass fields and update the TypedDict accordingly.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tests/training/test_activations_store.py | Adds a new regression test and pins activations_mixing_fraction in existing tests. |
| tests/test_util.py | Adds a sync test to ensure the TypedDict matches dataclass config fields. |
| tests/helpers.py | Updates LanguageModelSAERunnerConfigDict keys to match runner config fields. |
| sae_lens/training/activations_store.py | Fixes buffer sizing math and clarifies naming in activation shaping. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def test_activations_store__can_train_on_entire_activations_buffer( | ||
| ts_model: HookedTransformer, | ||
| ): | ||
| n_batches_in_buffer = 4 | ||
| store_batch_size_prompts = 8 | ||
| context_size = 5 | ||
| train_batch_size_tokens = ( | ||
| n_batches_in_buffer * store_batch_size_prompts * context_size | ||
| ) | ||
| cfg = build_runner_cfg( | ||
| n_batches_in_buffer=n_batches_in_buffer, | ||
| store_batch_size_prompts=store_batch_size_prompts, | ||
| context_size=context_size, | ||
| train_batch_size_tokens=train_batch_size_tokens, | ||
| ) | ||
| activation_store = ActivationsStore.from_config(ts_model, cfg) | ||
| activation_store.next_batch() |
| batch_size, n_context = layerwise_activations.shape[:2] | ||
|
|
||
| stacked_activations = torch.zeros((n_batches, n_context, self.d_in)) | ||
| stacked_activations = torch.zeros((batch_size, n_context, self.d_in)) |
Trying to consume the expected number of tokens from the activations buffer reveals that it is smaller than expected (mixing_buffer raises ValueError on `buffer_size < batch_size`) test_language_model_sae_runner_othellogpt in benchmark/test_language_model_sae_runner.py also fails with the same error. It probably went unnoticed because it's not run in CI.
… test unexpectedly fails!) Fixes failing test in last commit Also fixes test_language_model_sae_runner_othellogpt However, test_activations_next_batch_excludes_special_tokens unexpectedly starts failing!
…xes buffer, test passes again Mixing caused the test to incorrectly conclude bos is missing when it wasn't
chanind
left a comment
There was a problem hiding this comment.
I'm hesitant to merge this PR, since this is a breaking change that would cause existing configs to OOM. I agree the config parameter name is confusing, it's really n_sequences_in_buffer rather than n_batches_in_buffer. Could we instead just change the config param name to n_sequences_in_buffer instead, deprecate the old n_batches_in_buffer (while still allowing it to work backwards compatibly, we can use an @deprecated flag)? That's probably more clear anyway since batches is already a confusing term given there's the LLM batches vs SAE batches, and it's confusing which is which.
|
Makes sense to me. One argument against: perhaps some research was done given an incorrect assumption on how many batches were in the buffer and mixed together. But I suppose making the deprecation and documentation clear on this is good enough; better than unexpectedly OOM-ing, and probably also better than just making a completely breaking change (raising an exception on using old parameter name, not just a deprecation warning).
Agreed, it confused me for sure. |
If there was research done with an incorrect asumption on the batches in the mixing buffer there's not much we can do about that at this point IMO. I think with a deprecation warning we can output a helpful warning explaining everything to the user and what the issue is as well the new parameter name that's probably the gentlest way we can fix things. If you feel strongly that we should error and crash the run for the user rather than just outputting deprecation warning with a helpful message I'm open to that too, as long as it's extremely clear to the user what to do to fix things. But I would argue for the warning rather than error though, since I imagine a lot users just want to train an SAE and don't care about parameter names and likely don't know what the mixing buffer is to begin with. |
|
The soft deprecation warning seems like the best solution to me, too. |
Description
ActivationsStore.get_data_loaderand.get_multi_hook_data_loaderboth calculated the buffer size incorrectly, which I stumbled upon via an assertion error runningtest_language_model_sae_runner_othellogptunderbenchmark/test_language_model_sae_runner.py- which doesn't run in CI which is probably why it went undetected.A couple of other tangential issues also came up in the way to fix the main problem, and are also fixed here.
Review Guidance
Best reviewed commit-by-commit.
Type of change
Please delete options that are not relevant.
expectedbefore) - buffer size would now be larger in existing code.Checklist:
You have tested formatting, typing and tests
make check-cito check format and linting. (you can runmake formatto format code if needed.)Performance Check.
Tested
benchmark/test_language_model_sae_runner.py:test_language_model_sae_runner_top_k, losses seems unchanged, but on my machine the test now takes more time (~20secs vs previous ~15secs) - possibly due to the larger memory footprint.