-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathGST.py
210 lines (168 loc) · 7.59 KB
/
GST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from Hyperparameters import Hyperparameters as hp
class GST(nn.Module):
def __init__(self):
super().__init__()
self.encoder = ReferenceEncoder()
self.stl = STL()
def forward(self, inputs):
enc_out = self.encoder(inputs)
style_embed = self.stl(enc_out)
return style_embed
class ReferenceEncoder(nn.Module):
'''
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
'''
def __init__(self):
super().__init__()
K = len(hp.ref_enc_filters)
filters = [1] + hp.ref_enc_filters
convs = [nn.Conv2d(in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)) for i in range(K)]
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)])
out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K)
self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
hidden_size=hp.E // 2,
batch_first=True)
def forward(self, inputs):
N = inputs.size(0)
out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels]
for conv, bn in zip(self.convs, self.bns):
out = conv(out)
out = bn(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
memory, out = self.gru(out) # out --- [1, N, E//2]
return out.squeeze(0)
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for i in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class STL(nn.Module):
'''
inputs --- [N, E//2]
'''
def __init__(self):
super().__init__()
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
d_q = hp.E // 2
d_k = hp.E // hp.num_heads
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
init.normal_(self.embed, mean=0, std=0.5)
def forward(self, inputs):
N = inputs.size(0)
query = inputs.unsqueeze(1) # [N, 1, E//2]
keys = F.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
style_embed = self.attention(query, keys)
return style_embed
class MultiHeadAttention(nn.Module):
'''
input:
query --- [N, T_q, query_dim]
key --- [N, T_k, key_dim]
output:
out --- [N, T_q, num_units]
'''
def __init__(self, query_dim, key_dim, num_units, num_heads):
super().__init__()
self.num_units = num_units
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
def forward(self, query, key):
querys = self.W_query(query) # [N, T_q, num_units]
keys = self.W_key(key) # [N, T_k, num_units]
values = self.W_value(key)
split_size = self.num_units // self.num_heads
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5)
scores = F.softmax(scores, dim=3)
# out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
return out
class MultiHeadAttention2(nn.Module):
def __init__(self,
query_dim,
key_dim,
num_units,
# dropout_p=0.5,
h=8,
is_masked=False):
super(MultiHeadAttention, self).__init__()
# if query_dim != key_dim:
# raise ValueError("query_dim and key_dim must be the same")
# if num_units % h != 0:
# raise ValueError("num_units must be dividable by h")
# if query_dim != num_units:
# raise ValueError("to employ residual connection, the number of "
# "query_dim and num_units must be the same")
self._num_units = num_units
self._h = h
# self._key_dim = torch.tensor(
# data=[key_dim], requires_grad=True, dtype=torch.float32)
self._key_dim = key_dim
# self._dropout_p = dropout_p
self._is_masked = is_masked
self.query_layer = nn.Linear(query_dim, num_units, bias=False)
self.key_layer = nn.Linear(key_dim, num_units, bias=False)
self.value_layer = nn.Linear(key_dim, num_units, bias=False)
# self.bn = nn.BatchNorm1d(num_units)
def forward(self, query, keys):
Q = self.query_layer(query)
K = self.key_layer(keys)
V = self.value_layer(keys)
# split each Q, K and V into h different values from dim 2
# and then merge them back together in dim 0
chunk_size = int(self._num_units / self._h)
Q = torch.cat(Q.split(split_size=chunk_size, dim=2), dim=0)
K = torch.cat(K.split(split_size=chunk_size, dim=2), dim=0)
V = torch.cat(V.split(split_size=chunk_size, dim=2), dim=0)
# calculate QK^T
attention = torch.matmul(Q, K.transpose(1, 2))
# normalize with sqrt(dk)
attention = attention / (self._key_dim ** 0.5)
# use masking (usually for decoder) to prevent leftward
# information flow and retains auto-regressive property
# as said in the paper
if self._is_masked:
diag_vals = attention[0].sign().abs()
diag_mat = diag_vals.tril()
diag_mat = diag_mat.unsqueeze(0).expand(attention.size())
mask = torch.ones(diag_mat.size()) * (-2**32 + 1)
# this is some trick that I use to combine the lower diagonal
# matrix and its masking. (diag_mat-1).abs() will reverse the value
# inside diag_mat, from 0 to 1 and 1 to zero. with this
# we don't need loop operation andn could perform our calculation
# faster
attention = (attention * diag_mat) + (mask * (diag_mat - 1).abs())
# put it to softmax
attention = F.softmax(attention, dim=-1)
# apply dropout
# attention = F.dropout(attention, self._dropout_p)
# multiplyt it with V
attention = torch.matmul(attention, V)
# convert attention back to its input original size
restore_chunk_size = int(attention.size(0) / self._h)
attention = torch.cat(
attention.split(split_size=restore_chunk_size, dim=0), dim=2)
return attention