Skip to content

Commit

Permalink
Merge branch 'dev' into Owen/control_revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo authored Jan 16, 2025
2 parents d304d9f + cc0d4bc commit 22d00ca
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 8 deletions.
15 changes: 11 additions & 4 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _is_none(x: Any) -> bool:


def _assert_term_compatible(
t: FloatScalarLike,
y: PyTree[ArrayLike],
args: PyTree[Any],
terms: PyTree[AbstractTerm],
Expand All @@ -145,7 +146,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
for term, arg, term_contr_kwarg in zip(
term.terms, get_args(_tmp), term_contr_kwargs
):
_assert_term_compatible(yi, args, term, arg, term_contr_kwarg)
_assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg)
else:
raise ValueError(
f"Term {term} is not a MultiTerm but is expected to be."
Expand Down Expand Up @@ -173,7 +174,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
elif n_term_args == 3:
vf_type_expected, control_type_expected, path_type_expected = term_args
try:
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
except Exception as e:
raise ValueError(f"Error while tracing {term}.vf: " + str(e))
vf_type_compatible = eqx.filter_eval_shape(
Expand All @@ -186,7 +187,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
contr = ft.partial(term.contr, **term_contr_kwargs)
# Work around https://github.com/google/jax/issues/21825
try:
control_type, path_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
control_type = eqx.filter_eval_shape(contr, t, t)
except Exception as e:
raise ValueError(f"Error while tracing {term}.contr: " + str(e))
control_type_compatible = eqx.filter_eval_shape(
Expand Down Expand Up @@ -1096,6 +1097,7 @@ def _promote(yi):
if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)):
try:
_assert_term_compatible(
t0,
y0,
args,
terms,
Expand All @@ -1118,7 +1120,12 @@ def _promote(yi):
# Error checking for term compatibility

_assert_term_compatible(
y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs
t0,
y0,
args,
terms,
solver.term_structure,
solver.term_compatible_contr_kwargs,
)

if is_sde(terms):
Expand Down
8 changes: 4 additions & 4 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,13 +1082,13 @@ class UnderdampedLangevinDriftTerm(AbstractTerm):

gamma: PyTree[ArrayLike]
u: PyTree[ArrayLike]
grad_f: Callable[[UnderdampedLangevinX], UnderdampedLangevinX]
grad_f: Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX]

def __init__(
self,
gamma: PyTree[ArrayLike],
u: PyTree[ArrayLike],
grad_f: Callable[[UnderdampedLangevinX], UnderdampedLangevinX],
grad_f: Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX],
):
r"""
**Arguments:**
Expand All @@ -1099,7 +1099,7 @@ def __init__(
a scalar or a PyTree of the same shape as the position vector $x$.
- `grad_f`: A callable representing the gradient of the potential function $f$.
This callable should take a PyTree of the same shape as $x$ and
return a PyTree of the same shape.
an optional `args` argument, returning a PyTree of the same shape.
"""
self.gamma = gamma
self.u = u
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def fun(_gamma, _u, _v, _f_x):

vf_x = v
try:
f_x = self.grad_f(x)
f_x = self.grad_f(x, args) # Pass args to grad_f
vf_v = jtu.tree_map(fun, gamma, u, v, f_x)
except ValueError:
raise RuntimeError(
Expand Down
36 changes: 36 additions & 0 deletions test/test_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,39 @@ def test_weaklydiagonal_deprecate():
_ = diffrax.WeaklyDiagonalControlTerm(
lambda t, y, args: 0.0, lambda t0, t1: jnp.array(t1 - t0)
)


def test_underdamped_langevin_drift_term_args():
"""
Test that the UnderdampedLangevinDriftTerm handles `args` in grad_f correctly.
"""

# Mock gradient function that uses args
def mock_grad_f(x, args):
return jtu.tree_map(lambda xi, ai: xi + ai, x, args)

# Mock data
gamma = jnp.array([0.1, 0.2, 0.3])
u = jnp.array([0.4, 0.5, 0.6])
x = jnp.array([1.0, 2.0, 3.0])
v = jnp.array([0.1, 0.2, 0.3])
args = jnp.array([0.7, 0.8, 0.9])
y = (x, v)

# Create instance of the drift term
term = diffrax.UnderdampedLangevinDriftTerm(gamma=gamma, u=u, grad_f=mock_grad_f)

# Compute the vector field
vf_y = term.vf(0.0, y, args)

# Extract results
vf_x, vf_v = vf_y

# Expected results
expected_vf_x = v # By definition, vf_x = v
f_x = x + args # Output of mock_grad_f
expected_vf_v = -gamma * v - u * f_x # Drift term calculation

# Assertions
assert jnp.allclose(vf_x, expected_vf_x), "vf_x does not match expected results"
assert jnp.allclose(vf_v, expected_vf_v), "vf_v does not match expected results"

0 comments on commit 22d00ca

Please sign in to comment.