Skip to content

Commit

Permalink
Recursively add attribute information to the IRDL check
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed May 20, 2024
1 parent 2049e15 commit eddc5ea
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 4 deletions.
25 changes: 22 additions & 3 deletions xdsl_pdl/analysis/check_subset_to_z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
variables represent a subset of other IRDL variables.
"""

from typing import Any, Callable
from typing import Any, Callable, Sequence
from xdsl.traits import SymbolTable
from xdsl.utils.hints import isa
import z3
Expand All @@ -18,6 +18,7 @@
StringAttr,
)
from xdsl.dialects.irdl import (
AllOfOp,
AnyOfOp,
AnyOp,
AttributeOp,
Expand Down Expand Up @@ -55,6 +56,12 @@ def add_attribute_constructors_from_irdl(
)


def create_z3_attribute(attribute_sort: Any, attr_name: str, *parameters: Any) -> Any:
if parameters:
return attribute_sort.__dict__[attr_name](*parameters)
return attribute_sort.__dict__[attr_name]


def convert_attr_to_z3_attr(attr: Attribute, attribute_sort: Any) -> Any:
if attr == IndexType():
return attribute_sort.__dict__["builtin.index"]
Expand Down Expand Up @@ -104,6 +111,14 @@ def get_constraint_as_z3(
)
)
return
if isinstance(op, AllOfOp):
if not op.operands:
values_to_z3[op.output] = create_value(op.output)
return
values_to_z3[op.output] = values_to_z3[op.operands[0]]
for operand in op.operands[1:]:
add_constraint(values_to_z3[op.output] == values_to_z3[operand])
return
if isinstance(op, IsOp):
values_to_z3[op.output] = convert_attr_to_z3_attr(op.expected, attribute_sort)
return
Expand All @@ -122,7 +137,9 @@ def get_constraint_as_z3(
)
values_to_z3[op.output] = create_value(op.output)
add_constraint(
attribute_sort.__dict__["is_" + attribute_name](values_to_z3[op.output])
create_z3_attribute(
attribute_sort, "is_" + attribute_name, values_to_z3[op.output]
)
)
return
if isinstance(op, ParametricOp):
Expand All @@ -135,7 +152,9 @@ def get_constraint_as_z3(
parameters = [values_to_z3[arg] for arg in op.args]
attribute_name = dialect_def.sym_name.data + "." + base_attr_def.sym_name.data

values_to_z3[op.output] = attribute_sort.__dict__[attribute_name](*parameters)
values_to_z3[op.output] = create_z3_attribute(
attribute_sort, attribute_name, *parameters
)
return
if isinstance(op, EqOp):
val0 = values_to_z3[op.args[0]]
Expand Down
100 changes: 99 additions & 1 deletion xdsl_pdl/passes/pdl_to_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
TypeOp,
)

from xdsl.dialects.builtin import IntegerType, SymbolRefAttr, ModuleOp
from xdsl.dialects.builtin import IntegerType, SymbolRefAttr, ModuleOp, UnitAttr
from xdsl.dialects import irdl
from xdsl.traits import SymbolTable
from xdsl.utils.hints import isa
from z3 import Symbol
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, YieldOp


Expand Down Expand Up @@ -332,6 +334,100 @@ def convert_pdl_match_to_irdl_match(
walker.rewrite_op(program)


def get_op_ref_outside_dialect(
op_ref: SymbolRefAttr, location: Operation
) -> SymbolRefAttr:
"""Get an operation reference outside of the dialect."""
base_def = SymbolTable.lookup_symbol(location, op_ref)
assert base_def is not None
assert isinstance(base_def, irdl.AttributeOp | irdl.TypeOp)
base_def_dialect = base_def.parent_op()
assert isinstance(base_def_dialect, irdl.DialectOp)
new_ref = SymbolRefAttr(base_def_dialect.sym_name, [base_def.sym_name])
return new_ref


def create_param_attr_constraint_from_definition(
attr_def: irdl.TypeOp | irdl.AttributeOp,
rewriter: PatternRewriter,
) -> irdl.ParametricOp:
"""Clone the constraints on an attribute parameters at a given location."""
cloned_attr_def = attr_def.clone()
parameters = []
for cloned_op, op in zip(
cloned_attr_def.body.walk(), attr_def.body.walk(), strict=True
):
cloned_op.detach()
if isinstance(cloned_op, irdl.BaseOp) and cloned_op.base_ref is not None:
cloned_op.base_ref = get_op_ref_outside_dialect(cloned_op.base_ref, op)
if isinstance(cloned_op, irdl.ParametricOp):
cloned_op.base_type = get_op_ref_outside_dialect(cloned_op.base_type, op)
if isinstance(cloned_op, irdl.ParametersOp):
parameters = cloned_op.args
cloned_op.erase()
continue
rewriter.insert_op_before_matched_op(cloned_op)
cloned_attr_def.erase()

parent_dialect = attr_def.parent_op()
assert isinstance(parent_dialect, irdl.DialectOp)

param_op = irdl.ParametricOp(
SymbolRefAttr(parent_dialect.sym_name, [attr_def.sym_name]), parameters
)
rewriter.insert_op_before_matched_op(param_op)
return param_op


@dataclass
class EmbedIRDLAttrPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: irdl.BaseOp | irdl.ParametricOp, rewriter: PatternRewriter, /
):
if "processed" in op.attributes:
return
if isinstance(op, irdl.BaseOp):
# We cannot unfold attributes that are not from the IRDL module
if op.base_name is not None:
return
assert op.base_ref is not None
attr_def = SymbolTable.lookup_symbol(op, op.base_ref)
assert attr_def is not None
assert isinstance(attr_def, irdl.AttributeOp | irdl.TypeOp)
param_op = create_param_attr_constraint_from_definition(attr_def, rewriter)
param_op.attributes["processed"] = UnitAttr()
op.attributes["processed"] = UnitAttr()
cloned_op = op.clone()
rewriter.insert_op_before_matched_op(cloned_op)
rewriter.replace_matched_op(
irdl.AllOfOp([cloned_op.output, param_op.output])
)
return

attr_def = SymbolTable.lookup_symbol(op, op.base_type)
assert attr_def is not None
assert isinstance(attr_def, irdl.AttributeOp | irdl.TypeOp)
param_op = create_param_attr_constraint_from_definition(attr_def, rewriter)
param_op.attributes["processed"] = UnitAttr()
op.attributes["processed"] = UnitAttr()
cloned_op = op.clone()
rewriter.insert_op_before_matched_op(cloned_op)
rewriter.replace_matched_op(irdl.AllOfOp([cloned_op.output, param_op.output]))
return


def embed_irdl_attr_verifiers(op: Operation):
walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
EmbedIRDLAttrPattern(),
]
)
)
walker.rewrite_op(op)


class PDLToIRDLPass(ModulePass):
def apply(self, ctx: MLContext, op: ModuleOp):
# Grab the rewrite operation which should be the last one
Expand Down Expand Up @@ -363,3 +459,5 @@ def apply(self, ctx: MLContext, op: ModuleOp):

# Convert the remaining PDL operations to IRDL operations
convert_pdl_match_to_irdl_match(check_subset, irdl_ops)

embed_irdl_attr_verifiers(check_subset)

0 comments on commit eddc5ea

Please sign in to comment.