Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<img src="_assets/logo.png" width="50"/>

[ESM3](https://www.science.org/doi/10.1126/science.ads0018) &sdot; [ESM C](https://www.evolutionaryscale.ai/blog/esm-cambrian) &sdot;
[Slack](https://bit.ly/esm-slack) &sdot; [Tutorials](https://github.com/evolutionaryscale/esm/tree/main/cookbook/tutorials) <br>
[Slack](https://bit.ly/3FKwcWd) &sdot; [Tutorials](https://github.com/evolutionaryscale/esm/tree/main/cookbook/tutorials) <br>
</div>


Expand Down
61 changes: 59 additions & 2 deletions cookbook/snippets/esmc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import math
import os

import torch

from esm.models.esmc import ESMC
from esm.sdk import client
from esm.sdk import batch_executor, client
from esm.sdk.api import (
ESMCInferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
LogitsConfig,
LogitsOutput,
)
from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.forge import ESM3ForgeInferenceClient, ESMCForgeInferenceClient
from esm.tokenization import get_esmc_model_tokenizers


def main(client: ESMCInferenceClient | ESM3ForgeInferenceClient):
Expand Down Expand Up @@ -70,6 +75,58 @@ def raw_forward(model: ESMC):
)


def compute_pseudoperplexity(
forge_client: ESMCForgeInferenceClient, sequence: str
) -> float:
"""Compute L-pass pseudoperplexity for a protein sequence via Forge.

Masks each position one at a time, retrieves logits from Forge, and returns
exp(-mean(log_prob_true_aa)). Uses batch_executor for parallel requests.

Example::

forge_client = ESMCForgeInferenceClient(
model="esmc-6b-2024-12",
url="https://forge.evolutionaryscale.ai",
token=os.environ["ESM_API_KEY"],
)
pppl = compute_pseudoperplexity(forge_client, "MKTLLILAVL...")
"""
L = len(sequence)
masked_sequences = [sequence[:i] + "_" + sequence[i + 1 :] for i in range(L)]

def _get_logits(client: ESMCForgeInferenceClient, sequence: str) -> LogitsOutput:
protein = ESMProtein(sequence=sequence)
protein_tensor = client.encode(protein)
if isinstance(protein_tensor, ESMProteinError):
raise protein_tensor
output = client.logits(protein_tensor, LogitsConfig(sequence=True))
if isinstance(output, ESMProteinError):
raise output
return output

with batch_executor() as executor:
logit_outputs = executor.execute_batch(
_get_logits, client=forge_client, sequence=masked_sequences
)

# Build vocab from the tokenizer to map amino acid characters to token indices
vocab: dict[str, int] = get_esmc_model_tokenizers().get_vocab()

log_probs = []
for i in range(L):
output = logit_outputs[i]
if isinstance(output, Exception):
raise output
logits = output.logits.sequence # shape: (L+2, V)
position_logits = logits[i + 1] # +1 for BOS token
log_softmax = torch.log_softmax(position_logits, dim=-1)
true_aa_idx = vocab[sequence[i]]
log_probs.append(log_softmax[true_aa_idx].item())

return math.exp(-sum(log_probs) / L)


if __name__ == "__main__":
if os.environ.get("ESM_API_KEY", ""):
print("ESM_API_KEY found. Trying to use model from Forge...")
Expand Down
2 changes: 1 addition & 1 deletion esm/models/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def logits(
if device.type == "cuda"
else contextlib.nullcontext(),
):
output = self.forward(
output = self(
sequence_tokens=input.sequence,
structure_tokens=input.structure,
ss8_tokens=input.secondary_structure,
Expand Down
11 changes: 8 additions & 3 deletions esm/models/esmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,18 @@ def __init__(

@classmethod
def from_pretrained(
cls, model_name: str = ESMC_600M, device: torch.device | None = None
cls,
model_name: str = ESMC_600M,
device: torch.device | None = None,
use_flash_attn: bool = True,
) -> ESMC:
from esm.pretrained import load_local_model

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_local_model(model_name, device=device)
model = load_local_model(
model_name, device=device, use_flash_attn=use_flash_attn
)
if device.type != "cpu":
model = model.to(torch.bfloat16)
assert isinstance(model, ESMC)
Expand Down Expand Up @@ -208,7 +213,7 @@ def logits(
if device.type == "cuda"
else contextlib.nullcontext(),
):
output = self.forward(sequence_tokens=input.sequence)
output = self(sequence_tokens=input.sequence)
assert output.hidden_states is not None
output.hidden_states = (
output.hidden_states[config.ith_hidden_layer : config.ith_hidden_layer + 1]
Expand Down
44 changes: 28 additions & 16 deletions esm/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import inspect
from typing import Callable

import torch
import torch.nn as nn
from accelerate import init_empty_weights

from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
Expand All @@ -22,42 +24,45 @@


def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
with torch.device(device):
with init_empty_weights():
model = StructureTokenEncoder(
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
).eval()
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_structure_encoder_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, assign=True)
model = model.to(device)
return model


def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
with torch.device(device):
with init_empty_weights():
model = StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).eval()
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_structure_decoder_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, assign=True)
model = model.to(device)
return model


def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
with torch.device(device):
with init_empty_weights():
model = FunctionTokenDecoder().eval()
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_function_decoder_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, assign=True)
model = model.to(device)
return model


def ESMC_300M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = True):
with torch.device(device):
with init_empty_weights():
model = ESMC(
d_model=960,
n_heads=15,
Expand All @@ -69,13 +74,13 @@ def ESMC_300M_202412(device: torch.device | str = "cpu", use_flash_attn: bool =
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)

model.load_state_dict(state_dict, assign=True)
model = model.to(device)
return model


def ESMC_600M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = True):
with torch.device(device):
with init_empty_weights():
model = ESMC(
d_model=1152,
n_heads=18,
Expand All @@ -87,13 +92,13 @@ def ESMC_600M_202412(device: torch.device | str = "cpu", use_flash_attn: bool =
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
map_location=device,
)
model.load_state_dict(state_dict)

model.load_state_dict(state_dict, assign=True)
model = model.to(device)
return model


def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
with torch.device(device):
with init_empty_weights():
model = ESM3(
d_model=1536,
n_heads=24,
Expand All @@ -107,7 +112,8 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device
)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, assign=True)
model = model.to(device)
return model


Expand All @@ -122,11 +128,17 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):


