diff --git a/test/xrt/16_gemm_8x16_transform_vec_4x4/gen.py b/test/xrt/16_gemm_8x16_transform_vec_4x4/gen.py new file mode 100644 index 000000000..8bae9d067 --- /dev/null +++ b/test/xrt/16_gemm_8x16_transform_vec_4x4/gen.py @@ -0,0 +1,264 @@ +# aie.py -*- Python -*- +# +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +import air +import air.compiler.util +from air.dialects import linalg, tensor, arith, func, memref +from air.ir import * +import air.passmanager +from air.dialects import air as airdialect +from air.compiler.util import run_transform +import sys + +ctx = Context() + +################################################ +## Tiling +################################################ + +air_tiled_ir_string = """ +#map = affine_map<()[s0] -> (s0 * 128)> +#map1 = affine_map<()[s0] -> (s0 * 64)> +#map2 = affine_map<()[s0] -> (s0 * 8)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map5 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +#map6 = affine_map<(d0) -> (d0 * 16)> +#map7 = affine_map<(d0) -> (d0 * 4)> +module { + func.func @matmul_bf16(%0 : memref<512x1024xbf16>, %1 : memref<128x8x8x64xbf16>, %2 : memref<512x512xbf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %cst = arith.constant 0.000000e+00 : bf16 + %shape = memref.alloc() : memref<4xindex> + memref.store %c16, %shape[%c0] : memref<4xindex> + memref.store %c8, %shape[%c1] : memref<4xindex> + memref.store %c8, %shape[%c2] : memref<4xindex> + memref.store %c4, %shape[%c3] : memref<4xindex> + scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1) { + %3 = affine.apply #map()[%arg0] + %4 = affine.apply #map()[%arg1] + %map1 = affine.apply #map6(%arg1) + %subview = memref.subview %2[%3, %4] [128, 128] [1, 1] : memref<512x512xbf16> to memref<128x128xbf16, strided<[512, 1], offset: ?>> + %alloc = memref.alloc() : memref<128x1024xbf16, 1> + scf.for %arg2 = %c0 to %c1024 step %c256 { + %subview_2 = memref.subview %0[%3, %arg2] [128, 256] [1, 1] : memref<512x1024xbf16> to memref<128x256xbf16, strided<[1024, 1], offset: ?>> + %subview_3 = memref.subview %alloc[0, %arg2] [128, 256] [1, 1] : memref<128x1024xbf16, 1> to memref<128x256xbf16, strided<[1024, 1], offset: ?>, 1> + memref.copy %subview_2, %subview_3 : memref<128x256xbf16, strided<[1024, 1], offset: ?>> to memref<128x256xbf16, strided<[1024, 1], offset: ?>, 1> + } + %alloc_0 = memref.alloc() : memref<128x8x8x16xbf16, 1> + scf.for %arg2 = %c0 to %c128 step %c32 { + %subview_2 = memref.subview %1[%arg2, 0, 0, %map1] [32, 8, 8, 16] [1, 1, 1, 1] : memref<128x8x8x64xbf16> to memref<32x8x8x16xbf16, strided<[4096, 512, 64, 1], offset: ?>> + %transpose = memref.transpose %subview_2 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<32x8x8x16xbf16, strided<[4096, 512, 64, 1], offset: ?>> to memref<32x8x8x16xbf16, strided<[4096, 64, 512, 1], offset: ?>> + %subview_3 = memref.subview %alloc_0[%arg2, 0, 0, 0] [32, 8, 8, 16] [1, 1, 1, 1] : memref<128x8x8x16xbf16, 1> to memref<32x8x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> + memref.copy %transpose, %subview_3 : memref<32x8x8x16xbf16, strided<[4096, 64, 512, 1], offset: ?>> to memref<32x8x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> + } + %alloc_1 = memref.alloc() : memref<128x128xbf16, 1> + scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %5 = affine.apply #map1()[%arg2] + %6 = affine.apply #map1()[%arg3] + %map = affine.apply #map7(%arg3) + %subview_2 = memref.subview %alloc_1[%5, %6] [64, 64] [1, 1] : memref<128x128xbf16, 1> to memref<64x64xbf16, strided<[128, 1], offset: ?>, 1> + %alloc_3 = memref.alloc() : memref<16x16x4x4xbf16, 2> + linalg.fill ins(%cst : bf16) outs(%alloc_3 : memref<16x16x4x4xbf16, 2>) + scf.for %arg4 = %c0 to %c128 step %c8 { + %7 = affine.apply #map1()[%arg2] + %8 = affine.apply #map2()[%arg4] + %subview_4 = memref.subview %alloc[%7, %8] [64, 64] [1, 1] : memref<128x1024xbf16, 1> to memref<64x64xbf16, strided<[1024, 1], offset: ?>, 1> + %alloc_5 = memref.alloc() : memref<8x16x4x8xbf16, 2> + %expand_shape = memref.expand_shape %subview_4 [[0, 1], [2, 3]] output_shape [16, 4, 8, 8] : memref<64x64xbf16, strided<[1024, 1], offset: ?>, 1> into memref<16x4x8x8xbf16, strided<[4096, 1024, 8, 1], offset: ?>, 1> + %transpose_6 = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d2, d0, d1, d3) : memref<16x4x8x8xbf16, strided<[4096, 1024, 8, 1], offset: ?>, 1> to memref<8x16x4x8xbf16, strided<[8, 4096, 1024, 1], offset: ?>, 1> + air.dma_memcpy_nd (%alloc_5[] [] [], %transpose_6[] [] []) : (memref<8x16x4x8xbf16, 2>, memref<8x16x4x8xbf16, strided<[8, 4096, 1024, 1], offset: ?>, 1>) + %9 = affine.apply #map2()[%arg4] + %10 = affine.apply #map1()[%arg3] + %subview_7 = memref.subview %alloc_0[%arg4, 0, %map, 0] [8, 8, 4, 16] [1, 1, 1, 1] : memref<128x8x8x16xbf16, 1> to memref<8x8x4x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> + %alloc_8 = memref.alloc() : memref<4x4x8x8x4xbf16, 2> + %expand_shape_9 = memref.expand_shape %subview_7 [[0], [1], [2], [3, 4]] output_shape [8, 8, 4, 4, 4] : memref<8x8x4x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> into memref<8x8x4x4x4xbf16, strided<[1024, 128, 16, 4, 1], offset: ?>, 1> + %transpose_10 = memref.transpose %expand_shape_9 (d0, d1, d2, d3, d4) -> (d2, d3, d0, d1, d4) : memref<8x8x4x4x4xbf16, strided<[1024, 128, 16, 4, 1], offset: ?>, 1> to memref<4x4x8x8x4xbf16, strided<[16, 4, 1024, 128, 1], offset: ?>, 1> + air.dma_memcpy_nd (%alloc_8[] [] [], %transpose_10[] [] []) : (memref<4x4x8x8x4xbf16, 2>, memref<4x4x8x8x4xbf16, strided<[16, 4, 1024, 128, 1], offset: ?>, 1>) + %reshape = memref.reshape %alloc_8(%shape) : (memref<4x4x8x8x4xbf16, 2>, memref<4xindex>) -> memref<16x8x8x4xbf16, 2> + linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], library_call = "matmul_bf16_bf16"} ins(%alloc_5, %reshape : memref<8x16x4x8xbf16, 2>, memref<16x8x8x4xbf16, 2>) outs(%alloc_3 : memref<16x16x4x4xbf16, 2>) { + ^bb0(%in: bf16, %in_11: bf16, %out: bf16): + %11 = arith.mulf %in, %in_11 : bf16 + %12 = arith.addf %out, %11 : bf16 + linalg.yield %12 : bf16 + } + memref.dealloc %alloc_5 : memref<8x16x4x8xbf16, 2> + memref.dealloc %alloc_8 : memref<4x4x8x8x4xbf16, 2> + } + %transpose = memref.transpose %alloc_3 (d0, d1, d2, d3) -> (d1, d2, d0, d3) : memref<16x16x4x4xbf16, 2> to memref<16x4x16x4xbf16, strided<[16, 4, 256, 1]>, 2> + air.dma_memcpy_nd (%subview_2[] [] [], %transpose[] [] []) : (memref<64x64xbf16, strided<[128, 1], offset: ?>, 1>, memref<16x4x16x4xbf16, strided<[16, 4, 256, 1]>, 2>) + memref.dealloc %alloc_3 : memref<16x16x4x4xbf16, 2> + scf.reduce + } + memref.copy %alloc_1, %subview : memref<128x128xbf16, 1> to memref<128x128xbf16, strided<[512, 1], offset: ?>> + memref.dealloc %alloc : memref<128x1024xbf16, 1> + memref.dealloc %alloc_0 : memref<128x8x8x16xbf16, 1> + memref.dealloc %alloc_1 : memref<128x128xbf16, 1> + scf.reduce + } + return + } +} +""" +air_module = Module.parse(air_tiled_ir_string, context=ctx) + +with open("air_tiled.mlir", "w") as f: + f.write(str(air_module)) + +################################################ +## Binding scf.paralell to air hierarchies +################################################ + +pipeline = ( + "builtin.module(" + + ",".join( + [ + "air-copy-to-dma", + "buffer-results-to-out-params", + "air-linalg-to-func{link-with=mm.o}", + "air-par-to-herd{depth=1}", + "air-par-to-launch{has-air-segment=true}", + "canonicalize", + "cse", + ] + ) + + ")" +) +pm = air.passmanager.PassManager.parse(pipeline, context=ctx) +pm.run(air_module.operation) + +with open("air_sync.mlir", "w") as f: + f.write(str(air_module)) + +############################################### +# Extract event dependency and optimize schedule +############################################### + +pipeline = ( + "builtin.module(" + + ",".join( + [ + "air-dependency", + "air-dependency-schedule-opt", + "air-specialize-dma-broadcast", + "air-dma-to-channel", + "canonicalize", + "cse", + "air-dependency-canonicalize", + "canonicalize", + "cse", + "func.func(air-loop-fusion)", + "air-label-scf-for-to-ping-pong", + ] + ) + + ")" +) +pm = air.passmanager.PassManager.parse(pipeline, context=ctx) +pm.run(air_module.operation) + +with open("air_fusion.mlir", "w") as f: + f.write(str(air_module)) + +# Not sure why parsing the ir solves the segmentation fault... +air_module = Module.parse(str(air_module), context=ctx) +pipeline = ( + "builtin.module(" + + ",".join( + [ + "air-ping-pong-transform{keep-memref-dealloc=true}", + "canonicalize", + "cse", + "air-specialize-channel-wrap-and-stride", + "canonicalize", + "cse", + ] + ) + + ")" +) +pm = air.passmanager.PassManager.parse(pipeline, context=ctx) +pm.run(air_module.operation) +with open("aircc_input.mlir", "w") as f: + f.write(str(air_module)) + +################################################ +## Place herd to segment +################################################ + +air_async_module = Module.parse(str(air_module), context=ctx) +pipeline = ( + "builtin.module(" + + ",".join( + [ + "func.func(air-collapse-herd)", + "canonicalize", + "cse", + "air-place-herds{num-rows=4 num-cols=1 row-anchor=2 col-anchor=0}", + "canonicalize", + "cse", + "func.func(air-renumber-dma)", + ] + ) + + ")" +) +pm = air.passmanager.PassManager.parse(pipeline, context=ctx) +pm.run(air_module.operation) +with open("air_placed.mlir", "w") as f: + f.write(str(air_module)) + +################################################ +## MLIR-AIR to MLIR-AIE +################################################ + +pipeline = ( + "builtin.module(" + + ",".join( + [ + "canonicalize", + "cse", + "air-to-aie{row-offset=2 col-offset=0 device=npu1_4col emit-while-loop=true}", + "canonicalize", + ] + ) + + ")" +) +pm = air.passmanager.PassManager.parse(pipeline, context=ctx) +pm.run(air_module.operation) +with open("aircc_decomp_aiecc.mlir", "w") as f: + f.write(str(air_module)) + +################################################ +## MLIR-AIR runtime lowering +################################################ + +pipeline = ( + "builtin.module(" + + ",".join( + [ + "air-to-std", + "canonicalize", + "symbol-dce", + "func.func(affine-loop-opt{affine-opt-tile-sizes=16,16})", + "func.func(air-unroll-outer-affine-loops{depth=4})", + "affine-expand-index-ops", + "airrt-to-npu", + "canonicalize", + ] + ) + + ")" +) +pm = air.passmanager.PassManager.parse(pipeline, context=ctx) +pm.run(air_module.operation) +with open("aie.mlir", "w") as f: + f.write(str(air_module)) diff --git a/test/xrt/16_gemm_8x16_transform_vec_4x4/matrix_multiplication.h b/test/xrt/16_gemm_8x16_transform_vec_4x4/matrix_multiplication.h new file mode 100644 index 000000000..64afca3c4 --- /dev/null +++ b/test/xrt/16_gemm_8x16_transform_vec_4x4/matrix_multiplication.h @@ -0,0 +1,290 @@ +//===- matrix_multiplication.h ----------------------------000---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +// This file contains common helper functions for the matrix multiplication +// host code, such as verifying and printing matrices. + +#ifndef MATRIX_MULTIPLICATION_H +#define MATRIX_MULTIPLICATION_H + +#include +#include + +namespace matmul_common { + +namespace po = boost::program_options; + +// -------------------------------------------------------------------------- +// Command Line Argument Handling +// -------------------------------------------------------------------------- + +void check_arg_file_exists(po::variables_map &vm_in, std::string name) { + if (!vm_in.count(name)) { + throw std::runtime_error("Error: no " + name + " file was provided\n"); + } else { + std::ifstream test(vm_in[name].as()); + if (!test) { + throw std::runtime_error("The " + name + " file " + + vm_in[name].as() + + " does not exist.\n"); + } + } +} + +void add_default_options(po::options_description &desc) { + desc.add_options()("help,h", "produce help message")( + "xclbin,x", po::value()->required(), + "the input xclbin path")( + "kernel,k", po::value()->required(), + "the kernel name in the XCLBIN (for instance PP_PRE_FD)")( + "verbosity,v", po::value()->default_value(0), + "the verbosity of the output")( + "instr,i", po::value()->required(), + "path of file containing userspace instructions to be sent to the LX6"); +} + +void parse_options(int argc, const char *argv[], po::options_description &desc, + po::variables_map &vm) { + try { + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + if (vm.count("help")) { + std::cout << desc << "\n"; + std::exit(1); + } + } catch (const std::exception &ex) { + std::cerr << ex.what() << "\n\n"; + std::cerr << "Usage:\n" << desc << "\n"; + std::exit(1); + } + + check_arg_file_exists(vm, "xclbin"); + check_arg_file_exists(vm, "instr"); +} + +// -------------------------------------------------------------------------- +// AIE Specifics +// -------------------------------------------------------------------------- + +std::vector load_instr_sequence(std::string instr_path) { + std::ifstream instr_file(instr_path); + std::string line; + std::vector instr_v; + while (std::getline(instr_file, line)) { + std::istringstream iss(line); + uint32_t a; + if (!(iss >> std::hex >> a)) { + throw std::runtime_error("Unable to parse instruction file\n"); + } + instr_v.push_back(a); + } + return instr_v; +} + +// -------------------------------------------------------------------------- +// Matrix / Float / Math +// -------------------------------------------------------------------------- + +static inline std::int16_t random_int16_t() { + return (std::int16_t)rand() % 0x10000; +} + +static inline std::bfloat16_t random_bfloat16_t() { + // Random numbers should NOT be uniformly between 0 and 1, because that + // would make the matrix product AB always close to 1. + return std::bfloat16_t(4.0 * (float)rand() / (float)(RAND_MAX)); +} + +template +void matmul_naive(int M, int N, int K, const std::vector A, + const std::vector B, std::vector &C) { + for (int row = 0; row < M; row++) { + for (int col = 0; col < N; col++) { + Tout running_sum = 0; + for (int k = 0; k < K; k++) { + running_sum += Tout(A[row * K + k] * B[k * N + col]); + } + C[row * N + col] = Tout(running_sum); + } + } +} + +template +void matmul(int M, int N, int K, const std::vector A, + const std::vector B, std::vector &C) { + // A is an MxK matrix + // B is a KxN matrix + // C is the MxN output matrix, assumed to be zeroed out + + constexpr int K_block_size = 64; + const int n_K_blocks = K / K_block_size; + + const Tin *B_origin = B.data(); /* Avoid a calls to B.data() within the loop + with this const variable. B does not get + resized, so the pointer remains valid. */ + + const Tin *A_base = A.data(); /* Points to start of current row of A, + monotonically increasing by K. */ + const Tin *B_base = B_origin; /* Points to start of current column of B; + increases by 1 in each inner loop, resets + to B_origin (0) at the start of a new row + (outer loop). */ + + const Tin *A_ptr = A_base; + const Tin *B_ptr = B_base; + Tout *C_ptr = C.data(); /* Monotonically increasing by 1. */ + + for (int row = 0; row < M; row++) { + for (int col = 0; col < N; col++) { + A_ptr = A_base; + B_ptr = B_base; + Tout running_sum = 0; + for (int k = 0; k < n_K_blocks; k++) { + for (int i = 0; i < K_block_size; i++) { + running_sum += Tout(*A_ptr) * Tout(*B_ptr); + A_ptr += 1; // Advance to right neighbor; next value in this row + B_ptr += N; // Advance to bottom neighbor; next value in this column + } + } + *C_ptr = Tout(running_sum); + C_ptr += 1; + B_base += 1; /* Next iteration: same row of A (A_base unchanged), + next column of B (B_base increases by 1) */ + } + A_base += K; // Advance to next row of A + B_base = B_origin; /* Next row of A means we need to restart at the first + column of B. */ + } +} + +// nearly_equal function adapted from Stack Overflow, License CC BY-SA 4.0 +// Original author: P-Gn +// Source: https://stackoverflow.com/a/32334103 +bool nearly_equal(float a, float b, float epsilon = 128 * FLT_EPSILON, + float abs_th = FLT_MIN) +// those defaults are arbitrary and could be removed +{ + assert(std::numeric_limits::epsilon() <= epsilon); + assert(epsilon < 1.f); + + if (a == b) + return true; + + auto diff = std::abs(a - b); + auto norm = + std::min((std::abs(a) + std::abs(b)), std::numeric_limits::max()); + // or even faster: std::min(std::abs(a + b), + // std::numeric_limits::max()); keeping this commented out until I + // update figures below + return diff < std::max(abs_th, epsilon * norm); +} + +template +void print_matrix(const std::vector matrix, int n_cols, + int n_printable_rows = 10, int n_printable_cols = 10, + std::ostream &ostream = std::cout, + const char col_sep[] = " ", const char elide_sym[] = " ... ", + int w = -1) { + assert(matrix.size() % n_cols == 0); + + auto maxima = std::minmax_element(matrix.begin(), matrix.end()); + T max_val = std::max(*maxima.first, std::abs(*maxima.second)); + size_t n_digits = log10(max_val); + if (w == -1) { + w = n_digits; + } + int n_rows = matrix.size() / n_cols; + + n_printable_rows = std::min(n_rows, n_printable_rows); + n_printable_cols = std::min(n_cols, n_printable_cols); + + const bool elide_rows = n_printable_rows < n_rows; + const bool elide_cols = n_printable_cols < n_cols; + + if (elide_rows || elide_cols) { + w = std::max((int)w, (int)strlen(elide_sym)); + } + + w += 3; // for decimal point and two decimal digits + ostream << std::fixed << std::setprecision(2); + +#define print_row(what) \ + for (int col = 0; col < n_printable_cols / 2; col++) { \ + ostream << std::right << std::setw(w) << (what); \ + ostream << std::setw(0) << col_sep; \ + } \ + if (elide_cols) { \ + ostream << std::setw(0) << elide_sym; \ + } \ + for (int col = n_printable_cols / 2 + 1; col < n_printable_cols; col++) { \ + ostream << std::right << std::setw(w) << (what); \ + ostream << std::setw(0) << col_sep; \ + } + + for (int row = 0; row < n_printable_rows / 2; row++) { + print_row(matrix[row * n_rows + col]); + ostream << std::endl; + } + if (elide_rows) { + print_row(elide_sym); + ostream << std::endl; + } + for (int row = n_printable_rows / 2 + 1; row < n_printable_rows; row++) { + print_row(matrix[row * n_rows + col]); + ostream << std::endl; + } + +#undef print_row +} + +template +int verify(int M, int N, int K, std::vector A, std::vector B, + std::vector C) { + int errors = 0; + int max_printable_errors = 500; + const float absTol = 0.5; + const float relTol = 0.5; + + std::vector CRef(M * N); + matmul(M, N, K, A, B, CRef); + + for (int row = 0; row < M; row++) { + for (int col = 0; col < N; col++) { + if (!nearly_equal(CRef[row * N + col], C[row * N + col], relTol, + absTol)) { + errors++; + if (errors < max_printable_errors) { + std::cout << "Error in row " << row << ", col " << col << ". " + << "Expected " << std::setw(4) << (float)CRef[row * N + col] + << ", got " << std::setw(4) << (float)C[row * N + col] + << "." << std::endl; + } + } + } + } + + if (errors >= max_printable_errors) { + std::cout << "...and " << std::setw(0) << errors << " further errors." + << std::endl; + } + if (errors > 0) { + std::cout << std::endl << "Reference:" << std::endl; + matmul_common::print_matrix(CRef, N); + std::cout << std::endl << "Output:" << std::endl; + matmul_common::print_matrix(C, N); + } + + return errors; +} + +} // namespace matmul_common + +#endif diff --git a/test/xrt/16_gemm_8x16_transform_vec_4x4/mm.cc b/test/xrt/16_gemm_8x16_transform_vec_4x4/mm.cc new file mode 100644 index 000000000..c0e009a48 --- /dev/null +++ b/test/xrt/16_gemm_8x16_transform_vec_4x4/mm.cc @@ -0,0 +1,138 @@ +//===- mm.cc ----------------------------------------------000---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +#define __AIENGINE__ 2 +#define NOCPP +#define __AIEARCH__ 20 + +#include +#include +#include +#include + +#define REL_WRITE 0 +#define REL_READ 1 + +#include + +#include "zero.cc" + +template +void matmul_vectorized(const T_in *__restrict pA, const T_in *__restrict pB, + T_out *__restrict pC) { + using MMUL = aie::mmul; + + event0(); + + for (unsigned z = 0; z < rowA; z += 2) + chess_loop_range(2, ) { + T_out *__restrict pC1 = pC + (z)*MMUL::size_C; + T_out *__restrict pC2 = pC + ((z + 1)) * MMUL::size_C; + + for (unsigned j = 0; j < colB; j += 2) + chess_prepare_for_pipelining chess_loop_range(8, ) { + const T_in *__restrict pA1 = pA + (z)*MMUL::size_A; + const T_in *__restrict pA2 = pA + ((z + 1)) * MMUL::size_A; + const T_in *__restrict pB1 = pB + (j)*colA * MMUL::size_B; + const T_in *__restrict pB2 = pB + ((j + 1)) * colA * MMUL::size_B; + aie::vector A0 = aie::load_v(pA1); + pA1 += rowA * MMUL::size_A; + aie::vector A1 = aie::load_v(pA2); + pA2 += rowA * MMUL::size_A; + aie::vector B0 = aie::load_v(pB1); + pB1 += MMUL::size_B; + aie::vector B1 = aie::load_v(pB2); + pB2 += MMUL::size_B; + + aie::vector acc_C00 = + aie::load_v(pC1); + aie::vector acc_C01 = + aie::load_v(pC1 + MMUL::size_C * rowA); + aie::vector acc_C10 = + aie::load_v(pC2); + aie::vector acc_C11 = + aie::load_v(pC2 + MMUL::size_C * rowA); + + MMUL C00(acc_C00); + MMUL C01(acc_C01); + MMUL C10(acc_C10); + MMUL C11(acc_C11); + + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + + for (unsigned i = 1; i < colA; ++i) + chess_prepare_for_pipelining chess_loop_range(7, ) { + A0 = aie::load_v(pA1); + pA1 += rowA * MMUL::size_A; + A1 = aie::load_v(pA2); + pA2 += rowA * MMUL::size_A; + B0 = aie::load_v(pB1); + pB1 += MMUL::size_B; + B1 = aie::load_v(pB2); + pB2 += MMUL::size_B; + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + } + + aie::store_v(pC1, C00.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC1, C01.template to_vector()); + pC1 += MMUL::size_C * rowA; + aie::store_v(pC2, C10.template to_vector()); + pC2 += MMUL::size_C * rowA; + aie::store_v(pC2, C11.template to_vector()); + pC2 += MMUL::size_C * rowA; + } + } + + event1(); +} + +template +void matmul_vectorized_4x8x4_bf16_bf16(const bfloat16 *__restrict pA, + const bfloat16 *__restrict pB, + bfloat16 *__restrict pC) { + constexpr int r = 4; + constexpr int s = 8; + constexpr int t = 4; + static_assert(m % (2 * r) == 0 && m / (2 * r) > 0); + static_assert(k % (2 * s) == 0 && k / (2 * s) > 0); + static_assert(n % (2 * t) == 0 && n / (2 * t) > 0); + return matmul_vectorized( + pA, pB, pC); +} + +extern "C" { + +#define combos(X) X(bfloat16, bf16, bfloat16, bf16, 4, 8, 4) + +#define matmul_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \ + mlir_type_out, r, s, t) \ + void matmul_##mlir_type_in##_##mlir_type_out(ctype_in *a_in, ctype_in *b_in, \ + ctype_out *c_out) { \ + matmul_vectorized_##r##x##s##x##t##_##mlir_type_in##_##mlir_type_out< \ + 64, 64, 64>(a_in, b_in, c_out); \ + } + +#define zero_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \ + mlir_type_out, r, s, t) \ + void linalg_fill_bf16_view16x16x4x4xbf16as2(ctype_out *c_out) { \ + zero_vectorized(c_out); \ + } + +combos(matmul_vectorized_c_func) combos(zero_vectorized_c_func) + +} // extern "C" diff --git a/test/xrt/16_gemm_8x16_transform_vec_4x4/run.lit b/test/xrt/16_gemm_8x16_transform_vec_4x4/run.lit new file mode 100644 index 000000000..62b994ec1 --- /dev/null +++ b/test/xrt/16_gemm_8x16_transform_vec_4x4/run.lit @@ -0,0 +1,10 @@ +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + +// REQUIRES: ryzen_ai, valid_xchess_license + +// RUN: xchesscc_wrapper aie2 -I %aietools/include -c %S/mm.cc -o mm.o +// RUN: %python %S/gen.py +// RUN: %python aiecc.py --xchesscc --xbridge --no-aiesim --aie-generate-cdo --aie-generate-npu --no-compile-host --xclbin-name=aie.xclbin --npu-insts-name=insts.txt aie.mlir +// RUN: g++-13 %S/test.cpp -o test.exe -std=c++23 -Wall %xrt_flags -lrt -lstdc++ -lboost_program_options -lboost_filesystem +// RUN: %run_on_npu ./test.exe -x aie.xclbin -k MLIR_AIE -i insts.txt diff --git a/test/xrt/16_gemm_8x16_transform_vec_4x4/test.cpp b/test/xrt/16_gemm_8x16_transform_vec_4x4/test.cpp new file mode 100644 index 000000000..05dcf18ef --- /dev/null +++ b/test/xrt/16_gemm_8x16_transform_vec_4x4/test.cpp @@ -0,0 +1,230 @@ +//===- test.cpp -------------------------------------------000---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "xrt/xrt_bo.h" +#include "xrt/xrt_device.h" +#include "xrt/xrt_kernel.h" + +#include "matrix_multiplication.h" + +constexpr int M = 512; +constexpr int K = 1024; +constexpr int N = 512; + +constexpr int Tx = 16; +constexpr int Ty = 8; + +constexpr int A_VOLUME = M * K; +constexpr int B_VOLUME = N * K; +constexpr int C_VOLUME = M * N; + +using A_DATATYPE = std::bfloat16_t; +using B_DATATYPE = std::bfloat16_t; +using C_DATATYPE = std::bfloat16_t; + +constexpr int A_SIZE = (A_VOLUME * sizeof(A_DATATYPE)); +constexpr int B_SIZE = (B_VOLUME * sizeof(B_DATATYPE)); +constexpr int C_SIZE = (C_VOLUME * sizeof(C_DATATYPE)); + +constexpr bool VERIFY = true; + +namespace po = boost::program_options; + +int main(int argc, const char *argv[]) { + + // Program arguments parsing + po::options_description desc("Allowed options"); + po::variables_map vm; + matmul_common::add_default_options(desc); + matmul_common::parse_options(argc, argv, desc, vm); + int verbosity = vm["verbosity"].as(); + + srand(time(NULL)); + + std::vector instr_v = + matmul_common::load_instr_sequence(vm["instr"].as()); + if (verbosity >= 1) + std::cout << "Sequence instr count: " << instr_v.size() << "\n"; + + // Start the XRT test code + // Get a device handle + unsigned int device_index = 0; + auto device = xrt::device(device_index); + + // Load the xclbin + if (verbosity >= 1) + std::cout << "Loading xclbin: " << vm["xclbin"].as() << "\n"; + auto xclbin = xrt::xclbin(vm["xclbin"].as()); + + if (verbosity >= 1) + std::cout << "Kernel opcode: " << vm["kernel"].as() << "\n"; + std::string Node = vm["kernel"].as(); + + // Get the kernel from the xclbin + auto xkernels = xclbin.get_kernels(); + auto xkernel = *std::find_if(xkernels.begin(), xkernels.end(), + [Node, verbosity](xrt::xclbin::kernel &k) { + auto name = k.get_name(); + if (verbosity >= 1) { + std::cout << "Name: " << name << std::endl; + } + return name.rfind(Node, 0) == 0; + }); + auto kernelName = xkernel.get_name(); + + if (verbosity >= 1) + std::cout << "Registering xclbin: " << vm["xclbin"].as() + << "\n"; + + device.register_xclbin(xclbin); + + // get a hardware context + if (verbosity >= 1) + std::cout << "Getting hardware context.\n"; + xrt::hw_context context(device, xclbin.get_uuid()); + + // get a kernel handle + if (verbosity >= 1) + std::cout << "Getting handle to kernel:" << kernelName << "\n"; + auto kernel = xrt::kernel(context, kernelName); + + auto bo_instr = xrt::bo(device, instr_v.size() * sizeof(int), + XCL_BO_FLAGS_CACHEABLE, kernel.group_id(1)); + auto bo_a = + xrt::bo(device, A_SIZE, XRT_BO_FLAGS_HOST_ONLY, kernel.group_id(3)); + auto bo_b = + xrt::bo(device, B_SIZE, XRT_BO_FLAGS_HOST_ONLY, kernel.group_id(4)); + auto bo_c = + xrt::bo(device, C_SIZE, XRT_BO_FLAGS_HOST_ONLY, kernel.group_id(5)); + + if (verbosity >= 1) + std::cout << "Writing data into buffer objects.\n"; + + A_DATATYPE *bufA = bo_a.map(); + std::vector AVec(A_VOLUME); + for (int i = 0; i < A_VOLUME; i++) { + AVec[i] = matmul_common::random_bfloat16_t(); + } + memcpy(bufA, AVec.data(), (AVec.size() * sizeof(A_DATATYPE))); + B_DATATYPE *bufB = bo_b.map(); + std::vector BVec(B_VOLUME); + for (int i = 0; i < B_VOLUME; i++) { + BVec[i] = matmul_common::random_bfloat16_t(); + } + std::vector BlockedBVec(B_VOLUME); + for (int k = 0; k < (K / Ty); k++) { + for (int n = 0; n < (N / Tx); n++) { + for (int ty = 0; ty < Ty; ty++) { + for (int tx = 0; tx < Tx; tx++) { + int inputIdx = tx + (Tx * ty) + (n * Ty * Tx) + (k * N * Ty); + int blockIdx = tx + (N * ty) + (Tx * n) + (k * Ty * N); + BlockedBVec[blockIdx] = BVec[inputIdx]; + } + } + } + } + memcpy(bufB, BlockedBVec.data(), (BlockedBVec.size() * sizeof(B_DATATYPE))); + C_DATATYPE *bufC = bo_c.map(); + std::vector CVec(C_VOLUME); + memcpy(bufC, CVec.data(), (CVec.size() * sizeof(C_DATATYPE))); + + void *bufInstr = bo_instr.map(); + memcpy(bufInstr, instr_v.data(), instr_v.size() * sizeof(int)); + + bo_instr.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_a.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_b.sync(XCL_BO_SYNC_BO_TO_DEVICE); + bo_c.sync(XCL_BO_SYNC_BO_TO_DEVICE); + + unsigned num_iter = 1; + float npu_time_total = 0; + float npu_time_min = 9999999; + float npu_time_max = 0; + + int errors = 0; + float macs = 2.0 * float(M) * float(K) * float(N); + + for (unsigned iter = 0; iter < num_iter; iter++) { + + if (verbosity >= 1) { + std::cout << "Running Kernel.\n"; + } + auto start = std::chrono::high_resolution_clock::now(); + unsigned int opcode = 3; + auto run = kernel(opcode, bo_instr, instr_v.size(), bo_a, bo_b, bo_c); + run.wait(); + auto stop = std::chrono::high_resolution_clock::now(); + + bo_c.sync(XCL_BO_SYNC_BO_FROM_DEVICE); + memcpy(CVec.data(), bufC, (CVec.size() * sizeof(C_DATATYPE))); + std::vector CVecRef(C_VOLUME); + if (VERIFY) { + if (verbosity >= 1) { + std::cout << "Verifying against reference matmul ..." << std::endl; + } + auto vstart = std::chrono::system_clock::now(); + matmul_common::matmul(M, N, K, AVec, BVec, CVecRef); + errors = matmul_common::verify(M, N, K, AVec, BVec, CVec); + auto vstop = std::chrono::system_clock::now(); + float vtime = + std::chrono::duration_cast(vstop - vstart) + .count(); + if (verbosity >= 1) { + std::cout << "Verify time: " << vtime << "secs." << std::endl; + } + } else { + if (verbosity >= 1) + std::cout << "WARNING: matmul results not verified." << std::endl; + } + + float npu_time = + std::chrono::duration_cast(stop - start) + .count(); + + npu_time_total += npu_time; + npu_time_min = (npu_time < npu_time_min) ? npu_time : npu_time_min; + npu_time_max = (npu_time > npu_time_max) ? npu_time : npu_time_max; + } + + std::cout << std::endl + << "Avg NPU matmul time: " << npu_time_total / num_iter << "us." + << std::endl; + std::cout << "Avg NPU gflops: " << macs / (1000 * npu_time_total / num_iter) + << std::endl; + + std::cout << std::endl + << "Min NPU matmul time: " << npu_time_min << "us." << std::endl; + std::cout << "Min NPU gflops: " << macs / (1000 * npu_time_min) << std::endl; + + std::cout << std::endl + << "Max NPU matmul time: " << npu_time_max << "us." << std::endl; + std::cout << "Max NPU gflops: " << macs / (1000 * npu_time_max) << std::endl; + + if (VERIFY && !errors) { + std::cout << "\nPASS!\n\n"; + return 0; + } else { + std::cout << "\nError count: " << errors << "\n\n"; + std::cout << "\nFailed.\n\n"; + return 1; + } +} diff --git a/test/xrt/16_gemm_8x16_transform_vec_4x4/zero.cc b/test/xrt/16_gemm_8x16_transform_vec_4x4/zero.cc new file mode 100644 index 000000000..8c13b601d --- /dev/null +++ b/test/xrt/16_gemm_8x16_transform_vec_4x4/zero.cc @@ -0,0 +1,33 @@ +//===- zero.cc --------------------------------------------000---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +#ifndef ZERO_CC +#define ZERO_CC + +#include +#include +#include +#include + +template +void zero_vectorized(T *__restrict c) { + const aie::vector zeros = aie::zeros(); + const T *__restrict c_end = c + M * N; + for (; c + r < c_end; c += r) { + aie::store_v(c, zeros); + } + // Do a scalar write for any remainder not divisible by vector instruction + // size r + for (; c < c_end; c++) { + *c = 0; + } +} + +#endif