Skip to content

Commit

Permalink
_integrate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
LuggiStruggi authored Jan 14, 2025
1 parent 134a40a commit d96572e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,8 @@ def body_fun(state):
event_mask = final_state.event_mask
flat_mask = jtu.tree_leaves(event_mask)
assert all(jnp.shape(x) == () for x in flat_mask)
event_happened = jnp.any(jnp.stack(flat_mask))
float_mask = jnp.array(flat_mask).astype(jnp.float32)
event_happened = jnp.max(float_mask) > 0.0

def _root_find():
_interpolator = solver.interpolation_cls(
Expand Down

0 comments on commit d96572e

Please sign in to comment.