From 2426cd79443409535478af5cef959ba198f2a6cf Mon Sep 17 00:00:00 2001
From: Owen Lockwood <42878312+lockwo@users.noreply.github.com>
Date: Thu, 22 Aug 2024 22:42:50 -0700
Subject: [PATCH] a

---
 diffrax/_solver/stochastic_theta.py | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/diffrax/_solver/stochastic_theta.py b/diffrax/_solver/stochastic_theta.py
index 042681e4..99e8d7db 100644
--- a/diffrax/_solver/stochastic_theta.py
+++ b/diffrax/_solver/stochastic_theta.py
@@ -17,9 +17,16 @@
 
 
 def _implicit_relation(z1, nonlinear_solve_args):
-    vf_prod_drift, t1, y0, args, control, k0_drift, k0_diff, theta = (
-        nonlinear_solve_args
-    )
+    (
+        vf_prod_drift,
+        t1,
+        y0,
+        args,
+        control,
+        k0_drift,
+        k0_diff,
+        theta,
+    ) = nonlinear_solve_args
     add_state = (y0**ω + z1**ω).ω
     implicit_drift = (vf_prod_drift(t1, add_state, args, control) ** ω * theta).ω
     euler_drift = ((1 - theta) * k0_drift**ω).ω
@@ -61,9 +68,9 @@ class StochasticTheta(
 
     theta: float
     term_structure: ClassVar = MultiTerm[tuple[ODETerm, AbstractTerm]]
-    interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
-        LocalLinearInterpolation
-    )
+    interpolation_cls: ClassVar[
+        Callable[..., LocalLinearInterpolation]
+    ] = LocalLinearInterpolation
     root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)()
     root_find_max_steps: int = 10