diff --git a/src/proteingympy/data_import_funcs.py b/src/proteingympy/data_import_funcs.py index 04a6e6c..6af815e 100644 --- a/src/proteingympy/data_import_funcs.py +++ b/src/proteingympy/data_import_funcs.py @@ -1,5 +1,6 @@ import os import requests +from typing import Dict, List, Optional def get_dms_substitution_zip(cache_dir: str = ".cache/", use_cache: bool = True) -> str: """Download the DMS_ProteinGym_substitutions.zip file to the cache directory. @@ -61,4 +62,74 @@ def get_af2_structures_zip(cache_dir: str = ".cache/", use_cache: bool = True) - else: print(f"Using cached file at {zip_path}.") - return zip_path \ No newline at end of file + return zip_path + + +def _query_uniprot_api(entry_names: List[str]) -> Dict[str, Optional[str]]: + """ + Map UniProt entry names to accession IDs using UniProt REST API. + + Args: + entry_names: List of UniProt entry names (e.g., 'P53_HUMAN') + + Returns: + Dictionary mapping entry name to UniProt accession ID + """ + mapping = {} + + # Filter out special cases and duplicates + unique_names = list(set(entry_names)) + names_to_query = [] + + for name in unique_names: + if name == "ANCSZ_Hobbs": + mapping[name] = None + else: + names_to_query.append(name) + + if not names_to_query: + return mapping + + # Batch queries to avoid URL length limits + batch_size = 50 + base_url = "https://rest.uniprot.org/uniprotkb/search" + + print(f"Querying UniProt API for {len(names_to_query)} entries...") + + for i in range(0, len(names_to_query), batch_size): + batch = names_to_query[i:i+batch_size] + + # Construct query: id:NAME1 OR id:NAME2 ... + query_parts = [f"id:{name}" for name in batch] + query = " OR ".join(query_parts) + + params = { + "query": query, + "fields": "accession,id", + "format": "json", + "size": len(batch) + } + + try: + response = requests.get(base_url, params=params) + response.raise_for_status() + + results = response.json().get("results", []) + + for result in results: + # API returns 'primaryAccession' and 'uniProtkbId' (entry name) + accession = result.get("primaryAccession") + entry_name = result.get("uniProtkbId") + + if entry_name and accession: + mapping[entry_name] = accession + + except Exception as e: + print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}") + + # Ensure all requested names are in the mapping (None if not found) + for name in entry_names: + if name not in mapping: + mapping[name] = None + + return mapping \ No newline at end of file diff --git a/src/proteingympy/make_dms_substitutions.py b/src/proteingympy/make_dms_substitutions.py index def0aa3..42f7782 100644 --- a/src/proteingympy/make_dms_substitutions.py +++ b/src/proteingympy/make_dms_substitutions.py @@ -11,6 +11,7 @@ from typing import Dict, List, Optional import tempfile import zipfile +from .data_import_funcs import _query_uniprot_api def get_dms_substitution_data(cache_dir: str = ".cache", use_cache: bool = True) -> Dict[str, pd.DataFrame]: @@ -107,10 +108,8 @@ def _add_uniprot_ids(progym_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.Dat entry_name = f"{parts[0]}_{parts[1]}" if len(parts) >= 2 else parts[0] entry_names.append(entry_name) - # Create mapping - this is a simplified version - # In a full implementation, you would use the UniProt API or a mapping service - # For now, we'll create a basic mapping based on the R script's manual curation - uniprot_mapping = _get_basic_uniprot_mapping(entry_names) + # Create mapping using UniProt API + uniprot_mapping = _query_uniprot_api(entry_names) # Add UniProt_id to each DataFrame updated_tables = {} @@ -124,76 +123,6 @@ def _add_uniprot_ids(progym_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.Dat return updated_tables -def _get_basic_uniprot_mapping(entry_names: List[str]) -> Dict[str, Optional[str]]: - """ - Map UniProt entry names to accession IDs using UniProt REST API. - - Args: - entry_names: List of UniProt entry names (e.g., 'P53_HUMAN') - - Returns: - Dictionary mapping entry name to UniProt accession ID - """ - mapping = {} - - # Filter out special cases and duplicates - unique_names = list(set(entry_names)) - names_to_query = [] - - for name in unique_names: - if name == "ANCSZ_Hobbs": - mapping[name] = None - else: - names_to_query.append(name) - - if not names_to_query: - return mapping - - # Batch queries to avoid URL length limits - batch_size = 50 - base_url = "https://rest.uniprot.org/uniprotkb/search" - - print(f"Querying UniProt API for {len(names_to_query)} entries...") - - for i in range(0, len(names_to_query), batch_size): - batch = names_to_query[i:i+batch_size] - - # Construct query: id:NAME1 OR id:NAME2 ... - query_parts = [f"id:{name}" for name in batch] - query = " OR ".join(query_parts) - - params = { - "query": query, - "fields": "accession,id", - "format": "json", - "size": len(batch) - } - - try: - response = requests.get(base_url, params=params) - response.raise_for_status() - - results = response.json().get("results", []) - - for result in results: - # API returns 'primaryAccession' and 'uniProtkbId' (entry name) - accession = result.get("primaryAccession") - entry_name = result.get("uniProtkbId") - - if entry_name and accession: - mapping[entry_name] = accession - - except Exception as e: - print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}") - - # Ensure all requested names are in the mapping (None if not found) - for name in entry_names: - if name not in mapping: - mapping[name] = None - - return mapping - - def get_dms_metadata(cache_dir: str = ".cache") -> pd.DataFrame: """ Download and process DMS substitutions metadata/reference file. diff --git a/src/proteingympy/make_supervised_scores.py b/src/proteingympy/make_supervised_scores.py index 325ab5b..c0600b7 100644 --- a/src/proteingympy/make_supervised_scores.py +++ b/src/proteingympy/make_supervised_scores.py @@ -12,6 +12,7 @@ import zipfile from typing import Dict, List, Optional, Tuple import re +from .data_import_funcs import _query_uniprot_api def get_supervised_substitution_data( @@ -136,8 +137,8 @@ def _add_uniprot_ids_supervised(supervised_tables: Dict[str, pd.DataFrame]) -> D entry_name = f"{parts[0]}_{parts[1]}" if len(parts) >= 2 else parts[0] entry_names.append(entry_name) - # Create basic mapping (would use UniProt API in practice) - uniprot_mapping = _get_basic_uniprot_mapping(entry_names) + # Create basic mapping using UniProt API + uniprot_mapping = _query_uniprot_api(entry_names) # Add UniProt_id to each DataFrame updated_tables = {} @@ -151,72 +152,6 @@ def _add_uniprot_ids_supervised(supervised_tables: Dict[str, pd.DataFrame]) -> D return updated_tables -def _get_basic_uniprot_mapping(entry_names: List[str]) -> Dict[str, Optional[str]]: - """ - Get UniProt accession IDs for a list of entry names using UniProt API. - - Args: - entry_names: List of UniProt entry names (e.g., "P53_HUMAN") - - Returns: - Dictionary mapping entry names to UniProt accession IDs - """ - mapping = {} - names_to_query = [] - - # Filter out names we know won't be found or handle special cases - for name in set(entry_names): - if name == "ANCSZ_Hobbs": - mapping[name] = None - else: - names_to_query.append(name) - - if not names_to_query: - return mapping - - # Batch queries to avoid URL length limits - batch_size = 50 - base_url = "https://rest.uniprot.org/uniprotkb/search" - - print(f"Querying UniProt API for {len(names_to_query)} entries...") - - for i in range(0, len(names_to_query), batch_size): - batch = names_to_query[i:i+batch_size] - - # Construct query: id:NAME1 OR id:NAME2 ... - query_parts = [f"id:{name}" for name in batch] - query = " OR ".join(query_parts) - - params = { - "query": query, - "fields": "accession,id", - "format": "json", - "size": len(batch) - } - - try: - response = requests.get(base_url, params=params) - response.raise_for_status() - - results = response.json().get("results", []) - - for result in results: - # API returns 'primaryAccession' and 'uniProtkbId' (entry name) - accession = result.get("primaryAccession") - entry_name = result.get("uniProtkbId") - - if entry_name and accession: - mapping[entry_name] = accession - - except Exception as e: - print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}") - - # Ensure all requested names are in the mapping (None if not found) - for name in entry_names: - if name not in mapping: - mapping[name] = None - - return mapping def _clean_supervised_column_names(supervised_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]: diff --git a/src/proteingympy/make_zero_shot_substitutions.py b/src/proteingympy/make_zero_shot_substitutions.py index c3000b8..184d190 100644 --- a/src/proteingympy/make_zero_shot_substitutions.py +++ b/src/proteingympy/make_zero_shot_substitutions.py @@ -11,6 +11,7 @@ import zipfile from typing import Dict, List, Optional, Any import re +from .data_import_funcs import _query_uniprot_api def get_zero_shot_substitution_data(cache_dir: str = ".cache") -> Dict[str, pd.DataFrame]: @@ -121,8 +122,8 @@ def _add_uniprot_ids_zeroshot(zeroshot_tables: Dict[str, pd.DataFrame]) -> Dict[ entry_name = f"{parts[0]}_{parts[1]}" if len(parts) >= 2 else parts[0] entry_names.append(entry_name) - # Create basic mapping (would use UniProt API in practice) - uniprot_mapping = _get_basic_uniprot_mapping(entry_names) + # Create basic mapping using UniProt API + uniprot_mapping = _query_uniprot_api(entry_names) # Add UniProt_id to each DataFrame updated_tables = {} @@ -136,74 +137,7 @@ def _add_uniprot_ids_zeroshot(zeroshot_tables: Dict[str, pd.DataFrame]) -> Dict[ return updated_tables -def _get_basic_uniprot_mapping(entry_names: List[str]) -> Dict[str, Optional[str]]: - """ - Map UniProt entry names to accession IDs using UniProt REST API. - - Args: - entry_names: List of UniProt entry names (e.g., 'P53_HUMAN') - - Returns: - Dictionary mapping entry name to UniProt accession ID - """ - mapping = {} - - # Filter out special cases and duplicates - unique_names = list(set(entry_names)) - names_to_query = [] - - for name in unique_names: - if name == "ANCSZ_Hobbs": - mapping[name] = None - else: - names_to_query.append(name) - - if not names_to_query: - return mapping - - # Batch queries to avoid URL length limits - batch_size = 50 - base_url = "https://rest.uniprot.org/uniprotkb/search" - - print(f"Querying UniProt API for {len(names_to_query)} entries...") - - for i in range(0, len(names_to_query), batch_size): - batch = names_to_query[i:i+batch_size] - - # Construct query: id:NAME1 OR id:NAME2 ... - query_parts = [f"id:{name}" for name in batch] - query = " OR ".join(query_parts) - - params = { - "query": query, - "fields": "accession,id", - "format": "json", - "size": len(batch) - } - - try: - response = requests.get(base_url, params=params) - response.raise_for_status() - - results = response.json().get("results", []) - - for result in results: - # API returns 'primaryAccession' and 'uniProtkbId' (entry name) - accession = result.get("primaryAccession") - entry_name = result.get("uniProtkbId") - - if entry_name and accession: - mapping[entry_name] = accession - - except Exception as e: - print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}") - - # Ensure all requested names are in the mapping (None if not found) - for name in entry_names: - if name not in mapping: - mapping[name] = None - - return mapping + def _clean_zeroshot_column_names(zeroshot_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]: