diff --git a/ndfilters/_generic.py b/ndfilters/_generic.py index de22186..73cd4f3 100644 --- a/ndfilters/_generic.py +++ b/ndfilters/_generic.py @@ -14,7 +14,7 @@ def generic_filter( size: int | tuple[int, ...], axis: None | int | tuple[int, ...] = None, where: bool | np.ndarray = True, - mode: Literal["mirror"] = "mirror", + mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror", args: tuple = (), ) -> np.ndarray: """ @@ -42,7 +42,8 @@ def generic_filter( mode The method used to extend the input array beyond its boundaries. See :func:`scipy.ndimage.generic_filter` for the definitions. - Currently, only "reflect" mode is supported. + Currently, only "mirror", "nearest", "wrap", and "truncate" modes are + supported. args Extra arguments to pass to function. @@ -98,9 +99,6 @@ def function(a: np.ndarray, args: tuple) -> float: f"{size=} should have the same number of elements as {axis=}." ) - if mode != "mirror": # pragma: nocover - raise ValueError(f"Only mode='reflected' is supported, got {mode=}") - axis_numba = ~np.arange(len(axis))[::-1] shape = array.shape @@ -138,6 +136,30 @@ def function(a: np.ndarray, args: tuple) -> float: return result +@numba.njit +def _rectify_index_lower(index: int, size: int, mode: str) -> int: + if mode == "mirror": + return -index + elif mode == "nearest": + return 0 + elif mode == "wrap": + return index % size + else: # pragma: nocover + raise ValueError + + +@numba.njit +def _rectify_index_upper(index: int, size: int, mode: str) -> int: + if mode == "mirror": + return ~(index % size + 1) + elif mode == "nearest": + return size - 1 + elif mode == "wrap": + return index % size + else: # pragma: nocover + raise ValueError + + @numba.njit(parallel=True) def _generic_filter_1d( array: np.ndarray, @@ -157,8 +179,8 @@ def _generic_filter_1d( for ix in numba.prange(array_shape_x): - values = np.empty(shape=size) - mask = np.empty(shape=size, dtype=np.bool_) + values = np.zeros(shape=size) + mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): @@ -166,9 +188,13 @@ def _generic_filter_1d( jx = ix + px if jx < 0: - jx = -jx + if mode == "truncate": + continue + jx = _rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: - jx = ~(jx % array_shape_x + 1) + if mode == "truncate": + continue + jx = _rectify_index_upper(jx, array_shape_x, mode) values[kx] = array[it, jx] mask[kx] = where[it, jx] @@ -198,8 +224,8 @@ def _generic_filter_2d( for ix in numba.prange(array_shape_x): for iy in numba.prange(array_shape_y): - values = np.empty(shape=size) - mask = np.empty(shape=size, dtype=np.bool_) + values = np.zeros(shape=size) + mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): @@ -207,9 +233,13 @@ def _generic_filter_2d( jx = ix + px if jx < 0: - jx = -jx + if mode == "truncate": + continue + jx = _rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: - jx = ~(jx % array_shape_x + 1) + if mode == "truncate": + continue + jx = _rectify_index_upper(jx, array_shape_x, mode) for ky in range(kernel_shape_y): @@ -217,9 +247,13 @@ def _generic_filter_2d( jy = iy + py if jy < 0: - jy = -jy + if mode == "truncate": + continue + jy = _rectify_index_lower(jy, array_shape_y, mode) elif jy >= array_shape_y: - jy = ~(jy % array_shape_y + 1) + if mode == "truncate": + continue + jy = _rectify_index_upper(jy, array_shape_y, mode) values[kx, ky] = array[it, jx, jy] mask[kx, ky] = where[it, jx, jy] @@ -253,8 +287,8 @@ def _generic_filter_3d( for iy in numba.prange(array_shape_y): for iz in numba.prange(array_shape_z): - values = np.empty(shape=size) - mask = np.empty(shape=size, dtype=np.bool_) + values = np.zeros(shape=size) + mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): @@ -262,9 +296,13 @@ def _generic_filter_3d( jx = ix + px if jx < 0: - jx = -jx + if mode == "truncate": + continue + jx = _rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: - jx = ~(jx % array_shape_x + 1) + if mode == "truncate": + continue + jx = _rectify_index_upper(jx, array_shape_x, mode) for ky in range(kernel_shape_y): @@ -272,9 +310,13 @@ def _generic_filter_3d( jy = iy + py if jy < 0: - jy = -jy + if mode == "truncate": + continue + jy = _rectify_index_lower(jy, array_shape_y, mode) elif jy >= array_shape_y: - jy = ~(jy % array_shape_y + 1) + if mode == "truncate": + continue + jy = _rectify_index_upper(jy, array_shape_y, mode) for kz in range(kernel_shape_z): @@ -282,9 +324,13 @@ def _generic_filter_3d( jz = iz + pz if jz < 0: - jz = -jz + if mode == "truncate": + continue + jz = _rectify_index_lower(jz, array_shape_z, mode) elif jz >= array_shape_z: - jz = ~(jz % array_shape_z + 1) + if mode == "truncate": + continue + jz = _rectify_index_upper(jz, array_shape_z, mode) values[kx, ky, kz] = array[it, jx, jy, jz] mask[kx, ky, kz] = where[it, jx, jy, jz] diff --git a/ndfilters/_tests/test_generic.py b/ndfilters/_tests/test_generic.py index 4506429..ed15ceb 100644 --- a/ndfilters/_tests/test_generic.py +++ b/ndfilters/_tests/test_generic.py @@ -36,13 +36,17 @@ def _mean(a: np.ndarray, args: tuple = ()) -> float: argnames="mode", argvalues=[ "mirror", + "nearest", + "wrap", + "truncate", + pytest.param("foo", marks=pytest.mark.xfail), ], ) def test_generic_filter( array: np.ndarray | u.Quantity, function: Callable[[np.ndarray], float], size: int | tuple[int, ...], - mode: Literal["mirror"], + mode: Literal["mirror", "nearest", "wrap", "truncate"], ): result = ndfilters.generic_filter( array=array, @@ -50,17 +54,19 @@ def test_generic_filter( size=size, mode=mode, ) + assert result.shape == array.shape + assert result.sum() != 0 - result_expected = scipy.ndimage.generic_filter( - input=array, - function=function, - size=size, - mode=mode, - ) + if mode != "truncate": + result_expected = scipy.ndimage.generic_filter( + input=array, + function=function, + size=size, + mode=mode, + ) - assert result.shape == array.shape - if isinstance(array, u.Quantity): - assert np.all(result.value == result_expected) - assert result.unit == array.unit - else: - assert np.all(result == result_expected) + if isinstance(array, u.Quantity): + assert np.all(result.value == result_expected) + assert result.unit == array.unit + else: + assert np.all(result == result_expected) diff --git a/ndfilters/_tests/test_trimmed_mean.py b/ndfilters/_tests/test_trimmed_mean.py index 989c0b2..50b3441 100644 --- a/ndfilters/_tests/test_trimmed_mean.py +++ b/ndfilters/_tests/test_trimmed_mean.py @@ -1,3 +1,4 @@ +from typing import Literal import pytest import numpy as np import scipy.ndimage @@ -20,6 +21,7 @@ @pytest.mark.parametrize( argnames="axis", argvalues=[ + None, 0, -1, (0,), @@ -30,11 +32,29 @@ (2, 1, 0), ], ) +@pytest.mark.parametrize( + argnames="where", + argvalues=[ + True, + False, + ], +) @pytest.mark.parametrize("proportion", [0.25, 0.45]) +@pytest.mark.parametrize( + argnames="mode", + argvalues=[ + "mirror", + "nearest", + "wrap", + "truncate", + ], +) def test_trimmed_mean_filter( array: np.ndarray, size: int | tuple[int, ...], axis: None | int | tuple[int, ...], + where: bool | np.ndarray, + mode: Literal["mirror", "nearest", "wrap", "truncate"], proportion: float, ): if axis is None: @@ -51,6 +71,7 @@ def test_trimmed_mean_filter( size=size, proportion=proportion, axis=axis, + where=where, ) return @@ -66,6 +87,7 @@ def test_trimmed_mean_filter( size=size, proportion=proportion, axis=axis, + where=where, ) return @@ -74,8 +96,19 @@ def test_trimmed_mean_filter( size=size, proportion=proportion, axis=axis, + where=where, + mode=mode, ) + assert result.shape == array.shape + assert result.sum() != 0 + + if mode == "truncate": + return + + if not np.all(where): + return + size_scipy = [1] * array.ndim for i, ax in enumerate(axis_normalized): size_scipy[ax] = size_normalized[i] @@ -84,7 +117,7 @@ def test_trimmed_mean_filter( input=array, function=scipy.stats.trim_mean, size=size_scipy, - mode="mirror", + mode=mode, extra_keywords=dict(proportiontocut=proportion), ) diff --git a/ndfilters/_trimmed_mean.py b/ndfilters/_trimmed_mean.py index 1bd011a..54eceb4 100644 --- a/ndfilters/_trimmed_mean.py +++ b/ndfilters/_trimmed_mean.py @@ -14,7 +14,7 @@ def trimmed_mean_filter( size: int | tuple[int, ...], axis: None | int | tuple[int, ...] = None, where: bool | np.ndarray = True, - mode: Literal["mirror"] = "mirror", + mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror", proportion: float = 0.25, ) -> np.ndarray: """ @@ -36,7 +36,8 @@ def trimmed_mean_filter( mode The method used to extend the input array beyond its boundaries. See :func:`scipy.ndimage.generic_filter` for the definitions. - Currently, only "reflect" mode is supported. + Currently, only "mirror", "nearest", "wrap", and "truncate" modes are + supported. proportion The proportion to cut from the top and bottom of the distribution. @@ -83,6 +84,8 @@ def _trimmed_mean( (proportion,) = args nobs = array.size + if nobs == 0: + return np.nan lowercut = int(proportion * nobs) uppercut = nobs - lowercut if lowercut > uppercut: # pragma: nocover