diff --git a/asl_xdsl/dialects/asl.py b/asl_xdsl/dialects/asl.py index cc1d0b3..18c39d1 100644 --- a/asl_xdsl/dialects/asl.py +++ b/asl_xdsl/dialects/asl.py @@ -151,7 +151,14 @@ def parse_parameter(cls, parser: AttrParser) -> bool: def print_parameter(self, printer: Printer) -> None: """Print the attribute parameter.""" - printer.print("true" if self.data else "false") + printer.print("" if self.data else "") + + +@irdl_attr_definition +class StringType(ParametrizedAttribute, TypeAttribute): + """A string type.""" + + name = "asl.string" @irdl_attr_definition @@ -323,7 +330,7 @@ def __init__( @classmethod def parse(cls, parser: Parser) -> ConstantIntOp: """Parse the operation.""" - value = parser.parse_integer(allow_boolean=False, allow_negative=False) + value = parser.parse_integer(allow_boolean=False, allow_negative=True) attr_dict = parser.parse_optional_attr_dict() return ConstantIntOp(value, attr_dict) @@ -377,6 +384,29 @@ def print(self, printer: Printer) -> None: printer.print_attr_dict(self.attributes) +@irdl_op_definition +class ConstantStringOp(IRDLOperation): + """A constant string operation.""" + + name = "asl.constant_string" + + value = prop_def(builtin.StringAttr) + res = result_def(StringType) + + assembly_format = "$value attr-dict" + + def __init__( + self, value: str | builtin.StringAttr, attr_dict: Mapping[str, Attribute] = {} + ): + if isinstance(value, str): + value = builtin.StringAttr(value) + super().__init__( + result_types=[StringType()], + properties={"value": value}, + attributes=attr_dict, + ) + + @irdl_op_definition class NotOp(IRDLOperation): """A bitwise NOT operation.""" @@ -396,6 +426,25 @@ def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}): ) +@irdl_op_definition +class BoolToI1Op(IRDLOperation): + """A hack to convert !asl.bool to i1 so that we can use scf.if.""" + + name = "asl.bool_to_i1" + + arg = operand_def(BoolType()) + res = result_def(builtin.IntegerType(1)) + + assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict" + + def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}): + super().__init__( + operands=[arg], + result_types=[builtin.IntegerType(1)], + attributes=attr_dict, + ) + + class BinaryBoolOp(IRDLOperation): """A binary boolean operation.""" @@ -464,7 +513,7 @@ class EquivBoolOp(BinaryBoolOp): class NegateIntOp(IRDLOperation): """An integer negation operation.""" - name = "asl.negate_int" + name = "asl.neg_int" arg = operand_def(IntegerType) res = result_def(IntegerType) @@ -535,14 +584,14 @@ class ExpIntOp(BinaryIntOp): class ShiftLeftIntOp(BinaryIntOp): """An integer left shift operation.""" - name = "asl.shiftleft_int" + name = "asl.shl_int" @irdl_op_definition class ShiftRightIntOp(BinaryIntOp): """An integer right shift operation.""" - name = "asl.shiftright_int" + name = "asl.shr_int" @irdl_op_definition @@ -1052,7 +1101,9 @@ def __init__( ConstantBoolOp, ConstantIntOp, ConstantBitVectorOp, + ConstantStringOp, # Boolean operations + BoolToI1Op, NotOp, AndBoolOp, OrBoolOp, @@ -1095,5 +1146,12 @@ def __init__( # Slices SliceSingleOp, ], - [BoolType, BoolAttr, IntegerType, BitVectorType, BitVectorAttr], + [ + BoolType, + BoolAttr, + IntegerType, + BitVectorType, + BitVectorAttr, + StringType, + ], ) diff --git a/asl_xdsl/interpreters/asl.py b/asl_xdsl/interpreters/asl.py index 414b4af..08e606c 100644 --- a/asl_xdsl/interpreters/asl.py +++ b/asl_xdsl/interpreters/asl.py @@ -40,6 +40,14 @@ def run_call( ) -> tuple[Any, ...]: return interpreter.call_op(op.callee.string_value(), args) + @impl(asl.NegateIntOp) + def run_neg_int( + self, interpreter: Interpreter, op: asl.NegateIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + arg: int + [arg] = args + return (0 - arg,) + @impl(asl.AddIntOp) def run_add_int( self, interpreter: Interpreter, op: asl.AddIntOp, args: tuple[Any, ...] @@ -49,16 +57,132 @@ def run_add_int( (lhs, rhs) = args return (lhs + rhs,) + @impl(asl.SubIntOp) + def run_sub_int( + self, interpreter: Interpreter, op: asl.SubIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs - rhs,) + + @impl(asl.MulIntOp) + def run_mul_int( + self, interpreter: Interpreter, op: asl.MulIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs * rhs,) + + @impl(asl.ShiftLeftIntOp) + def run_shl_int( + self, interpreter: Interpreter, op: asl.ShiftLeftIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + assert rhs >= 0 + return (lhs << rhs,) + + @impl(asl.ShiftRightIntOp) + def run_shr_int( + self, interpreter: Interpreter, op: asl.ShiftRightIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + assert rhs >= 0 + return (lhs >> rhs,) + + @impl(asl.EqIntOp) + def run_eq_int( + self, interpreter: Interpreter, op: asl.EqIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs == rhs,) + + @impl(asl.NeIntOp) + def run_ne_int( + self, interpreter: Interpreter, op: asl.NeIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs != rhs,) + + @impl(asl.LeIntOp) + def run_le_int( + self, interpreter: Interpreter, op: asl.LeIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs <= rhs,) + + @impl(asl.LtIntOp) + def run_lt_int( + self, interpreter: Interpreter, op: asl.LtIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs < rhs,) + + @impl(asl.GeIntOp) + def run_ge_int( + self, interpreter: Interpreter, op: asl.GeIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs >= rhs,) + + @impl(asl.GtIntOp) + def run_gt_int( + self, interpreter: Interpreter, op: asl.GtIntOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + lhs: int + rhs: int + (lhs, rhs) = args + return (lhs > rhs,) + @impl(asl.ConstantIntOp) - def run_constant( + def run_constant_int( self, interpreter: Interpreter, op: asl.ConstantIntOp, args: PythonValues ) -> PythonValues: value = op.value return (value.data,) + @impl(asl.ConstantStringOp) + def run_constant_string( + self, interpreter: Interpreter, op: asl.ConstantStringOp, args: PythonValues + ) -> PythonValues: + value = op.value + return (value.data,) + + @impl(asl.BoolToI1Op) + def run_bool_to_i1( + self, interpreter: Interpreter, op: asl.BoolToI1Op, args: PythonValues + ) -> PythonValues: + arg: int + [arg] = args + return (arg,) + # region built-in function implementations - @impl_external("asl_print_int_dec") + @impl_external("print_bits_hex.0") + def asl_print_bits_hex( + self, interpreter: Interpreter, op: Operation, args: PythonValues + ) -> PythonValues: + arg: int + (arg,) = args + interpreter.print(hex(arg)) + return () + + @impl_external("print_int_dec.0") def asl_print_int_dec( self, interpreter: Interpreter, op: Operation, args: PythonValues ) -> PythonValues: @@ -67,7 +191,16 @@ def asl_print_int_dec( interpreter.print(arg) return () - @impl_external("asl_print_char") + @impl_external("print_int_hex.0") + def asl_print_int_hex( + self, interpreter: Interpreter, op: Operation, args: PythonValues + ) -> PythonValues: + arg: int + (arg,) = args + interpreter.print(hex(arg)) + return () + + @impl_external("print_char.0") def asl_print_char( self, interpreter: Interpreter, op: Operation, args: PythonValues ) -> PythonValues: @@ -76,4 +209,13 @@ def asl_print_char( interpreter.print(chr(arg)) return () + @impl_external("print_str.0") + def asl_print_string( + self, interpreter: Interpreter, op: Operation, args: PythonValues + ) -> PythonValues: + arg: str + (arg,) = args + interpreter.print(arg) + return () + # endregion diff --git a/tests/filecheck/dialects/asl/cf.mlir b/tests/filecheck/dialects/asl/cf.mlir new file mode 100644 index 0000000..142b597 --- /dev/null +++ b/tests/filecheck/dialects/asl/cf.mlir @@ -0,0 +1,25 @@ +// RUN: asl-opt %s | asl-opt %s | filecheck %s + +builtin.module { + asl.func @print_str.0(%x : !asl.string) -> () + %c = asl.constant_bool true {attr_dict} + %0 = asl.bool_to_i1 %c : !asl.bool -> i1 + scf.if %0 { + %1 = asl.constant_string "TRUE" {attr_dict} + asl.call @print_str.0(%1) : (!asl.string) -> () + } else { + %2 = asl.constant_string "FALSE" {attr_dict} + asl.call @print_str.0(%2) : (!asl.string) -> () + } + +// CHECK: %c = asl.constant_bool true {attr_dict} +// CHECK-NEXT: %0 = asl.bool_to_i1 %c : !asl.bool -> i1 +// CHECK-NEXT: scf.if %0 { +// CHECK-NEXT: %1 = asl.constant_string "TRUE" {attr_dict} +// CHECK-NEXT: asl.call @print_str.0(%1) : (!asl.string) -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: %2 = asl.constant_string "FALSE" {attr_dict} +// CHECK-NEXT: asl.call @print_str.0(%2) : (!asl.string) -> () +// CHECK-NEXT: } + +} diff --git a/tests/filecheck/dialects/asl/constant_ops.mlir b/tests/filecheck/dialects/asl/constant_ops.mlir index a96e056..0c15b27 100644 --- a/tests/filecheck/dialects/asl/constant_ops.mlir +++ b/tests/filecheck/dialects/asl/constant_ops.mlir @@ -7,6 +7,8 @@ builtin.module { %fourty_two = asl.constant_int 42 {attr_dict} %fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict} + + %fourty_two_string = asl.constant_string "Forty Two" {attr_dict} } // CHECK: builtin.module { @@ -14,4 +16,5 @@ builtin.module { // CHECK-NEXT: %false = asl.constant_bool true {attr_dict} // CHECK-NEXT: %fourty_two = asl.constant_int 42 {attr_dict} // CHECK-NEXT: %fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict} +// CHECK-NEXT: %fourty_two_string = asl.constant_string "Forty Two" {attr_dict} // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/asl/primitives.mlir b/tests/filecheck/dialects/asl/primitives.mlir index e5f541b..fb6b952 100644 --- a/tests/filecheck/dialects/asl/primitives.mlir +++ b/tests/filecheck/dialects/asl/primitives.mlir @@ -24,24 +24,24 @@ builtin.module { %int1, %int2 = "test.op"() : () -> (!asl.int, !asl.int) // CHECK-NEXT: %int1, %int2 = "test.op"() : () -> (!asl.int, !asl.int) - %negate_int = asl.negate_int %int1 : !asl.int -> !asl.int + %neg_int = asl.neg_int %int1 : !asl.int -> !asl.int %add_int = asl.add_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int %sub_int = asl.sub_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int %mul_int = asl.mul_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int %exp_int = asl.exp_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int - %shiftleft_int = asl.shiftleft_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int - %shiftright_int = asl.shiftright_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int + %shiftleft_int = asl.shl_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int + %shiftright_int = asl.shr_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int %div_int = asl.div_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int %fdiv_int = asl.fdiv_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int %frem_int = asl.frem_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int -// CHECK-NEXT: %negate_int = asl.negate_int %int1 +// CHECK-NEXT: %neg_int = asl.neg_int %int1 // CHECK-NEXT: %add_int = asl.add_int %int1, %int2 // CHECK-NEXT: %sub_int = asl.sub_int %int1, %int2 // CHECK-NEXT: %mul_int = asl.mul_int %int1, %int2 // CHECK-NEXT: %exp_int = asl.exp_int %int1, %int2 -// CHECK-NEXT: %shiftleft_int = asl.shiftleft_int %int1, %int2 -// CHECK-NEXT: %shiftright_int = asl.shiftright_int %int1, %int2 +// CHECK-NEXT: %shiftleft_int = asl.shl_int %int1, %int2 +// CHECK-NEXT: %shiftright_int = asl.shr_int %int1, %int2 // CHECK-NEXT: %div_int = asl.div_int %int1, %int2 // CHECK-NEXT: %fdiv_int = asl.fdiv_int %int1, %int2 // CHECK-NEXT: %frem_int = asl.frem_int %int1, %int2 diff --git a/tests/filecheck/dialects/asl/types_attrs.mlir b/tests/filecheck/dialects/asl/types_attrs.mlir index c8e4d30..ac2b60a 100644 --- a/tests/filecheck/dialects/asl/types_attrs.mlir +++ b/tests/filecheck/dialects/asl/types_attrs.mlir @@ -5,7 +5,7 @@ builtin.module { "test.op"() {bool_true = #asl.bool_attr, bool_false = #asl.bool_attr} : () -> () // CHECK: "test.op"() {bool_type = !asl.bool} : () -> () -// CHECK-NEXT: "test.op"() {bool_true = #asl.bool_attrtrue, bool_false = #asl.bool_attrfalse} : () -> () +// CHECK-NEXT: "test.op"() {bool_true = #asl.bool_attr, bool_false = #asl.bool_attr} : () -> () "test.op"() {int_type = !asl.int} : () -> () "test.op"() {constraint_int = !asl.int<42>} : () -> () @@ -24,4 +24,6 @@ builtin.module { // CHECK-NEXT: "test.op"() {bits_type = !asl.bits<32>} : () -> () // CHECK-NEXT: "test.op"() {bits_attr = #asl.bits_attr<42 : 32>} : () -> () + + "test.op"() {string_type = !asl.string} : () -> () } diff --git a/tests/filecheck/exec/t2.mlir b/tests/filecheck/exec/t2.mlir index 846946f..e93b9ef 100644 --- a/tests/filecheck/exec/t2.mlir +++ b/tests/filecheck/exec/t2.mlir @@ -2,14 +2,14 @@ // CHECK: 3 -asl.func @asl_print_int_dec(!asl.int) -asl.func @asl_print_char(!asl.int) +asl.func @print_int_dec.0(!asl.int) +asl.func @print_char.0(!asl.int) asl.func @main.0() -> !asl.int { %0 = asl.constant_int 1 {attr_dict} %1 = asl.constant_int 2 {attr_dict} %2 = asl.call @Test.0(%0, %1) : (!asl.int, !asl.int) -> !asl.int - asl.call @asl_print_int_dec (%2) : (!asl.int) -> () + asl.call @print_int_dec.0 (%2) : (!asl.int) -> () asl.call @println.0() : () -> () @@ -24,6 +24,6 @@ asl.func @Test.0(%x : !asl.int, %y : !asl.int) -> !asl.int { asl.func @println.0() -> () { %0 = asl.constant_int 10 {attr_dict} - asl.call @asl_print_char(%0) : (!asl.int) -> () + asl.call @print_char.0(%0) : (!asl.int) -> () asl.return }