diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 88887a24c8..3ee9b2fa44 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3443,7 +3443,9 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl idx = self.R.permutation(tuple(range(image_np.shape[0]))) idx = idx[: self.num_patches] idx_np = convert_data_type(idx, np.ndarray)[0] - return image_np[idx], locations[idx_np] + image_np = image_np[idx] + locations = locations[idx_np] + return image_np, locations elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX): raise ValueError(f'`sort_fn` should be either "min", "max", "random" or None! {self.sort_fn} provided!') return super().filter_count(image_np, locations) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4ba5849c46..79742f0582 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2436,7 +2436,7 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait): overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), - lowest values (`"min"`), or in their default order (`None`). Default to None. + lowest values (`"min"`), in random ("random"), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.