Skip to content

Commit

Permalink
finish code for masked pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihengli-UR committed Aug 29, 2020
1 parent 32d31bf commit 7a4c81a
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 55 deletions.
25 changes: 13 additions & 12 deletions dataset/ham.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
from networkx.algorithms.cycles import cycle_basis
from torch_geometric.data import Data
from tqdm import tqdm

from utils.automorphism_group import node_equal, edge_equal

MASK_ATOM_INDEX = 0

ATOMS = OrderedDict([('B', 10.81), ('C', 12.011), ('N', 14.007), ('O', 15.999), ('F', 18.998403163), ('Si', 28.085), ('P', 30.973761998), ('S', 32.06), ('Cl', 35.45), ('K', 39.0983), ('Fe', 55.845), ('Se', 78.971), ('Br', 79.904), ('Ru', 101.07), ('Sn', 118.71), ('I', 126.90447)])

BOND_TYPE_DICT = {1.0: 0, 1.5: 1, 2.0: 2, 3.0: 3, '-': 0, '/': 0, '\\': 0, ':': 1, '=': 2, '#': 3}


class HAM(Dataset):
def __init__(self, data_root, dataset_type='train', for_vis=False, cycle_feat=False, degree_feat=False, cross_validation=False, automorphism=True):
def __init__(self, data_root, dataset_type='train', for_vis=False, cycle_feat=False, degree_feat=False, cross_validation=False, automorphism=True, transform=None):
assert dataset_type in {'train', 'test'}
self.dataset_type = dataset_type
self.transform = transform
if not cross_validation:
jsons_root = os.path.join(data_root, dataset_type, '*.json')
else:
Expand Down Expand Up @@ -87,7 +89,6 @@ def __init__(self, data_root, dataset_type='train', for_vis=False, cycle_feat=Fa
mapping_lst.append(new_mapping)
self.smiles_cluster_idx_dict[smile] = mapping_lst


def __getitem__(self, index):
"""
get index-th data
Expand Down Expand Up @@ -122,7 +123,6 @@ def __getitem__(self, index):
}
"""
data = Data()
data.cg_fg_ratio = len(json_data['cgnodes']) / len(json_data['nodes'])

if 'smiles' not in json_data:
smiles = os.path.splitext(os.path.basename(json_fpath))[0]
Expand All @@ -144,17 +144,14 @@ def __getitem__(self, index):
fg_beads: list = json_data['nodes']
fg_beads.sort(key=lambda x: x['id'])
atom_types = torch.LongTensor([list(ATOMS.keys()).index(bead['element']) for bead in fg_beads]).reshape(-1, 1)
atom_types_tensor = torch.zeros((len(atom_types), len(ATOMS)))
atom_types_tensor.scatter_(1, atom_types, 1)

input_tensor = atom_types_tensor
data.x = atom_types

# ======== degree ===========
if self.degree_feat:
degrees = graph.degree
degrees = np.array(degrees)[:, 1]
degrees = torch.tensor(degrees).float().unsqueeze(dim=-1) / 4
input_tensor = torch.cat([input_tensor, degrees], dim=1)
data.degree_or_cycle_feat = degrees

# ========= cycles ==========
if self.cycle_feat:
Expand All @@ -164,9 +161,10 @@ def __getitem__(self, index):
for idx_cycle, cycle in enumerate(cycle_lst):
cycle = torch.tensor(cycle)
cycle_indicator_per_node[cycle] = 1
input_tensor = torch.cat([input_tensor, cycle_indicator_per_node], dim=1)

data.x = input_tensor
if hasattr(data, 'degree_or_cycle_feat'):
data.degree_or_cycle_feat = torch.cat([data.degree_or_cycle_feat, cycle_indicator_per_node], dim=1)
else:
data.degree_or_cycle_feat = cycle_indicator_per_node

edges = []
bond_types = []
Expand Down Expand Up @@ -256,6 +254,9 @@ def find_positive_vertex(fg_id, cur_cg_id):
if self.dataset_type == 'peptides_martini_prediction' or self.dataset_type == 'peptides' or self.dataset_type == 'ref_mappings':
data.fname = os.path.splitext(os.path.basename(json_fpath))[0]

if self.transform is not None:
data = self.transform(data)

return data

def __len__(self):
Expand Down
15 changes: 6 additions & 9 deletions dataset/ham_per_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __getitem__(self, index):
}
"""
data = Data()
data.cg_fg_ratio = len(json_data['cgnodes']) / len(json_data['nodes'])

if 'smiles' not in json_data:
smiles = os.path.splitext(os.path.basename(json_fpath))[0]
Expand All @@ -82,17 +81,14 @@ def __getitem__(self, index):
fg_beads.sort(key=lambda x: x['id'])
# atom_types = torch.LongTensor([ATOM_TYPES.index(bead['element']) for bead in fg_beads]).reshape(-1, 1)
atom_types = torch.LongTensor([list(ATOMS.keys()).index(bead['element']) for bead in fg_beads]).reshape(-1, 1)
atom_types_tensor = torch.zeros((len(atom_types), len(ATOMS)))
atom_types_tensor.scatter_(1, atom_types, 1)

input_tensor = atom_types_tensor
data.x = atom_types

# ======== degree ===========
if self.degree_feat:
degrees = graph.degree
degrees = np.array(degrees)[:, 1]
degrees = torch.tensor(degrees).float().unsqueeze(dim=-1) / 4
input_tensor = torch.cat([input_tensor, degrees], dim=1)
data.degree_or_cycle_feat = degrees

# ========= cycles ==========
if self.cycle_feat:
Expand All @@ -102,9 +98,10 @@ def __getitem__(self, index):
for idx_cycle, cycle in enumerate(cycle_lst):
cycle = torch.tensor(cycle)
cycle_indicator_per_node[cycle] = 1
input_tensor = torch.cat([input_tensor, cycle_indicator_per_node], dim=1)

data.x = input_tensor
if hasattr(data, 'degree_or_cycle_feat'):
data.degree_or_cycle_feat = torch.cat([data.degree_or_cycle_feat, cycle_indicator_per_node], dim=1)
else:
data.degree_or_cycle_feat = cycle_indicator_per_node

edges = []
bond_types = []
Expand Down
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def eval(test_dataloader, model, args):
data.batch = torch.zeros(num_nodes).long()
data = data.to(torch.device(0))
edge_index_cpu = data.edge_index.cpu().numpy()
fg_embed, pred_cg_fg_ratio = model(data)
fg_embed = model(data)
dense_adj = torch.sparse.LongTensor(data.edge_index, data.no_bond_edge_attr, (num_nodes, num_nodes)).to_dense()

if args.num_cg_beads is None:
Expand Down
53 changes: 22 additions & 31 deletions model/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_scatter import scatter_mean

from dataset.ham import ATOMS
from dataset.ham import MASK_ATOM_INDEX
NUM_ATOMS = len(ATOMS)


Expand All @@ -25,18 +24,12 @@ def __init__(self, input_dim, hidden_dim, embedding_dim, args):
self.input_fc, self.nn_conv, self.gru, self.output_fc = self.build_nnconv_layers(input_dim, hidden_dim,
embedding_dim,
layer=gnn.NNConv)
if self.args.use_degree_feat:
input_dim += 1
if self.args.use_cycle_feat:
input_dim += 1

self.fc_pred_cg_beads_ratio = nn.Sequential(
nn.Linear(embedding_dim + input_dim, 1),
nn.Sigmoid()
)

def build_nnconv_layers(self, input_dim, hidden_dim, embedding_dim, layer=gnn.NNConv):
input_fc = nn.Linear(input_dim, hidden_dim, bias=False)
if self.args.use_mask_embed:
input_fc = nn.Embedding(input_dim + 1, hidden_dim, padding_idx=MASK_ATOM_INDEX)
else:
input_fc = nn.Embedding(input_dim, hidden_dim)
if self.args.use_degree_feat:
hidden_dim += 1
if self.args.use_cycle_feat:
Expand All @@ -55,13 +48,14 @@ def build_nnconv_layers(self, input_dim, hidden_dim, embedding_dim, layer=gnn.NN
)
return input_fc, nn_conv, gru, output_fc

def nn_conv_forward(self, x, edge_index, edge_attr, batch):
def forward(self, data):
x, edge_index = data.x, data.edge_index
edge_attr = data.edge_attr
if self.args.use_cycle_feat or self.args.use_degree_feat:
x, degree_or_cycle_feat = x[:, :NUM_ATOMS], x[:, NUM_ATOMS:]
out = F.relu(self.input_fc(x))
out = torch.cat([out, degree_or_cycle_feat], dim=1)
out = F.relu(self.input_fc(x)).squeeze(1)
out = torch.cat([out, data.degree_or_cycle_feat], dim=1)
else:
out = F.relu(self.input_fc(x))
out = F.relu(self.input_fc(x)).squeeze(1)
h = out.unsqueeze(0)

for i in range(self.args.num_nn_iter):
Expand All @@ -70,20 +64,17 @@ def nn_conv_forward(self, x, edge_index, edge_attr, batch):
out = out.squeeze(0)

out = self.output_fc(out)
feat_lst = [out, x]
if self.args.use_cycle_feat or self.args.use_degree_feat:
feat_lst.append(degree_or_cycle_feat)
out = torch.cat(feat_lst, dim=1)
out = F.normalize(out)

readout = scatter_mean(out, batch, dim=0)
cg_fg_ratio = self.fc_pred_cg_beads_ratio(readout)
return out, cg_fg_ratio

def forward(self, data):
atom_types, edge_index = data.x, data.edge_index
if self.args.use_mask_embed:
atom_types_tensor = torch.zeros((x.shape[0], len(ATOMS) + 1), device=x.device)
else:
atom_types_tensor = torch.zeros((x.shape[0], len(ATOMS)), device=x.device)
atom_types_tensor.scatter_(1, x, 1)

edge_attr = data.edge_attr
fg_embed, cg_fg_ratio = self.nn_conv_forward(atom_types, edge_index, edge_attr, data.batch)
feat_lst = [out, atom_types_tensor]
if self.args.use_cycle_feat or self.args.use_degree_feat:
feat_lst.append(data.degree_or_cycle_feat)
out = torch.cat(feat_lst, dim=1)
fg_embed = F.normalize(out)

return fg_embed, cg_fg_ratio
return fg_embed
1 change: 1 addition & 0 deletions option.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def arg_parse():
parser.add_argument('--inference_method', choices=['dsgpm', 'spec_cluster', 'metis', 'graclus'], default='dsgpm')
parser.add_argument('--svg', action='store_true')
parser.add_argument('--vis_root', type=str)
parser.add_argument('--mask_ratio', type=float, default=0.15)

parser.set_defaults(cuda='0',
lr=1e-3,
Expand Down
14 changes: 14 additions & 0 deletions scripts/pretrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#
# Copyright (c) 2020
# Licensed under The MIT License
# Written by Zhiheng Li
# Email: [email protected]
#

CUDA_VISIBLE_DEVICES=0 python self-sup_pre-train.py \
--title ham_pretrain_mask \
--data_root /scratch/zli82/dataset/HAM_dataset/data \
--batch_size 50 \
--num_workers 28 \
--ckpt /scratch/zli82/cg_exp/ckpt/ham_pretrain_mask \
--tb_root /scratch/zli82/cg_exp/experiment/tensorboard
115 changes: 115 additions & 0 deletions self-sup_pre-train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2020
# Licensed under The MIT License
# Written by Zhiheng Li
# Email: [email protected]

import os
import torch
import torch.optim as optim
import tqdm
import itertools

from option import arg_parse
from dataset.ham import HAM, ATOMS
from torch_geometric.data import DataLoader
from model.networks import DSGPM
from model.losses import TripletLoss, PosPairMSE
from utils.util import get_run_name
from torch.utils.tensorboard import SummaryWriter

from utils.stat import AverageMeter
from utils.transforms import MaskAtomType

from warnings import simplefilter
from sklearn.exceptions import UndefinedMetricWarning
simplefilter(action='ignore', category=FutureWarning)
simplefilter(action='ignore', category=UndefinedMetricWarning)


class Trainer:
def __init__(self, args):
self.args = args
train_set = HAM(data_root=args.data_root, dataset_type='train', cycle_feat=args.use_cycle_feat,
degree_feat=args.use_degree_feat, cross_validation=True, automorphism=not args.debug,
transform=MaskAtomType(args.mask_ratio))

self.train_loader = DataLoader(train_set, batch_size=args.batch_size,
num_workers=args.num_workers, pin_memory=True)

self.model = DSGPM(args.input_dim, args.hidden_dim,
args.output_dim, args=args).cuda()
final_feat_dim = args.output_dim + len(ATOMS) + 1 # TODO confirm number of atom types
if self.args.use_cycle_feat:
final_feat_dim += 1
if self.args.use_degree_feat:
final_feat_dim += 1
self.atom_type_classifier = torch.nn.Linear(final_feat_dim, len(ATOMS)).cuda() # TODO confirm number of atom types
self.criterion = torch.nn.CrossEntropyLoss()

# setup optimizer
self.optimizer = optim.Adam(itertools.chain(self.model.parameters(),
self.atom_type_classifier.parameters()),
lr=args.lr, weight_decay=args.weight_decay)

if not args.debug:
run_name = get_run_name(args.title)

self.ckpt_dir = os.path.join(args.ckpt, run_name)
if not os.path.exists(self.ckpt_dir):
os.makedirs(self.ckpt_dir)

if args.tb_log:
tensorboard_dir = os.path.join(args.tb_root, run_name)
if not os.path.exists(tensorboard_dir):
os.mkdir(tensorboard_dir)

self.writer = SummaryWriter(tensorboard_dir)

def train(self, epoch):
self.model.train()
loss_meter = AverageMeter()
accuracy_meter = AverageMeter()

train_loader = iter(self.train_loader)

tbar = tqdm.tqdm(enumerate(train_loader), total=len(self.train_loader), dynamic_ncols=True)

for i, data in tbar:
data = data.to(torch.device(0))
self.optimizer.zero_grad()

fg_embed = self.model(data)
pred = self.atom_type_classifier(fg_embed[data.masked_atom_index])
loss = self.criterion(pred, data.masked_atom_type)
loss.backward()
self.optimizer.step()

accuracy = float(torch.sum(torch.max(pred.detach(), dim=1)[1] == data.masked_atom_type).cpu().item()) / len(pred)
loss_meter.update(loss.item())
accuracy_meter.update(accuracy)

tbar.set_description('[%d/%d] loss: %.4f, accuracy: %.4f'
% (epoch, self.args.epoch, loss_meter.avg, accuracy_meter.avg))

if not self.args.debug and self.args.tb_log:
self.writer.add_scalar('loss', loss_meter.avg, epoch)
self.writer.add_scalar('accuracy', accuracy_meter.avg, epoch)

if not self.args.debug:
state_dict = self.model.module.state_dict() if not isinstance(self.model, DSGPM) else self.model.state_dict()
torch.save(state_dict, os.path.join(self.ckpt_dir, '{}.pth'.format(epoch)))


def main():
args = arg_parse()
args.use_mask_embed = True
assert args.ckpt is not None, '--ckpt is required'
args.devices = [int(device_id) for device_id in args.devices.split(',')]

trainer = Trainer(args)
for e in range(1, args.epoch + 1):
trainer.train(e)


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def train(fold, epoch, train_loader, model, pos_pair_mse_criterion, triplet_crit
data = data.to(torch.device(0))
model.zero_grad()

fg_embed, cg_fg_ratio = model(data)
fg_embed = model(data)

loss = 0
pos_pair_loss = args.pos_pair_weight * pos_pair_mse_criterion(fg_embed, data.pos_pair_index)
Expand Down Expand Up @@ -84,7 +84,7 @@ def eval(fold, epoch, test_dataloader, model, args):

max_num_cg_beads = gt_hard_assigns.max(axis=1) + 1

fg_embed, cg_fg_ratio = model(data)
fg_embed = model(data)
dense_adj = torch.sparse.LongTensor(data.edge_index, data.no_bond_edge_attr, (num_nodes, num_nodes)).to_dense()

for _ in range(args.test_shots):
Expand Down
Loading

0 comments on commit 7a4c81a

Please sign in to comment.