diff --git a/benchmarks/benchmark_gpu_sparsity.py b/benchmarks/benchmark_gpu_sparsity.py index 9e22f6d43a..3918622b25 100644 --- a/benchmarks/benchmark_gpu_sparsity.py +++ b/benchmarks/benchmark_gpu_sparsity.py @@ -1,4 +1,5 @@ import argparse +from typing import Callable, List, Optional, Tuple import pandas as pd import torch @@ -11,7 +12,9 @@ create_block_sparse_tensor, create_semi_structured_tensor, ) -from torchao.utils import benchmark_model +import torch.utils.benchmark as benchmark + +from torchao.sparsity.blocksparse import BlockSparseTensor torch.set_printoptions( precision=2, @@ -27,6 +30,17 @@ def benchmark_model_with_warmup(func, x, N_WARMUP=3): benchmark_model(func, N_WARMUP, device_type="cuda") return benchmark_model(func, 10, device_type="cuda") +def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: + # warmup + for _ in range(1): + func(*args, **kwargs) + # t0 = benchmark.Timer( + # stmt="func(*args, **kwargs)", + # globals={"args": args, "kwargs": kwargs, "func": func}, + # ) + # return t0.adaptive_autorange(min_run_time=0.1).median * 1e6 + return 1 + def run_gpu_sparse_benchmark(m, k, n, args): with torch.no_grad(): @@ -43,7 +57,8 @@ def run_gpu_sparse_benchmark(m, k, n, args): A = create_block_sparse_tensor( m, k, args.block_size, args.sparsity_level, dtype ) - A_sparse = A.to_sparse_bsr(blocksize=args.block_size) + # A_sparse = A.to_sparse_bsr(blocksize=args.block_size) + A_sparse = BlockSparseTensor.from_dense(A, args.block_size).detach() # BSR kernel tuning if args.bsr_autotune: print("Tuning kernel params") @@ -61,13 +76,16 @@ def run_gpu_sparse_benchmark(m, k, n, args): raise ValueError(f"Unknown sparsity: {args.sparsity}") if args.eval_fn == "linear": - b = torch.randn(m, dtype=dtype).cuda() + # b = torch.randn(m, dtype=dtype).cuda() + b = None # can't use lambda - def dense_func(): + @torch.compile(mode="max-autotune") + def dense_func(x): return F.linear(x, A, b) - def sparse_func(): + @torch.compile(mode="max-autotune") + def sparse_func(x): return F.linear(x, A_sparse, b) elif args.eval_fn == "mm": @@ -101,20 +119,17 @@ def sparse_func(): else: raise ValueError(f"Unknown eval_fn: {args.eval_fn}") - dense_time = benchmark_model_with_warmup(dense_func, "dense.json.gz") - sparse_time = benchmark_model_with_warmup(sparse_func, "sparse.json.gz") - - dense_func_c = torch.compile(dense_func, mode="max-autotune") - dense_time_c = benchmark_model_with_warmup( - dense_func_c, "dense_compile.json.gz" - ) - - sparse_func_c = torch.compile(sparse_func, mode="max-autotune") - sparse_time_c = benchmark_model_with_warmup( - sparse_func_c, "sparse_compile.json.gz" - ) + # print(x) + # print(A) + # print(A_sparse.crow_indices()) + # print(A_sparse.col_indices()) + # print(A_sparse.values()) + dense_time, sparse_time = 0, 0 + dense_time_c, sparse_time_c = 1, 1 - torch._dynamo.reset() + #WARMUP + # dense_time_c = benchmark_torch_function_in_microseconds(dense_func, x) + sparse_time_c = benchmark_torch_function_in_microseconds(sparse_func, x) return { "test_function": args.eval_fn, @@ -126,8 +141,7 @@ def sparse_func(): "dense": dense_time, "dense_c": dense_time_c, "sparse_c": sparse_time_c, - "speedup (d/s)": min(dense_time, dense_time_c) - / min(sparse_time, sparse_time_c), + "speedup (d/s)": dense_time_c / sparse_time_c, } @@ -200,15 +214,17 @@ def sparse_func(): ) elif args.mode == "llama3-8b-w": mm_shapes = [ - (16, 4096, 11008), - (16, 4096, 4096), - (16, 11008, 4096), - (4096, 4096, 11008), - (4096, 4096, 4096), - (4096, 11008, 4096), - (8192, 4096, 11008), - (8192, 4096, 4096), - (8192, 11008, 4096), + # (32, 32, 16), + (4096, 14336, 1), + # (14336, 4096, 1), + # (11008, 4096, 16), + # (16, 4096, 4096), + # (4096, 4096, 11008), + # (4096, 4096, 4096), + # (4096, 11008, 4096), + # (8192, 4096, 11008), + # (8192, 4096, 4096), + # (8192, 11008, 4096), ] results = ( run_gpu_sparse_benchmark(m, k, n, args) for (m, k, n) in tqdm(mm_shapes) diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py new file mode 100644 index 0000000000..233826163f --- /dev/null +++ b/test/sparsity/test_supermask.py @@ -0,0 +1,66 @@ +import copy +import logging +import unittest +import math + +import torch +from torch import nn +from torch.testing._internal import common_utils + +from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout +from torchao.quantization.quant_api import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + quantize_, +) +from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, +) + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +class TestSupermask(common_utils.TestCase): + + @common_utils.parametrize("sparsity_level", [0.25, 0.5]) + @common_utils.parametrize("blocksize", [2, 4, 8]) + def test_supermask(self, sparsity_level, blocksize): + input = torch.randn((1, 16)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(16, 16, bias=False), + ) + .half() + .cuda() + .eval() + ) + + from torchao.sparsity import SupermaskLinear + + M, N = model[0].weight.shape + sparsify_(model, lambda x: SupermaskLinear.from_linear(x, sparsity_level=sparsity_level, blocksize=blocksize)) + sparsify_(model, SupermaskLinear.to_linear) + weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize) + + # Test correct sparsity level + nnz = weight_bsr._nnz() + expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level)) + assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}" + + def test_from_linear(self): + from torchao.sparsity import SupermaskLinear + linear = nn.Linear(128, 128) + supermask_linear = SupermaskLinear.from_linear(linear, sparsity_level=0.5, blocksize=4) + assert supermask_linear.weight.shape == linear.weight.shape + + +common_utils.instantiate_parametrized_tests(TestSupermask) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index d59c5f552e..179d308ce4 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -50,3 +50,117 @@ OTHER BENCHMARKS 20240910010056, tok/s= 47.85, mem/s= 213.24 GB/s, peak_mem=11.85 GB, model_size= 4.46 GB quant: uintx-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240910010647, tok/s= 34.83, mem/s= 261.42 GB/s, peak_mem=14.99 GB, model_size= 7.51 GB quant: uintx-2-8, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-2-8 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240910110958, tok/s=223.95, mem/s= 682.88 GB/s, peak_mem= 5.59 GB, model_size= 3.05 GB quant: sparse-marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20250115111811, tok/s=132.58, tok/s_decode=134.92, ttft=0.0256, mem/s=1989.99 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115111955, tok/s=132.39, tok/s_decode=134.90, ttft=0.0274, mem/s=1987.19 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115112851, tok/s=102.36, tok/s_decode=106.53, ttft=0.0759, mem/s= 499.36 GB/s, peak_mem=10.11 GB, model_size= 4.88 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115113023, tok/s=132.40, tok/s_decode=134.92, ttft=0.0275, mem/s=1987.31 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115113154, tok/s=102.34, tok/s_decode=106.46, ttft=0.0748, mem/s= 499.29 GB/s, peak_mem=10.11 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115114035, tok/s= 82.15, tok/s_decode=107.69, ttft=0.5768, mem/s=1233.05 GB/s, peak_mem=36.46 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115114623, tok/s= 72.78, tok/s_decode= 88.50, ttft=0.4874, mem/s= 355.08 GB/s, peak_mem=18.27 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115114936, tok/s=132.34, tok/s_decode=134.85, ttft=0.0274, mem/s=1986.47 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115115115, tok/s=102.81, tok/s_decode=106.89, ttft=0.0735, mem/s= 501.58 GB/s, peak_mem=10.11 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115115406, tok/s=132.39, tok/s_decode=134.90, ttft=0.0274, mem/s=1987.10 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115115503, tok/s=132.41, tok/s_decode=134.91, ttft=0.0273, mem/s=1987.40 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250115120048, tok/s=132.39, tok/s_decode=134.93, ttft=0.0277, mem/s=1987.15 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250116123651, tok/s=129.31, tok/s_decode=134.38, ttft=0.0576, mem/s= 630.81 GB/s, peak_mem= 6.94 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250116124020, tok/s=110.09, tok/s_decode=132.55, ttft=0.0607, mem/s= 537.06 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --profile bsr_trace --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250116124956, tok/s=131.75, tok/s_decode=134.13, ttft=0.0263, mem/s=1977.55 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250116130019, tok/s=130.31, tok/s_decode=134.85, ttft=0.0512, mem/s= 635.66 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250116130350, tok/s= 20.09, tok/s_decode= 20.32, ttft=0.1054, mem/s= 98.00 GB/s, peak_mem=16.97 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121085551, tok/s= 19.53, tok/s_decode= 19.75, ttft=0.1045, mem/s= 117.50 GB/s, peak_mem=16.97 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121090403, tok/s= 5.14, tok/s_decode= 5.17, ttft=0.1720, mem/s= 30.95 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121090648, tok/s=132.21, tok/s_decode=134.58, ttft=0.0261, mem/s=1984.43 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121090848, tok/s=132.12, tok/s_decode=134.62, ttft=0.0274, mem/s=1983.16 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121091251, tok/s= 5.13, tok/s_decode= 5.16, ttft=0.1628, mem/s= 30.89 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121091339, tok/s=121.71, tok/s_decode=134.38, ttft=0.0315, mem/s=1826.78 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121091826, tok/s= 4.65, tok/s_decode= 5.17, ttft=0.1760, mem/s= 27.99 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121092437, tok/s= 4.65, tok/s_decode= 5.16, ttft=0.1638, mem/s= 27.95 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121093419, tok/s= 4.67, tok/s_decode= 5.17, ttft=0.1728, mem/s= 28.10 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121093920, tok/s= 2.65, tok/s_decode= 5.14, ttft=0.5703, mem/s= 15.94 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121094143, tok/s= 2.66, tok/s_decode= 5.15, ttft=0.5759, mem/s= 16.03 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121100759, tok/s= 2.82, tok/s_decode= 5.14, ttft=0.5244, mem/s= 16.97 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121101108, tok/s= 2.85, tok/s_decode= 5.13, ttft=0.5582, mem/s= 17.15 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121101728, tok/s= 2.82, tok/s_decode= 5.14, ttft=0.5433, mem/s= 16.98 GB/s, peak_mem=27.73 GB, model_size= 6.02 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121102340, tok/s= 81.98, tok/s_decode=107.42, ttft=0.5773, mem/s=1230.47 GB/s, peak_mem=36.46 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121102642, tok/s= 82.03, tok/s_decode=107.47, ttft=0.5765, mem/s=1231.23 GB/s, peak_mem=36.44 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121102757, tok/s= 82.08, tok/s_decode=107.51, ttft=0.5758, mem/s=1231.94 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121102943, tok/s= 82.10, tok/s_decode=107.54, ttft=0.5757, mem/s=1232.24 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121103057, tok/s= 82.05, tok/s_decode=107.53, ttft=0.5769, mem/s=1231.59 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121103140, tok/s= 81.98, tok/s_decode=107.50, ttft=0.5785, mem/s=1230.47 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121103512, tok/s= 82.09, tok/s_decode=107.54, ttft=0.5757, mem/s=1232.19 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121104154, tok/s= 82.13, tok/s_decode=107.59, ttft=0.5755, mem/s=1232.79 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121104406, tok/s=119.88, tok/s_decode=151.12, ttft=0.3441, mem/s= 584.77 GB/s, peak_mem=12.38 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121104646, tok/s= 82.06, tok/s_decode=107.51, ttft=0.5761, mem/s=1231.68 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121104931, tok/s= 77.10, tok/s_decode=107.58, ttft=0.7870, mem/s=1157.20 GB/s, peak_mem=36.70 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121110040, tok/s= 82.08, tok/s_decode=107.53, ttft=0.5756, mem/s=1232.06 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121110148, tok/s= 76.50, tok/s_decode=107.04, ttft=0.5778, mem/s=1148.24 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile baseline_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121110258, tok/s=108.99, tok/s_decode=150.54, ttft=0.3432, mem/s= 531.67 GB/s, peak_mem=12.38 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile bsr_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121110904, tok/s=203.34, tok/s_decode=214.37, ttft=0.0499, mem/s= 991.92 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121111229, tok/s=182.26, tok/s_decode=214.52, ttft=0.0467, mem/s= 889.09 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121112735, tok/s=182.42, tok/s_decode=214.30, ttft=0.0495, mem/s= 889.89 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121113757, tok/s=182.86, tok/s_decode=214.41, ttft=0.0494, mem/s= 892.01 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121114610, tok/s=182.63, tok/s_decode=214.34, ttft=0.0503, mem/s= 890.88 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121122840, tok/s= 69.40, tok/s_decode= 70.52, ttft=0.0455, mem/s=1824.78 GB/s, peak_mem=27.82 GB, model_size=26.30 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121123616, tok/s=205.26, tok/s_decode=214.95, ttft=0.0434, mem/s=1001.28 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121124112, tok/s=204.65, tok/s_decode=214.91, ttft=0.0460, mem/s= 998.30 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121124437, tok/s=205.09, tok/s_decode=215.09, ttft=0.0448, mem/s=1000.48 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121135848, tok/s=123.54, tok/s_decode=134.43, ttft=0.0113, mem/s=1854.27 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121135953, tok/s=182.45, tok/s_decode=214.15, ttft=0.0495, mem/s= 890.04 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121140550, tok/s=123.42, tok/s_decode=134.38, ttft=0.0119, mem/s=1852.55 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250121140658, tok/s=182.52, tok/s_decode=214.21, ttft=0.0502, mem/s= 890.35 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250122135225, tok/s=123.34, tok/s_decode=134.41, ttft=0.0121, mem/s=1851.32 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250122135328, tok/s=182.94, tok/s_decode=214.26, ttft=0.0487, mem/s= 892.41 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250122135614, tok/s= 11.63, tok/s_decode= 12.98, ttft=0.1701, mem/s= 56.74 GB/s, peak_mem=17.34 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250122154428, tok/s= 1.08, tok/s_decode= 1.18, ttft=0.1716, mem/s= 5.29 GB/s, peak_mem=17.34 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123101614, tok/s=170.05, tok/s_decode=214.41, ttft=0.0481, mem/s= 829.55 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123101846, tok/s=182.85, tok/s_decode=214.25, ttft=0.0474, mem/s= 891.99 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123113033, tok/s= 11.72, tok/s_decode= 13.29, ttft=0.3043, mem/s= 57.15 GB/s, peak_mem=17.34 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123113307, tok/s=179.71, tok/s_decode=213.42, ttft=0.0530, mem/s= 876.67 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123113418, tok/s=182.62, tok/s_decode=214.15, ttft=0.0490, mem/s= 890.83 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123113647, tok/s=182.27, tok/s_decode=214.18, ttft=0.0488, mem/s= 889.15 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition.json.gz --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123114432, tok/s=182.23, tok/s_decode=217.09, ttft=0.0581, mem/s= 888.94 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123123324, tok/s=186.08, tok/s_decode=217.45, ttft=0.0475, mem/s= 907.74 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123135016, tok/s=186.23, tok/s_decode=217.48, ttft=0.0468, mem/s= 908.45 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123145640, tok/s=185.20, tok/s_decode=216.60, ttft=0.0494, mem/s= 903.44 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123145919, tok/s=185.33, tok/s_decode=217.23, ttft=0.0493, mem/s= 904.08 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123150712, tok/s= 77.17, tok/s_decode=109.19, ttft=0.5785, mem/s=1158.24 GB/s, peak_mem=36.46 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile baseline_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123151351, tok/s= 77.72, tok/s_decode=109.31, ttft=0.5766, mem/s=1166.50 GB/s, peak_mem=36.44 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile baseline_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123151839, tok/s=102.64, tok/s_decode=154.70, ttft=0.4758, mem/s= 500.68 GB/s, peak_mem=17.94 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile bsr_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123152218, tok/s= 77.85, tok/s_decode=109.37, ttft=0.5770, mem/s=1168.55 GB/s, peak_mem=36.19 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile baseline_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123152330, tok/s=102.95, tok/s_decode=154.99, ttft=0.4876, mem/s= 502.20 GB/s, peak_mem=17.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile bsr_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123152615, tok/s=102.68, tok/s_decode=154.82, ttft=0.4879, mem/s= 500.90 GB/s, peak_mem=17.67 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile bsr_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123153256, tok/s=100.57, tok/s_decode=151.81, ttft=0.4890, mem/s= 490.60 GB/s, peak_mem=17.94 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile bsr_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123154843, tok/s=101.30, tok/s_decode=152.23, ttft=0.4892, mem/s= 494.15 GB/s, peak_mem=17.92 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile bsr_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123155937, tok/s=101.24, tok/s_decode=152.16, ttft=0.4889, mem/s= 493.86 GB/s, peak_mem=17.92 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--profile bsr_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123160302, tok/s=122.71, tok/s_decode=134.11, ttft=0.0120, mem/s=1841.91 GB/s, peak_mem=16.50 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123160508, tok/s=123.29, tok/s_decode=134.19, ttft=0.0116, mem/s=1850.63 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123160810, tok/s=123.16, tok/s_decode=134.16, ttft=0.0118, mem/s=1848.66 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123161148, tok/s=185.76, tok/s_decode=217.48, ttft=0.0502, mem/s= 906.15 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123161835, tok/s=123.16, tok/s_decode=134.13, ttft=0.0118, mem/s=1848.54 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123162259, tok/s=123.05, tok/s_decode=134.13, ttft=0.0122, mem/s=1846.98 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile baseline --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123162406, tok/s=186.18, tok/s_decode=217.73, ttft=0.0470, mem/s= 908.22 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --profile bsr_padded_trition --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123162940, tok/s=133.31, tok/s_decode=134.35, ttft=0.0112, mem/s=2000.93 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123163117, tok/s=133.26, tok/s_decode=134.39, ttft=0.0120, mem/s=2000.18 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123163224, tok/s=133.28, tok/s_decode=134.39, ttft=0.0117, mem/s=2000.52 GB/s, peak_mem=16.24 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123163331, tok/s=207.77, tok/s_decode=218.34, ttft=0.0459, mem/s=1013.55 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123163555, tok/s=179.75, tok/s_decode=187.99, ttft=0.0481, mem/s= 879.72 GB/s, peak_mem= 6.32 GB, model_size= 4.89 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123165038, tok/s=207.95, tok/s_decode=218.24, ttft=0.0447, mem/s=1014.41 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123170322, tok/s=208.20, tok/s_decode=218.38, ttft=0.0442, mem/s=1015.65 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123171256, tok/s=208.58, tok/s_decode=218.48, ttft=0.0428, mem/s=1017.47 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123172543, tok/s=146.94, tok/s_decode=149.85, ttft=0.0259, mem/s=1941.80 GB/s, peak_mem=13.94 GB, model_size=13.21 GB quant: None, sparse: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123173042, tok/s=207.86, tok/s_decode=218.47, ttft=0.0461, mem/s=1013.96 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123173713, tok/s=208.45, tok/s_decode=218.38, ttft=0.0430, mem/s=1016.85 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123183901, tok/s=207.95, tok/s_decode=218.33, ttft=0.0450, mem/s=1014.43 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123184904, tok/s=146.81, tok/s_decode=149.91, ttft=0.0275, mem/s=1940.08 GB/s, peak_mem=13.92 GB, model_size=13.21 GB quant: None, sparse: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123184942, tok/s= 63.00, tok/s_decode= 68.67, ttft=0.2616, mem/s= 417.12 GB/s, peak_mem= 9.16 GB, model_size= 6.62 GB quant: int8dq, sparse: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123185104, tok/s=207.53, tok/s_decode=218.47, ttft=0.0475, mem/s=1012.36 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250123185604, tok/s=208.35, tok/s_decode=218.59, ttft=0.0444, mem/s=1016.38 GB/s, peak_mem= 6.31 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124120551, tok/s=148.85, tok/s_decode=157.66, ttft=0.0748, mem/s= 726.10 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124121349, tok/s= 93.00, tok/s_decode= 93.67, ttft=0.0150, mem/s=1395.96 GB/s, peak_mem=16.47 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124121533, tok/s=149.71, tok/s_decode=157.95, ttft=0.0695, mem/s= 730.29 GB/s, peak_mem= 6.58 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124124548, tok/s= 48.92, tok/s_decode= 70.49, ttft=1.2505, mem/s= 734.29 GB/s, peak_mem=36.45 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124124720, tok/s= 48.95, tok/s_decode= 70.50, ttft=1.2485, mem/s= 734.75 GB/s, peak_mem=36.70 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124125113, tok/s= 48.87, tok/s_decode= 70.78, ttft=1.2673, mem/s= 733.50 GB/s, peak_mem=36.70 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124125909, tok/s= 67.03, tok/s_decode= 99.25, ttft=0.9682, mem/s= 326.99 GB/s, peak_mem=18.15 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8192--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250124152728, tok/s=149.00, tok/s_decode=157.80, ttft=0.0745, mem/s= 726.43 GB/s, peak_mem= 6.67 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 \ No newline at end of file diff --git a/torchao/_models/llama/bsr_benchmarks.sh b/torchao/_models/llama/bsr_benchmarks.sh new file mode 100644 index 0000000000..5e0228b6c0 --- /dev/null +++ b/torchao/_models/llama/bsr_benchmarks.sh @@ -0,0 +1,8 @@ +export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B + +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8192 +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8192 --sparsity bsr-0.9-64 +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --sparsity bsr-0.9-64 +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --sparsity bsr-0.9-32 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index b1d3475601..65a35f2fb3 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -794,9 +794,35 @@ def ffn_or_attn_only(mod, fqn): from torchao.sparsity import semi_sparse_weight, sparsify_ if "semi" in sparsity: - # TODO there is a bug here, need to fix + # Fixed sparsity level for 2:4 sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) + if "bsr" in sparsity: + from torchao.sparsity import SupermaskLinear, block_sparse_weight + # parse "bsr-0.9-64" + _, sparsity_level, blocksize = sparsity.split("-") + sparsity_level, blocksize = float(sparsity_level), int(blocksize) + sparsify_( + model, + lambda x: SupermaskLinear.from_linear(x, + sparsity_level=sparsity_level, + blocksize=blocksize, + ), + filter_fn=ffn_only, + ) + print(model) + sparsify_( + model, + SupermaskLinear.to_linear, + filter_fn=ffn_only, + ) + print(model) + + # Accelerate with triton bsr kernels + sparsify_(model, + block_sparse_weight(blocksize=blocksize), + filter_fn=ffn_only) + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 if save: @@ -811,7 +837,7 @@ def ffn_or_attn_only(mod, fqn): print("Compiling Model") global decode_one_token, prefill decode_one_token = torch.compile( - decode_one_token, mode="reduce-overhead", fullgraph=True + decode_one_token, mode="reduce-overhead", fullgraph=True, dynamic=True, ) if compile_prefill: @@ -850,7 +876,7 @@ def ffn_or_attn_only(mod, fqn): prompt = f"{B_INST} {prompt.strip()} {E_INST}" encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - if interactive and i >= 0: + if interactive and i >= 0 and prefill_size is None: buffer = [] period_id = tokenizer.encode(".")[0] done_generating = False @@ -920,7 +946,7 @@ def callback(x): device_sync(device=device) # MKG t = time.perf_counter() - t0 - if not interactive and demo_summarize_prompt is None: + if not interactive and demo_summarize_prompt is None and prefill_size is None: tok_list = y[0].tolist() # truncate text after end of string token tokens = ( diff --git a/torchao/kernel/__init__.py b/torchao/kernel/__init__.py index 409da72601..2006d6c403 100644 --- a/torchao/kernel/__init__.py +++ b/torchao/kernel/__init__.py @@ -1,6 +1,9 @@ from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm +from torchao.kernel.bsr_triton_ops import bsr_dense_addmm, broadcast_batch_dims __all__ = [ + "bsr_dense_addmm", + "broadcast_batch_dims" "safe_int_mm", "int_scaled_matmul", ] diff --git a/torchao/kernel/bsr_triton_ops.py b/torchao/kernel/bsr_triton_ops.py new file mode 100644 index 0000000000..f6f28ee4a0 --- /dev/null +++ b/torchao/kernel/bsr_triton_ops.py @@ -0,0 +1,857 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +import os +import weakref +from functools import lru_cache +from typing import Optional + +import torch +from torch._dynamo.utils import warn_once +from torch.utils._triton import has_triton + +from torch.sparse._triton_ops_meta import get_meta + + + +TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( + os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2) +) + + +def check(cond, msg): + if not cond: + raise ValueError(msg) + + +def check_bsr_layout(f_name, t): + check( + t.layout == torch.sparse_bsr, + f"{f_name}(): only BSR sparse format is supported for the sparse argument.", + ) + + +def check_device(f_name, t, device): + check( + t.device == device and t.device.type == "cuda", + f"{f_name}(): all inputs are expected to be on the same GPU device.", + ) + + +def check_mm_compatible_shapes(f_name, lhs, rhs): + check( + lhs.dim() >= 2 and rhs.dim() >= 2, + f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, " + f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.", + ) + + _m, kl = lhs.shape[-2:] + kr, _n = rhs.shape[-2:] + + check( + kl == kr, + f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, " + f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.", + ) + + +def check_dtype(f_name, t, dtype, *additional_dtypes): + check( + t.dtype == dtype + and t.dtype + in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)), + f"{f_name}(): all inputs are expected to be of the same dtype " + f"and one of (half, bfloat16, float32) or {additional_dtypes}, " + f"but got dtype == {t.dtype}.", + ) + + +def check_blocksize(f_name, blocksize): + assert len(blocksize) == 2 + + def is_power_of_two(v): + return not (v & (v - 1)) + + def is_compatible_blocksize(b): + res = True + for blocksize in b: + # Triton loads only blocks which are at least 16 and powers of 2. + res = (blocksize >= 16 and is_power_of_two(blocksize)) and res + return res + + check( + is_compatible_blocksize(blocksize), + f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) " + "should be at least 16 and a power of 2 in each dimension.", + ) + + +def make_triton_contiguous(t): + """Return input as a triton-contiguous tensor. + + A triton-contiguous tensor is defined as a tensor that has strides + with minimal value smaller than or equal to 1. + + While triton kernels support triton-non-contiguous tensors (all + strides being greater than 1) arguments, a considerable slow-down + occurs because tensor data is copied element-wise rather than + chunk-wise. Zero strides is assumed to not have this defect. + """ + if min(t.stride()) > 1: + # TODO: investigate if contiguity along other axes than the + # last one can be beneficial for performance + return t.contiguous() + else: + return t + + +def broadcast_batch_dims(f_name, *tensors): + try: + return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors)) + except Exception: + check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!") + + +def slicer(dim, slice_range, *tensors): + for t in tensors: + slices = [slice(None)] * t.dim() + slices[dim] = slice_range + yield t[slices] + + +def multidim_slicer(dims, slices, *tensors): + for t in tensors: + s = [slice(None)] * t.dim() + for d, d_slice in zip(dims, slices): + if d is not None: + s[d] = d_slice + yield t[s] + + +def ptr_stride_extractor(*tensors): + for t in tensors: + yield t + yield from t.stride() + + +def grid_partitioner(full_grid, grid_blocks, tensor_dims_map): + assert 0 <= len(full_grid) <= 3 + assert 0 <= len(grid_blocks) <= 3 + + import itertools + + def generate_grid_points(): + for fg, mg in zip(full_grid, grid_blocks): + yield range(0, fg, mg) + + def generate_sliced_tensors(slices): + for t, t_dims in tensor_dims_map.items(): + yield next(multidim_slicer(t_dims, slices, t)) + + for grid_point in itertools.product(*generate_grid_points()): + grid = [ + min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks) + ] + slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)] + # grid_points are iterated in a "contiguous" order, i.e. + # left dimensions traversed slower than right dimensions. + # This order is reversed for CUDA grids. + yield grid[::-1], *generate_sliced_tensors(slices) + + +def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None): + # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1) + cuda_max_grid = (2147483647, 65535, 65535)[::-1] + if grid_blocks is None: + grid_blocks = cuda_max_grid + else: + + def valid_grid_dim(g, mg): + if g is None: + return mg + else: + # grid must be at least 1 and no greater than mg + return max(1, min(g, mg)) + + grid_blocks = tuple( + valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid) + ) # type: ignore[assignment] + + for grid, *sliced_tensors in grid_partitioner( + full_grid, grid_blocks, tensor_dims_map + ): + kernel(grid, *sliced_tensors) + + +def prepare_inputs(bsr, *dense_tensors): + # Introduce fake batch dimension if not present for convenience. + crow_indices = bsr.crow_indices().unsqueeze(0) + col_indices = bsr.col_indices().unsqueeze(0) + values = make_triton_contiguous(bsr.values().unsqueeze(0)) + tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors] + + # Compute broadcasted batch dimension + batch_dims_broadcasted = torch.broadcast_shapes( + values.shape[:-3], *(t.shape[:-2] for t in tensors) + ) + + # Broadcast batch dimensions and squash. + # The result can be either a view or a copy. + def batch_broadcast_and_squash(t, batch_dims, invariant_dims): + return t.broadcast_to(batch_dims + invariant_dims).flatten( + 0, len(batch_dims) - 1 + ) + + crow_indices = batch_broadcast_and_squash( + crow_indices, batch_dims_broadcasted, (-1,) + ) + + col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,)) + values = batch_broadcast_and_squash( + values, batch_dims_broadcasted, values.shape[-3:] + ) + tensors = [ + batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) + for t in tensors + ] + + return crow_indices, col_indices, values, *tensors + + +def broadcast_batch_dims_bsr(f_name, bsr, *tensors): + batch_shape = broadcast_batch_dims(f_name, bsr, *tensors) + + crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,)) + col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,)) + values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:]) + size = batch_shape + bsr.shape[-2:] + return torch.sparse_compressed_tensor( + crow_indices, col_indices, values, size=size, layout=bsr.layout + ) + + +# NOTE: this function will ALWAYS create a view +def tile_to_blocksize(t, blocksize): + *rest, m, n = t.shape + new_shape = rest + [ + m // blocksize[0], + blocksize[0], + n // blocksize[1], + blocksize[1], + ] + # using .view instead of .reshape to ensure that the result is + # indeed a view: + return t.view(new_shape).transpose(-3, -2) + + +def as1Dbatch(tensor): + """Return tensor as 3D tensor by either prepending new dimensions to + the tensor shape (when ``tensor.ndim < 3``), or by collapsing + starting dimensions into the first dimension (when ``tensor.ndim > + 3``). + """ + while tensor.ndim < 3: + tensor = tensor.unsqueeze(0) + if tensor.ndim > 3: + tensor = tensor.flatten(0, tensor.ndim - 3) + assert tensor.ndim == 3, tensor.shape + return tensor + + +## addmm functionality + +def bsr_dense_addmm_meta( + M, + K, + N, + Ms, + Ks, + beta, + alpha, + SPLIT_N=None, + GROUP_SIZE_ROW=None, + num_warps=None, + num_stages=None, + sparsity=None, + dtype=None, + out_dtype=None, + _version=0, + **extra, +): + # Specifying _version is useful for situations when one wants to + # discard existing triton kernel tuning results, say, in testing + # bsr_dense_addmm_meta functionality. + if dtype is None: + dtype = torch.float16 + if out_dtype is None: + out_dtype = dtype + if sparsity is None: + sparsity = 0.5 + if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: + device_name = torch.cuda.get_device_name() + key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) + if dtype is out_dtype: + version_dtype = dtype + else: + version_dtype = dtype, out_dtype + meta = get_meta( + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, sparsity), + ) + if meta is None and sparsity != 0.5: + meta = get_meta( + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, 0.5), + ) + if meta is None and dtype is not out_dtype: + meta = get_meta( + "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5) + ) + if meta is None: + # find approximate meta such that N % SPLIT_N == 0. + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, version_dtype, 0.5), + ) + if matching_meta is None and dtype is not out_dtype: + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, dtype, 0.5), + ) + for mkey in sorted(matching_meta or {}): + meta_ = matching_meta[mkey] + n = mkey[2] + split_n = meta_["SPLIT_N"] + c = n // split_n + if N % c == 0 and n <= N: + meta = dict(meta_) + meta["SPLIT_N"] = N // c + if meta is not None: + meta.update(**extra) + return meta + else: + # see [Computing optimal kernel parameters] in + # _triton_ops_meta.py for ways to avoid this warning + # optimize_bsr_dense_addmm( + # M, + # K, + # 16, + # 64, + # 64, + # beta=beta, + # alpha=alpha, + # sparsity=sparsity, + # dtype=dtype, + # opname="bsr_dense_addmm", + # verbose=True, + # ) + # get padded key + padded_key = (M, K, 16, Ms, Ks, beta == 0, beta == 1, alpha == 1) + meta = get_meta( + "bsr_dense_addmm", + padded_key, + device_name, + version=(_version, version_dtype, sparsity), + ) + # breakpoint() + # return meta + # message + # breakpoint() + # warn_once( + # "bsr_dense_addmm uses non-optimal triton kernel parameters" + # f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}" + # ) + + SPLIT_N = SPLIT_N or max(N // Ms, 1) + GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4 + num_stages = num_stages or 1 + num_warps = num_warps or 4 + return dict( + SPLIT_N=SPLIT_N, + GROUP_SIZE_ROW=GROUP_SIZE_ROW, + num_stages=num_stages, + num_warps=num_warps, + **extra, + ) + + +class TensorAsKey: + """A light-weight wrapper of a tensor that enables storing tensors as + keys with efficient memory reference based comparision as an + approximation to data equality based keys. + + Motivation: the hash value of a torch tensor is tensor instance + based that does not use data equality and makes the usage of + tensors as keys less useful. For instance, the result of + ``len({a.crow_indices(), a.crow_indices()})`` is `2`, although, + the tensor results from `crow_indices` method call are equal, in + fact, these share the same data storage. + On the other hand, for efficient caching of tensors we want to + avoid calling torch.equal that compares tensors item-wise. + + TensorAsKey offers a compromise in that it guarantees key equality + of tensors that references data in the same storage in the same + manner and without accessing underlying data. However, this + approach does not always guarantee correctness. For instance, for + a complex tensor ``x``, we have ``TensorAsKey(x) == + TensorAsKey(x.conj())`` while ``torch.equal(x, x.conj())`` would + return False. + """ + + def __init__(self, obj): + def get_tensor_key(obj): + # Warning: TensorAsKey does not track negative nor + # conjugate bits of its input object because in the use + # case of wrapping compressed/plain indices of compressed + # sparse tensors (that are always integer tensors with + # non-negative items) these bits are never set. However, + # when extending the use of TensorAsKey to float or + # complex tensors, the values of these bits (see is_neg + # and is_conj methods) must be included in the key as + # well. + assert not (obj.dtype.is_floating_point or obj.dtype.is_complex), obj.dtype + return ( + obj.data_ptr(), + obj.storage_offset(), + obj.shape, + obj.stride(), + obj.dtype, + ) + + self._obj_ref = weakref.ref(obj) + if obj.layout is torch.strided: + self.key = get_tensor_key(obj) + elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: + self.key = ( + get_tensor_key(obj.crow_indices()), + get_tensor_key(obj.col_indices()), + ) + elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}: + self.key = ( + get_tensor_key(obj.ccol_indices()), + get_tensor_key(obj.row_indices()), + ) + else: + raise NotImplementedError(obj.layout) + self._hash = hash(self.key) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if not isinstance(other, TensorAsKey): + return False + if self.obj is None or other.obj is None: + # dead objects always compare unequal unless these are + # same objects + return self is other + return self.key == other.key + + @property + def obj(self): + """Return object if alive, otherwise None.""" + return self._obj_ref() + + +def _int_bsr_dense_addmm( + input: torch.Tensor, + bsr: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + if out is None and dense.dtype is torch.int8: + f_name = "_int_bsr_dense_addmm" + crow_indices = bsr.crow_indices() + batch_ndim = crow_indices.dim() - 1 + M = bsr.shape[batch_ndim] + N = dense.shape[-1] + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + out = torch.empty( + original_batch_dims_broadcasted + (M, N), + dtype=torch.int32, + device=dense.device, + ) + return bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + out=out, + skip_checks=skip_checks, + max_grid=max_grid, + meta=meta, + ) + + +def bsr_dense_addmm( + input: torch.Tensor, + bsr: torch.Tensor, + row_indices: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + """Compute + + out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1) + + where left_alpha, right_alpha are (* + 1)-D tensors when + specified, otherwise, these are treated as tensors filled with + ones. + """ + f_name = "bsr_dense_addmm" + values = bsr.values() + crow_indices = bsr.crow_indices() + col_indices = bsr.col_indices() + batch_ndim = crow_indices.dim() - 1 + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3] + N = dense.shape[-1] + + # todo: implement checks + + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + if out is None: + out = dense.new_empty(original_batch_dims_broadcasted + (M, N)) + + if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0: + if beta == 0: + out.zero_() + else: + out.copy_(input) + if beta != 1: + out.mul_(beta) + return out + + left_alpha_is_one = False + right_alpha_is_one = False + if left_alpha is None: + left_alpha_is_one = True + left_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand( + *original_batch_dims_broadcasted, M, N + ) + + if right_alpha is None: + right_alpha_is_one = True + right_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand( + *original_batch_dims_broadcasted, M, N + ) + assert left_alpha.stride()[-1] == 0 + assert right_alpha.stride()[-2] == 0 + + if meta is None: + sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) + meta = bsr_dense_addmm_meta( + M, + K, + N, + blocksize[0], + blocksize[1], + beta, + alpha, + sparsity=sparsity, + dtype=dense.dtype, + out_dtype=out.dtype, + ) + out_backup = out + + ( + crow_indices, + col_indices, + values, + input, + dense, + left_alpha, + right_alpha, + out, + ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out) + + BM, BK = blocksize + SPLIT_N = meta.get("SPLIT_N", N // BM) + BN = N // SPLIT_N + + out_untiled = out + out = tile_to_blocksize(out, (BM, BN)) + dense = tile_to_blocksize(dense, (BK, BN)) + input = tile_to_blocksize(input, (BM, BN)) + left_alpha = tile_to_blocksize(left_alpha, (BM, BN)) + right_alpha = tile_to_blocksize(right_alpha, (BM, BN)) + + # tl.dot supports float16, float32, int32 as accumulator types. + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + torch.int8: tl.int32, + torch.int32: tl.int32, + }[out.dtype] + + n_batches = dense.size(0) + n_block_rows = crow_indices.size(-1) - 1 + n_block_cols = dense.size(-3) + + full_grid = (n_batches, n_block_cols, n_block_rows) + if max_grid is not None: + grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3])) + else: + grid_blocks = None + + tensor_dims_map = { + values: (0, None, None), + crow_indices: (0, None, -1), + col_indices: (0, None, None), + input: (0, -3, -4), + dense: (0, -3, None), + left_alpha: (0, -3, -4), + right_alpha: (0, -3, -4), + out: (0, -3, -4), + } + + assert alpha != 0 + + def kernel(grid, *sliced_tensors): + _bsr_strided_addmm_kernel[grid]( + *ptr_stride_extractor(*sliced_tensors), + beta, + alpha, + beta_is_one=beta == 1, + beta_is_nonzero=beta != 0, + alpha_is_one=alpha == 1, + left_alpha_is_one=left_alpha_is_one, + right_alpha_is_one=right_alpha_is_one, + BLOCKSIZE_ROW=BM, + BLOCKSIZE_INNER=BK, + BLOCKSIZE_COL=BN, + BLOCKSIZE_K=32, + allow_tf32=dot_out_dtype == tl.float32, + acc_dtype=dot_out_dtype, + **meta, + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + if out.data_ptr() != out_backup.data_ptr(): + # prepare_inputs has made a copy of out, copy its content back + # to out_backup: + out_backup.copy_(out_untiled.view(out_backup.shape)) + + return out_backup + + +if has_triton(): + import triton + import triton.language as tl + + @triton.jit + def _bsr_strided_addmm_kernel( + # values prologue + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # crow_indices prologue + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + # crow_indices epilogue + # col_indices prologue + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + # col_indices epilogue + # input prologue + input_ptr, + input_batch_stride, + input_tiled_row_stride, + input_tiled_col_stride, + input_row_block_stride, + input_col_block_stride, + # input epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # left_alpha prologue + left_alpha_ptr, + left_alpha_batch_stride, + left_alpha_tiled_row_stride, + left_alpha_tiled_col_stride: tl.constexpr, + left_alpha_row_block_stride, + left_alpha_col_block_stride: tl.constexpr, + # left_alpha epilogue + # right_alpha prologue + right_alpha_ptr, + right_alpha_batch_stride, + right_alpha_tiled_row_stride: tl.constexpr, + right_alpha_tiled_col_stride, + right_alpha_row_block_stride: tl.constexpr, + right_alpha_col_block_stride, + # right_alpha epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + beta, + alpha, + beta_is_one: tl.constexpr, + beta_is_nonzero: tl.constexpr, + alpha_is_one: tl.constexpr, + left_alpha_is_one: tl.constexpr, + right_alpha_is_one: tl.constexpr, + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + BLOCKSIZE_INNER: tl.constexpr, + BLOCKSIZE_K: tl.constexpr, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + GROUP_SIZE_ROW: tl.constexpr, + SPLIT_N: tl.constexpr, + ): + # left/right_alpha tensors are originally (* + 1)-dimensional + assert left_alpha_tiled_col_stride == 0 + assert left_alpha_col_block_stride == 0 + assert right_alpha_tiled_row_stride == 0 + assert right_alpha_row_block_stride == 0 + + batch_pid = tl.program_id(axis=2) + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + row_nnz = nnz_offset_next - nnz_offset + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) + + PADDED_BLOCKSIZE_COL : tl.constexpr = 16 + # if BLOCKSIZE_COL < 16 or BLOCKSIZE_COL % 16 != 0: + # else: + # PADDED_BLOCKSIZE_COL: tl.constexpr = BLOCKSIZE_COL + + col_block_arange = tl.arange(0, PADDED_BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * inner_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_pid + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * inner_block_arange[:, None] + + dense_col_block_stride * col_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_pid + + output_tiled_row_stride * row_block_pid + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * col_block_arange[None, :] + ) + + # Set pointer to the first nonzero element in the current row + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + output_acc_block = tl.zeros((BLOCKSIZE_ROW, PADDED_BLOCKSIZE_COL), dtype=acc_dtype) + + nsub_blocks = tl.cdiv(BLOCKSIZE_ROW, BLOCKSIZE_K) + + + for i in range(row_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx, + mask=col_block_arange[None, :] < BLOCKSIZE_COL, + ) + + # do block mm + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + # write back the result + tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty), mask=col_block_arange[None, :]< BLOCKSIZE_COL) + +else: + _bsr_strided_addmm_kernel = None # type: ignore[assignment] diff --git a/torchao/prototype/sparsity/superblock/README.md b/torchao/prototype/sparsity/superblock/README.md index 6fea1a0e3a..bed75c9ad3 100644 --- a/torchao/prototype/sparsity/superblock/README.md +++ b/torchao/prototype/sparsity/superblock/README.md @@ -66,11 +66,11 @@ Please refer to [TRAINING.md](TRAINING.md) for training from scratch. We use [To For example, if you would like to train a `vit_b_16` from scratch using Supermask, you can use the respective torchvision command found in [TRAINING.md](TRAINING.md) and append the supermask arguments: ``` torchrun --nproc_per_node=8 train.py\ - --model vit_h_14 --epochs 3 --batch-size 64 --opt adamw --lr 0.003 --wd 0.3\ + --model vit_b_16 --epochs 1 --batch-size 64 --opt adamw --lr 0.003 --wd 0.3\ --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 \ --clip-grad-norm 1 --cutmix-alpha 1.0 --model-ema\ - --sparsity semi_structured --data-path $IMAGENET_PATH + --sparsity bsr --data-path $IMAGENET_PATH ``` Through this command, we are training a `vit_b_16` with 90% sparsity to linear layers using 32x32 tiles. diff --git a/torchao/prototype/sparsity/superblock/benchmark.py b/torchao/prototype/sparsity/superblock/benchmark.py index b87834afae..a53f53c21a 100644 --- a/torchao/prototype/sparsity/superblock/benchmark.py +++ b/torchao/prototype/sparsity/superblock/benchmark.py @@ -81,6 +81,9 @@ def main(args): # With quantization, we must use cuSPARSELt to fuse one of the scalar matmuls. # Otherwise, we observe the CUTLASS kernels to be faster, so we use those instead. accelerate_with_sparsity(model, args) + if "bsr" in args.sparsity: + sparsify_(model, block_sparse_weight(blocksize=args.blocksize)) + elif "semi-structured" in args.sparsityk # compile model = torch.compile(model, mode="max-autotune", fullgraph=True) diff --git a/torchao/prototype/sparsity/superblock/supermask.py b/torchao/prototype/sparsity/superblock/supermask.py index abd23c566e..e1f8a67108 100644 --- a/torchao/prototype/sparsity/superblock/supermask.py +++ b/torchao/prototype/sparsity/superblock/supermask.py @@ -1,14 +1,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import torch.nn as nn import math - import torch -import torch.nn as nn +from torch.autograd import Variable import torch.nn.functional as F +import numpy as np + +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter # original supermask -scores_min = None -scores_max = 9e9 +scores_min=None +scores_max=9e9 uniform_init_01 = False # adjusted supermask, initialize scores with uniform distribution in [0,1], clamp scores in each step in [0,1] @@ -16,54 +19,51 @@ # scores_max=1. # uniform_init_01 = True - def percentile(t, q): """Return the value that is larger than q% of t""" - k = 1 + round(0.01 * float(q) * (t.numel() - 1)) + k = 1 + round(.01 * float(q) * (t.numel() - 1)) return t.view(-1).kthvalue(k).values class GetSubnet(torch.autograd.Function): """Supermask STE function""" - @staticmethod def forward(ctx, scores, zeros, ones, sparsity): - clamped_scores = scores.clamp(min=scores_min, max=scores_max) - k_val = percentile(clamped_scores, sparsity * 100) - return torch.where( - clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device) - ) - + clamped_scores = scores.clamp(min=scores_min,max=scores_max) + k_val = percentile(clamped_scores, sparsity*100) + return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device)) @staticmethod def backward(ctx, g): return g, None, None, None +class ApplyMask(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, weight, scores): + return weight * scores + @staticmethod + def backward(ctx, grad_output): + grad_weight = grad_scores = None + if ctx.needs_input_grad[0]: + grad_weight = grad_output + if ctx.needs_input_grad[1]: + grad_scores = grad_output + return grad_weight, grad_scores + + class SupermaskLinear(nn.Linear): """Supermask class for Linear layer""" - - def __init__( - self, - sparsity, - fixed_mask, - fixed_weight, - bitwidth, - transform, - fixed_transform, - *args, - **kwargs, - ): + def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): tile_size = kwargs.pop("tile_size", 1) super(SupermaskLinear, self).__init__(*args, **kwargs) # initialize the scores - max_sparsity = 1 - ( - 1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()]) - ) + max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) self.sparsity = sparsity if self.sparsity > max_sparsity: print( f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" ) self.sparsity = max_sparsity self.tile_size = tile_size @@ -74,60 +74,42 @@ def __init__( ), requires_grad=not fixed_mask, ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_( - self.scores, a=math.sqrt(5) - ) + nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) - # the shift and the scale are transformation parameters + # the shift and the scale are transformation parameters # the actually used weights = self.weight*self.scale+self.shift # the transformation is activated only for quantized weights - self.shift = nn.Parameter(torch.Tensor(1).fill_(0.0), requires_grad=False) - self.scale = nn.Parameter(torch.Tensor(1).fill_(1.0), requires_grad=False) - + self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) + self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) + with torch.no_grad(): # if bitwidth is None, then use floating point values in self.weight # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 + # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 # these quantized values are uniformly distributed if bitwidth is not None: weights_max = torch.max(self.weight).item() weights_min = torch.min(self.weight).item() - least_step = (weights_max - weights_min) / pow(2, bitwidth) - left_bound = weights_min - 1e-6 - right_bound = weights_min + least_step + 1e-6 + least_step = (weights_max-weights_min)/pow(2,bitwidth) + left_bound = weights_min-1e-6 + right_bound = weights_min+least_step+1e-6 # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift = nn.Parameter( - torch.Tensor(1).fill_( - 0.0 if transform[0] is None else transform[0] - ), - requires_grad=not fixed_transform[0], - ) - self.scale = nn.Parameter( - torch.Tensor(1).fill_( - 1.0 if transform[1] is None else transform[1] - ), - requires_grad=not fixed_transform[1], - ) - for i in range(-int(pow(2, bitwidth - 1)), int(pow(2, bitwidth - 1))): - self.weight[ - torch.logical_and( - self.weight > left_bound, self.weight <= right_bound - ) - ] = i + self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) + for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): + self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i left_bound = right_bound right_bound += least_step self.weight.requires_grad = not fixed_weight def get_mask(self): - subnet = GetSubnet.apply( - self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity, - ) + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity) if self.tile_size != 1: for i, k in enumerate(self.weight.shape): @@ -135,231 +117,40 @@ def get_mask(self): subnet = torch.narrow(subnet, i, 0, k) return subnet - + def sparsify_offline(self): subnet = self.get_mask() - self.weight.data = (self.weight * self.scale + self.shift) * subnet + self.weight.data = (self.weight*self.scale+self.shift) * subnet self.sparsify_weights = True def forward(self, x): if not self.sparsify_weights: subnet = self.get_mask() - w = (self.weight * self.scale + self.shift) * subnet - else: - w = self.weight - return F.linear(x, w, self.bias) - - -class SupermaskConv2d(nn.Conv2d): - """Supermask class for Conv2d layer""" - - def __init__( - self, - sparsity, - fixed_mask, - fixed_weight, - bitwidth, - transform, - fixed_transform, - *args, - **kwargs, - ): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskConv2d, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - ( - 1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()]) - ) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})", - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_( - self.scores, a=math.sqrt(5) - ) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift = nn.Parameter(torch.Tensor(1).fill_(0.0), requires_grad=False) - self.scale = nn.Parameter(torch.Tensor(1).fill_(1.0), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max - weights_min) / pow(2, bitwidth) - left_bound = weights_min - 1e-6 - right_bound = weights_min + least_step + 1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift = nn.Parameter( - torch.Tensor(1).fill_( - 0.0 if transform[0] is None else transform[0] - ), - requires_grad=not fixed_transform[0], - ) - self.scale = nn.Parameter( - torch.Tensor(1).fill_( - 1.0 if transform[1] is None else transform[1] - ), - requires_grad=not fixed_transform[1], - ) - for i in range(-int(pow(2, bitwidth - 1)), int(pow(2, bitwidth - 1))): - self.weight[ - torch.logical_and( - self.weight > left_bound, self.weight <= right_bound - ) - ] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def forward(self, x): - subnet = GetSubnet.apply( - self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity, - ) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - # if k == 1: continue - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - w = (self.weight * self.scale + self.shift) * subnet - return F.conv2d( - x, w, self.bias, self.stride, self.padding, self.dilation, self.groups - ) - - -def apply_supermask( - model, - linear_sparsity=0.0, - linear_sp_tilesize=1, - conv1x1_sparsity=0.0, - conv1x1_sp_tilesize=1, - conv_sparsity=0.0, - conv_sp_tilesize=1, - skip_last_layer_sparsity=False, - skip_first_transformer_sparsity=False, - device="cuda", - verbose=False, -): - sparsified_modules = {} - - for n, m in model.named_modules(): - # check conditions for skipping sparsity - if skip_last_layer_sparsity and n == "heads.head": - continue - if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: - continue - - # convert 1x1 convolutions - if ( - conv1x1_sparsity != 0.0 - and isinstance(m, torch.nn.Conv2d) - and m.kernel_size == (1, 1) - ): - new_m = SupermaskConv2d( - conv1x1_sparsity, - False, - False, - None, - None, - None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv1x1_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # convert all other convolutions (not tested!) - if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): - new_m = SupermaskConv2d( - conv_sparsity, - False, - False, - None, - None, - None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): - new_m = SupermaskLinear( - linear_sparsity, - False, - False, - None, - None, - None, - m.in_features, - m.out_features, - bias=m.bias is not None, - device=device, - tile_size=linear_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # add modules to model - for k, v in sparsified_modules.items(): - sm_name, ch_name = k.rsplit(".", 1) - sm = model.get_submodule(sm_name) - sm.add_module(ch_name, v) - - if verbose: - print( - f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}' - ) - - return model + # w = (self.weight*self.scale+self.shift) + w = ApplyMask.apply(self.weight, subnet) + return F.linear(x, w, self.bias) + return F.linear(x, self.weight, self.bias) + + @classmethod + def from_linear(cls, linear : torch.nn.Linear, sparsity_level:float=0.0, blocksize=1, inference=True): + module_new = None + + assert isinstance(linear, torch.nn.Linear) + module_new = SupermaskLinear( + sparsity_level, False, False, None, None, None, + linear.in_features, + linear.out_features, + bias=linear.bias is not None, + tile_size=blocksize, + ).to(device=linear.weight.device, dtype=linear.weight.dtype) + module_new.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + module_new.bias.data.copy_(linear.bias.data) + if inference: + module_new.sparsify_offline() + return module_new + + @classmethod + def to_linear(cls): + pass + diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 77ccd2c00b..96b74fdb70 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -4,13 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from torchao.quantization.quant_api import ( - int8_dynamic_activation_int8_semi_sparse_weight, -) +from .supermask import SupermaskLinear from .sparse_api import ( apply_fake_sparsity, semi_sparse_weight, + block_sparse_weight, sparsify_, ) from .utils import PerChannelNormObserver # noqa: F403 @@ -18,9 +17,10 @@ __all__ = [ "WandaSparsifier", + "SupermaskLinear", "PerChannelNormObserver", "apply_fake_sparsity", "sparsify_", "semi_sparse_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", + "block_sparse_weight", ] diff --git a/torchao/prototype/sparsity/superblock/blocksparse.py b/torchao/sparsity/blocksparse.py similarity index 57% rename from torchao/prototype/sparsity/superblock/blocksparse.py rename to torchao/sparsity/blocksparse.py index b5e8432949..d4e92cc940 100644 --- a/torchao/prototype/sparsity/superblock/blocksparse.py +++ b/torchao/sparsity/blocksparse.py @@ -1,41 +1,17 @@ from functools import partial -from typing import List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch -from torch.sparse._triton_ops import broadcast_batch_dims, bsr_dense_addmm from torch.utils._python_dispatch import return_and_correct_aliasing - from torchao.quantization.quant_api import _get_linear_subclass_inserter from torchao.utils import TorchAOBaseTensor -aten = torch.ops.aten - - -# quantization support -@torch.library.custom_op("blocksparse::bsr_to_dense", mutates_args=()) -def bsr_to_dense( - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - M: int, - K: int, -) -> torch.Tensor: - return torch.sparse_bsr_tensor( - crow_indices=crow_indices, col_indices=col_indices, values=values, size=(M, K) - ).to_dense() - +from torchao.kernel.bsr_triton_ops import bsr_dense_addmm, broadcast_batch_dims -@torch.library.register_fake("blocksparse::bsr_to_dense") -def bsr_to_dense_abstract( - crow_indices: torch.Tensor, - col_indices: torch.Tensor, - values: torch.Tensor, - M: int, - K: int, -) -> torch.Tensor: - return torch.empty((M, K), dtype=values.dtype, device=values.device) +aten = torch.ops.aten +# custom op definition @torch.library.custom_op("blocksparse::int_addmm", mutates_args=()) def blocksparse_int_addmm( crow_indices: torch.Tensor, @@ -110,20 +86,67 @@ def blocksparse_linear_abstract( return torch.empty(new_shape, dtype=A.dtype, device=A.device) +# bsr wrapper custom op +@torch.library.custom_op("blocksparse::addmm", mutates_args=()) +def blocksparse_addmm( + x_padded: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + row_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: + assert bias is None + weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + N_padded = x_padded.shape[1] + out = x_padded.new_empty((M, N_padded)) + bsr_dense_addmm( + out, + weight_bsr, + row_indices, + x_padded, + alpha=1, + beta=0, + out=out, + ) + return out + + +@torch.library.register_fake("blocksparse::addmm") +def blocksparse_addmm_abstract( + x_padded: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + row_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: + N_padded = x_padded.shape[1] + return x_padded.new_empty((M, N_padded)) + + # Subclass definition class BlockSparseTensor(TorchAOBaseTensor): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] + bsr_row_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] + blocksize: int - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_row_indices", "bsr_values"] @staticmethod def __new__( # noqa: PYI034 cls, shape: torch.Size, + blocksize: int, bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], + bsr_row_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], requires_grad: bool = False, ): @@ -141,46 +164,54 @@ def __new__( # noqa: PYI034 "requires_grad": requires_grad, } tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + tensor.blocksize = blocksize tensor.bsr_crow_indices = bsr_crow_indices - tensor.bsr_col_indices = bsr_col_indices tensor.bsr_values = bsr_values + tensor.bsr_col_indices = bsr_col_indices + tensor.bsr_row_indices = bsr_row_indices return tensor def __repr__(self) -> str: # type: ignore[override] assert hasattr(self, "shape") return f"{self.__class__.__name__}(shape={self.shape})" - def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]: + def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool, int]]: inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) - tensor_meta = (self.shape, self.requires_grad) + tensor_meta = (self.shape, self.requires_grad, self.blocksize) return inner_tensors, tensor_meta @classmethod def __tensor_unflatten__( cls, inner_tensors, - tensor_meta: Tuple[torch.Size, bool], + tensor_meta: Tuple[torch.Size, bool, int], outer_size, outer_stride, ) -> torch.Tensor: - shape, requires_grad = tensor_meta + shape, requires_grad, blocksize = tensor_meta return cls( shape=shape, + blocksize=blocksize, bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_row_indices=inner_tensors.get("bsr_row_indices", None), bsr_values=inner_tensors.get("bsr_values", None), requires_grad=requires_grad, ) + @classmethod def from_dense(cls, dense_tensor, blocksize): bsr_tensor = dense_tensor.to_sparse_bsr(blocksize) + bsr_tensor_t = dense_tensor.t().contiguous().to_sparse_bsr(blocksize) return cls( shape=dense_tensor.shape, + blocksize=blocksize, bsr_crow_indices=bsr_tensor.crow_indices(), bsr_col_indices=bsr_tensor.col_indices(), + bsr_row_indices=bsr_tensor_t.col_indices(), bsr_values=bsr_tensor.values(), requires_grad=False, ) @@ -188,13 +219,24 @@ def from_dense(cls, dense_tensor, blocksize): def apply_fn_to_shard(self, func): return BlockSparseTensor( shape=self.shape, + blocksize=self.blocksize, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), + bsr_row_indices=func(self.bsr_row_indices), bsr_values=func(self.bsr_values), requires_grad=self.requires_grad, ) + def dense(self): + return torch.sparse_bsr_tensor( + crow_indices=self.bsr_crow_indices, + col_indices=self.bsr_col_indices, + values=self.bsr_values, + size=self.shape, + ).to_dense() + + # Subclass op dispatch registration implements = BlockSparseTensor.implements @@ -206,6 +248,69 @@ def block_sparse_detach(func, types, args, kwargs): ) +@implements(aten.unsqueeze.default) +def block_sparse_unsqueeze(func, types, args, kwargs): + assert len(args) == 2 + assert len(kwargs) == 0 + assert args[-1] == 2 + bsr = args[0] + assert bsr.dim() == 2 + assert not bsr.requires_grad + return BlockSparseTensor(bsr.shape + (1,), + bsr.blocksize, + bsr.crow_indices(), + bsr.col_indices(), + bsr.values().unsqueeze(-1)) + + +@implements(aten.mul.Tensor) +def block_sparse_mul(func, types, args, kwargs): + assert len(args) == 2 + assert len(kwargs) == 0 + bsr, t = args + + def my_mul(bsr, t): + assert isinstance(bsr, BlockSparseTensor) + assert isinstance(t, torch.Tensor) + assert bsr.dim() == 3 + assert t.dim() == 3 + assert not bsr.requires_grad + assert t.size(0) == 1 + t_blocked = t.view(t.size(0), t.size(1) // 64, 64, 1) + masked_t = t_blocked.transpose(0, 1).index_select(0, bsr.col_indices()) + new_values = bsr.values() * masked_t + return BlockSparseTensor(bsr.shape, + bsr.blocksize, + bsr.crow_indices(), + bsr.col_indices(), + new_values) + + if isinstance(bsr, torch.Tensor) and isinstance(t, BlockSparseTensor): + return my_mul(t, bsr) + return my_mul(bsr, t) + + +@implements(aten.sum.dim_IntList) +def block_sparse_sum(func, types, args, kwargs): + bsr, dim = args + assert type(dim) == list + assert len(dim) == 1 + dim = dim[0] + bsr_dim = bsr.dim() + assert dim == 1 + out = torch.empty((bsr.shape[0], bsr.shape[2]), dtype=bsr.dtype, device=bsr.device) + crow_indices = bsr.crow_indices() + blocksize = bsr.blocksize + + for i in range(crow_indices.shape[0]-1): + start, stop = crow_indices[i], crow_indices[i+1] + temp_sum = bsr.values()[start:stop] + temp_sum = temp_sum.sum(dim=0).sum(dim=1) + out[i * blocksize : (i + 1) * blocksize] = temp_sum + + return out + + @implements(aten.values.default) def block_sparse_values(func, types, args, kwargs): return args[0].bsr_values.detach() @@ -220,6 +325,9 @@ def block_sparse_crow_indices(func, types, args, kwargs): def block_sparse_col_indices(func, types, args, kwargs): return args[0].bsr_col_indices.detach() +@implements(aten.row_indices.default) +def block_sparse_col_indices(func, types, args, kwargs): + return args[0].bsr_row_indices.detach() @implements(aten._nnz.default) def block_sparse__nnz(func, types, args, kwargs): @@ -228,13 +336,23 @@ def block_sparse__nnz(func, types, args, kwargs): @implements(torch.nn.functional.linear) def block_sparse_linear(func, types, args, kwargs): - x, w, bias = args - return torch.ops.blocksparse.linear( - x, w.crow_indices(), w.col_indices(), w.values(), w.shape[0], w.shape[1], bias + x_orig, w, bias = args + x = x_orig.reshape(-1, x_orig.size(-1)).t() + M = w.shape[0] + K = w.shape[1] + N = x.shape[1] + out = torch.ops.blocksparse.addmm( + x, + w.crow_indices(), + w.col_indices(), + w.row_indices(), + w.values(), + M, + K, + None, ) + out_orig = out.t() + if bias is None: + return out_orig - -def block_sparse_weight(blocksize=64): - return _get_linear_subclass_inserter( - partial(BlockSparseTensor.from_dense, blocksize=blocksize) - ) + return out_orig + bias diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 3dd7971525..3277518f87 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,9 +1,11 @@ from typing import Callable, Optional +from functools import partial import torch -from torch.ao.pruning import WeightNormSparsifier from torch.sparse import to_sparse_semi_structured +from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import WeightNormSparsifier +from torchao.sparsity.blocksparse import BlockSparseTensor from torchao.quantization.quant_api import ( _get_linear_subclass_inserter, _is_linear, @@ -31,6 +33,12 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.squash_mask() +def block_sparse_weight(blocksize=64): + return _get_linear_subclass_inserter( + partial(BlockSparseTensor.from_dense, blocksize=blocksize) + ) + + def semi_sparse_weight(): """ Convert the weight of linear moduels to semi-structured (2:4) sparsity diff --git a/torchao/sparsity/supermask.py b/torchao/sparsity/supermask.py new file mode 100644 index 0000000000..0f2fec55f3 --- /dev/null +++ b/torchao/sparsity/supermask.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import torch.nn as nn +import math +import torch +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np + +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + +# original supermask +scores_min=None +scores_max=9e9 + +def percentile(t, q): + """Return the value that is larger than q% of t""" + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + return t.view(-1).kthvalue(k).values + + +class GetSubnet(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, scores, zeros, ones, sparsity): + clamped_scores = scores.clamp(min=scores_min,max=scores_max) + k_val = percentile(clamped_scores, sparsity*100) + return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device)) + + @staticmethod + def backward(ctx, g): + return g, None, None, None + + +class ApplyMask(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, weight, scores): + return weight * scores + @staticmethod + def backward(ctx, grad_output): + grad_weight = grad_scores = None + if ctx.needs_input_grad[0]: + grad_weight = grad_output + if ctx.needs_input_grad[1]: + grad_scores = grad_output + return grad_weight, grad_scores + + +class SupermaskLinear(nn.Linear): + """Supermask class for Linear layer""" + def __init__(self, sparsity_level, blocksize, fixed_mask, fixed_weight, *args, **kwargs): + super(SupermaskLinear, self).__init__(*args, **kwargs) + # calculate the maximum sparsity given blocksize for the layer + max_sparsity_level = 1 - (1 / math.prod([math.ceil(k / blocksize) for k in self.weight.size()])) + self.sparsity_level = sparsity_level + if self.sparsity_level > max_sparsity_level: + print( + f"reducing sparsity from {self.sparsity} to {max_sparsity}", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {blocksize})" + ) + self.sparsity_level = max_sparsity_level + self.blocksize = blocksize + self.sparsify_weights = False + self.scores = nn.Parameter( + torch.empty( + [max(1, int(math.ceil(wn / blocksize))) for wn in self.weight.size()] + ), + requires_grad=not fixed_mask, + ) + nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) + + # NOTE: the previous implementation of Supermask supported quantizing the weights, this has been removed. + + self.weight.requires_grad = not fixed_weight + + def get_mask(self): + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity_level) + + if self.blocksize != 1: + for i, k in enumerate(self.weight.shape): + subnet = subnet.repeat_interleave(self.blocksize, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + return subnet + + + def forward(self, x): + subnet = self.get_mask() + w = ApplyMask.apply(self.weight, subnet) + return F.linear(x, w, self.bias) + + @classmethod + def from_linear(cls, linear, sparsity_level=0.0, blocksize=1, ): + """ + Main entrypoint for creating a SupermaskLinear from a Linear layer. + """ + assert isinstance(linear, torch.nn.Linear) + + supermask_linear = SupermaskLinear( + sparsity_level, blocksize, False, False, + linear.in_features, + linear.out_features, + bias=linear.bias is not None, + ).to(device=linear.weight.device, dtype=linear.weight.dtype) + supermask_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + supermask_linear.bias.data.copy_(linear.bias.data) + return supermask_linear + + @classmethod + def to_linear(cls, supermask_linear): + """ + Convert a SupermaskLinear to a Linear layer. + Replaces the old sparsify_offline() function. + """ + self = supermask_linear + + linear = torch.nn.Linear( + self.in_features, + self.out_features, + bias=self.bias is not None, + ).to(device=self.weight.device, dtype=self.weight.dtype) + + mask = self.get_mask() + linear.weight.data.copy_(self.weight * mask) + if self.bias is not None: + linear.bias.data.copy_(self.bias.data) + return linear