From 6f5cea3795278c20568261ebb38907d94b576b98 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 3 Aug 2023 09:07:40 -0500 Subject: [PATCH] 6676 port losses from monai-generative (#6729) Work towards addressing issue #6676 ### Description This PR ports spectral, perceptual and patch adversial losses from [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/installation.md | 7 +- docs/source/losses.rst | 15 ++ monai/losses/__init__.py | 3 + monai/losses/adversarial_loss.py | 173 ++++++++++++++ monai/losses/perceptual.py | 394 +++++++++++++++++++++++++++++++ monai/losses/spectral_loss.py | 88 +++++++ requirements-dev.txt | 1 + setup.cfg | 3 + tests/min_tests.py | 1 + tests/test_adversarial_loss.py | 93 ++++++++ tests/test_perceptual_loss.py | 87 +++++++ tests/test_spectral_loss.py | 86 +++++++ 12 files changed, 948 insertions(+), 3 deletions(-) create mode 100644 monai/losses/adversarial_loss.py create mode 100644 monai/losses/perceptual.py create mode 100644 monai/losses/spectral_loss.py create mode 100644 tests/test_adversarial_loss.py create mode 100644 tests/test_perceptual_loss.py create mode 100644 tests/test_spectral_loss.py diff --git a/docs/source/installation.md b/docs/source/installation.md index bc79040546..6d63fbf08f 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,11 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips] ``` -which correspond to `nibabel`, `scikit-image`, `scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively. +which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr` and `lpips` respectively. + - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 0f262894cf..39f1d0e4d1 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -99,6 +99,21 @@ Reconstruction Losses .. autoclass:: monai.losses.ssim_loss.SSIMLoss :members: +`PatchAdversarialLoss` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: PatchAdversarialLoss + :members: + +`PerceptualLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: PerceptualLoss + :members: + +`JukeboxLoss` +~~~~~~~~~~~~~~ +.. autoclass:: JukeboxLoss + :members: + Loss Wrappers ------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index db6b133ef0..75f4d181d0 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss @@ -34,7 +35,9 @@ from .giou_loss import BoxGIoULoss, giou from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss +from .perceptual import PerceptualLoss from .spatial_mask import MaskedLoss +from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py new file mode 100644 index 0000000000..f16fdee564 --- /dev/null +++ b/monai/losses/adversarial_loss.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch +from torch.nn.modules.loss import _Loss + +from monai.networks.layers.utils import get_act_layer +from monai.utils import LossReduction +from monai.utils.enums import StrEnum + + +class AdversarialCriterions(StrEnum): + BCE = "bce" + HINGE = "hinge" + LEAST_SQUARE = "least_squares" + + +class PatchAdversarialLoss(_Loss): + """ + Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator. + Warning: due to the possibility of using different criterions, the output of the discrimination + mustn't be passed to a final activation layer. That is taken care of internally within the loss. + + Args: + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs. + Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs + through an activation layer prior to calling the loss. + no_activation_leastsq: if True, the activation layer in the case of least-squares is removed. + """ + + def __init__( + self, + reduction: LossReduction | str = LossReduction.MEAN, + criterion: str = AdversarialCriterions.LEAST_SQUARE, + no_activation_leastsq: bool = False, + ) -> None: + super().__init__(reduction=LossReduction(reduction)) + + if criterion.lower() not in list(AdversarialCriterions): + raise ValueError( + "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" + % ", ".join(AdversarialCriterions) + ) + + # Depending on the criterion, a different activation layer is used. + self.real_label = 1.0 + self.fake_label = 0.0 + self.loss_fct: _Loss + if criterion == AdversarialCriterions.BCE: + self.activation = get_act_layer("SIGMOID") + self.loss_fct = torch.nn.BCELoss(reduction=reduction) + elif criterion == AdversarialCriterions.HINGE: + self.activation = get_act_layer("TANH") + self.fake_label = -1.0 + elif criterion == AdversarialCriterions.LEAST_SQUARE: + if no_activation_leastsq: + self.activation = None + else: + self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) + self.loss_fct = torch.nn.MSELoss(reduction=reduction) + + self.criterion = criterion + self.reduction = reduction + + def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor: + """ + Gets the ground truth tensor for the discriminator depending on whether the input is real or fake. + + Args: + input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale + discriminator). This is used to match the shape. + target_is_real: whether the input is real or wannabe-real (1s) or fake (0s). + Returns: + """ + filling_label = self.real_label if target_is_real else self.fake_label + label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device) + label_tensor.requires_grad_(False) + return label_tensor.expand_as(input) + + def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor: + """ + Gets a zero tensor. + + Args: + input: tensor which shape you want the zeros tensor to correspond to. + Returns: + """ + + zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device) + zero_label_tensor.requires_grad_(False) + return zero_label_tensor.expand_as(input) + + def forward( + self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool + ) -> torch.Tensor | list[torch.Tensor]: + """ + + Args: + input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors + or a tensor; they shouldn't have gone through an activation layer. + target_is_real: whereas the input corresponds to discriminator output for real or fake images + for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last + case, target_is_real is set to True, as the generator wants the input to be dimmed as real. + Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale + discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the + summed or mean loss over the tensor and discriminator/s. + + """ + + if not for_discriminator and not target_is_real: + target_is_real = True # With generator, we always want this to be true! + warnings.warn( + "Variable target_is_real has been set to False, but for_discriminator is set" + "to False. To optimise a generator, target_is_real must be set to True." + ) + + if type(input) is not list: + input = [input] + target_ = [] + for _, disc_out in enumerate(input): + if self.criterion != AdversarialCriterions.HINGE: + target_.append(self.get_target_tensor(disc_out, target_is_real)) + else: + target_.append(self.get_zero_tensor(disc_out)) + + # Loss calculation + loss_list = [] + for disc_ind, disc_out in enumerate(input): + if self.activation is not None: + disc_out = self.activation(disc_out) + if self.criterion == AdversarialCriterions.HINGE and not target_is_real: + loss_ = self._forward_single(-disc_out, target_[disc_ind]) + else: + loss_ = self._forward_single(disc_out, target_[disc_ind]) + loss_list.append(loss_) + + loss: torch.Tensor | list[torch.Tensor] + if loss_list is not None: + if self.reduction == LossReduction.MEAN: + loss = torch.mean(torch.stack(loss_list)) + elif self.reduction == LossReduction.SUM: + loss = torch.sum(torch.stack(loss_list)) + else: + loss = loss_list + return loss + + def _forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + forward: torch.Tensor + if self.criterion == AdversarialCriterions.BCE or self.criterion == AdversarialCriterions.LEAST_SQUARE: + forward = self.loss_fct(input, target) + elif self.criterion == AdversarialCriterions.HINGE: + minval = torch.min(input - 1, self.get_zero_tensor(input)) + forward = -torch.mean(minval) + return forward diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py new file mode 100644 index 0000000000..2207de5e64 --- /dev/null +++ b/monai/losses/perceptual.py @@ -0,0 +1,394 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch +import torch.nn as nn + +from monai.utils import optional_import +from monai.utils.enums import StrEnum + +LPIPS, _ = optional_import("lpips", name="LPIPS") +torchvision, _ = optional_import("torchvision") + + +class PercetualNetworkType(StrEnum): + alex = "alex" + vgg = "vgg" + squeeze = "squeeze" + radimagenet_resnet50 = "radimagenet_resnet50" + medicalnet_resnet10_23datasets = "medicalnet_resnet10_23datasets" + medical_resnet50_23datasets = "medical_resnet50_23datasets" + resnet50 = "resnet50" + + +class PerceptualLoss(nn.Module): + """ + Perceptual loss using features from pretrained deep neural networks trained. The function supports networks + pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep + features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An + Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning" + https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. "Med3D: Transfer Learning for + 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ; + and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html . + + The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all + three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss. + + Args: + spatial_dims: number of spatial dimensions. + network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``, + ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``} + Specifies the network architecture to use. Defaults to ``"alex"``. + is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. + fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. + cache_dir: path to cache directory to save the pretrained network weights. + pretrained: whether to load pretrained weights. This argument only works when using networks from + LIPIS or Torchvision. Defaults to ``"True"``. + pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded + via using this argument. This argument only works when ``"network_type"`` is "resnet50". + Defaults to `None`. + pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to + extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50". + Defaults to `None`. + """ + + def __init__( + self, + spatial_dims: int, + network_type: str = PercetualNetworkType.alex, + is_fake_3d: bool = True, + fake_3d_ratio: float = 0.5, + cache_dir: str | None = None, + pretrained: bool = True, + pretrained_path: str | None = None, + pretrained_state_dict_key: str | None = None, + ): + super().__init__() + + if spatial_dims not in [2, 3]: + raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") + + if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``." + "Argument is_fake_3d must be set to False." + ) + + if network_type.lower() not in list(PercetualNetworkType): + raise ValueError( + "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" + % ", ".join(PercetualNetworkType) + ) + + if cache_dir: + torch.hub.set_dir(cache_dir) + # raise a warning that this may change the default cache dir for all torch.hub calls + warnings.warn( + f"Setting cache_dir to {cache_dir}, this may change the default cache dir for all torch.hub calls." + ) + + self.spatial_dims = spatial_dims + self.perceptual_function: nn.Module + if spatial_dims == 3 and is_fake_3d is False: + self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) + elif "radimagenet_" in network_type: + self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) + elif network_type == "resnet50": + self.perceptual_function = TorchvisionModelPerceptualSimilarity( + net=network_type, + pretrained=pretrained, + pretrained_path=pretrained_path, + pretrained_state_dict_key=pretrained_state_dict_key, + ) + else: + self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) + self.is_fake_3d = is_fake_3d + self.fake_3d_ratio = fake_3d_ratio + + def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatial_axis: int) -> torch.Tensor: + """ + Calculate perceptual loss in one of the axis used in the 2.5D approach. After the slices of one spatial axis + is transformed into different instances in the batch, we compute the loss using the 2D approach. + + Args: + input: input 5D tensor. BNHWD + target: target 5D tensor. BNHWD + spatial_axis: spatial axis to obtain the 2D slices. + """ + + def batchify_axis(x: torch.Tensor, fake_3d_perm: tuple) -> torch.Tensor: + """ + Transform slices from one spatial axis into different instances in the batch. + """ + slices = x.float().permute((0,) + fake_3d_perm).contiguous() + slices = slices.view(-1, x.shape[fake_3d_perm[1]], x.shape[fake_3d_perm[2]], x.shape[fake_3d_perm[3]]) + + return slices + + preserved_axes = [2, 3, 4] + preserved_axes.remove(spatial_axis) + + channel_axis = 1 + input_slices = batchify_axis(x=input, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes)) + indices = torch.randperm(input_slices.shape[0])[: int(input_slices.shape[0] * self.fake_3d_ratio)].to( + input_slices.device + ) + input_slices = torch.index_select(input_slices, dim=0, index=indices) + target_slices = batchify_axis(x=target, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes)) + target_slices = torch.index_select(target_slices, dim=0, index=indices) + + axis_loss = torch.mean(self.perceptual_function(input_slices, target_slices)) + + return axis_loss + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNHW[D]. + target: the shape should be BNHW[D]. + """ + if target.shape != input.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + if self.spatial_dims == 3 and self.is_fake_3d: + # Compute 2.5D approach + loss_sagittal = self._calculate_axis_loss(input, target, spatial_axis=2) + loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3) + loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4) + loss = loss_sagittal + loss_axial + loss_coronal + else: + # 2D and real 3D cases + loss = self.perceptual_function(input, target) + + return torch.mean(loss) + + +class MedicalNetPerceptualSimilarity(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer + Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from + "Warvito/MedicalNet-models". + + Args: + net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} + Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: + super().__init__() + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose) + self.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the + pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across + the channels. Finally, we compute the difference between the input and target features and calculate the mean + value from the spatial dimensions to obtain the perceptual loss. + + Args: + input: 3D input tensor with shape BCDHW. + target: 3D target tensor with shape BCDHW. + """ + input = medicalnet_intensity_normalisation(input) + target = medicalnet_intensity_normalisation(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results: torch.Tensor = (feats_input - feats_target) ** 2 + results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3, 4], keepdim=keepdim) + + +def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def medicalnet_intensity_normalisation(volume): + """Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133""" + mean = volume.mean() + std = volume.std() + return (volume - mean) / std + + +class RadImageNetPerceptualSimilarity(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et + al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class + uses torch Hub to download the networks from "Warvito/radimagenet-models". + + Args: + net: {``"radimagenet_resnet50"``} + Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: + super().__init__() + self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) + self.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at + https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from + 'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised + across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package). + """ + # If input has just 1 channel, repeat channel to have 3 channels + if input.shape[1] == 1 and target.shape[1] == 1: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + # Change order from 'RGB' to 'BGR' + input = input[:, [2, 1, 0], ...] + target = target[:, [2, 1, 0], ...] + + # Subtract mean used during training + input = subtract_mean(input) + target = subtract_mean(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results: torch.Tensor = (feats_input - feats_target) ** 2 + results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +class TorchvisionModelPerceptualSimilarity(nn.Module): + """ + Component to perform the perceptual evaluation with TorchVision models. + Currently, only ResNet50 is supported. The network structure is based on: + https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html + + Args: + net: {``"resnet50"``} + Specifies the network architecture to use. Defaults to ``"resnet50"``. + pretrained: whether to load pretrained weights. Defaults to `True`. + pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded + via using this argument. Defaults to `None`. + pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to + extract the expected state dict. Defaults to `None`. + """ + + def __init__( + self, + net: str = "resnet50", + pretrained: bool = True, + pretrained_path: str | None = None, + pretrained_state_dict_key: str | None = None, + ) -> None: + super().__init__() + supported_networks = ["resnet50"] + if net not in supported_networks: + raise NotImplementedError( + f"'net' {net} is not supported, please select a network from {supported_networks}." + ) + + if pretrained_path is None: + network = torchvision.models.resnet50( + weights=torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None + ) + else: + network = torchvision.models.resnet50(weights=None) + if pretrained is True: + state_dict = torch.load(pretrained_path) + if pretrained_state_dict_key is not None: + state_dict = state_dict[pretrained_state_dict_key] + network.load_state_dict(state_dict) + self.final_layer = "layer4.2.relu_2" + self.model = torchvision.models.feature_extraction.create_feature_extractor(network, [self.final_layer]) + self.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at + https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights, + we make sure that the input and target have 3 channels, and then do Z-Score normalization. + The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar + approach to the lpips package). + """ + # If input has just 1 channel, repeat channel to have 3 channels + if input.shape[1] == 1 and target.shape[1] == 1: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + # Input normalization + input = torchvision_zscore_norm(input) + target = torchvision_zscore_norm(target) + + # Get model outputs + outs_input = self.model.forward(input)[self.final_layer] + outs_target = self.model.forward(target)[self.final_layer] + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results: torch.Tensor = (feats_input - feats_target) ** 2 + results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3], keepdim=keepdim) + + +def torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor: + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0] + x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1] + x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2] + return x + + +def subtract_mean(x: torch.Tensor) -> torch.Tensor: + mean = [0.406, 0.456, 0.485] + x[:, 0, :, :] -= mean[0] + x[:, 1, :, :] -= mean[1] + x[:, 2, :, :] -= mean[2] + return x diff --git a/monai/losses/spectral_loss.py b/monai/losses/spectral_loss.py new file mode 100644 index 0000000000..06714f3993 --- /dev/null +++ b/monai/losses/spectral_loss.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch.fft import fftn +from torch.nn.modules.loss import _Loss + +from monai.utils import LossReduction + + +class JukeboxLoss(_Loss): + """ + Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT). + + Based on: + Dhariwal, et al. 'Jukebox: A generative model for music.' https://arxiv.org/abs/2005.00341 + + Args: + spatial_dims: number of spatial dimensions. + fft_signal_size: signal size in the transformed dimensions. See torch.fft.fftn() for more information. + fft_norm: {``"forward"``, ``"backward"``, ``"ortho"``} Specifies the normalization mode in the fft. See + torch.fft.fftn() for more information. + + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + + def __init__( + self, + spatial_dims: int, + fft_signal_size: tuple[int] | None = None, + fft_norm: str = "ortho", + reduction: LossReduction | str = LossReduction.MEAN, + ) -> None: + super().__init__(reduction=LossReduction(reduction).value) + + self.spatial_dims = spatial_dims + self.fft_signal_size = fft_signal_size + self.fft_dim = tuple(range(1, spatial_dims + 2)) + self.fft_norm = fft_norm + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + input_amplitude = self._get_fft_amplitude(target) + target_amplitude = self._get_fft_amplitude(input) + + # Compute distance between amplitude of frequency components + # See Section 3.3 from https://arxiv.org/abs/2005.00341 + loss = F.mse_loss(target_amplitude, input_amplitude, reduction="none") + + if self.reduction == LossReduction.MEAN.value: + loss = loss.mean() + elif self.reduction == LossReduction.SUM.value: + loss = loss.sum() + elif self.reduction == LossReduction.NONE.value: + pass + + return loss + + def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor: + """ + Calculate the amplitude of the fourier transformations representation of the images + + Args: + images: Images that are to undergo fftn + + Returns: + fourier transformation amplitude + """ + img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm) + + amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2) + + return amplitude diff --git a/requirements-dev.txt b/requirements-dev.txt index 0ad08e56d2..d419cd2467 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -53,3 +53,4 @@ onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr +lpips==0.1.4 diff --git a/setup.cfg b/setup.cfg index a61a42395f..65d5bce2c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -81,6 +81,7 @@ all = onnx>=1.13.0 onnxruntime; python_version <= '3.10' zarr + lpips==0.1.4 nibabel = nibabel ninja = @@ -148,6 +149,8 @@ onnx = onnxruntime; python_version <= '3.10' zarr = zarr +lpips = + lpips==0.1.4 # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded diff --git a/tests/min_tests.py b/tests/min_tests.py index 9a7d920a2e..4c4e374311 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -204,6 +204,7 @@ def run_testsuit(): "test_spatial_combine_transforms", "test_bundle_workflow", "test_zarr_avg_merger", + "test_perceptual_loss", "test_ultrasound_confidence_map_transform", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_adversarial_loss.py b/tests/test_adversarial_loss.py new file mode 100644 index 0000000000..77880725ec --- /dev/null +++ b/tests/test_adversarial_loss.py @@ -0,0 +1,93 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.losses import PatchAdversarialLoss + +shapes_tensors = {"2d": [4, 1, 64, 64], "3d": [4, 1, 64, 64, 64]} +reductions = ["sum", "mean"] +criterion = ["bce", "least_squares", "hinge"] + +TEST_CASE_CREATION_FAIL = [{"reduction": "sum", "criterion": "invalid"}] + +TEST_CASES_LOSS_LOGIC_2D = [] +TEST_CASES_LOSS_LOGIC_3D = [] + +for c in criterion: + for r in reductions: + TEST_CASES_LOSS_LOGIC_2D.append([{"reduction": r, "criterion": c}, shapes_tensors["2d"]]) + TEST_CASES_LOSS_LOGIC_3D.append([{"reduction": r, "criterion": c}, shapes_tensors["3d"]]) + +TEST_CASES_LOSS_LOGIC_LIST = [] +for c in criterion: + TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["2d"]]) + TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["3d"]]) + + +class TestPatchAdversarialLoss(unittest.TestCase): + def get_input(self, shape, is_positive): + """ + Get tensor for the tests. The tensor is around (-1) or (+1), depending on + is_positive. + """ + if is_positive: + offset = 1 + else: + offset = -1 + return torch.ones(shape) * (offset) + 0.01 * torch.randn(shape) + + def test_criterion(self): + """ + Make sure that unknown criterion fail. + """ + with self.assertRaises(ValueError): + PatchAdversarialLoss(**TEST_CASE_CREATION_FAIL[0]) + + @parameterized.expand(TEST_CASES_LOSS_LOGIC_2D + TEST_CASES_LOSS_LOGIC_3D) + def test_loss_logic(self, input_param: dict, shape_input: list): + """ + We want to make sure that the adversarial losses do what they should. + If the discriminator takes in a tensor that looks positive, yet the label is fake, + the loss should be bigger than that obtained with a tensor that looks negative. + Same for the real label, and for the generator. + """ + loss = PatchAdversarialLoss(**input_param) + fakes = self.get_input(shape_input, is_positive=False) + reals = self.get_input(shape_input, is_positive=True) + # Discriminator: fake label + loss_disc_f_f = loss(fakes, target_is_real=False, for_discriminator=True) + loss_disc_f_r = loss(reals, target_is_real=False, for_discriminator=True) + assert loss_disc_f_f < loss_disc_f_r + # Discriminator: real label + loss_disc_r_f = loss(fakes, target_is_real=True, for_discriminator=True) + loss_disc_r_r = loss(reals, target_is_real=True, for_discriminator=True) + assert loss_disc_r_f > loss_disc_r_r + # Generator: + loss_gen_f = loss(fakes, target_is_real=True, for_discriminator=False) # target_is_real is overridden + loss_gen_r = loss(reals, target_is_real=True, for_discriminator=False) # target_is_real is overridden + assert loss_gen_f > loss_gen_r + + @parameterized.expand(TEST_CASES_LOSS_LOGIC_LIST) + def test_multiple_discs(self, input_param: dict, shape_input): + shapes = [shape_input] + [shape_input[0:2] + [int(i / j) for i in shape_input[2:]] for j in range(1, 3)] + inputs = [self.get_input(shapes[i], is_positive=True) for i in range(len(shapes))] + loss = PatchAdversarialLoss(**input_param) + assert len(loss(inputs, for_discriminator=True, target_is_real=True)) == 3 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py new file mode 100644 index 0000000000..2f807d8222 --- /dev/null +++ b/tests/test_perceptual_loss.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.losses import PerceptualLoss +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion + +_, has_torchvision = optional_import("torchvision") +TEST_CASES = [ + [{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)], + [ + {"spatial_dims": 3, "network_type": "squeeze", "is_fake_3d": True, "fake_3d_ratio": 0.1}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], + [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 1, 64, 64), (2, 1, 64, 64)], + [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 3, 64, 64), (2, 3, 64, 64)], + [ + {"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@unittest.skipUnless(has_torchvision, "Requires torchvision") +class TestPerceptualLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, target_shape): + loss = PerceptualLoss(**input_param) + result = loss(torch.randn(input_shape), torch.randn(target_shape)) + self.assertEqual(result.shape, torch.Size([])) + + @parameterized.expand(TEST_CASES) + def test_identical_input(self, input_param, input_shape, target_shape): + loss = PerceptualLoss(**input_param) + tensor = torch.randn(input_shape) + result = loss(tensor, tensor) + self.assertEqual(result, torch.Tensor([0.0])) + + def test_different_shape(self): + loss = PerceptualLoss(spatial_dims=2, network_type="squeeze") + tensor = torch.randn(2, 1, 64, 64) + target = torch.randn(2, 1, 32, 32) + with self.assertRaises(ValueError): + loss(tensor, target) + + def test_1d(self): + with self.assertRaises(NotImplementedError): + PerceptualLoss(spatial_dims=1) + + def test_medicalnet_on_2d_data(self): + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets") + + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spectral_loss.py b/tests/test_spectral_loss.py new file mode 100644 index 0000000000..21b5c48de4 --- /dev/null +++ b/tests/test_spectral_loss.py @@ -0,0 +1,86 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import JukeboxLoss +from tests.utils import test_script_save + +TEST_CASES = [ + [ + {"spatial_dims": 2}, + { + "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + }, + 0.070648, + ], + [ + {"spatial_dims": 2, "reduction": "sum"}, + { + "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + }, + 0.8478, + ], + [ + {"spatial_dims": 3}, + { + "input": torch.tensor( + [ + [ + [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + ] + ] + ), + "target": torch.tensor( + [ + [ + [[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + ] + ] + ), + }, + 0.03838, + ], +] + + +class TestJukeboxLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_results(self, input_param, input_data, expected_val): + results = JukeboxLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_2d_shape(self): + results = JukeboxLoss(spatial_dims=2, reduction="none").forward(**TEST_CASES[0][1]) + self.assertEqual(results.shape, (1, 2, 2, 3)) + + def test_3d_shape(self): + results = JukeboxLoss(spatial_dims=3, reduction="none").forward(**TEST_CASES[2][1]) + self.assertEqual(results.shape, (1, 2, 2, 2, 3)) + + def test_script(self): + loss = JukeboxLoss(spatial_dims=2) + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + + +if __name__ == "__main__": + unittest.main()