Skip to content

Commit

Permalink
Nit fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Oct 31, 2023
1 parent 28c77fc commit bdda361
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 10 deletions.
7 changes: 0 additions & 7 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from .saveat import SaveAt, SubSaveAt
from .solution import is_okay, is_successful, RESULTS, Solution
from .solver import (
AbstractImplicitSolver,
AbstractItoSolver,
AbstractSolver,
AbstractStratonovichSolver,
Expand Down Expand Up @@ -608,11 +607,6 @@ def diffeqsolve(

# Error checking and warning for complex dtypes
if any(jtu.tree_leaves(jtu.tree_map(jnp.iscomplexobj, y0))):
if isinstance(solver, AbstractImplicitSolver):
raise ValueError(
"Implicit solvers in conjunction with complex dtypes is currently not "
"supported."
)
warnings.warn(
"Complex dtype support is work in progress, please read "
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully."
Expand Down Expand Up @@ -667,7 +661,6 @@ def diffeqsolve(
t1 = jnp.asarray(t1, dtype=dtype)
if dt0 is not None:
dt0 = jnp.asarray(dt0, dtype=dtype)
timelikes.append(dtype)

def _get_subsaveat_ts(saveat):
out = [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)]
Expand Down
3 changes: 0 additions & 3 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey):
def f(t, y, args):
return jtu.tree_map(lambda _y: operator.mul(-1j, _y), y)

if isinstance(solver, diffrax.AbstractImplicitSolver):
return

else:

def f(t, y, args):
Expand Down

0 comments on commit bdda361

Please sign in to comment.