From 20c025c492699a7ad24247051ea77deab91fd315 Mon Sep 17 00:00:00 2001 From: Aditya Vartak Date: Tue, 14 Dec 2021 15:48:08 +0530 Subject: [PATCH] Added new args to remove some hardcoded paths added args for 1. data_dir: To fetch the dataset.csv from user specified loc 2. gen_csv_dir : directory to save the generated csv into 3. model_weights_dir: directory containing the model weights --- generate/generate.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/generate/generate.py b/generate/generate.py index bf37540..afce12b 100644 --- a/generate/generate.py +++ b/generate/generate.py @@ -35,7 +35,10 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--model_weight', type=str, help="path of model weights", required=True) + parser.add_argument('--model_weights_dir', type=str, help="directory containing the model weights", required=True) + parser.add_argument('--model_weight', type=str, help="path of model weights in model_weights_root directory", required=True) + parser.add_argument('--data_dir', type=str, help="dataset directory containing the datasets", required=True) + parser.add_argument('--gen_csv_dir', type=str, help="directory to save the generated Moleucles' CSV to ", required=True) parser.add_argument('--scaffold', action='store_true', default=False, help='condition on scaffold') parser.add_argument('--lstm', action='store_true', default=False, help='use lstm for transforming scaffold') parser.add_argument('--csv_name', type=str, help="name to save the generated mols in csv format", required=True) @@ -57,7 +60,7 @@ context = "C" - data = pd.read_csv('/home/viraj.bagal/viraj/ligflow/Code/code/cond_gpt/datasets/' + args.data_name + '.csv') + data = pd.read_csv(args.data_dir + args.data_name + '.csv') data = data.dropna(axis=0).reset_index(drop=True) data.columns = data.columns.str.lower() @@ -125,7 +128,7 @@ model = GPT(mconf) - model.load_state_dict(torch.load('/home/viraj.bagal/viraj/ligflow/Code/code/cond_gpt/weights/' + args.model_weight)) + model.load_state_dict(torch.load(args.model_weights_dir + args.model_weight)) model.to('cuda') print('Model loaded') @@ -458,7 +461,7 @@ results = pd.concat(all_dfs) - results.to_csv('gen_csv_again/' + args.csv_name + '.csv', index = False) + results.to_csv(args.gen_csv_dir + args.csv_name + '.csv', index = False) unique_smiles = list(set(results['smiles'])) canon_smiles = [canonic_smiles(s) for s in results['smiles']]