-
Notifications
You must be signed in to change notification settings - Fork 41
/
convgru.py
133 lines (103 loc) · 4.58 KB
/
convgru.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.nn import init
class ConvGRUCell(nn.Module):
"""
Generate a convolutional GRU cell
"""
def __init__(self, input_size, hidden_size, kernel_size):
super().__init__()
padding = kernel_size // 2
self.input_size = input_size
self.hidden_size = hidden_size
self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
init.orthogonal(self.reset_gate.weight)
init.orthogonal(self.update_gate.weight)
init.orthogonal(self.out_gate.weight)
init.constant(self.reset_gate.bias, 0.)
init.constant(self.update_gate.bias, 0.)
init.constant(self.out_gate.bias, 0.)
def forward(self, input_, prev_state):
# get batch and spatial sizes
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
# generate empty prev_state, if None is provided
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
if torch.cuda.is_available():
prev_state = Variable(torch.zeros(state_size)).cuda()
else:
prev_state = Variable(torch.zeros(state_size))
# data size is [batch, channel, height, width]
stacked_inputs = torch.cat([input_, prev_state], dim=1)
update = F.sigmoid(self.update_gate(stacked_inputs))
reset = F.sigmoid(self.reset_gate(stacked_inputs))
out_inputs = F.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1)))
new_state = prev_state * (1 - update) + out_inputs * update
return new_state
class ConvGRU(nn.Module):
def __init__(self, input_size, hidden_sizes, kernel_sizes, n_layers):
'''
Generates a multi-layer convolutional GRU.
Preserves spatial dimensions across cells, only altering depth.
Parameters
----------
input_size : integer. depth dimension of input tensors.
hidden_sizes : integer or list. depth dimensions of hidden state.
if integer, the same hidden size is used for all cells.
kernel_sizes : integer or list. sizes of Conv2d gate kernels.
if integer, the same kernel size is used for all cells.
n_layers : integer. number of chained `ConvGRUCell`.
'''
super(ConvGRU, self).__init__()
self.input_size = input_size
if type(hidden_sizes) != list:
self.hidden_sizes = [hidden_sizes]*n_layers
else:
assert len(hidden_sizes) == n_layers, '`hidden_sizes` must have the same length as n_layers'
self.hidden_sizes = hidden_sizes
if type(kernel_sizes) != list:
self.kernel_sizes = [kernel_sizes]*n_layers
else:
assert len(kernel_sizes) == n_layers, '`kernel_sizes` must have the same length as n_layers'
self.kernel_sizes = kernel_sizes
self.n_layers = n_layers
cells = []
for i in range(self.n_layers):
if i == 0:
input_dim = self.input_size
else:
input_dim = self.hidden_sizes[i-1]
cell = ConvGRUCell(input_dim, self.hidden_sizes[i], self.kernel_sizes[i])
name = 'ConvGRUCell_' + str(i).zfill(2)
setattr(self, name, cell)
cells.append(getattr(self, name))
self.cells = cells
def forward(self, x, hidden=None):
'''
Parameters
----------
x : 4D input tensor. (batch, channels, height, width).
hidden : list of 4D hidden state representations. (batch, channels, height, width).
Returns
-------
upd_hidden : 5D hidden representation. (layer, batch, channels, height, width).
'''
if not hidden:
hidden = [None]*self.n_layers
input_ = x
upd_hidden = []
for layer_idx in range(self.n_layers):
cell = self.cells[layer_idx]
cell_hidden = hidden[layer_idx]
# pass through layer
upd_cell_hidden = cell(input_, cell_hidden)
upd_hidden.append(upd_cell_hidden)
# update input_ to the last updated hidden layer for next pass
input_ = upd_cell_hidden
# retain tensors in list to allow different hidden sizes
return upd_hidden