Skip to content

Commit

Permalink
add tests for projection
Browse files Browse the repository at this point in the history
  • Loading branch information
Henley13 committed May 5, 2020
1 parent f95ed8c commit cf105bc
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 32 deletions.
85 changes: 53 additions & 32 deletions bigfish/stack/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

import numpy as np

from .utils import check_array, check_parameter
from .utils import check_array
from .utils import check_parameter

from .preprocess import cast_img_uint8
from .preprocess import cast_img_uint16

from .filter import mean_filter


Expand All @@ -27,34 +31,39 @@ def maximum_projection(tensor):
"""
# check parameters
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16], allow_nan=False)
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16])

# project tensor along the z axis
projected_tensor = tensor.max(axis=0)

return projected_tensor


def mean_projection(tensor):
def mean_projection(tensor, return_float=False):
"""Project the z-dimension of a tensor, computing the mean intensity of
each yx pixel.
Parameters
----------
tensor : np.ndarray, np.uint
A 3-d tensor with shape (z, y, x).
return_float : bool
Return a (potentially more accurate) float array.
Returns
-------
projected_tensor : np.ndarray, np.float
projected_tensor : np.ndarray
A 2-d tensor with shape (y, x).
"""
# check parameters
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16], allow_nan=False)
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16])

# project tensor along the z axis
projected_tensor = tensor.mean(axis=0)
if return_float:
projected_tensor = tensor.mean(axis=0)
else:
projected_tensor = tensor.mean(axis=0, dtype=tensor.dtype)

return projected_tensor

Expand All @@ -75,7 +84,7 @@ def median_projection(tensor):
"""
# check parameters
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16], allow_nan=False)
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16])

# project tensor along the z axis
projected_tensor = np.median(tensor, axis=0)
Expand All @@ -85,8 +94,8 @@ def median_projection(tensor):


