diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index bbebefc0b..8829bc030 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -18,6 +18,11 @@ from sae_lens.training.activations_store import ActivationsStore from sae_lens.util import str_to_dtype +# Directory names for temporary operations during caching +_TMP_SHARDS_DIR = ".tmp_shards" +_SHUFFLED_DIR = ".shuffled" +_BACKUP_SUFFIX = ".backup" + def _mk_activations_store( model: HookedRootModule, @@ -28,6 +33,7 @@ def _mk_activations_store( Internal method used in CacheActivationsRunner. Used to create a cached dataset from a ActivationsStore. """ + device = torch.device("cpu") # since we're saving to disk return ActivationsStore( model=model, dataset=override_dataset or cfg.dataset_path, @@ -42,13 +48,15 @@ def _mk_activations_store( train_batch_size_tokens=-1, prepend_bos=cfg.prepend_bos, normalize_activations="none", - device=torch.device("cpu"), # since we're saving to disk + device=device, dtype=cfg.dtype, cached_activations_path=None, model_kwargs=cfg.model_kwargs, autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, seqpos_slice=cfg.seqpos_slice, + disable_concat_sequences=cfg.disable_concat_sequences, + sequence_separator_token=cfg.sequence_separator_token, ) @@ -176,10 +184,10 @@ def _consolidate_shards( f"output_dir is not an existing directory: {output_dir}" ) - other_items = [p for p in output_dir.iterdir() if p.name != ".tmp_shards"] + other_items = [p for p in output_dir.iterdir() if p.name != _TMP_SHARDS_DIR] if other_items: raise FileExistsError( - f"output_dir must be empty (besides .tmp_shards). Found: {other_items}" + f"output_dir must be empty (besides {_TMP_SHARDS_DIR}). Found: {other_items}" ) if not (source_dir / first_shard_dir_name).exists(): @@ -221,7 +229,7 @@ def _consolidate_shards( "_split": None, } - # fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint) + # fingerprint is generated from dataset.__getstate__ (not including _fingerprint) with open(output_dir / "state.json", "w") as f: json.dump(new_state, f, indent=2) @@ -254,7 +262,7 @@ def run(self) -> Dataset: f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files." ) - tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/" + tmp_cached_activation_path = final_cached_activation_path / _TMP_SHARDS_DIR tmp_cached_activation_path.mkdir(exist_ok=False, parents=False) ### Create temporary sharded datasets @@ -284,8 +292,93 @@ def run(self) -> Dataset: ) if self.cfg.shuffle: - logger.info("Shuffling...") - dataset = dataset.shuffle(seed=self.cfg.seed) + # shuffle_across_sequences: shuffle individual activations globally, + # treating the entire dataset as a flat array of (total_tokens, d_in). + # This breaks up sequential patterns within sequences while keeping + # token_ids paired with their corresponding activations. + if self.cfg.shuffle_across_sequences: + logger.info("Shuffling across sequences...") + dataset.set_format("torch") + hook_name = self.cfg.hook_name + + # Load all data and flatten + # With torch format, [:] returns tensors directly + all_data = dataset[:] + acts = all_data[hook_name] # (n_seq, context_size, d_in) + token_ids = all_data["token_ids"] # (n_seq, context_size) + n_seq = acts.shape[0] + + acts_flat = einops.rearrange( + acts, "n_seq context_size d_in -> (n_seq context_size) d_in" + ) + token_ids_flat = einops.rearrange( + token_ids, "n_seq context_size -> (n_seq context_size)" + ) + + # Shuffle globally with the same permutation for both + generator = torch.Generator().manual_seed(self.cfg.seed) + perm = torch.randperm(acts_flat.shape[0], generator=generator) + acts_flat = acts_flat[perm] + token_ids_flat = token_ids_flat[perm] + + # Reshape back to sequences + acts_shuffled = einops.rearrange( + acts_flat, + "(n_seq context_size) d_in -> n_seq context_size d_in", + n_seq=n_seq, + context_size=self.context_size, + ) + token_ids_shuffled = einops.rearrange( + token_ids_flat, + "(n_seq context_size) -> n_seq context_size", + n_seq=n_seq, + context_size=self.context_size, + ) + + # Create new dataset from shuffled data + dataset = Dataset.from_dict( + { + hook_name: acts_shuffled, + "token_ids": token_ids_shuffled.to(torch.int32), + }, + features=self.features, + ) + else: + # Sequence-level shuffle only: shuffle the order of sequences (rows) + # Skip if shuffle_across_sequences was used since global shuffle is stronger + logger.info("Shuffling sequences...") + dataset = dataset.shuffle(seed=self.cfg.seed) + + # Save the shuffled dataset back to disk using atomic rename with backup + # to prevent data loss if the process crashes mid-operation. + # Note: shuffled_path must be a sibling (not child) of final_cached_activation_path + # so that renaming the parent doesn't invalidate the shuffled path. + shuffled_path = final_cached_activation_path.parent / ( + final_cached_activation_path.name + _SHUFFLED_DIR + ) + backup_path = final_cached_activation_path.parent / ( + final_cached_activation_path.name + _BACKUP_SUFFIX + ) + + dataset.save_to_disk(str(shuffled_path)) + + # Atomic swap: rename original to backup, then shuffled to original + try: + final_cached_activation_path.rename(backup_path) + shuffled_path.rename(final_cached_activation_path) + # Success - remove backup + shutil.rmtree(backup_path) + except Exception: + # Rollback: restore from backup if it exists + if backup_path.exists() and not final_cached_activation_path.exists(): + backup_path.rename(final_cached_activation_path) + # Clean up shuffled path if it still exists + if shuffled_path.exists(): + shutil.rmtree(shuffled_path) + raise + + # Reload the dataset from the new location + dataset = Dataset.load_from_disk(str(final_cached_activation_path)) if self.cfg.hf_repo_id: logger.info("Pushing to Huggingface Hub...") @@ -323,6 +416,7 @@ def _create_shard( ) -> Dataset: hook_names = [self.cfg.hook_name] acts, token_ids = buffer + acts = einops.rearrange( acts, "(bs context_size) d_in -> bs context_size d_in", diff --git a/sae_lens/config.py b/sae_lens/config.py index 8ca5d0e74..977a8c74f 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -478,7 +478,8 @@ class CacheActivationsRunnerConfig: context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized. model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`. new_cached_activations_path (str, optional): The path to save the activations. - shuffle (bool): Whether to shuffle the dataset. + shuffle (bool): Whether to shuffle the dataset at the sequence level. + shuffle_across_sequences (bool): Whether to shuffle individual activations across all sequence positions within each buffer. This treats the buffer as a flat 2D array and shuffles activation positions while keeping token_ids paired with their activations. seed (int): The seed to use for shuffling. dtype (str): Datatype of activations to be stored. device (str): The device for the model. @@ -496,6 +497,8 @@ class CacheActivationsRunnerConfig: streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical. autocast_lm (bool): Whether to use autocast during activation fetching. dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface. + disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences. + sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `` token. """ dataset_path: str @@ -510,6 +513,7 @@ class CacheActivationsRunnerConfig: # defaults to "activations/{dataset}/{model}/{hook_name} new_cached_activations_path: str | None = None shuffle: bool = True + shuffle_across_sequences: bool = False seed: int = 42 dtype: str = "float32" device: str = "cuda" if torch.cuda.is_available() else "cpu" @@ -533,6 +537,10 @@ class CacheActivationsRunnerConfig: streaming: bool = True autocast_lm: bool = False dataset_trust_remote_code: bool | None = None + disable_concat_sequences: bool = False + sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = ( + special_token_field(default="bos") + ) def __post_init__(self): # Automatically determine context_size if dataset is tokenized @@ -562,6 +570,12 @@ def __post_init__(self): self.dataset_path, self.model_name, self.hook_name, None ) + if self.shuffle_across_sequences and not self.shuffle: + raise ValueError( + "shuffle_across_sequences=True requires shuffle=True. " + "Set shuffle=True to enable shuffling across sequences." + ) + @property def sliced_context_size(self) -> int: if self.seqpos_slice is not None: diff --git a/tests/test_cache_activations_runner.py b/tests/test_cache_activations_runner.py index 40ef2958b..6f7f01401 100644 --- a/tests/test_cache_activations_runner.py +++ b/tests/test_cache_activations_runner.py @@ -452,3 +452,226 @@ def test_cache_activations_runner_shuffling(tmp_path: Path): torch.from_numpy(unshuffled_acts_array[i]), torch.from_numpy(shuffled_acts_array[shuffled_idx]), ) + + +def test_cache_activations_runner_shuffled_saved_to_disk(tmp_path: Path): + """Test that when shuffle=True, the shuffled dataset is saved to disk (not just returned).""" + # Create test dataset with arbitrary unique tokens + tokenizer = HookedTransformer.from_pretrained("gelu-1l").tokenizer + text = "".join( + [ + " " + word[1:] + for word in tokenizer.vocab # type: ignore + if word[0] == "Ġ" and word[1:].isascii() and word.isalnum() + ] + ) + dataset = Dataset.from_list([{"text": text}]) + + # Create configs for unshuffled and shuffled versions + unshuffled_path = tmp_path / "unshuffled" + shuffled_path = tmp_path / "shuffled" + + unshuffled_cfg = _default_cfg( + unshuffled_path, + context_size=3, + batch_size=2, + dataset_num_rows=8, + shuffle=False, + ) + shuffled_cfg = _default_cfg( + shuffled_path, + context_size=3, + batch_size=2, + dataset_num_rows=8, + shuffle=True, + ) + + # Run both + unshuffled_runner = CacheActivationsRunner(unshuffled_cfg, override_dataset=dataset) + unshuffled_runner.run() + + shuffled_runner = CacheActivationsRunner(shuffled_cfg, override_dataset=dataset) + returned_shuffled_ds = shuffled_runner.run() + returned_shuffled_ds.set_format("torch") + + # Load datasets from disk + unshuffled_from_disk = datasets.load_from_disk(str(unshuffled_path)) + shuffled_from_disk = datasets.load_from_disk(str(shuffled_path)) + unshuffled_from_disk.set_format("torch") + shuffled_from_disk.set_format("torch") + + hook_name = unshuffled_cfg.hook_name + + # Verify the shuffled dataset on disk is different from the unshuffled one + unshuffled_tokens = np.array(unshuffled_from_disk["token_ids"]) + shuffled_tokens_on_disk = np.array(shuffled_from_disk["token_ids"]) + assert not np.array_equal( + unshuffled_tokens, shuffled_tokens_on_disk + ), "Shuffled dataset on disk should be different from unshuffled" + + # Verify the shuffled dataset on disk matches what was returned + returned_tokens = np.array(returned_shuffled_ds["token_ids"]) + assert np.array_equal( + shuffled_tokens_on_disk, returned_tokens + ), "Dataset on disk should match returned dataset" + + # Also verify activations match + shuffled_acts_on_disk = np.array(shuffled_from_disk[hook_name]) + returned_acts = np.array(returned_shuffled_ds[hook_name]) + assert np.array_equal( + shuffled_acts_on_disk, returned_acts + ), "Activations on disk should match returned activations" + + +def test_cache_activations_runner_shuffle_across_sequences(tmp_path: Path): + """Test that shuffle_across_sequences shuffles individual activations across all sequence positions.""" + # Create test dataset with unique tokens + tokenizer = HookedTransformer.from_pretrained("gelu-1l").tokenizer + text = "".join( + [ + " " + word[1:] + for word in tokenizer.vocab # type: ignore + if word[0] == "Ġ" and word[1:].isascii() and word.isalnum() + ] + ) + dataset = Dataset.from_list([{"text": text}]) + + # Create configs for unshuffled and shuffle_across_sequences versions + base_cfg = _default_cfg( + tmp_path / "base", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=False, + shuffle_across_sequences=False, + ) + shuffle_across_cfg = _default_cfg( + tmp_path / "shuffle_across", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=True, # Required when shuffle_across_sequences=True + shuffle_across_sequences=True, + ) + + # Get unshuffled dataset + unshuffled_runner = CacheActivationsRunner(base_cfg, override_dataset=dataset) + unshuffled_ds = unshuffled_runner.run() + unshuffled_ds.set_format("torch") + + # Get shuffle_across_sequences dataset + shuffled_runner = CacheActivationsRunner( + shuffle_across_cfg, override_dataset=dataset + ) + shuffled_ds = shuffled_runner.run() + shuffled_ds.set_format("torch") + + # Get activations and tokens + hook_name = base_cfg.hook_name + unshuffled_acts: torch.Tensor = unshuffled_ds[hook_name] # type: ignore + unshuffled_tokens: torch.Tensor = unshuffled_ds["token_ids"] # type: ignore + shuffled_acts: torch.Tensor = shuffled_ds[hook_name] # type: ignore + shuffled_tokens: torch.Tensor = shuffled_ds["token_ids"] # type: ignore + + # Convert to numpy for easier manipulation + unshuffled_acts_np = np.array(unshuffled_acts) + unshuffled_tokens_np = np.array(unshuffled_tokens) + shuffled_acts_np = np.array(shuffled_acts) + shuffled_tokens_np = np.array(shuffled_tokens) + + # Verify shapes are preserved + assert unshuffled_acts_np.shape == shuffled_acts_np.shape + assert unshuffled_tokens_np.shape == shuffled_tokens_np.shape + + # Flatten to compare individual activations + unshuffled_acts_flat = unshuffled_acts_np.reshape(-1, unshuffled_acts_np.shape[-1]) + unshuffled_tokens_flat = unshuffled_tokens_np.reshape(-1) + shuffled_acts_flat = shuffled_acts_np.reshape(-1, shuffled_acts_np.shape[-1]) + shuffled_tokens_flat = shuffled_tokens_np.reshape(-1) + + # Verify data is actually shuffled (activations should be in different positions) + assert not np.array_equal(unshuffled_acts_flat, shuffled_acts_flat) + assert not np.array_equal(unshuffled_tokens_flat, shuffled_tokens_flat) + + # Verify token-activation pairs remain aligned after shuffling + # For each unique token in unshuffled, find its activation and verify + # the same token has the same activation in shuffled + for i in range(len(unshuffled_tokens_flat)): + token = unshuffled_tokens_flat[i] + act = unshuffled_acts_flat[i] + # Find where this token is in the shuffled version + shuffled_indices = np.where(shuffled_tokens_flat == token)[0] + # At least one of these positions should have the matching activation + found_match = False + for idx in shuffled_indices: + if np.allclose(act, shuffled_acts_flat[idx], rtol=1e-5, atol=1e-5): + found_match = True + break + assert found_match, f"Token {token} at position {i} lost its paired activation" + + +def test_cache_activations_runner_shuffle_across_sequences_reproducible(tmp_path: Path): + """Test that shuffle_across_sequences is reproducible with the same seed.""" + tokenizer = HookedTransformer.from_pretrained("gelu-1l").tokenizer + text = "".join( + [ + " " + word[1:] + for word in tokenizer.vocab # type: ignore + if word[0] == "Ġ" and word[1:].isascii() and word.isalnum() + ] + ) + dataset = Dataset.from_list([{"text": text}]) + + # Create two configs with the same seed (default is 42) + cfg1 = _default_cfg( + tmp_path / "run1", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=True, # Required when shuffle_across_sequences=True + shuffle_across_sequences=True, + ) + cfg2 = _default_cfg( + tmp_path / "run2", + context_size=4, + batch_size=2, + dataset_num_rows=8, + shuffle=True, # Required when shuffle_across_sequences=True + shuffle_across_sequences=True, + ) + + # Run both + runner1 = CacheActivationsRunner(cfg1, override_dataset=dataset) + ds1 = runner1.run() + ds1.set_format("torch") + + runner2 = CacheActivationsRunner(cfg2, override_dataset=dataset) + ds2 = runner2.run() + ds2.set_format("torch") + + # Results should be identical + hook_name = cfg1.hook_name + acts1 = np.array(ds1[hook_name]) + acts2 = np.array(ds2[hook_name]) + tokens1 = np.array(ds1["token_ids"]) + tokens2 = np.array(ds2["token_ids"]) + + assert np.array_equal( + acts1, acts2 + ), "Same seed should produce identical activations" + assert np.array_equal(tokens1, tokens2), "Same seed should produce identical tokens" + + +def test_cache_activations_runner_shuffle_across_sequences_requires_shuffle( + tmp_path: Path, +): + """Test that shuffle_across_sequences=True requires shuffle=True.""" + with pytest.raises( + ValueError, + match="shuffle_across_sequences=True requires shuffle=True", + ): + _default_cfg( + tmp_path, + shuffle=False, + shuffle_across_sequences=True, + )