Skip to content

Commit

Permalink
6676 port generative networks vqvae (#7285)
Browse files Browse the repository at this point in the history
Partially fixes #6676

### Description

Implements the VQ-VAE network, including the vector quantizer block,
from MONAI Generative.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
Signed-off-by: Mark Graham <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: KumoLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 7, 2023
1 parent fac754d commit b3fdfdd
Show file tree
Hide file tree
Showing 9 changed files with 1,085 additions and 8 deletions.
13 changes: 13 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ N-Dim Fourier Transform
.. autofunction:: monai.networks.blocks.fft_utils_t.fftshift
.. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift


Layers
------

Expand Down Expand Up @@ -408,6 +409,13 @@ Layers
.. autoclass:: LLTM
:members:

`Vector Quantizer`
~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.networks.layers.vector_quantizer.EMAQuantizer
:members:
.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer
:members:

`Utilities`
~~~~~~~~~~~
.. automodule:: monai.networks.layers.convutils
Expand Down Expand Up @@ -728,6 +736,11 @@ Nets

.. autoclass:: voxelmorph

`VQ-VAE`
~~~~~~~~
.. autoclass:: VQVAE
:members:

Utilities
---------
.. automodule:: monai.networks.utils
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _download_from_ngc(

def _get_latest_bundle_version_monaihosting(name):
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
full_url = f"{url}/{name}"
full_url = f"{url}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
Expand Down
1 change: 1 addition & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
from .vector_quantizer import EMAQuantizer, VectorQuantizer
from .weight_init import _no_grad_trunc_normal_, trunc_normal_
233 changes: 233 additions & 0 deletions monai/networks/layers/vector_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Sequence, Tuple

import torch
from torch import nn

__all__ = ["VectorQuantizer", "EMAQuantizer"]


class EMAQuantizer(nn.Module):
"""
Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural
Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation
that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit
58d9a2746493717a7c9252938da7efa6006f3739.
This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due
to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353
on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False.
Args:
spatial_dims: number of spatial dimensions of the input.
num_embeddings: number of atomic elements in the codebook.
embedding_dim: number of channels of the input and atomic elements.
commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25.
decay: EMA decay. Defaults to 0.99.
epsilon: epsilon value. Defaults to 1e-5.
embedding_init: initialization method for the codebook. Defaults to "normal".
ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
"""

def __init__(
self,
spatial_dims: int,
num_embeddings: int,
embedding_dim: int,
commitment_cost: float = 0.25,
decay: float = 0.99,
epsilon: float = 1e-5,
embedding_init: str = "normal",
ddp_sync: bool = True,
):
super().__init__()
self.spatial_dims: int = spatial_dims
self.embedding_dim: int = embedding_dim
self.num_embeddings: int = num_embeddings

assert self.spatial_dims in [2, 3], ValueError(
f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}."
)

self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim)
if embedding_init == "normal":
# Initialization is passed since the default one is normal inside the nn.Embedding
pass
elif embedding_init == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear")
self.embedding.weight.requires_grad = False

self.commitment_cost: float = commitment_cost

self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings))
self.register_buffer("ema_w", self.embedding.weight.data.clone())
# declare types for mypy
self.ema_cluster_size: torch.Tensor
self.ema_w: torch.Tensor
self.decay: float = decay
self.epsilon: float = epsilon

self.ddp_sync: bool = ddp_sync

# Precalculating required permutation shapes
self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1]
self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list(
range(1, self.spatial_dims + 1)
)

