Skip to content

Commit

Permalink
Support lowering of vector.contract to amx for brgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
shahidact committed Mar 7, 2025
1 parent 0d504c7 commit a1a2315
Show file tree
Hide file tree
Showing 7 changed files with 919 additions and 1 deletion.
9 changes: 9 additions & 0 deletions include/TPP/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef TPP_PASSES_H
#define TPP_PASSES_H

#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
Expand Down Expand Up @@ -96,6 +97,14 @@ namespace xegpu {
class XeGPUDialect;
} // namespace xegpu

namespace amx {
class AMXDialect;
} // namespace amx

namespace x86vector {
class X86VectorDialect;
} // namespace x86vector

} // namespace mlir

namespace mlir {
Expand Down
12 changes: 12 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ def VectorContractToFMA : Pass<
"arith::ArithDialect"];
}

def VectorContractToAMX : Pass<
"vector-contract-to-amx"> {
let summary = "Perform vector amx lowering of vector contraction ops";
let dependentDialects = ["memref::MemRefDialect",
"scf::SCFDialect",
"tensor::TensorDialect",
"vector::VectorDialect",
"arith::ArithDialect",
"amx::AMXDialect",
"x86vector::X86VectorDialect"];
}


def BrgemmLinalgTiling : Pass<"tile-brgemm-linalg"> {
let summary = "Tile bregmm matmul and reduction dimension.";
Expand Down
2 changes: 2 additions & 0 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
// Lower to LLVM
ConvertVectorToLLVMPassOptions options;
options.amx = vnni::utils::hasAMX();
if (options.amx)
options.x86Vector = true;
pm.addPass(createConvertVectorToLLVMPass(options));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createSCFToControlFlowPass());
Expand Down
5 changes: 4 additions & 1 deletion lib/TPP/PassBundles/VectorToKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
//
//===----------------------------------------------------------------------===//

#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/Debug.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"

#include "TPP/PassBundles.h"
#include "TPP/PassUtils.h"
Expand Down Expand Up @@ -51,6 +52,8 @@ struct VectorToKernel : public tpp::impl::VectorToKernelBase<VectorToKernel>,
void constructPipeline() override {
pm.addNestedPass<func::FuncOp>(createHoistVectorTransfers());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (vnni::utils::hasAMX())
pm.addNestedPass<func::FuncOp>(createVectorContractToAMX());
pm.addNestedPass<func::FuncOp>(createVectorContractToFMA());
}
};
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_mlir_library(TPPTransforms
VectorContractToOuterproduct.cpp
HoistVectorTransfers.cpp
VectorContractToFMA.cpp
VectorContractToAMX.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
Loading

0 comments on commit a1a2315

Please sign in to comment.