Skip to content

Commit

Permalink
using RuntimeError for when ULD args have wrong structure
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Sep 1, 2024
1 parent b760c65 commit dff726b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 36 deletions.
12 changes: 9 additions & 3 deletions diffrax/_solver/foster_langevin_srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
UnderdampedLangevinDiffusionTerm,
UnderdampedLangevinDriftTerm,
UnderdampedLangevinLeaf,
UnderdampedLangevinStructureError,
UnderdampedLangevinTuple,
UnderdampedLangevinX,
WrapTerm,
Expand Down Expand Up @@ -290,13 +289,20 @@ def compare_args_fun(arg1, arg2):
try:
grad_f_shape = jax.eval_shape(grad_f, x0)
except ValueError:
raise UnderdampedLangevinStructureError("grad_f")
raise RuntimeError(
"The function `grad_f` in the Underdamped Langevin term must be"
" a callable, whose input and output have the same PyTree structure"
" and shapes as the position `x`."
)

def shape_check_fun(_x, _g, _u, _fx):
return _x.shape == _g.shape == _u.shape == _fx.shape

if not jtu.tree_all(jtu.tree_map(shape_check_fun, x0, gamma, u, grad_f_shape)):
raise UnderdampedLangevinStructureError(None)
raise RuntimeError(
"The shapes and PyTree structures of x0, gamma, u, and grad_f(x0)"
" must match."
)

tay_coeffs = jtu.tree_map(self._tay_coeffs_single, gamma)
# tay_coeffs have the same tree structure as gamma, with each leaf being a
Expand Down
43 changes: 10 additions & 33 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,37 +798,6 @@ def _to_vjp(_y, _diff_args, _diff_term):
UnderdampedLangevinTuple = tuple[UnderdampedLangevinX, UnderdampedLangevinX]


class UnderdampedLangevinStructureError(Exception):
"""Raised when the structure of the arguments in the Underdamped Langevin
terms is incorrect."""

# Without this, the ValueError would be caught in _integrate._term_compatible,
# which would then give a less informative error message.
def __init__(self, problematic_arg: Optional[str]):
if problematic_arg is None:
msg = (
"If `x` is the position of the Underdamped Langevin diffusion,"
" then the PyTree structures and shapes of `grad_f(x)` and of "
"the arguments `gamma` and `u` must be the same as the structure"
" and shapes of x."
)
elif problematic_arg == "grad_f":
msg = (
"The function `grad_f` in the Underdamped Langevin term must be"
" a callable, whose input and output have the same PyTree structure"
" and shapes as the position `x`."
)
else:
msg = (
f"If `x` is the position of the Underdamped Langevin diffusion,"
f" then the PyTree structure and shapes of the argument"
f" `{problematic_arg}` must be the same as the structure and"
f" shapes of x."
)

super().__init__(msg)


def _broadcast_pytree(source, target_tree):
# Broadcasts the source PyTree to the shape and PyTree structure of
# target_tree_shape. Requires that source is a prefix tree of target_tree
Expand All @@ -853,7 +822,11 @@ def broadcast_underdamped_langevin_arg(
try:
return _broadcast_pytree(arg, x)
except ValueError:
raise UnderdampedLangevinStructureError(arg_name)
raise RuntimeError(
"The PyTree structure and shapes of the arguments `gamma` and `u`"
"in the Underdamped Langevin term must be the same as the structure"
"and shapes of the position `x`."
)


class UnderdampedLangevinDiffusionTerm(
Expand Down Expand Up @@ -992,7 +965,11 @@ def fun(_gamma, _u, _v, _f_x):
f_x = self.grad_f(x)
vf_v = jtu.tree_map(fun, gamma, u, v, f_x)
except ValueError:
raise UnderdampedLangevinStructureError("grad_f")
raise RuntimeError(
"The function `grad_f` in the Underdamped Langevin term must be"
" a callable, whose input and output have the same PyTree structure"
" and shapes as the position `x`."
)
vf_y = (vf_x, vf_v)
return vf_y

Expand Down

0 comments on commit dff726b

Please sign in to comment.