Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/proteingympy/data_import_funcs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import requests
from typing import Dict, List, Optional

def get_dms_substitution_zip(cache_dir: str = ".cache/", use_cache: bool = True) -> str:
"""Download the DMS_ProteinGym_substitutions.zip file to the cache directory.
Expand Down Expand Up @@ -61,4 +62,74 @@ def get_af2_structures_zip(cache_dir: str = ".cache/", use_cache: bool = True) -
else:
print(f"Using cached file at {zip_path}.")

return zip_path
return zip_path


def _query_uniprot_api(entry_names: List[str]) -> Dict[str, Optional[str]]:
"""
Map UniProt entry names to accession IDs using UniProt REST API.

Args:
entry_names: List of UniProt entry names (e.g., 'P53_HUMAN')

Returns:
Dictionary mapping entry name to UniProt accession ID
"""
mapping = {}

# Filter out special cases and duplicates
unique_names = list(set(entry_names))
names_to_query = []

for name in unique_names:
if name == "ANCSZ_Hobbs":
mapping[name] = None
else:
names_to_query.append(name)

if not names_to_query:
return mapping

# Batch queries to avoid URL length limits
batch_size = 50
base_url = "https://rest.uniprot.org/uniprotkb/search"

print(f"Querying UniProt API for {len(names_to_query)} entries...")

for i in range(0, len(names_to_query), batch_size):
batch = names_to_query[i:i+batch_size]

# Construct query: id:NAME1 OR id:NAME2 ...
query_parts = [f"id:{name}" for name in batch]
query = " OR ".join(query_parts)

params = {
"query": query,
"fields": "accession,id",
"format": "json",
"size": len(batch)
}

try:
response = requests.get(base_url, params=params)
response.raise_for_status()

results = response.json().get("results", [])

for result in results:
# API returns 'primaryAccession' and 'uniProtkbId' (entry name)
accession = result.get("primaryAccession")
entry_name = result.get("uniProtkbId")

if entry_name and accession:
mapping[entry_name] = accession

except Exception as e:
print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}")

# Ensure all requested names are in the mapping (None if not found)
for name in entry_names:
if name not in mapping:
mapping[name] = None

return mapping
77 changes: 3 additions & 74 deletions src/proteingympy/make_dms_substitutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Dict, List, Optional
import tempfile
import zipfile
from .data_import_funcs import _query_uniprot_api


def get_dms_substitution_data(cache_dir: str = ".cache", use_cache: bool = True) -> Dict[str, pd.DataFrame]:
Expand Down Expand Up @@ -107,10 +108,8 @@ def _add_uniprot_ids(progym_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.Dat
entry_name = f"{parts[0]}_{parts[1]}" if len(parts) >= 2 else parts[0]
entry_names.append(entry_name)

# Create mapping - this is a simplified version
# In a full implementation, you would use the UniProt API or a mapping service
# For now, we'll create a basic mapping based on the R script's manual curation
uniprot_mapping = _get_basic_uniprot_mapping(entry_names)
# Create mapping using UniProt API
uniprot_mapping = _query_uniprot_api(entry_names)

# Add UniProt_id to each DataFrame
updated_tables = {}
Expand All @@ -124,76 +123,6 @@ def _add_uniprot_ids(progym_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.Dat
return updated_tables


def _get_basic_uniprot_mapping(entry_names: List[str]) -> Dict[str, Optional[str]]:
"""
Map UniProt entry names to accession IDs using UniProt REST API.

Args:
entry_names: List of UniProt entry names (e.g., 'P53_HUMAN')

Returns:
Dictionary mapping entry name to UniProt accession ID
"""
mapping = {}

# Filter out special cases and duplicates
unique_names = list(set(entry_names))
names_to_query = []

for name in unique_names:
if name == "ANCSZ_Hobbs":
mapping[name] = None
else:
names_to_query.append(name)

if not names_to_query:
return mapping

# Batch queries to avoid URL length limits
batch_size = 50
base_url = "https://rest.uniprot.org/uniprotkb/search"

print(f"Querying UniProt API for {len(names_to_query)} entries...")

for i in range(0, len(names_to_query), batch_size):
batch = names_to_query[i:i+batch_size]

# Construct query: id:NAME1 OR id:NAME2 ...
query_parts = [f"id:{name}" for name in batch]
query = " OR ".join(query_parts)

params = {
"query": query,
"fields": "accession,id",
"format": "json",
"size": len(batch)
}

try:
response = requests.get(base_url, params=params)
response.raise_for_status()

results = response.json().get("results", [])

for result in results:
# API returns 'primaryAccession' and 'uniProtkbId' (entry name)
accession = result.get("primaryAccession")
entry_name = result.get("uniProtkbId")

if entry_name and accession:
mapping[entry_name] = accession

except Exception as e:
print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}")

# Ensure all requested names are in the mapping (None if not found)
for name in entry_names:
if name not in mapping:
mapping[name] = None

return mapping


def get_dms_metadata(cache_dir: str = ".cache") -> pd.DataFrame:
"""
Download and process DMS substitutions metadata/reference file.
Expand Down
71 changes: 3 additions & 68 deletions src/proteingympy/make_supervised_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import zipfile
from typing import Dict, List, Optional, Tuple
import re
from .data_import_funcs import _query_uniprot_api


