diff --git a/numba_rvsdg/__init__.py b/numba_rvsdg/__init__.py index e69de29..3c59890 100644 --- a/numba_rvsdg/__init__.py +++ b/numba_rvsdg/__init__.py @@ -0,0 +1 @@ +from numba_rvsdg.core.datastructures.ast_transforms import AST2SCFG # noqa diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py new file mode 100644 index 0000000..30a93ad --- /dev/null +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -0,0 +1,665 @@ +import ast +import inspect +from typing import Callable, Any, MutableMapping +import textwrap +from dataclasses import dataclass + +from numba_rvsdg.core.datastructures.scfg import SCFG +from numba_rvsdg.core.datastructures.basic_block import PythonASTBlock + + +class WritableASTBlock: + """A basic block containing Python AST that can be written to. + + The recursive AST -> CFG algorithm requires a basic block that can be + written to. + + """ + + name: str + instructions: list[ast.AST] + jump_targets: list[str] + + def __init__( + self, + name: str, + instructions: list[ast.AST] | None = None, + jump_targets: list[str] | None = None, + ) -> None: + self.name = name + self.instructions: list[ast.AST] = ( + [] if instructions is None else instructions + ) + self.jump_targets: list[str] = ( + [] if jump_targets is None else jump_targets + ) + + def set_jump_targets(self, *indices: int) -> None: + """Set jump targets for the block.""" + self.jump_targets = [str(a) for a in indices] + + def is_instruction(self, instruction: type[ast.AST]) -> bool: + """Check if the last instruction is of a certain type.""" + return len(self.instructions) > 0 and isinstance( + self.instructions[-1], instruction + ) + + def is_return(self) -> bool: + """Check if the last instruction is a return statement.""" + return self.is_instruction(ast.Return) + + def is_break(self) -> bool: + """Check if the last instruction is a break statement.""" + return self.is_instruction(ast.Break) + + def is_continue(self) -> bool: + """Check if the last instruction is a continue statement.""" + return self.is_instruction(ast.Continue) + + def seal_outside_loop(self, index: int) -> None: + """Seal the block by setting the jump targets based on the last + instruction. + """ + if self.is_return(): + pass + else: + self.set_jump_targets(index) + + def seal_inside_loop( + self, head_index: int, exit_index: int, default_index: int + ) -> None: + """Seal the block by setting the jump targets based on the last + instruction and taking into account that this block is nested in a + loop. + """ + if self.is_continue(): + self.set_jump_targets(head_index) + elif self.is_break(): + self.set_jump_targets(exit_index) + elif self.is_return(): + pass + else: + self.set_jump_targets(default_index) + + def to_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "instructions": [ast.unparse(n) for n in self.instructions], + "jump_targets": self.jump_targets, + } + + def __repr__(self) -> str: + return ( + f"WritableASTBlock({self.name}, " + f"{self.instructions}, {self.jump_targets})" + ) + + +class ASTCFG(dict[str, WritableASTBlock]): + """A CFG consisting of WritableASTBlocks.""" + + unreachable: set[WritableASTBlock] + empty: set[WritableASTBlock] + noops: set[type[ast.AST]] + + def convert_blocks(self) -> MutableMapping[str, Any]: + """Convert WritableASTBlocks to PythonASTBlocks.""" + return { + v.name: PythonASTBlock( + v.name, + tree=v.instructions, + _jump_targets=tuple(v.jump_targets), + ) + for v in self.values() + } + + def to_dict(self) -> dict[str, dict[str, object]]: + """Convert ASTCFG to simple dict based data structure.""" + return {k: v.to_dict() for (k, v) in self.items()} + + def to_SCFG(self) -> SCFG: + """Convert ASTCFG to SCFG""" + return SCFG(graph=self.convert_blocks()) + + def prune_unreachable(self) -> set[WritableASTBlock]: + """Prune unreachable blocks from the CFG.""" + # Assume that the entry block is named zero (0). + to_visit, reachable, unreachable = set("0"), set(), set() + # Visit all reachable blocks. + while to_visit: + block = to_visit.pop() + if block not in reachable: + # Add block to reachable set. + reachable.add(block) + # Update to_visit with jump targets of the block. + to_visit.update(self[block].jump_targets) + # Remove unreachable blocks. + for block in list(self.keys()): + if block not in reachable: + unreachable.add(self.pop(block)) + self.unreachable = unreachable + return unreachable + + def prune_noops(self) -> set[type[ast.AST]]: + """Prune no-op instructions from the CFG.""" + noops = set() + exclude = (ast.Pass, ast.Continue, ast.Break) + for block in self.values(): + block.instructions = [ + i for i in block.instructions if not isinstance(i, exclude) + ] + noops.update( + [i for i in block.instructions if isinstance(i, exclude)] + ) + self.noops = noops # type: ignore + return noops # type: ignore + + def prune_empty(self) -> set[WritableASTBlock]: + """Prune empty blocks from the CFG.""" + empty = set() + for name, block in list(self.items()): + if not block.instructions: + empty.add(self.pop(name)) + # Empty blocks can only have a single jump target. + it = block.jump_targets[0] + # Iterate over the blocks looking for blocks that point to the + # removed block. Then rewire the jump_targets accordingly. + for b in list(self.values()): + if len(b.jump_targets) == 0: + continue + elif len(b.jump_targets) == 1: + if b.jump_targets[0] == name: + b.jump_targets[0] = it + elif len(b.jump_targets) == 2: + if b.jump_targets[0] == name: + b.jump_targets[0] = it + elif b.jump_targets[1] == name: + b.jump_targets[1] = it + self.empty = empty + return empty + + +@dataclass(frozen=True) +class LoopIndices: + """Structure to hold the head and exit block indices of a loop.""" + + head: int + exit: int + + +class AST2SCFGTransformer: + """AST2SCFGTransformer + + The AST2SCFGTransformer class is responsible for transforming code in the + form of a Python Abstract Syntax Tree (AST) into CFG/SCFG. + + """ + + # Prune noop statements and unreachable/empty blocks from the CFG. + prune: bool + # The code to be transformed. + code: str | list[ast.FunctionDef] | Callable[..., Any] + tree: list[type[ast.AST]] + # Monotonically increasing block index, starts at 1. + block_index: int + # The current block being modified + current_block: WritableASTBlock + # Dict mapping block indices as strings to WritableASTBlocks. + # (This is the data structure to hold the CFG.) + blocks: ASTCFG + # Stack for header and exiting block of current loop. + loop_stack: list[LoopIndices] + + def __init__( + self, + code: str | list[ast.FunctionDef] | Callable[..., Any], + prune: bool = True, + ) -> None: + self.prune = prune + self.code = code + self.tree = AST2SCFGTransformer.unparse_code(code) + self.block_index: int = 1 # 0 is reserved for genesis block + self.blocks = ASTCFG() + # Initialize first (genesis) block, assume it's named zero. + # (This also initializes the self.current_block attribute.) + self.add_block(0) + self.loop_stack: list[LoopIndices] = [] + + @staticmethod + def unparse_code( + code: str | list[ast.FunctionDef] | Callable[..., Any] + ) -> list[type[ast.AST]]: + # Convert source code into AST. + if isinstance(code, str): + tree = ast.parse(code).body + elif callable(code): + tree = ast.parse(textwrap.dedent(inspect.getsource(code))).body + elif ( + isinstance(code, list) + and len(code) > 0 + and all([isinstance(i, ast.AST) for i in code]) + ): + tree = code # type: ignore + else: + msg = "Type: '{type(self.code}}' is not implemented." + raise NotImplementedError(msg) + return tree # type: ignore + + def transform_to_ASTCFG(self) -> ASTCFG: + """Generate ASTCFG from Python function.""" + self.transform() + return self.blocks + + def transform_to_SCFG(self) -> SCFG: + """Generate SCFG from Python function.""" + self.transform() + return self.blocks.to_SCFG() + + def add_block(self, index: int) -> None: + """Create block, add to CFG and set as current_block.""" + self.blocks[str(index)] = self.current_block = WritableASTBlock( + name=str(index) + ) + + def seal_block(self, default_index: int) -> None: + """Seal the current block by setting the jump_targets.""" + if self.loop_stack: + self.current_block.seal_inside_loop( + self.loop_stack[-1].head, + self.loop_stack[-1].exit, + default_index, + ) + else: + self.current_block.seal_outside_loop(default_index) + + def transform(self) -> None: + """Transform Python function stored as self.code.""" + # Assert that the code handed in was a function, we can only transform + # functions. + assert isinstance(self.tree[0], ast.FunctionDef) + # Run recursive code generation. + self.codegen(self.tree) + # Prune if requested. + if self.prune: + _ = self.blocks.prune_unreachable() + _ = self.blocks.prune_noops() + _ = self.blocks.prune_empty() + + def codegen(self, tree: list[type[ast.AST]] | list[ast.stmt]) -> None: + """Recursively transform from a list of AST nodes. + + The function is called 'codegen' as it generates an intermediary + representation (IR) from an abstract syntax tree (AST). The name was + chosen to honour the compiler writing tradition, where this type of + recursive function is commonly called 'codegen'. + + """ + for node in tree: + self.handle_ast_node(node) + + def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None: + """Dispatch an AST node to handle.""" + if isinstance(node, ast.FunctionDef): + self.handle_function_def(node) + elif isinstance( + node, + ( + ast.Assign, + ast.AugAssign, + ast.Expr, + ast.Return, + ast.Break, + ast.Continue, + ast.Pass, + ), + ): + self.current_block.instructions.append(node) + elif isinstance(node, ast.If): + self.handle_if(node) + elif isinstance(node, ast.While): + self.handle_while(node) + elif isinstance(node, ast.For): + self.handle_for(node) + else: + raise NotImplementedError(f"Node type {node} not implemented") + + def handle_function_def(self, node: ast.FunctionDef) -> None: + """Handle a function definition.""" + # Insert implicit return None, if the function isn't terminated. May + # end up being an unreachable block if all other paths through the + # program already call return. + if not isinstance(node.body[-1], ast.Return): + node.body.append(ast.Return(None)) + self.codegen(node.body) + + def handle_if(self, node: ast.If) -> None: + """Handle if statement.""" + # Preallocate block indices for then, else, and end-if. + then_index = self.block_index + else_index = self.block_index + 1 + enif_index = self.block_index + 2 + self.block_index += 3 + + # Emit comparison value to current/header block. + self.current_block.instructions.append(node.test) + # Setup jump targets for current/header block. + self.current_block.set_jump_targets(then_index, else_index) + + # Create a new block for the then branch. + self.add_block(then_index) + # Recursively transform then branch (this may alter the current_block). + self.codegen(node.body) + # After recursion, current_block may need a jump target. + self.seal_block(enif_index) + + # Create a new block for the else branch. + self.add_block(else_index) + # Recursively transform then branch (this may alter the current_block). + self.codegen(node.orelse) + # After recursion, current_block may need a jump target. + self.seal_block(enif_index) + + # Create a new block and assign it to the be the current_block, this + # will hold the end-if statements if any exist. We leave 'open' for + # modification. + self.add_block(enif_index) + + def handle_while(self, node: ast.While) -> None: + """Handle while statement.""" + # If the current block already has instructions, we need a new block as + # header. Otherwise just re-use the current_block. This happens + # when the previous statement was an if-statement with an empty + # endif_block, for example. This is possible because the Python + # while-loop does not need to modify it's preheader. + + # Preallocate header, body, else and exiting indices. + # (Technically, we could re-use the current block as header if it is + # still empty. We elect to potentially leave a block empty instead, + # since there is a pass to prune empty blocks anyway.) + head_index = self.block_index + body_index = self.block_index + 1 + exit_index = self.block_index + 2 + else_index = self.block_index + 3 + self.block_index += 4 + + self.current_block.set_jump_targets(head_index) + # And create new header block + self.add_block(head_index) + + # Emit comparison expression into header. + self.current_block.instructions.append(node.test) + # Set the jump targets to be the body and the else branch. + self.current_block.set_jump_targets(body_index, else_index) + + # Create body block. + self.add_block(body_index) + + # Push to loop stack for recursion. + self.loop_stack.append(LoopIndices(head_index, exit_index)) + + # Recurs into the body of the while statement. (This may modify + # current_block). + self.codegen(node.body) + # After recursion, seal current_block. This sets the jump targets based + # on the last instruction in the current_block. + self.seal_block(head_index) + + # Pop values from loop stack post recursion. + loop_indices = self.loop_stack.pop() + assert ( + loop_indices.head == head_index and loop_indices.exit == exit_index + ) + + # Create else block. + self.add_block(else_index) + + # Recurs into the body of the else-branch, again this may modify the + # current_block. + self.codegen(node.orelse) + + # Seal current_block. + self.seal_block(exit_index) + + # Create exit block and leave open for modifictaion. + self.add_block(exit_index) + + def handle_for(self, node: ast.For) -> None: + """Handle for statement. + + The Python 'for' statement needs to be decomposed into a series of + equivalent Python statements, since the semantics of the statement can + not be represented in the control flow graph (CFG) formalism of blocks + with directed edges. We note that the for-loop in Python is effectively + syntactic sugar for a generalised c-style while-loop. To our advantage, + this while-loop can indeed be represented using the blocks and directed + edges of the CFG formalism and allows us to transform the Python + for-loop construct. This docstring explains the decomposition + from for- into while-loop. + + Remember that the for-loop has a target variable that will be assigned, + an iterator to iterate over, a loop body and an else clause. The AST + node has the following signature: + + ast.For(target, iter, body, orelse, type_comment) + + Remember also that Python for-loops can have an else-branch, that is + executed upon regular loop conclusion. + + def function(a: int) -> None + c = 0 + for i in range(10): + c += i + if i == a: + i = 420 # set i arbitrarily + break # early exit, break from loop, bypass else-branch + else: + c += 1 # loop conclusion, i.e. we have not hit the break + return c, i + + So, effectively, to decompose the for-loop, we need to setup the + iterator by calling 'iter(iter)' and assign it to a variable, + initialize the target variable to be None and then check if the + iterator has a next value. If it does, we need to assign that value to + the target variable, enter the body and then check the iterator again + and again and again.. until there are no items left, at which point we + execute the else-branch. + + The Python for-loop usually waits for the iterator to raise a + StopIteration exception to determine when the iteration has concluded. + However, it is possible to use the 'next()' method with a second + argument to avoid exception handling here. We do this so we don't need + to rely on being able to transform exceptions as part of this + transformer. + + i = next(iter, "__sentinel__") + if i != "__sentinel__": + ... + + Lastly, it is important to also remember that the target variable + escapes the scope of the for loop: + + >>> for i in range(1): + ... print("hello loop") + ... + hello loop + >>> i + 0 + >>> + + So, to summarize: we want to decompose a Python for loop into a while + loop with some assignments and he target variable must escape the + scope. + + Consider again the following function: + + def function(a: int) -> None + c = 0 + for i in range(10): + c += i + if i == a: + i = 420 + break + else: + c += 1 + return c, i + + This will be decomposed as the following construct that can be encoded + using the available block and edge primitives of the CFG. + + def function(a: int) -> None + c = 0 + * __iterator_1__ = iter(range(10)) # setup iterator + * i = None # assign target, in this case i + while True: # loop until we break + * __iter_last_1__ = i # backup value of i + * i = next(__iterator_1__, '__sentinel__') # get next i + * if i != '__sentinel__': # regular iteration + c += i # add to accumulator + if i == a: # check for early exit + i = 420 # set i to some wild value + break # early exit break while True + else: # for-else clause + * i == __iter_last_1__ # restore value of i + c += 1 # execute code in for-else clause + break # regular exit break while True + return c, i + + The above is actually a full Python source reconstruction. In the + implementation below, it is only necessary to emit some of the special + assignments (marked above with a *-prefix above) into the blocks of the + CFG. All of the control-flow inside the function will be represented + by the directed edges of the CFG. + + The first two assignments are for the pre-header: + + * __iterator_1__ = iter(range(10)) # setup iterator + * i = None # assign target, in this case i + + The next three is for the header, the predicate determines the end of + the loop. + + * __iter_last_1__ = i # backup value of i + * i = next(__iterator_1__, '__sentinel__') # get next i + * if i != '__sentinel__': # regular iteration + + And lastly, one assignment in the for-else clause + + * i == __iter_last_1__ # restore value of i + + We modify the pre-header, the header and the else blocks with + appropriate Python statements in the following implementation. The + Python code is injected by generating Python source using f-strings and + then using the 'unparse()' function of the 'ast' module to then use the + 'codegen' method of this transformer to emit the required 'ast.AST' + objects into the blocks of the CFG. + + Lastly the important thing to observe is that we can not ignore the + else clause, since this must contain the reset of the variable i, which + will have been set to '__sentinel__'. This reset is required such that + the target variable 'i' will escape the scope of the for-loop. + + """ + # Preallocate indices for header, body, else, and exiting blocks. + head_index = self.block_index + body_index = self.block_index + 1 + else_index = self.block_index + 2 + exit_index = self.block_index + 3 + self.block_index += 4 + + # Assign the components of the for-loop to variables. These variables + # are versioned using the index of the loop header so that scopes can + # be nested. While this is strictly required for the 'iter_setup' it is + # technically optional for the 'last_target_value'... But, we version + # it too so that the two can easily be matched when visually inspecting + # the CFG. + target = ast.unparse(node.target) + iter_setup = ast.unparse(node.iter) + iter_assign = f"__iterator_{head_index}__" + last_target_value = f"__iter_last_{head_index}__" + + # Emit iterator setup to pre-header. + preheader_code = textwrap.dedent( + f""" + {iter_assign} = iter({iter_setup}) + {target} = None + """ + ) + self.codegen(ast.parse(preheader_code).body) + + # Point the current_block to header block. + self.current_block.set_jump_targets(head_index) + # And create new header block. + self.add_block(head_index) + + # Emit header instructions. This first makes a backup of the iteration + # target and then checks if the iterator is exhausted and if the loop + # should continue. The '__sentinel__' is an singleton style marker, so + # it need not be versioned. + + header_code = textwrap.dedent( + f""" + {last_target_value} = {target} + {target} = next({iter_assign}, "__sentinel__") + {target} != "__sentinel__" + """ + ) + self.codegen(ast.parse(header_code).body) + # Set the jump targets to be the body and the else block. + self.current_block.set_jump_targets(body_index, else_index) + + # Create body block. + self.add_block(body_index) + + # Setup loop stack for recursion. + self.loop_stack.append(LoopIndices(head_index, exit_index)) + + # Recurs into the loop body (this may modify current_block). + self.codegen(node.body) + # After recursion, seal current block. + self.seal_block(head_index) + + # Pop values from loop stack post recursion. + loop_indices = self.loop_stack.pop() + assert ( + loop_indices.head == head_index and loop_indices.exit == exit_index + ) + + # Create else block. + self.add_block(else_index) + + # Emit orelse instructions. Needs to be prefixed with an assignment + # such that the for loop target can escape the scope of the loop. + else_code = textwrap.dedent( + f""" + {target} = {last_target_value} + """ + ) + self.codegen(ast.parse(else_code).body) + + # Recurs into the body of the else-branch. + self.codegen(node.orelse) + + # Seal current block, whatever it may be. + self.seal_block(exit_index) + + # Create exit block and leave open for modification + self.add_block(exit_index) + + def render(self) -> None: + """Render the CFG contained in this transformer as a SCFG. + + Useful for debugging purposes, set a breakpoint and then render to view + intermediary results. + + """ + self.blocks.to_SCFG().render() + + +def AST2SCFG(code: str | list[ast.FunctionDef] | Callable[..., Any]) -> SCFG: + """Transform Python function into an SCFG.""" + return AST2SCFGTransformer(code).transform_to_SCFG() + + +def SCFG2AST(scfg: SCFG) -> ast.FunctionDef: # type: ignore + """Transform SCFG with PythonASTBlocks into an AST FunctionDef.""" + # TODO diff --git a/numba_rvsdg/core/datastructures/basic_block.py b/numba_rvsdg/core/datastructures/basic_block.py index 758d932..4619f22 100644 --- a/numba_rvsdg/core/datastructures/basic_block.py +++ b/numba_rvsdg/core/datastructures/basic_block.py @@ -1,4 +1,5 @@ import dis +import ast from typing import Tuple, Dict, List, Optional from dataclasses import dataclass, replace, field @@ -199,6 +200,30 @@ def get_instructions( return out +@dataclass(frozen=True) +class PythonASTBlock(BasicBlock): + """The PythonASTBlock class is a subclass of the BasicBlock that + represents basic blocks with Python AST. + + Attributes + ---------- + begin: int + The starting line. + + end: int + The ending line. + """ + + begin: int = -1 + + end: int = -1 + + tree: List[ast.AST] = field(default_factory=lambda: []) + + def get_tree(self) -> List[ast.AST]: + return self.tree + + @dataclass(frozen=True) class SyntheticBlock(BasicBlock): """The SyntheticBlock represents a artificially added block in a diff --git a/numba_rvsdg/core/datastructures/scfg.py b/numba_rvsdg/core/datastructures/scfg.py index 5e818b9..6508190 100644 --- a/numba_rvsdg/core/datastructures/scfg.py +++ b/numba_rvsdg/core/datastructures/scfg.py @@ -10,6 +10,7 @@ Optional, Generator, Mapping, + MutableMapping, Sized, ) from textwrap import indent @@ -164,7 +165,7 @@ class SCFG(Sized): regions, and variables. """ - graph: Dict[str, BasicBlock] = field(default_factory=dict) + graph: MutableMapping[str, BasicBlock] = field(default_factory=dict) name_gen: NameGenerator = field( default_factory=NameGenerator, compare=False @@ -238,7 +239,11 @@ def __iter__(self) -> Generator[Tuple[str, BasicBlock], None, None]: over the given view. """ # initialise housekeeping datastructures - to_visit, seen = [self.find_head()], [] + try: + to_visit = [self.find_head()] + seen: list[str] = [] + except KeyError: + to_visit, seen = ["0"], [] while to_visit: # get the next name on the list name = to_visit.pop(0) @@ -812,6 +817,10 @@ def view(self, name: Optional[str] = None) -> None: SCFGRenderer(self).view(name) + def render(self) -> None: + """Alias for view().""" + self.view() + @staticmethod def from_yaml(yaml_string: str) -> "Tuple[SCFG, Dict[str, str]]": """Static method that creates an SCFG object from a YAML diff --git a/numba_rvsdg/rendering/rendering.py b/numba_rvsdg/rendering/rendering.py index 0ef79bb..ded8086 100644 --- a/numba_rvsdg/rendering/rendering.py +++ b/numba_rvsdg/rendering/rendering.py @@ -1,9 +1,11 @@ +import ast import logging from abc import abstractmethod from numba_rvsdg.core.datastructures.basic_block import ( BasicBlock, RegionBlock, PythonBytecodeBlock, + PythonASTBlock, SyntheticAssignment, SyntheticBranch, SyntheticBlock, @@ -70,6 +72,8 @@ def render_block( self.render_basic_block(digraph, name, block) if type(block) == PythonBytecodeBlock: # noqa: E721 self.render_basic_block(digraph, name, block) + if type(block) == PythonASTBlock: # noqa: E721 + self.render_python_ast_block(digraph, name, block) # type: ignore elif type(block) == SyntheticAssignment: # noqa: E721 self.render_control_variable_block(digraph, name, block) elif isinstance(block, SyntheticBranch): @@ -103,7 +107,10 @@ def find_base_header(block: BasicBlock) -> BasicBlock: continue src_block = find_base_header(src_block) for dst_name in src_block.jump_targets: - dst_name = find_base_header(blocks[dst_name]).name + try: + dst_name = find_base_header(blocks[dst_name]).name + except KeyError: + continue if dst_name in blocks.keys(): self.g.edge(str(src_block.name), str(dst_name)) else: @@ -272,6 +279,24 @@ def render_basic_block( digraph.node(str(name), shape="rect", label=body) + def render_python_ast_block( + self, digraph: "Digraph", name: str, block: BasicBlock + ) -> None: + code = r"\l".join( + ast.unparse(n) for n in block.get_tree() # type: ignore + ) + body = ( + name + + "\n\n" + + code + + r"\l\ljump targets: " + + str(block.jump_targets) + + r"\lback edges: " + + str(block.backedges) + ) + + digraph.node(str(name), shape="rect", label=body) + def render_control_variable_block( self, digraph: "Digraph", name: str, block: SyntheticAssignment ) -> None: @@ -311,6 +336,13 @@ def render_branching_block( raise Exception("Unknown name type: " + name) digraph.node(str(name), shape="rect", label=body) + def render_scfg(self) -> "Digraph": + """Renders the provided SCFG object.""" + for name, block in self.scfg.graph.items(): # type: ignore + self.render_block(self.g, name, block) + self.render_edges(self.scfg) # type: ignore + return self.g + def view(self, name: Optional[str] = None) -> None: """Method used to view the current SCFG as an external graphviz generated PDF file. @@ -386,4 +418,4 @@ def render_scfg(scfg: SCFG) -> None: The structured control flow graph (SCFG) to be rendered. """ # is this function used?? - ByteFlowRenderer().render_scfg(scfg).view("scfg") # type: ignore + SCFGRenderer(scfg).view("scfg") diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py new file mode 100644 index 0000000..bc55d1d --- /dev/null +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -0,0 +1,1071 @@ +# mypy: ignore-errors +import ast +import textwrap +from typing import Callable, Any +from unittest import main, TestCase + +from numba_rvsdg.core.datastructures.ast_transforms import AST2SCFGTransformer + + +class TestAST2SCFGTransformer(TestCase): + + def compare( + self, + function: Callable[..., Any], + expected: dict[str, dict[str, Any]], + unreachable: set[int] = set(), + empty: set[int] = set(), + ): + transformer = AST2SCFGTransformer(function) + astcfg = transformer.transform_to_ASTCFG() + self.assertEqual(expected, astcfg.to_dict()) + self.assertEqual(unreachable, {i.name for i in astcfg.unreachable}) + self.assertEqual(empty, {i.name for i in astcfg.empty}) + + def setUp(self): + self.maxDiff = None + + def test_solo_return(self): + def function() -> int: + return 1 + + expected = { + "0": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "0", + } + } + self.compare(function, expected) + + def test_solo_return_from_string(self): + function = textwrap.dedent( + """ + def function() -> int: + return 1 + """ + ) + + expected = { + "0": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "0", + } + } + self.compare(function, expected) + + def test_solo_return_from_AST(self): + function = ast.parse(textwrap.dedent( + """ + def function() -> int: + return 1 + """)).body + + expected = { + "0": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "0", + } + } + self.compare(function, expected) + + def test_solo_assign(self): + def function() -> None: + x = 1 # noqa: F841 + + expected = { + "0": { + "instructions": ["x = 1", "return"], + "jump_targets": [], + "name": "0", + } + } + self.compare(function, expected) + + def test_solo_pass(self): + def function() -> None: + pass + + expected = { + "0": { + "instructions": ["return"], + "jump_targets": [], + "name": "0", + } + } + self.compare(function, expected) + + def test_assign_return(self): + def function() -> int: + x = 1 + return x + + expected = { + "0": { + "instructions": ["x = 1", "return x"], + "jump_targets": [], + "name": "0", + } + } + self.compare(function, expected) + + def test_if_return(self): + def function(x: int) -> int: + if x < 10: + return 1 + return 2 + + expected = { + "0": { + "instructions": ["x < 10"], + "jump_targets": ["1", "3"], + "name": "0", + }, + "1": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "1", + }, + "3": { + "instructions": ["return 2"], + "jump_targets": [], + "name": "3", + }, + } + self.compare(function, expected, empty={"2"}) + + def test_if_else_return(self): + def function(x: int) -> int: + if x < 10: + return 1 + else: + return 2 + + expected = { + "0": { + "instructions": ["x < 10"], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "1", + }, + "2": { + "instructions": ["return 2"], + "jump_targets": [], + "name": "2", + }, + } + self.compare(function, expected, unreachable={"3"}) + + def test_if_else_assign(self): + def function(x: int) -> int: + if x < 10: + z = 1 + else: + z = 2 + return z + + expected = { + "0": { + "instructions": ["x < 10"], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["z = 1"], + "jump_targets": ["3"], + "name": "1", + }, + "2": { + "instructions": ["z = 2"], + "jump_targets": ["3"], + "name": "2", + }, + "3": { + "instructions": ["return z"], + "jump_targets": [], + "name": "3", + }, + } + self.compare(function, expected) + + def test_nested_if(self): + def function(x: int, y: int) -> int: + if x < 10: + if y < 5: + y = 1 + else: + y = 2 + else: + if y < 15: + y = 3 + else: + y = 4 + return y + + expected = { + "0": { + "instructions": ["x < 10"], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["y < 5"], + "jump_targets": ["4", "5"], + "name": "1", + }, + "2": { + "instructions": ["y < 15"], + "jump_targets": ["7", "8"], + "name": "2", + }, + "3": { + "instructions": ["return y"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["y = 1"], + "jump_targets": ["3"], + "name": "4", + }, + "5": { + "instructions": ["y = 2"], + "jump_targets": ["3"], + "name": "5", + }, + "7": { + "instructions": ["y = 3"], + "jump_targets": ["3"], + "name": "7", + }, + "8": { + "instructions": ["y = 4"], + "jump_targets": ["3"], + "name": "8", + }, + } + self.compare(function, expected, empty={"6", "9"}) + + def test_nested_if_with_empty_else_and_return(self): + def function(x: int, y: int) -> None: + y << 2 + if x < 10: + y -= 1 + if y < 5: + y = 1 + else: + if y < 15: + y = 2 + else: + return + y += 1 + return y + + expected = { + "0": { + "instructions": ["y << 2", "x < 10"], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["y -= 1", "y < 5"], + "jump_targets": ["4", "3"], + "name": "1", + }, + "2": { + "instructions": ["y < 15"], + "jump_targets": ["7", "8"], + "name": "2", + }, + "3": { + "instructions": ["return y"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["y = 1"], + "jump_targets": ["3"], + "name": "4", + }, + "7": { + "instructions": ["y = 2"], + "jump_targets": ["9"], + "name": "7", + }, + "8": {"instructions": ["return"], "jump_targets": [], "name": "8"}, + "9": { + "instructions": ["y += 1"], + "jump_targets": ["3"], + "name": "9", + }, + } + self.compare(function, expected, empty={"5", "6"}) + + def test_elif(self): + def function(x: int, a: int, b: int) -> int: + if x < 10: + return + elif x < 15: + y = b - a + elif x < 20: + y = a**2 + else: + y = a - b + return y + + expected = { + "0": { + "instructions": ["x < 10"], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": {"instructions": ["return"], "jump_targets": [], "name": "1"}, + "2": { + "instructions": ["x < 15"], + "jump_targets": ["4", "5"], + "name": "2", + }, + "3": { + "instructions": ["return y"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["y = b - a"], + "jump_targets": ["3"], + "name": "4", + }, + "5": { + "instructions": ["x < 20"], + "jump_targets": ["7", "8"], + "name": "5", + }, + "7": { + "instructions": ["y = a ** 2"], + "jump_targets": ["3"], + "name": "7", + }, + "8": { + "instructions": ["y = a - b"], + "jump_targets": ["3"], + "name": "8", + }, + } + self.compare(function, expected, empty={"9", "6"}) + + def test_simple_while(self): + def function() -> int: + x = 0 + while x < 10: + x += 1 + return x + + expected = { + "0": { + "instructions": ["x = 0"], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": ["x < 10"], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": ["x += 1"], + "jump_targets": ["1"], + "name": "2", + }, + "3": { + "instructions": ["return x"], + "jump_targets": [], + "name": "3", + }, + } + self.compare(function, expected, empty={"4"}) + + def test_nested_while(self): + def function() -> tuple[int, int]: + x, y = 0, 0 + while x < 10: + while y < 5: + x += 1 + y += 1 + x += 1 + return x, y + + expected = { + "0": { + "instructions": ["x, y = (0, 0)"], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": ["x < 10"], + "jump_targets": ["5", "3"], + "name": "1", + }, + "3": { + "instructions": ["return (x, y)"], + "jump_targets": [], + "name": "3", + }, + "5": { + "instructions": ["y < 5"], + "jump_targets": ["6", "7"], + "name": "5", + }, + "6": { + "instructions": ["x += 1", "y += 1"], + "jump_targets": ["5"], + "name": "6", + }, + "7": { + "instructions": ["x += 1"], + "jump_targets": ["1"], + "name": "7", + }, + } + + self.compare(function, expected, empty={"2", "4", "8"}) + + def test_if_in_while(self): + def function() -> int: + x = 0 + while x < 10: + if x < 5: + x += 2 + else: + x += 1 + return x + + expected = { + "0": { + "instructions": ["x = 0"], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": ["x < 10"], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": ["x < 5"], + "jump_targets": ["5", "6"], + "name": "2", + }, + "3": { + "instructions": ["return x"], + "jump_targets": [], + "name": "3", + }, + "5": { + "instructions": ["x += 2"], + "jump_targets": ["1"], + "name": "5", + }, + "6": { + "instructions": ["x += 1"], + "jump_targets": ["1"], + "name": "6", + }, + } + self.compare(function, expected, empty={"4", "7"}) + + def test_while_in_if(self): + def function(a: bool) -> int: + x = 0 + if a is True: + while x < 10: + x += 2 + else: + while x < 10: + x += 1 + return x + + expected = { + "0": { + "instructions": ["x = 0", "a is True"], + "jump_targets": ["4", "8"], + "name": "0", + }, + "3": { + "instructions": ["return x"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["x < 10"], + "jump_targets": ["5", "3"], + "name": "4", + }, + "5": { + "instructions": ["x += 2"], + "jump_targets": ["4"], + "name": "5", + }, + "8": { + "instructions": ["x < 10"], + "jump_targets": ["9", "3"], + "name": "8", + }, + "9": { + "instructions": ["x += 1"], + "jump_targets": ["8"], + "name": "9", + }, + } + self.compare( + function, expected, empty={"1", "2", "6", "7", "10", "11"} + ) + + def test_while_break_continue(self): + def function() -> int: + x = 0 + while x < 10: + x += 1 + if x % 2 == 0: + continue + elif x == 9: + break + else: + x += 1 + return x + + expected = { + "0": { + "instructions": ["x = 0"], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": ["x < 10"], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": ["x += 1", "x % 2 == 0"], + "jump_targets": ["1", "6"], + "name": "2", + }, + "3": { + "instructions": ["return x"], + "jump_targets": [], + "name": "3", + }, + "6": { + "instructions": ["x == 9"], + "jump_targets": ["3", "9"], + "name": "6", + }, + "9": { + "instructions": ["x += 1"], + "jump_targets": ["1"], + "name": "9", + }, + } + self.compare(function, expected, empty={"4", "5", "7", "8", "10"}) + + def test_while_else(self): + def function() -> int: + x = 0 + while x < 10: + x += 1 + else: + x += 1 + return x + + expected = { + "0": { + "instructions": ["x = 0"], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": ["x < 10"], + "jump_targets": ["2", "4"], + "name": "1", + }, + "2": { + "instructions": ["x += 1"], + "jump_targets": ["1"], + "name": "2", + }, + "3": { + "instructions": ["return x"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["x += 1"], + "jump_targets": ["3"], + "name": "4", + }, + } + self.compare(function, expected) + + def test_simple_for(self): + def function() -> int: + c = 0 + for i in range(10): + c += i + return c + + expected = { + "0": { + "instructions": [ + "c = 0", + "__iterator_1__ = iter(range(10))", + "i = None", + ], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__iter_last_1__ = i", + "i = next(__iterator_1__, '__sentinel__')", + "i != '__sentinel__'", + ], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": ["c += i"], + "jump_targets": ["1"], + "name": "2", + }, + "3": { + "instructions": ["i = __iter_last_1__"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["return c"], + "jump_targets": [], + "name": "4", + }, + } + self.compare(function, expected) + + def test_nested_for(self): + def function() -> int: + c = 0 + for i in range(3): + c += i + for j in range(3): + c += j + return c + + expected = { + "0": { + "instructions": [ + "c = 0", + "__iterator_1__ = iter(range(3))", + "i = None", + ], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__iter_last_1__ = i", + "i = next(__iterator_1__, '__sentinel__')", + "i != '__sentinel__'", + ], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": [ + "c += i", + "__iterator_5__ = iter(range(3))", + "j = None", + ], + "jump_targets": ["5"], + "name": "2", + }, + "3": { + "instructions": ["i = __iter_last_1__"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["return c"], + "jump_targets": [], + "name": "4", + }, + "5": { + "instructions": [ + "__iter_last_5__ = j", + "j = next(__iterator_5__, '__sentinel__')", + "j != '__sentinel__'", + ], + "jump_targets": ["6", "7"], + "name": "5", + }, + "6": { + "instructions": ["c += j"], + "jump_targets": ["5"], + "name": "6", + }, + "7": { + "instructions": ["j = __iter_last_5__"], + "jump_targets": ["1"], + "name": "7", + }, + } + self.compare(function, expected, empty={"8"}) + + def test_for_with_return_break_and_continue(self): + def function(a: int, b: int) -> int: + for i in range(2): + if i == a: + i = 3 + return i + elif i == b: + i = 4 + break + else: + continue + return i + + expected = { + "0": { + "instructions": [ + "__iterator_1__ = iter(range(2))", + "i = None", + ], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__iter_last_1__ = i", + "i = next(__iterator_1__, '__sentinel__')", + "i != '__sentinel__'", + ], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": ["i == a"], + "jump_targets": ["5", "6"], + "name": "2", + }, + "3": { + "instructions": ["i = __iter_last_1__"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["return i"], + "jump_targets": [], + "name": "4", + }, + "5": { + "instructions": ["i = 3", "return i"], + "jump_targets": [], + "name": "5", + }, + "6": { + "instructions": ["i == b"], + "jump_targets": ["8", "1"], + "name": "6", + }, + "8": { + "instructions": ["i = 4"], + "jump_targets": ["4"], + "name": "8", + }, + } + self.compare(function, expected, unreachable={"7", "10"}, empty={"9"}) + + def test_for_with_if_in_else(self): + def function(a: int): + c = 0 + for i in range(10): + c += i + else: + if a: + r = c + else: + r = -1 * c + return r + + expected = { + "0": { + "instructions": [ + "c = 0", + "__iterator_1__ = iter(range(10))", + "i = None", + ], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__iter_last_1__ = i", + "i = next(__iterator_1__, '__sentinel__')", + "i != '__sentinel__'", + ], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": ["c += i"], + "jump_targets": ["1"], + "name": "2", + }, + "3": { + "instructions": ["i = __iter_last_1__", "a"], + "jump_targets": ["5", "6"], + "name": "3", + }, + "4": { + "instructions": ["return r"], + "jump_targets": [], + "name": "4", + }, + "5": { + "instructions": ["r = c"], + "jump_targets": ["4"], + "name": "5", + }, + "6": { + "instructions": ["r = -1 * c"], + "jump_targets": ["4"], + "name": "6", + }, + } + self.compare(function, expected, empty={"7"}) + + def test_for_with_nested_for_else(self): + def function(a: bool) -> int: + c = 1 + for i in range(1): + for j in range(1): + if a: + c *= 3 + break # This break decides, if True skip continue. + else: + c *= 5 + continue # Causes break below to be skipped. + c *= 7 + break # Causes the else below to be skipped + else: + c *= 9 # Not breaking in inner loop leads here + return c + + self.assertEqual(function(True), 3 * 7) + self.assertEqual(function(False), 5 * 9) + expected = { + "0": { + "instructions": [ + "c = 1", + "__iterator_1__ = iter(range(1))", + "i = None", + ], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__iter_last_1__ = i", + "i = next(__iterator_1__, '__sentinel__')", + "i != '__sentinel__'", + ], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": [ + "__iterator_5__ = iter(range(1))", + "j = None", + ], + "jump_targets": ["5"], + "name": "2", + }, + "3": { + "instructions": ["i = __iter_last_1__", "c *= 9"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["return c"], + "jump_targets": [], + "name": "4", + }, + "5": { + "instructions": [ + "__iter_last_5__ = j", + "j = next(__iterator_5__, '__sentinel__')", + "j != '__sentinel__'", + ], + "jump_targets": ["6", "7"], + "name": "5", + }, + "6": { + "instructions": ["a"], + "jump_targets": ["9", "5"], + "name": "6", + }, + "7": { + "instructions": ["j = __iter_last_5__", "c *= 5"], + "jump_targets": ["1"], + "name": "7", + }, + "8": { + "instructions": ["c *= 7"], + "jump_targets": ["4"], + "name": "8", + }, + "9": { + "instructions": ["c *= 3"], + "jump_targets": ["8"], + "name": "9", + }, + } + + self.compare(function, expected, empty={"11", "10"}) + + def test_for_with_nested_else_return_break_and_continue(self): + def function(a: int, b: int, c: int, d: int, e: int, f: int) -> int: + for i in range(2): + if i == a: + i = 3 + return i + elif i == b: + i = 4 + break + elif i == c: + i = 5 + continue + else: + while i < 10: + i += 1 + if i == d: + i = 3 + return i + elif i == e: + i = 4 + break + elif i == f: + i = 5 + continue + else: + i += 1 + return i + + expected = { + "0": { + "instructions": [ + "__iterator_1__ = iter(range(2))", + "i = None", + ], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__iter_last_1__ = i", + "i = next(__iterator_1__, '__sentinel__')", + "i != '__sentinel__'", + ], + "jump_targets": ["2", "3"], + "name": "1", + }, + "11": { + "instructions": ["i = 5"], + "jump_targets": ["1"], + "name": "11", + }, + "14": { + "instructions": ["i < 10"], + "jump_targets": ["15", "1"], + "name": "14", + }, + "15": { + "instructions": ["i += 1", "i == d"], + "jump_targets": ["18", "19"], + "name": "15", + }, + "18": { + "instructions": ["i = 3", "return i"], + "jump_targets": [], + "name": "18", + }, + "19": { + "instructions": ["i == e"], + "jump_targets": ["21", "22"], + "name": "19", + }, + "2": { + "instructions": ["i == a"], + "jump_targets": ["5", "6"], + "name": "2", + }, + "21": { + "instructions": ["i = 4"], + "jump_targets": ["1"], + "name": "21", + }, + "22": { + "instructions": ["i == f"], + "jump_targets": ["24", "25"], + "name": "22", + }, + "24": { + "instructions": ["i = 5"], + "jump_targets": ["14"], + "name": "24", + }, + "25": { + "instructions": ["i += 1"], + "jump_targets": ["14"], + "name": "25", + }, + "3": { + "instructions": ["i = __iter_last_1__"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["return i"], + "jump_targets": [], + "name": "4", + }, + "5": { + "instructions": ["i = 3", "return i"], + "jump_targets": [], + "name": "5", + }, + "6": { + "instructions": ["i == b"], + "jump_targets": ["8", "9"], + "name": "6", + }, + "8": { + "instructions": ["i = 4"], + "jump_targets": ["4"], + "name": "8", + }, + "9": { + "instructions": ["i == c"], + "jump_targets": ["11", "14"], + "name": "9", + }, + } + empty = {"7", "10", "12", "13", "16", "17", "20", "23", "26"} + self.compare(function, expected, empty=empty) + + +if __name__ == "__main__": + main()