Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend ASL support: strings, comparisons, etc. #1

Merged
merged 8 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions asl_xdsl/dialects/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,14 @@ def parse_parameter(cls, parser: AttrParser) -> bool:

def print_parameter(self, printer: Printer) -> None:
"""Print the attribute parameter."""
printer.print("true" if self.data else "false")
printer.print("<true>" if self.data else "<false>")


@irdl_attr_definition
class StringType(ParametrizedAttribute, TypeAttribute):
"""A string type."""

name = "asl.string"


@irdl_attr_definition
Expand Down Expand Up @@ -323,7 +330,7 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> ConstantIntOp:
"""Parse the operation."""
value = parser.parse_integer(allow_boolean=False, allow_negative=False)
value = parser.parse_integer(allow_boolean=False, allow_negative=True)
attr_dict = parser.parse_optional_attr_dict()
return ConstantIntOp(value, attr_dict)

Expand Down Expand Up @@ -377,6 +384,29 @@ def print(self, printer: Printer) -> None:
printer.print_attr_dict(self.attributes)


@irdl_op_definition
class ConstantStringOp(IRDLOperation):
"""A constant string operation."""

name = "asl.constant_string"

value = prop_def(builtin.StringAttr)
res = result_def(StringType)

assembly_format = "$value attr-dict"

def __init__(
self, value: str | builtin.StringAttr, attr_dict: Mapping[str, Attribute] = {}
):
if isinstance(value, str):
value = builtin.StringAttr(value)
super().__init__(
result_types=[StringType()],
properties={"value": value},
attributes=attr_dict,
)


@irdl_op_definition
class NotOp(IRDLOperation):
"""A bitwise NOT operation."""
Expand All @@ -396,6 +426,25 @@ def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
)


@irdl_op_definition
class BoolToI1Op(IRDLOperation):
"""A hack to convert !asl.bool to i1 so that we can use scf.if."""

name = "asl.bool_to_i1"

arg = operand_def(BoolType())
res = result_def(builtin.IntegerType(1))

assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why, but the attr dict tends to be before the type by convention

Suggested change
assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"
assembly_format = "$arg attr-dict `:` type($arg) `->` type($res)"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, that's because I moved it to the end on the other operations.
Mostly because I hate being it at the beginning, because it is "optional" in a way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other dialects that do this? Seems to me like we should stick with the convention if not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not changed - unsure which to go with)


def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
super().__init__(
operands=[arg],
result_types=[builtin.IntegerType(1)],
attributes=attr_dict,
)


class BinaryBoolOp(IRDLOperation):
"""A binary boolean operation."""

Expand Down Expand Up @@ -464,7 +513,7 @@ class EquivBoolOp(BinaryBoolOp):
class NegateIntOp(IRDLOperation):
"""An integer negation operation."""

name = "asl.negate_int"
name = "asl.neg_int"

arg = operand_def(IntegerType)
res = result_def(IntegerType)
Expand Down Expand Up @@ -535,14 +584,14 @@ class ExpIntOp(BinaryIntOp):
class ShiftLeftIntOp(BinaryIntOp):
"""An integer left shift operation."""

name = "asl.shiftleft_int"
name = "asl.shl_int"


@irdl_op_definition
class ShiftRightIntOp(BinaryIntOp):
"""An integer right shift operation."""

name = "asl.shiftright_int"
name = "asl.shr_int"


@irdl_op_definition
Expand Down Expand Up @@ -1052,7 +1101,9 @@ def __init__(
ConstantBoolOp,
ConstantIntOp,
ConstantBitVectorOp,
ConstantStringOp,
# Boolean operations
BoolToI1Op,
NotOp,
AndBoolOp,
OrBoolOp,
Expand Down Expand Up @@ -1095,5 +1146,12 @@ def __init__(
# Slices
SliceSingleOp,
],
[BoolType, BoolAttr, IntegerType, BitVectorType, BitVectorAttr],
[
BoolType,
BoolAttr,
IntegerType,
BitVectorType,
BitVectorAttr,
StringType,
],
)
148 changes: 145 additions & 3 deletions asl_xdsl/interpreters/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def run_call(
) -> tuple[Any, ...]:
return interpreter.call_op(op.callee.string_value(), args)