def get_supervised_substitution_data(
Expand Down Expand Up @@ -136,8 +137,8 @@ def _add_uniprot_ids_supervised(supervised_tables: Dict[str, pd.DataFrame]) -> D
entry_name = f"{parts[0]}_{parts[1]}" if len(parts) >= 2 else parts[0]
entry_names.append(entry_name)

# Create basic mapping (would use UniProt API in practice)
uniprot_mapping = _get_basic_uniprot_mapping(entry_names)
# Create basic mapping using UniProt API
uniprot_mapping = _query_uniprot_api(entry_names)

# Add UniProt_id to each DataFrame
updated_tables = {}
Expand All @@ -151,72 +152,6 @@ def _add_uniprot_ids_supervised(supervised_tables: Dict[str, pd.DataFrame]) -> D
return updated_tables


def _get_basic_uniprot_mapping(entry_names: List[str]) -> Dict[str, Optional[str]]:
"""
Get UniProt accession IDs for a list of entry names using UniProt API.

Args:
entry_names: List of UniProt entry names (e.g., "P53_HUMAN")

Returns:
Dictionary mapping entry names to UniProt accession IDs
"""
mapping = {}
names_to_query = []

# Filter out names we know won't be found or handle special cases
for name in set(entry_names):
if name == "ANCSZ_Hobbs":
mapping[name] = None
else:
names_to_query.append(name)

if not names_to_query:
return mapping

# Batch queries to avoid URL length limits
batch_size = 50
base_url = "https://rest.uniprot.org/uniprotkb/search"

print(f"Querying UniProt API for {len(names_to_query)} entries...")

for i in range(0, len(names_to_query), batch_size):
batch = names_to_query[i:i+batch_size]

# Construct query: id:NAME1 OR id:NAME2 ...
query_parts = [f"id:{name}" for name in batch]
query = " OR ".join(query_parts)

params = {
"query": query,
"fields": "accession,id",
"format": "json",
"size": len(batch)
}

try:
response = requests.get(base_url, params=params)
response.raise_for_status()

results = response.json().get("results", [])

for result in results:
# API returns 'primaryAccession' and 'uniProtkbId' (entry name)
accession = result.get("primaryAccession")
entry_name = result.get("uniProtkbId")

if entry_name and accession:
mapping[entry_name] = accession

except Exception as e:
print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}")

# Ensure all requested names are in the mapping (None if not found)
for name in entry_names:
if name not in mapping:
mapping[name] = None

return mapping


def _clean_supervised_column_names(supervised_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]:
Expand Down
74 changes: 4 additions & 70 deletions src/proteingympy/make_zero_shot_substitutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import zipfile
from typing import Dict, List, Optional, Any
import re
from .data_import_funcs import _query_uniprot_api


def get_zero_shot_substitution_data(cache_dir: str = ".cache") -> Dict[str, pd.DataFrame]:
Expand Down Expand Up @@ -121,8 +122,8 @@ def _add_uniprot_ids_zeroshot(zeroshot_tables: Dict[str, pd.DataFrame]) -> Dict[
entry_name = f"{parts[0]}_{parts[1]}" if len(parts) >= 2 else parts[0]
entry_names.append(entry_name)

# Create basic mapping (would use UniProt API in practice)
uniprot_mapping = _get_basic_uniprot_mapping(entry_names)
# Create basic mapping using UniProt API
uniprot_mapping = _query_uniprot_api(entry_names)

# Add UniProt_id to each DataFrame
updated_tables = {}
Expand All @@ -136,74 +137,7 @@ def _add_uniprot_ids_zeroshot(zeroshot_tables: Dict[str, pd.DataFrame]) -> Dict[
return updated_tables


def _get_basic_uniprot_mapping(entry_names: List[str]) -> Dict[str, Optional[str]]:
"""
Map UniProt entry names to accession IDs using UniProt REST API.

Args:
entry_names: List of UniProt entry names (e.g., 'P53_HUMAN')

Returns:
Dictionary mapping entry name to UniProt accession ID
"""
mapping = {}

# Filter out special cases and duplicates
unique_names = list(set(entry_names))
names_to_query = []

for name in unique_names:
if name == "ANCSZ_Hobbs":
mapping[name] = None
else:
names_to_query.append(name)

if not names_to_query:
return mapping

# Batch queries to avoid URL length limits
batch_size = 50
base_url = "https://rest.uniprot.org/uniprotkb/search"

print(f"Querying UniProt API for {len(names_to_query)} entries...")

for i in range(0, len(names_to_query), batch_size):
batch = names_to_query[i:i+batch_size]

# Construct query: id:NAME1 OR id:NAME2 ...
query_parts = [f"id:{name}" for name in batch]
query = " OR ".join(query_parts)

params = {
"query": query,
"fields": "accession,id",
"format": "json",
"size": len(batch)
}

try:
response = requests.get(base_url, params=params)
response.raise_for_status()

results = response.json().get("results", [])

for result in results:
# API returns 'primaryAccession' and 'uniProtkbId' (entry name)
accession = result.get("primaryAccession")
entry_name = result.get("uniProtkbId")

if entry_name and accession:
mapping[entry_name] = accession

except Exception as e:
print(f"Error querying UniProt API for batch {i//batch_size + 1}: {e}")

# Ensure all requested names are in the mapping (None if not found)
for name in entry_names:
if name not in mapping:
mapping[name] = None

return mapping



def _clean_zeroshot_column_names(zeroshot_tables: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]:
Expand Down