diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index a75df2da..5083948b 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -103,8 +103,7 @@ def _split_interval( class VirtualBrownianTree(AbstractBrownianPath): - """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`, - and is piecewise quadratic at that discretisation. + """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`. Can be initialised with `levy_area` set to `""`, or `"space-time"`. If `levy_area="space_time"`, then it also computes space-time Lévy area `H`. @@ -267,16 +266,12 @@ def _evaluate_leaf( ) def _cond_fun(_state): - # Slight adaptation on the version of the algorithm given in the - # above-referenced thesis. There the returned value is snapped to one of - # the dyadic grid points, so they just stop once - # jnp.abs(τ - state.s) > self.tol - # Here, because we use quadratic splines to get better samples, we always - # iterate down to the level of the spline. + """Condition for the binary search for r.""" + # If true, continue splitting the interval and descending the tree. return 2.0 ** (-_state.level) > self.tol def _body_fun(_state: _State): - """Single-step of binary search for r.""" + """Single-step of the binary search for r.""" ( _t, @@ -318,15 +313,16 @@ def _body_fun(_state: _State): s = final_state.s su = 2.0**-final_state.level - sr = r - s - ru = su - sr # make sure su = sr + ru regardless of cancellation error + sr = jax.nn.relu(r - s) + # make sure su = sr + ru regardless of cancellation error + ru = jax.nn.relu(su - sr) w_s, w_u, w_su = final_state.w_s_u_su # BM only case if self.levy_area == "": z = jr.normal(final_state.key, shape, dtype) - w_sr = sr / su * w_su + jnp.sqrt(jnp.abs(sr * ru / su)) * z + w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z w_r = w_s + w_sr return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None) @@ -340,15 +336,15 @@ def _body_fun(_state: _State): x1 = jr.normal(key1, shape, dtype) x2 = jr.normal(key2, shape, dtype) - sr_ru_half = jnp.sqrt(jnp.abs(sr * ru)) - d = jnp.sqrt(jnp.abs(sr3 + ru3)) + sr_ru_half = jnp.sqrt(sr * ru) + d = jnp.sqrt(sr3 + ru3) d_prime = 1 / (2 * su * d) a = d_prime * sr3 * sr_ru_half b = d_prime * ru3 * sr_ru_half w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1 w_r = w_s + w_sr - c = jnp.sqrt(jnp.abs(3 * sr3 * ru3)) / (6 * d) + c = jnp.sqrt(3 * sr3 * ru3) / (6 * d) bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2 bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)