diff --git a/asl_xdsl/dialects/asl.py b/asl_xdsl/dialects/asl.py index 0736516..7d7c58f 100644 --- a/asl_xdsl/dialects/asl.py +++ b/asl_xdsl/dialects/asl.py @@ -64,10 +64,8 @@ class BitVectorType(ParametrizedAttribute, TypeAttribute): width: ParameterDef[builtin.IntAttr] - def __init__(self, width: int | builtin.IntAttr): - if isinstance(width, int): - width = builtin.IntAttr(width) - super().__init__([width]) + def __init__(self, width: int): + super().__init__([builtin.IntAttr(width)]) @classmethod def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: @@ -91,11 +89,11 @@ class BitVectorAttr(ParametrizedAttribute): name = "asl.bits_attr" value: ParameterDef[builtin.IntAttr] - width: ParameterDef[builtin.IntAttr] + type: ParameterDef[BitVectorType] def maximum_value(self) -> int: """Return the maximum value that can be represented.""" - return (1 << self.width.data) - 1 + return (1 << self.type.width.data) - 1 @staticmethod def normalize_value(value: int, width: int) -> int: @@ -103,20 +101,15 @@ def normalize_value(value: int, width: int) -> int: max_value = 1 << width return ((value % max_value) + max_value) % max_value - def __init__(self, value: int | builtin.IntAttr, width: int | builtin.IntAttr): - if isinstance(value, int): - value = builtin.IntAttr(value) - if isinstance(width, int): - width = builtin.IntAttr(width) - - value_int = value.data - value = builtin.IntAttr(self.normalize_value(value_int, width.data)) - super().__init__([value, width]) + def __init__(self, value: int, type: BitVectorType): + value = self.normalize_value(value, type.width.data) + super().__init__([builtin.IntAttr(value), type]) def _verify(self) -> None: if self.value.data < 0 or self.value.data >= self.maximum_value(): raise VerifyException( - f"Value {self.value.data} is out of range for width {self.width.data}" + f"Value {self.value.data} is out of range " + f"for width {self.type.width.data}" ) @classmethod @@ -125,7 +118,7 @@ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: parser.parse_characters("<") value = builtin.IntAttr(parser.parse_integer()) parser.parse_characters(":") - width = builtin.IntAttr(parser.parse_integer()) + width = parser.parse_attribute() parser.parse_characters(">") return [value, width] @@ -134,7 +127,7 @@ def print_parameters(self, printer: Printer) -> None: printer.print("<") printer.print(self.value.data) printer.print(" : ") - printer.print(self.width.data) + printer.print(self.type.width.data) printer.print(">") @@ -147,12 +140,10 @@ class ConstantBoolOp(IRDLOperation): value = prop_def(BoolAttr) res = result_def(BoolType()) - def __init__(self, value: bool | BoolAttr, attr_dict: Mapping[str, Attribute] = {}): - if isinstance(value, bool): - value = BoolAttr(value) + def __init__(self, value: bool, attr_dict: Mapping[str, Attribute] = {}): super().__init__( result_types=[BoolType()], - properties={"value": value}, + properties={"value": BoolAttr(value)}, attributes=attr_dict, ) @@ -206,8 +197,50 @@ def print(self, printer: Printer) -> None: printer.print_attr_dict(self.attributes) +@irdl_op_definition +class ConstantBitVectorOp(IRDLOperation): + """A constant bit vector operation.""" + + name = "asl.constant_bits" + + value = prop_def(BitVectorAttr) + res = result_def(BitVectorType) + + def __init__( + self, + value: BitVectorAttr, + attr_dict: Mapping[str, Attribute] = {}, + ) -> None: + super().__init__( + result_types=[value.type], + properties={"value": value}, + attributes=attr_dict, + ) + + @classmethod + def parse(cls, parser: Parser) -> ConstantBitVectorOp: + """Parse the operation.""" + value = parser.parse_integer() + parser.parse_characters(":") + + type = parser.parse_attribute() + if not isinstance(type, BitVectorType): + parser.raise_error(f"Expected bit vector type, got {type}") + + value = BitVectorAttr(value, type) + attr_dict = parser.parse_optional_attr_dict() + return ConstantBitVectorOp(value, attr_dict) + + def print(self, printer: Printer) -> None: + """Print the operation.""" + printer.print(" ", self.value.value.data, " : ", self.res.type) + if self.attributes: + printer.print(" ") + printer.print_attr_dict(self.attributes) + + ASLDialect = Dialect( "asl", - [ConstantBoolOp, ConstantIntOp], + [ConstantBoolOp, ConstantIntOp, ConstantBitVectorOp], [BoolType, BoolAttr, IntegerType, BitVectorType, BitVectorAttr], ) diff --git a/tests/filecheck/dialects/asl/constant_ops.mlir b/tests/filecheck/dialects/asl/constant_ops.mlir index 23a9107..a015c45 100644 --- a/tests/filecheck/dialects/asl/constant_ops.mlir +++ b/tests/filecheck/dialects/asl/constant_ops.mlir @@ -5,10 +5,13 @@ builtin.module { %false = asl.constant_bool false {attr_dict} %fourty_two = asl.constant_int 42 {attr_dict} + + %fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict} } // CHECK: builtin.module { // CHECK-NEXT: %true = asl.constant_bool true {"attr_dict"} // CHECK-NEXT: %false = asl.constant_bool true {"attr_dict"} // CHECK-NEXT: %fourty_two = asl.constant_int 42 {"attr_dict"} +// CHECK-NEXT: %fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {"attr_dict"} // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/asl/types_attrs.mlir b/tests/filecheck/dialects/asl/types_attrs.mlir index f9495b5..88b1c2f 100644 --- a/tests/filecheck/dialects/asl/types_attrs.mlir +++ b/tests/filecheck/dialects/asl/types_attrs.mlir @@ -7,7 +7,7 @@ builtin.module { "test.op"() {int_type = !asl.int} : () -> () "test.op"() {bits_type = !asl.bits<32>} : () -> () - "test.op"() {bits_attr = #asl.bits_attr<42 : 32>} : () -> () + "test.op"() {bits_attr = #asl.bits_attr<42 : !asl.bits<32>>} : () -> () } // CHECK: "test.op"() {"bool_type" = !asl.bool} : () -> ()