From 134a40ad351e18eb8c2e6afa556356596123a69c Mon Sep 17 00:00:00 2001 From: Florian List Date: Mon, 13 Jan 2025 10:22:53 +0100 Subject: [PATCH 1/2] Fix issue #563: time t0 instead of 0 is passed to _check in _assert_term_compatible, so the term is not required to be well-defined at time 0, but rather at time t0. --- diffrax/_integrate.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6e5d278c..cacc1070 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -119,6 +119,7 @@ def _is_none(x: Any) -> bool: def _assert_term_compatible( + t: FloatScalarLike, y: PyTree[ArrayLike], args: PyTree[Any], terms: PyTree[AbstractTerm], @@ -138,7 +139,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): for term, arg, term_contr_kwarg in zip( term.terms, get_args(_tmp), term_contr_kwargs ): - _assert_term_compatible(yi, args, term, arg, term_contr_kwarg) + _assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg) else: raise ValueError( f"Term {term} is not a MultiTerm but is expected to be." @@ -166,7 +167,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): elif n_term_args == 2: vf_type_expected, control_type_expected = term_args try: - vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args) + vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) except Exception as e: raise ValueError(f"Error while tracing {term}.vf: " + str(e)) vf_type_compatible = eqx.filter_eval_shape( @@ -178,7 +179,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): contr = ft.partial(term.contr, **term_contr_kwargs) # Work around https://github.com/google/jax/issues/21825 try: - control_type = eqx.filter_eval_shape(contr, 0.0, 0.0) + control_type = eqx.filter_eval_shape(contr, t, t) except Exception as e: raise ValueError(f"Error while tracing {term}.contr: " + str(e)) control_type_compatible = eqx.filter_eval_shape( @@ -1077,6 +1078,7 @@ def _promote(yi): if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)): try: _assert_term_compatible( + t0, y0, args, terms, @@ -1098,6 +1100,7 @@ def _promote(yi): # Error checking for term compatibility _assert_term_compatible( + t0, y0, args, terms, From cc0d4bca8fb4ad4a3892a89eb8d3e30cfccba833 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <104301293+ricor07@users.noreply.github.com> Date: Fri, 27 Dec 2024 11:52:13 +0100 Subject: [PATCH 2/2] Allowing args into grad_f for ULD --- diffrax/_term.py | 8 ++++---- test/test_term.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/diffrax/_term.py b/diffrax/_term.py index efa28d29..d13d430b 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -925,13 +925,13 @@ class UnderdampedLangevinDriftTerm(AbstractTerm): gamma: PyTree[ArrayLike] u: PyTree[ArrayLike] - grad_f: Callable[[UnderdampedLangevinX], UnderdampedLangevinX] + grad_f: Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX] def __init__( self, gamma: PyTree[ArrayLike], u: PyTree[ArrayLike], - grad_f: Callable[[UnderdampedLangevinX], UnderdampedLangevinX], + grad_f: Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX], ): r""" **Arguments:** @@ -942,7 +942,7 @@ def __init__( a scalar or a PyTree of the same shape as the position vector $x$. - `grad_f`: A callable representing the gradient of the potential function $f$. This callable should take a PyTree of the same shape as $x$ and - return a PyTree of the same shape. + an optional `args` argument, returning a PyTree of the same shape. """ self.gamma = gamma self.u = u @@ -963,7 +963,7 @@ def fun(_gamma, _u, _v, _f_x): vf_x = v try: - f_x = self.grad_f(x) + f_x = self.grad_f(x, args) # Pass args to grad_f vf_v = jtu.tree_map(fun, gamma, u, v, f_x) except ValueError: raise RuntimeError( diff --git a/test/test_term.py b/test/test_term.py index 5260db2c..8e8bf8be 100644 --- a/test/test_term.py +++ b/test/test_term.py @@ -158,3 +158,39 @@ def test_weaklydiagonal_deprecate(): _ = diffrax.WeaklyDiagonalControlTerm( lambda t, y, args: 0.0, lambda t0, t1: jnp.array(t1 - t0) ) + + +def test_underdamped_langevin_drift_term_args(): + """ + Test that the UnderdampedLangevinDriftTerm handles `args` in grad_f correctly. + """ + + # Mock gradient function that uses args + def mock_grad_f(x, args): + return jtu.tree_map(lambda xi, ai: xi + ai, x, args) + + # Mock data + gamma = jnp.array([0.1, 0.2, 0.3]) + u = jnp.array([0.4, 0.5, 0.6]) + x = jnp.array([1.0, 2.0, 3.0]) + v = jnp.array([0.1, 0.2, 0.3]) + args = jnp.array([0.7, 0.8, 0.9]) + y = (x, v) + + # Create instance of the drift term + term = diffrax.UnderdampedLangevinDriftTerm(gamma=gamma, u=u, grad_f=mock_grad_f) + + # Compute the vector field + vf_y = term.vf(0.0, y, args) + + # Extract results + vf_x, vf_v = vf_y + + # Expected results + expected_vf_x = v # By definition, vf_x = v + f_x = x + args # Output of mock_grad_f + expected_vf_v = -gamma * v - u * f_x # Drift term calculation + + # Assertions + assert jnp.allclose(vf_x, expected_vf_x), "vf_x does not match expected results" + assert jnp.allclose(vf_v, expected_vf_v), "vf_v does not match expected results"