Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beignet.root_scalar #26

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
e3c7f08
wip beignet.root
kleinhenz May 29, 2024
2eb9dc9
cleanup
0x00b1 Jun 10, 2024
487258f
cleanup
0x00b1 Jun 10, 2024
e82b912
cleanup
0x00b1 Jun 10, 2024
7fea1b4
add bisection
kleinhenz Jul 30, 2024
d02b7df
wip autograd
kleinhenz Jul 30, 2024
936923b
wip autograd
kleinhenz Jul 31, 2024
2d4fa10
wip
kleinhenz Jul 31, 2024
634609d
wip
kleinhenz Aug 3, 2024
94db89a
add test
kleinhenz Aug 3, 2024
80138d5
wip
kleinhenz Aug 4, 2024
8e50038
wip
kleinhenz Aug 4, 2024
4a2c8db
comment
kleinhenz Aug 4, 2024
0ea18f7
wip cleanup chandrupatla
kleinhenz Aug 4, 2024
0f40be1
wip
kleinhenz Aug 4, 2024
7a52c2d
add chandrupatla to root_scalar
kleinhenz Aug 4, 2024
581af1a
parametrize root_scalar tests by method
kleinhenz Aug 4, 2024
86e44d8
flyby: remove unused out argument
kleinhenz Aug 4, 2024
28944b7
add type hint
kleinhenz Aug 4, 2024
2e926ad
wip beignet.root
0x00b1 Aug 9, 2024
7e028c8
cleanup
0x00b1 Aug 9, 2024
288b769
cleanup
kleinhenz Aug 9, 2024
8b868f3
f -> func
kleinhenz Aug 9, 2024
c0c1875
cleanup
kleinhenz Aug 10, 2024
f268ec6
cleanup
0x00b1 Aug 12, 2024
ac715de
remove inplace update of iterations
kleinhenz Aug 12, 2024
86ef2d0
use local import
kleinhenz Aug 12, 2024
9a829bb
custom_scalar_root
0x00b1 Aug 12, 2024
748d799
custom_root
0x00b1 Aug 13, 2024
7ce1299
custom_root
0x00b1 Aug 13, 2024
50b9416
make broadcast explicit and forward vmap
kleinhenz Aug 28, 2024
d116541
make broadcast explicit in chandrupatla
kleinhenz Aug 28, 2024
1d97026
cleanup
kleinhenz Aug 28, 2024
2e09fa9
simplify custom_scalar_root + add jacrev test
kleinhenz Aug 28, 2024
729ae4e
add docstring to root_scalar
kleinhenz Aug 28, 2024
9b3bf58
improve docstrings
kleinhenz Aug 28, 2024
f47e5f6
remove non scalar code for this pr
kleinhenz Feb 11, 2025
2fb37ac
simplify signature
kleinhenz Feb 13, 2025
499da53
wip use higher_order_ops
kleinhenz Feb 17, 2025
c2c5e6b
wip use higher_order_ops
kleinhenz Feb 17, 2025
fcd1a80
fix aliasing
kleinhenz Feb 17, 2025
ccb99cb
fix aliasing
kleinhenz Feb 17, 2025
b176e7a
enable compile with fullgraph=True
kleinhenz Feb 17, 2025
c0fd8a3
add unroll option
kleinhenz Mar 2, 2025
789fff2
skip compile tests on windows
kleinhenz Mar 2, 2025
8905a36
remove RootSolutionInfo dataclass
kleinhenz Mar 2, 2025
7c790ae
ci wip
kleinhenz Mar 2, 2025
f6a7a18
fix skip condition
kleinhenz Mar 2, 2025
6188a5c
remove tmp
kleinhenz Mar 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ._apply_rotation_matrix import apply_rotation_matrix
from ._apply_rotation_vector import apply_rotation_vector
from ._apply_transform import apply_transform
from ._bisect import bisect
from ._chandrupatla import chandrupatla
from ._chebyshev_extrema import chebyshev_extrema
from ._chebyshev_gauss_quadrature import chebyshev_gauss_quadrature
from ._chebyshev_interpolation import chebyshev_interpolation
Expand Down Expand Up @@ -311,6 +313,7 @@
from ._random_quaternion import random_quaternion
from ._random_rotation_matrix import random_rotation_matrix
from ._random_rotation_vector import random_rotation_vector
from ._root_scalar import root_scalar
from ._rotation_matrix_identity import rotation_matrix_identity
from ._rotation_matrix_magnitude import rotation_matrix_magnitude
from ._rotation_matrix_mean import rotation_matrix_mean
Expand Down Expand Up @@ -564,6 +567,7 @@
"random_quaternion",
"random_rotation_matrix",
"random_rotation_vector",
"root_scalar",
"rotation_matrix_identity",
"rotation_matrix_magnitude",
"rotation_matrix_mean",
Expand All @@ -589,4 +593,5 @@
"trim_physicists_hermite_polynomial_coefficients",
"trim_polynomial_coefficients",
"trim_probabilists_hermite_polynomial_coefficients",
"root",
]
106 changes: 106 additions & 0 deletions src/beignet/_bisect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Callable

