Skip to content

Commit

Permalink
update code for better handling of sets
Browse files Browse the repository at this point in the history
  • Loading branch information
hynky1999 committed Feb 4, 2025
1 parent be0ced3 commit 6ef1e10
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 33 deletions.
17 changes: 16 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
# Changelog

## [0.4.3]

### Changed
- Replaced `FiniteSet` from `sympy` with `FiniteSet` from `latex2sympy2_extended.sets` in `src/math_verify/grader.py` and `src/math_verify/parser.py`.
- Modified `sympy_deep_compare_set_and_tuple` and `sympy_compare_sets` functions to use `SympyFiniteSet` for better compatibility with `latex2sympy2_extended`.
- Updated `is_assignment_relation` to use `is_expr_of_only_symbols` instead of `is_assignment_symbol`.
- Improved sorting logic in `sympy_deep_compare_set_and_tuple` to handle `TimeoutError`.

### Added
- New test cases in `tests/test_numina_cases.py` for enhanced expression comparison, including complex expressions and boxed expressions.

### Fixed
- Fixed issues with expression comparison logic, ensuring more accurate results when comparing sets and tuples.

## [0.4.2]
- Bump latex2sympy2_extended to 1.0.2

Expand Down Expand Up @@ -38,4 +53,4 @@

### Removed
- Removed redundant `sympy_compare_set_interval` function
- Removed unnecessary string comparison in some cases
- Removed unnecessary string comparison in some cases
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

[project]
name = "math-verify"
version = "0.4.2"
version = "0.5.0"
description = "A library for verifying mathematical answers"
authors = [
{ name = "Hynek Kydlíček", email = "[email protected]" }
]
dependencies = [
"latex2sympy2_extended==1.0.2",
"latex2sympy2_extended==1.0.3",
]
requires-python = ">=3.10"

Expand Down
62 changes: 48 additions & 14 deletions src/math_verify/grader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
# SOFTWARE.

# Heavily inspired by https://github.com/QwenLM/Qwen2.5-Math and https://github.com/huggingface/lm-evaluation-harness
from functools import lru_cache
import re
from itertools import product

from latex2sympy2_extended.sets import FiniteSet
from sympy import (
E,
And,
Basic,
EmptySet,
Eq,
FiniteSet,
Float,
GreaterThan,
Interval,
Expand All @@ -45,13 +46,16 @@
StrictLessThan,
Symbol,
Tuple,
default_sort_key,
ordered,
simplify,
)
from sympy.core.relational import Relational
from sympy.core.function import UndefinedFunction
from sympy import FiniteSet as SympyFiniteSet

from math_verify.utils import timeout
from latex2sympy2_extended import is_assignment_symbol
from latex2sympy2_extended import is_expr_of_only_symbols


def safe_sympy_doit(a: Basic | MatrixBase):
Expand Down Expand Up @@ -165,7 +169,7 @@ def sympy_symbolic_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool:
return False


def sympy_deep_compare_set_and_tuple(gold: FiniteSet | Tuple, pred: FiniteSet | Tuple, precision: int) -> bool:
def sympy_deep_compare_set_and_tuple(gold: SympyFiniteSet | Tuple, pred: SympyFiniteSet | Tuple, precision: int) -> bool:
"""Compare two finite sets by comparing each element with given precision.
Args:
Expand All @@ -179,9 +183,39 @@ def sympy_deep_compare_set_and_tuple(gold: FiniteSet | Tuple, pred: FiniteSet |
Note: in order to fully support finite sets, we should ideally do kartesian product comparison
but this is not implemented yet. We kinda hope sympy will order the elements.
"""
def unwrap_eq(s):
if is_assignment_relation(s):
return take_last_relation(s).rhs
return s

def sort_key(x):
try:
return default_sort_key(unwrap_eq(x).evalf())
except TimeoutError:
raise
except:
return default_sort_key(unwrap_eq(x))


# This ensures it works for {1/3} and {0.333333}
if len(gold) == len(pred) and all(sympy_expr_eq(a, b, precision) for a, b in zip(gold.args, pred.args)):
return True
if len(gold) == len(pred):
if isinstance(gold, SympyFiniteSet):
gold_args = list(ordered(gold.args, keys=sort_key, default=False))
pred_args = list(ordered(pred.args, keys=sort_key, default=False))

elif isinstance(gold, Tuple) and isinstance(pred, FiniteSet):
# We treat the pred as tuple too
pred_args = pred._unsorted_args
gold_args = gold.args

elif isinstance(pred, SympyFiniteSet):
pred_args = list(ordered(pred.args, keys=sort_key, default=False))
gold_args = gold.args
else:
gold_args = gold.args
pred_args = pred.args

return all(sympy_expr_eq(a, b, precision) for a, b in zip(gold_args, pred_args))

return False

