From a56a6dd6ebc4493ba2b5a839d9a44de6572aa2bc Mon Sep 17 00:00:00 2001 From: odashi Date: Sat, 18 Nov 2023 21:18:51 +0000 Subject: [PATCH] fix wrappign around pow --- src/latexify/codegen/expression_codegen.py | 16 +++++++++++----- .../codegen/expression_codegen_test.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/latexify/codegen/expression_codegen.py b/src/latexify/codegen/expression_codegen.py index f88869c..9239d72 100644 --- a/src/latexify/codegen/expression_codegen.py +++ b/src/latexify/codegen/expression_codegen.py @@ -408,16 +408,22 @@ def visit_Call(self, node: ast.Call) -> str: if rule.is_unary and len(node.args) == 1: # Unary function. Applies the same wrapping policy with the unary operators. + precedence = expression_rules.get_precedence(node) + arg = node.args[0] # NOTE(odashi): # Factorial "x!" is treated as a special case: it requires both inner/outer # parentheses for correct interpretation. - precedence = expression_rules.get_precedence(node) - arg = node.args[0] - force_wrap = isinstance(arg, ast.Call) and ( + force_wrap_factorial = isinstance(arg, ast.Call) and ( func_name == "factorial" or ast_utils.extract_function_name_or_none(arg) == "factorial" ) - arg_latex = self._wrap_operand(arg, precedence, force_wrap) + # Note(odashi): + # Wrapping is also required if the argument is pow. + # https://github.com/google/latexify_py/issues/189 + force_wrap_pow = isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Pow) + arg_latex = self._wrap_operand( + arg, precedence, force_wrap_factorial or force_wrap_pow + ) elements = [rule.left, arg_latex, rule.right] else: arg_latex = ", ".join(self.visit(arg) for arg in node.args) @@ -490,7 +496,7 @@ def _wrap_operand( latex = self.visit(child) child_prec = expression_rules.get_precedence(child) - if child_prec < parent_prec or force_wrap and child_prec == parent_prec: + if force_wrap or child_prec < parent_prec: return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" return latex diff --git a/src/latexify/codegen/expression_codegen_test.py b/src/latexify/codegen/expression_codegen_test.py index 8ed960b..e869777 100644 --- a/src/latexify/codegen/expression_codegen_test.py +++ b/src/latexify/codegen/expression_codegen_test.py @@ -218,6 +218,25 @@ def test_visit_call(code: str, latex: str) -> None: assert expression_codegen.ExpressionCodegen().visit(node) == latex +@pytest.mark.parametrize( + "code,latex", + [ + ("log(x)**2", r"\mathopen{}\left( \log x \mathclose{}\right)^{2}"), + ("log(x**2)", r"\log \mathopen{}\left( x^{2} \mathclose{}\right)"), + ( + "log(x**2)**3", + r"\mathopen{}\left(" + r" \log \mathopen{}\left( x^{2} \mathclose{}\right)" + r" \mathclose{}\right)^{3}", + ), + ], +) +def test_visit_call_with_pow(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, (ast.Call, ast.BinOp)) + assert expression_codegen.ExpressionCodegen().visit(node) == latex + + @pytest.mark.parametrize( "src_suffix,dest_suffix", [