def focus_projection(tensor):
"""Project the z-dimension of a tensor as describe in Aubin's thesis
(part 5.3, strategy 5).
"""Project the z-dimension of a tensor as describe in Samacoits Aubin's
thesis (part 5.3, strategy 5).
1) We keep 75% best in-focus z-slices.
2) Compute a focus value for each voxel zyx with a 7x7 neighborhood window.
Expand All @@ -104,7 +113,7 @@ def focus_projection(tensor):
"""
# check parameters
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16], allow_nan=False)
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16])

# remove out-of-focus z-slices
in_focus_image = in_focus_selection(tensor,
Expand Down Expand Up @@ -141,9 +150,9 @@ def focus_projection_fast(tensor, proportion=0.75, neighborhood_size=7,
method="median"):
"""Project the z-dimension of a tensor.
Inspired from Aubin's thesis (part 5.3, strategy 5). Compare to the
original algorithm we use the same focus levels to select the in-focus
z-slices and project our tensor.
Inspired from Samacoits Aubin's thesis (part 5.3, strategy 5). Compare to
the original algorithm we use the same focus measures to select the
in-focus z-slices and project our tensor.
1) Compute a focus value for each voxel zyx with a fixed neighborhood size.
2) We keep 75% best in-focus z-slices (based on a global focus score).
Expand All @@ -156,7 +165,7 @@ def focus_projection_fast(tensor, proportion=0.75, neighborhood_size=7,
A 3-d tensor with shape (z, y, x).
proportion : float or int
Proportion of z-slices to keep (float between 0 and 1) or number of
z-slices to keep (integer above 1).
z-slices to keep (positive integer).
neighborhood_size : int
The size of the square used to define the neighborhood of each pixel.
method : str
Expand All @@ -168,9 +177,8 @@ def focus_projection_fast(tensor, proportion=0.75, neighborhood_size=7,
A 2-d tensor with shape (y, x).
"""
# TODO case where proportion = {0, 1}
# check parameters
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16], allow_nan=False)
check_array(tensor, ndim=3, dtype=[np.uint8, np.uint16])
check_parameter(proportion=(float, int),
neighborhood_size=int)
if isinstance(proportion, float) and 0 <= proportion <= 1:
Expand Down Expand Up @@ -232,7 +240,7 @@ def in_focus_selection(image, proportion, neighborhood_size=30):
A 3-d tensor with shape (z, y, x).
proportion : float or int
Proportion of z-slices to keep (float between 0 and 1) or number of
z-slices to keep (integer above 1).
z-slices to keep (positive integer).
neighborhood_size : int
The size of the square used to define the neighborhood of each pixel.
Expand All @@ -246,8 +254,7 @@ def in_focus_selection(image, proportion, neighborhood_size=30):
# check parameters
check_array(image,
ndim=3,
dtype=[np.uint8, np.uint16, np.float32, np.float64],
allow_nan=False)
dtype=[np.uint8, np.uint16, np.float32, np.float64])
check_parameter(proportion=(float, int),
neighborhood_size=int)
if isinstance(proportion, float) and 0 <= proportion <= 1:
Expand All @@ -268,7 +275,7 @@ def in_focus_selection(image, proportion, neighborhood_size=30):
return in_focus_image


def focus_measurement(image, neighborhood_size=30):
def focus_measurement(image, neighborhood_size=30, cast_8bit=False):
"""Helmli and Scherer’s mean method used as a focus metric.
For each pixel xy in an image, we compute the ratio:
Expand All @@ -286,6 +293,9 @@ def focus_measurement(image, neighborhood_size=30):
A 2-d or 3-d tensor with shape (y, x) or (z, y, x).
neighborhood_size : int
The size of the square used to define the neighborhood of each pixel.
cast_8bit : bool
Cast image in 8 bit before measuring the focus scores. Can speed up
the computation, but vanish the signal as well.
Returns
-------
Expand All @@ -306,7 +316,10 @@ def focus_measurement(image, neighborhood_size=30):
check_parameter(neighborhood_size=int)

# cast image in np.uint8
image = cast_img_uint8(image)
if cast_8bit:
image = cast_img_uint8(image, catch_warning=True)
else:
image = cast_img_uint16(image)

if image.ndim == 2:
ratio, global_focus = _focus_measurement_2d(image, neighborhood_size)
Expand All @@ -330,7 +343,7 @@ def _focus_measurement_2d(image, neighborhood_size):
Parameters
----------
image : np.ndarray, np.np.uint8
image : np.ndarray, np.uint
A 2-d tensor with shape (y, x).
neighborhood_size : int
The size of the square used to define the neighborhood of each pixel.
Expand Down Expand Up @@ -374,7 +387,7 @@ def _focus_measurement_3d(image, neighborhood_size):
Parameters
----------
image : np.ndarray, np.uint8
image : np.ndarray, np.uint
A 3-d tensor with shape (z, y, x).
neighborhood_size : int
The size of the square used to define the neighborhood of each pixel.
Expand Down Expand Up @@ -417,7 +430,7 @@ def get_in_focus_indices(global_focus, proportion):
is (z,) for a 3-d image or () for a 2-d image.
proportion : float or int
Proportion of z-slices to keep (float between 0 and 1) or number of
z-slices to keep (integer above 1).
z-slices to keep (positive integer).
Returns
-------
Expand All @@ -428,11 +441,7 @@ def get_in_focus_indices(global_focus, proportion):
# check parameters
check_parameter(global_focus=(np.ndarray, np.float32),
proportion=(float, int))
if isinstance(global_focus, np.ndarray):
check_array(global_focus,
ndim=[0, 1],
dtype=np.float32,
allow_nan=False)
check_array(global_focus, ndim=[0, 1], dtype=np.float32, allow_nan=False)
if isinstance(proportion, float) and 0 <= proportion <= 1:
n = int(len(global_focus) * proportion)
elif isinstance(proportion, int) and 0 <= proportion:
Expand All @@ -448,7 +457,7 @@ def get_in_focus_indices(global_focus, proportion):
return indices_to_keep


def _one_hot_3d(indices, depth):
def _one_hot_3d(indices, depth, return_boolean=False):
"""Build a 3-d one-hot matrix from a 2-d indices matrix.
Parameters
Expand All @@ -457,15 +466,24 @@ def _one_hot_3d(indices, depth):
A 2-d tensor with integer indices and shape (y, x).
depth : int
Depth of the 3-d one-hot matrix.
return_boolean : bool
Return a boolean one-hot encoded matrix.
Returns
-------
one_hot : np.ndarray, np.uint8
one_hot : np.ndarray
A 3-d binary tensor with shape (depth, y, x)
"""
# check parameters
check_parameter(depth=int)
check_array(indices,
ndim=2,
dtype=[np.uint8, np.uint16, np.uint32,
np.int8, np.int16, np.int32, np.int64])

# initialize the 3-d one-hot matrix
one_hot = np.zeros((indices.size, depth), dtype=np.uint8)
one_hot = np.zeros((indices.size, depth), dtype=indices.dtype)

# flatten the matrix to easily one-hot encode it, then reshape it
one_hot[np.arange(indices.size), indices.ravel()] = 1
Expand All @@ -474,4 +492,7 @@ def _one_hot_3d(indices, depth):
# rearrange the axis
one_hot = np.moveaxis(one_hot, source=2, destination=0)

if return_boolean:
one_hot = one_hot.astype(bool)

return one_hot
Loading

0 comments on commit cf105bc

Please sign in to comment.