diff --git a/odl/discr/discr_utils.py b/odl/discr/discr_utils.py index 4d8d855d60f..a2b49d317a3 100644 --- a/odl/discr/discr_utils.py +++ b/odl/discr/discr_utils.py @@ -20,6 +20,7 @@ from builtins import object from functools import partial from itertools import product +from warnings import warn import numpy as np @@ -119,7 +120,7 @@ def point_collocation(func, points, out=None, **kwargs): >>> ys = [3, 4] >>> mesh = sparse_meshgrid(xs, ys) >>> def vec_valued(x): - ... return (x[0] - 1, 0, x[0] + x[1]) # broadcasting + ... return (x[0] - 1., 0., x[0] + x[1]) # broadcasting >>> # For a function with several output components, we must specify the >>> # shape explicitly in the `out_dtype` parameter >>> func1 = sampling_function( @@ -621,6 +622,13 @@ def _find_indices(self, x): # iterate through dimensions for xi, cvec in zip(x, self.coord_vecs): + try: + xi = np.asarray(xi).astype(self.values.dtype, casting='safe') + except TypeError: + warn("Unable to infer accurate dtype for" + +" interpolation coefficients, defaulting to `float`.") + xi = np.asarray(xi, dtype=float) + idcs = np.searchsorted(cvec, xi) - 1 idcs[idcs < 0] = 0 @@ -706,6 +714,8 @@ def _compute_nearest_weights_edge(idcs, ndist): def _compute_linear_weights_edge(idcs, ndist): """Helper for linear interpolation.""" + ndist = np.asarray(ndist) + # Get out-of-bounds indices from the norm_distances. Negative # means "too low", larger than or equal to 1 means "too high" lo = np.where(ndist < 0) @@ -799,7 +809,7 @@ def _evaluate(self, indices, norm_distances, out=None): # axis, resulting in a loop of length 2**ndim for lo_hi, edge in zip(product(*([['l', 'h']] * len(indices))), product(*edge_indices)): - weight = 1.0 + weight = np.array([1.0], dtype=self.values.dtype) # TODO(kohr-h): determine best summation order from array strides for lh, w_lo, w_hi in zip(lo_hi, low_weights, high_weights): @@ -923,6 +933,23 @@ def _func_out_type(func): return has_out, out_optional +def _broadcast_nested_list(arr_lists, element_shape, ndim): + """ A generalisation of `np.broadcast_to`, applied to an arbitrarily + deep list (or tuple) eventually containing arrays or scalars. """ + if isinstance(arr_lists, np.ndarray) or np.isscalar(arr_lists): + if ndim == 1: + # As usual, 1d is tedious to deal with. This + # code deals with extra dimensions in result + # components that stem from using x instead of + # x[0] in a function. + # Without this, broadcasting fails. + shp = getattr(arr_lists, 'shape', ()) + if shp and shp[0] == 1: + arr_lists = arr_lists.reshape(arr_lists.shape[1:]) + return np.broadcast_to(arr_lists, element_shape) + else: + return [_broadcast_nested_list(row, element_shape, ndim) for row in arr_lists] + def sampling_function(func_or_arr, domain, out_dtype=None): """Return a function that can be used for sampling. @@ -987,10 +1014,11 @@ def _default_oop(func_ip, x, **kwargs): def _default_ip(func_oop, x, out, **kwargs): """Default in-place variant of an out-of-place-only function.""" - result = np.array(func_oop(x, **kwargs), copy=False) - if result.dtype == object: + result = func_oop(x, **kwargs) + try: + result = np.array(result, copy=False) + except ValueError: # Different shapes encountered, need to broadcast - flat_results = result.ravel() if is_valid_input_array(x, domain.ndim): scalar_out_shape = out_shape_from_array(x) elif is_valid_input_meshgrid(x, domain.ndim): @@ -998,8 +1026,8 @@ def _default_ip(func_oop, x, out, **kwargs): else: raise TypeError('invalid input `x`') - bcast_results = [np.broadcast_to(res, scalar_out_shape) - for res in flat_results] + bcast_results = _broadcast_nested_list(result, scalar_out_shape, domain.ndim) + # New array that is flat in the `out_shape` axes, reshape it # to the final `out_shape + scalar_shape`, using the same # order ('C') as the initial `result.ravel()`. @@ -1332,33 +1360,17 @@ def dual_use_func(x, out=None, **kwargs): elif tensor_valued: # The out object can be any array-like of objects with shapes # that should all be broadcastable to scalar_out_shape. - results = np.array(out) - if results.dtype == object or scalar_in: - # Some results don't have correct shape, need to - # broadcast - bcast_res = [] - for res in results.ravel(): - if ndim == 1: - # As usual, 1d is tedious to deal with. This - # code deals with extra dimensions in result - # components that stem from using x instead of - # x[0] in a function. - # Without this, broadcasting fails. - shp = getattr(res, 'shape', ()) - if shp and shp[0] == 1: - res = res.reshape(res.shape[1:]) - bcast_res.append( - np.broadcast_to(res, scalar_out_shape)) + try: + out_arr = np.asarray(out) + except ValueError: + out_arr = np.asarray(_broadcast_nested_list(out, scalar_out_shape, ndim=ndim)) - out_arr = np.array(bcast_res, dtype=scalar_out_dtype) - elif results.dtype != scalar_out_dtype: + if out_arr.dtype != scalar_out_dtype: raise ValueError( 'result is of dtype {}, expected {}' - ''.format(dtype_repr(results.dtype), + ''.format(dtype_repr(out_arr.dtype), dtype_repr(scalar_out_dtype)) ) - else: - out_arr = results out = out_arr.reshape(out_shape) diff --git a/odl/discr/grid.py b/odl/discr/grid.py index ef0d9bd2c35..0317629fe93 100644 --- a/odl/discr/grid.py +++ b/odl/discr/grid.py @@ -1111,8 +1111,8 @@ def uniform_grid_fromintv(intv_prod, shape, nodes_on_bdry=True): shape = normalized_scalar_param_list(shape, intv_prod.ndim, safe_int_conv) - if np.shape(nodes_on_bdry) == (): - nodes_on_bdry = ([(bool(nodes_on_bdry), bool(nodes_on_bdry))] * + if isinstance(nodes_on_bdry, bool): + nodes_on_bdry = ([(nodes_on_bdry, nodes_on_bdry)] * intv_prod.ndim) elif intv_prod.ndim == 1 and len(nodes_on_bdry) == 2: nodes_on_bdry = [nodes_on_bdry] diff --git a/odl/test/discr/discr_space_test.py b/odl/test/discr/discr_space_test.py index 4b254e5d494..ebfda76ae1e 100644 --- a/odl/test/discr/discr_space_test.py +++ b/odl/test/discr/discr_space_test.py @@ -897,7 +897,9 @@ def test_ufunc_corner_cases(odl_tspace_impl): # --- UFuncs with nin = 1, nout = 1 --- # - with pytest.raises(ValueError): + wrong_argcount_error = ValueError if np.__version__ < "1.21" else TypeError + + with pytest.raises(wrong_argcount_error): # Too many arguments x.__array_ufunc__(np.sin, '__call__', x, np.ones((2, 3))) @@ -928,7 +930,7 @@ def test_ufunc_corner_cases(odl_tspace_impl): # --- UFuncs with nin = 2, nout = 1 --- # - with pytest.raises(ValueError): + with pytest.raises(wrong_argcount_error): # Too few arguments x.__array_ufunc__(np.add, '__call__', x) diff --git a/odl/test/space/tensors_test.py b/odl/test/space/tensors_test.py index 1af171b20e3..422a5450fa2 100644 --- a/odl/test/space/tensors_test.py +++ b/odl/test/space/tensors_test.py @@ -1530,7 +1530,9 @@ def test_ufunc_corner_cases(odl_tspace_impl): # --- Ufuncs with nin = 1, nout = 1 --- # - with pytest.raises(ValueError): + wrong_argcount_error = ValueError if np.__version__ < "1.21" else TypeError + + with pytest.raises(wrong_argcount_error): # Too many arguments x.__array_ufunc__(np.sin, '__call__', x, np.ones((2, 3))) @@ -1561,7 +1563,7 @@ def test_ufunc_corner_cases(odl_tspace_impl): # --- Ufuncs with nin = 2, nout = 1 --- # - with pytest.raises(ValueError): + with pytest.raises(wrong_argcount_error): # Too few arguments x.__array_ufunc__(np.add, '__call__', x) diff --git a/odl/util/normalize.py b/odl/util/normalize.py index e28911d32bf..27984f0640a 100644 --- a/odl/util/normalize.py +++ b/odl/util/normalize.py @@ -278,11 +278,11 @@ def normalized_nodes_on_bdry(nodes_on_bdry, length): >>> normalized_nodes_on_bdry([[True, False], False, True], length=3) [(True, False), (False, False), (True, True)] """ - shape = np.shape(nodes_on_bdry) - if shape == (): - out_list = [(bool(nodes_on_bdry), bool(nodes_on_bdry))] * length - elif length == 1 and shape == (2,): - out_list = [(bool(nodes_on_bdry[0]), bool(nodes_on_bdry[1]))] + if isinstance(nodes_on_bdry, bool): + return [(nodes_on_bdry, nodes_on_bdry)] * length + elif (length == 1 and len(nodes_on_bdry) == 2 + and all(isinstance(d, bool) for d in nodes_on_bdry)): + return [nodes_on_bdry[0], nodes_on_bdry[1]] elif len(nodes_on_bdry) == length: out_list = [] diff --git a/odl/util/vectorization.py b/odl/util/vectorization.py index 6d668a8c0fc..460fa62e81c 100644 --- a/odl/util/vectorization.py +++ b/odl/util/vectorization.py @@ -21,7 +21,10 @@ def is_valid_input_array(x, ndim=None): """Test if ``x`` is a correctly shaped point array in R^d.""" - x = np.asarray(x) + try: + x = np.asarray(x) + except ValueError: + return False if ndim is None or ndim == 1: return x.ndim == 1 and x.size > 1 or x.ndim == 2 and x.shape[0] == 1