From 1c2e8165e68b87eebfe2e600e518164cfd5ae720 Mon Sep 17 00:00:00 2001 From: Zoey Sun Date: Wed, 12 Feb 2025 13:07:09 -0800 Subject: [PATCH] Small modifications to quantize_bench script (#3684) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/760 This Diff makes the following changes: 1. Move sanity checks on quant ops run ahead of time instead of inside benchmarking functions 2. Adds an output directory for storing benchmarking results. Reviewed By: jiawenliu64 Differential Revision: D69483115 --- .../gen_ai/bench/quantize_bench.py | 275 ++++++++++-------- 1 file changed, 146 insertions(+), 129 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 7eeff8a8c..95560ae63 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -6,7 +6,10 @@ import argparse import itertools + import os + +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import matplotlib.pyplot as plt @@ -96,13 +99,36 @@ def get_ldm_shapes() -> List[Tuple[int, int, int, int]]: ] +@dataclass +class Metrics: + op_name: str + + sim: float = 0.0 + ms: float = 0.0 + tflops: float = 0.0 + gbps: float = 0.0 + + def __str__(self) -> str: + return ( + "%s sim: %.3f.\n%s ms: %.3f. \n" "%s TFLOPS: %.3f. \n%s GB/s: %.3f." + ) % ( + self.op_name, + self.sim, + self.op_name, + self.ms, + self.op_name, + self.tflops, + self.op_name, + self.gbps, + ) + + def benchmark_grouped( quantize_ops: List[QuantizeOpBase], b: List[int], m: List[int], n: List[int], k: List[int], - kernels: Optional[List[str]] = None, bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, use_cuda_graph: bool = True, @@ -131,71 +157,60 @@ def benchmark_grouped( results: Dict[str, Any] = {"M": log_m, "N": log_n, "K": log_k, "groups": num_groups} # Benchmark each operator. for quantize_op in quantize_ops: - # If kernel filter is provided, skip kernels that arent requested. - kernel_requested = (kernels is None) or ( - kernels is not None and quantize_op.name in kernels - ) - # Also check if the operator is supported. - if kernel_requested and quantize_op.supported: - # Get the quantized tensors for this operator. - preprocessed_args = quantize_op.preprocess(A, B) - quantized_vals = quantize_op.quantize(*preprocessed_args) - # Compute the output given quantized values. - output = quantize_op.compute(*quantized_vals) - # Some kernels may pad output, just take the first m values of each row. - output = [o[: m[i]] for i, o in enumerate(output)] - # Compare the quantize op output to reference as a sanity check. - sim_check: float = 0 - for i in range(num_groups): - sim_check += float( - torch.mean(torch.pow(output[i] - out_ref[i], 2)).item() + metrics = Metrics(op_name=quantize_op.name) + # Get the quantized tensors for this operator. + preprocessed_args = quantize_op.preprocess(A, B) + quantized_vals = quantize_op.quantize(*preprocessed_args) + # Compute the output given quantized values. + output = quantize_op.compute(*quantized_vals) + # Some kernels may pad output, just take the first m values of each row. + output = [o[: m[i]] for i, o in enumerate(output)] + # Compare the quantize op output to reference as a sanity check. + + for i in range(num_groups): + metrics.sim += float( + torch.mean(torch.pow(output[i] - out_ref[i], 2)).item() + ) + + # Now perform benchmark. + if bench_quantize: + # Benchmark both quantize and compute. + with profiler_or_nullcontext(enabled=trace, with_stack=True): + metrics.ms = quantize_op.benchmark( + *preprocessed_args, + bench_quantize=True, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) + else: + with profiler_or_nullcontext(enabled=trace, with_stack=True): + metrics.ms = quantize_op.benchmark( + *quantized_vals, + bench_quantize=False, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, ) - # Now perform benchmark. - if bench_quantize: - # Benchmark both quantize and compute. - with profiler_or_nullcontext(enabled=trace, with_stack=True): - ms_runtime = quantize_op.benchmark( - *preprocessed_args, - bench_quantize=True, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) - else: - with profiler_or_nullcontext(enabled=trace, with_stack=True): - ms_runtime = quantize_op.benchmark( - *quantized_vals, - bench_quantize=False, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) - - # Print out results for this op. - tflops = 0 - gbps = 0 - for i in range(num_groups): - tflops += 2 * b[i] * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12 - gbps += ( - ( - quantized_vals[0][i][: m[i]].numel() - * quantized_vals[0][i][: m[i]].element_size() - + quantized_vals[1][i].numel() - * quantized_vals[1][i].element_size() - + output[i].numel() * output[i].element_size() - ) - / (ms_runtime / 1e3) - / 1e9 + # Print out results for this op. + for i in range(num_groups): + metrics.tflops += 2 * b[i] * m[i] * n[i] * k[i] / (metrics.ms / 1e3) / 1e12 + metrics.gbps += ( + ( + quantized_vals[0][i][: m[i]].numel() + * quantized_vals[0][i][: m[i]].element_size() + + quantized_vals[1][i].numel() * quantized_vals[1][i].element_size() + + output[i].numel() * output[i].element_size() ) - print(f"{quantize_op.name} sim: {sim_check:.3f}.") - print(f"{quantize_op.name} ms: {ms_runtime:.3f}.") - print(f"{quantize_op.name} TFLOPS: {tflops:.3f}.") - print(f"{quantize_op.name} GB/s: {gbps:.3f}.") + / (metrics.ms / 1e3) + / 1e9 + ) + print(metrics) - # Save results for this operator. - results[f"{quantize_op.name}_sim"] = sim_check - results[f"{quantize_op.name}_ms"] = ms_runtime - results[f"{quantize_op.name}_tflops"] = tflops - results[f"{quantize_op.name}_gb/s"] = gbps + # Save results for this operator. + results[f"{quantize_op.name}_sim"] = metrics.sim + results[f"{quantize_op.name}_ms"] = metrics.ms + results[f"{quantize_op.name}_tflops"] = metrics.tflops + results[f"{quantize_op.name}_gb/s"] = metrics.gbps return results @@ -206,7 +221,6 @@ def benchmark( m: int, n: int, k: int, - kernels: Optional[List[str]] = None, bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, use_cuda_graph: bool = True, @@ -226,66 +240,58 @@ def benchmark( results: Dict[str, Any] = {"B": b, "M": m, "N": n, "K": k} # Benchmark each operator. for quantize_op in quantize_ops: - # If kernel filter is provided, skip kernels that arent requested. - kernel_requested = (kernels is None) or ( - kernels is not None and quantize_op.name in kernels - ) - # Also check if the operator is supported. - if kernel_requested and quantize_op.supported: - # Preprocess data if needed. - preprocessed_args = quantize_op.preprocess(A, B) - # Get the quantized tensors for this operator. - quantized_vals = quantize_op.quantize(*preprocessed_args) - # Compute the output given quantized values. - output = quantize_op.compute(*quantized_vals) - # Compare the quantize op output to reference as a sanity check. - sim_check = torch.mean(torch.pow(output - out_ref, 2)) - - # Now perform benchmark. - if bench_quantize: - # Benchmark both quantize and compute. - with profiler_or_nullcontext(enabled=trace, with_stack=True): - ms_runtime = quantize_op.benchmark( - *preprocessed_args, - bench_quantize=True, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) - else: - with profiler_or_nullcontext(enabled=trace, with_stack=True): - ms_runtime = quantize_op.benchmark( - *quantized_vals, - bench_quantize=False, - use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=use_cuda_graph, - ) - - # Print out results for this op. - tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12 - gbps = ( - ( - quantized_vals[0].numel() * quantized_vals[0].element_size() - + quantized_vals[1].numel() * quantized_vals[1].element_size() - + output.numel() * output.element_size() + metrics = Metrics(op_name=quantize_op.name) + # Preprocess data if needed. + preprocessed_args = quantize_op.preprocess(A, B) + # Get the quantized tensors for this operator. + quantized_vals = quantize_op.quantize(*preprocessed_args) + # Compute the output given quantized values. + output = quantize_op.compute(*quantized_vals) + # Compare the quantize op output to reference as a sanity check. + metrics.sim = torch.mean(torch.pow(output - out_ref, 2)).item() + + # Now perform benchmark. + if bench_quantize: + # Benchmark both quantize and compute. + with profiler_or_nullcontext(enabled=trace, with_stack=True): + metrics.ms = quantize_op.benchmark( + *preprocessed_args, + bench_quantize=True, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, ) - / (ms_runtime / 1e3) - / 1e9 + else: + with profiler_or_nullcontext(enabled=trace, with_stack=True): + metrics.ms = quantize_op.benchmark( + *quantized_vals, + bench_quantize=False, + use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, + ) + + # Print out results for this op. + metrics.tflops = 2 * b * m * n * k / (metrics.ms / 1e3) / 1e12 + metrics.gbps = ( + ( + quantized_vals[0].numel() * quantized_vals[0].element_size() + + quantized_vals[1].numel() * quantized_vals[1].element_size() + + output.numel() * output.element_size() ) - print(f"{quantize_op.name} sim: {sim_check:.3f}.") - print(f"{quantize_op.name} ms: {ms_runtime:.3f}.") - print(f"{quantize_op.name} TFLOPS: {tflops:.3f}.") - print(f"{quantize_op.name} GB/s: {gbps:.3f}.") + / (metrics.ms / 1e3) + / 1e9 + ) + print(metrics) - # Save results for this operator. - results[f"{quantize_op.name}_sim"] = sim_check.item() - results[f"{quantize_op.name}_ms"] = ms_runtime - results[f"{quantize_op.name}_tflops"] = tflops - results[f"{quantize_op.name}_gb/s"] = gbps + # Save results for this operator. + results[f"{quantize_op.name}_sim"] = metrics.sim + results[f"{quantize_op.name}_ms"] = metrics.ms + results[f"{quantize_op.name}_tflops"] = metrics.tflops + results[f"{quantize_op.name}_gb/s"] = metrics.gbps return results -def plot_benchmark(results: List[Dict[str, Any]]) -> None: +def plot_benchmark(results: List[Dict[str, Any]], output_dir: str) -> None: """Create a barplot visualizing the TFLOPS of each kernel.""" # Reprocess into new dataframe with proper graph format. data = [] @@ -306,21 +312,26 @@ def plot_benchmark(results: List[Dict[str, Any]]) -> None: plt.yscale("log") ax = sns.barplot(x="MNK", y="TFLOPS", hue="kernel", data=df) ax.tick_params(axis="x", labelsize=3) - plot.savefig("quantize_ops_benchmark.png", dpi=300) + img_fn = os.path.join(output_dir, "quantize_ops_benchmark.png") + plot.savefig(img_fn, dpi=300) + + +def collect_kernels_to_profile(kernels: Optional[List[str]]) -> List[QuantizeOpBase]: + # Get existing quantization operators. + quantize_ops = get_quantize_ops() + quantize_ops = [op for op in quantize_ops if op.supported] + if kernels is None: + return quantize_ops + return [op for op in quantize_ops if op.name in kernels] def main(args: Any): if args.enable_amd_env_vars: set_amd_env_vars() - - # Get operators to quantize. - quantize_ops = get_quantize_ops() - - # If kernel filter is provided, parse it. - if args.kernels is not None: - kernels = args.kernels.strip().split(",") - else: - kernels = None + # If kernel filter is provided, parse it. Else, benchmark all kernels. + quantize_ops = collect_kernels_to_profile( + args.kernels.strip().split(",") if args.kernels else None + ) # Enumerate shapes to benchmark. if args.grouped and not args.groups: @@ -382,23 +393,29 @@ def main(args: Any): m, # pyre-ignore[6]: Incompatible parameter type [6] n, # pyre-ignore[6]: Incompatible parameter type [6] k, # pyre-ignore[6]: Incompatible parameter type [6] - kernels, args.bench_quantize, args.use_rotating_buffer_bench, not args.no_cuda_graph, args.trace, ) benchmark_results.append(quantize_measurements) + if args.export_csv or args.plot: + os.makedirs(args.output_dir, exist_ok=True) + print("csv and images will be saved to " + args.output_dir) if args.export_csv: + csv_file = os.path.join(args.output_dir, "quantize_ops_benchmark.csv") # Export results to a CSV file. df = pd.DataFrame(benchmark_results) - df.to_csv("quantize_ops_benchmark.csv", index=False) + df.to_csv(csv_file, index=False) if args.plot: - plot_benchmark(benchmark_results) + plot_benchmark(benchmark_results, args.output_dir) def invoke_main() -> None: parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", default="/tmp", help="Directory to save plots and csvs to" + ) parser.add_argument( "--export_csv", action="store_true",