From 66e604042f1d56a918c28abfd212bae91b4045dc Mon Sep 17 00:00:00 2001 From: Fergal Walsh Date: Fri, 22 Dec 2023 17:17:15 +0000 Subject: [PATCH] Add router --- tealish/nodes.py | 121 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 5 deletions(-) diff --git a/tealish/nodes.py b/tealish/nodes.py index 22de921..63197ad 100644 --- a/tealish/nodes.py +++ b/tealish/nodes.py @@ -16,7 +16,7 @@ from .base import BaseNode from .errors import CompileError, ParseError from .tx_expressions import parse_expression -from .tealish_builtins import Var +from .tealish_builtins import Var, constants from .types import ( AVMType, AnyType, @@ -249,6 +249,8 @@ def consume(cls, compiler: "TealishCompiler", parent: Node) -> "Statement": return InnerTxn.consume(compiler, parent) elif line.startswith("struct "): return StructDefinition.consume(compiler, parent) + elif line.startswith("router:"): + return Router.consume(compiler, parent) else: return LineStatement.consume(compiler, parent) @@ -287,16 +289,16 @@ def consume(cls, compiler: "TealishCompiler", parent: Optional[Node]) -> "Progra expect_struct_definition = False if exit_statement: - if not isinstance(n, (Func, Block, Comment, Blank)): + if not isinstance(n, (Func, DecoratedFunc, Block, Comment, Blank)): raise ParseError( f"Unexpected statement at line {n.line_no}." + f" Only Block and Function definitions should occure after a {exit_statement}." ) else: - if isinstance(n, (Func, Block)): + if isinstance(n, (Func, DecoratedFunc, Block)): raise ParseError( f"Unexpected {n} definition at line {n.line_no}. " - + "Block and Function definitions must occur after an exit statement (e.g Exit, switch, jump)." + + "Block and Function definitions must occur after an exit statement (e.g Exit, switch, jump, router)." ) if is_exit_statement(n): exit_statement = n @@ -795,6 +797,115 @@ def _tealish(self) -> str: return output +class Route(Node): + pattern = r"(?P.*)" + name: str + + @property + def label(self): + return f"route_{self.name}" + + def process(self) -> None: + self.arg_expressions = [] + self.func = self.lookup_func(self.name) + if "public" not in self.func.attributes: + raise CompileError(f"{self.name} is not a public function", node=self) + for i, (arg, type_name) in enumerate(self.func.args.args): + a = i + 1 + arg_type = get_type_instance(type_name) + if type_name == "bytes": + line = f"Txn.ApplicationArgs[{a}]" + elif type_name == "int": + line = f"btoi(Txn.ApplicationArgs[{a}])" + elif isinstance(arg_type, IntType) and arg_type.size != 8: + line = f"Cast(btoi(Txn.ApplicationArgs[{a}]), {type_name})" + elif isinstance(arg_type, BytesType): + line = f"Cast(Txn.ApplicationArgs[{a}], {type_name})" + expression = GenericExpression.parse(line, self, self.compiler) + expression.process() + self.arg_expressions.append(expression) + + def _tealish(self) -> str: + output = f"{self.name}\n" + return output + + +class Router(InlineStatement): + possible_child_nodes = [Route] + pattern = r"router:$" + + def __init__( + self, + line: str, + parent: Optional[Node] = None, + compiler: Optional["TealishCompiler"] = None, + ) -> None: + super().__init__(line, parent, compiler) + self.routes: List[Route] = [] + + def add_route(self, node: Route) -> None: + self.routes.append(node) + self.add_child(node) + + @classmethod + def consume(cls, compiler: "TealishCompiler", parent: Optional[Node]) -> "Switch": + router = Router(compiler.consume_line(), parent, compiler=compiler) + while True: + if compiler.peek() == "end": + compiler.consume_line() + break + router.add_route(Route(compiler.consume_line(), router, compiler=compiler)) + return router + + def process(self) -> None: + for node in self.child_nodes: + node.process() + + def write_teal(self, writer: "TealWriter") -> None: + writer.write(self, f"// {self.line}") + for route in self.routes: + writer.write(self, "txna ApplicationArgs 0") + writer.write(self, f'pushbytes "{route.name}"') + writer.write(self, "==") + writer.write(self, f"bnz {route.label}") + writer.write(self, "err // unexpected value") + + for i, route in enumerate(self.routes): + writer.write(self, f"{route.label}:") + func = self.lookup_func(route.name) + oc = func.attributes["public"].get("OnCompletion", "NoOp") + if oc == "CreateApplication": + writer.write(self, "txn ApplicationID") + writer.write(self, "pushint 0") + writer.write(self, "==") + writer.write(self, "assert // ApplicationID == 0") + else: + writer.write(self, "txn OnCompletion") + writer.write(self, f"pushint {constants[oc][1]} // {oc}") + writer.write(self, "==") + writer.write(self, "assert") + + for arg_expression in route.arg_expressions: + writer.write(self, f"// {arg_expression.tealish()}") + writer.write(self, arg_expression) + writer.write(self, f"callsub {func.label}") + if func.returns: + if isinstance(func.returns[0], IntType): + writer.write(self, "itob") + writer.write(self, "pushbytes 0x151f7c75 // arc4 return prefix") + writer.write(self, "concat") + writer.write(self, "log") + writer.write(self, "pushint 1") + writer.write(self, "return") + + def _tealish(self) -> str: + output = "router:\n" + for n in self.child_nodes: + output += indent(n.tealish()) + output += "end\n" + return output + + class TealLine(Node): def write_teal(self, writer: "TealWriter") -> None: writer.write(self, f"{self.line}") @@ -1892,5 +2003,5 @@ def indent(s: str) -> str: def is_exit_statement(node): - if isinstance(node, (Exit, Switch, Jump)): + if isinstance(node, (Exit, Switch, Jump, Router)): return True