import torch
from torch import Tensor

from ._root_scalar import RootSolutionInfo


def bisect(
func: Callable,
*args,
lower: float | Tensor,
upper: float | Tensor,
rtol: float | None = None,
atol: float | None = None,
maxiter: int = 100,
return_solution_info: bool = False,
dtype=None,
device=None,
**_,
) -> Tensor | tuple[Tensor, RootSolutionInfo]:
"""Find the root of a scalar (elementwise) function using bisection.

This method is slow but guarenteed to converge.

Parameters
----------
func: Callable
Function to find a root of. Called as `f(x, *args)`.
The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`.
Handling *args via broadcasting is acceptable.

*args
Extra arguments to be passed to `func`.

lower: float | Tensor
Lower bracket for root

upper: float | Tensor
Upper bracket for root

rtol: float | None = None
Relative tolerance

atol: float | None = None
Absolute tolerance

maxiter: int = 100
Maximum number of iterations

return_solution_info: bool = False
Whether to return a `RootSolutionInfo` object

dtype = None
if upper/lower are passed as floats instead of tensors
use this dtype when constructing the tensor.

device = None
if upper/lower are passed as floats instead of tensors
use this device when constructing the tensor.

Returns
-------
Tensor | tuple[Tensor, RootSolutionInfo]
"""
a = torch.as_tensor(lower, dtype=dtype, device=device)
b = torch.as_tensor(upper, dtype=dtype, device=device)
a, b, *args = torch.broadcast_tensors(a, b, *args)

fa = func(a, *args)
fb = func(b, *args)

c = (a + b) / 2
fc = func(c, *args)

eps = torch.finfo(fa.dtype).eps

if rtol is None:
rtol = eps

if atol is None:
atol = 2 * eps

converged = torch.zeros_like(fa, dtype=torch.bool)
iterations = torch.zeros_like(fa, dtype=torch.int)

if (torch.sign(fa) * torch.sign(fb) > 0).any():
raise ValueError("a and b must bracket a root")

for _ in range(maxiter):
converged = converged | ((b - a) / 2 < (rtol * torch.abs(c) + atol))

if converged.all():
break

cond = torch.sign(fc) == torch.sign(fa)
a = torch.where(cond, c, a)
b = torch.where(cond, b, c)
c = (a + b) / 2
fc = func(c, *args)
iterations = iterations + ~converged

if return_solution_info:
return c, RootSolutionInfo(converged=converged, iterations=iterations)
else:
return c
167 changes: 167 additions & 0 deletions src/beignet/_chandrupatla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Callable

import torch
from torch import Tensor

from ._root_scalar import RootSolutionInfo


def chandrupatla(
func: Callable,
*args,
lower: float | Tensor,
upper: float | Tensor,
rtol: float | None = None,
atol: float | None = None,
maxiter: int = 100,
return_solution_info: bool = False,
dtype=None,
device=None,
**_,
) -> Tensor | tuple[Tensor, RootSolutionInfo]:
"""Find the root of a scalar (elementwise) function using chandrupatla method.

This method is slow but guarenteed to converge.

Parameters
----------
func: Callable
Function to find a root of. Called as `f(x, *args)`.
The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`.
Handling *args via broadcasting is acceptable.

*args
Extra arguments to be passed to `func`.

lower: float | Tensor
Lower bracket for root

upper: float | Tensor
Upper bracket for root

rtol: float | None = None
Relative tolerance

atol: float | None = None
Absolute tolerance

maxiter: int = 100
Maximum number of iterations

return_solution_info: bool = False
Whether to return a `RootSolutionInfo` object

dtype = None
if upper/lower are passed as floats instead of tensors
use this dtype when constructing the tensor.

device = None
if upper/lower are passed as floats instead of tensors
use this device when constructing the tensor.

Returns
-------
Tensor | tuple[Tensor, RootSolutionInfo]


References
----------

[1] Tirupathi R. Chandrupatla. A new hybrid quadratic/bisection algorithm for
finding the zero of a nonlinear function without using derivatives.
Advances in Engineering Software, 28.3:145-149, 1997.
"""
# maintain three points a,b,c for inverse quadratic interpolation
# we will keep (a,b) as the bracketing interval
a = torch.as_tensor(lower, dtype=dtype, device=device)
b = torch.as_tensor(upper, dtype=dtype, device=device)
a, b, *args = torch.broadcast_tensors(a, b, *args)
c = a

