Skip to content

Commit

Permalink
Complex fixes in SDEs
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl authored and patrick-kidger committed Jul 13, 2024
1 parent d6d09dc commit 61d1c91
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions diffrax/_solver/milstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def step(
leaves_ΔwΔw = []
for i1, l1 in enumerate(leaves_Δw):
for i2, l2 in enumerate(leaves_Δw):
leaf = jnp.tensordot(l1[..., None], l2[None, ...], axes=1)
leaf = jnp.tensordot(jnp.conj(l1[..., None]), l2[None, ...], axes=1)
if i1 == i2:
eye = jnp.eye(l1.size).reshape(l1.shape + l1.shape)
with jax.numpy_dtype_promotion("standard"):
Expand Down Expand Up @@ -305,7 +305,7 @@ def _to_treemap(_Δw, _g0):
def __dot(_v0, _ΔwΔw):
# _v0 has structure (leaf(y0), leaf(Δw), leaf(Δw))
# _ΔwΔw has structure (leaf(Δw), leaf(Δw))
_out = jnp.tensordot(_v0, _ΔwΔw, axes=jnp.ndim(_ΔwΔw))
_out = jnp.tensordot(jnp.conj(_v0), _ΔwΔw, axes=jnp.ndim(_ΔwΔw))
# _out has structure (leaf(y0),)
return _out

Expand Down
4 changes: 2 additions & 2 deletions diffrax/_solver/srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _comp_g(_t):

g0_g1 = _comp_g(jnp.array([t0, t1], dtype=complex_to_real_dtype(dtype)))
g0 = jtu.tree_map(lambda g_leaf: g_leaf[0], g0_g1)
# g_delta = 0.5 * g1 - g0
# g_delta = 0.5 * (g1 - g0)
g_delta = jtu.tree_map(lambda g_leaf: 0.5 * (g_leaf[1] - g_leaf[0]), g0_g1)
w_kgs = diffusion.prod(g0, w)
a_w = jnp.asarray(self.tableau.coeffs_w.a, dtype=dtype)
Expand Down Expand Up @@ -456,7 +456,7 @@ def sum_prev_stages(_stage_out_buff, _a_j):
)
# Sum up the previous stages weighted by the coefficients in the tableau
return jtu.tree_map(
lambda lf: jnp.tensordot(_a_j, lf, axes=1), _stage_out_view
lambda lf: jnp.tensordot(jnp.conj(_a_j), lf, axes=1), _stage_out_view
)

def insert_jth_stage(results, k_j, j):
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _callable_to_path(
# control: Shaped[Array, "*control"]
# return: Shaped[Array, "*state"]
def _prod(vf, control):
return jnp.tensordot(vf, control, axes=jnp.ndim(control))
return jnp.tensordot(jnp.conj(vf), control, axes=jnp.ndim(control))


# This class exists for backward compatibility with `WeaklyDiagonalControlTerm`. If we
Expand Down
2 changes: 1 addition & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def path_l2_dist(
# and the length of saveat). Also sum all the PyTree leaves.
def sum_square_diff(y1, y2):
with jax.numpy_dtype_promotion("standard"):
square_diff = jnp.square(y1 - y2)
square_diff = jnp.square(jnp.abs(y1 - y2))
# sum all but the first two axes
axes = range(2, square_diff.ndim)
out = jnp.sum(square_diff, axis=axes)
Expand Down

0 comments on commit 61d1c91

Please sign in to comment.