Expand Down Expand Up @@ -297,8 +331,8 @@ def sympy_compare_sets(gold: Set | Basic | MatrixBase | Tuple, pred: Set | Basic
True if sets are equal by any comparison method, False otherwise
"""
# Convert non-sets to singleton sets
a_set = gold if isinstance(gold, (Set, Tuple)) else FiniteSet(gold)
b_set = pred if isinstance(pred, (Set, Tuple)) else FiniteSet(pred)
a_set = gold if isinstance(gold, (Set, Tuple)) else SympyFiniteSet(gold)
b_set = pred if isinstance(pred, (Set, Tuple)) else SympyFiniteSet(pred)

# If both are intervals, use interval comparison
if isinstance(a_set, Interval) and isinstance(b_set, Interval):
Expand All @@ -314,16 +348,16 @@ def sympy_compare_sets(gold: Set | Basic | MatrixBase | Tuple, pred: Set | Basic
return True

# For finite sets, compare elements
if isinstance(a_set, (FiniteSet, Tuple)) and isinstance(b_set, (FiniteSet, Tuple)):
if isinstance(a_set, (SympyFiniteSet, Tuple)) and isinstance(b_set, (SympyFiniteSet, Tuple)):
return sympy_deep_compare_set_and_tuple(a_set, b_set, precision)

# Because (1,2) is parsed as Interval(1,2,left_open=True,right_open=True), it could have that the
# correct is (1,2) and predicted is 1,2, which is parsed as Set(1,2)
if isinstance(a_set, Interval) and isinstance(b_set, (FiniteSet, Tuple)):
if isinstance(a_set, Interval) and isinstance(b_set, (SympyFiniteSet, Tuple)):
if a_set.is_open and len(b_set) == 2:
return sympy_deep_compare_set_and_tuple(Tuple(a_set.start, a_set.end), b_set, precision)

if isinstance(b_set, Interval) and isinstance(a_set, (FiniteSet, Tuple)):
if isinstance(b_set, Interval) and isinstance(a_set, (SympyFiniteSet, Tuple)):
if b_set.is_open and len(a_set) == 2:
return sympy_deep_compare_set_and_tuple(a_set, Tuple(b_set.start, b_set.end), precision)

Expand Down Expand Up @@ -401,11 +435,11 @@ def is_assignment_relation(expr: Basic | MatrixBase) -> bool:
Returns:
bool: True if expr is a relational expression or And of relations, False otherwise
"""
if isinstance(expr, Eq) and is_assignment_symbol(expr.lhs):
if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs):
return True

if isinstance(expr, And) and len(expr.args) > 0:
return all(isinstance(arg, Eq) for arg in expr.args) and is_assignment_symbol(expr.args[0].lhs)
return all(isinstance(arg, Eq) for arg in expr.args) and is_expr_of_only_symbols(expr.args[0].lhs)

return False

Expand Down Expand Up @@ -484,12 +518,12 @@ def sympy_expr_eq(gold: Basic | MatrixBase, pred: Basic | MatrixBase, precision:
# We assume that the gold never needs to be simplified, so we don't handle that case
# e.g 1+1+1=3 will never be simplified to 3; it would be possible to do so with lhs-rhs == 0, but we assume the gold is at its most simplified form.
# The new latex2sympy2 will actually convert such cases automatically, but so this is in theory not needed
if is_assignment_relation(gold) and not is_relation(pred):
if is_assignment_relation(gold) and not is_equation(pred):
gold = take_last_relation(gold).rhs

# Here we respect the gold and simplify accordingly, thus any of
# k=x+1+z or 1+1+1=3 will be simplified to rhs
if is_equation(pred) and not is_relation(gold):
if is_equation(pred) and not is_equation(gold):
pred = take_last_relation(pred).rhs

if is_relation(gold) and isinstance(pred, Set):
Expand Down
5 changes: 3 additions & 2 deletions src/math_verify/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from typing import Literal, Sequence

import sympy
from sympy import Basic, FiniteSet, MatrixBase, Number
from sympy import Basic, MatrixBase, Number
from latex2sympy2_extended.sets import FiniteSet
from sympy.parsing import parse_expr
from math_verify.grader import should_treat_as_complex
from latex2sympy2_extended.latex2sympy2 import (
Expand Down Expand Up @@ -458,7 +459,7 @@ def extract_latex(match: re.Match, latex_config: LatexExtractionConfig, timeout_
all_elements.extend(expr.args)
else:
all_elements.append(expr)
return sympy.FiniteSet(*all_elements), " and ".join(latex_strs)
return FiniteSet(*all_elements), " and ".join(latex_strs)

# Otherwise return the single expression
return latex_exprs[0], latex_strs[0]
Expand Down
9 changes: 1 addition & 8 deletions tests/test_numina_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@
r"$(3,2,7)$",
1,
),
(
r"$P(x) = 1$",
r"$p(x) = 1$",
1,
),
(
r"$V_{1}:V_{2}=11:21$",
r"$11:21$",
Expand Down Expand Up @@ -123,8 +118,6 @@
r"$\boxed{-5, \frac{14}{3}}$",
1,
),
#TODO: make sure that \, is translate to ,
# the or joining should be extend if one of the is a
(
r"\boxed{a=4,\,-8,\,-10}",
r"$\boxed{-10,-8,4}$",
Expand Down Expand Up @@ -163,7 +156,7 @@
(
r"$\text{Even}$",
r"$Even$",
1
)
# (
# r"$f(x)$",
Expand Down
31 changes: 26 additions & 5 deletions tests/test_open_thoughts.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@
r"\boxed{1},\boxed{2},\boxed{3}",
1,
),
(
r"$$x+z=1$$",
r"$$1$$",
0,
),
# (
# r"$$x+z=1$$",
# r"$$1$$",
# 0,
# ),
(
r"$$|AB|=1$$",
r"$$1$$",
Expand All @@ -128,6 +128,27 @@
r"$$1$$",
1,
),
(
r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$",
r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$",
1,
),
(
r"$y_{1}=-2 x^{2}+4 x+3, y_{2}=3 x^{2}+12 x+10$",
r"\($y_1 = \boxed{-2(x - 1)^2 + 5} \) and \( y_2 = \boxed{3(x + 2)^2 - 2} \) ",
1,
),
(
r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$",
r"$\boxed{\dfrac{108 + 31\sqrt{5}}{216}} \quad \text{and} \quad \boxed{\dfrac{108 - 31\sqrt{5}}{216}}$",
1,
),
(
r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$",
r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$",
1,
),
])
def test_numina_cases(gold, pred, expected):
assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected

0 comments on commit 6ef1e10

Please sign in to comment.