Skip to content

Commit

Permalink
Improve counter-example generation
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 19, 2024
1 parent 1ee49fc commit f95abf3
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 25 deletions.
124 changes: 100 additions & 24 deletions xdsl_pdl/fuzzing/generate_pdl_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from dataclasses import dataclass, field
from random import Random
from typing import Generator, Iterable
from typing import Generator, Generic, Iterable, TypeVar

from xdsl.ir import Attribute, MLContext, Operation, Region, SSAValue
from xdsl.ir import Attribute, Block, MLContext, Operation, Region, SSAValue
from xdsl.dialects.builtin import (
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -130,25 +130,109 @@ def pdl_to_operations(
return region, synth_ops


T = TypeVar("T")


@dataclass
class UnionFind(Generic[T]):
"""Union-find data structure for representing equivalence classes."""

parents: dict[T, T] = field(default_factory=dict)

def find(self, value: T) -> T:
if value not in self.parents:
self.parents[value] = value
if self.parents[value] == value:
return value
while self.parents[value] != value:
parent = self.parents[value]
grand_parent = self.parents[parent]
(value, self.parents[value]) = (parent, grand_parent)
return value

def union(self, value1: T, value2: T) -> None:
self.parents[self.find(value1)] = self.find(value2)


def get_edges(ops: Iterable[Operation]) -> set[tuple[Operation, Operation]]:
"""Get all edges of the DAG formed by the given operations."""
edges = set[tuple[Operation, Operation]]()
for op in ops:
for operand in op.operands:
if isinstance(operand.owner, Operation) and operand.owner in ops:
edges.add((operand.owner, op))
return edges


def get_roots(ops: list[Operation]) -> list[Operation]:
"""Get all operations that do not depend on any other operations on the list."""
roots = set(ops)
for _, to_op in get_edges(ops):
roots.discard(to_op)
return list(roots)


def get_connected_components(ops: list[Operation]) -> list[list[Operation]]:
"""Get all connected components of the DAG formed by the given operations."""
uf = UnionFind[Operation]()
for from_op, to_op in get_edges(ops):
uf.union(from_op, to_op)
components: dict[Operation, list[Operation]] = {}
for op in ops:
root = uf.find(op)
components.setdefault(root, []).append(op)
return list(components.values())


def get_all_interleavings(
ops: list[Operation],
) -> Generator[list[Operation], None, None]:
current_block: Block,
region: Region,
ctx: MLContext,
) -> Generator[Region, None, None]:
"""
Generate all possible interleavings of the given operations,
Generate all possible interleaving of the given operations,
while respecting dominance order.
"""
if ops == []:
yield []
if not ops:
yield region
return
for i in range(len(ops)):
dominating_ops = {
operand.owner
for operand in ops[i].operands
if isinstance(operand.owner, Operation)
}
if dominating_ops.isdisjoint(ops):
for interleaving in get_all_interleavings(ops[:i] + ops[i + 1 :]):
yield [ops[i]] + interleaving

components = get_connected_components(ops)

# If we have multiple connected components, we can split them,
# and recurse on each component.
if len(components) != 1:
component1 = components[0]
component2 = [op for component in components[1:] for op in component]
block1 = Block()
block2 = Block()
region.add_block(block1)
region.add_block(block2)
terminator = ctx.get_op("test.terminator").create(successors=[block1, block2])
current_block.add_op(terminator)

for _ in get_all_interleavings(component1, block1, region, ctx):
yield from get_all_interleavings(component2, block2, region, ctx)

# Rollback the changes
current_block.erase_op(terminator)
region.erase_block(block1)
region.erase_block(block2)
return

# If we have a single connected component, we add a root to the current block
roots = get_roots(ops)
assert roots
for root in roots:
current_block.add_op(root)
use_op = ctx.get_op("test.use_op").create(operands=root.results)
current_block.add_op(use_op)
for _ in get_all_interleavings(ops[1:], current_block, region, ctx):
yield region
current_block.erase_op(use_op)
current_block.erase_op(root)

return


Expand All @@ -163,12 +247,4 @@ def get_all_matches(
assert len(region.blocks[0].ops) == 0

region, ops = pdl_to_operations(pattern, region, ctx, randgen)
for interleaving in get_all_interleavings(ops):
for op in interleaving:
region.blocks[0].add_op(op)
region.blocks[0].add_op(
ctx.get_op("test.use_op").create(operands=op.results)
)
yield region
for op in region.blocks[0].ops:
region.blocks[0].detach_op(op)
yield from get_all_interleavings(ops, region.blocks[0], region, ctx)
1 change: 0 additions & 1 deletion xdsl_pdl/tools/generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
PatternOp,
)
from xdsl_pdl.analysis.pdl_analysis import (
PDLAnalysisAborted,
PDLAnalysisException,
pdl_analysis_pass,
)
Expand Down

0 comments on commit f95abf3

Please sign in to comment.