Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend ASL support: strings, comparisons, etc. #1

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 96 additions & 6 deletions asl_xdsl/dialects/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,33 @@

def print_parameter(self, printer: Printer) -> None:
"""Print the attribute parameter."""
printer.print("true" if self.data else "false")
printer.print("<true>" if self.data else "<false>")


@irdl_attr_definition
class StringType(ParametrizedAttribute, TypeAttribute):
"""A string type."""

name = "asl.string"


@irdl_attr_definition
class StringAttr(Data[str]):
Comment on lines +164 to +165
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend using the StringAttr in the builtin dialect

"""A string attribute."""

name = "asl.string_attr"

@classmethod
def parse_parameter(cls, parser: AttrParser) -> bool:

Check failure on line 171 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Method "parse_parameter" overrides class "Data" in an incompatible manner   Return type mismatch: base method returns type "str", override returns type "bool"     "bool" is not assignable to "str" (reportIncompatibleMethodOverride)

Check failure on line 171 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Method "parse_parameter" overrides class "Data" in an incompatible manner   Return type mismatch: base method returns type "str", override returns type "bool"     "bool" is not assignable to "str" (reportIncompatibleMethodOverride)
"""Parse the attribute parameter."""
parser.parse_characters('"')
value = parser.parse_string()

Check failure on line 174 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type of "value" is unknown (reportUnknownVariableType)

Check failure on line 174 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type of "parse_string" is unknown (reportUnknownMemberType)

Check failure on line 174 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Cannot access attribute "parse_string" for class "AttrParser"   Attribute "parse_string" is unknown (reportAttributeAccessIssue)

Check failure on line 174 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Type of "value" is unknown (reportUnknownVariableType)

Check failure on line 174 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Type of "parse_string" is unknown (reportUnknownMemberType)

Check failure on line 174 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Cannot access attribute "parse_string" for class "AttrParser"   Attribute "parse_string" is unknown (reportAttributeAccessIssue)
parser.parse_characters('"')
return value

Check failure on line 176 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Return type is unknown (reportUnknownVariableType)

Check failure on line 176 in asl_xdsl/dialects/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Return type is unknown (reportUnknownVariableType)

def print_parameter(self, printer: Printer) -> None:
"""Print the attribute parameter."""
printer.print(f'"{self.data}"')


@irdl_attr_definition
Expand Down Expand Up @@ -323,7 +349,7 @@
@classmethod
def parse(cls, parser: Parser) -> ConstantIntOp:
"""Parse the operation."""
value = parser.parse_integer(allow_boolean=False, allow_negative=False)
value = parser.parse_integer(allow_boolean=False, allow_negative=True)
attr_dict = parser.parse_optional_attr_dict()
return ConstantIntOp(value, attr_dict)

Expand Down Expand Up @@ -377,6 +403,41 @@
printer.print_attr_dict(self.attributes)


@irdl_op_definition
class ConstantStringOp(IRDLOperation):
"""A constant string operation."""

name = "asl.constant_string"

value = prop_def(builtin.StringAttr)
res = result_def(StringType)

def __init__(
self, value: str | builtin.StringAttr, attr_dict: Mapping[str, Attribute] = {}
):
if isinstance(value, str):
value = builtin.StringAttr(value)
super().__init__(
result_types=[StringType()],
properties={"value": value},
attributes=attr_dict,
)

@classmethod
def parse(cls, parser: Parser) -> ConstantStringOp:
"""Parse the operation."""
value = parser.parse_str_literal()
attr_dict = parser.parse_optional_attr_dict()
return ConstantStringOp(value, attr_dict)

def print(self, printer: Printer) -> None:
"""Print the operation."""
printer.print(" ", '"' + self.value.data + '"')
if self.attributes:
printer.print(" ")
printer.print_attr_dict(self.attributes)
Comment on lines +426 to +438
Copy link
Member

@superlopuh superlopuh Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this work instead?

assembly_format = "$value attr-dict"



@irdl_op_definition
class NotOp(IRDLOperation):
"""A bitwise NOT operation."""
Expand All @@ -396,6 +457,25 @@
)


@irdl_op_definition
class BoolToI1Op(IRDLOperation):
"""A hack to convert !asl.bool to i1 so that we can use scf.if."""

name = "asl.bool_to_i1"

