diff --git a/README.md b/README.md
index 7b577741..62cb57a9 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[ESM3](https://www.science.org/doi/10.1126/science.ads0018) ⋅ [ESM C](https://www.evolutionaryscale.ai/blog/esm-cambrian) ⋅
-[Slack](https://bit.ly/esm-slack) ⋅ [Tutorials](https://github.com/evolutionaryscale/esm/tree/main/cookbook/tutorials)
+[Slack](https://bit.ly/3FKwcWd) ⋅ [Tutorials](https://github.com/evolutionaryscale/esm/tree/main/cookbook/tutorials)
diff --git a/cookbook/snippets/esmc.py b/cookbook/snippets/esmc.py
index 335f3e19..8911b509 100644
--- a/cookbook/snippets/esmc.py
+++ b/cookbook/snippets/esmc.py
@@ -1,15 +1,20 @@
+import math
import os
+import torch
+
from esm.models.esmc import ESMC
-from esm.sdk import client
+from esm.sdk import batch_executor, client
from esm.sdk.api import (
ESMCInferenceClient,
ESMProtein,
+ ESMProteinError,
ESMProteinTensor,
LogitsConfig,
LogitsOutput,
)
-from esm.sdk.forge import ESM3ForgeInferenceClient
+from esm.sdk.forge import ESM3ForgeInferenceClient, ESMCForgeInferenceClient
+from esm.tokenization import get_esmc_model_tokenizers
def main(client: ESMCInferenceClient | ESM3ForgeInferenceClient):
@@ -70,6 +75,58 @@ def raw_forward(model: ESMC):
)
+def compute_pseudoperplexity(
+ forge_client: ESMCForgeInferenceClient, sequence: str
+) -> float:
+ """Compute L-pass pseudoperplexity for a protein sequence via Forge.
+
+ Masks each position one at a time, retrieves logits from Forge, and returns
+ exp(-mean(log_prob_true_aa)). Uses batch_executor for parallel requests.
+
+ Example::
+
+ forge_client = ESMCForgeInferenceClient(
+ model="esmc-6b-2024-12",
+ url="https://forge.evolutionaryscale.ai",
+ token=os.environ["ESM_API_KEY"],
+ )
+ pppl = compute_pseudoperplexity(forge_client, "MKTLLILAVL...")
+ """
+ L = len(sequence)
+ masked_sequences = [sequence[:i] + "_" + sequence[i + 1 :] for i in range(L)]
+
+ def _get_logits(client: ESMCForgeInferenceClient, sequence: str) -> LogitsOutput:
+ protein = ESMProtein(sequence=sequence)
+ protein_tensor = client.encode(protein)
+ if isinstance(protein_tensor, ESMProteinError):
+ raise protein_tensor
+ output = client.logits(protein_tensor, LogitsConfig(sequence=True))
+ if isinstance(output, ESMProteinError):
+ raise output
+ return output
+
+ with batch_executor() as executor:
+ logit_outputs = executor.execute_batch(
+ _get_logits, client=forge_client, sequence=masked_sequences
+ )
+
+ # Build vocab from the tokenizer to map amino acid characters to token indices
+ vocab: dict[str, int] = get_esmc_model_tokenizers().get_vocab()
+
+ log_probs = []
+ for i in range(L):
+ output = logit_outputs[i]
+ if isinstance(output, Exception):
+ raise output
+ logits = output.logits.sequence # shape: (L+2, V)
+ position_logits = logits[i + 1] # +1 for BOS token
+ log_softmax = torch.log_softmax(position_logits, dim=-1)
+ true_aa_idx = vocab[sequence[i]]
+ log_probs.append(log_softmax[true_aa_idx].item())
+
+ return math.exp(-sum(log_probs) / L)
+
+
if __name__ == "__main__":
if os.environ.get("ESM_API_KEY", ""):
print("ESM_API_KEY found. Trying to use model from Forge...")
diff --git a/esm/models/esm3.py b/esm/models/esm3.py
index 0d3ead1d..31cfb853 100644
--- a/esm/models/esm3.py
+++ b/esm/models/esm3.py
@@ -524,7 +524,7 @@ def logits(
if device.type == "cuda"
else contextlib.nullcontext(),
):
- output = self.forward(
+ output = self(
sequence_tokens=input.sequence,
structure_tokens=input.structure,
ss8_tokens=input.secondary_structure,
diff --git a/esm/models/esmc.py b/esm/models/esmc.py
index d3a5bfc9..2e83a4ac 100644
--- a/esm/models/esmc.py
+++ b/esm/models/esmc.py
@@ -77,13 +77,18 @@ def __init__(
@classmethod
def from_pretrained(
- cls, model_name: str = ESMC_600M, device: torch.device | None = None
+ cls,
+ model_name: str = ESMC_600M,
+ device: torch.device | None = None,
+ use_flash_attn: bool = True,
) -> ESMC:
from esm.pretrained import load_local_model
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = load_local_model(model_name, device=device)
+ model = load_local_model(
+ model_name, device=device, use_flash_attn=use_flash_attn
+ )
if device.type != "cpu":
model = model.to(torch.bfloat16)
assert isinstance(model, ESMC)
@@ -208,7 +213,7 @@ def logits(
if device.type == "cuda"
else contextlib.nullcontext(),
):
- output = self.forward(sequence_tokens=input.sequence)
+ output = self(sequence_tokens=input.sequence)
assert output.hidden_states is not None
output.hidden_states = (
output.hidden_states[config.ith_hidden_layer : config.ith_hidden_layer + 1]
diff --git a/esm/pretrained.py b/esm/pretrained.py
index e452e1d2..38527832 100644
--- a/esm/pretrained.py
+++ b/esm/pretrained.py
@@ -1,7 +1,9 @@
+import inspect
from typing import Callable
import torch
import torch.nn as nn
+from accelerate import init_empty_weights
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
@@ -22,7 +24,7 @@
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
- with torch.device(device):
+ with init_empty_weights():
model = StructureTokenEncoder(
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
).eval()
@@ -30,34 +32,37 @@ def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
data_root("esm3") / "data/weights/esm3_structure_encoder_v0.pth",
map_location=device,
)
- model.load_state_dict(state_dict)
+ model.load_state_dict(state_dict, assign=True)
+ model = model.to(device)
return model
def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
- with torch.device(device):
+ with init_empty_weights():
model = StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).eval()
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_structure_decoder_v0.pth",
map_location=device,
)
- model.load_state_dict(state_dict)
+ model.load_state_dict(state_dict, assign=True)
+ model = model.to(device)
return model
def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
- with torch.device(device):
+ with init_empty_weights():
model = FunctionTokenDecoder().eval()
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_function_decoder_v0.pth",
map_location=device,
)
- model.load_state_dict(state_dict)
+ model.load_state_dict(state_dict, assign=True)
+ model = model.to(device)
return model
def ESMC_300M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = True):
- with torch.device(device):
+ with init_empty_weights():
model = ESMC(
d_model=960,
n_heads=15,
@@ -69,13 +74,13 @@ def ESMC_300M_202412(device: torch.device | str = "cpu", use_flash_attn: bool =
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
map_location=device,
)
- model.load_state_dict(state_dict)
-
+ model.load_state_dict(state_dict, assign=True)
+ model = model.to(device)
return model
def ESMC_600M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = True):
- with torch.device(device):
+ with init_empty_weights():
model = ESMC(
d_model=1152,
n_heads=18,
@@ -87,13 +92,13 @@ def ESMC_600M_202412(device: torch.device | str = "cpu", use_flash_attn: bool =
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
map_location=device,
)
- model.load_state_dict(state_dict)
-
+ model.load_state_dict(state_dict, assign=True)
+ model = model.to(device)
return model
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
- with torch.device(device):
+ with init_empty_weights():
model = ESM3(
d_model=1536,
n_heads=24,
@@ -107,7 +112,8 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device
)
- model.load_state_dict(state_dict)
+ model.load_state_dict(state_dict, assign=True)
+ model = model.to(device)
return model
@@ -122,11 +128,17 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
def load_local_model(
- model_name: str, device: torch.device = torch.device("cpu")
+ model_name: str,
+ device: torch.device = torch.device("cpu"),
+ use_flash_attn: bool = True,
) -> nn.Module:
if model_name not in LOCAL_MODEL_REGISTRY:
raise ValueError(f"Model {model_name} not found in local model registry.")
- return LOCAL_MODEL_REGISTRY[model_name](device)
+ builder = LOCAL_MODEL_REGISTRY[model_name]
+ kwargs = {}
+ if "use_flash_attn" in inspect.signature(builder).parameters:
+ kwargs["use_flash_attn"] = use_flash_attn
+ return builder(device, **kwargs)
# Register custom versions of ESM3 for use with the local inference API
diff --git a/esm/sdk/api.py b/esm/sdk/api.py
index 77a42b0e..a01fa4f5 100644
--- a/esm/sdk/api.py
+++ b/esm/sdk/api.py
@@ -58,14 +58,29 @@ def __len__(self):
def from_pdb(
cls,
path: PathOrBuffer,
- chain_id: str = "detect",
+ chain_id: str = "all",
id: str | None = None,
is_predicted: bool = False,
) -> ESMProtein:
- protein_chain = ProteinChain.from_pdb(
- path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
- )
- return cls.from_protein_chain(protein_chain)
+ """Return an ESMProtein object from a pdb file.
+
+ Args:
+ path (str | Path | io.TextIO): Path or buffer to read pdb file from. Should be uncompressed.
+ chain_id (str, optional): Select a chain corresponding to (author) chain id. "all" uses all chains,
+ "detect" uses the first detected chain
+ id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
+ is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
+ """
+ if chain_id == "all":
+ protein_complex = ProteinComplex.from_pdb(
+ path=path, id=id, is_predicted=is_predicted
+ )
+ return cls.from_protein_complex(protein_complex)
+ else:
+ protein_chain = ProteinChain.from_pdb(
+ path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
+ )
+ return cls.from_protein_chain(protein_chain)
@classmethod
def from_protein_chain(
@@ -396,6 +411,16 @@ class LogitsConfig:
return_mean_hidden_states: bool = False
ith_hidden_layer: int = -1
+ # SAE config only applies to ESMC models
+ sae_config: SAEConfig | None = None
+
+
+@define
+class SAEConfig:
+ model: str
+ normalize_features: bool = True
+ mode: str | None = None
+
@define
class LogitsOutput:
@@ -409,6 +434,8 @@ class LogitsOutput:
residue_annotation_logits: torch.Tensor | None = None
hidden_states: torch.Tensor | None = None
mean_hidden_state: torch.Tensor | None = None
+ # sae_outputs keys are sae model names and values are sparse representations of the sae activations
+ sae_outputs: dict[str, torch.Tensor] | None = None
@define
diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py
index 98973326..ab6fd2b6 100644
--- a/esm/sdk/forge.py
+++ b/esm/sdk/forge.py
@@ -4,7 +4,7 @@
import base64
import pickle
from concurrent.futures import ThreadPoolExecutor
-from typing import Any, Sequence
+from typing import Any, Sequence, cast
import torch
@@ -720,7 +720,7 @@ def batch_generate(
try:
results.append(future.result())
except Exception as e:
- results.append(ESMProteinError(500, str(e)))
+ results.append(ESMProteinError(error_code=500, error_msg=str(e)))
return results
async def __async_generate_protein(
@@ -967,6 +967,13 @@ def _process_logits_request(
"return_mean_hidden_states": config.return_mean_hidden_states,
"return_hidden_states": config.return_hidden_states,
"ith_hidden_layer": config.ith_hidden_layer,
+ "sae_config": {
+ "model": config.sae_config.model,
+ "normalize_features": config.sae_config.normalize_features,
+ "mode": config.sae_config.mode,
+ }
+ if config.sae_config
+ else None,
}
request = {"model": model_name, "inputs": req, "logits_config": logits_config}
return request
@@ -980,12 +987,31 @@ def _process_logits_response(
data["embeddings"] = _maybe_b64_decode(data["embeddings"], return_bytes)
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
+ # sae outputs are always encoded
+ # NOTE: leave this intact for application/json since it's always
+ # base64 bytes, even with return_bytes=False
+ def _b64_decode(obj):
+ return (
+ deserialize_tensors(base64.b64decode(obj)) if obj is not None else obj
+ )
+
+ sae_outputs = data["sae_outputs"]
+ if isinstance(sae_outputs, str):
+ # sae outputs are always encoded for application/json
+ sae_outputs = _b64_decode(sae_outputs)
+ sae_outputs = (
+ {k: v.to(torch.float32) for k, v in sae_outputs.items()}
+ if sae_outputs
+ else None
+ )
+ sae_outputs = cast(dict[str, torch.Tensor] | None, sae_outputs)
output = LogitsOutput(
logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
+ sae_outputs=sae_outputs,
)
return output
diff --git a/esm/utils/constants/esm3.py b/esm/utils/constants/esm3.py
index bea46eda..35c201ac 100644
--- a/esm/utils/constants/esm3.py
+++ b/esm/utils/constants/esm3.py
@@ -100,7 +100,7 @@
def data_root(model: str):
if "INFRA_PROVIDER" in os.environ:
return Path("")
- # Try to download from hugginface if it doesn't exist
+ # Try to download from huggingface if it doesn't exist
if model.startswith("esm3"):
path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1"))
elif model.startswith("esmc-300"):
diff --git a/esm/utils/forge_context_manager.py b/esm/utils/forge_context_manager.py
index b1c2bdf3..98de9ee4 100644
--- a/esm/utils/forge_context_manager.py
+++ b/esm/utils/forge_context_manager.py
@@ -48,7 +48,7 @@ class ForgeBatchExecutor:
Args:
max_attempts: Maximum attempts per task before failing.
- max_workers: Maximum number of concurrent workers. Default is 512.
+ max_workers: Maximum number of concurrent workers. Default is 64.
show_progress: Whether to display a tqdm progress bar. Default ``True``.
"""
diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py
index 838e5b4b..120f1333 100644
--- a/esm/utils/msa/msa.py
+++ b/esm/utils/msa/msa.py
@@ -126,6 +126,9 @@ def from_bytes(cls, data: bytes) -> MSA:
headers = header_bytes.decode().split("\n")
# Sometimes the separation is two newlines, which results in an empty header.
headers = [header for header in headers if header]
+ # If all headers were empty (e.g., saved from from_sequences), use empty headers
+ if len(headers) == 0 and depth > 0:
+ headers = [""] * depth
entries = [
FastaEntry(header, b"".join(row).decode())
for header, row in zip(headers, array)
@@ -367,6 +370,9 @@ def from_bytes(cls, data: bytes) -> FastMSA:
headers = header_bytes.decode().split("\n")
# Sometimes the separation is two newlines, which results in an empty header.
headers = [header for header in headers if header]
+ # If all headers were empty (e.g., saved from from_sequences), use empty headers
+ if len(headers) == 0 and depth > 0:
+ headers = [""] * depth
return cls(array, headers)
@classmethod
diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py
index b2e664fb..4ae55fed 100644
--- a/esm/utils/structure/input_builder.py
+++ b/esm/utils/structure/input_builder.py
@@ -7,7 +7,8 @@
@dataclass
class Modification:
position: int # zero-indexed
- ccd: str
+ ccd: str | None = None
+ smiles: str | None = None # TODO(mlee): add smiles support
@dataclass
@@ -15,6 +16,7 @@ class ProteinInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
+ cyclic: bool = False
@dataclass
diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py
index e49be51c..c0ec8e1e 100644
--- a/esm/utils/structure/molecular_complex.py
+++ b/esm/utils/structure/molecular_complex.py
@@ -320,29 +320,52 @@ def to_protein_complex(self) -> ProteinComplex:
# Calculate final sequence length (includes chain breaks)
sequence_length = len(single_letter_sequence)
- # Convert flat atoms back to atom37 representation
+ # Convert flat atoms back to atom37 representation using atom names
for res_idx, token_idx in enumerate(protein_indices):
token = self.sequence[token_idx]
start_atom, end_atom = self.token_to_atoms[token_idx]
- # Get atom data for this residue
res_atom_positions = self.atom_positions[start_atom:end_atom]
+ res_atom_names = (
+ np.array(self.atom_names[start_atom:end_atom], dtype=str)
+ if self.atom_names is not None
+ else np.array([], dtype=str)
+ )
- # Reconstruct atom37 representation by exactly reversing the forward conversion logic
- # In from_protein_complex, atoms are added in atom_types order if present in mask
- # So we need to reconstruct the mask and positions in the same order
- atom_count = 0
- for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
- # Check if this atom type exists for this residue and was present
- residue_atoms = residue_constants.residue_atoms.get(token, [])
- if atom_name in residue_atoms:
- # This atom type exists for this residue, so it should have been included
- if atom_count < len(res_atom_positions):
- atom37_positions[res_idx, atom_type_idx] = res_atom_positions[
- atom_count
- ]
- atom37_mask[res_idx, atom_type_idx] = True
- atom_count += 1
+ # Build a mapping from normalized atom name -> position for this residue
+ # Normalize to uppercase and strip whitespace for robust matching
+ name_to_pos: dict[str, np.ndarray] = {}
+ for i, nm in enumerate(res_atom_names):
+ key = nm.upper().strip()
+ # Prefer first occurrence; ignore duplicates/altlocs
+ if key not in name_to_pos:
+ name_to_pos[key] = res_atom_positions[i]
+
+ canonical_atoms = residue_constants.residue_atoms.get(token, [])
+ swap_map = residue_constants.residue_atom_renaming_swaps.get(token, {})
+
+ # Place coordinates by canonical atom name into atom37 order
+ for cn in canonical_atoms:
+ # Normalize canonical name for lookup
+ cn_key = cn.upper()
+ idx37 = residue_constants.atom_order.get(cn_key)
+ if idx37 is None:
+ continue
+ # Try direct name, otherwise use renaming swap (normalized)
+ present_name = (
+ cn_key
+ if cn_key in name_to_pos
+ else (
+ swap_map.get(cn) if isinstance(swap_map.get(cn), str) else None
+ )
+ )
+ if present_name:
+ present_key = present_name.upper()
+ else:
+ present_key = None
+ if present_key and present_key in name_to_pos:
+ atom37_positions[res_idx, idx37] = name_to_pos[present_key]
+ atom37_mask[res_idx, idx37] = True
# Create arrays that match sequence length (including chain breaks)
# Initialize arrays with proper size
@@ -703,9 +726,10 @@ def to_mmcif(self) -> str:
atom_names[start:end] = names
# Set all AtomArray attributes at once (convert object arrays to proper string arrays)
+ # Note: res_name uses U8 to accommodate CCD codes up to 5 characters (e.g., A1AZ2)
atom_array.res_id = atom_res_ids
atom_array.chain_id = np.array(atom_chain_ids, dtype="U4")
- atom_array.res_name = np.array(atom_res_names, dtype="U4")
+ atom_array.res_name = np.array(atom_res_names, dtype="U8")
atom_array.hetero = atom_hetero
atom_array.atom_name = np.array(atom_names, dtype="U4")
atom_array.add_annotation("b_factor", dtype=float)
diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py
index fa71d2ad..49b4f2f9 100644
--- a/esm/utils/structure/protein_complex.py
+++ b/esm/utils/structure/protein_complex.py
@@ -317,7 +317,9 @@ def as_chain(self, force_conversion: bool = False) -> ProteinChain:
)
@classmethod
- def from_pdb(cls, path: PathOrBuffer, id: str | None = None) -> "ProteinComplex":
+ def from_pdb(
+ cls, path: PathOrBuffer, id: str | None = None, is_predicted: bool = False
+ ) -> "ProteinComplex":
atom_array = PDBFile.read(path).get_structure(
model=1, extra_fields=["b_factor"]
)
@@ -327,7 +329,7 @@ def from_pdb(cls, path: PathOrBuffer, id: str | None = None) -> "ProteinComplex"
chain = chain[~chain.hetero]
if len(chain) == 0:
continue
- chains.append(ProteinChain.from_atomarray(chain, id))
+ chains.append(ProteinChain.from_atomarray(chain, id, is_predicted))
return ProteinComplex.from_chains(chains)
def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
diff --git a/pixi.lock b/pixi.lock
index 3aa514df..4eabe1d4 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -142,6 +142,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstandard-0.23.0-py312h66e93f0_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda
+ - pypi: https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6f/12/e5e0282d673bb9746bacfb6e2dba8719989d3660cdb2ea79aee9a9651afb/anyio-4.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl
@@ -191,6 +192,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a3/17/c476487ba903c7d793db633b4e8ca4a420ae272890302189d9402ba8ff85/py3dmol-2.5.2-py2.py3-none-any.whl
@@ -340,6 +342,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.0-pyhd8ed1ab_0.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstandard-0.23.0-py312hea69d52_2.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda
+ - pypi: https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6f/12/e5e0282d673bb9746bacfb6e2dba8719989d3660cdb2ea79aee9a9651afb/anyio-4.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl
@@ -377,6 +380,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ce/4f/5249960887b1fbe561d9ff265496d170b55a735b76724f10ef19f9e40716/prompt_toolkit-3.0.51-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a3/17/c476487ba903c7d793db633b4e8ca4a420ae272890302189d9402ba8ff85/py3dmol-2.5.2-py2.py3-none-any.whl
@@ -629,6 +633,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstandard-0.23.0-py312h66e93f0_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda
+ - pypi: https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6f/12/e5e0282d673bb9746bacfb6e2dba8719989d3660cdb2ea79aee9a9651afb/anyio-4.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl
@@ -678,6 +683,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a3/17/c476487ba903c7d793db633b4e8ca4a420ae272890302189d9402ba8ff85/py3dmol-2.5.2-py2.py3-none-any.whl
@@ -854,6 +860,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.0-pyhd8ed1ab_0.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstandard-0.23.0-py312hea69d52_2.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda
+ - pypi: https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6f/12/e5e0282d673bb9746bacfb6e2dba8719989d3660cdb2ea79aee9a9651afb/anyio-4.10.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl
@@ -891,6 +898,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ce/4f/5249960887b1fbe561d9ff265496d170b55a735b76724f10ef19f9e40716/prompt_toolkit-3.0.51-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a3/17/c476487ba903c7d793db633b4e8ca4a420ae272890302189d9402ba8ff85/py3dmol-2.5.2-py2.py3-none-any.whl
@@ -960,6 +968,81 @@ packages:
purls: []
size: 8191
timestamp: 1744137672556
+- pypi: https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl
+ name: accelerate
+ version: 1.13.0
+ sha256: cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0
+ requires_dist:
+ - numpy>=1.17
+ - packaging>=20.0
+ - psutil
+ - pyyaml
+ - torch>=2.0.0
+ - huggingface-hub>=0.21.0
+ - safetensors>=0.4.3
+ - ruff==0.13.1 ; extra == 'quality'
+ - pytest>=7.2.0 ; extra == 'test-prod'
+ - pytest-xdist ; extra == 'test-prod'
+ - pytest-subtests ; extra == 'test-prod'
+ - parameterized ; extra == 'test-prod'
+ - pytest-order ; extra == 'test-prod'
+ - datasets ; extra == 'test-dev'
+ - diffusers ; extra == 'test-dev'
+ - evaluate ; extra == 'test-dev'
+ - torchdata>=0.8.0 ; extra == 'test-dev'
+ - torchpippy>=0.2.0 ; extra == 'test-dev'
+ - transformers ; extra == 'test-dev'
+ - scipy ; extra == 'test-dev'
+ - scikit-learn ; extra == 'test-dev'
+ - tqdm ; extra == 'test-dev'
+ - bitsandbytes ; extra == 'test-dev'
+ - timm ; extra == 'test-dev'
+ - pytest>=7.2.0 ; extra == 'testing'
+ - pytest-xdist ; extra == 'testing'
+ - pytest-subtests ; extra == 'testing'
+ - parameterized ; extra == 'testing'
+ - pytest-order ; extra == 'testing'
+ - datasets ; extra == 'testing'
+ - diffusers ; extra == 'testing'
+ - evaluate ; extra == 'testing'
+ - torchdata>=0.8.0 ; extra == 'testing'
+ - torchpippy>=0.2.0 ; extra == 'testing'
+ - transformers ; extra == 'testing'
+ - scipy ; extra == 'testing'
+ - scikit-learn ; extra == 'testing'
+ - tqdm ; extra == 'testing'
+ - bitsandbytes ; extra == 'testing'
+ - timm ; extra == 'testing'
+ - deepspeed ; extra == 'deepspeed'
+ - rich ; extra == 'rich'
+ - torchao ; extra == 'test-fp8'
+ - wandb ; extra == 'test-trackers'
+ - comet-ml ; extra == 'test-trackers'
+ - tensorboard ; extra == 'test-trackers'
+ - dvclive ; extra == 'test-trackers'
+ - matplotlib ; extra == 'test-trackers'
+ - swanlab[dashboard] ; extra == 'test-trackers'
+ - trackio ; extra == 'test-trackers'
+ - ruff==0.13.1 ; extra == 'dev'
+ - pytest>=7.2.0 ; extra == 'dev'
+ - pytest-xdist ; extra == 'dev'
+ - pytest-subtests ; extra == 'dev'
+ - parameterized ; extra == 'dev'
+ - pytest-order ; extra == 'dev'
+ - datasets ; extra == 'dev'
+ - diffusers ; extra == 'dev'
+ - evaluate ; extra == 'dev'
+ - torchdata>=0.8.0 ; extra == 'dev'
+ - torchpippy>=0.2.0 ; extra == 'dev'
+ - transformers ; extra == 'dev'
+ - scipy ; extra == 'dev'
+ - scikit-learn ; extra == 'dev'
+ - tqdm ; extra == 'dev'
+ - bitsandbytes ; extra == 'dev'
+ - timm ; extra == 'dev'
+ - rich ; extra == 'dev'
+ - sagemaker ; extra == 'sagemaker'
+ requires_python: '>=3.10.0'
- conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.14-hb9d3cd8_0.conda
sha256: b9214bc17e89bf2b691fad50d952b7f029f6148f4ac4fe7c60c08f093efdf745
md5: 76df83c2a9035c54df5d04ff81bcc02d
@@ -1727,7 +1810,7 @@ packages:
- pypi: ./
name: esm
version: 3.2.4a1
- sha256: 9a3a042ef03cda7a67cb638f08f7536d8c564329f9d9f9024e7428c7bcccc2d7
+ sha256: 75e7ec205413f4d96cd673be9845c893456eac381c0dd5a9655f10adc9a8a9f8
requires_dist:
- torch>=2.2.0
- torchvision
@@ -1752,6 +1835,7 @@ packages:
- boto3
- pygtrie
- dna-features-viewer
+ - accelerate
requires_python: '>=3.12,<3.13'
editable: true
- conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda
@@ -5028,6 +5112,94 @@ packages:
requires_dist:
- wcwidth
requires_python: '>=3.8'
+- pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl
+ name: psutil
+ version: 7.2.2
+ sha256: 1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979
+ requires_dist:
+ - psleak ; extra == 'dev'
+ - pytest ; extra == 'dev'
+ - pytest-instafail ; extra == 'dev'
+ - pytest-xdist ; extra == 'dev'
+ - setuptools ; extra == 'dev'
+ - abi3audit ; extra == 'dev'
+ - black ; extra == 'dev'
+ - check-manifest ; extra == 'dev'
+ - coverage ; extra == 'dev'
+ - packaging ; extra == 'dev'
+ - pylint ; extra == 'dev'
+ - pyperf ; extra == 'dev'
+ - pypinfo ; extra == 'dev'
+ - pytest-cov ; extra == 'dev'
+ - requests ; extra == 'dev'
+ - rstcheck ; extra == 'dev'
+ - ruff ; extra == 'dev'
+ - sphinx ; extra == 'dev'
+ - sphinx-rtd-theme ; extra == 'dev'
+ - toml-sort ; extra == 'dev'
+ - twine ; extra == 'dev'
+ - validate-pyproject[all] ; extra == 'dev'
+ - virtualenv ; extra == 'dev'
+ - vulture ; extra == 'dev'
+ - wheel ; extra == 'dev'
+ - colorama ; os_name == 'nt' and extra == 'dev'
+ - pyreadline3 ; os_name == 'nt' and extra == 'dev'
+ - pywin32 ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'dev'
+ - wheel ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'dev'
+ - wmi ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'dev'
+ - psleak ; extra == 'test'
+ - pytest ; extra == 'test'
+ - pytest-instafail ; extra == 'test'
+ - pytest-xdist ; extra == 'test'
+ - setuptools ; extra == 'test'
+ - pywin32 ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'test'
+ - wheel ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'test'
+ - wmi ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'test'
+ requires_python: '>=3.6'
+- pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
+ name: psutil
+ version: 7.2.2
+ sha256: 076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9
+ requires_dist:
+ - psleak ; extra == 'dev'
+ - pytest ; extra == 'dev'
+ - pytest-instafail ; extra == 'dev'
+ - pytest-xdist ; extra == 'dev'
+ - setuptools ; extra == 'dev'
+ - abi3audit ; extra == 'dev'
+ - black ; extra == 'dev'
+ - check-manifest ; extra == 'dev'
+ - coverage ; extra == 'dev'
+ - packaging ; extra == 'dev'
+ - pylint ; extra == 'dev'
+ - pyperf ; extra == 'dev'
+ - pypinfo ; extra == 'dev'
+ - pytest-cov ; extra == 'dev'
+ - requests ; extra == 'dev'
+ - rstcheck ; extra == 'dev'
+ - ruff ; extra == 'dev'
+ - sphinx ; extra == 'dev'
+ - sphinx-rtd-theme ; extra == 'dev'
+ - toml-sort ; extra == 'dev'
+ - twine ; extra == 'dev'
+ - validate-pyproject[all] ; extra == 'dev'
+ - virtualenv ; extra == 'dev'
+ - vulture ; extra == 'dev'
+ - wheel ; extra == 'dev'
+ - colorama ; os_name == 'nt' and extra == 'dev'
+ - pyreadline3 ; os_name == 'nt' and extra == 'dev'
+ - pywin32 ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'dev'
+ - wheel ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'dev'
+ - wmi ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'dev'
+ - psleak ; extra == 'test'
+ - pytest ; extra == 'test'
+ - pytest-instafail ; extra == 'test'
+ - pytest-xdist ; extra == 'test'
+ - setuptools ; extra == 'test'
+ - pywin32 ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'test'
+ - wheel ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'test'
+ - wmi ; implementation_name != 'pypy' and os_name == 'nt' and extra == 'test'
+ requires_python: '>=3.6'
- conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda
sha256: 9c88f8c64590e9567c6c80823f0328e58d3b1efb0e1c539c0315ceca764e0973
md5: b3c17d95b5a10c6e64a21fa17573e70e
diff --git a/pyproject.toml b/pyproject.toml
index 8d15c1da..070e3625 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ version = "3.2.4.a1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.12,<3.13"
-license = {file = "LICENSE.txt"}
+license = {file = "LICENSE.md"}
authors = [
{name = "EvolutionaryScale Team"}
@@ -44,6 +44,7 @@ dependencies = [
"boto3",
"pygtrie",
"dna_features_viewer",
+ "accelerate",
]
# Pytest
[tool.pytest.ini_options]
diff --git a/tests/Makefile b/tests/Makefile
index 61c37238..2f203b96 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -1,6 +1,7 @@
# OSS-specific variables and commands
DOCKER_TAG ?= dev
DOCKER_IMAGE_OSS=oss_pytests:${DOCKER_TAG}
+INFRA_PROVIDER ?= AWS
build-oss-ci:
docker build \
@@ -12,6 +13,7 @@ build-oss-ci:
start-docker-oss:
docker run \
--rm \
+ -e INFRA_PROVIDER=${INFRA_PROVIDER} \
-e URL=${URL} \
-e ESM3_FORGE_TOKEN=${ESM3_FORGE_TOKEN} \
--name=$(USER)-oss_pytests \