Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[codegen][gpu] Adding support to generic op and flexible layout to pad_to_intrinsics on convolution #20073

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir::iree_compiler::Preprocessing {

Expand Down Expand Up @@ -170,13 +172,6 @@ getIntrinsics(linalg::LinalgOp linalgOp,
static void
padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargets) {
if (!isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
return;
}
// TODO: Handle other variants.
if (!isa<linalg::Conv2DNhwcHwcfOp>(linalgOp))
return;

// Early exit if cannot find intrinsics or if multiple executable targets.
SmallVector<GPUMatmulShapeType> intrinsics =
getIntrinsics(linalgOp, executableTargets);
Expand Down Expand Up @@ -209,11 +204,10 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,

int64_t mDim = convolutionDims->outputImage.back();
int64_t nDim = convolutionDims->outputChannel.front();
// TODO: Support NCHW convolutions. This is just a matmul_transpose_a,
// however the distribution patterns currently do not support that variant.
if (mDim > nDim) {
return;
}
// In NCHW convolutions, mDim > nDim and the position of the input with filter
// tensors will be swapped in igemm passes later.
bool isIGemmOperandSwapped = mDim > nDim;

int64_t kDim = convolutionDims->inputChannel.front();
int64_t mSize = bounds[mDim];
int64_t nSize = bounds[nDim];
Expand All @@ -226,8 +220,6 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType())
.getElementType();

// TODO: Generalize to other dimensions.
// Try to search for pad value and check only filter dimension is blocked.
SmallVector<std::array<int64_t, 3>> mnkPaddingCandidates;
for (const GPUMatmulShapeType &intrinsic : intrinsics) {

Expand All @@ -241,12 +233,17 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
return llvm::divideCeil(value, padTo) * padTo - value;
};

auto mIntrinsicSize =
isIGemmOperandSwapped ? intrinsic.nSizes[0] : intrinsic.mSizes[0];
auto nIntrinsicSize =
isIGemmOperandSwapped ? intrinsic.mSizes[0] : intrinsic.nSizes[0];

if (mSize % intrinsic.mSizes[0] != 0) {
mPadding = getPadding(mSize, intrinsic.mSizes[0]);
mPadding = getPadding(mSize, mIntrinsicSize);
}

if (nSize % intrinsic.nSizes[0] != 0) {
nPadding = getPadding(nSize, intrinsic.nSizes[0]);
nPadding = getPadding(nSize, nIntrinsicSize);
}

if (kSize % intrinsic.kSizes[0] != 0) {
Expand All @@ -268,32 +265,51 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,

Value newInput = linalgOp.getDpsInputOperand(0)->get();
Value newFilter = linalgOp.getDpsInputOperand(1)->get();
Value newOuts = linalgOp.getDpsInitOperand(0)->get();
Value newOutput = linalgOp.getDpsInitOperand(0)->get();

auto indexingMaps = linalgOp.getIndexingMapsArray();
auto inputMap = indexingMaps[0];
auto filterMap = indexingMaps[1];
auto outputMap = indexingMaps[2];

Location loc = linalgOp.getLoc();
OpFoldResult mPadding = rewriter.getIndexAttr(mnkPadding[0]);
OpFoldResult nPadding = rewriter.getIndexAttr(mnkPadding[1]);
OpFoldResult kPadding = rewriter.getIndexAttr(mnkPadding[2]);
OpFoldResult zero = rewriter.getIndexAttr(0);
if (!isConstantIntValue(mPadding, 0) || !isConstantIntValue(kPadding, 0)) {
// For NHWC, the m-padding is for W and k-padding is for C
newInput = getPaddedValue(rewriter, loc, newInput,
{zero, zero, mPadding, kPadding});
}
if (!isConstantIntValue(nPadding, 0) || !isConstantIntValue(kPadding, 0)) {
// For HWCF, the n-padding is for F and k-padding is for C
newFilter = getPaddedValue(rewriter, loc, newFilter,
{zero, zero, kPadding, nPadding});
}
if (!isConstantIntValue(mPadding, 0) || !isConstantIntValue(nPadding, 0)) {
// For output, the m-padding is for W and k-padding is for F
newOuts = getPaddedValue(rewriter, loc, newOuts,
{zero, zero, mPadding, nPadding});
}

auto createExprToIdMap = [](AffineMap map) {
llvm::SmallDenseMap<AffineExpr, unsigned> exprToIdMap;
for (unsigned i = 0; i < map.getNumResults(); ++i) {
exprToIdMap[map.getResult(i)] = i;
}
return exprToIdMap;
};

auto applyPadding = [&](AffineMap map, OpFoldResult padding1,
OpFoldResult padding2, unsigned dim1, unsigned dim2,
Value &paddingTarget) {
if (!isConstantIntValue(padding1, 0) || !isConstantIntValue(padding2, 0)) {
llvm::SmallDenseMap<AffineExpr, unsigned> exprToIdMap =
createExprToIdMap(map);
auto id1 = exprToIdMap[getAffineDimExpr(dim1, map.getContext())];
auto id2 = exprToIdMap[getAffineDimExpr(dim2, map.getContext())];

llvm::SmallVector<OpFoldResult> paddingValues(4, zero);
paddingValues[id1] = padding1;
paddingValues[id2] = padding2;
paddingTarget =
getPaddedValue(rewriter, loc, paddingTarget, paddingValues);
}
};

applyPadding(inputMap, mPadding, kPadding, mDim, kDim, newInput);
applyPadding(filterMap, nPadding, kPadding, nDim, kDim, newFilter);
applyPadding(outputMap, mPadding, nPadding, mDim, nDim, newOutput);

linalg::LinalgOp paddedConv2dOp =
mlir::clone(rewriter, linalgOp, {newOuts.getType()},
ArrayRef<Value>{newInput, newFilter, newOuts});
mlir::clone(rewriter, linalgOp, {newOutput.getType()},
ArrayRef<Value>{newInput, newFilter, newOutput});
// Extract slice.
IntegerAttr one = rewriter.getI64IntegerAttr(1);
SmallVector<OpFoldResult> offsets(4, zero);
Expand Down Expand Up @@ -562,8 +578,7 @@ void PadToIntrinsicsPass::runOnOperation() {
SmallVector<linalg::LinalgOp> targetContractOps;
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp.getOperation()) &&
padConvOps) {
if (linalg::isaConvolutionOpInterface(linalgOp) && padConvOps) {
targetConvOps.push_back(linalgOp);
} else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp,
linalg::MatmulTransposeBOp>(linalgOp.getOperation()) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,77 @@ func.func @main0(%arg0: tensor<2x130x130x4xf16>, %arg1: tensor<3x3x4x320xf16>, %

// -----

// CHECK-LABEL: func.func @conv_nchw_fchw(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x4x130x130xf16>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<320x4x3x3xf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<2x320x128x128xf32>)
func.func @conv_nchw_fchw(%arg0: tensor<2x4x130x130xf16>, %arg1: tensor<320x4x3x3xf16>, %arg2: tensor<2x320x128x128xf32>)
-> tensor<2x320x128x128xf32> {
%conv0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
ins(%arg0, %arg1 : tensor<2x4x130x130xf16>, tensor<320x4x3x3xf16>)
outs(%arg2 : tensor<2x320x128x128xf32>) -> tensor<2x320x128x128xf32>
return %conv0 : tensor<2x320x128x128xf32>
}

// CHECK: %[[CST0:.+]] = arith.constant 0.0{{.*}} : f16
// CHECK: %[[PAD0:.+]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 12, 0, 0]
// CHECK: tensor.yield %[[CST0]] : f16
// CHECK-NEXT: tensor<2x4x130x130xf16> to tensor<2x16x130x130xf16>
// CHECK: %[[PAD1:.+]] = tensor.pad %[[ARG1]] low[0, 0, 0, 0] high[0, 12, 0, 0]
// CHECK: tensor<320x4x3x3xf16> to tensor<320x16x3x3xf16>
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
// CHECK-SAME: ins(%[[PAD0]], %[[PAD1]] : tensor<2x16x130x130xf16>, tensor<320x16x3x3xf16>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x320x128x128xf32>)
// CHECK: return %[[CONV]] : tensor<2x320x128x128xf32>

// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]
// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]

// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]
// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 12, 0, 0]

// -----

#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>

// CHECK-LABEL: func.func @conv_generic_nhwc_fhwc(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x130x130x4xf16>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<320x3x3x4xf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<2x128x128x320xf32>)
func.func @conv_generic_nhwc_fhwc(%arg0: tensor<2x130x130x4xf16>, %arg1: tensor<320x3x3x4xf16>, %arg2: tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> {
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x130x130x4xf16>, tensor<320x3x3x4xf16>) outs(%arg2 : tensor<2x128x128x320xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%2 = arith.extf %in : f16 to f32
%3 = arith.extf %in_0 : f16 to f32
%4 = arith.mulf %2, %3 : f32
%5 = arith.addf %out, %4 : f32
linalg.yield %5 : f32
} -> tensor<2x128x128x320xf32>
return %1 : tensor<2x128x128x320xf32>
}

// CHECK: %[[CST0:.+]] = arith.constant 0.0{{.*}} : f16
// CHECK: %[[PAD0:.+]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 0, 0, 12]
// CHECK: tensor.yield %[[CST0]] : f16
// CHECK-NEXT: tensor<2x130x130x4xf16> to tensor<2x130x130x16xf16>
// CHECK: %[[PAD1:.+]] = tensor.pad %[[ARG1]] low[0, 0, 0, 0] high[0, 0, 0, 12]
// CHECK: tensor<320x3x3x4xf16> to tensor<320x3x3x16xf16>
// CHECK: %[[CONV:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map, #map1, #map2],
// CHECK-SAME: ins(%[[PAD0]], %[[PAD1]] : tensor<2x130x130x16xf16>, tensor<320x3x3x16xf16>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x128x128x320xf32>)
// CHECK: return %[[CONV]] : tensor<2x128x128x320xf32>

// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
// CONVOLUTION: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]

// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
// CONTRACT-NOT: tensor.pad {{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]

// -----

// CHECK-LABEL: func.func @main1(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x130x130x320xf16>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<3x3x320x4xf16>,
Expand Down
Loading