Skip to content

Commit

Permalink
[compiler] support decomposition of byteir.arg_max/arg_min
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Jul 27, 2024
1 parent 2b4231c commit 3e39320
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 1 deletion.
5 changes: 5 additions & 0 deletions compiler/include/byteir/Dialect/mhlo/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def DecomposeMhloCustomCallOps : Pass<"decompose-mhlo-custom-call-ops", "mlir::f
ListOption<"legalOps", "legal-ops", "std::string",
"List of custom call ops not to be decomposed">,
];
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::shape::ShapeDialect",
"mlir::tensor::TensorDialect",
];
}

//===----------------------------------------------------------------------===//
Expand Down
169 changes: 169 additions & 0 deletions compiler/lib/Dialect/mhlo/Transforms/DecomposeMhloCustomCallOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
#include "byteir/Dialect/mhlo/Transforms/DecomposeMhloCustomCallOps.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"

Expand Down Expand Up @@ -50,6 +53,166 @@ struct DecomposeByteIRAddN : public OpRewritePattern<mhlo::CustomCallOp> {
}
};

struct DecomposeByteIRArgMaxMin : public OpRewritePattern<mhlo::CustomCallOp> {
DecomposeByteIRArgMaxMin(MLIRContext *context, llvm::StringRef customCallName)
: OpRewritePattern<mhlo::CustomCallOp>(context),
customCallName(customCallName.str()) {}
LogicalResult matchAndRewrite(mhlo::CustomCallOp op,
PatternRewriter &rewriter) const override {
if (op.getCallTargetName() != customCallName)
return failure();

DictionaryAttr byteirAttrs =
cast<DictionaryAttr>(op->getAttr(getCustomCallAttrName()));
if (!byteirAttrs)
return failure();
auto axisAttr = cast<IntegerAttr>(byteirAttrs.get("axis"));
auto keepDimAttr = cast<BoolAttr>(byteirAttrs.get("keep_dims"));
auto selectLastIndexAttr =
cast<BoolAttr>(byteirAttrs.get("select_last_index"));
if (selectLastIndexAttr.getValue()) {
return op.emitError("unimplemented: select_last_index = true");
}
// TODO(lyq): support keep_dims = true
if (keepDimAttr.getValue()) {
return op.emitError("unimplemented: keep_dims = true");
}

RankedTensorType inType =
cast<RankedTensorType>(op.getOperand(0).getType());
RankedTensorType outType, outIndexType;
if (op.getResults().size() == 1) {
outIndexType = cast<RankedTensorType>(op.getResults()[0].getType());
outType = outIndexType.clone(inType.getElementType());
} else if (op.getResults().size() == 2) {
outType = cast<RankedTensorType>(op.getResults()[0].getType());
outIndexType = cast<RankedTensorType>(op.getResults()[1].getType());
} else {
return op.emitError("unsupported result size");
}

if (!isa<mlir::FloatType>(inType.getElementType())) {
return op.emitError("only support float type");
}

// create init values
Value initValue;
if (customCallName == getArgMaxName().str()) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inType.getElementType()),
{APFloat::getInf(cast<mlir::FloatType>(inType.getElementType())
.getFloatSemantics(),
/*negative=*/true)}));
} else if (customCallName == getArgMinName().str()) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inType.getElementType()),
{APFloat::getInf(cast<mlir::FloatType>(inType.getElementType())
.getFloatSemantics(),
/*negative=*/false)}));
} else {
return op.emitError("unknown custom call name");
}
Value initIndex = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, outIndexType.getElementType()),
{APInt::getZero(
outIndexType.getElementType().getIntOrFloatBitWidth())}));

llvm::SmallVector<Value> inputShapeVec;
for (int64_t i = 0; i < inType.getRank(); i++) {
inputShapeVec.push_back(rewriter.create<tensor::DimOp>(
op.getLoc(), op.getOperand(0),
rewriter.create<arith::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(i))));
}
Value inputShapeTensor =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), inputShapeVec);
Value indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
op.getLoc(), inType.clone(outIndexType.getElementType()),
inputShapeTensor, axisAttr);
auto reduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), TypeRange{outType, outIndexType},
ValueRange{op.getOperand(0), indexTensor},
ValueRange{initValue, initIndex},
rewriter.getI64TensorAttr({axisAttr.getInt()}));
{
Block &block = reduceOp.getBody().emplaceBlock();
// Add block arguments
auto blockValArgumentType =
RankedTensorType::get({}, inType.getElementType());
auto blockIdxArgumentType =
RankedTensorType::get({}, outIndexType.getElementType());
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());

block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());

auto *firstValArg = block.args_begin();
auto *firstIdxArg = std::next(firstValArg);
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);

mhlo::ComparisonTypeAttr compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareLeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::LE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::EQ);

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareResult;
if (customCallName == getArgMaxName().str()) {
compareResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
} else {
compareResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareLeDirectionAttr, compareTypeAttr);
}

Value retValResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareResult, *firstValArg, *secondValArg);

// get smaller index value if compared nums are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx = rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);

rewriter.create<mhlo::ReturnOp>(op->getLoc(),
ValueRange{retValResult, retIdxResult});
}

if (op.getResults().size() == 1) {
rewriter.replaceOp(op, reduceOp.getResults()[1]);
} else {
rewriter.replaceOp(op, reduceOp.getResults());
}
return success();
}

