From d983ff08da116ba71fcdaaa77d4c3eaf30c02b02 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 1 Feb 2024 20:01:57 +0100 Subject: [PATCH] Make subampling maskers to output accel --- direct/common/subsample.py | 245 +++++++++++++++++++++++----------- direct/data/mri_transforms.py | 46 +++++-- 2 files changed, 204 insertions(+), 87 deletions(-) diff --git a/direct/common/subsample.py b/direct/common/subsample.py index 7762c5e7..5ba07131 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -16,8 +16,13 @@ import torch import direct.data.transforms as T -from direct.common._gaussian import gaussian_mask_1d, gaussian_mask_2d # pylint: disable=no-name-in-module -from direct.common._poisson import poisson as _poisson # pylint: disable=no-name-in-module +from direct.common._gaussian import ( + gaussian_mask_1d, + gaussian_mask_2d, +) # pylint: disable=no-name-in-module +from direct.common._poisson import ( + poisson as _poisson, +) # pylint: disable=no-name-in-module from direct.environment import DIRECT_CACHE_DIR from direct.types import Number from direct.utils import str_to_class @@ -33,7 +38,6 @@ "FastMRIEquispacedMaskFunc", "FastMRIMagicMaskFunc", "FastMRIRandomMaskFunc", - "Gaussian1DMaskFunc", "Gaussian2DMaskFunc", "RadialMaskFunc", "SpiralMaskFunc", @@ -88,6 +92,11 @@ def __init__( self.accelerations = accelerations self.uniform_range = uniform_range + if uniform_range and (len(center_fractions) != 2 or len(accelerations) != 2): + raise ValueError( + f"When `uniform_range` is True, both `center_fractions` and `accelerations` should have " + f"a length of two. Received center_fractions={center_fractions} and accelerations={accelerations}." + ) self.rng = np.random.RandomState() @@ -95,35 +104,43 @@ def choose_acceleration(self): if not self.accelerations: return None - if not self.uniform_range: + if self.uniform_range: + acceleration = self.rng.uniform(low=min(self.accelerations), high=max(self.accelerations), size=1)[0] + if self.center_fractions is None: + return acceleration + center_fraction = self.rng.uniform( + low=min(self.center_fractions), high=max(self.center_fractions), size=1 + )[0] + else: choice = self.rng.randint(0, len(self.accelerations)) acceleration = self.accelerations[choice] + if self.center_fractions is None: return acceleration center_fraction = self.center_fractions[choice] - return center_fraction, acceleration - raise NotImplementedError("Uniform range is not yet implemented.") + center_fraction = min(1 / acceleration, center_fraction) + return center_fraction, acceleration @abstractmethod def mask_func(self, *args, **kwargs): raise NotImplementedError("This method should be implemented by a child class.") - def __call__(self, *args, **kwargs) -> torch.Tensor: + def __call__(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Produces a sampling mask by calling class method :meth:`mask_func`. + This also might return additional values. + Parameters ---------- - *args - **kwargs + - *args: Variable length arguments. + - **kwargs: Variable keyword arguments. Returns ------- - mask: torch.Tensor - Sampling mask. + The return values of :meth:`mask_func` method. """ - mask = self.mask_func(*args, **kwargs) - return mask + return self.mask_func(*args, **kwargs) class FastMRIMaskFunc(BaseMaskFunc): @@ -204,7 +221,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Creates vertical line mask. Parameters @@ -217,12 +235,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. - + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -236,13 +255,15 @@ def mask_func( mask = self.center_mask_func(num_cols, num_low_freqs) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask # Create the mask prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) mask = mask | (self.rng.uniform(size=num_cols) < prob) - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask class CartesianRandomMaskFunc(FastMRIRandomMaskFunc): @@ -282,7 +303,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Creates a random vertical Cartesian mask. Parameters @@ -295,12 +317,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. - + return_acceleration : bool + If True, output will contain acceleration and num_center_lines. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and num_center_lines. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -313,13 +336,15 @@ def mask_func( mask = self.center_mask_func(num_cols, num_center_lines) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, num_center_lines) if return_acceleration else torch_mask # Create the mask prob = (num_cols / acceleration - num_center_lines) / (num_cols - num_center_lines) mask = mask | (self.rng.uniform(size=num_cols) < prob) - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, num_center_lines) if return_acceleration else torch_mask class FastMRIEquispacedMaskFunc(FastMRIMaskFunc): @@ -359,7 +384,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Creates an vertical equispaced vertical line mask. Parameters @@ -372,11 +398,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -389,8 +417,9 @@ def mask_func( mask = self.center_mask_func(num_cols, num_low_freqs) - if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + if return_acs or (num_low_freqs - num_cols // acceleration >= 0): + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask # determine acceleration rate by adjusting for the number of low frequencies adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) @@ -400,7 +429,8 @@ def mask_func( accel_samples = np.around(accel_samples).astype(np.uint) mask[accel_samples] = True - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask class CartesianEquispacedMaskFunc(FastMRIEquispacedMaskFunc): @@ -440,7 +470,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Creates an equispaced vertical Cartesian mask. Parameters @@ -453,12 +484,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. - + return_acceleration : bool + If True, output will contain acceleration and num_center_lines. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and num_center_lines. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -472,7 +504,8 @@ def mask_func( mask = self.center_mask_func(num_cols, num_center_lines) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, num_center_lines) if return_acceleration else torch_mask # determine acceleration rate by adjusting for the number of low frequencies adjusted_accel = (acceleration * (num_center_lines - num_cols)) / ( @@ -484,7 +517,8 @@ def mask_func( accel_samples = np.around(accel_samples).astype(np.uint) mask[accel_samples] = True - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, num_center_lines) if return_acceleration else torch_mask class FastMRIMagicMaskFunc(FastMRIMaskFunc): @@ -517,7 +551,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: r"""Creates a vertical equispaced mask that exploits conjugate symmetry. @@ -531,11 +566,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -553,7 +590,8 @@ def mask_func( acs_mask = self.center_mask_func(num_cols, num_low_freqs) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask # adjust acceleration rate based on target acceleration. adjusted_target_cols_to_sample = target_cols_to_sample - num_low_freqs @@ -582,7 +620,8 @@ def mask_func( mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative))) mask = np.logical_or(mask, acs_mask) - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask class CartesianMagicMaskFunc(FastMRIMagicMaskFunc): @@ -627,7 +666,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: r"""Creates an equispaced Cartesian mask that exploits conjugate symmetry. Parameters @@ -640,11 +680,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. + return_acceleration : bool + If True, output will contain acceleration and num_center_lines. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and num_center_lines. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -661,7 +703,8 @@ def mask_func( acs_mask = self.center_mask_func(num_cols, num_center_lines) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) + return (torch_mask, acceleration, num_center_lines) if return_acceleration else torch_mask # adjust acceleration rate based on target acceleration. adjusted_target_cols_to_sample = target_cols_to_sample - num_center_lines @@ -690,7 +733,8 @@ def mask_func( mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative))) mask = np.logical_or(mask, acs_mask) - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return (torch_mask, acceleration, num_center_lines) if return_acceleration else torch_mask class CalgaryCampinasMaskFunc(BaseMaskFunc): @@ -731,10 +775,11 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: r"""Downloads and loads pre-computed Poisson masks. - Currently supports shapes of :math`218 \times 170/174/180` and acceleration factors of `5` or `10`. + Currently, supports shapes of :math`218 \times 170/174/180` and acceleration factors of `5` or `10`. Parameters ---------- @@ -746,28 +791,37 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ shape = tuple(shape)[:-1] - if return_acs: - return torch.from_numpy(self.circular_centered_mask(shape, 18)) if shape not in self.shapes: raise ValueError(f"No mask of shape {shape} is available in the CalgaryCampinas dataset.") with temp_seed(self.rng, seed): acceleration = self.choose_acceleration() + + acs_mask = self.circular_centered_mask(shape, 18) + center_fraction = acs_mask.sum() / np.prod(acs_mask.shape) + + if return_acs: + torch_mask = torch.from_numpy(acs_mask) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask + masks = self.masks[acceleration] mask, num_masks = masks[shape] # Randomly pick one example choice = self.rng.randint(0, num_masks) - return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis]) + torch_mask = torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis]) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask def __load_masks(self, acceleration): masks_path = DIRECT_CACHE_DIR / "calgary_campinas_masks" @@ -954,7 +1008,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Produces :class:`CIRCUSMaskFunc` sampling masks. Parameters @@ -967,11 +1022,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ if len(shape) < 3: @@ -993,10 +1050,14 @@ def mask_func( acceleration=acceleration, ) + acs_mask = self.circular_centered_mask(mask).unsqueeze(0).unsqueeze(-1) + center_fraction = acs_mask.sum() / np.prod(acs_mask.shape) + if return_acs: - return self.circular_centered_mask(mask).unsqueeze(0).unsqueeze(-1) + return (acs_mask, acceleration, center_fraction) if return_acceleration else acs_mask - return mask.unsqueeze(0).unsqueeze(-1) + mask = mask.unsqueeze(0).unsqueeze(-1) + return (mask, acceleration, center_fraction) if return_acceleration else mask class RadialMaskFunc(CIRCUSMaskFunc): @@ -1099,7 +1160,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Produces variable Density Poisson sampling masks. Parameters @@ -1112,11 +1174,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask of shape (1, shape[0], shape[1], 1). + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -1124,12 +1188,17 @@ def mask_func( with temp_seed(self.rng, seed): center_fraction, acceleration = self.choose_acceleration() + if return_acs: - return torch.from_numpy( + acs_mask = torch.from_numpy( centered_disk_mask((num_rows, num_cols), center_fraction)[np.newaxis, ..., np.newaxis] ).bool() + return (acs_mask, acceleration, center_fraction) if return_acceleration else acs_mask + mask = self.poisson(num_rows, num_cols, center_fraction, acceleration, integerize_seed(seed)) - return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + + torch_mask = torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask def poisson( self, @@ -1223,7 +1292,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Creates a vertical gaussian mask. Parameters @@ -1236,12 +1306,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. - + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -1255,15 +1326,22 @@ def mask_func( mask = self.center_mask_func(num_cols, num_low_freqs).astype(int) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask).astype(bool)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask # Calls cython function nonzero_count = int(np.round(num_cols / acceleration - num_low_freqs - 1)) gaussian_mask_1d( - nonzero_count, num_cols, num_cols // 2, 4 * np.sqrt(num_cols // 2), mask, integerize_seed(seed) + nonzero_count, + num_cols, + num_cols // 2, + 6 * np.sqrt(num_cols // 2), + mask, + integerize_seed(seed), ) - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask).astype(bool)) + torch_mask = torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask).astype(bool)) + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask class Gaussian2DMaskFunc(BaseMaskFunc): @@ -1286,7 +1364,8 @@ def mask_func( shape: Union[List[int], Tuple[int, ...]], return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: + return_acceleration: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float, float]]: """Creates a 2D gaussian mask. Parameters @@ -1299,12 +1378,13 @@ def mask_func( seed: int or iterable of ints or None (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. - + return_acceleration : bool + If True, output will contain acceleration and center_fraction. Default: False. Returns ------- - mask: torch.Tensor - The sampling mask. + Union[torch.Tensor, Tuple[torch.Tensor, float, float]] + The sampling mask, and optionally, acceleration and center_fraction. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -1316,16 +1396,25 @@ def mask_func( mask = centered_disk_mask((num_rows, num_cols), center_fraction) if return_acs: - return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + torch_mask = torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask # Calls cython function nonzero_count = int(np.round(num_cols * num_rows / acceleration - mask.sum() - 1)) std = 6 * np.array([np.sqrt(num_rows // 2), np.sqrt(num_cols // 2)], dtype=float) gaussian_mask_2d( - nonzero_count, num_rows, num_cols, num_rows // 2, num_cols // 2, std, mask, integerize_seed(seed) + nonzero_count, + num_rows, + num_cols, + num_rows // 2, + num_cols // 2, + std, + mask, + integerize_seed(seed), ) - return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + torch_mask = torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + return (torch_mask, acceleration, center_fraction) if return_acceleration else torch_mask def integerize_seed(seed: Union[None, Tuple[int, ...], List[int]]) -> int: diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index 19bb5b82..34518427 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -325,23 +325,42 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: if sample["kspace"].ndim == 5 and self.dynamic_mask: nz = sample["kspace"].shape[1] # Number of time frames or slices sampling_mask = [] + acceleration = [] + center_fraction = [] if self.return_acs: acs_mask = [] + seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) + + if seed: + np.random.seed(seed) + dynamic_seeds = [int(_) for _ in np.random.randint(0, 10000, nz)] + else: + dynamic_seeds = [None for _ in range(nz)] + for i in range(nz): - seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) - if seed: - seed = tuple(element + i for element in seed) - sampling_mask.append( - self.mask_func(shape=shape, seed=seed, return_acs=False).to(sample["kspace"].dtype) + sampling_mask_z, acceleration_z, center_fraction_z = self.mask_func( + shape=shape, seed=seed, return_acs=False, return_acceleration=True ) + + sampling_mask.append(sampling_mask_z.to(sample["kspace"].dtype)) + + acceleration.append(acceleration_z) + center_fraction.append(center_fraction_z) + if self.return_acs: - acs_mask.append(self.mask_func(shape=shape, seed=seed, return_acs=True).to(sample["kspace"].dtype)) + acs_mask.append( + self.mask_func(shape=shape, seed=dynamic_seeds[i], return_acs=True).to(sample["kspace"].dtype) + ) + sampling_mask = torch.stack(sampling_mask, dim=1) if self.return_acs: acs_mask = torch.stack(acs_mask, dim=1) else: seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) - sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False).to(sample["kspace"].dtype) + + sampling_mask, acceleration, center_fraction = self.mask_func( + shape=shape, seed=seed, return_acs=False, return_acceleration=True + ) if sample["kspace"].ndim == 5: sampling_mask = sampling_mask.unsqueeze(1) @@ -351,8 +370,15 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: if sample["kspace"].ndim == 5: acs_mask = acs_mask.unsqueeze(1) + acceleration = [acceleration] + center_fraction = [center_fraction] + # Shape (1, [1 or nz], height, width, 1) - sample["sampling_mask"] = sampling_mask + sample["sampling_mask"] = sampling_mask.to(sample["kspace"].dtype) + + sample["acceleration"] = torch.tensor(acceleration, dtype=sample["kspace"].dtype) + sample["center_fraction"] = torch.tensor(center_fraction, dtype=sample["kspace"].dtype) + if self.return_acs: sample["acs_mask"] = acs_mask return sample @@ -2243,7 +2269,9 @@ def build_mri_transforms( mri_transforms += [ ComputeScalingFactor( - normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.SCALING_FACTOR + normalize_key=scaling_key, + percentile=scale_percentile, + scaling_factor_key=TransformKey.SCALING_FACTOR, ), Normalize(scaling_factor_key=TransformKey.SCALING_FACTOR), ]