From f17aab5406bf9f9598c5a633f1ca60cdaee88061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Justus=20Sagem=C3=BCller?= Date: Tue, 13 Aug 2024 18:10:28 +0200 Subject: [PATCH] Manual broadcasting for the general tensor-with-inconsistent-dimensions case. Previously, this was more or less automatically done in NumPy, but not anymore. Normally, this is more likely to be user mistake. However, ODL actually relies somewhat on such broadcasts when defining functions on "sparse" mesh grids, so I added this functionality back be recursively transversing lists of different-shape arrays (like in the old version NumPy did, manually generating ragged dtype=object arrays). --- odl/discr/discr_utils.py | 56 +++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/odl/discr/discr_utils.py b/odl/discr/discr_utils.py index 4d8d855d60f..8f9902bbf9a 100644 --- a/odl/discr/discr_utils.py +++ b/odl/discr/discr_utils.py @@ -923,6 +923,23 @@ def _func_out_type(func): return has_out, out_optional +def _broadcast_nested_list(arr_lists, element_shape, ndim): + """ A generalisation of `np.broadcast_to`, applied to an arbitrarily + deep list (or tuple) eventually containing arrays or scalars. """ + if isinstance(arr_lists, np.ndarray) or np.isscalar(arr_lists): + if ndim == 1: + # As usual, 1d is tedious to deal with. This + # code deals with extra dimensions in result + # components that stem from using x instead of + # x[0] in a function. + # Without this, broadcasting fails. + shp = getattr(arr_lists, 'shape', ()) + if shp and shp[0] == 1: + arr_lists = arr_lists.reshape(arr_lists.shape[1:]) + return np.broadcast_to(arr_lists, element_shape) + else: + return [_broadcast_nested_list(row, element_shape, ndim) for row in arr_lists] + def sampling_function(func_or_arr, domain, out_dtype=None): """Return a function that can be used for sampling. @@ -987,10 +1004,11 @@ def _default_oop(func_ip, x, **kwargs): def _default_ip(func_oop, x, out, **kwargs): """Default in-place variant of an out-of-place-only function.""" - result = np.array(func_oop(x, **kwargs), copy=False) - if result.dtype == object: + result = func_oop(x, **kwargs) + try: + result = np.array(result, copy=False) + except ValueError: # Different shapes encountered, need to broadcast - flat_results = result.ravel() if is_valid_input_array(x, domain.ndim): scalar_out_shape = out_shape_from_array(x) elif is_valid_input_meshgrid(x, domain.ndim): @@ -998,8 +1016,8 @@ def _default_ip(func_oop, x, out, **kwargs): else: raise TypeError('invalid input `x`') - bcast_results = [np.broadcast_to(res, scalar_out_shape) - for res in flat_results] + bcast_results = _broadcast_nested_list(result, scalar_out_shape, domain.ndim) + # New array that is flat in the `out_shape` axes, reshape it # to the final `out_shape + scalar_shape`, using the same # order ('C') as the initial `result.ravel()`. @@ -1332,33 +1350,17 @@ def dual_use_func(x, out=None, **kwargs): elif tensor_valued: # The out object can be any array-like of objects with shapes # that should all be broadcastable to scalar_out_shape. - results = np.array(out) - if results.dtype == object or scalar_in: - # Some results don't have correct shape, need to - # broadcast - bcast_res = [] - for res in results.ravel(): - if ndim == 1: - # As usual, 1d is tedious to deal with. This - # code deals with extra dimensions in result - # components that stem from using x instead of - # x[0] in a function. - # Without this, broadcasting fails. - shp = getattr(res, 'shape', ()) - if shp and shp[0] == 1: - res = res.reshape(res.shape[1:]) - bcast_res.append( - np.broadcast_to(res, scalar_out_shape)) + try: + out_arr = np.asarray(out) + except ValueError: + out_arr = np.asarray(_broadcast_nested_list(out, scalar_out_shape, ndim=ndim)) - out_arr = np.array(bcast_res, dtype=scalar_out_dtype) - elif results.dtype != scalar_out_dtype: + if out_arr.dtype != scalar_out_dtype: raise ValueError( 'result is of dtype {}, expected {}' - ''.format(dtype_repr(results.dtype), + ''.format(dtype_repr(out_arr.dtype), dtype_repr(scalar_out_dtype)) ) - else: - out_arr = results out = out_arr.reshape(out_shape)