Skip to content

Commit

Permalink
Add a lot of optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed May 21, 2024
1 parent eddc5ea commit 0089d70
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 62 deletions.
33 changes: 14 additions & 19 deletions tests/filecheck/pdl_to_irdl_check/add_commute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,17 @@ pdl.pattern @AddCommute : benefit(0) {
}
}

// CHECK: irdl_ext.check_subset {
// CHECK-NEXT: %match_t = irdl.any
// CHECK-NEXT: %match_op_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %match_op_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %match_op_t = irdl.any_of(%match_op_index, %match_op_integer)
// CHECK-NEXT: irdl_ext.eq %match_op_t, %match_t
// CHECK-NEXT: irdl_ext.eq %match_op_t, %match_t
// CHECK-NEXT: irdl_ext.eq %match_op_t, %match_t
// CHECK-NEXT: irdl_ext.yield %match_t, %match_t, %match_t
// CHECK-NEXT: } of {
// CHECK-NEXT: %rewrite_t = irdl.any
// CHECK-NEXT: %rewrite_new_op_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %rewrite_new_op_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %rewrite_new_op_t = irdl.any_of(%rewrite_new_op_index, %rewrite_new_op_integer)
// CHECK-NEXT: irdl_ext.eq %rewrite_new_op_t, %rewrite_t
// CHECK-NEXT: irdl_ext.eq %rewrite_new_op_t, %rewrite_t
// CHECK-NEXT: irdl_ext.eq %rewrite_new_op_t, %rewrite_t
// CHECK-NEXT: irdl_ext.yield %rewrite_t, %rewrite_t, %rewrite_t
// CHECK-NEXT: }
// CHECK: irdl_ext.check_subset {
// CHECK-NEXT: %match_op_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %0 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %match_op_integer = irdl.parametric @builtin::@integer_type<%0>
// CHECK-NEXT: %match_op_t = irdl.any_of(%match_op_index, %match_op_integer)
// CHECK-NEXT: irdl_ext.yield %match_op_t, %match_op_t, %match_op_t
// CHECK-NEXT: } of {
// CHECK-NEXT: %rewrite_new_op_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %1 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %rewrite_new_op_integer = irdl.parametric @builtin::@integer_type<%1>
// CHECK-NEXT: %rewrite_new_op_t = irdl.any_of(%rewrite_new_op_index, %rewrite_new_op_integer)
// CHECK-NEXT: irdl_ext.yield %rewrite_new_op_t, %rewrite_new_op_t, %rewrite_new_op_t
// CHECK-NEXT: }

68 changes: 25 additions & 43 deletions tests/filecheck/pdl_to_irdl_check/mulsi_extended_bug.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,54 +80,36 @@ pdl.pattern @MulSIExtendedRHSOne : benefit(0) {
}

