-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend ASL support: strings, comparisons, etc. #1
base: main
Are you sure you want to change the base?
Changes from all commits
9a0cc79
c280e29
fbd25c3
acfa33e
8d6ab37
ffab6a3
cb451ed
e05e161
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -151,7 +151,33 @@ | |||||
|
||||||
def print_parameter(self, printer: Printer) -> None: | ||||||
"""Print the attribute parameter.""" | ||||||
printer.print("true" if self.data else "false") | ||||||
printer.print("<true>" if self.data else "<false>") | ||||||
|
||||||
|
||||||
@irdl_attr_definition | ||||||
class StringType(ParametrizedAttribute, TypeAttribute): | ||||||
"""A string type.""" | ||||||
|
||||||
name = "asl.string" | ||||||
|
||||||
|
||||||
@irdl_attr_definition | ||||||
class StringAttr(Data[str]): | ||||||
"""A string attribute.""" | ||||||
|
||||||
name = "asl.string_attr" | ||||||
|
||||||
@classmethod | ||||||
def parse_parameter(cls, parser: AttrParser) -> bool: | ||||||
Check failure on line 171 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.10)
Check failure on line 171 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.12)
|
||||||
"""Parse the attribute parameter.""" | ||||||
parser.parse_characters('"') | ||||||
value = parser.parse_string() | ||||||
Check failure on line 174 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.10)
Check failure on line 174 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.10)
Check failure on line 174 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.10)
Check failure on line 174 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.12)
Check failure on line 174 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.12)
|
||||||
parser.parse_characters('"') | ||||||
return value | ||||||
Check failure on line 176 in asl_xdsl/dialects/asl.py GitHub Actions / build (3.10)
|
||||||
|
||||||
def print_parameter(self, printer: Printer) -> None: | ||||||
"""Print the attribute parameter.""" | ||||||
printer.print(f'"{self.data}"') | ||||||
|
||||||
|
||||||
@irdl_attr_definition | ||||||
|
@@ -323,7 +349,7 @@ | |||||
@classmethod | ||||||
def parse(cls, parser: Parser) -> ConstantIntOp: | ||||||
"""Parse the operation.""" | ||||||
value = parser.parse_integer(allow_boolean=False, allow_negative=False) | ||||||
value = parser.parse_integer(allow_boolean=False, allow_negative=True) | ||||||
attr_dict = parser.parse_optional_attr_dict() | ||||||
return ConstantIntOp(value, attr_dict) | ||||||
|
||||||
|
@@ -377,6 +403,41 @@ | |||||
printer.print_attr_dict(self.attributes) | ||||||
|
||||||
|
||||||
@irdl_op_definition | ||||||
class ConstantStringOp(IRDLOperation): | ||||||
"""A constant string operation.""" | ||||||
|
||||||
name = "asl.constant_string" | ||||||
|
||||||
value = prop_def(builtin.StringAttr) | ||||||
res = result_def(StringType) | ||||||
|
||||||
def __init__( | ||||||
self, value: str | builtin.StringAttr, attr_dict: Mapping[str, Attribute] = {} | ||||||
): | ||||||
if isinstance(value, str): | ||||||
value = builtin.StringAttr(value) | ||||||
super().__init__( | ||||||
result_types=[StringType()], | ||||||
properties={"value": value}, | ||||||
attributes=attr_dict, | ||||||
) | ||||||
|
||||||
@classmethod | ||||||
def parse(cls, parser: Parser) -> ConstantStringOp: | ||||||
"""Parse the operation.""" | ||||||
value = parser.parse_str_literal() | ||||||
attr_dict = parser.parse_optional_attr_dict() | ||||||
return ConstantStringOp(value, attr_dict) | ||||||
|
||||||
def print(self, printer: Printer) -> None: | ||||||
"""Print the operation.""" | ||||||
printer.print(" ", '"' + self.value.data + '"') | ||||||
if self.attributes: | ||||||
printer.print(" ") | ||||||
printer.print_attr_dict(self.attributes) | ||||||
Comment on lines
+426
to
+438
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would something like this work instead? assembly_format = "$value attr-dict" |
||||||
|
||||||
|
||||||
@irdl_op_definition | ||||||
class NotOp(IRDLOperation): | ||||||
"""A bitwise NOT operation.""" | ||||||
|
@@ -396,6 +457,25 @@ | |||||
) | ||||||
|
||||||
|
||||||
@irdl_op_definition | ||||||
class BoolToI1Op(IRDLOperation): | ||||||
"""A hack to convert !asl.bool to i1 so that we can use scf.if.""" | ||||||
|
||||||
name = "asl.bool_to_i1" | ||||||
|
||||||
arg = operand_def(BoolType()) | ||||||
res = result_def(builtin.IntegerType(1)) | ||||||
|
||||||
assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why, but the attr dict tends to be before the type by convention
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, that's because I moved it to the end on the other operations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there other dialects that do this? Seems to me like we should stick with the convention if not. |
||||||
|
||||||
def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}): | ||||||
super().__init__( | ||||||
operands=[arg], | ||||||
result_types=[builtin.IntegerType(1)], | ||||||
attributes=attr_dict, | ||||||
) | ||||||
|
||||||
|
||||||
class BinaryBoolOp(IRDLOperation): | ||||||
"""A binary boolean operation.""" | ||||||
|
||||||
|
@@ -464,7 +544,7 @@ | |||||
class NegateIntOp(IRDLOperation): | ||||||
"""An integer negation operation.""" | ||||||
|
||||||
name = "asl.negate_int" | ||||||
name = "asl.neg_int" | ||||||
|
||||||
arg = operand_def(IntegerType) | ||||||
res = result_def(IntegerType) | ||||||
|
@@ -535,14 +615,14 @@ | |||||
class ShiftLeftIntOp(BinaryIntOp): | ||||||
"""An integer left shift operation.""" | ||||||
|
||||||
name = "asl.shiftleft_int" | ||||||
name = "asl.shl_int" | ||||||
|
||||||
|
||||||
@irdl_op_definition | ||||||
class ShiftRightIntOp(BinaryIntOp): | ||||||
"""An integer right shift operation.""" | ||||||
|
||||||
name = "asl.shiftright_int" | ||||||
name = "asl.shr_int" | ||||||
|
||||||
|
||||||
@irdl_op_definition | ||||||
|
@@ -1052,7 +1132,9 @@ | |||||
ConstantBoolOp, | ||||||
ConstantIntOp, | ||||||
ConstantBitVectorOp, | ||||||
ConstantStringOp, | ||||||
# Boolean operations | ||||||
BoolToI1Op, | ||||||
NotOp, | ||||||
AndBoolOp, | ||||||
OrBoolOp, | ||||||
|
@@ -1095,5 +1177,13 @@ | |||||
# Slices | ||||||
SliceSingleOp, | ||||||
], | ||||||
[BoolType, BoolAttr, IntegerType, BitVectorType, BitVectorAttr], | ||||||
[ | ||||||
BoolType, | ||||||
BoolAttr, | ||||||
IntegerType, | ||||||
BitVectorType, | ||||||
BitVectorAttr, | ||||||
StringType, | ||||||
StringAttr, | ||||||
], | ||||||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||||
from typing import Any | ||||||
|
||||||
from xdsl.dialects import scf | ||||||
from xdsl.interpreter import ( | ||||||
Interpreter, | ||||||
InterpreterFunctions, | ||||||
|
@@ -40,6 +41,40 @@ | |||||
) -> tuple[Any, ...]: | ||||||
return interpreter.call_op(op.callee.string_value(), args) | ||||||
|
||||||
@impl(scf.IfOp) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised this is necessary, xDSL already has an implementation for scf.If, I would have expected it to be enough There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just pushed the change that includes the xDSL version of this to main |
||||||
def run_if( | ||||||
Check failure on line 45 in asl_xdsl/interpreters/asl.py GitHub Actions / build (3.10)
|
||||||
self, | ||||||
interpreter: Interpreter, | ||||||
op: scf.IfOp, | ||||||
args: tuple[Any, ...], | ||||||
): | ||||||
cond = args[0] | ||||||
results = [] # hack that is maybe good enough for the print_bool function | ||||||
if cond: | ||||||
args = interpreter.run_ssacfg_region( | ||||||
op.true_region, tuple(results), "then_region" | ||||||
Check failure on line 55 in asl_xdsl/interpreters/asl.py GitHub Actions / build (3.10)
Check failure on line 55 in asl_xdsl/interpreters/asl.py GitHub Actions / build (3.10)
Check failure on line 55 in asl_xdsl/interpreters/asl.py GitHub Actions / build (3.12)
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) | ||||||
else: | ||||||
args = interpreter.run_ssacfg_region( | ||||||
op.false_region, tuple(results), "else_region" | ||||||
Check failure on line 59 in asl_xdsl/interpreters/asl.py GitHub Actions / build (3.10)
Check failure on line 59 in asl_xdsl/interpreters/asl.py GitHub Actions / build (3.10)
Check failure on line 59 in asl_xdsl/interpreters/asl.py GitHub Actions / build (3.12)
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) | ||||||
|
||||||
return tuple(results) | ||||||
|
||||||
@impl_terminator(scf.YieldOp) | ||||||
def run_yield( | ||||||
self, interpreter: Interpreter, op: scf.YieldOp, args: tuple[Any, ...] | ||||||
): | ||||||
return ReturnedValues(args), () | ||||||
|
||||||
@impl(asl.NegateIntOp) | ||||||
def run_neg_int( | ||||||
self, interpreter: Interpreter, op: asl.NegateIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
arg: int | ||||||
[arg] = args | ||||||
return (0 - arg,) | ||||||
|
||||||
@impl(asl.AddIntOp) | ||||||
def run_add_int( | ||||||
self, interpreter: Interpreter, op: asl.AddIntOp, args: tuple[Any, ...] | ||||||
|
@@ -49,16 +84,132 @@ | |||||
(lhs, rhs) = args | ||||||
return (lhs + rhs,) | ||||||
|
||||||
@impl(asl.SubIntOp) | ||||||
def run_sub_int( | ||||||
self, interpreter: Interpreter, op: asl.SubIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs - rhs,) | ||||||
|
||||||
@impl(asl.MulIntOp) | ||||||
def run_mul_int( | ||||||
self, interpreter: Interpreter, op: asl.MulIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs * rhs,) | ||||||
|
||||||
@impl(asl.ShiftLeftIntOp) | ||||||
def run_shl_int( | ||||||
self, interpreter: Interpreter, op: asl.ShiftLeftIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
assert rhs >= 0 | ||||||
return (lhs << rhs,) | ||||||
|
||||||
@impl(asl.ShiftRightIntOp) | ||||||
def run_shr_int( | ||||||
self, interpreter: Interpreter, op: asl.ShiftRightIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
assert rhs >= 0 | ||||||
return (lhs >> rhs,) | ||||||
|
||||||
@impl(asl.EqIntOp) | ||||||
def run_eq_int( | ||||||
self, interpreter: Interpreter, op: asl.EqIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs == rhs,) | ||||||
|
||||||
@impl(asl.NeIntOp) | ||||||
def run_ne_int( | ||||||
self, interpreter: Interpreter, op: asl.NeIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs != rhs,) | ||||||
|
||||||
@impl(asl.LeIntOp) | ||||||
def run_le_int( | ||||||
self, interpreter: Interpreter, op: asl.LeIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs <= rhs,) | ||||||
|
||||||
@impl(asl.LtIntOp) | ||||||
def run_lt_int( | ||||||
self, interpreter: Interpreter, op: asl.LtIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs < rhs,) | ||||||
|
||||||
@impl(asl.GeIntOp) | ||||||
def run_ge_int( | ||||||
self, interpreter: Interpreter, op: asl.GeIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs >= rhs,) | ||||||
|
||||||
@impl(asl.GtIntOp) | ||||||
def run_gt_int( | ||||||
self, interpreter: Interpreter, op: asl.GtIntOp, args: tuple[Any, ...] | ||||||
) -> tuple[Any, ...]: | ||||||
lhs: int | ||||||
rhs: int | ||||||
(lhs, rhs) = args | ||||||
return (lhs > rhs,) | ||||||
|
||||||
@impl(asl.ConstantIntOp) | ||||||
def run_constant( | ||||||
def run_constant_int( | ||||||
self, interpreter: Interpreter, op: asl.ConstantIntOp, args: PythonValues | ||||||
) -> PythonValues: | ||||||
value = op.value | ||||||
return (value.data,) | ||||||
|
||||||
@impl(asl.ConstantStringOp) | ||||||
def run_constant_string( | ||||||
self, interpreter: Interpreter, op: asl.ConstantStringOp, args: PythonValues | ||||||
) -> PythonValues: | ||||||
value = op.value | ||||||
return (value.data,) | ||||||
|
||||||
@impl(asl.BoolToI1Op) | ||||||
def run_bool_to_i1( | ||||||
self, interpreter: Interpreter, op: asl.BoolToI1Op, args: PythonValues | ||||||
) -> PythonValues: | ||||||
arg: int | ||||||
[arg] = args | ||||||
return (arg,) | ||||||
|
||||||
# region built-in function implementations | ||||||
|
||||||
@impl_external("asl_print_int_dec") | ||||||
@impl_external("print_bits_hex.0") | ||||||
def asl_print_bits_hex( | ||||||
self, interpreter: Interpreter, op: Operation, args: PythonValues | ||||||
) -> PythonValues: | ||||||
arg: int | ||||||
(arg,) = args | ||||||
interpreter.print(hex(arg)) | ||||||
return () | ||||||
|
||||||
@impl_external("print_int_dec.0") | ||||||
def asl_print_int_dec( | ||||||
self, interpreter: Interpreter, op: Operation, args: PythonValues | ||||||
) -> PythonValues: | ||||||
|
@@ -67,7 +218,16 @@ | |||||
interpreter.print(arg) | ||||||
return () | ||||||
|
||||||
@impl_external("asl_print_char") | ||||||
@impl_external("print_int_hex.0") | ||||||
def asl_print_int_hex( | ||||||
self, interpreter: Interpreter, op: Operation, args: PythonValues | ||||||
) -> PythonValues: | ||||||
arg: int | ||||||
(arg,) = args | ||||||
interpreter.print(hex(arg)) | ||||||
return () | ||||||
|
||||||
@impl_external("print_char.0") | ||||||
def asl_print_char( | ||||||
self, interpreter: Interpreter, op: Operation, args: PythonValues | ||||||
) -> PythonValues: | ||||||
|
@@ -76,4 +236,13 @@ | |||||
interpreter.print(chr(arg)) | ||||||
return () | ||||||
|
||||||
@impl_external("print_str.0") | ||||||
def asl_print_string( | ||||||
self, interpreter: Interpreter, op: Operation, args: PythonValues | ||||||
) -> PythonValues: | ||||||
arg: str | ||||||
(arg,) = args | ||||||
interpreter.print(arg) | ||||||
return () | ||||||
|
||||||
# endregion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend using the StringAttr in the builtin dialect