-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.cu
647 lines (552 loc) · 25.3 KB
/
test.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
#include <cuda.h>
#include <cuda_runtime.h>
#include <sys/time.h>
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>
# define d 64 //Need to set this depending on the dim you specify in bench.py
#define NEG_INFINITY __int_as_float(0xff800000)
# define B_r 32 // B_r or BM
# define B_c 32 // B_c or BN
# define BK 32 // used to be B_c but now different due to coarsening
// thread - 2nd level tiling
# define TM 4 // threadblock size
# define TN 4// threadblock size
# define CACHE_Q 1 // if you want to cache Q over full d (by default set it to True)
double getTimeStamp() {
struct timeval tv;
gettimeofday( &tv, NULL );
return (double) tv.tv_usec/1000000 + tv.tv_sec;
}
__global__
void flash_tiled(float *out, float *K, float *Q, float* V, float scaling, int batch_stride, int T_r, int T_c)
{
int tid_x = threadIdx.x;
int tid_y = threadIdx.y;
int batch_offset = batch_stride * blockIdx.x;
/*
all are fully loaded into shared memory, I think we should adjust this as second step to only loading it in tiles of B_r x 32
and iterating the mults over the 32 sized tiles this way we can have a larger d, while keeping occupancy high
*/
__shared__ float Q_i[B_r][d];
__shared__ float K_j[B_r][B_c];
__shared__ float V_j[B_r][d];
// attention result
__shared__ float S_i[B_r][B_c];
// assuming B_c = blockdim.x, within a block, number of tiles a thread has to calculate
const int num_tiles = d/B_c;
float l_i;
float m_i;
assert (B_r == B_c && B_r == blockDim.x && B_r == blockDim.y);
// assert (num_tiles == 1); // Hack: for now
// this will be automatucally be put onto registers since very small
float O_i[num_tiles]; // per register
for (int t = 0; t < num_tiles; t++) {
O_i[t] = 0;
}
// row wise statistics
for (int t = 0; t < num_tiles; t++) {
l_i = 0.f;
m_i = NEG_INFINITY;
}
// load Q_i
for (int t=0; t<num_tiles; t++){
Q_i[tid_y][t * B_c + tid_x] = Q[batch_offset + (blockIdx.y * B_r + tid_y) * d + t * B_c + tid_x ];
}
__syncthreads();
// T_c = seq_len (due to K^T) / B_c, chunk over the d dimension
// T_c is the number of chunks of K, we iterate over them
for (int j = 0; j < T_c; j++) {
S_i[tid_y][tid_x] = 0.f;
float S_ij = 0.f;
for (int t=0; t<num_tiles; t++){
// load K_j and V_j, thread idx, idy loads idy,idx
// we load a tile
K_j[tid_y][tid_x] = K[batch_offset + (tid_y + j * B_c) * d + tid_x + t * B_c]; // not with with r and c
// TO OPTIMIZE, just loading the V_j for now
V_j[tid_y][t * B_c + tid_x] = V[batch_offset + (tid_y + j * B_c) * d + tid_x + t * B_c]; // not with with r and c
__syncthreads();
for (int dd=0; dd<B_c; dd++){
S_ij += Q_i[tid_y][t*B_c+dd] * K_j[tid_x][dd]; // this maybe leads to bank conflicts in the K
}
__syncthreads();
}
S_i[tid_y][tid_x] += scaling * S_ij;
__syncthreads();
float last_m = m_i;
float m = m_i;
for (int jj = 0; jj < B_c; jj += 1) {
if (m < S_i[tid_y][jj]) {
m = S_i[tid_y][jj];
}
}
__syncthreads();
m_i = m;
// 2) renormalize current O
for (int t = 0; t < num_tiles; t++){
O_i[t] *= exp(last_m - m);
}
// 3) renormalize the sum
float l = exp(last_m - m) * l_i;
// 4) compute \exp(Q_iK^T_{j+1} - m^{j+1}) = \exp(S_i-m^{j+1})
float S_id;
__syncthreads();
for (int dd = 0; dd < B_c; dd++) {
S_id = exp(S_i[tid_y][dd] - m);
l += S_id;
for (int t = 0; t < num_tiles; t++){
// replaced o_y with 1
O_i[t] += S_id * V_j[dd][t * B_c + tid_x];
}
}
l_i = l;
__syncthreads();
}
// normalize the whole thing by the sum and write to output
for (int t = 0; t < num_tiles; t++){
out[batch_offset + (blockIdx.y * B_r + tid_y ) * d + t * B_c + tid_x] = O_i[t] / l_i;
}
}
__global__
void flash_tiled_coarse(float *out, float* out_l, float *K, float *Q, float* V, float scaling, int batch_stride, int T_r, int T_c, int seq_len)
{
int tid_x = threadIdx.x;
int tid_y = threadIdx.y;
int batch_offset = batch_stride * blockIdx.x;
/*
all are fully loaded into shared memory SMEM, I think we should adjust this as second step to only loading it in tiles of B_r x 32
and iterating the mults over the 32 sized tiles this way we can have a larger d, while keeping occupancy high
*/
/*
NOTE: all are fully loaded into shared memory SMEM, I think we should adjust this as second step to only loading it in tiles of B_r x 32
and iterating the mults over the 32 sized tiles this way we can have a larger d, while keeping occupancy high
*/
// statically define in SMEM and still address it with indices
//__shared__ float Q_i[B_r][d]; // uncomment only if you want to cache over full d (if CACHE_Q = 1)
__shared__ float Q_i[B_r][BK]; // if you want to save SMEM loads and keep the full Q loaded then change this to [B_r][d]
__shared__ float K_j[B_c][BK+1]; // reduce SMEM bank conflicts by adding 1 column as K will be loaded transposed!
__shared__ float V_j[B_c][BK];
// attention result
__shared__ float S_i[B_r][B_c+1]; // reduce SMEM bank conflicts by adding 1 column (in the naive softmax part)
const int num_tiles = d/BK; // how many tiles are the computation of the attention is split into
const uint totalResultsBlocktile = B_r * B_c; // number of results to calculate per block
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN); // number of threads needed
const int threadId_flat = threadIdx.y * blockDim.x + threadIdx.x; // flattened thread id (used for coalesced loading of tiles)
// each thread process one block at position:
const int threadCol = threadId_flat % (B_c / TN);
const int threadRow = threadId_flat / (B_c / TN);
float l_i[TM]= {0.0};; // storing the intermediate sum of exponentials per row
float m_i[TM]; // storing the intermediate max value of the rows
float last_m[TM]; // storing the last max value of the rows
float O_i[num_tiles * TN * TM] = {0.0}; // storing the intermediate results of the Outputs (each thread stores a chunk TM x TN per tile)
// reset to min
for (int ii = 0; ii < TM; ii++) {
m_i[ii] = -INFINITY;
}
//WARNING: due to coalsecing I should probably add a second set of variables for using BK+1
const uint strideK = numThreadsBlocktile / BK; // 64 / 64 = 1
const uint innerRowK = threadId_flat / BK; // 0-63 / 64, 0000000000000...0
const uint innerColK = threadId_flat % BK; // 0-63 % 64, 0123456789101112...63
int id;
// load Q_i, UNCOMMENT only if your Q is caching over full d
const uint innerRowQ = threadId_flat / d; // 0-63 / 64, 0000000000000...0
const uint innerColQ = threadId_flat % d; // 0-63 % 64, 0123456789012...63
const uint nr_loads = B_r * d / numThreadsBlocktile;
for (int t=0; t<nr_loads; t++){
// need to load block of size B_r x d (64 x 64) with numThreadsBlocktile threads
// if (blockIdx.y * B_r + innerRowQ) * d + innerColQ + t * numThreadsBlocktile / d
id = (blockIdx.y * B_r + innerRowQ) * d + innerColQ + t * numThreadsBlocktile;
// 4 x 4 then this is 5 thus 5/
if (id < d*seq_len){
Q_i[innerRowQ][innerColQ + t * numThreadsBlocktile] = Q[batch_offset + id];
}
else {
Q_i[innerRowQ][innerColQ + t * numThreadsBlocktile] = 0.0;
}
}
__syncthreads();
// scratchpad register for register-tiling (coarsening of the matrix mults)
float regM[TM] = {0.0};
float regN[TN] = {0.0};
for (int j = 0; j < T_c; j++) { // iterate of ver the chunks of K and V
float threadResults[TM * TN] = {0.0}; // storing the intermediate outputs
for (int t=0; t<num_tiles; t++){
// load K_j and V_j, thread idx, idy loads idy,idx
// we load a tile
for (int i=0; i<B_r; i+=strideK){
// load Q, K and V in tiles (for now we are loading the full V)
if (not CACHE_Q){Q_i[innerRowK+i][innerColK] = Q[batch_offset + (innerRowK + blockIdx.y * B_r) * d + i * d + innerColK + t * B_c];
} // if you cache Q over whole d then remove this line
id = (innerRowK + j * B_c) * d + i * d + innerColK + t * B_c;
if (id < d*seq_len){
K_j[innerRowK+i][innerColK] = K[batch_offset + id];
//V_j[innerRowK+i][innerColK+t*B_c] = V[batch_offset + id];
} else {
K_j[innerRowK+i][innerColK] = 0.0;
//V_j[innerRowK+i][innerColK+t*B_c] = 0.0;
}
}
__syncthreads();
for (int dd=0; dd<BK; dd++){ // load elements of Q_i and K_j^T into registers
for (uint i = 0; i < TM; ++i) {
if (CACHE_Q){
regM[i] = Q_i[(threadRow * TM + i)][dd+t*BK]; // uncomment if you cache Q over full d
} else {
regM[i] = Q_i[(threadRow * TM + i)][dd];
}
}
for (uint i = 0; i < TN; ++i) {
regN[i] = K_j[threadCol * TN + i][dd];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] += regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
}
// store the results in S_i, account for causal masking
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
S_i[(threadRow * TM + resIdxM)][threadCol * TN + resIdxN] = threadResults[resIdxM * TN + resIdxN] *scaling;
}
}
__syncthreads();
for (int i=0;i<TM;++i){
last_m[i] = m_i[i];
float m = m_i[i];
for (int jj = 0; jj < B_c; jj += 1) {
if (m < S_i[threadRow*TM+i][jj]) {
m = S_i[threadRow*TM+i][jj];
}
}
m_i[i] = m;
}
// 2) renormalize current O
if (j > 0) {
for (int t = 0; t < num_tiles; t++){
for (int i=0;i<TM;++i){
for (int jj=0;jj<TN;++jj){
O_i[t*TN*TM + i*TN + jj] *= exp(last_m[i] - m_i[i]);
}
}
}
}
// 3) renormalize the sum l_i
for (int i=0;i<TM;++i){
l_i[i] *= exp(last_m[i] - m_i[i]);
}
// // 4) compute \exp(Q_iK^T_{j+1} - m^{j+1}) = \exp(S_i-m^{j+1}) // TODO: TO OPTIMIZE
// for (int dd = 0; dd < B_c; dd++) {
// for (int ii = 0; ii < TN; ii++){
// // calculate new sum and load exp(Attention) weights
// //check whether thus is in range or not (if not we set it to 0)
// //if (idrow+ii < seq_len && idcol+dd < seq_len){
// regM[ii] = exp(S_i[threadRow*TM+ii][dd] - m_i[ii]);
// l_i[ii] += regM[ii];
// }
// for (int t = 0; t < num_tiles; t++){
// for (int ii=0;ii<TN;ii++){
// for (int jj=0;jj<TM;jj++){ // calculate output elements
// regN[jj] = V_j[dd][t * B_c + threadCol * TN + jj];
// O_i[t*TN*TM + ii*TM + jj] += regM[ii] * regN[jj];
// }
// }
// }
// __syncthreads();
// }
for (int t = 0; t < num_tiles; t++){
// load V
__syncthreads();
for (int i=0; i<B_r; i+=strideK){
id = (innerRowK + j * B_c) * d + i * d + innerColK + t * B_c;
if (id < d*seq_len){
V_j[innerRowK+i][innerColK] = V[batch_offset + id];
} else {
V_j[innerRowK+i][innerColK] = 0.0;
}
}
__syncthreads();
for (int dd = 0; dd < B_c; dd++) {
for (int ii = 0; ii < TN; ii++){
regM[ii] = exp(S_i[threadRow*TM+ii][dd] - m_i[ii]);
if (t==0){
l_i[ii] += regM[ii];
}
regN[ii] = V_j[dd][threadCol * TN + ii];
}
for (int ii=0;ii<TN;ii++){
for (int jj=0;jj<TM;jj++){ // calculate output elements
regN[jj] = V_j[dd][threadCol * TN + jj];
O_i[t*TN*TM + ii*TM + jj] += regM[ii] * regN[jj];
}
}
}
__syncthreads();
}
}
// normalize by the output sum and write to out matrix
for (int t = 0; t < num_tiles; t++){
for (int ii=0;ii<TM;ii++){
for (int jj=0;jj<TN;jj++){
if(blockIdx.y*B_r+threadRow*TM+ii < seq_len){
out[batch_offset + (blockIdx.y * B_r + threadRow*TM + ii) * d + t * B_c + threadCol*TN + jj] = O_i[t*TN*TM+ii*TM+jj] / l_i[ii];
}
}
}
}
}
__global__
void flash_tiled_coarse_causal(float *out, float* out_l, float *K, float *Q, float* V, float scaling, int batch_stride, int T_r, int T_c, int seq_len)
{
int tid_x = threadIdx.x;
int tid_y = threadIdx.y;
int batch_offset = batch_stride * blockIdx.x;
/*
all are fully loaded into shared memory SMEM, I think we should adjust this as second step to only loading it in tiles of B_r x 32
and iterating the mults over the 32 sized tiles this way we can have a larger d, while keeping occupancy high
*/
/*
NOTE: all are fully loaded into shared memory SMEM, I think we should adjust this as second step to only loading it in tiles of B_r x 32
and iterating the mults over the 32 sized tiles this way we can have a larger d, while keeping occupancy high
*/
// statically define in SMEM and still address it with indices
//__shared__ float Q_i[B_r][d]; // uncomment only if you want to cache over full d (if CACHE_Q = 1)
__shared__ float Q_i[B_r][BK]; // if you want to save SMEM loads and keep the full Q loaded then change this to [B_r][d]
__shared__ float K_j[B_c][BK+1]; // reduce SMEM bank conflicts by adding 1 column as K will be loaded transposed!
__shared__ float V_j[B_c][BK];
// attention result
__shared__ float S_i[B_r][B_c+1]; // reduce SMEM bank conflicts by adding 1 column (in the naive softmax part)
const int num_tiles = d/BK; // how many tiles are the computation of the attention is split into
const uint totalResultsBlocktile = B_r * B_c; // number of results to calculate per block
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN); // number of threads needed
const int threadId_flat = threadIdx.y * blockDim.x + threadIdx.x; // flattened thread id (used for coalesced loading of tiles)
// each thread process one block at position:
const int threadCol = threadId_flat % (B_c / TN);
const int threadRow = threadId_flat / (B_c / TN);
float l_i[TM]= {0.0};; // storing the intermediate sum of exponentials per row
float m_i[TM]; // storing the intermediate max value of the rows
float last_m[TM]; // storing the last max value of the rows
float O_i[num_tiles * TN * TM] = {0.0}; // storing the intermediate results of the Outputs (each thread stores a chunk TM x TN per tile)
// reset to min
for (int ii = 0; ii < TM; ii++) {
m_i[ii] = -INFINITY;
}
//WARNING: due to coalsecing I should probably add a second set of variables for using BK+1
const uint strideK = numThreadsBlocktile / BK; // 64 / 64 = 1
const uint innerRowK = threadId_flat / BK; // 0-63 / 64, 0000000000000...0
const uint innerColK = threadId_flat % BK; // 0-63 % 64, 0123456789101112...63
int id;
// load Q_i, UNCOMMENT only if your Q is caching over full d
const uint innerRowQ = threadId_flat / d; // 0-63 / 64, 0000000000000...0
const uint innerColQ = threadId_flat % d; // 0-63 % 64, 0123456789012...63
const uint nr_loads = B_r * d / numThreadsBlocktile;
for (int t=0; t<nr_loads; t++){
// need to load block of size B_r x d (64 x 64) with numThreadsBlocktile threads
// if (blockIdx.y * B_r + innerRowQ) * d + innerColQ + t * numThreadsBlocktile / d
id = (blockIdx.y * B_r + innerRowQ) * d + innerColQ + t * numThreadsBlocktile;
// 4 x 4 then this is 5 thus 5/
if (id < d*seq_len){
Q_i[innerRowQ][innerColQ + t * numThreadsBlocktile] = Q[batch_offset + id];
}
else {
Q_i[innerRowQ][innerColQ + t * numThreadsBlocktile] = 0.0;
}
}
__syncthreads();
// scratchpad register for register-tiling (coarsening of the matrix mults)
float regM[TM] = {0.0};
float regN[TN] = {0.0};
for (int j = 0; j < T_c && j <= blockIdx.y ; j++) { // iterate of ver the chunks of K and V
float threadResults[TM * TN] = {0.0}; // storing the intermediate outputs
for (int t=0; t<num_tiles; t++){
// load K_j and V_j, thread idx, idy loads idy,idx
// we load a tile
for (int i=0; i<B_r; i+=strideK){
// load Q, K and V in tiles (for now we are loading the full V)
if (not CACHE_Q){Q_i[innerRowK+i][innerColK] = Q[batch_offset + (innerRowK + blockIdx.y * B_r) * d + i * d + innerColK + t * B_c];
} // if you cache Q over whole d then remove this line
id = (innerRowK + j * B_c) * d + i * d + innerColK + t * B_c;
if (id < d*seq_len){
K_j[innerRowK+i][innerColK] = K[batch_offset + id];
//V_j[innerRowK+i][innerColK+t*B_c] = V[batch_offset + id];
} else {
K_j[innerRowK+i][innerColK] = 0.0;
//V_j[innerRowK+i][innerColK+t*B_c] = 0.0;
}
}
__syncthreads();
for (int dd=0; dd<BK; dd++){ // load elements of Q_i and K_j^T into registers
for (uint i = 0; i < TM; ++i) {
if (CACHE_Q){
regM[i] = Q_i[(threadRow * TM + i)][dd+t*BK]; // uncomment if you cache Q over full d
} else {
regM[i] = Q_i[(threadRow * TM + i)][dd];
}
}
for (uint i = 0; i < TN; ++i) {
regN[i] = K_j[threadCol * TN + i][dd];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] += regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
}
// store the results in S_i, account for causal masking
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
if (j*B_c + threadCol * TN + resIdxN <= blockIdx.y * B_r + threadRow * TM + resIdxM){
S_i[(threadRow * TM + resIdxM)][threadCol * TN + resIdxN] = threadResults[resIdxM * TN + resIdxN] *scaling;
} else {
S_i[(threadRow * TM + resIdxM)][threadCol * TN + resIdxN] = -INFINITY;
}
}
}
__syncthreads();
for (int i=0;i<TM;++i){
last_m[i] = m_i[i];
float m = m_i[i];
for (int jj = 0; jj < B_c; jj += 1) {
if (m < S_i[threadRow*TM+i][jj]) {
m = S_i[threadRow*TM+i][jj];
}
}
m_i[i] = m;
}
// 2) renormalize current O
if (j > 0) {
for (int t = 0; t < num_tiles; t++){
for (int i=0;i<TM;++i){
for (int jj=0;jj<TN;++jj){
O_i[t*TN*TM + i*TN + jj] *= exp(last_m[i] - m_i[i]);
}
}
}
}
// 3) renormalize the sum l_i
for (int i=0;i<TM;++i){
l_i[i] *= exp(last_m[i] - m_i[i]);
}
// // 4) compute \exp(Q_iK^T_{j+1} - m^{j+1}) = \exp(S_i-m^{j+1}) // TODO: TO OPTIMIZE
// for (int dd = 0; dd < B_c; dd++) {
// for (int ii = 0; ii < TN; ii++){
// // calculate new sum and load exp(Attention) weights
// //check whether thus is in range or not (if not we set it to 0)
// //if (idrow+ii < seq_len && idcol+dd < seq_len){
// regM[ii] = exp(S_i[threadRow*TM+ii][dd] - m_i[ii]);
// l_i[ii] += regM[ii];
// }
// for (int t = 0; t < num_tiles; t++){
// for (int ii=0;ii<TN;ii++){
// for (int jj=0;jj<TM;jj++){ // calculate output elements
// regN[jj] = V_j[dd][t * B_c + threadCol * TN + jj];
// O_i[t*TN*TM + ii*TM + jj] += regM[ii] * regN[jj];
// }
// }
// }
// __syncthreads();
// }
for (int t = 0; t < num_tiles; t++){
// load V
__syncthreads();
for (int i=0; i<B_r; i+=strideK){
id = (innerRowK + j * B_c) * d + i * d + innerColK + t * B_c;
if (id < d*seq_len){
V_j[innerRowK+i][innerColK] = V[batch_offset + id];
} else {
V_j[innerRowK+i][innerColK] = 0.0;
}
}
__syncthreads();
for (int dd = 0; dd < B_c; dd++) {
for (int ii = 0; ii < TN; ii++){
regM[ii] = exp(S_i[threadRow*TM+ii][dd] - m_i[ii]);
if (t==0){
l_i[ii] += regM[ii];
}
regN[ii] = V_j[dd][threadCol * TN + ii];
}
for (int ii=0;ii<TN;ii++){
for (int jj=0;jj<TM;jj++){ // calculate output elements
regN[jj] = V_j[dd][threadCol * TN + jj];
O_i[t*TN*TM + ii*TM + jj] += regM[ii] * regN[jj];
}
}
}
__syncthreads();
}
}
// normalize by the output sum and write to out matrix
for (int t = 0; t < num_tiles; t++){
for (int ii=0;ii<TM;ii++){
for (int jj=0;jj<TN;jj++){
if(blockIdx.y*B_r+threadRow*TM+ii < seq_len){
out[batch_offset + (blockIdx.y * B_r + threadRow*TM + ii) * d + t * B_c + threadCol*TN + jj] = O_i[t*TN*TM+ii*TM+jj] / l_i[ii];
}
}
}
}
}
void run_flash_tiled(float* O, float* K_d, float* Q_d, float* V_d, int batch_size, int seq_len) {
dim3 blockDim(B_r, B_c);
dim3 gridDim(batch_size, (seq_len+B_r-1)/B_r);
flash_tiled<<<gridDim, blockDim>>>(O, K_d, Q_d, V_d, (float) 1.0, seq_len * d, (int) seq_len/B_r, (int) seq_len/B_c);
cudaDeviceSynchronize();
}
void run_flash_tiled_coarse(float* O, float* K_d, float* Q_d, float* V_d, int batch_size, int seq_len) {
dim3 blockDim(B_r/TN, B_c/TM);
dim3 gridDim(batch_size, (seq_len+B_r-1)/B_r);
flash_tiled_coarse<<<gridDim, blockDim>>>(O, K_d, Q_d, V_d, (float) 1.0, seq_len * d, (int) (seq_len+B_r-1)/B_r, (int) (seq_len+B_c-1)/B_c, seq_len);
cudaDeviceSynchronize();
}
void run_flash_tiled_coarse_causal(float* O, float* K_d, float* Q_d, float* V_d, int batch_size, int seq_len) {
dim3 blockDim(B_r/TN, B_c/TM);
dim3 gridDim(batch_size, (seq_len+B_r-1)/B_r);
flash_tiled_coarse_causal<<<gridDim, blockDim>>>(O, K_d, Q_d, V_d, (float) 1.0, seq_len * d, (int) (seq_len+B_r-1)/B_r, (int) (seq_len+B_c-1)/B_c, seq_len);
cudaDeviceSynchronize();
}
int main() {
int seq_len = 8192;
int batch_size = 8;
float *K_d, *Q_d, *V_d, *O;
cudaMalloc((void**)&O, seq_len * d * sizeof(float));
cudaMalloc((void**)&K_d, batch_size * seq_len * d * sizeof(float));
cudaMalloc((void**)&Q_d, batch_size * seq_len * d * sizeof(float));
cudaMalloc((void**)&V_d, batch_size * seq_len * d * sizeof(float));
// set K_d to 1
float *K_h = (float*) malloc(batch_size * seq_len * d * sizeof(float));
for (int i = 0; i < batch_size * seq_len * d; i++) {
K_h[i] = i;
}
cudaMemcpy(K_d, K_h, batch_size * seq_len * d * sizeof(float), cudaMemcpyHostToDevice);
// set Q_d to 1
float *Q_h = (float*) malloc(batch_size * seq_len * d * sizeof(float));
for (int i = 0; i < batch_size * seq_len * d; i++) {
Q_h[i] = i;
}
cudaMemcpy(Q_d, Q_h, batch_size * seq_len * d * sizeof(float), cudaMemcpyHostToDevice);
// set V_d to 1
float *V_h = (float*) malloc(batch_size * seq_len * d * sizeof(float));
for (int i = 0; i < batch_size * seq_len * d; i++) {
V_h[i] = 1.0;
}
cudaMemcpy(V_d, V_h, batch_size * seq_len * d * sizeof(float), cudaMemcpyHostToDevice);
float *O_h = (float*) malloc(batch_size * seq_len * d * sizeof(float));
double start, end;
start = getTimeStamp();
cudaDeviceSynchronize();
run_flash_tiled_coarse_causal(O, K_d, Q_d, V_d, batch_size, seq_len);
cudaDeviceSynchronize();
end = getTimeStamp();
printf("Time: %f\n", end - start);
cudaMemcpy(O_h, O, batch_size * seq_len * d * sizeof(float), cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
return 0;
}