Skip to content

Commit

Permalink
dialects: Add constant bits ops in asl dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Dec 7, 2024
1 parent b100c26 commit c8c3c3f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 24 deletions.
79 changes: 56 additions & 23 deletions asl_xdsl/dialects/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ class BitVectorType(ParametrizedAttribute, TypeAttribute):

width: ParameterDef[builtin.IntAttr]

def __init__(self, width: int | builtin.IntAttr):
if isinstance(width, int):
width = builtin.IntAttr(width)
super().__init__([width])
def __init__(self, width: int):
super().__init__([builtin.IntAttr(width)])

@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
Expand All @@ -91,32 +89,27 @@ class BitVectorAttr(ParametrizedAttribute):
name = "asl.bits_attr"

value: ParameterDef[builtin.IntAttr]
width: ParameterDef[builtin.IntAttr]
type: ParameterDef[BitVectorType]

def maximum_value(self) -> int:
"""Return the maximum value that can be represented."""
return (1 << self.width.data) - 1
return (1 << self.type.width.data) - 1

@staticmethod
def normalize_value(value: int, width: int) -> int:
"""Normalize the value to the range [0, 2^width)."""
max_value = 1 << width
return ((value % max_value) + max_value) % max_value

def __init__(self, value: int | builtin.IntAttr, width: int | builtin.IntAttr):
if isinstance(value, int):
value = builtin.IntAttr(value)
if isinstance(width, int):
width = builtin.IntAttr(width)

value_int = value.data
value = builtin.IntAttr(self.normalize_value(value_int, width.data))
super().__init__([value, width])
def __init__(self, value: int, type: BitVectorType):
value = self.normalize_value(value, type.width.data)
super().__init__([builtin.IntAttr(value), type])

def _verify(self) -> None:
if self.value.data < 0 or self.value.data >= self.maximum_value():
raise VerifyException(
f"Value {self.value.data} is out of range for width {self.width.data}"
f"Value {self.value.data} is out of range "
f"for width {self.type.width.data}"
)

@classmethod
Expand All @@ -125,7 +118,7 @@ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
parser.parse_characters("<")
value = builtin.IntAttr(parser.parse_integer())
parser.parse_characters(":")
width = builtin.IntAttr(parser.parse_integer())
width = parser.parse_attribute()
parser.parse_characters(">")
return [value, width]

Expand All @@ -134,7 +127,7 @@ def print_parameters(self, printer: Printer) -> None:
printer.print("<")
printer.print(self.value.data)
printer.print(" : ")
printer.print(self.width.data)
printer.print(self.type.width.data)
printer.print(">")


Expand All @@ -147,12 +140,10 @@ class ConstantBoolOp(IRDLOperation):
value = prop_def(BoolAttr)
res = result_def(BoolType())

def __init__(self, value: bool | BoolAttr, attr_dict: Mapping[str, Attribute] = {}):
if isinstance(value, bool):
value = BoolAttr(value)
def __init__(self, value: bool, attr_dict: Mapping[str, Attribute] = {}):
super().__init__(
result_types=[BoolType()],
properties={"value": value},
properties={"value": BoolAttr(value)},
attributes=attr_dict,
)

Expand Down Expand Up @@ -206,8 +197,50 @@ def print(self, printer: Printer) -> None:
printer.print_attr_dict(self.attributes)


@irdl_op_definition
class ConstantBitVectorOp(IRDLOperation):
"""A constant bit vector operation."""

name = "asl.constant_bits"

value = prop_def(BitVectorAttr)
res = result_def(BitVectorType)

def __init__(
self,
value: BitVectorAttr,
attr_dict: Mapping[str, Attribute] = {},
) -> None:
super().__init__(
result_types=[value.type],
properties={"value": value},
attributes=attr_dict,
)

@classmethod
def parse(cls, parser: Parser) -> ConstantBitVectorOp:
"""Parse the operation."""
value = parser.parse_integer()
parser.parse_characters(":")

type = parser.parse_attribute()
if not isinstance(type, BitVectorType):
parser.raise_error(f"Expected bit vector type, got {type}")

value = BitVectorAttr(value, type)
attr_dict = parser.parse_optional_attr_dict()
return ConstantBitVectorOp(value, attr_dict)

def print(self, printer: Printer) -> None:
"""Print the operation."""
printer.print(" ", self.value.value.data, " : ", self.res.type)
if self.attributes:
printer.print(" ")
printer.print_attr_dict(self.attributes)


ASLDialect = Dialect(
"asl",
[ConstantBoolOp, ConstantIntOp],
[ConstantBoolOp, ConstantIntOp, ConstantBitVectorOp],
[BoolType, BoolAttr, IntegerType, BitVectorType, BitVectorAttr],
)
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/asl/constant_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ builtin.module {
%false = asl.constant_bool false {attr_dict}

%fourty_two = asl.constant_int 42 {attr_dict}

%fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict}
}

// CHECK: builtin.module {
// CHECK-NEXT: %true = asl.constant_bool true {"attr_dict"}
// CHECK-NEXT: %false = asl.constant_bool true {"attr_dict"}
// CHECK-NEXT: %fourty_two = asl.constant_int 42 {"attr_dict"}
// CHECK-NEXT: %fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {"attr_dict"}
// CHECK-NEXT: }
2 changes: 1 addition & 1 deletion tests/filecheck/dialects/asl/types_attrs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ builtin.module {
"test.op"() {int_type = !asl.int} : () -> ()

"test.op"() {bits_type = !asl.bits<32>} : () -> ()
"test.op"() {bits_attr = #asl.bits_attr<42 : 32>} : () -> ()
"test.op"() {bits_attr = #asl.bits_attr<42 : !asl.bits<32>>} : () -> ()
}

// CHECK: "test.op"() {"bool_type" = !asl.bool} : () -> ()
Expand Down

0 comments on commit c8c3c3f

Please sign in to comment.