def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.
Args:
inputs: Encoding space tensors of shape [B, C, H, W, D].
Returns:
torch.Tensor: Flatten version of the input of shape [B*H*W*D, C].
torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings].
torch.Tensor: Quantization indices of shape [B,H,W,D,1]
"""
with torch.cuda.amp.autocast(enabled=False):
encoding_indices_view = list(inputs.shape)
del encoding_indices_view[1]

inputs = inputs.float()

# Converting to channel last format
flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)

# Calculate Euclidean distances
distances = (
(flat_input**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
- 2 * torch.mm(flat_input, self.embedding.weight.t())
)

# Mapping distances to indexes
encoding_indices = torch.max(-distances, dim=1)[1]
encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()

# Quantize and reshape
encoding_indices = encoding_indices.view(encoding_indices_view)

return flat_input, encodings, encoding_indices

def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
"""
Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
[B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the
decoder.
Args:
embedding_indices: Tensor in channel last format which holds indices referencing atomic
elements from self.embedding
Returns:
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
"""
with torch.cuda.amp.autocast(enabled=False):
embedding: torch.Tensor = (
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
)
return embedding

def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
"""
TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the
example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused
Args:
encodings_sum: The summation of one hot representation of what encoding was used for each
position.
dw: The multiplication of the one hot representation of what encoding was used for each
position with the flattened input.
Returns:
None
"""
if self.ddp_sync and torch.distributed.is_initialized():
torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
else:
pass

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
flat_input, encodings, encoding_indices = self.quantize(inputs)
quantized = self.embed(encoding_indices)

# Use EMA to update the embedding vectors
if self.training:
with torch.no_grad():
encodings_sum = encodings.sum(0)
dw = torch.mm(encodings.t(), flat_input)

if self.ddp_sync:
self.distributed_synchronization(encodings_sum, dw)

self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay))

# Laplace smoothing of the cluster size
n = self.ema_cluster_size.sum()
weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay))
self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1))

# Encoding Loss
loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs)

# Straight Through Estimator
quantized = inputs + (quantized - inputs).detach()

return quantized, loss, encoding_indices


class VectorQuantizer(torch.nn.Module):
"""
Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of
the quantization in their own class.
Args:
quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index
based quantized representation.
"""

def __init__(self, quantizer: EMAQuantizer):
super().__init__()

self.quantizer: EMAQuantizer = quantizer

self.perplexity: torch.Tensor = torch.rand(1)

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
quantized, loss, encoding_indices = self.quantizer(inputs)
# Perplexity calculations
avg_probs = (
torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings)
.float()
.div(encoding_indices.numel())
)

self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

return loss, quantized

def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
return self.quantizer.embed(embedding_indices=embedding_indices)

def quantize(self, encodings: torch.Tensor) -> torch.Tensor:
output = self.quantizer(encodings)
encoding_indices: torch.Tensor = output[2]
return encoding_indices
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@
from .vitautoenc import ViTAutoEnc
from .vnet import VNet
from .voxelmorph import VoxelMorph, VoxelMorphUNet
from .vqvae import VQVAE
14 changes: 7 additions & 7 deletions monai/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class _Upsample(nn.Module):
Convolution-based upsampling layer.
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: number of input channels to the layer.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
"""
Expand Down Expand Up @@ -98,7 +98,7 @@ class _Downsample(nn.Module):
Convolution-based downsampling layer.
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: number of input channels.
"""

Expand Down Expand Up @@ -132,7 +132,7 @@ class _ResBlock(nn.Module):
residual connection between input and output.
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: input channels to the layer.
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
channels is divisible by this number.
Expand Down Expand Up @@ -206,7 +206,7 @@ class _AttentionBlock(nn.Module):
Attention block.
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
num_channels: number of input channels.
num_head_channels: number of channels in each attention head.
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
Expand Down Expand Up @@ -325,7 +325,7 @@ class Encoder(nn.Module):
Convolutional cascade that downsamples the image into a spatial latent space.
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: number of input channels.
channels: sequence of block output channels.
out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
Expand Down Expand Up @@ -463,7 +463,7 @@ class Decoder(nn.Module):
Convolutional cascade upsampling from a spatial latent space into an image space.
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
channels: sequence of block output channels.
in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
out_channels: number of output channels.
Expand Down Expand Up @@ -611,7 +611,7 @@ class AutoencoderKL(nn.Module):
and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: number of input channels.
out_channels: number of output channels.
num_res_blocks: number of residual blocks (see _ResBlock) per level.
Expand Down
Loading

0 comments on commit b3fdfdd

Please sign in to comment.