From b3fdfdd2111c5d1349a345fbd4e24c570d1fb690 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 6 Dec 2023 22:36:50 -0500 Subject: [PATCH] 6676 port generative networks vqvae (#7285) Partially fixes https://github.com/Project-MONAI/MONAI/issues/6676 ### Description Implements the VQ-VAE network, including the vector quantizer block, from MONAI Generative. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: Mark Graham Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: KumoLiu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/networks.rst | 13 + monai/bundle/scripts.py | 2 +- monai/networks/layers/__init__.py | 1 + monai/networks/layers/vector_quantizer.py | 233 +++++++++++ monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 14 +- monai/networks/nets/vqvae.py | 466 ++++++++++++++++++++++ tests/test_vector_quantizer.py | 89 +++++ tests/test_vqvae.py | 274 +++++++++++++ 9 files changed, 1085 insertions(+), 8 deletions(-) create mode 100644 monai/networks/layers/vector_quantizer.py create mode 100644 monai/networks/nets/vqvae.py create mode 100644 tests/test_vector_quantizer.py create mode 100644 tests/test_vqvae.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index dbfdf35784..d8be26264b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -258,6 +258,7 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift + Layers ------ @@ -408,6 +409,13 @@ Layers .. autoclass:: LLTM :members: +`Vector Quantizer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.layers.vector_quantizer.EMAQuantizer + :members: +.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils @@ -728,6 +736,11 @@ Nets .. autoclass:: voxelmorph +`VQ-VAE` +~~~~~~~~ +.. autoclass:: VQVAE + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 20a491e493..2565a3cf64 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -221,7 +221,7 @@ def _download_from_ngc( def _get_latest_bundle_version_monaihosting(name): url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - full_url = f"{url}/{name}" + full_url = f"{url}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..bd3e3af3af 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -37,4 +37,5 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .vector_quantizer import EMAQuantizer, VectorQuantizer from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py new file mode 100644 index 0000000000..9c354e1009 --- /dev/null +++ b/monai/networks/layers/vector_quantizer.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +from torch import nn + +__all__ = ["VectorQuantizer", "EMAQuantizer"] + + +class EMAQuantizer(nn.Module): + """ + Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural + Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation + that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit + 58d9a2746493717a7c9252938da7efa6006f3739. + + This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due + to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 + on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. + + Args: + spatial_dims: number of spatial dimensions of the input. + num_embeddings: number of atomic elements in the codebook. + embedding_dim: number of channels of the input and atomic elements. + commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. + decay: EMA decay. Defaults to 0.99. + epsilon: epsilon value. Defaults to 1e-5. + embedding_init: initialization method for the codebook. Defaults to "normal". + ddp_sync: whether to synchronize the codebook across processes. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + num_embeddings: int, + embedding_dim: int, + commitment_cost: float = 0.25, + decay: float = 0.99, + epsilon: float = 1e-5, + embedding_init: str = "normal", + ddp_sync: bool = True, + ): + super().__init__() + self.spatial_dims: int = spatial_dims + self.embedding_dim: int = embedding_dim + self.num_embeddings: int = num_embeddings + + assert self.spatial_dims in [2, 3], ValueError( + f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." + ) + + self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) + if embedding_init == "normal": + # Initialization is passed since the default one is normal inside the nn.Embedding + pass + elif embedding_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") + self.embedding.weight.requires_grad = False + + self.commitment_cost: float = commitment_cost + + self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) + self.register_buffer("ema_w", self.embedding.weight.data.clone()) + # declare types for mypy + self.ema_cluster_size: torch.Tensor + self.ema_w: torch.Tensor + self.decay: float = decay + self.epsilon: float = epsilon + + self.ddp_sync: bool = ddp_sync + + # Precalculating required permutation shapes + self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1] + self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( + range(1, self.spatial_dims + 1) + ) + + def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. + + Args: + inputs: Encoding space tensors of shape [B, C, H, W, D]. + + Returns: + torch.Tensor: Flatten version of the input of shape [B*H*W*D, C]. + torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings]. + torch.Tensor: Quantization indices of shape [B,H,W,D,1] + + """ + with torch.cuda.amp.autocast(enabled=False): + encoding_indices_view = list(inputs.shape) + del encoding_indices_view[1] + + inputs = inputs.float() + + # Converting to channel last format + flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) + + # Calculate Euclidean distances + distances = ( + (flat_input**2).sum(dim=1, keepdim=True) + + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) + - 2 * torch.mm(flat_input, self.embedding.weight.t()) + ) + + # Mapping distances to indexes + encoding_indices = torch.max(-distances, dim=1)[1] + encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() + + # Quantize and reshape + encoding_indices = encoding_indices.view(encoding_indices_view) + + return flat_input, encodings, encoding_indices + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + """ + Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space + [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the + decoder. + + Args: + embedding_indices: Tensor in channel last format which holds indices referencing atomic + elements from self.embedding + + Returns: + torch.Tensor: Quantize space representation of encoding_indices in channel first format. + """ + with torch.cuda.amp.autocast(enabled=False): + embedding: torch.Tensor = ( + self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() + ) + return embedding + + def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: + """ + TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the + example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused + + Args: + encodings_sum: The summation of one hot representation of what encoding was used for each + position. + dw: The multiplication of the one hot representation of what encoding was used for each + position with the flattened input. + + Returns: + None + """ + if self.ddp_sync and torch.distributed.is_initialized(): + torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) + else: + pass + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat_input, encodings, encoding_indices = self.quantize(inputs) + quantized = self.embed(encoding_indices) + + # Use EMA to update the embedding vectors + if self.training: + with torch.no_grad(): + encodings_sum = encodings.sum(0) + dw = torch.mm(encodings.t(), flat_input) + + if self.ddp_sync: + self.distributed_synchronization(encodings_sum, dw) + + self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) + + # Laplace smoothing of the cluster size + n = self.ema_cluster_size.sum() + weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n + self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) + self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) + + # Encoding Loss + loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + + return quantized, loss, encoding_indices + + +class VectorQuantizer(torch.nn.Module): + """ + Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of + the quantization in their own class. + + Args: + quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index + based quantized representation. + """ + + def __init__(self, quantizer: EMAQuantizer): + super().__init__() + + self.quantizer: EMAQuantizer = quantizer + + self.perplexity: torch.Tensor = torch.rand(1) + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + quantized, loss, encoding_indices = self.quantizer(inputs) + # Perplexity calculations + avg_probs = ( + torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) + .float() + .div(encoding_indices.numel()) + ) + + self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return loss, quantized + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.quantizer.embed(embedding_indices=embedding_indices) + + def quantize(self, encodings: torch.Tensor) -> torch.Tensor: + output = self.quantizer(encodings) + encoding_indices: torch.Tensor = output[2] + return encoding_indices diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ea08246d25..db3c77c717 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -113,3 +113,4 @@ from .vitautoenc import ViTAutoEnc from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet +from .vqvae import VQVAE diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 9a9f35d5ae..f7ae77f056 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -38,7 +38,7 @@ class _Upsample(nn.Module): Convolution-based upsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels to the layer. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ @@ -98,7 +98,7 @@ class _Downsample(nn.Module): Convolution-based downsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. """ @@ -132,7 +132,7 @@ class _ResBlock(nn.Module): residual connection between input and output. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: input channels to the layer. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. @@ -206,7 +206,7 @@ class _AttentionBlock(nn.Module): Attention block. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. num_head_channels: number of channels in each attention head. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of @@ -325,7 +325,7 @@ class Encoder(nn.Module): Convolutional cascade that downsamples the image into a spatial latent space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. @@ -463,7 +463,7 @@ class Decoder(nn.Module): Convolutional cascade upsampling from a spatial latent space into an image space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. channels: sequence of block output channels. in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. @@ -611,7 +611,7 @@ class AutoencoderKL(nn.Module): and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. out_channels: number of output channels. num_res_blocks: number of residual blocks (see _ResBlock) per level. diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py new file mode 100644 index 0000000000..d4771e203a --- /dev/null +++ b/monai/networks/nets/vqvae.py @@ -0,0 +1,466 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Tuple + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer +from monai.utils import ensure_tuple_rep + +__all__ = ["VQVAE"] + + +class VQVAEResidualUnit(nn.Module): + """ + Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving + Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf). + + The original implementation that can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150. + + Args: + spatial_dims: number of spatial spatial_dims of the input data. + in_channels: number of input channels. + num_res_channels: number of channels in the residual layers. + act: activation type and arguments. Defaults to RELU. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_channels: int, + act: tuple | str | None = Act.RELU, + dropout: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_res_channels = num_res_channels + self.act = act + self.dropout = dropout + self.bias = bias + + self.conv1 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=self.num_res_channels, + adn_ordering="DA", + act=self.act, + dropout=self.dropout, + bias=self.bias, + ) + + self.conv2 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_res_channels, + out_channels=self.in_channels, + bias=self.bias, + conv_only=True, + ) + + def forward(self, x): + return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) + + +class Encoder(nn.Module): + """ + Encoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of channels in the latent space (embedding_dim). + channels: sequence containing the number of channels at each level of the encoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + dropout: dropout ratio. + act: activation type and arguments. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Tuple[int, int, int, int]], + dropout: float, + act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.downsample_parameters = downsample_parameters + self.dropout = dropout + self.act = act + + blocks: list[nn.Module] = [] + + for i in range(len(self.channels)): + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels if i == 0 else self.channels[i - 1], + out_channels=self.channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=None if i == 0 else self.dropout, + dropout_dim=1, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], + ) + ) + + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + num_res_channels=self.num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.channels[len(self.channels) - 1], + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of channels in the latent space (embedding_dim). + out_channels: number of output channels. + channels: sequence containing the number of channels at each level of the decoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Tuple[int, int, int, int, int]], + dropout: float, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.dropout = dropout + self.act = act + self.output_act = output_act + + reversed_num_channels = list(reversed(self.channels)) + + blocks: list[nn.Module] = [] + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.channels)): + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=self.dropout if i != len(self.channels) - 1 else None, + norm=None, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.channels) - 1, + is_transposed=True, + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], + ) + ) + + if self.output_act: + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class VQVAE(nn.Module): + """ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) + + The original implementation can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of output channels. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + num_res_layers: number of sequential residual layers at each level. + channels: number of channels at each level. + num_res_channels: number of channels in the residual layers at each level. + num_embeddings: VectorQuantization number of atomic elements in the codebook. + embedding_dim: VectorQuantization number of channels of the input and atomic elements. + commitment_cost: VectorQuantization commitment_cost. + decay: VectorQuantization decay. + epsilon: VectorQuantization epsilon. + act: activation type and arguments. + dropout: dropout ratio. + output_act: activation type and arguments for the output. + ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int] = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), + downsample_parameters: Sequence[Tuple[int, int, int, int]] + | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] + | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + num_embeddings: int = 32, + embedding_dim: int = 64, + embedding_init: str = "normal", + commitment_cost: float = 0.25, + decay: float = 0.5, + epsilon: float = 1e-5, + dropout: float = 0.0, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, + ddp_sync: bool = True, + use_checkpointing: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_dims = spatial_dims + self.channels = channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing + + if isinstance(num_res_channels, int): + num_res_channels = ensure_tuple_rep(num_res_channels, len(channels)) + + if len(num_res_channels) != len(channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channls`." + ) + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels) + else: + upsample_parameters_tuple = upsample_parameters + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels) + else: + downsample_parameters_tuple = downsample_parameters + + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + for parameter in downsample_parameters_tuple: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters_tuple: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + + if len(downsample_parameters_tuple) != len(channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters_tuple) != len(channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters_tuple, + dropout=dropout, + act=act, + ) + + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters_tuple, + dropout=dropout, + act=act, + output_act=output_act, + ) + + self.quantizer = VectorQuantizer( + quantizer=EMAQuantizer( + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, + ) + ) + + def encode(self, images: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + output = self.encoder(images) + return output + + def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x_loss, x = self.quantizer(encodings) + return x, x_loss + + def decode(self, quantizations: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + output = self.decoder(quantizations) + return output + + def index_quantize(self, images: torch.Tensor) -> torch.Tensor: + return self.quantizer.quantize(self.encode(images=images)) + + def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.decode(self.quantizer.embed(embedding_indices)) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + quantizations, quantization_losses = self.quantize(self.encode(images)) + reconstruction = self.decode(quantizations) + + return reconstruction, quantization_losses + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z = self.encode(x) + e, _ = self.quantize(z) + return e + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + e, _ = self.quantize(z) + image = self.decode(e) + return image diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py new file mode 100644 index 0000000000..43533d0377 --- /dev/null +++ b/tests/test_vector_quantizer.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from math import prod + +import torch +from parameterized import parameterized + +from monai.networks.layers import EMAQuantizer, VectorQuantizer + +TEST_CASES = [ + [{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)], + [{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)], +] + + +class TestEMA(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_ema_shape(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + layer = layer.train() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + layer = layer.eval() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + @parameterized.expand(TEST_CASES) + def test_ema_quantize(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C) + self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E) + self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D]) + + def test_ema(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) + original_weight_0 = layer.embedding.weight[0].clone() + original_weight_1 = layer.embedding.weight[1].clone() + x_0 = original_weight_0 + x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 + + x_1 = original_weight_1 + x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_1 = x_1.repeat(1, 1, 1, 2) + + x = torch.cat([x_0, x_1], dim=0) + layer = layer.train() + _ = layer(x) + + self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) + self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) + + +class TestVectorQuantizer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_shape(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer(x) + self.assertEqual(outputs[1].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs.shape, output_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py new file mode 100644 index 0000000000..4916dc2faa --- /dev/null +++ b/tests/test_vqvae.py @@ -0,0 +1,274 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vqvae import VQVAE +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], +] + +TEST_LATENT_SHAPE = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "num_embeddings": 16, + "embedding_dim": 8, +} + + +class TestVQVAE(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) + + def test_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) + + def test_encode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8, 8)) + + def test_index_quantize_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8)) + + def test_decode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + def test_decode_samples_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + +if __name__ == "__main__": + unittest.main()