diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6e5d278c..cacc1070 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -119,6 +119,7 @@ def _is_none(x: Any) -> bool: def _assert_term_compatible( + t: FloatScalarLike, y: PyTree[ArrayLike], args: PyTree[Any], terms: PyTree[AbstractTerm], @@ -138,7 +139,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): for term, arg, term_contr_kwarg in zip( term.terms, get_args(_tmp), term_contr_kwargs ): - _assert_term_compatible(yi, args, term, arg, term_contr_kwarg) + _assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg) else: raise ValueError( f"Term {term} is not a MultiTerm but is expected to be." @@ -166,7 +167,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): elif n_term_args == 2: vf_type_expected, control_type_expected = term_args try: - vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args) + vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) except Exception as e: raise ValueError(f"Error while tracing {term}.vf: " + str(e)) vf_type_compatible = eqx.filter_eval_shape( @@ -178,7 +179,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): contr = ft.partial(term.contr, **term_contr_kwargs) # Work around https://github.com/google/jax/issues/21825 try: - control_type = eqx.filter_eval_shape(contr, 0.0, 0.0) + control_type = eqx.filter_eval_shape(contr, t, t) except Exception as e: raise ValueError(f"Error while tracing {term}.contr: " + str(e)) control_type_compatible = eqx.filter_eval_shape( @@ -1077,6 +1078,7 @@ def _promote(yi): if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)): try: _assert_term_compatible( + t0, y0, args, terms, @@ -1098,6 +1100,7 @@ def _promote(yi): # Error checking for term compatibility _assert_term_compatible( + t0, y0, args, terms,