Skip to content

Commit

Permalink
Quadratic solver matches quartic in ins/outs
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiefl committed Jan 20, 2025
1 parent a1aa278 commit d206410
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 58 deletions.
28 changes: 14 additions & 14 deletions pooltool/ptmath/roots/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def filter_non_physical_roots(
Returns:
A 1D array of the same shape as `roots`, where non-physical (negative or
“too imaginary”) roots are replaced by `np.inf + 0j`, and valid roots are
retained unchanged.
“too imaginary”) roots are replaced by `np.inf + 0j`, and valid roots
remain unchanged.
"""
processed_roots = np.full(len(roots), np.inf, dtype=np.complex128)

Expand Down Expand Up @@ -126,7 +126,7 @@ def filter_non_physical_roots_many(
return processed_roots


def get_smallest_physical_roots(
def get_smallest_physical_root_many(
roots: NDArray[np.complex128],
abs_or_rel_cutoff: float = 1e-3,
rtol: float = 1e-3,
Expand All @@ -136,8 +136,8 @@ def get_smallest_physical_roots(
Args:
roots:
A mxn array of polynomial root solutions, where m is the number of equations
and n is the order of the polynomial.
An array of shape (m, n) of polynomial root solutions, where m is the number
of equations and n is the order of the polynomial.
abs_or_rel_cutoff:
The criteria for a root being real depends on the magnitude of its real
component. If it's large, we require the imaginary component to be less than
Expand All @@ -154,17 +154,16 @@ def get_smallest_physical_roots(
the root is considered real if r.imag == 0, too.
Returns:
An array of shape m. Each value is the smallest root that is real and
positive. If no such root exists (e.g. all roots are complex), then
`np.inf` is used.
An array of shape (m,). Values are the smallest root that is real and positive.
If no such root exists (e.g. all roots are complex), then `np.inf` is used.
"""

processed_roots = filter_non_physical_roots_many(
roots, abs_or_rel_cutoff, rtol, atol
)

# Find the minimum real positive root in each row
min_real_positive_roots = np.min(processed_roots.real, axis=1)
min_real_positive_roots = np.min(processed_roots.real, axis=-1)

return min_real_positive_roots

