Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithread batch sampler for PatchInferer #6139

Merged
merged 9 commits into from
Jul 15, 2023
16 changes: 15 additions & 1 deletion monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from monai.apps.utils import get_logger
from monai.data.meta_tensor import MetaTensor
from monai.data.thread_buffer import ThreadBuffer
from monai.inferers.merger import AvgMerger, Merger
from monai.inferers.splitter import Splitter
from monai.inferers.utils import compute_importance_map, sliding_window_inference
Expand Down Expand Up @@ -103,6 +104,7 @@
the output dictionary to be used for merging.
Defaults to None, where all the keys are used.
match_spatial_shape: whether to crop the output to match the input shape. Defaults to True.
buffer_size: number of patches to be held in the buffer with a separate thread for batch sampling. Defaults to 0.
merger_kwargs: arguments to be passed to `merger_cls` for instantiation.
`merged_shape` is calculated automatically based on the input shape and
the output patch shape unless it is passed here.
Expand All @@ -117,6 +119,7 @@
postprocessing: Callable | None = None,
output_keys: Sequence | None = None,
match_spatial_shape: bool = True,
buffer_size: int = 0,
**merger_kwargs: Any,
) -> None:
Inferer.__init__(self)
Expand Down Expand Up @@ -157,6 +160,8 @@
self.postprocessing = postprocessing

# batch size for patches
if batch_size < 1:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"`batch_size` must be a positive number, {batch_size} is given.")

Check warning on line 164 in monai/inferers/inferer.py

View check run for this annotation

Codecov / codecov/patch

monai/inferers/inferer.py#L164

Added line #L164 was not covered by tests
self.batch_size = batch_size

# model output keys
Expand All @@ -165,6 +170,9 @@
# whether to crop the output to match the input shape
self.match_spatial_shape = match_spatial_shape

# buffer size for multithreaded batch sampling
self.buffer_size = buffer_size

def _batch_sampler(
self, patches: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
) -> Iterator[tuple[torch.Tensor, Sequence, int]]:
Expand All @@ -182,10 +190,16 @@
batch_size = min(self.batch_size, total_size - i)
yield patches[i : i + batch_size], patches[i : i + batch_size].meta[PatchKeys.LOCATION], batch_size # type: ignore
else:
buffer: Iterable | ThreadBuffer
if self.buffer_size > 0:
# Use multi-threading to sample patches with a buffer
buffer = ThreadBuffer(patches, buffer_size=self.buffer_size, timeout=0.1)
else:
buffer = patches
patch_batch: list[Any] = [None] * self.batch_size
location_batch: list[Any] = [None] * self.batch_size
idx_in_batch = 0
for sample in patches:
for sample in buffer:
patch_batch[idx_in_batch] = sample[0]
location_batch[idx_in_batch] = sample[1]
idx_in_batch += 1
Expand Down
20 changes: 20 additions & 0 deletions tests/test_patch_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
TENSOR_4x4,
]


# non-divisible patch_size leading to larger image (without matching spatial shape)
TEST_CASE_11_PADDING = [
TENSOR_4x4,
Expand Down Expand Up @@ -155,6 +156,23 @@
TENSOR_4x4,
]

# multi-threading
TEST_CASE_14_MULTITHREAD_BUFFER = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2),
lambda x: x,
TENSOR_4x4,
]

# multi-threading with batch
TEST_CASE_15_MULTITHREADD_BUFFER = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4),
lambda x: x,
TENSOR_4x4,
]


# list of tensor output
TEST_CASE_0_LIST_TENSOR = [
TENSOR_4x4,
Expand Down Expand Up @@ -245,6 +263,8 @@ class PatchInfererTests(unittest.TestCase):
TEST_CASE_11_PADDING,
TEST_CASE_12_MATCHING,
TEST_CASE_13_PADDING_MATCHING,
TEST_CASE_14_MULTITHREAD_BUFFER,
TEST_CASE_15_MULTITHREADD_BUFFER,
]
)
def test_patch_inferer_tensor(self, inputs, arguments, network, expected):
Expand Down