Skip to content

Commit

Permalink
Parallelize AIS post-processing (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape authored Feb 5, 2025
1 parent 959af7d commit c319e36
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 15 deletions.
6 changes: 5 additions & 1 deletion micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def automatic_instance_segmentation(
verbose=verbose,
)

segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
# If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})

segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
masks = segmenter.generate(**generate_kwargs)

if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels
Expand Down
104 changes: 90 additions & 14 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import vigra
import numpy as np
import elf.parallel as parallel
from elf.parallel.filters import apply_filter
from skimage.measure import label, regionprops
from skimage.segmentation import relabel_sequential

Expand Down Expand Up @@ -559,11 +561,13 @@ def generate(


# Helper function for tiled embedding computation and checking consistent state.
def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo):
def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose):
if image_embeddings is None:
if tile_shape is None or halo is None:
raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
image_embeddings = util.precompute_image_embeddings(predictor, image, tile_shape=tile_shape, halo=halo)
image_embeddings = util.precompute_image_embeddings(
predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose
)

# Use tile shape and halo from the precomputed embeddings if not given.
# Otherwise check that they are consistent.
Expand Down Expand Up @@ -650,7 +654,7 @@ def initialize(
self._original_size = original_size

image_embeddings, tile_shape, halo = _process_tiled_embeddings(
self._predictor, image, image_embeddings, tile_shape, halo
self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
)

tiling = blocking([0, 0], original_size, tile_shape)
Expand Down Expand Up @@ -853,6 +857,55 @@ def get_predictor_and_decoder(
return predictor, decoder


def _watershed_from_center_and_boundary_distances_parallel(
center_distances,
boundary_distances,
foreground_map,
center_distance_threshold,
boundary_distance_threshold,
foreground_threshold,
distance_smoothing,
min_size,
tile_shape,
halo,
n_threads,
verbose=False,
):
center_distances = apply_filter(
center_distances, "gaussianSmoothing", sigma=distance_smoothing,
block_shape=tile_shape, n_threads=n_threads
)
boundary_distances = apply_filter(
boundary_distances, "gaussianSmoothing", sigma=distance_smoothing,
block_shape=tile_shape, n_threads=n_threads
)

fg_mask = foreground_map > foreground_threshold

marker_map = np.logical_and(
center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold
)
marker_map[~fg_mask] = 0

markers = np.zeros(marker_map.shape, dtype="uint64")
markers = parallel.label(
marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
)

seg = np.zeros_like(markers, dtype="uint64")
seg = parallel.seeded_watershed(
boundary_distances, seeds=markers, out=seg, block_shape=tile_shape,
halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
)

out = np.zeros_like(seg, dtype="uint64")
out = parallel.size_filter(
seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose
)

return out


class InstanceSegmentationWithDecoder:
"""Generates an instance segmentation without prompts, using a decoder.
Expand Down Expand Up @@ -988,6 +1041,9 @@ def generate(
distance_smoothing: float = 1.6,
min_size: int = 0,
output_mode: Optional[str] = "binary_mask",
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
n_threads: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""Generate instance segmentation for the currently initialized image.
Expand All @@ -1002,6 +1058,11 @@ def generate(
distance_smoothing: Sigma value for smoothing the distance predictions.
min_size: Minimal object size in the segmentation result.
output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
This parameter is independent from the tile shape for computing the embeddings.
If not given then post-processing will not be parallelized.
halo: Halo for parallel post-processing. See also `tile_shape`.
n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
Returns:
The instance segmentation masks.
Expand All @@ -1013,16 +1074,29 @@ def generate(
foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
else:
foreground = self._foreground
# Further optimization: parallel implementation using elf.parallel functionality.
# (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
segmentation = watershed_from_center_and_boundary_distances(
self._center_distances, self._boundary_distances, foreground,
center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
foreground_threshold=foreground_threshold,
distance_smoothing=distance_smoothing,
min_size=min_size,
)

if tile_shape is None:
segmentation = watershed_from_center_and_boundary_distances(
self._center_distances, self._boundary_distances, foreground,
center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
foreground_threshold=foreground_threshold,
distance_smoothing=distance_smoothing,
min_size=min_size,
)
else:
if halo is None:
raise ValueError("You must pass a value for halo if tile_shape is given.")
segmentation = _watershed_from_center_and_boundary_distances_parallel(
self._center_distances, self._boundary_distances, foreground,
center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
foreground_threshold=foreground_threshold,
distance_smoothing=distance_smoothing,
min_size=min_size, tile_shape=tile_shape,
halo=halo, n_threads=n_threads, verbose=False,
)

if output_mode is not None:
segmentation = self._to_masks(segmentation, output_mode)
return segmentation
Expand Down Expand Up @@ -1094,7 +1168,7 @@ def initialize(
"""
original_size = image.shape[:2]
image_embeddings, tile_shape, halo = _process_tiled_embeddings(
self._predictor, image, image_embeddings, tile_shape, halo
self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
)
tiling = blocking([0, 0], original_size, tile_shape)

Expand Down Expand Up @@ -1127,6 +1201,8 @@ def initialize(
foreground[inner_bb] = output[0][local_bb]
center_distances[inner_bb] = output[1][local_bb]
boundary_distances[inner_bb] = output[2][local_bb]
pbar_update(1)

pbar_close()

# Set the state.
Expand Down

0 comments on commit c319e36

Please sign in to comment.