Expand All @@ -179,8 +178,8 @@ def get_sorted_physical_roots(
Args:
roots:
A mxn array of polynomial root solutions, where m is the number of equations
and n is the order of the polynomial.
An array of shape (m, n) of polynomial root solutions, where m is the number
of equations and n is the order of the polynomial.
abs_or_rel_cutoff:
The criteria for a root being real depends on the magnitude of its real
component. If it's large, we require the imaginary component to be less than
Expand All @@ -197,14 +196,15 @@ def get_sorted_physical_roots(
the root is considered real if r.imag == 0, too.
Returns:
An array of shape mx4. Columns are sorted by smallest real positive root.
Negative and imaginary roots are converted to infinity.
An array of shape (m,). Values are the smallest root that is real and positive.
Columns are sorted by smallest real positive root. Negative and imaginary roots
are converted to infinity.
"""

processed_roots = filter_non_physical_roots_many(
roots, abs_or_rel_cutoff, rtol, atol
)

sorted_real_positive_roots = np.sort(processed_roots.real, axis=1)
sorted_real_positive_roots = np.sort(processed_roots.real, axis=-1)

return sorted_real_positive_roots
49 changes: 48 additions & 1 deletion pooltool/ptmath/roots/quadratic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import cmath
from typing import Tuple

import numpy as np
from numba import jit
from numpy.typing import NDArray

import pooltool.constants as const


@jit(nopython=True, cache=const.use_numba_cache)
def solve(a: float, b: float, c: float) -> Tuple[float, float]:
def solve_old(a: float, b: float, c: float) -> Tuple[float, float]:
"""Solve a quadratic equation At^2 + Bt + C = 0 (just-in-time compiled)"""
if a == 0:
if b == 0:
Expand All @@ -26,3 +28,48 @@ def solve(a: float, b: float, c: float) -> Tuple[float, float]:
u1 = (-bp - delta**0.5) / a
u2 = -u1 - b / a
return u1, u2


@jit(nopython=True, cache=const.use_numba_cache)
def solve(a: float, b: float, c: float) -> NDArray[np.complex128]:
_a = complex(a)
_b = complex(b)
_c = complex(c)

roots = np.full(2, np.nan, dtype=np.complex128)

if abs(_a) != 0:
# Quadratic case
d = _b * _b - 4 * _a * _c
sqrt_d = cmath.sqrt(d)

# Sign trick to reduce catastrophic cancellation
sign_b = 1.0 if _b.real >= 0 else -1.0

r1_num = -_b - sign_b * sqrt_d
r1_den = 2 * _a

# Fallback if numerator is tiny
if abs(r1_num) < 1e-14 * abs(r1_den):
r1_num = -_b + sign_b * sqrt_d

r1 = r1_num / r1_den

# Use product identity for x2
if abs(r1) < 1e-14:
r2 = (-_b + sqrt_d) / (2 * _a)
else:
r2 = (_c / _a) / r1

roots[0] = r1
roots[1] = r2
return roots

if abs(_b) != 0:
# Linear case
r1 = -_c / _b
roots[0] = r1
return roots

# Equation is just c=0. Either zero or infinite solutions. Returns nans
return roots
16 changes: 7 additions & 9 deletions pooltool/ptmath/roots/quartic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.typing import NDArray

import pooltool.constants as const
from pooltool.ptmath.roots.quadratic import solve as solve_quadratic
from pooltool.ptmath.roots import quadratic
from pooltool.utils.strenum import StrEnum, auto


Expand All @@ -15,11 +15,12 @@ class QuarticSolver(StrEnum):


@jit(nopython=True, cache=const.use_numba_cache)
def _solve_quadratics(ps: NDArray[np.complex128]) -> NDArray[np.complex128]:
def _solve_quadratics(ps: NDArray[np.float64]) -> NDArray[np.complex128]:
"""Solves an array of quadratics.
This is used internally by the quartic solver when it is passed coefficients where
a=b=0, which make the polynomial quadratic, not quartic.
a=b=0, which make the polynomial quadratic, not quartic. It delegates to
`quadratic.solve`.
Args:
ps:
Expand All @@ -34,9 +35,7 @@ def _solve_quadratics(ps: NDArray[np.complex128]) -> NDArray[np.complex128]:
roots = np.full((m, 4), np.inf, dtype=np.complex128)

for i in range(m):
r1, r2 = solve_quadratic(ps[i, 0].real, ps[i, 1].real, ps[i, 2].real)
roots[i, 0] = r1
roots[i, 1] = r2
roots[i, :2] = quadratic.solve(ps[i, 0], ps[i, 1], ps[i, 2])

return roots

Expand Down Expand Up @@ -88,8 +87,7 @@ def solve_quartics(
all_roots[quartic_mask] = quartic_roots

if np.any(quadratic_mask):
quadratic_ps = ps[quadratic_mask, 2:].astype(np.complex128)
quadratic_roots = _solve_quadratics(quadratic_ps)
quadratic_roots = _solve_quadratics(ps[quadratic_mask, 2:])
all_roots[quadratic_mask] = quadratic_roots

return all_roots
Expand Down Expand Up @@ -207,7 +205,7 @@ def _solve(

if a == 0 and b == 0:
# Quadratic!
return _solve_quadratics(p[np.newaxis, 2:])[0, :], 4
return quadratic.solve(p[2], p[3], p[4]), 4

# The analytic solutions don't like 0s
if (p == 0).any():
Expand Down
90 changes: 56 additions & 34 deletions tests/ptmath/roots/test_quadratic.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,64 @@
import math

import numpy as np
import pytest

from pooltool.ptmath.roots.quadratic import solve


def test_solve_standard_quadratic():
# x^2 - 5x + 6 = 0
# Solutions: x = 2 or x = 3
# Solutions: x = 2, x = 3
u1, u2 = solve(1.0, -5.0, 6.0)
solutions = sorted([u1, u2])
assert pytest.approx(solutions[0]) == 2.0
assert pytest.approx(solutions[1]) == 3.0
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> 2.0 + 0.0j
assert solutions[0].real == 2.0
assert solutions[0].imag == 0.0
# Second root -> 3.0 + 0.0j
assert solutions[1].real == 3.0
assert solutions[1].imag == 0.0

# x^2 - x - 2 = 0
# Solutions: x = 2, x = -1
# Solutions: x = -1, x = 2
u1, u2 = solve(1.0, -1.0, -2.0)
solutions = sorted([u1, u2])
assert pytest.approx(solutions[0]) == -1.0
assert pytest.approx(solutions[1]) == 2.0
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> -1.0 + 0.0j
assert solutions[0].real == -1.0
assert solutions[0].imag == 0.0
# Second root -> 2.0 + 0.0j
assert solutions[1].real == 2.0
assert solutions[1].imag == 0.0

# Perfect square: x^2 - 4x + 4 = 0
# Single repeated solution: x = 2
u1, u2 = solve(1.0, -4.0, 4.0)
solutions = sorted([u1, u2])
assert pytest.approx(solutions[0]) == 2.0
assert pytest.approx(solutions[1]) == 2.0
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# Both roots -> 2.0 + 0.0j
for root in solutions:
assert root.real == 2.0
assert root.imag == 0.0

# Difference of squares: x^2 - 4 = 0
# Solutions: x = 2, x = -2
# Solutions: x = -2, x = 2
u1, u2 = solve(1.0, 0.0, -4.0)
solutions = sorted([u1, u2])
assert pytest.approx(solutions[0], 0.0001) == -2.0
assert pytest.approx(solutions[1], 0.0001) == 2.0


def test_solve_negative_discriminant():
# Equation with negative discriminant: x^2 + x + 1 = 0 Solutions are complex, but
# since we are taking the square root directly, we get nan.

u1, u2 = solve(1.0, 1.0, 1.0)
assert math.isnan(u1)
assert math.isnan(u2)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> -2.0 + 0.0j
assert solutions[0].real == -2.0
assert solutions[0].imag == 0.0
# Second root -> 2.0 + 0.0j
assert solutions[1].real == 2.0
assert solutions[1].imag == 0.0

# Complex roots: x^2 + 1 = 0
# Solutions: x = i, x = -i
u1, u2 = solve(1.0, 0.0, 1.0)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> -i -> (0.0, -1.0)
assert solutions[0].real == 0.0
assert solutions[0].imag == -1.0
# Second root -> i -> (0.0, 1.0)
assert solutions[1].real == 0.0
assert solutions[1].imag == 1.0


def test_solve_large_values():
Expand All @@ -53,31 +70,36 @@ def test_solve_large_values():
u1, u2 = solve(a, b, c)
solutions = sorted([u1, u2])

# The large root should be close to 1e7, the smaller should be close to 1e-7. The
# required relative tolerance required to pass this test is pretty large (1e-2).
assert pytest.approx(solutions[0], rel=1e-2) == 1e-7
assert pytest.approx(solutions[1], rel=1e-2) == 1e7
# The large root should be close to 1e7, the smaller should be close to 1e-7. We're
# able to use a very small relative tolerance due to the way the solver avoids
# catastrophic cancellation.
assert pytest.approx(solutions[0].real, rel=1e-12) == 1e-7
assert pytest.approx(solutions[1].real, rel=1e-12) == 1e7

assert solutions[0].imag == 0.0
assert solutions[1].imag == 0.0


def test_solve_linear_equation():
# a=0, b≠0 => linear equation b*t + c = 0 => t=-c/b
# e.g. 2t + 4 = 0 => t=-2
r1, r2 = solve(0.0, 2.0, 4.0)
assert r1 == -2.0
assert math.isnan(r2) # The second "root" should be NaN for a linear equation
assert r1.real == -2.0
assert r1.imag == 0.0
assert np.isnan(r2)


def test_solve_degenerate_no_solution():
# a=0, b=0, c≠0 => no solutions
# e.g. 0*t^2 + 0*t + 5 = 0 => no real solution
r1, r2 = solve(0.0, 0.0, 5.0)
assert math.isnan(r1)
assert math.isnan(r2)
assert np.isnan(r1)
assert np.isnan(r2)


def test_solve_degenerate_infinite_solutions():
# a=0, b=0, c=0 => infinite solutions
# e.g. 0*t^2 + 0*t + 0 = 0 => t can be anything
r1, r2 = solve(0.0, 0.0, 0.0)
assert math.isnan(r1)
assert math.isnan(r2)
assert np.isnan(r1)
assert np.isnan(r2)

0 comments on commit d206410

Please sign in to comment.