Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced abs with relu, and updated some comments #340

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def _split_interval(


class VirtualBrownianTree(AbstractBrownianPath):
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`,
and is piecewise quadratic at that discretisation.
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.

Can be initialised with `levy_area` set to `""`, or `"space-time"`.
If `levy_area="space_time"`, then it also computes space-time Lévy area `H`.
Expand Down Expand Up @@ -267,16 +266,12 @@ def _evaluate_leaf(
)

def _cond_fun(_state):
# Slight adaptation on the version of the algorithm given in the
# above-referenced thesis. There the returned value is snapped to one of
# the dyadic grid points, so they just stop once
# jnp.abs(τ - state.s) > self.tol
# Here, because we use quadratic splines to get better samples, we always
# iterate down to the level of the spline.
"""Condition for the binary search for r."""
# If true, continue splitting the interval and descending the tree.
return 2.0 ** (-_state.level) > self.tol

def _body_fun(_state: _State):
"""Single-step of binary search for r."""
"""Single-step of the binary search for r."""

(
_t,
Expand Down Expand Up @@ -318,15 +313,16 @@ def _body_fun(_state: _State):
s = final_state.s
su = 2.0**-final_state.level

sr = r - s
ru = su - sr # make sure su = sr + ru regardless of cancellation error
sr = jax.nn.relu(r - s)
# make sure su = sr + ru regardless of cancellation error
ru = jax.nn.relu(su - sr)

w_s, w_u, w_su = final_state.w_s_u_su

# BM only case
if self.levy_area == "":
z = jr.normal(final_state.key, shape, dtype)
w_sr = sr / su * w_su + jnp.sqrt(jnp.abs(sr * ru / su)) * z
w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z
w_r = w_s + w_sr
return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None)

Expand All @@ -340,15 +336,15 @@ def _body_fun(_state: _State):
x1 = jr.normal(key1, shape, dtype)
x2 = jr.normal(key2, shape, dtype)

sr_ru_half = jnp.sqrt(jnp.abs(sr * ru))
d = jnp.sqrt(jnp.abs(sr3 + ru3))
sr_ru_half = jnp.sqrt(sr * ru)
d = jnp.sqrt(sr3 + ru3)
d_prime = 1 / (2 * su * d)
a = d_prime * sr3 * sr_ru_half
b = d_prime * ru3 * sr_ru_half

w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
w_r = w_s + w_sr
c = jnp.sqrt(jnp.abs(3 * sr3 * ru3)) / (6 * d)
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)

Expand Down
Loading