std::string customCallName;
};

struct DecomposeMhloCustomCallOpsPass
: public DecomposeMhloCustomCallOpsBase<DecomposeMhloCustomCallOpsPass> {
DecomposeMhloCustomCallOpsPass(ArrayRef<std::string> legalOps) {
Expand All @@ -67,6 +230,12 @@ struct DecomposeMhloCustomCallOpsPass
if (!legalOpsSet.contains(getAddNName())) {
patterns.add<DecomposeByteIRAddN>(context);
}
if (!legalOpsSet.contains(getArgMaxName())) {
patterns.add<DecomposeByteIRArgMaxMin>(context, getArgMaxName());
}
if (!legalOpsSet.contains(getArgMinName())) {
patterns.add<DecomposeByteIRArgMaxMin>(context, getArgMinName());
}

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns))) {
Expand Down
4 changes: 4 additions & 0 deletions compiler/lib/Dialect/mhlo/Transforms/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ namespace mhlo {
class MhloDialect;
} // namespace mhlo

namespace arith {
class ArithDialect;
} // namespace arith

namespace shape {
class ShapeDialect;
} // namespace shape
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: byteir-opt %s --decompose-mhlo-custom-call-ops | FileCheck %s
// RUN: byteir-opt %s --decompose-mhlo-custom-call-ops --canonicalize | FileCheck %s

func.func @byteir.addn(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> {
%0 = mhlo.custom_call @byteir.addn(%arg0, %arg1, %arg2) {byteir_attrs = {}} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
Expand All @@ -10,3 +10,82 @@ func.func @byteir.addn(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor
// CHECK: mhlo.add
// CHECK: return

func.func @byteir.arg_max$return_1(%arg0: tensor<3x4xf32>) -> tensor<3xi64> {
%0 = mhlo.custom_call @byteir.arg_max(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> tensor<3xi64>
return %0 : tensor<3xi64>
}
// CHECK-LABEL: func.func @byteir.arg_max$return_1
// CHECK-NOT: byteir.arg_max
// CHECK-DAG: mhlo.constant dense<0xFF800000> : tensor<f32>
// CHECK-DAG: mhlo.constant dense<0> : tensor<i64>
// CHECK: mhlo.iota
// CHECK: mhlo.broadcast_in_dim
// CHECK: mhlo.reduce
// CHECK: mhlo.compare GE
// CHECK: mhlo.select
// CHECK: mhlo.compare
// CHECK: mhlo.minimum
// CHECK: mhlo.select
// CHECK: mhlo.select
// CHECK: mhlo.return
// CHECK: return

func.func @byteir.arg_max$return_2(%arg0: tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>) {
%0:2 = mhlo.custom_call @byteir.arg_max(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xi64>
}
// CHECK-LABEL: func.func @byteir.arg_max$return_2
// CHECK-NOT: byteir.arg_max
// CHECK-DAG: mhlo.constant dense<0xFF800000> : tensor<f32>
// CHECK-DAG: mhlo.constant dense<0> : tensor<i64>
// CHECK: mhlo.iota
// CHECK: mhlo.broadcast_in_dim
// CHECK: mhlo.reduce
// CHECK: mhlo.compare GE
// CHECK: mhlo.select
// CHECK: mhlo.compare
// CHECK: mhlo.minimum
// CHECK: mhlo.select
// CHECK: mhlo.select
// CHECK: mhlo.return
// CHECK: return

func.func @byteir.arg_min$return_1(%arg0: tensor<3x4xf32>) -> tensor<3xi64> {
%0 = mhlo.custom_call @byteir.arg_min(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> tensor<3xi64>
return %0 : tensor<3xi64>
}
// CHECK-LABEL: func.func @byteir.arg_min$return_1
// CHECK-NOT: byteir.arg_min
// CHECK-DAG: mhlo.constant dense<0x7F800000> : tensor<f32>
// CHECK-DAG: mhlo.constant dense<0> : tensor<i64>
// CHECK: mhlo.iota
// CHECK: mhlo.broadcast_in_dim
// CHECK: mhlo.reduce
// CHECK: mhlo.compare LE
// CHECK: mhlo.select
// CHECK: mhlo.compare
// CHECK: mhlo.minimum
// CHECK: mhlo.select
// CHECK: mhlo.select
// CHECK: mhlo.return
// CHECK: return

func.func @byteir.arg_min$return_2(%arg0: tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>) {
%0:2 = mhlo.custom_call @byteir.arg_min(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xi64>
}
// CHECK-LABEL: func.func @byteir.arg_min$return_2
// CHECK-NOT: byteir.arg_min
// CHECK-DAG: mhlo.constant dense<0x7F800000> : tensor<f32>
// CHECK-DAG: mhlo.constant dense<0> : tensor<i64>
// CHECK: mhlo.iota
// CHECK: mhlo.broadcast_in_dim
// CHECK: mhlo.reduce
// CHECK: mhlo.compare LE
// CHECK: mhlo.select
// CHECK: mhlo.compare
// CHECK: mhlo.minimum
// CHECK: mhlo.select
// CHECK: mhlo.select
// CHECK: mhlo.return
// CHECK: return
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
func.func @byteir.arg_max$return_2(%arg0: tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>) {
%0:2 = stablehlo.custom_call @byteir.arg_max(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xi64>
}

0 comments on commit 3e39320

Please sign in to comment.