@impl(asl.NegateIntOp)
def run_neg_int(
self, interpreter: Interpreter, op: asl.NegateIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
arg: int
[arg] = args
return (0 - arg,)

@impl(asl.AddIntOp)
def run_add_int(
self, interpreter: Interpreter, op: asl.AddIntOp, args: tuple[Any, ...]
Expand All @@ -49,16 +57,132 @@ def run_add_int(
(lhs, rhs) = args
return (lhs + rhs,)

@impl(asl.SubIntOp)
def run_sub_int(
self, interpreter: Interpreter, op: asl.SubIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs - rhs,)

@impl(asl.MulIntOp)
def run_mul_int(
self, interpreter: Interpreter, op: asl.MulIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs * rhs,)

@impl(asl.ShiftLeftIntOp)
def run_shl_int(
self, interpreter: Interpreter, op: asl.ShiftLeftIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
assert rhs >= 0
return (lhs << rhs,)

@impl(asl.ShiftRightIntOp)
def run_shr_int(
self, interpreter: Interpreter, op: asl.ShiftRightIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
assert rhs >= 0
return (lhs >> rhs,)

@impl(asl.EqIntOp)
def run_eq_int(
self, interpreter: Interpreter, op: asl.EqIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs == rhs,)

@impl(asl.NeIntOp)
def run_ne_int(
self, interpreter: Interpreter, op: asl.NeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs != rhs,)

@impl(asl.LeIntOp)
def run_le_int(
self, interpreter: Interpreter, op: asl.LeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs <= rhs,)

@impl(asl.LtIntOp)
def run_lt_int(
self, interpreter: Interpreter, op: asl.LtIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs < rhs,)

@impl(asl.GeIntOp)
def run_ge_int(
self, interpreter: Interpreter, op: asl.GeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs >= rhs,)

@impl(asl.GtIntOp)
def run_gt_int(
self, interpreter: Interpreter, op: asl.GtIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs > rhs,)

@impl(asl.ConstantIntOp)
def run_constant(
def run_constant_int(
self, interpreter: Interpreter, op: asl.ConstantIntOp, args: PythonValues
) -> PythonValues:
value = op.value
return (value.data,)

@impl(asl.ConstantStringOp)
def run_constant_string(
self, interpreter: Interpreter, op: asl.ConstantStringOp, args: PythonValues
) -> PythonValues:
value = op.value
return (value.data,)

@impl(asl.BoolToI1Op)
def run_bool_to_i1(
self, interpreter: Interpreter, op: asl.BoolToI1Op, args: PythonValues
) -> PythonValues:
arg: int
[arg] = args
return (arg,)

# region built-in function implementations

@impl_external("asl_print_int_dec")
@impl_external("print_bits_hex.0")
def asl_print_bits_hex(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: int
(arg,) = args
interpreter.print(hex(arg))
return ()

@impl_external("print_int_dec.0")
def asl_print_int_dec(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
Expand All @@ -67,7 +191,16 @@ def asl_print_int_dec(
interpreter.print(arg)
return ()

@impl_external("asl_print_char")
@impl_external("print_int_hex.0")
def asl_print_int_hex(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: int
(arg,) = args
interpreter.print(hex(arg))
return ()

@impl_external("print_char.0")
def asl_print_char(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
Expand All @@ -76,4 +209,13 @@ def asl_print_char(
interpreter.print(chr(arg))
return ()

@impl_external("print_str.0")
def asl_print_string(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: str
(arg,) = args
interpreter.print(arg)
return ()

# endregion
25 changes: 25 additions & 0 deletions tests/filecheck/dialects/asl/cf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: asl-opt %s | asl-opt %s | filecheck %s

builtin.module {
asl.func @print_str.0(%x : !asl.string) -> ()
%c = asl.constant_bool true {attr_dict}
%0 = asl.bool_to_i1 %c : !asl.bool -> i1
scf.if %0 {
%1 = asl.constant_string "TRUE" {attr_dict}
asl.call @print_str.0(%1) : (!asl.string) -> ()
} else {
%2 = asl.constant_string "FALSE" {attr_dict}
asl.call @print_str.0(%2) : (!asl.string) -> ()
}

// CHECK: %c = asl.constant_bool true {attr_dict}
// CHECK-NEXT: %0 = asl.bool_to_i1 %c : !asl.bool -> i1
// CHECK-NEXT: scf.if %0 {
// CHECK-NEXT: %1 = asl.constant_string "TRUE" {attr_dict}
// CHECK-NEXT: asl.call @print_str.0(%1) : (!asl.string) -> ()
// CHECK-NEXT: } else {
// CHECK-NEXT: %2 = asl.constant_string "FALSE" {attr_dict}
// CHECK-NEXT: asl.call @print_str.0(%2) : (!asl.string) -> ()
// CHECK-NEXT: }

}
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/asl/constant_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ builtin.module {
%fourty_two = asl.constant_int 42 {attr_dict}

%fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict}

%fourty_two_string = asl.constant_string "Forty Two" {attr_dict}
}

// CHECK: builtin.module {
// CHECK-NEXT: %true = asl.constant_bool true {attr_dict}
// CHECK-NEXT: %false = asl.constant_bool true {attr_dict}
// CHECK-NEXT: %fourty_two = asl.constant_int 42 {attr_dict}
// CHECK-NEXT: %fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict}
// CHECK-NEXT: %fourty_two_string = asl.constant_string "Forty Two" {attr_dict}
// CHECK-NEXT: }
12 changes: 6 additions & 6 deletions tests/filecheck/dialects/asl/primitives.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ builtin.module {
%int1, %int2 = "test.op"() : () -> (!asl.int, !asl.int)
// CHECK-NEXT: %int1, %int2 = "test.op"() : () -> (!asl.int, !asl.int)

%negate_int = asl.negate_int %int1 : !asl.int -> !asl.int
%neg_int = asl.neg_int %int1 : !asl.int -> !asl.int
%add_int = asl.add_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%sub_int = asl.sub_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%mul_int = asl.mul_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%exp_int = asl.exp_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftleft_int = asl.shiftleft_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftright_int = asl.shiftright_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftleft_int = asl.shl_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftright_int = asl.shr_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%div_int = asl.div_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%fdiv_int = asl.fdiv_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%frem_int = asl.frem_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int

// CHECK-NEXT: %negate_int = asl.negate_int %int1
// CHECK-NEXT: %neg_int = asl.neg_int %int1
// CHECK-NEXT: %add_int = asl.add_int %int1, %int2
// CHECK-NEXT: %sub_int = asl.sub_int %int1, %int2
// CHECK-NEXT: %mul_int = asl.mul_int %int1, %int2
// CHECK-NEXT: %exp_int = asl.exp_int %int1, %int2
// CHECK-NEXT: %shiftleft_int = asl.shiftleft_int %int1, %int2
// CHECK-NEXT: %shiftright_int = asl.shiftright_int %int1, %int2
// CHECK-NEXT: %shiftleft_int = asl.shl_int %int1, %int2
// CHECK-NEXT: %shiftright_int = asl.shr_int %int1, %int2
// CHECK-NEXT: %div_int = asl.div_int %int1, %int2
// CHECK-NEXT: %fdiv_int = asl.fdiv_int %int1, %int2
// CHECK-NEXT: %frem_int = asl.frem_int %int1, %int2
Expand Down
Loading
Loading