diff --git a/examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir b/examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir new file mode 100644 index 0000000000..a7d7662b10 --- /dev/null +++ b/examples/GemminiDialect/conv_2d_nhwc_fhwc_5x5_i8.mlir @@ -0,0 +1,32 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @input : memref<1x7x7x1xi8> = dense<[[[[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]]]]> + +memref.global "private" @kernel : memref<1x5x5x1xi8> = dense<[[[[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]], + [[1], [1], [1], [1], [1]]]]> + +func.func @main() -> i8 { + %0 = arith.constant 0 : i8 + %input = memref.get_global @input : memref<1x7x7x1xi8> + %kernel = memref.get_global @kernel : memref<1x5x5x1xi8> + %output = memref.alloc() : memref<1x3x3x1xi8> + + // CHECK: gemmini.tile_conv %{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} : + // CHECK-SAME: memref<1x7x7x1xi8> memref<25x1xi8> memref<1xi32> memref<9x1xi8> i64 i64 + linalg.conv_2d_nhwc_fhwc + ins(%input, %kernel : memref<1x7x7x1xi8>, memref<1x5x5x1xi8>) + outs(%output : memref<1x3x3x1xi8>) + gemmini.print %output : memref<1x3x3x1xi8> + return %0 : i8 +} diff --git a/examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir b/examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir new file mode 100644 index 0000000000..998b2d4388 --- /dev/null +++ b/examples/GemminiDialect/conv_2d_nhwc_fhwc_f32.mlir @@ -0,0 +1,31 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini="acc_t=f32" | \ +// RUN: FileCheck %s + +memref.global "private" @input : memref<1x5x5x1xf32> = dense<[[[[1.],[2.],[3.],[4.],[5.]], + [[6.],[7.],[8.],[9.],[10.]], + [[11.],[12.],[13.],[14.],[15.]], + [[16.],[17.],[18.],[19.],[20.]], + [[21.],[22.],[23.],[24.],[25.]]]]> + +memref.global "private" @kernel : memref<1x3x3x1xf32> = dense<[[[[1.], [1.], [1.]], + [[1.], [1.], [1.]], + [[1.], [1.], [1.]]]]> + + +func.func @main() -> i8 { + %0 = arith.constant 0 : i8 + // batchsize = 2 inputchannel = 2 + %input = memref.get_global @input : memref<1x5x5x1xf32> + // outputchannel = 3 + %kernel = memref.get_global @kernel : memref<1x3x3x1xf32> + // batchsize h w outputchannel + %output = memref.alloc() : memref<1x3x3x1xf32> + // CHECK: gemmini.tile_conv %{{.+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} : + // CHECK: memref<1x5x5x1xf32> memref<9x1xf32> memref<1xf32> memref<9x1xf32> i64 i64 + linalg.conv_2d_nhwc_fhwc + ins(%input, %kernel : memref<1x5x5x1xf32>, memref<1x3x3x1xf32>) + outs(%output : memref<1x3x3x1xf32>) + gemmini.print %output : memref<1x3x3x1xf32> + return %0 : i8 +} diff --git a/examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir b/examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir new file mode 100644 index 0000000000..0bfeafca19 --- /dev/null +++ b/examples/GemminiDialect/conv_2d_nhwc_fhwc_i8.mlir @@ -0,0 +1,30 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @input : memref<1x7x7x1xi8> = dense<[[[[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]], + [[1],[1],[1],[1],[1],[1],[1]]]]> + +memref.global "private" @kernel : memref<1x3x3x1xi8> = dense<[[[[1], [1], [1]], + [[1], [1], [1]], + [[1], [1], [1]]]]> + +func.func @main() -> i8 { + %0 = arith.constant 0 : i8 + %input = memref.get_global @input : memref<1x7x7x1xi8> + %kernel = memref.get_global @kernel : memref<1x3x3x1xi8> + %output = memref.alloc() : memref<1x5x5x1xi8> + + // CHECK: gemmini.tile_conv %{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %alloc_{{[0-9]+}} %{{.+}} %{{.+}} : + // CHECK-SAME: memref<1x7x7x1xi8> memref<9x1xi8> memref<1xi32> memref<25x1xi8> i64 i64 + linalg.conv_2d_nhwc_fhwc + ins(%input, %kernel : memref<1x7x7x1xi8>, memref<1x3x3x1xi8>) + outs(%output : memref<1x5x5x1xi8>) + gemmini.print %output : memref<1x5x5x1xi8> + return %0 : i8 +} diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index ca30047f40..873c8d4e10 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -2491,3 +2491,23 @@ exo-matmul-4-run: -I${RISCV}/../../generators/gemmini/software/gemmini-rocc-tests \ -O2 -static -o a.out @spike --extension=gemmini pk a.out + +gemmini-print-lower: + @${BUDDY_OPT} ./print.mlir \ + -convert-linalg-to-gemmini \ + -convert-linalg-to-loops \ + -lower-gemmini \ + -o log.mlir + +gemmini-print-run: + @${BUDDY_OPT} ./print.mlir \ + -convert-linalg-to-gemmini \ + -convert-linalg-to-loops \ + -lower-gemmini | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -relocation-model=pic \ + -o log.o + @riscv64-unknown-linux-gnu-gcc -O2 -static log.o -o print + @spike --extension=gemmini pk print diff --git a/examples/GemminiDialect/print.mlir b/examples/GemminiDialect/print.mlir new file mode 100644 index 0000000000..d366a9417f --- /dev/null +++ b/examples/GemminiDialect/print.mlir @@ -0,0 +1,31 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini \ +// --convert-linalg-to-loops \ +// --lower-gemmini | \ +// RUN: FileCheck %s + +func.func @main() -> i8 { + %c0 = arith.constant 0 : i8 + + %scalar = arith.constant 42 : i8 + // CHECK: gemmini.print_scalar %{{.*}} : i8 + gemmini.print_scalar %scalar : i8 + + %vector = memref.alloc() : memref<4xi8> // 1D向量 + %matrix = memref.alloc() : memref<2x3xi8> // 2D矩阵 + %tensor = memref.alloc() : memref<1x2x3xi8> // 3D张量 + %c1 = arith.constant 1 : i8 + linalg.fill ins(%c1 : i8) outs(%vector : memref<4xi8>) + linalg.fill ins(%c1 : i8) outs(%matrix : memref<2x3xi8>) + // CHECK: gemmini.print %{{.*}} : memref<4xi8> + gemmini.print %vector : memref<4xi8> + // CHECK: gemmini.print %{{.*}} : memref<2x3xi8> + gemmini.print %matrix : memref<2x3xi8> + // CHECK: gemmini.print %{{.*}} : memref<1x2x3xi8> + gemmini.print %tensor : memref<1x2x3xi8> + memref.dealloc %vector : memref<4xi8> + memref.dealloc %matrix : memref<2x3xi8> + memref.dealloc %tensor : memref<1x2x3xi8> + + return %c0 : i8 +} diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index e098ccc578..00d852258e 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -64,7 +64,7 @@ def ConfigStOp : Gemmini_Op<"config_st"> { }]; let arguments = (ins I64:$stride, DefaultValuedAttr:$activation, - DefaultValuedAttr:$scale); + DefaultValuedAttr:$scale); let assemblyFormat = "$stride attr-dict `:` type($stride)"; } @@ -88,19 +88,19 @@ def ConfigExOp : Gemmini_Op<"config_ex"> { ConfigExOp configures the execute pipeline. - dataflow: output-stationary (0) or weight-stationary (1) dataflow - sysAct: activation function relu (1) or no activation function (0) - - sysShift: the number of bits by which the accumulated result of a matmul + - sysShift: the number of bits by which the accumulated result of a matmul is right-shifted when leaving the systolic array. - - sysAccScale: the scalar value by which we scale the accType output of the + - sysAccScale: the scalar value by which we scale the accType output of the accumulator down to inputType values when reading from the accumulator. (In the default config, rs1[63:32] is of type float32) - cStride: TODO - - aStride: the stride (in scratchpad addresses) by which the rows of A are - fed into the systolic array. "A" in this context refers to the - left-hand matrix A in the matmul represented by A * B = C. - If this stride is 1, then we feed consecutive rows in the - scratchpad, starting from the starting address of A, into the - systolic array as the A matrix. If the stride is 2, then we feed + - aStride: the stride (in scratchpad addresses) by which the rows of A are + fed into the systolic array. "A" in this context refers to the + left-hand matrix A in the matmul represented by A * B = C. + If this stride is 1, then we feed consecutive rows in the + scratchpad, starting from the starting address of A, into the + systolic array as the A matrix. If the stride is 2, then we feed every other row into the systolic array instead. - aTranspose: transpose A - bTranspose: transpose B @@ -192,7 +192,13 @@ def MvoutOp : Gemmini_Op<"mvout"> { def PrintOp : Gemmini_Op<"print"> { let summary = "Print memref value."; - let arguments = (ins AnyTypeOf<[I8MemRef, I32MemRef, F32MemRef, F64MemRef]>:$input); + let arguments = (ins AnyTypeOf<[I8MemRef, I32MemRef, F32MemRef, F64MemRef]>:$input); + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +def PrintScalarOp : Gemmini_Op<"print_scalar"> { + let summary = "Print a scalar value."; + let arguments = (ins AnyType:$input); let assemblyFormat = "$input attr-dict `:` type($input)"; } @@ -224,7 +230,7 @@ def PreloadOp : Gemmini_Op<"preload"> { let arguments = (ins I64:$bdAddr, I64:$cAddr, I64:$bdRows, I64:$bdCols, I64:$cRows, I64:$cCols); let assemblyFormat = [{ - $bdAddr $cAddr $bdRows $bdCols $cRows $cCols attr-dict `:` type($bdAddr) + $bdAddr $cAddr $bdRows $bdCols $cRows $cCols attr-dict `:` type($bdAddr) type($cAddr) type($bdRows) type($bdCols) type($cRows) type($cCols) }]; } @@ -308,17 +314,17 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { I64:$outRowDim, I64:$outColDim, I64:$kernelDim, DefaultValuedAttr:$scale, DefaultValuedAttr:$stride, - DefaultValuedAttr:$inputDilation, + DefaultValuedAttr:$inputDilation, DefaultValuedAttr:$kernelDilation, - DefaultValuedAttr:$padding, + DefaultValuedAttr:$padding, DefaultValuedAttr:$wrot180, - DefaultValuedAttr:$transOutput1203, + DefaultValuedAttr:$transOutput1203, DefaultValuedAttr:$transInput3120, - DefaultValuedAttr:$transWeight1203, + DefaultValuedAttr:$transWeight1203, DefaultValuedAttr:$transWeight0132, DefaultValuedAttr:$act, DefaultValuedAttr:$poolSize, - DefaultValuedAttr:$poolStride, + DefaultValuedAttr:$poolStride, DefaultValuedAttr:$poolPadding); let assemblyFormat = [{ $input $weights $bias $output $outRowDim $outColDim $kernelDim attr-dict `:` type($input) @@ -330,13 +336,13 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { // Gemmini intrinsic operation definitions //===----------------------------------------------------------------------===// -class Gemmini_IntrOpBase traits = []> : - LLVM_IntrOpBase traits = []> : + LLVM_IntrOpBase overloadedResults=*/[], - /*list overloadedOperands=*/[], - /*list traits=*/traits, + /*list overloadedResults=*/[], + /*list overloadedOperands=*/[], + /*list traits=*/traits, /*int numResults=*/0>; def Gemmini_Mvin_IntrOp : Gemmini_IntrOpBase<"mvin">, @@ -357,13 +363,13 @@ def Gemmini_Flush_IntrOp : Gemmini_IntrOpBase<"flush">, def Gemmini_ConifgLd_IntrOp : Gemmini_IntrOpBase<"config_ld">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">, +def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, +def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">, +def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">, Arguments<(ins LLVM_Type, LLVM_Type)>; def Gemmini_Preload_IntrOp : Gemmini_IntrOpBase<"preload">, @@ -414,4 +420,4 @@ def Gemmini_LoopConvWsConfig5_IntrOp : Gemmini_IntrOpBase<"loop_conv_ws_config5" def Gemmini_LoopConvWsConfig6_IntrOp : Gemmini_IntrOpBase<"loop_conv_ws_config6">, Arguments<(ins LLVM_Type, LLVM_Type)>; -#endif +#endif diff --git a/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp b/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp index 0dcecff32f..2369848040 100644 --- a/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp +++ b/midend/lib/Conversion/LowerGemmini/LowerGemminiPass.cpp @@ -152,6 +152,95 @@ class PrintOpLowering : public ConversionPattern { } }; +class PrintScalarOpLowering : public ConversionPattern { +public: + explicit PrintScalarOpLowering(MLIRContext *context) + : ConversionPattern(gemmini::PrintScalarOp::getOperationName(), 1, + context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto context = rewriter.getContext(); + auto loc = op->getLoc(); + + ModuleOp parentModule = op->getParentOfType(); + + auto printfRef = getOrInsertPrintf(rewriter, parentModule); + + Type elementType = op->getOperand(0).getType(); + Value formatSpecifierCst; + + if (elementType == rewriter.getF32Type() || + elementType == rewriter.getF64Type()) { + formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "scalar_fmt", StringRef("%f\n\0", 5), parentModule); + } else if (elementType == rewriter.getI8Type() || + elementType == rewriter.getI32Type()) { + formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "scalar_fmt", StringRef("%d\n\0", 5), parentModule); + } + + Value valueToPrint = op->getOperand(0); + if (elementType == rewriter.getF32Type()) { + valueToPrint = rewriter.create(loc, rewriter.getF64Type(), + valueToPrint); + } else if (elementType == rewriter.getI8Type()) { + valueToPrint = rewriter.create(loc, rewriter.getI32Type(), + valueToPrint); + } + + rewriter.create( + loc, getPrintfType(context), printfRef, + ArrayRef({formatSpecifierCst, valueToPrint})); + + rewriter.eraseOp(op); + return success(); + } + +private: + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtr = LLVM::LLVMPointerType::get(context); + return LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtr, true); + } + + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module) { + auto *context = module.getContext(); + if (module.lookupSymbol("printf")) + return SymbolRefAttr::get(context, "printf"); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), "printf", + getPrintfType(context)); + return SymbolRefAttr::get(context, "printf"); + } + + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module) { + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 8), value.size()); + global = builder.create(loc, type, true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), 0); + } + + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create(loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), + globalPtr, ArrayRef({cst0, cst0})); + } +}; + namespace { class LowerGemminiToLLVMPass : public PassWrapper> { @@ -222,6 +311,7 @@ void LowerGemminiToLLVMPass::runOnOperation() { cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); populateFuncToLLVMConversionPatterns(converter, patterns); patterns.add(&getContext()); + patterns.add(&getContext()); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp index bfee320cc4..7def482508 100644 --- a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp +++ b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp @@ -47,7 +47,7 @@ class MatmulLowering : public OpRewritePattern { Value input0 = inputs[0]; Value input1 = inputs[1]; Value output0 = ouputs[0]; - MemRefType input0Type = dyn_cast(input0.getType()); + MemRefType input0Type = dyn_cast(input0.getType()); MemRefType biasType = MemRefType::get(input0Type.getShape(), rewriter.getI32Type()); TypedAttr fillOpInputAttr = rewriter.getI32IntegerAttr(0); @@ -75,6 +75,167 @@ class MatmulLowering : public OpRewritePattern { std::string accType; }; +class Conv2DNhwcFhwcLowering + : public OpRewritePattern { +public: + explicit Conv2DNhwcFhwcLowering(MLIRContext *context, std::string accType) + : OpRewritePattern(context), accType(accType) {} + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + Value input = convOp.getInputs()[0]; + Value kernel = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + Location loc = convOp.getLoc(); + + MemRefType inputType = dyn_cast(input.getType()); + MemRefType kernelType = dyn_cast(kernel.getType()); + MemRefType outputType = dyn_cast(output.getType()); + + Type kernelElemType = kernelType.getElementType(); + Type outputElemType = outputType.getElementType(); + + ArrayRef inputShape = inputType.getShape(); + + DenseIntElementsAttr dilationsAttr = convOp.getDilationsAttr(); + DenseIntElementsAttr stridesAttr = convOp.getStridesAttr(); + + size_t dilations = 1; + size_t strides = 1; + if (dilationsAttr) + dilations = (*dilationsAttr.begin()).getLimitedValue(); + if (stridesAttr) + strides = (*stridesAttr.begin()).getLimitedValue(); + + if (inputShape[1] != inputShape[2]) // h, w + return failure(); + ArrayRef kernelShape = kernelType.getShape(); + if (kernelShape[1] != kernelShape[2]) // h, w + return failure(); + ArrayRef outputShape = outputType.getShape(); + + // Create kernelMat(hwc, f) and outputMat(nhw, c). + SmallVector kernelMatShape = { + kernelShape[1] * kernelShape[2] * kernelShape[3], kernelShape[0]}; + MemRefType kernelMatType = MemRefType::get(kernelMatShape, kernelElemType); + Value kernelMat = rewriter.create(loc, kernelMatType); + + SmallVector outputMatShape = { + outputShape[0] * outputShape[1] * outputShape[2], outputShape[3]}; + MemRefType outputMatType = MemRefType::get(outputMatShape, outputElemType); + Value outputMat = rewriter.create(loc, outputMatType); + + MemRefType biasType = + MemRefType::get(outputShape[3], rewriter.getI32Type()); + if (accType == "f32") + biasType = MemRefType::get(outputShape[3], rewriter.getF32Type()); + Value bias = rewriter.create(loc, biasType); + + TypedAttr attr = rewriter.getI32IntegerAttr(0); + if (accType == "f32") + attr = rewriter.getF32FloatAttr(0); + Value constant0 = rewriter.create(loc, attr); + SmallVector inputs = {constant0}; + SmallVector outputs = {bias}; + rewriter.create(loc, inputs, outputs); + + // kernelShape + Operation *loopOp = nullptr; + SmallVector loopIvs; + for (size_t i = 0; i != kernelShape.size(); i++) { + Value lowerBound = rewriter.create(loc, 0); + Value upperBound = + rewriter.create(loc, kernelShape[i]); + Value step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + loopIvs.push_back(loop.getInductionVar()); + if (i == 0) + loopOp = loop.getOperation(); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + Value kernelDim = rewriter.create( + loc, kernelShape[1]); // dim_h = dim_w + Value inChannels = + rewriter.create(loc, kernelShape[3]); + + // Conv kernel mapping (f,h,w,c) -> (h*w*c, f) + Value tmp0 = + rewriter.create(loc, loopIvs[1], kernelDim); // h * kW + tmp0 = rewriter.create(loc, tmp0, inChannels); // * C + Value tmp1 = + rewriter.create(loc, loopIvs[2], inChannels); // w * C + tmp0 = rewriter.create(loc, tmp0, tmp1); // + (w * C) + tmp0 = rewriter.create(loc, tmp0, loopIvs[3]); // + c + + // load kernel + Value element = rewriter.create(loc, kernel, loopIvs); + SmallVector indices = {tmp0, loopIvs[0]}; // [h*w*c, f] + rewriter.create(loc, element, kernelMat, + indices); // Store the loaded data + rewriter.setInsertionPointAfter(loopOp); + + attr = rewriter.getI64IntegerAttr(outputShape[1]); + Value outRowDim = rewriter.create(loc, attr); + attr = rewriter.getI64IntegerAttr(outputShape[2]); + Value outColDim = rewriter.create(loc, attr); + kernelDim = rewriter.create(loc, attr); + kernelDim = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(kernelShape[1])); + + rewriter.create( + loc, input, kernelMat, bias, outputMat, outRowDim, outColDim, kernelDim, + llvm::APFloat(float(1.0)), strides, dilations); + + // After the conv operation is completed, the data in outputMat needs to be + // transferred into output (2-D to 4-D). + loopIvs.clear(); + indices.clear(); + + for (size_t i = 0; i < outputShape.size(); i++) { + Value lowerBound = rewriter.create(loc, 0); + Value upperBound = + rewriter.create(loc, outputShape[i]); + Value step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + loopIvs.push_back(loop.getInductionVar()); + if (i == 0) + loopOp = loop.getOperation(); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Map output from 2D (N*H*W, C) back to NHWC (n,h,w,c) + Value outH = rewriter.create(loc, outputShape[1]); + Value outW = rewriter.create(loc, outputShape[2]); + + // Calculate the row index in the 2D matrix: n * (H*W) + h * W + w + tmp0 = rewriter.create(loc, loopIvs[0], outH); // n * H + tmp0 = rewriter.create(loc, tmp0, outW); // * W + tmp1 = rewriter.create(loc, loopIvs[1], outW); // h * W + tmp0 = rewriter.create(loc, tmp0, tmp1); // + (h * W) + tmp0 = rewriter.create(loc, tmp0, loopIvs[2]); // + w + + // The index in the 2D matrix is [n*H*W + h*W + w, c] + indices.assign({tmp0, loopIvs[3]}); + + tmp0 = rewriter.create(loc, outputMat, indices); + rewriter.create(loc, tmp0, output, loopIvs); + rewriter.setInsertionPointAfter(loopOp); + + rewriter.create(loc, kernelMat); + rewriter.create(loc, outputMat); + rewriter.create(loc, bias); + + rewriter.eraseOp(convOp); + return success(); + } + +private: + std::string accType; +}; + class Conv2DNchwFchwLowering : public OpRewritePattern { public: @@ -88,9 +249,9 @@ class Conv2DNchwFchwLowering Value input1 = inputs[1]; Value output = convOp.getOutputs()[0]; Location loc = convOp.getLoc(); - MemRefType inputType = dyn_cast(input0.getType()); - MemRefType weightsType = dyn_cast(input1.getType()); - MemRefType outputType = dyn_cast(output.getType()); + MemRefType inputType = dyn_cast(input0.getType()); + MemRefType weightsType = dyn_cast(input1.getType()); + MemRefType outputType = dyn_cast(output.getType()); ArrayRef inputShape = inputType.getShape(); ArrayRef outputShape = outputType.getShape(); ArrayRef weightsShape = weightsType.getShape(); @@ -233,9 +394,9 @@ class Conv2DNhwcHwcfLowering Value kernel = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; Location loc = convOp.getLoc(); - MemRefType inputType = dyn_cast(input.getType()); - MemRefType kernelType = dyn_cast(kernel.getType()); - MemRefType outputType = dyn_cast(output.getType()); + MemRefType inputType = dyn_cast(input.getType()); + MemRefType kernelType = dyn_cast(kernel.getType()); + MemRefType outputType = dyn_cast(output.getType()); Type kernelElemType = kernelType.getElementType(); Type outputElemType = outputType.getElementType(); ArrayRef inputShape = inputType.getShape(); @@ -359,11 +520,11 @@ class BatchMatMulOpLowering : public OpRewritePattern { Value input0 = inputs[0]; Value input1 = inputs[1]; Value output = batchMatMulOp.getOutputs()[0]; - MemRefType input0Type = dyn_cast(input0.getType()); + MemRefType input0Type = dyn_cast(input0.getType()); ArrayRef input0Shape = input0Type.getShape(); - MemRefType input1Type = dyn_cast(input1.getType()); + MemRefType input1Type = dyn_cast(input1.getType()); ArrayRef input1Shape = input1Type.getShape(); - MemRefType outputType = dyn_cast(output.getType()); + MemRefType outputType = dyn_cast(output.getType()); ArrayRef outputShape = outputType.getShape(); Type elemType = input0Type.getElementType(); for (unsigned i = 0; i != input0Shape[0]; i++) { @@ -414,6 +575,7 @@ void populateLowerLinalgToGemminiConversionPatterns(RewritePatternSet &patterns, std::string accType) { patterns.add(patterns.getContext(), accType); patterns.add(patterns.getContext(), accType); + patterns.add(patterns.getContext(), accType); patterns.add(patterns.getContext(), accType); patterns.add(patterns.getContext()); } diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 31304a913e..76d2b9bab2 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -38,7 +38,8 @@ using namespace buddy::gemmini; namespace { int64_t getNumberFromValue(Value &value) { - return dyn_cast(value.getDefiningOp()->getAttr("value")).getInt(); + return dyn_cast(value.getDefiningOp()->getAttr("value")) + .getInt(); } acc_scale_t_bits acc_scale_t_to_acc_scale_t_bits(acc_scale_t x) { @@ -249,7 +250,8 @@ struct GemminiMvinLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value input = mvinOp.getInput(); Location loc = input.getLoc(); - MemRefType memRefType = dyn_cast(mvinOp.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvinOp.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); Value extractOp = rewriter.create( @@ -281,7 +283,8 @@ struct GemminiMvin2Lowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value input = mvin2Op.getInput(); Location loc = input.getLoc(); - MemRefType memRefType = dyn_cast(mvin2Op.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvin2Op.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); Value extractOp = rewriter.create( @@ -313,7 +316,8 @@ struct GemminiMvin3Lowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value input = mvin3Op.getInput(); Location loc = input.getLoc(); - MemRefType memRefType = dyn_cast(mvin3Op.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvin3Op.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); Value extractOp = rewriter.create( @@ -353,7 +357,8 @@ struct GemminiMvoutLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, i64Type, extractOp); Value spadAddr = mvoutOp.getAddr(); uint64_t number = getNumberFromValue(spadAddr); - MemRefType memRefType =dyn_cast(mvoutOp.getOperandTypes().front()); + MemRefType memRefType = + dyn_cast(mvoutOp.getOperandTypes().front()); llvm::ArrayRef memRefShape = memRefType.getShape(); uint64_t spadAddrInt = (uint64_t)memRefShape[0] << (addrLen + 16) | (uint64_t)memRefShape[1] << addrLen | number; @@ -947,9 +952,12 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { MemRefType bArrayType = dyn_cast(bArray.getType()); MemRefType cArrayType = dyn_cast(cArray.getType()); MemRefType dArrayType = dyn_cast(dArray.getType()); - StridedLayoutAttr aArrayLayout = dyn_cast(aArrayType.getLayout()); - StridedLayoutAttr bArrayLayout = dyn_cast(bArrayType.getLayout()); - StridedLayoutAttr cArrayLayout = dyn_cast(cArrayType.getLayout()); + StridedLayoutAttr aArrayLayout = + dyn_cast(aArrayType.getLayout()); + StridedLayoutAttr bArrayLayout = + dyn_cast(bArrayType.getLayout()); + StridedLayoutAttr cArrayLayout = + dyn_cast(cArrayType.getLayout()); SmallVector resultType = {rewriter.getIndexType()}; TypeRange typeRange(resultType); Location loc = tileMatMulOp.getLoc(); @@ -1145,30 +1153,26 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, rs1Value, rs2Value); } - void spTiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, - int outChannels, int outRowDim, int outColDim, - int poolOutRowDim, int poolOutColDim, int stride, - int padding, int kernelDim, int kernelDilation, int inStride, - int weightStride, int outStride, int poolSize, - int poolStride, int poolPadding, int batches, int porows, - int pocols, int pochs, int krows, int kcols, int kchs, - int lpad, int rpad, int upad, int dpad, int plpad, int prpad, - int pupad, int pdpad, Value &input, Value &weights, - Value &output, Value &bias, int act, acc_scale_t scale, - bool wrot180, bool transOutput1203, bool transInput3120, - bool transWeight1203, bool transWeight0132, bool noBias, - bool noPool, bool downsample, bool inputDilated, bool dw, - TileConvOp &tileConvOp, - ConversionPatternRewriter &rewriter) const { + void gemminiRiscConvWs( + int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, int poolOutRowDim, + int poolOutColDim, int stride, int padding, int kernelDim, + int kernelDilation, int inStride, int weightStride, int outStride, + int poolSize, int poolStride, int poolPadding, int batches, int porows, + int pocols, int pochs, int krows, int kcols, int kchs, int lpad, int rpad, + int upad, int dpad, int plpad, int prpad, int pupad, int pdpad, + Value &input, Value &weights, Value &output, Value &bias, int act, + acc_scale_t scale, bool wrot180, bool transOutput1203, + bool transInput3120, bool transWeight1203, bool transWeight0132, + bool noBias, bool noPool, bool downsample, bool inputDilated, + int maxPixelsPerRow, bool dw, TileConvOp &tileConvOp, + ConversionPatternRewriter &rewriter) const { Location loc = tileConvOp.getLoc(); - if (dw) { - kchs = 1; - pochs = 1; - } - const int orows = porows * poolStride + poolSize - 1 - pupad - pdpad; const int ocols = pocols * poolStride + poolSize - 1 - plpad - prpad; + + const int ichs = kchs; const int ochs = pochs; // Calculate image dimensions @@ -1180,10 +1184,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { int irowsUnpadded = irows - upad - dpad; int icolsUnpadded = icols - lpad - rpad; - const int ichs = kchs; - #define UNDILATED(x) ((inputDilated) ? (((x) + 1) / 2) : (x)) - if (inputDilated) { irowsUnpadded = (irowsUnpadded + 1) / 2; icolsUnpadded = (icolsUnpadded + 1) / 2; @@ -1192,18 +1193,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { icols = icolsUnpadded + UNDILATED(lpad) + UNDILATED(rpad); } -#ifdef HAS_FIRST_LAYER_OPTIMIZATIONS - const bool transposed = - transOutput1203 || transInput3120 || transWeight1203 || transWeight0132; - int maxPixelsPerRow = transposed || wrot180 || downsample || inputDilated || - kernelDilation > 1 || ichs > dim - ? 1 - : dim / ichs; - if (maxPixelsPerRow > kcols) - maxPixelsPerRow = kcols; -#else - const int maxPixelsPerRow = 1; -#endif // Calculate spad address offsets const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); const int inChannelsPerBank = kchs / dim + (kchs % dim != 0); @@ -1226,25 +1215,13 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { if (output != 0) { cSpAddrRow = (cSpAddrRow + accRows / 2) % accRows; } - if (inRowDim == inColDim && outRowDim == outColDim && - poolOutRowDim == poolOutColDim) { - gemminiLoopConvWs( - batchSize, inRowDim, inChannels, outChannels, outRowDim, - poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, - poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, - kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, - ocols, weights, output, bias, input, noBias, noPool, downsample, - wrot180, inputDilated, act, transOutput1203, transWeight1203, - transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, - rewriter); - return; - } - if (!noPool) { + + if ((inRowDim == inColDim) && (outRowDim == outColDim) && + (poolOutRowDim == poolOutColDim) && !noPool) { llvm::outs() << "Pooling with rectangular convolutions is currently not " "supported.\n"; return; } - // Only rectangular convolutions will use the following C code // mvin bias const size_t maxBlockLen = MAX_BYTES / (dim * 1); const size_t maxBlockLenAcc = MAX_BYTES / (dim * 4); @@ -1357,6 +1334,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } } } + // mvin weights if (weights != NULL) { int max_chs_per_mvin = @@ -1424,6 +1402,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } } } + // Compute { const int b_it = transInput3120 ? dim : 1; @@ -1444,14 +1423,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { bool newWeights = true; for (int b = 0; b < batches; b += b_it) { for (int orow = 0; orow < orows; orow++) { - // Skip some kernel rows due to input-dilation if (inputDilated && ((krow * kernelDilation + orow * stride - upad) % 2 != 0)) { continue; } for (int ocol = 0; ocol < ocols;) { - // Skip some cols dimensions due to input-dilation if (inputDilated && ((kcol + ocol * stride - lpad) % 2 != 0)) { ocol++; @@ -1575,463 +1552,500 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } } } else { - printf("Pooling with rectangular convolutions is currently not " - "supported.\n"); + // TODO: need to enable pooling + printf("Pooling in RISC mode is unsupported.\n "); exit(1); } } } - void tiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, - int outChannels, int outRowDim, int outColDim, int stride, - int inputDilation, int kernelDilation, int padding, - int kernelDim, int inStride, int weightStride, int outStride, - bool wrot180, bool transOutput1203, bool transInput3120, - bool transWeight1203, bool transWeight0132, int batches, - int porows, int pocols, int pochs, int krows, int kcols, - int kchs, const Value &input, const Value &weights, - const Value &bias, Value &output, int act, acc_scale_t scale, - int poolSize, int poolStride, int poolPadding, - TileConvOp &tileConvOp, - ConversionPatternRewriter &rewriter) const { - bool noBias = false; - bool noPool = poolStride == 0; - if (noPool) { - poolSize = 1; - poolStride = 1; - poolPadding = 0; + void spTiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, + int poolOutRowDim, int poolOutColDim, int stride, + int padding, int kernelDim, int kernelDilation, int inStride, + int weightStride, int outStride, int poolSize, + int poolStride, int poolPadding, int batches, int porows, + int pocols, int pochs, int krows, int kcols, int kchs, + int lpad, int rpad, int upad, int dpad, int plpad, int prpad, + int pupad, int pdpad, Value &input, Value &weights, + Value &output, Value &bias, int act, acc_scale_t scale, + bool wrot180, bool transOutput1203, bool transInput3120, + bool transWeight1203, bool transWeight0132, bool noBias, + bool noPool, bool downsample, bool inputDilated, bool dw, + TileConvOp &tileConvOp, + ConversionPatternRewriter &rewriter) const { + + if (dw) { + kchs = 1; + pochs = 1; } - const bool downsample = stride == 2 && kernelDim == 1 && - inRowDim % 2 == 0 && inColDim % 2 == 0 && - padding == 0 && noPool && inputDilation == 1 && - !transInput3120; - const int inputDilated = inputDilation == 2; - int64_t stDramStride = transOutput1203 - ? batchSize * outChannels * sizeOfElemT - : outChannels * sizeOfElemT; - Location loc = tileConvOp.getLoc(); - Value strideValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(stDramStride)); - rewriter.create(loc, strideValue, act, llvm::APFloat(scale)); - rewriter.create( - loc, /*dataflow = */ WEIGHT_STATIONARY, /*act = */ 0, /*shift = */ 0, - /*scale = */ llvm::APFloat((float)0), /*cStride = */ inputDilation, - /*aStride = */ stride >> downsample, - /*aTranspose = */ transInput3120, /*bTranspose*/ transWeight0132, - /*setOnlyStrides = */ false); - const int poolOutRowDim = - (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int poolOutColDim = - (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int dilatedInRowDim = inRowDim + (inputDilation - 1) * (inRowDim - 1); - const int dilatedInColDim = inColDim + (inputDilation - 1) * (inColDim - 1); - - int porowEnd = poolOutRowDim; - - for (int b = 0; b < batchSize; b += batches) { - for (int porow = 0; porow < porowEnd; porow += porows) { - const int orow = porow * poolStride - poolPadding; - for (int pocol = 0; pocol < poolOutColDim; pocol += pocols) { - const int ocol = pocol * poolStride - poolPadding; - for (int poch = 0; poch < outChannels; poch += pochs) { - for (int krow = 0; krow < kernelDim; krow += krows) { - const int orow_floored = orow < 0 ? 0 : orow; - - int irow = - orow_floored * stride + krow * kernelDilation - padding; - for (int kcol = 0; kcol < kernelDim; kcol += kcols) { - const int ocol_floored = ocol < 0 ? 0 : ocol; - int icol = - ocol_floored * stride + kcol * kernelDilation - padding; - - for (int kch = 0; kch < inChannels; kch += kchs) { - TypedAttr offsetAttr = rewriter.getI64IntegerAttr( - ((b * poolOutRowDim * poolOutColDim + - porow * poolOutColDim + pocol) * - outChannels + - poch) * - sizeOfElemT); - Value offsetValue = - rewriter.create(loc, offsetAttr); - Value out = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), output, - offsetValue); - if (transOutput1203) { - offsetAttr = rewriter.getI64IntegerAttr( - ((porow * poolOutColDim * batchSize + - pocol * batchSize + b) * + + const int ichs = kchs; + +#ifdef HAS_FIRST_LAYER_OPTIMIZATIONS + const bool transposed = + transOutput1203 || transInput3120 || transWeight1203 || transWeight0132; + int maxPixelsPerRow = transposed || wrot180 || downsample || inputDilated || + kernelDilation > 1 || ichs > dim + ? 1 + : dim / ichs; + if (maxPixelsPerRow > kcols) + maxPixelsPerRow = kcols; +#else + const int maxPixelsPerRow = 1; +#endif + + // TODO: add an option to select between gemminiRiscConvWs and + // gemminiLoopConvWs if (inRowDim == inColDim && outRowDim == outColDim && + // poolOutRowDim == poolOutColDim) { + // gemminiLoopConvWs( + // batchSize, inRowDim, inChannels, outChannels, outRowDim, + // poolOutRowDim, stride, padding, kernelDim, kernelDilation, + // poolSize, poolStride, poolPadding, batches, porows, pocols, pochs, + // krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, + // pdpad, orows, ocols, weights, output, bias, input, noBias, noPool, + // downsample, wrot180, inputDilated, act, transOutput1203, + // transWeight1203, transWeight0132, transInput3120, maxPixelsPerRow, + // dw, tileConvOp, rewriter); + // return; + // } + + gemminiRiscConvWs( + batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, + outColDim, poolOutRowDim, poolOutColDim, stride, padding, kernelDim, + kernelDilation, inStride, weightStride, outStride, poolSize, poolStride, + poolPadding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, + rpad, upad, dpad, plpad, prpad, pupad, pdpad, input, weights, output, + bias, act, scale, wrot180, transOutput1203, transInput3120, + transWeight1203, transWeight0132, noBias, noPool, downsample, + inputDilated, maxPixelsPerRow, dw, tileConvOp, rewriter); + + void tiledConv( + int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, int stride, + int inputDilation, int kernelDilation, int padding, int kernelDim, + int inStride, int weightStride, int outStride, bool wrot180, + bool transOutput1203, bool transInput3120, bool transWeight1203, + bool transWeight0132, int batches, int porows, int pocols, int pochs, + int krows, int kcols, int kchs, const Value &input, + const Value &weights, const Value &bias, Value &output, int act, + acc_scale_t scale, int poolSize, int poolStride, int poolPadding, + TileConvOp &tileConvOp, ConversionPatternRewriter &rewriter) const { + bool noBias = false; + bool noPool = poolStride == 0; + if (noPool) { + poolSize = 1; + poolStride = 1; + poolPadding = 0; + } + const bool downsample = stride == 2 && kernelDim == 1 && + inRowDim % 2 == 0 && inColDim % 2 == 0 && + padding == 0 && noPool && inputDilation == 1 && + !transInput3120; + const int inputDilated = inputDilation == 2; + int64_t stDramStride = transOutput1203 + ? batchSize * outChannels * sizeOfElemT + : outChannels * sizeOfElemT; + Location loc = tileConvOp.getLoc(); + Value strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(stDramStride)); + rewriter.create(loc, strideValue, act, llvm::APFloat(scale)); + rewriter.create( + loc, /*dataflow = */ WEIGHT_STATIONARY, /*act = */ 0, /*shift = */ 0, + /*scale = */ llvm::APFloat((float)0), /*cStride = */ inputDilation, + /*aStride = */ stride >> downsample, + /*aTranspose = */ transInput3120, /*bTranspose*/ transWeight0132, + /*setOnlyStrides = */ false); + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int dilatedInRowDim = + inRowDim + (inputDilation - 1) * (inRowDim - 1); + const int dilatedInColDim = + inColDim + (inputDilation - 1) * (inColDim - 1); + + int porowEnd = poolOutRowDim; + + for (int b = 0; b < batchSize; b += batches) { + for (int porow = 0; porow < porowEnd; porow += porows) { + const int orow = porow * poolStride - poolPadding; + for (int pocol = 0; pocol < poolOutColDim; pocol += pocols) { + const int ocol = pocol * poolStride - poolPadding; + for (int poch = 0; poch < outChannels; poch += pochs) { + for (int krow = 0; krow < kernelDim; krow += krows) { + const int orow_floored = orow < 0 ? 0 : orow; + + int irow = + orow_floored * stride + krow * kernelDilation - padding; + for (int kcol = 0; kcol < kernelDim; kcol += kcols) { + const int ocol_floored = ocol < 0 ? 0 : ocol; + int icol = + ocol_floored * stride + kcol * kernelDilation - padding; + + for (int kch = 0; kch < inChannels; kch += kchs) { + TypedAttr offsetAttr = rewriter.getI64IntegerAttr( + ((b * poolOutRowDim * poolOutColDim + + porow * poolOutColDim + pocol) * outChannels + poch) * sizeOfElemT); - offsetValue = + Value offsetValue = rewriter.create(loc, offsetAttr); - out = rewriter.create(tileConvOp.getLoc(), - rewriter.getI64Type(), - output, offsetValue); - } + Value out = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), output, + offsetValue); + if (transOutput1203) { + offsetAttr = rewriter.getI64IntegerAttr( + ((porow * poolOutColDim * batchSize + + pocol * batchSize + b) * + outChannels + + poch) * + sizeOfElemT); + offsetValue = + rewriter.create(loc, offsetAttr); + out = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), output, + offsetValue); + } - if (krow + krows < kernelDim || kcol + kcols < kernelDim || - kch + kchs < inChannels) { - out = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); - } - Value pochValue = rewriter.create( - tileConvOp.getLoc(), - rewriter.getI64IntegerAttr(poch * sizeOfAccT)); - Value bias_ = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), bias, - pochValue); - if (krow > 0 || kcol > 0 || kch > 0) { - bias_ = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); - } + if (krow + krows < kernelDim || kcol + kcols < kernelDim || + kch + kchs < inChannels) { + out = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); + } + Value pochValue = rewriter.create( + tileConvOp.getLoc(), + rewriter.getI64IntegerAttr(poch * sizeOfAccT)); + Value bias_ = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), bias, + pochValue); + if (krow > 0 || kcol > 0 || kch > 0) { + bias_ = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64IntegerAttr(0)); + } - const int batches_ = - batchSize - b > batches ? batches : batchSize - b; - const int porows_ = poolOutRowDim - porow > porows - ? porows - : poolOutRowDim - porow; - const int pocols_ = poolOutColDim - pocol > pocols - ? pocols - : poolOutColDim - pocol; - const int pochs_ = - outChannels - poch > pochs ? pochs : outChannels - poch; - const int krows_ = - kernelDim - krow > krows ? krows : kernelDim - krow; - const int kcols_ = - kernelDim - kcol > kcols ? kcols : kernelDim - kcol; - const int kchs_ = - inChannels - kch > kchs ? kchs : inChannels - kch; - - const int ocols_ = pocols_ * poolStride + poolSize - 1; - const int orows_ = porows_ * poolStride + poolSize - 1; - - const int plpad = ocol < 0 ? -ocol : 0; - const int prpad = - ocol + ocols_ > outColDim ? ocol + ocols_ - outColDim : 0; - const int pupad = orow < 0 ? -orow : 0; - const int pdpad = - orow + orows_ > outRowDim ? orow + orows_ - outRowDim : 0; - - const int dilatedKrows_ = - krows_ + (kernelDilation - 1) * (krows_ - 1); - const int dilatedKcols_ = - kcols_ + (kernelDilation - 1) * (kcols_ - 1); - - const int icols_ = - (ocols_ - plpad - prpad) * stride + dilatedKcols_ - 1; - const int irows_ = - (orows_ - pupad - pdpad) * stride + dilatedKrows_ - 1; - - int lpad = icol < 0 ? -icol : 0; - int rpad = icol + icols_ > dilatedInColDim - ? icol + icols_ - dilatedInColDim - : 0; - int upad = irow < 0 ? -irow : 0; - int dpad = irow + irows_ > dilatedInRowDim - ? irow + irows_ - dilatedInRowDim - : 0; - - if (inputDilated) { - lpad += lpad == 0 && icol % 2 != 0; - rpad += rpad == 0 && (icol + icols_) % 2 != 1; - upad += upad == 0 && irow % 2 != 0; - dpad += dpad == 0 && (irow + irows_) % 2 != 1; - } + const int batches_ = + batchSize - b > batches ? batches : batchSize - b; + const int porows_ = poolOutRowDim - porow > porows + ? porows + : poolOutRowDim - porow; + const int pocols_ = poolOutColDim - pocol > pocols + ? pocols + : poolOutColDim - pocol; + const int pochs_ = + outChannels - poch > pochs ? pochs : outChannels - poch; + const int krows_ = + kernelDim - krow > krows ? krows : kernelDim - krow; + const int kcols_ = + kernelDim - kcol > kcols ? kcols : kernelDim - kcol; + const int kchs_ = + inChannels - kch > kchs ? kchs : inChannels - kch; + + const int ocols_ = pocols_ * poolStride + poolSize - 1; + const int orows_ = porows_ * poolStride + poolSize - 1; + + const int plpad = ocol < 0 ? -ocol : 0; + const int prpad = ocol + ocols_ > outColDim + ? ocol + ocols_ - outColDim + : 0; + const int pupad = orow < 0 ? -orow : 0; + const int pdpad = orow + orows_ > outRowDim + ? orow + orows_ - outRowDim + : 0; + + const int dilatedKrows_ = + krows_ + (kernelDilation - 1) * (krows_ - 1); + const int dilatedKcols_ = + kcols_ + (kernelDilation - 1) * (kcols_ - 1); + + const int icols_ = + (ocols_ - plpad - prpad) * stride + dilatedKcols_ - 1; + const int irows_ = + (orows_ - pupad - pdpad) * stride + dilatedKrows_ - 1; + + int lpad = icol < 0 ? -icol : 0; + int rpad = icol + icols_ > dilatedInColDim + ? icol + icols_ - dilatedInColDim + : 0; + int upad = irow < 0 ? -irow : 0; + int dpad = irow + irows_ > dilatedInRowDim + ? irow + irows_ - dilatedInRowDim + : 0; - int krow_ = krow; - int kcol_ = kcol; - if (wrot180) { - krow_ = kernelDim - krow - krows_; - kcol_ = kernelDim - kcol - kcols_; - } - offsetAttr = rewriter.getI64IntegerAttr( - ((krow_ * kernelDim * inChannels + kcol_ * inChannels + - kch) * - outChannels + - poch) * - sizeOfElemT); - offsetValue = rewriter.create( - tileConvOp.getLoc(), offsetAttr); - Value weightsSlice = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), weights, - offsetValue); - if (transWeight1203) { + if (inputDilated) { + lpad += lpad == 0 && icol % 2 != 0; + rpad += rpad == 0 && (icol + icols_) % 2 != 1; + upad += upad == 0 && irow % 2 != 0; + dpad += dpad == 0 && (irow + irows_) % 2 != 1; + } + + int krow_ = krow; + int kcol_ = kcol; + if (wrot180) { + krow_ = kernelDim - krow - krows_; + kcol_ = kernelDim - kcol - kcols_; + } offsetAttr = rewriter.getI64IntegerAttr( - ((kch * kernelDim * kernelDim + krow_ * kernelDim + - kcol_) * + ((krow_ * kernelDim * inChannels + kcol_ * inChannels + + kch) * outChannels + poch) * sizeOfElemT); offsetValue = rewriter.create( tileConvOp.getLoc(), offsetAttr); - weightsSlice = rewriter.create( + Value weightsSlice = rewriter.create( tileConvOp.getLoc(), rewriter.getI64Type(), weights, offsetValue); - } else if (transWeight0132) { + if (transWeight1203) { + offsetAttr = rewriter.getI64IntegerAttr( + ((kch * kernelDim * kernelDim + krow_ * kernelDim + + kcol_) * + outChannels + + poch) * + sizeOfElemT); + offsetValue = rewriter.create( + tileConvOp.getLoc(), offsetAttr); + weightsSlice = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), weights, + offsetValue); + } else if (transWeight0132) { + offsetAttr = rewriter.getI64IntegerAttr( + ((krow_ * kernelDim * outChannels + + kcol_ * outChannels + poch) * + inChannels + + kch) * + sizeOfElemT); + offsetValue = rewriter.create( + tileConvOp.getLoc(), offsetAttr); + weightsSlice = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), weights, + offsetValue); + } offsetAttr = rewriter.getI64IntegerAttr( - ((krow_ * kernelDim * outChannels + - kcol_ * outChannels + poch) * + ((b * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + + ((icol + lpad) >> inputDilated)) * inChannels + kch) * sizeOfElemT); offsetValue = rewriter.create( tileConvOp.getLoc(), offsetAttr); - weightsSlice = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), weights, + Value in = rewriter.create( + tileConvOp.getLoc(), rewriter.getI64Type(), input, offsetValue); - } - offsetAttr = rewriter.getI64IntegerAttr( - ((b * inRowDim * inColDim + - ((irow + upad) >> inputDilated) * inColDim + - ((icol + lpad) >> inputDilated)) * - inChannels + - kch) * - sizeOfElemT); - offsetValue = rewriter.create( - tileConvOp.getLoc(), offsetAttr); - Value in = rewriter.create( - tileConvOp.getLoc(), rewriter.getI64Type(), input, - offsetValue); - if (transInput3120) { - offsetAttr = rewriter.getI64IntegerAttr( - ((kch * inRowDim * inColDim + - ((irow + upad) >> inputDilated) * inColDim + - ((icol + lpad) >> inputDilated)) * - batchSize + - b) * - sizeOfElemT); - in = rewriter.create(tileConvOp.getLoc(), - rewriter.getI64Type(), - input, offsetValue); - } + if (transInput3120) { + offsetAttr = rewriter.getI64IntegerAttr( + ((kch * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + + ((icol + lpad) >> inputDilated)) * + batchSize + + b) * + sizeOfElemT); + in = rewriter.create(tileConvOp.getLoc(), + rewriter.getI64Type(), + input, offsetValue); + } - spTiledConv( - batchSize, inRowDim, inColDim, inChannels, outChannels, - outRowDim, outColDim, poolOutRowDim, poolOutColDim, - stride, padding, kernelDim, kernelDilation, inStride, - weightStride, outStride, poolSize, poolStride, - poolPadding, batches_, porows_, pocols_, pochs_, krows_, - kcols_, kchs_, lpad, rpad, upad, dpad, plpad, prpad, - pupad, pdpad, in, weightsSlice, out, bias_, act, scale, - wrot180, transOutput1203, transInput3120, transWeight1203, - transWeight0132, noBias, noPool, downsample, inputDilated, - false, tileConvOp, rewriter); + spTiledConv( + batchSize, inRowDim, inColDim, inChannels, outChannels, + outRowDim, outColDim, poolOutRowDim, poolOutColDim, + stride, padding, kernelDim, kernelDilation, inStride, + weightStride, outStride, poolSize, poolStride, + poolPadding, batches_, porows_, pocols_, pochs_, krows_, + kcols_, kchs_, lpad, rpad, upad, dpad, plpad, prpad, + pupad, pdpad, in, weightsSlice, out, bias_, act, scale, + wrot180, transOutput1203, transInput3120, + transWeight1203, transWeight0132, noBias, noPool, + downsample, inputDilated, false, tileConvOp, rewriter); + } } } } } } } + IntegerAttr flushAttr = rewriter.getI64IntegerAttr(0); + Value flushValue = rewriter.create( + loc, rewriter.getI64Type(), flushAttr); + rewriter.replaceOpWithNewOp(tileConvOp, flushValue, + flushValue); } - IntegerAttr flushAttr = rewriter.getI64IntegerAttr(0); - Value flushValue = rewriter.create( - loc, rewriter.getI64Type(), flushAttr); - rewriter.replaceOpWithNewOp(tileConvOp, flushValue, - flushValue); - } - int tiledConvTotalSpadRows(bool acc, int stride, int inputDilation, - int kernelDilation, bool downsample, - bool transWeight0132, bool transInput3120, - int batches, int porows, int pocols, int ochs, - int krows, int kcols, int kchs, int poolSize, - int poolStride) const { + int tiledConvTotalSpadRows( + bool acc, int stride, int inputDilation, int kernelDilation, + bool downsample, bool transWeight0132, bool transInput3120, int batches, + int porows, int pocols, int ochs, int krows, int kcols, int kchs, + int poolSize, int poolStride) const { - const int orows = porows * poolStride + poolSize - 1; - const int ocols = pocols * poolStride + poolSize - 1; + const int orows = porows * poolStride + poolSize - 1; + const int ocols = pocols * poolStride + poolSize - 1; - const int krowsDilated = krows + (kernelDilation - 1) * (krows - 1); - const int kcolsDilated = kcols + (kernelDilation - 1) * (kcols - 1); + const int krowsDilated = krows + (kernelDilation - 1) * (krows - 1); + const int kcolsDilated = kcols + (kernelDilation - 1) * (kcols - 1); - int irows = orows * stride + krowsDilated - 1; - int icols = ocols * stride + kcolsDilated - 1; - const int ichs = kchs; + int irows = orows * stride + krowsDilated - 1; + int icols = ocols * stride + kcolsDilated - 1; + const int ichs = kchs; - irows = irows / inputDilation + (irows % inputDilation != 0); - icols = icols / inputDilation + (icols % inputDilation != 0); + irows = irows / inputDilation + (irows % inputDilation != 0); + icols = icols / inputDilation + (icols % inputDilation != 0); - const int inChannelsPerBank = ichs / dim + (ichs % dim != 0); - const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); - const int batchesPerBank = batches / dim + (batches % dim != 0); - - const int aRows = transInput3120 - ? (batchesPerBank * ichs * (irows >> downsample) * - (icols >> downsample)) - : (inChannelsPerBank * batches * - (irows >> downsample) * (icols >> downsample)); + const int inChannelsPerBank = ichs / dim + (ichs % dim != 0); + const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); + const int batchesPerBank = batches / dim + (batches % dim != 0); - const int bRows = transWeight0132 - ? inChannelsPerBank * kcols * krows * ochs - : outChannelsPerBank * kcols * krows * kchs; + const int aRows = transInput3120 + ? (batchesPerBank * ichs * (irows >> downsample) * + (icols >> downsample)) + : (inChannelsPerBank * batches * + (irows >> downsample) * (icols >> downsample)); - const int cRows = outChannelsPerBank * batches * orows * ocols; + const int bRows = transWeight0132 + ? inChannelsPerBank * kcols * krows * ochs + : outChannelsPerBank * kcols * krows * kchs; - return acc ? cRows : aRows + bRows; - } + const int cRows = outChannelsPerBank * batches * orows * ocols; -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - explicit GemminiTileConvLowering(LLVMTypeConverter &typeConverter, - int64_t dim, int64_t addrLen, - int64_t accRows, int64_t bankRows, - size_t sizeOfElemT, size_t sizeOfAccT) - : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), - accRows(accRows), bankRows(bankRows), sizeOfElemT(sizeOfElemT), - sizeOfAccT(sizeOfAccT) {} - LogicalResult - matchAndRewrite(TileConvOp tileConvOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value input = tileConvOp.getInput(); - Value output = tileConvOp.getOutput(); - Value weights = tileConvOp.getWeights(); - Value bias = tileConvOp.getBias(); - MemRefType inputType = dyn_cast(input.getType()); - MemRefType biasType = dyn_cast(bias.getType()); - ArrayRef inputShape = inputType.getShape(); - ArrayRef biasShape = biasType.getShape(); - - Value outRowDimValue = tileConvOp.getOutRowDim(); - int outRowDim = getNumberFromValue(outRowDimValue); - Value outColDimValue = tileConvOp.getOutColDim(); - int outColDim = getNumberFromValue(outColDimValue); - Value kernelDimValue = tileConvOp.getKernelDim(); - int kernelDim = getNumberFromValue(kernelDimValue); - int batchSize = inputShape[0]; - int inRowDim = inputShape[1]; - int inColDim = inputShape[2]; - int inChannels = inputShape[3]; - int outChannels = biasShape[0]; - int stride = tileConvOp.getStride(); - int inputDilation = tileConvOp.getInputDilation(); - int kernelDilation = tileConvOp.getKernelDilation(); - int padding = tileConvOp.getPadding(); - int act = tileConvOp.getAct(); - float scale = tileConvOp.getScale().convertToFloat(); - int poolSize = tileConvOp.getPoolSize(); - int poolStride = tileConvOp.getPoolStride(); - int poolPadding = tileConvOp.getPoolPadding(); - bool wrot180 = tileConvOp.getWrot180(); - bool transOutput1203 = tileConvOp.getTransOutput1203(); - bool transInput3120 = tileConvOp.getTransInput3120(); - bool transWeight1203 = tileConvOp.getTransWeight1203(); - bool transWeight0132 = tileConvOp.getTransWeight0132(); - Location loc = tileConvOp.getLoc(); - IntegerType i64Type = rewriter.getI64Type(); - Value inputExtractOp = - rewriter.create(loc, input); - Value inputIndexCastOp = - rewriter.create(loc, i64Type, inputExtractOp); - Value outputExtractOp = - rewriter.create(loc, output); - Value outputIndexCastOp = - rewriter.create(loc, i64Type, outputExtractOp); - Value biasExtractOp = - rewriter.create(loc, bias); - Value biasIndexCastOp = - rewriter.create(loc, i64Type, biasExtractOp); - Value weightsExtractOp = - rewriter.create(loc, weights); - Value weightsIndexCastOp = - rewriter.create(loc, i64Type, weightsExtractOp); - const bool noPool = poolSize == 0; - if (noPool) { - poolSize = 1; - poolStride = 1; - poolPadding = 0; + return acc ? cRows : aRows + bRows; } - const int poolOutRowDim = - (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int poolOutColDim = - (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; - const bool downsample = stride == 2 && kernelDim == 1 && padding == 0 && - noPool && inRowDim % 2 == 0 && inColDim % 2 == 0; - int args[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, - kernelDim, kernelDim, inChannels}; - const int maxArgs[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, - kernelDim, kernelDim, inChannels}; - const int orowsIdx = 1; - const int ocolsIdx = 2; - const int outChannelsIdx = 3; - const int inChannelsIdx = 6; - const int maxSpadRows = (BANK_NUM * bankRows / 2); - const int maxAccRows = (accRows / 2); - int spadRows = tiledConvTotalSpadRows( - false, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, args[0], args[1], args[2], args[3], - args[4], args[5], args[6], poolSize, poolStride); - int accRows = tiledConvTotalSpadRows( - true, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, args[0], args[1], args[2], args[3], - args[4], args[5], args[6], poolSize, poolStride); - while (spadRows > maxSpadRows || accRows > maxAccRows) { - int maxVal = -1; - int maxIdx = -1; - for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { - if (!(i == ocolsIdx && args[i] <= dim && args[orowsIdx] > 1) && - args[i] > maxVal) { - maxVal = args[i]; - maxIdx = i; - } - } - if (maxIdx == outChannelsIdx || maxIdx == inChannelsIdx) { - if (args[maxIdx] % dim != 0) { - args[maxIdx] = (args[maxIdx] / dim) * dim; - } else { - args[maxIdx] -= dim; - } - args[maxIdx] = args[maxIdx] == 0 ? 1 : args[maxIdx]; - } else { - args[maxIdx]--; + public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit GemminiTileConvLowering(LLVMTypeConverter & typeConverter, + int64_t dim, int64_t addrLen, + int64_t accRows, int64_t bankRows, + size_t sizeOfElemT, size_t sizeOfAccT) + : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), + accRows(accRows), bankRows(bankRows), sizeOfElemT(sizeOfElemT), + sizeOfAccT(sizeOfAccT) {} + LogicalResult matchAndRewrite(TileConvOp tileConvOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) + const override { + Value input = tileConvOp.getInput(); + Value output = tileConvOp.getOutput(); + Value weights = tileConvOp.getWeights(); + Value bias = tileConvOp.getBias(); + MemRefType inputType = dyn_cast(input.getType()); + MemRefType biasType = dyn_cast(bias.getType()); + ArrayRef inputShape = inputType.getShape(); + ArrayRef biasShape = biasType.getShape(); + + Value outRowDimValue = tileConvOp.getOutRowDim(); + int outRowDim = getNumberFromValue(outRowDimValue); + Value outColDimValue = tileConvOp.getOutColDim(); + int outColDim = getNumberFromValue(outColDimValue); + Value kernelDimValue = tileConvOp.getKernelDim(); + int kernelDim = getNumberFromValue(kernelDimValue); + int batchSize = inputShape[0]; + int inRowDim = inputShape[1]; + int inColDim = inputShape[2]; + int inChannels = inputShape[3]; + int outChannels = biasShape[0]; + int stride = tileConvOp.getStride(); + int inputDilation = tileConvOp.getInputDilation(); + int kernelDilation = tileConvOp.getKernelDilation(); + int padding = tileConvOp.getPadding(); + int act = tileConvOp.getAct(); + float scale = tileConvOp.getScale().convertToFloat(); + int poolSize = tileConvOp.getPoolSize(); + int poolStride = tileConvOp.getPoolStride(); + int poolPadding = tileConvOp.getPoolPadding(); + bool wrot180 = tileConvOp.getWrot180(); + bool transOutput1203 = tileConvOp.getTransOutput1203(); + bool transInput3120 = tileConvOp.getTransInput3120(); + bool transWeight1203 = tileConvOp.getTransWeight1203(); + bool transWeight0132 = tileConvOp.getTransWeight0132(); + Location loc = tileConvOp.getLoc(); + IntegerType i64Type = rewriter.getI64Type(); + Value inputExtractOp = + rewriter.create(loc, input); + Value inputIndexCastOp = + rewriter.create(loc, i64Type, inputExtractOp); + Value outputExtractOp = + rewriter.create(loc, output); + Value outputIndexCastOp = + rewriter.create(loc, i64Type, outputExtractOp); + Value biasExtractOp = + rewriter.create(loc, bias); + Value biasIndexCastOp = + rewriter.create(loc, i64Type, biasExtractOp); + Value weightsExtractOp = + rewriter.create(loc, weights); + Value weightsIndexCastOp = + rewriter.create(loc, i64Type, weightsExtractOp); + const bool noPool = poolSize == 0; + if (noPool) { + poolSize = 1; + poolStride = 1; + poolPadding = 0; } - spadRows = tiledConvTotalSpadRows( + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; + const bool downsample = stride == 2 && kernelDim == 1 && padding == 0 && + noPool && inRowDim % 2 == 0 && inColDim % 2 == 0; + int args[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, + kernelDim, kernelDim, inChannels}; + const int maxArgs[] = {batchSize, poolOutRowDim, poolOutColDim, + outChannels, kernelDim, kernelDim, + inChannels}; + const int orowsIdx = 1; + const int ocolsIdx = 2; + const int outChannelsIdx = 3; + const int inChannelsIdx = 6; + const int maxSpadRows = (BANK_NUM * bankRows / 2); + const int maxAccRows = (accRows / 2); + int spadRows = tiledConvTotalSpadRows( false, stride, inputDilation, kernelDilation, downsample, transWeight0132, transInput3120, args[0], args[1], args[2], args[3], args[4], args[5], args[6], poolSize, poolStride); - accRows = tiledConvTotalSpadRows( + int accRows = tiledConvTotalSpadRows( true, stride, inputDilation, kernelDilation, downsample, transWeight0132, transInput3120, args[0], args[1], args[2], args[3], args[4], args[5], args[6], poolSize, poolStride); - } - bool notIncreased = false; - while (!notIncreased) { - notIncreased = true; - - int argsCandidate[] = {args[0], args[1], args[2], args[3], - args[4], args[5], args[6]}; - argsCandidate[ocolsIdx]++; - - if (argsCandidate[ocolsIdx] > maxArgs[ocolsIdx]) - continue; - - spadRows = tiledConvTotalSpadRows( - false, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, argsCandidate[0], argsCandidate[1], - argsCandidate[2], argsCandidate[3], argsCandidate[4], - argsCandidate[5], argsCandidate[6], poolSize, poolStride); - accRows = tiledConvTotalSpadRows( - true, stride, inputDilation, kernelDilation, downsample, - transWeight0132, transInput3120, argsCandidate[0], argsCandidate[1], - argsCandidate[2], argsCandidate[3], argsCandidate[4], - argsCandidate[5], argsCandidate[6], poolSize, poolStride); + while (spadRows > maxSpadRows || accRows > maxAccRows) { + int maxVal = -1; + int maxIdx = -1; + for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { + if (!(i == ocolsIdx && args[i] <= dim && args[orowsIdx] > 1) && + args[i] > maxVal) { + maxVal = args[i]; + maxIdx = i; + } + } - if (spadRows <= maxSpadRows && accRows <= maxAccRows) { - args[ocolsIdx] = argsCandidate[ocolsIdx]; - notIncreased = false; + if (maxIdx == outChannelsIdx || maxIdx == inChannelsIdx) { + if (args[maxIdx] % dim != 0) { + args[maxIdx] = (args[maxIdx] / dim) * dim; + } else { + args[maxIdx] -= dim; + } + args[maxIdx] = args[maxIdx] == 0 ? 1 : args[maxIdx]; + } else { + args[maxIdx]--; + } + spadRows = tiledConvTotalSpadRows( + false, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, args[0], args[1], args[2], args[3], + args[4], args[5], args[6], poolSize, poolStride); + accRows = tiledConvTotalSpadRows( + true, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, args[0], args[1], args[2], args[3], + args[4], args[5], args[6], poolSize, poolStride); } - } + bool notIncreased = false; + while (!notIncreased) { + notIncreased = true; - bool nothingIncreased = false; - while (!nothingIncreased) { - nothingIncreased = true; - for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { int argsCandidate[] = {args[0], args[1], args[2], args[3], args[4], args[5], args[6]}; - argsCandidate[i]++; + argsCandidate[ocolsIdx]++; - if (argsCandidate[i] > maxArgs[i]) + if (argsCandidate[ocolsIdx] > maxArgs[ocolsIdx]) continue; + spadRows = tiledConvTotalSpadRows( false, stride, inputDilation, kernelDilation, downsample, transWeight0132, transInput3120, argsCandidate[0], argsCandidate[1], @@ -2044,82 +2058,113 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { argsCandidate[5], argsCandidate[6], poolSize, poolStride); if (spadRows <= maxSpadRows && accRows <= maxAccRows) { - args[i] = argsCandidate[i]; - nothingIncreased = false; + args[ocolsIdx] = argsCandidate[ocolsIdx]; + notIncreased = false; } } + + bool nothingIncreased = false; + while (!nothingIncreased) { + nothingIncreased = true; + for (size_t i = 0; i < sizeof(args) / sizeof(args[0]); i++) { + int argsCandidate[] = {args[0], args[1], args[2], args[3], + args[4], args[5], args[6]}; + argsCandidate[i]++; + + if (argsCandidate[i] > maxArgs[i]) + continue; + spadRows = tiledConvTotalSpadRows( + false, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, argsCandidate[0], + argsCandidate[1], argsCandidate[2], argsCandidate[3], + argsCandidate[4], argsCandidate[5], argsCandidate[6], poolSize, + poolStride); + accRows = tiledConvTotalSpadRows( + true, stride, inputDilation, kernelDilation, downsample, + transWeight0132, transInput3120, argsCandidate[0], + argsCandidate[1], argsCandidate[2], argsCandidate[3], + argsCandidate[4], argsCandidate[5], argsCandidate[6], poolSize, + poolStride); + + if (spadRows <= maxSpadRows && accRows <= maxAccRows) { + args[i] = argsCandidate[i]; + nothingIncreased = false; + } + } + } + const int batches = args[0]; + const int orows = args[1]; + const int ocols = args[2]; + const int ochs = args[3]; + const int krows = args[4]; + const int kcols = args[5]; + const int kchs = args[6]; + + const int inStride = inChannels; + const int outStride = outChannels; + const int weightStride = outChannels; + tiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, + outRowDim, outColDim, stride, inputDilation, kernelDilation, + padding, kernelDim, inStride, weightStride, outStride, wrot180, + transOutput1203, transInput3120, transWeight1203, + transWeight0132, batches, orows, ocols, ochs, krows, kcols, + kchs, inputIndexCastOp, weightsIndexCastOp, biasIndexCastOp, + outputIndexCastOp, act, scale, poolSize, + noPool ? 0 : poolStride, poolPadding, tileConvOp, rewriter); + return success(); } - const int batches = args[0]; - const int orows = args[1]; - const int ocols = args[2]; - const int ochs = args[3]; - const int krows = args[4]; - const int kcols = args[5]; - const int kchs = args[6]; - - const int inStride = inChannels; - const int outStride = outChannels; - const int weightStride = outChannels; - tiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, - outColDim, stride, inputDilation, kernelDilation, padding, - kernelDim, inStride, weightStride, outStride, wrot180, - transOutput1203, transInput3120, transWeight1203, transWeight0132, - batches, orows, ocols, ochs, krows, kcols, kchs, inputIndexCastOp, - weightsIndexCastOp, biasIndexCastOp, outputIndexCastOp, act, - scale, poolSize, noPool ? 0 : poolStride, poolPadding, tileConvOp, - rewriter); - return success(); - } -private: - int64_t dim; - int64_t addrLen; - int64_t accRows; - int64_t bankRows; - size_t sizeOfElemT; - size_t sizeOfAccT; -}; + private: + int64_t dim; + int64_t addrLen; + int64_t accRows; + int64_t bankRows; + size_t sizeOfElemT; + size_t sizeOfAccT; + }; -void mlir::populateGemminiLegalizeForLLVMExportPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns, int64_t dim, - int64_t addrLen, int64_t accRows, int64_t bankRows, size_t sizeOfElemT, - size_t sizeOfAccT) { - patterns - .add, ForwardOperands, - ForwardOperands>(converter, &converter.getContext()); - patterns.add(converter); - patterns.add(converter); - patterns.add(converter, dim); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter); - patterns.add(converter); - patterns.add(converter, dim, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, addrLen); - patterns.add(converter, dim, addrLen, accRows, + void mlir::populateGemminiLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, int64_t dim, + int64_t addrLen, int64_t accRows, int64_t bankRows, size_t sizeOfElemT, + size_t sizeOfAccT) { + patterns.add, + ForwardOperands, + ForwardOperands>(converter, + &converter.getContext()); + patterns.add(converter); + patterns.add(converter); + patterns.add(converter, dim); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter); + patterns.add(converter); + patterns.add(converter, dim, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, dim, addrLen, accRows, + bankRows, sizeOfElemT, sizeOfAccT); + patterns.add(converter, dim, addrLen, accRows, bankRows, sizeOfElemT, sizeOfAccT); - patterns.add(converter, dim, addrLen, accRows, - bankRows, sizeOfElemT, sizeOfAccT); -} + } -void mlir::configureGemminiLegalizeForExportTarget( - LLVMConversionTarget &target) { - target.addLegalOp< - Flush_IntrOp, ConfigSt_IntrOp, ConifgLd_IntrOp, ConfigEX_IntrOp, - Mvin_IntrOp, Mvin2_IntrOp, Mvin3_IntrOp, Mvout_IntrOp, Preload_IntrOp, - ComputePreloaded_IntrOp, ComputeAccumulated_IntrOp, - LoopWsConfigBounds_IntrOp, LoopWsConfigAddrsAB_IntrOp, - LoopWsConfigAddrsDC_IntrOp, LoopWsConfigStridesAB_IntrOp, - LoopWsConfigStridesDC_IntrOp, LoopWs_IntrOp, LoopConvWsConfig1_IntrOp, - LoopConvWsConfig2_IntrOp, LoopConvWsConfig3_IntrOp, - LoopConvWsConfig4_IntrOp, LoopConvWsConfig5_IntrOp, - LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp, ConfigNorm_IntrOp>(); - target.addIllegalOp(); -} + void + mlir::configureGemminiLegalizeForExportTarget(LLVMConversionTarget &target) { + target.addLegalOp< + Flush_IntrOp, ConfigSt_IntrOp, ConifgLd_IntrOp, ConfigEX_IntrOp, + Mvin_IntrOp, Mvin2_IntrOp, Mvin3_IntrOp, Mvout_IntrOp, Preload_IntrOp, + ComputePreloaded_IntrOp, ComputeAccumulated_IntrOp, + LoopWsConfigBounds_IntrOp, LoopWsConfigAddrsAB_IntrOp, + LoopWsConfigAddrsDC_IntrOp, LoopWsConfigStridesAB_IntrOp, + LoopWsConfigStridesDC_IntrOp, LoopWs_IntrOp, LoopConvWsConfig1_IntrOp, + LoopConvWsConfig2_IntrOp, LoopConvWsConfig3_IntrOp, + LoopConvWsConfig4_IntrOp, LoopConvWsConfig5_IntrOp, + LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp, ConfigNorm_IntrOp>(); + target.addIllegalOp(); + }