// CHECK: irdl_ext.check_subset {
// CHECK-NEXT: %match_t = irdl.any
// CHECK-NEXT: %match_one = irdl.any
// CHECK-NEXT: %match_one_1 = irdl.parametric @builtin::@integer_attr<%match_one, %match_t>
// CHECK-NEXT: %match_one_op_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %match_one_op_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %match_one_op_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %0 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %match_one_op_integer = irdl.parametric @builtin::@integer_type<%0>
// CHECK-NEXT: %match_one_op_t = irdl.any_of(%match_one_op_index, %match_one_op_integer)
// CHECK-NEXT: irdl_ext.eq %match_one_op_t, %match_t
// CHECK-NEXT: %match_root_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %match_root_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %match_root_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %1 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %match_root_integer = irdl.parametric @builtin::@integer_type<%1>
// CHECK-NEXT: %match_root_t = irdl.any_of(%match_root_index, %match_root_integer)
// CHECK-NEXT: irdl_ext.eq %match_root_t, %match_t
// CHECK-NEXT: irdl_ext.eq %match_root_t, %match_t
// CHECK-NEXT: irdl_ext.eq %match_root_t, %match_t
// CHECK-NEXT: irdl_ext.eq %match_root_t, %match_t
// CHECK-NEXT: irdl_ext.yield %match_t, %match_t, %match_t, %match_t
// CHECK-NEXT: %2 = irdl.all_of(%match_root_t, %match_one_op_t)
// CHECK-NEXT: irdl_ext.yield %2, %2, %2, %2
// CHECK-NEXT: } of {
// CHECK-NEXT: %rewrite_t = irdl.any
// CHECK-NEXT: %rewrite_one = irdl.any
// CHECK-NEXT: %rewrite_one_1 = irdl.parametric @builtin::@integer_attr<%rewrite_one, %rewrite_t>
// CHECK-NEXT: %rewrite_one_op_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %rewrite_one_op_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %rewrite_one_op_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %3 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %rewrite_one_op_integer = irdl.parametric @builtin::@integer_type<%3>
// CHECK-NEXT: %rewrite_one_op_t = irdl.any_of(%rewrite_one_op_index, %rewrite_one_op_integer)
// CHECK-NEXT: irdl_ext.eq %rewrite_one_op_t, %rewrite_t
// CHECK-NEXT: %rewrite_zero = irdl.any
// CHECK-NEXT: %rewrite_zero_1 = irdl.parametric @builtin::@integer_attr<%rewrite_zero, %rewrite_t>
// CHECK-NEXT: %rewrite_zero_op_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %rewrite_zero_op_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %rewrite_zero_op_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %4 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %rewrite_zero_op_integer = irdl.parametric @builtin::@integer_type<%4>
// CHECK-NEXT: %rewrite_zero_op_t = irdl.any_of(%rewrite_zero_op_index, %rewrite_zero_op_integer)
// CHECK-NEXT: irdl_ext.eq %rewrite_zero_op_t, %rewrite_t
// CHECK-NEXT: %rewrite_two = irdl.is 2 : i64
// CHECK-NEXT: %rewrite_i1 = irdl.is i1
// CHECK-NEXT: %rewrite_cmpi_op_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %rewrite_cmpi_op_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %rewrite_cmpi_op_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %5 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %rewrite_cmpi_op_integer = irdl.parametric @builtin::@integer_type<%5>
// CHECK-NEXT: %rewrite_cmpi_op_t = irdl.any_of(%rewrite_cmpi_op_index, %rewrite_cmpi_op_integer)
// CHECK-NEXT: %rewrite_cmpi_op_i1 = irdl.is i1
// CHECK-NEXT: irdl_ext.eq %rewrite_cmpi_op_t, %rewrite_t
// CHECK-NEXT: irdl_ext.eq %rewrite_cmpi_op_t, %rewrite_t
// CHECK-NEXT: irdl_ext.eq %rewrite_cmpi_op_i1, %rewrite_i1
// CHECK-NEXT: %rewrite_extsi_op_integer1 = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %rewrite_extsi_op_integer2 = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: irdl_ext.eq %rewrite_extsi_op_integer1, %rewrite_i1
// CHECK-NEXT: irdl_ext.eq %rewrite_extsi_op_integer2, %rewrite_t
// CHECK-NEXT: %rewrite_root_index = irdl.base @builtin::@index {"base_ref" = @builtin::@index}
// CHECK-NEXT: %rewrite_root_integer = irdl.base @builtin::@integer_type {"base_ref" = @builtin::@integer_type}
// CHECK-NEXT: %6 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %rewrite_extsi_op_integer2 = irdl.parametric @builtin::@integer_type<%6>
// CHECK-NEXT: %rewrite_root_index = irdl.parametric @builtin::@index<>
// CHECK-NEXT: %7 = irdl.base "#int" {"base_name" = "#int"}
// CHECK-NEXT: %rewrite_root_integer = irdl.parametric @builtin::@integer_type<%7>
// CHECK-NEXT: %rewrite_root_t = irdl.any_of(%rewrite_root_index, %rewrite_root_integer)
// CHECK-NEXT: irdl_ext.eq %rewrite_root_t, %rewrite_t
// CHECK-NEXT: irdl_ext.eq %rewrite_root_t, %rewrite_t
// CHECK-NEXT: irdl_ext.eq %rewrite_root_t, %rewrite_t
// CHECK-NEXT: irdl_ext.eq %rewrite_root_t, %rewrite_t
// CHECK-NEXT: irdl_ext.yield %rewrite_t, %rewrite_t, %rewrite_t, %rewrite_t
// CHECK-NEXT: %8 = irdl.all_of(%rewrite_root_t, %rewrite_extsi_op_integer2, %rewrite_cmpi_op_t, %rewrite_zero_op_t, %rewrite_one_op_t)
// CHECK-NEXT: irdl_ext.yield %8, %8, %8, %8
// CHECK-NEXT: }

