Skip to content

Commit 4cee5aa

Browse files
committed
Remove hardcoded plm checkpoint paths
1 parent 43cc933 commit 4cee5aa

3 files changed

Lines changed: 10 additions & 4 deletions

File tree

enzymeexplorer/src/embeddings_extraction/ankh_transformer_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
def get_model_and_tokenizer(
88
model_name: str,
9+
checkpoint_dir: Optional[str] = "data/plm_checkpoints"
910
) -> tuple:
1011
"""
1112
This function returns bert model and batch converter (basically a tokenizer) based on the model name
1213
:param model_name: model name
14+
:param checkpoint_dir: directory where checkpoints are stored
1315
:return: a pair of the bert protein model and its tokenizer
1416
"""
1517
assert model_name in {
@@ -22,7 +24,7 @@ def get_model_and_tokenizer(
2224
elif model_name == "ankh_tps":
2325
model, tokenizer = ankh.load_base_model(generation=True)
2426
model.load_state_dict(
25-
torch.load("data/plm_checkpoints/tps_ankh_lr=5e-05_bs=32.pth")[
27+
torch.load(f"{checkpoint_dir}/tps_ankh_lr=5e-05_bs=32.pth")[
2628
"model_state_dict"
2729
],
2830
strict=False,

enzymeexplorer/src/embeddings_extraction/esm_transformer_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,22 @@ def get_model_and_tokenizer(
2424
model_name: str,
2525
checkpoint_names: Optional[dict[str, str]] = None,
2626
return_alphabet: bool = False,
27+
checkpoint_dir: Optional[str] = "data/plm_checkpoints",
2728
) -> tuple:
2829
"""
2930
This function returns bert model and batch converter (basically a tokenizer) based on the name
3031
:param model_name: model name
3132
:param checkpoint_names: mapping between model name and checkpoint file
3233
:param return_alphabet: flag to return alphabet object
34+
:param checkpoint_dir: directory where checkpoints are stored
3335
:return: a pair of the bert protein model and its batch converter
3436
"""
3537
if checkpoint_names is None:
3638
checkpoint_names = CHECKPOINT_NAMES
3739
if model_name in checkpoint_names:
3840
checkpoint_name = checkpoint_names[model_name]
3941
ckpt = torch.load(
40-
f"data/plm_checkpoints/{checkpoint_name}",
42+
f"{checkpoint_dir}/{checkpoint_name}",
4143
map_location=torch.device("cpu"),
4244
)
4345
bert_model, bert_alphabet = getattr(esm.pretrained, "esm1v_t33_650M_UR90S_1")()

enzymeexplorer/src/screening/tps_predict_fasta.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def parse_args() -> argparse.Namespace:
5656
parser.add_argument(
5757
"--ckpt-root-path", type=str, default="data/classifier_checkpoints.pkl"
5858
)
59+
parser.add_argument("--plm_checkpoint_dir", type=str, default="data/plm_checkpoints")
5960
parser.add_argument("--detection-threshold", type=float, default=0.2)
6061
parser.add_argument("--detect-precursor-synthases", help="Flag to detect precursor synthases as well. Set to False with `--no-detect-precursor-synthases`.", default=True, action=argparse.BooleanOptionalAction)
6162
parser.add_argument("--gpu", type=str, default="0")
@@ -94,6 +95,7 @@ def main(args: argparse.Namespace):
9495
- clf_batch_size: Number of samples processed in each classification batch.
9596
- output_root: Directory to store prediction outputs.
9697
- ckpt_root_path: Path to the checkpoint file containing pre-trained classifiers.
98+
- plm_checkpoint_dir: Directory where PLM checkpoints are stored.
9799
- detect_precursor_synthases: Boolean flag to detect precursor synthases.
98100
- starting_i, end_i: Range of indices to process sequences.
99101
- gpu: GPU identifier for processing sequences.
@@ -105,7 +107,7 @@ def main(args: argparse.Namespace):
105107

106108
if "esm" in args.model:
107109
model, batch_converter, alphabet = get_model_and_tokenizer(
108-
args.model, return_alphabet=True
110+
args.model, return_alphabet=True, checkpoint_dir=args.plm_checkpoint_dir
109111
)
110112

111113
compute_embeddings_partial = partial(
@@ -119,7 +121,7 @@ def main(args: argparse.Namespace):
119121
elif "ankh" in args.model:
120122
model, tokenizer = ankh_get_model_and_tokenizer(args.model)
121123
compute_embeddings_partial = partial(
122-
ankh_compute_embeddings, bert_model=model, tokenizer=tokenizer
124+
ankh_compute_embeddings, bert_model=model, tokenizer=tokenizer, checkpoint_dir=args.plm_checkpoint_dir
123125
)
124126
else:
125127
raise NotImplementedError(

0 commit comments

Comments
 (0)