1313
1414import click
1515import sqlitedict
16+ import numpy as np
1617
1718from profold2 .data .dataset import ProteinStructureDataset
1819from profold2 .data .parsers import parse_fasta , parse_a3m
@@ -615,6 +616,7 @@ def attr_update_weight_and_task(**args):
615616)
616617@click .option ("--mask" , type = str , default = "-" , hidden = True )
617618@click .option ("--task_def" , type = str , default = json .dumps (task .make_def ()), hidden = True )
619+ @click .option ("--task_pid_prefix" , type = str , default = "tcr_pmhc_" , hidden = True )
618620@click .option (
619621 "--chunksize" ,
620622 type = int ,
@@ -642,7 +644,23 @@ def predict(**args):
642644 max_var_depth = None
643645 )
644646
647+ def _parse_result (a3m_string ):
648+ _ , descriptions = parse_fasta (a3m_string )
649+ for fields in map (lambda x : x .split ("\t " ), descriptions ):
650+ pid , fields = fields [0 ], fields [1 :]
651+ if pid .startswith (args .task_pid_prefix ):
652+ pred = [None ] * task .task_num
653+ for field in fields :
654+ i = field .find (":" )
655+ if i != - 1 :
656+ if field [:i ] == "Elo_score" :
657+ pred = json .loads (field [i + 1 :])
658+ break
659+ yield pid , pred
660+
645661 for model in args .model :
662+ pred_dict = defaultdict (list )
663+
646664 for ref_pkl in glob .glob (os .path .join (args .ref_pkl , f"{ model } _*.pkl" )):
647665 pdb_id = os .path .basename (ref_pkl )
648666 assert pdb_id .startswith (f"{ model } _" )
@@ -663,11 +681,32 @@ def predict(**args):
663681 )
664682 with io .StringIO (a3m_string ) as a3m_file :
665683 setattr (args , "a3m_file" , [a3m_file ])
666- with open (
667- os .path .join (args .output_dir , f"{ model } _{ pid } .a3m" ), "w"
668- ) as output_file :
684+
685+ output_file_path = os .path .join (args .output_dir , f"{ model } _{ pid } .a3m" )
686+ with open ( output_file_path , "w" ) as output_file :
669687 setattr (args , "output_file" , output_file )
670688 energy .main (args )
689+ with open (output_file_path , "r" ) as output_file :
690+ a3m_string = output_file .read ()
691+ for pid , pred in _parse_result (a3m_string ):
692+ pred_dict [pid ].append (pred )
693+
694+ with open (os .path .join (args .output_dir , f"{ model } _pred.csv" ), "w" ) as f :
695+ writer = csv .DictWriter (f , fieldnames = ["id" , "chains" ] + task .task_name_list )
696+ writer .writeheader ()
697+ for pid , pred_list in pred_dict .items ():
698+ chain_list , * _ = data .chain_list [pid ] # FIX: data.get_chain_list(protein_id)
699+ assert chain_list , (pid , pid in data .chain_list , len (data .chain_list ))
700+ _ , pred_mask = task .make_label (0 , chain_list )
701+
702+ pred_list , pred_mask = np .asarray (pred_list ), np .asarray (pred_mask )
703+ pred_list = np .sum (pred_list * pred_mask [None ], axis = 0 ) / pred_list .shape [0 ]
704+
705+ row = {"id" : pid , "chains" : "_" .join (chain_list )}
706+ for idx , (pred , mask ) in enumerate (zip (pred_list , pred_mask )):
707+ if mask :
708+ row [task .task_name_list [idx ]] = pred
709+ writer .writerow (row )
671710
672711
673712if __name__ == "__main__" :
0 commit comments