@@ -56,6 +56,7 @@ def parse_args() -> argparse.Namespace:
5656 parser .add_argument (
5757 "--ckpt-root-path" , type = str , default = "data/classifier_checkpoints.pkl"
5858 )
59+ parser .add_argument ("--plm_checkpoint_dir" , type = str , default = "data/plm_checkpoints" )
5960 parser .add_argument ("--detection-threshold" , type = float , default = 0.2 )
6061 parser .add_argument ("--detect-precursor-synthases" , help = "Flag to detect precursor synthases as well. Set to False with `--no-detect-precursor-synthases`." , default = True , action = argparse .BooleanOptionalAction )
6162 parser .add_argument ("--gpu" , type = str , default = "0" )
@@ -94,6 +95,7 @@ def main(args: argparse.Namespace):
9495 - clf_batch_size: Number of samples processed in each classification batch.
9596 - output_root: Directory to store prediction outputs.
9697 - ckpt_root_path: Path to the checkpoint file containing pre-trained classifiers.
98+ - plm_checkpoint_dir: Directory where PLM checkpoints are stored.
9799 - detect_precursor_synthases: Boolean flag to detect precursor synthases.
98100 - starting_i, end_i: Range of indices to process sequences.
99101 - gpu: GPU identifier for processing sequences.
@@ -105,7 +107,7 @@ def main(args: argparse.Namespace):
105107
106108 if "esm" in args .model :
107109 model , batch_converter , alphabet = get_model_and_tokenizer (
108- args .model , return_alphabet = True
110+ args .model , return_alphabet = True , checkpoint_dir = args . plm_checkpoint_dir
109111 )
110112
111113 compute_embeddings_partial = partial (
@@ -119,7 +121,7 @@ def main(args: argparse.Namespace):
119121 elif "ankh" in args .model :
120122 model , tokenizer = ankh_get_model_and_tokenizer (args .model )
121123 compute_embeddings_partial = partial (
122- ankh_compute_embeddings , bert_model = model , tokenizer = tokenizer
124+ ankh_compute_embeddings , bert_model = model , tokenizer = tokenizer , checkpoint_dir = args . plm_checkpoint_dir
123125 )
124126 else :
125127 raise NotImplementedError (
0 commit comments