Skip to content

Commit

Permalink
Add way more optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed May 22, 2024
1 parent 0c2e1dc commit ee01c99
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 11 deletions.
4 changes: 2 additions & 2 deletions xdsl_pdl/analysis/check_subset_to_z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from xdsl.ir import Attribute, Operation, SSAValue
from xdsl.parser import IndexType, ModuleOp
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, YieldOp
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, MatchOp, YieldOp


def add_attribute_constructors_from_irdl(
Expand Down Expand Up @@ -161,7 +161,7 @@ def get_constraint_as_z3(
for arg in op.args[1:]:
add_constraint(val0 == values_to_z3[arg])
return
if isinstance(op, YieldOp):
if isinstance(op, YieldOp | MatchOp):
return
assert False, f"Unsupported op {op.name}"

Expand Down
215 changes: 206 additions & 9 deletions xdsl_pdl/passes/optimize_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from xdsl.parser import SymbolRefAttr
from xdsl.passes import ModulePass

from xdsl.ir import MLContext, Operation, SSAValue
from xdsl.ir import Attribute, MLContext, Operation, SSAValue
from xdsl.dialects import irdl
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.traits import IsTerminator
from z3 import v
from xdsl_pdl.dialects import irdl_extension
from xdsl.dialects.builtin import ModuleOp
from xdsl.dialects.builtin import ModuleOp, StringAttr
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
Expand Down Expand Up @@ -58,9 +59,62 @@ def get_bases(value: SSAValue) -> set[SymbolRefAttr | str] | None:
return bases
if isinstance(op, irdl.IsOp):
# TODO: Add support for known types
if isinstance(op.expected, StringAttr):
return {"#builtin.string"}
return None


def is_rooted_dag_with_one_use(value: SSAValue) -> bool:
assert isinstance(value.owner, Operation)
if len(value.uses) != 1 and not isinstance(value.owner, irdl.IsOp):
return False

values_to_walk = [value]
walked_values = {value}

operations: list[Operation] = []

while values_to_walk:
value_to_walk = values_to_walk.pop()
assert isinstance(value_to_walk.owner, Operation)
operations.append(value_to_walk.owner)
for operand in value_to_walk.owner.operands:
if operand in walked_values:
continue
walked_values.add(operand)
values_to_walk.append(operand)

for operation in operations:
assert len(operation.results) == 1
if operation.results[0] == value:
continue
if isinstance(operation, irdl.IsOp):
continue
if isinstance(operation, irdl.ParametricOp) and not operation.args:
continue
for use in operation.results[0].uses:
if use.operation not in operations:
return False

return True


def match_attribute(
value: SSAValue, attr: Attribute, mappings: dict[SSAValue, Attribute] = {}
) -> bool:
if value in mappings:
return mappings[value] == attr
if isinstance(value.owner, irdl.IsOp):
return value.owner.expected == attr
if isinstance(value.owner, irdl.AnyOfOp):
for arg in value.owner.args:
if match_attribute(arg, attr, mappings):
mappings[value] = attr
return True
return False
assert False


class RemoveUnusedOpPattern(RewritePattern):
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
if op.dialect_name() == "irdl" and op.results:
Expand Down Expand Up @@ -92,12 +146,86 @@ def match_and_rewrite(self, op: irdl.AnyOfOp, rewriter: PatternRewriter, /):
class AllOfIsPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
is_op: Operation
for is_arg in op.args:
if isinstance(is_arg.owner, irdl.IsOp):
new_args = [
arg for arg in op.args if arg == is_arg or len(arg.uses) != 1
]
rewriter.replace_matched_op(irdl.AllOfOp(new_args))
is_op = is_arg.owner
break
else:
return

new_args: list[SSAValue] = []
for arg in op.args:
if arg == is_arg:
new_args.append(arg)
continue
if not is_rooted_dag_with_one_use(arg):
new_args.append(arg)
continue
if match_attribute(is_arg, is_op.expected):
continue

# Contradiction in the AllOf
rewriter.replace_matched_op(irdl.AnyOfOp([]))
return

if len(new_args) == len(op.args):
return
rewriter.replace_matched_op(irdl.AllOfOp(new_args))


def is_dag_equivalent(
val1: SSAValue, val2: SSAValue, mappings: dict[SSAValue, SSAValue] | None = None
):
if mappings is None:
mappings = {}
if val1 in mappings:
return mappings[val1] == val2

assert isinstance(val1.owner, Operation)
assert isinstance(val2.owner, Operation)
op1 = val1.owner
op2 = val2.owner

if op1 == op2:
return True

if op1.attributes != op2.attributes:
return False

if len(op1.operands) != len(op2.operands):
return False
for operand1, operand2 in zip(op1.operands, op2.operands, strict=True):
if not is_dag_equivalent(operand1, operand2, mappings):
return False
mappings[val1] = val2
return True


class AllOfEquivAnyOfPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for index, arg in enumerate(op.args):
if not isinstance(
arg.owner, irdl.AnyOfOp
) or not is_rooted_dag_with_one_use(arg):
continue
for index2, arg2 in list(enumerate(op.args))[index + 1 :]:
if not isinstance(
arg2.owner, irdl.AnyOfOp
) or not is_rooted_dag_with_one_use(arg2):
continue
if not is_dag_equivalent(arg, arg2):
continue

rewriter.replace_matched_op(
irdl.AllOfOp(
[
*op.args[:index2],
*op.args[index2 + 1 :],
]
)
)
return


Expand Down Expand Up @@ -286,6 +414,8 @@ def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for index, arg in enumerate(op.args):
if not isinstance(arg.owner, irdl.AnyOfOp):
continue
if len(arg.owner.output.uses) != 1:
continue
new_all_ofs: list[irdl.AllOfOp] = []
for any_of_arg in arg.owner.args:
new_all_ofs.append(
Expand All @@ -305,6 +435,47 @@ def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
if bases == set():
rewriter.replace_matched_op(irdl.AnyOfOp([]))

is_value = None
for arg in op.args:
if not isinstance(arg.owner, irdl.IsOp):
continue
if is_value is None:
is_value = arg.owner.expected
continue
if is_value != arg.owner.expected:
rewriter.replace_matched_op(irdl.AnyOfOp([]))
return


class RemoveBaseFromAllOfInNestedAnyOfPattern(RewritePattern):
"""
On a pattern like this: "AllOf(AnyOf(y, z), x)", if the AnyOf is only used in
the AllOf, we can remove y or z if their bases are incompatible with x.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for any_of_arg in op.args:
if not isinstance(any_of_arg.owner, irdl.AnyOfOp):
continue
if not is_rooted_dag_with_one_use(any_of_arg):
continue

for arg in op.args:
if arg == any_of_arg:
continue
bases = get_bases(arg)
if bases is None:
continue
new_any_of_args: list[SSAValue] = []
for arg_any_of in any_of_arg.owner.args:
if meet_bases(bases, get_bases(arg_any_of)) != set():
new_any_of_args.append(arg_any_of)
if len(new_any_of_args) == len(any_of_arg.owner.args):
continue
rewriter.replace_op(any_of_arg.owner, irdl.AnyOfOp(new_any_of_args))
return


class RemoveDuplicateMatchOpPattern(RewritePattern):
@op_type_rewrite_pattern
Expand All @@ -317,6 +488,9 @@ def match_and_rewrite(
if isinstance(match_op, irdl_extension.MatchOp):
match_ops.append(match_op)

if not match_ops:
return

# Detach the match operations
for match_op in match_ops:
match_op.detach()
Expand All @@ -333,6 +507,9 @@ def match_and_rewrite(
match_op2.erase()
dedup_match_ops[index2] = None

if None not in dedup_match_ops:
return

deduped_match_ops = [
match_op for match_op in dedup_match_ops if match_op is not None
]
Expand All @@ -346,16 +523,31 @@ def match_and_rewrite(
)


class CSEIsParametricPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: irdl.IsOp | irdl.ParametricOp, rewriter: PatternRewriter, /
):
current_op = op.next_op
while current_op is not None:
if (
current_op.name == op.name
and list(current_op.operands) == list(op.operands)
and current_op.attributes == op.attributes
):
rewriter.replace_op(current_op, [], [op.output])
return
current_op = current_op.next_op


class OptimizeIRDL(ModulePass):
def apply(self, ctx: MLContext, op: ModuleOp):

walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
RemoveUnusedOpPattern(),
AllOfSinglePattern(),
AnyOfSinglePattern(),
AllOfIsPattern(),
AllOfAnyPattern(),
AllOfBaseBasePattern(),
AllOfParametricBasePattern(),
Expand All @@ -364,10 +556,15 @@ def apply(self, ctx: MLContext, op: ModuleOp):
RemoveEqOpPattern(),
AllOfNestedPattern(),
AnyOfNestedPattern(),
NestAllOfInAnyOfPattern(),
# NestAllOfInAnyOfPattern(),
AllOfEquivAnyOfPattern(),
AllOfIsPattern(),
RemoveAllOfContradictionPatterns(),
RemoveBaseFromAllOfInNestedAnyOfPattern(),
RemoveDuplicateMatchOpPattern(),
CSEIsParametricPattern(),
]
)
)

walker.rewrite_op(op)

0 comments on commit ee01c99

Please sign in to comment.