Skip to content

Commit

Permalink
Replaced abs with relu, and updated some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking authored and patrick-kidger committed Dec 20, 2023
1 parent 00d7af2 commit 40402de
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def _split_interval(


class VirtualBrownianTree(AbstractBrownianPath):
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`,
and is piecewise quadratic at that discretisation.
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
Can be initialised with `levy_area` set to `""`, or `"space-time"`.
If `levy_area="space_time"`, then it also computes space-time Lévy area `H`.
Expand Down Expand Up @@ -267,16 +266,12 @@ def _evaluate_leaf(
)

def _cond_fun(_state):
# Slight adaptation on the version of the algorithm given in the
# above-referenced thesis. There the returned value is snapped to one of
# the dyadic grid points, so they just stop once
# jnp.abs(τ - state.s) > self.tol
# Here, because we use quadratic splines to get better samples, we always
# iterate down to the level of the spline.
"""Condition for the binary search for r."""
# If true, continue splitting the interval and descending the tree.
return 2.0 ** (-_state.level) > self.tol

def _body_fun(_state: _State):
"""Single-step of binary search for r."""
"""Single-step of the binary search for r."""

(
_t,
Expand Down Expand Up @@ -318,15 +313,16 @@ def _body_fun(_state: _State):
s = final_state.s
su = 2.0**-final_state.level

sr = r - s
ru = su - sr # make sure su = sr + ru regardless of cancellation error
sr = jax.nn.relu(r - s)
# make sure su = sr + ru regardless of cancellation error
ru = jax.nn.relu(su - sr)

w_s, w_u, w_su = final_state.w_s_u_su

# BM only case
if self.levy_area == "":
z = jr.normal(final_state.key, shape, dtype)
w_sr = sr / su * w_su + jnp.sqrt(jnp.abs(sr * ru / su)) * z
w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z
w_r = w_s + w_sr
return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None)

Expand All @@ -340,15 +336,15 @@ def _body_fun(_state: _State):
x1 = jr.normal(key1, shape, dtype)
x2 = jr.normal(key2, shape, dtype)

sr_ru_half = jnp.sqrt(jnp.abs(sr * ru))
d = jnp.sqrt(jnp.abs(sr3 + ru3))
sr_ru_half = jnp.sqrt(sr * ru)
d = jnp.sqrt(sr3 + ru3)
d_prime = 1 / (2 * su * d)
a = d_prime * sr3 * sr_ru_half
b = d_prime * ru3 * sr_ru_half

w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
w_r = w_s + w_sr
c = jnp.sqrt(jnp.abs(3 * sr3 * ru3)) / (6 * d)
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)

Expand Down

0 comments on commit 40402de

Please sign in to comment.