From 02378a0cb49f678ca5e0d01303d22184a93c0ac5 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 12:44:01 -0400 Subject: [PATCH 01/28] Add a setup file for the project --- setup.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8ab68af --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +#-*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setup( + name="GeomMol", + version="1.0.0", + author="Lagnajit Pattanaik", + author_email="lagnajit@mit.com", + description="Machine learning tools for molecule conformer generation", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/PattanaikL/GeoMol", + packages=find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Chemistry" + ], + license="MIT License", + python_requires='>=3.7', +) From 3079937efbc7992f5fd6d3683680dbb916934006 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 12:46:02 -0400 Subject: [PATCH 02/28] Utilize unused imports in featurization --- model/featurization.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/model/featurization.py b/model/featurization.py index 786d137..56dd389 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -153,11 +153,11 @@ def featurize_mol(self, mol_dic): 1 if atom.GetIsAromatic() else 0]) atom_features.extend(one_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])) atom_features.extend(one_k_encoding(atom.GetHybridization(), [ - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2])) + HybridizationType.SP, + HybridizationType.SP2, + HybridizationType.SP3, + HybridizationType.SP3D, + HybridizationType.SP3D2])) atom_features.extend(one_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])) atom_features.extend(one_k_encoding(atom.GetFormalCharge(), [-1, 0, 1])) atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)), From 9563cbe3513362458a173d2b14798e58495123d0 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 12:47:42 -0400 Subject: [PATCH 03/28] clean up imports in featurization --- model/featurization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model/featurization.py b/model/featurization.py index 56dd389..2224a2b 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -1,18 +1,18 @@ from rdkit import Chem -from rdkit.Chem.rdchem import HybridizationType +from rdkit.Chem.rdchem import ChiralType, HybridizationType from rdkit.Chem.rdchem import BondType as BT -from rdkit.Chem.rdchem import ChiralType -import os.path as osp -import numpy as np import glob +import os.path as osp import pickle import random +import numpy as np import torch import torch.nn.functional as F from torch_scatter import scatter from torch_geometric.data import Dataset, Data, DataLoader + from model.utils import get_dihedral_pairs dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]') From 79c122ff7da58d83bc82f058de8eba19bd05cb44 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 12:49:06 -0400 Subject: [PATCH 04/28] Move constant variables to the top in featurization --- model/featurization.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/model/featurization.py b/model/featurization.py index 2224a2b..f5ea03c 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -15,11 +15,21 @@ from model.utils import get_dihedral_pairs -dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]') + +bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1., ChiralType.CHI_TETRAHEDRAL_CCW: 1., ChiralType.CHI_UNSPECIFIED: 0, ChiralType.CHI_OTHER: 0} +dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]') + +qm9_types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} +drugs_types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10, + 'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20, + 'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30, + 'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34} + + def one_k_encoding(value, choices): @@ -247,14 +257,6 @@ def construct_loader(args, modes=('train', 'val')): return loaders -bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} -qm9_types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} -drugs_types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10, - 'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20, - 'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30, - 'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34} - - def featurize_mol_from_smiles(smiles, dataset='qm9'): if dataset == 'qm9': From e5959c16acd2a45ae93329a960f13246966ae29d Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 12:55:53 -0400 Subject: [PATCH 05/28] Simplify dataset definition and reference --- model/featurization.py | 51 +++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/model/featurization.py b/model/featurization.py index f5ea03c..9705662 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -28,8 +28,7 @@ 'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20, 'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30, 'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34} - - +dataset_types = {'qm9': qm9_types, 'drugs': drugs_types} def one_k_encoding(value, choices): @@ -48,24 +47,29 @@ def one_k_encoding(value, choices): class geom_confs(Dataset): - def __init__(self, root, split_path, mode, transform=None, pre_transform=None, max_confs=10): - super(geom_confs, self).__init__(root, transform, pre_transform) + + dataset = '' + + def __init__(self, + root, + split_path, + mode, + transform=None, + pre_transform=None, + max_confs=10): + super().__init__(root, transform, pre_transform) self.root = root self.split_idx = 0 if mode == 'train' else 1 if mode == 'val' else 2 self.split = np.load(split_path, allow_pickle=True)[self.split_idx] - self.bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} - - # try: - # with open(osp.join(self.root, 'all_data.pickle'), 'rb') as f: - # data_dict = pickle.load(f) - # smiles = [list(data_dict)[i] for i in self.split] - # self.pickle_files = [data_dict[smi] for smi in smiles] - # except FileNotFoundError: - self.dihedral_pairs = {} # for memoization + self.bonds = bonds + + self.dihedral_pairs = {} # for memoization all_files = sorted(glob.glob(osp.join(self.root, '*.pickle'))) - self.pickle_files = [f for i, f in enumerate(all_files) if i in self.split] + self.pickle_files = [f for i, f in enumerate(all_files) + if i in self.split] self.max_confs = max_confs + self.types = dataset_types[self.dataset] def len(self): # return len(self.pickle_files) # should we change this to an integer for random sampling? @@ -217,20 +221,14 @@ def featurize_mol(self, mol_dic): return data - class qm9_confs(geom_confs): - def __init__(self, root, split_path, mode, transform=None, pre_transform=None, max_confs=10): - super(qm9_confs, self).__init__(root, split_path, mode, transform, pre_transform, max_confs) - self.types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} + + dataset = 'qm9' class drugs_confs(geom_confs): - def __init__(self, root, split_path, mode, transform=None, pre_transform=None, max_confs=10): - super(drugs_confs, self).__init__(root, split_path, mode, transform, pre_transform, max_confs) - self.types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10, - 'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20, - 'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30, - 'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34} + + dataset = 'drugs' def construct_loader(args, modes=('train', 'val')): @@ -259,10 +257,7 @@ def construct_loader(args, modes=('train', 'val')): def featurize_mol_from_smiles(smiles, dataset='qm9'): - if dataset == 'qm9': - types = qm9_types - elif dataset == 'drugs': - types = drugs_types + types = dataset_types[dataset] # filter fragments if '.' in smiles: From 2f1dbdb0869aa4ba98fc4dd03a2910ff1317e45d Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 12:59:04 -0400 Subject: [PATCH 06/28] Add two convenient function check_mol and smiles_to_mol --- model/featurization.py | 43 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/model/featurization.py b/model/featurization.py index 9705662..79a92b0 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -348,3 +348,46 @@ def featurize_mol_from_smiles(smiles, dataset='qm9'): data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data) return data + + +def smiles_to_mol(smiles: str, + check_mol: bool = True): + """ + Convert a SMILES string to a RDKit molecule. + """ + try: + mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) + except Exception: + return None + if check_mol: + return _check_mol(mol, smiles=smiles) + return mol + + +def _check_mol(mol, + smiles=None): + """ + Check if a molecule is valid. + """ + # filter fragments + if smiles is not None: + if '.' in smiles: + return None + else: + frags = Chem.rdmolops.GetMolFrags(mol, + asMols=False) + if len(frags) > 1: + return None + + # filter out mols model can't make predictions for + if mol.GetNumAtoms() < 4: + return None + if mol.GetNumBonds() < 4: + # in Lucky' original implementation + # this criteria is included in geom_confs.featurize_mol + # but not included in featurize_mol_from_smiles + # add it here anyway + return None + if not mol.HasSubstructMatch(dihedral_pattern): + return None + return mol From 9af9a8d3c45f64ad67dc137651b6c748f910b609 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 13:11:52 -0400 Subject: [PATCH 07/28] Add a helper function to create necessary features --- model/featurization.py | 85 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/model/featurization.py b/model/featurization.py index 79a92b0..106c895 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -391,3 +391,88 @@ def _check_mol(mol, if not mol.HasSubstructMatch(dihedral_pattern): return None return mol + + +def _mol_to_features(mol, + dataset: str = 'qm9'): + """ + Prepare necessary information for converting a RDKit mol object to a torch_geometry_data object. + """ + types = dataset_types[dataset] + + type_idx = [] + atomic_number = [] + atom_features = [] + chiral_tag = [] + neighbor_dict = {} + ring = mol.GetRingInfo() + + n_atom = mol.GetNumAtoms() + # Atomic features + for i, atom in enumerate(mol.GetAtoms()): + type_idx.append(types[atom.GetSymbol()]) + if len(atom.GetNeighbors()) > 1: + n_ids = [n.GetIdx() for n in atom.GetNeighbors()] + neighbor_dict[i] = torch.tensor(n_ids) + chiral_tag.append(chirality[atom.GetChiralTag()]) + atomic_number.append(atom.GetAtomicNum()) + atom_features.extend([atom.GetAtomicNum(), + 1 if atom.GetIsAromatic() else 0]) + atom_features.extend(one_k_encoding( + atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])) + atom_features.extend(one_k_encoding(atom.GetHybridization(), [ + HybridizationType.SP, + HybridizationType.SP2, + HybridizationType.SP3, + HybridizationType.SP3D, + HybridizationType.SP3D2])) + atom_features.extend(one_k_encoding( + atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])) + atom_features.extend(one_k_encoding( + atom.GetFormalCharge(), [-1, 0, 1])) + atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)), + int(ring.IsAtomInRingOfSize(i, 4)), + int(ring.IsAtomInRingOfSize(i, 5)), + int(ring.IsAtomInRingOfSize(i, 6)), + int(ring.IsAtomInRingOfSize(i, 7)), + int(ring.IsAtomInRingOfSize(i, 8))]) + atom_features.extend(one_k_encoding( + int(ring.NumAtomRings(i)), [0, 1, 2, 3])) + + z = torch.tensor(atomic_number, dtype=torch.long) + chiral_tag = torch.tensor(chiral_tag, dtype=torch.float) + + # Edge features + row, col, edge_type, bond_features = [], [], [], [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + row += [start, end] + col += [end, start] + edge_type += 2 * [bonds[bond.GetBondType()]] + bt = tuple( + sorted( + [bond.GetBeginAtom().GetAtomicNum(), + bond.GetEndAtom().GetAtomicNum()] + )), bond.GetBondTypeAsDouble() + bond_features += 2 * [int(bond.IsInRing()), + int(bond.GetIsConjugated()), + int(bond.GetIsAromatic())] + + edge_index = torch.tensor([row, col], dtype=torch.long) + edge_type = torch.tensor(edge_type, dtype=torch.long) + edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) + + perm = (edge_index[0] * n_atom + edge_index[1]).argsort() + edge_index = edge_index[:, perm] + edge_type = edge_type[perm] + edge_attr = edge_attr[perm] + + row, col = edge_index + hs = (z == 1).to(torch.float) + num_hs = scatter(hs[row], col, dim_size=n_atom).tolist() + + x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types)) + x2 = torch.tensor(atom_features).view(n_atom, -1) + x = torch.cat([x1.to(torch.float), x2], dim=-1) + + return x, z, edge_index, edge_attr, neighbor_dict, chiral_tag From 2f51656ce3bd55abe8636930f62d87c76cf19ea4 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 13:14:57 -0400 Subject: [PATCH 08/28] Add a function to featurize molecule given a RDKit Mol object This can be useful for other softwares using rdkit --- model/featurization.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/model/featurization.py b/model/featurization.py index 106c895..6e29454 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -6,6 +6,7 @@ import os.path as osp import pickle import random +from typing import Optional import numpy as np import torch @@ -476,3 +477,27 @@ def _mol_to_features(mol, x = torch.cat([x1.to(torch.float), x2], dim=-1) return x, z, edge_index, edge_attr, neighbor_dict, chiral_tag + + +def featurize_mol(mol, + dataset: str = 'qm9', + smiles: Optional[str] = None, + name: str = ''): + """ + Featurize a molecule. + """ + mol = _check_mol(mol, smiles=smiles) + name = smiles if (smiles and not name) else name + + if mol: + x, _, edge_index, edge_attr, neighbor_dict, chiral_tag \ + = _mol_to_features(mol, dataset=dataset) + data = Data(x=x, + edge_index=edge_index, + edge_attr=edge_attr, + neighbors=neighbor_dict, + chiral_tag=chiral_tag, + name=name) + data.edge_index_dihedral_pairs = get_dihedral_pairs( + data.edge_index, + data=data) From aa04db4381a67ca3fe8305de5ab9bdead023fb6b Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 13:15:43 -0400 Subject: [PATCH 09/28] Bugfix: featurize mol missing returns --- model/featurization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model/featurization.py b/model/featurization.py index 6e29454..ebfcd88 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -499,5 +499,6 @@ def featurize_mol(mol, chiral_tag=chiral_tag, name=name) data.edge_index_dihedral_pairs = get_dihedral_pairs( - data.edge_index, - data=data) + data.edge_index, + data=data) + return data From 1616e5396e9921108049865b5eb653122b9c71a5 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 13:18:03 -0400 Subject: [PATCH 10/28] Simplify featurize_mol_from_smiles --- model/featurization.py | 107 +++++------------------------------------ 1 file changed, 12 insertions(+), 95 deletions(-) diff --git a/model/featurization.py b/model/featurization.py index ebfcd88..ef8a0aa 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -256,101 +256,6 @@ def construct_loader(args, modes=('train', 'val')): return loaders -def featurize_mol_from_smiles(smiles, dataset='qm9'): - - types = dataset_types[dataset] - - # filter fragments - if '.' in smiles: - return None - - # filter mols rdkit can't intrinsically handle - mol = Chem.MolFromSmiles(smiles) - if mol: - mol = Chem.AddHs(mol) - else: - return None - N = mol.GetNumAtoms() - - # filter out mols model can't make predictions for - if not mol.HasSubstructMatch(dihedral_pattern): - return None - if N < 4: - return None - - type_idx = [] - atomic_number = [] - atom_features = [] - chiral_tag = [] - neighbor_dict = {} - ring = mol.GetRingInfo() - for i, atom in enumerate(mol.GetAtoms()): - type_idx.append(types[atom.GetSymbol()]) - n_ids = [n.GetIdx() for n in atom.GetNeighbors()] - if len(n_ids) > 1: - neighbor_dict[i] = torch.tensor(n_ids) - chiral_tag.append(chirality[atom.GetChiralTag()]) - atomic_number.append(atom.GetAtomicNum()) - atom_features.extend([atom.GetAtomicNum(), - 1 if atom.GetIsAromatic() else 0]) - atom_features.extend(one_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetHybridization(), [ - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2])) - atom_features.extend(one_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetFormalCharge(), [-1, 0, 1])) - atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)), - int(ring.IsAtomInRingOfSize(i, 4)), - int(ring.IsAtomInRingOfSize(i, 5)), - int(ring.IsAtomInRingOfSize(i, 6)), - int(ring.IsAtomInRingOfSize(i, 7)), - int(ring.IsAtomInRingOfSize(i, 8))]) - atom_features.extend(one_k_encoding(int(ring.NumAtomRings(i)), [0, 1, 2, 3])) - - z = torch.tensor(atomic_number, dtype=torch.long) - chiral_tag = torch.tensor(chiral_tag, dtype=torch.float) - - row, col, edge_type, bond_features = [], [], [], [] - for bond in mol.GetBonds(): - start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() - row += [start, end] - col += [end, start] - edge_type += 2 * [bonds[bond.GetBondType()]] - bt = tuple( - sorted([bond.GetBeginAtom().GetAtomicNum(), bond.GetEndAtom().GetAtomicNum()])), bond.GetBondTypeAsDouble() - bond_features += 2 * [int(bond.IsInRing()), - int(bond.GetIsConjugated()), - int(bond.GetIsAromatic())] - - edge_index = torch.tensor([row, col], dtype=torch.long) - edge_type = torch.tensor(edge_type, dtype=torch.long) - edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) - # bond_features = torch.tensor(bond_features, dtype=torch.float).view(len(bond_type), -1) - - perm = (edge_index[0] * N + edge_index[1]).argsort() - edge_index = edge_index[:, perm] - edge_type = edge_type[perm] - # edge_attr = torch.cat([edge_attr[perm], bond_features], dim=-1) - edge_attr = edge_attr[perm] - - row, col = edge_index - hs = (z == 1).to(torch.float) - num_hs = scatter(hs[row], col, dim_size=N).tolist() - - x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types)) - x2 = torch.tensor(atom_features).view(N, -1) - x = torch.cat([x1.to(torch.float), x2], dim=-1) - - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, neighbors=neighbor_dict, chiral_tag=chiral_tag, - name=smiles) - data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data) - - return data - - def smiles_to_mol(smiles: str, check_mol: bool = True): """ @@ -502,3 +407,15 @@ def featurize_mol(mol, data.edge_index, data=data) return data + + +def featurize_mol_from_smiles(smiles: str, + dataset='qm9'): + """ + Featurize a molecule from a SMILES string. + """ + mol = smiles_to_mol(smiles, check_mol=True) + if mol: + return featurize_mol(mol, + dataset=dataset, + name=smiles) From ae796e54af91ebf74b973c4ae3d3970e2fefea14 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 13:23:48 -0400 Subject: [PATCH 11/28] Simplify featurize_mol method of geom_confs --- model/featurization.py | 105 ++++++++--------------------------------- 1 file changed, 19 insertions(+), 86 deletions(-) diff --git a/model/featurization.py b/model/featurization.py index ef8a0aa..00e7a8b 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -96,31 +96,21 @@ def open_pickle(self, mol_path): return dic def featurize_mol(self, mol_dic): - confs = mol_dic['conformers'] + confs, name = mol_dic['conformers'], mol_dic["smiles"] random.shuffle(confs) # shuffle confs - name = mol_dic["smiles"] # filter mols rdkit can't intrinsically handle - mol_ = Chem.MolFromSmiles(name) - if mol_: - canonical_smi = Chem.MolToSmiles(mol_) - else: - return None - - # skip conformers with fragments - if '.' in name: + try: + canonical_smi = Chem.MolToSmiles(Chem.MolFromSmiles(name)) + except Exception: return None # skip conformers without dihedrals - N = confs[0]['rd_mol'].GetNumAtoms() - if N < 4: - return None - if confs[0]['rd_mol'].GetNumBonds() < 4: - return None - if not confs[0]['rd_mol'].HasSubstructMatch(dihedral_pattern): + if _check_mol(confs[0]['rd_mol'], smiles=name) is None: return None - pos = torch.zeros([self.max_confs, N, 3]) + n_atom = confs[0]['rd_mol'].GetNumAtoms() + pos = torch.zeros([self.max_confs, n_atom, 3]) pos_mask = torch.zeros(self.max_confs, dtype=torch.int64) k = 0 for conf in confs: @@ -134,7 +124,7 @@ def featurize_mol(self, mol_dic): # filter for conformers that may have reacted try: conf_canonical_smi = Chem.MolToSmiles(Chem.RemoveHs(mol)) - except Exception as e: + except Exception: continue if conf_canonical_smi != canonical_smi: @@ -151,74 +141,17 @@ def featurize_mol(self, mol_dic): if k == 0: return None - type_idx = [] - atomic_number = [] - atom_features = [] - chiral_tag = [] - neighbor_dict = {} - ring = correct_mol.GetRingInfo() - for i, atom in enumerate(correct_mol.GetAtoms()): - type_idx.append(self.types[atom.GetSymbol()]) - n_ids = [n.GetIdx() for n in atom.GetNeighbors()] - if len(n_ids) > 1: - neighbor_dict[i] = torch.tensor(n_ids) - chiral_tag.append(chirality[atom.GetChiralTag()]) - atomic_number.append(atom.GetAtomicNum()) - atom_features.extend([atom.GetAtomicNum(), - 1 if atom.GetIsAromatic() else 0]) - atom_features.extend(one_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetHybridization(), [ - HybridizationType.SP, - HybridizationType.SP2, - HybridizationType.SP3, - HybridizationType.SP3D, - HybridizationType.SP3D2])) - atom_features.extend(one_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetFormalCharge(), [-1, 0, 1])) - atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)), - int(ring.IsAtomInRingOfSize(i, 4)), - int(ring.IsAtomInRingOfSize(i, 5)), - int(ring.IsAtomInRingOfSize(i, 6)), - int(ring.IsAtomInRingOfSize(i, 7)), - int(ring.IsAtomInRingOfSize(i, 8))]) - atom_features.extend(one_k_encoding(int(ring.NumAtomRings(i)), [0, 1, 2, 3])) - - z = torch.tensor(atomic_number, dtype=torch.long) - chiral_tag = torch.tensor(chiral_tag, dtype=torch.float) - - row, col, edge_type, bond_features = [], [], [], [] - for bond in correct_mol.GetBonds(): - start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() - row += [start, end] - col += [end, start] - edge_type += 2 * [self.bonds[bond.GetBondType()]] - bt = tuple(sorted([bond.GetBeginAtom().GetAtomicNum(), bond.GetEndAtom().GetAtomicNum()])), bond.GetBondTypeAsDouble() - bond_features += 2 * [int(bond.IsInRing()), - int(bond.GetIsConjugated()), - int(bond.GetIsAromatic())] - - edge_index = torch.tensor([row, col], dtype=torch.long) - edge_type = torch.tensor(edge_type, dtype=torch.long) - edge_attr = F.one_hot(edge_type, num_classes=len(self.bonds)).to(torch.float) - # bond_features = torch.tensor(bond_features, dtype=torch.float).view(len(bond_type), -1) - - perm = (edge_index[0] * N + edge_index[1]).argsort() - edge_index = edge_index[:, perm] - edge_type = edge_type[perm] - # edge_attr = torch.cat([edge_attr[perm], bond_features], dim=-1) - edge_attr = edge_attr[perm] - - row, col = edge_index - hs = (z == 1).to(torch.float) - num_hs = scatter(hs[row], col, dim_size=N).tolist() - - x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(self.types)) - x2 = torch.tensor(atom_features).view(N, -1) - x = torch.cat([x1.to(torch.float), x2], dim=-1) - - data = Data(x=x, z=z, pos=[pos], edge_index=edge_index, edge_attr=edge_attr, neighbors=neighbor_dict, - chiral_tag=chiral_tag, name=name, boltzmann_weight=conf['boltzmannweight'], - degeneracy=conf['degeneracy'], mol=correct_mol, pos_mask=pos_mask) + x, z, edge_index, edge_attr, neighbor_dict, chiral_tag \ + = _mol_to_features(correct_mol, self.dataset) + + data = Data(x=x, z=z, pos=[pos], + edge_index=edge_index, edge_attr=edge_attr, + neighbors=neighbor_dict, + chiral_tag=chiral_tag, + name=name, mol=correct_mol, + boltzmann_weight=conf['boltzmannweight'], + degeneracy=conf['degeneracy'], + pos_mask=pos_mask) return data From f341404fe69a24a3c64379415c9af42daafb0526 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 13:37:03 -0400 Subject: [PATCH 12/28] Add from_data_list to avoid errors in neighbors in pyg 2 --- model/featurization.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/model/featurization.py b/model/featurization.py index 00e7a8b..8bd36a6 100644 --- a/model/featurization.py +++ b/model/featurization.py @@ -7,15 +7,18 @@ import pickle import random from typing import Optional +from packaging import version import numpy as np import torch import torch.nn.functional as F +import torch_geometric as tg +from torch_geometric.data import Batch, Data, DataLoader, Dataset from torch_scatter import scatter -from torch_geometric.data import Dataset, Data, DataLoader from model.utils import get_dihedral_pairs +tg_version_ge_2 = version.parse(tg.__version__) > version.parse('2.0.0') bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1., @@ -352,3 +355,22 @@ def featurize_mol_from_smiles(smiles: str, return featurize_mol(mol, dataset=dataset, name=smiles) + + +def from_data_list(data_list: list): + """ + Creates a batch object from a list of data objects. This is useful for inference with an improvisational list of features from different molecules. + This function is a wrapper for the torch_geometric function Batch.from_data_list + with a special treatment for the neighbors attribute. If without the + treatment, neighbors will be collapsed into a single dict and only have keys in the + first elements, causing an error raised in "get_neighbor_ids". + + It has only been tested and applied for torch_geometric over version 2.0.0. + """ + if tg_version_ge_2: + batch_data = Batch.from_data_list(data_list, + exclude_keys=['neighbors']) + batch_data.neighbors = [d.neighbors for d in data_list] + else: + batch_data = Batch.from_data_list(data_list) + return batch_data From 398897b19a89033c19630db7e7dd33f03d1115b6 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 14:22:51 -0400 Subject: [PATCH 13/28] Reame model to geomol --- {model => geomol}/GNN.py | 0 {model => geomol}/__init__.py | 0 {model => geomol}/cycle_utils.py | 0 {model => geomol}/featurization.py | 0 {model => geomol}/inference.py | 0 {model => geomol}/model.py | 0 {model => geomol}/parsing.py | 0 {model => geomol}/training.py | 0 {model => geomol}/utils.py | 0 setup.py | 2 +- train.py | 8 ++++---- utils.py | 2 +- 12 files changed, 6 insertions(+), 6 deletions(-) rename {model => geomol}/GNN.py (100%) rename {model => geomol}/__init__.py (100%) rename {model => geomol}/cycle_utils.py (100%) rename {model => geomol}/featurization.py (100%) rename {model => geomol}/inference.py (100%) rename {model => geomol}/model.py (100%) rename {model => geomol}/parsing.py (100%) rename {model => geomol}/training.py (100%) rename {model => geomol}/utils.py (100%) diff --git a/model/GNN.py b/geomol/GNN.py similarity index 100% rename from model/GNN.py rename to geomol/GNN.py diff --git a/model/__init__.py b/geomol/__init__.py similarity index 100% rename from model/__init__.py rename to geomol/__init__.py diff --git a/model/cycle_utils.py b/geomol/cycle_utils.py similarity index 100% rename from model/cycle_utils.py rename to geomol/cycle_utils.py diff --git a/model/featurization.py b/geomol/featurization.py similarity index 100% rename from model/featurization.py rename to geomol/featurization.py diff --git a/model/inference.py b/geomol/inference.py similarity index 100% rename from model/inference.py rename to geomol/inference.py diff --git a/model/model.py b/geomol/model.py similarity index 100% rename from model/model.py rename to geomol/model.py diff --git a/model/parsing.py b/geomol/parsing.py similarity index 100% rename from model/parsing.py rename to geomol/parsing.py diff --git a/model/training.py b/geomol/training.py similarity index 100% rename from model/training.py rename to geomol/training.py diff --git a/model/utils.py b/geomol/utils.py similarity index 100% rename from model/utils.py rename to geomol/utils.py diff --git a/setup.py b/setup.py index 8ab68af..94ee2fc 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ long_description = fh.read() setup( - name="GeomMol", + name="GeoMol", version="1.0.0", author="Lagnajit Pattanaik", author_email="lagnajit@mit.com", diff --git a/train.py b/train.py index 36d87af..5aab533 100644 --- a/train.py +++ b/train.py @@ -6,11 +6,11 @@ import numpy as np import random -from model.model import GeoMol -from model.training import train, test, NoamLR +from geomol.model import GeoMol +from geomol.training import train, test, NoamLR from utils import create_logger, dict_to_str, plot_train_val_loss, save_yaml_file, get_optimizer_and_scheduler -from model.featurization import construct_loader -from model.parsing import parse_train_args, set_hyperparams +from geomol.featurization import construct_loader +from geomol.parsing import parse_train_args, set_hyperparams from torch.utils.tensorboard import SummaryWriter import resource diff --git a/utils.py b/utils.py index 7e5c894..c6fbc16 100644 --- a/utils.py +++ b/utils.py @@ -7,7 +7,7 @@ import yaml import torch -from model.training import build_lr_scheduler +from geomol.training import build_lr_scheduler sns.set_style('whitegrid', {'axes.edgecolor': '.2'}) From 0b0b64cd2247fb685ce85c549e0fc458fd27dccd Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 14:25:44 -0400 Subject: [PATCH 14/28] Rename model to geomol continue --- generate_confs.py | 16 ++++++++-------- geomol/featurization.py | 2 +- geomol/inference.py | 6 ++++-- geomol/model.py | 4 ++-- geomol/utils.py | 2 +- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/generate_confs.py b/generate_confs.py index c17c1ea..0272a9f 100644 --- a/generate_confs.py +++ b/generate_confs.py @@ -9,10 +9,10 @@ import torch import yaml -from model.model import GeoMol -from model.featurization import featurize_mol_from_smiles +from geomol.model import GeoMol +from geomol.featurization import featurize_mol_from_smiles from torch_geometric.data import Batch -from model.inference import construct_conformers +from geomol.inference import construct_conformers parser = ArgumentParser() @@ -45,17 +45,17 @@ conformer_dict = {} for smi, n_confs in tqdm(test_data.values): - + # create data object (skip smiles rdkit can't handle) tg_data = featurize_mol_from_smiles(smi, dataset=dataset) if not tg_data: print(f'failed to featurize SMILES: {smi}') continue - + # generate model predictions data = Batch.from_data_list([tg_data]) model(data, inference=True, n_model_confs=n_confs*2) - + # set coords n_atoms = tg_data.x.size(0) model_coords = construct_conformers(data, model) @@ -73,9 +73,9 @@ except Exception as e: pass mols.append(mol) - + conformer_dict[smi] = mols - + # save to file if args.out: with open(f'{args.out}', 'wb') as f: diff --git a/geomol/featurization.py b/geomol/featurization.py index 8bd36a6..92a9201 100644 --- a/geomol/featurization.py +++ b/geomol/featurization.py @@ -16,7 +16,7 @@ from torch_geometric.data import Batch, Data, DataLoader, Dataset from torch_scatter import scatter -from model.utils import get_dihedral_pairs +from geomol.utils import get_dihedral_pairs tg_version_ge_2 = version.parse(tg.__version__) > version.parse('2.0.0') diff --git a/geomol/inference.py b/geomol/inference.py index e626f2e..9bbe282 100644 --- a/geomol/inference.py +++ b/geomol/inference.py @@ -2,8 +2,8 @@ import numpy as np import networkx as nx import torch_geometric as tg -from model.utils import batch_dihedrals -from model.cycle_utils import * +from geomol.utils import batch_dihedrals +from geomol.cycle_utils import * device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -436,3 +436,5 @@ def build_gamma_rotation_inf(gamma_sin, gamma_cos, n_model_confs): H_gamma[:, 2, 2] = gamma_cos return H_gamma + + diff --git a/geomol/model.py b/geomol/model.py index 56ba621..1e26701 100644 --- a/geomol/model.py +++ b/geomol/model.py @@ -6,8 +6,8 @@ from torch_geometric.nn import global_add_pool from torch_scatter import scatter -from model.GNN import GNN, MLP -from model.utils import * +from geomol.GNN import GNN, MLP +from geomol.utils import * from itertools import permutations import numpy as np diff --git a/geomol/utils.py b/geomol/utils.py index d27c941..1af9a7a 100644 --- a/geomol/utils.py +++ b/geomol/utils.py @@ -2,7 +2,7 @@ import torch_geometric as tg from torch_geometric.utils import degree import networkx as nx -from model.cycle_utils import get_current_cycle_indices +from geomol.cycle_utils import get_current_cycle_indices device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') angle_mask_ref = torch.LongTensor([[0, 0, 0, 0, 0, 0], From 39a60fb5b8f510795ca00220866cfe8c7cd3e50b Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 15:26:56 -0400 Subject: [PATCH 15/28] Adjust for the change in global_add_pool PR#4827 in PYG changes how scatter is called, defaults to the dim=-2 instead of 0. This change directly calls scatter instead of global_add_pool to avoid the incompatibility. --- geomol/model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/geomol/model.py b/geomol/model.py index 1e26701..012416c 100644 --- a/geomol/model.py +++ b/geomol/model.py @@ -245,7 +245,15 @@ def embed(self, x, edge_index, edge_attr, batch): x2 = x_global[x_mask, :] else: - h_mol = self.h_mol_mlp(global_add_pool(x2, batch)) + # global_add_pool changed in PR #4827 to use dim=-2 instead of 0 by default + # Use a more general version to support both new and old versions of PyTorch Geometric + size = int(batch.max().item() + 1) + h_mol = self.h_mol_mlp( + scatter(x2, + batch, + dim=0, + dim_size=size, + reduce='sum')) return x1, x2, h_mol From 5e3cb5ad03a54bd45c227b162fb954aa97dec771 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 15:44:00 -0400 Subject: [PATCH 16/28] update generate_confs with from_data_list --- generate_confs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/generate_confs.py b/generate_confs.py index 0272a9f..1e4e3a5 100644 --- a/generate_confs.py +++ b/generate_confs.py @@ -10,8 +10,7 @@ import yaml from geomol.model import GeoMol -from geomol.featurization import featurize_mol_from_smiles -from torch_geometric.data import Batch +from geomol.featurization import featurize_mol_from_smiles, from_data_list from geomol.inference import construct_conformers @@ -53,7 +52,7 @@ continue # generate model predictions - data = Batch.from_data_list([tg_data]) + data = from_data_list([tg_data]) model(data, inference=True, n_model_confs=n_confs*2) # set coords From a59e38b60502af77d8665bb2975ffb76fa36fec0 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 18 Apr 2023 16:01:28 -0400 Subject: [PATCH 17/28] Add a dummy __init__ file to root for easy external import --- __init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 __init__.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 From 6a877695d665b3bd30cce066572e4e7d51eb0010 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 12 Sep 2023 14:07:54 -0400 Subject: [PATCH 18/28] AutoPEP8 correction --- generate_confs.py | 2 +- geomol/GNN.py | 12 ++++-- geomol/featurization.py | 17 ++++----- geomol/inference.py | 31 +++++++++------ geomol/model.py | 81 ++++++++++++++++++++++++++-------------- geomol/training.py | 9 ++++- geomol/utils.py | 4 +- scripts/compare_confs.py | 24 ++++++------ setup.py | 2 +- utils.py | 17 +++++++-- 10 files changed, 126 insertions(+), 73 deletions(-) diff --git a/generate_confs.py b/generate_confs.py index 1e4e3a5..850652c 100644 --- a/generate_confs.py +++ b/generate_confs.py @@ -53,7 +53,7 @@ # generate model predictions data = from_data_list([tg_data]) - model(data, inference=True, n_model_confs=n_confs*2) + model(data, inference=True, n_model_confs=n_confs * 2) # set coords n_atoms = tg_data.x.size(0) diff --git a/geomol/GNN.py b/geomol/GNN.py index d332c18..466931d 100644 --- a/geomol/GNN.py +++ b/geomol/GNN.py @@ -13,11 +13,12 @@ class MLP(nn.Module): Inputs: in_dim (int): number of features contained in the input layer. - out_dim (int): number of features input and output from each hidden layer, + out_dim (int): number of features input and output from each hidden layer, including the output layer. num_layers (int): number of layers in the network activation (torch function): activation function to be used during the hidden layers """ + def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), layer_norm=False, batch_norm=False): super(MLP, self).__init__() self.layers = nn.ModuleList() @@ -30,11 +31,13 @@ def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), laye self.layers.append(nn.Linear(in_dim, h_dim)) else: self.layers.append(nn.Linear(h_dim, h_dim)) - if layer_norm: self.layers.append(nn.LayerNorm(h_dim)) - if batch_norm: self.layers.append(nn.BatchNorm1d(h_dim)) + if layer_norm: + self.layers.append(nn.LayerNorm(h_dim)) + if batch_norm: + self.layers.append(nn.BatchNorm1d(h_dim)) self.layers.append(activation) self.layers.append(nn.Linear(h_dim, out_dim)) - + def forward(self, x): for i in range(len(self.layers)): x = self.layers[i](x) @@ -46,6 +49,7 @@ class MetaLayer(torch.nn.Module): `"Relational Inductive Biases, Deep Learning, and Graph Networks" `_ paper. """ + def __init__(self, edge_model=None, node_model=None): super(MetaLayer, self).__init__() self.edge_model = edge_model diff --git a/geomol/featurization.py b/geomol/featurization.py index 92a9201..4357e2f 100644 --- a/geomol/featurization.py +++ b/geomol/featurization.py @@ -291,11 +291,10 @@ def _mol_to_features(mol, row += [start, end] col += [end, start] edge_type += 2 * [bonds[bond.GetBondType()]] - bt = tuple( - sorted( - [bond.GetBeginAtom().GetAtomicNum(), - bond.GetEndAtom().GetAtomicNum()] - )), bond.GetBondTypeAsDouble() + bt = tuple(sorted( + [bond.GetBeginAtom().GetAtomicNum(), + bond.GetEndAtom().GetAtomicNum()] + )), bond.GetBondTypeAsDouble() bond_features += 2 * [int(bond.IsInRing()), int(bond.GetIsConjugated()), int(bond.GetIsAromatic())] @@ -332,16 +331,16 @@ def featurize_mol(mol, if mol: x, _, edge_index, edge_attr, neighbor_dict, chiral_tag \ - = _mol_to_features(mol, dataset=dataset) + = _mol_to_features(mol, dataset=dataset) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, neighbors=neighbor_dict, chiral_tag=chiral_tag, name=name) - data.edge_index_dihedral_pairs = get_dihedral_pairs( - data.edge_index, - data=data) + data.edge_index_dihedral_pairs \ + = get_dihedral_pairs(data.edge_index, + data=data) return data diff --git a/geomol/inference.py b/geomol/inference.py index 9bbe282..1ba9df8 100644 --- a/geomol/inference.py +++ b/geomol/inference.py @@ -36,7 +36,12 @@ def construct_conformers(data, model): if any(x_cycle_check) and any(y_cycle_check): # both in new cycle cycle_indices = get_current_cycle_indices(cycles, x_cycle_check, x_index) - cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i) # i instead of i+1 + cycle_avg_coords, cycle_avg_indices \ + = smooth_cycle_coords(model, + cycle_indices, + new_pos, + dihedral_pairs, + i) # i instead of i+1 # new graph if x_index not in Sx: @@ -48,9 +53,11 @@ def construct_conformers(data, model): p_mask = [True if a in Sx else False for a in sorted(cycle_avg_indices)] q_mask = [True if a in sorted(cycle_avg_indices) else False for a in Sx] p_reorder = sorted(range(len(cycle_avg_indices)), key=lambda k: cycle_avg_indices[k]) - aligned_cycle_coords = align_coords_Kabsch(cycle_avg_coords[p_reorder].permute(1, 0, 2).unsqueeze(0), new_pos[Sx].permute(1, 0, 2), p_mask, q_mask) + aligned_cycle_coords = align_coords_Kabsch( + cycle_avg_coords[p_reorder].permute(1, 0, 2).unsqueeze(0), + new_pos[Sx].permute(1, 0, 2), p_mask, q_mask) aligned_cycle_coords = aligned_cycle_coords.squeeze(0).permute(1, 0, 2) - cycle_avg_indices_reordered = [cycle_avg_indices[l] for l in p_reorder] + cycle_avg_indices_reordered = [cycle_avg_indices[i] for i in p_reorder] # apply to all new coordinates? new_pos[cycle_avg_indices_reordered] = aligned_cycle_coords @@ -63,7 +70,7 @@ def construct_conformers(data, model): if any(y_cycle_check): cycle_indices = get_current_cycle_indices(cycles, y_cycle_check, y_index) cycle_added = True - in_cycle = len(cycle_indices)+1 + in_cycle = len(cycle_indices) + 1 # new graph p_coords = torch.zeros([4, model.n_model_confs, 3]) @@ -94,8 +101,8 @@ def construct_conformers(data, model): # set Y if cycle_added: - cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i+1) - cycle_avg_coords = cycle_avg_coords - cycle_avg_coords[cycle_avg_indices == y_index] # move y to origin + cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i + 1) + cycle_avg_coords = cycle_avg_coords - cycle_avg_coords[cycle_avg_indices == y_index] # move y to origin q_idx = model.neighbors[y_index] q_coords_mask = [True if a in q_idx else False for a in cycle_avg_indices] q_coords = torch.zeros([4, model.n_model_confs, 3]) @@ -147,10 +154,10 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta cycle_len = len(cycle_indices) # get dihedral pairs corresponding to current cycle - cycle_pairs = dihedral_pairs[cycle_start_idx:cycle_start_idx+cycle_len] + cycle_pairs = dihedral_pairs[cycle_start_idx:cycle_start_idx + cycle_len] # create indices for cycle - cycle_i = np.arange(cycle_start_idx, cycle_start_idx+cycle_len) + cycle_i = np.arange(cycle_start_idx, cycle_start_idx + cycle_len) # create ordered dihedral pairs and indices which each start at a different point in the cycle cycle_dihedral_pair_orders = np.stack([np.roll(cycle_pairs, -i, axis=0) for i in range(len(cycle_pairs))])[:-1] @@ -258,7 +265,11 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta p_cycle_coords_aligned = align_coords_Kabsch(p_cycle_coords, q_cycle_coords, cycle_rmsd_mask).permute(0, 2, 1, 3) # average aligned coords - cycle_avg_coords_ = torch.vstack([q_cycle_coords_aligned.unsqueeze(0), p_cycle_coords_aligned]) * cycle_mask[:, Sx_cycle[0]].unsqueeze(-1).unsqueeze(-1) + cycle_avg_coords_ \ + = torch.vstack([q_cycle_coords_aligned.unsqueeze(0), + p_cycle_coords_aligned]) \ + * cycle_mask[:, Sx_cycle[0]].unsqueeze(-1).unsqueeze(-1) + cycle_avg_coords = cycle_avg_coords_.sum(dim=0) / cycle_mask[:, Sx_cycle[0]].sum(dim=0).unsqueeze(-1).unsqueeze(-1) return cycle_avg_coords, Sx_cycle[0] @@ -436,5 +447,3 @@ def build_gamma_rotation_inf(gamma_sin, gamma_cos, n_model_confs): H_gamma[:, 2, 2] = gamma_cos return H_gamma - - diff --git a/geomol/model.py b/geomol/model.py index 012416c..48f8a58 100644 --- a/geomol/model.py +++ b/geomol/model.py @@ -79,8 +79,8 @@ def forward(self, data, ignore_neighbors=False, inference=False, n_model_confs=N self.generate_model_prediction(data.x, data.edge_index, data.edge_attr, data.batch, data.chiral_tag) return - x, edge_index, edge_attr, pos_list, batch, pos_mask, chiral_tag = \ - data.x, data.edge_index, data.edge_attr, data.pos, data.batch, data.pos_mask, data.chiral_tag + x, edge_index, edge_attr, pos_list, batch, pos_mask, chiral_tag \ + = data.x, data.edge_index, data.edge_attr, data.pos, data.batch, data.pos_mask, data.chiral_tag # assign neighborhoods self.assign_neighborhoods(x, edge_index, edge_attr, batch, data) @@ -216,8 +216,8 @@ def embed(self, x, edge_index, edge_attr, batch): # stochasticity rand_dist = torch.distributions.normal.Normal(loc=0, scale=self.random_vec_std) # rand_dist = torch.distributions.uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) - rand_x = rand_dist.sample([x.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze - rand_edge = rand_dist.sample([edge_attr.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze + rand_x = rand_dist.sample([x.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze + rand_edge = rand_dist.sample([edge_attr.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze x = torch.cat([x.unsqueeze(1).repeat(1, self.n_model_confs, 1), rand_x], dim=-1) edge_attr = torch.cat([edge_attr.unsqueeze(1).repeat(1, self.n_model_confs, 1), rand_edge], dim=-1) @@ -234,9 +234,18 @@ def embed(self, x, edge_index, edge_attr, batch): x_transformer = x_transformer.permute(1, 0, 2, 3).reshape(n_max, -1, self.model_dim) x_transformer_mask = x_mask.unsqueeze(1).repeat(1, self.n_model_confs, 1).view(-1, n_max) - x_global = self.global_embed(x_transformer, src_key_padding_mask=~x_transformer_mask).view( - n_max, max(batch)+1, self.n_model_confs, -1).permute(1, 0, 2, 3) * \ - x_transformer_mask.view(max(batch)+1, n_max, self.n_model_confs, 1) + x_global \ + = self.global_embed(x_transformer, + src_key_padding_mask=~x_transformer_mask) \ + .view(n_max, + max(batch) + 1, + self.n_model_confs, + -1)\ + .permute(1, 0, 2, 3) \ + * x_transformer_mask.view(max(batch) + 1, + n_max, + self.n_model_confs, + 1) # global reps for torsions h_mol = self.h_mol_mlp(x_global.sum(dim=1)) @@ -249,11 +258,11 @@ def embed(self, x, edge_index, edge_attr, batch): # Use a more general version to support both new and old versions of PyTorch Geometric size = int(batch.max().item() + 1) h_mol = self.h_mol_mlp( - scatter(x2, - batch, - dim=0, - dim_size=size, - reduce='sum')) + scatter(x2, + batch, + dim=0, + dim_size=size, + reduce='sum')) return x1, x2, h_mol @@ -273,8 +282,11 @@ def model_local_stats(self, x, chiral_tag): h_ = h.permute(1, 0, 2, 3).reshape(4, self.n_neighborhoods * self.n_model_confs, self.model_dim * 2) # CHECK RESHAPE OP h_mask = self.neighbor_mask.bool().unsqueeze(1).repeat(1, self.n_model_confs, 1).view(self.n_neighborhoods * self.n_model_confs, 4) - h_new = self.encoder(h_, src_key_padding_mask=~h_mask).view(4, self.n_neighborhoods, self.n_model_confs, self.model_dim * 2).permute(1, 0, 2, 3) \ - * self.neighbor_mask.unsqueeze(-1).unsqueeze(-1) + h_new \ + = self.encoder(h_, src_key_padding_mask=~h_mask) \ + .view(4, self.n_neighborhoods, self.n_model_confs, self.model_dim * 2)\ + .permute(1, 0, 2, 3) \ + * self.neighbor_mask.unsqueeze(-1).unsqueeze(-1) unit_normals = self.coord_pred(h_new) * self.neighbor_mask.unsqueeze(-1).unsqueeze(-1) # tetrahedral chiral corrections @@ -359,7 +371,8 @@ def local_loss(self, true_one_hop, true_two_hop, true_angles, model_one_hop, mod # bending angles loss model_angles_perms = model_angles.unsqueeze(1).repeat(1, 6, 1) - angle_loss_perm = torch.sum(von_Mises_loss(true_angles, model_angles_perms) * true_angles.bool(), dim=-1) / (true_angles.bool().sum(dim=-1) + 1e-10) + angle_loss_perm = torch.sum(von_Mises_loss(true_angles, model_angles_perms) * true_angles.bool(), + dim=-1) / (true_angles.bool().sum(dim=-1) + 1e-10) angle_loss = scatter(angle_loss_perm.max(dim=-1).values, self.neighborhood_to_mol_map, reduce="mean") return one_hop_loss, two_hop_loss, angle_loss @@ -418,7 +431,8 @@ def model_pair_stats(self, x, batch, h_mol): q_Z_translated_combos = q_Z_translated[:, qZ_idx, :] p_Y_alpha_combos = p_Y_alpha.unsqueeze(1).repeat(1, 9, 1, 1) - model_dihedrals_sin, model_dihedrals_cos = batch_dihedrals(p_T_alpha_combos, torch.zeros_like(p_Y_alpha_combos), p_Y_alpha_combos, q_Z_translated_combos) + model_dihedrals_sin, model_dihedrals_cos = batch_dihedrals( + p_T_alpha_combos, torch.zeros_like(p_Y_alpha_combos), p_Y_alpha_combos, q_Z_translated_combos) model_dihedrals_sin = model_dihedrals_sin * self.dihedral_mask.unsqueeze(-1) model_dihedrals_cos = model_dihedrals_cos * self.dihedral_mask.unsqueeze(-1) model_dihedrals = torch.stack([model_dihedrals_sin, model_dihedrals_cos], dim=0) @@ -480,12 +494,14 @@ def ground_truth_pair_stats(self, pos): true_dihedral_yn_coords = true_dihedral_coords[:, 3][~self.y_map_to_neighbor_x.bool(), :].view(-1, 3, 6, self.n_true_confs, 3)[:, qZ_idx, :] # calculate true dihedrals - true_dihedrals_sin, true_dihedrals_cos = batch_dihedrals(true_dihedral_xn_coords, true_dihedral_x_coords, true_dihedral_y_coords, true_dihedral_yn_coords) + true_dihedrals_sin, true_dihedrals_cos = batch_dihedrals( + true_dihedral_xn_coords, true_dihedral_x_coords, true_dihedral_y_coords, true_dihedral_yn_coords) true_dihedrals_sin = true_dihedrals_sin * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) true_dihedrals_cos = true_dihedrals_cos * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) true_dihedrals = torch.stack([true_dihedrals_sin, true_dihedrals_cos], dim=0) # true_dihedrals = batch_vector_angles(true_dihedral_xn_coords, true_dihedral_x_coords, true_dihedral_y_coords, - # true_dihedral_yn_coords).view(-1, 9, 6, self.n_true_confs) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) + # true_dihedral_yn_coords).view(-1, 9, 6, self.n_true_confs) * + # self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) # calculate true three-hop distances true_three_hop = torch.linalg.norm(true_dihedral_xn_coords - true_dihedral_yn_coords, dim=-1) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) @@ -514,7 +530,12 @@ def pair_loss(self, true_dihedrals, model_dihedrals, true_three_hop, model_three # dihedral loss model_dihedrals_perms = model_dihedrals.unsqueeze(-1).repeat(1, 1, 1, 6) - dihedral_loss_perms = torch.sum(von_Mises_loss(true_dihedrals[1], model_dihedrals_perms[1], true_dihedrals[0], model_dihedrals_perms[0]) * self.dihedral_mask.unsqueeze(-1), dim=-2) / (self.dihedral_mask.sum(dim=-1, keepdim=True) + 1e-10) + dihedral_loss_perms = torch.sum(von_Mises_loss(true_dihedrals[1], + model_dihedrals_perms[1], + true_dihedrals[0], + model_dihedrals_perms[0]) * self.dihedral_mask.unsqueeze(-1), + dim=-2) / (self.dihedral_mask.sum(dim=-1, + keepdim=True) + 1e-10) dihedral_loss = scatter(dihedral_loss_perms.max(dim=-1).values, self.neighborhood_pairs_to_mol_map, reduce="mean") # three-hop distance loss @@ -565,7 +586,8 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch q_X_prime = q_H[self.y_map_to_neighbor_x.bool()] transform_matrix = torch.diag(torch.tensor([-1., -1., 1.]).to(self.device)).unsqueeze(0).unsqueeze(0).unsqueeze(0) - q_Z_translated = torch.matmul(transform_matrix, q_Z_prime.unsqueeze(-1)).squeeze(-1) + p_Y_prime.unsqueeze(1) # broadcast over not coordinates + q_Z_translated = torch.matmul(transform_matrix, + q_Z_prime.unsqueeze(-1)).squeeze(-1) + p_Y_prime.unsqueeze(1) # broadcast over not coordinates # calculate alpha dihedral_h_mol = h_mol[batch[self.dihedral_pairs[0]]] # (n_dihedral_pairs, n_model_confs. model_dim/2) @@ -574,11 +596,13 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch if self. random_alpha: rand_dist = torch.distributions.normal.Normal(loc=0, scale=self.random_vec_std) rand_alpha = rand_dist.sample([self.n_dihedral_pairs, self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) - alpha = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) + \ - self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) + alpha \ + = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) \ + + self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) else: - alpha = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol], dim=-1)) + \ - self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol], dim=-1)) + alpha \ + = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol], dim=-1)) \ + + self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol], dim=-1)) alpha = alpha.view(self.n_dihedral_pairs, self.n_model_confs, 1) self.v_star = torch.cat([torch.cos(alpha), torch.sin(alpha)], dim=-1) @@ -596,14 +620,17 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch q_reps = dihedral_y_neighbor_reps[~self.y_map_to_neighbor_x.bool()].view(-1, 3, self.n_model_confs, self.model_dim) cx_reps = dihedral_x_node_reps.unsqueeze(1).repeat(1, 9, 1, 1) cy_reps = dihedral_y_node_reps.unsqueeze(1).repeat(1, 9, 1, 1) - self.c_ij = self.c_mlp(torch.cat([p_reps[:, pT_idx], cx_reps, q_reps[:, qZ_idx], cy_reps], dim=-1)) + \ - self.c_mlp(torch.cat([q_reps[:, qZ_idx], cy_reps, p_reps[:, pT_idx], cx_reps], dim=-1)) + self.c_ij \ + = self.c_mlp(torch.cat([p_reps[:, pT_idx], cx_reps, q_reps[:, qZ_idx], cy_reps], dim=-1)) \ + + self.c_mlp(torch.cat([q_reps[:, qZ_idx], cy_reps, p_reps[:, pT_idx], cx_reps], dim=-1)) # calculate gamma sin and cos A_ij = self.build_A_matrix(XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) A_curr = torch.sum(A_ij * self.c_ij.unsqueeze(-1), dim=1) determinants = torch.det(A_curr) + 1e-10 - A_curr_inv_ = A_curr.view(self.n_dihedral_pairs, self.n_model_confs, 4)[:, :, [3, 1, 2, 0]] * torch.tensor([[[1., -1., -1., 1.]]]).to(self.device) + A_curr_inv_ = A_curr.view(self.n_dihedral_pairs, + self.n_model_confs, + 4)[:, :, [3, 1, 2, 0]] * torch.tensor([[[1., -1., -1., 1.]]]).to(self.device) A_curr_inv = (A_curr_inv_ / determinants.unsqueeze(-1)).view(self.n_dihedral_pairs, self.n_model_confs, 2, 2) A_curr_inv_v_star = torch.matmul(A_curr_inv, self.v_star.unsqueeze(-1)).squeeze(-1) diff --git a/geomol/training.py b/geomol/training.py index 0995b6b..a02393c 100644 --- a/geomol/training.py +++ b/geomol/training.py @@ -100,6 +100,7 @@ class NoamLR(_LRScheduler): total_epochs * steps_per_epoch). This is roughly based on the learning rate schedule from Attention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762). """ + def __init__(self, optimizer: Optimizer, warmup_epochs: List[Union[float, int]], @@ -119,8 +120,12 @@ def __init__(self, :param max_lr: The maximum learning rate (achieved after warmup_epochs). :param final_lr: The final learning rate (achieved after total_epochs). """ - assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \ - len(max_lr) == len(final_lr) + assert len(optimizer.param_groups) \ + == len(warmup_epochs) \ + == len(total_epochs) \ + == len(init_lr) \ + == len(max_lr) \ + == len(final_lr) self.num_lrs = len(optimizer.param_groups) diff --git a/geomol/utils.py b/geomol/utils.py index 1af9a7a..08e0a46 100644 --- a/geomol/utils.py +++ b/geomol/utils.py @@ -200,7 +200,7 @@ def batch_dihedrals(p0, p1, p2, p3, angle=False): else: den = torch.linalg.norm(torch.cross(s1, s2, dim=-1), dim=-1) * torch.linalg.norm(torch.cross(s2, s3, dim=-1), dim=-1) + 1e-10 - return sin_d_/den, cos_d_/den + return sin_d_ / den, cos_d_ / den def batch_vector_angles(xn, x, y, yn): @@ -227,7 +227,7 @@ def von_Mises_loss(a, b, a_sin=None, b_sin=None): if torch.is_tensor(a_sin): out = a * b + a_sin * b_sin else: - out = a * b + torch.sqrt(1-a**2 + 1e-5) * torch.sqrt(1-b**2 + 1e-5) + out = a * b + torch.sqrt(1 - a**2 + 1e-5) * torch.sqrt(1 - b**2 + 1e-5) return out diff --git a/scripts/compare_confs.py b/scripts/compare_confs.py index 9662830..9b3732c 100644 --- a/scripts/compare_confs.py +++ b/scripts/compare_confs.py @@ -23,7 +23,7 @@ def calc_performance_stats(true_confs, model_confs): - + threshold = np.arange(0, 2.5, .125) rmsd_list = [] for tc in true_confs: @@ -42,7 +42,7 @@ def calc_performance_stats(true_confs, model_confs): coverage_precision = np.sum(rmsd_array.min(axis=0, keepdims=True) < np.expand_dims(threshold, 1), axis=1) / len(model_confs) amr_precision = rmsd_array.min(axis=0).mean() - + return coverage_recall, amr_recall, coverage_precision, amr_precision @@ -63,45 +63,45 @@ def clean_confs(smi, confs): for smi, n_confs, corrected_smi in tqdm(test_data.values): if not Chem.MolFromSmiles(smi): continue - + try: model_confs = model_preds[corrected_smi] except KeyError: print(f'no model prediction available: {corrected_smi}') - coverage_recall.append(threshold_ranges*0) + coverage_recall.append(threshold_ranges * 0) amr_recall.append(np.nan) - coverage_precision.append(threshold_ranges*0) + coverage_precision.append(threshold_ranges * 0) amr_precision.append(np.nan) test_smiles.append(smi) continue - # failure if model can't generate confs + # failure if model can't generate confs if len(model_confs) == 0: print(f'model failed: {smi}') - coverage_recall.append(threshold_ranges*0) + coverage_recall.append(threshold_ranges * 0) amr_recall.append(np.nan) - coverage_precision.append(threshold_ranges*0) + coverage_precision.append(threshold_ranges * 0) amr_precision.append(np.nan) test_smiles.append(smi) continue - + try: true_confs = true_mols[smi] except KeyError: print(f'cannot find ground truth conformer file: {smi}') continue - + # remove reacted conformers true_confs = clean_confs(corrected_smi, true_confs) if len(true_confs) == 0: print(f'poor ground truth conformers: {corrected_smi}') continue - + stats = calc_performance_stats(true_confs, model_confs) if not stats: print(f'failure calculating stats: {smi, corrected_smi}') continue - + cr, mr, cp, mp = stats coverage_recall.append(cr) amr_recall.append(mr) diff --git a/setup.py b/setup.py index 94ee2fc..b5e0172 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -#-*- coding: utf-8 -*- +# -*- coding: utf-8 -*- from setuptools import setup, find_packages diff --git a/utils.py b/utils.py index c6fbc16..c23c144 100644 --- a/utils.py +++ b/utils.py @@ -19,8 +19,10 @@ sns.color_palette('husl') local_modules = ['gnn', 'encoder', 'coord_pred', 'd_mlp'] + class Standardizer: """Z-score standardization""" + def __init__(self, mean, std): self.mean = mean self.std = std @@ -173,12 +175,19 @@ def get_optimizer_and_scheduler(args, model, train_data_size): if args.scheduler == 'plateau': if args.separate_opts: - scheduling_fn = lambda opt: torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.7, - patience=5, min_lr=args.lr / 100) + def scheduling_fn(opt): + return torch.optim.lr_scheduler.ReduceLROnPlateau(opt, + mode='min', + factor=0.7, + patience=5, + min_lr=args.lr / 100) scheduler = MultipleScheduler(optimizer, scheduling_fn) else: - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, - patience=5, min_lr=args.lr/100) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=0.7, + patience=5, + min_lr=args.lr / 100) elif args.scheduler == 'noam': scheduler = build_lr_scheduler(optimizer=optimizer, args=args, train_data_size=train_data_size) else: From cf1daef73e8d3d11579099aba4d97c91c579bae7 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Tue, 12 Sep 2023 14:08:17 -0400 Subject: [PATCH 19/28] Ignore Mac .DS_Store --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5fb6284..939e37f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,7 @@ __pycache__/ .eggs/ # Jupyter -.ipynb_checkpoints/ \ No newline at end of file +.ipynb_checkpoints/ + +# Mac File System +.DS_Store \ No newline at end of file From 6bfda9ce0eb0f95c616e13d63265a5db14851109 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Wed, 13 Mar 2024 21:41:41 -0400 Subject: [PATCH 20/28] Update the environment setup script 1. Modularization 2. able to user provided CUDA version 3. Found a workaround for Mchip Mac 4. Remove the constraints for deps 5. backup the original dep constraints in environment_reproduce.yml --- devtools/create_env.sh | 73 +++----------- devtools/environment.yml | 146 +--------------------------- devtools/environment_reproduce.yml | 146 ++++++++++++++++++++++++++++ devtools/initialize_conda.sh | 19 ++++ devtools/install_pyg.sh | 122 +++++++++++++++++++++++ devtools/install_pyg_macos_arm64.sh | 82 ++++++++++++++++ 6 files changed, 386 insertions(+), 202 deletions(-) create mode 100644 devtools/environment_reproduce.yml create mode 100644 devtools/initialize_conda.sh create mode 100644 devtools/install_pyg.sh create mode 100644 devtools/install_pyg_macos_arm64.sh diff --git a/devtools/create_env.sh b/devtools/create_env.sh index 12deec1..0b7ae8a 100644 --- a/devtools/create_env.sh +++ b/devtools/create_env.sh @@ -1,10 +1,12 @@ -# Developed by Kevin A. Spiekermann +# Developed by Xiaorui Dong and Kevin A. Spiekermann # This script does the following tasks: # - creates the conda -# - prompts user for desired CUDA version # - installs PyTorch with specified CUDA version in the environment # - installs torch torch-geometric in the environment +SCRIPT_DIR=$(dirname $0) + +CONDA_ENV_NAME="GeoMol" # get OS type unameOut="$(uname -s)" @@ -17,66 +19,17 @@ case "${unameOut}" in esac echo "Running ${machine}..." +if [ "$machine" != "MacOS" ]; then + # Prompt the user to input their desired CUDA version or 'cpu' + echo "Please input your desired CUDA version in the format xx.xx (e.g., 10.2, 12.3) or 'cpu' for no CUDA available:" + read cuda_input -# request user to select one of the supported CUDA versions -# source: https://pytorch.org/get-started/locally/ -PS3='Please enter 1, 2, 3, or 4 to specify the desired CUDA version from the options above: ' -options=("9.2" "10.1" "10.2" "cpu" "Quit") -select opt in "${options[@]}" -do - case $opt in - "9.2") - CUDA="cudatoolkit=9.2" - CUDA_VERSION="cu92" - break - ;; - "10.1") - CUDA="cudatoolkit=10.1" - CUDA_VERSION="cu101" - break - ;; - "10.2") - CUDA="cudatoolkit=10.2" - CUDA_VERSION="cu102" - break - ;; - "cpu") - # "cpuonly" works for Linux and Windows - CUDA="cpuonly" - # Mac does not use "cpuonly" - if [ $machine == "Mac" ] - then - CUDA=" " - fi - CUDA_VERSION="cpu" - break - ;; - "Quit") - exit - ;; - *) echo "invalid option $REPLY";; - esac -done - -echo "Creating conda environment..." -echo "Running: conda env create -f environment.yml" -conda env create -f devtools/environment.yml +if [ "$machine" == "MacOS" ] && [ "$(uname -m)" == "arm64" ]; then -# activate the environment to install torch-geometric -source activate GeoMol + $SHELL $SCRIPT_DIR/install_pyg_macos_arm64.sh -n $CONDA_ENV_NAME -echo "Installing PyTorch with requested CUDA version..." -echo "Running: conda install pytorch torchvision $CUDA -c pytorch" -conda install pytorch torchvision $CUDA -c pytorch +else -echo "Installing torch-geometric..." -echo "Using CUDA version: $CUDA_VERSION" -# get PyTorch version -TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") -echo "Using PyTorch version: $TORCH_VERSION" + source $SCRIPT_DIR/install_pyg.sh -n $CONDA_ENV_NAME -c $cuda_input -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-geometric +fi diff --git a/devtools/environment.yml b/devtools/environment.yml index 02962b2..e131c5d 100644 --- a/devtools/environment.yml +++ b/devtools/environment.yml @@ -1,146 +1,8 @@ name: GeoMol channels: - - rdkit - - anaconda - - conda-forge - defaults + - conda-forge dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=4.5=1_gnu - - argon2-cffi=20.1.0=py37h5e8e339_2 - - async_generator=1.10=py_0 - - attrs=21.2.0=pyhd8ed1ab_0 - - backcall=0.2.0=pyh9f0ad1d_0 - - backports=1.0=py_2 - - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 - - blas=1.0=mkl - - bleach=3.3.0=pyh44b312d_0 - - bzip2=1.0.8=h7b6447c_0 - - ca-certificates=2021.5.30=ha878542_0 - - cairo=1.16.0=hf32fb01_1 - - certifi=2021.5.30=py37h89c1867_0 - - cffi=1.14.5=py37hc58025e_0 - - cycler=0.10.0=py37_0 - - dbus=1.13.18=hb2f20db_0 - - decorator=4.4.2=pyhd3eb1b0_0 - - defusedxml=0.7.1=pyhd8ed1ab_0 - - entrypoints=0.3=pyhd8ed1ab_1003 - - expat=2.4.1=h2531618_2 - - fontconfig=2.13.1=h6c09931_0 - - freetype=2.10.4=h5ab3b9f_0 - - glib=2.68.2=h36276a3_0 - - gst-plugins-base=1.14.0=h8213a91_2 - - gstreamer=1.14.0=h28cd5cc_2 - - icu=58.2=he6710b0_3 - - importlib-metadata=4.6.0=py37h89c1867_0 - - intel-openmp=2021.2.0=h06a4308_610 - - ipykernel=5.5.5=py37h085eea5_0 - - ipython=7.25.0=py37h085eea5_1 - - ipython_genutils=0.2.0=py_1 - - ipywidgets=7.6.3=pyhd3deb0d_0 - - jedi=0.18.0=py37h89c1867_2 - - jinja2=3.0.1=pyhd8ed1ab_0 - - jpeg=9b=h024ee3a_2 - - jsonschema=3.2.0=pyhd8ed1ab_3 - - jupyter=1.0.0=py37h89c1867_6 - - jupyter_client=6.1.12=pyhd8ed1ab_0 - - jupyter_console=6.4.0=pyhd8ed1ab_0 - - jupyter_core=4.7.1=py37h89c1867_0 - - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 - - jupyterlab_widgets=1.0.0=pyhd8ed1ab_1 - - kiwisolver=1.3.1=py37h2531618_0 - - lcms2=2.12=h3be6417_0 - - ld_impl_linux-64=2.35.1=h7274673_9 - - libboost=1.73.0=h3ff78a5_11 - - libffi=3.3=he6710b0_2 - - libgcc-ng=9.3.0=h5101ec6_17 - - libgfortran-ng=7.5.0=ha8ba4b0_17 - - libgfortran4=7.5.0=ha8ba4b0_17 - - libgomp=9.3.0=h5101ec6_17 - - libpng=1.6.37=hbc83047_0 - - libsodium=1.0.18=h36c2ea0_1 - - libstdcxx-ng=9.3.0=hd4cf53a_17 - - libtiff=4.2.0=h85742a9_0 - - libuuid=1.0.3=h1bed415_2 - - libwebp-base=1.2.0=h27cfd23_0 - - libxcb=1.14=h7b6447c_0 - - libxml2=2.9.12=h03d6c58_0 - - lz4-c=1.9.3=h2531618_0 - - markupsafe=2.0.1=py37h5e8e339_0 - - matplotlib=3.3.4=py37h06a4308_0 - - matplotlib-base=3.3.4=py37h62a2d02_0 - - matplotlib-inline=0.1.2=pyhd8ed1ab_2 - - mistune=0.8.4=py37h5e8e339_1004 - - mkl=2021.2.0=h06a4308_296 - - mkl-service=2.3.0=py37h27cfd23_1 - - mkl_fft=1.3.0=py37h42c9631_2 - - mkl_random=1.2.1=py37ha9443f7_2 - - nbclient=0.5.3=pyhd8ed1ab_0 - - nbconvert=6.1.0=py37h89c1867_0 - - nbformat=5.1.3=pyhd8ed1ab_0 - - ncurses=6.2=he6710b0_1 - - nest-asyncio=1.5.1=pyhd8ed1ab_0 - - networkx=2.5.1=pyhd3eb1b0_0 - - notebook=6.4.0=pyha770c72_0 - - numpy=1.20.2=py37h2d18471_0 - - numpy-base=1.20.2=py37hfae3a4d_0 - - olefile=0.46=py37_0 - - openssl=1.1.1k=h7f98852_0 - - packaging=20.9=pyh44b312d_0 - - pandas=1.2.5=py37h295c915_0 - - pandoc=2.14.0.3=h7f98852_0 - - pandocfilters=1.4.2=py_1 - - parso=0.8.2=pyhd8ed1ab_0 - - pcre=8.45=h295c915_0 - - pexpect=4.8.0=pyh9f0ad1d_2 - - pickleshare=0.7.5=py_1003 - - pillow=8.2.0=py37he98fc37_0 - - pip=21.1.3=py37h06a4308_0 - - pixman=0.40.0=h7b6447c_0 - - pot=0.7.0=py37h3340039_0 - - prometheus_client=0.11.0=pyhd8ed1ab_0 - - prompt-toolkit=3.0.19=pyha770c72_0 - - prompt_toolkit=3.0.19=hd8ed1ab_0 - - ptyprocess=0.7.0=pyhd3deb0d_0 - - py-boost=1.73.0=py37ha9443f7_11 - - py3dmol=0.9.1=pyhd8ed1ab_0 - - pycparser=2.20=pyh9f0ad1d_2 - - pygments=2.9.0=pyhd8ed1ab_0 - - pyparsing=2.4.7=pyhd3eb1b0_0 - - pyqt=5.9.2=py37h05f1152_2 - - pyrsistent=0.17.3=py37h5e8e339_2 - - python=3.7.10=h12debd9_4 - - python-dateutil=2.8.1=pyhd3eb1b0_0 - - python_abi=3.7=2_cp37m - - pytz=2021.1=pyhd3eb1b0_0 - - pyyaml=5.3.1=py37h7b6447c_1 - - pyzmq=22.1.0=py37h336d617_0 - - qt=5.9.7=h5867ecd_1 - - qtconsole=5.1.1=pyhd8ed1ab_0 - - qtpy=1.9.0=py_0 - - rdkit=2020.09.1.0=py37hd50e099_1 - - readline=8.1=h27cfd23_0 - - scipy=1.6.2=py37had2a1c9_1 - - seaborn=0.11.1=pyhd3eb1b0_0 - - send2trash=1.7.1=pyhd8ed1ab_0 - - setuptools=52.0.0=py37h06a4308_0 - - sip=4.19.8=py37hf484d3e_0 - - six=1.16.0=pyhd3eb1b0_0 - - sqlite=3.36.0=hc218d9a_0 - - terminado=0.10.1=py37h89c1867_0 - - testpath=0.5.0=pyhd8ed1ab_0 - - tk=8.6.10=hbc83047_0 - - tornado=6.1=py37h27cfd23_0 - - tqdm=4.61.1=pyhd3eb1b0_1 - - traitlets=5.0.5=py_0 - - typing_extensions=3.10.0.0=pyha770c72_0 - - wcwidth=0.2.5=pyh9f0ad1d_2 - - webencodings=0.5.1=py_1 - - wheel=0.36.2=pyhd3eb1b0_0 - - widgetsnbextension=3.5.1=py37h89c1867_4 - - xz=5.2.5=h7b6447c_0 - - yaml=0.2.5=h7b6447c_0 - - zeromq=4.3.4=h9c3ff4c_0 - - zipp=3.4.1=pyhd8ed1ab_0 - - zlib=1.2.11=h7b6447c_3 - - zstd=1.4.9=haebb681_0 + - rdkit >=2020.03.2 + - networkx + - pot diff --git a/devtools/environment_reproduce.yml b/devtools/environment_reproduce.yml new file mode 100644 index 0000000..02962b2 --- /dev/null +++ b/devtools/environment_reproduce.yml @@ -0,0 +1,146 @@ +name: GeoMol +channels: + - rdkit + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - argon2-cffi=20.1.0=py37h5e8e339_2 + - async_generator=1.10=py_0 + - attrs=21.2.0=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=py_2 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - blas=1.0=mkl + - bleach=3.3.0=pyh44b312d_0 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2021.5.30=ha878542_0 + - cairo=1.16.0=hf32fb01_1 + - certifi=2021.5.30=py37h89c1867_0 + - cffi=1.14.5=py37hc58025e_0 + - cycler=0.10.0=py37_0 + - dbus=1.13.18=hb2f20db_0 + - decorator=4.4.2=pyhd3eb1b0_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - entrypoints=0.3=pyhd8ed1ab_1003 + - expat=2.4.1=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.2=h36276a3_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - importlib-metadata=4.6.0=py37h89c1867_0 + - intel-openmp=2021.2.0=h06a4308_610 + - ipykernel=5.5.5=py37h085eea5_0 + - ipython=7.25.0=py37h085eea5_1 + - ipython_genutils=0.2.0=py_1 + - ipywidgets=7.6.3=pyhd3deb0d_0 + - jedi=0.18.0=py37h89c1867_2 + - jinja2=3.0.1=pyhd8ed1ab_0 + - jpeg=9b=h024ee3a_2 + - jsonschema=3.2.0=pyhd8ed1ab_3 + - jupyter=1.0.0=py37h89c1867_6 + - jupyter_client=6.1.12=pyhd8ed1ab_0 + - jupyter_console=6.4.0=pyhd8ed1ab_0 + - jupyter_core=4.7.1=py37h89c1867_0 + - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 + - jupyterlab_widgets=1.0.0=pyhd8ed1ab_1 + - kiwisolver=1.3.1=py37h2531618_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libboost=1.73.0=h3ff78a5_11 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libgomp=9.3.0=h5101ec6_17 + - libpng=1.6.37=hbc83047_0 + - libsodium=1.0.18=h36c2ea0_1 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtiff=4.2.0=h85742a9_0 + - libuuid=1.0.3=h1bed415_2 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.12=h03d6c58_0 + - lz4-c=1.9.3=h2531618_0 + - markupsafe=2.0.1=py37h5e8e339_0 + - matplotlib=3.3.4=py37h06a4308_0 + - matplotlib-base=3.3.4=py37h62a2d02_0 + - matplotlib-inline=0.1.2=pyhd8ed1ab_2 + - mistune=0.8.4=py37h5e8e339_1004 + - mkl=2021.2.0=h06a4308_296 + - mkl-service=2.3.0=py37h27cfd23_1 + - mkl_fft=1.3.0=py37h42c9631_2 + - mkl_random=1.2.1=py37ha9443f7_2 + - nbclient=0.5.3=pyhd8ed1ab_0 + - nbconvert=6.1.0=py37h89c1867_0 + - nbformat=5.1.3=pyhd8ed1ab_0 + - ncurses=6.2=he6710b0_1 + - nest-asyncio=1.5.1=pyhd8ed1ab_0 + - networkx=2.5.1=pyhd3eb1b0_0 + - notebook=6.4.0=pyha770c72_0 + - numpy=1.20.2=py37h2d18471_0 + - numpy-base=1.20.2=py37hfae3a4d_0 + - olefile=0.46=py37_0 + - openssl=1.1.1k=h7f98852_0 + - packaging=20.9=pyh44b312d_0 + - pandas=1.2.5=py37h295c915_0 + - pandoc=2.14.0.3=h7f98852_0 + - pandocfilters=1.4.2=py_1 + - parso=0.8.2=pyhd8ed1ab_0 + - pcre=8.45=h295c915_0 + - pexpect=4.8.0=pyh9f0ad1d_2 + - pickleshare=0.7.5=py_1003 + - pillow=8.2.0=py37he98fc37_0 + - pip=21.1.3=py37h06a4308_0 + - pixman=0.40.0=h7b6447c_0 + - pot=0.7.0=py37h3340039_0 + - prometheus_client=0.11.0=pyhd8ed1ab_0 + - prompt-toolkit=3.0.19=pyha770c72_0 + - prompt_toolkit=3.0.19=hd8ed1ab_0 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - py-boost=1.73.0=py37ha9443f7_11 + - py3dmol=0.9.1=pyhd8ed1ab_0 + - pycparser=2.20=pyh9f0ad1d_2 + - pygments=2.9.0=pyhd8ed1ab_0 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py37h05f1152_2 + - pyrsistent=0.17.3=py37h5e8e339_2 + - python=3.7.10=h12debd9_4 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - python_abi=3.7=2_cp37m + - pytz=2021.1=pyhd3eb1b0_0 + - pyyaml=5.3.1=py37h7b6447c_1 + - pyzmq=22.1.0=py37h336d617_0 + - qt=5.9.7=h5867ecd_1 + - qtconsole=5.1.1=pyhd8ed1ab_0 + - qtpy=1.9.0=py_0 + - rdkit=2020.09.1.0=py37hd50e099_1 + - readline=8.1=h27cfd23_0 + - scipy=1.6.2=py37had2a1c9_1 + - seaborn=0.11.1=pyhd3eb1b0_0 + - send2trash=1.7.1=pyhd8ed1ab_0 + - setuptools=52.0.0=py37h06a4308_0 + - sip=4.19.8=py37hf484d3e_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - terminado=0.10.1=py37h89c1867_0 + - testpath=0.5.0=pyhd8ed1ab_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py37h27cfd23_0 + - tqdm=4.61.1=pyhd3eb1b0_1 + - traitlets=5.0.5=py_0 + - typing_extensions=3.10.0.0=pyha770c72_0 + - wcwidth=0.2.5=pyh9f0ad1d_2 + - webencodings=0.5.1=py_1 + - wheel=0.36.2=pyhd3eb1b0_0 + - widgetsnbextension=3.5.1=py37h89c1867_4 + - xz=5.2.5=h7b6447c_0 + - yaml=0.2.5=h7b6447c_0 + - zeromq=4.3.4=h9c3ff4c_0 + - zipp=3.4.1=pyhd8ed1ab_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 diff --git a/devtools/initialize_conda.sh b/devtools/initialize_conda.sh new file mode 100644 index 0000000..f44d76a --- /dev/null +++ b/devtools/initialize_conda.sh @@ -0,0 +1,19 @@ +echo "Initializing Conda..." + +if which mamba > /dev/null; then + conda_bin="mamba" + echo "As mamba is available, using mamba by default..." +else + conda_bin="conda" +fi + +conda_base_dir=$(dirname $(dirname $CONDA_EXE)) + +if [ "$conda_bin" = "mamba" ]; then + source "$conda_base_dir/etc/profile.d/conda.sh" + source "$conda_base_dir/etc/profile.d/mamba.sh" +else + source "$conda_base_dir/etc/profile.d/conda.sh" +fi + +export conda_bin diff --git a/devtools/install_pyg.sh b/devtools/install_pyg.sh new file mode 100644 index 0000000..f79653f --- /dev/null +++ b/devtools/install_pyg.sh @@ -0,0 +1,122 @@ +# A script to install Pytorch geometric on normal platform +# Author: Xiaorui Dong +# Inspired by this https://medium.com/@jgbrasier/installing-pytorch-geometric-on-mac-m1-with-accelerated-gpu-support-2e7118535c50 + +CONDA_ENV_NAME="GeoMol" +PYTHON_VERSION="3.12" +CUDA_VERSION="cpu" +SCRIPT_DIR=$(dirname $0) # Assume the other scripts are available in the same directory as this file + +# Function to display usage +usage() { + echo "Usage: $0 [-n ] [--name ] [-v ] [--python-version ] [- ] [--cuda-version ]" + exit 1 +} + +# Parse short options (-n and -v) +while getopts ":n:v:c:" opt; do + case ${opt} in + n ) + CONDA_ENV_NAME=$OPTARG + ;; + v ) + PYTHON_VERSION=$OPTARG + ;; + c ) + CUDA_VERSION=$OPTARG + ;; + \? ) + usage + ;; + esac +done + +# Remove the processed options from the parameters +shift $((OPTIND -1)) + +# Parse long options (--name and --version) +for arg in "$@"; do + case $arg in + --name=*) + CONDA_ENV_NAME="${arg#*=}" + shift # Remove --name from processing + ;; + --python_version=*) + PYTHON_VERSION="${arg#*=}" + shift # Remove --version from processing + ;; + --cuda_version=*) + CUDA_VERSION="${arg#*=}" + shift # Remove --version from processing + ;; + *) + usage + ;; + esac +done + +# parse cuda +# Using regex to capture the major and minor version numbers for detailed matching +if [[ "$(uname)" != 'Darwin' ]]; then + if [[ $CUDA_VERSION =~ ^([0-9]+)\.([0-9]+)(\.([0-9]+))?$ ]]; then + major_version="${BASH_REMATCH[1]}" + minor_version="${BASH_REMATCH[2]}" + cuda_version_formatted="${major_version}.${minor_version}" + + # Construct the CUDA and CUDA_VERSION variables based on input + CUDA="cudatoolkit=$cuda_version_formatted" + CUDA_VERSION="cu${major_version}${minor_version}" + elif [ "$cuda_input" == "cpu" ]; then + # For CPU-only selection + CUDA="cpuonly" + CUDA_VERSION="cpu" + else + echo "Invalid input. Please ensure you enter a valid CUDA version in the format xx.xx or 'cpu'." + exit 1 + fi +else + CUDA="cpuonly" + CUDA_VERSION="cpu" +fi +echo "You selected CUDA version: $CUDA_VERSION ($CUDA)" + +source $SCRIPT_DIR/initialize_conda.sh + +if conda env list | grep -qw $CONDA_ENV_NAME; then + $conda_bin activate $CONDA_ENV_NAME +else + $conda_bin create -n $CONDA_ENV_NAME python=$PYTHON_VERSION -y + $conda_bin activate $CONDA_ENV_NAME +fi + +# check Python version +PYTHON_VERSION=$(python --version) +echo "Using Python version: $PYTHON_VERSION" + +# install PyTorch +echo "Installing PyTorch with requested CUDA version $CUDA_VERSION..." +# echo "Running: conda install pytorch torchvision $CUDA -c pytorch -y" +# $conda_bin install pytorch torchvision $CUDA -c pytorch -y +echo "Running: pip install torch torchvision" +pip install torch torchvision --index-url https://download.pytorch.org/whl/$CUDA_VERSION + +# get PyTorch version +TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") +if [ -n $TORCH_VERSION ]; then + echo "Using PyTorch version: $TORCH_VERSION" +else + echo "Cannot find a matched PyTorch version with $CUDA_VERSION for Python $PYTHON_VERSION. Exit." + # echo "Removing the installed environment" + # source deactivate + # $conda_bin env remove -n $environmentName + exit 1 +fi + +# install torch_geometric +echo "Installing torch-geometric..." +echo "Using CUDA version: $CUDA_VERSION" +pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html +pip install torch-geometric + +# install other package +$conda_bin env update -f $SCRIPT_DIR/environment.yml -n $CONDA_ENV_NAME \ No newline at end of file diff --git a/devtools/install_pyg_macos_arm64.sh b/devtools/install_pyg_macos_arm64.sh new file mode 100644 index 0000000..b13c7b5 --- /dev/null +++ b/devtools/install_pyg_macos_arm64.sh @@ -0,0 +1,82 @@ +# A script to install Pytorch geometric on Mchip MacOS +# Author: Xiaorui Dong +# Inspired by this https://medium.com/@jgbrasier/installing-pytorch-geometric-on-mac-m1-with-accelerated-gpu-support-2e7118535c50 + +CONDA_ENV_NAME="GeoMol" +PYTHON_VERSION="3.12" +SCRIPT_DIR=$(dirname $0) # Assume the other scripts are available in the same directory as this file + +# Function to display usage +usage() { + echo "Usage: $0 [-n ] [--name ] [-v ] [--version ]" + exit 1 +} + +# Parse short options (-n and -v) +while getopts ":n:v:" opt; do + case ${opt} in + n ) + CONDA_ENV_NAME=$OPTARG + ;; + v ) + PYTHON_VERSION=$OPTARG + ;; + \? ) + usage + ;; + esac +done + +# Remove the processed options from the parameters +shift $((OPTIND -1)) + +# Parse long options (--name and --version) +for arg in "$@"; do + case $arg in + --name=*) + CONDA_ENV_NAME="${arg#*=}" + shift # Remove --name from processing + ;; + --version=*) + PYTHON_VERSION="${arg#*=}" + shift # Remove --version from processing + ;; + *) + # Handle unrecognized options + usage + ;; + esac +done + +source $SCRIPT_DIR/initialize_conda.sh + +if conda env list | grep -qw $CONDA_ENV_NAME; then + $conda_bin activate $CONDA_ENV_NAME +else + $conda_bin create -n $CONDA_ENV_NAME python=$PYTHON_VERSION -y + $conda_bin activate $CONDA_ENV_NAME +fi + +PYTHON_VERSION=$(python --version) +echo "Using Python version: $PYTHON_VERSION" + +# make sure compiler are correctly installed +$conda_bin install -y clang_osx-arm64 clangxx_osx-arm64 gfortran_osx-arm64 + +os_version=$(sw_vers -productVersion) + +# install PyTorch and pytorch_geometric with the correct compiler +echo "Installing PyTorch..." +MACOSX_DEPLOYMENT_TARGET=$os_version CC=clang CXX=clang++ python -m pip --no-cache-dir install torch torchvision +TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") + +MACOSX_DEPLOYMENT_TARGET=$os_version CC=clang CXX=clang++ \ +python -m pip --no-cache-dir install torch_scatter torch_sparse torch_cluster torch_spline_conv \ +-f https://data.pyg.org/whl/torch-${TORCH_VERSION}+cpu.html + +MACOSX_DEPLOYMENT_TARGET=$os_version CC=clang CXX=clang++ \ +python -m pip --no-cache-dir install torch-geometric + +# install other packages +$conda_bin env update -f $SCRIPT_DIR/environment.yml -n $CONDA_ENV_NAME +$conda_bin install nomkl From acdc7a8986d4d7f5f869e8fa19a45632666f867d Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Wed, 13 Mar 2024 22:15:50 -0400 Subject: [PATCH 21/28] Add a convenient path to the trained models --- geomol/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/geomol/utils.py b/geomol/utils.py index 08e0a46..6059293 100644 --- a/geomol/utils.py +++ b/geomol/utils.py @@ -1,3 +1,5 @@ +from pathlib import Path + import torch import torch_geometric as tg from torch_geometric.utils import degree @@ -18,6 +20,8 @@ [1, 3], [2, 3]]).to(device) +model_path = Path(__file__).parents[1] / "trained_models" + def get_neighbor_ids(data): """ From 50475e0b58ca4938261f25c1cb0603ef6ec9c966 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Wed, 13 Mar 2024 23:42:18 -0400 Subject: [PATCH 22/28] Add YAML as requirement --- devtools/environment.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/devtools/environment.yml b/devtools/environment.yml index e131c5d..6750a7f 100644 --- a/devtools/environment.yml +++ b/devtools/environment.yml @@ -6,3 +6,5 @@ dependencies: - rdkit >=2020.03.2 - networkx - pot + - yaml + - pyyaml From 9d6f7688fef048a898d406506c20af6ca9d58eca Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Thu, 14 Mar 2024 11:43:28 -0400 Subject: [PATCH 23/28] Decouple model device assignment with available device Originally, model is assigned to GPU whenever possible. This causes conflicts when using a model on a CPU but on a machine with GPU. --- geomol/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/geomol/model.py b/geomol/model.py index 48f8a58..e1476eb 100644 --- a/geomol/model.py +++ b/geomol/model.py @@ -17,7 +17,7 @@ class GeoMol(nn.Module): - def __init__(self, hyperparams, num_node_features, num_edge_features): + def __init__(self, hyperparams, num_node_features, num_edge_features, device=None): super(GeoMol, self).__init__() self.model_dim = hyperparams['model_dim'] @@ -27,7 +27,10 @@ def __init__(self, hyperparams, num_node_features, num_edge_features): self.loss_type = hyperparams['loss_type'] self.teacher_force = hyperparams['teacher_force'] self.random_alpha = hyperparams['random_alpha'] - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device is None: + torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = torch.device(device) self.gnn = GNN(node_dim=num_node_features + self.random_vec_dim, edge_dim=num_edge_features + self.random_vec_dim, From 563ef0c508e1ac65fd46088767aba605267bc0c8 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Thu, 14 Mar 2024 11:48:38 -0400 Subject: [PATCH 24/28] Workaround for batch_angles_from_coords angle_mask_ref and angle_combos are predefined without awareness of where the operation will be. Add a temporary workaround. --- geomol/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/geomol/utils.py b/geomol/utils.py index 6059293..8af762f 100644 --- a/geomol/utils.py +++ b/geomol/utils.py @@ -6,21 +6,20 @@ import networkx as nx from geomol.cycle_utils import get_current_cycle_indices -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model_path = Path(__file__).parents[1] / "trained_models" + angle_mask_ref = torch.LongTensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1]]).to(device) + [1, 1, 1, 1, 1, 1]]) angle_combos = torch.LongTensor([[0, 1], [0, 2], [1, 2], [0, 3], [1, 3], - [2, 3]]).to(device) - -model_path = Path(__file__).parents[1] / "trained_models" + [2, 3]]) def get_neighbor_ids(data): @@ -123,7 +122,7 @@ def get_dihedral_pairs(edge_index, data): keep.append(pair) - keep = [t.to(device) for t in keep] + keep = [t for t in keep] return torch.stack(keep).t() @@ -165,6 +164,12 @@ def batch_angles_from_coords(coords, mask): """ Given coordinates, compute all local neighborhood angles """ + device = coords.device + + global angle_mask_ref, angle_combos + angle_mask_ref = angle_mask_ref.to(device) + angle_combos = angle_combos.to(device) + if coords.dim() == 4: all_possible_combos = coords[:, angle_combos] v_a, v_b = all_possible_combos.split(1, dim=2) # does one of these need to be negative? From d8d4058d0a9349ed721d6c28899966a5013290c3 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Thu, 14 Mar 2024 11:51:37 -0400 Subject: [PATCH 25/28] introduce more accurate device control during inference Previously, inference's device are kind of hardcoded. This commit makes sure that device are corrected assigned, so that you can run inference on CPU or GPU on a machine with GPU or inference on CPU on a machine without GPU --- geomol/cycle_utils.py | 2 +- geomol/inference.py | 74 +++++++++++++++++++++++-------------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/geomol/cycle_utils.py b/geomol/cycle_utils.py index 213cf27..005ed41 100644 --- a/geomol/cycle_utils.py +++ b/geomol/cycle_utils.py @@ -48,7 +48,7 @@ def align_coords_Kabsch(p_cycle_coords, q_cycle_coords, p_mask, q_mask=None): H = torch.matmul(p_cycle_coords_centered.permute(0, 1, 3, 2), q_cycle_coords_centered.unsqueeze(0)) u, s, v = torch.svd(H) d = torch.sign(torch.det(torch.matmul(v, u.permute(0, 1, 3, 2)))) - R_1 = torch.diag_embed(torch.ones([p_cycle_coords.size(0), q_cycle_coords.size(0), 3])) + R_1 = torch.diag_embed(torch.ones([p_cycle_coords.size(0), q_cycle_coords.size(0), 3], device=u.device)) R_1[:, :, 2, 2] = d R = torch.matmul(v, torch.matmul(R_1, u.permute(0, 1, 3, 2))) b = q_cycle_coords[:, q_mask].mean(dim=1) - torch.matmul(R, p_cycle_coords[:, :, p_mask].mean(dim=2).unsqueeze( diff --git a/geomol/inference.py b/geomol/inference.py index 1ba9df8..14b7598 100644 --- a/geomol/inference.py +++ b/geomol/inference.py @@ -5,16 +5,16 @@ from geomol.utils import batch_dihedrals from geomol.cycle_utils import * -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +def construct_conformers(data, model, device=None): -def construct_conformers(data, model): + device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') G = nx.to_undirected(tg.utils.to_networkx(data)) cycles = nx.cycle_basis(G) - new_pos = torch.zeros([data.batch.size(0), model.n_model_confs, 3]) - dihedral_pairs = model.dihedral_pairs.t().detach().numpy() + new_pos = torch.zeros([data.batch.size(0), model.n_model_confs, 3], device=device) + dihedral_pairs = model.dihedral_pairs.t().detach().cpu().numpy() Sx = [] Sy = [] @@ -41,7 +41,8 @@ def construct_conformers(data, model): cycle_indices, new_pos, dihedral_pairs, - i) # i instead of i+1 + i, # i instead of i+1 + device) # new graph if x_index not in Sx: @@ -73,7 +74,7 @@ def construct_conformers(data, model): in_cycle = len(cycle_indices) + 1 # new graph - p_coords = torch.zeros([4, model.n_model_confs, 3]) + p_coords = torch.zeros([4, model.n_model_confs, 3], device=device) p_idx = model.neighbors[x_index] if x_index not in Sx: @@ -87,11 +88,11 @@ def construct_conformers(data, model): # update indices Sx.extend([x_index]) - Sx.extend(model.neighbors[x_index].detach().numpy()) + Sx.extend(model.neighbors[x_index].detach().cpu().numpy()) Sx = list(set(Sx)) Sy.extend([y_index]) - Sy.extend(model.neighbors[y_index].detach().numpy()) + Sy.extend(model.neighbors[y_index].detach().cpu().numpy()) # set px p_X = new_pos[x_index] @@ -101,11 +102,11 @@ def construct_conformers(data, model): # set Y if cycle_added: - cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i + 1) + cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i + 1, device) cycle_avg_coords = cycle_avg_coords - cycle_avg_coords[cycle_avg_indices == y_index] # move y to origin q_idx = model.neighbors[y_index] q_coords_mask = [True if a in q_idx else False for a in cycle_avg_indices] - q_coords = torch.zeros([4, model.n_model_confs, 3]) + q_coords = torch.zeros([4, model.n_model_confs, 3], device=device) q_reorder = np.argsort([np.where(a == q_idx)[0][0] for a in torch.tensor(cycle_avg_indices)[q_coords_mask]]) q_coords[0:sum(q_coords_mask)] = cycle_avg_coords[q_coords_mask][q_reorder] new_pos_Sy = cycle_avg_coords.clone() @@ -128,12 +129,12 @@ def construct_conformers(data, model): # translate q new_p_Y = new_pos_Sx_2[Sx == y_index] - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.])).unsqueeze(0).unsqueeze(0) + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=device)).unsqueeze(0).unsqueeze(0) new_pos_Sy_3 = torch.matmul(transform_matrix, new_pos_Sy_2.unsqueeze(-1)).squeeze(-1) + new_p_Y # rotate by gamma H_gamma = calculate_gamma(model.n_model_confs, model.dihedral_mask[i], model.c_ij[i], model.v_star[i], Sx, Sy, - p_idx, q_idx, x_index, y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y) + p_idx, q_idx, x_index, y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y, device) new_pos_Sx_3 = torch.matmul(H_gamma.unsqueeze(0), new_pos_Sx_2.unsqueeze(-1)).squeeze(-1) # update all coordinates @@ -148,7 +149,7 @@ def construct_conformers(data, model): return new_pos -def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_start_idx): +def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_start_idx, device): # find index of cycle starting position cycle_len = len(cycle_indices) @@ -175,7 +176,7 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta x_indices, y_indices = pairs.transpose() - p_coords = torch.zeros([cycle_len, 4, model.n_model_confs, 3]) + p_coords = torch.zeros([cycle_len, 4, model.n_model_confs, 3], device=device) p_idx = [model.neighbors[x] for x in x_indices] if ii == 0: @@ -225,13 +226,13 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta # translate q new_p_Y = new_pos_Sx_2[i][Sx_cycle[i] == y_indices[i]].squeeze(-1) - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.])).unsqueeze(0).unsqueeze(0) + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=device)).unsqueeze(0).unsqueeze(0) new_pos_Sy_3 = torch.matmul(transform_matrix, new_pos_Sy_2[i].unsqueeze(-1)).squeeze(-1) + new_p_Y # rotate by gamma H_gamma = calculate_gamma(model.n_model_confs, model.dihedral_mask[ids[i]], model.c_ij[ids[i]], model.v_star[ids[i]], Sx_cycle[i], Sy_cycle[i], p_idx[i], q_idx[i], pairs[i][0], - pairs[i][1], new_pos_Sx_2[i], new_pos_Sy_3, new_p_Y) + pairs[i][1], new_pos_Sx_2[i], new_pos_Sy_3, new_p_Y, device) new_pos_Sx_3 = torch.matmul(H_gamma.unsqueeze(0), new_pos_Sx_2[i].unsqueeze(-1)).squeeze(-1) # update all coordinates @@ -246,7 +247,7 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta if not np.all(ids == cycle_i_orders[-1]): Sy_cycle = [[] for i in range(cycle_len)] else: - cycle_mask = torch.ones([cycle_pos.size(0), cycle_pos.size(1)]) + cycle_mask = torch.ones([cycle_pos.size(0), cycle_pos.size(1)], device=device) for i in range(cycle_len): cycle_mask[i, y_indices[i]] = 0 y_neighbor_ids = model.neighbors[y_indices[i]] @@ -277,10 +278,10 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pairs, neighbors, model_p_coords, model_q_coords, dihedral_x_mask, dihedral_y_mask, x_map_to_neighbor_y, - y_map_to_neighbor_x, dihedral_mask, c_ij, v_star): + y_map_to_neighbor_x, dihedral_mask, c_ij, v_star, device): pos = torch.cat([torch.cat([p[0][i] for p in data.pos]).unsqueeze(1) for i in range(n_true_confs)], dim=1) - new_pos = torch.zeros([pos.size(0), n_model_confs, 3]).to(device) + new_pos = torch.zeros([pos.size(0), n_model_confs, 3], device=device) dihedral_pairs = dihedral_pairs.t().detach().cpu().numpy() Sx = [] @@ -295,7 +296,7 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai continue # new graph - p_coords = torch.zeros([4, n_model_confs, 3]).to(device) + p_coords = torch.zeros([4, n_model_confs, 3], device=device) p_idx = neighbors[x_index] if x_index not in Sx: @@ -309,11 +310,11 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai # update indices Sx.extend([x_index]) - Sx.extend(neighbors[x_index].detach().numpy()) + Sx.extend(neighbors[x_index].detach().cpu().numpy()) Sx = list(set(Sx)) Sy.extend([y_index]) - Sy.extend(neighbors[y_index].detach().numpy()) + Sy.extend(neighbors[y_index].detach().cpu().numpy()) # set px p_X = new_pos[x_index] @@ -338,12 +339,12 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai # translate q new_p_Y = new_pos_Sx_2[Sx == y_index] - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.])).unsqueeze(0).unsqueeze(0) + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=device)).unsqueeze(0).unsqueeze(0) new_pos_Sy_3 = torch.matmul(transform_matrix, new_pos_Sy_2.unsqueeze(-1)).squeeze(-1) + new_p_Y # rotate by gamma H_gamma = calculate_gamma(n_model_confs, dihedral_mask[i], c_ij[i], v_star[i], Sx, Sy, p_idx, q_idx, x_index, - y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y) + y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y, device) new_pos_Sx_3 = torch.matmul(H_gamma.unsqueeze(0), new_pos_Sx_2.unsqueeze(-1)).squeeze(-1) # update all coordinates @@ -364,10 +365,10 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai def calculate_gamma(n_model_confs, dihedral_mask, c_ij, v_star, Sx, Sy, p_idx, q_idx, x_index, y_index, - new_pos_Sx_2, new_pos_Sy_3, new_p_Y): + new_pos_Sx_2, new_pos_Sy_3, new_p_Y, device): # calculate current dihedrals - pT_prime = torch.zeros([3, n_model_confs, 3]).to(device) - qZ_translated = torch.zeros([3, n_model_confs, 3]).to(device) + pT_prime = torch.zeros([3, n_model_confs, 3], device=device) + qZ_translated = torch.zeros([3, n_model_confs, 3], device=device) pY_prime = new_p_Y.repeat(9, 1, 1) qX = torch.zeros_like(pY_prime) @@ -379,19 +380,24 @@ def calculate_gamma(n_model_confs, dihedral_mask, c_ij, v_star, Sx, Sy, p_idx, q qZ_translated[:len(q_ids_in_Sy)] = new_pos_Sy_3[q_ids_in_Sy] XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos = batch_dihedrals(pT_prime[pT_idx], qX, pY_prime, qZ_translated[qZ_idx]) - A_ij = build_A_matrix_inf(XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos, n_model_confs) * dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + A_ij = build_A_matrix_inf( + XYTi_XYZj_curr_sin, + XYTi_XYZj_curr_cos, + n_model_confs, + device=device, + ) * dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # build A matrix A_curr = torch.sum(A_ij * c_ij.unsqueeze(-1), dim=0) determinants = torch.det(A_curr) + 1e-10 - A_curr_inv_ = A_curr.view(n_model_confs, 4)[:, [3, 1, 2, 0]] * torch.tensor([[1., -1., -1., 1.]]) + A_curr_inv_ = A_curr.view(n_model_confs, 4)[:, [3, 1, 2, 0]] * torch.tensor([[1., -1., -1., 1.]], device=device) A_curr_inv = (A_curr_inv_ / determinants.unsqueeze(-1)).view(n_model_confs, 2, 2) A_curr_inv_v_star = torch.matmul(A_curr_inv, v_star.unsqueeze(-1)).squeeze(-1) # get gamma matrix v_gamma = A_curr_inv_v_star / (A_curr_inv_v_star.norm(dim=-1, keepdim=True) + 1e-10) gamma_cos, gamma_sin = v_gamma.split(1, dim=-1) - H_gamma = build_gamma_rotation_inf(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1), n_model_confs) + H_gamma = build_gamma_rotation_inf(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1), n_model_confs, device) return H_gamma @@ -428,9 +434,9 @@ def rotation_matrix_inf_v2(neighbor_coords, neighbor_map): return H -def build_A_matrix_inf(curr_sin, curr_cos, n_model_confs): +def build_A_matrix_inf(curr_sin, curr_cos, n_model_confs, device): - A_ij = torch.FloatTensor([[[[0, 0], [0, 0]]]]).repeat(9, n_model_confs, 1, 1) + A_ij = torch.FloatTensor([[[[0, 0], [0, 0]]]]).repeat(9, n_model_confs, 1, 1).to(device) A_ij[:, :, 0, 0] = curr_cos A_ij[:, :, 0, 1] = curr_sin A_ij[:, :, 1, 0] = curr_sin @@ -439,8 +445,8 @@ def build_A_matrix_inf(curr_sin, curr_cos, n_model_confs): return A_ij -def build_gamma_rotation_inf(gamma_sin, gamma_cos, n_model_confs): - H_gamma = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(n_model_confs, 1, 1) +def build_gamma_rotation_inf(gamma_sin, gamma_cos, n_model_confs, device): + H_gamma = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(n_model_confs, 1, 1).to(device) H_gamma[:, 1, 1] = gamma_cos H_gamma[:, 1, 2] = -gamma_sin H_gamma[:, 2, 1] = gamma_sin From f8cccd4ddde746cbe871bd3010a483716a1a224a Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Thu, 14 Mar 2024 13:12:26 -0400 Subject: [PATCH 26/28] Make the model device agnotic --- geomol/model.py | 78 +++++++++++++++++++++++-------------------------- 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/geomol/model.py b/geomol/model.py index e1476eb..b9503b0 100644 --- a/geomol/model.py +++ b/geomol/model.py @@ -17,7 +17,7 @@ class GeoMol(nn.Module): - def __init__(self, hyperparams, num_node_features, num_edge_features, device=None): + def __init__(self, hyperparams, num_node_features, num_edge_features): super(GeoMol, self).__init__() self.model_dim = hyperparams['model_dim'] @@ -27,10 +27,6 @@ def __init__(self, hyperparams, num_node_features, num_edge_features, device=Non self.loss_type = hyperparams['loss_type'] self.teacher_force = hyperparams['teacher_force'] self.random_alpha = hyperparams['random_alpha'] - if device is None: - torch.device('cuda' if torch.cuda.is_available() else 'cpu') - else: - self.device = torch.device(device) self.gnn = GNN(node_dim=num_node_features + self.random_vec_dim, edge_dim=num_edge_features + self.random_vec_dim, @@ -110,18 +106,18 @@ def forward(self, data, ignore_neighbors=False, inference=False, n_model_confs=N pos_mask_L2 = pos_mask.view(molecule_loss.size(2), self.n_true_confs).t() pos_mask_L1 = pos_mask_L2.unsqueeze(1).repeat(1, self.n_model_confs, 1) - molecule_loss = torch.where(pos_mask_L1 == 1, molecule_loss, torch.FloatTensor([9e99]).to(self.device)) + molecule_loss = torch.where(pos_mask_L1 == 1, molecule_loss, torch.FloatTensor([9e99], device=molecule_loss.device)) if self.loss_type == 'implicit_mle': if DEBUG_NEIGHBORHOOD_PAIRS or self.teacher_force: L1 = torch.where(pos_mask_L2 == 1, torch.min(molecule_loss, dim=0).values, - torch.FloatTensor([0]).to(self.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) + torch.FloatTensor([0], device=pos_mask_L2.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) else: L1 = torch.min(molecule_loss, dim=0).values.sum(dim=0) / self.n_model_confs L2 = torch.where(pos_mask_L2 == 1, torch.min(molecule_loss, dim=1).values, - torch.FloatTensor([0]).to(self.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) + torch.FloatTensor([0], device=pos_mask_L2.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) # logging self.run_writer_mle(True if L1.mean() > L2.mean() else False, molecule_loss, pos_mask_L2) @@ -142,13 +138,13 @@ def forward(self, data, ignore_neighbors=False, inference=False, n_model_confs=N if self.teacher_force: cost_mat_i = cost_mat_detach[i, :n_true_confs_batch[i], :n_true_confs_batch[i]] ot_mat = ot.emd(a=H_1, b=H_1, M=np.max(np.abs(cost_mat_i)) + cost_mat_i, numItermax=10000) - ot_mat_attached = torch.tensor(ot_mat, device=self.device, requires_grad=False).float() + ot_mat_attached = torch.tensor(ot_mat, device=molecule_loss.device, requires_grad=False).float() ot_mat_list.append(ot_mat_attached) loss += torch.sum(ot_mat_attached * molecule_loss[:n_true_confs_batch[i], :n_true_confs_batch[i], i]) else: cost_mat_i = cost_mat_detach[i, :n_true_confs_batch[i]] ot_mat = ot.emd(a=H_1, b=H_2, M=np.max(np.abs(cost_mat_i)) + cost_mat_i, numItermax=10000) - ot_mat_attached = torch.tensor(ot_mat, device=self.device, requires_grad=False).float() + ot_mat_attached = torch.tensor(ot_mat, device=molecule_loss.device, requires_grad=False).float() ot_mat_list.append(ot_mat_attached) loss += torch.sum(ot_mat_attached * molecule_loss[:n_true_confs_batch[i], :, i]) @@ -168,13 +164,13 @@ def assign_neighborhoods(self, x, edge_index, edge_attr, batch, data): self.n_dihedral_pairs = len(self.dihedral_pairs.t()) # mask for neighbors - self.neighbor_mask = torch.zeros([self.n_neighborhoods, 4]).to(self.device) + self.neighbor_mask = torch.zeros([self.n_neighborhoods, 4], device=x.device) # maps node index to hidden index as given by self.neighbors self.x_to_h_map = torch.zeros(x.size(0)) # maps local neighborhood to batch molecule - self.neighborhood_to_mol_map = torch.zeros(self.n_neighborhoods, dtype=torch.int64).to(self.device) + self.neighborhood_to_mol_map = torch.zeros(self.n_neighborhoods, dtype=torch.int64, device=x.device) for i, (a, n) in enumerate(self.neighbors.items()): self.x_to_h_map[a] = i @@ -183,18 +179,18 @@ def assign_neighborhoods(self, x, edge_index, edge_attr, batch, data): self.neighborhood_to_mol_map[i] = batch[a] # maps which atom in (x,y) corresponds to the same atom in (y,x) for each dihedral pair - self.x_map_to_neighbor_y = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) - self.y_map_to_neighbor_x = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) + self.x_map_to_neighbor_y = torch.zeros_like(self.neighbor_mask) + self.y_map_to_neighbor_x = torch.zeros_like(self.neighbor_mask) # neighbor mask but for dihedral pairs - self.dihedral_x_mask = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) - self.dihedral_y_mask = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) + self.dihedral_x_mask = torch.zeros_like(self.neighbor_mask) + self.dihedral_y_mask = torch.zeros_like(self.neighbor_mask) # maps neighborhood pair to batch molecule - self.neighborhood_pairs_to_mol_map = torch.zeros(self.n_dihedral_pairs, dtype=torch.int64).to(self.device) + self.neighborhood_pairs_to_mol_map = torch.zeros(self.n_dihedral_pairs, dtype=torch.int64, device=x.device) # indicates which type of bond is formed by X-Y - self.xy_bond_type = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) + self.xy_bond_type = torch.zeros_like(self.neighbor_mask) for i, (s, e) in enumerate(self.dihedral_pairs.t()): # this indicates which neighbor is the correct x <--> y map (see overleaf doc) @@ -219,8 +215,8 @@ def embed(self, x, edge_index, edge_attr, batch): # stochasticity rand_dist = torch.distributions.normal.Normal(loc=0, scale=self.random_vec_std) # rand_dist = torch.distributions.uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) - rand_x = rand_dist.sample([x.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze - rand_edge = rand_dist.sample([edge_attr.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze + rand_x = rand_dist.sample([x.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(x.device) # added squeeze + rand_edge = rand_dist.sample([edge_attr.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(edge_attr.device) # added squeeze x = torch.cat([x.unsqueeze(1).repeat(1, self.n_model_confs, 1), rand_x], dim=-1) edge_attr = torch.cat([edge_attr.unsqueeze(1).repeat(1, self.n_model_confs, 1), rand_edge], dim=-1) @@ -271,8 +267,8 @@ def embed(self, x, edge_index, edge_attr, batch): def model_local_stats(self, x, chiral_tag): - n_h = torch.zeros([self.n_neighborhoods, 4, self.n_model_confs, self.model_dim]).to(self.device) - x_h = torch.zeros([self.n_neighborhoods, self.n_model_confs, self.model_dim]).to(self.device) + n_h = torch.zeros([self.n_neighborhoods, 4, self.n_model_confs, self.model_dim], device=x.device) + x_h = torch.zeros([self.n_neighborhoods, self.n_model_confs, self.model_dim], device=x.device) for i, (a, n) in enumerate(self.neighbors.items()): n_h[i, 0:len(n), :] = x[n] @@ -322,7 +318,7 @@ def model_local_stats(self, x, chiral_tag): self.neighbor_mask) if self.teacher_force: - R = random_rotation_matrix([self.n_neighborhoods, 1, self.n_model_confs]).to(self.device) + R = random_rotation_matrix([self.n_neighborhoods, 1, self.n_model_confs]).to(self.true_local_coords.device) self.model_local_coords = torch.matmul(R, self.true_local_coords[:, 0].unsqueeze(-1)).squeeze(-1) return model_one_hop, model_two_hop, model_angles @@ -342,13 +338,13 @@ def ground_truth_local_stats(self, pos): """ n_neighborhoods = len(self.neighbors) - self.true_local_coords = torch.zeros(n_neighborhoods, 6, 4, self.n_true_confs, 3).to(self.device) + self.true_local_coords = torch.zeros(n_neighborhoods, 6, 4, self.n_true_confs, 3, device=pos.device) for i, (a, n) in enumerate(self.neighbors.items()): # permutations for symmetric hydrogens n_perms = n.unsqueeze(0).repeat(6, 1) - perms = torch.tensor(list(permutations(n[self.leaf_hydrogens[a]]))).to(self.device) + perms = torch.tensor(list(permutations(n[self.leaf_hydrogens[a]])), device=n_perms.device) if perms.size(1) != 0: n_perms[0:len(perms), self.leaf_hydrogens[a]] = perms @@ -391,13 +387,13 @@ def model_pair_stats(self, x, batch, h_mol): :return: tuple of true stats (dihedral and three-hop), each with size (n_dihedral_pairs, 9, n_true_confs) """ - dihedral_x_neighbors = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, 3]).to(self.device) - dihedral_x_node_reps = torch.zeros([self.n_dihedral_pairs, self.n_model_confs, self.model_dim]).to(self.device) - dihedral_x_neighbor_reps = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, self.model_dim]).to(self.device) + dihedral_x_neighbors = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, 3], device=x.device) + dihedral_x_node_reps = torch.zeros([self.n_dihedral_pairs, self.n_model_confs, self.model_dim], device=x.device) + dihedral_x_neighbor_reps = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, self.model_dim], device=x.device) - dihedral_y_neighbors = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, 3]).to(self.device) - dihedral_y_node_reps = torch.zeros([self.n_dihedral_pairs, self.n_model_confs, self.model_dim]).to(self.device) - dihedral_y_neighbor_reps = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, self.model_dim]).to(self.device) + dihedral_y_neighbors = torch.zeros_like(dihedral_x_neighbors) + dihedral_y_node_reps = torch.zeros_like(dihedral_x_node_reps) + dihedral_y_neighbor_reps = torch.zeros_like(dihedral_x_neighbor_reps) for i, (s, e) in enumerate(self.dihedral_pairs.t()): @@ -465,7 +461,7 @@ def ground_truth_pair_stats(self, pos): """ n_dihedral_pairs = len(self.dihedral_pairs.t()) - true_dihedral_coords = torch.zeros([n_dihedral_pairs, 4, 4, 6, self.n_true_confs, 3]).to(self.device) + true_dihedral_coords = torch.zeros([n_dihedral_pairs, 4, 4, 6, self.n_true_confs, 3], device=pos.device) for i, (s, e) in enumerate(self.dihedral_pairs.t()): # construct true coordinates (order is x_n, x, y, y_n) @@ -473,8 +469,8 @@ def ground_truth_pair_stats(self, pos): y_neighbor_map_perms = self.neighbors[e.item()].unsqueeze(1).repeat(1, 6) # permutations for symmetric hydrogens - x_perms = torch.tensor(list(permutations(self.neighbors[s.item()][self.leaf_hydrogens[s.item()]]))).t().to(self.device) - y_perms = torch.tensor(list(permutations(self.neighbors[e.item()][self.leaf_hydrogens[e.item()]]))).t().to(self.device) + x_perms = torch.tensor(list(permutations(self.neighbors[s.item()][self.leaf_hydrogens[s.item()]]))).t().to(pos.device) + y_perms = torch.tensor(list(permutations(self.neighbors[e.item()][self.leaf_hydrogens[e.item()]]))).t().to(pos.device) if x_perms.size(0) != 0: x_neighbor_map_perms[self.leaf_hydrogens[s.item()], 0:x_perms.size(1)] = x_perms @@ -588,7 +584,7 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch p_Y_prime = p_H[self.x_map_to_neighbor_y.bool()] q_X_prime = q_H[self.y_map_to_neighbor_x.bool()] - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.]).to(self.device)).unsqueeze(0).unsqueeze(0).unsqueeze(0) + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=q_Z_prime.device)).unsqueeze(0).unsqueeze(0).unsqueeze(0) q_Z_translated = torch.matmul(transform_matrix, q_Z_prime.unsqueeze(-1)).squeeze(-1) + p_Y_prime.unsqueeze(1) # broadcast over not coordinates @@ -598,7 +594,7 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch # more stochasticity! if self. random_alpha: rand_dist = torch.distributions.normal.Normal(loc=0, scale=self.random_vec_std) - rand_alpha = rand_dist.sample([self.n_dihedral_pairs, self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) + rand_alpha = rand_dist.sample([self.n_dihedral_pairs, self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(dihedral_x_node_reps.device) alpha \ = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) \ + self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) @@ -628,12 +624,12 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch + self.c_mlp(torch.cat([q_reps[:, qZ_idx], cy_reps, p_reps[:, pT_idx], cx_reps], dim=-1)) # calculate gamma sin and cos - A_ij = self.build_A_matrix(XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + A_ij = self.build_A_matrix(XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos).to(self.dihedral_mask.device) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) A_curr = torch.sum(A_ij * self.c_ij.unsqueeze(-1), dim=1) determinants = torch.det(A_curr) + 1e-10 A_curr_inv_ = A_curr.view(self.n_dihedral_pairs, self.n_model_confs, - 4)[:, :, [3, 1, 2, 0]] * torch.tensor([[[1., -1., -1., 1.]]]).to(self.device) + 4)[:, :, [3, 1, 2, 0]] * torch.tensor([[[1., -1., -1., 1.]]], device=A_curr.device) A_curr_inv = (A_curr_inv_ / determinants.unsqueeze(-1)).view(self.n_dihedral_pairs, self.n_model_confs, 2, 2) A_curr_inv_v_star = torch.matmul(A_curr_inv, self.v_star.unsqueeze(-1)).squeeze(-1) @@ -641,7 +637,7 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch gamma_cos, gamma_sin = v_gamma.split(1, dim=-1) # rotate p_coords by gamma - H_gamma = self.build_alpha_rotation(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1)) + H_gamma = self.build_alpha_rotation(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1)).to(p_T_prime.device) p_T_alpha = torch.matmul(H_gamma.unsqueeze(1), p_T_prime.unsqueeze(-1)).squeeze(-1) return q_Z_prime, p_T_alpha, p_Y_prime, q_Z_translated @@ -653,7 +649,7 @@ def build_alpha_rotation(self, alpha, alpha_cos=None): :param alpha: predicted values of torsion parameter alpha (n_dihedral_pairs, n_model_confs) :return: alpha rotation matrix (n_dihedral_pairs, n_model_confs, 3, 3) """ - H_alpha = torch.FloatTensor([[[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]]).repeat(self.n_dihedral_pairs, self.n_model_confs, 1, 1).to(self.device) + H_alpha = torch.FloatTensor([[[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]]).repeat(self.n_dihedral_pairs, self.n_model_confs, 1, 1) if torch.is_tensor(alpha_cos): H_alpha[:, :, 1, 1] = alpha_cos @@ -670,7 +666,7 @@ def build_alpha_rotation(self, alpha, alpha_cos=None): def build_A_matrix(self, curr_sin, curr_cos): - A_ij = torch.FloatTensor([[[[[0, 0], [0, 0]]]]]).repeat(self.n_dihedral_pairs, 9, self.n_model_confs, 1, 1).to(self.device) + A_ij = torch.FloatTensor([[[[[0, 0], [0, 0]]]]]).repeat(self.n_dihedral_pairs, 9, self.n_model_confs, 1, 1) A_ij[:, :, :, 0, 0] = curr_cos A_ij[:, :, :, 0, 1] = curr_sin A_ij[:, :, :, 1, 0] = curr_sin From 689d572d219fa196dabe124e07437ecb6bbca835 Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Thu, 14 Mar 2024 13:53:34 -0400 Subject: [PATCH 27/28] Fix shapes for dihedral_pairs-related operations in model --- geomol/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/geomol/model.py b/geomol/model.py index b9503b0..399be04 100644 --- a/geomol/model.py +++ b/geomol/model.py @@ -179,18 +179,18 @@ def assign_neighborhoods(self, x, edge_index, edge_attr, batch, data): self.neighborhood_to_mol_map[i] = batch[a] # maps which atom in (x,y) corresponds to the same atom in (y,x) for each dihedral pair - self.x_map_to_neighbor_y = torch.zeros_like(self.neighbor_mask) - self.y_map_to_neighbor_x = torch.zeros_like(self.neighbor_mask) + self.x_map_to_neighbor_y = torch.zeros([self.n_dihedral_pairs, 4], device=x.device) + self.y_map_to_neighbor_x = torch.zeros_like(self.x_map_to_neighbor_y) # neighbor mask but for dihedral pairs - self.dihedral_x_mask = torch.zeros_like(self.neighbor_mask) - self.dihedral_y_mask = torch.zeros_like(self.neighbor_mask) + self.dihedral_x_mask = torch.zeros_like(self.x_map_to_neighbor_y) + self.dihedral_y_mask = torch.zeros_like(self.dihedral_x_mask) # maps neighborhood pair to batch molecule self.neighborhood_pairs_to_mol_map = torch.zeros(self.n_dihedral_pairs, dtype=torch.int64, device=x.device) # indicates which type of bond is formed by X-Y - self.xy_bond_type = torch.zeros_like(self.neighbor_mask) + self.xy_bond_type = torch.zeros_like(self.x_map_to_neighbor_y) for i, (s, e) in enumerate(self.dihedral_pairs.t()): # this indicates which neighbor is the correct x <--> y map (see overleaf doc) From 70341a480a42c5ad84d923bb55d838d7d9f5791e Mon Sep 17 00:00:00 2001 From: Xiaorui Dong Date: Thu, 14 Mar 2024 16:32:33 -0400 Subject: [PATCH 28/28] Fix device mismatch during inference for a ring molecule Molecule CN1C2=C(C=C(C=C2)Cl)C(=NCC1=O)C3=CC=CC=C3 helps identify this problem --- geomol/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geomol/inference.py b/geomol/inference.py index 14b7598..93fb117 100644 --- a/geomol/inference.py +++ b/geomol/inference.py @@ -107,7 +107,7 @@ def construct_conformers(data, model, device=None): q_idx = model.neighbors[y_index] q_coords_mask = [True if a in q_idx else False for a in cycle_avg_indices] q_coords = torch.zeros([4, model.n_model_confs, 3], device=device) - q_reorder = np.argsort([np.where(a == q_idx)[0][0] for a in torch.tensor(cycle_avg_indices)[q_coords_mask]]) + q_reorder = torch.argsort(torch.tensor([torch.where(a == q_idx)[0][0] for a in torch.tensor(cycle_avg_indices)[q_coords_mask]])) q_coords[0:sum(q_coords_mask)] = cycle_avg_coords[q_coords_mask][q_reorder] new_pos_Sy = cycle_avg_coords.clone() Sy = cycle_avg_indices