From 4ebd5350dc9274abc4d32241b9e339056f84b316 Mon Sep 17 00:00:00 2001 From: Yusuke Oda Date: Fri, 17 Nov 2023 09:08:28 +0900 Subject: [PATCH] Add AugAssign support (#193) --- src/latexify/generate_latex.py | 5 +++- src/latexify/generate_latex_test.py | 14 +++++++++++ src/latexify/transformers/__init__.py | 2 ++ .../transformers/aug_assign_replacer.py | 20 ++++++++++++++++ .../transformers/aug_assign_replacer_test.py | 24 +++++++++++++++++++ 5 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 src/latexify/transformers/aug_assign_replacer.py create mode 100644 src/latexify/transformers/aug_assign_replacer_test.py diff --git a/src/latexify/generate_latex.py b/src/latexify/generate_latex.py index f734bdf..dbcf0e8 100644 --- a/src/latexify/generate_latex.py +++ b/src/latexify/generate_latex.py @@ -47,7 +47,10 @@ def get_latex( # Obtains the source AST. tree = parser.parse_function(fn) - # Applies AST transformations. + # Mandatory AST Transformation. + tree = transformers.AugAssignReplacer().visit(tree) + + # Conditional AST transformation. if merged_config.prefixes is not None: tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree) if merged_config.identifiers is not None: diff --git a/src/latexify/generate_latex_test.py b/src/latexify/generate_latex_test.py index 6f128b0..bc61b97 100644 --- a/src/latexify/generate_latex_test.py +++ b/src/latexify/generate_latex_test.py @@ -54,6 +54,20 @@ def f(x): assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag +def test_get_latex_reduce_assignments_with_aug_assign() -> None: + def f(x): + y = 3 + y *= x + return y + + latex_without_flag = r"\begin{array}{l} y = 3 \\ y = y x \\ f(x) = y \end{array}" + latex_with_flag = r"f(x) = 3 x" + + assert generate_latex.get_latex(f) == latex_without_flag + assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag + assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag + + def test_get_latex_use_math_symbols() -> None: def f(alpha): return alpha diff --git a/src/latexify/transformers/__init__.py b/src/latexify/transformers/__init__.py index 7ec48c6..79c9e21 100644 --- a/src/latexify/transformers/__init__.py +++ b/src/latexify/transformers/__init__.py @@ -1,12 +1,14 @@ """Package latexify.transformers.""" from latexify.transformers.assignment_reducer import AssignmentReducer +from latexify.transformers.aug_assign_replacer import AugAssignReplacer from latexify.transformers.function_expander import FunctionExpander from latexify.transformers.identifier_replacer import IdentifierReplacer from latexify.transformers.prefix_trimmer import PrefixTrimmer __all__ = [ "AssignmentReducer", + "AugAssignReplacer", "FunctionExpander", "IdentifierReplacer", "PrefixTrimmer", diff --git a/src/latexify/transformers/aug_assign_replacer.py b/src/latexify/transformers/aug_assign_replacer.py new file mode 100644 index 0000000..0e4df66 --- /dev/null +++ b/src/latexify/transformers/aug_assign_replacer.py @@ -0,0 +1,20 @@ +"""Transformer to replace AugAssign to Assign.""" + +from __future__ import annotations + +import ast + + +class AugAssignReplacer(ast.NodeTransformer): + """NodeTransformer to replace AugAssign to corresponding Assign. + + AugAssign(target, op, value) => Assign([target], BinOp(target, op, value)) + + """ + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.Assign: + left_args = {**vars(node.target), "ctx": ast.Load()} + left = type(node.target)(**left_args) + return ast.Assign( + targets=[node.target], value=ast.BinOp(left, node.op, node.value) + ) diff --git a/src/latexify/transformers/aug_assign_replacer_test.py b/src/latexify/transformers/aug_assign_replacer_test.py new file mode 100644 index 0000000..9dc009e --- /dev/null +++ b/src/latexify/transformers/aug_assign_replacer_test.py @@ -0,0 +1,24 @@ +"""Tests for latexify.transformers.aug_assign_replacer.""" + +import ast + +from latexify import test_utils +from latexify.transformers.aug_assign_replacer import AugAssignReplacer + + +def test_replace() -> None: + tree = ast.AugAssign( + target=ast.Name(id="x", ctx=ast.Store()), + op=ast.Add(), + value=ast.Name(id="y", ctx=ast.Load()), + ) + expected = ast.Assign( + targets=[ast.Name(id="x", ctx=ast.Store())], + value=ast.BinOp( + left=ast.Name(id="x", ctx=ast.Load()), + op=ast.Add(), + right=ast.Name(id="y", ctx=ast.Load()), + ), + ) + transformed = AugAssignReplacer().visit(tree) + test_utils.assert_ast_equal(transformed, expected)