forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_ns.py
103 lines (86 loc) · 3.64 KB
/
train_ns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Training and testing for node selection tasks in bAbI
"""
import argparse
from data_utils import get_babi_dataloaders
from ggnn_ns import NodeSelectionGGNN
from torch.optim import Adam
import torch
import numpy as np
import time
def main(args):
out_feats = {4: 4, 15: 5, 16: 6}
n_etypes = {4: 4, 15: 2, 16: 2}
train_dataloader, dev_dataloader, test_dataloaders = \
get_babi_dataloaders(batch_size=args.batch_size,
train_size=args.train_num,
task_id=args.task_id,
q_type=args.question_id)
model = NodeSelectionGGNN(annotation_size=1,
out_feats=out_feats[args.task_id],
n_steps=5,
n_etypes=n_etypes[args.task_id])
opt = Adam(model.parameters(), lr=args.lr)
print(f'Task {args.task_id}, question_id {args.question_id}')
print(f'Training set size: {len(train_dataloader.dataset)}')
print(f'Dev set size: {len(dev_dataloader.dataset)}')
# training and dev stage
for epoch in range(args.epochs):
model.train()
for i, batch in enumerate(train_dataloader):
g, labels = batch
loss, _ = model(g, labels)
opt.zero_grad()
loss.backward()
opt.step()
print(f'Epoch {epoch}, batch {i} loss: {loss.data}')
dev_preds = []
dev_labels = []
model.eval()
for g, labels in dev_dataloader:
with torch.no_grad():
preds = model(g)
preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
labels = labels.data.numpy().tolist()
dev_preds += preds
dev_labels += labels
acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
print(f"Epoch {epoch}, Dev acc {acc}")
# test stage
for i, dataloader in enumerate(test_dataloaders):
print(f'Test set {i} size: {len(dataloader.dataset)}')
test_acc_list = []
for dataloader in test_dataloaders:
test_preds = []
test_labels = []
model.eval()
for g, labels in dataloader:
with torch.no_grad():
preds = model(g)
preds = torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
labels = labels.data.numpy().tolist()
test_preds += preds
test_labels += labels
acc = np.equal(test_labels, test_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
test_acc_list.append(acc)
test_acc_mean = np.mean(test_acc_list)
test_acc_std = np.std(test_acc_list)
print(f'Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Gated Graph Neural Networks for node selection tasks in bAbI')
parser.add_argument('--task_id', type=int, default=16,
help='task id from 1 to 20')
parser.add_argument('--question_id', type=int, default=1,
help='question id for each task')
parser.add_argument('--train_num', type=int, default=50,
help='Number of training examples')
parser.add_argument('--batch_size', type=int, default=10,
help='batch size')
parser.add_argument('--lr', type=float, default=1e-3,
help='learning rate')
parser.add_argument('--epochs', type=int, default=100,
help='number of training epochs')
args = parser.parse_args()
main(args)