246 changes: 246 additions & 0 deletions xdsl_pdl/passes/optimize_irdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from xdsl.parser import SymbolRefAttr
from xdsl.passes import ModulePass

from xdsl.ir import MLContext, Operation, SSAValue, Use
from xdsl.dialects import irdl
from xdsl.rewriter import InsertPoint
from xdsl_pdl.dialects import irdl_extension
from xdsl.dialects.builtin import ModuleOp
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)


def all_of_has_base_of_type(op: irdl.AllOfOp, base_type: SymbolRefAttr) -> int | None:
"""Return the index of the first argument of `op` that is a `BaseOp` with the given `base_type`."""
for index, arg in enumerate(op.args):
if not isinstance(arg.owner, irdl.BaseOp):
continue
if arg.owner.base_ref == base_type:
return index
return None


class RemoveUnusedOpPattern(RewritePattern):
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
if op.dialect_name() == "irdl" and op.results:
for result in op.results:
if result.uses:
return
rewriter.erase_op(op)


class AllOfSinglePattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
if len(op.args) == 1:
rewriter.replace_matched_op([], [op.args[0]])
return
if len(op.args) == 0:
rewriter.replace_matched_op(irdl.AnyOp())
return


class AllOfIsPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for is_arg in op.args:
if isinstance(is_arg.owner, irdl.IsOp):
new_args = [
arg for arg in op.args if arg == is_arg or len(arg.uses) != 1
]
rewriter.replace_matched_op(irdl.AllOfOp(new_args))
return


class AllOfAnyPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for index, arg in enumerate(op.args):
if isinstance(arg.owner, irdl.AnyOp) and len(arg.uses) == 1:
rewriter.replace_matched_op(
irdl.AllOfOp(op.args[:index] + op.args[index + 1 :])
)
return


class AllOfBaseBasePattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for index, arg in enumerate(op.args):
if not isinstance(arg.owner, irdl.BaseOp):
continue
if len(arg.uses) != 1:
continue
for arg2 in op.args[:index] + op.args[index + 1 :]:
if not isinstance(arg2.owner, irdl.BaseOp):
continue
if (
arg.owner.base_ref == arg2.owner.base_ref
and arg.owner.base_name == arg2.owner.base_name
):
rewriter.replace_matched_op(
irdl.AllOfOp(op.args[:index] + op.args[index + 1 :])
)
return


class AllOfParametricBasePattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for arg in op.args:
if not isinstance(arg.owner, irdl.ParametricOp):
continue
for index2, arg2 in enumerate(op.args):
if not isinstance(arg2.owner, irdl.BaseOp):
continue
if len(arg2.uses) != 1:
continue
if arg.owner.base_type == arg2.owner.base_ref:
rewriter.replace_matched_op(
irdl.AllOfOp(op.args[:index2] + op.args[index2 + 1 :])
)
return


class AllOfParametricParametricPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for index, arg in enumerate(op.args):
if not isinstance(arg.owner, irdl.ParametricOp):
continue
for index2, arg2 in list(enumerate(op.args))[index + 1 :]:
if not isinstance(arg2.owner, irdl.ParametricOp):
continue
if arg.owner.base_type == arg2.owner.base_type:
args: list[SSAValue] = []
for param1, param2 in zip(arg.owner.args, arg2.owner.args):
param_all_of = irdl.AllOfOp([param1, param2])
rewriter.insert_op_before_matched_op(param_all_of)
args.append(param_all_of.output)
new_parametric = irdl.ParametricOp(arg.owner.base_type, args)
rewriter.replace_matched_op(
[
new_parametric,
irdl.AllOfOp(
[
*op.args[:index],
*op.args[index + 1 : index2],
*op.args[index2 + 1 :],
new_parametric.output,
]
),
]
)
return


class AllOfIdenticalPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for index1, arg1 in enumerate(op.args):
for arg2 in op.args[index1 + 1 :]:
if arg1 == arg2:
rewriter.replace_matched_op(
irdl.AllOfOp(
[
arg
for index, arg in enumerate(op.args)
if index != index1
]
)
)
return


class AllOfNestedPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl.AllOfOp, rewriter: PatternRewriter, /):
for index, arg in enumerate(op.args):
if not isinstance(arg.owner, irdl.AllOfOp):
continue
new_args = [*op.args[:index], *arg.owner.args, *op.args[index + 1 :]]
rewriter.replace_matched_op(irdl.AllOfOp(new_args))
return


class RemoveEqOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: irdl_extension.EqOp, rewriter: PatternRewriter, /):
if len(op.args) != 2:
return
lhs = op.args[0]
rhs = op.args[1]
assert isinstance(lhs.owner, Operation)
assert isinstance(rhs.owner, Operation)
assert lhs.owner.parent_block() == rhs.owner.parent_block()
block = lhs.owner.parent_block()
assert block is not None

# Get the operation indices of the operands
block_ops = list(block.ops)
index_lhs = block_ops.index(lhs.owner)
index_rhs = block_ops.index(rhs.owner)

# Get the earliest operation using either operand
earliest_use_index = None
for index, block_op in enumerate(block_ops):
if lhs in block_op.operands or rhs in block_op.operands:
earliest_use_index = index
break
else:
assert False

# Merging both operations is harder in that case, so we don't do it for now
if earliest_use_index < max(index_lhs, index_rhs):
return

# Get the latest operation
if index_lhs > index_rhs:
insert_point = InsertPoint.after(lhs.owner)
else:
insert_point = InsertPoint.after(rhs.owner)

# Create a new `AllOfOp` with the operands of the `EqOp`
all_of_op = irdl.AllOfOp([lhs, rhs])
rewriter.insert_op_at_location(all_of_op, insert_point)

# Erase the `EqOp`
rewriter.erase_matched_op()

# Replace uses of both operands with the `AllOfOp`
# Do not replace the uses of the `AllOfOp` itself
for use in [*lhs.uses, *rhs.uses]:
if use.operation is all_of_op:
continue
operands = use.operation.operands
use.operation.operands = [
*operands[: use.index],
all_of_op.output,
*operands[use.index + 1 :],
]


class OptimizeIRDL(ModulePass):
def apply(self, ctx: MLContext, op: ModuleOp):
walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
RemoveUnusedOpPattern(),
AllOfSinglePattern(),
AllOfIsPattern(),
AllOfAnyPattern(),
AllOfBaseBasePattern(),
AllOfParametricBasePattern(),
AllOfParametricParametricPattern(),
AllOfIdenticalPattern(),
RemoveEqOpPattern(),
AllOfNestedPattern(),
]
)
)
walker.rewrite_op(op)
6 changes: 6 additions & 0 deletions xdsl_pdl/tools/analyze_irdl_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from xdsl.dialects.irdl import IRDL
from xdsl.dialects.pdl import PDL
from xdsl_pdl.passes.optimize_irdl import OptimizeIRDL
from xdsl_pdl.passes.pdl_to_irdl import PDLToIRDLPass


Expand All @@ -29,6 +30,7 @@ def main():
)
arg_parser.add_argument("input_file", type=str, help="path to input file")
arg_parser.add_argument("irdl_file", type=str, help="path to IRDL file")
arg_parser.add_argument("--debug", action="store_true", help="enable debug mode")
args = arg_parser.parse_args()

# Setup the xDSL context
Expand All @@ -52,6 +54,10 @@ def main():
)

PDLToIRDLPass().apply(ctx, program)
OptimizeIRDL().apply(ctx, program)
if args.debug:
print("Converted IRDL program:")
print(program)
solver = z3.Solver()
check_subset_to_z3(program, solver)

Expand Down
Loading

0 comments on commit 0089d70

Please sign in to comment.