Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] support decompose of byteir.softmax #454

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,69 @@ struct DecomposeByteIRAddN : public OpRewritePattern<mhlo::CustomCallOp> {
}
};

struct DecomposeByteIRSoftmax : public OpRewritePattern<mhlo::CustomCallOp> {
using OpRewritePattern<mhlo::CustomCallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::CustomCallOp op,
PatternRewriter &rewriter) const override {
if (op.getCallTargetName() != getSoftmaxName())
return failure();

DictionaryAttr byteirAttrs =
cast<DictionaryAttr>(op->getAttr(getCustomCallAttrName()));
if (!byteirAttrs)
return failure();
auto axisAttr = cast<IntegerAttr>(byteirAttrs.get("axis"));

RankedTensorType inType =
cast<RankedTensorType>(op.getOperand(0).getType());
Value exp = rewriter.create<mhlo::ExpOp>(op.getLoc(), op.getOperand(0));
Value reduce;
{
SmallVector<int64_t> reduceResultShape(inType.getShape());
reduceResultShape.erase(reduceResultShape.begin() + axisAttr.getInt());
RankedTensorType reduceResultType =
RankedTensorType::get(reduceResultShape, inType.getElementType());

Value initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inType.getElementType()),
{APFloat::getZero(cast<mlir::FloatType>(inType.getElementType())
.getFloatSemantics())}));
auto reduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), reduceResultType, exp, initValue,
rewriter.getI64TensorAttr({axisAttr.getInt()}));

Block &block = reduceOp.getBody().emplaceBlock();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto blockValArgumentType =
RankedTensorType::get({}, inType.getElementType());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockValArgumentType, op->getLoc());
auto *firstValArg = block.args_begin();
auto *secondValArg = std::next(firstValArg);
Value result = rewriter.create<mhlo::AddOp>(op->getLoc(), *firstValArg,
*secondValArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);

reduce = reduceOp.getResults()[0];
}

SmallVector broadcastDim =
llvm::to_vector(llvm::seq<int64_t>(0, inType.getRank()));
broadcastDim.erase(broadcastDim.begin() + axisAttr.getInt());
Value broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), inType, reduce,
rewriter.create<shape::ShapeOfOp>(op.getLoc(), exp),
rewriter.getI64TensorAttr(broadcastDim));
Value result = rewriter.create<mhlo::DivOp>(op->getLoc(), exp, broadcast);
rewriter.replaceOp(op, result);
return success();
}
};

struct DecomposeByteIRArgMaxMin : public OpRewritePattern<mhlo::CustomCallOp> {
DecomposeByteIRArgMaxMin(MLIRContext *context, llvm::StringRef customCallName)
: OpRewritePattern<mhlo::CustomCallOp>(context),
Expand Down Expand Up @@ -230,6 +293,9 @@ struct DecomposeMhloCustomCallOpsPass
if (!legalOpsSet.contains(getAddNName())) {
patterns.add<DecomposeByteIRAddN>(context);
}
if (!legalOpsSet.contains(getSoftmaxName())) {
patterns.add<DecomposeByteIRSoftmax>(context);
}
if (!legalOpsSet.contains(getArgMaxName())) {
patterns.add<DecomposeByteIRArgMaxMin>(context, getArgMaxName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ func.func @byteir.addn(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor
// CHECK: mhlo.add
// CHECK: return

func.func @byteir.softmax(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = mhlo.custom_call @byteir.softmax(%arg0) {byteir_attrs = {axis = 1 : i64}} : (tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
// CHECK-LABEL: func.func @byteir.softmax
// CHECK-NOT: byteir.softmax
// CHECK: mhlo.exp
// CHECK: mhlo.reduce
// CHECK-SAME: mhlo.add
// CHECK: mhlo.broadcast_in_dim
// CHECK: mhlo.div
// CHECK: return

func.func @byteir.arg_max$return_1(%arg0: tensor<3x4xf32>) -> tensor<3xi64> {
%0 = mhlo.custom_call @byteir.arg_max(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> tensor<3xi64>
return %0 : tensor<3xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
func.func @byteir.softmax(%arg0: tensor<10x128xf32>) -> tensor<10x128xf32> {
%0 = stablehlo.custom_call @byteir.softmax(%arg0) {byteir_attrs = {axis = 1 : i64}} : (tensor<10x128xf32>) -> tensor<10x128xf32>
return %0 : tensor<10x128xf32>
}
Loading