Skip to content

Commit

Permalink
Add an option to read shapes from csv
Browse files Browse the repository at this point in the history
Summary:
Make 3 improvements:
1. Read shapes from a csv file.
2. Support bias in the gemm computation.
3. Add Hammer HSTU Triton Matmul Impl.

Reviewed By: nmacchioni

Differential Revision: D55768916

fbshipit-source-id: bb139311b98adcacf4e0383ad569b70865d789f1
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Apr 5, 2024
1 parent abf184d commit a513690
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 79 deletions.
176 changes: 106 additions & 70 deletions torchbenchmark/operators/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import statistics
from typing import Callable, Generator, List, Optional, Any

import csv
import numpy
import torch
import triton
Expand All @@ -14,40 +15,42 @@
register_metric,
)

from torchbenchmark import REPO_PATH
from hammer.ops.triton.triton_matmul import triton_matmul as hstu_triton_matmul
from .triton_matmul import matmul as triton_matmul

BUILDIN_SHAPES = [
(256, 256, 256),
(384, 384, 384),
(512, 512, 512),
(640, 640, 640),
(768, 768, 768),
(896, 896, 896),
(1024, 1024, 1024),
(1152, 1152, 1152),
(1280, 1280, 1280),
(1408, 1408, 1408),
(1536, 1536, 1536),
(1664, 1664, 1664),
(1792, 1792, 1792),
(1920, 1920, 1920),
(2048, 2048, 2048),
(2176, 2176, 2176),
(2304, 2304, 2304),
(2432, 2432, 2432),
(2560, 2560, 2560),
(2688, 2688, 2688),
(2816, 2816, 2816),
(2944, 2944, 2944),
(3072, 3072, 3072),
(3200, 3200, 3200),
(3328, 3328, 3328),
(3456, 3456, 3456),
(3584, 3584, 3584),
(3712, 3712, 3712),
(3840, 3840, 3840),
(3968, 3968, 3968),
(4096, 4096, 4096),
(256, 256, 256, None),
(384, 384, 384, None),
(512, 512, 512, None),
(640, 640, 640, None),
(768, 768, 768, None),
(896, 896, 896, None),
(1024, 1024, 1024, None),
(1152, 1152, 1152, None),
(1280, 1280, 1280, None),
(1408, 1408, 1408, None),
(1536, 1536, 1536, None),
(1664, 1664, 1664, None),
(1792, 1792, 1792, None),
(1920, 1920, 1920, None),
(2048, 2048, 2048, None),
(2176, 2176, 2176, None),
(2304, 2304, 2304, None),
(2432, 2432, 2432, None),
(2560, 2560, 2560, None),
(2688, 2688, 2688, None),
(2816, 2816, 2816, None),
(2944, 2944, 2944, None),
(3072, 3072, 3072, None),
(3200, 3200, 3200, None),
(3328, 3328, 3328, None),
(3456, 3456, 3456, None),
(3584, 3584, 3584, None),
(3712, 3712, 3712, None),
(3840, 3840, 3840, None),
(3968, 3968, 3968, None),
(4096, 4096, 4096, None),
]


Expand All @@ -56,38 +59,63 @@ def parse_args(args: List[str]) -> argparse.Namespace:
parser.add_argument("--m", default=8, type=int)
parser.add_argument("--k", default=8, type=int)
parser.add_argument("--n", default=8, type=int)
parser.add_argument("--input", default=None, type=str)
args = parser.parse_args(args)
return args


def read_shapes_from_csv(csv_path: str) -> List[List[int]]:
input_file_path = os.path.join(REPO_PATH, "torchbenchmark", "operators", "gemm", csv_path)
shapes = []
with open(input_file_path, "r") as f:
reader = csv.reader(f)
_header = next(reader)
for row in reader:
shapes.append([ int(x) if x else None for x in row])
return shapes

class Operator(BenchmarkOperator):
USE_BUILTIN_SHAPES = True
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
if not extra_args:
self.USE_BUILTIN_SHAPES = True
super().__init__(mode=mode, device=device, extra_args=extra_args)
if not self.extra_args:
self.DEFAULT_NUM_BATCH = len(BUILDIN_SHAPES)
self.extra_builtin_metrics = ["speedup", "accuracy"]
else:
self.USE_BUILTIN_SHAPES = False
self.DEFAULT_NUM_BATCH = 1
self.tbargs = parse_args(self.extra_args)
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.required_metrics = list(
set(self.required_metrics + self.extra_builtin_metrics)
)
if self.tbargs.input:
self.shapes = read_shapes_from_csv(self.tbargs.input)
else:
self.shapes = [(self.tb_args.m, self.tbargs.k, self.tbargs.n)]
self.DEFAULT_NUM_BATCH = len(self.shapes)


@register_benchmark()
def triton_matmul(self, a, b) -> Callable:
return lambda: triton_matmul(a, b)
def triton_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: triton_matmul(a, b) + bias
else:
return lambda: triton_matmul(a, b)


