Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Algorithmic decorator (implement output for _repr_latex_) #163

Merged
merged 31 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions src/integration_tests/algorithmic_style_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,8 @@
from __future__ import annotations

import textwrap
from typing import Any, Callable

from latexify import generate_latex


def check_algorithm(
fn: Callable[..., Any],
latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.

Args:
fn: Function to check.
latex: LaTeX form of `fn`.
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# def fn(...):
# ...
# latexified = get_latex(fn, style=ALGORITHM, **kwargs)
latexified = generate_latex.get_latex(
fn, style=generate_latex.Style.ALGORITHMIC, **kwargs
)
assert latexified == latex
from integration_tests import integration_utils


def test_factorial() -> None:
Expand All @@ -50,7 +27,20 @@ def fact(n):
\end{algorithmic}
""" # noqa: E501
).strip()
check_algorithm(fact, latex)
ipython_latex = (
r"\begin{array}{l}"
r" \mathbf{function} \ \mathrm{fact}(n) \\"
r" \hspace{1em} \mathbf{if} \ n = 0 \\"
r" \hspace{2em} \mathbf{return} \ 1 \\"
r" \hspace{1em} \mathbf{else} \\"
r" \hspace{2em}"
r" \mathbf{return} \ n \cdot"
r" \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right) \\"
r" \hspace{1em} \mathbf{end \ if} \\"
r" \mathbf{end \ function}"
r" \end{array}"
)
integration_utils.check_algorithm(fact, latex, ipython_latex)


def test_collatz() -> None:
Expand Down Expand Up @@ -82,4 +72,21 @@ def collatz(n):
\end{algorithmic}
"""
).strip()
check_algorithm(collatz, latex)
ipython_latex = (
r"\begin{array}{l}"
r" \mathbf{function} \ \mathrm{collatz}(n) \\"
r" \hspace{1em} \mathrm{iterations} \gets 0 \\"
r" \hspace{1em} \mathbf{while} \ n > 1 \\"
r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\"
r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\"
r" \hspace{2em} \mathbf{else} \\"
r" \hspace{3em} n \gets 3 \cdot n + 1 \\"
r" \hspace{2em} \mathbf{end \ if} \\"
r" \hspace{2em}"
r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\"
r" \hspace{1em} \mathbf{end \ while} \\"
r" \hspace{1em} \mathbf{return} \ \mathrm{iterations} \\"
r" \mathbf{end \ function}"
r" \end{array}"
)
integration_utils.check_algorithm(collatz, latex, ipython_latex)
46 changes: 43 additions & 3 deletions src/integration_tests/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,60 @@ def check_function(
if not kwargs:
latexified = frontend.function(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"
assert latexified._repr_latex_() == r"$$ \displaystyle " + latex + " $$"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restore the original expression, and don't use "+" concatenation as well as possible. f-strings or join() are basically the most fastest option to concatenate strings and "+" is the worst choice.

Ditto for all other places.

def f():
  for i in range(1_000_000):
    x = f"The answer is {i}."
  return x

def g():
  for i in range(1_000_000):
    x = "The answer is " + str(i) + "."
  return x

%time f()
%time g()
CPU times: user 191 ms, sys: 0 ns, total: 191 ms
Wall time: 192 ms
CPU times: user 327 ms, sys: 0 ns, total: 327 ms
Wall time: 327 ms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, I think I've converted it all back. When codegen though lots of places can't use f strings because of \ not allowed :(


# Checks the syntax:
# @function(**kwargs)
# def fn(...):
# ...
latexified = frontend.function(**kwargs)(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"
assert latexified._repr_latex_() == r"$$ \displaystyle " + latex + " $$"

# Checks the syntax:
# def fn(...):
# ...
# latexified = function(fn, **kwargs)
latexified = frontend.function(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"
assert latexified._repr_latex_() == r"$$ \displaystyle " + latex + " $$"


def check_algorithm(
fn: Callable[..., Any],
latex: str,
ipython_latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.

Args:
fn: Function to check.
latex: LaTeX form of `fn`.
ipython_latex: IPython LaTeX form of `fn`
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# @algorithmic
# def fn(...):
# ...
if not kwargs:
latexified = frontend.algorithmic(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == "$ " + ipython_latex + " $"

# Checks the syntax:
# @algorithmic(**kwargs)
# def fn(...):
# ...
latexified = frontend.algorithmic(**kwargs)(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == "$ " + ipython_latex + " $"

# Checks the syntax:
# def fn(...):
# ...
# latexified = algorithmic(fn, **kwargs)
latexified = frontend.algorithmic(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == "$ " + ipython_latex + " $"
1 change: 1 addition & 0 deletions src/latexify/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen
ExpressionCodegen = expression_codegen.ExpressionCodegen
FunctionCodegen = function_codegen.FunctionCodegen
IPythonAlgorithmicCodegen = algorithmic_codegen.IPythonAlgorithmicCodegen
156 changes: 150 additions & 6 deletions src/latexify/codegen/algorithmic_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def visit_Expr(self, node: ast.Expr) -> str:
rf"\State ${self._expression_codegen.visit(node.value)}$"
)

# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
# Arguments
Expand All @@ -78,7 +79,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
body_strs: list[str] = [self.visit(stmt) for stmt in node.body]
body_latex = "\n".join(body_strs)

latex += f"{body_latex}\n"
latex += body_latex + "\n"
latex += self._add_indent("\\EndFunction\n")
return latex + self._add_indent(r"\end{algorithmic}")

Expand All @@ -89,10 +90,10 @@ def visit_If(self, node: ast.If) -> str:
with self._increment_level():
body_latex = "\n".join(self.visit(stmt) for stmt in node.body)

latex = self._add_indent(f"\\If{{${cond_latex}$}}\n{body_latex}")
latex = self._add_indent(f"\\If{{${cond_latex}$}}\n" + body_latex)

if node.orelse:
latex += "\n" + self._add_indent(r"\Else") + "\n"
latex += "\n" + self._add_indent("\\Else\n")
with self._increment_level():
latex += "\n".join(self.visit(stmt) for stmt in node.orelse)

Expand Down Expand Up @@ -124,7 +125,8 @@ def visit_While(self, node: ast.While) -> str:
body_latex = "\n".join(self.visit(stmt) for stmt in node.body)
return (
self._add_indent(f"\\While{{${cond_latex}$}}\n")
+ f"{body_latex}\n"
+ body_latex
+ "\n"
+ self._add_indent(r"\EndWhile")
)

Expand All @@ -136,9 +138,151 @@ def _increment_level(self) -> Generator[None, None, None]:
self._indent_level -= 1

def _add_indent(self, line: str) -> str:
"""Adds whitespace before the line.
"""Adds an indent before the line.

Args:
line: The line to add whitespace to.
line: The line to add an indent to.
"""
return self._indent_level * self._SPACES_PER_INDENT * " " + line


class IPythonAlgorithmicCodegen(ast.NodeVisitor):
"""Codegen for single algorithms targeting IPython.

This codegen works for Module with single FunctionDef node to generate a single
LaTeX expression of the given algorithm.
"""

_EM_PER_INDENT = 1
_LINE_BREAK = r" \\ "

_identifier_converter: identifier_converter.IdentifierConverter
_indent_level: int

def __init__(
self, *, use_math_symbols: bool = False, use_set_symbols: bool = False
) -> None:
"""Initializer.

Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols
)
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols
)
self._indent_level = 0

def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)

def visit_Assign(self, node: ast.Assign) -> str:
"""Visit an Assign node."""
operands: list[str] = [
self._expression_codegen.visit(target) for target in node.targets
]
operands.append(self._expression_codegen.visit(node.value))
operands_latex = r" \gets ".join(operands)
return self._add_indent(operands_latex)

def visit_Expr(self, node: ast.Expr) -> str:
"""Visit an Expr node."""
return self._add_indent(self._expression_codegen.visit(node.value))

# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
# Arguments
arg_strs = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
]
# Body
with self._increment_level():
body_strs: list[str] = [self.visit(stmt) for stmt in node.body]
body = self._LINE_BREAK.join(body_strs)

return (
r"\begin{array}{l} "
+ self._add_indent(r"\mathbf{function}")
+ rf" \ \mathrm{{{node.name}}}({', '.join(arg_strs)})"
+ self._LINE_BREAK
+ body
+ self._LINE_BREAK
+ self._add_indent(r"\mathbf{end \ function}")
+ r" \end{array}"
)

# TODO(ZibingZhang): support \ELSIF
def visit_If(self, node: ast.If) -> str:
"""Visit an If node."""
cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body)
latex = self._add_indent(
r"\mathbf{if} \ " + cond_latex + self._LINE_BREAK + body_latex
)

if node.orelse:
latex += self._LINE_BREAK + self._add_indent(r"\mathbf{else} \\ ")
with self._increment_level():
latex += self._LINE_BREAK.join(self.visit(stmt) for stmt in node.orelse)

return latex + self._LINE_BREAK + self._add_indent(r"\mathbf{end \ if}")

def visit_Module(self, node: ast.Module) -> str:
"""Visit a Module node."""
return self.visit(node.body[0])

def visit_Return(self, node: ast.Return) -> str:
"""Visit a Return node."""
return (
self._add_indent(r"\mathbf{return}")
+ r" \ "
+ self._expression_codegen.visit(node.value)
if node.value is not None
else self._add_indent(r"\mathbf{return}")
)

def visit_While(self, node: ast.While) -> str:
"""Visit a While node."""
if node.orelse:
raise exceptions.LatexifyNotSupportedError(
"While statement with the else clause is not supported"
)

cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body)
return (
self._add_indent(r"\mathbf{while} \ ")
+ cond_latex
+ self._LINE_BREAK
+ body_latex
+ self._LINE_BREAK
+ self._add_indent(r"\mathbf{end \ while}")
)

@contextlib.contextmanager
def _increment_level(self) -> Generator[None, None, None]:
"""Context manager controlling indent level."""
self._indent_level += 1
yield
self._indent_level -= 1

def _add_indent(self, line: str) -> str:
"""Adds an indent before the line.

Args:
line: The line to add an indent to.
"""
return (
rf"\hspace{{{self._indent_level * self._EM_PER_INDENT}em}} {line}"
if self._indent_level > 0
else line
)
Loading