From c319e36117d51c4df4a10b58d04c4f58d7f30283 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 5 Feb 2025 19:23:06 +0100 Subject: [PATCH] Parallelize AIS post-processing (#851) --- micro_sam/automatic_segmentation.py | 6 +- micro_sam/instance_segmentation.py | 104 ++++++++++++++++++++++++---- 2 files changed, 95 insertions(+), 15 deletions(-) diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index bcdb4c02..a51bc154 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -120,7 +120,11 @@ def automatic_instance_segmentation( verbose=verbose, ) - segmenter.initialize(image=image_data, image_embeddings=image_embeddings) + # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. + if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: + generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) + + segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose) masks = segmenter.generate(**generate_kwargs) if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index f9ba8908..b03fb5d9 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -12,6 +12,8 @@ import vigra import numpy as np +import elf.parallel as parallel +from elf.parallel.filters import apply_filter from skimage.measure import label, regionprops from skimage.segmentation import relabel_sequential @@ -559,11 +561,13 @@ def generate( # Helper function for tiled embedding computation and checking consistent state. -def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo): +def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose): if image_embeddings is None: if tile_shape is None or halo is None: raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.") - image_embeddings = util.precompute_image_embeddings(predictor, image, tile_shape=tile_shape, halo=halo) + image_embeddings = util.precompute_image_embeddings( + predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose + ) # Use tile shape and halo from the precomputed embeddings if not given. # Otherwise check that they are consistent. @@ -650,7 +654,7 @@ def initialize( self._original_size = original_size image_embeddings, tile_shape, halo = _process_tiled_embeddings( - self._predictor, image, image_embeddings, tile_shape, halo + self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, ) tiling = blocking([0, 0], original_size, tile_shape) @@ -853,6 +857,55 @@ def get_predictor_and_decoder( return predictor, decoder +def _watershed_from_center_and_boundary_distances_parallel( + center_distances, + boundary_distances, + foreground_map, + center_distance_threshold, + boundary_distance_threshold, + foreground_threshold, + distance_smoothing, + min_size, + tile_shape, + halo, + n_threads, + verbose=False, +): + center_distances = apply_filter( + center_distances, "gaussianSmoothing", sigma=distance_smoothing, + block_shape=tile_shape, n_threads=n_threads + ) + boundary_distances = apply_filter( + boundary_distances, "gaussianSmoothing", sigma=distance_smoothing, + block_shape=tile_shape, n_threads=n_threads + ) + + fg_mask = foreground_map > foreground_threshold + + marker_map = np.logical_and( + center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold + ) + marker_map[~fg_mask] = 0 + + markers = np.zeros(marker_map.shape, dtype="uint64") + markers = parallel.label( + marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose, + ) + + seg = np.zeros_like(markers, dtype="uint64") + seg = parallel.seeded_watershed( + boundary_distances, seeds=markers, out=seg, block_shape=tile_shape, + halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask, + ) + + out = np.zeros_like(seg, dtype="uint64") + out = parallel.size_filter( + seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose + ) + + return out + + class InstanceSegmentationWithDecoder: """Generates an instance segmentation without prompts, using a decoder. @@ -988,6 +1041,9 @@ def generate( distance_smoothing: float = 1.6, min_size: int = 0, output_mode: Optional[str] = "binary_mask", + tile_shape: Optional[Tuple[int, int]] = None, + halo: Optional[Tuple[int, int]] = None, + n_threads: Optional[int] = None, ) -> List[Dict[str, Any]]: """Generate instance segmentation for the currently initialized image. @@ -1002,6 +1058,11 @@ def generate( distance_smoothing: Sigma value for smoothing the distance predictions. min_size: Minimal object size in the segmentation result. output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. + tile_shape: Tile shape for parallelizing the instance segmentation post-processing. + This parameter is independent from the tile shape for computing the embeddings. + If not given then post-processing will not be parallelized. + halo: Halo for parallel post-processing. See also `tile_shape`. + n_threads: Number of threads for parallel post-processing. See also `tile_shape`. Returns: The instance segmentation masks. @@ -1013,16 +1074,29 @@ def generate( foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing) else: foreground = self._foreground - # Further optimization: parallel implementation using elf.parallel functionality. - # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization) - segmentation = watershed_from_center_and_boundary_distances( - self._center_distances, self._boundary_distances, foreground, - center_distance_threshold=center_distance_threshold, - boundary_distance_threshold=boundary_distance_threshold, - foreground_threshold=foreground_threshold, - distance_smoothing=distance_smoothing, - min_size=min_size, - ) + + if tile_shape is None: + segmentation = watershed_from_center_and_boundary_distances( + self._center_distances, self._boundary_distances, foreground, + center_distance_threshold=center_distance_threshold, + boundary_distance_threshold=boundary_distance_threshold, + foreground_threshold=foreground_threshold, + distance_smoothing=distance_smoothing, + min_size=min_size, + ) + else: + if halo is None: + raise ValueError("You must pass a value for halo if tile_shape is given.") + segmentation = _watershed_from_center_and_boundary_distances_parallel( + self._center_distances, self._boundary_distances, foreground, + center_distance_threshold=center_distance_threshold, + boundary_distance_threshold=boundary_distance_threshold, + foreground_threshold=foreground_threshold, + distance_smoothing=distance_smoothing, + min_size=min_size, tile_shape=tile_shape, + halo=halo, n_threads=n_threads, verbose=False, + ) + if output_mode is not None: segmentation = self._to_masks(segmentation, output_mode) return segmentation @@ -1094,7 +1168,7 @@ def initialize( """ original_size = image.shape[:2] image_embeddings, tile_shape, halo = _process_tiled_embeddings( - self._predictor, image, image_embeddings, tile_shape, halo + self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, ) tiling = blocking([0, 0], original_size, tile_shape) @@ -1127,6 +1201,8 @@ def initialize( foreground[inner_bb] = output[0][local_bb] center_distances[inner_bb] = output[1][local_bb] boundary_distances[inner_bb] = output[2][local_bb] + pbar_update(1) + pbar_close() # Set the state.