fa = func(a, *args)
fb = func(b, *args)
fc = fa

# root estimate
xm = torch.where(torch.abs(fa) < torch.abs(fb), a, b)

eps = torch.finfo(fa.dtype).eps

if rtol is None:
rtol = eps

if atol is None:
atol = 2 * eps

converged = torch.zeros_like(fa, dtype=torch.bool)
iterations = torch.zeros_like(fa, dtype=torch.int)

if (torch.sign(fa) * torch.sign(fb) > 0).any():
raise ValueError("a and b must bracket a root")

for _ in range(maxiter):
xm = torch.where(
converged, xm, torch.where(torch.abs(fa) < torch.abs(fb), a, b)
)
tol = atol + torch.abs(xm) * rtol
bracket_size = torch.abs(b - a)
tlim = tol / bracket_size
# converged = converged | 0.5 * bracket_size < tol
converged = converged | (tlim > 0.5)

if converged.all():
break

a, b, c, fa, fb, fc = _find_root_chandrupatla_iter(
func, *args, a=a, b=b, c=c, fa=fa, fb=fb, fc=fc, tlim=tlim
)

iterations = iterations + ~converged

if return_solution_info:
return xm, RootSolutionInfo(converged=converged, iterations=iterations)
else:
return xm


def _find_root_chandrupatla_iter(
func: Callable,
*args,
a: Tensor,
b: Tensor,
c: Tensor,
fa: Tensor,
fb: Tensor,
fc: Tensor,
tlim: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
# check validity of inverse quadratic interpolation
xi = (a - b) / (c - b)
phi = (fa - fb) / (fc - fb)
do_iqi = (phi.pow(2) < xi) & ((1 - phi).pow(2) < (1 - xi))

# use iqi where applicable, otherwise bisect interval
t = torch.where(
do_iqi,
fa / (fb - fa) * fc / (fb - fc)
+ (c - a) / (b - a) * fa / (fc - fa) * fb / (fc - fb),
0.5,
)
t = torch.clip(t, min=tlim, max=1 - tlim)

xt = a + t * (b - a)
ft = func(xt, *args)

# check which side of root t is on
cond = torch.sign(ft) == torch.sign(fa)

# update a,b,c maintaining (a,b) a bracket of root
# NOTE we do not maintain the order of a and b
c = torch.where(cond, a, b)
fc = torch.where(cond, fa, fb)
b = torch.where(cond, b, a)
fb = torch.where(cond, fb, fa)
a = xt
fa = ft

return a, b, c, fa, fb, fc
67 changes: 67 additions & 0 deletions src/beignet/_root_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import dataclasses
from typing import Callable, Literal

from torch import Tensor

import beignet
import beignet.func


@dataclasses.dataclass
class RootSolutionInfo:
converged: Tensor
iterations: Tensor


def root_scalar(
func: Callable,
*args,
method: Literal["bisect", "chandrupatla"] = "chandrupatla",
implicit_diff: bool = True,
options: dict | None = None,
) -> Tensor | tuple[Tensor, RootSolutionInfo]:
"""
Find the root of a scalar (elementwise) function.

Parameters
----------
func: Callable
Function to find a root of. Called as `f(x, *args)`.
The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`.
Handling *args via broadcasting is acceptable.

*args
Extra arguments to be passed to `func`.

method: Literal["bisect", "chandrupatla"] = "chandrupatla"
Solver method to use.
* bisect: `beignet.bisect`
* chandrupatla: `beignet.chandrupatla`
See docstring of underlying solvers for description of options dict.

implicit_diff: bool = True
If true, the solver is wrapped in `beignet.func.custom_scalar_root` which
enables gradients with respect to *args using implicit differentiation.

options: dict | None = None
A dictionary of options that are passed through to the solver as keyword args.


Returns
-------
Tensor | tuple[Tensor, RootSolutionInfo]
"""
if options is None:
options = {}

if method == "bisect":
solver = beignet.bisect
elif method == "chandrupatla":
solver = beignet.chandrupatla
else:
raise ValueError(f"method {method} not recognized")

if implicit_diff:
solver = beignet.func.custom_scalar_root(solver)

return solver(func, *args, **options)
6 changes: 6 additions & 0 deletions src/beignet/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from ._custom_scalar_root import custom_scalar_root
from ._space import space

__all__ = [
"custom_scalar_root",
"space",
]
Loading
Loading