From 2d2dbe15593bc59f11fafe3b069f615786be608c Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 10 May 2024 10:48:42 -0500 Subject: [PATCH 01/12] Add Perf Kernels Add Perf Kernels This is a combination of 2 commits. Add Perf Kernels Add Perf Kernels This is a combination of 6 commits. add perf-kernels fix formating issues fix unused variables and other bugs fix other issues remove scripts save check changes format save save try pre-commit check save --- .github/workflows/amd_perf_kernel_tests.yml | 133 ++ .../03-matrix-multiplication-all-types.py | 377 ++++ .../03-matrix-multiplication-stream-k.py | 395 +++++ python/perf-kernels/06-attention-decode.py | 730 ++++++++ .../06-fused-attention-fwd-transV.py | 308 ++++ .../perf-kernels/06-fused-attention-transV.py | 928 ++++++++++ python/perf-kernels/README.md | 63 + python/perf-kernels/flash-attention.py | 1527 +++++++++++++++++ python/perf-kernels/hbm-bw-test.py | 200 +++ ...trix-multiplication-stream-k-oldversion.py | 485 ++++++ ...iplication-stream-k-singlekern-autotune.py | 563 ++++++ ...ultiplication-stream-k-singleloop-nomod.py | 387 +++++ 12 files changed, 6096 insertions(+) create mode 100644 .github/workflows/amd_perf_kernel_tests.yml create mode 100644 python/perf-kernels/03-matrix-multiplication-all-types.py create mode 100755 python/perf-kernels/03-matrix-multiplication-stream-k.py create mode 100644 python/perf-kernels/06-attention-decode.py create mode 100644 python/perf-kernels/06-fused-attention-fwd-transV.py create mode 100644 python/perf-kernels/06-fused-attention-transV.py create mode 100644 python/perf-kernels/README.md create mode 100644 python/perf-kernels/flash-attention.py create mode 100644 python/perf-kernels/hbm-bw-test.py create mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py create mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py create mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py diff --git a/.github/workflows/amd_perf_kernel_tests.yml b/.github/workflows/amd_perf_kernel_tests.yml new file mode 100644 index 000000000000..07424924a832 --- /dev/null +++ b/.github/workflows/amd_perf_kernel_tests.yml @@ -0,0 +1,133 @@ +name: AMD Perf Kernel Tests + +on: + workflow_dispatch: + pull_request: + branches: [main_perf] + merge_group: + branches: [main_perf] + types: [checks_requested] + push: + branches: [main_perf] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main_perf' }} + +permissions: read-all + +env: + TRITON_BUILD_WITH_CLANG_LLD: "TRUE" + TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" + TRITON_DISABLE_LINE_INFO: 1 + +jobs: + Check-File-Changes: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Check file changes + run: | + git fetch origin ${{ github.base_ref }} + changed_files=$(git diff --name-only origin/${{ github.base_ref }} ${{ github.sha }}) + echo "Changed files:" + echo "$changed_files" + if echo "$changed_files" | grep -v "^python/perf-kernels/"; then + echo "Changes detected outside of the python/perf-kernels directory. Failing the workflow." + exit 1 + fi + + Runner-Preparation-AMD: + runs-on: ubuntu-latest + timeout-minutes: 30 + outputs: + matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} + steps: + - name: Prepare runner matrix + id: set-matrix + run: | + if [ x"${{ github.repository }}" == x"ROCm/triton" ]; then + echo '::set-output name=matrix-HIP::[["self-hosted", "rocm.gfx90a"]]' + else + echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' + fi + + pre-commit: + name: pre-commit (code formatting) + needs: Runner-Preparation-AMD + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + - name: Compute hash of pre-commit config + id: cache-key + run: | + echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml)" >> $GITHUB_OUTPUT + shell: bash + - name: Cache pre-commit's cache dir + uses: actions/cache@v4 + with: + # Note that we cannot use environment variables here given there is + # no shell to interpret them in the paths. + path: | + ~/.cache/pre-commit + key: ${{ runner.os }}-${{ steps.cache-key.outputs.pre_commit_hash }} + - name: Check pre-commit + run: | + python3 -m pip install --upgrade pre-commit + # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed + python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true + # If first run of yapf worked and made changes reset the tree to the original state + git reset --hard + python3 -m pre_commit run --all-files --verbose + - name: Print diff of changes if pre-commit failed + if: failure() + run: | + git diff + + Integration-Tests-AMD: + needs: Runner-Preparation-AMD + if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} + container: + image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Clear cache + run: | + rm -rf ~/.triton + mkdir -p ~/.triton + ls -alh ~/.triton + - name: Update PATH + run: | + echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH + - name: Install pip dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install lit matplotlib pandas + - name: Install Triton + run: | + echo "PATH is '$PATH'" + pip uninstall -y triton + cd python + pip install -v -e . + - name: Run Perf Kernels Unit Tests + run: | + pytest -vvv ./python/perf-kernels/flash-attention.py + - name: Run Perf Kernels Benchmark + run: | + python ./python/perf-kernels/flash-attention.py diff --git a/python/perf-kernels/03-matrix-multiplication-all-types.py b/python/perf-kernels/03-matrix-multiplication-all-types.py new file mode 100644 index 000000000000..1b0676079ede --- /dev/null +++ b/python/perf-kernels/03-matrix-multiplication-all-types.py @@ -0,0 +1,377 @@ +import torch + +import triton +import triton.language as tl +import sys +import argparse +import pytest +import re + + +@triton.autotune( + configs=[ + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + ], + key=['M', 'N', 'K'], + use_cuda_graph=True, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(c_ptr.type.element_ty) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, c, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + # assert a.is_contiguous(), "Matrix A must be contiguous" + # assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ACTIVATION=activation, + ) + + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8e4': tl.float8e4b8, + 'fp8e5': tl.float8e5b16, +} + + +def gen_input(M, N, ty_name, needTrans, seed, device='cuda'): + d_type = name_to_tl_types[ty_name] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + if needTrans: + raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T + else: + raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') + # avoid type conversion rounding errors of subnormal values + raw_data += 0.1 + if d_type == tl.float8e4b8: + raw_data += torch.sign(raw_data) + + if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ + (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): + input = raw_data.to(tl_to_torch_types[d_type]) + input_f16 = input.to(torch.float16) + else: + f8_tensor = raw_data.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + n_elements = raw_data.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + + return input, input_f16 + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., rocBLAS). +def get_x_vals(): + x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range(1, 9)] + + x_vals += [(4864, 4096, 8192), (9728, 8192, 65536)] + + return x_vals + + +@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype, col_a, col_b", [ + (*shape, in_dtype, out_dtype, col_a, col_b) + for shape in get_x_vals() + for in_dtype, out_dtype in [('fp16', 'fp16'), ('bf16', 'bf16'), ('fp16', + 'fp32'), ('fp32', + 'fp32'), ('fp8e4', + 'fp16'), ('fp8e5', 'fp16'), + #('int8', 'int8'), + ('int8', 'int32')] + # Only test k-major tensors because + # 1. This is the most preformant config and the current focus + # 2. Other case does not work with num_stages=0 (TODO (zhanglx)) + for col_a in [True, False] + for col_b in [True, False] +]) +def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype): + a, a_fp16 = gen_input(M, K, in_dtype, col_a, 1, device='cuda') + b, b_fp16 = gen_input(K, N, in_dtype, col_b, 2, device='cuda') + # Allocates output. + tl_out_dtype = name_to_tl_types[out_dtype] + torch_out_dtype = tl_to_torch_types[tl_out_dtype] + c = torch.empty((M, N), device=a.device, dtype=torch_out_dtype) + matmul(a, b, c, activation="") + if in_dtype == 'fp8e4' or in_dtype == 'fp8e5' or in_dtype == 'int8': + # For f8 and int8 inputs, use fp16 for torch.matmul + torch_output = torch.matmul(a_fp16, b_fp16) + else: + torch_output = torch.matmul(a, b) + #print(f"triton_output={c}") + #print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + if in_dtype == 'int8': + torch.testing.assert_close(c.to(torch.float16), torch_output, atol=1e-3, rtol=rtol) + else: + torch.testing.assert_close(c, torch_output.to(torch_out_dtype), atol=5e-3, rtol=rtol) + + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of rocBLAS. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + + +def get_type(provider): + res = re.findall(r'\(.*?\)', provider) + return res[0][1:-1] + + +inout_dtype = { + 'int8': torch.int8, + 'fp16': torch.float16, + 'fp32': torch.float32, + 'bf16': torch.bfloat16, + 'fp8e4': torch.float16, + 'fp8e5': torch.float16, +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot + x_vals=get_x_vals(), + line_arg='provider', # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=[ + 'rocblas(fp16)', 'rocblas(bf16)', 'triton(fp16)', 'triton(bf16)', 'triton(int8)', 'triton(fp8e4)', + 'triton(fp8e5)' + ], + # Label name for the lines + line_names=[ + "rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16", "Triton.Bf16", "Triton.Int8", "Triton.Fp8E4", "Triton.Fp8E5" + ], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + in_dtype = get_type(provider) + out_dtype = inout_dtype[in_dtype] + + quantiles = [0.5, 0.2, 0.8] + if 'rocblas' in provider: + a = torch.randn((M, K), dtype=tl_to_torch_types[name_to_tl_types[in_dtype]], device='cuda') + b = torch.randn((K, N), dtype=tl_to_torch_types[name_to_tl_types[in_dtype]], device='cuda') + + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + else: # triton, different data types + assert "triton" in provider + a, _ = gen_input(M, K, in_dtype, False, 1, device='cuda') + b, _ = gen_input(K, N, in_dtype, True, 2, device='cuda') + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=out_dtype) + + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, activation=""), quantiles=quantiles) + global verbose + if verbose: + print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})') + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="GEMM tutorial example", + allow_abbrev=False, + ) + + parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") + args = parser.parse_args() + + return args + + +def main(): + # assign to a global verbose var to indicate whether print + # best tuning config + global verbose + args = parse_args() + verbose = args.v + benchmark.run(show_plots=True, print_data=True) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/03-matrix-multiplication-stream-k.py b/python/perf-kernels/03-matrix-multiplication-stream-k.py new file mode 100755 index 000000000000..62d820719b9a --- /dev/null +++ b/python/perf-kernels/03-matrix-multiplication-stream-k.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") + +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +@triton.jit() +def streamk_gemm( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_full_tiles_streamk, + total_partial_tiles_streamk, + iters_per_tile, + total_tiles_streamk, + total_programs_streamk, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) + + # Determine whether we are in the first wave or full_tiles phase based on pid + is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 + + # Calculate starting and ending iterations for first wave + if not is_first_wave: + tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += BLOCK_K * stride_ak + B_BASE += BLOCK_K * stride_bk + # acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers +# rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) +# rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, acc) + else: + # start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for current_iter in range(start_iter, end_iter): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += BLOCK_K * stride_ak + B_BASE += BLOCK_K * stride_bk + + if remainder == 0 and end_iter % iters_per_tile == 0: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + start_iter = end_iter + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, + two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + GROUP_M = 4 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{total_full_tiles_streamk=}") + print(f"{total_partial_tiles_streamk=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + # allocates locks to sync work accross SMs + grids = total_programs_streamk + total_blocking_tiles + kk = streamk_gemm[(grids, )]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, + total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, + total_tiles_streamk=total_tiles_streamk, + total_programs_streamk=total_programs_streamk, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + ) + if matmul._debug: + print(f"{kk.n_regs} registers used, {kk.n_spills} spills") + + # print(kk.asm['ttgir']) + # print(kk.asm['amdgcn']) + + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, + num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, kpack=kpack) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +#m, n, k = 4864, 4096, 8256 # some problem size to test +#m, n, k = 4096, 4096, 8192 # some problem size to test +#m, n, k = 8192, 8192, 8192 # some problem size to test +m, n, k = 6912, 768, 256 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +BLK_M = 64 +BLK_N = 64 +BLK_K = 64 +two_tiles = 'True' +num_stages = 0 +num_warps = 4 +waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 2 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, + kpack) +#exit(0) +matmul.set_debug(False) +expected = A @ B + +#assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" +print("pass validation test") + +# for debugging, uncomment the following line +# exit(0) + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, + waves_per_eu, mfmaInstrSize, kpack)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // BLK_M) * (n // BLK_N) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench( + lambda: wrapper_matmul(A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/06-attention-decode.py b/python/perf-kernels/06-attention-decode.py new file mode 100644 index 000000000000..3f38e5031eca --- /dev/null +++ b/python/perf-kernels/06-attention-decode.py @@ -0,0 +1,730 @@ +from typing import Optional +import pytest +import torch +import sys + +import triton +import triton.language as tl + + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + splitk_idx = tl.program_id(2) + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) + q = (q * qk_scale).to(q.dtype) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p, v) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (0, BLOCK_N)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (BLOCK_N, 0)) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0, ), + ) + # Write metadata for split-K reduction + Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + + tl.arange(0, BLOCK_M)) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load(K_scale_shift_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v_scale_shift = tl.load(V_scale_shift_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + use_mask: tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) + + o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + + off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +class _attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } + SUPPORTED_MAX_K = 128 + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.empty([B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device) + metadata = torch.empty([B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + + num_warps = 1 + split_size = (Mk + split_k - 1) // split_k + use_seq_len = seq_len is not None + + # print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=num_warps, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + ) + + if mqa_swap_seqlen_head: + out = torch.empty((B, H, G, M, Kq), device=q.device, dtype=q.dtype).transpose(1, 3) + else: + out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + grid = (B * G * H, M, k_block_num) + _splitK_reduce[grid]( + o_splitk, metadata, out, lse, **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), M_ceil=M_ceil, BLOCK_SIZE=k_block_size, G=G, H=H, + # TODO: Tune num_warps + split_k=split_k, splitK_pow2=splitK_pow2, use_mask=use_mask, num_warps=4) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + + +attention = _attention.apply + + +def get_input_shapes(): + cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] + + return cases + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) +def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(20) + q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + scale = 1 / K**0.5 + tri_out = attention(q, k, v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0) + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) +def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(2) + q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, + device="cuda").normal_(mean=1.0, std=0.5).requires_grad_()) + k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + + num_groups = 1 + quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32)) + quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32)) + scale = 1 / K**0.5 + tri_out = attention(q, quant_k, quant_v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) + + # since quantization introduces rounding error, use the + # dequantized kv as inputs to the ref implementation to reduce + # the tolerance to 1e-3 + dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) + dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) + dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) + dq_ref_out = dq_attn @ dqv + torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + +def test_quantization(): + a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') + qa = quantize_kv_int4(a, num_groups=4) + dqa = dequantize_kv_fp16(qa, num_groups=4) + torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) + + +try: + FLASH_VER = 2 +except BaseException: + try: + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +configs = [] +for mode in ['fwd']: + # for D_HEAD in [128]: + for causal in [False]: + configs.append( + triton.testing.Benchmark( + x_names=['B', 'Mq', 'Mkv', 'Hq', 'Hkv', 'K'], x_vals=get_input_shapes(), line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), styles=[('red', '-'), + ('blue', '-')], + ylabel='ms', plot_name=f'fused-attention-d{128}-{mode}-causal={causal}', args={ + # 'D_HEAD': D_HEAD, + 'dtype': torch.float16, 'mode': mode, 'causal': causal + })) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 100 + rep = 400 + ms = 0 + if provider == "triton": + q = torch.randn([B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=False) + k = torch.randn([B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, + requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1) + v = torch.randn([B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, + requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1) + + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + # flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K) + # total_flops = 2 * flops_per_matmul + # totalBytes = ((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2 + + # return totalBytes / ms * 1e-9 + return ms * 1000 + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/06-fused-attention-fwd-transV.py b/python/perf-kernels/06-fused-attention-fwd-transV.py new file mode 100644 index 000000000000..53517a395c8d --- /dev/null +++ b/python/perf-kernels/06-fused-attention-fwd-transV.py @@ -0,0 +1,308 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math + +""" + +import pytest +import torch +import sys + +import triton +import triton.language as tl + +# Pick the fp8 data type + +# AMD E5M2B16 +# float8:tl.constexpr = torch.float8_e5m2fnuz + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8: tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + sm_scale, + M, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr(base=Q + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qkv_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(0, 1)) + # initialize offsets + # offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr(base=Out + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + if torch.version.hip is None: + BLOCK_M = 128 + BLOCK_N = 64 if Lk <= 64 else 32 + num_stages = 4 if Lk <= 64 else 3 + num_warps = 4 if Lk <= 64 else 8 + + ## hardcoded best perf_configs for MI250 + if Lk == 64: + ## D_HEAD = 64 + BLOCK_M = 128 + BLOCK_N = 64 + waves_per_eu = 3 + num_warps = 4 + num_stages = 1 + ## causal=False likes to pre load v but causal=True does not + pre_load_v = False if causal else True + slice_k_tile = 32 + kpack = 1 + else: + ## D_HEAD = 128 + ## For fp16, pick BLOCK_M=256, num_warps=8 + ## For fp8, pick BLOCK_M=128, num_warps=4 + ## TODO (zhanglx): add tuning infra for FA + BLOCK_M = 128 if TORCH_HAS_FP8E4 and q.dtype == torch.float8_e4m3fnuz else 256 + BLOCK_N = 128 + waves_per_eu = 2 + num_warps = BLOCK_M // 32 + num_stages = 1 + pre_load_v = False + slice_k_tile = 32 + kpack = 1 + + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[grid]( + q, + k, + v, + sm_scale, + M, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + waves_per_eu=waves_per_eu, + num_warps=num_warps, + num_stages=num_stages, + pre_load_v=pre_load_v, + slice_k_tile=slice_k_tile, + kpack=kpack, + ) + + return o + + +attention = _attention.apply + +name_to_torch_types = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp8': float8} + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype', + [(*shape, dtype) + for shape in [(4, 48, 1024, 128), (4, 48, 2048, 128), (4, 48, 4096, 128)] + for dtype in ['fp16', 'bf16', 'fp8']]) +def test_op_fwd(Z, H, N_CTX, D_HEAD, dtype): + torch.manual_seed(20) + init_dtype = torch.float16 if dtype == 'fp8' else name_to_torch_types[dtype] + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + v = (torch.empty((Z, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + sm_scale = 0.5 + # reference implementation + # M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + p = torch.softmax(p.float(), dim=-1).to(q.dtype) + ref_out = torch.matmul(p, v.transpose(2, 3)) + # triton implementation + # q,k casting for partial fp8 + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + # dout = torch.randn_like(q, dtype=torch.float16) + tri_out = attention(q, k, v, sm_scale) + # compare + atol = 1.4e-1 if dtype == 'fp8' else 1e-2 + rtol = 1e-2 if dtype == 'fp8' else 3e-3 + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol) + + +try: + FLASH_VER = 2 +except BaseException: + try: + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +# vary seq length for fixed head and batch=4 +configs = [] +for dtype in ['fp16', 'bf16', 'fp8']: + for D_HEAD in [128]: + for causal in [False]: + configs.append( + triton.testing.Benchmark( + x_names=['BATCH', 'H', 'N_CTX'], x_vals=[ + (16, 16, 1024), + (8, 16, 2048), + (4, 16, 4096), + (2, 16, 8192), + (1, 16, 16384), + (4, 48, 1024), + (4, 48, 2048), + (4, 48, 4096), + (4, 48, 8192), + (4, 48, 16384), + ], line_arg='provider', line_vals=['triton'], line_names=['Triton'], + #styles=[('red', '-'), ('blue', '-')], + ylabel='ms', plot_name=f'fused-attention-fwd-d{D_HEAD}-causal={causal}-{dtype}', + args={'D_HEAD': D_HEAD, 'dtype': dtype, 'causal': causal})) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, provider, dtype, device="cuda"): + if dtype == 'fp8' and not TORCH_HAS_FP8E4: + sys.exit("fp8 is not available") + warmup = 25 + rep = 100 + init_dtype = torch.float16 if dtype != 'bf16' else torch.bfloat16 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + # q,k casting for partial fp8 + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + fn = lambda: attention(q, k, v, sm_scale) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + return total_flops / ms * 1e-9 + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/06-fused-attention-transV.py b/python/perf-kernels/06-fused-attention-transV.py new file mode 100644 index 000000000000..60113d3aa17d --- /dev/null +++ b/python/perf-kernels/06-fused-attention-transV.py @@ -0,0 +1,928 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math + +""" + +import pytest +import torch + +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype: tl.constexpr = torch.float8_e5m2fnuz + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + N_CTX, + pre_load_v: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # causal = False + else: + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = tl.where(mask, qk, float("-inf")) + qk += tl.dot(q, k) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, + num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=4), # d64-True + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, + num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=4), # d64-True + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, + num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=4), # d64-True + ], + key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], +) +@triton.jit +def _attn_fwd( + Q, + K, + V, + sm_scale, + M, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + STAGE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr(base=Q + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qkv_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(0, 1)) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(q.dtype) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + N_CTX, + pre_load_v, + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compiler to schedule the + # two loops independently + tl.debug_barrier() + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + 2, + offs_m, + offs_n, + N_CTX, + pre_load_v, + ) + # epilogue + # write back m + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr(base=Out + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _bwd_preprocess( + Out, + DO, + NewDO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + P_SEQ, + num_block_q, + num_block_kv, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + qk_scale = sm_scale * 1.44269504 + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_kz + off_h * stride_kh + DV += off_z * stride_vz + off_h * stride_vh + # See fwd pass above for explanation. + qk_scale = sm_scale * 1.44269504 + for start_n in range(0, num_block_kv): + if CAUSAL: + lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0) + else: + lo = 0 + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[None, :] * stride_qm + offs_k[:, None] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + # initialize dk amd dv + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + if CAUSAL: + qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk * qk_scale - l_i[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(Q.dtype.element_ty), k) + tl.store(dq_ptrs, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + tl.store(dk_ptrs, dk) + tl.store(dv_ptrs, dv) + + +@triton.jit +def _bwd_kernel_dk_dv( + Q, + K, + V, + sm_scale, + Out, + DO, + DK, + DV, + L, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # Q is consumed depending on block ID. Every block uses + # previous block offset by BLOCK_M x D_HEAD. + qvk_offset = off_hz * stride_qh + qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # offs_d = tl.arange(0, BLOCK_DMODEL) + # Initialize pointers to Q, K, V + Q_block_ptr = tl.make_block_ptr(base=Q + qdo_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, start_m * BLOCK_M), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_vn, stride_vk), + offsets=(0, start_m * BLOCK_M), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO + qdo_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0)) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + qk_scale = sm_scale * 1.44269504 + # load k and v: they will stay in SRAM throughout + k = tl.load(K_block_ptr) + k = (k * qk_scale).to(k.dtype) + v = tl.load(V_block_ptr) + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # This lower loop bound is because of the causal mask. We create a lower triangular + # result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can + # be ignored in the GEMM. + lo = start_m * BLOCK_M + hi = N_CTX + # loop over q, do + for start_n in range(lo, hi, BLOCK_N): + offs_m_curr = offs_n[:, None] + start_n + # -- load q, do -- + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + # -- compute qk ---- + qk = tl.dot(q, k) + qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf")) + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i) + # -- compute dv ---- + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di + dp += tl.dot(do, v) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp + # compute dk + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + # update pointers + Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0)) + # initialize pointers to output + DK_block_ptr = tl.make_block_ptr(base=DK + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_kn, stride_kk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + DV_block_ptr = tl.make_block_ptr(base=DV + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(DK_block_ptr, (dk * sm_scale).to(k.dtype)) + tl.store(DV_block_ptr, dv.to(v.dtype)) + + +@triton.jit +def _bwd_kernel_dq( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + L, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # offs_d = tl.arange(0, BLOCK_DMODEL) + # Initialize pointers to Q, K, V + Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_vn, stride_vk), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + qk_scale = sm_scale * 1.44269504 + # load q and do: they will stay in SRAM throughout + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(q.dtype) + do = tl.load(DO_block_ptr) + Di = tl.load(D_ptrs + offs_m) + l_i = tl.load(l_ptrs + offs_m) + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # loop over k, v + lo = 0 + hi = (start_m + 1) * BLOCK_M + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk ---- + qk = tl.dot(q, k) + qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf")) + p = tl.math.exp2(qk - l_i[:, None]) + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp + # compute dq. Unfortunately we cannot avoid transpose here as this loop + # uses k both normal and transpose. + dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k)) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N)) + # initialize pointers to output + DQ_block_ptr = tl.make_block_ptr(base=DQ + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(DQ_block_ptr, (dq * sm_scale).to(q.dtype)) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + if torch.version.hip is None: + # BLOCK_M = 128 + # BLOCK_N = 64 if Lk <= 64 else 32 + # num_stages = 4 if Lk <= 64 else 3 + # num_warps = 4 if Lk <= 64 else 8 + pass + + stage = 3 if causal else 1 + grid = lambda META: (triton.cdiv(q.shape[2], META['BLOCK_M']), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[grid]( + q, + k, + v, + sm_scale, + M, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + STAGE=stage, + ) + + ## restore the grid for bwd kernel + best_config = _attn_fwd.get_best_config() + block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) + grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.split_kernel = split_kernel + return o + + @staticmethod + def backward(ctx, do): + # configuration is not supported + assert (not (ctx.split_kernel and not ctx.causal)) + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, L = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + delta = torch.empty_like(L) + do_scaled = torch.empty_like(do) + # Figure out what BLOCK size fwd used and adjust num_blocks accordingly. + # If the two are the same, we don't need this but the bwd pass block size + # is smaller than the fwd so we need this scaling to ensure we loop over all + # values and don't skip some blocks. + # Alternatively we could compute a new grid but this keeps it consistent + # with fwd and easier to reason about. + block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, + do, + do_scaled, + delta, + BLOCK_M=block_scale * BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + if not ctx.split_kernel: + _bwd_kernel[(ctx.grid[1], )]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + dk, + dv, + L, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + block_scale * ctx.grid[0], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=4, + CAUSAL=ctx.causal, + num_stages=1, + ) + else: + dq = torch.zeros_like(q) + _bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dk, + dv, + L, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=4, + num_stages=1, + ) + _bwd_kernel_dq[ctx.grid]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + L, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=2 * BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=4, + waves_per_eu=1, + num_stages=1, + ) + # print(h.asm["ttgir"]) + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (4, 48, 1024, 128), + (4, 48, 2048, 128), + (4, 48, 4096, 128), + #(4, 48, 8192, 64), + #(4, 48, 16384, 64) +]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + v = (torch.empty((Z, H, D_HEAD, N_CTX), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + sm_scale = 0.5 + # dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(v.dtype) + ref_out = torch.matmul(p, v.transpose(2, 3)) + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale) + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (1, 16, 8192, 64), +]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + causal = True + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0, 5 + split_kernel = True + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(v.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation + tri_out = attention(q, k, v, causal, sm_scale, split_kernel) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + if torch.version.hip is None: + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + else: + assert torch.allclose(ref_dv, tri_dv, atol=5e-2, rtol=0) + assert torch.allclose(ref_dk, tri_dk, atol=5e-2, rtol=0) + assert torch.allclose(ref_dq, tri_dq, atol=5e-2, rtol=0) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +name_to_torch_types = { + 'fp16': torch.float16, + 'bf16': torch.bfloat16, +} + +# vary seq length for fixed head and batch=4 +configs = [] +for mode in ['fwd']: + for dtype in ["fp16", "bf16"]: + for D_HEAD in [128, 64]: + for causal in [False, True]: + configs.append( + triton.testing.Benchmark( + x_names=['BATCH', 'H', 'N_CTX'], x_vals=[ + (16, 16, 1024), + (8, 16, 2048), + (4, 16, 4096), + (2, 16, 8192), + (1, 16, 16384), + (4, 48, 1024), + (4, 48, 2048), + (4, 48, 4096), + (4, 48, 8192), + (4, 48, 16384), + ], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), styles=[('red', '-'), + ('blue', '-')], + ylabel='ms', plot_name=f'fused-attention-d{D_HEAD}-{mode}-causal={causal}-{dtype}', + args={'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode, 'causal': causal})) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + init_dtype = name_to_torch_types[dtype] + split_kernel = False + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + split_kernel = True + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=init_dtype, device=device, requires_grad=True) + if FLASH_VER == 1: + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) + elif FLASH_VER == 2: + fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f'unknown {FLASH_VER = }') + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == 'bwd': + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops / ms * 1e-9 + + +# only works on post-Ampere GPUs right now +bench_flash_attention.run(save_path='.', print_data=True) diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md new file mode 100644 index 000000000000..5bcedbf49cdd --- /dev/null +++ b/python/perf-kernels/README.md @@ -0,0 +1,63 @@ +# AMD Perf Kernels + +This directory contains customized/tuned/experimental kernels for AMD Instinct series GPUs. +Please make sure your Triton compiler is v2.1 or later, and is from the OpenAI Triton repository +[here](https://github.com/openai/triton). To install Triton, please see +[these](https://github.com/openai/triton/tree/main?tab=readme-ov-file#install-from-source) instructions. + +## `06-fused-attention-transV.py` + +This script is a copy of `tutorials/06-fused-attention.py` with the following +two changes: + +- Tensor V is transposed in the way that seqlen/N_CTX dimension becomes the +fastest changing (a.k.a. leading or least strided) dimension. +This script produces better performance than `tutorials/06-fused-attention.py` +since it has better LDS access efficiency for tensor V. +Note that in the future, we'll improve the LDS access efficiency for +non-transposed tensor V, i.e. head dimension is the fastest changing dimension. +- Only fwd kernel is benchmarked. + +## `06-fused-attention-fwd-transV.py` + +This script is used to produce the best performance for fwd kernel. +It is a copy of `06-fused-attention-transV.py` with the following +changes: + +- All bwd kernels are removed. +- Storing `m` at the end of the fwd kernel is removed. +- Autotuner is removed. All parameters for D=64 ad D=128 are pre-tuned +on MI250X and hard coded. + +Note that this script is also used to benchmark FA performance with 2 GCDs. +Check the [2GCD benchmark script](https://github.com/ROCmSoftwarePlatform/triton/blob/triton-mlir/scripts/amd/benchmark_flash_attention.py) for more details. + +## `flash-attention.py` + +This script contains the Flash Attention kernel with the following support + +- Arbitrary Q and KV sequence lengths, and arbitrary head sizes +- Autoregressive or "causal" masking +- Flash Attention v2 with variable sequence lengths +- Multi and Grouped Query attention +- ALiBi bias +- Matrix bias + +These are currently supported for the forward kernel only. + +## `06-attention-decode.py` + +This contains the Flash Decoding kernel. + +## `hbm-bw-test.py` + +This is a script that measures HBM bandwidth performance on your device. + +## `03-matrix-multiplication-all-types.py` + +This script contains the GEMM kernel that supports int8, int32, fp16, +fp32, bf16 and f8 (both e5m2 and e4m3) datatypes. + +## `03-matrix-multiplication-stream-k.py` + +This script contains the GEMM kernel that implements [stream-k](https://arxiv.org/abs/2301.03598) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py new file mode 100644 index 000000000000..6fc861b281fa --- /dev/null +++ b/python/perf-kernels/flash-attention.py @@ -0,0 +1,1527 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import argparse +import pytest +import sys +import torch + +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype: tl.constexpr = torch.float8_e5m2fnuz + + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + dropout_p, return_encoded_softmax = 0.0, False + + def __init__(self, sm_scale=1.0): + self.sm_scale = sm_scale + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def need_dropout(self, dropout_p, return_encoded_softmax): + self.dropout_p = dropout_p + self.return_encoded_softmax = return_encoded_softmax + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + if self.varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def print_gpu(prefix, val=None): + if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): + if val is not None: + tl.device_print(prefix, val) + else: + tl.device_print(prefix) + + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, + philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, + masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + use_cuda_graph=True, +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + alibi_slopes, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_ALIBI: tl.constexpr, + BATCH_SIZE: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0, 1)) + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + # need_padding = False + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + # need_padding = True + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + # need_padding = True + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr(base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1)) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0)) + if BIAS_TYPE != 0: + b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs + bias_ptr = tl.make_block_ptr( + base=bias + b_offset, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if ENABLE_DROPOUT: + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr(base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0)) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, + alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + # This is a > check because mask being 0 blocks the store. + l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +@triton.jit +def _attn_bwd_preprocess( + Out, + DO, + Delta, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_doz, + stride_doh, + stride_dom, + stride_don, + seqlen_q, + head_dim, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + # off_n = tl.arange(0, D_HEAD) + off_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + o_offset = off_h * stride_oh + off_z * stride_oz + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + do_offset = off_h * stride_doh + off_z * stride_doz + DO_block_ptr = tl.make_block_ptr(base=DO + do_offset, shape=(seqlen_q, head_dim), strides=(stride_dom, stride_don), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + # load + # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) + off_zh = off_z * num_h + off_h * 1 + # Check for OOB accesses + delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) + overflow = off_m + BLOCK_M - seqlen_q + if overflow > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) + mask = boundary > tl.arange(0, BLOCK_M) + tl.store(delta_ptrs, delta, mask=mask) + else: + tl.store(delta_ptrs, delta) + + +@triton.jit +def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + # offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr(base=Q, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1, 0)) + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(QT_block_ptr) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + kqT = tl.dot(k, qT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) + kqT += alibi_block * 1.44269504089 + + pT = tl.math.exp2(kqT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(DO_block_ptr) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + return dk, dv + + +@triton.jit +def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + # offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr(base=K, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + VT_block_ptr = tl.make_block_ptr(base=V, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(KT_block_ptr) + qk = tl.dot(q, kT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n) + qk += alibi_block * 1.44269504089 + + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + vT = tl.load(VT_block_ptr) + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ.0. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + + if USE_ALIBI: + a_offset = bhid + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + # compute dK and dV for blocks close to the diagonal that need to be masked + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + + # compute dK and dV for blocks that don't need masking further from the diagonal + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + + DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DV_block_ptrs, dv.to(v.dtype)) + + # Write back dK. + dk *= sm_scale + DK_block_ptrs = tl.make_block_ptr(base=DK, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DK_block_ptrs, dk.to(k.dtype)) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + Q_block_ptr = tl.make_block_ptr(base=Q, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, MASK_BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, MASK=True) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * BLOCK_N2, num_steps, MASK=False) + # Write back dQ. + DQ_block_ptr = tl.make_block_ptr(base=DQ, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + dq *= LN2 + tl.store(DQ_block_ptr, dq.to(q.dtype)) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, o, metadata): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) + + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + metadata.check_args(q, k, v, o) + if metadata.varlen: + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = metadata.num_contexts + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + + # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, + dtype=torch.float32) + else: + encoded_softmax = None + + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2), + metadata.bias.stride(3)) + else: + bias_strides = (0, 0, 0, 0) + + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, *alibi_strides, metadata.cu_seqlens_q, metadata.cu_seqlens_k, + dropout_p=metadata.dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, + BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, + USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = metadata.sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = metadata.return_encoded_softmax + return o, encoded_softmax + + @staticmethod + def backward(ctx, do, _): + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + seqlen_q = q.shape[2] + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + # NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + assert N_CTX % PRE_BLOCK == 0 + delta = torch.empty_like(M) + _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] + # padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + _attn_bwd_preprocess[grid_preprocess]( + o, + do, + delta, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + seqlen_q, + head_dim=Lk, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + ctx.alibi_slopes, + do, + dq, + dk, + dv, + M, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + N_HEAD, + N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + USE_ALIBI=False if ctx.alibi_slopes is None else True, + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply + + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): + torch.manual_seed(20) + + # Initialize q, k, v + q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + return q, k, v, input_metadata + + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): + torch.manual_seed(20) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + max_seqlens_q = torch.max(seqlens_q).item() + max_seqlens_k = torch.max(seqlens_k).item() + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + # -1 because the last entry of cu_seqlens_q specifies the end of the last seq + # num_ctxs = len(cu_seqlens_q) - 1 + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + if TORCH_HAS_FP8E5: + q = q.to(torch_dtype) + k = k.to(torch_dtype) + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = attention(q, k, v, o, input_metadata) + + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 24, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 16, 15498, 2, 128), + (2, 16, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('use_bias', [True]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): + pytest.skip() + torch.manual_seed(20) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + if causal: + input_metadata.need_causal() + if use_bias: + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") + input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) + else: + bias = None + q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + if TORCH_HAS_FP8E5: + q = q.to(torch_dtype) + k = k.to(torch_dtype) + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = attention(q, k, v, o, input_metadata) + # reference implementation:171 + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_bias: + scores += input_metadata.bias + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64), + (4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64), + (4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128), + (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + pytest.skip() + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, D_HEAD, dtype) + tri_out = torch.empty_like(q) + ref_out = torch.empty_like(q) + + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), + (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), + (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), + (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), + (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) + ref_out = torch.empty_like(q) + tri_out = torch.empty_like(q) + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the + # size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_ref[start_k:end_k] + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (2, 48, 4096, 64), + (1, 16, 1024, 64), + (1, 16, 1024, 128), + #(1, 16, 8192, 63), + #(1, 16, 1022, 64), +]) +@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) +@pytest.mark.parametrize('torch_sdpa_test', [False, True]) +@pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('use_alibi', [False, True]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, + dtype=torch.float16): + torch.manual_seed(20) + if qseqlen_not_equal_kseqlen is not None: + seqlen_q = qseqlen_not_equal_kseqlen + else: + seqlen_q = N_CTX + seqlen_k = N_CTX + + if causal and ((N_CTX - 1) & N_CTX): + pytest.skip() + if causal and seqlen_q != seqlen_k: + pytest.skip() + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = seqlen_q + input_metadata.max_seqlens_k = seqlen_k + + dropout_p = 0 + q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + o = torch.empty_like(q) + + if causal: + input_metadata.need_causal() + + if use_alibi and not torch_sdpa_test: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, H) + dout = torch.randn_like(q) + # reference implementation + if torch_sdpa_test: + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, + is_causal=causal, scale=sm_scale, + dropout_mask=None) + ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + else: + M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if use_alibi: + p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) + if causal: + p[:, :, M == 0] = float("-inf") + + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # # triton implementation + tri_out, _ = attention(q, k, v, o, input_metadata) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # test + #print("reference") + #print(ref_dv) + #print("tri") + #print(tri_dv) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + + if dtype == torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + if dtype == torch.float32: + ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + else: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + + RTOL = 0 + + torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) + + +def nonvarlen_benchmark_configs(): + configs = [ + (16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + (2, 48, 48, 2048, 1024), + (2, 48, 48, 4096, 8192), + (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + (8, 16, 16, 1989, 15344), + (4, 16, 16, 4097, 163), + (2, 16, 16, 8122, 2159), + (1, 16, 16, 16281, 7), + (2, 48, 48, 1021, 1020), + (2, 48, 48, 2001, 2048), + (2, 48, 48, 3996, 9639), + (2, 48, 48, 8181, 1021), + ] + return configs + + +def varlen_benchmark_configs(): + configs = [ + (2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + (4, 16, 8, 4096, 4096), + (2, 16, 4, 8192, 8192), + (2, 16, 8, 16384, 16384), + (2, 48, 12, 1024, 1024), + (2, 48, 24, 2048, 2048), + (2, 48, 8, 4096, 4096), + (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + (2, 64, 32, 1024, 1024), + (4, 64, 16, 2048, 2048), + (4, 64, 8, 4096, 4096), + (4, 64, 32, 8192, 8192), + (4, 128, 16, 16384, 16384), + ] + return configs + + +def run_benchmark(custom): + + args = parse_args() + dtype = arg_to_torch_dtype[args.dtype] + # hk = args.hq if not args.hk else args.hk + # sk = args.sq if not args.sk else args.sk + head_size = 128 if not args.d else args.d + mode = 'fwd' + x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + causal = args.causal + varlen = args.varlen + configs = [] + if custom: + x_vals_list = [(args.b, args.hq, args.hk, args.sq, args.sk)] + else: + if varlen: + x_vals_list = varlen_benchmark_configs() + else: + x_vals_list = nonvarlen_benchmark_configs() + print_time = args.return_time + line_names = 'Time (ms)' if print_time else 'TFLOPS' + configs.append( + triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], + line_names=[line_names], styles=[('red', '-')], ylabel='ms', + plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}', + args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) + + @triton.testing.perf_report(configs) + def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"): + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + # TODO: Enable bias after testing. + # if use_bias: + # bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + # input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX) + # else: + # bias = None + # bias = None + + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + + flops_per_matmul = 0 + if varlen: + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + for i in range(0, input_metadata.num_contexts): + seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] + seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] + # x2 for 2 GEMMs + flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 + else: + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + if causal: + input_metadata.need_causal() + o = torch.empty_like(q) + fn = lambda: attention(q, k, v, o, input_metadata) + if mode == 'bwd': + o, _ = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + total_flops = 2 * flops_per_matmul + # TODO: This needs to be fixed for unequal Q/K seqlens + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + if print_time: + return ms + else: + return total_flops / ms * 1e-9 + + bench_flash_attention.run(save_path=".", print_data=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark FlashAttention", + allow_abbrev=False, + ) + parser.add_argument("-b", type=int, default=0) + parser.add_argument("-hq", type=int, default=0) + parser.add_argument("-hk", type=int, default=0) + parser.add_argument("-sq", type=int, default=0) + parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-d", type=int, default=0) + parser.add_argument("-causal", action='store_true', default=False) + parser.add_argument("-varlen", action='store_true', default=False) + parser.add_argument("-dtype", default='fp16') + parser.add_argument("-return_time", action='store_true', default=False) + return parser.parse_args() + + +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def main(): + args = parse_args() + custom_config = False + if args.b or args.hq or args.hk or args.sq or args.sk or args.d: + custom_config = True + assert args.b and args.hq and args.sq and args.d, \ + "If custom config is specified, please provide \ + all of batch, number of Q heads, Q sequence length \ + and head size." + + assert args.dtype in arg_to_torch_dtype, \ + "Only fp16, bf16 and f32 types currently supported." + + run_benchmark(custom_config) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/hbm-bw-test.py b/python/perf-kernels/hbm-bw-test.py new file mode 100644 index 000000000000..a20ce044eaee --- /dev/null +++ b/python/perf-kernels/hbm-bw-test.py @@ -0,0 +1,200 @@ +""" +Simple test to measure achieved HBM bandwidth. +This kernel moves N bytes of data from one region in HBM to another, using Triton. +""" + +# %% +# Compute Kernel +# -------------- + +import argparse +import sys +import torch + +import triton +import triton.language as tl + + +@triton.jit +def copy_kernel( + input_ptr, # *Pointer* to input vector. + output_ptr, # *Pointer* to output vector. + NUM_ELEMENTS: tl.constexpr, # Total elements to move. + BLOCK_SIZE: tl.constexpr, # Elements to load / store per iteration + VECTOR_SIZE: tl.constexpr, # Size of the entire vector being moved. + READ_ONLY: tl.constexpr, +): + pid = tl.program_id(axis=0) + # Offset at which to start for this WG. + lo = pid * NUM_ELEMENTS + # Offset until which to read for this WG. + hi = lo + NUM_ELEMENTS + # NUM_ITERS: tl.constexpr = triton.cdiv(NUM_ELEMENTS, BLOCK_SIZE) + IRREGULAR_SIZE: tl.constexpr = NUM_ELEMENTS % BLOCK_SIZE + acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if IRREGULAR_SIZE: + hi = hi - IRREGULAR_SIZE + # Move buffer in chunks of block_size + for idx in range(lo, hi, BLOCK_SIZE): + offsets = idx + tl.arange(0, BLOCK_SIZE) + in_vals = tl.load(input_ptr + offsets) + acc += in_vals + if not READ_ONLY: + tl.store(output_ptr + offsets, in_vals) + # Unroll last irregular iter in case the total sized moved by this WG + # is not a multiple of block size. + if IRREGULAR_SIZE: + lo = hi + hi = hi + IRREGULAR_SIZE + offsets = lo + tl.arange(0, BLOCK_SIZE) + mask = offsets < hi + in_vals = tl.load(input_ptr + offsets, mask=mask) + if not READ_ONLY: + tl.store(output_ptr + offsets, in_vals, mask=mask) + + if READ_ONLY: + tl.store(output_ptr + tl.arange(0, BLOCK_SIZE), acc) + + +def copy(src: torch.Tensor, block_size, wgs, dst: torch.Tensor): + assert src.is_cuda + vector_size = src.numel() + assert dst.numel() == vector_size or dst.numel() == block_size + size_per_wg = vector_size / wgs + assert size_per_wg >= block_size, \ + "Too many WGS. Please increase the size of the buffer using -size." \ + f" We want a buffer of size {wgs * block_size} f32 elements or larger." + grid = (wgs, 1, 1) + # Each WG will move these many elements + n_elements = triton.cdiv(vector_size, wgs) + # If we want to read only, we do a dummy write of a single block size back to HBM + read_only = dst.numel() != src.numel() + copy_kernel[grid]( + src, + dst, + NUM_ELEMENTS=n_elements, + BLOCK_SIZE=block_size, + VECTOR_SIZE=vector_size, + READ_ONLY=read_only, + num_warps=4, + ) + + +def get_reference(x, wgs, gbps): + ms = triton.testing.do_bench(lambda: torch.clone(x)) + bw = gbps(ms) + triton_output = torch.empty_like(x) + copy(x, block_size=16384, wgs=wgs, dst=triton_output) + err = triton_output - x + if torch.count_nonzero(err): + assert False, f"Torch and Triton do not match - max error is "\ + f"{torch.max(torch.abs(err))}" + return bw + + +def align_size_to_wgs(size, wgs): + return (size // wgs) * wgs + + +def run_benchmark_suite(vector_size, block_size, num_cores, read_only): + configs = [] + # Define WGs in powers of 2 from 1 - 2048. + x_vals = [(2**i) for i in range(0, 12)] + num_cu_aligned_wgs = [(num_cores * i) for i in range(1, 5)] + import bisect + for i in num_cu_aligned_wgs: + bisect.insort(x_vals, i) + configs.append( + triton.testing.Benchmark( + x_names=['wgs'], # Argument names to use as an x-axis for the plot. + x_vals=x_vals, x_log=True, # x axis is logarithmic. + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=['triton'], # Possible values for `line_arg`. + line_names=['Triton'], # Label name for the lines. + styles=[('blue', '-'), ('green', '-')], # Line styles. + ylabel='GiB/s', # Label name for the y-axis. + plot_name=f'size={vector_size}', # Name for the plot. Used also as a file name for saving the plot. + args={'size': vector_size}, # Values for function arguments not in `x_names` and `y_name`. + )) + + @triton.testing.perf_report(configs) + def benchmark(size, provider, wgs): + aligned_size = align_size_to_wgs(size, wgs) + src_tensor = torch.randn(aligned_size, device='cuda') + dst_tensor = torch.empty(block_size, device='cuda') + if not read_only: + dst_tensor = torch.empty_like(src_tensor) + ms = triton.testing.do_bench(lambda: copy(src_tensor, block_size, wgs, dst_tensor)) + # 8 because 4 bytes from load, 4 from store. + if read_only: + gbps = lambda ms: 4 * size / ms * 1e3 / 1024**3 + else: + gbps = lambda ms: 8 * size / ms * 1e3 / 1024**3 + return gbps(ms) + + benchmark.run(print_data=True, show_plots=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="HBM Bandwidth Benchmark", + allow_abbrev=False, + ) + parser.add_argument("-direction", type=str, default="read-only", + help="Data movement direction: read-only, read-write") + parser.add_argument("-size", type=int, default=1024, help="Size of buffer moved, in MiB") + parser.add_argument("-num_wgs", type=int, default=0, help="Number of workgroups to use") + parser.add_argument("-block_size", type=int, default=16384, help="Block size per iteration to load / store") + parser.add_argument("-run_sweep", action='store_true', default=False, help="Run sweep of B/W vs workgroups") + return parser.parse_args() + + +def main(): + args = parse_args() + torch.manual_seed(0) + num_cores = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + size = args.size + rw = args.direction == "read_write" + num_elements = size * 1024 * 1024 // 4 + if args.run_sweep: + assert args.num_wgs == 0, "If running the benchmark suite, please do not specify the number of WGs to use." + run_benchmark_suite(num_elements, args.block_size, num_cores, not rw) + return + if args.num_wgs == 0: + # num_wgs not user specified - get from device properties + num_wgs = num_cores + print(f"Using {num_wgs} workgroups. It is recommended to "\ + "use -num_wgs to provide this number.") + else: + assert args.num_wgs > 0, "Please provide a positive, non-zero number of workgroups!" + num_wgs = args.num_wgs + if num_wgs % num_cores: + print(f"Note! Your device has {num_cores} cores. It is recommended to use"\ + " a number for workgroups that is a multiple of this number."\ + f" You have currently chosen {num_wgs}.") + num_elements_rounded = align_size_to_wgs(num_elements, num_wgs) + if num_elements != num_elements_rounded: + print(f"Removing last {num_elements - num_elements_rounded} elements to "\ + "get a tensor size aligned to multiple of number of workgroups.") + num_elements = num_elements_rounded + src_tensor = torch.randn(num_elements, device="cuda") + if rw: + # 8 because 4B for read. 4B for write. + gbps = lambda ms: 8 * num_elements / ms * 1e3 / 1024**3 + ref_bw = get_reference(src_tensor, num_wgs, gbps) + print(f"Reference PyTorch bandwidth = {ref_bw} GiB/s") + else: + gbps = lambda ms: 4 * num_elements / ms * 1e3 / 1024**3 + if size < 1024: + print("Note! It is recommended to use a buffer larger than 1 GiB.") + if num_elements % args.block_size: + print("Note! This config is suboptimal. It is recommended to use a buffer that"\ + f" is a multiple of wgs x block size = {num_wgs * args.block_size} elements.") + dst_tensor = torch.empty_like(src_tensor) if rw else torch.empty(args.block_size, device='cuda') + triton_ms = triton.testing.do_bench(lambda: copy(src_tensor, args.block_size, num_wgs, dst=dst_tensor), warmup=1, + rep=1) + print(f"Triton bandwidth = {gbps(triton_ms)} GiB/s") + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py new file mode 100644 index 000000000000..beb8b0df9b1f --- /dev/null +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py @@ -0,0 +1,485 @@ +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") + +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +# iterate, multiply and accumulate over K axis +@triton.jit() +def mac_loop( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + tile_id, + mod1, + mod2, + iters_per_tile, + start_iter, + end_iter, + pid_m, + pid_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + + # where are we in the grid + # tile_id = start_iter // iters_per_tile + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (start_iter % iters_per_tile) + # B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (start_iter % iters_per_tile) + A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (mod1) + B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (mod1) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for current_iter in range(start_iter, end_iter): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + #if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM + + +# if mod2 == 0:# last iteration of the tile always happens before its start on another SM +# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! +# tl.store(C_, acc) +# if start_iter % iters_per_tile != 0: # only if tile has been partially processed +# if mod1 != 0: # only if tile has been partially processed +# tl.atomic_xchg(locks + tile_id, 1) +# else: +# while tl.atomic_cas(locks + tile_id, 1, 1) != 1: +# pass +# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! +# tl.atomic_add(C_, acc) + if mod1 == 0 and mod2 == 0: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + +@triton.jit() +def first_wave( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_full_tiles_streamk, + total_partial_tiles_streamk, + iters_per_tile, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + + while start_iter < last_iter: + end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter) + mod1 = start_iter % iters_per_tile + mod2 = end_iter % iters_per_tile + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + mac_loop( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + tile_id, + mod1, + mod2, + iters_per_tile, + start_iter, + end_iter, + pid_m, + pid_n, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + ) + + start_iter = end_iter + + +# similar to the reference matmul kernel +@triton.jit() +def full_tiles( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_tiles_streamk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + # first wave has done more tiles than there are SMs, we adjust pid + tile_id = tl.program_id(0) + total_tiles_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C, acc) + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = False + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, + two_tiles: bool, num_stages: int, num_warps: int): + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + GROUP_M = 8 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + # allocates locks to sync work accross SMs + k1 = first_wave[(total_programs_streamk, )]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, + total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + if matmul._debug: + print(f"{k1.n_regs} registers used, {k1.n_spills} spills") + k2 = full_tiles[(total_blocking_tiles, )]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_tiles_streamk=total_tiles_streamk, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + if matmul._debug: + print(f"{k2.n_regs} registers used, {k2.n_spills} spills") + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, + num_stages=3, num_warps=4): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +m, n, k = 8192, 8192, 8192 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +BLK_M = 128 +BLK_N = 256 +BLK_K = 16 +two_tiles = 'True' +num_stages = 0 +num_warps = 4 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, 128, 128, 32, 4, 4) +matmul.set_debug(False) +expected = A @ B + +assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" + +# for debugging, uncomment the following line +# exit(0) + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench( + lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench( + lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench( + lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // 128) * (n // 128) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, 128, 128, 32, two_tiles, 4, 4)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py new file mode 100644 index 000000000000..a35d691a0225 --- /dev/null +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py @@ -0,0 +1,563 @@ +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") +# global flag to indicate whether using the full tuing space +tuning_full_space = True +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +@triton.jit() +def get_tile_config(M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, two_tiles, + total_programs_streamk): + total_blocks_M = tl.cdiv(M, BLOCK_M) + total_blocks_N = tl.cdiv(N, BLOCK_N) + iters_per_tile = tl.cdiv(K, BLOCK_K) + # GROUP_M = 0 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + return iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk + + +# pruned some unreasonable config +def prune_configs(configs, named_args): + # call only for full tuning space + if not tuning_full_space: + return configs + + SIZE_M = named_args["A"].shape[0] + SIZE_N = named_args["B"].shape[1] + # SIZE_K = named_args["A"].shape[1] + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, _ =\ + kw["BLOCK_M"], kw["BLOCK_N"], kw["BLOCK_K"] + if SIZE_M <= 32 and BLOCK_M != 32: + continue + if SIZE_N <= 32 and BLOCK_N != 32: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def get_full_tuning_space(): + configs = [] + if not tuning_full_space: + return configs + + block_mn_range = [64, 128, 256] + block_k_range = [16, 32, 64] + num_warps_range = [1, 2, 4, 8] + # group_m_range = [0, 1, 2, 4, 8] + group_m_range = [0, 4, 8] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for num_stages in num_stage_range: + for num_waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append( + triton.Config( + { + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu, + 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack + }, + num_stages=num_stages, + num_warps=num_warps, + )) + + return configs + + +#To do: we need update the default autotune configuration once we go through the whole performance test sets. +@triton.autotune( + configs=get_full_tuning_space() if tuning_full_space else [ + triton.Config( + { + 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=4), + ], + key=['M', 'N', 'K'], + # prune_configs_by={ + # 'early_config_prune': prune_configs, + # 'perf_model': None, + # "top_k": None + # }, + reset_to_zero=['C'], +) +@triton.jit() +def streamk_gemm( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, + # total_tiles_streamk, + total_programs_streamk, + two_tiles, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) + iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk = get_tile_config( + M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, total_programs_streamk) + + # Determine whether we are in the first wave or full_tiles phase based on pid + is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 + + # Calculate starting and ending iterations for first wave + if not is_first_wave: + tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + precomputed_stride_ak = BLOCK_K * stride_ak + precomputed_stride_bk = BLOCK_K * stride_bk + # pointers + A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += precomputed_stride_ak + B_BASE += precomputed_stride_bk + # acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, acc) + else: + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder + precomputed_stride_ak = BLOCK_K * stride_ak + precomputed_stride_bk = BLOCK_K * stride_bk + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for current_iter in range(start_iter, end_iter): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += precomputed_stride_ak + B_BASE += precomputed_stride_bk + + # acc = acc.to(tl.float16) # restore C.dtype.element_ty + if remainder == 0 and end_iter % iters_per_tile == 0: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + start_iter = end_iter + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, + two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): + + def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_programs_streamk): + total_blocks_M = triton.cdiv(M, BLOCK_M) + total_blocks_N = triton.cdiv(N, BLOCK_N) + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + + return total_blocking_tiles + + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + # GROUP_M = 8 # 0 to disable swizzling + + if matmul._debug: + total_blocks_M = triton.cdiv(M, BLOCK_M) + total_blocks_N = triton.cdiv(N, BLOCK_N) + iters_per_tile = triton.cdiv(K, BLOCK_K) + total_tiles = total_blocks_M * total_blocks_N + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + # total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + # total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + print(f"M,N,K={M},{N},{K} ; BLOCK_M,N,K={BLOCK_M},{BLOCK_N},{BLOCK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{total_partial_tiles_streamk=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + grids = lambda META: (total_programs_streamk + compute_total_blocking_tiles(M, N, META['BLOCK_M'], META[ + 'BLOCK_N'], two_tiles, total_programs_streamk), ) + kk = streamk_gemm[(grids)]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + # total_full_tiles_streamk=total_full_tiles_streamk, + # total_partial_tiles_streamk=total_partial_tiles_streamk, + # iters_per_tile=iters_per_tile, + # total_tiles_streamk=total_tiles_streamk, + total_programs_streamk=total_programs_streamk, + two_tiles=two_tiles, + ACC_TYPE=ACC_TYPE, + # GROUP_M=GROUP_M, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # BLOCK_K=BLOCK_K, + # num_stages=num_stages, + # num_warps=num_warps, + # waves_per_eu = waves_per_eu, + ) + if matmul._debug: + print(f"{kk.n_regs} registers used, {kk.n_spills} spills") + + # print(kk.asm['ttgir']) + # print(kk.asm['amdgcn']) + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, + num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, kpack=kpack) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +#m, n, k = 1792, 7424, 4864 # some problem size to test +#m, n, k = 8192, 8192, 8192 # some problem size to test +m, n, k = 4096, 4096, 8192 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +#A = torch.ones((m, k), device="cuda", dtype=torch.float16) +#B = torch.ones((k, n), device="cuda", dtype=torch.float16) +BLOCK_M = 256 +BLOCK_N = 256 +BLOCK_K = 64 +two_tiles = True +num_stages = 0 +num_warps = 8 +waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 1 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, + mfmaInstrSize, kpack) +matmul.set_debug(False) +expected = A @ B + +assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" +print("pass validation test") + +# for debugging, uncomment the following line +#exit(0) + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, + num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") +print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, + num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") +print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") +print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // BLOCK_M) * (n // BLOCK_N) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, + num_stages, num_warps, waves_per_eu)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + Best_tuning_config = f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})' + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + "Best tuning config": Best_tuning_config, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py new file mode 100644 index 000000000000..2651ad59d923 --- /dev/null +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py @@ -0,0 +1,387 @@ +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") + +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +@triton.jit() +def first_wave( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_full_tiles_streamk, + total_partial_tiles_streamk, + iters_per_tile, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for current_iter in range(start_iter, end_iter): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += BLOCK_K * stride_ak + B_BASE += BLOCK_K * stride_bk + + if remainder == 0 and end_iter % iters_per_tile == 0: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + start_iter = end_iter + + +# similar to the reference matmul kernel +@triton.jit() +def full_tiles( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_tiles_streamk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + # first wave has done more tiles than there are SMs, we adjust pid + tile_id = tl.program_id(0) + total_tiles_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers + # rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + # rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C, acc) + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, + two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + GROUP_M = 4 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + + k1 = first_wave[(total_programs_streamk, )]( + a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack) + if matmul._debug: + print(f"{k1.n_regs} registers used, {k1.n_spills} spills") + k2 = full_tiles[(total_blocking_tiles, )](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c.stride(0), c.stride(1), total_tiles_streamk=total_tiles_streamk, + BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, + waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack) + if matmul._debug: + print(f"{k2.n_regs} registers used, {k2.n_spills} spills") +# print(k2.asm['amdgcn']) + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, + num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, kpack=kpack) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +#m, n, k = 4864, 4096, 8256 # some problem size to test +m, n, k = 6912, 768, 256 # some problem size to test +#m, n, k = 8192, 8192, 8192 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +#A = torch.ones((m, k), device="cuda", dtype=torch.float16) +#B = torch.ones((k, n), device="cuda", dtype=torch.float16) +BLK_M = 64 +BLK_N = 64 +BLK_K = 64 +two_tiles = 'True' +num_stages = 0 +num_warps = 4 +waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 2 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, + kpack) +#exit(0) +matmul.set_debug(False) +expected = A @ B + +assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" + +# for debugging, uncomment the following line + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, + waves_per_eu, mfmaInstrSize, kpack)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // BLK_M) * (n // BLK_N) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul( + A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) From 17575ea88e229bbdbcd476553b3bc25b3b8dab58 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 13 May 2024 14:36:34 -0400 Subject: [PATCH 02/12] skip backward (#586) --- python/perf-kernels/flash-attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 6fc861b281fa..d70a43ecd36c 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1277,6 +1277,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 @pytest.mark.parametrize('use_alibi', [False, True]) def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): + pytest.skip() torch.manual_seed(20) if qseqlen_not_equal_kseqlen is not None: seqlen_q = qseqlen_not_equal_kseqlen From a3d784a869aad6801694680f424c4f36e447db98 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 16 May 2024 15:20:39 -0500 Subject: [PATCH 03/12] Change all block pointers to tensor pointers (#585) Change all block pointers to tensor pointers Block pointers are for nvidia TMAs. They are useful for regular loads as well but not well supported. Also cleaned up some code I came across along the way and updated comment at the top. --- python/perf-kernels/flash-attention.py | 246 ++++++++++++------------- 1 file changed, 119 insertions(+), 127 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index d70a43ecd36c..42e9ac310195 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -2,19 +2,21 @@ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf -Features supported: +Credits: +AMD Triton kernels team +OpenAI kernel team -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: +Currently only the forward kernel is supported, and contains these features: -1) Non power of two head dims +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias """ @@ -28,10 +30,6 @@ torch_dtype: tl.constexpr = torch.float16 -TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') -if TORCH_HAS_FP8E5: - torch_dtype: tl.constexpr = torch.float8_e5m2fnuz - class MetaData(): cu_seqlens_q = None @@ -141,16 +139,22 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): return rng_keep +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) else: - tensor = tl.load(block_ptr) + tensor = tl.load(ptrs) return tensor @@ -204,19 +208,26 @@ def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, - philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, - masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr): +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) if PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n @@ -238,8 +249,9 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) # While bias is added after multiplying qk with sm_scale, # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. @@ -249,10 +261,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ # Compute the global position of each token within the sequence global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) - qk += (alibi_block * 1.44269504089) # scale factor of log2(e) # softmax @@ -266,26 +276,26 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i @@ -364,7 +374,7 @@ def attn_fwd( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, + USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -375,6 +385,7 @@ def attn_fwd( off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -412,18 +423,20 @@ def attn_fwd( # If we have no blocks after adjusting for seqlen deltas, this WG is part of # the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result - tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0, 1)) + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # We store inf to LSE, not -inf because in the bwd pass, we subtract this # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - tl.store(l_ptrs, l) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) # TODO: Should dropout and return encoded softmax be handled here too? return @@ -434,41 +447,26 @@ def attn_fwd( else: off_h_k = off_h_q - # need_padding = False n_extra_tokens = 0 if seqlen_k < BLOCK_N: - # need_padding = True n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: - # need_padding = True n_extra_tokens = seqlen_k % BLOCK_N PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr(base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1)) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0)) - if BIAS_TYPE != 0: - b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs - bias_ptr = tl.make_block_ptr( - base=bias + b_offset, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn else: - bias_ptr = None + bias_ptrs = None if USE_ALIBI: a_offset = off_z * stride_az + off_h_q * stride_ah @@ -483,14 +481,11 @@ def attn_fwd( batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. In # this case, we return an invalid pointer so indicate the mask is not valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr(base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0)) + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] else: - encoded_softmax_block_ptr = 0 + encoded_sm_ptrs = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) @@ -499,8 +494,11 @@ def attn_fwd( # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -522,14 +520,16 @@ def attn_fwd( # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope, + block_min, block_max, 0, 0, 0, alibi_slope, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) block_min = block_max block_max = n_blocks * BLOCK_N @@ -540,18 +540,20 @@ def attn_fwd( offs_n_causal = offs_n + (seqlen_q - seqlen_k) else: offs_n_causal = 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, - alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, + offs_n, # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -578,21 +580,20 @@ def attn_fwd( overflow_size = end_m_idx - seqlen_q if overflow_size > 0: boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) - # This is a > check because mask being 0 blocks the store. - l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) else: tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) @triton.jit @@ -941,7 +942,7 @@ def forward(ctx, q, k, v, o, metadata): encoded_softmax=encoded_softmax, alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, - BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, + BLOCK_DMODEL=padded_d_model, USE_BIAS=False if metadata.bias is None else True, USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) @@ -1065,8 +1066,6 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) cu_seqlens_q = cu_seqlens_q.to(device="cuda") cu_seqlens_k = cu_seqlens_k.to(device="cuda") - # -1 because the last entry of cu_seqlens_q specifies the end of the last seq - # num_ctxs = len(cu_seqlens_q) - 1 # Initialize q, k, v with variable lengths total_q = cu_seqlens_q[-1].item() @@ -1114,9 +1113,6 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to else: alibi_slopes = None - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1150,11 +1146,11 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ (4, 48, 1024, 1024, 64), - (4, 24, 8192, 8192, 64), + (4, 12, 8192, 8192, 64), (2, 4, 16384, 16384, 128), (2, 16, 1020, 987, 128), (2, 16, 15498, 2, 128), - (2, 16, 7, 16219, 64), + (2, 4, 7, 16219, 64), (4, 48, 1, 1, 64), (4, 48, 1, 1, 128), (4, 48, 3, 3, 128), @@ -1164,12 +1160,12 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to (4, 4, 1024, 1024, 33), (4, 4, 65, 1019, 65), (4, 4, 128, 128, 65), - (4, 4, 113, 123, 1), + # TODO: This config fails. Disabled until triaged and fixed. + # (4, 4, 113, 123, 1), ]) -@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_bias', [True]) def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): - pytest.skip() torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) @@ -1185,9 +1181,6 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1218,9 +1211,8 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) @pytest.mark.parametrize('causal', [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - pytest.skip() + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) ref_out = torch.empty_like(q) @@ -1413,8 +1405,8 @@ def run_benchmark(custom): args = parse_args() dtype = arg_to_torch_dtype[args.dtype] - # hk = args.hq if not args.hk else args.hk - # sk = args.sq if not args.sk else args.sk + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d mode = 'fwd' x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] @@ -1422,7 +1414,7 @@ def run_benchmark(custom): varlen = args.varlen configs = [] if custom: - x_vals_list = [(args.b, args.hq, args.hk, args.sq, args.sk)] + x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] else: if varlen: x_vals_list = varlen_benchmark_configs() From aa6685a16dde93b0c559f16f39cf0cf2994c27a9 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Mon, 20 May 2024 14:57:21 -0500 Subject: [PATCH 04/12] Add support for bshd layout (#587) Add support for layouts commonly used by users. Add option for varlen / thd layout to specify equal context lengths for all batches. Also often used by users. --- python/perf-kernels/flash-attention.py | 216 +++++++++++++------------ 1 file changed, 114 insertions(+), 102 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 42e9ac310195..d36caaf61952 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -28,8 +28,6 @@ import triton import triton.language as tl -torch_dtype: tl.constexpr = torch.float16 - class MetaData(): cu_seqlens_q = None @@ -41,6 +39,7 @@ class MetaData(): causal = False num_contexts = 0 varlen = False + layout = None dropout_p, return_encoded_softmax = 0.0, False def __init__(self, sm_scale=1.0): @@ -48,6 +47,7 @@ def __init__(self, sm_scale=1.0): def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): self.varlen = True + self.layout = 'thd' self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k # Without "varlen", there should still be one sequence. @@ -81,10 +81,10 @@ def need_dropout(self, dropout_p, return_encoded_softmax): def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, self) if self.varlen: assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape assert self.cu_seqlens_q is not None assert self.cu_seqlens_k is not None assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) @@ -95,8 +95,6 @@ def check_args(self, q, k, v, o): assert not self.return_encoded_softmax else: assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 assert self.cu_seqlens_q is None and self.cu_seqlens_k is None assert k.shape == v.shape @@ -106,6 +104,8 @@ def check_args(self, q, k, v, o): assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen @triton.jit @@ -326,60 +326,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri use_cuda_graph=True, ) @triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - stride_az, - stride_ah, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - alibi_slopes, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_ALIBI: tl.constexpr, - BATCH_SIZE: tl.constexpr, -): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, + stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -876,6 +830,44 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, empty = torch.empty(128, device="cuda") +def get_shape_from_layout(q, k, metadata): + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + else: + assert False, "Got unsupported layout." + return batch, nheads_q, nheads_k, head_size + + +# TODO: This can probably optimized to have fewer lines of code. +def get_strides_from_layout(q, k, v, o, metadata): + if metadata.layout == 'thd': + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif metadata.layout == 'bhsd': + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif metadata.layout == 'bshd': + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return q_strides, k_strides, v_strides, o_strides + + class _attention(torch.autograd.Function): @staticmethod @@ -887,24 +879,14 @@ def forward(ctx, q, k, v, o, metadata): if o is None: o = torch.empty_like(q, dtype=v.dtype) metadata.check_args(q, k, v, o) - if metadata.varlen: - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = metadata.num_contexts - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) - k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) - v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) - o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata) # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) @@ -944,7 +926,7 @@ def forward(ctx, q, k, v, o, metadata): MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if metadata.bias is None else True, USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p - > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid @@ -1036,30 +1018,41 @@ def backward(ctx, do, _): attention = _attention.apply -def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout): torch.manual_seed(20) # Initialize q, k, v - q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + else: + assert False, 'Got unsupported tensor layout' + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.max_seqlens_q = N_CTX_Q input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout return q, k, v, input_metadata -def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False): torch.manual_seed(20) # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs - max_seqlens_q = N_CTX_Q // Z - max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) - max_seqlens_q = torch.max(seqlens_q).item() - max_seqlens_k = torch.max(seqlens_k).item() + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) # Calculate cumulative sequence lengths cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) @@ -1099,9 +1092,10 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_alibi', [True, False]) -def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=torch.float16): +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): torch.manual_seed(20) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) if causal: input_metadata.need_causal() @@ -1118,6 +1112,11 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to # triton implementation tri_out, _ = attention(q, k, v, o, input_metadata) + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() # Replicate K and V if using MQA/GQA if HQ != HK: k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], @@ -1141,6 +1140,8 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to p[nan_mask == 1] = 0 ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) @@ -1169,8 +1170,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = N_CTX_Q - input_metadata.max_seqlens_k = N_CTX_K + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') if causal: input_metadata.need_causal() if use_bias: @@ -1178,9 +1178,6 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) else: bias = None - q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() o = torch.empty_like(q) # triton implementation @@ -1211,6 +1208,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) @pytest.mark.parametrize('causal', [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) @@ -1401,9 +1399,8 @@ def varlen_benchmark_configs(): return configs -def run_benchmark(custom): +def run_benchmark(custom, args): - args = parse_args() dtype = arg_to_torch_dtype[args.dtype] hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk @@ -1411,7 +1408,7 @@ def run_benchmark(custom): mode = 'fwd' x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal - varlen = args.varlen + varlen = args.layout == 'thd' configs = [] if custom: x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] @@ -1425,7 +1422,7 @@ def run_benchmark(custom): configs.append( triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], line_names=[line_names], styles=[('red', '-')], ylabel='ms', - plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}', + plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}', args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) @triton.testing.perf_report(configs) @@ -1447,14 +1444,15 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal flops_per_matmul = 0 if varlen: - q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, + args.equal_seqlens) for i in range(0, input_metadata.num_contexts): seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] # x2 for 2 GEMMs flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 else: - q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout) flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD if causal: input_metadata.need_causal() @@ -1479,6 +1477,15 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal bench_flash_attention.run(save_path=".", print_data=True) +def supported_layouts(): + layouts = \ + 'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \ + 'bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]' \ + 'thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]' \ + 'This layout is sometimes called "varlen" or "grouped" layout.' + return layouts + + def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", @@ -1489,11 +1496,14 @@ def parse_args(): parser.add_argument("-hk", type=int, default=0) parser.add_argument("-sq", type=int, default=0) parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-equal_seqlens", action='store_true', default=False, + help='If specified, each context within the thd layout' \ + ' has same seqlen as sq and sk') parser.add_argument("-d", type=int, default=0) parser.add_argument("-causal", action='store_true', default=False) - parser.add_argument("-varlen", action='store_true', default=False) parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False) + parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) return parser.parse_args() @@ -1503,6 +1513,8 @@ def parse_args(): def main(): args = parse_args() custom_config = False + assert args.layout == 'thd' or not args.equal_seqlens, \ + "Equal sequence lengths arg must be used with the thd layout." if args.b or args.hq or args.hk or args.sq or args.sk or args.d: custom_config = True assert args.b and args.hq and args.sq and args.d, \ @@ -1513,7 +1525,7 @@ def main(): assert args.dtype in arg_to_torch_dtype, \ "Only fp16, bf16 and f32 types currently supported." - run_benchmark(custom_config) + run_benchmark(custom_config, args) if __name__ == '__main__': From dbe11738b9de976b4423db46faa94385a900ae6e Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 16 Jul 2024 19:22:27 -0400 Subject: [PATCH 05/12] Post-Merge CI (#612) * remove on push for Integration Tests * rename * add post merge test * save * dtype params * skip bad config * fix more stuff --- ... => amd_perf_kernel_Integration_tests.yml} | 8 +- .../amd_perf_kernel_postmerge_tests.yml | 92 +++++++++++++++++++ python/perf-kernels/flash-attention.py | 11 ++- 3 files changed, 101 insertions(+), 10 deletions(-) rename .github/workflows/{amd_perf_kernel_tests.yml => amd_perf_kernel_Integration_tests.yml} (95%) create mode 100644 .github/workflows/amd_perf_kernel_postmerge_tests.yml diff --git a/.github/workflows/amd_perf_kernel_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml similarity index 95% rename from .github/workflows/amd_perf_kernel_tests.yml rename to .github/workflows/amd_perf_kernel_Integration_tests.yml index 07424924a832..a8a8b3d50b9e 100644 --- a/.github/workflows/amd_perf_kernel_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -1,4 +1,4 @@ -name: AMD Perf Kernel Tests +name: AMD Perf Kernel Integration Tests on: workflow_dispatch: @@ -7,8 +7,6 @@ on: merge_group: branches: [main_perf] types: [checks_requested] - push: - branches: [main_perf] concurrency: group: ${{ github.ref }} @@ -36,8 +34,8 @@ jobs: changed_files=$(git diff --name-only origin/${{ github.base_ref }} ${{ github.sha }}) echo "Changed files:" echo "$changed_files" - if echo "$changed_files" | grep -v "^python/perf-kernels/"; then - echo "Changes detected outside of the python/perf-kernels directory. Failing the workflow." + if echo "$changed_files" | grep -vE "^python/perf-kernels/|^\.github/workflows/amd_"; then + echo "Changes detected outside of the python/perf-kernels directory or .github/workflows/amd_ files. Failing the workflow." exit 1 fi diff --git a/.github/workflows/amd_perf_kernel_postmerge_tests.yml b/.github/workflows/amd_perf_kernel_postmerge_tests.yml new file mode 100644 index 000000000000..40f211118541 --- /dev/null +++ b/.github/workflows/amd_perf_kernel_postmerge_tests.yml @@ -0,0 +1,92 @@ +name: AMD Perf Kernel Post-Merge Tests + +on: + workflow_dispatch: + push: + branches: [main_perf, micmelesse/post_merge_ci] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main_perf' }} + +permissions: read-all + +env: + TRITON_BUILD_WITH_CLANG_LLD: "TRUE" + TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" + TRITON_DISABLE_LINE_INFO: 1 + +jobs: + Runner-Preparation-AMD: + runs-on: ubuntu-latest + timeout-minutes: 30 + outputs: + matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} + steps: + - name: Prepare runner matrix + id: set-matrix + run: | + if [ x"${{ github.repository }}" == x"ROCm/triton" ]; then + echo '::set-output name=matrix-HIP::[["self-hosted", "rocm.gfx90a"]]' + else + echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' + fi + + PostMerge-Tests-AMD: + needs: Runner-Preparation-AMD + if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} + container: + image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Ensure the entire history is fetched for rebase + - name: Add upstream remote + run: | + git config --global --add safe.directory /__w/triton/triton + if [ $(git remote | grep -c upstream) -eq 0 ]; then + git remote add upstream https://github.com/triton-lang/triton.git + fi + git fetch upstream + - name: Rebase onto upstream/main + run: | + git config --global user.email "ci@amd.com" + git config --global user.name "Github Actions Post-Merge CI Script" + git rebase upstream/main || { echo "Rebase failed"; exit 1; } + - name: Show Git Log + run: | + echo "Git log after rebase from upstream/main to HEAD:" + git log $(git rev-parse upstream/main~2)..HEAD --oneline --graph --decorate + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Clear cache + run: | + rm -rf ~/.triton + mkdir -p ~/.triton + ls -alh ~/.triton + - name: Update PATH + run: | + echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH + - name: Install pip dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install lit matplotlib pandas + - name: Install Triton + run: | + echo "PATH is '$PATH'" + pip uninstall -y triton + cd python + pip install -v -e . + - name: Run Perf Kernels Unit Tests + run: | + pytest -vvv ./python/perf-kernels/flash-attention.py + - name: Run Perf Kernels Benchmark + run: | + python ./python/perf-kernels/flash-attention.py diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index d36caaf61952..8177cf4ebf30 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -309,8 +309,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, @@ -1166,7 +1166,8 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_bias', [True]) -def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) @@ -1174,7 +1175,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor if causal: input_metadata.need_causal() if use_bias: - bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda") input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) else: bias = None @@ -1197,7 +1198,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) From 23ba5467d83db1d8ca36f8ee34ff287ae089469c Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 18 Jul 2024 17:04:16 -0500 Subject: [PATCH 06/12] Increase CI timeout (#615) Increase CI timeout --- .github/workflows/amd_perf_kernel_Integration_tests.yml | 2 +- .github/workflows/amd_perf_kernel_postmerge_tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/amd_perf_kernel_Integration_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml index a8a8b3d50b9e..956ff8903115 100644 --- a/.github/workflows/amd_perf_kernel_Integration_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -95,7 +95,7 @@ jobs: needs: Runner-Preparation-AMD if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 90 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} diff --git a/.github/workflows/amd_perf_kernel_postmerge_tests.yml b/.github/workflows/amd_perf_kernel_postmerge_tests.yml index 40f211118541..21470c094e46 100644 --- a/.github/workflows/amd_perf_kernel_postmerge_tests.yml +++ b/.github/workflows/amd_perf_kernel_postmerge_tests.yml @@ -36,7 +36,7 @@ jobs: needs: Runner-Preparation-AMD if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 90 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} From df4c4d3a7fa7a1329626972b36b9b5d8a84c75f2 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Fri, 19 Jul 2024 17:50:49 -0500 Subject: [PATCH 07/12] Couple of FA optimizations (#608) Couple of FA optimizations Set SM scale multiplication to a constexpr. Minor asm improvement. Changed acc scaling to adjust for softmax division to multiplication with reciprocal. ~10% perf improvement. --------- Co-authored-by: Michael Melesse --- python/perf-kernels/flash-attention.py | 39 ++++++++++++-------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 8177cf4ebf30..988438340abe 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -301,35 +301,28 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - # num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, - stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, - stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, +def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, + stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, + stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, + cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, @@ -446,13 +439,13 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - q = (q * qk_scale).to(q.type.element_ty) + q = (q * QK_SCALE).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -509,7 +502,10 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL) # epilogue - acc = acc / l_i[:, None] + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, @@ -1198,6 +1194,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) From 52a908fd512e1eb0790d2c48eae2b70e982e751d Mon Sep 17 00:00:00 2001 From: xiaohuguo2023 <149615094+xiaohuguo2023@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:33:02 +0100 Subject: [PATCH 08/12] streamk v0.1 (#619) * streamk v0.1 * remove unused variable * fix format issues * add README * fix format issue * change num_sms to num_cus --- .../03-matrix-multiplication-stream-k.py | 395 -------- ...trix-multiplication-stream-k-oldversion.py | 485 ---------- ...iplication-stream-k-singlekern-autotune.py | 563 ------------ ...ultiplication-stream-k-singleloop-nomod.py | 387 -------- python/perf-kernels/streamk/README.md | 43 + python/perf-kernels/streamk/streamk_kernel.py | 206 +++++ python/perf-kernels/streamk/tune_streamk.py | 847 ++++++++++++++++++ 7 files changed, 1096 insertions(+), 1830 deletions(-) delete mode 100755 python/perf-kernels/03-matrix-multiplication-stream-k.py delete mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py delete mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py delete mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py create mode 100644 python/perf-kernels/streamk/README.md create mode 100644 python/perf-kernels/streamk/streamk_kernel.py create mode 100644 python/perf-kernels/streamk/tune_streamk.py diff --git a/python/perf-kernels/03-matrix-multiplication-stream-k.py b/python/perf-kernels/03-matrix-multiplication-stream-k.py deleted file mode 100755 index 62d820719b9a..000000000000 --- a/python/perf-kernels/03-matrix-multiplication-stream-k.py +++ /dev/null @@ -1,395 +0,0 @@ -#!/usr/bin/env python -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def streamk_gemm( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - total_tiles_streamk, - total_programs_streamk, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(0) - - # Determine whether we are in the first wave or full_tiles phase based on pid - is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 - - # Calculate starting and ending iterations for first wave - if not is_first_wave: - tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers -# rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) -# rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C_, acc) - else: - # start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 4 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{total_full_tiles_streamk=}") - print(f"{total_partial_tiles_streamk=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - # allocates locks to sync work accross SMs - grids = total_programs_streamk + total_blocking_tiles - kk = streamk_gemm[(grids, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, - total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, - total_tiles_streamk=total_tiles_streamk, - total_programs_streamk=total_programs_streamk, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - num_stages=num_stages, - num_warps=num_warps, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, - ) - if matmul._debug: - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) - - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 4864, 4096, 8256 # some problem size to test -#m, n, k = 4096, 4096, 8192 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -m, n, k = 6912, 768, 256 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -BLK_M = 64 -BLK_N = 64 -BLK_K = 64 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 2 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, - kpack) -#exit(0) -matmul.set_debug(False) -expected = A @ B - -#assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" -print("pass validation test") - -# for debugging, uncomment the following line -# exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, - waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLK_M) * (n // BLK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench( - lambda: wrapper_matmul(A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py deleted file mode 100644 index beb8b0df9b1f..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py +++ /dev/null @@ -1,485 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -# iterate, multiply and accumulate over K axis -@triton.jit() -def mac_loop( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - tile_id, - mod1, - mod2, - iters_per_tile, - start_iter, - end_iter, - pid_m, - pid_n, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, -): - - # where are we in the grid - # tile_id = start_iter // iters_per_tile - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (start_iter % iters_per_tile) - # B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (start_iter % iters_per_tile) - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (mod1) - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (mod1) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for current_iter in range(start_iter, end_iter): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - #if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM - - -# if mod2 == 0:# last iteration of the tile always happens before its start on another SM -# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! -# tl.store(C_, acc) -# if start_iter % iters_per_tile != 0: # only if tile has been partially processed -# if mod1 != 0: # only if tile has been partially processed -# tl.atomic_xchg(locks + tile_id, 1) -# else: -# while tl.atomic_cas(locks + tile_id, 1, 1) != 1: -# pass -# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! -# tl.atomic_add(C_, acc) - if mod1 == 0 and mod2 == 0: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - -@triton.jit() -def first_wave( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - pid = tl.program_id(0) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - - while start_iter < last_iter: - end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter) - mod1 = start_iter % iters_per_tile - mod2 = end_iter % iters_per_tile - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - mac_loop( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - tile_id, - mod1, - mod2, - iters_per_tile, - start_iter, - end_iter, - pid_m, - pid_n, - BLOCK_M, - BLOCK_N, - BLOCK_K, - ACC_TYPE, - ) - - start_iter = end_iter - - -# similar to the reference matmul kernel -@triton.jit() -def full_tiles( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_tiles_streamk, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - # first wave has done more tiles than there are SMs, we adjust pid - tile_id = tl.program_id(0) + total_tiles_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C, acc) - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = False - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 8 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - # allocates locks to sync work accross SMs - k1 = first_wave[(total_programs_streamk, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, - total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - num_stages=num_stages, - num_warps=num_warps, - ) - if matmul._debug: - print(f"{k1.n_regs} registers used, {k1.n_spills} spills") - k2 = full_tiles[(total_blocking_tiles, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_tiles_streamk=total_tiles_streamk, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - num_stages=num_stages, - num_warps=num_warps, - ) - if matmul._debug: - print(f"{k2.n_regs} registers used, {k2.n_spills} spills") - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -m, n, k = 8192, 8192, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -BLK_M = 128 -BLK_N = 256 -BLK_K = 16 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, 128, 128, 32, 4, 4) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" - -# for debugging, uncomment the following line -# exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // 128) * (n // 128) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, 128, 128, 32, two_tiles, 4, 4)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py deleted file mode 100644 index a35d691a0225..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py +++ /dev/null @@ -1,563 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") -# global flag to indicate whether using the full tuing space -tuning_full_space = True -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def get_tile_config(M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, two_tiles, - total_programs_streamk): - total_blocks_M = tl.cdiv(M, BLOCK_M) - total_blocks_N = tl.cdiv(N, BLOCK_N) - iters_per_tile = tl.cdiv(K, BLOCK_K) - # GROUP_M = 0 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - return iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk - - -# pruned some unreasonable config -def prune_configs(configs, named_args): - # call only for full tuning space - if not tuning_full_space: - return configs - - SIZE_M = named_args["A"].shape[0] - SIZE_N = named_args["B"].shape[1] - # SIZE_K = named_args["A"].shape[1] - - pruned_configs = [] - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, _ =\ - kw["BLOCK_M"], kw["BLOCK_N"], kw["BLOCK_K"] - if SIZE_M <= 32 and BLOCK_M != 32: - continue - if SIZE_N <= 32 and BLOCK_N != 32: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def get_full_tuning_space(): - configs = [] - if not tuning_full_space: - return configs - - block_mn_range = [64, 128, 256] - block_k_range = [16, 32, 64] - num_warps_range = [1, 2, 4, 8] - # group_m_range = [0, 1, 2, 4, 8] - group_m_range = [0, 4, 8] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - for num_stages in num_stage_range: - for num_waves_per_eu in waves_per_eu_range: - for matrix_instr_nonkdim in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append( - triton.Config( - { - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu, - 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack - }, - num_stages=num_stages, - num_warps=num_warps, - )) - - return configs - - -#To do: we need update the default autotune configuration once we go through the whole performance test sets. -@triton.autotune( - configs=get_full_tuning_space() if tuning_full_space else [ - triton.Config( - { - 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=4), - ], - key=['M', 'N', 'K'], - # prune_configs_by={ - # 'early_config_prune': prune_configs, - # 'perf_model': None, - # "top_k": None - # }, - reset_to_zero=['C'], -) -@triton.jit() -def streamk_gemm( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, - # total_tiles_streamk, - total_programs_streamk, - two_tiles, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(0) - iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk = get_tile_config( - M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, total_programs_streamk) - - # Determine whether we are in the first wave or full_tiles phase based on pid - is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 - - # Calculate starting and ending iterations for first wave - if not is_first_wave: - tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - precomputed_stride_ak = BLOCK_K * stride_ak - precomputed_stride_bk = BLOCK_K * stride_bk - # pointers - A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += precomputed_stride_ak - B_BASE += precomputed_stride_bk - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C_, acc) - else: - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - precomputed_stride_ak = BLOCK_K * stride_ak - precomputed_stride_bk = BLOCK_K * stride_bk - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += precomputed_stride_ak - B_BASE += precomputed_stride_bk - - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - - def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_programs_streamk): - total_blocks_M = triton.cdiv(M, BLOCK_M) - total_blocks_N = triton.cdiv(N, BLOCK_N) - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - - return total_blocking_tiles - - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - # GROUP_M = 8 # 0 to disable swizzling - - if matmul._debug: - total_blocks_M = triton.cdiv(M, BLOCK_M) - total_blocks_N = triton.cdiv(N, BLOCK_N) - iters_per_tile = triton.cdiv(K, BLOCK_K) - total_tiles = total_blocks_M * total_blocks_N - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - # total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - # total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - print(f"M,N,K={M},{N},{K} ; BLOCK_M,N,K={BLOCK_M},{BLOCK_N},{BLOCK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{total_partial_tiles_streamk=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - grids = lambda META: (total_programs_streamk + compute_total_blocking_tiles(M, N, META['BLOCK_M'], META[ - 'BLOCK_N'], two_tiles, total_programs_streamk), ) - kk = streamk_gemm[(grids)]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - # total_full_tiles_streamk=total_full_tiles_streamk, - # total_partial_tiles_streamk=total_partial_tiles_streamk, - # iters_per_tile=iters_per_tile, - # total_tiles_streamk=total_tiles_streamk, - total_programs_streamk=total_programs_streamk, - two_tiles=two_tiles, - ACC_TYPE=ACC_TYPE, - # GROUP_M=GROUP_M, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # BLOCK_K=BLOCK_K, - # num_stages=num_stages, - # num_warps=num_warps, - # waves_per_eu = waves_per_eu, - ) - if matmul._debug: - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 1792, 7424, 4864 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -m, n, k = 4096, 4096, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -#A = torch.ones((m, k), device="cuda", dtype=torch.float16) -#B = torch.ones((k, n), device="cuda", dtype=torch.float16) -BLOCK_M = 256 -BLOCK_N = 256 -BLOCK_K = 64 -two_tiles = True -num_stages = 0 -num_warps = 8 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 1 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, - mfmaInstrSize, kpack) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" -print("pass validation test") - -# for debugging, uncomment the following line -#exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLOCK_M) * (n // BLOCK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - Best_tuning_config = f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})' - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - "Best tuning config": Best_tuning_config, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py deleted file mode 100644 index 2651ad59d923..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py +++ /dev/null @@ -1,387 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def first_wave( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - pid = tl.program_id(0) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# similar to the reference matmul kernel -@triton.jit() -def full_tiles( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_tiles_streamk, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - # first wave has done more tiles than there are SMs, we adjust pid - tile_id = tl.program_id(0) + total_tiles_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - # rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - # rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C, acc) - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 4 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - - k1 = first_wave[(total_programs_streamk, )]( - a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack) - if matmul._debug: - print(f"{k1.n_regs} registers used, {k1.n_spills} spills") - k2 = full_tiles[(total_blocking_tiles, )](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c.stride(0), c.stride(1), total_tiles_streamk=total_tiles_streamk, - BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, - waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack) - if matmul._debug: - print(f"{k2.n_regs} registers used, {k2.n_spills} spills") -# print(k2.asm['amdgcn']) - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 4864, 4096, 8256 # some problem size to test -m, n, k = 6912, 768, 256 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -#A = torch.ones((m, k), device="cuda", dtype=torch.float16) -#B = torch.ones((k, n), device="cuda", dtype=torch.float16) -BLK_M = 64 -BLK_N = 64 -BLK_K = 64 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 2 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, - kpack) -#exit(0) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" - -# for debugging, uncomment the following line - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, - waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLK_M) * (n // BLK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul( - A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/README.md b/python/perf-kernels/streamk/README.md new file mode 100644 index 000000000000..aa0b11d41b73 --- /dev/null +++ b/python/perf-kernels/streamk/README.md @@ -0,0 +1,43 @@ +# streamk gemm script v0.1 + +The plan is to use this version as the base version for the future triton streamk gemm development. + +### Main features +- comparable performance with tune gemm + +- use the persistent loop so that a WG may work on multiple output tiles, and also allowing workgroups to do part of the work for an output tile. + +- use atomics for spinning lock to replace atomic_add for the final output. + +- pid renumbering based on chiplet structure of MI300X + +- dynamic grid setting + +- tuning script adapt from tune_gemm + +### Usage + +Go to the script dir +```bash +cd triton/python/perf_kernels/streamk +``` + +1. Tune gemm sizes given in a yaml file and check correctness on the way +```bash +python tune_streamk.py --gemm_size_file input_gemm_sizes.yaml --compare +``` + +2. Tune a single gemm size +```bash +python tune_streamk.py -m 16 -n 16 -k 16 +``` + +3. Choose the file to store tuning results +```bash +python tune_streamk.py --gemm_size_file input_gemm_sizes.yaml --o output_tuning.yaml +``` + +4. Only check correctness given the tuning results +```bash +python tune_streamk.py --gemm_size_file output_tuning.yaml --compare_wo_tuning +``` diff --git a/python/perf-kernels/streamk/streamk_kernel.py b/python/perf-kernels/streamk/streamk_kernel.py new file mode 100644 index 000000000000..138e6540e203 --- /dev/null +++ b/python/perf-kernels/streamk/streamk_kernel.py @@ -0,0 +1,206 @@ +import triton +import triton.language as tl + + +@triton.jit() +def get_new_pid(current_pid, num_cus): + # Number of XCDs + num_xcds = 8 + # Number of pids per XCD in the new arrangement + pids_per_xcd = num_cus // num_xcds + # Compute current XCD and local pid within the XCD + xcd = current_pid % num_xcds + local_pid = current_pid // num_xcds + + # Calculate new pid based on the new grouping + new_pid = xcd * pids_per_xcd + local_pid + return new_pid + + +@triton.jit() +def get_tiles_config( + M, + N, + K, + num_cus, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + total_blocks_M = tl.cdiv(M, BLOCK_SIZE_M) + total_blocks_N = tl.cdiv(N, BLOCK_SIZE_N) + iters_per_tile = tl.cdiv(K, BLOCK_SIZE_K) + + total_tiles = total_blocks_M * total_blocks_N + if num_cus > 0 and total_tiles > num_cus: # Stream-K + total_streamk_tiles = total_tiles % num_cus + total_full_tiles = total_tiles - total_streamk_tiles + total_streamk_iters = total_streamk_tiles * iters_per_tile + # iterations related to full waves + streamk_iters_pcu = total_streamk_iters // num_cus + # iterations related to last (partial) wave + streamk_remainder_iters = total_streamk_iters % num_cus + + else: # all tiles are computed using classical blocking + total_full_tiles = total_tiles + total_streamk_tiles = 0 + streamk_iters_pcu = 0 + streamk_remainder_iters = 0 + total_streamk_iters = 0 + + return iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters + + +@triton.jit() +def streamk_gemm( + A, + B, + C, + P, + locks, + M, + N, + K, + num_cus, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + pid = get_new_pid(pid, num_cus) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters = get_tiles_config( + M, N, K, num_cus, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + rk = tl.arange(0, BLOCK_SIZE_K) + + for tile_id in range(pid, total_full_tiles, num_cus): + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if EVEN_K: + a = tl.load(A_BASE) + b = tl.load(B_BASE) + else: + a = tl.load(A_BASE, mask=rk[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + c = acc.to(C.type.element_ty) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_, c, mask=mask) + + start_iter = total_full_tiles * iters_per_tile + pid * streamk_iters_pcu + tl.minimum(pid, streamk_remainder_iters) + last_iter = total_full_tiles * iters_per_tile + (pid + 1) * streamk_iters_pcu + tl.minimum( + pid + 1, streamk_remainder_iters) + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + # rk = tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for current_iter in range(start_iter, end_iter): + if EVEN_K: + a = tl.load(A_BASE) + b = tl.load(B_BASE) + else: + global_k_offset = (current_iter % iters_per_tile) * BLOCK_SIZE_K + k_mask = global_k_offset + rk < K + a = tl.load(A_BASE, mask=k_mask[None, :], other=0.0) + b = tl.load(B_BASE, mask=k_mask[:, None], other=0.0) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + tile_iter = tile_id * iters_per_tile + if start_iter == tile_iter: + tile_iter_end = tile_iter + iters_per_tile + next_pid = pid + 1 + end = end_iter + while (end < tile_iter_end and next_pid < num_cus): + # todo: try use tl.load once cache modifier landed upstream + while tl.atomic_cas(locks + next_pid, 1, 1) != 1: + pass + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + acc += tl.load(P_) + end += streamk_iters_pcu + (next_pid < streamk_remainder_iters) + + next_pid += 1 + + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_, c, mask=mask) + + else: + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + tl.store(P_, acc) + tl.atomic_xchg(locks + pid, 1) + + start_iter = end_iter diff --git a/python/perf-kernels/streamk/tune_streamk.py b/python/perf-kernels/streamk/tune_streamk.py new file mode 100644 index 000000000000..3b0fbdb960c7 --- /dev/null +++ b/python/perf-kernels/streamk/tune_streamk.py @@ -0,0 +1,847 @@ +# fp8 +import argparse +import sys +import yaml +import os +import glob +import subprocess + +import torch +import triton +import triton.language as tl + +from streamk_kernel import streamk_gemm + +from datetime import datetime +import multiprocessing +import pandas as pd + +device_oi = 650. / 3.0 + + +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, + 'GROUP_SIZE_M': group_m, 'num_warps': num_warps, 'num_stages': num_stages, + 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, + 'kpack': kpack + }) + + return configs + + +def get_gemm_oi(M, N, K): + FLOPs = 2 * M * N * K + # 4 for fp32 + # to do check dtype for bytesmoved + bytesmoved = (M * K + K * N + 2 * M * N) * 4 + return FLOPs / bytesmoved + + +def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): + pruned_configs = [] + + if M < 32 or N < 32: + mfma = 16 + else: + mfma = 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elemens per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + GROUP_M = config.get("GROUP_SIZE_M") + if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + continue + if BLOCK_SIZE_K == 16 and matrix_instr_nonkdim == 16 and kpack == 2: + continue + if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + continue + if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + continue + if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def run_bash_command_wrapper(commandstring, capture=True): + try: + run_bash_command(commandstring, capture) + except subprocess.CalledProcessError: + if not capture: + print(f"running {commandstring} one more time") + run_bash_command(commandstring, capture) + + +def run_bash_command(commandstring, capture=True): + if capture: + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) + return proc.stdout.splitlines() + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash') + return None + + +def read_config(config): + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + block_k = config.get('BLOCK_SIZE_K') + group_m = config.get('GROUP_SIZE_M') + num_warps = config.get('num_warps') + num_stages = config.get('num_stages') + waves_per_eu = config.get('waves_per_eu') + mfma_instr_size = config.get('matrix_instr_nonkdim') + kpack = config.get('kpack') + return block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack + + +def gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, dtype_b, dtype_c, dtype_p, + dtype_lock): + block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) + torch_dtype_a = 'fp16' + torch_dtype_b = 'fp16' + torch_dtype_c = 'fp16' + torch_dtype_p = 'fp32' + torch_dtype_lock = 'int32' + if dtype_a: + torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] + if dtype_b: + torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] + if dtype_c: + torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] + if dtype_p: + torch_dtype_p = tl_to_torch_types[name_to_tl_types[dtype_p]] + if dtype_lock: + torch_dtype_lock = tl_to_torch_types[name_to_tl_types[dtype_lock]] + configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" + + matmul_def_str = f""" +def matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, am, ak, bk, bn, cm, cn, warmup=False): + grid = num_cus + #print(f'config: streamk_gemm_{configStr}', flush=True) + if warmup: + streamk_gemm_{configStr}.warmup( + {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_p}, {torch_dtype_lock}, + M, N, K, num_cus, + am, ak, bk, bn, cm, cn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + EVEN_K = {EVEN_K}, + grid=(1,) + ) + return None + else: + streamk_gemm_{configStr}[grid,]( + a, b, c, P, locks, + M, N, K, num_cus, + am, ak, bk, bn, cm, cn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + EVEN_K = {EVEN_K} + ) + return c + +def try_config_{configStr}(M, N, K, num_cus, am, ak, bk, bn, cm, cn): + try: + matmul_{configStr}(None, None, None, None, None, M, N, K, num_cus, am, ak, bk, bn, cm, cn, True) + return True + except Exception as e: + print(f'invalid config(compilation): {configStr}: ', e, flush=True) + return False +""" + return configStr, matmul_def_str + + +def generated_kernel_name(M, N, K, gpu_id): + return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py" + + +# Open {len(gpus)} files +# generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py +# and generate +# 1. matmul kernels of all configs +# 2. wrapper function matmul to invoke all the generated kernels +# 3. Another wraper function try_config to invoke matmul function +# 4. test_gemm to invoke +# 4.1 run try_config in parallel +# 4.2 matmul in a loop of 10 iterations +def generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + jobs, iters, run_bench): + filenames = [] + for i in range(jobs): + filenames.append(generated_kernel_name(M, N, K, i)) + f_kernel = [open(path, 'w') for path in filenames] + + # write imports + import_str = """import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_streamk import gen_input +""" + for fi in range(jobs): + f_kernel[fi].write(import_str + "\n") + + # write definitions of streamk_gemm_xxx + # and matmul_xxx and try_config + with open("streamk_kernel.py") as file: + streamk_gemm_code = file.read() + idx = 0 + for config in configs: + file_idx = idx % jobs + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, + dtype_b, dtype_c, dtype_p, dtype_lock) + # Copy the streamk_gemm with name replaced + streamk_gemm_config = streamk_gemm_code.replace("streamk_gemm", f"streamk_gemm_{configStr}") + streamk_gemm_config = streamk_gemm_config.replace("import triton.language as tl", "") + streamk_gemm_config = streamk_gemm_config.replace("import triton", "") + f_kernel[file_idx].write(streamk_gemm_config + "\n\n") + f_kernel[file_idx].write(matmul_def_str + "\n") + idx += 1 + + # write test_gemm + # pre string + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + test_gemm_pre_str = f"""def test_gemm(M, N, K, num_cus, num_threads): + thread_pool = multiprocessing.Pool(processes=num_threads) + a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda') + b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda') + c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]}) + task_args = (M, N, K, num_cus, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1)) + + if num_threads > 1: + results = [] + config_names = [] +""" + for fi in range(jobs): + f_kernel[fi].write(test_gemm_pre_str + "\n") + + # warm up call of all matmul functions in parallel + idx = 0 + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, + None) + task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \ + f" config_names += ['{configStr}']\n" + f_kernel[idx % jobs].write(task_str) + idx += 1 + + for fi in range(jobs): + threadpool_str = """ + failed_configs = [] + for i in range(len(results)): + results[i].wait() + res = results[i].get() + if not res: + failed_configs += [config_names[i]] + thread_pool.close() + thread_pool.join() + with open("{filename}.failed_configs", "w") as f: + for cfg in failed_configs: + f.write(cfg + "\\n") + else: + try: + with open("{filename}.failed_configs", "r") as f: + failed_configs = [cfg.strip() for cfg in f.readlines()] + except Exception: + failed_configs = [] + """.format(filename=filenames[fi]) + f_kernel[fi].write(threadpool_str) + # call all matmul_xxx functions + idx = 0 + runs = iters if run_bench else 200 + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, + None) + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + matmul_call_str = f""" + if '{configStr}' not in failed_configs: + print(f"{configStr}") + for i in range({runs}): + locks = torch.zeros((num_cus,), device = "cuda", dtype = torch.int32) + P = torch.zeros((num_cus, {block_m}*{block_n}), device="cuda", dtype=torch.float32) + d = matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))""" + f_kernel[idx % jobs].write(matmul_call_str + "\n") + idx += 1 + # post string + for fi in range(jobs): + f_kernel[fi].write(" return d\n") + + # def main and call test_gemm + def_main_str = """ +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=1, help='number of threads') + args = parser.parse_args() + numThreads = args.n + num_cus = 304 + """ + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, num_cus, numThreads)' + for fi in range(jobs): + f_kernel[fi].write(def_main_str) + f_kernel[fi].write(test_gemm_call_str + "\n\n") + f_kernel[fi].write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel[fi].close() + + +def extract_kernel_time(M, N, K, num_cus, EVEN_K, config, df): + # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 + # once the bug is fixed, we should not need below two lines + cols = [ + 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr', + 'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' + ] + df.columns = cols + + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, None) + + filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() + filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] + meanTime = filtered_df['DurationNs'].tail(100).mean() + return config, meanTime + + +def profile_batch_kernels(M, N, K, num_cus, gpuid, gpus, jobs, verbose): + ngpus = len(gpus) + gpuIdx = gpus.index(gpuid) + if gpuIdx + 1 > jobs: + return + os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid) + jobId = gpuIdx + while jobId < jobs: + if verbose: + print(f"profiling {generated_kernel_name(M, N, K, jobId)} on GPU {gpuid}") + run_bash_command_wrapper( + f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {generated_kernel_name(M, N, K, jobId)}", + capture=(verbose < 2)) + jobId += ngpus + + +def tune_gemm_config(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + run_bench, jobs, iters, skipWarmup, verbose=0, num_threads=16, gpus=[0]): + # Generate kernel out of all configs + generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + jobs, iters, run_bench) + + # remove any compiled kernel in the cache + run_bash_command("rm -rf ~/.triton/cache") + + # precompile the kernels in parallel + start_time = datetime.now() + if not skipWarmup: + for i in range(jobs): + run_bash_command(f"python {generated_kernel_name(M, N, K, i)} -n {num_threads}", capture=(verbose < 2)) + compile_end = datetime.now() + compile_time = compile_end - start_time + if verbose: + print(f"compile time: {compile_time}", flush=True) + + # profile generated kernels + running = [ + multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, num_cus, gpu_id, gpus, jobs, verbose)) + for gpu_id in gpus + ] + for p in running: + p.start() + for p in running: + p.join() + + profile_end = datetime.now() + profile_time = profile_end - compile_end + if verbose: + print(f"profile time: {profile_time}", flush=True) + + # post process results.csv to get the best config and minTime + # TODO: process the file in parallel + minTime = 1024 * 1024 * 1024 + thread_pool = multiprocessing.Pool(processes=num_threads) + tasks = [] + idx = 0 + df_prof = [ + pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') + for i in range(jobs) + ] + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + file_idx = idx % jobs + tasks += [ + thread_pool.apply_async(extract_kernel_time, args=(M, N, K, num_cus, EVEN_K, config, df_prof[file_idx])) + ] + idx += 1 + thread_pool.close() + thread_pool.join() + + for task in tasks: + config, myTime = task.get() + if myTime: + min_us = myTime / 1000 + if min_us < minTime: + minTime = min_us + bestConfig = config + else: + min_us = -1 + print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True) + post_end = datetime.now() + post_time = post_end - profile_end + if verbose: + print(f"post procesing time: {post_time}", flush=True) + return minTime, bestConfig, compile_time, profile_time, post_time + + +def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'): + d_type = name_to_tl_types[ty_name] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + def init_by_size_and_type(size, dtype, init_type): + if init_type == 'hpl': + return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5) + # This init type has element[i] in row[j] equal to sin(i+j*N) + elif init_type == 'trig_float': + M, N = size + return torch.reshape(torch.arange(0, M * N), (M, N)).sin().to(dtype=dtype, device='cuda') + elif init_type == 'zeros': + return torch.zeros(size, dtype=dtype, device='cuda') + elif init_type == "randn": + temp = torch.randn(size, dtype=dtype, device='cuda') + return temp + else: + raise ValueError("Bad matrix initialization type.") + + raw_data = init_by_size_and_type((N, M) if needTrans else (M, N), torch.float32, init_type) + if needTrans: + raw_data = raw_data.T + if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ + (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): + input = raw_data.to(tl_to_torch_types[d_type]) + input_f16 = input.to(torch.float16) + else: + f8_tensor = raw_data.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + n_elements = raw_data.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + + return input, input_f16 + + +def matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, + mfmaInstrSize, kpack, EVEN_K): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + #assert a.is_contiguous(), "Matrix A must be contiguous" + #assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # 1D launch kernel where each block gets its own program. + + grid = num_cus + + streamk_gemm[ + grid, + ](a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, num_warps=num_warps, + num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, EVEN_K=EVEN_K) + return c + + +def test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose): + block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) + torch.manual_seed(0) + #a = torch.randn((M, K), device='cuda', dtype=datatype) + #b = torch.randn((K, N), device='cuda', dtype=datatype) + a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda') + b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') + # Allocates output. + print(f"{block_k}") + EVEN_K = K % block_k == 0 + c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) + locks = torch.zeros((num_cus, ), device="cuda", dtype=torch.int32) + P = torch.zeros((num_cus, block_m * block_n), device="cuda", dtype=torch.float32) + triton_output = matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, + waves_per_eu, mfmaInstrSize, kpack, EVEN_K) + torch_output = torch.matmul(a_fp16, b_fp16) + # print(f"triton_output={triton_output}") + # print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + atol = 1e-3 + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = '' + if verbose: + size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}' + if torch.allclose(triton_output.to(torch.float16), torch_output, atol=atol, rtol=rtol): + print(f'{size_str} Correct✅') + else: + print(f'{size_str} Incorrect❌') + + +def get_default_tuning_result_filename(): + git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") + git_branch_name = git_branch_name[0].decode() + git_commit_hash = run_bash_command("git rev-parse --short HEAD") + git_commit_hash = git_commit_hash[0].decode() + + dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + defaultName = f"tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" + return defaultName + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=0) + parser.add_argument("-n", type=int, default=0) + parser.add_argument("-k", type=int, default=0) + parser.add_argument("-col_a", action='store_true', default=False, help='whether matrix a is column major') + parser.add_argument("-col_b", action='store_true', default=False, help='whether matrix b is column major') + parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type") + parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type") + parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type") + parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step') + parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], + help='list of gpu ids to use for tuning') + parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') + parser.add_argument("--o", type=str, default=get_default_tuning_result_filename(), + help='yaml file to store tuning results') + parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') + parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") + parser.add_argument("--compare_wo_tuning", action='store_true', default=False, + help="Whether check result correctness") + parser.add_argument("--benchmark", action='store_true', default=False, help="Benchmark the given config") + parser.add_argument("--time_breakdown", action='store_true', default=False, + help="Show detailed time breakdown of each step during the tuning") + parser.add_argument("--verbose", action='store_true', default=False, + help="enables time_breakdown and additional logging messages") + parser.add_argument("--num_threads", type=int, default=16, + help="number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", type=int, default=1, help="number of generated files") + parser.add_argument("--iters", type=int, default=1000, help="number of generated files") + parser.add_argument("--init_type", type=str, default='randn', + help="Initialization type for input matrices (default uniform rand [0, 1.0)])") + parser.add_argument("--no_warmup", action='store_true', default=False, help="Do not call the warmup kernel") + args = parser.parse_args() + + return args + + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8': tl.float8e4b8, + 'bf8': tl.float8e5b16, +} + + +def process_item(item): + M = item['M'] + N = item['N'] + K = item['K'] + col_a = False if item['rowMajorA'] == 'T' else True + col_b = False if item['rowMajorB'] == 'T' else True + del item['M'] + del item['N'] + del item['K'] + del item['rowMajorA'] + del item['rowMajorB'] + return M, N, K, col_a, col_b, item + + +def type_name_to_bytes(ty_name): + if '32' in ty_name: + return 4 + if '16' in ty_name: + return 2 + if '8' in ty_name: + return 1 + else: + print(f"Unrecognized input type name {ty_name}") + sys.exit(1) + + +def format_output(unformatted): + if unformatted < 0.0001: + formatted = "{:.3e}".format(unformatted) + elif unformatted > 1000: + formatted = "{:.1f}".format(unformatted) + else: + formatted = "{:.2f}".format(unformatted) + return formatted + + +def main(): + args = parse_args() + matrix_size_file = args.gemm_size_file + tuning_output_file = args.o + keepTmp = args.keep + run_bench = args.benchmark + jobs = args.jobs + iters = args.iters + skipWarmup = args.no_warmup + num_cus = 304 + + # Get GPU ids + ngpus = args.ngpus + gpu_ids = args.gpu_ids + if ngpus != 0 and gpu_ids: + print("--ngpus and --gpu_ids are mutually exclusive options") + return os.EX_USAGE + if ngpus == 0 and not gpu_ids: + ngpus = 1 + if ngpus != 0: + gpus = range(ngpus) + if gpu_ids: + gpus = gpu_ids + + if run_bench: + gpus = [gpus[0]] + jobs = 1 + + # Get element type + dtype_a = args.dtype_a + dtype_b = args.dtype_b + dtype_c = args.dtype_c + dtype_p = 'fp32' + dtype_lock = 'int32' + if dtype_a not in name_to_tl_types or dtype_b not in name_to_tl_types or dtype_c not in name_to_tl_types: + print(f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}") + print("Supported types: ", list(name_to_tl_types.keys())) + sys.exit(1) + + mnks = [] + # TODO: make it more robust to get user input + init_type = args.init_type + if matrix_size_file == "" or not os.path.isfile(matrix_size_file): + M = args.m + N = args.n + K = args.k + col_a = args.col_a + col_b = args.col_b + mnks = [(M, N, K, col_a, col_b, None)] + else: + with open(matrix_size_file) as file: + matrix_sizes = yaml.safe_load(file) + for item in matrix_sizes: + M, N, K, col_a, col_b, item = process_item(item) + mnks.append((M, N, K, col_a, col_b, item)) + + # Check correctness from given configs + if args.compare_wo_tuning: + for (M, N, K, col_a, col_b, myConfig) in mnks: + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, True) + return + + configs_full = get_full_tuning_space() + + start_time = datetime.now() + if run_bench: + print(f"Benchmarking gemm with {dtype_a} inputs") + print("trans M N K TFLOPS us") + else: + print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True) + f_results = open(tuning_output_file, 'w') + + for (M, N, K, col_a, col_b, myConfig) in mnks: + start_local_time = datetime.now() + # Obtain a pruned tuning space according to gemm size + # If running benchmark, use the provided config + pruned_configs = [myConfig] if run_bench else prune_configs(M, N, K, configs_full, type_name_to_bytes(dtype_a), + type_name_to_bytes(dtype_b)) + + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}' + if not run_bench: + print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True) + else: + print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", end="") + + # The main tuning funtion for one gemm size + verbose_level = 0 + if args.time_breakdown: + verbose_level = 1 + if args.verbose: + verbose_level = 2 + minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( + M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, pruned_configs, + run_bench, jobs, iters, skipWarmup, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level) + + EVEN_K = True if K % bestConfig.get('BLOCK_SIZE_K') == 0 else False + # post processing the numbers + perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) + tri_tflops = perf_tflops(minTime) + formatted_tflops = format_output(tri_tflops) + minTime = format_output(minTime) + if not run_bench: + print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) + + bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, bestConfig, None, + None, None, None, None) + if not run_bench: + print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) + + # write best config to tuning_results.yaml + if run_bench: + print(f"{formatted_tflops} {minTime}") + + sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str} + sizeDict.update(bestConfig) + if not run_bench: + f_results.write("- " + str(sizeDict) + " ") + f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + + # remove generated files if asked to + if not keepTmp: + for i in range(jobs): + generated_script = generated_kernel_name(M, N, K, i) + os.remove(generated_script) + if not skipWarmup: + os.remove(generated_script + ".failed_configs") + for f in glob.glob(f"results_{i}.*"): + os.remove(f) + + # Check correctness if asked to + if args.compare: + print("correctness: ", end=" ", flush=True) + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) + elif not run_bench: + print("", flush=True) + + end_local_time = datetime.now() + if not run_bench: + print( + f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", + flush=True) + + if not run_bench: + f_results.close() + + end_time = datetime.now() + tuning_time = end_time - start_time + if not run_bench: + print(f"Tuning ends at: {end_time}") + print(f"Total tuning time (h:m:s): {tuning_time}") + + +if __name__ == '__main__': + sys.exit(main()) From 1d2e06681f9cb912086cade50cc53338d2a490b8 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Tue, 6 Aug 2024 14:20:12 -0300 Subject: [PATCH 09/12] Add explicit multiply-reduce GEMM kernel (#621) * Add explicit multiply-reduce GEMM kernel * Remove `SPLIT_K` argument from kernel * Remove `GROUP_SIZE_M` argument from kernel * Remove conditional call to `tl.dot` from kernel * Remove table with performance data from README --- python/perf-kernels/README.md | 8 ++++ .../perf-kernels/multreduce_matmul_kernel.py | 45 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 python/perf-kernels/multreduce_matmul_kernel.py diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index 5bcedbf49cdd..b8f930ef94ea 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -61,3 +61,11 @@ fp32, bf16 and f8 (both e5m2 and e4m3) datatypes. ## `03-matrix-multiplication-stream-k.py` This script contains the GEMM kernel that implements [stream-k](https://arxiv.org/abs/2301.03598) + +## `multreduce_matmul_kernel.py` + +Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. Such +small block sizes aren't natively supported by `tl.dot` operator. + +Despite being numerically correct, this kernel performed worse than a corresponding GEMM kernel that +used `tl.dot` with minimum block size equal to $16$. diff --git a/python/perf-kernels/multreduce_matmul_kernel.py b/python/perf-kernels/multreduce_matmul_kernel.py new file mode 100644 index 000000000000..61535d5bcdd3 --- /dev/null +++ b/python/perf-kernels/multreduce_matmul_kernel.py @@ -0,0 +1,45 @@ +import triton +import triton.language as tl + + +# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. +# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch). +@triton.jit +def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + if BIAS: + bias_ptrs = bias_ptr + offs_am * stride_bias + bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) + acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # Dot product implemented as explicit multiply-reduce: + a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype) + b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype) + accumulator += tl.sum(a * b, axis=1) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(c_ptr.type.element_ty) + if BIAS: + c += bias[:, None] + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) From 51d0d9201bbcd7479468958e006ff22090eec5a2 Mon Sep 17 00:00:00 2001 From: Joseph Groenenboom Date: Thu, 8 Aug 2024 15:10:22 +0000 Subject: [PATCH 10/12] Add support for causal masking as a toggle and more datatype support --- python/perf-kernels/flash-attention.py | 34 +++++++++++++++++--------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 988438340abe..faac1fe7d123 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -632,14 +632,14 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, do = tl.load(DO_block_ptr) # Compute dV. ppT = pT - ppT = ppT.to(tl.float16) + ppT = ppT.to(do.dtype) dv += tl.dot(ppT, do) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(tl.float16) + dsT = dsT.to(qT.dype) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m @@ -685,7 +685,7 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) - ds = ds.to(tl.float16) + ds = ds.to(kT.dtype) # Compute dQ.0. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. dq += tl.dot(ds, tl.trans(kT)) @@ -765,14 +765,14 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # compute dK and dV for blocks close to the diagonal that need to be masked num_steps = BLOCK_N1 // MASK_BLOCK_M1 dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL) # compute dK and dV for blocks that don't need masking further from the diagonal start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL) DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) @@ -999,6 +999,7 @@ def backward(ctx, do, _): q.stride(3), N_HEAD, N_CTX, + CAUSAL=ctx.causal, BLOCK_DMODEL=ctx.BLOCK_DMODEL, BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, @@ -1261,10 +1262,11 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 ]) @pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) @pytest.mark.parametrize('torch_sdpa_test', [False, True]) -@pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('causal', [False,True]) @pytest.mark.parametrize('use_alibi', [False, True]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, - dtype=torch.float16): + dtype): pytest.skip() torch.manual_seed(20) if qseqlen_not_equal_kseqlen is not None: @@ -1396,6 +1398,15 @@ def varlen_benchmark_configs(): ] return configs +def nonvarlen_backward_benchmark_configs(): + configs=[(16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + ] + return configs def run_benchmark(custom, args): @@ -1403,7 +1414,7 @@ def run_benchmark(custom, args): hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d - mode = 'fwd' + mode = args.direction x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal varlen = args.layout == 'thd' @@ -1413,6 +1424,8 @@ def run_benchmark(custom, args): else: if varlen: x_vals_list = varlen_benchmark_configs() + elif mode == 'bwd': + x_vals_list = nonvarlen_backward_benchmark_configs() else: x_vals_list = nonvarlen_benchmark_configs() print_time = args.return_time @@ -1436,10 +1449,6 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal # bias = None # bias = None - # Bwd pass only supports causal=True right now - if mode == 'bwd': - causal = True - flops_per_matmul = 0 if varlen: q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, @@ -1502,6 +1511,7 @@ def parse_args(): parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False) parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) + parser.add_argument("-direction", default='fwd') return parser.parse_args() From ae4633c4e7e12bcd17710635b0516958644c2c50 Mon Sep 17 00:00:00 2001 From: Joseph Groenenboom Date: Thu, 8 Aug 2024 12:32:57 -0500 Subject: [PATCH 11/12] Unify with new forward tests and set num_stages --- python/perf-kernels/flash-attention.py | 132 +++-- python/tutorials/06-fused-attention.py | 753 +++++++++++++++---------- 2 files changed, 525 insertions(+), 360 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index faac1fe7d123..58cf45024dcd 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -316,7 +316,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri num_warps=4), ], key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], - use_cuda_graph=True, + #use_cuda_graph=True, ) @triton.jit def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, @@ -639,7 +639,7 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, # Compute dP and dS. dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(qT.dype) + dsT = dsT.to(qT.dtype) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m @@ -695,13 +695,12 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq - @triton.jit def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # shared by Q/K/V/DO. stride_z, stride_h, stride_tok, stride_d, # H = 16, N_CTX = 1024 - H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + H, N_CTX, CAUSAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) @@ -943,6 +942,7 @@ def backward(ctx, do, _): BLOCK = 64 else: BLOCK = 128 + num_stages = 1 q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() @@ -1007,6 +1007,7 @@ def backward(ctx, do, _): BLOCK_N2=BLOCK_N2, BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, USE_ALIBI=False if ctx.alibi_slopes is None else True, + num_stages = 1, ) return dq, dk, dv, None, None @@ -1260,92 +1261,86 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 #(1, 16, 8192, 63), #(1, 16, 1022, 64), ]) -@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) -@pytest.mark.parametrize('torch_sdpa_test', [False, True]) -@pytest.mark.parametrize('causal', [False,True]) +@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('use_alibi', [False, True]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, - dtype): - pytest.skip() - torch.manual_seed(20) - if qseqlen_not_equal_kseqlen is not None: - seqlen_q = qseqlen_not_equal_kseqlen - else: - seqlen_q = N_CTX - seqlen_k = N_CTX - - if causal and ((N_CTX - 1) & N_CTX): - pytest.skip() - if causal and seqlen_q != seqlen_k: - pytest.skip() - - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = seqlen_q - input_metadata.max_seqlens_k = seqlen_k - - dropout_p = 0 - q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - o = torch.empty_like(q) - +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_bwd(Z, H, N_CTX, D_HEAD, causal, use_alibi, + layout, dtype): + torch.manual_seed(20) + + N_CTX_Q = N_CTX_K = N_CTX + HQ = HK = H + + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + dout = torch.randn_like(q) + if causal: input_metadata.need_causal() - if use_alibi and not torch_sdpa_test: + if use_alibi: # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, device="cuda").repeat(Z, 1) - input_metadata.need_alibi(alibi_slopes, Z, H) - dout = torch.randn_like(q) - # reference implementation - if torch_sdpa_test: - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, - is_causal=causal, scale=sm_scale, - dropout_mask=None) - ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + input_metadata.need_alibi(alibi_slopes, Z, HQ) else: - M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if use_alibi: - p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) - if causal: - p[:, :, M == 0] = float("-inf") + alibi_slopes = None - p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + o = torch.empty_like(q) - # # triton implementation + # triton implementation tri_out, _ = attention(q, k, v, o, input_metadata) tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None - # test - #print("reference") - #print(ref_dv) - #print("tri") - #print(tri_dv) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p = torch.where(nan_mask == 1,0,p) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + # The current block size for MI200 series is 64x64. This results in # larger differences in float results due to rounding. if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) if dtype == torch.float32: - ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-3 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) else: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) RTOL = 0 @@ -1353,7 +1348,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) - def nonvarlen_benchmark_configs(): configs = [ (16, 16, 16, 1024, 1024), diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index e533576d467b..c661510f6b2f 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -17,57 +17,60 @@ import triton import triton.language as tl +# Pick the fp8 data type -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +#TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') + +# AMD E5M2B16 +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, # - K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # - N_CTX: tl.constexpr, fp8_v: tl.constexpr): +def _attn_fwd_inner(acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + N_CTX, + pre_load_v: tl.constexpr): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # causal = False else: lo, hi = 0, N_CTX - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(K_block_ptr) - qk = tl.dot(q, k) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] + qk = tl.where(mask, qk, float("-inf")) + qk += tl.dot(q, k) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - # update acc - v = tl.load(V_block_ptr) - if fp8_v: - p = p.to(tl.float8e5) - else: - p = p.to(tl.float16) - acc = tl.dot(p, v, acc) + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -75,78 +78,72 @@ def _attn_fwd_inner(acc, l_i, m_i, q, # return acc, l_i, m_i -# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting # the code below and commenting out the equivalent parameters is convenient for # re-tuning. -configs = [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ - for BM in [64, 128]\ - for BN in [32, 64]\ - for s in ([1] if is_hip() else [3, 4, 7])\ - for w in [4, 8]\ -] - +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), + ], + key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], +) -def keep(conf): - BLOCK_M = conf.kwargs["BLOCK_M"] - BLOCK_N = conf.kwargs["BLOCK_N"] - if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: - return False - return True - -@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) @triton.jit -def _attn_fwd(Q, K, V, sm_scale, M, Out, # - stride_qz, stride_qh, stride_qm, stride_qk, # - stride_kz, stride_kh, stride_kn, stride_kk, # - stride_vz, stride_vh, stride_vk, stride_vn, # - stride_oz, stride_oh, stride_om, stride_on, # - Z, H, N_CTX, # - HEAD_DIM: tl.constexpr, # - BLOCK_M: tl.constexpr, # - BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr # +def _attn_fwd(Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + STAGE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, ): - tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + qvk_offset = off_hz * stride_qh # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, - shape=(N_CTX, HEAD_DIM), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) - v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, - shape=(N_CTX, HEAD_DIM), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), - block_shape=(BLOCK_N, HEAD_DIM), - order=v_order, + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), ) K_block_ptr = tl.make_block_ptr( base=K + qvk_offset, - shape=(HEAD_DIM, N_CTX), + shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_N), + block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) O_block_ptr = tl.make_block_ptr( base=Out + qvk_offset, - shape=(N_CTX, HEAD_DIM), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) # initialize offsets @@ -155,80 +152,96 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) - # load q: it will stay in SRAM throughout + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(q.dtype) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 4 - STAGE, offs_m, offs_n, N_CTX, + pre_load_v, ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + tl.debug_barrier() + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 2, offs_m, offs_n, N_CTX, + pre_load_v, ) # epilogue - m_i += tl.math.log2(l_i) + # write back m acc = acc / l_i[:, None] m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) + tl.store(m_ptrs, m_i + tl.math.log2(l_i)) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - @triton.jit -def _attn_bwd_preprocess(O, DO, # - Delta, # - Z, H, N_CTX, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # +def _attn_bwd_preprocess(O, DO, + Delta, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_hz = tl.program_id(1) - off_n = tl.arange(0, HEAD_DIM) - # load - o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) - do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + off_n = tl.arange(0, D_HEAD) + o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]) + do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) - # write-back tl.store(Delta + off_hz * N_CTX + off_m, delta) # The main inner-loop logic for computing dK and dV. @triton.jit -def _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # +def _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - HEAD_DIM: tl.constexpr, # + stride_tok, stride_d, + H, N_CTX, BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # Filled in by the wrapper. - start_n, start_m, num_steps, # + start_n, start_m, num_steps, MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) - offs_k = tl.arange(0, HEAD_DIM) - qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr( + base=Q, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_m), + block_shape=(BLOCK_DMODEL, BLOCK_M1), + order=(0,1) + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M1, BLOCK_DMODEL), + order=(1,0) + ) # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) curr_m = start_m step_m = BLOCK_M1 for blk_idx in range(num_steps): - qT = tl.load(qT_ptrs) + qT = tl.load(QT_block_ptr) # Load m before computing qk to reduce pipeline stall. offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) @@ -238,7 +251,7 @@ def _attn_bwd_dkdv(dk, dv, # if MASK: mask = (offs_m[None, :] >= offs_n[:, None]) pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs) + do = tl.load(DO_block_ptr) # Compute dV. ppT = pT ppT = ppT.to(tl.float16) @@ -246,35 +259,49 @@ def _attn_bwd_dkdv(dk, dv, # # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) dsT = dsT.to(tl.float16) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m - qT_ptrs += step_m * stride_tok - do_ptrs += step_m * stride_tok + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) return dk, dv # the main inner-loop logic for computing dQ @triton.jit -def _attn_bwd_dq(dq, q, K, V, # +def _attn_bwd_dq(dq, q, K, V, do, m, D, # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # Filled in by the wrapper. - start_m, start_n, num_steps, # + start_m, start_n, num_steps, MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) + VT_block_ptr = tl.make_block_ptr( + base=V, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -282,8 +309,7 @@ def _attn_bwd_dq(dq, q, K, V, # curr_n = start_n step_n = BLOCK_N2 for blk_idx in range(num_steps): - kT = tl.load(kT_ptrs) - vT = tl.load(vT_ptrs) + kT = tl.load(KT_block_ptr) qk = tl.dot(q, kT) p = tl.math.exp2(qk - m) # Autoregressive masking. @@ -292,6 +318,7 @@ def _attn_bwd_dq(dq, q, K, V, # mask = (offs_m[:, None] >= offs_n[None, :]) p = tl.where(mask, p, 0.0) # Compute dP and dS. + vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) ds = ds.to(tl.float16) @@ -300,25 +327,50 @@ def _attn_bwd_dq(dq, q, K, V, # dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += step_n - kT_ptrs += step_n * stride_tok - vT_ptrs += step_n * stride_tok + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=8), + ], + key=['H', 'N_CTX', 'BLOCK_DMODEL'], +) + @triton.jit -def _attn_bwd(Q, K, V, sm_scale, # - DO, # - DQ, DK, DV, # +def _attn_bwd(Q, K, V, sm_scale, + DO, + DQ, DK, DV, M, D, # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - BLK_SLICE_FACTOR: tl.constexpr, # - HEAD_DIM: tl.constexpr): + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) @@ -337,58 +389,91 @@ def _attn_bwd(Q, K, V, sm_scale, # M += off_chz D += off_chz - # load scales - offs_k = tl.arange(0, HEAD_DIM) + offs_k = tl.arange(0, BLOCK_DMODEL) start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR offs_n = start_n + tl.arange(0, BLOCK_N1) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) num_steps = BLOCK_N1 // MASK_BLOCK_M1 - dk, dv = _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=True # + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=True ) start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv( # - dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=False # + dk, dv = _attn_bwd_dkdv( + dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=False ) - dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dv_ptrs, dv) + DV_block_ptrs = tl.make_block_ptr( + base=DV, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) + tl.store(DV_block_ptrs, dv.to(tl.float16)) # Write back dK. dk *= sm_scale - dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dk_ptrs, dk) + DK_block_ptrs = tl.make_block_ptr( + base=DK, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) + tl.store(DK_block_ptrs, dk.to(tl.float16)) # THIS BLOCK DOES DQ: start_m = pid * BLOCK_M2 @@ -397,9 +482,26 @@ def _attn_bwd(Q, K, V, sm_scale, # MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) + + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) m = tl.load(M + offs_m) m = m[:, None] @@ -410,29 +512,39 @@ def _attn_bwd(Q, K, V, sm_scale, # # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # - MASK=True # + dq = _attn_bwd_dq(dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * BLOCK_N2, num_steps, # - MASK=False # + dq = _attn_bwd_dq(dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * BLOCK_N2, num_steps, + MASK=False ) # Write back dQ. - dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) dq *= LN2 - tl.store(dq_ptrs, dq) + tl.store(DQ_block_ptr, dq.to(tl.float16)) + + +empty = torch.empty(128, device="cuda") class _attention(torch.autograd.Function): @@ -440,42 +552,56 @@ class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints - HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] - # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} - o = torch.empty_like(q) + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + if torch.version.hip is None: + BLOCK_M = 128 + BLOCK_N = 64 if Lk <= 64 else 32 + num_stages = 4 if Lk <= 64 else 3 + num_warps = 4 if Lk <= 64 else 8 + # Tuning for H100 + if torch.cuda.get_device_capability()[0] == 9: + num_warps = 8 + num_stages = 7 if Lk >= 64 else 3 stage = 3 if causal else 1 - extra_kern_args = {} - # Tuning for AMD target - if is_hip(): - waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 - extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} - - grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) - M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + grid = lambda META: ( + triton.cdiv(q.shape[2], META['BLOCK_M']), + q.shape[0] * q.shape[1], + 1 + ) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _attn_fwd[grid]( - q, k, v, sm_scale, M, o, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - k.stride(0), k.stride(1), k.stride(2), k.stride(3), # - v.stride(0), v.stride(1), v.stride(2), v.stride(3), # - o.stride(0), o.stride(1), o.stride(2), o.stride(3), # - q.shape[0], q.shape[1], # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - STAGE=stage, # - **extra_kern_args) + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + STAGE=stage, + ) + + ## restore the grid for bwd kernel + #best_config = _attn_fwd.get_best_config() + block_m = 64#int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) + grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale - ctx.HEAD_DIM = HEAD_DIM_K + ctx.BLOCK_DMODEL = Lk ctx.causal = causal return o @staticmethod def backward(ctx, do): + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() @@ -484,49 +610,96 @@ def backward(ctx, do): dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 assert N_CTX % PRE_BLOCK == 0 pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) delta = torch.empty_like(M) _attn_bwd_preprocess[pre_grid]( - o, do, # - delta, # - BATCH, N_HEAD, N_CTX, # - BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + o, do, + delta, + BATCH, N_HEAD, N_CTX, + BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL + ) + grid = lambda META: ( + triton.cdiv(N_CTX, META['BLOCK_N1']), + 1, + BATCH * N_HEAD ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # - M, delta, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - N_HEAD, N_CTX, # - BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES # + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, + M, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + N_HEAD, N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL ) return dq, dk, dv, None, None - attention = _attention.apply +name_to_torch_types = { + 'fp16': torch.float16, +} + +if TORCH_HAS_FP8E5B16: + name_to_torch_types['fp8'] = torch.float8_e5m2fnuz + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype', +[ (*shape, dtype) + for shape in [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (4, 48, 1024, 128), + (4, 48, 2048, 128), + (4, 48, 4096, 128)] + for dtype in ['fp16', 'fp8']]) +@pytest.mark.parametrize('causal', [False, True]) +def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype): + if dtype == 'fp8' and not TORCH_HAS_FP8E5B16: + pytest.skip("fp8 not supported") + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() -@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) -@pytest.mark.parametrize("causal", [True]) -def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + sm_scale = 0.5 + dout = torch.randn_like(q, dtype=torch.float16) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale) + # compare + atol = 1.4e-1 if dtype == 'fp8' else 1e-2 + rtol = 1e-2 if dtype == 'fp8' else 0 + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (1, 16, 8192, 64), + (1, 16, 1024, 64), + ]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + causal = True + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation @@ -535,28 +708,27 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() - # p = torch.exp(p) ref_out = torch.matmul(p, v) ref_out.backward(dout) ref_dv, v.grad = v.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None - # triton implementation - tri_out = attention(q, k, v, causal, sm_scale).half() + # # triton implementation + tri_out = attention(q, k, v, causal, sm_scale) tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None # compare - assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) - rtol = 0.0 - # Relative tolerance workaround for known hardware limitation of MI200 GPU. - # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": - rtol = 1e-2 - assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) - assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) - assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + if torch.version.hip is None: + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + else: + torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2) + torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2) try: @@ -566,68 +738,69 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): except BaseException: HAS_FLASH = False -TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') -BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 # vary seq length for fixed head and batch=4 configs = [] -for mode in ["fwd", "bwd"]: - for causal in [True, False]: - if mode == "bwd" and not causal: - continue - configs.append( - triton.testing.Benchmark( - x_names=["N_CTX"], - x_vals=[2**i for i in range(10, 15)], - line_arg="provider", - line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + - (["flash"] if HAS_FLASH else []), - line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + - (["Flash-2"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-"), ("green", "-")], - ylabel="ms", - plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", +for mode in ['fwd', 'bwd']: + for D_HEAD in [128, 64]: + for causal in [False, True]: + if mode == 'bwd' and causal == False: + continue + configs.append(triton.testing.Benchmark( + x_names=['BATCH', 'H', 'N_CTX'], + x_vals=[(4, 16, 1024), + (8, 16, 2048), + (4, 16, 4096), + (2, 16, 8192), + (1, 16, 16384), + (4, 48, 1024), + (4, 48, 2048), + (4, 48, 4096), + (4, 48, 8192), + (4, 48, 16384), + ], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}', args={ - "H": N_HEADS, - "BATCH": BATCH, - "HEAD_DIM": HEAD_DIM, - "mode": mode, - "causal": causal, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal, }, )) @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): assert mode in ["fwd", "bwd"] warmup = 25 - rep = 100 - dtype = torch.float16 - if "triton" in provider: - q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - if mode == "fwd" and "fp8" in provider: - q = q.to(torch.float8_e5m2) - k = k.to(torch.float8_e5m2) - v = v.permute(0, 1, 3, 2).contiguous() - v = v.permute(0, 1, 3, 2) - v = v.to(torch.float8_e5m2) - sm_scale = 1.3 + rep = 10 + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = D_HEAD ** -0.5 fn = lambda: attention(q, k, v, causal, sm_scale) - if mode == "bwd": + if mode == 'bwd': o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if causal: total_flops *= 0.5 @@ -635,7 +808,5 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) return total_flops / ms * 1e-9 - -if __name__ == "__main__": - # only works on post-Ampere GPUs right now - bench_flash_attention.run(save_path=".", print_data=True) +# only works on post-Ampere GPUs right now +bench_flash_attention.run(save_path=".", print_data=True) From 550f3954f6570fdf00470c49c0c81529b9a72816 Mon Sep 17 00:00:00 2001 From: Joseph Groenenboom Date: Mon, 12 Aug 2024 11:45:59 -0500 Subject: [PATCH 12/12] revert changes to tutorial kernel --- python/tutorials/06-fused-attention.py | 753 ++++++++++--------------- 1 file changed, 291 insertions(+), 462 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index c661510f6b2f..e533576d467b 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -17,60 +17,57 @@ import triton import triton.language as tl -# Pick the fp8 data type -# AMD E4M3B8 -# Note: When picking this f8 data type, scaling is required when using f8 -# for the second gemm -#TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') - -# AMD E5M2B16 -TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, - K_block_ptr, V_block_ptr, - start_m, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, - N_CTX, - pre_load_v: tl.constexpr): +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, fp8_v: tl.constexpr): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # causal = False else: lo, hi = 0, N_CTX + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = tl.where(mask, qk, float("-inf")) - qk += tl.dot(q, k) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + if fp8_v: + p = p.to(tl.float8e5) + else: + p = p.to(tl.float16) + acc = tl.dot(p, v, acc) # update m_i and l_i m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -78,72 +75,78 @@ def _attn_fwd_inner(acc, l_i, m_i, q, return acc, l_i, m_i -# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting +# We don't run auto-tuning every time to keep the tutorial fast. Keeping # the code below and commenting out the equivalent parameters is convenient for # re-tuning. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), - ], - key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], -) +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64]\ + for s in ([1] if is_hip() else [3, 4, 7])\ + for w in [4, 8]\ +] + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) @triton.jit -def _attn_fwd(Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - STAGE: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # ): + tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) - qvk_offset = off_hz * stride_qh + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(N_CTX, HEAD_DIM), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(N_CTX, HEAD_DIM), strides=(stride_vk, stride_vn), offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, ) K_block_ptr = tl.make_block_ptr( base=K + qvk_offset, - shape=(BLOCK_DMODEL, N_CTX), + shape=(HEAD_DIM, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), + block_shape=(HEAD_DIM, BLOCK_N), order=(0, 1), ) O_block_ptr = tl.make_block_ptr( base=Out + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(N_CTX, HEAD_DIM), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) # initialize offsets @@ -152,96 +155,80 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(q.dtype) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 4 - STAGE, offs_m, offs_n, N_CTX, - pre_load_v, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently - tl.debug_barrier() - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 2, offs_m, offs_n, N_CTX, - pre_load_v, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # ) # epilogue - # write back m + m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i + tl.math.log2(l_i)) + tl.store(m_ptrs, m_i) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + @triton.jit -def _attn_bwd_preprocess(O, DO, - Delta, - Z, H, N_CTX, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_hz = tl.program_id(1) - off_n = tl.arange(0, D_HEAD) - o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]) - do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) + # write-back tl.store(Delta + off_hz * N_CTX + off_m, delta) # The main inner-loop logic for computing dK and dV. @triton.jit -def _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # # Filled in by the wrapper. - start_n, start_m, num_steps, + start_n, start_m, num_steps, # MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) - offs_k = tl.arange(0, BLOCK_DMODEL) - QT_block_ptr = tl.make_block_ptr( - base=Q, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_d, stride_tok), - offsets=(0, start_m), - block_shape=(BLOCK_DMODEL, BLOCK_M1), - order=(0,1) - ) - DO_block_ptr = tl.make_block_ptr( - base=DO, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M1, BLOCK_DMODEL), - order=(1,0) - ) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) curr_m = start_m step_m = BLOCK_M1 for blk_idx in range(num_steps): - qT = tl.load(QT_block_ptr) + qT = tl.load(qT_ptrs) # Load m before computing qk to reduce pipeline stall. offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) @@ -251,7 +238,7 @@ def _attn_bwd_dkdv(dk, dv, if MASK: mask = (offs_m[None, :] >= offs_n[:, None]) pT = tl.where(mask, pT, 0.0) - do = tl.load(DO_block_ptr) + do = tl.load(do_ptrs) # Compute dV. ppT = pT ppT = ppT.to(tl.float16) @@ -259,49 +246,35 @@ def _attn_bwd_dkdv(dk, dv, # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)) + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) dsT = pT * (dpT - Di[None, :]) dsT = dsT.to(tl.float16) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m - QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) - DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok return dk, dv # the main inner-loop logic for computing dQ @triton.jit -def _attn_bwd_dq(dq, q, K, V, +def _attn_bwd_dq(dq, q, K, V, # do, m, D, # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, # Filled in by the wrapper. - start_m, start_n, num_steps, + start_m, start_n, num_steps, # MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, BLOCK_DMODEL) - KT_block_ptr = tl.make_block_ptr( - base=K, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_d, stride_tok), - offsets=(0, start_n), - block_shape=(BLOCK_DMODEL, BLOCK_N2), - order=(0, 1) - ) - VT_block_ptr = tl.make_block_ptr( - base=V, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_d, stride_tok), - offsets=(0, start_n), - block_shape=(BLOCK_DMODEL, BLOCK_N2), - order=(0, 1) - ) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -309,7 +282,8 @@ def _attn_bwd_dq(dq, q, K, V, curr_n = start_n step_n = BLOCK_N2 for blk_idx in range(num_steps): - kT = tl.load(KT_block_ptr) + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) qk = tl.dot(q, kT) p = tl.math.exp2(qk - m) # Autoregressive masking. @@ -318,7 +292,6 @@ def _attn_bwd_dq(dq, q, K, V, mask = (offs_m[:, None] >= offs_n[None, :]) p = tl.where(mask, p, 0.0) # Compute dP and dS. - vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) ds = ds.to(tl.float16) @@ -327,50 +300,25 @@ def _attn_bwd_dq(dq, q, K, V, dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += step_n - KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) - VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok return dq -@triton.autotune( - configs=[ - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=8), - ], - key=['H', 'N_CTX', 'BLOCK_DMODEL'], -) - @triton.jit -def _attn_bwd(Q, K, V, sm_scale, - DO, - DQ, DK, DV, +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # M, D, # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, - # H = 16, N_CTX = 1024 - H, N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr): + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) @@ -389,91 +337,58 @@ def _attn_bwd(Q, K, V, sm_scale, M += off_chz D += off_chz - offs_k = tl.arange(0, BLOCK_DMODEL) + # load scales + offs_k = tl.arange(0, HEAD_DIM) start_n = pid * BLOCK_N1 - # This assignment is important. It is what allows us to pick the diagonal - # blocks. Later, when we want to do the lower triangular, we update start_m - # after the first dkdv call. start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR offs_n = start_n + tl.arange(0, BLOCK_N1) - dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1, 0), - ) - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1, 0), - ) - - # load K and V: they stay in SRAM throughout the inner loop for dkdv. - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) num_steps = BLOCK_N1 // MASK_BLOCK_M1 - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=True + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True # ) start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv( - dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=False + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=False # ) - DV_block_ptrs = tl.make_block_ptr( - base=DV, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1,0) - ) - tl.store(DV_block_ptrs, dv.to(tl.float16)) + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) # Write back dK. dk *= sm_scale - DK_block_ptrs = tl.make_block_ptr( - base=DK, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1,0) - ) - tl.store(DK_block_ptrs, dk.to(tl.float16)) + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) # THIS BLOCK DOES DQ: start_m = pid * BLOCK_M2 @@ -482,26 +397,9 @@ def _attn_bwd(Q, K, V, sm_scale, MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) - Q_block_ptr = tl.make_block_ptr( - base=Q, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M2, BLOCK_DMODEL), - order=(1, 0) - ) - - DO_block_ptr = tl.make_block_ptr( - base=DO, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M2, BLOCK_DMODEL), - order=(1, 0) - ) - q = tl.load(Q_block_ptr) - do = tl.load(DO_block_ptr) - dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) m = tl.load(M + offs_m) m = m[:, None] @@ -512,39 +410,29 @@ def _attn_bwd(Q, K, V, sm_scale, # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, - MASK=True + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * BLOCK_N2, num_steps, - MASK=False + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # ) # Write back dQ. - DQ_block_ptr = tl.make_block_ptr( - base=DQ, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M2, BLOCK_DMODEL), - order=(1, 0) - ) + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d dq *= LN2 - tl.store(DQ_block_ptr, dq.to(tl.float16)) - - -empty = torch.empty(128, device="cuda") + tl.store(dq_ptrs, dq) class _attention(torch.autograd.Function): @@ -552,56 +440,42 @@ class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q, dtype=v.dtype) - if torch.version.hip is None: - BLOCK_M = 128 - BLOCK_N = 64 if Lk <= 64 else 32 - num_stages = 4 if Lk <= 64 else 3 - num_warps = 4 if Lk <= 64 else 8 - # Tuning for H100 - if torch.cuda.get_device_capability()[0] == 9: - num_warps = 8 - num_stages = 7 if Lk >= 64 else 3 + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) stage = 3 if causal else 1 - grid = lambda META: ( - triton.cdiv(q.shape[2], META['BLOCK_M']), - q.shape[0] * q.shape[1], - 1 - ) - M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _attn_fwd[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - STAGE=stage, - ) - - ## restore the grid for bwd kernel - #best_config = _attn_fwd.get_best_config() - block_m = 64#int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) - grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + **extra_kern_args) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = Lk + ctx.HEAD_DIM = HEAD_DIM_K ctx.causal = causal return o @staticmethod def backward(ctx, do): - if torch.version.hip is not None: - BLOCK = 64 - else: - BLOCK = 128 q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() @@ -610,96 +484,49 @@ def backward(ctx, do): dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 1 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 assert N_CTX % PRE_BLOCK == 0 pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) delta = torch.empty_like(M) _attn_bwd_preprocess[pre_grid]( - o, do, - delta, - BATCH, N_HEAD, N_CTX, - BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL - ) - grid = lambda META: ( - triton.cdiv(N_CTX, META['BLOCK_N1']), - 1, - BATCH * N_HEAD + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, - M, delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - N_HEAD, N_CTX, - BLOCK_DMODEL=ctx.BLOCK_DMODEL + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # ) return dq, dk, dv, None, None + attention = _attention.apply -name_to_torch_types = { - 'fp16': torch.float16, -} - -if TORCH_HAS_FP8E5B16: - name_to_torch_types['fp8'] = torch.float8_e5m2fnuz - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype', -[ (*shape, dtype) - for shape in [(4, 48, 1024, 64), - (4, 48, 2048, 64), - (4, 48, 4096, 64), - (4, 48, 1024, 128), - (4, 48, 2048, 128), - (4, 48, 4096, 128)] - for dtype in ['fp16', 'fp8']]) -@pytest.mark.parametrize('causal', [False, True]) -def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype): - if dtype == 'fp8' and not TORCH_HAS_FP8E5B16: - pytest.skip("fp8 not supported") - torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - q = q.to(name_to_torch_types[dtype]) - k = k.to(name_to_torch_types[dtype]) - sm_scale = 0.5 - dout = torch.randn_like(q, dtype=torch.float16) - # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - # triton implementation - tri_out = attention(q, k, v, causal, sm_scale) - # compare - atol = 1.4e-1 if dtype == 'fp8' else 1e-2 - rtol = 1e-2 if dtype == 'fp8' else 0 - torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', - [(4, 48, 1024, 64), - (4, 48, 2048, 64), - (4, 48, 4096, 64), - (1, 16, 8192, 64), - (1, 16, 1024, 64), - ]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): torch.manual_seed(20) - causal = True - q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation @@ -708,27 +535,28 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) ref_out = torch.matmul(p, v) ref_out.backward(dout) ref_dv, v.grad = v.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None - # # triton implementation - tri_out = attention(q, k, v, causal, sm_scale) + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale).half() tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) - if torch.version.hip is None: - torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) - # The current block size for MI200 series is 64x64. This results in - # larger differences in float results due to rounding. - else: - torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0) - torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2) - torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2) + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + rtol = 0.0 + # Relative tolerance workaround for known hardware limitation of MI200 GPU. + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + rtol = 1e-2 + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) try: @@ -738,69 +566,68 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): except BaseException: HAS_FLASH = False +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 # vary seq length for fixed head and batch=4 configs = [] -for mode in ['fwd', 'bwd']: - for D_HEAD in [128, 64]: - for causal in [False, True]: - if mode == 'bwd' and causal == False: - continue - configs.append(triton.testing.Benchmark( - x_names=['BATCH', 'H', 'N_CTX'], - x_vals=[(4, 16, 1024), - (8, 16, 2048), - (4, 16, 4096), - (2, 16, 8192), - (1, 16, 16384), - (4, 48, 1024), - (4, 48, 2048), - (4, 48, 4096), - (4, 48, 8192), - (4, 48, 16384), - ], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}', +for mode in ["fwd", "bwd"]: + for causal in [True, False]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + + (["flash"] if HAS_FLASH else []), + line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", args={ - 'D_HEAD': D_HEAD, - 'dtype': torch.float16, - 'mode': mode, - 'causal': causal, + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "mode": mode, + "causal": causal, }, )) @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): assert mode in ["fwd", "bwd"] warmup = 25 - rep = 10 - # Bwd pass only supports causal=True right now - if mode == 'bwd': - causal = True - if provider == "triton": - q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - sm_scale = D_HEAD ** -0.5 + rep = 100 + dtype = torch.float16 + if "triton" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + sm_scale = 1.3 fn = lambda: attention(q, k, v, causal, sm_scale) - if mode == 'bwd': + if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM total_flops = 2 * flops_per_matmul if causal: total_flops *= 0.5 @@ -808,5 +635,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) return total_flops / ms * 1e-9 -# only works on post-Ampere GPUs right now -bench_flash_attention.run(save_path=".", print_data=True) + +if __name__ == "__main__": + # only works on post-Ampere GPUs right now + bench_flash_attention.run(save_path=".", print_data=True)