Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jcaip/llm bsr #1601

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
74 changes: 45 additions & 29 deletions benchmarks/benchmark_gpu_sparsity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import Callable, List, Optional, Tuple

import pandas as pd
import torch
Expand All @@ -11,7 +12,9 @@
create_block_sparse_tensor,
create_semi_structured_tensor,
)
from torchao.utils import benchmark_model
import torch.utils.benchmark as benchmark

from torchao.sparsity.blocksparse import BlockSparseTensor

torch.set_printoptions(
precision=2,
Expand All @@ -27,6 +30,17 @@ def benchmark_model_with_warmup(func, x, N_WARMUP=3):
benchmark_model(func, N_WARMUP, device_type="cuda")
return benchmark_model(func, 10, device_type="cuda")

def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(1):
func(*args, **kwargs)
# t0 = benchmark.Timer(
# stmt="func(*args, **kwargs)",
# globals={"args": args, "kwargs": kwargs, "func": func},
# )
# return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
return 1


def run_gpu_sparse_benchmark(m, k, n, args):
with torch.no_grad():
Expand All @@ -43,7 +57,8 @@ def run_gpu_sparse_benchmark(m, k, n, args):
A = create_block_sparse_tensor(
m, k, args.block_size, args.sparsity_level, dtype
)
A_sparse = A.to_sparse_bsr(blocksize=args.block_size)
# A_sparse = A.to_sparse_bsr(blocksize=args.block_size)
A_sparse = BlockSparseTensor.from_dense(A, args.block_size).detach()
# BSR kernel tuning
if args.bsr_autotune:
print("Tuning kernel params")
Expand All @@ -61,13 +76,16 @@ def run_gpu_sparse_benchmark(m, k, n, args):
raise ValueError(f"Unknown sparsity: {args.sparsity}")

if args.eval_fn == "linear":
b = torch.randn(m, dtype=dtype).cuda()
# b = torch.randn(m, dtype=dtype).cuda()
b = None

# can't use lambda
def dense_func():
@torch.compile(mode="max-autotune")
def dense_func(x):
return F.linear(x, A, b)

def sparse_func():
@torch.compile(mode="max-autotune")
def sparse_func(x):
return F.linear(x, A_sparse, b)

elif args.eval_fn == "mm":
Expand Down Expand Up @@ -101,20 +119,17 @@ def sparse_func():
else:
raise ValueError(f"Unknown eval_fn: {args.eval_fn}")

dense_time = benchmark_model_with_warmup(dense_func, "dense.json.gz")
sparse_time = benchmark_model_with_warmup(sparse_func, "sparse.json.gz")

dense_func_c = torch.compile(dense_func, mode="max-autotune")
dense_time_c = benchmark_model_with_warmup(
dense_func_c, "dense_compile.json.gz"
)

sparse_func_c = torch.compile(sparse_func, mode="max-autotune")
sparse_time_c = benchmark_model_with_warmup(
sparse_func_c, "sparse_compile.json.gz"
)
# print(x)
# print(A)
# print(A_sparse.crow_indices())
# print(A_sparse.col_indices())
# print(A_sparse.values())
dense_time, sparse_time = 0, 0
dense_time_c, sparse_time_c = 1, 1

torch._dynamo.reset()
#WARMUP
# dense_time_c = benchmark_torch_function_in_microseconds(dense_func, x)
sparse_time_c = benchmark_torch_function_in_microseconds(sparse_func, x)

return {
"test_function": args.eval_fn,
Expand All @@ -126,8 +141,7 @@ def sparse_func():
"dense": dense_time,
"dense_c": dense_time_c,
"sparse_c": sparse_time_c,
"speedup (d/s)": min(dense_time, dense_time_c)
/ min(sparse_time, sparse_time_c),
"speedup (d/s)": dense_time_c / sparse_time_c,
}


Expand Down Expand Up @@ -200,15 +214,17 @@ def sparse_func():
)
elif args.mode == "llama3-8b-w":
mm_shapes = [
(16, 4096, 11008),
(16, 4096, 4096),
(16, 11008, 4096),
(4096, 4096, 11008),
(4096, 4096, 4096),
(4096, 11008, 4096),
(8192, 4096, 11008),
(8192, 4096, 4096),
(8192, 11008, 4096),
# (32, 32, 16),
(4096, 14336, 1),
# (14336, 4096, 1),
# (11008, 4096, 16),
# (16, 4096, 4096),
# (4096, 4096, 11008),
# (4096, 4096, 4096),
# (4096, 11008, 4096),
# (8192, 4096, 11008),
# (8192, 4096, 4096),
# (8192, 11008, 4096),
]
results = (
run_gpu_sparse_benchmark(m, k, n, args) for (m, k, n) in tqdm(mm_shapes)
Expand Down
66 changes: 66 additions & 0 deletions test/sparsity/test_supermask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import copy
import logging
import unittest
import math

import torch
from torch import nn
from torch.testing._internal import common_utils

from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
)
from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
)

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

class TestSupermask(common_utils.TestCase):

@common_utils.parametrize("sparsity_level", [0.25, 0.5])
@common_utils.parametrize("blocksize", [2, 4, 8])
def test_supermask(self, sparsity_level, blocksize):
input = torch.randn((1, 16)).half().cuda()
model = (
nn.Sequential(
nn.Linear(16, 16, bias=False),
)
.half()
.cuda()
.eval()
)

from torchao.sparsity import SupermaskLinear

M, N = model[0].weight.shape
sparsify_(model, lambda x: SupermaskLinear.from_linear(x, sparsity_level=sparsity_level, blocksize=blocksize))
sparsify_(model, SupermaskLinear.to_linear)
weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize)

# Test correct sparsity level
nnz = weight_bsr._nnz()
expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level))
assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}"

def test_from_linear(self):
from torchao.sparsity import SupermaskLinear
linear = nn.Linear(128, 128)
supermask_linear = SupermaskLinear.from_linear(linear, sparsity_level=0.5, blocksize=4)
assert supermask_linear.weight.shape == linear.weight.shape


common_utils.instantiate_parametrized_tests(TestSupermask)


if __name__ == "__main__":
unittest.main()
Loading
Loading