Skip to content

Commit

Permalink
Fix minor complex-related problems
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Oct 29, 2023
1 parent 1d67804 commit 28c77fc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions diffrax/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def _rms_norm_jvp(x, tx):
pred = (out == 0) | jnp.isinf(out)
numerator = jnp.where(pred, 0, x)
denominator = jnp.where(pred, 1, out * x.size)
t_out = jnp.dot(numerator / denominator, tx)
return out, t_out
t_out = jnp.dot(numerator / denominator, jnp.conj(tx))
return out, jnp.real(t_out)


def adjoint_rms_seminorm(x: Tuple[PyTree, PyTree, PyTree, PyTree]) -> Scalar:
Expand Down
4 changes: 3 additions & 1 deletion diffrax/nonlinear_solver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,6 @@ def jac(fn: Callable, x: PyTree, args: PyTree) -> LU_Jacobian:
if not jnp.issubdtype(flat, jnp.inexact):
# Handle integer arguments
flat = flat.astype(jnp.float32)
return jsp.linalg.lu_factor(jax.jacfwd(curried)(flat))
return jsp.linalg.lu_factor(
jax.jacfwd(curried, holomorphic=jnp.iscomplexobj(flat))(flat)
)

0 comments on commit 28c77fc

Please sign in to comment.