Skip to content

Commit

Permalink
[compiler] support cpu codegen of byteir.arg_max with integer type (#469
Browse files Browse the repository at this point in the history
)

as title
  • Loading branch information
qingyunqu authored Oct 23, 2024
1 parent cdfc595 commit 4c96ed9
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 19 deletions.
82 changes: 63 additions & 19 deletions compiler/lib/Dialect/mhlo/Transforms/DecomposeMhloCustomCallOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,39 +143,74 @@ struct DecomposeByteIRArgMaxMin : public OpRewritePattern<mhlo::CustomCallOp> {

RankedTensorType inType =
cast<RankedTensorType>(op.getOperand(0).getType());
Type inElemType = inType.getElementType();
RankedTensorType outType, outIndexType;
if (op.getResults().size() == 1) {
outIndexType = cast<RankedTensorType>(op.getResults()[0].getType());
outType = outIndexType.clone(inType.getElementType());
outType = outIndexType.clone(inElemType);
} else if (op.getResults().size() == 2) {
outType = cast<RankedTensorType>(op.getResults()[0].getType());
outIndexType = cast<RankedTensorType>(op.getResults()[1].getType());
} else {
return op.emitError("unsupported result size");
}

if (!isa<mlir::FloatType>(inType.getElementType())) {
return op.emitError("only support float type");
if (!isa<mlir::FloatType, mlir::IntegerType>(inElemType)) {
return op.emitError("only support float or int type");
}

// create init values
Value initValue;
if (customCallName == getArgMaxName().str()) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inType.getElementType()),
{APFloat::getInf(cast<mlir::FloatType>(inType.getElementType())
.getFloatSemantics(),
/*negative=*/true)}));
if (isa<mlir::FloatType>(inElemType)) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inElemType),
{APFloat::getInf(
cast<mlir::FloatType>(inElemType).getFloatSemantics(),
/*negative=*/true)}));
} else if (isa<mlir::IntegerType>(inElemType)) {
if (cast<mlir::IntegerType>(inElemType).isSignless() &&
inElemType.getIntOrFloatBitWidth() != 1) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(RankedTensorType::get({}, inElemType),
{APInt::getSignedMinValue(
inElemType.getIntOrFloatBitWidth())}));
} else {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inElemType),
{APInt::getMinValue(inElemType.getIntOrFloatBitWidth())}));
}
}
} else if (customCallName == getArgMinName().str()) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inType.getElementType()),
{APFloat::getInf(cast<mlir::FloatType>(inType.getElementType())
.getFloatSemantics(),
/*negative=*/false)}));
if (isa<mlir::FloatType>(inElemType)) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inElemType),
{APFloat::getInf(
cast<mlir::FloatType>(inElemType).getFloatSemantics(),
/*negative=*/false)}));
} else if (isa<mlir::IntegerType>(inElemType)) {
if (cast<mlir::IntegerType>(inElemType).isSignless() &&
inElemType.getIntOrFloatBitWidth() != 1) {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(RankedTensorType::get({}, inElemType),
{APInt::getSignedMaxValue(
inElemType.getIntOrFloatBitWidth())}));
} else {
initValue = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
RankedTensorType::get({}, inElemType),
{APInt::getMaxValue(inElemType.getIntOrFloatBitWidth())}));
}
}
} else {
return op.emitError("unknown custom call name");
}
Expand Down Expand Up @@ -206,8 +241,7 @@ struct DecomposeByteIRArgMaxMin : public OpRewritePattern<mhlo::CustomCallOp> {
{
Block &block = reduceOp.getBody().emplaceBlock();
// Add block arguments
auto blockValArgumentType =
RankedTensorType::get({}, inType.getElementType());
auto blockValArgumentType = RankedTensorType::get({}, inElemType);
auto blockIdxArgumentType =
RankedTensorType::get({}, outIndexType.getElementType());
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
Expand All @@ -224,6 +258,16 @@ struct DecomposeByteIRArgMaxMin : public OpRewritePattern<mhlo::CustomCallOp> {

mhlo::ComparisonTypeAttr compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
if (isa<mlir::IntegerType>(inElemType)) {
if (cast<mlir::IntegerType>(inElemType).isSignless() &&
inElemType.getIntOrFloatBitWidth() != 1) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
} else {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::UNSIGNED);
}
}
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::GE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ func.func @byteir.arg_max$return_1(%arg0: tensor<3x4xf32>) -> tensor<3xi64> {
// CHECK: mhlo.return
// CHECK: return

func.func @byteir.arg_max.int$return_1(%arg0: tensor<3x4xi32>) -> tensor<3xi64> {
%0 = mhlo.custom_call @byteir.arg_max(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xi32>) -> tensor<3xi64>
return %0 : tensor<3xi64>
}
// CHECK-LABEL: func.func @byteir.arg_max.int$return_1
// CHECK-NOT: byteir.arg_max
// CHECK-DAG: mhlo.constant dense<-2147483648> : tensor<i32>
// CHECK-DAG: mhlo.constant dense<0> : tensor<i64>
// CHECK: mhlo.iota
// CHECK: mhlo.broadcast_in_dim
// CHECK: mhlo.reduce
// CHECK: mhlo.compare GE
// CHECK: mhlo.select
// CHECK: mhlo.compare
// CHECK: mhlo.minimum
// CHECK: mhlo.select
// CHECK: mhlo.select
// CHECK: mhlo.return
// CHECK: return

func.func @byteir.arg_max$return_2(%arg0: tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>) {
%0:2 = mhlo.custom_call @byteir.arg_max(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xi64>
Expand Down Expand Up @@ -83,6 +103,26 @@ func.func @byteir.arg_min$return_1(%arg0: tensor<3x4xf32>) -> tensor<3xi64> {
// CHECK: mhlo.return
// CHECK: return

func.func @byteir.arg_min.int$return_1(%arg0: tensor<3x4xi32>) -> tensor<3xi64> {
%0 = mhlo.custom_call @byteir.arg_min(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xi32>) -> tensor<3xi64>
return %0 : tensor<3xi64>
}
// CHECK-LABEL: func.func @byteir.arg_min.int$return_1
// CHECK-NOT: byteir.arg_min
// CHECK-DAG: mhlo.constant dense<2147483647> : tensor<i32>
// CHECK-DAG: mhlo.constant dense<0> : tensor<i64>
// CHECK: mhlo.iota
// CHECK: mhlo.broadcast_in_dim
// CHECK: mhlo.reduce
// CHECK: mhlo.compare LE
// CHECK: mhlo.select
// CHECK: mhlo.compare
// CHECK: mhlo.minimum
// CHECK: mhlo.select
// CHECK: mhlo.select
// CHECK: mhlo.return
// CHECK: return

func.func @byteir.arg_min$return_2(%arg0: tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>) {
%0:2 = mhlo.custom_call @byteir.arg_min(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x4xf32>) -> (tensor<3xf32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
func.func @byteir.arg_max$return_2(%arg0: tensor<3x128xi32>) -> (tensor<3xi32>, tensor<3xi64>) {
%0:2 = stablehlo.custom_call @byteir.arg_max(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x128xi32>) -> (tensor<3xi32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xi32>, tensor<3xi64>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
func.func @byteir.arg_min$return_2(%arg0: tensor<3x128xf32>) -> (tensor<3xf32>, tensor<3xi64>) {
%0:2 = stablehlo.custom_call @byteir.arg_min(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x128xf32>) -> (tensor<3xf32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xi64>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
func.func @byteir.arg_min$return_2(%arg0: tensor<3x128xi32>) -> (tensor<3xi32>, tensor<3xi64>) {
%0:2 = stablehlo.custom_call @byteir.arg_min(%arg0) {byteir_attrs = {axis = 1 : i64, keep_dims = false, select_last_index = false}} : (tensor<3x128xi32>) -> (tensor<3xi32>, tensor<3xi64>)
return %0#0, %0#1 : tensor<3xi32>, tensor<3xi64>
}

0 comments on commit 4c96ed9

Please sign in to comment.