-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathGEOM_dataset_preparation.py
378 lines (311 loc) · 15.8 KB
/
GEOM_dataset_preparation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import argparse
import json
import os
import pickle
import random
from itertools import repeat
from os.path import join
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from torch_geometric.data import Data, InMemoryDataset
from tqdm import tqdm
from datasets import allowable_features
from extended_feature import get_bond_angle_index, get_bond_angle, dihedral_bond_angle_index
from extended_feature import edge_geo_distance
from torch_geometric.utils import to_dense_adj, dense_to_sparse, sort_edge_index
from rdkit.Geometry.rdGeometry import ComputeDihedralAngle
def mol_to_graph_data_obj_simple_3D(mol):
"""
Converts rdkit mol object to graph Data object required by the pytorch
geometric package. NB: Uses simplified atom and bond features, and represent as indices
:param mol: rdkit mol object
return: graph data object with the attributes: x, edge_index, edge_attr """
# todo: more atom/bond features in the future
# atoms, two features: atom type, chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
atom_feature = [allowable_features['possible_atomic_num_list'].index(atom.GetAtomicNum())] + \
[allowable_features['possible_chirality_list'].index(atom.GetChiralTag())]
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds, two features: bond type, bond direction
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = [allowable_features['possible_bonds'].index(bond.GetBondType())] + \
[allowable_features['possible_bond_dirs'].index(bond.GetBondDir())]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)
else: # mol has no bonds
num_bond_features = 2
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
# every CREST conformer gets its own mol object,
# every mol object has only one RDKit conformer
# ref: https://github.com/learningmatter-mit/geom/blob/master/tutorials/
conformer = mol.GetConformers()[0]
positions = conformer.GetPositions()
positions = torch.Tensor(positions)
data = Data(x=x, edge_index=edge_index,
edge_attr=edge_attr, positions=positions)
return data
def summarise():
""" summarise the stats of molecules and conformers """
dir_name = '{}/rdkit_folder'.format(data_folder)
drugs_file = '{}/summary_drugs.json'.format(dir_name)
with open(drugs_file, 'r') as f:
drugs_summary = json.load(f)
# expected: 304,466 molecules
print('number of items (SMILES): {}'.format(len(drugs_summary.items())))
sum_list = []
drugs_summary = list(drugs_summary.items())
for smiles, sub_dic in tqdm(drugs_summary):
##### Path should match #####
if sub_dic.get('pickle_path', '') == '':
continue
mol_path = join(dir_name, sub_dic['pickle_path'])
with open(mol_path, 'rb') as f:
mol_sum = {}
mol_dic = pickle.load(f)
conformer_list = mol_dic['conformers']
conformer_dict = conformer_list[0]
rdkit_mol = conformer_dict['rd_mol']
data = mol_to_graph_data_obj_simple_3D(rdkit_mol)
mol_sum['geom_id'] = conformer_dict['geom_id']
mol_sum['num_edge'] = len(data.edge_attr)
mol_sum['num_node'] = len(data.positions)
mol_sum['num_conf'] = len(conformer_list)
# conf['boltzmannweight'] a float for the conformer (a few rotamers)
# conf['conformerweights'] a list of fine weights of each rotamer
bw_ls = []
for conf in conformer_list:
bw_ls.append(conf['boltzmannweight'])
mol_sum['boltzmann_weight'] = bw_ls
sum_list.append(mol_sum)
return sum_list
class Molecule3DDataset(InMemoryDataset):
def __init__(self, root, n_mol, n_conf, n_upper, transform=None, seed=777,
pre_transform=None, pre_filter=None, empty=False, **kwargs):
os.makedirs(root, exist_ok=True)
os.makedirs(join(root, 'raw'), exist_ok=True)
os.makedirs(join(root, 'processed'), exist_ok=True)
if 'smiles_copy_from_3D_file' in kwargs: # for 2D Datasets (SMILES)
self.smiles_copy_from_3D_file = kwargs['smiles_copy_from_3D_file']
else:
self.smiles_copy_from_3D_file = None
self.root, self.seed = root, seed
self.n_mol, self.n_conf, self.n_upper = n_mol, n_conf, n_upper
self.pre_transform, self.pre_filter = pre_transform, pre_filter
super(Molecule3DDataset, self).__init__(
root, transform, pre_transform, pre_filter)
if not empty:
self.data, self.slices = torch.load(self.processed_paths[0])
print('root: {},\ndata: {},\nn_mol: {},\nn_conf: {}'.format(
self.root, self.data, self.n_mol, self.n_conf))
def get(self, idx):
data = Data()
for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
s = list(repeat(slice(None), item.dim()))
s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
data[key] = item[s]
return data
@property
def raw_file_names(self):
return os.listdir(self.raw_dir)
@property
def processed_file_names(self):
return 'geometric_data_processed.pt'
def download(self):
return
def process(self):
data_list = []
data_smiles_list = []
downstream_task_list = ["tox21", "toxcast", "clintox", "bbbp", "sider", "muv", "hiv", "bace", "esol",
"lipophilicity"]
whole_SMILES_set = set()
for task in downstream_task_list:
print("====== {} ======".format(task))
file_path = "/data/username/datasets/molecule_datasets/{}/processed/smiles.csv".format(task)
SMILES_list = load_SMILES_list(file_path)
temp_SMILES_set = set(SMILES_list)
whole_SMILES_set = whole_SMILES_set | temp_SMILES_set
print("len of downstream SMILES:", len(whole_SMILES_set))
if self.smiles_copy_from_3D_file is None: # 3D datasets
dir_name = '{}/rdkit_folder'.format(data_folder)
drugs_file = '{}/summary_drugs.json'.format(dir_name)
with open(drugs_file, 'r') as f:
drugs_summary = json.load(f)
drugs_summary = list(drugs_summary.items())
print('# of SMILES: {}'.format(len(drugs_summary)))
# expected: 304,466 molecules
random.seed(self.seed)
random.shuffle(drugs_summary)
mol_idx, idx, notfound = 0, 0, 0
for smiles, sub_dic in tqdm(drugs_summary):
smiles = smiles.strip()
if smiles in whole_SMILES_set:
continue
if sub_dic.get('pickle_path', '') == '':
notfound += 1
continue
mol_path = join(dir_name, sub_dic['pickle_path'])
with open(mol_path, 'rb') as f:
mol_dic = pickle.load(f)
conformer_list = mol_dic['conformers']
##### count should match #####
conf_n = len(conformer_list)
if conf_n < self.n_conf or conf_n > self.n_upper:
notfound += 1
continue
##### SMILES should match #####
# export prefix=https://github.com/learningmatter-mit/geom
# Ref: ${prefix}/issues/4#issuecomment-853486681
# Ref: ${prefix}/blob/master/tutorials/02_loading_rdkit_mols.ipynb
conf_list = [
Chem.MolToSmiles(
Chem.MolFromSmiles(
Chem.MolToSmiles(rd_mol['rd_mol'])))
for rd_mol in conformer_list[:self.n_conf]]
conf_list_raw = [
Chem.MolToSmiles(rd_mol['rd_mol'])
for rd_mol in conformer_list[:self.n_conf]]
# check that they're all the same
same_confs = len(list(set(conf_list))) == 1
same_confs_raw = len(list(set(conf_list_raw))) == 1
if not same_confs:
if same_confs_raw is True:
print("Interesting")
notfound += 1
continue
for conformer_dict in conformer_list[:self.n_conf]:
# select the first n_conf conformations
rdkit_mol = conformer_dict['rd_mol']
data = mol_to_graph_data_obj_simple_3D(rdkit_mol)
data.id = torch.tensor([idx])
data.energy = conformer_dict['totalenergy']
# bond_angle
bond_length_pair = data.positions[data.edge_index.T]
unit = (bond_length_pair[:,0,:] - bond_length_pair[:,1,:])
unit_vector = unit / unit.norm(dim=-1).unsqueeze(1).repeat(1,3)
direction_unit = torch.zeros(data.x.shape[0],3)
for i in range(data.x.shape[0]):
direction_unit[i] = unit_vector[data.edge_index[0] == i].sum()
data.bond_angle_true = (direction_unit.norm(dim=-1)**2).unsqueeze(1)
# dihedral_angle
unit_neg = (bond_length_pair[:,1,:] - bond_length_pair[:,0,:])
unit_neg_vector = unit_neg / unit_neg.norm(dim=-1).unsqueeze(1).repeat(1,3)
data.dihedral_angle_true = torch.zeros(data.edge_index.shape[1])
for i in range(data.edge_index.shape[1]):
rej_pos = direction_unit[data.edge_index[0][i]] - torch.dot(direction_unit[data.edge_index[0][i]], unit_vector[i]) * unit_vector[i]
rej_neg = direction_unit[data.edge_index[1][i]] - torch.dot(direction_unit[data.edge_index[0][i]], unit_neg_vector[i]) * unit_neg_vector[i]
data.dihedral_angle_true[i] = torch.dot(rej_pos, rej_neg)
data.mol_id = torch.tensor([mol_idx])
data_smiles_list.append(smiles)
data_list.append(data)
idx += 1
# print(data.id, '\t', data.mol_id)
# select the first n_mol molecules
if mol_idx + 1 >= self.n_mol:
break
if same_confs:
mol_idx += 1
print('mol id: [0, {}]\tlen of smiles: {}\tlen of set(smiles): {}'.format(
mol_idx, len(data_smiles_list), len(set(data_smiles_list))))
else: # 2D datasets
with open(self.smiles_copy_from_3D_file, 'r') as f:
lines = f.readlines()
for smiles in lines:
data_smiles_list.append(smiles.strip())
data_smiles_list = list(dict.fromkeys(data_smiles_list))
# load 3D structure
dir_name = '{}/rdkit_folder'.format(data_folder)
drugs_file = '{}/summary_drugs.json'.format(dir_name)
with open(drugs_file, 'r') as f:
drugs_summary = json.load(f)
# expected: 304,466 molecules
print('number of items (SMILES): {}'.format(len(drugs_summary.items())))
mol_idx, idx, notfound = 0, 0, 0
for smiles in tqdm(data_smiles_list):
sub_dic = drugs_summary[smiles]
mol_path = join(dir_name, sub_dic['pickle_path'])
with open(mol_path, 'rb') as f:
mol_dic = pickle.load(f)
conformer_list = mol_dic['conformers']
conformer = conformer_list[0]
rdkit_mol = conformer['rd_mol']
data = mol_to_graph_data_obj_simple_3D(rdkit_mol)
data.mol_id = torch.tensor([mol_idx])
data.id = torch.tensor([idx])
data_list.append(data)
mol_idx += 1
idx += 1
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data_smiles_series = pd.Series(data_smiles_list)
saver_path = join(self.processed_dir, 'smiles.csv')
print('saving to {}'.format(saver_path))
data_smiles_series.to_csv(saver_path, index=False, header=False)
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
print("%d molecules do not meet the requirements" % notfound)
print("%d molecules have been processed" % mol_idx)
print("%d conformers have been processed" % idx)
return
def load_SMILES_list(file_path):
SMILES_list = []
with open(file_path, 'rb') as f:
for line in tqdm(f.readlines()):
SMILES_list.append(line.strip().decode())
return SMILES_list
if __name__ == '__main__':
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument('--sum', type=bool, default=False, help='cal dataset stats')
parser.add_argument('--n_mol', type=int, help='number of unique smiles/molecules')
parser.add_argument('--n_conf', type=int, help='number of conformers of each molecule')
parser.add_argument('--n_upper', type=int, help='upper bound for number of conformers')
parser.add_argument('--data_folder', type=str)
args = parser.parse_args()
data_folder = args.data_folder
if args.sum:
sum_list = summarise()
with open('{}/summarise.json'.format(data_folder), 'w') as fout:
json.dump(sum_list, fout)
else:
n_mol, n_conf, n_upper = args.n_mol, args.n_conf, args.n_upper
root_2d = '{}/GEOM_2D_nmol{}_nconf{}_nupper{}'.format(data_folder, n_mol, n_conf, n_upper)
root_3d = '{}/GEOM_3D_nmol{}_nconf{}_nupper{}'.format(data_folder, n_mol, n_conf, n_upper)
# Generate 3D Datasets (2D SMILES + 3D Conformer)
Molecule3DDataset(root=root_3d, n_mol=n_mol, n_conf=n_conf, n_upper=n_upper)
# Generate 2D Datasets (2D SMILES)
Molecule3DDataset(root=root_2d, n_mol=n_mol, n_conf=n_conf, n_upper=n_upper,
smiles_copy_from_3D_file='%s/processed/smiles.csv' % root_3d)
##### to data copy to SLURM_TMPDIR under the `datasets` folder #####
'''
wget https://dataverse.harvard.edu/api/access/datafile/4327252
mv 4327252 rdkit_folder.tar.gz
cp rdkit_folder.tar.gz $SLURM_TMPDIR
cd $SLURM_TMPDIR
tar -xvf rdkit_folder.tar.gz
'''
##### for data pre-processing #####
'''
python GEOM_dataset_preparation.py --n_mol 100 --n_conf 5 --n_upper 1000 --data_folder $SLURM_TMPDIR
python GEOM_dataset_preparation.py --n_mol 50000 --n_conf 1 --n_upper 1000 --data_folder $/tmp/datasets
'''