diff --git a/odl/discr/discr_utils.py b/odl/discr/discr_utils.py index 4d8d855d60f..8f9902bbf9a 100644 --- a/odl/discr/discr_utils.py +++ b/odl/discr/discr_utils.py @@ -923,6 +923,23 @@ def _func_out_type(func): return has_out, out_optional +def _broadcast_nested_list(arr_lists, element_shape, ndim): + """ A generalisation of `np.broadcast_to`, applied to an arbitrarily + deep list (or tuple) eventually containing arrays or scalars. """ + if isinstance(arr_lists, np.ndarray) or np.isscalar(arr_lists): + if ndim == 1: + # As usual, 1d is tedious to deal with. This + # code deals with extra dimensions in result + # components that stem from using x instead of + # x[0] in a function. + # Without this, broadcasting fails. + shp = getattr(arr_lists, 'shape', ()) + if shp and shp[0] == 1: + arr_lists = arr_lists.reshape(arr_lists.shape[1:]) + return np.broadcast_to(arr_lists, element_shape) + else: + return [_broadcast_nested_list(row, element_shape, ndim) for row in arr_lists] + def sampling_function(func_or_arr, domain, out_dtype=None): """Return a function that can be used for sampling. @@ -987,10 +1004,11 @@ def _default_oop(func_ip, x, **kwargs): def _default_ip(func_oop, x, out, **kwargs): """Default in-place variant of an out-of-place-only function.""" - result = np.array(func_oop(x, **kwargs), copy=False) - if result.dtype == object: + result = func_oop(x, **kwargs) + try: + result = np.array(result, copy=False) + except ValueError: # Different shapes encountered, need to broadcast - flat_results = result.ravel() if is_valid_input_array(x, domain.ndim): scalar_out_shape = out_shape_from_array(x) elif is_valid_input_meshgrid(x, domain.ndim): @@ -998,8 +1016,8 @@ def _default_ip(func_oop, x, out, **kwargs): else: raise TypeError('invalid input `x`') - bcast_results = [np.broadcast_to(res, scalar_out_shape) - for res in flat_results] + bcast_results = _broadcast_nested_list(result, scalar_out_shape, domain.ndim) + # New array that is flat in the `out_shape` axes, reshape it # to the final `out_shape + scalar_shape`, using the same # order ('C') as the initial `result.ravel()`. @@ -1332,33 +1350,17 @@ def dual_use_func(x, out=None, **kwargs): elif tensor_valued: # The out object can be any array-like of objects with shapes # that should all be broadcastable to scalar_out_shape. - results = np.array(out) - if results.dtype == object or scalar_in: - # Some results don't have correct shape, need to - # broadcast - bcast_res = [] - for res in results.ravel(): - if ndim == 1: - # As usual, 1d is tedious to deal with. This - # code deals with extra dimensions in result - # components that stem from using x instead of - # x[0] in a function. - # Without this, broadcasting fails. - shp = getattr(res, 'shape', ()) - if shp and shp[0] == 1: - res = res.reshape(res.shape[1:]) - bcast_res.append( - np.broadcast_to(res, scalar_out_shape)) + try: + out_arr = np.asarray(out) + except ValueError: + out_arr = np.asarray(_broadcast_nested_list(out, scalar_out_shape, ndim=ndim)) - out_arr = np.array(bcast_res, dtype=scalar_out_dtype) - elif results.dtype != scalar_out_dtype: + if out_arr.dtype != scalar_out_dtype: raise ValueError( 'result is of dtype {}, expected {}' - ''.format(dtype_repr(results.dtype), + ''.format(dtype_repr(out_arr.dtype), dtype_repr(scalar_out_dtype)) ) - else: - out_arr = results out = out_arr.reshape(out_shape)