From 2b9ec3188221fc7b79c85e3baa6c6f590800ca86 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 10 Jul 2023 14:11:47 -0400 Subject: [PATCH 1/3] implement multithread buffer for batch sampling Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/inferer.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index f9c04664be..5f3ffb02d5 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -22,6 +22,7 @@ from monai.apps.utils import get_logger from monai.data.meta_tensor import MetaTensor +from monai.data.thread_buffer import ThreadBuffer from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference @@ -103,6 +104,7 @@ class PatchInferer(Inferer): the output dictionary to be used for merging. Defaults to None, where all the keys are used. match_spatial_shape: whether to crop the output to match the input shape. Defaults to True. + buffer_size: number of patches to be held in the buffer with a separate thread for batch sampling. Defaults to 0. merger_kwargs: arguments to be passed to `merger_cls` for instantiation. `merged_shape` is calculated automatically based on the input shape and the output patch shape unless it is passed here. @@ -117,6 +119,7 @@ def __init__( postprocessing: Callable | None = None, output_keys: Sequence | None = None, match_spatial_shape: bool = True, + buffer_size: int = 0, **merger_kwargs: Any, ) -> None: Inferer.__init__(self) @@ -157,6 +160,8 @@ def __init__( self.postprocessing = postprocessing # batch size for patches + if batch_size < 1: + raise ValueError(f"`batch_size` must be a positive number, {batch_size} is given.") self.batch_size = batch_size # model output keys @@ -165,6 +170,9 @@ def __init__( # whether to crop the output to match the input shape self.match_spatial_shape = match_spatial_shape + # buffer size for multithreaded batch sampling + self.buffer_size = buffer_size + def _batch_sampler( self, patches: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor ) -> Iterator[tuple[torch.Tensor, Sequence, int]]: @@ -182,10 +190,16 @@ def _batch_sampler( batch_size = min(self.batch_size, total_size - i) yield patches[i : i + batch_size], patches[i : i + batch_size].meta[PatchKeys.LOCATION], batch_size # type: ignore else: + buffer: Iterable | ThreadBuffer + if self.buffer_size > 0: + # Use multi-threading to sample patches with a buffer + buffer = ThreadBuffer(patches, buffer_size=self.buffer_size, timeout=0.01) + else: + buffer = patches patch_batch: list[Any] = [None] * self.batch_size location_batch: list[Any] = [None] * self.batch_size idx_in_batch = 0 - for sample in patches: + for sample in buffer: patch_batch[idx_in_batch] = sample[0] location_batch[idx_in_batch] = sample[1] idx_in_batch += 1 From a40b29a7854e5774a4243c81e18bdf31df411849 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 10 Jul 2023 14:12:02 -0400 Subject: [PATCH 2/3] add unittests for multithread buffer Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_patch_inferer.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_patch_inferer.py b/tests/test_patch_inferer.py index b0d25a98b9..bc6fc30a88 100644 --- a/tests/test_patch_inferer.py +++ b/tests/test_patch_inferer.py @@ -127,6 +127,7 @@ TENSOR_4x4, ] + # non-divisible patch_size leading to larger image (without matching spatial shape) TEST_CASE_11_PADDING = [ TENSOR_4x4, @@ -155,6 +156,23 @@ TENSOR_4x4, ] +# multi-threading +TEST_CASE_14_MULTITHREAD_BUFFER = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2), + lambda x: x, + TENSOR_4x4, +] + +# multi-threading with batch +TEST_CASE_15_MULTITHREADD_BUFFER = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4), + lambda x: x, + TENSOR_4x4, +] + + # list of tensor output TEST_CASE_0_LIST_TENSOR = [ TENSOR_4x4, @@ -245,6 +263,8 @@ class PatchInfererTests(unittest.TestCase): TEST_CASE_11_PADDING, TEST_CASE_12_MATCHING, TEST_CASE_13_PADDING_MATCHING, + TEST_CASE_14_MULTITHREAD_BUFFER, + TEST_CASE_15_MULTITHREADD_BUFFER, ] ) def test_patch_inferer_tensor(self, inputs, arguments, network, expected): From aa34cc1936e043eef8080a20c06720b3fe4886ba Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 13 Jul 2023 11:32:01 -0400 Subject: [PATCH 3/3] increase timeout Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 5f3ffb02d5..5484970d82 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -193,7 +193,7 @@ def _batch_sampler( buffer: Iterable | ThreadBuffer if self.buffer_size > 0: # Use multi-threading to sample patches with a buffer - buffer = ThreadBuffer(patches, buffer_size=self.buffer_size, timeout=0.01) + buffer = ThreadBuffer(patches, buffer_size=self.buffer_size, timeout=0.1) else: buffer = patches patch_batch: list[Any] = [None] * self.batch_size