arg = operand_def(BoolType())
res = result_def(builtin.IntegerType(1))

assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why, but the attr dict tends to be before the type by convention

Suggested change
assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"
assembly_format = "$arg attr-dict `:` type($arg) `->` type($res)"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, that's because I moved it to the end on the other operations.
Mostly because I hate being it at the beginning, because it is "optional" in a way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other dialects that do this? Seems to me like we should stick with the convention if not.


def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
super().__init__(
operands=[arg],
result_types=[builtin.IntegerType(1)],
attributes=attr_dict,
)


class BinaryBoolOp(IRDLOperation):
"""A binary boolean operation."""

Expand Down Expand Up @@ -464,7 +544,7 @@
class NegateIntOp(IRDLOperation):
"""An integer negation operation."""

name = "asl.negate_int"
name = "asl.neg_int"

arg = operand_def(IntegerType)
res = result_def(IntegerType)
Expand Down Expand Up @@ -535,14 +615,14 @@
class ShiftLeftIntOp(BinaryIntOp):
"""An integer left shift operation."""

name = "asl.shiftleft_int"
name = "asl.shl_int"


@irdl_op_definition
class ShiftRightIntOp(BinaryIntOp):
"""An integer right shift operation."""

name = "asl.shiftright_int"
name = "asl.shr_int"


@irdl_op_definition
Expand Down Expand Up @@ -1052,7 +1132,9 @@
ConstantBoolOp,
ConstantIntOp,
ConstantBitVectorOp,
ConstantStringOp,
# Boolean operations
BoolToI1Op,
NotOp,
AndBoolOp,
OrBoolOp,
Expand Down Expand Up @@ -1095,5 +1177,13 @@
# Slices
SliceSingleOp,
],
[BoolType, BoolAttr, IntegerType, BitVectorType, BitVectorAttr],
[
BoolType,
BoolAttr,
IntegerType,
BitVectorType,
BitVectorAttr,
StringType,
StringAttr,
],
)
175 changes: 172 additions & 3 deletions asl_xdsl/interpreters/asl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

from xdsl.dialects import scf
from xdsl.interpreter import (
Interpreter,
InterpreterFunctions,
Expand Down Expand Up @@ -40,6 +41,40 @@
) -> tuple[Any, ...]:
return interpreter.call_op(op.callee.string_value(), args)

