This repository has been archived by the owner on Jan 8, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcheck.py
376 lines (342 loc) · 19.3 KB
/
check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import time
import torch
import numpy as np
from tqdm import trange, tqdm
from cuda_implementation import relative_positioning_2d, relative_positioning_3d
from np_implementation import python_relative_att_nd
from relative_embedding import EmbeddingPaddingMode, PositionEmbeddingType, KeyStartPosition
from relative_attention import RelativeAttention1d, RelativeAttention2d, RelativeAttention3d
tf_is_available = True
try:
from tf_code import TensorFlowTest
except ImportError:
tf_is_available = False
profiler_is_available = True
try:
from pytorch_memlab import profile
except ImportError:
profiler_is_available = False
def profile(func):
return func
def assert_equal(a, b):
assert np.allclose(a, b, atol=1.e-5), np.max(np.abs(a - b))
def correctness_check_1d_basic():
batch_size = 2
num_heads = 3
width = 4
depth = 5
for _ in trange(100, desc='1d correctness check'):
for heads_share_relative_embedding in (True, False):
for mask in [None, torch.randn(batch_size * num_heads, width, width) > 0]:
net = RelativeAttention1d(num_heads, depth * num_heads, width,
heads_share_relative_embeddings=heads_share_relative_embedding,
embedding_padding_modes=EmbeddingPaddingMode.Zero,
position_embedding_types=PositionEmbeddingType.Learned,
add_bias_to_query_for_relative_logits=False,
add_bias_to_query_for_key_logit=False)
q = torch.randn(batch_size, num_heads, width, depth)
k = torch.randn_like(q)
torch_ans = net(q, k, mask).detach().numpy()
width_key_relative_embeddings = net.relative_embeddings[0](width).detach().numpy()
np_ans = python_relative_att_nd(q.numpy(), k.numpy(), heads_share_relative_embedding,
width_key_relative_embeddings, mask=mask)
assert_equal(torch_ans, np_ans)
def correctness_check_2d():
batch_size = 2
num_heads = 3
width = 4
depth = 5
height = 6
for _ in trange(100, desc='2d correctness check'):
for heads_share_relative_embedding in (True, False):
for mask in [None, torch.randn(batch_size * num_heads, height * width, height * width) > 0]:
net = RelativeAttention2d(num_heads, depth * num_heads, [width, height],
heads_share_relative_embeddings=heads_share_relative_embedding,
embedding_padding_modes=EmbeddingPaddingMode.Zero,
position_embedding_types=PositionEmbeddingType.Learned,
add_bias_to_query_for_relative_logits=False,
add_bias_to_query_for_key_logit=False)
q = torch.randn(batch_size, num_heads, height, width, depth)
k = torch.randn_like(q)
net.use_custom_cuda_kernel = False
torch_ans = net(q, k, mask).detach().numpy()
width_key_relative_embeddings = net.relative_embeddings[0](width).detach().numpy()
height_key_relative_embeddings = net.relative_embeddings[1](height).detach().numpy()
np_ans = python_relative_att_nd(q.numpy(), k.numpy(), heads_share_relative_embedding,
width_key_relative_embeddings, height_key_relative_embeddings,
mask=mask)
net = net.cuda()
net.use_custom_cuda_kernel = True
custom_ans = net(q.cuda(), k.cuda(), mask.cuda() if mask is not None else None).detach().cpu().numpy()
assert_equal(torch_ans, np_ans)
assert_equal(custom_ans, np_ans)
def correctness_check_tf():
batch_size = 2
num_heads = 3
width = 4
depth = 5
height = 6
for _ in trange(100, desc='tf correctness check'):
for heads_share_relative_embedding in (True, False):
q, k, tf_ans, height_key_relative_embeddings, width_key_relative_embeddings = TensorFlowTest().run(
batch_size, height, width, num_heads, max(height, width), depth, heads_share_relative_embedding)
if heads_share_relative_embedding:
width_key_relative_embeddings = width_key_relative_embeddings.transpose(1, 0)
height_key_relative_embeddings = height_key_relative_embeddings.transpose(1, 0)
else:
width_key_relative_embeddings = width_key_relative_embeddings.transpose(0, 2, 1)
height_key_relative_embeddings = height_key_relative_embeddings.transpose(0, 2, 1)
np_ans = python_relative_att_nd(q, k, heads_share_relative_embedding,
width_key_relative_embeddings, height_key_relative_embeddings)
assert_equal(tf_ans.reshape(np_ans.shape), np_ans)
def correctness_check_3d():
batch_size = 1
num_heads = 2
width = 3
depth = 4
height = 5
time = 6
for _ in trange(100, desc='3d correctness check'):
for heads_share_relative_embedding in (True, False):
for mask in [None, torch.randn(batch_size * num_heads, time * height * width, time * height * width) > 0]:
net = RelativeAttention3d(num_heads, depth * num_heads, [width, height, time],
heads_share_relative_embeddings=heads_share_relative_embedding,
embedding_padding_modes=EmbeddingPaddingMode.Zero,
position_embedding_types=PositionEmbeddingType.Learned,
add_bias_to_query_for_relative_logits=False,
add_bias_to_query_for_key_logit=False)
q = torch.randn(batch_size, num_heads, time, height, width, depth)
k = torch.randn_like(q)
net.use_custom_cuda_kernel = False
torch_ans = net(q, k, mask).detach().numpy()
width_key_relative_embeddings = net.relative_embeddings[0](width).detach().numpy()
height_key_relative_embeddings = net.relative_embeddings[1](height).detach().numpy()
time_key_relative_embeddings = net.relative_embeddings[2](time).detach().numpy()
np_ans = python_relative_att_nd(q.numpy(), k.numpy(), heads_share_relative_embedding,
width_key_relative_embeddings, height_key_relative_embeddings,
time_key_relative_embeddings, mask=mask)
net = net.cuda()
net.use_custom_cuda_kernel = True
custom_ans = net(q.cuda(), k.cuda(), mask.cuda() if mask is not None else None).detach().cpu().numpy()
assert_equal(torch_ans, np_ans)
assert_equal(custom_ans, np_ans)
def config_check():
batch_size = 2
num_heads = 3
depth = 5
width_q = 8
height_q = 6
time_q = 4
config_tqdm = tqdm(total=832, desc='check configs')
for n in range(3):
for max_relative_positions_past in (3,):
for max_relative_positions_future in (2, 5):
for heads_share_relative_embedding in (True, False):
for embedding_padding_mode in range(3):
for position_embedding_type in range(3):
if position_embedding_type == 1 and embedding_padding_mode == 2: # learned and extend
continue
if position_embedding_type == 0 and not heads_share_relative_embedding: # fixed and !shared
continue
for key_start_position in range(2):
for add_bias_to_query_for_relative_logits in (True,):
for add_bias_to_query_for_key_logit in (True,):
net = RelativeAttention3d if n == 2 else (
RelativeAttention2d if n == 1 else RelativeAttention1d)
net = net(num_heads, num_heads * depth, max_relative_positions_past,
max_relative_positions_future, heads_share_relative_embedding,
embedding_padding_mode, position_embedding_type,
key_start_position, add_bias_to_query_for_relative_logits,
add_bias_to_query_for_key_logit)
for use_custom in (True, False):
if n != 0:
net.use_custom_cuda_kernel = use_custom
net = net.cuda() if use_custom else net.cpu()
pass
elif use_custom:
continue
if n == 0:
q = torch.randn(batch_size, num_heads, width_q, depth)
elif n == 1:
q = torch.randn(batch_size, num_heads, width_q, height_q, depth)
else:
q = torch.randn(batch_size, num_heads, width_q, height_q, time_q, depth)
if use_custom:
q = q.cuda()
for width_k in (width_q, width_q // 2, width_q * 2):
if (n == 0 or not use_custom) and width_k != width_q:
continue
if key_start_position == 0 and width_k == width_q // 2: # before and q > k
continue
for height_k in (height_q, height_q // 2, height_q * 2):
if (n == 0 or not use_custom) and height_k != height_q:
continue
if key_start_position == 0 and height_k == height_q // 2: # before and q > k
continue
for time_k in (time_q,):
if (n == 0 or not use_custom) and time_k != time_q:
continue
if n == 1 and time_k != time_q:
continue
if key_start_position == 0 and time_k == time_q // 2: # before and q > k
continue
if n == 0:
k = torch.randn(batch_size, num_heads, width_k, depth)
elif n == 1:
k = torch.randn(batch_size, num_heads, width_k, height_k,
depth)
else:
k = torch.randn(batch_size, num_heads, width_k, height_k,
time_k, depth)
if use_custom:
k = k.cuda()
net(q, k)
config_tqdm.update()
def grad_check():
for i in trange(10, desc='grad check'):
for j in range(2):
if i == 0: H = 1; B = 1; w_q = h_q = 1; w_k = h_k = 1; t_k = t_q = 1
if i == 1: H = 1; B = 1; w_q = h_q = 2; w_k = h_k = 1; t_k = t_q = 2
if i == 2: H = 1; B = 2; w_q = h_q = 1; w_k = h_k = 2; t_k = t_q = 4
if i == 3: H = 1; B = 2; w_q = h_q = 4; w_k = h_k = 4; t_k = t_q = 1
if i == 4: H = 2; B = 1; w_q = h_q = 8; w_k = h_k = 4; t_k = 1; t_q = 2
if i == 5: H = 2; B = 1; w_q = h_q = 1; w_k = h_k = 8; t_k = 2; t_q = 1
if i == 6: H = 2; B = 2; w_q = h_q = 8; w_k = h_k = 1; t_k = 4; t_q = 1
if i == 7: H = 2; B = 2; w_q = 2; h_q = 4; w_k = 1; h_k = 2; t_k = 1; t_q = 4
if i == 8: H = 4; B = 1; w_q = 2; h_q = 4; w_k = 4; h_k = 2; t_k = 2; t_q = 2
if i == 9: H = 4; B = 2; w_q = 4; h_q = 2; w_k = 2; h_k = 1; t_k = 2; t_q = 4
N = B * H
if j == 0:
t_k = t_q = 1
logits = torch.randn(N, w_q * h_q * t_q, w_k * h_k * t_k).double().cuda().requires_grad_()
r_w = torch.randn(N, w_q * h_q * t_q, w_q + w_k - 1).double().cuda().requires_grad_()
r_h = torch.randn(N, w_q * h_q * t_q, h_q + h_k - 1).double().cuda().requires_grad_()
for mi, mask in enumerate((None, (torch.randn(w_q * h_q * t_q, w_k * h_k * t_k) > 0).bool().cuda(),
(torch.randn(N, w_q * h_q * t_q, w_k * h_k * t_k) > 0).bool().cuda())):
if j == 0:
torch.autograd.gradcheck(relative_positioning_2d, (logits, r_h, r_w, h_q, w_q, h_k, w_k, mask))
else:
r_t = torch.randn(N, w_q * h_q * t_q, t_q + t_k - 1).double().cuda().requires_grad_()
torch.autograd.gradcheck(relative_positioning_3d,
(logits, r_t, r_h, r_w, t_q, h_q, w_q, t_k, h_k, w_k, mask))
def speed_check():
model_depth = 32
num_heads = 4
width = 16
height = 16
time_ = 4
num_runs = 1000
speed_tqdm = tqdm([False, True, False, True], desc='speed check')
shared_params = dict(max_relative_positions_future=None, heads_share_relative_embeddings=True,
embedding_padding_modes=EmbeddingPaddingMode.Extend,
position_embedding_types=PositionEmbeddingType.Hybrid,
key_start_positions=KeyStartPosition.BeforeQuery, add_bias_to_query_for_relative_logits=True,
add_bias_to_query_for_key_logit=True, use_custom_cuda_kernel=True)
forward_speedup_2d = 'N/A'
backward_speedup_2d = 'N/A'
forward_speedup_3d = 'N/A'
backward_speedup_3d = 'N/A'
for i, backward in enumerate(speed_tqdm):
if i // 2 == 0:
B = 32
net = RelativeAttention2d(num_heads, model_depth, [width, height], **shared_params).cuda()
else:
B = 8
net = RelativeAttention3d(num_heads, model_depth, [width, height, time_], **shared_params).cuda()
with torch.set_grad_enabled(backward):
if i // 2 == 0:
q = torch.randn(B, num_heads, height, width, model_depth // num_heads).cuda()
else:
q = torch.randn(B, num_heads, time_, height, width, model_depth // num_heads).cuda()
k = torch.randn_like(q).cuda()
if backward:
q.requires_grad_()
k.requires_grad_()
for j in range(2):
if j == 0:
net.use_custom_cuda_kernel = True
else:
net.use_custom_cuda_kernel = False
# warmup
ans = net(q, k)
if backward:
ans.mean().backward()
start = time.time()
for _ in trange(num_runs):
ans = net(q, k)
if backward:
ans.mean().backward()
if j == 0:
custom = time.time() - start
else:
default = time.time() - start
del ans
if i // 2 == 0:
if backward:
backward_speedup_2d = default / custom
else:
forward_speedup_2d = default / custom
else:
if backward:
backward_speedup_3d = default / custom
else:
forward_speedup_3d = default / custom
speed_tqdm.set_description(f'f2: {backward_speedup_2d}, b2: {forward_speedup_2d}, '
f'f3:{forward_speedup_3d}, b3:{backward_speedup_3d}')
def run_profiler():
for is_2d in (True, False):
for is_custom in (True, False):
model_depth = 32
B = 32 if is_2d else 8
num_heads = 8 if is_2d else 4
width = 16 if is_2d else 8
height = 16 if is_2d else 8
time_ = 4
shared_params = dict(heads_share_relative_embeddings=True,
embedding_padding_modes=EmbeddingPaddingMode.Extend,
position_embedding_types=PositionEmbeddingType.Hybrid,
add_bias_to_query_for_relative_logits=True, add_bias_to_query_for_key_logit=True,
use_custom_cuda_kernel=is_custom)
if is_2d:
net = RelativeAttention2d(num_heads, model_depth, [width, height], **shared_params).cuda()
q = torch.randn(B, num_heads, height, width, model_depth // num_heads).cuda()
else:
net = RelativeAttention3d(num_heads, model_depth, [width, height, time_], **shared_params).cuda()
q = torch.randn(B, num_heads, height, width, time_, model_depth // num_heads).cuda()
k = torch.randn_like(q).cuda()
m = (torch.randn(B * num_heads, height * width * (1 if is_2d else time_),
height * width * (1 if is_2d else time_)) > 0).cuda()
func_list = [[run_3d_default, run_3d_custom], [run_2d_default, run_2d_custom]]
net.use_custom_cuda_kernel = is_custom
func_list[is_2d][is_custom](net, q, k, m)
@profile
def run_2d_custom(net, q, k, m):
ans = net(q, k, m)
ans.mean().backward()
@profile
def run_2d_default(net, q, k, m):
ans = net(q, k, m)
ans.mean().backward()
@profile
def run_3d_custom(net, q, k, m):
ans = net(q, k, m)
ans.mean().backward()
@profile
def run_3d_default(net, q, k, m):
ans = net(q, k, m)
ans.mean().backward()
if __name__ == '__main__':
correctness_check_1d_basic()
correctness_check_2d()
if tf_is_available:
correctness_check_tf()
else:
print('skipping tf check')
correctness_check_3d()
config_check()
grad_check()
speed_check()
if profiler_is_available:
run_profiler()
else:
print('skipping memory profile checks')