def load_local_model(
model_name: str, device: torch.device = torch.device("cpu")
model_name: str,
device: torch.device = torch.device("cpu"),
use_flash_attn: bool = True,
) -> nn.Module:
if model_name not in LOCAL_MODEL_REGISTRY:
raise ValueError(f"Model {model_name} not found in local model registry.")
return LOCAL_MODEL_REGISTRY[model_name](device)
builder = LOCAL_MODEL_REGISTRY[model_name]
kwargs = {}
if "use_flash_attn" in inspect.signature(builder).parameters:
kwargs["use_flash_attn"] = use_flash_attn
return builder(device, **kwargs)


# Register custom versions of ESM3 for use with the local inference API
Expand Down
37 changes: 32 additions & 5 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,29 @@ def __len__(self):
def from_pdb(
cls,
path: PathOrBuffer,
chain_id: str = "detect",
chain_id: str = "all",
id: str | None = None,
is_predicted: bool = False,
) -> ESMProtein:
protein_chain = ProteinChain.from_pdb(
path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
)
return cls.from_protein_chain(protein_chain)
"""Return an ESMProtein object from a pdb file.

Args:
path (str | Path | io.TextIO): Path or buffer to read pdb file from. Should be uncompressed.
chain_id (str, optional): Select a chain corresponding to (author) chain id. "all" uses all chains,
"detect" uses the first detected chain
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
"""
if chain_id == "all":
protein_complex = ProteinComplex.from_pdb(
path=path, id=id, is_predicted=is_predicted
)
return cls.from_protein_complex(protein_complex)
else:
protein_chain = ProteinChain.from_pdb(
path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
)
return cls.from_protein_chain(protein_chain)

@classmethod
def from_protein_chain(
Expand Down Expand Up @@ -396,6 +411,16 @@ class LogitsConfig:
return_mean_hidden_states: bool = False
ith_hidden_layer: int = -1

# SAE config only applies to ESMC models
sae_config: SAEConfig | None = None


@define
class SAEConfig:
model: str
normalize_features: bool = True
mode: str | None = None


@define
class LogitsOutput:
Expand All @@ -409,6 +434,8 @@ class LogitsOutput:
residue_annotation_logits: torch.Tensor | None = None
hidden_states: torch.Tensor | None = None
mean_hidden_state: torch.Tensor | None = None
# sae_outputs keys are sae model names and values are sparse representations of the sae activations
sae_outputs: dict[str, torch.Tensor] | None = None


@define
Expand Down
30 changes: 28 additions & 2 deletions esm/sdk/forge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import base64
import pickle
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Sequence
from typing import Any, Sequence, cast

import torch

Expand Down Expand Up @@ -720,7 +720,7 @@ def batch_generate(
try:
results.append(future.result())
except Exception as e:
results.append(ESMProteinError(500, str(e)))
results.append(ESMProteinError(error_code=500, error_msg=str(e)))
return results

async def __async_generate_protein(
Expand Down Expand Up @@ -967,6 +967,13 @@ def _process_logits_request(
"return_mean_hidden_states": config.return_mean_hidden_states,
"return_hidden_states": config.return_hidden_states,
"ith_hidden_layer": config.ith_hidden_layer,
"sae_config": {
"model": config.sae_config.model,
"normalize_features": config.sae_config.normalize_features,
"mode": config.sae_config.mode,
}
if config.sae_config
else None,
}
request = {"model": model_name, "inputs": req, "logits_config": logits_config}
return request
Expand All @@ -980,12 +987,31 @@ def _process_logits_response(
data["embeddings"] = _maybe_b64_decode(data["embeddings"], return_bytes)
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)

# sae outputs are always encoded
# NOTE: leave this intact for application/json since it's always
# base64 bytes, even with return_bytes=False
def _b64_decode(obj):
return (
deserialize_tensors(base64.b64decode(obj)) if obj is not None else obj
)

sae_outputs = data["sae_outputs"]
if isinstance(sae_outputs, str):
# sae outputs are always encoded for application/json
sae_outputs = _b64_decode(sae_outputs)
sae_outputs = (
{k: v.to(torch.float32) for k, v in sae_outputs.items()}
if sae_outputs
else None
)
sae_outputs = cast(dict[str, torch.Tensor] | None, sae_outputs)
output = LogitsOutput(
logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
sae_outputs=sae_outputs,
)
return output

Expand Down
2 changes: 1 addition & 1 deletion esm/utils/constants/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
def data_root(model: str):
if "INFRA_PROVIDER" in os.environ:
return Path("")
# Try to download from hugginface if it doesn't exist
# Try to download from huggingface if it doesn't exist
if model.startswith("esm3"):
path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1"))
elif model.startswith("esmc-300"):
Expand Down
2 changes: 1 addition & 1 deletion esm/utils/forge_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ForgeBatchExecutor:

Args:
max_attempts: Maximum attempts per task before failing.
max_workers: Maximum number of concurrent workers. Default is 512.
max_workers: Maximum number of concurrent workers. Default is 64.
show_progress: Whether to display a tqdm progress bar. Default ``True``.
"""

Expand Down
Loading
Loading