diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index e85bb0f6b227b9..a3c3aafa7d33bb 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -39,6 +39,54 @@ class ArithConstantOpConversionPattern } }; +/// Return an operation that returns true (in i1) when operand is NaN. +emitc::CmpOp isNan(ConversionPatternRewriter &rewriter, Location loc, + Value operand) { + // A value is NaN exactly when it compares unequal to itself. + return rewriter.create( + loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); +} + +class CmpFOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + emitc::CmpPredicate predicate; + switch (op.getPredicate()) { + case arith::CmpFPredicate::UGT: + // unordered or greater than + predicate = emitc::CmpPredicate::gt; + break; + case arith::CmpFPredicate::ULT: + // unordered or less than + predicate = emitc::CmpPredicate::lt; + break; + case arith::CmpFPredicate::UNO: { + // unordered, i.e. either operand is nan + auto lhsIsNan = isNan(rewriter, op.getLoc(), adaptor.getLhs()); + if (adaptor.getLhs() == adaptor.getRhs()) { + rewriter.replaceOp(op, lhsIsNan); + return success(); + } + auto rhsIsNan = isNan(rewriter, op.getLoc(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp(op, op.getType(), lhsIsNan, + rhsIsNan); + return success(); + } + default: + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot match predicate "); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), predicate, adaptor.getLhs(), adaptor.getRhs()); + return success(); + } +}; + template class ArithOpConversion final : public OpConversionPattern { public: @@ -99,6 +147,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, ArithOpConversion, ArithOpConversion, ArithOpConversion, + CmpFOpConversion, SelectOpConversion >(typeConverter, ctx); // clang-format on diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index e5f2c330b851c3..c6bbf0a2890380 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -55,3 +55,16 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) - %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32> return } + +// ----- + +func.func @cmpf(%arg0: f32, %arg1: f32) { + // CHECK: emitc.cmp gt, %arg0, %arg1 : (f32, f32) -> i1 + %0 = arith.cmpf ugt, %arg0, %arg1 : f32 + // CHECK: emitc.cmp lt, %arg0, %arg1 : (f32, f32) -> i1 + %1 = arith.cmpf ult, %arg0, %arg1 : f32 + // CHECK: emitc.cmp ne, %arg0, %arg0 : (f32, f32) -> i1 + %2 = arith.cmpf uno, %arg0, %arg0 : f32 + + return +}