Skip to content

Commit

Permalink
6676 port losses from monai-generative (#6729)
Browse files Browse the repository at this point in the history
Work towards addressing issue #6676 

### Description

This PR ports spectral, perceptual and patch adversial losses from
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mark Graham <[email protected]>
  • Loading branch information
marksgraham authored Aug 3, 2023
1 parent d6bafc9 commit 6f5cea3
Show file tree
Hide file tree
Showing 12 changed files with 948 additions and 3 deletions.
7 changes: 4 additions & 3 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,11 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips]
```

which correspond to `nibabel`, `scikit-image`, `scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.
which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr` and `lpips` respectively.


- `pip install 'monai[all]'` installs all the optional dependencies.
15 changes: 15 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ Reconstruction Losses
.. autoclass:: monai.losses.ssim_loss.SSIMLoss
:members:

`PatchAdversarialLoss`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: PatchAdversarialLoss
:members:

`PerceptualLoss`
~~~~~~~~~~~~~~~~~
.. autoclass:: PerceptualLoss
:members:

`JukeboxLoss`
~~~~~~~~~~~~~~
.. autoclass:: JukeboxLoss
:members:


Loss Wrappers
-------------
Expand Down
3 changes: 3 additions & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from .adversarial_loss import PatchAdversarialLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss
Expand All @@ -34,7 +35,9 @@
from .giou_loss import BoxGIoULoss, giou
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .perceptual import PerceptualLoss
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
from .ssim_loss import SSIMLoss
from .tversky import TverskyLoss
from .unified_focal_loss import AsymmetricUnifiedFocalLoss
173 changes: 173 additions & 0 deletions monai/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
from torch.nn.modules.loss import _Loss

from monai.networks.layers.utils import get_act_layer
from monai.utils import LossReduction
from monai.utils.enums import StrEnum


class AdversarialCriterions(StrEnum):
BCE = "bce"
HINGE = "hinge"
LEAST_SQUARE = "least_squares"


class PatchAdversarialLoss(_Loss):
"""
Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator.
Warning: due to the possibility of using different criterions, the output of the discrimination
mustn't be passed to a final activation layer. That is taken care of internally within the loss.
Args:
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs.
Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs
through an activation layer prior to calling the loss.
no_activation_leastsq: if True, the activation layer in the case of least-squares is removed.
"""

def __init__(
self,
reduction: LossReduction | str = LossReduction.MEAN,
criterion: str = AdversarialCriterions.LEAST_SQUARE,
no_activation_leastsq: bool = False,
) -> None:
super().__init__(reduction=LossReduction(reduction))

if criterion.lower() not in list(AdversarialCriterions):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
% ", ".join(AdversarialCriterions)
)

# Depending on the criterion, a different activation layer is used.
self.real_label = 1.0
self.fake_label = 0.0
self.loss_fct: _Loss
if criterion == AdversarialCriterions.BCE:
self.activation = get_act_layer("SIGMOID")
self.loss_fct = torch.nn.BCELoss(reduction=reduction)
elif criterion == AdversarialCriterions.HINGE:
self.activation = get_act_layer("TANH")
self.fake_label = -1.0
elif criterion == AdversarialCriterions.LEAST_SQUARE:
if no_activation_leastsq:
self.activation = None
else:
self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05}))
self.loss_fct = torch.nn.MSELoss(reduction=reduction)

self.criterion = criterion
self.reduction = reduction

def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor:
"""
Gets the ground truth tensor for the discriminator depending on whether the input is real or fake.
Args:
input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale
discriminator). This is used to match the shape.
target_is_real: whether the input is real or wannabe-real (1s) or fake (0s).
Returns:
"""
filling_label = self.real_label if target_is_real else self.fake_label
label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device)
label_tensor.requires_grad_(False)
return label_tensor.expand_as(input)

def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor:
"""
Gets a zero tensor.
Args:
input: tensor which shape you want the zeros tensor to correspond to.
Returns:
"""

zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device)
zero_label_tensor.requires_grad_(False)
return zero_label_tensor.expand_as(input)

def forward(
self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool
) -> torch.Tensor | list[torch.Tensor]:
"""
Args:
input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors
or a tensor; they shouldn't have gone through an activation layer.
target_is_real: whereas the input corresponds to discriminator output for real or fake images
for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last
case, target_is_real is set to True, as the generator wants the input to be dimmed as real.
Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale
discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the
summed or mean loss over the tensor and discriminator/s.
"""

if not for_discriminator and not target_is_real:
target_is_real = True # With generator, we always want this to be true!
warnings.warn(
"Variable target_is_real has been set to False, but for_discriminator is set"
"to False. To optimise a generator, target_is_real must be set to True."
)

if type(input) is not list:
input = [input]
target_ = []
for _, disc_out in enumerate(input):
if self.criterion != AdversarialCriterions.HINGE:
target_.append(self.get_target_tensor(disc_out, target_is_real))
else:
target_.append(self.get_zero_tensor(disc_out))

# Loss calculation
loss_list = []
for disc_ind, disc_out in enumerate(input):
if self.activation is not None:
disc_out = self.activation(disc_out)
if self.criterion == AdversarialCriterions.HINGE and not target_is_real:
loss_ = self._forward_single(-disc_out, target_[disc_ind])
else:
loss_ = self._forward_single(disc_out, target_[disc_ind])
loss_list.append(loss_)

loss: torch.Tensor | list[torch.Tensor]
if loss_list is not None:
if self.reduction == LossReduction.MEAN:
loss = torch.mean(torch.stack(loss_list))
elif self.reduction == LossReduction.SUM:
loss = torch.sum(torch.stack(loss_list))
else:
loss = loss_list
return loss

def _forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
forward: torch.Tensor
if self.criterion == AdversarialCriterions.BCE or self.criterion == AdversarialCriterions.LEAST_SQUARE:
forward = self.loss_fct(input, target)
elif self.criterion == AdversarialCriterions.HINGE:
minval = torch.min(input - 1, self.get_zero_tensor(input))
forward = -torch.mean(minval)
return forward
Loading

0 comments on commit 6f5cea3

Please sign in to comment.