Skip to content

Commit

Permalink
[CIR] Extend support for floating point attributes
Browse files Browse the repository at this point in the history
This commit extends the support for floating point attributes parsing by using
the new `AsmParser::parseFloat(fltSemnatics, APFloat&)` interface.
As a drive-by, this commit also harmonizes the cir.fp print/parse namespace
usage, and adds the constraint of supporting only "CIRFPType"s for cir.fp in
tablegen instead of verifying it manually in the parsing logic.
  • Loading branch information
orbiri committed Nov 2, 2024
1 parent 3ef67c1 commit 2a65d71
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 40 deletions.
13 changes: 10 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,20 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
let summary = "An attribute containing a floating-point value";
let description = [{
An fp attribute is a literal attribute that represents a floating-point
value of the specified floating-point type.
value of the specified floating-point type. Supporting only CIR FP types.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APFloat":$value);
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::cir::CIRFPTypeInterface">:$type,
APFloatParameter<"">:$value
);
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const APFloat &":$value), [{
return $_get(type.getContext(), type, value);
return $_get(type.getContext(), type.cast<CIRFPTypeInterface>(), value);
}]>,
AttrBuilder<(ins "Type":$type,
"const APFloat &":$value), [{
return $_get($_ctxt, type.cast<CIRFPTypeInterface>(), value);
}]>,
];
let extraClassDeclaration = [{
Expand Down
51 changes: 16 additions & 35 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
mlir::Type ty);
static mlir::ParseResult
parseFloatLiteral(mlir::AsmParser &parser,
mlir::FailureOr<llvm::APFloat> &value, mlir::Type ty);
mlir::FailureOr<llvm::APFloat> &value,
mlir::cir::CIRFPTypeInterface fpType);

static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
mlir::IntegerAttr &value);
Expand Down Expand Up @@ -311,50 +312,30 @@ LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// FPAttr definitions
//===----------------------------------------------------------------------===//

static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
mlir::Type ty) {
static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
p << value;
}

