Skip to content

Commit

Permalink
dialects: add asl dialect with builtin types and attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Dec 7, 2024
1 parent 2058668 commit 28bb142
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 1 deletion.
166 changes: 166 additions & 0 deletions asl_xdsl/dialects/asl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from __future__ import annotations

from collections.abc import Sequence

from xdsl.dialects import builtin
from xdsl.ir import (
Attribute,
Data,
Dialect,
ParametrizedAttribute,
TypeAttribute,
VerifyException,
)
from xdsl.irdl import ParameterDef, irdl_attr_definition
from xdsl.parser import AttrParser
from xdsl.printer import Printer


@irdl_attr_definition
class BoolType(ParametrizedAttribute, TypeAttribute):
"""A boolean type."""

name = "asl.bool"


@irdl_attr_definition
class BoolAttr(Data[bool]):
"""A boolean attribute."""

name = "asl.bool_attr"

@classmethod
def parse_parameter(cls, parser: AttrParser) -> bool:
"""Parse the attribute parameter."""
parser.parse_characters("<")
value = parser.parse_boolean()
parser.parse_characters(">")
return value

def print_parameter(self, printer: Printer) -> None:
"""Print the attribute parameter."""
printer.print("true" if self.data else "false")


@irdl_attr_definition
class IntegerType(ParametrizedAttribute, TypeAttribute):
"""An arbitrary-precision integer type."""

name = "asl.int"


@irdl_attr_definition
class IntegerAttr(ParametrizedAttribute):
"""An arbitrary-precision integer attribute."""

name = "asl.int_attr"

value: ParameterDef[builtin.IntAttr]

def __init__(self, value: int | builtin.IntAttr):
if isinstance(value, int):
value = builtin.IntAttr(value)
super().__init__([value])

@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
"""Parse the attribute parameters."""
parser.parse_characters("<")
value = builtin.IntAttr(parser.parse_integer())
parser.parse_characters(">")
return [value]

def print_parameters(self, printer: Printer) -> None:
"""Print the attribute parameters."""
printer.print("<")
printer.print(self.value.data)
printer.print(">")


@irdl_attr_definition
class BitVectorType(ParametrizedAttribute, TypeAttribute):
"""A bit vector type."""

name = "asl.bits"

width: ParameterDef[builtin.IntAttr]

def __init__(self, width: int | builtin.IntAttr):
if isinstance(width, int):
width = builtin.IntAttr(width)
super().__init__([width])

@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
"""Parse the attribute parameters."""
parser.parse_characters("<")
width = builtin.IntAttr(parser.parse_integer())
parser.parse_characters(">")
return [width]

def print_parameters(self, printer: Printer) -> None:
"""Print the attribute parameters."""
printer.print("<")
printer.print(self.width.data)
printer.print(">")


@irdl_attr_definition
class BitVectorAttr(ParametrizedAttribute):
"""A bit vector attribute."""

name = "asl.bits_attr"

value: ParameterDef[builtin.IntAttr]
width: ParameterDef[builtin.IntAttr]

def maximum_value(self) -> int:
"""Return the maximum value that can be represented."""
return (1 << self.width.data) - 1

@staticmethod
def normalize_value(value: int, width: int) -> int:
"""Normalize the value to the range [0, 2^width)."""
max_value = 1 << width
return ((value % max_value) + max_value) % max_value

def __init__(self, value: int | builtin.IntAttr, width: int | builtin.IntAttr):
if isinstance(value, int):
value = builtin.IntAttr(value)
if isinstance(width, int):
width = builtin.IntAttr(width)

value_int = value.data
value = builtin.IntAttr(self.normalize_value(value_int, width.data))
super().__init__([value, width])

def _verify(self) -> None:
if self.value.data < 0 or self.value.data >= self.maximum_value():
raise VerifyException(
f"Value {self.value.data} is out of range for width {self.width.data}"
)

@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
"""Parse the attribute parameters."""
parser.parse_characters("<")
value = builtin.IntAttr(parser.parse_integer())
parser.parse_characters(":")
width = builtin.IntAttr(parser.parse_integer())
parser.parse_characters(">")
return [value, width]

def print_parameters(self, printer: Printer) -> None:
"""Print the attribute parameters."""
printer.print("<")
printer.print(self.value.data)
printer.print(" : ")
printer.print(self.width.data)
printer.print(">")


ASLDialect = Dialect(
"asl",
[],
[BoolType, BoolAttr, IntegerType, IntegerAttr, BitVectorType, BitVectorAttr],
)
5 changes: 4 additions & 1 deletion asl_xdsl/tools/asl_opt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from xdsl.xdsl_opt_main import xDSLOptMain

from asl_xdsl.dialects.asl import ASLDialect


class ASLOptMain(xDSLOptMain):
def register_all_dialects(self):
return super().register_all_dialects()
super().register_all_dialects()
self.ctx.load_dialect(ASLDialect)

def register_all_passes(self):
return super().register_all_passes()
Expand Down
19 changes: 19 additions & 0 deletions tests/filecheck/dialects/asl/types_attrs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: asl-opt %s | asl-opt %s | filecheck %s

builtin.module {
"test.op"() {bool_type = !asl.bool} : () -> ()
"test.op"() {bool_true = #asl.bool_attr<true>, bool_false = #asl.bool_attr<false>} : () -> ()

"test.op"() {int_type = !asl.int} : () -> ()
"test.op"() {int_attr = #asl.int_attr<42>} : () -> ()

"test.op"() {bits_type = !asl.bits<32>} : () -> ()
"test.op"() {bits_attr = #asl.bits_attr<42 : 32>} : () -> ()
}

// CHECK: "test.op"() {"bool_type" = !asl.bool} : () -> ()
// CHECK-NEXT: "test.op"() {"bool_true" = #asl.bool_attrtrue, "bool_false" = #asl.bool_attrfalse} : () -> ()
// CHECK-NEXT: "test.op"() {"int_type" = !asl.int} : () -> ()
// CHECK-NEXT: "test.op"() {"int_attr" = #asl.int_attr<42>} : () -> ()
// CHECK-NEXT: "test.op"() {"bits_type" = !asl.bits<32>} : () -> ()
// CHECK-NEXT: "test.op"() {"bits_attr" = #asl.bits_attr<42 : 32>} : () -> ()

0 comments on commit 28bb142

Please sign in to comment.