diff --git a/pyproject.toml b/pyproject.toml index 777751b..9694e10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ generate-pdl-rewrite = "xdsl_pdl.tools.generate_pdl_rewrite:main" generate-pdl-matches = "xdsl_pdl.tools.generate_pdl_matches:main" analyze-pdl-rewrite = "xdsl_pdl.tools.analyze_pdl_rewrite:main" generate-table = "xdsl_pdl.tools.generate_table:main" -check-irdl-subset = "xdsl_pdl.tools.check_irdl_subset:main" +test-check-irdl-subset = "xdsl_pdl.tools.test_check_irdl_subset:main" test-pdl-to-irdl-check = "xdsl_pdl.tools.test_pdl_to_irdl_check:main" [build-system] diff --git a/tests/filecheck/check_irdl_subset/eq_to_non_eq.mlir b/tests/filecheck/check_irdl_subset/eq_to_non_eq.mlir index 90d1e68..0082841 100644 --- a/tests/filecheck/check_irdl_subset/eq_to_non_eq.mlir +++ b/tests/filecheck/check_irdl_subset/eq_to_non_eq.mlir @@ -1,4 +1,4 @@ -// RUN: check-irdl-subset %s | filecheck %s +// RUN: test-check-irdl-subset %s | filecheck %s // Check that int | vec is not a subset of int diff --git a/tests/filecheck/check_irdl_subset/int_or_vec_to_int.mlir b/tests/filecheck/check_irdl_subset/int_or_vec_to_int.mlir index 1b43e2d..9f38b35 100644 --- a/tests/filecheck/check_irdl_subset/int_or_vec_to_int.mlir +++ b/tests/filecheck/check_irdl_subset/int_or_vec_to_int.mlir @@ -1,4 +1,4 @@ -// RUN: check-irdl-subset %s | filecheck %s +// RUN: test-check-irdl-subset %s | filecheck %s // Check that int | vec is not a subset of int diff --git a/tests/filecheck/check_irdl_subset/int_to_int_or_vec.mlir b/tests/filecheck/check_irdl_subset/int_to_int_or_vec.mlir index 1c4e391..90ec723 100644 --- a/tests/filecheck/check_irdl_subset/int_to_int_or_vec.mlir +++ b/tests/filecheck/check_irdl_subset/int_to_int_or_vec.mlir @@ -1,4 +1,4 @@ -// RUN: check-irdl-subset %s | filecheck %s +// RUN: test-check-irdl-subset %s | filecheck %s // Check that int is a subset of int | vec diff --git a/tests/filecheck/check_irdl_subset/non_eq_to_eq.mlir b/tests/filecheck/check_irdl_subset/non_eq_to_eq.mlir index f34a8cd..a28bbef 100644 --- a/tests/filecheck/check_irdl_subset/non_eq_to_eq.mlir +++ b/tests/filecheck/check_irdl_subset/non_eq_to_eq.mlir @@ -1,4 +1,4 @@ -// RUN: check-irdl-subset %s | filecheck %s +// RUN: test-check-irdl-subset %s | filecheck %s // Check that int | vec is not a subset of int diff --git a/xdsl_pdl/tools/check_irdl_subset.py b/xdsl_pdl/analysis/check_subset_to_z3.py similarity index 82% rename from xdsl_pdl/tools/check_irdl_subset.py rename to xdsl_pdl/analysis/check_subset_to_z3.py index e48ad15..692da77 100644 --- a/xdsl_pdl/tools/check_irdl_subset.py +++ b/xdsl_pdl/analysis/check_subset_to_z3.py @@ -1,9 +1,8 @@ """ -Check if a group of IRDL variables represent a subset of other IRDL variables. +Create an SMT query to check if a group of IRDL +variables represent a subset of other IRDL variables. """ -import argparse -import sys from typing import Any, Callable from xdsl.traits import SymbolTable from xdsl.utils.hints import isa @@ -11,28 +10,23 @@ from xdsl.dialects.builtin import ( AnyIntegerAttr, - Builtin, IntAttr, - IntegerAttr, IntegerType, ) -from xdsl.dialects.func import Func from xdsl.dialects.irdl import ( - IRDL, AnyOfOp, AnyOp, AttributeOp, BaseOp, IsOp, - ParametersOp, ParametricOp, TypeOp, DialectOp, ) -from xdsl.ir import Attribute, MLContext, Operation, SSAValue -from xdsl.parser import IndexType, ModuleOp, Parser -from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, IRDLExtension, YieldOp +from xdsl.ir import Attribute, Operation, SSAValue +from xdsl.parser import IndexType, ModuleOp +from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, YieldOp def add_attribute_constructors_from_irdl( @@ -134,36 +128,7 @@ def get_constraint_as_z3( assert False, f"Unsupported op {op.name}" -def main(): - arg_parser = argparse.ArgumentParser( - prog="check-irdl-subset", - description="Check if a group of IRDL variables represent a " - "subset of other IRDL variables.", - ) - arg_parser.add_argument( - "input_file", type=str, nargs="?", help="path to input file" - ) - args = arg_parser.parse_args() - - # Setup the xDSL context - ctx = MLContext() - ctx.load_dialect(Builtin) - ctx.load_dialect(Func) - ctx.load_dialect(IRDL) - ctx.load_dialect(IRDLExtension) - - # Grab the input program from the command line or a file - if args.input_file is None: - f = sys.stdin - else: - f = open(args.input_file) - - # - with f: - program = Parser(ctx, f.read()).parse_module() - - solver = z3.Solver() - +def check_subset_to_z3(program: ModuleOp, solver: z3.Solver): assert isinstance(main := program.ops.last, CheckSubsetOp) # The Attribute datatype is an union of all possible attributes found in the @@ -227,15 +192,3 @@ def add_constraint(constraint: Any): for lhs_arg, rhs_arg in zip(lhs_yield.args, rhs_yield.args): constraints.append(values_to_z3[lhs_arg] == values_to_z3[rhs_arg]) solver.add(z3.Not(z3.Exists(constants, z3.And(constraints)))) - - print("SMT program:") - print(solver) - if solver.check() == z3.sat: - print("sat: lhs is not a subset of rhs") - print("model: ", solver.model()) - else: - print("unsat: lhs is a subset of rhs") - - -if "__main__" == __name__: - main() diff --git a/xdsl_pdl/tools/test_check_irdl_subset.py b/xdsl_pdl/tools/test_check_irdl_subset.py new file mode 100644 index 0000000..3238844 --- /dev/null +++ b/xdsl_pdl/tools/test_check_irdl_subset.py @@ -0,0 +1,62 @@ +""" +Check if a group of IRDL variables represent a subset of other IRDL variables. +""" + +import argparse +import sys +import z3 + +from xdsl.dialects.builtin import ( + Builtin, +) +from xdsl.dialects.func import Func +from xdsl.dialects.irdl import IRDL + +from xdsl.ir import MLContext +from xdsl.parser import Parser +from xdsl_pdl.analysis.check_subset_to_z3 import check_subset_to_z3 +from xdsl_pdl.dialects.irdl_extension import IRDLExtension + + +def main(): + arg_parser = argparse.ArgumentParser( + prog="check-irdl-subset", + description="Check if a group of IRDL variables represent a " + "subset of other IRDL variables.", + ) + arg_parser.add_argument( + "input_file", type=str, nargs="?", help="path to input file" + ) + args = arg_parser.parse_args() + + # Setup the xDSL context + ctx = MLContext() + ctx.load_dialect(Builtin) + ctx.load_dialect(Func) + ctx.load_dialect(IRDL) + ctx.load_dialect(IRDLExtension) + + # Grab the input program from the command line or a file + if args.input_file is None: + f = sys.stdin + else: + f = open(args.input_file) + + # + with f: + program = Parser(ctx, f.read()).parse_module() + + solver = z3.Solver() + check_subset_to_z3(program, solver) + + print("SMT program:") + print(solver) + if solver.check() == z3.sat: + print("sat: lhs is not a subset of rhs") + print("model: ", solver.model()) + else: + print("unsat: lhs is a subset of rhs") + + +if "__main__" == __name__: + main()