diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index 6c41455b4..9c464b4db 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -11,11 +11,10 @@ from datasets.fingerprint import generate_fingerprint from huggingface_hub import HfApi from jaxtyping import Float -from tqdm import tqdm - from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig from sae_lens.load_model import load_model from sae_lens.training.activations_store import ActivationsStore +from tqdm import tqdm class CacheActivationsRunner: @@ -273,10 +272,20 @@ def run(self) -> Dataset: meta_io.seek(0) api = HfApi() + + # Tacks on username to the repo id + user_repo_id = api.create_repo( + self.cfg.hf_repo_id, + private=self.cfg.hf_is_private_repo, + repo_type="dataset", + exist_ok=True, # should exist already + ).repo_id + api.upload_file( path_or_fileobj=meta_io, path_in_repo="cache_activations_runner_cfg.json", - repo_id=self.cfg.hf_repo_id, + repo_id=user_repo_id, + revision=self.cfg.hf_revision, repo_type="dataset", commit_message="Add cache_activations_runner metadata", )