@impl(scf.IfOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this is necessary, xDSL already has an implementation for scf.If, I would have expected it to be enough

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed the change that includes the xDSL version of this to main

def run_if(

Check failure on line 45 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Return type, "tuple[Unknown, ...]", is partially unknown (reportUnknownParameterType)

Check failure on line 45 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Return type, "tuple[Unknown, ...]", is partially unknown (reportUnknownParameterType)
self,
interpreter: Interpreter,
op: scf.IfOp,
args: tuple[Any, ...],
):
cond = args[0]
results = [] # hack that is maybe good enough for the print_bool function
if cond:
args = interpreter.run_ssacfg_region(
op.true_region, tuple(results), "then_region"

Check failure on line 55 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Argument type is partially unknown   Argument corresponds to parameter "args" in function "run_ssacfg_region"   Argument type is "tuple[Unknown, ...]" (reportUnknownArgumentType)

Check failure on line 55 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Argument type is partially unknown   Argument corresponds to parameter "iterable" in function "__new__"   Argument type is "list[Unknown]" (reportUnknownArgumentType)

Check failure on line 55 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Argument type is partially unknown   Argument corresponds to parameter "args" in function "run_ssacfg_region"   Argument type is "tuple[Unknown, ...]" (reportUnknownArgumentType)

Check failure on line 55 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Argument type is partially unknown   Argument corresponds to parameter "iterable" in function "__new__"   Argument type is "list[Unknown]" (reportUnknownArgumentType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
op.true_region, tuple(results), "then_region"
op.true_region, (), "then_region"

)
else:
args = interpreter.run_ssacfg_region(
op.false_region, tuple(results), "else_region"

Check failure on line 59 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Argument type is partially unknown   Argument corresponds to parameter "args" in function "run_ssacfg_region"   Argument type is "tuple[Unknown, ...]" (reportUnknownArgumentType)

Check failure on line 59 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Argument type is partially unknown   Argument corresponds to parameter "iterable" in function "__new__"   Argument type is "list[Unknown]" (reportUnknownArgumentType)

Check failure on line 59 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Argument type is partially unknown   Argument corresponds to parameter "args" in function "run_ssacfg_region"   Argument type is "tuple[Unknown, ...]" (reportUnknownArgumentType)

Check failure on line 59 in asl_xdsl/interpreters/asl.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Argument type is partially unknown   Argument corresponds to parameter "iterable" in function "__new__"   Argument type is "list[Unknown]" (reportUnknownArgumentType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
op.false_region, tuple(results), "else_region"
op.false_region, (), "else_region"

)

return tuple(results)

@impl_terminator(scf.YieldOp)
def run_yield(
self, interpreter: Interpreter, op: scf.YieldOp, args: tuple[Any, ...]
):
return ReturnedValues(args), ()

@impl(asl.NegateIntOp)
def run_neg_int(
self, interpreter: Interpreter, op: asl.NegateIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
arg: int
[arg] = args
return (0 - arg,)

@impl(asl.AddIntOp)
def run_add_int(
self, interpreter: Interpreter, op: asl.AddIntOp, args: tuple[Any, ...]
Expand All @@ -49,16 +84,132 @@
(lhs, rhs) = args
return (lhs + rhs,)

@impl(asl.SubIntOp)
def run_sub_int(
self, interpreter: Interpreter, op: asl.SubIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs - rhs,)

@impl(asl.MulIntOp)
def run_mul_int(
self, interpreter: Interpreter, op: asl.MulIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs * rhs,)

@impl(asl.ShiftLeftIntOp)
def run_shl_int(
self, interpreter: Interpreter, op: asl.ShiftLeftIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
assert rhs >= 0
return (lhs << rhs,)

@impl(asl.ShiftRightIntOp)
def run_shr_int(
self, interpreter: Interpreter, op: asl.ShiftRightIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
assert rhs >= 0
return (lhs >> rhs,)

@impl(asl.EqIntOp)
def run_eq_int(
self, interpreter: Interpreter, op: asl.EqIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs == rhs,)

@impl(asl.NeIntOp)
def run_ne_int(
self, interpreter: Interpreter, op: asl.NeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs != rhs,)

@impl(asl.LeIntOp)
def run_le_int(
self, interpreter: Interpreter, op: asl.LeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs <= rhs,)

@impl(asl.LtIntOp)
def run_lt_int(
self, interpreter: Interpreter, op: asl.LtIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs < rhs,)

@impl(asl.GeIntOp)
def run_ge_int(
self, interpreter: Interpreter, op: asl.GeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs >= rhs,)

@impl(asl.GtIntOp)
def run_gt_int(
self, interpreter: Interpreter, op: asl.GtIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs > rhs,)

@impl(asl.ConstantIntOp)
def run_constant(
def run_constant_int(
self, interpreter: Interpreter, op: asl.ConstantIntOp, args: PythonValues
) -> PythonValues:
value = op.value
return (value.data,)

@impl(asl.ConstantStringOp)
def run_constant_string(
self, interpreter: Interpreter, op: asl.ConstantStringOp, args: PythonValues
) -> PythonValues:
value = op.value
return (value.data,)

@impl(asl.BoolToI1Op)
def run_bool_to_i1(
self, interpreter: Interpreter, op: asl.BoolToI1Op, args: PythonValues
) -> PythonValues:
arg: int
[arg] = args
return (arg,)

# region built-in function implementations

@impl_external("asl_print_int_dec")
@impl_external("print_bits_hex.0")
def asl_print_bits_hex(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: int
(arg,) = args
interpreter.print(hex(arg))
return ()

@impl_external("print_int_dec.0")
def asl_print_int_dec(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
Expand All @@ -67,7 +218,16 @@
interpreter.print(arg)
return ()

@impl_external("asl_print_char")
@impl_external("print_int_hex.0")
def asl_print_int_hex(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: int
(arg,) = args
interpreter.print(hex(arg))
return ()

@impl_external("print_char.0")
def asl_print_char(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
Expand All @@ -76,4 +236,13 @@
interpreter.print(chr(arg))
return ()

@impl_external("print_str.0")
def asl_print_string(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: str
(arg,) = args
interpreter.print(arg)
return ()

# endregion
Loading
Loading