Skip to content

Commit

Permalink
[FXML-4320][mlir][emitc] Restrict integer and float types (#143)
Browse files Browse the repository at this point in the history
Restrict which integers types and floating point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports.
  • Loading branch information
TinaAMD authored Mar 19, 2024
1 parent 31b7dc7 commit e36f02b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 6 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
/// Determines whether \p type is a valid integer type in EmitC.
bool isValidEmitCIntegerType(mlir::Type type);
/// Determines whether \p type is a valid floating-point type in EmitC.
bool isValidEmitCFloatType(mlir::Type type);
} // namespace emitc
} // namespace mlir

Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
def CExpression : NativeOpTrait<"emitc::CExpression">;

// Types only used in binary arithmetic operations.
def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
def IntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Integer_Type, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Float_Type, IntegerIndexOrOpaqueType]>;

def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
let summary = "Addition operation";
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// EmitC type definitions
//===----------------------------------------------------------------------===//

def Valid_EmitC_Integer_Type : Type<CPred<"emitc::isValidEmitCIntegerType($_self)">,
"EmitC integer type">;

def Valid_EmitC_Float_Type : Type<CPred<"emitc::isValidEmitCFloatType($_self)">,
"EmitC floating-point type">;

class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<EmitC_Dialect, name, traits> {
let mnemonic = typeMnemonic;
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<emitc::YieldOp>(loc);
}

bool mlir::emitc::isValidEmitCIntegerType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
switch (intType.getWidth()) {
case 1:
case 8:
case 16:
case 32:
case 64:
return true;
default:
return false;
}
}
return false;
}

bool mlir::emitc::isValidEmitCFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
case 32:
case 64:
return true;
default:
return false;
}
}
return false;
}

/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,31 +170,31 @@ func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
// -----

func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// expected-error @+1 {{'emitc.div' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
// expected-error @+1 {{'emitc.div' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.div" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}

// -----

func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
// expected-error @+1 {{'emitc.mul' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.mul" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}

// -----

func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'tensor<i32>'}}
// expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.rem" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}

// -----

func.func @rem_float(%arg0: f32, %arg1: f32) {
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'f32'}}
// expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'f32'}}
%1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32
return
}
Expand Down

0 comments on commit e36f02b

Please sign in to comment.