@register_benchmark(baseline=True)
def aten_matmul(self, a, b) -> Callable:
return lambda: torch.matmul(a, b)
def aten_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: torch.matmul(a, b) + bias
else:
return lambda: torch.matmul(a, b)


@register_benchmark()
def hstu_triton_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: hstu_triton_matmul(a, b) + bias
else:
return lambda: hstu_triton_matmul(a, b)

def get_x_val(self, example_inputs) -> float:
# x-value: computation intensity
a, w = example_inputs
a, w, bias = example_inputs
m, k = a.size()
k, n = w.size()
# computation intensity for the shape m, n, k
Expand All @@ -96,44 +124,48 @@ def get_x_val(self, example_inputs) -> float:

@register_metric()
def gbps(self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float:
a, w = example_inputs
a, w, bias = example_inputs
numel = a.numel() + w.numel() + (torch.mm(a, w).numel())
numel = numel * a.element_size() / 1e9
gbps = list(map(lambda x: numel / x * 1e3, metrics.latency))
return statistics.median(gbps)

@register_metric(skip_baseline=True)
def xShape(self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> list[int]:
a, w = example_inputs
a, w, bias = example_inputs
m, k = a.size()
k, n = w.size()
if not bias == None:
return [m, k, n, bias.size()[0]]
return [m, k, n]

@register_metric()
def tflops(self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float:
a, w = example_inputs
a, w, bias = example_inputs
m, k = a.size()
k, n = w.size()
flops = m * k * 2 * n
latency = numpy.median(metrics.latency)
return flops / latency / 1e12 * 1e3
if not bias == None:
flops = m * k * 2 * n + 2 * m * n
else:
flops = m * k * 2 * n
return [flops / x / 1e12 * 1e3 for x in metrics.latency]

def get_input_iter(self) -> Generator:
if self.USE_BUILTIN_SHAPES:
for shape in BUILDIN_SHAPES:
m, k, n = shape
a = torch.randn(
(m, k), device=self.device, dtype=torch.float16
).requires_grad_(False)
w = torch.randn(
(k, n), device=self.device, dtype=torch.float16
for shape in self.shapes:
m, k, n, bias = shape
a = torch.randn(
(m, k), device=self.device, dtype=torch.float16
).requires_grad_(False)
w = torch.randn(
(k, n), device=self.device, dtype=torch.float16
).requires_grad_(False)
if not bias == None:
bias = torch.randn(
(bias), device=self.device, dtype=torch.float16
).requires_grad_(False)
yield a, w
while True:
yield None
else:
meta_tensor = torch.randn((self.tbargs.m, self.tbargs.k), device="meta")
yield torch.randn_like(meta_tensor, device=self.device).requires_grad(False)
yield a, w, bias
while True:
yield None

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
Expand All @@ -156,20 +188,24 @@ def plot(self):
x_vals=self.output.x_vals, # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"aten_matmul",
"triton_matmul",
"hstu_triton_matmul",
], # possible values for `line_arg``
line_names=[
"ATen GEMM",
"Triton GEMM",
"HSTU Triton GEMM",
], # label name for the lines
styles=[("blue", "-"), ("green", "-")], # line styles
ylabel="speedup", # label name for the y-axis
styles=[("blue", "-"), ("green", "-"), ("red", "-")], # line styles
ylabel="tflops", # label name for the y-axis
plot_name="gemm-performance", # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def _plot(density, provider):
speedup = self.output.get_y_vals(density, provider, "speedup")
return speedup
tflops = self.output.get_y_vals(density, provider, "tflops")
return tflops

save_path = "/tmp/test_gemm"

Expand Down
59 changes: 59 additions & 0 deletions torchbenchmark/operators/gemm/amd.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
M,K,N,Bias
152710,512,2048,2048
64896,512,2048,2048
6,1536,384,
6,768,3072,
6144,384,384,
462,384,768,
462,384,4096,
462,384,1472,
6144,3072,384,
6144,384,1536,
6,1536,3072,
1536,768,768,
462,768,768,
462,768,4096,
462,768,1472,
1536,6144,768,
1536,768,3072,
6,3072,3072,
384,1536,1536,
462,1536,768,
462,1536,4096,
462,1536,1472,
384,12288,1536,
384,1536,6144,
96,1536,1536,
96,12288,1536,
96,1536,6144,
256,256,256,
384,384,384,
512,512,512,
640,640,640,
768,768,768,
896,896,896,
1024,1024,1024,
1152,1152,1152,
1280,1280,1280,
1408,1408,1408,
1536,1536,1536,
1664,1664,1664,
1792,1792,1792,
1920,1920,1920,
2048,2048,2048,
2176,2176,2176,
2304,2304,2304,
2432,2432,2432,
2560,2560,2560,
2688,2688,2688,
2816,2816,2816,
2944,2944,2944,
3072,3072,3072,
3200,3200,3200,
3328,3328,3328,
3456,3456,3456,
3584,3584,3584,
3712,3712,3712,
3840,3840,3840,
3968,3968,3968,
4096,4096,4096,
Loading

0 comments on commit a513690

Please sign in to comment.