Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FXML-4614] Add EmitC index types, lower arith.index_cast, arith.index_castui #183

Merged
merged 27 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f8c754c
Add EmitC types for indices
cferry-AMD May 7, 2024
8468027
TranslateToCpp emit types
cferry-AMD May 10, 2024
5803f48
Add type converter
cferry-AMD May 10, 2024
49b27f6
Draft: Type converter & tests
cferry-AMD May 10, 2024
3578257
Fix missing symbol dep
cferry-AMD May 13, 2024
1e4dbf1
Issue the new type correctly
cferry-AMD May 13, 2024
bf70671
Play around with type converter, attempting to convert value
cferry-AMD May 13, 2024
de9fff6
Don't convert value, add hack to accept diverging attr type
cferry-AMD May 14, 2024
6a04667
Rename types
cferry-AMD May 14, 2024
590611d
Incorporate trunci upstreaming changes
cferry-AMD May 14, 2024
9639d1f
Add index_cast support, add/sub/mul index ops broken
cferry-AMD May 14, 2024
dd54066
Convert types of constants, fix arith ops test
cferry-AMD May 14, 2024
fae6c49
Remove hack comment, proper way
cferry-AMD May 14, 2024
8dc639c
handle cmpi ops
cferry-AMD May 14, 2024
76d21c1
Add predicates to filter only index types
cferry-AMD May 14, 2024
85a8255
Fix error messages
cferry-AMD May 14, 2024
6bcb8f6
Remove type converter test, already tested in ArithToEmitC
cferry-AMD May 14, 2024
e86a2dc
Review comments, try to fix linker issue on CI
cferry-AMD May 14, 2024
9e7be39
Remove EmitC_SizeType
cferry-AMD May 14, 2024
378a55b
ArithToEmitC depends on EmitCTransforms
cferry-AMD May 15, 2024
9cec233
Factor in type signedness & value type adaptation
cferry-AMD May 15, 2024
59ae2c9
Remove redundant type checks, EmitCTypes includes our new types
cferry-AMD May 15, 2024
c93adb3
Comments
cferry-AMD May 15, 2024
a9348c3
Nits
cferry-AMD May 15, 2024
78f8460
Review comments
cferry-AMD May 15, 2024
ff6dcb8
Add cast op folder
cferry-AMD May 15, 2024
e284808
Handle to-bool cast conversions
cferry-AMD May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ bool isIntegerIndexOrOpaqueType(Type type);

/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);

/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isAnySizeTType(mlir::Type type);

} // namespace emitc
} // namespace mlir

Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
def CExpression : NativeOpTrait<"emitc::CExpression">;

// Types only used in binary arithmetic operations.
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index,
EmitC_SignedSizeT, EmitC_SizeT, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;

def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
Expand Down Expand Up @@ -287,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
let arguments = (ins EmitCType:$source);
let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
}

def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
Expand Down Expand Up @@ -470,7 +472,7 @@ def EmitC_ForOp : EmitC_Op<"for",
upper bound and step respectively, and defines an SSA value for its
induction variable. It has one region capturing the loop body. The induction
variable is represented as an argument of this region. This SSA value is a
signless integer or index. The step is a value of same type.
signless integer, or an index. The step is a value of same type.

This operation has no result. The body region must contain exactly one block
that terminates with `emitc.yield`. Calling ForOp::build will create such a
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,12 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
let assemblyFormat = "`<` qualified($pointee) `>`";
}

def EmitC_SignedSizeT : EmitC_Type<"SignedSizeT", "ssize_t"> {
let summary = "EmitC signed size type";
}

def EmitC_SizeT : EmitC_Type<"SizeT", "size_t"> {
let summary = "EmitC unsigned size type";
}

#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===- TypeConversions.h - Convert signless types into C/C++ types -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
void populateEmitCSizeTypeConversionPatterns(mlir::TypeConverter &converter);
} // namespace mlir
160 changes: 90 additions & 70 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
matchAndRewrite(arith::ConstantOp arithConst,
arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
arithConst, arithConst.getType(), adaptor.getValue());
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
if (!newTy)
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
adaptor.getValue());
return success();
}
};
Expand Down Expand Up @@ -201,6 +205,35 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
}
};

/// Check if the signedness of type \p ty matches the expected
/// signedness, and issue a type with the correct signedness if
/// necessary.
Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
if (isa<IntegerType>(ty)) {
// Turns signless integers into signed integers.
if (ty.isUnsignedInteger() != needsUnsigned) {
auto signedness = needsUnsigned
? IntegerType::SignednessSemantics::Unsigned
: IntegerType::SignednessSemantics::Signed;
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
signedness);
}
} else if (emitc::isAnySizeTType(ty)) {
if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
if (needsUnsigned)
return emitc::SizeTType::get(ty.getContext());
return emitc::SignedSizeTType::get(ty.getContext());
}
}
return ty;
}

/// Insert a cast operation to type \p ty if \p val
/// does not have this type.
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
}

class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -250,31 +283,25 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = adaptor.getLhs().getType();
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer or index type");
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
type)) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t type");
}

bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
Type arithmeticType = type;
if (type.isUnsignedInteger() != needsUnsigned) {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/!needsUnsigned);
}
Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}

Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
return success();
}
};

template <typename ArithOp, bool needsUnsigned>
template <typename ArithOp, bool castToUnsigned>
class CastConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
Expand All @@ -284,52 +311,42 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType>(opReturnType)) {
return rewriter.notifyMatchFailure(op, "expected integer result type");
}
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
opReturnType))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t result type");

if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
op, "CastConversion only supports unary ops");
}

Type operandType = adaptor.getIn().getType();
if (!isa_and_nonnull<IntegerType>(operandType)) {
return rewriter.notifyMatchFailure(op, "expected integer operand type");
}
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
operandType))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t operand type");

bool isTruncation = operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth();
bool doUnsigned = needsUnsigned || isTruncation;

Type castType = opReturnType;
// For int conversions: if the op is a ui variant and the type wanted as
// return type isn't unsigned, we need to issue an unsigned type to do
// the conversion.
if (castType.isUnsignedInteger() != doUnsigned) {
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
}
bool isTruncation =
mgehre-amd marked this conversation as resolved.
Show resolved Hide resolved
(isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth());
bool doUnsigned = castToUnsigned || isTruncation;

Value actualOp = adaptor.getIn();
// Fix the signedness of the operand if necessary
if (operandType.isUnsignedInteger() != doUnsigned) {
Type correctSignednessType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
actualOp = rewriter.template create<emitc::CastOp>(
op.getLoc(), correctSignednessType, actualOp);
}
// Adapt the signedness of the result (bitwidth-preserving cast)
// This is needed e.g., if the return type is signless.
Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);

auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
actualOp);
// Adapt the signedness of the operand (bitwidth-preserving cast)
Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);

// Fix the signedness of what this operation returns (for integers,
// the arith ops want signless results)
if (castType != opReturnType) {
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
opReturnType, result);
}
// Actual cast (may change bitwidth)
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
castDestType, actualOp);

// Cast to the expected output type
auto result = adaptValueType(cast, rewriter, opReturnType);

rewriter.replaceOp(op, result);
return success();
Expand All @@ -355,7 +372,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithOp,
"converting result type failed");
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
adaptor.getOperands());

return success();
Expand All @@ -372,17 +393,17 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer type");
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
type)) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t type");
}

if (type.isInteger(1)) {
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}

Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
Expand All @@ -392,20 +413,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}

Value result = rewriter.template create<EmitCOp>(op.getLoc(),
arithmeticType, lhs, rhs);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

Value arithmeticResult = rewriter.template create<EmitCOp>(
op.getLoc(), arithmeticType, lhs, rhs);

Value result = adaptValueType(arithmeticResult, rewriter, type);

if (arithmeticType != type) {
result =
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
}
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -535,6 +551,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();

mlir::populateEmitCSizeTypeConversionPatterns(typeConverter);

// clang-format off
patterns.add<
ArithConstantOpConversionPattern,
Expand All @@ -554,6 +572,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
SignedCastConversion<arith::IndexCastOp>,
UnsignedCastConversion<arith::IndexCastUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
MLIREmitCTransforms
MLIRPass
MLIRTransformUtils
)
31 changes: 25 additions & 6 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ bool mlir::emitc::isSupportedEmitCType(Type type) {
return !llvm::isa<emitc::ArrayType>(elemType) &&
isSupportedEmitCType(elemType);
}
if (type.isIndex())
if (type.isIndex() ||
llvm::isa<emitc::SignedSizeTType, emitc::SizeTType>(type))
return true;
if (llvm::isa<IntegerType>(type))
return isSupportedIntegerType(type);
Expand Down Expand Up @@ -109,7 +110,8 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
}

bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
return llvm::isa<IndexType, emitc::SignedSizeTType, emitc::SizeTType,
emitc::OpaqueType>(type) ||
isSupportedIntegerType(type);
}

Expand All @@ -126,6 +128,10 @@ bool mlir::emitc::isSupportedFloatType(Type type) {
return false;
}

bool mlir::emitc::isAnySizeTType(Type type) {
return isa<emitc::SignedSizeTType, emitc::SizeTType>(type);
}

/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
Expand All @@ -142,6 +148,10 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
Type resultType = op->getResult(0).getType();
Type attrType = cast<TypedAttr>(value).getType();

if (isa<emitc::SignedSizeTType, emitc::SizeTType>(resultType) &&
attrType.isIndex())
return success();

if (resultType != attrType)
return op->emitOpError()
<< "requires attribute to either be an #emitc.opaque attribute or "
Expand Down Expand Up @@ -226,10 +236,19 @@ LogicalResult emitc::AssignOp::verify() {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
Type input = inputs.front(), output = outputs.front();

return ((llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType>(input)) &&
(llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType>(output)));
return (
(llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType, emitc::SignedSizeTType, emitc::SizeTType>(
input)) &&
(llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType, emitc::SignedSizeTType, emitc::SizeTType>(
output)));
}

OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
if (getOperand().getType() == getResult().getType())
return getOperand();
return nullptr;
}

//===----------------------------------------------------------------------===//
Expand Down
Loading