From 75769137cefb78616162107b1170ecf763055597 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 16:58:07 -0500 Subject: [PATCH 1/8] feat: add full tokenization options to cache activations runner --- sae_lens/cache_activations_runner.py | 17 +++++++++++++++-- sae_lens/config.py | 13 +++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index bbebefc0b..ddbed06d6 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -16,7 +16,7 @@ from sae_lens.config import CacheActivationsRunnerConfig from sae_lens.load_model import load_model from sae_lens.training.activations_store import ActivationsStore -from sae_lens.util import str_to_dtype +from sae_lens.util import get_special_token_ids, str_to_dtype def _mk_activations_store( @@ -28,6 +28,16 @@ def _mk_activations_store( Internal method used in CacheActivationsRunner. Used to create a cached dataset from a ActivationsStore. """ + device = torch.device("cpu") # since we're saving to disk + exclude_special_tokens = cfg.exclude_special_tokens + if exclude_special_tokens is False: + exclude_special_tokens = None + if exclude_special_tokens is True: + exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore + if exclude_special_tokens is not None: + exclude_special_tokens = torch.tensor( + exclude_special_tokens, dtype=torch.long, device=device + ) return ActivationsStore( model=model, dataset=override_dataset or cfg.dataset_path, @@ -42,13 +52,16 @@ def _mk_activations_store( train_batch_size_tokens=-1, prepend_bos=cfg.prepend_bos, normalize_activations="none", - device=torch.device("cpu"), # since we're saving to disk + device=device, dtype=cfg.dtype, cached_activations_path=None, model_kwargs=cfg.model_kwargs, autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, seqpos_slice=cfg.seqpos_slice, + exclude_special_tokens=exclude_special_tokens, + disable_concat_sequences=cfg.disable_concat_sequences, + sequence_separator_token=cfg.sequence_separator_token, ) diff --git a/sae_lens/config.py b/sae_lens/config.py index 8ca5d0e74..709f61bd2 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -496,6 +496,9 @@ class CacheActivationsRunnerConfig: streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical. autocast_lm (bool): Whether to use autocast during activation fetching. dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface. + disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences. + sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `` token. + exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs. """ dataset_path: str @@ -533,6 +536,11 @@ class CacheActivationsRunnerConfig: streaming: bool = True autocast_lm: bool = False dataset_trust_remote_code: bool | None = None + disable_concat_sequences: bool = False + sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = ( + special_token_field(default="bos") + ) + exclude_special_tokens: bool | list[int] = False def __post_init__(self): # Automatically determine context_size if dataset is tokenized @@ -562,6 +570,11 @@ def __post_init__(self): self.dataset_path, self.model_name, self.hook_name, None ) + if isinstance(self.exclude_special_tokens, list) and not all( + isinstance(x, int) for x in self.exclude_special_tokens + ): + raise ValueError("exclude_special_tokens list must contain only integers") + @property def sliced_context_size(self) -> int: if self.seqpos_slice is not None: From 6979716ecaa5a738606dcba35f72c3533ab436c1 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 17:12:31 -0500 Subject: [PATCH 2/8] fixing shuffling behavior for on-disk activations --- sae_lens/cache_activations_runner.py | 18 ++++++- tests/test_cache_activations_runner.py | 69 ++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index ddbed06d6..bec9de6ae 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -234,7 +234,7 @@ def _consolidate_shards( "_split": None, } - # fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint) + # fingerprint is generated from dataset.__getstate__ (not including _fingerprint) with open(output_dir / "state.json", "w") as f: json.dump(new_state, f, indent=2) @@ -299,6 +299,22 @@ def run(self) -> Dataset: if self.cfg.shuffle: logger.info("Shuffling...") dataset = dataset.shuffle(seed=self.cfg.seed) + # Save the shuffled dataset back to disk + # We need to save to a temp location first since datasets can't overwrite themselves + shuffled_path = final_cached_activation_path / ".shuffled" + dataset.save_to_disk(str(shuffled_path)) + # Remove old unshuffled data and replace with shuffled + for item in final_cached_activation_path.iterdir(): + if item.name != ".shuffled": + if item.is_dir(): + shutil.rmtree(item) + else: + item.unlink() + for item in shuffled_path.iterdir(): + shutil.move(str(item), str(final_cached_activation_path / item.name)) + shuffled_path.rmdir() + # Reload the dataset from the new location + dataset = Dataset.load_from_disk(str(final_cached_activation_path)) if self.cfg.hf_repo_id: logger.info("Pushing to Huggingface Hub...") diff --git a/tests/test_cache_activations_runner.py b/tests/test_cache_activations_runner.py index 40ef2958b..4fa77feb2 100644 --- a/tests/test_cache_activations_runner.py +++ b/tests/test_cache_activations_runner.py @@ -452,3 +452,72 @@ def test_cache_activations_runner_shuffling(tmp_path: Path): torch.from_numpy(unshuffled_acts_array[i]), torch.from_numpy(shuffled_acts_array[shuffled_idx]), ) + + +def test_cache_activations_runner_shuffled_saved_to_disk(tmp_path: Path): + """Test that when shuffle=True, the shuffled dataset is saved to disk (not just returned).""" + # Create test dataset with arbitrary unique tokens + tokenizer = HookedTransformer.from_pretrained("gelu-1l").tokenizer + text = "".join( + [ + " " + word[1:] + for word in tokenizer.vocab # type: ignore + if word[0] == "Ġ" and word[1:].isascii() and word.isalnum() + ] + ) + dataset = Dataset.from_list([{"text": text}]) + + # Create configs for unshuffled and shuffled versions + unshuffled_path = tmp_path / "unshuffled" + shuffled_path = tmp_path / "shuffled" + + unshuffled_cfg = _default_cfg( + unshuffled_path, + context_size=3, + batch_size=2, + dataset_num_rows=8, + shuffle=False, + ) + shuffled_cfg = _default_cfg( + shuffled_path, + context_size=3, + batch_size=2, + dataset_num_rows=8, + shuffle=True, + ) + + # Run both + unshuffled_runner = CacheActivationsRunner(unshuffled_cfg, override_dataset=dataset) + unshuffled_runner.run() + + shuffled_runner = CacheActivationsRunner(shuffled_cfg, override_dataset=dataset) + returned_shuffled_ds = shuffled_runner.run() + returned_shuffled_ds.set_format("torch") + + # Load datasets from disk + unshuffled_from_disk = datasets.load_from_disk(str(unshuffled_path)) + shuffled_from_disk = datasets.load_from_disk(str(shuffled_path)) + unshuffled_from_disk.set_format("torch") + shuffled_from_disk.set_format("torch") + + hook_name = unshuffled_cfg.hook_name + + # Verify the shuffled dataset on disk is different from the unshuffled one + unshuffled_tokens = np.array(unshuffled_from_disk["token_ids"]) + shuffled_tokens_on_disk = np.array(shuffled_from_disk["token_ids"]) + assert not np.array_equal( + unshuffled_tokens, shuffled_tokens_on_disk + ), "Shuffled dataset on disk should be different from unshuffled" + + # Verify the shuffled dataset on disk matches what was returned + returned_tokens = np.array(returned_shuffled_ds["token_ids"]) + assert np.array_equal( + shuffled_tokens_on_disk, returned_tokens + ), "Dataset on disk should match returned dataset" + + # Also verify activations match + shuffled_acts_on_disk = np.array(shuffled_from_disk[hook_name]) + returned_acts = np.array(returned_shuffled_ds[hook_name]) + assert np.array_equal( + shuffled_acts_on_disk, returned_acts + ), "Activations on disk should match returned activations" From 846f0508aee03e77d97e8f6aff2fb055f70a4f94 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 17:32:14 -0500 Subject: [PATCH 3/8] removing from the cache config, since its a noop --- sae_lens/cache_activations_runner.py | 12 +----------- sae_lens/config.py | 7 ------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index bec9de6ae..2ca77f263 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -16,7 +16,7 @@ from sae_lens.config import CacheActivationsRunnerConfig from sae_lens.load_model import load_model from sae_lens.training.activations_store import ActivationsStore -from sae_lens.util import get_special_token_ids, str_to_dtype +from sae_lens.util import str_to_dtype def _mk_activations_store( @@ -29,15 +29,6 @@ def _mk_activations_store( from a ActivationsStore. """ device = torch.device("cpu") # since we're saving to disk - exclude_special_tokens = cfg.exclude_special_tokens - if exclude_special_tokens is False: - exclude_special_tokens = None - if exclude_special_tokens is True: - exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore - if exclude_special_tokens is not None: - exclude_special_tokens = torch.tensor( - exclude_special_tokens, dtype=torch.long, device=device - ) return ActivationsStore( model=model, dataset=override_dataset or cfg.dataset_path, @@ -59,7 +50,6 @@ def _mk_activations_store( autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, seqpos_slice=cfg.seqpos_slice, - exclude_special_tokens=exclude_special_tokens, disable_concat_sequences=cfg.disable_concat_sequences, sequence_separator_token=cfg.sequence_separator_token, ) diff --git a/sae_lens/config.py b/sae_lens/config.py index 709f61bd2..1c3621b1a 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -498,7 +498,6 @@ class CacheActivationsRunnerConfig: dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface. disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences. sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `` token. - exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs. """ dataset_path: str @@ -540,7 +539,6 @@ class CacheActivationsRunnerConfig: sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = ( special_token_field(default="bos") ) - exclude_special_tokens: bool | list[int] = False def __post_init__(self): # Automatically determine context_size if dataset is tokenized @@ -570,11 +568,6 @@ def __post_init__(self): self.dataset_path, self.model_name, self.hook_name, None ) - if isinstance(self.exclude_special_tokens, list) and not all( - isinstance(x, int) for x in self.exclude_special_tokens - ): - raise ValueError("exclude_special_tokens list must contain only integers") - @property def sliced_context_size(self) -> int: if self.seqpos_slice is not None: From 1225e66ee4bb72c776812f416313abf104950cc2 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 17:54:57 -0500 Subject: [PATCH 4/8] only delete old dataset files --- sae_lens/cache_activations_runner.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index 2ca77f263..fe01d6314 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -294,12 +294,14 @@ def run(self) -> Dataset: shuffled_path = final_cached_activation_path / ".shuffled" dataset.save_to_disk(str(shuffled_path)) # Remove old unshuffled data and replace with shuffled + # Only remove known dataset files (from _consolidate_shards output) for item in final_cached_activation_path.iterdir(): - if item.name != ".shuffled": - if item.is_dir(): - shutil.rmtree(item) - else: - item.unlink() + is_arrow_file = ( + item.name.startswith("data-") and item.suffix == ".arrow" + ) + is_dataset_metadata = item.name in ("dataset_info.json", "state.json") + if is_arrow_file or is_dataset_metadata: + item.unlink() for item in shuffled_path.iterdir(): shutil.move(str(item), str(final_cached_activation_path / item.name)) shuffled_path.rmdir() From 45808d44702062ac86182afd4d719a8c7b6e94cf Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 18:17:02 -0500 Subject: [PATCH 5/8] adding option to shuffle across sequences --- sae_lens/cache_activations_runner.py | 12 ++ sae_lens/config.py | 10 +- tests/test_cache_activations_runner.py | 150 +++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 1 deletion(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index fe01d6314..22486b3ed 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -88,6 +88,8 @@ def __init__( Value(dtype="int32"), length=self.context_size ) self.features = Features(features_dict) + # Generator for reproducible shuffling across sequences + self._shuffle_generator = torch.Generator().manual_seed(self.cfg.seed) def __str__(self): """ @@ -344,6 +346,16 @@ def _create_shard( ) -> Dataset: hook_names = [self.cfg.hook_name] acts, token_ids = buffer + + # Shuffle across sequences if enabled - this shuffles individual activation + # positions across all sequences, keeping token_ids paired with activations + if self.cfg.shuffle_across_sequences: + n_activations = acts.shape[0] + perm = torch.randperm(n_activations, generator=self._shuffle_generator) + acts = acts[perm] + if token_ids is not None: + token_ids = token_ids[perm] + acts = einops.rearrange( acts, "(bs context_size) d_in -> bs context_size d_in", diff --git a/sae_lens/config.py b/sae_lens/config.py index 1c3621b1a..977a8c74f 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -478,7 +478,8 @@ class CacheActivationsRunnerConfig: context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized. model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`. new_cached_activations_path (str, optional): The path to save the activations. - shuffle (bool): Whether to shuffle the dataset. + shuffle (bool): Whether to shuffle the dataset at the sequence level. + shuffle_across_sequences (bool): Whether to shuffle individual activations across all sequence positions within each buffer. This treats the buffer as a flat 2D array and shuffles activation positions while keeping token_ids paired with their activations. seed (int): The seed to use for shuffling. dtype (str): Datatype of activations to be stored. device (str): The device for the model. @@ -512,6 +513,7 @@ class CacheActivationsRunnerConfig: # defaults to "activations/{dataset}/{model}/{hook_name} new_cached_activations_path: str | None = None shuffle: bool = True + shuffle_across_sequences: bool = False seed: int = 42 dtype: str = "float32" device: str = "cuda" if torch.cuda.is_available() else "cpu" @@ -568,6 +570,12 @@ def __post_init__(self): self.dataset_path, self.model_name, self.hook_name, None ) + if self.shuffle_across_sequences and not self.shuffle: + raise ValueError( + "shuffle_across_sequences=True requires shuffle=True. " + "Set shuffle=True to enable shuffling across sequences." + ) + @property def sliced_context_size(self) -> int: if self.seqpos_slice is not None: diff --git a/tests/test_cache_activations_runner.py b/tests/test_cache_activations_runner.py index 4fa77feb2..37f4250e8 100644 --- a/tests/test_cache_activations_runner.py +++ b/tests/test_cache_activations_runner.py @@ -521,3 +521,153 @@ def test_cache_activations_runner_shuffled_saved_to_disk(tmp_path: Path): assert np.array_equal( shuffled_acts_on_disk, returned_acts ), "Activations on disk should match returned activations" + + +def test_cache_activations_runner_shuffle_across_sequences(tmp_path: Path): + """Test that shuffle_across_sequences shuffles individual activations across all sequence positions.""" + # Create test dataset with unique tokens + tokenizer = HookedTransformer.from_pretrained("gelu-1l").tokenizer + text = "".join( + [ + " " + word[1:] + for word in tokenizer.vocab # type: ignore + if word[0] == "Ġ" and word[1:].isascii() and word.isalnum() + ] + ) + dataset = Dataset.from_list([{"text": text}]) + + # Create configs for unshuffled and shuffle_across_sequences versions + base_cfg = _default_cfg( + tmp_path / "base", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=False, + shuffle_across_sequences=False, + ) + shuffle_across_cfg = _default_cfg( + tmp_path / "shuffle_across", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=True, # Required when shuffle_across_sequences=True + shuffle_across_sequences=True, + ) + + # Get unshuffled dataset + unshuffled_runner = CacheActivationsRunner(base_cfg, override_dataset=dataset) + unshuffled_ds = unshuffled_runner.run() + unshuffled_ds.set_format("torch") + + # Get shuffle_across_sequences dataset + shuffled_runner = CacheActivationsRunner(shuffle_across_cfg, override_dataset=dataset) + shuffled_ds = shuffled_runner.run() + shuffled_ds.set_format("torch") + + # Get activations and tokens + hook_name = base_cfg.hook_name + unshuffled_acts: torch.Tensor = unshuffled_ds[hook_name] # type: ignore + unshuffled_tokens: torch.Tensor = unshuffled_ds["token_ids"] # type: ignore + shuffled_acts: torch.Tensor = shuffled_ds[hook_name] # type: ignore + shuffled_tokens: torch.Tensor = shuffled_ds["token_ids"] # type: ignore + + # Convert to numpy for easier manipulation + unshuffled_acts_np = np.array(unshuffled_acts) + unshuffled_tokens_np = np.array(unshuffled_tokens) + shuffled_acts_np = np.array(shuffled_acts) + shuffled_tokens_np = np.array(shuffled_tokens) + + # Verify shapes are preserved + assert unshuffled_acts_np.shape == shuffled_acts_np.shape + assert unshuffled_tokens_np.shape == shuffled_tokens_np.shape + + # Flatten to compare individual activations + unshuffled_acts_flat = unshuffled_acts_np.reshape(-1, unshuffled_acts_np.shape[-1]) + unshuffled_tokens_flat = unshuffled_tokens_np.reshape(-1) + shuffled_acts_flat = shuffled_acts_np.reshape(-1, shuffled_acts_np.shape[-1]) + shuffled_tokens_flat = shuffled_tokens_np.reshape(-1) + + # Verify data is actually shuffled (activations should be in different positions) + assert not np.array_equal(unshuffled_acts_flat, shuffled_acts_flat) + assert not np.array_equal(unshuffled_tokens_flat, shuffled_tokens_flat) + + # Verify token-activation pairs remain aligned after shuffling + # For each unique token in unshuffled, find its activation and verify + # the same token has the same activation in shuffled + for i in range(len(unshuffled_tokens_flat)): + token = unshuffled_tokens_flat[i] + act = unshuffled_acts_flat[i] + # Find where this token is in the shuffled version + shuffled_indices = np.where(shuffled_tokens_flat == token)[0] + # At least one of these positions should have the matching activation + found_match = False + for idx in shuffled_indices: + if np.allclose(act, shuffled_acts_flat[idx], rtol=1e-5, atol=1e-5): + found_match = True + break + assert found_match, f"Token {token} at position {i} lost its paired activation" + + +def test_cache_activations_runner_shuffle_across_sequences_reproducible(tmp_path: Path): + """Test that shuffle_across_sequences is reproducible with the same seed.""" + tokenizer = HookedTransformer.from_pretrained("gelu-1l").tokenizer + text = "".join( + [ + " " + word[1:] + for word in tokenizer.vocab # type: ignore + if word[0] == "Ġ" and word[1:].isascii() and word.isalnum() + ] + ) + dataset = Dataset.from_list([{"text": text}]) + + # Create two configs with the same seed (default is 42) + cfg1 = _default_cfg( + tmp_path / "run1", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=True, # Required when shuffle_across_sequences=True + shuffle_across_sequences=True, + ) + cfg2 = _default_cfg( + tmp_path / "run2", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=True, # Required when shuffle_across_sequences=True + shuffle_across_sequences=True, + ) + + # Run both + runner1 = CacheActivationsRunner(cfg1, override_dataset=dataset) + ds1 = runner1.run() + ds1.set_format("torch") + + runner2 = CacheActivationsRunner(cfg2, override_dataset=dataset) + ds2 = runner2.run() + ds2.set_format("torch") + + # Results should be identical + hook_name = cfg1.hook_name + acts1 = np.array(ds1[hook_name]) + acts2 = np.array(ds2[hook_name]) + tokens1 = np.array(ds1["token_ids"]) + tokens2 = np.array(ds2["token_ids"]) + + assert np.array_equal(acts1, acts2), "Same seed should produce identical activations" + assert np.array_equal(tokens1, tokens2), "Same seed should produce identical tokens" + + +def test_cache_activations_runner_shuffle_across_sequences_requires_shuffle( + tmp_path: Path, +): + """Test that shuffle_across_sequences=True requires shuffle=True.""" + with pytest.raises( + ValueError, + match="shuffle_across_sequences=True requires shuffle=True", + ): + _default_cfg( + tmp_path, + shuffle=False, + shuffle_across_sequences=True, + ) From dc3dd7625cbb303e856eefe08a8985a6098cf278 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 19:07:45 -0500 Subject: [PATCH 6/8] fixing formatting --- tests/test_cache_activations_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_cache_activations_runner.py b/tests/test_cache_activations_runner.py index 37f4250e8..6f7f01401 100644 --- a/tests/test_cache_activations_runner.py +++ b/tests/test_cache_activations_runner.py @@ -560,7 +560,9 @@ def test_cache_activations_runner_shuffle_across_sequences(tmp_path: Path): unshuffled_ds.set_format("torch") # Get shuffle_across_sequences dataset - shuffled_runner = CacheActivationsRunner(shuffle_across_cfg, override_dataset=dataset) + shuffled_runner = CacheActivationsRunner( + shuffle_across_cfg, override_dataset=dataset + ) shuffled_ds = shuffled_runner.run() shuffled_ds.set_format("torch") @@ -654,7 +656,9 @@ def test_cache_activations_runner_shuffle_across_sequences_reproducible(tmp_path tokens1 = np.array(ds1["token_ids"]) tokens2 = np.array(ds2["token_ids"]) - assert np.array_equal(acts1, acts2), "Same seed should produce identical activations" + assert np.array_equal( + acts1, acts2 + ), "Same seed should produce identical activations" assert np.array_equal(tokens1, tokens2), "Same seed should produce identical tokens" From 656849a25c8271d330aec24858257f991e560208 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 19:23:15 -0500 Subject: [PATCH 7/8] changes from CR --- sae_lens/cache_activations_runner.py | 53 ++++++++++++++++++---------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index 22486b3ed..0e0e06bbd 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -18,6 +18,11 @@ from sae_lens.training.activations_store import ActivationsStore from sae_lens.util import str_to_dtype +# Directory names for temporary operations during caching +_TMP_SHARDS_DIR = ".tmp_shards" +_SHUFFLED_DIR = ".shuffled" +_BACKUP_SUFFIX = ".backup" + def _mk_activations_store( model: HookedRootModule, @@ -181,10 +186,10 @@ def _consolidate_shards( f"output_dir is not an existing directory: {output_dir}" ) - other_items = [p for p in output_dir.iterdir() if p.name != ".tmp_shards"] + other_items = [p for p in output_dir.iterdir() if p.name != _TMP_SHARDS_DIR] if other_items: raise FileExistsError( - f"output_dir must be empty (besides .tmp_shards). Found: {other_items}" + f"output_dir must be empty (besides {_TMP_SHARDS_DIR}). Found: {other_items}" ) if not (source_dir / first_shard_dir_name).exists(): @@ -259,7 +264,7 @@ def run(self) -> Dataset: f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files." ) - tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/" + tmp_cached_activation_path = final_cached_activation_path / _TMP_SHARDS_DIR tmp_cached_activation_path.mkdir(exist_ok=False, parents=False) ### Create temporary sharded datasets @@ -291,22 +296,34 @@ def run(self) -> Dataset: if self.cfg.shuffle: logger.info("Shuffling...") dataset = dataset.shuffle(seed=self.cfg.seed) - # Save the shuffled dataset back to disk - # We need to save to a temp location first since datasets can't overwrite themselves - shuffled_path = final_cached_activation_path / ".shuffled" + # Save the shuffled dataset back to disk using atomic rename with backup + # to prevent data loss if the process crashes mid-operation. + # Note: shuffled_path must be a sibling (not child) of final_cached_activation_path + # so that renaming the parent doesn't invalidate the shuffled path. + shuffled_path = final_cached_activation_path.parent / ( + final_cached_activation_path.name + _SHUFFLED_DIR + ) + backup_path = final_cached_activation_path.parent / ( + final_cached_activation_path.name + _BACKUP_SUFFIX + ) + dataset.save_to_disk(str(shuffled_path)) - # Remove old unshuffled data and replace with shuffled - # Only remove known dataset files (from _consolidate_shards output) - for item in final_cached_activation_path.iterdir(): - is_arrow_file = ( - item.name.startswith("data-") and item.suffix == ".arrow" - ) - is_dataset_metadata = item.name in ("dataset_info.json", "state.json") - if is_arrow_file or is_dataset_metadata: - item.unlink() - for item in shuffled_path.iterdir(): - shutil.move(str(item), str(final_cached_activation_path / item.name)) - shuffled_path.rmdir() + + # Atomic swap: rename original to backup, then shuffled to original + try: + final_cached_activation_path.rename(backup_path) + shuffled_path.rename(final_cached_activation_path) + # Success - remove backup + shutil.rmtree(backup_path) + except Exception: + # Rollback: restore from backup if it exists + if backup_path.exists() and not final_cached_activation_path.exists(): + backup_path.rename(final_cached_activation_path) + # Clean up shuffled path if it still exists + if shuffled_path.exists(): + shutil.rmtree(shuffled_path) + raise + # Reload the dataset from the new location dataset = Dataset.load_from_disk(str(final_cached_activation_path)) From 4e17344ef45848b518ab844ba8ba4354cb7789d0 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 3 Jan 2026 23:25:21 -0500 Subject: [PATCH 8/8] shuffle sequences globally, not just per-shard --- sae_lens/cache_activations_runner.py | 70 ++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index 0e0e06bbd..8829bc030 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -93,8 +93,6 @@ def __init__( Value(dtype="int32"), length=self.context_size ) self.features = Features(features_dict) - # Generator for reproducible shuffling across sequences - self._shuffle_generator = torch.Generator().manual_seed(self.cfg.seed) def __str__(self): """ @@ -294,8 +292,63 @@ def run(self) -> Dataset: ) if self.cfg.shuffle: - logger.info("Shuffling...") - dataset = dataset.shuffle(seed=self.cfg.seed) + # shuffle_across_sequences: shuffle individual activations globally, + # treating the entire dataset as a flat array of (total_tokens, d_in). + # This breaks up sequential patterns within sequences while keeping + # token_ids paired with their corresponding activations. + if self.cfg.shuffle_across_sequences: + logger.info("Shuffling across sequences...") + dataset.set_format("torch") + hook_name = self.cfg.hook_name + + # Load all data and flatten + # With torch format, [:] returns tensors directly + all_data = dataset[:] + acts = all_data[hook_name] # (n_seq, context_size, d_in) + token_ids = all_data["token_ids"] # (n_seq, context_size) + n_seq = acts.shape[0] + + acts_flat = einops.rearrange( + acts, "n_seq context_size d_in -> (n_seq context_size) d_in" + ) + token_ids_flat = einops.rearrange( + token_ids, "n_seq context_size -> (n_seq context_size)" + ) + + # Shuffle globally with the same permutation for both + generator = torch.Generator().manual_seed(self.cfg.seed) + perm = torch.randperm(acts_flat.shape[0], generator=generator) + acts_flat = acts_flat[perm] + token_ids_flat = token_ids_flat[perm] + + # Reshape back to sequences + acts_shuffled = einops.rearrange( + acts_flat, + "(n_seq context_size) d_in -> n_seq context_size d_in", + n_seq=n_seq, + context_size=self.context_size, + ) + token_ids_shuffled = einops.rearrange( + token_ids_flat, + "(n_seq context_size) -> n_seq context_size", + n_seq=n_seq, + context_size=self.context_size, + ) + + # Create new dataset from shuffled data + dataset = Dataset.from_dict( + { + hook_name: acts_shuffled, + "token_ids": token_ids_shuffled.to(torch.int32), + }, + features=self.features, + ) + else: + # Sequence-level shuffle only: shuffle the order of sequences (rows) + # Skip if shuffle_across_sequences was used since global shuffle is stronger + logger.info("Shuffling sequences...") + dataset = dataset.shuffle(seed=self.cfg.seed) + # Save the shuffled dataset back to disk using atomic rename with backup # to prevent data loss if the process crashes mid-operation. # Note: shuffled_path must be a sibling (not child) of final_cached_activation_path @@ -364,15 +417,6 @@ def _create_shard( hook_names = [self.cfg.hook_name] acts, token_ids = buffer - # Shuffle across sequences if enabled - this shuffles individual activation - # positions across all sequences, keeping token_ids paired with activations - if self.cfg.shuffle_across_sequences: - n_activations = acts.shape[0] - perm = torch.randperm(n_activations, generator=self._shuffle_generator) - acts = acts[perm] - if token_ids is not None: - token_ids = token_ids[perm] - acts = einops.rearrange( acts, "(bs context_size) d_in -> bs context_size d_in",