From 58cd2afe551daf39e46e207384265b711cbce3ff Mon Sep 17 00:00:00 2001 From: Or Biri Date: Sun, 28 Apr 2024 15:56:45 +0300 Subject: [PATCH] [CIR] Extend support for floating point attributes This commit extends the support for floating point attributes parsing by using the new `AsmParser::parseFloat(fltSemnatics, APFloat&)` interface. As a drive-by, this commit also harmonizes the cir.fp print/parse namespace usage, and adds the constraint of supporting only "CIRFPType"s for cir.fp in tablegen instead of verifying it manually in the parsing logic. --- .../include/clang/CIR/Dialect/IR/CIRAttrs.td | 25 ++++-- clang/lib/CIR/Dialect/IR/CIRAttrs.cpp | 52 ++++------- clang/test/CIR/IR/attribute.cir | 25 ++++++ clang/test/CIR/IR/float.cir | 90 +++++++++++++++++++ clang/test/CIR/IR/invalid.cir | 59 ++++++++++++ clang/test/CIR/Lowering/class.cir | 2 +- clang/test/CIR/Lowering/struct.cir | 2 +- 7 files changed, 209 insertions(+), 46 deletions(-) create mode 100644 clang/test/CIR/IR/attribute.cir create mode 100644 clang/test/CIR/IR/float.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index a6a27006f357..a81ac5037caa 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -285,13 +285,20 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> { let summary = "An attribute containing a floating-point value"; let description = [{ An fp attribute is a literal attribute that represents a floating-point - value of the specified floating-point type. + value of the specified floating-point type. Supporting only CIR FP types. }]; - let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APFloat":$value); + let parameters = (ins + AttributeSelfTypeParameter<"", "::mlir::cir::CIRFPTypeInterface">:$type, + APFloatParameter<"">:$value + ); let builders = [ AttrBuilderWithInferredContext<(ins "Type":$type, "const APFloat &":$value), [{ - return $_get(type.getContext(), type, value); + return $_get(type.getContext(), mlir::cast(type), value); + }]>, + AttrBuilder<(ins "Type":$type, + "const APFloat &":$value), [{ + return $_get($_ctxt, mlir::cast(type), value); }]>, ]; let extraClassDeclaration = [{ @@ -319,7 +326,7 @@ def ComplexAttr : CIR_Attr<"Complex", "complex", [TypedAttrInterface]> { contains values of the same CIR type. }]; - let parameters = (ins + let parameters = (ins AttributeSelfTypeParameter<"", "mlir::cir::ComplexType">:$type, "mlir::TypedAttr":$real, "mlir::TypedAttr":$imag); @@ -820,7 +827,7 @@ def AddressSpaceAttr : CIR_Attr<"AddressSpace", "addrspace"> { let extraClassDeclaration = [{ static constexpr char kTargetKeyword[] = "}]#targetASCase.symbol#[{"; static constexpr int32_t kFirstTargetASValue = }]#targetASCase.value#[{; - + bool isLang() const; bool isTarget() const; unsigned getTargetValue() const; @@ -980,7 +987,7 @@ def ASTCallExprAttr : AST<"CallExpr", "call.expr", // VisibilityAttr //===----------------------------------------------------------------------===// -def VK_Default : I32EnumAttrCase<"Default", 1, "default">; +def VK_Default : I32EnumAttrCase<"Default", 1, "default">; def VK_Hidden : I32EnumAttrCase<"Hidden", 2, "hidden">; def VK_Protected : I32EnumAttrCase<"Protected", 3, "protected">; @@ -1013,7 +1020,7 @@ def VisibilityAttr : CIR_Attr<"Visibility", "visibility"> { bool isDefault() const { return getValue() == VisibilityKind::Default; }; bool isHidden() const { return getValue() == VisibilityKind::Hidden; }; bool isProtected() const { return getValue() == VisibilityKind::Protected; }; - }]; + }]; } @@ -1160,7 +1167,7 @@ def AnnotationAttr : CIR_Attr<"Annotation", "annotation"> { let parameters = (ins "StringAttr":$name, "ArrayAttr":$args); - let assemblyFormat = "`<` struct($name, $args) `>`"; + let assemblyFormat = "`<` struct($name, $args) `>`"; let extraClassDeclaration = [{ bool isNoArgs() const { return getArgs().empty(); }; @@ -1187,7 +1194,7 @@ def GlobalAnnotationValuesAttr : CIR_Attr<"GlobalAnnotationValues", void *c __attribute__((annotate("noargvar"))); void foo(int i) __attribute__((annotate("noargfunc"))) {} ``` - After CIR lowering prepare pass, compiler generates a + After CIR lowering prepare pass, compiler generates a `GlobalAnnotationValuesAttr` like the following: ``` #cir &value, mlir::Type ty); + mlir::FailureOr &value, + mlir::cir::CIRFPTypeInterface fpType); static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser, mlir::IntegerAttr &value); @@ -311,50 +312,31 @@ LogicalResult IntAttr::verify(function_ref emitError, // FPAttr definitions //===----------------------------------------------------------------------===// -static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value, - mlir::Type ty) { +static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) { p << value; } -static mlir::ParseResult -parseFloatLiteral(mlir::AsmParser &parser, - mlir::FailureOr &value, mlir::Type ty) { - double rawValue; - if (parser.parseFloat(rawValue)) { - return parser.emitError(parser.getCurrentLocation(), - "expected floating-point value"); - } - - auto losesInfo = false; - value.emplace(rawValue); +static ParseResult parseFloatLiteral(AsmParser &parser, + FailureOr &value, + CIRFPTypeInterface fpType) { - auto tyFpInterface = dyn_cast(ty); - if (!tyFpInterface) { - // Parsing of the current floating-point literal has succeeded, but the - // given attribute type is invalid. This error will be reported later when - // the attribute is being verified. - return success(); - } + APFloat parsedValue(0.0); + if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue)) + return failure(); - value->convert(tyFpInterface.getFloatSemantics(), - llvm::RoundingMode::TowardZero, &losesInfo); + value.emplace(parsedValue); return success(); } -cir::FPAttr cir::FPAttr::getZero(mlir::Type type) { - return get( - type, APFloat::getZero( - mlir::cast(type).getFloatSemantics())); +FPAttr FPAttr::getZero(Type type) { + return get(type, + APFloat::getZero( + mlir::cast(type).getFloatSemantics())); } -LogicalResult cir::FPAttr::verify(function_ref emitError, - Type type, APFloat value) { - auto fltTypeInterface = mlir::dyn_cast(type); - if (!fltTypeInterface) { - emitError() << "expected floating-point type"; - return failure(); - } - if (APFloat::SemanticsToEnum(fltTypeInterface.getFloatSemantics()) != +LogicalResult FPAttr::verify(function_ref emitError, + CIRFPTypeInterface fpType, APFloat value) { + if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) != APFloat::SemanticsToEnum(value.getSemantics())) { emitError() << "floating-point semantics mismatch"; return failure(); diff --git a/clang/test/CIR/IR/attribute.cir b/clang/test/CIR/IR/attribute.cir new file mode 100644 index 000000000000..4c9d4083ad4a --- /dev/null +++ b/clang/test/CIR/IR/attribute.cir @@ -0,0 +1,25 @@ +// RUN: cir-opt %s -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s + +cir.func @float_attrs_pass() { + "test.float_attrs"() { + // CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.float + float_attr = #cir.fp<2.> : !cir.float + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = #cir.fp<-2.000000e+00> : !cir.float + float_attr = #cir.fp<-2.> : !cir.float + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.double + float_attr = #cir.fp<2.> : !cir.double + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.long_double + float_attr = #cir.fp<2.> : !cir.long_double + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.long_double + float_attr = #cir.fp<2.> : !cir.long_double + } : () -> () + cir.return +} \ No newline at end of file diff --git a/clang/test/CIR/IR/float.cir b/clang/test/CIR/IR/float.cir new file mode 100644 index 000000000000..1be52c339ad5 --- /dev/null +++ b/clang/test/CIR/IR/float.cir @@ -0,0 +1,90 @@ +// RUN: cir-opt %s | FileCheck %s + +// Adapted from mlir/test/IR/parser.mlir + +// CHECK-LABEL: @f32_special_values +cir.func @f32_special_values() { + // F32 signaling NaNs. + // CHECK: cir.const #cir.fp<0x7F800001> : !cir.float + %0 = cir.const #cir.fp<0x7F800001> : !cir.float + // CHECK: cir.const #cir.fp<0x7FBFFFFF> : !cir.float + %1 = cir.const #cir.fp<0x7FBFFFFF> : !cir.float + + // F32 quiet NaNs. + // CHECK: cir.const #cir.fp<0x7FC00000> : !cir.float + %2 = cir.const #cir.fp<0x7FC00000> : !cir.float + // CHECK: cir.const #cir.fp<0xFFFFFFFF> : !cir.float + %3 = cir.const #cir.fp<0xFFFFFFFF> : !cir.float + + // F32 positive infinity. + // CHECK: cir.const #cir.fp<0x7F800000> : !cir.float + %4 = cir.const #cir.fp<0x7F800000> : !cir.float + // F32 negative infinity. + // CHECK: cir.const #cir.fp<0xFF800000> : !cir.float + %5 = cir.const #cir.fp<0xFF800000> : !cir.float + + cir.return +} + +// CHECK-LABEL: @f64_special_values +cir.func @f64_special_values() { + // F64 signaling NaNs. + // CHECK: cir.const #cir.fp<0x7FF0000000000001> : !cir.double + %0 = cir.const #cir.fp<0x7FF0000000000001> : !cir.double + // CHECK: cir.const #cir.fp<0x7FF8000000000000> : !cir.double + %1 = cir.const #cir.fp<0x7FF8000000000000> : !cir.double + + // F64 quiet NaNs. + // CHECK: cir.const #cir.fp<0x7FF0000001000000> : !cir.double + %2 = cir.const #cir.fp<0x7FF0000001000000> : !cir.double + // CHECK: cir.const #cir.fp<0xFFF0000001000000> : !cir.double + %3 = cir.const #cir.fp<0xFFF0000001000000> : !cir.double + + // F64 positive infinity. + // CHECK: cir.const #cir.fp<0x7FF0000000000000> : !cir.double + %4 = cir.const #cir.fp<0x7FF0000000000000> : !cir.double + // F64 negative infinity. + // CHECK: cir.const #cir.fp<0xFFF0000000000000> : !cir.double + %5 = cir.const #cir.fp<0xFFF0000000000000> : !cir.double + + // Check that values that can't be represented with the default format, use + // hex instead. + // CHECK: cir.const #cir.fp<0xC1CDC00000000000> : !cir.double + %6 = cir.const #cir.fp<0xC1CDC00000000000> : !cir.double + + cir.return +} + +// CHECK-LABEL: @f80_special_values +cir.func @f80_special_values() { + // F80 signaling NaNs. + // CHECK: cir.const #cir.fp<0x7FFFE000000000000001> : !cir.long_double + %0 = cir.const #cir.fp<0x7FFFE000000000000001> : !cir.long_double + // CHECK: cir.const #cir.fp<0x7FFFB000000000000011> : !cir.long_double + %1 = cir.const #cir.fp<0x7FFFB000000000000011> : !cir.long_double + + // F80 quiet NaNs. + // CHECK: cir.const #cir.fp<0x7FFFC000000000100000> : !cir.long_double + %2 = cir.const #cir.fp<0x7FFFC000000000100000> : !cir.long_double + // CHECK: cir.const #cir.fp<0x7FFFE000000001000000> : !cir.long_double + %3 = cir.const #cir.fp<0x7FFFE000000001000000> : !cir.long_double + + // F80 positive infinity. + // CHECK: cir.const #cir.fp<0x7FFF8000000000000000> : !cir.long_double + %4 = cir.const #cir.fp<0x7FFF8000000000000000> : !cir.long_double + // F80 negative infinity. + // CHECK: cir.const #cir.fp<0xFFFF8000000000000000> : !cir.long_double + %5 = cir.const #cir.fp<0xFFFF8000000000000000> : !cir.long_double + + cir.return +} + +// We want to print floats in exponential notation with 6 significant digits, +// but it may lead to precision loss when parsing back, in which case we print +// the decimal form instead. +// CHECK-LABEL: @f32_potential_precision_loss() +cir.func @f32_potential_precision_loss() { + // CHECK: cir.const #cir.fp<1.23697901> : !cir.float + %0 = cir.const #cir.fp<1.23697901> : !cir.float + cir.return +} diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 9df6e0c858fb..316ac3080797 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -1378,3 +1378,62 @@ module { cir.return } } +// ----- + +// Type of the attribute must be a CIR floating point type + +// expected-error @below {{invalid kind of type specified}} +cir.global external @f = #cir.fp<0.5> : !cir.int + +// ----- + +// Value must be a floating point literal or integer literal + +// expected-error @below {{expected floating point literal}} +cir.global external @f = #cir.fp<"blabla"> : !cir.float + +// ----- + +// Integer value must be in the width of the floating point type + +// expected-error @below {{hexadecimal float constant out of range for type}} +cir.global external @f = #cir.fp<0x7FC000000> : !cir.float + +// ----- + +// Integer value must be in the width of the floating point type + +// expected-error @below {{hexadecimal float constant out of range for type}} +cir.global external @f = #cir.fp<0x7FC000007FC0000000> : !cir.double + +// ----- + +// Integer value must be in the width of the floating point type + +// expected-error @below {{hexadecimal float constant out of range for type}} +cir.global external @f = #cir.fp<0x7FC0000007FC0000007FC000000> : !cir.long_double + +// ----- + +// Long double with `double` semnatics should have a value that fits in a double. + +// CHECK: cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double +cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double + +// expected-error @below {{hexadecimal float constant out of range for type}} +cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double + +// ----- + +// Verify no need for type inside the attribute + +// expected-error @below {{expected '>'}} +cir.global external @f = #cir.fp<0x7FC00000 : !cir.float> : !cir.float + +// ----- + +// Verify literal must be hex or float + +// expected-error @below {{unexpected decimal integer literal for a floating point value}} +// expected-note @below {{add a trailing dot to make the literal a float}} +cir.global external @f = #cir.fp<42> : !cir.float diff --git a/clang/test/CIR/Lowering/class.cir b/clang/test/CIR/Lowering/class.cir index dd028f4c3b7d..4f0c25151179 100644 --- a/clang/test/CIR/Lowering/class.cir +++ b/clang/test/CIR/Lowering/class.cir @@ -44,7 +44,7 @@ module { // CHECK: %0 = llvm.mlir.undef : !llvm.struct<"class.S1", (i32, f32, ptr)> // CHECK: %1 = llvm.mlir.constant(1 : i32) : i32 // CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"class.S1", (i32, f32, ptr)> - // CHECK: %3 = llvm.mlir.constant(0.099999994 : f32) : f32 + // CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32 // CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"class.S1", (i32, f32, ptr)> // CHECK: %5 = llvm.mlir.zero : !llvm.ptr // CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"class.S1", (i32, f32, ptr)> diff --git a/clang/test/CIR/Lowering/struct.cir b/clang/test/CIR/Lowering/struct.cir index a1a3d352c8a1..c89a58a9772e 100644 --- a/clang/test/CIR/Lowering/struct.cir +++ b/clang/test/CIR/Lowering/struct.cir @@ -44,7 +44,7 @@ module { // CHECK: %0 = llvm.mlir.undef : !llvm.struct<"struct.S1", (i32, f32, ptr)> // CHECK: %1 = llvm.mlir.constant(1 : i32) : i32 // CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"struct.S1", (i32, f32, ptr)> - // CHECK: %3 = llvm.mlir.constant(0.099999994 : f32) : f32 + // CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32 // CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"struct.S1", (i32, f32, ptr)> // CHECK: %5 = llvm.mlir.zero : !llvm.ptr // CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"struct.S1", (i32, f32, ptr)>