-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CUTLASS-based row-wise scaled sparse FP8 kernel
- Loading branch information
1 parent
c8eb8d3
commit 6983b61
Showing
10 changed files
with
975 additions
and
1 deletion.
There are no files selected for viewing
62 changes: 62 additions & 0 deletions
62
benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import pandas as pd | ||
import torch | ||
from tqdm import tqdm | ||
from triton.testing import do_bench | ||
|
||
from torchao.ops import ( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
to_sparse_semi_structured_cutlass_sm9x_f8, | ||
) | ||
from torchao.sparsity.utils import create_semi_structured_tensor | ||
|
||
|
||
def benchmark_microseconds(f, *args): | ||
return do_bench(lambda: f(*args), return_mode="median") * 1e3 | ||
|
||
|
||
def get_problem(m: int, n: int, k: int): | ||
dev = torch.device("cuda") | ||
|
||
A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2) | ||
A_scale = torch.randn((m,), dtype=torch.half, device=dev) | ||
B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn) | ||
B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B) | ||
B_scale = torch.randn((n,), dtype=torch.half, device=dev) | ||
C = None | ||
|
||
return A, A_scale, B_sp, B_meta, B_scale, C | ||
|
||
|
||
def benchmark(m: int, k: int, n: int): | ||
dev = torch.device("cuda") | ||
A_ref = torch.randn((m, k), dtype=torch.half, device=dev) | ||
B_ref = torch.randn((n, k), dtype=torch.half, device=dev) | ||
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) | ||
|
||
A, A_scale, B_sp, B_meta, B_scale, C = get_problem(m, n, k) | ||
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale, C | ||
) | ||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"fp16_latency (ms)": fp16_time, | ||
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time, | ||
"f8f8 speedup (d/s)": fp16_time / rowwise_scaled_linear_sparse_cutlass_f8f8_time, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (8192, 10240, 57344, 8192) | ||
|
||
results = [] | ||
for m in tqdm([1 << i for i in range(10)]): | ||
for n, k in zip(n_vals, k_vals): | ||
results.append(benchmark(m, k, n)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import itertools | ||
import random | ||
|
||
import pytest | ||
import torch | ||
from torch.testing._internal.common_cuda import SM90OrLater | ||
|
||
from torchao.dtypes import ( | ||
Float8Layout, | ||
to_affine_quantized_floatx, | ||
) | ||
from torchao.ops import ( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
to_sparse_semi_structured_cutlass_sm9x_f8, | ||
) | ||
from torchao.sparsity.utils import create_semi_structured_tensor | ||
|
||
|
||
X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)] | ||
XQ_WQ_DTYPES = [ | ||
(torch.float8_e5m2, torch.float8_e4m3fn), | ||
(torch.float8_e4m3fn, torch.float8_e4m3fn), | ||
] | ||
BATCH_SIZE = [1, 4] | ||
SIZE_MNK = [ | ||
(2, 128, 256), | ||
(3, 128, 256), | ||
(13, 128, 256), | ||
(27, 128, 128), | ||
(33, 128, 64), | ||
(65, 128, 32), | ||
] | ||
USE_BIAS = [False, True] | ||
BIAS_DTYPE = [torch.float16] | ||
TEST_PARAMS = list( | ||
itertools.product( | ||
X_W_DTYPES, | ||
XQ_WQ_DTYPES, | ||
BATCH_SIZE, | ||
SIZE_MNK, | ||
USE_BIAS, | ||
BIAS_DTYPE, | ||
) | ||
) | ||
|
||
|
||
def run_test_for_op( | ||
op, | ||
x_dtype, | ||
w_dtype, | ||
xq_dtype, | ||
wq_dtype, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
size_m, size_n, size_k = size_mnk | ||
|
||
x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda") | ||
w = create_semi_structured_tensor(size_n, size_k, dtype=w_dtype) | ||
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None | ||
|
||
block_size = [1] * (x.dim() - 1) + [x.shape[-1]] | ||
x_aqt = to_affine_quantized_floatx( | ||
input_float=x, | ||
target_dtype=xq_dtype, | ||
block_size=block_size, | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
|
||
block_size = [1] * (w.dim() - 1) + [w.shape[-1]] | ||
w_aqt = to_affine_quantized_floatx( | ||
input_float=w, | ||
target_dtype=wq_dtype, | ||
block_size=block_size, | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq) | ||
wq_sp_scales = wq_scales | ||
|
||
xq_2d = xq.view(-1, xq.shape[-1]) | ||
size_m_2d = xq_2d.shape[0] | ||
output_ref = ( | ||
(xq_2d.float() @ wq.float().T) | ||
* xq_scales.view(size_m_2d, 1) | ||
* wq_scales.view(1, size_n) | ||
) | ||
if bias is not None: | ||
output_ref += bias | ||
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,)) | ||
|
||
fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias) | ||
try: | ||
output = op(*fn_inputs) | ||
except NotImplementedError: | ||
pytest.xfail("operator not implemented") | ||
|
||
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=5e-3) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") | ||
@pytest.mark.parametrize( | ||
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype", | ||
TEST_PARAMS, | ||
) | ||
def test_rowwise_scaled_liner_sparse_cutlass_f8f8( | ||
x_w_dtypes, | ||
xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
run_test_for_op( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
*x_w_dtypes, | ||
*xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
) |
1 change: 1 addition & 0 deletions
1
torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#include <cutlass/cutlass.h> | ||
#include <torch/library.h> | ||
|
||
#include "rowwise_scaled_linear_cutlass.cuh" | ||
|
1 change: 1 addition & 0 deletions
1
torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#include <cutlass/cutlass.h> | ||
#include <torch/library.h> | ||
|
||
#include "rowwise_scaled_linear_cutlass.cuh" | ||
|
Oops, something went wrong.