From a9c250c0f7e1933c494dc8ba0667ac16b8122573 Mon Sep 17 00:00:00 2001 From: lee-t Date: Thu, 20 Nov 2025 22:56:36 -0500 Subject: [PATCH] Refactor: Centralize cache directory management with new utilities --- src/proteingympy/data_import_funcs.py | 270 +++++++++++++++--- .../make_alphamissense_supplementary.py | 52 ++-- src/proteingympy/make_dms_substitutions.py | 72 ++--- src/proteingympy/make_supervised_scores.py | 103 +++---- .../make_zero_shot_substitutions.py | 43 ++- .../make_zeroshot_dms_benchmarks.py | 53 ++-- 6 files changed, 363 insertions(+), 230 deletions(-) diff --git a/src/proteingympy/data_import_funcs.py b/src/proteingympy/data_import_funcs.py index 04a6e6c..2cc97e1 100644 --- a/src/proteingympy/data_import_funcs.py +++ b/src/proteingympy/data_import_funcs.py @@ -1,64 +1,256 @@ +""" +data_import_funcs.py - Base download utilities for ProteinGym data files. + +This module provides centralized cache configuration and download utilities +with caching support, eliminating code duplication across data pipeline modules. +""" + import os +from functools import wraps +from pathlib import Path +from typing import Optional, Union, Callable import requests -def get_dms_substitution_zip(cache_dir: str = ".cache/", use_cache: bool = True) -> str: - """Download the DMS_ProteinGym_substitutions.zip file to the cache directory. - + +# ============================================================================ +# Cache Configuration +# ============================================================================ + +# Default cache directory - can be overridden via environment variable +DEFAULT_CACHE_DIR = Path(os.getenv("PROTEINGYM_CACHE_DIR", ".cache")) + + +def get_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> Path: + """ + Get the cache directory as a Path object. + + This function provides a centralized way to determine the cache directory, + with the following precedence: + 1. Explicit cache_dir parameter (if provided) + 2. PROTEINGYM_CACHE_DIR environment variable (if set) + 3. Default ".cache" directory + Args: - cache_dir: Directory to store the downloaded file. + cache_dir: Optional override for cache directory. Can be a string or Path object. + + Returns: + Path object for the cache directory + + Examples: + >>> get_cache_dir() + PosixPath('.cache') + + >>> get_cache_dir("/tmp/my_cache") + PosixPath('/tmp/my_cache') + + >>> os.environ["PROTEINGYM_CACHE_DIR"] = "/data/cache" + >>> get_cache_dir() + PosixPath('/data/cache') + """ + if cache_dir is not None: + return Path(cache_dir) + return DEFAULT_CACHE_DIR + + +def set_default_cache_dir(cache_dir: Union[str, Path]) -> None: + """ + Set the default cache directory globally. + + This updates the DEFAULT_CACHE_DIR module variable and also sets + the PROTEINGYM_CACHE_DIR environment variable. + + Args: + cache_dir: New default cache directory + + Example: + >>> set_default_cache_dir("/data/proteingym_cache") + """ + global DEFAULT_CACHE_DIR + DEFAULT_CACHE_DIR = Path(cache_dir) + os.environ["PROTEINGYM_CACHE_DIR"] = str(DEFAULT_CACHE_DIR) + + +# ============================================================================ +# Download Utilities +# ============================================================================ + +def cached_download( + url: str, + filename: str, + cache_dir: Optional[Union[str, Path]] = None, + use_cache: bool = True, + chunk_size: int = 8192 +) -> Path: + """ + Download a file with caching support. + + If the file already exists in the cache and use_cache is True, the cached + version is used. Otherwise, the file is downloaded from the URL. + + Args: + url: URL to download from + filename: Name for the cached file + cache_dir: Cache directory (uses default if None) use_cache: If True, use cached file if it exists. If False, force a fresh download. - + chunk_size: Download chunk size in bytes (default: 8192) + Returns: - Path to the downloaded zip file. + Path to the cached file + + Raises: + requests.HTTPError: If the download fails with an HTTP error + requests.RequestException: If the download fails for other reasons + + Example: + >>> zip_path = cached_download( + ... url="https://zenodo.org/records/15293562/files/data.zip", + ... filename="data.zip", + ... cache_dir=".cache" + ... ) + Using cached file at .cache/data.zip """ - url = "https://zenodo.org/records/15293562/files/DMS_ProteinGym_substitutions.zip" - os.makedirs(cache_dir, exist_ok=True) - zip_path = os.path.join(cache_dir, "DMS_ProteinGym_substitutions.zip") - - if not use_cache or not os.path.exists(zip_path): - if os.path.exists(zip_path): - os.remove(zip_path) - print(f"Downloading {url} to {zip_path}...") + cache_path = get_cache_dir(cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + + file_path = cache_path / filename + + if not use_cache or not file_path.exists(): + if file_path.exists(): + file_path.unlink() # Remove existing file + + print(f"Downloading {url} to {file_path}...") response = requests.get(url, stream=True) response.raise_for_status() - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): + + with open(file_path, "wb") as f: + for chunk in response.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) print("Download complete.") else: - print(f"Using cached file at {zip_path}.") - return zip_path + print(f"Using cached file at {file_path}") + return file_path -def get_af2_structures_zip(cache_dir: str = ".cache/", use_cache: bool = True) -> str: - """Download the ProteinGym_AF2_structures.zip file to the cache directory. + +def download_with_cache(filename: str, url_param: str = "url"): + """ + Decorator to add caching support to download functions. + + The decorated function should return a URL string (or take a URL as a parameter). + This decorator wraps the function to automatically handle downloading and caching. + + Args: + filename: Name for the cached file + url_param: Name of the URL parameter in the decorated function (default: "url") + + Returns: + Decorator function + + Example: + >>> @download_with_cache("data.zip") + ... def get_data_url(): + ... return "https://example.com/data.zip" + ... + >>> path = get_data_url(cache_dir=".cache", use_cache=True) + Downloading https://example.com/data.zip to .cache/data.zip... + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(cache_dir: Optional[Union[str, Path]] = None, use_cache: bool = True, **kwargs): + # Call the original function to get the URL + # Check if the function expects a 'url' parameter + import inspect + sig = inspect.signature(func) + + if url_param in sig.parameters: + # Function takes URL as parameter, pass through kwargs + url = func(**kwargs) + else: + # Function returns URL + url = func(**kwargs) + + cache_path = get_cache_dir(cache_dir) + return cached_download(url, filename, cache_path, use_cache) + return wrapper + return decorator + + +def download_multiple( + downloads: list[tuple[str, str]], + cache_dir: Optional[Union[str, Path]] = None, + use_cache: bool = True, + chunk_size: int = 8192 +) -> dict[str, Path]: + """ + Download multiple files with caching support. + + Args: + downloads: List of (url, filename) tuples to download + cache_dir: Cache directory (uses default if None) + use_cache: If True, use cached files if they exist + chunk_size: Download chunk size in bytes + + Returns: + Dictionary mapping filenames to their cached paths + + Example: + >>> files = download_multiple([ + ... ("https://example.com/data1.zip", "data1.zip"), + ... ("https://example.com/data2.zip", "data2.zip"), + ... ]) + >>> files["data1.zip"] + PosixPath('.cache/data1.zip') + """ + results = {} + for url, filename in downloads: + results[filename] = cached_download( + url=url, + filename=filename, + cache_dir=cache_dir, + use_cache=use_cache, + chunk_size=chunk_size + ) + return results + + +# ============================================================================ +# Specific ProteinGym Data Downloads +# ============================================================================ + +def get_dms_substitution_zip(cache_dir: Optional[Union[str, Path]] = None, use_cache: bool = True) -> Path: + """ + Download the DMS_ProteinGym_substitutions.zip file to the cache directory. Args: - cache_dir: Directory to store the downloaded file. + cache_dir: Directory to store the downloaded file (uses default if None) use_cache: If True, use cached file if it exists. If False, force a fresh download. Returns: Path to the downloaded zip file. """ - url = ( - "https://zenodo.org/records/15293562/files/ProteinGym_AF2_structures.zip?download=1" + return cached_download( + url="https://zenodo.org/records/15293562/files/DMS_ProteinGym_substitutions.zip", + filename="DMS_ProteinGym_substitutions.zip", + cache_dir=cache_dir, + use_cache=use_cache ) - os.makedirs(cache_dir, exist_ok=True) - zip_path = os.path.join(cache_dir, "ProteinGym_AF2_structures.zip") - if not use_cache or not os.path.exists(zip_path): - if os.path.exists(zip_path): - os.remove(zip_path) - print(f"Downloading {url} to {zip_path}...") - response = requests.get(url, stream=True) - response.raise_for_status() - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - print("Download complete.") - else: - print(f"Using cached file at {zip_path}.") - return zip_path \ No newline at end of file +def get_af2_structures_zip(cache_dir: Optional[Union[str, Path]] = None, use_cache: bool = True) -> Path: + """ + Download the ProteinGym_AF2_structures.zip file to the cache directory. + + Args: + cache_dir: Directory to store the downloaded file (uses default if None) + use_cache: If True, use cached file if it exists. If False, force a fresh download. + + Returns: + Path to the downloaded zip file. + """ + return cached_download( + url="https://zenodo.org/records/15293562/files/ProteinGym_AF2_structures.zip?download=1", + filename="ProteinGym_AF2_structures.zip", + cache_dir=cache_dir, + use_cache=use_cache + ) diff --git a/src/proteingympy/make_alphamissense_supplementary.py b/src/proteingympy/make_alphamissense_supplementary.py index 6ded604..a9ed5ff 100644 --- a/src/proteingympy/make_alphamissense_supplementary.py +++ b/src/proteingympy/make_alphamissense_supplementary.py @@ -6,14 +6,16 @@ import io import json -import os import re import zipfile from typing import Dict, Optional +from pathlib import Path import pandas as pd import requests +from .data_import_funcs import cached_download, get_cache_dir + def _is_proteingym_csv(member_name: str) -> bool: """Return True if the archive member looks like the ProteinGym CSV.""" @@ -97,14 +99,14 @@ def _query_uniprot_accessions(entry_names, session: Optional[requests.Session] = return results -def _add_uniprot_accessions(df: pd.DataFrame, cache_dir: str) -> pd.DataFrame: +def _add_uniprot_accessions(df: pd.DataFrame, cache_dir: Optional[str]) -> pd.DataFrame: """Augment AlphaMissense data with UniProt accessions via the UniProt REST API.""" entry_col = 'Uniprot_ID' if entry_col not in df.columns: return df - cache_path = os.path.join(cache_dir, "alphamissense_uniprot_mapping.json") - cached_mapping = _load_cached_uniprot_mapping(cache_path) + cache_path = get_cache_dir(cache_dir) / "alphamissense_uniprot_mapping.json" + cached_mapping = _load_cached_uniprot_mapping(str(cache_path)) entry_names = sorted(df[entry_col].dropna().unique()) missing = [entry for entry in entry_names if cached_mapping.get(entry) in (None, '')] @@ -122,7 +124,7 @@ def _add_uniprot_accessions(df: pd.DataFrame, cache_dir: str) -> pd.DataFrame: # Merge fetched results into cache cached_mapping.update({entry: fetched.get(entry, cached_mapping.get(entry)) for entry in missing}) - _save_uniprot_mapping(cache_path, cached_mapping) + _save_uniprot_mapping(str(cache_path), cached_mapping) df = df.copy() df['SwissProt_ID'] = df[entry_col] @@ -142,7 +144,7 @@ def _add_uniprot_accessions(df: pd.DataFrame, cache_dir: str) -> pd.DataFrame: } if derived_mapping: cached_mapping.update(derived_mapping) - _save_uniprot_mapping(cache_path, cached_mapping) + _save_uniprot_mapping(str(cache_path), cached_mapping) missing_mask = df[entry_col].isna() @@ -170,45 +172,47 @@ def _derive_accession_from_entry(entry_name: str) -> Optional[str]: return None -def get_alphamissense_proteingym_data(cache_dir: str = ".cache") -> pd.DataFrame: +def get_alphamissense_proteingym_data(cache_dir: str = None) -> pd.DataFrame: """ Download and process AlphaMissense supplementary data for ProteinGym variants. - - This loads Table S8 from Cheng et al. 2023 containing AlphaMissense pathogenicity - scores for ~1.6M variants that match those in ProteinGym from 87 DMS experiments + + This loads Table S8 from Cheng et al. 2023 containing AlphaMissense pathogenicity + scores for ~1.6M variants that match those in ProteinGym from 87 DMS experiments across 72 proteins. - + Args: - cache_dir: Directory to cache downloaded files - + cache_dir: Directory to cache downloaded files (uses default if None) + Returns: DataFrame with columns: - - DMS_id: DMS assay identifier + - DMS_id: DMS assay identifier - Uniprot_ID: UniProt accession (resolved via UniProt API) - SwissProt_ID: Original AlphaMissense SwissProt entry name - variant_id: Variant identifier - AlphaMissense: Pathogenicity score (0-1, higher = more pathogenic) """ - os.makedirs(cache_dir, exist_ok=True) - + # Get cache directory path + cache_path = get_cache_dir(cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + # File paths - csv_path = os.path.join(cache_dir, "Supplementary_Data_S8_proteingym.csv") + csv_path = cache_path / "Supplementary_Data_S8_proteingym.csv" #url = "https://www.science.org/doi/suppl/10.1126/science.adg7492/suppl_file/science.adg7492_data_s1_to_s9.zip" # Science is blocking requests with TLS fingerprinting, so we rely on a local copy # Preferred zip path is the copy bundled with the package at src/ - repo_zip_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "science.adg7492_data_s1_to_s9.zip")) - cache_zip_path = os.path.join(cache_dir, "science.adg7492_data_s1_to_s9.zip") + repo_zip_path = Path(__file__).parent.parent / "science.adg7492_data_s1_to_s9.zip" + cache_zip_path = cache_path / "science.adg7492_data_s1_to_s9.zip" # Prefer the repository zip file if present, otherwise fallback to cache - if os.path.exists(repo_zip_path): + if repo_zip_path.exists(): zip_path = repo_zip_path else: zip_path = cache_zip_path # Extract CSV if not present - if not os.path.exists(csv_path): - if not os.path.exists(zip_path): + if not csv_path.exists(): + if not zip_path.exists(): print(f"Zip file not found locally. Downloading from GitHub...") url = "https://github.com/ccb-hms/ProteinGymPy/blob/main/src/science.adg7492_data_s1_to_s9.zip?raw=true" try: @@ -222,9 +226,9 @@ def get_alphamissense_proteingym_data(cache_dir: str = ".cache") -> pd.DataFrame except Exception as e: print(f"Warning: Failed to download AlphaMissense data: {e}") - if os.path.exists(zip_path): + if zip_path.exists(): with zipfile.ZipFile(zip_path, 'r') as zip_ref: - extracted_name = _extract_proteingym_csv(zip_ref, csv_path) + extracted_name = _extract_proteingym_csv(zip_ref, str(csv_path)) if extracted_name: print(f"Extracted {extracted_name} from {zip_path} -> {csv_path}") else: diff --git a/src/proteingympy/make_dms_substitutions.py b/src/proteingympy/make_dms_substitutions.py index def0aa3..6a03ea0 100644 --- a/src/proteingympy/make_dms_substitutions.py +++ b/src/proteingympy/make_dms_substitutions.py @@ -5,51 +5,41 @@ Loads 217 DMS substitution assays with UniProt ID mapping. """ -import os import pandas as pd import requests from typing import Dict, List, Optional -import tempfile +from pathlib import Path import zipfile +from .data_import_funcs import cached_download, get_cache_dir -def get_dms_substitution_data(cache_dir: str = ".cache", use_cache: bool = True) -> Dict[str, pd.DataFrame]: + +def get_dms_substitution_data(cache_dir: str = None, use_cache: bool = True) -> Dict[str, pd.DataFrame]: """ Download and process ProteinGym DMS substitution data. - + Returns a dictionary of 217 DMS assays, each as a pandas DataFrame with columns: - - UniProt_id: UniProt accession identifier + - UniProt_id: UniProt accession identifier - DMS_id: DMS assay identifier - mutant: substitution description (e.g. A1P:D2N) - mutated_sequence: full amino acid sequence - DMS_score: experimental measurement (higher = more fit) - DMS_score_bin: binary fitness (1=fit, 0=not fit) - + Args: - cache_dir: Directory to cache downloaded files + cache_dir: Directory to cache downloaded files (uses default if None) use_cache: If True, use cached file if it exists. If False, force a fresh download. - + Returns: Dictionary mapping DMS study names to DataFrames """ - os.makedirs(cache_dir, exist_ok=True) - zip_path = os.path.join(cache_dir, "DMS_ProteinGym_substitutions.zip") - - # Download if not cached or if use_cache is False - if not use_cache or not os.path.exists(zip_path): - if os.path.exists(zip_path): - os.remove(zip_path) - url = "https://zenodo.org/records/15293562/files/DMS_ProteinGym_substitutions.zip" - print(f"Downloading {url} to {zip_path}...") - response = requests.get(url, stream=True) - response.raise_for_status() - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - print("Download complete.") - else: - print(f"Using cached file at {zip_path}.") + # Download using centralized caching utility + zip_path = cached_download( + url="https://zenodo.org/records/15293562/files/DMS_ProteinGym_substitutions.zip", + filename="DMS_ProteinGym_substitutions.zip", + cache_dir=cache_dir, + use_cache=use_cache + ) # Extract and load data progym_tables = {} @@ -194,30 +184,24 @@ def _get_basic_uniprot_mapping(entry_names: List[str]) -> Dict[str, Optional[str return mapping -def get_dms_metadata(cache_dir: str = ".cache") -> pd.DataFrame: +def get_dms_metadata(cache_dir: str = None) -> pd.DataFrame: """ Download and process DMS substitutions metadata/reference file. - + Args: - cache_dir: Directory to cache downloaded files - + cache_dir: Directory to cache downloaded files (uses default if None) + Returns: DataFrame with metadata for 217 DMS assays """ - os.makedirs(cache_dir, exist_ok=True) - metadata_path = os.path.join(cache_dir, "DMS_substitutions.csv") - - if not os.path.exists(metadata_path): - url = "https://zenodo.org/records/15293562/files/DMS_substitutions.csv" - print(f"Downloading metadata from {url}...") - response = requests.get(url, stream=True) - response.raise_for_status() - with open(metadata_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - print("Metadata download complete.") - + # Download using centralized caching utility + metadata_path = cached_download( + url="https://zenodo.org/records/15293562/files/DMS_substitutions.csv", + filename="DMS_substitutions.csv", + cache_dir=cache_dir, + use_cache=True # Metadata rarely changes, always use cache + ) + # Load and process metadata df = pd.read_csv(metadata_path) diff --git a/src/proteingympy/make_supervised_scores.py b/src/proteingympy/make_supervised_scores.py index 325ab5b..264db50 100644 --- a/src/proteingympy/make_supervised_scores.py +++ b/src/proteingympy/make_supervised_scores.py @@ -5,26 +5,26 @@ Handles contiguous_5, modulo_5, and random_5 fold types. """ -import os import pandas as pd import requests -import tempfile +from pathlib import Path import zipfile from typing import Dict, List, Optional, Tuple -import re + +from .data_import_funcs import cached_download, get_cache_dir def get_supervised_substitution_data( - fold_type: str = "random_5", - cache_dir: str = ".cache" + fold_type: str = "random_5", + cache_dir: str = None ) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]: """ Download and process raw ProteinGym supervised model substitution scores. - + Args: fold_type: Type of cross-validation fold ("contiguous_5", "modulo_5", or "random_5") - cache_dir: Directory to cache downloaded files - + cache_dir: Directory to cache downloaded files (uses default if None) + Returns: Tuple of (supervised_scores_dict, summary_metrics_df) - supervised_scores_dict: Dictionary mapping DMS assay names to DataFrames with model predictions @@ -32,33 +32,25 @@ def get_supervised_substitution_data( """ if fold_type not in ["contiguous_5", "modulo_5", "random_5"]: raise ValueError("fold_type must be one of: 'contiguous_5', 'modulo_5', 'random_5'") - - os.makedirs(cache_dir, exist_ok=True) - - # Download supervised scores data (this would need the actual URL from Zenodo v1.2) - zip_path = os.path.join(cache_dir, "DMS_supervised_substitutions_scores.zip") - - if not os.path.exists(zip_path): - url = "https://zenodo.org/records/14997691/files/DMS_supervised_substitutions_scores.zip?download=1" - print(f"Downloading supervised scores from {url}...") - - response = requests.get(url, stream=True) - response.raise_for_status() - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - print("Download complete.") - + + # Download supervised scores data using centralized caching utility + zip_path = cached_download( + url="https://zenodo.org/records/14997691/files/DMS_supervised_substitutions_scores.zip?download=1", + filename="DMS_supervised_substitutions_scores.zip", + cache_dir=cache_dir, + use_cache=True + ) + # Check if we need to extract summary metrics - summary_path = os.path.join(cache_dir, "merged_scores_substitutions_DMS.csv") - if not os.path.exists(summary_path) and os.path.exists(zip_path): + cache_path = get_cache_dir(cache_dir) + summary_path = cache_path / "merged_scores_substitutions_DMS.csv" + if not summary_path.exists(): try: with zipfile.ZipFile(zip_path, 'r') as zip_ref: # Check if file exists in zip (at root or in subfolder) target_file = "merged_scores_substitutions_DMS.csv" if target_file in zip_ref.namelist(): - zip_ref.extract(target_file, cache_dir) + zip_ref.extract(target_file, cache_path) except zipfile.BadZipFile: print(f"Warning: Could not read {zip_path} to extract summary metrics") @@ -270,45 +262,30 @@ def _clean_supervised_column_names(supervised_tables: Dict[str, pd.DataFrame]) - return cleaned_tables -def _load_from_zenodo_v12_supervised(cache_dir: str) -> pd.DataFrame: +def _load_from_zenodo_v12_supervised(cache_dir: Optional[str]) -> pd.DataFrame: """ Download and load the merged supervised DMS benchmark scores (Zenodo v1.2). - + Steps: - Downloads the zip file containing DMS scores if not already cached. - Opens the ZIP and loads 'merged_scores_substitutions_DMS.csv'. - + Args: - cache_dir (str): Directory where the ZIP file is stored or downloaded. - + cache_dir: Directory where the ZIP file is stored or downloaded (uses default if None) + Returns: pandas.DataFrame: The merged scores table. """ - - url = ( - "https://zenodo.org/records/14997691/files/" - "DMS_supervised_substitutions_scores.zip?download=1" + # Download using centralized caching utility + zip_path = cached_download( + url="https://zenodo.org/records/14997691/files/DMS_supervised_substitutions_scores.zip?download=1", + filename="DMS_supervised_substitutions_scores.zip", + cache_dir=cache_dir, + use_cache=True ) - zip_path = os.path.join(cache_dir, "DMS_supervised_substitutions_scores.zip") target_file = "DMS_supervised_substitutions_scores/merged_scores_substitutions_DMS.csv" - # -------------------------------------------------------------- - # 1. Download the ZIP if missing - # -------------------------------------------------------------- - if not os.path.exists(zip_path): - print(f"Downloading benchmark ZIP from: {url}") - response = requests.get(url, stream=True) - response.raise_for_status() - - os.makedirs(cache_dir, exist_ok=True) - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - - print("Download complete.") - # -------------------------------------------------------------- # 2. Open ZIP and load the target CSV # -------------------------------------------------------------- @@ -327,28 +304,22 @@ def _load_from_zenodo_v12_supervised(cache_dir: str) -> pd.DataFrame: return df -def get_supervised_metrics(cache_dir: str = ".cache") -> pd.DataFrame: +def get_supervised_metrics(cache_dir: str = None) -> pd.DataFrame: """ Load supervised DMS benchmark metrics (Zenodo v1.2). - + This function: - Ensures the cache directory exists - Downloads the benchmark ZIP if missing - Loads 'merged_scores_substitutions_DMS.csv' - Returns it as a pandas DataFrame - + Args: - cache_dir (str, optional): Directory to store or read cached files. - Defaults to ".cache". - + cache_dir: Directory to store or read cached files (uses default if None) + Returns: pandas.DataFrame: The merged supervised benchmark scores. """ - - # Normalize and create cache directory if necessary - cache_dir = os.path.abspath(cache_dir) - os.makedirs(cache_dir, exist_ok=True) - # Load merged scores via helper function benchmark_table = _load_from_zenodo_v12_supervised(cache_dir) diff --git a/src/proteingympy/make_zero_shot_substitutions.py b/src/proteingympy/make_zero_shot_substitutions.py index c3000b8..5f3df71 100644 --- a/src/proteingympy/make_zero_shot_substitutions.py +++ b/src/proteingympy/make_zero_shot_substitutions.py @@ -4,54 +4,43 @@ Downloads and processes ProteinGym zero-shot model scores for DMS substitution assays. """ -import os import pandas as pd import requests -import tempfile +from pathlib import Path import zipfile from typing import Dict, List, Optional, Any -import re +from .data_import_funcs import cached_download, get_cache_dir -def get_zero_shot_substitution_data(cache_dir: str = ".cache") -> Dict[str, pd.DataFrame]: + +def get_zero_shot_substitution_data(cache_dir: str = None) -> Dict[str, pd.DataFrame]: """ Download and process ProteinGym zero-shot model scores for DMS substitutions. - + This loads zero-shot model predictions across 217 DMS assays for multiple models. Each assay contains predictions from various protein language models and other zero-shot approaches. - + Args: - cache_dir: Directory to cache downloaded files - + cache_dir: Directory to cache downloaded files (uses default if None) + Returns: Dictionary mapping DMS assay names to DataFrames with columns: - UniProt_id: UniProt accession identifier - - DMS_id: DMS assay identifier + - DMS_id: DMS assay identifier - mutant: substitution description - mutated_sequence: full amino acid sequence - DMS_score: experimental measurement - DMS_score_bin: binary fitness classification - [model_name]: Prediction scores from various zero-shot models """ - os.makedirs(cache_dir, exist_ok=True) - - # Download zero-shot scores data - zip_path = os.path.join(cache_dir, "zero_shot_substitutions_scores.zip") - - if not os.path.exists(zip_path): - # URL from ProteinGym Zenodo v1.2 - url = "https://zenodo.org/records/14997691/files/zero_shot_substitutions_scores.zip?download=1" - print(f"Downloading zero-shot scores from {url}...") - response = requests.get(url, stream=True) - response.raise_for_status() - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - print("Download complete.") - else: - print(f"Zero-shot scores found in cache at {zip_path}") + # Download zero-shot scores data using centralized caching utility + zip_path = cached_download( + url="https://zenodo.org/records/14997691/files/zero_shot_substitutions_scores.zip?download=1", + filename="zero_shot_substitutions_scores.zip", + cache_dir=cache_dir, + use_cache=True + ) # Load zero-shot scores zeroshot_tables = _load_zero_shot_data(zip_path) diff --git a/src/proteingympy/make_zeroshot_dms_benchmarks.py b/src/proteingympy/make_zeroshot_dms_benchmarks.py index 1b2e9b1..884ef39 100644 --- a/src/proteingympy/make_zeroshot_dms_benchmarks.py +++ b/src/proteingympy/make_zeroshot_dms_benchmarks.py @@ -5,46 +5,45 @@ Handles 5 performance metrics: Spearman, AUC, MCC, NDCG, and Top_recall. """ -import os import pandas as pd import requests -import tempfile +from pathlib import Path import zipfile from typing import Dict, List, Optional, Any import numpy as np +from .data_import_funcs import cached_download, get_cache_dir -def get_zero_shot_metrics(cache_dir: str = ".cache") -> Dict[str, pd.DataFrame]: + +def get_zero_shot_metrics(cache_dir: str = None) -> Dict[str, pd.DataFrame]: """ Download and process ProteinGym zero-shot benchmarking metrics. - + This loads performance metrics for zero-shot models across 217 DMS assays. The benchmarking uses 5 metrics to evaluate model performance in predicting experimental DMS measurements without training on the specific assay labels. - + Metrics included: 1. Spearman's rank correlation coefficient (primary metric) - 2. Area Under the ROC Curve (AUC) + 2. Area Under the ROC Curve (AUC) 3. Matthews Correlation Coefficient (MCC) for bimodal measurements 4. Normalized Discounted Cumulative Gains (NDCG) for identifying top variants 5. Top K Recall (top 10% of DMS values) - + Args: - cache_dir: Directory to cache downloaded files - + cache_dir: Directory to cache downloaded files (uses default if None) + Returns: Dictionary with 5 entries (one per metric), each containing a DataFrame with: - Rows: 217 DMS assays - Columns: Model performance scores (79 models in v1.2) """ - os.makedirs(cache_dir, exist_ok=True) - # Option 1: Load from GitHub (older approach with 62 models) # benchmark_data = _load_from_github() - + # Option 2: Load from Zenodo v1.2 (79 models) benchmark_data = _load_from_zenodo_v12(cache_dir) - + return benchmark_data @@ -83,29 +82,23 @@ def _load_from_github() -> Dict[str, pd.DataFrame]: return score_list -def _load_from_zenodo_v12(cache_dir: str) -> Dict[str, pd.DataFrame]: +def _load_from_zenodo_v12(cache_dir: Optional[str]) -> Dict[str, pd.DataFrame]: """ Load benchmark data from Zenodo v1.2 repository (79 models). - + Args: - cache_dir: Directory to cache downloaded files - + cache_dir: Directory to cache downloaded files (uses default if None) + Returns: Dictionary with 5 DataFrames for each metric """ - zip_path = os.path.join(cache_dir, "DMS_benchmarks_performance.zip") - - if not os.path.exists(zip_path): - # URL from ProteinGym Zenodo v1.2 - url = "https://zenodo.org/records/14997691/files/DMS_benchmark_performance.zip?download=1" - print(f"Downloading benchmarks from {url}...") - response = requests.get(url, stream=True) - response.raise_for_status() - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - print("Download complete.") + # Download using centralized caching utility + zip_path = cached_download( + url="https://zenodo.org/records/14997691/files/DMS_benchmark_performance.zip?download=1", + filename="DMS_benchmarks_performance.zip", + cache_dir=cache_dir, + use_cache=True + ) # Extract and load benchmark files score_list = {}