Skip to content

Commit

Permalink
Fixing issues related to dtype=object arrays in interpolation routines (
Browse files Browse the repository at this point in the history
#1655)

* Ensure the distances to computed linear-interpolation weights from are stored as arrays.

Without this, NumPy implicitly generates arrays but makes them ragged (dtype=object),
which is bad for performance and disabled in newer versions.

* Failure to convert to array implies it is not a suitable input array.

In older NumPy, this would silently create an array-of-object, but that
has a wrong shape too.

* Ensure interpolation weight is computed with the correct array type.

* Ensure mesh coordinate calculations are not carried out with dtype=object.

This would previously happen because meshgrids are passed in as ragged arrays,
and NumPy does not convert the rows to float dtype as should happen.

* For integral data, interpolating on integral points may not be appropriate.

One of the tests samples from an integral grid at non-integral points.
This failed after the explicit conversion introduced in c90044a.
Falling back to `float` fixes the test case, though perhaps it would
be better to select a dedicated meshing dtype.

* Add warning when falling back to `float` for interpolation coefficients.

When this happens it is likely that the used did something like starting
from an integer mesh, but in that case linear interpolation does not seem
very appropriate.
  • Loading branch information
leftaroundabout authored Aug 30, 2024
1 parent 3528685 commit f720076
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
12 changes: 11 additions & 1 deletion odl/discr/discr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from builtins import object
from functools import partial
from itertools import product
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -621,6 +622,13 @@ def _find_indices(self, x):

# iterate through dimensions
for xi, cvec in zip(x, self.coord_vecs):
try:
xi = np.asarray(xi).astype(self.values.dtype, casting='safe')
except TypeError:
warn("Unable to infer accurate dtype for"
+" interpolation coefficients, defaulting to `float`.")
xi = np.asarray(xi, dtype=float)

idcs = np.searchsorted(cvec, xi) - 1

idcs[idcs < 0] = 0
Expand Down Expand Up @@ -706,6 +714,8 @@ def _compute_nearest_weights_edge(idcs, ndist):

def _compute_linear_weights_edge(idcs, ndist):
"""Helper for linear interpolation."""
ndist = np.asarray(ndist)

# Get out-of-bounds indices from the norm_distances. Negative
# means "too low", larger than or equal to 1 means "too high"
lo = np.where(ndist < 0)
Expand Down Expand Up @@ -799,7 +809,7 @@ def _evaluate(self, indices, norm_distances, out=None):
# axis, resulting in a loop of length 2**ndim
for lo_hi, edge in zip(product(*([['l', 'h']] * len(indices))),
product(*edge_indices)):
weight = 1.0
weight = np.array([1.0], dtype=self.values.dtype)
# TODO(kohr-h): determine best summation order from array strides
for lh, w_lo, w_hi in zip(lo_hi, low_weights, high_weights):

Expand Down
5 changes: 4 additions & 1 deletion odl/util/vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@

def is_valid_input_array(x, ndim=None):
"""Test if ``x`` is a correctly shaped point array in R^d."""
x = np.asarray(x)
try:
x = np.asarray(x)
except ValueError:
return False

if ndim is None or ndim == 1:
return x.ndim == 1 and x.size > 1 or x.ndim == 2 and x.shape[0] == 1
Expand Down

0 comments on commit f720076

Please sign in to comment.