Skip to content

Commit

Permalink
Fix issue #563: time t0 instead of 0 is passed to _check in _assert_t…
Browse files Browse the repository at this point in the history
…erm_compatible, so the term is not required to be well-defined at time 0, but rather at time t0.
  • Loading branch information
FloList authored and patrick-kidger committed Jan 13, 2025
1 parent 4a308b8 commit 134a40a
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _is_none(x: Any) -> bool:


def _assert_term_compatible(
t: FloatScalarLike,
y: PyTree[ArrayLike],
args: PyTree[Any],
terms: PyTree[AbstractTerm],
Expand All @@ -138,7 +139,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
for term, arg, term_contr_kwarg in zip(
term.terms, get_args(_tmp), term_contr_kwargs
):
_assert_term_compatible(yi, args, term, arg, term_contr_kwarg)
_assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg)
else:
raise ValueError(
f"Term {term} is not a MultiTerm but is expected to be."
Expand Down Expand Up @@ -166,7 +167,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
elif n_term_args == 2:
vf_type_expected, control_type_expected = term_args
try:
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
except Exception as e:
raise ValueError(f"Error while tracing {term}.vf: " + str(e))
vf_type_compatible = eqx.filter_eval_shape(
Expand All @@ -178,7 +179,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
contr = ft.partial(term.contr, **term_contr_kwargs)
# Work around https://github.com/google/jax/issues/21825
try:
control_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
control_type = eqx.filter_eval_shape(contr, t, t)
except Exception as e:
raise ValueError(f"Error while tracing {term}.contr: " + str(e))
control_type_compatible = eqx.filter_eval_shape(
Expand Down Expand Up @@ -1077,6 +1078,7 @@ def _promote(yi):
if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)):
try:
_assert_term_compatible(
t0,
y0,
args,
terms,
Expand All @@ -1098,6 +1100,7 @@ def _promote(yi):

# Error checking for term compatibility
_assert_term_compatible(
t0,
y0,
args,
terms,
Expand Down

0 comments on commit 134a40a

Please sign in to comment.