Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[midend/lib/Conversion/LowerLinalgToGemmini] Add pass support for gemmini to run E2E LeNet and some tests. #463

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: buddy-opt %s \
// RUN: --convert-linalg-to-gemmini | \
// RUN: FileCheck %s

memref.global "private" @input : memref<1x7x7x1xi8> = dense<[[[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]]]]>

memref.global "private" @kernel : memref<1x5x5x1xi8> = dense<[[[[1], [1], [1], [1], [1]],
[[1], [1], [1], [1], [1]],
[[1], [1], [1], [1], [1]],
[[1], [1], [1], [1], [1]],
[[1], [1], [1], [1], [1]]]]>

func.func @main() -> i8 {
%0 = arith.constant 0 : i8
%input = memref.get_global @input : memref<1x7x7x1xi8>
%kernel = memref.get_global @kernel : memref<1x5x5x1xi8>
%output = memref.alloc() : memref<1x3x3x1xi8>

// CHECK: gemmini.tile_conv %{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} :
// CHECK-SAME: memref<1x7x7x1xi8> memref<25x1xi8> memref<1xi32> memref<9x1xi8> i64 i64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you don't need check-same, and feel that the content of check-smae is not important. The same is true for the following example.

Copy link
Member

@linuxlonelyeagle linuxlonelyeagle Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemmini.tile_conv ({{.*}} ({{.*}} ..., Maybe better since you're not capturing the variable and then using it.

linalg.conv_2d_nhwc_fhwc
ins(%input, %kernel : memref<1x7x7x1xi8>, memref<1x5x5x1xi8>)
outs(%output : memref<1x3x3x1xi8>)
gemmini.print %output : memref<1x3x3x1xi8>
return %0 : i8
}
31 changes: 31 additions & 0 deletions examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: buddy-opt %s \
// RUN: --convert-linalg-to-gemmini="acc_t=f32" | \
// RUN: FileCheck %s

memref.global "private" @input : memref<1x5x5x1xf32> = dense<[[[[1.],[2.],[3.],[4.],[5.]],
[[6.],[7.],[8.],[9.],[10.]],
[[11.],[12.],[13.],[14.],[15.]],
[[16.],[17.],[18.],[19.],[20.]],
[[21.],[22.],[23.],[24.],[25.]]]]>

memref.global "private" @kernel : memref<1x3x3x1xf32> = dense<[[[[1.], [1.], [1.]],
[[1.], [1.], [1.]],
[[1.], [1.], [1.]]]]>


func.func @main() -> i8 {
%0 = arith.constant 0 : i8
// batchsize = 2 inputchannel = 2
%input = memref.get_global @input : memref<1x5x5x1xf32>
// outputchannel = 3
%kernel = memref.get_global @kernel : memref<1x3x3x1xf32>
// batchsize h w outputchannel
%output = memref.alloc() : memref<1x3x3x1xf32>
// CHECK: gemmini.tile_conv %{{.+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} :
// CHECK: memref<1x5x5x1xf32> memref<9x1xf32> memref<1xf32> memref<9x1xf32> i64 i64
linalg.conv_2d_nhwc_fhwc
ins(%input, %kernel : memref<1x5x5x1xf32>, memref<1x3x3x1xf32>)
outs(%output : memref<1x3x3x1xf32>)
gemmini.print %output : memref<1x3x3x1xf32>
return %0 : i8
}
30 changes: 30 additions & 0 deletions examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: buddy-opt %s \
// RUN: --convert-linalg-to-gemmini | \
// RUN: FileCheck %s

memref.global "private" @input : memref<1x7x7x1xi8> = dense<[[[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]],
[[1],[1],[1],[1],[1],[1],[1]]]]>

memref.global "private" @kernel : memref<1x3x3x1xi8> = dense<[[[[1], [1], [1]],
[[1], [1], [1]],
[[1], [1], [1]]]]>

func.func @main() -> i8 {
%0 = arith.constant 0 : i8
%input = memref.get_global @input : memref<1x7x7x1xi8>
%kernel = memref.get_global @kernel : memref<1x3x3x1xi8>
%output = memref.alloc() : memref<1x5x5x1xi8>

// CHECK: gemmini.tile_conv %{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} :
// CHECK-SAME: memref<1x7x7x1xi8> memref<9x1xi8> memref<1xi32> memref<25x1xi8> i64 i64
linalg.conv_2d_nhwc_fhwc
ins(%input, %kernel : memref<1x7x7x1xi8>, memref<1x3x3x1xi8>)
outs(%output : memref<1x5x5x1xi8>)
gemmini.print %output : memref<1x5x5x1xi8>
return %0 : i8
}
20 changes: 20 additions & 0 deletions examples/GemminiDialect/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2491,3 +2491,23 @@ exo-matmul-4-run:
-I${RISCV}/../../generators/gemmini/software/gemmini-rocc-tests \
-O2 -static -o a.out
@spike --extension=gemmini pk a.out

gemmini-print-lower:
@${BUDDY_OPT} ./print.mlir \
-convert-linalg-to-gemmini \
-convert-linalg-to-loops \
-lower-gemmini \
-o log.mlir

gemmini-print-run:
@${BUDDY_OPT} ./print.mlir \
-convert-linalg-to-gemmini \
-convert-linalg-to-loops \
-lower-gemmini | \
${BUDDY_TRANSLATE} -buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-relocation-model=pic \
-o log.o
@riscv64-unknown-linux-gnu-gcc -O2 -static log.o -o print
@spike --extension=gemmini pk print
31 changes: 31 additions & 0 deletions examples/GemminiDialect/print.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: buddy-opt %s \
// RUN: --convert-linalg-to-gemmini \
// --convert-linalg-to-loops \
// --lower-gemmini | \
// RUN: FileCheck %s

func.func @main() -> i8 {
%c0 = arith.constant 0 : i8

%scalar = arith.constant 42 : i8
// CHECK: gemmini.print_scalar %{{.*}} : i8
gemmini.print_scalar %scalar : i8

%vector = memref.alloc() : memref<4xi8> // 1D向量
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use English.

%matrix = memref.alloc() : memref<2x3xi8> // 2D矩阵
%tensor = memref.alloc() : memref<1x2x3xi8> // 3D张量
%c1 = arith.constant 1 : i8
linalg.fill ins(%c1 : i8) outs(%vector : memref<4xi8>)
linalg.fill ins(%c1 : i8) outs(%matrix : memref<2x3xi8>)
// CHECK: gemmini.print %{{.*}} : memref<4xi8>
gemmini.print %vector : memref<4xi8>
// CHECK: gemmini.print %{{.*}} : memref<2x3xi8>
gemmini.print %matrix : memref<2x3xi8>
// CHECK: gemmini.print %{{.*}} : memref<1x2x3xi8>
gemmini.print %tensor : memref<1x2x3xi8>
memref.dealloc %vector : memref<4xi8>
memref.dealloc %matrix : memref<2x3xi8>
memref.dealloc %tensor : memref<1x2x3xi8>

return %c0 : i8
}
56 changes: 31 additions & 25 deletions midend/include/Dialect/Gemmini/Gemmini.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def ConfigStOp : Gemmini_Op<"config_st"> {
}];
let arguments = (ins I64:$stride,
DefaultValuedAttr<I64Attr, "0">:$activation,
DefaultValuedAttr<F32Attr, "1.0">:$scale);
DefaultValuedAttr<F32Attr, "1.0">:$scale);
let assemblyFormat = "$stride attr-dict `:` type($stride)";
}

