-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathligand_graph_features.py
More file actions
176 lines (167 loc) · 7.48 KB
/
ligand_graph_features.py
File metadata and controls
176 lines (167 loc) · 7.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from rdkit import Chem
import torch
from torch_geometric.data import Data
import numpy as np
# allowable node and edge features
# allowable_features = {
# 'possible_atomic_num_list' : list(range(1, 119)),
# 'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
# 'possible_chirality_list' : [
# Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
# Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
# Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
# Chem.rdchem.ChiralType.CHI_OTHER
# ],
# 'possible_hybridization_list' : [
# Chem.rdchem.HybridizationType.S,
# Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
# Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
# Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
# ],
# 'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8],
# 'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6],
# 'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
# 'possible_bonds' : [
# Chem.rdchem.BondType.SINGLE,
# Chem.rdchem.BondType.DOUBLE,
# Chem.rdchem.BondType.TRIPLE,
# Chem.rdchem.BondType.AROMATIC
# ],
# 'possible_bond_dirs' : [ # only for double bond stereo information
# Chem.rdchem.BondDir.NONE,
# Chem.rdchem.BondDir.ENDUPRIGHT,
# Chem.rdchem.BondDir.ENDDOWNRIGHT
# ]
# }
allowable_features = { # from Yang
"possible_atomic_num_list": list(range(1, 119)),
"possible_formal_charge_list": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
"possible_chirality_list": [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER,
],
"possible_hybridization_list": [
Chem.rdchem.HybridizationType.S,
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,
Chem.rdchem.HybridizationType.UNSPECIFIED,
],
"possible_numH_list": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"possible_implicit_valence_list": [0, 1, 2, 3, 4, 5, 6],
"possible_degree_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"possible_bonds": [
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC,
],
"possible_aromatic_list": [True, False],
"possible_bond_dirs": [ # only for double bond stereo information
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT,
],
}
# def mol_to_graph_data_obj_simple(mol):
# """
# Converts rdkit mol object to graph Data object required by the pytorch
# geometric package. NB: Uses simplified atom and bond features, and represent
# as indices
# :param mol: rdkit mol object
# :return: graph data object with the attributes: x, edge_index, edge_attr
# """
# # atoms
# num_atom_features = 2 # atom type, chirality tag
# atom_features_list = []
# for atom in mol.GetAtoms():
# atom_feature = [allowable_features['possible_atomic_num_list'].index(
# atom.GetAtomicNum())] + [allowable_features[
# 'possible_chirality_list'].index(atom.GetChiralTag())]
# atom_features_list.append(atom_feature)
# x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
#
# # bonds
# num_bond_features = 2 # bond type, bond direction
# if len(mol.GetBonds()) > 0: # mol has bonds
# edges_list = []
# edge_features_list = []
# for bond in mol.GetBonds():
# i = bond.GetBeginAtomIdx()
# j = bond.GetEndAtomIdx()
# edge_feature = [allowable_features['possible_bonds'].index(
# bond.GetBondType())] + [allowable_features[
# 'possible_bond_dirs'].index(
# bond.GetBondDir())]
# edges_list.append((i, j))
# edge_features_list.append(edge_feature)
# edges_list.append((j, i))
# edge_features_list.append(edge_feature)
#
# # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
# edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
#
# # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
# edge_attr = torch.tensor(np.array(edge_features_list),
# dtype=torch.long)
# else: # mol has no bonds
# edge_index = torch.empty((2, 0), dtype=torch.long)
# edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
#
# data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
#
# return data
def mol_to_graph_data_obj_simple(mol): # from Yang
"""
Converts rdkit mol object to graph Data object required by the pytorch
geometric package. NB: Uses simplified atom and bond features, and represent
as indices
:param mol: rdkit mol object
:return: graph data object with the attributes: x, edge_index, edge_attr
"""
# atoms
# num_atom_features = 6 # atom type, chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
atom_feature = (
[allowable_features["possible_atomic_num_list"].index(atom.GetAtomicNum())]
+ [allowable_features["possible_degree_list"].index(atom.GetDegree())]
+ [allowable_features["possible_formal_charge_list"].index(atom.GetFormalCharge())]
+ [
allowable_features["possible_hybridization_list"].index(
atom.GetHybridization()
)
]
+ [allowable_features["possible_aromatic_list"].index(atom.GetIsAromatic())]
+ [allowable_features["possible_chirality_list"].index(atom.GetChiralTag())]
)
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds
num_bond_features = 2 # bond type, bond direction
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = [
allowable_features["possible_bonds"].index(bond.GetBondType())
] + [allowable_features["possible_bond_dirs"].index(bond.GetBondDir())]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)
else: # mol has no bonds
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return data