From 2426cd79443409535478af5cef959ba198f2a6cf Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:42:50 -0700 Subject: [PATCH] a --- diffrax/_solver/stochastic_theta.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/diffrax/_solver/stochastic_theta.py b/diffrax/_solver/stochastic_theta.py index 042681e4..99e8d7db 100644 --- a/diffrax/_solver/stochastic_theta.py +++ b/diffrax/_solver/stochastic_theta.py @@ -17,9 +17,16 @@ def _implicit_relation(z1, nonlinear_solve_args): - vf_prod_drift, t1, y0, args, control, k0_drift, k0_diff, theta = ( - nonlinear_solve_args - ) + ( + vf_prod_drift, + t1, + y0, + args, + control, + k0_drift, + k0_diff, + theta, + ) = nonlinear_solve_args add_state = (y0**ω + z1**ω).ω implicit_drift = (vf_prod_drift(t1, add_state, args, control) ** ω * theta).ω euler_drift = ((1 - theta) * k0_drift**ω).ω @@ -61,9 +68,9 @@ class StochasticTheta( theta: float term_structure: ClassVar = MultiTerm[tuple[ODETerm, AbstractTerm]] - interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( - LocalLinearInterpolation - ) + interpolation_cls: ClassVar[ + Callable[..., LocalLinearInterpolation] + ] = LocalLinearInterpolation root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)() root_find_max_steps: int = 10