From 61d1c91ab52dad2ac1e6b06cba93e1469f47f370 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 1 Jul 2024 11:31:25 +0300 Subject: [PATCH] Complex fixes in SDEs --- diffrax/_solver/milstein.py | 4 ++-- diffrax/_solver/srk.py | 4 ++-- diffrax/_term.py | 2 +- test/helpers.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/diffrax/_solver/milstein.py b/diffrax/_solver/milstein.py index 3d14e343..e5853762 100644 --- a/diffrax/_solver/milstein.py +++ b/diffrax/_solver/milstein.py @@ -211,7 +211,7 @@ def step( leaves_ΔwΔw = [] for i1, l1 in enumerate(leaves_Δw): for i2, l2 in enumerate(leaves_Δw): - leaf = jnp.tensordot(l1[..., None], l2[None, ...], axes=1) + leaf = jnp.tensordot(jnp.conj(l1[..., None]), l2[None, ...], axes=1) if i1 == i2: eye = jnp.eye(l1.size).reshape(l1.shape + l1.shape) with jax.numpy_dtype_promotion("standard"): @@ -305,7 +305,7 @@ def _to_treemap(_Δw, _g0): def __dot(_v0, _ΔwΔw): # _v0 has structure (leaf(y0), leaf(Δw), leaf(Δw)) # _ΔwΔw has structure (leaf(Δw), leaf(Δw)) - _out = jnp.tensordot(_v0, _ΔwΔw, axes=jnp.ndim(_ΔwΔw)) + _out = jnp.tensordot(jnp.conj(_v0), _ΔwΔw, axes=jnp.ndim(_ΔwΔw)) # _out has structure (leaf(y0),) return _out diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index 8ac53715..77c1d53e 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -407,7 +407,7 @@ def _comp_g(_t): g0_g1 = _comp_g(jnp.array([t0, t1], dtype=complex_to_real_dtype(dtype))) g0 = jtu.tree_map(lambda g_leaf: g_leaf[0], g0_g1) - # g_delta = 0.5 * g1 - g0 + # g_delta = 0.5 * (g1 - g0) g_delta = jtu.tree_map(lambda g_leaf: 0.5 * (g_leaf[1] - g_leaf[0]), g0_g1) w_kgs = diffusion.prod(g0, w) a_w = jnp.asarray(self.tableau.coeffs_w.a, dtype=dtype) @@ -456,7 +456,7 @@ def sum_prev_stages(_stage_out_buff, _a_j): ) # Sum up the previous stages weighted by the coefficients in the tableau return jtu.tree_map( - lambda lf: jnp.tensordot(_a_j, lf, axes=1), _stage_out_view + lambda lf: jnp.tensordot(jnp.conj(_a_j), lf, axes=1), _stage_out_view ) def insert_jth_stage(results, k_j, j): diff --git a/diffrax/_term.py b/diffrax/_term.py index 8d031b95..d43f951b 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -257,7 +257,7 @@ def _callable_to_path( # control: Shaped[Array, "*control"] # return: Shaped[Array, "*state"] def _prod(vf, control): - return jnp.tensordot(vf, control, axes=jnp.ndim(control)) + return jnp.tensordot(jnp.conj(vf), control, axes=jnp.ndim(control)) # This class exists for backward compatibility with `WeaklyDiagonalControlTerm`. If we diff --git a/test/helpers.py b/test/helpers.py index fd4097dd..749d35a1 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -100,7 +100,7 @@ def path_l2_dist( # and the length of saveat). Also sum all the PyTree leaves. def sum_square_diff(y1, y2): with jax.numpy_dtype_promotion("standard"): - square_diff = jnp.square(y1 - y2) + square_diff = jnp.square(jnp.abs(y1 - y2)) # sum all but the first two axes axes = range(2, square_diff.ndim) out = jnp.sum(square_diff, axis=axes)