From dff726b774357d1e0e1bd064831f2c1cc2ff2cba Mon Sep 17 00:00:00 2001 From: andyElking Date: Sun, 1 Sep 2024 18:08:55 +0100 Subject: [PATCH] using RuntimeError for when ULD args have wrong structure --- diffrax/_solver/foster_langevin_srk.py | 12 +++++-- diffrax/_term.py | 43 ++++++-------------------- 2 files changed, 19 insertions(+), 36 deletions(-) diff --git a/diffrax/_solver/foster_langevin_srk.py b/diffrax/_solver/foster_langevin_srk.py index 24627fdb..82e6538f 100644 --- a/diffrax/_solver/foster_langevin_srk.py +++ b/diffrax/_solver/foster_langevin_srk.py @@ -26,7 +26,6 @@ UnderdampedLangevinDiffusionTerm, UnderdampedLangevinDriftTerm, UnderdampedLangevinLeaf, - UnderdampedLangevinStructureError, UnderdampedLangevinTuple, UnderdampedLangevinX, WrapTerm, @@ -290,13 +289,20 @@ def compare_args_fun(arg1, arg2): try: grad_f_shape = jax.eval_shape(grad_f, x0) except ValueError: - raise UnderdampedLangevinStructureError("grad_f") + raise RuntimeError( + "The function `grad_f` in the Underdamped Langevin term must be" + " a callable, whose input and output have the same PyTree structure" + " and shapes as the position `x`." + ) def shape_check_fun(_x, _g, _u, _fx): return _x.shape == _g.shape == _u.shape == _fx.shape if not jtu.tree_all(jtu.tree_map(shape_check_fun, x0, gamma, u, grad_f_shape)): - raise UnderdampedLangevinStructureError(None) + raise RuntimeError( + "The shapes and PyTree structures of x0, gamma, u, and grad_f(x0)" + " must match." + ) tay_coeffs = jtu.tree_map(self._tay_coeffs_single, gamma) # tay_coeffs have the same tree structure as gamma, with each leaf being a diff --git a/diffrax/_term.py b/diffrax/_term.py index 6be83043..5457df20 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -798,37 +798,6 @@ def _to_vjp(_y, _diff_args, _diff_term): UnderdampedLangevinTuple = tuple[UnderdampedLangevinX, UnderdampedLangevinX] -class UnderdampedLangevinStructureError(Exception): - """Raised when the structure of the arguments in the Underdamped Langevin - terms is incorrect.""" - - # Without this, the ValueError would be caught in _integrate._term_compatible, - # which would then give a less informative error message. - def __init__(self, problematic_arg: Optional[str]): - if problematic_arg is None: - msg = ( - "If `x` is the position of the Underdamped Langevin diffusion," - " then the PyTree structures and shapes of `grad_f(x)` and of " - "the arguments `gamma` and `u` must be the same as the structure" - " and shapes of x." - ) - elif problematic_arg == "grad_f": - msg = ( - "The function `grad_f` in the Underdamped Langevin term must be" - " a callable, whose input and output have the same PyTree structure" - " and shapes as the position `x`." - ) - else: - msg = ( - f"If `x` is the position of the Underdamped Langevin diffusion," - f" then the PyTree structure and shapes of the argument" - f" `{problematic_arg}` must be the same as the structure and" - f" shapes of x." - ) - - super().__init__(msg) - - def _broadcast_pytree(source, target_tree): # Broadcasts the source PyTree to the shape and PyTree structure of # target_tree_shape. Requires that source is a prefix tree of target_tree @@ -853,7 +822,11 @@ def broadcast_underdamped_langevin_arg( try: return _broadcast_pytree(arg, x) except ValueError: - raise UnderdampedLangevinStructureError(arg_name) + raise RuntimeError( + "The PyTree structure and shapes of the arguments `gamma` and `u`" + "in the Underdamped Langevin term must be the same as the structure" + "and shapes of the position `x`." + ) class UnderdampedLangevinDiffusionTerm( @@ -992,7 +965,11 @@ def fun(_gamma, _u, _v, _f_x): f_x = self.grad_f(x) vf_v = jtu.tree_map(fun, gamma, u, v, f_x) except ValueError: - raise UnderdampedLangevinStructureError("grad_f") + raise RuntimeError( + "The function `grad_f` in the Underdamped Langevin term must be" + " a callable, whose input and output have the same PyTree structure" + " and shapes as the position `x`." + ) vf_y = (vf_x, vf_v) return vf_y