-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGearNet.py
126 lines (109 loc) · 5.56 KB
/
GearNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from torch import nn
from torch_scatter import scatter_add
from torchdrug import layers
class GearNet(nn.Module):
"""
Geometry Aware Relational Graph Neural Network proposed in
`Protein Representation Learning by Geometric Structure Pretraining`_.
.. _Protein Representation Learning by Geometric Structure Pretraining:
https://arxiv.org/pdf/2203.06125.pdf
Parameters:
input_dim (int): input dimension
hidden_dims (list of int): hidden dimensions
num_relation (int): number of relations
edge_input_dim (int, optional): dimension of edge features
num_angle_bin (int, optional): number of bins to discretize angles between edges.
The discretized angles are used as relations in edge message passing.
If not provided, edge message passing is disabled.
short_cut (bool, optional): use short cut or not
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
"""
def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, num_angle_bin=None,
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
super(GearNet, self).__init__()
self.input_dim = input_dim
self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
self.dims = [input_dim] + list(hidden_dims)
self.edge_dims = [edge_input_dim] + self.dims[:-1]
self.num_relation = num_relation
self.num_angle_bin = num_angle_bin
self.short_cut = short_cut
self.concat_hidden = concat_hidden
self.batch_norm = batch_norm
self.dropout = nn.Dropout(p = 0.5)
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(layers.GeometricRelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation,
None, batch_norm, activation))
if num_angle_bin:
self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
self.edge_layers = nn.ModuleList()
for i in range(len(self.edge_dims) - 1):
self.edge_layers.append(layers.GeometricRelationalGraphConv(
self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))
if batch_norm:
self.batch_norms = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1]))
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)
# MLP output layer
if concat_hidden:
self.mlp = layers.MLP(sum(hidden_dims), 1,
batch_norm=False, dropout=0)
else:
self.mlp = layers.MLP(hidden_dims[-1], 1,
batch_norm=False, dropout=0, activation = 'relu')
def forward(self, graph, input=None, all_loss=None, metric=None):
"""
Compute the node representations and the graph representation(s).
Parameters:
graph (Graph): :math:`n` graph(s)
input (Tensor): input node representations
all_loss (Tensor, optional): if specified, add loss to this tensor
metric (dict, optional): if specified, output metrics to this dict
Returns:
dict with ``node_feature`` and ``graph_feature`` fields:
node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)`
"""
hiddens = []
layer_input = input
if self.num_angle_bin:
line_graph = self.spatial_line_graph(graph)
edge_input = line_graph.node_feature.float()
for i in range(len(self.layers)):
hidden = self.layers[i](graph, layer_input)
if self.short_cut and hidden.shape == layer_input.shape:
hidden = hidden + layer_input
if self.num_angle_bin:
edge_hidden = self.edge_layers[i](line_graph, edge_input)
edge_weight = graph.edge_weight.unsqueeze(-1)
node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
dim_size=graph.num_node * self.num_relation)
update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
update = self.layers[i].linear(update)
update = self.layers[i].activation(update)
hidden = hidden + update
edge_input = edge_hidden
if self.batch_norm:
hidden = self.batch_norms[i](hidden)
hidden = self.dropout(hidden)
hiddens.append(hidden)
layer_input = hidden
if self.concat_hidden:
node_feature = torch.cat(hiddens, dim=-1)
else:
node_feature = hiddens[-1]
graph_feature = self.readout(graph, node_feature)
pred = self.mlp(graph_feature).squeeze(-1)
return nn.functional.sigmoid(pred)
# add sigmoid for BCE loss calculation