Skip to content

Commit

Permalink
Added "nearest", "wrap", and "truncate" modes to `ndfilters.generic_f…
Browse files Browse the repository at this point in the history
…ilter()`. (#17)
  • Loading branch information
byrdie authored Aug 23, 2024
1 parent c6185e0 commit 3ae5925
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 39 deletions.
92 changes: 69 additions & 23 deletions ndfilters/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def generic_filter(
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror"] = "mirror",
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
args: tuple = (),
) -> np.ndarray:
"""
Expand Down Expand Up @@ -42,7 +42,8 @@ def generic_filter(
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "reflect" mode is supported.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.
args
Extra arguments to pass to function.
Expand Down Expand Up @@ -98,9 +99,6 @@ def function(a: np.ndarray, args: tuple) -> float:
f"{size=} should have the same number of elements as {axis=}."
)

if mode != "mirror": # pragma: nocover
raise ValueError(f"Only mode='reflected' is supported, got {mode=}")

axis_numba = ~np.arange(len(axis))[::-1]

shape = array.shape
Expand Down Expand Up @@ -138,6 +136,30 @@ def function(a: np.ndarray, args: tuple) -> float:
return result


@numba.njit
def _rectify_index_lower(index: int, size: int, mode: str) -> int:
if mode == "mirror":
return -index
elif mode == "nearest":
return 0
elif mode == "wrap":
return index % size
else: # pragma: nocover
raise ValueError


@numba.njit
def _rectify_index_upper(index: int, size: int, mode: str) -> int:
if mode == "mirror":
return ~(index % size + 1)
elif mode == "nearest":
return size - 1
elif mode == "wrap":
return index % size
else: # pragma: nocover
raise ValueError


@numba.njit(parallel=True)
def _generic_filter_1d(
array: np.ndarray,
Expand All @@ -157,18 +179,22 @@ def _generic_filter_1d(

for ix in numba.prange(array_shape_x):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue
jx = _rectify_index_upper(jx, array_shape_x, mode)

values[kx] = array[it, jx]
mask[kx] = where[it, jx]
Expand Down Expand Up @@ -198,28 +224,36 @@ def _generic_filter_2d(
for ix in numba.prange(array_shape_x):
for iy in numba.prange(array_shape_y):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue
jx = _rectify_index_upper(jx, array_shape_x, mode)

for ky in range(kernel_shape_y):

py = ky - kernel_shape_y // 2
jy = iy + py

if jy < 0:
jy = -jy
if mode == "truncate":
continue
jy = _rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
jy = ~(jy % array_shape_y + 1)
if mode == "truncate":
continue
jy = _rectify_index_upper(jy, array_shape_y, mode)

values[kx, ky] = array[it, jx, jy]
mask[kx, ky] = where[it, jx, jy]
Expand Down Expand Up @@ -253,38 +287,50 @@ def _generic_filter_3d(
for iy in numba.prange(array_shape_y):
for iz in numba.prange(array_shape_z):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue
jx = _rectify_index_upper(jx, array_shape_x, mode)

for ky in range(kernel_shape_y):

py = ky - kernel_shape_y // 2
jy = iy + py

if jy < 0:
jy = -jy
if mode == "truncate":
continue
jy = _rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
jy = ~(jy % array_shape_y + 1)
if mode == "truncate":
continue
jy = _rectify_index_upper(jy, array_shape_y, mode)

for kz in range(kernel_shape_z):

pz = kz - kernel_shape_z // 2
jz = iz + pz

if jz < 0:
jz = -jz
if mode == "truncate":
continue
jz = _rectify_index_lower(jz, array_shape_z, mode)
elif jz >= array_shape_z:
jz = ~(jz % array_shape_z + 1)
if mode == "truncate":
continue
jz = _rectify_index_upper(jz, array_shape_z, mode)

values[kx, ky, kz] = array[it, jx, jy, jz]
mask[kx, ky, kz] = where[it, jx, jy, jz]
Expand Down
32 changes: 19 additions & 13 deletions ndfilters/_tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,37 @@ def _mean(a: np.ndarray, args: tuple = ()) -> float:
argnames="mode",
argvalues=[
"mirror",
"nearest",
"wrap",
"truncate",
pytest.param("foo", marks=pytest.mark.xfail),
],
)
def test_generic_filter(
array: np.ndarray | u.Quantity,
function: Callable[[np.ndarray], float],
size: int | tuple[int, ...],
mode: Literal["mirror"],
mode: Literal["mirror", "nearest", "wrap", "truncate"],
):
result = ndfilters.generic_filter(
array=array,
function=function,
size=size,
mode=mode,
)
assert result.shape == array.shape
assert result.sum() != 0

result_expected = scipy.ndimage.generic_filter(
input=array,
function=function,
size=size,
mode=mode,
)
if mode != "truncate":
result_expected = scipy.ndimage.generic_filter(
input=array,
function=function,
size=size,
mode=mode,
)

assert result.shape == array.shape
if isinstance(array, u.Quantity):
assert np.all(result.value == result_expected)
assert result.unit == array.unit
else:
assert np.all(result == result_expected)
if isinstance(array, u.Quantity):
assert np.all(result.value == result_expected)
assert result.unit == array.unit
else:
assert np.all(result == result_expected)
35 changes: 34 additions & 1 deletion ndfilters/_tests/test_trimmed_mean.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
import pytest
import numpy as np
import scipy.ndimage
Expand All @@ -20,6 +21,7 @@
@pytest.mark.parametrize(
argnames="axis",
argvalues=[
None,
0,
-1,
(0,),
Expand All @@ -30,11 +32,29 @@
(2, 1, 0),
],
)
@pytest.mark.parametrize(
argnames="where",
argvalues=[
True,
False,
],
)
@pytest.mark.parametrize("proportion", [0.25, 0.45])
@pytest.mark.parametrize(
argnames="mode",
argvalues=[
"mirror",
"nearest",
"wrap",
"truncate",
],
)
def test_trimmed_mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...],
where: bool | np.ndarray,
mode: Literal["mirror", "nearest", "wrap", "truncate"],
proportion: float,
):
if axis is None:
Expand All @@ -51,6 +71,7 @@ def test_trimmed_mean_filter(
size=size,
proportion=proportion,
axis=axis,
where=where,
)
return

Expand All @@ -66,6 +87,7 @@ def test_trimmed_mean_filter(
size=size,
proportion=proportion,
axis=axis,
where=where,
)
return

Expand All @@ -74,8 +96,19 @@ def test_trimmed_mean_filter(
size=size,
proportion=proportion,
axis=axis,
where=where,
mode=mode,
)

assert result.shape == array.shape
assert result.sum() != 0

if mode == "truncate":
return

if not np.all(where):
return

size_scipy = [1] * array.ndim
for i, ax in enumerate(axis_normalized):
size_scipy[ax] = size_normalized[i]
Expand All @@ -84,7 +117,7 @@ def test_trimmed_mean_filter(
input=array,
function=scipy.stats.trim_mean,
size=size_scipy,
mode="mirror",
mode=mode,
extra_keywords=dict(proportiontocut=proportion),
)

Expand Down
7 changes: 5 additions & 2 deletions ndfilters/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def trimmed_mean_filter(
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror"] = "mirror",
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
proportion: float = 0.25,
) -> np.ndarray:
"""
Expand All @@ -36,7 +36,8 @@ def trimmed_mean_filter(
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "reflect" mode is supported.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.
proportion
The proportion to cut from the top and bottom of the distribution.
Expand Down Expand Up @@ -83,6 +84,8 @@ def _trimmed_mean(
(proportion,) = args

nobs = array.size
if nobs == 0:
return np.nan
lowercut = int(proportion * nobs)
uppercut = nobs - lowercut
if lowercut > uppercut: # pragma: nocover
Expand Down

0 comments on commit 3ae5925

Please sign in to comment.