diff --git a/LIST_OF_PAPERS.md b/LIST_OF_PAPERS.md index 6d159beb..bf2421b0 100644 --- a/LIST_OF_PAPERS.md +++ b/LIST_OF_PAPERS.md @@ -14,7 +14,9 @@ The following is a short list of fastMRI publications. Clicking on the title wil 10. Defazio, A., Murrell, T., & Recht, M. P. (2020). [MRI Banding Removal via Adversarial Training](#mri-banding-removal-via-adversarial-training). In *Advances in Neural Information Processing Systems*, 33, pages 7660-7670. 11. Muckley, M. J.\*, Riemenschneider, B.\*, Radmanesh, A., Kim, S., Jeong, G., Ko, J., ... & Knoll, F. (2021). [Results of the 2020 fastMRI Challenge for Machine Learning MR Image Reconstruction](#results-of-the-2020-fastmri-challenge-for-machine-learning-mr-image-reconstruction). *IEEE Transactions on Medical Imaging*, 40(9), pages 2306-2317. 12. Johnson, P. M., Jeong, G., Hammernik, K., Schlemper, J., Qin, C., Duan, J., ..., & Knoll, F. [Evaluation of the Robustness of Learned MR Image Reconstruction to Systematic Deviations Between Training and Test Data for the Models from the fastMRI Challenge](#evaluation-of-the-robustness-of-learned-mr-image-reconstruction-to-systematic-deviations-between-training-and-test-data-for-the-models-from-the-fastmri-challenge). In *MICCAI Machine Learning for Medical Image Reconstruction Workshop*, pages 25–34, 2021. -13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). *Accepted at MIDL, 2022*, to appear. +13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). In *Medical Imaging with Deep Learning*. +14. 14. Radmanesh, A.\*, Muckley, M. J.\*, Murrell, T., Lindsey, E., Sriram, A., Knoll, F., ... & Lui, Y. W. (2022). [Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI](https://doi.org/10.1148/ryai.210313). *Radiology: Artificial Intelligence*, e210313. + ## fastMRI: An open dataset and benchmarks for accelerated MRI @@ -282,4 +284,35 @@ Most current approaches to undersampled multi-coil MRI reconstruction focus on l pages={to appear}, year={2022}, } -``` \ No newline at end of file +``` + +## Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI + +[publication](https://doi.org/10.1148/ryai.210313) + +**Purpose** + +To explore the limits of deep learning-based brain MRI reconstruction and identify useful acceleration ranges for general-purpose imaging and potential screening. + +**Materials and Methods** + +In this retrospective study conducted from 2019 through 2021, a model was trained for reconstruction on 5,847 brain MRIs. Performance was evaluated across a wide range of accelerations (up to 100-fold along a single phase-encoded direction for two-dimensional [2D] slices) on the fastMRI test set collected by New York University, consisting of 558 image volumes. In a sample of 69 volumes, reconstructions were classified by radiologists for identifying two clinical thresholds: 1) general-purpose diagnostic imaging and 2) potential use in a screening protocol. A Monte Carlo procedure was developed for estimating reconstruction error with only undersampled data. The model was evaluated on both in-domain and out-of-domain data. Confidence intervals were calculated using the percentile bootstrap method. + +**Results** + +Radiologists rated 100% of 69 volumes as having sufficient image quality for general-purpose imaging at up to 4× acceleration and 65 of 69 (94%) of volumes as having sufficient image quality for screening at up to 14× acceleration. The Monte Carlo procedure estimated ground truth peak signal-to-noise ratio and mean squared error with coefficients of determination greater than 0.5 at all accelerations. Out-of-distribution experiments demonstrated the model’s ability to produce images substantially distinct from the training set, even at 100× acceleration. + +**Conclusion** + +For 2D brain images using deep learning-based reconstruction, maximum acceleration for potential screening was 3–4 times higher than that for diagnostic general-purpose imaging. + +```BibTeX +@article{radmanesh2022exploring, + title={Exploring the Acceleration Limits of Deep Learning {VarNet}-based Two-dimensional Brain {MRI}}, + author={Radmanesh, Alireza and Muckley, Matthew J and Murrell, Tullie and Lindsey, Emma and Sriram, Anuroop and Knoll, Florian and Sodickson, Daniel K and Lui, Yvonne W}, + journal={Radiology: Artificial Intelligence}, + pages={e210313}, + year={2022}, + publisher={Radiological Society of North America} +} +``` diff --git a/README.md b/README.md index d525d96a..6b397a3c 100644 --- a/README.md +++ b/README.md @@ -190,4 +190,5 @@ corresponding abstracts, as well as links to preprints and code can be found 10. Defazio, A., Murrell, T., & Recht, M. P. (2020). [MRI Banding Removal via Adversarial Training](https://papers.nips.cc/paper/2020/hash/567b8f5f423af15818a068235807edc0-Abstract.html). In *Advances in Neural Information Processing Systems*, 33, pages 7660-7670. 11. Muckley, M. J.\*, Riemenschneider, B.\*, Radmanesh, A., Kim, S., Jeong, G., Ko, J., ... & Knoll, F. (2021). [Results of the 2020 fastMRI Challenge for Machine Learning MR Image Reconstruction](https://doi.org/10.1109/TMI.2021.3075856). *IEEE Transactions on Medical Imaging*, 40(9), pages 2306-2317. 12. Johnson, P. M., Jeong, G., Hammernik, K., Schlemper, J., Qin, C., Duan, J., ..., & Knoll, F. (2021). [Evaluation of the Robustness of Learned MR Image Reconstruction to Systematic Deviations Between Training and Test Data for the Models from the fastMRI Challenge](https://doi.org/10.1007/978-3-030-88552-6_3). In *MICCAI MLMIR Workshop*, pages 25–34, -13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). *In MIDL*. Accepted. +13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](https://arxiv.org/abs/2203.16392). In *MIDL*. +14. Radmanesh, A.\*, Muckley, M. J.\*, Murrell, T., Lindsey, E., Sriram, A., Knoll, F., ... & Lui, Y. W. (2022). [Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI](https://doi.org/10.1148/ryai.210313). *Radiology: Artificial Intelligence*, e210313. diff --git a/fastmri/models/__init__.py b/fastmri/models/__init__.py index 863b627e..b7b030ec 100644 --- a/fastmri/models/__init__.py +++ b/fastmri/models/__init__.py @@ -5,6 +5,7 @@ LICENSE file in the root directory of this source tree. """ +from ._zsnet import ZSNet from .adaptive_varnet import AdaptiveVarNet from .policy import StraightThroughPolicy from .unet import Unet diff --git a/fastmri/models/_zsnet.py b/fastmri/models/_zsnet.py new file mode 100644 index 00000000..6fdae51b --- /dev/null +++ b/fastmri/models/_zsnet.py @@ -0,0 +1,319 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import fastmri +from fastmri.data import transforms +from fastmri.models.varnet import SensitivityModel + +from ._zsnet_unet import ZSNetUnet + + +def _create_zero_tensor(tensor_type: torch.Tensor) -> torch.Tensor: + return torch.zeros( + (1, 1, 1, 1, 1), dtype=tensor_type.dtype, device=tensor_type.device + ) + + +class NormUnet(nn.Module): + """ + Normalized U-Net model. + This is the same as a regular U-Net, but with normalization applied to the + input before the U-Net. This keeps the values more numerically stable + during training. + """ + + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + ): + """ + Args: + chans: Number of output channels of the first convolution layer. + num_pools: Number of down-sampling and up-sampling layers. + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + drop_prob: Dropout probability. + """ + super().__init__() + + self.unet = ZSNetUnet( + in_chans=in_chans, + out_chans=out_chans, + chans=chans, + num_pool_layers=num_pools, + drop_prob=drop_prob, + ) + + def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w, two = x.shape + return x.permute(0, 4, 1, 2, 3).reshape(b, two * c, h, w) + + def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: + b, c2, h, w = x.shape + assert c2 % 2 == 0 + c = c2 // 2 + return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() + + def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # group norm + b, c, h, w = x.shape + + x = x.contiguous().view(b, c, c // c * h * w) + mean = ( + x.mean(dim=2) + .view(b, c, 1, 1, 1) + .expand(b, c, c // c, 1, 1) + .contiguous() + .view(b, c, 1, 1) + ) + std = ( + x.std(dim=2) + .view(b, c, 1, 1, 1) + .expand(b, c, c // c, 1, 1) + .contiguous() + .view(b, c, 1, 1) + ) + + x = x.view(b, c, h, w) + + return (x - mean) / std, mean, std + + def unnorm( + self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor + ) -> torch.Tensor: + if not mean.shape[1] == x.shape[1]: + mean = mean[:, :2] + if not std.shape[1] == x.shape[1]: + std = std[:, :2] + return x * std + mean + + def pad( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: + _, _, h, w = x.shape + w_mult = ((w - 1) | 15) + 1 + h_mult = ((h - 1) | 15) + 1 + w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] + h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] + # TODO: fix this type when PyTorch fixes theirs + # the documentation lies - this actually takes a list + # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 + # https://github.com/pytorch/pytorch/pull/16949 + + x = F.pad(x, w_pad + h_pad) + + return x, (h_pad, w_pad, h_mult, w_mult) + + def unpad( + self, + x: torch.Tensor, + h_pad: List[int], + w_pad: List[int], + h_mult: int, + w_mult: int, + ) -> torch.Tensor: + return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # get shapes for unet and normalize + x = self.complex_to_chan_dim(x) + x, pad_sizes = self.pad(x) + x, mean, std = self.norm(x) + + x = self.unet(x.contiguous()) + + # get shapes back and unnormalize + x = self.unnorm(x, mean, std) + x = self.unpad(x, *pad_sizes) + x = self.chan_complex_to_last_dim(x) + + return x + + +class ZSNetSensitivityModel(SensitivityModel): + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + mask_center: bool = True, + ): + super().__init__(chans, num_pools, in_chans, out_chans, drop_prob, mask_center) + # overwrite unet + self.norm_unet = NormUnet( + chans, + num_pools, + in_chans=in_chans, + out_chans=out_chans, + drop_prob=drop_prob, + ) + + +class ZSNet(nn.Module): + def __init__( + self, + image_crop_size: int, + num_concat_cascades: int = 18, + sens_chans: int = 10, + sens_pools: int = 5, + chans: int = 20, + pools: int = 5, + mask_center: bool = False, + ): + """ + Args: + num_cascades: Number of cascades (i.e., layers) for variational + network. + sens_chans: Number of channels for sensitivity map U-Net. + sens_pools Number of downsampling and upsampling layers for + sensitivity map U-Net. + chans: Number of channels for cascade U-Net. + pools: Number of downsampling and upsampling layers for cascade + U-Net. + """ + super().__init__() + + self.sens_net = ZSNetSensitivityModel( + sens_chans, sens_pools, mask_center=mask_center + ) + self.init_layer = ZSNetBlock(NormUnet(chans, pools), crop_size=image_crop_size) + self.cascades = nn.ModuleList() + for i in range(1, num_concat_cascades): + chan_mult = 2 + self.cascades.append( + ZSNetConcatBlock( + NormUnet(chans=chans, num_pools=pools, in_chans=2 * chan_mult), + crop_size=image_crop_size, + ) + ) + + def forward( + self, + masked_kspace: torch.Tensor, + mask: torch.Tensor, + num_low_frequencies: Optional[int] = None, + only_center: bool = False, + ) -> torch.Tensor: + sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) + kspace_pred, image = self.init_layer( + masked_kspace.clone(), masked_kspace, mask, sens_maps + ) + previous_images = [image] + for cascade in self.cascades: + kspace_pred, image = cascade( + kspace_pred, + masked_kspace, + mask, + sens_maps, + previous_images, + ) + + return fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace_pred)), dim=1) + + +class ZSNetBaseBlock(nn.Module): + def __init__(self, model: nn.Module, crop_size: int): + """ + Args: + model: Module for "regularization" component of variational + network. + """ + super().__init__() + + self.model = model + self.crop_size = crop_size + self.dc_weight = nn.Parameter(torch.ones(1)) + + def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) + + def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + return fastmri.complex_mul( + fastmri.ifft2c(x), fastmri.complex_conj(sens_maps) + ).sum(dim=1, keepdim=True) + + def image_crop(self, image: torch.Tensor) -> torch.Tensor: + input_shape = image.shape + crop_size = (min(self.crop_size, input_shape[-3]), input_shape[-2]) + + return transforms.complex_center_crop(image, crop_size) + + def image_uncrop( + self, image: torch.Tensor, original_image: torch.Tensor + ) -> torch.Tensor: + """Insert values back into original image.""" + in_shape = original_image.shape + crop_height = image.shape[-3] + in_height = in_shape[-3] + pad_height = (in_height - crop_height) // 2 + if (in_height - crop_height) % 2 != 0: + pad_height_top = pad_height + 1 + else: + pad_height_top = pad_height + + original_image[..., pad_height_top:-pad_height, :, :] = image[...] # type: ignore + + return original_image + + def apply_model_with_crop(self, image: torch.Tensor) -> torch.Tensor: + if self.crop_size is not None: + image = self.image_uncrop( + self.model(self.image_crop(image)), image[..., :2].clone() + ) + else: + image = self.model(image) + + return image + + +class ZSNetBlock(ZSNetBaseBlock): + def forward( + self, + current_kspace: torch.Tensor, + ref_kspace: torch.Tensor, + mask: torch.Tensor, + sens_maps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + zero = _create_zero_tensor(current_kspace) + soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight + image = self.sens_reduce(current_kspace, sens_maps) + model_term = self.sens_expand(self.apply_model_with_crop(image), sens_maps) + + return current_kspace - soft_dc - model_term, image + + +class ZSNetConcatBlock(ZSNetBaseBlock): + def forward( + self, + current_kspace: torch.Tensor, + ref_kspace: torch.Tensor, + mask: torch.Tensor, + sens_maps: torch.Tensor, + previous_images: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + zero = _create_zero_tensor(current_kspace) + soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight + image = self.sens_reduce(current_kspace, sens_maps) + + model_term = self.sens_expand( + self.apply_model_with_crop(torch.cat([image] + previous_images, dim=-1)), + sens_maps, + ) + return current_kspace - soft_dc - model_term, image diff --git a/fastmri/models/_zsnet_unet.py b/fastmri/models/_zsnet_unet.py new file mode 100644 index 00000000..2b9bcede --- /dev/null +++ b/fastmri/models/_zsnet_unet.py @@ -0,0 +1,185 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class ZSNetUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, + in_chans: int, + out_chans: int, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.shortcut = nn.Conv2d( + in_chans, out_chans, 1, stride=1, padding=0, bias=False + ) + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.shortcut(image) + self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) diff --git a/fastmri/models/varnet.py b/fastmri/models/varnet.py index 92e0b045..7c1e8a8f 100644 --- a/fastmri/models/varnet.py +++ b/fastmri/models/varnet.py @@ -156,7 +156,7 @@ def __init__( """ super().__init__() self.mask_center = mask_center - self.norm_unet = NormUnet( + self.norm_unet: nn.Module = NormUnet( chans, num_pools, in_chans=in_chans,