Skip to content

Commit

Permalink
got Rid of internally using LevyVal in VBT
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking authored and patrick-kidger committed Apr 20, 2024
1 parent fbd83de commit 28c19b3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 30 deletions.
4 changes: 2 additions & 2 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def evaluate(
self.shape,
)
if use_levy:
out = levy_tree_transpose(self.shape, self.levy_area, out)
out = levy_tree_transpose(self.shape, out)
assert isinstance(out, LevyVal)
return out

Expand All @@ -137,7 +137,7 @@ def _evaluate_leaf(
w = jr.normal(key, shape.shape, shape.dtype) * w_std

if use_levy:
return LevyVal(dt=t1 - t0, W=w, H=hh, bar_H=None, K=None, bar_K=None)
return LevyVal(dt=t1 - t0, W=w, H=hh, K=None)
else:
return w

Expand Down
69 changes: 44 additions & 25 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ class _State(eqx.Module):
bkk_s_u_su: Optional[FloatTriple] # \bar{K}_s, _u, _{s,u}


def _levy_diff(x0: LevyVal, x1: LevyVal) -> LevyVal:
def _levy_diff(_, x0: tuple, x1: tuple) -> LevyVal:
r"""Computes $(W_{s,u}, H_{s,u})$ from $(W_s, \bar{H}_{s,u})$ and
$(W_u, \bar{H}_u)$, where $\bar{H}_u = u * H_u$.
**Arguments:**
- `_`: unused, for the purposes of aligning the `jtu.tree_map`.
- `x0`: `LevyVal` at time `s`.
- `x1`: `LevyVal` at time `u`.
Expand All @@ -83,19 +84,39 @@ def _levy_diff(x0: LevyVal, x1: LevyVal) -> LevyVal:
`LevyVal(W_su, H_su)`
"""

su = jnp.asarray(x1.dt - x0.dt, dtype=x0.W.dtype)
w_su = x1.W - x0.W
if x0.H is None or x1.H is None: # BM only case
return LevyVal(dt=su, W=w_su, H=None, bar_H=None, K=None, bar_K=None)

assert (x0.bar_H is not None) and (x1.bar_H is not None)
# if we are at this point levy_area == "space-time"
_su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su)
inverse_su = 1 / _su
u_bb_s = x1.dt * x0.W - x0.dt * x1.W
bhh_su = x1.bar_H - x0.bar_H - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
hh_su = inverse_su * bhh_su
return LevyVal(dt=su, W=w_su, H=hh_su, bar_H=None, K=None, bar_K=None)
if len(x0) == 2: # BM only case
assert len(x1) == 2
dt0, w0 = x0
dt1, w1 = x1
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
return LevyVal(dt=su, W=w1 - w0, H=None, K=None)

elif len(x0) == 4: # space-time levy area case
assert len(x1) == 4
dt0, w0, hh0, bhh0 = x0
dt1, w1, hh1, bhh1 = x1

w_su = w1 - w0
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
_su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su)
inverse_su = 1 / _su
u_bb_s = dt1 * w0 - dt0 * w1
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
hh_su = inverse_su * bhh_su
return LevyVal(dt=su, W=w_su, H=hh_su, K=None)
else:
assert False


def _make_levy_val(_, x: tuple) -> LevyVal:
if len(x) == 2:
dt, w = x
return LevyVal(dt=dt, W=w, H=None, K=None)
elif len(x) == 4:
dt, w, hh, bhh = x
return LevyVal(dt=dt, W=w, H=hh, K=None)
else:
assert False


def _split_interval(
Expand Down Expand Up @@ -210,9 +231,7 @@ def sqrt_mult(z):
dt=jtu.tree_map(mult, x.dt),
W=jtu.tree_map(sqrt_mult, x.W),
H=jtu.tree_map(sqrt_mult, x.H),
bar_H=None,
K=jtu.tree_map(sqrt_mult, x.K),
bar_K=None,
)

@eqx.filter_jit
Expand All @@ -223,30 +242,28 @@ def evaluate(
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
def _is_levy_val(obj):
return isinstance(obj, LevyVal)

t0 = eqxi.nondifferentiable(t0, name="t0")
# map the interval [self.t0, self.t1] onto [0,1]
t0 = linear_rescale(self.t0, t0, self.t1)
levy_0 = self._evaluate(t0)
if t1 is None:
levy_out = levy_0
levy_out = jtu.tree_map(_make_levy_val, self.shape, levy_out)

else:
t1 = eqxi.nondifferentiable(t1, name="t1")
# map the interval [self.t0, self.t1] onto [0,1]
t1 = linear_rescale(self.t0, t1, self.t1)
levy_1 = self._evaluate(t1)
levy_out = jtu.tree_map(_levy_diff, levy_0, levy_1, is_leaf=_is_levy_val)
levy_out = jtu.tree_map(_levy_diff, self.shape, levy_0, levy_1)

levy_out = levy_tree_transpose(self.shape, self.levy_area, levy_out)
levy_out = levy_tree_transpose(self.shape, levy_out)
# now map [0,1] back onto [self.t0, self.t1]
levy_out = self._denormalise_bm_inc(levy_out)
assert isinstance(levy_out, LevyVal)
return levy_out if use_levy else levy_out.W

def _evaluate(self, r: RealScalarLike) -> PyTree[LevyVal]:
def _evaluate(self, r: RealScalarLike) -> PyTree:
"""Maps the _evaluate_leaf function at time r using self.key onto self.shape"""
r = eqxi.error_if(
r,
Expand All @@ -261,7 +278,9 @@ def _evaluate_leaf(
key,
r: RealScalarLike,
struct: jax.ShapeDtypeStruct,
) -> LevyVal:
) -> Union[
tuple[RealScalarLike, Array], tuple[RealScalarLike, Array, Array, Array]
]:
shape, dtype = struct.shape, struct.dtype

t0 = jnp.zeros((), dtype)
Expand Down Expand Up @@ -355,7 +374,7 @@ def _body_fun(_state: _State):
else:
assert False
w_r = w_mean + bb
return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None)
return r, w_r

elif self.levy_area == "space-time":
# This is based on Theorem 6.1.4 of Foster's thesis (see above).
Expand Down Expand Up @@ -399,7 +418,7 @@ def _body_fun(_state: _State):
else:
assert False

return LevyVal(dt=r, W=w_r, H=hh_r, bar_H=bhh_r, K=None, bar_K=None)
return r, w_r, hh_r, bhh_r

def _brownian_arch(
self, _state: _State, shape, dtype
Expand Down
4 changes: 1 addition & 3 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@ class LevyVal(eqx.Module):
dt: PyTree
W: PyTree
H: Optional[PyTree]
bar_H: Optional[PyTree]
K: Optional[PyTree]
bar_K: Optional[PyTree]


def levy_tree_transpose(tree_shape, levy_area: LevyArea, tree: PyTree):
def levy_tree_transpose(tree_shape, tree: PyTree):
"""Helper that takes a PyTree of LevyVals and transposes
into a LevyVal of PyTrees.
Expand Down

0 comments on commit 28c19b3

Please sign in to comment.