Skip to content

Commit

Permalink
Add CUTLASS-based row-wise scaled sparse FP8 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Feb 9, 2025
1 parent c8eb8d3 commit 6983b61
Show file tree
Hide file tree
Showing 10 changed files with 975 additions and 1 deletion.
62 changes: 62 additions & 0 deletions benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.ops import (
rowwise_scaled_linear_sparse_cutlass_f8f8,
to_sparse_semi_structured_cutlass_sm9x_f8,
)
from torchao.sparsity.utils import create_semi_structured_tensor


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_problem(m: int, n: int, k: int):
dev = torch.device("cuda")

A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn)
B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B_sp, B_meta, B_scale, C


def benchmark(m: int, k: int, n: int):
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B_sp, B_meta, B_scale, C = get_problem(m, n, k)
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
"f8f8 speedup (d/s)": fp16_time / rowwise_scaled_linear_sparse_cutlass_f8f8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,15 @@ def get_extensions():
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_util_include_dir = os.path.join(cutlass_dir, "tools", "util", "include")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
if use_cutlass:
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_util_include_dir,
"-I" + cutlass_extensions_include_dir,
]
)
Expand Down
128 changes: 128 additions & 0 deletions test/test_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import itertools
import random

import pytest
import torch
from torch.testing._internal.common_cuda import SM90OrLater

from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx,
)
from torchao.ops import (
rowwise_scaled_linear_sparse_cutlass_f8f8,
to_sparse_semi_structured_cutlass_sm9x_f8,
)
from torchao.sparsity.utils import create_semi_structured_tensor


X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)]
XQ_WQ_DTYPES = [
(torch.float8_e5m2, torch.float8_e4m3fn),
(torch.float8_e4m3fn, torch.float8_e4m3fn),
]
BATCH_SIZE = [1, 4]
SIZE_MNK = [
(2, 128, 256),
(3, 128, 256),
(13, 128, 256),
(27, 128, 128),
(33, 128, 64),
(65, 128, 32),
]
USE_BIAS = [False, True]
BIAS_DTYPE = [torch.float16]
TEST_PARAMS = list(
itertools.product(
X_W_DTYPES,
XQ_WQ_DTYPES,
BATCH_SIZE,
SIZE_MNK,
USE_BIAS,
BIAS_DTYPE,
)
)


def run_test_for_op(
op,
x_dtype,
w_dtype,
xq_dtype,
wq_dtype,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda")
w = create_semi_structured_tensor(size_n, size_k, dtype=w_dtype)
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None

block_size = [1] * (x.dim() - 1) + [x.shape[-1]]
x_aqt = to_affine_quantized_floatx(
input_float=x,
target_dtype=xq_dtype,
block_size=block_size,
_layout=Float8Layout(mm_config=None),
)
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain()
assert zero_points is None

block_size = [1] * (w.dim() - 1) + [w.shape[-1]]
w_aqt = to_affine_quantized_floatx(
input_float=w,
target_dtype=wq_dtype,
block_size=block_size,
_layout=Float8Layout(mm_config=None),
)
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain()
assert zero_points is None
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq)
wq_sp_scales = wq_scales

xq_2d = xq.view(-1, xq.shape[-1])
size_m_2d = xq_2d.shape[0]
output_ref = (
(xq_2d.float() @ wq.float().T)
* xq_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias)
try:
output = op(*fn_inputs)
except NotImplementedError:
pytest.xfail("operator not implemented")

torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=5e-3)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices")
@pytest.mark.parametrize(
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype",
TEST_PARAMS,
)
def test_rowwise_scaled_liner_sparse_cutlass_f8f8(
x_w_dtypes,
xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
run_test_for_op(
rowwise_scaled_linear_sparse_cutlass_f8f8,
*x_w_dtypes,
*xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cutlass/cutlass.h>
#include <torch/library.h>

#include "rowwise_scaled_linear_cutlass.cuh"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cutlass/cutlass.h>
#include <torch/library.h>

#include "rowwise_scaled_linear_cutlass.cuh"
Expand Down
Loading

0 comments on commit 6983b61

Please sign in to comment.