Skip to content

Commit

Permalink
Fixed issue
Browse files Browse the repository at this point in the history
  • Loading branch information
LuggiStruggi authored Jan 15, 2025
1 parent a4aeb42 commit a3c1167
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,16 +801,15 @@ def dynamics(t, y, args):
param = args
return param - y


def event_fn(t, y, args, **kwargs):
return y - 1.5

def single_loss_fn(param):
solver = diffrax.Euler()
root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
event = diffrax.Event(event_fn, root_finder)
term = diffrax.ODETerm(dynamics)

term = diffrax.ODETerm(dynamics)
sol = diffrax.diffeqsolve(
term,
solver=solver,
Expand All @@ -831,9 +830,11 @@ def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:

def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)

batch = jnp.array([1.0, 2.0, 3.0])

try:
grad = grad_fn(y0)
grad = grad_fn(batch)
except NotImplementedError as e:
pytest.fail(f"NotImplementedError was raised: {e}")
except Exception as e:
Expand Down

0 comments on commit a3c1167

Please sign in to comment.