Skip to content

Commit

Permalink
Manual broadcasting for the general tensor-with-inconsistent-dimensio…
Browse files Browse the repository at this point in the history
…ns case.

Previously, this was more or less automatically done in NumPy, but not anymore.
Normally, this is more likely to be user mistake. However, ODL actually
relies somewhat on such broadcasts when defining functions on "sparse"
mesh grids, so I added this functionality back be recursively
transversing lists of different-shape arrays (like in the old version
NumPy did, manually generating ragged dtype=object arrays).
  • Loading branch information
leftaroundabout committed Aug 29, 2024
1 parent 546a01c commit f17aab5
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions odl/discr/discr_utils.py
Original file line number Diff line number Diff line change
@@ -923,6 +923,23 @@ def _func_out_type(func):

return has_out, out_optional

def _broadcast_nested_list(arr_lists, element_shape, ndim):
""" A generalisation of `np.broadcast_to`, applied to an arbitrarily
deep list (or tuple) eventually containing arrays or scalars. """
if isinstance(arr_lists, np.ndarray) or np.isscalar(arr_lists):
if ndim == 1:
# As usual, 1d is tedious to deal with. This
# code deals with extra dimensions in result
# components that stem from using x instead of
# x[0] in a function.
# Without this, broadcasting fails.
shp = getattr(arr_lists, 'shape', ())
if shp and shp[0] == 1:
arr_lists = arr_lists.reshape(arr_lists.shape[1:])
return np.broadcast_to(arr_lists, element_shape)
else:
return [_broadcast_nested_list(row, element_shape, ndim) for row in arr_lists]


def sampling_function(func_or_arr, domain, out_dtype=None):
"""Return a function that can be used for sampling.
@@ -987,19 +1004,20 @@ def _default_oop(func_ip, x, **kwargs):

def _default_ip(func_oop, x, out, **kwargs):
"""Default in-place variant of an out-of-place-only function."""
result = np.array(func_oop(x, **kwargs), copy=False)
if result.dtype == object:
result = func_oop(x, **kwargs)
try:
result = np.array(result, copy=False)
except ValueError:
# Different shapes encountered, need to broadcast
flat_results = result.ravel()
if is_valid_input_array(x, domain.ndim):
scalar_out_shape = out_shape_from_array(x)
elif is_valid_input_meshgrid(x, domain.ndim):
scalar_out_shape = out_shape_from_meshgrid(x)
else:
raise TypeError('invalid input `x`')

bcast_results = [np.broadcast_to(res, scalar_out_shape)
for res in flat_results]
bcast_results = _broadcast_nested_list(result, scalar_out_shape, domain.ndim)

# New array that is flat in the `out_shape` axes, reshape it
# to the final `out_shape + scalar_shape`, using the same
# order ('C') as the initial `result.ravel()`.
@@ -1332,33 +1350,17 @@ def dual_use_func(x, out=None, **kwargs):
elif tensor_valued:
# The out object can be any array-like of objects with shapes
# that should all be broadcastable to scalar_out_shape.
results = np.array(out)
if results.dtype == object or scalar_in:
# Some results don't have correct shape, need to
# broadcast
bcast_res = []
for res in results.ravel():
if ndim == 1:
# As usual, 1d is tedious to deal with. This
# code deals with extra dimensions in result
# components that stem from using x instead of
# x[0] in a function.
# Without this, broadcasting fails.
shp = getattr(res, 'shape', ())
if shp and shp[0] == 1:
res = res.reshape(res.shape[1:])
bcast_res.append(
np.broadcast_to(res, scalar_out_shape))
try:
out_arr = np.asarray(out)
except ValueError:
out_arr = np.asarray(_broadcast_nested_list(out, scalar_out_shape, ndim=ndim))

out_arr = np.array(bcast_res, dtype=scalar_out_dtype)
elif results.dtype != scalar_out_dtype:
if out_arr.dtype != scalar_out_dtype:
raise ValueError(
'result is of dtype {}, expected {}'
''.format(dtype_repr(results.dtype),
''.format(dtype_repr(out_arr.dtype),
dtype_repr(scalar_out_dtype))
)
else:
out_arr = results

out = out_arr.reshape(out_shape)

0 comments on commit f17aab5

Please sign in to comment.