Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement and test SCFG->AST conversion #127

Merged
merged 12 commits into from
Jun 9, 2024
4 changes: 3 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ all:
build:
python -m pip install -vv -e .
test:
coverage run -m pytest --pyargs numba_rvsdg
# Activate using the sys.monitoring implementation of coverage.
# Needs at least coverage veraion 7.4.0 to work.
COVERAGE_CORE=sysmon coverage run -m pytest --pyargs numba_rvsdg
coverage report
lint:
pre-commit run --verbose --all-files
Expand Down
5 changes: 4 additions & 1 deletion numba_rvsdg/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from numba_rvsdg.core.datastructures.ast_transforms import AST2SCFG # noqa
from numba_rvsdg.core.datastructures.ast_transforms import ( # noqa
AST2SCFG,
SCFG2AST,
)
278 changes: 246 additions & 32 deletions numba_rvsdg/core/datastructures/ast_transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
import ast
import inspect
import itertools
from typing import Callable, Any, MutableMapping
import textwrap
from dataclasses import dataclass
from collections import defaultdict

from numba_rvsdg.core.datastructures.scfg import SCFG
from numba_rvsdg.core.datastructures.basic_block import PythonASTBlock
from numba_rvsdg.core.datastructures.basic_block import (
PythonASTBlock,
RegionBlock,
SyntheticHead,
SyntheticTail,
SyntheticFill,
SyntheticReturn,
SyntheticAssignment,
SyntheticExitingLatch,
SyntheticExitBranch,
)


def unparse_code(
code: str | list[ast.FunctionDef] | Callable[..., Any]
) -> list[type[ast.AST]]:
# Convert source code into AST.
if isinstance(code, str):
tree = ast.parse(code).body
elif callable(code):
tree = ast.parse(textwrap.dedent(inspect.getsource(code))).body
elif (
isinstance(code, list)
and len(code) > 0
and all([isinstance(i, ast.AST) for i in code])
):
tree = code # type: ignore
else:
msg = "Type: '{type(self.code}}' is not implemented."
raise NotImplementedError(msg)
return tree # type: ignore


class WritableASTBlock:
Expand Down Expand Up @@ -217,34 +249,14 @@ def __init__(
) -> None:
self.prune = prune
self.code = code
self.tree = AST2SCFGTransformer.unparse_code(code)
self.tree = unparse_code(code)
self.block_index: int = 1 # 0 is reserved for genesis block
self.blocks = ASTCFG()
# Initialize first (genesis) block, assume it's named zero.
# (This also initializes the self.current_block attribute.)
self.add_block(0)
self.loop_stack: list[LoopIndices] = []

@staticmethod
def unparse_code(
code: str | list[ast.FunctionDef] | Callable[..., Any]
) -> list[type[ast.AST]]:
# Convert source code into AST.
if isinstance(code, str):
tree = ast.parse(code).body
elif callable(code):
tree = ast.parse(textwrap.dedent(inspect.getsource(code))).body
elif (
isinstance(code, list)
and len(code) > 0
and all([isinstance(i, ast.AST) for i in code])
):
tree = code # type: ignore
else:
msg = "Type: '{type(self.code}}' is not implemented."
raise NotImplementedError(msg)
return tree # type: ignore

def transform_to_ASTCFG(self) -> ASTCFG:
"""Generate ASTCFG from Python function."""
self.transform()
Expand Down Expand Up @@ -329,7 +341,7 @@ def handle_function_def(self, node: ast.FunctionDef) -> None:
# end up being an unreachable block if all other paths through the
# program already call return.
if not isinstance(node.body[-1], ast.Return):
node.body.append(ast.Return(None))
node.body.append(ast.Return())
self.codegen(node.body)

def handle_if(self, node: ast.If) -> None:
Expand Down Expand Up @@ -574,8 +586,8 @@ def function(a: int) -> None
# the CFG.
target = ast.unparse(node.target)
iter_setup = ast.unparse(node.iter)
iter_assign = f"__iterator_{head_index}__"
last_target_value = f"__iter_last_{head_index}__"
iter_assign = f"__scfg_iterator_{head_index}__"
last_target_value = f"__scfg_iter_last_{head_index}__"

