Skip to content

Commit

Permalink
Adding support to conv generic and flexible layout to pad_to_intrinsics
Browse files Browse the repository at this point in the history
Signed-off-by: jerryyin <[email protected]>
  • Loading branch information
jerryyin committed Feb 24, 2025
1 parent 308d176 commit 795c322
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 36 deletions.
88 changes: 52 additions & 36 deletions compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir::iree_compiler::Preprocessing {

Expand Down Expand Up @@ -170,13 +173,6 @@ getIntrinsics(linalg::LinalgOp linalgOp,
static void
padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargets) {
if (!isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
return;
}
// TODO: Handle other variants.
if (!isa<linalg::Conv2DNhwcHwcfOp>(linalgOp))
return;

// Early exit if cannot find intrinsics or if multiple executable targets.
SmallVector<GPUMatmulShapeType> intrinsics =
getIntrinsics(linalgOp, executableTargets);
Expand Down Expand Up @@ -209,11 +205,10 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,

int64_t mDim = convolutionDims->outputImage.back();
int64_t nDim = convolutionDims->outputChannel.front();
// TODO: Support NCHW convolutions. This is just a matmul_transpose_a,
// however the distribution patterns currently do not support that variant.
if (mDim > nDim) {
return;
}
// In NCHW convolutions, mDim > nDim and the position of the input with filter
// tensors will be swapped in igemm passes later.
bool isIGemmOperandSwapped = mDim > nDim;

int64_t kDim = convolutionDims->inputChannel.front();
int64_t mSize = bounds[mDim];
int64_t nSize = bounds[nDim];
Expand All @@ -226,8 +221,6 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType())
.getElementType();

// TODO: Generalize to other dimensions.
// Try to search for pad value and check only filter dimension is blocked.
SmallVector<std::array<int64_t, 3>> mnkPaddingCandidates;
for (const GPUMatmulShapeType &intrinsic : intrinsics) {

Expand All @@ -241,12 +234,17 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
return llvm::divideCeil(value, padTo) * padTo - value;
};

auto mIntrinsicSize =
isIGemmOperandSwapped ? intrinsic.nSizes[0] : intrinsic.mSizes[0];
auto nIntrinsicSize =
isIGemmOperandSwapped ? intrinsic.mSizes[0] : intrinsic.nSizes[0];

if (mSize % intrinsic.mSizes[0] != 0) {
mPadding = getPadding(mSize, intrinsic.mSizes[0]);
mPadding = getPadding(mSize, mIntrinsicSize);
}

if (nSize % intrinsic.nSizes[0] != 0) {
nPadding = getPadding(nSize, intrinsic.nSizes[0]);
nPadding = getPadding(nSize, nIntrinsicSize);
}

if (kSize % intrinsic.kSizes[0] != 0) {
Expand All @@ -268,32 +266,51 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,

Value newInput = linalgOp.getDpsInputOperand(0)->get();
Value newFilter = linalgOp.getDpsInputOperand(1)->get();
Value newOuts = linalgOp.getDpsInitOperand(0)->get();
Value newOutput = linalgOp.getDpsInitOperand(0)->get();

auto indexingMaps = linalgOp.getIndexingMapsArray();
auto inputMap = indexingMaps[0];
auto filterMap = indexingMaps[1];
auto outputMap = indexingMaps[2];

Location loc = linalgOp.getLoc();
OpFoldResult mPadding = rewriter.getIndexAttr(mnkPadding[0]);
OpFoldResult nPadding = rewriter.getIndexAttr(mnkPadding[1]);
OpFoldResult kPadding = rewriter.getIndexAttr(mnkPadding[2]);
OpFoldResult zero = rewriter.getIndexAttr(0);
if (!isConstantIntValue(mPadding, 0) || !isConstantIntValue(kPadding, 0)) {
// For NHWC, the m-padding is for W and k-padding is for C
newInput = getPaddedValue(rewriter, loc, newInput,
{zero, zero, mPadding, kPadding});
}
if (!isConstantIntValue(nPadding, 0) || !isConstantIntValue(kPadding, 0)) {
// For HWCF, the n-padding is for F and k-padding is for C
newFilter = getPaddedValue(rewriter, loc, newFilter,
{zero, zero, kPadding, nPadding});
}
if (!isConstantIntValue(mPadding, 0) || !isConstantIntValue(nPadding, 0)) {
// For output, the m-padding is for W and k-padding is for F
newOuts = getPaddedValue(rewriter, loc, newOuts,
{zero, zero, mPadding, nPadding});
}

auto createExprToIdMap = [](AffineMap map) {
llvm::SmallDenseMap<AffineExpr, unsigned> exprToIdMap;
for (unsigned i = 0; i < map.getNumResults(); ++i) {
exprToIdMap[map.getResult(i)] = i;
}
return exprToIdMap;
};

auto applyPadding = [&](AffineMap map, OpFoldResult padding1,
OpFoldResult padding2, unsigned dim1, unsigned dim2,
Value &paddingTarget) {
if (!isConstantIntValue(padding1, 0) || !isConstantIntValue(padding2, 0)) {
llvm::SmallDenseMap<AffineExpr, unsigned> exprToIdMap =
createExprToIdMap(map);
auto id1 = exprToIdMap[getAffineDimExpr(dim1, map.getContext())];
auto id2 = exprToIdMap[getAffineDimExpr(dim2, map.getContext())];

llvm::SmallVector<OpFoldResult> paddingValues(4, zero);
paddingValues[id1] = padding1;
paddingValues[id2] = padding2;
paddingTarget =
getPaddedValue(rewriter, loc, paddingTarget, paddingValues);
}
};

applyPadding(inputMap, mPadding, kPadding, mDim, kDim, newInput);
applyPadding(filterMap, nPadding, kPadding, nDim, kDim, newFilter);
applyPadding(outputMap, mPadding, nPadding, mDim, nDim, newOutput);

linalg::LinalgOp paddedConv2dOp =
mlir::clone(rewriter, linalgOp, {newOuts.getType()},
ArrayRef<Value>{newInput, newFilter, newOuts});
mlir::clone(rewriter, linalgOp, {newOutput.getType()},
ArrayRef<Value>{newInput, newFilter, newOutput});
// Extract slice.
IntegerAttr one = rewriter.getI64IntegerAttr(1);
SmallVector<OpFoldResult> offsets(4, zero);
Expand Down Expand Up @@ -562,8 +579,7 @@ void PadToIntrinsicsPass::runOnOperation() {
SmallVector<linalg::LinalgOp> targetContractOps;
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp.getOperation()) &&
padConvOps) {
if (linalg::isaConvolutionOpInterface(linalgOp) && padConvOps) {
targetConvOps.push_back(linalgOp);
} else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp,
linalg::MatmulTransposeBOp>(linalgOp.getOperation()) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,77 @@ func.func @main0(%arg0: tensor<2x130x130x4xf16>, %arg1: tensor<3x3x4x320xf16>, %

// -----

// CHECK-LABEL: func.func @conv_nchw_fchw(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x4x130x130xf16>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<320x4x3x3xf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<2x320x128x128xf32>)
func.func @conv_nchw_fchw(%arg0: tensor<2x4x130x130xf16>, %arg1: tensor<320x4x3x3xf16>, %arg2: tensor<2x320x128x128xf32>)
-> tensor<2x320x128x128xf32> {
%conv0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
ins(%arg0, %arg1 : tensor<2x4x130x130xf16>, tensor<320x4x3x3xf16>)
outs(%arg2 : tensor<2x320x128x128xf32>) -> tensor<2x320x128x128xf32>
return %conv0 : tensor<2x320x128x128xf32>
}

// CHECK: %[[CST0:.+]] = arith.constant 0.0{{.*}} : f16
// CHECK: %[[PAD0:.+]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 12, 0, 0]
// CHECK: tensor.yield %[[CST0]] : f16
// CHECK-NEXT: tensor<2x4x130x130xf16> to tensor<2x16x130x130xf16>
// CHECK: %[[PAD1:.+]] = tensor.pad %[[ARG1]] low[0, 0, 0, 0] high[0, 12, 0, 0]
// CHECK: tensor<320x4x3x3xf16> to tensor<320x16x3x3xf16>
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
// CHECK-SAME: ins(%[[PAD0]], %[[PAD1]] : tensor<2x16x130x130xf16>, tensor<320x16x3x3xf16>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x320x128x128xf32>)
// CHECK: return %[[CONV]] : tensor<2x320x128x128xf32>

// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]
// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]

// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]
// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]

// -----

#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>

// CHECK-LABEL: func.func @conv_generic_nhwc_fhwc(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x130x130x4xf16>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<320x3x3x4xf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<2x128x128x320xf32>)
func.func @conv_generic_nhwc_fhwc(%arg0: tensor<2x130x130x4xf16>, %arg1: tensor<320x3x3x4xf16>, %arg2: tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> {
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x130x130x4xf16>, tensor<320x3x3x4xf16>) outs(%arg2 : tensor<2x128x128x320xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%2 = arith.extf %in : f16 to f32
%3 = arith.extf %in_0 : f16 to f32
%4 = arith.mulf %2, %3 : f32
%5 = arith.addf %out, %4 : f32
linalg.yield %5 : f32
} -> tensor<2x128x128x320xf32>
return %1 : tensor<2x128x128x320xf32>
}

// CHECK: %[[CST0:.+]] = arith.constant 0.0{{.*}} : f16
// CHECK: %[[PAD0:.+]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 0, 0, 12]
// CHECK: tensor.yield %[[CST0]] : f16
// CHECK-NEXT: tensor<2x130x130x4xf16> to tensor<2x130x130x16xf16>
// CHECK: %[[PAD1:.+]] = tensor.pad %[[ARG1]] low[0, 0, 0, 0] high[0, 0, 0, 12]
// CHECK: tensor<320x3x3x4xf16> to tensor<320x3x3x16xf16>
// CHECK: %[[CONV:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map, #map1, #map2],
// CHECK-SAME: ins(%[[PAD0]], %[[PAD1]] : tensor<2x130x130x16xf16>, tensor<320x3x3x16xf16>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x128x128x320xf32>)
// CHECK: return %[[CONV]] : tensor<2x128x128x320xf32>

// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]

// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]

// -----

// CHECK-LABEL: func.func @main1(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x130x130x320xf16>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<3x3x320x4xf16>,
Expand Down

0 comments on commit 795c322

Please sign in to comment.