-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfp6.cu
602 lines (476 loc) · 23.2 KB
/
fp6.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <torch/library.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdint.h>
#include <stdexcept>
#include <cstring>
class fp6_nan_inf : public std::invalid_argument {
public:
fp6_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6.") { }
};
class fp6_overflow : public std::invalid_argument {
public:
fp6_overflow() : std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf. Make sure input < 30.0") { }
};
// need to do this trick so that static_assert(false) only evaluates at template instantiation.
template <typename T> constexpr std::false_type always_false{};
// This implementation doesn't have a lot of bit manipulation, so it's less error-prone.
// On CPU, for FP32->FP6, bit manipulation (to_fp6_bits()) is 20% faster than this.
// On CUDA, dtype conversion kernels are memory-bound. Thus, using to_fp6_value() or
// to_fp6_bits() does not matter much. However, to_fp6_bits() has a lot of branching
// based on input value, thus it will cause warp divergence.
template <typename T>
__device__ __host__ static uint8_t to_fp6_value(T a) {
float fp32_value;
// need to use if constexpr so that the branches are pruned at compile-time.
// without it, expression in each branch must be valid regardless of template type T.
if constexpr (std::is_same_v<T, float>)
fp32_value = a;
else if constexpr (std::is_same_v<T, __half>)
fp32_value = __half2float(a);
else if constexpr (std::is_same_v<T, __nv_bfloat16>)
fp32_value = __bfloat162float(a);
else if constexpr (std::is_same_v<T, c10::Half> || std::is_same_v<T, c10::BFloat16>)
fp32_value = static_cast<float>(a);
else
static_assert(always_false<T>, "Only float, __half, __nv_bfloat16, c10::Half, and c10::BFloat16 are suppored");
#ifndef __CUDA_ARCH__
if (std::isnan(fp32_value) | std::isinf(fp32_value)) throw fp6_nan_inf();
if (std::abs(fp32_value) >= 30.0f) throw fp6_overflow();
#endif
fp32_value *= 0x1p-124; // 2^(127-3)
uint32_t bits;
std::memcpy(&bits, &fp32_value, sizeof(fp32_value));
uint8_t sign = bits >> 31u << 5u;
uint8_t exp_and_man = (bits >> 21u) & 0x1Fu;
uint8_t result = sign | exp_and_man;
// round to nearest even
uint32_t remainder = bits << 11u;
if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 1u))) {
result += 1;
}
return result;
}
// we need to do this because C++17 does not allow using struct as template non-type parameter
// use the upper 16 bits for num exponent, lower 16 bits for num mantissa
static constexpr uint32_t encode_fp_spec(uint32_t n_exp, uint32_t n_man) { return (n_exp << 16u) | n_man; }
static constexpr uint32_t FP32_SPEC = encode_fp_spec(8u, 23u);
static constexpr uint32_t FP16_SPEC = encode_fp_spec(5u, 10u);
static constexpr uint32_t BF16_SPEC = encode_fp_spec(8u, 7u);
// NOTE: only works for len < 32
__device__ __host__ static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; }
// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp"
template <typename T, uint32_t FP_SPEC>
__device__ __host__ static uint8_t to_fp6_bits(T bits) {
constexpr uint32_t N_EXP = FP_SPEC >> 16u;
constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u);
constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN;
// sanity checks. will be removed in template instantiation.
// minimum 1 bit above FP6 (3 exponent bits and 2 mantissa bits) to avoid edge cases.
static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4.");
static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3.");
T remainder = 0u;
T sign = bits >> N_EXP_MAN << 5u;
bits &= ones_mask(N_EXP_MAN); // clear sign bit
T result;
constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u;
// only checks for invalid values on CPU, since we can't throw exception in CUDA
#ifndef __CUDA_ARCH__
// all exponent bits are 1s
if (bits >= (ones_mask(N_EXP) << N_MAN)) throw fp6_nan_inf();
// max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3)
if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw fp6_overflow();
#endif
// FP6 normal number (E>=001)
if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) {
remainder = bits << (1u + N_EXP + 2u);
bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent
result = sign | (bits >> (N_MAN - 2u));
}
// FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5)
else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN)) {
T exp = bits >> N_MAN;
T man = bits & ones_mask(N_MAN);
// to make subnormal FP6 from normal FP16
// step 1: add implicit 1 to mantissa
man |= (1u << N_MAN);
// step 2: shift mantissa right so that exponent value is equal to
// exponent value of FP6 subnormal, which is -2 (equivalent to E=001)
T shift = EXP_BIAS_DIFF + 1u - exp;
remainder = man << (1u + N_EXP + 2u - shift);
result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000
}
// FP6 underflow. E=000, M=00
else {
result = sign;
}
// round to nearest even
constexpr T HALF_REMAINDER = 1u << N_EXP_MAN;
if ((remainder > HALF_REMAINDER) || ((remainder == HALF_REMAINDER) && (result & 0x1u))) {
result += 1;
}
return result;
}
// assume the lower 6 bits contain the data.
// NOTE: probably not efficient for FP6->FP16 and FP6->BF16 on CPU since FP32->FP16/BF16 is slow.
template <typename T>
__device__ __host__ static T from_fp6(uint8_t a) {
// we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32.
// this also handles subnormal numbers correctly.
// FP6: SE EEMM
// FP32: S000 00EE EMM0 0000 0000 0000 0000 0000
uint32_t bits = a; // bit extension
uint32_t sign = bits >> 5u << 31u;
uint32_t exp_and_man = (bits & 0x1Fu) << 21u;
uint32_t result_bits = sign | exp_and_man;
// the result will be off by the difference in exponent bias (3 in FP6 and 127 in FP32)
// we can correct this by direct FP32 multiplication, which also handles subnormal numbers.
float result;
std::memcpy(&result, &result_bits, sizeof(result));
result *= 0x1p124; // 2^(127-3)
return static_cast<T>(result);
}
namespace torchao {
template <typename T, uint32_t FP_SPEC> void to_fp6_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) {
// exception within OpenMP parallel region must be caught.
// set a flag when exception occurs, then re-raise it.
bool found_nan_inf = false;
bool found_overflow = false;
#pragma omp parallel for
for (int i = 0; i < n; i++) {
try { fp6_ptr[i] = to_fp6_bits<T, FP_SPEC>(bits_ptr[i]); }
catch (fp6_nan_inf) { found_nan_inf = true; }
catch (fp6_overflow) { found_overflow = true; }
}
if (found_nan_inf) throw fp6_nan_inf();
if (found_overflow) throw fp6_overflow();
}
// this is useful for debugging
at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) {
TORCH_CHECK(fp_tensor.is_contiguous());
TORCH_CHECK(fp_tensor.is_cpu());
at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device());
at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options);
uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp_tensor.numel();
auto dtype = fp_tensor.dtype();
if (dtype == torch::kFloat32) {
const uint32_t *fp32_ptr = reinterpret_cast<uint32_t *>(fp_tensor.data_ptr<float>());
to_fp6_unpacked_cpu_impl<uint32_t, FP32_SPEC>(fp32_ptr, fp6_ptr, n);
} else if (dtype == torch::kFloat16) {
const uint16_t *fp16_ptr = reinterpret_cast<uint16_t *>(fp_tensor.data_ptr<at::Half>());
to_fp6_unpacked_cpu_impl<uint16_t, FP16_SPEC>(fp16_ptr, fp6_ptr, n);
} else if (dtype == torch::kBFloat16) {
const uint16_t *bf16_ptr = reinterpret_cast<uint16_t *>(fp_tensor.data_ptr<at::BFloat16>());
to_fp6_unpacked_cpu_impl<uint16_t, BF16_SPEC>(bf16_ptr, fp6_ptr, n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted.");
}
return fp6_tensor;
}
template <typename T>
__global__ void to_fp6_unpacked_kernel(const T *fp_ptr, uint8_t *fp6_ptr, int n) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
// NOTE: we are writing 32 uint8 (32 bytes) to global memory. vector load can be used
// to improve memory throughput. using uchar4, we can issue 128-byte global memory write.
if (idx < n)
fp6_ptr[idx] = to_fp6_value(fp_ptr[idx]);
}
// this is useful for debugging
at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) {
TORCH_CHECK(fp_tensor.is_contiguous());
TORCH_CHECK(fp_tensor.is_cuda());
at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device());
at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options);
uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp_tensor.numel();
auto dtype = fp_tensor.dtype();
constexpr int block_size = 256;
const int grid_size = (n + block_size - 1) / block_size;
if (dtype == torch::kFloat32) {
const float *fp32_ptr = fp_tensor.data_ptr<float>();
to_fp6_unpacked_kernel<<<grid_size, block_size>>>(fp32_ptr, fp6_ptr, n);
} else if (dtype == torch::kFloat16) {
const at::Half *fp16_ptr = fp_tensor.data_ptr<at::Half>();
to_fp6_unpacked_kernel<<<grid_size, block_size>>>(fp16_ptr, fp6_ptr, n);
} else if (dtype == torch::kBFloat16) {
const at::BFloat16 *bf16_ptr = fp_tensor.data_ptr<at::BFloat16>();
to_fp6_unpacked_kernel<<<grid_size, block_size>>>(bf16_ptr, fp6_ptr, n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted.");
}
return fp6_tensor;
}
template <typename T, uint32_t FP_SPEC> void to_fp6_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) {
// exception within OpenMP parallel region must be caught.
// set a flag when exception occurs, then re-raise it.
bool found_nan_inf = false;
bool found_overflow = false;
#pragma omp parallel for
for (int i = 0; i < n / 4; i++) {
try {
uint8_t val0 = to_fp6_bits<T, FP_SPEC>(bits_ptr[i * 4]);
uint8_t val1 = to_fp6_bits<T, FP_SPEC>(bits_ptr[i * 4 + 1]);
uint8_t val2 = to_fp6_bits<T, FP_SPEC>(bits_ptr[i * 4 + 2]);
uint8_t val3 = to_fp6_bits<T, FP_SPEC>(bits_ptr[i * 4 + 3]);
fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011
fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222
fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333
}
catch (fp6_nan_inf) { found_nan_inf = true; }
catch (fp6_overflow) { found_overflow = true; }
}
if (found_nan_inf) throw fp6_nan_inf();
if (found_overflow) throw fp6_overflow();
}
at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) {
TORCH_CHECK(fp_tensor.is_contiguous());
TORCH_CHECK(fp_tensor.is_cpu());
TORCH_CHECK(fp_tensor.ndimension() == 2);
int M = fp_tensor.size(0);
int N = fp_tensor.size(1);
TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N);
at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device());
at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options);
uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp_tensor.numel();
auto dtype = fp_tensor.dtype();
if (dtype == torch::kFloat32) {
const uint32_t *fp32_ptr = reinterpret_cast<uint32_t *>(fp_tensor.data_ptr<float>());
to_fp6_packed_cpu_impl<uint32_t, FP32_SPEC>(fp32_ptr, fp6_ptr, n);
} else if (dtype == torch::kFloat16) {
const uint16_t *fp16_ptr = reinterpret_cast<uint16_t *>(fp_tensor.data_ptr<at::Half>());
to_fp6_packed_cpu_impl<uint16_t, FP16_SPEC>(fp16_ptr, fp6_ptr, n);
} else if (dtype == torch::kBFloat16) {
const uint16_t *bf16_ptr = reinterpret_cast<uint16_t *>(fp_tensor.data_ptr<at::BFloat16>());
to_fp6_packed_cpu_impl<uint16_t, BF16_SPEC>(bf16_ptr, fp6_ptr, n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted.");
}
return fp6_tensor;
}
// define our own vector types since NVIDIA doesn't provide them.
typedef struct __align__(8) { __half x, y, z, w; } fp16_vec4;
typedef struct __align__(8) { __nv_bfloat16 x, y, z, w; } bf16_vec4;
template <typename T, int BLOCK_SIZE>
__global__ void to_fp6_packed_kernel(const T *fp_ptr, uint8_t *fp6_ptr, int n) {
const int tid = threadIdx.x;
const int input_offset = (blockIdx.x * blockDim.x) * 4;
const int output_offset = (blockIdx.x * blockDim.x) * 3;
fp_ptr += input_offset;
fp6_ptr += output_offset;
__shared__ uint8_t shmem[BLOCK_SIZE * 3];
if (input_offset + tid * 4 < n) {
uint8_t val0, val1, val2, val3;
// vector load for coalesced memory read
if constexpr (std::is_same_v<T, float>) {
float4 values = reinterpret_cast<const float4 *>(fp_ptr)[tid];
val0 = to_fp6_value(values.x);
val1 = to_fp6_value(values.y);
val2 = to_fp6_value(values.z);
val3 = to_fp6_value(values.w);
} else if constexpr (std::is_same_v<T, at::Half> || std::is_same_v<T, __half>) {
fp16_vec4 values = reinterpret_cast<const fp16_vec4 *>(fp_ptr)[tid];
val0 = to_fp6_value(values.x);
val1 = to_fp6_value(values.y);
val2 = to_fp6_value(values.z);
val3 = to_fp6_value(values.w);
} else if constexpr (std::is_same_v<T, at::BFloat16> || std::is_same_v<T, __nv_bfloat16>) {
bf16_vec4 values = reinterpret_cast<const bf16_vec4 *>(fp_ptr)[tid];
val0 = to_fp6_value(values.x);
val1 = to_fp6_value(values.y);
val2 = to_fp6_value(values.z);
val3 = to_fp6_value(values.w);
} else {
// fallback. no coalesced memory access. (assert false instead?)
val0 = to_fp6_value(fp_ptr[tid * 4]);
val1 = to_fp6_value(fp_ptr[tid * 4 + 1]);
val2 = to_fp6_value(fp_ptr[tid * 4 + 2]);
val3 = to_fp6_value(fp_ptr[tid * 4 + 3]);
}
shmem[tid * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011
shmem[tid * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222
shmem[tid * 3 + 2] = (val2 << 6) | (val3); // 2233 3333
}
__syncthreads();
// coalesced memory write
// TODO: write in larger word size
for (int i = 0; i < 3; i++) {
if (output_offset + BLOCK_SIZE * i + tid < n / 4 * 3) {
fp6_ptr[BLOCK_SIZE * i + tid] = shmem[BLOCK_SIZE * i + tid];
}
}
}
at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) {
TORCH_CHECK(fp_tensor.is_contiguous());
TORCH_CHECK(fp_tensor.is_cuda());
TORCH_CHECK(fp_tensor.ndimension() == 2);
int M = fp_tensor.size(0);
int N = fp_tensor.size(1);
TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N);
at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device());
at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options);
uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp_tensor.numel();
auto dtype = fp_tensor.dtype();
// times 4 since each thread will handle 4 values
constexpr int block_size = 256;
const int grid_size = (n + (block_size * 4) - 1) / (block_size * 4);
if (dtype == torch::kFloat32) {
const float *fp32_ptr = fp_tensor.data_ptr<float>();
to_fp6_packed_kernel<float, block_size><<<grid_size, block_size>>>(fp32_ptr, fp6_ptr, n);
} else if (dtype == torch::kFloat16) {
const at::Half *fp16_ptr = fp_tensor.data_ptr<at::Half>();
to_fp6_packed_kernel<at::Half, block_size><<<grid_size, block_size>>>(fp16_ptr, fp6_ptr, n);
} else if (dtype == torch::kBFloat16) {
const at::BFloat16 *bf16_ptr = fp_tensor.data_ptr<at::BFloat16>();
to_fp6_packed_kernel<at::BFloat16, block_size><<<grid_size, block_size>>>(bf16_ptr, fp6_ptr, n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted.");
}
return fp6_tensor;
}
template <typename T>
void from_fp6_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) {
#pragma omp parallel for
for (int i = 0; i < n; i++)
fp_ptr[i] = from_fp6<T>(fp6_ptr[i]);
}
at::Tensor from_fp6_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) {
TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8);
TORCH_CHECK(fp6_tensor.is_contiguous());
TORCH_CHECK(fp6_tensor.is_cpu());
at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device());
at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options);
const uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp6_tensor.numel();
if (dtype == torch::kFloat32) {
from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr<float>(), n);
} else if (dtype == torch::kFloat16) {
from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr<at::Half>(), n);
} else if (dtype == torch::kBFloat16) {
from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr<at::BFloat16>(), n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted.");
}
return fp_tensor;
}
template <typename T>
__global__ void from_fp6_unpacked_kernel(const uint8_t *fp6_ptr, T *fp_ptr, int n) {
// TODO: use vector load for reading from global memory
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n)
fp_ptr[idx] = from_fp6<T>(fp6_ptr[idx]);
}
at::Tensor from_fp6_unpacked_cuda(at::Tensor fp6_tensor, c10::ScalarType dtype) {
TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8);
TORCH_CHECK(fp6_tensor.is_contiguous());
TORCH_CHECK(fp6_tensor.is_cuda());
at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device());
at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options);
const uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp6_tensor.numel();
constexpr int block_size = 256;
const int grid_size = (n + block_size - 1) / block_size;
if (dtype == torch::kFloat32) {
from_fp6_unpacked_kernel<<<grid_size, block_size>>>(fp6_ptr, fp_tensor.data_ptr<float>(), n);
} else if (dtype == torch::kFloat16) {
from_fp6_unpacked_kernel<<<grid_size, block_size>>>(fp6_ptr, fp_tensor.data_ptr<at::Half>(), n);
} else if (dtype == torch::kBFloat16) {
from_fp6_unpacked_kernel<<<grid_size, block_size>>>(fp6_ptr, fp_tensor.data_ptr<at::BFloat16>(), n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted.");
}
return fp_tensor;
}
template <typename T>
void from_fp6_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) {
#pragma omp parallel for
for (int i = 0; i < n / 3; i++) {
uint8_t bits0 = fp6_ptr[i * 3]; // 0000 0011
uint8_t bits1 = fp6_ptr[i * 3 + 1]; // 1111 2222
uint8_t bits2 = fp6_ptr[i * 3 + 2]; // 2233 3333
fp_ptr[i * 4] = from_fp6<T>(bits0 >> 2);
fp_ptr[i * 4 + 1] = from_fp6<T>(((bits0 & 0x3u) << 4) | (bits1 >> 4));
fp_ptr[i * 4 + 2] = from_fp6<T>(((bits1 & 0xFu) << 2) | (bits2 >> 6));
fp_ptr[i * 4 + 3] = from_fp6<T>(bits2 & 0x3Fu);
}
}
at::Tensor from_fp6_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) {
TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8);
TORCH_CHECK(fp6_tensor.is_contiguous());
TORCH_CHECK(fp6_tensor.is_cpu());
TORCH_CHECK(fp6_tensor.ndimension() == 2);
int M = fp6_tensor.size(0);
int N = fp6_tensor.size(1);
TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N);
at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device());
at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options);
const uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp6_tensor.numel();
if (dtype == torch::kFloat32) {
from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr<float>(), n);
} else if (dtype == torch::kFloat16) {
from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr<at::Half>(), n);
} else if (dtype == torch::kBFloat16) {
from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr<at::BFloat16>(), n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted.");
}
return fp_tensor;
}
template <typename T>
__global__ void from_fp6_packed_kernel(const uint8_t *fp6_ptr, T *fp_ptr, int n) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n / 3) {
// TODO: use vector load for reading from global memory
uint8_t bits0 = fp6_ptr[idx * 3]; // 0000 0011
uint8_t bits1 = fp6_ptr[idx * 3 + 1]; // 1111 2222
uint8_t bits2 = fp6_ptr[idx * 3 + 2]; // 2233 3333
fp_ptr[idx * 4] = from_fp6<T>(bits0 >> 2);
fp_ptr[idx * 4 + 1] = from_fp6<T>(((bits0 & 0x3u) << 4) | (bits1 >> 4));
fp_ptr[idx * 4 + 2] = from_fp6<T>(((bits1 & 0xFu) << 2) | (bits2 >> 6));
fp_ptr[idx * 4 + 3] = from_fp6<T>(bits2 & 0x3Fu);
}
}
at::Tensor from_fp6_packed_cuda(at::Tensor fp6_tensor, c10::ScalarType dtype) {
TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8);
TORCH_CHECK(fp6_tensor.is_contiguous());
TORCH_CHECK(fp6_tensor.is_cuda());
TORCH_CHECK(fp6_tensor.ndimension() == 2);
int M = fp6_tensor.size(0);
int N = fp6_tensor.size(1);
TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N);
at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device());
at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options);
const uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
int n = fp6_tensor.numel();
// times 3 because each thread read 3 bytes (which represent 4 FP6 values)
constexpr int block_size = 256;
const int grid_size = (n + block_size * 3 - 1) / (block_size * 3);
if (dtype == torch::kFloat32) {
from_fp6_packed_kernel<<<grid_size, block_size>>>(fp6_ptr, fp_tensor.data_ptr<float>(), n);
} else if (dtype == torch::kFloat16) {
from_fp6_packed_kernel<<<grid_size, block_size>>>(fp6_ptr, fp_tensor.data_ptr<at::Half>(), n);
} else if (dtype == torch::kBFloat16) {
from_fp6_packed_kernel<<<grid_size, block_size>>>(fp6_ptr, fp_tensor.data_ptr<at::BFloat16>(), n);
} else {
throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted.");
}
return fp_tensor;
}
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cpu);
m.impl("torchao::to_fp6_packed", &to_fp6_packed_cpu);
m.impl("torchao::from_fp6_unpacked", &from_fp6_unpacked_cpu);
m.impl("torchao::from_fp6_packed", &from_fp6_packed_cpu);
}
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cuda);
m.impl("torchao::to_fp6_packed", &to_fp6_packed_cuda);
m.impl("torchao::from_fp6_unpacked", &from_fp6_unpacked_cuda);
m.impl("torchao::from_fp6_packed", &from_fp6_packed_cuda);
}
}