From 5bdd890f661fe243c0336a5316b38613b76b10a2 Mon Sep 17 00:00:00 2001 From: andyElking Date: Fri, 18 Oct 2024 16:40:55 +0100 Subject: [PATCH] small fix of docs in all three and a return type in quicsort --- diffrax/_solver/align.py | 5 +++-- diffrax/_solver/quicsort.py | 8 ++++---- diffrax/_solver/should.py | 5 +++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/diffrax/_solver/align.py b/diffrax/_solver/align.py index 4fb5f47e..c6bc6105 100644 --- a/diffrax/_solver/align.py +++ b/diffrax/_solver/align.py @@ -46,8 +46,9 @@ def __init__(self, beta, a1, b1, aa, chh): class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): r"""The Adaptive Langevin via Interpolated Gradients and Noise method designed by James Foster. This is a second order solver for the - Underdamped Langevin Diffusion, the terms for which can be created using - [`diffrax.make_underdamped_langevin_term`][]. Uses two evaluations of the vector + Underdamped Langevin Diffusion, and accepts terms of the form + `MultiTerm(UnderdampedLangevinDriftTerm, UnderdampedLangevinDiffusionTerm)`. + Uses two evaluations of the vector field per step, but is FSAL, so in practice it only requires one. ??? cite "Reference" diff --git a/diffrax/_solver/quicsort.py b/diffrax/_solver/quicsort.py index da9b53be..4f21bd6f 100644 --- a/diffrax/_solver/quicsort.py +++ b/diffrax/_solver/quicsort.py @@ -47,9 +47,9 @@ def __init__(self, beta_lr1, a_lr1, b_lr1, a_third, a_div_h): class QUICSORT(AbstractFosterLangevinSRK[_QUICSORTCoeffs, None]): r"""The QUadrature Inspired and Contractive Shifted ODE with Runge-Kutta Three method by James Foster and Daire O'Kane. This is a third order solver for the - Underdamped Langevin Diffusion, the terms for which can be created using - [`diffrax.make_underdamped_langevin_term`][]. Uses two evaluations of the vector - field per step. + Underdamped Langevin Diffusion, and accepts terms of the form + `MultiTerm(UnderdampedLangevinDriftTerm, UnderdampedLangevinDiffusionTerm)`. + Uses two evaluations of the vector field per step. ??? cite "Reference" @@ -199,7 +199,7 @@ def _compute_step( coeffs: _QUICSORTCoeffs, rho: UnderdampedLangevinX, prev_f: Optional[UnderdampedLangevinX], - ) -> tuple[UnderdampedLangevinX, UnderdampedLangevinX, UnderdampedLangevinX, None]: + ) -> tuple[UnderdampedLangevinX, UnderdampedLangevinX, None, None]: del prev_f dtypes = jtu.tree_map(jnp.result_type, x0) w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) diff --git a/diffrax/_solver/should.py b/diffrax/_solver/should.py index 6d8b0cb0..caab54d3 100644 --- a/diffrax/_solver/should.py +++ b/diffrax/_solver/should.py @@ -59,8 +59,9 @@ def __init__(self, beta_half, a_half, b_half, beta1, a1, b1, aa, chh, ckk): class ShOULD(AbstractFosterLangevinSRK[_ShOULDCoeffs, None]): r"""The Shifted-ODE Runge-Kutta Three method designed by James Foster. This is a third order solver for the - Underdamped Langevin Diffusion, the terms for which can be created using - [`diffrax.make_underdamped_langevin_term`][]. Uses three evaluations of the vector + Underdamped Langevin Diffusion, the terms of the form + `MultiTerm(UnderdampedLangevinDriftTerm, UnderdampedLangevinDiffusionTerm)`. + Uses three evaluations of the vector field per step, but is FSAL, so in practice it only requires two. ??? cite "Reference"