diff --git a/xdsl_pdl/fuzzing/generate_pdl_matches.py b/xdsl_pdl/fuzzing/generate_pdl_matches.py index 9c11601..0332067 100644 --- a/xdsl_pdl/fuzzing/generate_pdl_matches.py +++ b/xdsl_pdl/fuzzing/generate_pdl_matches.py @@ -2,9 +2,9 @@ from dataclasses import dataclass, field from random import Random -from typing import Generator, Iterable +from typing import Generator, Generic, Iterable, TypeVar -from xdsl.ir import Attribute, MLContext, Operation, Region, SSAValue +from xdsl.ir import Attribute, Block, MLContext, Operation, Region, SSAValue from xdsl.dialects.builtin import ( IntegerAttr, IntegerType, @@ -130,25 +130,109 @@ def pdl_to_operations( return region, synth_ops +T = TypeVar("T") + + +@dataclass +class UnionFind(Generic[T]): + """Union-find data structure for representing equivalence classes.""" + + parents: dict[T, T] = field(default_factory=dict) + + def find(self, value: T) -> T: + if value not in self.parents: + self.parents[value] = value + if self.parents[value] == value: + return value + while self.parents[value] != value: + parent = self.parents[value] + grand_parent = self.parents[parent] + (value, self.parents[value]) = (parent, grand_parent) + return value + + def union(self, value1: T, value2: T) -> None: + self.parents[self.find(value1)] = self.find(value2) + + +def get_edges(ops: Iterable[Operation]) -> set[tuple[Operation, Operation]]: + """Get all edges of the DAG formed by the given operations.""" + edges = set[tuple[Operation, Operation]]() + for op in ops: + for operand in op.operands: + if isinstance(operand.owner, Operation) and operand.owner in ops: + edges.add((operand.owner, op)) + return edges + + +def get_roots(ops: list[Operation]) -> list[Operation]: + """Get all operations that do not depend on any other operations on the list.""" + roots = set(ops) + for _, to_op in get_edges(ops): + roots.discard(to_op) + return list(roots) + + +def get_connected_components(ops: list[Operation]) -> list[list[Operation]]: + """Get all connected components of the DAG formed by the given operations.""" + uf = UnionFind[Operation]() + for from_op, to_op in get_edges(ops): + uf.union(from_op, to_op) + components: dict[Operation, list[Operation]] = {} + for op in ops: + root = uf.find(op) + components.setdefault(root, []).append(op) + return list(components.values()) + + def get_all_interleavings( ops: list[Operation], -) -> Generator[list[Operation], None, None]: + current_block: Block, + region: Region, + ctx: MLContext, +) -> Generator[Region, None, None]: """ - Generate all possible interleavings of the given operations, + Generate all possible interleaving of the given operations, while respecting dominance order. """ - if ops == []: - yield [] + if not ops: + yield region return - for i in range(len(ops)): - dominating_ops = { - operand.owner - for operand in ops[i].operands - if isinstance(operand.owner, Operation) - } - if dominating_ops.isdisjoint(ops): - for interleaving in get_all_interleavings(ops[:i] + ops[i + 1 :]): - yield [ops[i]] + interleaving + + components = get_connected_components(ops) + + # If we have multiple connected components, we can split them, + # and recurse on each component. + if len(components) != 1: + component1 = components[0] + component2 = [op for component in components[1:] for op in component] + block1 = Block() + block2 = Block() + region.add_block(block1) + region.add_block(block2) + terminator = ctx.get_op("test.terminator").create(successors=[block1, block2]) + current_block.add_op(terminator) + + for _ in get_all_interleavings(component1, block1, region, ctx): + yield from get_all_interleavings(component2, block2, region, ctx) + + # Rollback the changes + current_block.erase_op(terminator) + region.erase_block(block1) + region.erase_block(block2) + return + + # If we have a single connected component, we add a root to the current block + roots = get_roots(ops) + assert roots + for root in roots: + current_block.add_op(root) + use_op = ctx.get_op("test.use_op").create(operands=root.results) + current_block.add_op(use_op) + for _ in get_all_interleavings(ops[1:], current_block, region, ctx): + yield region + current_block.erase_op(use_op) + current_block.erase_op(root) + return @@ -163,12 +247,4 @@ def get_all_matches( assert len(region.blocks[0].ops) == 0 region, ops = pdl_to_operations(pattern, region, ctx, randgen) - for interleaving in get_all_interleavings(ops): - for op in interleaving: - region.blocks[0].add_op(op) - region.blocks[0].add_op( - ctx.get_op("test.use_op").create(operands=op.results) - ) - yield region - for op in region.blocks[0].ops: - region.blocks[0].detach_op(op) + yield from get_all_interleavings(ops, region.blocks[0], region, ctx) diff --git a/xdsl_pdl/tools/generate_table.py b/xdsl_pdl/tools/generate_table.py index f41f366..011d315 100644 --- a/xdsl_pdl/tools/generate_table.py +++ b/xdsl_pdl/tools/generate_table.py @@ -17,7 +17,6 @@ PatternOp, ) from xdsl_pdl.analysis.pdl_analysis import ( - PDLAnalysisAborted, PDLAnalysisException, pdl_analysis_pass, )