Expand All @@ -88,19 +88,19 @@ def ConfigExOp : Gemmini_Op<"config_ex"> {
ConfigExOp configures the execute pipeline.
- dataflow: output-stationary (0) or weight-stationary (1) dataflow
- sysAct: activation function relu (1) or no activation function (0)
- sysShift: the number of bits by which the accumulated result of a matmul
- sysShift: the number of bits by which the accumulated result of a matmul
is right-shifted when leaving the systolic array.
- sysAccScale: the scalar value by which we scale the accType output of the
- sysAccScale: the scalar value by which we scale the accType output of the
accumulator down to inputType values when reading from the
accumulator.
(In the default config, rs1[63:32] is of type float32)
- cStride: TODO
- aStride: the stride (in scratchpad addresses) by which the rows of A are
fed into the systolic array. "A" in this context refers to the
left-hand matrix A in the matmul represented by A * B = C.
If this stride is 1, then we feed consecutive rows in the
scratchpad, starting from the starting address of A, into the
systolic array as the A matrix. If the stride is 2, then we feed
- aStride: the stride (in scratchpad addresses) by which the rows of A are
fed into the systolic array. "A" in this context refers to the
left-hand matrix A in the matmul represented by A * B = C.
If this stride is 1, then we feed consecutive rows in the
scratchpad, starting from the starting address of A, into the
systolic array as the A matrix. If the stride is 2, then we feed
every other row into the systolic array instead.
- aTranspose: transpose A
- bTranspose: transpose B
Expand Down Expand Up @@ -192,7 +192,13 @@ def MvoutOp : Gemmini_Op<"mvout"> {

def PrintOp : Gemmini_Op<"print"> {
let summary = "Print memref value.";
let arguments = (ins AnyTypeOf<[I8MemRef, I32MemRef, F32MemRef, F64MemRef]>:$input);
let arguments = (ins AnyTypeOf<[I8MemRef, I32MemRef, F32MemRef, F64MemRef]>:$input);
let assemblyFormat = "$input attr-dict `:` type($input)";
}

def PrintScalarOp : Gemmini_Op<"print_scalar"> {
let summary = "Print a scalar value.";
let arguments = (ins AnyType:$input);
let assemblyFormat = "$input attr-dict `:` type($input)";
}

Expand Down Expand Up @@ -224,7 +230,7 @@ def PreloadOp : Gemmini_Op<"preload"> {
let arguments = (ins I64:$bdAddr, I64:$cAddr, I64:$bdRows,
I64:$bdCols, I64:$cRows, I64:$cCols);
let assemblyFormat = [{
$bdAddr $cAddr $bdRows $bdCols $cRows $cCols attr-dict `:` type($bdAddr)
$bdAddr $cAddr $bdRows $bdCols $cRows $cCols attr-dict `:` type($bdAddr)
type($cAddr) type($bdRows) type($bdCols) type($cRows) type($cCols)
}];
}
Expand Down Expand Up @@ -308,17 +314,17 @@ def TileConvOp : Gemmini_Op<"tile_conv"> {
I64:$outRowDim, I64:$outColDim, I64:$kernelDim,
DefaultValuedAttr<F32Attr, "1.0">:$scale,
DefaultValuedAttr<I64Attr, "1">:$stride,
DefaultValuedAttr<I64Attr, "1">:$inputDilation,
DefaultValuedAttr<I64Attr, "1">:$inputDilation,
DefaultValuedAttr<I64Attr, "1">:$kernelDilation,
DefaultValuedAttr<I64Attr, "0">:$padding,
DefaultValuedAttr<I64Attr, "0">:$padding,
DefaultValuedAttr<BoolAttr, "false">:$wrot180,
DefaultValuedAttr<BoolAttr, "false">:$transOutput1203,
DefaultValuedAttr<BoolAttr, "false">:$transOutput1203,
DefaultValuedAttr<BoolAttr, "false">:$transInput3120,
DefaultValuedAttr<BoolAttr, "false">:$transWeight1203,
DefaultValuedAttr<BoolAttr, "false">:$transWeight1203,
DefaultValuedAttr<BoolAttr, "false">:$transWeight0132,
DefaultValuedAttr<I64Attr, "0">:$act,
DefaultValuedAttr<I64Attr, "0">:$poolSize,
DefaultValuedAttr<I64Attr, "0">:$poolStride,
DefaultValuedAttr<I64Attr, "0">:$poolStride,
DefaultValuedAttr<I64Attr, "0">:$poolPadding);
let assemblyFormat = [{
$input $weights $bias $output $outRowDim $outColDim $kernelDim attr-dict `:` type($input)
Expand All @@ -330,13 +336,13 @@ def TileConvOp : Gemmini_Op<"tile_conv"> {
// Gemmini intrinsic operation definitions
//===----------------------------------------------------------------------===//

class Gemmini_IntrOpBase<string mnemonic, list<Trait> traits = []> :
LLVM_IntrOpBase</*Dialect dialect=*/Gemmini_Dialect,
class Gemmini_IntrOpBase<string mnemonic, list<Trait> traits = []> :
LLVM_IntrOpBase</*Dialect dialect=*/Gemmini_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"riscv_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/[],
/*list<int> overloadedOperands=*/[],
/*list<Trait> traits=*/traits,
/*list<int> overloadedResults=*/[],
/*list<int> overloadedOperands=*/[],
/*list<Trait> traits=*/traits,
/*int numResults=*/0>;

def Gemmini_Mvin_IntrOp : Gemmini_IntrOpBase<"mvin">,
Expand All @@ -357,13 +363,13 @@ def Gemmini_Flush_IntrOp : Gemmini_IntrOpBase<"flush">,
def Gemmini_ConifgLd_IntrOp : Gemmini_IntrOpBase<"config_ld">,
Arguments<(ins LLVM_Type, LLVM_Type)>;

def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">,
def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">,
Arguments<(ins LLVM_Type, LLVM_Type)>;

def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">,
def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">,
Arguments<(ins LLVM_Type, LLVM_Type)>;

def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">,
def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">,
Arguments<(ins LLVM_Type, LLVM_Type)>;

def Gemmini_Preload_IntrOp : Gemmini_IntrOpBase<"preload">,
Expand Down Expand Up @@ -414,4 +420,4 @@ def Gemmini_LoopConvWsConfig5_IntrOp : Gemmini_IntrOpBase<"loop_conv_ws_config5"
def Gemmini_LoopConvWsConfig6_IntrOp : Gemmini_IntrOpBase<"loop_conv_ws_config6">,
Arguments<(ins LLVM_Type, LLVM_Type)>;

#endif
#endif
90 changes: 90 additions & 0 deletions midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,95 @@ class PrintOpLowering : public ConversionPattern {
}
};

class PrintScalarOpLowering : public ConversionPattern {
public:
explicit PrintScalarOpLowering(MLIRContext *context)
: ConversionPattern(gemmini::PrintScalarOp::getOperationName(), 1,
context) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto context = rewriter.getContext();
auto loc = op->getLoc();

ModuleOp parentModule = op->getParentOfType<ModuleOp>();

auto printfRef = getOrInsertPrintf(rewriter, parentModule);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use auto,The data types here can be confusing for people who are not familiar with this code.


Type elementType = op->getOperand(0).getType();
Value formatSpecifierCst;

if (elementType == rewriter.getF32Type() ||
elementType == rewriter.getF64Type()) {
formatSpecifierCst = getOrCreateGlobalString(
loc, rewriter, "scalar_fmt", StringRef("%f\n\0", 5), parentModule);
} else if (elementType == rewriter.getI8Type() ||
elementType == rewriter.getI32Type()) {
formatSpecifierCst = getOrCreateGlobalString(
loc, rewriter, "scalar_fmt", StringRef("%d\n\0", 5), parentModule);
}

Value valueToPrint = op->getOperand(0);
if (elementType == rewriter.getF32Type()) {
valueToPrint = rewriter.create<LLVM::FPExtOp>(loc, rewriter.getF64Type(),
valueToPrint);
} else if (elementType == rewriter.getI8Type()) {
valueToPrint = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI32Type(),
valueToPrint);
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unsupported data types, you should return failure.

rewriter.create<LLVM::CallOp>(
loc, getPrintfType(context), printfRef,
ArrayRef<Value>({formatSpecifierCst, valueToPrint}));

rewriter.eraseOp(op);
return success();
}

private:
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
auto llvmI32Ty = IntegerType::get(context, 32);
auto llvmPtr = LLVM::LLVMPointerType::get(context);
return LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtr, true);
}

static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get(context, "printf");

PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf",
getPrintfType(context));
return SymbolRefAttr::get(context, "printf");
}

static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
ModuleOp module) {
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMArrayType::get(
IntegerType::get(builder.getContext(), 8), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value), 0);
}

Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(0));
return builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
globalPtr, ArrayRef<Value>({cst0, cst0}));
}
};

namespace {
class LowerGemminiToLLVMPass
: public PassWrapper<LowerGemminiToLLVMPass, OperationPass<ModuleOp>> {
Expand Down Expand Up @@ -222,6 +311,7 @@ void LowerGemminiToLLVMPass::runOnOperation() {
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
patterns.add<PrintOpLowering>(&getContext());
patterns.add<PrintScalarOpLowering>(&getContext());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... context = &getContext();
patterns.add<...>(context);

if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
Expand Down
Loading