forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
192 lines (167 loc) · 8.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from utils import evaluate_f1_score
from data_loader import load_PPI
import argparse
import numpy as np
import os
class GNNFiLMLayer(nn.Module):
def __init__(self, in_size, out_size, etypes, dropout=0.1):
super(GNNFiLMLayer, self).__init__()
self.in_size = in_size
self.out_size = out_size
#weights for different types of edges
self.W = nn.ModuleDict({
name : nn.Linear(in_size, out_size, bias = False) for name in etypes
})
#hypernets to learn the affine functions for different types of edges
self.film = nn.ModuleDict({
name : nn.Linear(in_size, 2*out_size, bias = False) for name in etypes
})
#layernorm before each propogation
self.layernorm = nn.LayerNorm(out_size)
#dropout layer
self.dropout = nn.Dropout(dropout)
def forward(self, g, feat_dict):
#the input graph is a multi-relational graph, so treated as hetero-graph.
funcs = {} #message and reduce functions dict
#for each type of edges, compute messages and reduce them all
for srctype, etype, dsttype in g.canonical_etypes:
messages = self.W[etype](feat_dict[srctype]) #apply W_l on src feature
film_weights = self.film[etype](feat_dict[dsttype]) #use dst feature to compute affine function paras
gamma = film_weights[:,:self.out_size] #"gamma" for the affine function
beta = film_weights[:,self.out_size:] #"beta" for the affine function
messages = gamma * messages + beta #compute messages
messages = F.relu_(messages)
g.nodes[srctype].data[etype] = messages #store in ndata
funcs[etype] = (fn.copy_u(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions
g.multi_update_all(funcs, 'sum') #update all, reduce by first type-wisely then across different types
feat_dict={}
for ntype in g.ntypes:
feat_dict[ntype] = self.dropout(self.layernorm(g.nodes[ntype].data['h'])) #apply layernorm and dropout
return feat_dict
class GNNFiLM(nn.Module):
def __init__(self, etypes, in_size, hidden_size, out_size, num_layers, dropout=0.1):
super(GNNFiLM, self).__init__()
self.film_layers = nn.ModuleList()
self.film_layers.append(
GNNFiLMLayer(in_size, hidden_size, etypes, dropout)
)
for i in range(num_layers-1):
self.film_layers.append(
GNNFiLMLayer(hidden_size, hidden_size, etypes, dropout)
)
self.predict = nn.Linear(hidden_size, out_size, bias = True)
def forward(self, g, out_key):
h_dict = {ntype : g.nodes[ntype].data['feat'] for ntype in g.ntypes} #prepare input feature dict
for layer in self.film_layers:
h_dict = layer(g, h_dict)
h = self.predict(h_dict[out_key]) #use the final embed to predict, out_size = num_classes
h = torch.sigmoid(h)
return h
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test dataloader ============================= #
if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
else:
device = 'cpu'
if args.dataset == 'PPI':
train_set, valid_set, test_set, etypes, in_size, out_size = load_PPI(args.batch_size, device)
# Step 2: Create model and training components=========================================================== #
model = GNNFiLM(etypes, in_size, args.hidden_size, out_size, args.num_layers).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step_size, gamma=args.gamma)
# Step 4: training epoches ============================================================================== #
lastf1 = 0
cnt = 0
best_val_f1 = 0
for epoch in range(args.max_epoch):
train_loss = []
train_f1 = []
val_loss = []
val_f1 = []
model.train()
for batch in train_set:
g = batch.graph
g = g.to(device)
logits = model.forward(g, '_N')
labels = batch.label
loss = criterion(logits, labels)
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
train_f1.append(f1)
train_loss = np.mean(train_loss)
train_f1 = np.mean(train_f1)
scheduler.step()
model.eval()
with torch.no_grad():
for batch in valid_set:
g = batch.graph
g = g.to(device)
logits = model.forward(g, '_N')
labels = batch.label
loss = criterion(logits, labels)
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
val_loss.append(loss.item())
val_f1.append(f1)
val_loss = np.mean(val_loss)
val_f1 = np.mean(val_f1)
print('Epoch {:d} | Train Loss {:.4f} | Train F1 {:.4f} | Val Loss {:.4f} | Val F1 {:.4f} |'.format(epoch + 1, train_loss, train_f1, val_loss, val_f1))
if val_f1 > best_val_f1:
best_val_f1 = val_f1
torch.save(model.state_dict(), os.path.join(args.save_dir, args.name))
if val_f1 < lastf1:
cnt += 1
if cnt == args.early_stopping:
print('Early stop.')
break
else:
cnt = 0
lastf1 = val_f1
model.eval()
test_loss = []
test_f1 = []
model.load_state_dict(torch.load(os.path.join(args.save_dir, args.name)))
with torch.no_grad():
for batch in test_set:
g = batch.graph
g = g.to(device)
logits = model.forward(g, '_N')
labels = batch.label
loss = criterion(logits, labels)
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
test_loss.append(loss.item())
test_f1.append(f1)
test_loss = np.mean(test_loss)
test_f1 = np.mean(test_f1)
print("Test F1: {:.4f} | Test loss: {:.4f}".format(test_f1, test_loss))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GNN-FiLM')
parser.add_argument("--dataset", type=str, default="PPI", help="DGL dataset for this GNN-FiLM")
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
parser.add_argument("--in_size", type=int, default=50, help="Input dimensionalities")
parser.add_argument("--hidden_size", type=int, default=320, help="Hidden layer dimensionalities")
parser.add_argument("--out_size", type=int, default=121, help="Output dimensionalities")
parser.add_argument("--num_layers", type=int, default=4, help="Number of GNN layers")
parser.add_argument("--batch_size", type=int, default=5, help="Batch size")
parser.add_argument("--max_epoch", type=int, default=1500, help="The max number of epoches. Default: 500")
parser.add_argument("--early_stopping", type=int, default=80, help="Early stopping. Default: 50")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 3e-1")
parser.add_argument("--wd", type=float, default=0.0009, help="Weight decay. Default: 3e-1")
parser.add_argument('--step-size', type=int, default=40, help='Period of learning rate decay.')
parser.add_argument('--gamma', type=float, default=0.8, help='Multiplicative factor of learning rate decay.')
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate. Default: 0.9")
parser.add_argument('--save_dir', type=str, default='./out', help='Path to save the model.')
parser.add_argument("--name", type=str, default='GNN-FiLM', help="Saved model name.")
args = parser.parse_args()
print(args)
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
main(args)