Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing several bugs in Halide's AMX support #8350

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
95 changes: 40 additions & 55 deletions src/ExtractTileOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,35 +81,7 @@ const auto wild_i32x = Variable::make(Int(32, 0), "*");

Tile<1> get_1d_tile_index(const Expr &e) {
if (const auto *r1 = e.as<Ramp>()) {

const auto stride_var = Variable::make(Int(32), "stride");
const auto v1 = Variable::make(Int(32), "v1");
const auto v2 = Variable::make(Int(32), "v2");
const auto v3 = Variable::make(Int(32), "v3");

Expr patterns[] = {
((v1 * stride_var) + v2) * v3,
v3 * ((v1 * stride_var) + v2),
(v2 + (v1 * stride_var)) * v3,
v3 * (v2 + (v1 * stride_var)),
};

std::map<std::string, Expr> matches;
for (const auto &pattern : patterns) {
if (expr_match(pattern, r1->base, matches)) {
auto stride = std::move(matches["stride"]);
// stride must be a constant in order to not be confused with v1
if (stride.as<IntImm>()) {
return {true, r1->base, {std::move(stride)}, {r1->lanes}};
}

// if stride wasn't a constant then v1 could possibly be the stride if constant
auto v1_expr = std::move(matches["v1"]);
if (v1_expr.as<IntImm>()) {
return {true, r1->base, {std::move(v1_expr)}, {r1->lanes}};
}
}
}
return {true, r1->base, {r1->stride}, {r1->lanes}};
}

return {};
Expand Down Expand Up @@ -218,7 +190,7 @@ Tile<3> get_3d_tile_index(const Expr &e) {
* The pattern which is getting matched looks roughly like
* `broadcast(ramp(0, 1, r), x*y) / broadcast(4, x*y*r) + optional(broadcast(base, x*y*r)) * broadcast(8, x*y*r) +
* broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) +
* broadcast(ramp(broadcast(_, r), broadcast(4, r), x) , y)`
* broadcast(ramp(broadcast(_, r), broadcast(4, r), y) , x)`
*/
Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
const auto *sub = e.as<Sub>();
Expand All @@ -239,38 +211,38 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
// The right hand side of the add expression is used for retrieving the dimensions of the matrix.
// obtain the x, y, r dimensions
// this expr looks like below, the shape of `add_lhs->a` can be seen further down below
// broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), x) , y)
// broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), y) , x)
const Add *dim_expr = add_lhs->b.as<Add>();

if (!dim_expr) {
return {};
}

// broadcast(ramp(broadcast(_, r), broadcast(4, r), x), y)
// broadcast(ramp(broadcast(_, r), broadcast(4, r), y), x)
const Broadcast *base_stride_bc = dim_expr->b.as<Broadcast>();

if (!base_stride_bc) {
return {};
}

int tile_y = base_stride_bc->lanes;
int tile_x = base_stride_bc->lanes;

// broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r)
const Mod *mod = dim_expr->a.as<Mod>();

if (!mod) {
std::vector<Expr> results{};
const Expr mod_pattern = Mod::make(wild_i32x, Broadcast::make(4 / element_width, 0));
if (!expr_match(mod_pattern, dim_expr->a, results)) {
return {};
}

// broadcast(ramp(0, 1, r), x*y)
const Broadcast *bc_ramp = mod->a.as<Broadcast>();
const Broadcast *bc_ramp = results[0].as<Broadcast>();

if (!bc_ramp) {
return {};
}

int tile_xy = bc_ramp->lanes;
int tile_x = tile_xy / tile_y;
int tile_y = tile_xy / tile_x;

// ramp(0, 1, r)
const Ramp *r_ramp = bc_ramp->value.as<Ramp>();
Expand All @@ -282,21 +254,13 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
int tile_r = r_ramp->lanes;

// get the base and stride
// ramp(broadcast(_, r), broadcast(4, r), x)
const Ramp *base_stride_ramp = base_stride_bc->value.as<Ramp>();

if (!base_stride_ramp) {
// ramp(broadcast(_, r), broadcast(4, r), y)
const Expr base_stride_ramp_pattern = Ramp::make(Broadcast::make(wild_i32, tile_r), Broadcast::make(4 / element_width, tile_r), tile_y);
if (!expr_match(base_stride_ramp_pattern, base_stride_bc->value, results)) {
return {};
}

// broadcast(_, r)
const Broadcast *base_bc = base_stride_ramp->base.as<Broadcast>();

if (!base_bc) {
return {};
}

Expr base = base_bc->value;
Expr base = results[0];
Expr stride;

bool found_stride = false;
Expand All @@ -308,7 +272,6 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) {
// this stride pattern can occur if `tile_r` is the same size as `acc`
auto stride_pattern = Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r);

std::vector<Expr> results{};
if (expr_match(stride_pattern, add_lhs->a, results)) {
found_stride = true;
stride = std::move(results[0]);
Expand Down Expand Up @@ -353,19 +316,41 @@ BaseStride get_rhs_tile_index(const Expr &index, int element_width, int tile_x,

return {true, rhs_tile3.base, rhs_tile3.stride[0] * element_width};
} else {
// 1D: degenerate as dot product. There are two cases:
// * tile_r is 4, so effectively there is only one row in the loaded tile
// * rhs.stride.1 == 4 && tile_y = 1, where the loaded RHS has shape (K/4)x4
// and is contiguous in the memory
if (rhs_tile1.extent[0] != tile_y * tile_r) {
return {};
}
if (!(rhs_tile1.stride[0].as<IntImm>() && rhs_tile1.stride[0].as<IntImm>()->value == 1)) {
return {};
}

if (tile_r == 4 / element_width) {
return {true, rhs_tile1.base, 0};
}

// times 4 because of the rhs layout, each vector used by AMX is 4 bytes in size.
// For the 4 gets divided by the element width which means each vector has 4 elements in u8/i8 and
// 2 elements for bf16.
return {true, rhs_tile1.base, rhs_tile1.stride[0] * (4 / element_width)};
if (tile_y == 1) {
// 4 elements in u8/i8 and 2 elements for bf16.
return {true, rhs_tile1.base, 4 / element_width};
}

return {};
}
} else {
// The only case where there is a ramp of ramp is when tile_y = 1 and so RHS has size (K/4)x4
// (and rhs.stride.1 != 4, for o.w. it degenerates to 1D)
if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) {
return {};
}
if (!(rhs_tile2.stride[1].as<IntImm>() && rhs_tile2.stride[1].as<IntImm>()->value == 1)) {
return {};
}

if (tile_y != 1) {
return {};
}

return {true, rhs_tile2.base, rhs_tile2.stride[0]};
}
Expand Down
16 changes: 15 additions & 1 deletion test/correctness/tiled_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "Halide.h"

#include <halide_test_dirs.h>
#include <stdio.h>

using namespace Halide;
Expand Down Expand Up @@ -134,6 +136,7 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) {
Buffer<int32_t> out(col, row);

result.realize(out);
// result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A_buf, B_buf}, target);

// uncomment to check the matrices
// std::cout << "Matrix A\n";
Expand Down Expand Up @@ -248,7 +251,18 @@ auto matmul_su = &matmul<int8_t, uint8_t>;
auto matmul_uu = &matmul<uint8_t, uint8_t>;

bool run_tests(bool (*fn)(int, int, int, int, int, int), int element_width) {
return fn(2, 2, 16, 2, 2, 8 / element_width) && fn(4, 4, 8, 4, 4, 8 / element_width) && fn(32, 32, 32, 8, 8, 8 / element_width) && fn(32, 32, 32, 8, 8, 4 / element_width);
return true
// TODO: tile_x and tile_y is not supported because they degenerate to a pattern that the matcher for LHS fails to recognize
// && fn(2, 2, 16, 1, 2, 4 / element_width)
// && fn(2, 2, 16, 2, 2, 4 / element_width)
&& fn(2, 2, 16, 2, 2, 8 / element_width)
&& fn(4, 4, 8, 4, 4, 8 / element_width)
&& fn(8, 8, 4, 8, 8, 4 / element_width)
&& fn(32, 32, 32, 8, 8, 8 / element_width)
&& fn(32, 32, 32, 8, 8, 4 / element_width)
&& fn(32, 32, 32, 6, 8, 4 / element_width)
&& fn(32, 32, 32, 6, 8, 8 / element_width)
;
}

int main(int argc, char **argv) {
Expand Down
4 changes: 4 additions & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ tests(GROUPS error
split_same_var_names.cpp
store_at_without_compute_at.cpp
thread_id_outside_block_id.cpp
tiled_matmul_wrong_layout.cpp
tiled_matmul_wrong_modulo.cpp
tiled_matmul_wrong_pattern.cpp
tiled_matmul_wrong_tiling.cpp
too_many_args.cpp
tuple_arg_select_undef.cpp
tuple_output_bounds_check.cpp
Expand Down
114 changes: 114 additions & 0 deletions test/error/tiled_matmul_wrong_layout.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "Halide.h"
#include "halide_test_dirs.h"
#include <stdio.h>

using namespace Halide;

template<typename IntT>
void fill_buffer_a(Buffer<IntT> &buf, int row, int acc) {
for (int iy = 0; iy < row; iy++) {
for (int ix = 0; ix < acc; ix++) {
buf(ix, iy) = rand() % 256 + std::numeric_limits<IntT>::min();
}
}
}

template<typename IntT>
void fill_buffer_b(Buffer<IntT> &buf, int col, int acc) {
for (int iy = 0; iy < acc / 4; iy++) {
for (int ix = 0; ix < col; ix++) {
for (int ik = 0; ik < 8; ++ik) {
buf(ik, ix, iy) = rand() % 256 + std::numeric_limits<IntT>::min();
}
}
}
}

template<typename LhsInt8, typename RhsInt8>
bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) {
Target target("x86-64-linux-avx512_sapphirerapids");
Buffer<LhsInt8> A_buf(acc, row);
// Each tile in B is padded with another 4 bytes.
Buffer<RhsInt8> B_buf(8, col, acc / 4);

Var x("x"), y("y");
RDom r(0, acc);

Func mm("matmul");
mm(x, y) = cast<int32_t>(0);
mm(x, y) += cast<int32_t>(A_buf(r, y)) * cast<int32_t>(B_buf(r % 4, x, r / 4));

Var rxi("rxi"), ryi("ryi");
RVar rri("rri"), rro("rro");

mm.compute_at(mm.in(), x)
.store_in(MemoryType::AMXTile)
.update()
.tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf)
.split(r, rro, rri, tile_r)
.reorder(rri, rxi, ryi, rro, x, y)
.atomic()
.vectorize(rri)
.vectorize(rxi)
.vectorize(ryi);

Var ixi("ixi"), iyi("iyi");
mm.compute_at(mm.in(), x)
.tile(x, y, ixi, iyi, tile_x, tile_y)
.vectorize(ixi)
.vectorize(iyi);

Var mmxi("mmxi"), mmyi("mmyi");
mm.in()
.tile(x, y, mmxi, mmyi, tile_x, tile_y)
.vectorize(mmxi)
.vectorize(mmyi);

Func result = mm.in();

if (!validate) {
// Should err with AMX mapping failure since B buffer has a
// different layout than expected by AMX
result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target);
} else {
std::cerr << "Validating compiled program\n";

fill_buffer_a(A_buf, row, acc);
fill_buffer_b(B_buf, col, acc);
Buffer<int32_t> out(col, row);
result.realize(out);

for (int j = 0; j < row; ++j) {
for (int i = 0; i < col; ++i) {
int32_t val = 0;
for (int k = 0; k < acc; ++k) {
val += static_cast<int32_t>(A_buf(k, j)) * static_cast<int32_t>(B_buf(k % 4, i, k / 4));
}
if (val != out(i, j)) {
std::cerr << "Invalid result at " << i << ", " << j << "\n"
<< out(i, j) << " != " << val << "\n"
<< "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n";
return false;
}
}
}
}

return true;
}

int main(int argc, char **argv) {
bool validate = false;
if (argc == 2 && argv[1] == std::string("--validate")) {
validate = true;
}
if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) {
std::cerr << "Skipping test since target does not support AMX\n";
return 0;
}
// Note theoretically we should be able to compile this if tile_x is set to 1, in which case
// each row of a tile becomes contiguous in memory again.
// However, we cannot do this because the matcher for LHS cannot handle the case
// when tile_x or tile_y is 1.
matmul<int8_t, int8_t>(32, 32, 32, 8, 8, 4, validate);
}
Loading
Loading