# Emit iterator setup to pre-header.
preheader_code = textwrap.dedent(
Expand All @@ -593,14 +605,14 @@ def function(a: int) -> None

# Emit header instructions. This first makes a backup of the iteration
# target and then checks if the iterator is exhausted and if the loop
# should continue. The '__sentinel__' is an singleton style marker, so
# it need not be versioned.
# should continue. The '__scfg__sentinel__' is an singleton style
# marker, so it need not be versioned.

header_code = textwrap.dedent(
f"""
{last_target_value} = {target}
{target} = next({iter_assign}, "__sentinel__")
{target} != "__sentinel__"
{target} = next({iter_assign}, "__scfg_sentinel__")
{target} != "__scfg_sentinel__"
"""
)
self.codegen(ast.parse(header_code).body)
Expand Down Expand Up @@ -655,11 +667,213 @@ def render(self) -> None:
self.blocks.to_SCFG().render()


class SCFG2ASTTransformer:

def transform(
self, original: ast.FunctionDef, scfg: SCFG
) -> ast.FunctionDef:
body: list[ast.AST] = []
self.region_stack = [scfg.region]
self.scfg = scfg
for name, block in scfg.concealed_region_view.items():
if type(block) is RegionBlock and block.kind == "branch":
continue
body.extend(self.codegen(block))
fdef = ast.FunctionDef(
name="transformed_function",
args=original.args,
body=body,
lineno=0,
decorator_list=original.decorator_list,
returns=original.returns,
)
return fdef

def lookup(self, item: Any) -> Any:
subregion_scfg = self.region_stack[-1].subregion
parent_region_block = self.region_stack[-1].parent_region
if item in subregion_scfg: # type: ignore
return subregion_scfg[item] # type: ignore
else:
return self.rlookup(parent_region_block, item) # type: ignore

def rlookup(self, region_block: RegionBlock, item: Any) -> Any:
if item in region_block.subregion: # type: ignore
return region_block.subregion[item] # type: ignore
elif region_block.parent_region is not None:
return self.rlookup(region_block.parent_region, item)
else:
raise KeyError(f"Item {item} not found in subregion or parent")

def codegen(self, block: Any) -> list[ast.AST]:
sklam marked this conversation as resolved.
Show resolved Hide resolved
if type(block) is PythonASTBlock:
if len(block.jump_targets) == 2:
if type(block.tree[-1]) in (ast.Name, ast.Compare):
test = block.tree[-1]
else:
test = block.tree[-1].value # type: ignore
body = self.codegen(self.lookup(block.jump_targets[0]))
orelse = self.codegen(self.lookup(block.jump_targets[1]))
if_node = ast.If(test, body, orelse)
return block.tree[:-1] + [if_node]
elif block.fallthrough and type(block.tree[-1]) is ast.Return:
# The value of the ast.Return could be either None or an
# ast.AST type. In the case of None, this refers to a plain
# 'return', which is implicitly 'return None'. So, if it is
# None, we assign the __scfg_return_value__ an
# ast.Constant(None) and whatever the ast.AST node is
# otherwise.
val = block.tree[-1].value
return block.tree[:-1] + [
ast.Assign(
[ast.Name("__scfg_return_value__")],
(ast.Constant(None) if val is None else val),
lineno=0,
)
]
elif block.fallthrough or block.is_exiting:
return block.tree
else:
raise NotImplementedError
elif type(block) is RegionBlock:
# We maintain a stack of the current region, in order to allow for
# random node lookup by name.
self.region_stack.append(block)

# This is a custom view that uses the concealed_region_view and
# additionally filters all branch regions. Essentially, branch
# regions will be visited by calling codegen recursively from
# blocks with multiple jump targets and all other regions must be
# visited linearly.
def codegen_view() -> list[Any]:
return list(
itertools.chain.from_iterable(
self.codegen(b)
for b in block.subregion.concealed_region_view.values() # type: ignore # noqa
if not (type(b) is RegionBlock and b.kind == "branch")
)
)

if block.kind in ("head", "tail", "branch"):
rval = codegen_view()
elif block.kind == "loop":
# A loop region gives rise to a Python while __scfg_loop_cont__
# loop. We recursively visit the body. The exiting latch will
# update __scfg_loop_continue__.
rval = [
ast.Assign(
[ast.Name("__scfg_loop_cont__")],
ast.Constant(True),
lineno=0,
),
ast.While(
test=ast.Name("__scfg_loop_cont__"),
body=codegen_view(),
orelse=[],
),
]
else:
raise NotImplementedError
self.region_stack.pop()
return rval
elif type(block) is SyntheticAssignment:
# Synthetic assignments just create Python assignments, one for
# each variable..
return [
ast.Assign([ast.Name(t)], ast.Constant(v), lineno=0)
for t, v in block.variable_assignment.items()
]
elif type(block) is SyntheticTail:
# Synthetic tails do nothing.
return []
elif type(block) is SyntheticFill:
# Synthetic fills must have a pass statement to main syntactical
# correctness of the final program.
return [ast.Pass()]
elif type(block) is SyntheticReturn:
# Synthetic return blocks must re-assigne the return value to a
# special reserved variable.
return [ast.Return(ast.Name("__scfg_return_value__"))]
elif type(block) is SyntheticExitingLatch:
# The synthetic exiting latch simply assigns the negated value of
# the exit variable to '__scfg_loop_cont__'.
assert len(block.jump_targets) == 1
assert len(block.backedges) == 1
return [
ast.Assign(
[ast.Name("__scfg_loop_cont__")],
ast.UnaryOp(ast.Not(), ast.Name(block.variable)),
lineno=0,
)
]
elif type(block) in (SyntheticExitBranch, SyntheticHead):
# Both the Synthetic exit branch and the synthetic head contain a
# branching statement with potentially multiple outgoing branches.
# This means we must recursively generate an if-cascade in Python,
# such that all jump targets may be visisted. Looking at the
# resulting AST, it does appear as though the compilation of the
# AST to source code will use `elif` statements.

# Create a reverse lookup from the branch_value_table
# branch_name --> list of variables that lead there
reverse = defaultdict(list)
for (
variable_value,
jump_target,
) in block.branch_value_table.items():
reverse[jump_target].append(variable_value)
# recursive generation of if-cascade

def if_cascade(jump_targets: list[str]) -> list[ast.AST]:
if len(jump_targets) == 1:
# base case, final else
return self.codegen(self.lookup(jump_targets.pop()))
else:
# otherwise generate if statement for current jump_target
current = jump_targets.pop()
# compare to all variable values that point to this
# jump_target
if_test = ast.Compare(
left=ast.Name(block.variable),
ops=[ast.In()],
comparators=[
ast.Tuple(
elts=[
ast.Constant(i) for i in reverse[current]
],
ctx=ast.Load(),
)
],
)
# Create the the if-statement itself, using the test. Do
# code-gen for the block that the is being pointed to and
# recurse for the rest of the jump_targets.
if_node = ast.If(
test=if_test,
body=self.codegen(self.lookup(current)),
orelse=if_cascade(jump_targets),
)
return [if_node]

# Send in a copy of the jump_targets as this list will be mutated.
return if_cascade(list(block.jump_targets[::-1]))
else:
raise NotImplementedError

raise NotImplementedError("unreachable")


def AST2SCFG(code: str | list[ast.FunctionDef] | Callable[..., Any]) -> SCFG:
"""Transform Python function into an SCFG."""
return AST2SCFGTransformer(code).transform_to_SCFG()


def SCFG2AST(scfg: SCFG) -> ast.FunctionDef: # type: ignore
"""Transform SCFG with PythonASTBlocks into an AST FunctionDef."""
# TODO
def SCFG2AST(
code: str | list[ast.FunctionDef] | Callable[..., Any], scfg: SCFG
) -> ast.FunctionDef:
"""Transform SCFG with PythonASTBlocks into an AST FunctionDef defined in
code."""
original_ast = unparse_code(code)[0]
return SCFG2ASTTransformer().transform(
original=original_ast, scfg=scfg # type: ignore
)
4 changes: 2 additions & 2 deletions numba_rvsdg/core/datastructures/scfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def new_var_name(self, kind: str) -> str:
"""
if kind in self.kinds.keys():
idx = self.kinds[kind]
name = str(kind) + "_var_" + str(idx)
name = "__scfg_" + str(kind) + "_var_" + str(idx) + "__"
self.kinds[kind] = idx + 1
else:
idx = 0
name = str(kind) + "_var_" + str(idx)
name = "__scfg_" + str(kind) + "_var_" + str(idx) + "__"
self.kinds[kind] = idx + 1
return name

Expand Down
Loading