diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index 9e493b0..8cbe68e 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -3,31 +3,8 @@ from __future__ import annotations import textwrap -from typing import Any, Callable -from latexify import generate_latex - - -def check_algorithm( - fn: Callable[..., Any], - latex: str, - **kwargs, -) -> None: - """Helper to check if the obtained function has the expected LaTeX form. - - Args: - fn: Function to check. - latex: LaTeX form of `fn`. - **kwargs: Arguments passed to `frontend.get_latex`. - """ - # Checks the syntax: - # def fn(...): - # ... - # latexified = get_latex(fn, style=ALGORITHM, **kwargs) - latexified = generate_latex.get_latex( - fn, style=generate_latex.Style.ALGORITHMIC, **kwargs - ) - assert latexified == latex +from integration_tests import integration_utils def test_factorial() -> None: @@ -50,7 +27,20 @@ def fact(n): \end{algorithmic} """ # noqa: E501 ).strip() - check_algorithm(fact, latex) + ipython_latex = ( + r"\begin{array}{l}" + r" \mathbf{function} \ \mathrm{fact}(n) \\" + r" \hspace{1em} \mathbf{if} \ n = 0 \\" + r" \hspace{2em} \mathbf{return} \ 1 \\" + r" \hspace{1em} \mathbf{else} \\" + r" \hspace{2em}" + r" \mathbf{return} \ n \cdot" + r" \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right) \\" + r" \hspace{1em} \mathbf{end \ if} \\" + r" \mathbf{end \ function}" + r" \end{array}" + ) + integration_utils.check_algorithm(fact, latex, ipython_latex) def test_collatz() -> None: @@ -82,4 +72,21 @@ def collatz(n): \end{algorithmic} """ ).strip() - check_algorithm(collatz, latex) + ipython_latex = ( + r"\begin{array}{l}" + r" \mathbf{function} \ \mathrm{collatz}(n) \\" + r" \hspace{1em} \mathrm{iterations} \gets 0 \\" + r" \hspace{1em} \mathbf{while} \ n > 1 \\" + r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\" + r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\" + r" \hspace{2em} \mathbf{else} \\" + r" \hspace{3em} n \gets 3 \cdot n + 1 \\" + r" \hspace{2em} \mathbf{end \ if} \\" + r" \hspace{2em}" + r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\" + r" \hspace{1em} \mathbf{end \ while} \\" + r" \hspace{1em} \mathbf{return} \ \mathrm{iterations} \\" + r" \mathbf{end \ function}" + r" \end{array}" + ) + integration_utils.check_algorithm(collatz, latex, ipython_latex) diff --git a/src/integration_tests/integration_utils.py b/src/integration_tests/integration_utils.py index 92ada9a..5ffe7a0 100644 --- a/src/integration_tests/integration_utils.py +++ b/src/integration_tests/integration_utils.py @@ -43,3 +43,43 @@ def check_function( latexified = frontend.function(fn, **kwargs) assert str(latexified) == latex assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$" + + +def check_algorithm( + fn: Callable[..., Any], + latex: str, + ipython_latex: str, + **kwargs, +) -> None: + """Helper to check if the obtained function has the expected LaTeX form. + + Args: + fn: Function to check. + latex: LaTeX form of `fn`. + ipython_latex: IPython LaTeX form of `fn` + **kwargs: Arguments passed to `frontend.get_latex`. + """ + # Checks the syntax: + # @algorithmic + # def fn(...): + # ... + if not kwargs: + latexified = frontend.algorithmic(fn) + assert str(latexified) == latex + assert latexified._repr_latex_() == f"$ {ipython_latex} $" + + # Checks the syntax: + # @algorithmic(**kwargs) + # def fn(...): + # ... + latexified = frontend.algorithmic(**kwargs)(fn) + assert str(latexified) == latex + assert latexified._repr_latex_() == f"$ {ipython_latex} $" + + # Checks the syntax: + # def fn(...): + # ... + # latexified = algorithmic(fn, **kwargs) + latexified = frontend.algorithmic(fn, **kwargs) + assert str(latexified) == latex + assert latexified._repr_latex_() == f"$ {ipython_latex} $" diff --git a/src/latexify/codegen/__init__.py b/src/latexify/codegen/__init__.py index 1aea2c3..8d09290 100644 --- a/src/latexify/codegen/__init__.py +++ b/src/latexify/codegen/__init__.py @@ -5,3 +5,4 @@ AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen ExpressionCodegen = expression_codegen.ExpressionCodegen FunctionCodegen = function_codegen.FunctionCodegen +IPythonAlgorithmicCodegen = algorithmic_codegen.IPythonAlgorithmicCodegen diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index 0460ffd..685663c 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -60,6 +60,7 @@ def visit_Expr(self, node: ast.Expr) -> str: rf"\State ${self._expression_codegen.visit(node.value)}$" ) + # TODO(ZibingZhang): support nested functions def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" # Arguments @@ -89,14 +90,14 @@ def visit_If(self, node: ast.If) -> str: with self._increment_level(): body_latex = "\n".join(self.visit(stmt) for stmt in node.body) - latex = self._add_indent(f"\\If{{${cond_latex}$}}\n{body_latex}") + latex = self._add_indent(f"\\If{{${cond_latex}$}}\n" + body_latex) if node.orelse: - latex += "\n" + self._add_indent(r"\Else") + "\n" + latex += "\n" + self._add_indent("\\Else\n") with self._increment_level(): latex += "\n".join(self.visit(stmt) for stmt in node.orelse) - return latex + "\n" + self._add_indent(r"\EndIf") + return f"{latex}\n" + self._add_indent(r"\EndIf") def visit_Module(self, node: ast.Module) -> str: """Visit a Module node.""" @@ -136,9 +137,145 @@ def _increment_level(self) -> Generator[None, None, None]: self._indent_level -= 1 def _add_indent(self, line: str) -> str: - """Adds whitespace before the line. + """Adds an indent before the line. Args: - line: The line to add whitespace to. + line: The line to add an indent to. """ return self._indent_level * self._SPACES_PER_INDENT * " " + line + + +class IPythonAlgorithmicCodegen(ast.NodeVisitor): + """Codegen for single algorithms targeting IPython. + + This codegen works for Module with single FunctionDef node to generate a single + LaTeX expression of the given algorithm. + """ + + _EM_PER_INDENT = 1 + _LINE_BREAK = r" \\ " + + _identifier_converter: identifier_converter.IdentifierConverter + _indent_level: int + + def __init__( + self, *, use_math_symbols: bool = False, use_set_symbols: bool = False + ) -> None: + """Initializer. + + Args: + use_math_symbols: Whether to convert identifiers with a math symbol surface + (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). + use_set_symbols: Whether to use set symbols or not. + """ + self._expression_codegen = expression_codegen.ExpressionCodegen( + use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols + ) + self._identifier_converter = identifier_converter.IdentifierConverter( + use_math_symbols=use_math_symbols + ) + self._indent_level = 0 + + def generic_visit(self, node: ast.AST) -> str: + raise exceptions.LatexifyNotSupportedError( + f"Unsupported AST: {type(node).__name__}" + ) + + def visit_Assign(self, node: ast.Assign) -> str: + """Visit an Assign node.""" + operands: list[str] = [ + self._expression_codegen.visit(target) for target in node.targets + ] + operands.append(self._expression_codegen.visit(node.value)) + operands_latex = r" \gets ".join(operands) + return self._add_indent(operands_latex) + + def visit_Expr(self, node: ast.Expr) -> str: + """Visit an Expr node.""" + return self._add_indent(self._expression_codegen.visit(node.value)) + + # TODO(ZibingZhang): support nested functions + def visit_FunctionDef(self, node: ast.FunctionDef) -> str: + """Visit a FunctionDef node.""" + # Arguments + arg_strs = [ + self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args + ] + # Body + with self._increment_level(): + body_strs: list[str] = [self.visit(stmt) for stmt in node.body] + body = self._LINE_BREAK.join(body_strs) + + return ( + r"\begin{array}{l} " + + self._add_indent(r"\mathbf{function}") + + rf" \ \mathrm{{{node.name}}}({', '.join(arg_strs)})" + + f"{self._LINE_BREAK}{body}{self._LINE_BREAK}" + + self._add_indent(r"\mathbf{end \ function}") + + r" \end{array}" + ) + + # TODO(ZibingZhang): support \ELSIF + def visit_If(self, node: ast.If) -> str: + """Visit an If node.""" + cond_latex = self._expression_codegen.visit(node.test) + with self._increment_level(): + body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body) + latex = self._add_indent( + rf"\mathbf{{if}} \ {cond_latex}{self._LINE_BREAK}{body_latex}" + ) + + if node.orelse: + latex += self._LINE_BREAK + self._add_indent(r"\mathbf{else} \\ ") + with self._increment_level(): + latex += self._LINE_BREAK.join(self.visit(stmt) for stmt in node.orelse) + + return latex + self._LINE_BREAK + self._add_indent(r"\mathbf{end \ if}") + + def visit_Module(self, node: ast.Module) -> str: + """Visit a Module node.""" + return self.visit(node.body[0]) + + def visit_Return(self, node: ast.Return) -> str: + """Visit a Return node.""" + return ( + self._add_indent(r"\mathbf{return} \ ") + + self._expression_codegen.visit(node.value) + if node.value is not None + else self._add_indent(r"\mathbf{return}") + ) + + def visit_While(self, node: ast.While) -> str: + """Visit a While node.""" + if node.orelse: + raise exceptions.LatexifyNotSupportedError( + "While statement with the else clause is not supported" + ) + + cond_latex = self._expression_codegen.visit(node.test) + with self._increment_level(): + body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body) + return ( + self._add_indent(r"\mathbf{while} \ ") + + f"{cond_latex}{self._LINE_BREAK}{body_latex}{self._LINE_BREAK}" + + self._add_indent(r"\mathbf{end \ while}") + ) + + @contextlib.contextmanager + def _increment_level(self) -> Generator[None, None, None]: + """Context manager controlling indent level.""" + self._indent_level += 1 + yield + self._indent_level -= 1 + + def _add_indent(self, line: str) -> str: + """Adds an indent before the line. + + Args: + line: The line to add an indent to. + """ + return ( + rf"\hspace{{{self._indent_level * self._EM_PER_INDENT}em}} {line}" + if self._indent_level > 0 + else line + ) diff --git a/src/latexify/codegen/algorithmic_codegen_test.py b/src/latexify/codegen/algorithmic_codegen_test.py index 80a2d0c..a0781c7 100644 --- a/src/latexify/codegen/algorithmic_codegen_test.py +++ b/src/latexify/codegen/algorithmic_codegen_test.py @@ -165,3 +165,135 @@ def test_visit_while_with_else() -> None: match="^While statement with the else clause is not supported$", ): algorithmic_codegen.AlgorithmicCodegen().visit(node) + + +@pytest.mark.parametrize( + "code,latex", + [ + ("x = 3", r"x \gets 3"), + ("a = b = 0", r"a \gets b \gets 0"), + ], +) +def test_visit_assign_jupyter(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.Assign) + assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "def f(x): return x", + ( + r"\begin{array}{l}" + r" \mathbf{function}" + r" \ \mathrm{f}(x) \\" + r" \hspace{1em} \mathbf{return} \ x \\" + r" \mathbf{end \ function}" + r" \end{array}" + ), + ), + ( + "def f(a, b, c): return 3", + ( + r"\begin{array}{l}" + r" \mathbf{function}" + r" \ \mathrm{f}(a, b, c) \\" + r" \hspace{1em} \mathbf{return} \ 3 \\" + r" \mathbf{end \ function}" + r" \end{array}" + ), + ), + ], +) +def test_visit_functiondef_ipython(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.FunctionDef) + assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "if x < y: return x", + ( + r"\mathbf{if} \ x < y \\" + r" \hspace{1em} \mathbf{return} \ x \\" + r" \mathbf{end \ if}" + ), + ), + ( + "if True: x\nelse: y", + ( + r"\mathbf{if} \ \mathrm{True} \\" + r" \hspace{1em} x \\" + r" \mathbf{else} \\" + r" \hspace{1em} y \\" + r" \mathbf{end \ if}" + ), + ), + ], +) +def test_visit_if_ipython(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.If) + assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "return x + y", + r"\mathbf{return} \ x + y", + ), + ( + "return", + r"\mathbf{return}", + ), + ], +) +def test_visit_return_ipython(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.Return) + assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "while x < y: x = x + 1", + ( + r"\mathbf{while} \ x < y \\" + r" \hspace{1em} x \gets x + 1 \\" + r" \mathbf{end \ while}" + ), + ) + ], +) +def test_visit_while_ipython(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.While) + assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex + + +def test_visit_while_with_else_ipython() -> None: + node = ast.parse( + textwrap.dedent( + """ + while True: + x = x + else: + x = y + """ + ) + ).body[0] + assert isinstance(node, ast.While) + with pytest.raises( + exceptions.LatexifyNotSupportedError, + match="^While statement with the else clause is not supported$", + ): + algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) diff --git a/src/latexify/generate_latex.py b/src/latexify/generate_latex.py index 2d8d817..f734bdf 100644 --- a/src/latexify/generate_latex.py +++ b/src/latexify/generate_latex.py @@ -8,7 +8,7 @@ from latexify import codegen from latexify import config as cfg -from latexify import exceptions, parser, transformers +from latexify import parser, transformers class Style(enum.Enum): @@ -63,14 +63,16 @@ def get_latex( use_math_symbols=merged_config.use_math_symbols, use_set_symbols=merged_config.use_set_symbols, ).visit(tree) - elif style == Style.IPYTHON_ALGORITHMIC: - # TODO(ZibingZhang): implement algorithmic codegen for ipython - raise exceptions.LatexifyNotSupportedError elif style == Style.FUNCTION: return codegen.FunctionCodegen( use_math_symbols=merged_config.use_math_symbols, use_signature=merged_config.use_signature, use_set_symbols=merged_config.use_set_symbols, ).visit(tree) + elif style == Style.IPYTHON_ALGORITHMIC: + return codegen.IPythonAlgorithmicCodegen( + use_math_symbols=merged_config.use_math_symbols, + use_set_symbols=merged_config.use_set_symbols, + ).visit(tree) raise ValueError(f"Unrecognized style: {style}") diff --git a/src/latexify/ipython_wrappers.py b/src/latexify/ipython_wrappers.py index 88aa015..c77a869 100644 --- a/src/latexify/ipython_wrappers.py +++ b/src/latexify/ipython_wrappers.py @@ -95,7 +95,7 @@ def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display LaTeX visualization.""" return ( - r"$ " + self._ipython_latex + " $" + f"$ {self._ipython_latex} $" if self._ipython_latex is not None else self._ipython_error ) @@ -133,7 +133,7 @@ def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display LaTeX visualization.""" return ( - r"$$ \displaystyle " + self._latex + " $$" + rf"$$ \displaystyle {self._latex} $$" if self._latex is not None else self._error )