static mlir::ParseResult
parseFloatLiteral(mlir::AsmParser &parser,
mlir::FailureOr<llvm::APFloat> &value, mlir::Type ty) {
double rawValue;
if (parser.parseFloat(rawValue)) {
return parser.emitError(parser.getCurrentLocation(),
"expected floating-point value");
}

auto losesInfo = false;
value.emplace(rawValue);
static ParseResult parseFloatLiteral(AsmParser &parser,
FailureOr<APFloat> &value,
CIRFPTypeInterface fpType) {

auto tyFpInterface = dyn_cast<cir::CIRFPTypeInterface>(ty);
if (!tyFpInterface) {
// Parsing of the current floating-point literal has succeeded, but the
// given attribute type is invalid. This error will be reported later when
// the attribute is being verified.
return success();
}
APFloat parsedValue(0.0);
if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
return failure();

value->convert(tyFpInterface.getFloatSemantics(),
llvm::RoundingMode::TowardZero, &losesInfo);
value.emplace(parsedValue);
return success();
}

cir::FPAttr cir::FPAttr::getZero(mlir::Type type) {
return get(
type, APFloat::getZero(
mlir::cast<cir::CIRFPTypeInterface>(type).getFloatSemantics()));
FPAttr FPAttr::getZero(Type type) {
return get(type, APFloat::getZero(
type.cast<CIRFPTypeInterface>().getFloatSemantics()));
}

LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APFloat value) {
auto fltTypeInterface = mlir::dyn_cast<cir::CIRFPTypeInterface>(type);
if (!fltTypeInterface) {
emitError() << "expected floating-point type";
return failure();
}
if (APFloat::SemanticsToEnum(fltTypeInterface.getFloatSemantics()) !=
LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
CIRFPTypeInterface fpType, APFloat value) {
if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
APFloat::SemanticsToEnum(value.getSemantics())) {
emitError() << "floating-point semantics mismatch";
return failure();
Expand Down
25 changes: 25 additions & 0 deletions clang/test/CIR/IR/attribute.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: cir-opt %s -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s

cir.func @float_attrs_pass() {
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.float
float_attr = #cir.fp<2.> : !cir.float
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<-2.000000e+00> : !cir.float
float_attr = #cir.fp<-2.> : !cir.float
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.double
float_attr = #cir.fp<2.> : !cir.double
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.long_double<!cir.f80>
float_attr = #cir.fp<2.> : !cir.long_double<!cir.f80>
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = #cir.fp<2.000000e+00> : !cir.long_double<!cir.double>
float_attr = #cir.fp<2.> : !cir.long_double<!cir.double>
} : () -> ()
cir.return
}
90 changes: 90 additions & 0 deletions clang/test/CIR/IR/float.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// RUN: cir-opt %s | FileCheck %s

// Adapted from mlir/test/IR/parser.mlir

// CHECK-LABEL: @f32_special_values
cir.func @f32_special_values() {
// F32 signaling NaNs.
// CHECK: cir.const(#cir.fp<0x7F800001> : !cir.float) : !cir.float
%0 = cir.const(#cir.fp<0x7F800001> : !cir.float) : !cir.float
// CHECK: cir.const(#cir.fp<0x7FBFFFFF> : !cir.float) : !cir.float
%1 = cir.const(#cir.fp<0x7FBFFFFF> : !cir.float) : !cir.float

// F32 quiet NaNs.
// CHECK: cir.const(#cir.fp<0x7FC00000> : !cir.float) : !cir.float
%2 = cir.const(#cir.fp<0x7FC00000> : !cir.float) : !cir.float
// CHECK: cir.const(#cir.fp<0xFFFFFFFF> : !cir.float) : !cir.float
%3 = cir.const(#cir.fp<0xFFFFFFFF> : !cir.float) : !cir.float

// F32 positive infinity.
// CHECK: cir.const(#cir.fp<0x7F800000> : !cir.float) : !cir.float
%4 = cir.const(#cir.fp<0x7F800000> : !cir.float) : !cir.float
// F32 negative infinity.
// CHECK: cir.const(#cir.fp<0xFF800000> : !cir.float) : !cir.float
%5 = cir.const(#cir.fp<0xFF800000> : !cir.float) : !cir.float

cir.return
}

// CHECK-LABEL: @f64_special_values
cir.func @f64_special_values() {
// F64 signaling NaNs.
// CHECK: cir.const(#cir.fp<0x7FF0000000000001> : !cir.double) : !cir.double
%0 = cir.const(#cir.fp<0x7FF0000000000001> : !cir.double) : !cir.double
// CHECK: cir.const(#cir.fp<0x7FF8000000000000> : !cir.double) : !cir.double
%1 = cir.const(#cir.fp<0x7FF8000000000000> : !cir.double) : !cir.double

// F64 quiet NaNs.
// CHECK: cir.const(#cir.fp<0x7FF0000001000000> : !cir.double) : !cir.double
%2 = cir.const(#cir.fp<0x7FF0000001000000> : !cir.double) : !cir.double
// CHECK: cir.const(#cir.fp<0xFFF0000001000000> : !cir.double) : !cir.double
%3 = cir.const(#cir.fp<0xFFF0000001000000> : !cir.double) : !cir.double

// F64 positive infinity.
// CHECK: cir.const(#cir.fp<0x7FF0000000000000> : !cir.double) : !cir.double
%4 = cir.const(#cir.fp<0x7FF0000000000000> : !cir.double) : !cir.double
// F64 negative infinity.
// CHECK: cir.const(#cir.fp<0xFFF0000000000000> : !cir.double) : !cir.double
%5 = cir.const(#cir.fp<0xFFF0000000000000> : !cir.double) : !cir.double

// Check that values that can't be represented with the default format, use
// hex instead.
// CHECK: cir.const(#cir.fp<0xC1CDC00000000000> : !cir.double) : !cir.double
%6 = cir.const(#cir.fp<0xC1CDC00000000000> : !cir.double) : !cir.double

cir.return
}

// CHECK-LABEL: @f80_special_values
cir.func @f80_special_values() {
// F80 signaling NaNs.
// CHECK: cir.const(#cir.fp<0x7FFFE000000000000001> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
%0 = cir.const(#cir.fp<0x7FFFE000000000000001> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
// CHECK: cir.const(#cir.fp<0x7FFFB000000000000011> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
%1 = cir.const(#cir.fp<0x7FFFB000000000000011> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>

// F80 quiet NaNs.
// CHECK: cir.const(#cir.fp<0x7FFFC000000000100000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
%2 = cir.const(#cir.fp<0x7FFFC000000000100000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
// CHECK: cir.const(#cir.fp<0x7FFFE000000001000000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
%3 = cir.const(#cir.fp<0x7FFFE000000001000000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>

// F80 positive infinity.
// CHECK: cir.const(#cir.fp<0x7FFF8000000000000000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
%4 = cir.const(#cir.fp<0x7FFF8000000000000000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
// F80 negative infinity.
// CHECK: cir.const(#cir.fp<0xFFFF8000000000000000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>
%5 = cir.const(#cir.fp<0xFFFF8000000000000000> : !cir.long_double<!cir.f80>) : !cir.long_double<!cir.f80>

cir.return
}

// We want to print floats in exponential notation with 6 significant digits,
// but it may lead to precision loss when parsing back, in which case we print
// the decimal form instead.
// CHECK-LABEL: @f32_potential_precision_loss()
cir.func @f32_potential_precision_loss() {
// CHECK: cir.const(#cir.fp<1.23697901> : !cir.float) : !cir.float
%0 = cir.const(#cir.fp<1.23697901> : !cir.float) : !cir.float
cir.return
}
59 changes: 59 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1378,3 +1378,62 @@ module {
cir.return
}
}
// -----

// Type of the attribute must be a CIR floating point type

// expected-error @below {{invalid kind of type specified}}
cir.global external @f = #cir.fp<0.5> : !cir.int<s, 32>

// -----

// Value must be a floating point literal or integer literal

// expected-error @below {{expected floating point literal}}
cir.global external @f = #cir.fp<"blabla"> : !cir.float

// -----

// Integer value must be in the width of the floating point type

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC000000> : !cir.float

// -----

// Integer value must be in the width of the floating point type

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC000007FC0000000> : !cir.double

// -----

// Integer value must be in the width of the floating point type

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC0000007FC0000007FC000000> : !cir.long_double<!cir.f80>

// -----

// Long double with `double` semnatics should have a value that fits in a double.

// CHECK: cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double<!cir.f80>
cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double<!cir.f80>

// expected-error @below {{hexadecimal float constant out of range for type}}
cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double<!cir.double>

// -----

// Verify no need for type inside the attribute

// expected-error @below {{expected '>'}}
cir.global external @f = #cir.fp<0x7FC00000 : !cir.float> : !cir.float

// -----

// Verify literal must be hex or float

// expected-error @below {{unexpected decimal integer literal for a floating point value}}
// expected-note @below {{add a trailing dot to make the literal a float}}
cir.global external @f = #cir.fp<42> : !cir.float
2 changes: 1 addition & 1 deletion clang/test/CIR/Lowering/class.cir
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ module {
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %1 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %3 = llvm.mlir.constant(0.099999994 : f32) : f32
// CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32
// CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %5 = llvm.mlir.zero : !llvm.ptr
// CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"class.S1", (i32, f32, ptr)>
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/Lowering/struct.cir
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ module {
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"struct.S1", (i32, f32, ptr)>
// CHECK: %1 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"struct.S1", (i32, f32, ptr)>
// CHECK: %3 = llvm.mlir.constant(0.099999994 : f32) : f32
// CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32
// CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"struct.S1", (i32, f32, ptr)>
// CHECK: %5 = llvm.mlir.zero : !llvm.ptr
// CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"struct.S1", (i32, f32, ptr)>
Expand Down

0 comments on commit 2a65d71

Please sign in to comment.