From d99044bb91f4bbd09354f674d74aa4238139a422 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 27 Nov 2023 10:42:50 +0000 Subject: [PATCH 01/12] Adds network and tests --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 816 +++++++++++++++++++++++++++ tests/test_autoencoderkl.py | 270 +++++++++ 4 files changed, 1092 insertions(+) create mode 100644 monai/networks/nets/autoencoderkl.py create mode 100644 tests/test_autoencoderkl.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..45a9d33388 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -595,6 +595,11 @@ Nets .. autoclass:: AutoEncoder :members: +`AutoEncoderKL` +~~~~~~~~~~~~~ +.. autoclass:: AutoEncoderKL + :members: + `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9247aaee85..ea08246d25 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -14,6 +14,7 @@ from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder +from .autoencoderkl import AutoencoderKL from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py new file mode 100644 index 0000000000..acdff6c96a --- /dev/null +++ b/monai/networks/nets/autoencoderkl.py @@ -0,0 +1,816 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +import math +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.utils import ensure_tuple_rep + +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["AutoencoderKL"] + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based upsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + super().__init__() + if use_convtranspose: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + return self.conv(x) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + x = self.conv(x) + return x + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class _ResBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = F.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Attention block. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +class Encoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + num_channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True) + ) + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + reversed_block_out_channels = list(reversed(num_channels)) + + blocks = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + _Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) + ) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class AutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + num_channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_checkpointing: if True, use activation checkpointing to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_channels=num_channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = Decoder( + spatial_dims=spatial_dims, + num_channels=num_channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + self.use_checkpointing = use_checkpointing + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + if self.use_checkpointing: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) + + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu) + return reconstruction + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + if self.use_checkpointing: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) + return dec + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + image = self.decode(z) + return image diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py new file mode 100644 index 0000000000..58943a1626 --- /dev/null +++ b/tests/test_autoencoderkl.py @@ -0,0 +1,270 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import AutoencoderKL + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + + +class TestAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + @parameterized.expand(CASES) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + +if __name__ == "__main__": + unittest.main() From 9d16b48033c2d368b6c80c5fa5b6e4581ad190f4 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 27 Nov 2023 11:02:27 +0000 Subject: [PATCH 02/12] Mypy fixes Signed-off-by: Mark Graham --- monai/networks/nets/autoencoderkl.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index acdff6c96a..7466f3dba6 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -14,6 +14,7 @@ import importlib.util import math from collections.abc import Sequence +from typing import List import torch import torch.nn as nn @@ -80,7 +81,8 @@ def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_convtranspose: - return self.conv(x) + conv: torch.Tensor = self.conv(x) + return conv # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # https://github.com/pytorch/pytorch/issues/86679 @@ -177,6 +179,7 @@ def __init__( conv_only=True, ) + self.nin_shortcut: nn.Module if self.in_channels != self.out_channels: self.nin_shortcut = Convolution( spatial_dims=spatial_dims, @@ -271,7 +274,7 @@ def _memory_efficient_attention_xformers( query = query.contiguous() key = key.contiguous() value = value.contiguous() - x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + x: torch.Tensor = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: @@ -369,7 +372,7 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels - blocks = [] + blocks: List[nn.Module] = [] # Initial convolution blocks.append( Convolution( @@ -513,7 +516,8 @@ def __init__( reversed_block_out_channels = list(reversed(num_channels)) - blocks = [] + blocks: List[nn.Module] = [] + # Initial convolution blocks.append( Convolution( @@ -794,6 +798,7 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: decoded image tensor """ z = self.post_quant_conv(z) + dec: torch.Tensor if self.use_checkpointing: dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) else: From fbfd8f588849603506dbe9b0a39bae226e349536 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 27 Nov 2023 11:17:22 +0000 Subject: [PATCH 03/12] Fix doc Signed-off-by: Mark Graham --- docs/source/networks.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 45a9d33388..dbfdf35784 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -596,8 +596,8 @@ Nets :members: `AutoEncoderKL` -~~~~~~~~~~~~~ -.. autoclass:: AutoEncoderKL +~~~~~~~~~~~~~~~ +.. autoclass:: AutoencoderKL :members: `VarAutoEncoder` From cd60374126c4532085f94917ac13ada59f248cb9 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 08:27:24 -0600 Subject: [PATCH 04/12] Update monai/networks/nets/autoencoderkl.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/networks/nets/autoencoderkl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 7466f3dba6..6d07da2bf0 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -203,8 +203,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: h = F.silu(h) h = self.conv2(h) - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) + x = self.nin_shortcut(x) return x + h From b9798b3c9737bd752dd7a9e13468098ec3c2652f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 08:42:42 -0600 Subject: [PATCH 05/12] Update monai/networks/nets/autoencoderkl.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/networks/nets/autoencoderkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 6d07da2bf0..3a83627895 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -636,7 +636,7 @@ class AutoencoderKL(nn.Module): with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - use_checkpointing: if True, use activation checkpointing to save memory. + use_checkpoint: if True, use activation checkpoint to save memory. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ From a6bb9258da10b701c1c5a27076a214f9f259b8d6 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 14:50:46 +0000 Subject: [PATCH 06/12] Rename num_channels to channels --- monai/networks/nets/autoencoderkl.py | 56 ++++++++++++++-------------- tests/test_autoencoderkl.py | 20 +++++----- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 3a83627895..6e70e0da25 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -338,7 +338,7 @@ class Encoder(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. - num_channels: sequence of block output channels. + channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. num_res_blocks: number of residual blocks (see _ResBlock) per level. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. @@ -352,7 +352,7 @@ def __init__( self, spatial_dims: int, in_channels: int, - num_channels: Sequence[int], + channels: Sequence[int], out_channels: int, num_res_blocks: Sequence[int], norm_num_groups: int, @@ -364,7 +364,7 @@ def __init__( super().__init__() self.spatial_dims = spatial_dims self.in_channels = in_channels - self.num_channels = num_channels + self.channels = channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.norm_num_groups = norm_num_groups @@ -377,7 +377,7 @@ def __init__( Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels[0], + out_channels=channels[0], strides=1, kernel_size=3, padding=1, @@ -386,11 +386,11 @@ def __init__( ) # Residual and downsampling blocks - output_channel = num_channels[0] - for i in range(len(num_channels)): + output_channel = channels[0] + for i in range(len(channels)): input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 + output_channel = channels[i] + is_final_block = i == len(channels) - 1 for _ in range(self.num_res_blocks[i]): blocks.append( @@ -422,17 +422,17 @@ def __init__( blocks.append( _ResBlock( spatial_dims=spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, - out_channels=num_channels[-1], + out_channels=channels[-1], ) ) blocks.append( _AttentionBlock( spatial_dims=spatial_dims, - num_channels=num_channels[-1], + num_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, use_flash_attention=use_flash_attention, @@ -441,20 +441,20 @@ def __init__( blocks.append( _ResBlock( spatial_dims=spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, - out_channels=num_channels[-1], + out_channels=channels[-1], ) ) # Normalise and convert to latent size blocks.append( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True) + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True) ) blocks.append( Convolution( spatial_dims=self.spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], out_channels=out_channels, strides=1, kernel_size=3, @@ -477,7 +477,7 @@ class Decoder(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). - num_channels: sequence of block output channels. + channels: sequence of block output channels. in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. num_res_blocks: number of residual blocks (see _ResBlock) per level. @@ -492,7 +492,7 @@ class Decoder(nn.Module): def __init__( self, spatial_dims: int, - num_channels: Sequence[int], + channels: Sequence[int], in_channels: int, out_channels: int, num_res_blocks: Sequence[int], @@ -505,7 +505,7 @@ def __init__( ) -> None: super().__init__() self.spatial_dims = spatial_dims - self.num_channels = num_channels + self.channels = channels self.in_channels = in_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks @@ -513,7 +513,7 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels - reversed_block_out_channels = list(reversed(num_channels)) + reversed_block_out_channels = list(reversed(channels)) blocks: List[nn.Module] = [] @@ -566,7 +566,7 @@ def __init__( for i in range(len(reversed_block_out_channels)): block_in_ch = block_out_ch block_out_ch = reversed_block_out_channels[i] - is_final_block = i == len(num_channels) - 1 + is_final_block = i == len(channels) - 1 for _ in range(reversed_num_res_blocks[i]): blocks.append( @@ -628,7 +628,7 @@ class AutoencoderKL(nn.Module): in_channels: number of input channels. out_channels: number of output channels. num_res_blocks: number of residual blocks (see _ResBlock) per level. - num_channels: sequence of block output channels. + channels: number of output channels for each block. attention_levels: sequence of levels to add attention. latent_channels: latent embedding dimension. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. @@ -646,7 +646,7 @@ def __init__( in_channels: int = 1, out_channels: int = 1, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), + channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), latent_channels: int = 3, norm_num_groups: int = 32, @@ -660,16 +660,16 @@ def __init__( super().__init__() # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") - if len(num_channels) != len(attention_levels): + if len(channels) != len(attention_levels): raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - if len(num_res_blocks) != len(num_channels): + if len(num_res_blocks) != len(channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " "`num_channels`." @@ -683,7 +683,7 @@ def __init__( self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, - num_channels=num_channels, + channels=channels, out_channels=latent_channels, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, @@ -694,7 +694,7 @@ def __init__( ) self.decoder = Decoder( spatial_dims=spatial_dims, - num_channels=num_channels, + channels=channels, in_channels=latent_channels, out_channels=out_channels, num_res_blocks=num_res_blocks, diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 58943a1626..dfe6dc2b56 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -27,7 +27,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (4, 4, 4), + "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), "num_res_blocks": 1, @@ -42,7 +42,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (4, 4, 4), + "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), "num_res_blocks": (1, 1, 2), @@ -57,7 +57,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (4, 4, 4), + "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), "num_res_blocks": 1, @@ -72,7 +72,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (4, 4, 4), + "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, True), "num_res_blocks": 1, @@ -87,7 +87,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (4, 4, 4), + "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), "num_res_blocks": 1, @@ -103,7 +103,7 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": (4, 4, 4), + "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, False), "num_res_blocks": 1, @@ -120,7 +120,7 @@ "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "num_channels": (4, 4, 4), + "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, True), "num_res_blocks": 1, @@ -162,7 +162,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): spatial_dims=2, in_channels=1, out_channels=1, - num_channels=(24, 24, 24), + channels=(24, 24, 24), attention_levels=(False, False, False), latent_channels=8, num_res_blocks=1, @@ -175,7 +175,7 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): spatial_dims=2, in_channels=1, out_channels=1, - num_channels=(24, 24, 24), + channels=(24, 24, 24), attention_levels=(False, False), latent_channels=8, num_res_blocks=1, @@ -188,7 +188,7 @@ def test_model_num_channels_not_same_size_of_num_res_blocks(self): spatial_dims=2, in_channels=1, out_channels=1, - num_channels=(24, 24, 24), + channels=(24, 24, 24), attention_levels=(False, False, False), latent_channels=8, num_res_blocks=(8, 8), From 4caafe0423e3a162de85f4cb53dd7818d1d08ab1 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 14:51:05 +0000 Subject: [PATCH 07/12] Use optional import --- monai/networks/nets/autoencoderkl.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 6e70e0da25..55ac5635d7 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -11,7 +11,6 @@ from __future__ import annotations -import importlib.util import math from collections.abc import Sequence from typing import List @@ -21,21 +20,11 @@ import torch.nn.functional as F from monai.networks.blocks import Convolution -from monai.utils import ensure_tuple_rep # To install xformers, use pip install xformers==0.0.16rc401 -if importlib.util.find_spec("xformers") is not None: - import xformers - import xformers.ops - - has_xformers = True -else: - xformers = None - has_xformers = False - -# TODO: Use MONAI's optional_import -# from monai.utils import optional_import -# xformers, has_xformers = optional_import("xformers.ops", name="xformers") +from monai.utils import ensure_tuple_rep, optional_import + +xformers, has_xformers = optional_import("xformers") __all__ = ["AutoencoderKL"] From 2bf01697ab887a93078623c4d6ece7a95c81728e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 14:53:11 +0000 Subject: [PATCH 08/12] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: d99044bb91f4bbd09354f674d74aa4238139a422 I, Mark Graham , hereby add my Signed-off-by to this commit: a6bb9258da10b701c1c5a27076a214f9f259b8d6 I, Mark Graham , hereby add my Signed-off-by to this commit: 4caafe0423e3a162de85f4cb53dd7818d1d08ab1 Signed-off-by: Mark Graham --- monai/networks/nets/autoencoderkl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 55ac5635d7..2643f0889a 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -437,9 +437,7 @@ def __init__( ) ) # Normalise and convert to latent size - blocks.append( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True) - ) + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True)) blocks.append( Convolution( spatial_dims=self.spatial_dims, From 6a9f6572ea9c96ded8b0a3477b597b7dd05e808b Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 14:55:55 +0000 Subject: [PATCH 09/12] Updates use_checkpointing to use_checkpoint Signed-off-by: Mark Graham --- monai/networks/nets/autoencoderkl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 2643f0889a..9a9f35d5ae 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -641,7 +641,7 @@ def __init__( with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, use_flash_attention: bool = False, - use_checkpointing: bool = False, + use_checkpoint: bool = False, use_convtranspose: bool = False, ) -> None: super().__init__() @@ -720,7 +720,7 @@ def __init__( conv_only=True, ) self.latent_channels = latent_channels - self.use_checkpointing = use_checkpointing + self.use_checkpoint = use_checkpoint def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -730,7 +730,7 @@ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x: BxCx[SPATIAL DIMS] tensor """ - if self.use_checkpointing: + if self.use_checkpoint: h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) else: h = self.encoder(x) @@ -785,7 +785,7 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: """ z = self.post_quant_conv(z) dec: torch.Tensor - if self.use_checkpointing: + if self.use_checkpoint: dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) else: dec = self.decoder(z) From 3e7dc6cac4fe92322af01f1708a00ec762f2fdb2 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 15:06:04 +0000 Subject: [PATCH 10/12] Checkpoint fixes Signed-off-by: Mark Graham --- tests/test_autoencoderkl.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index dfe6dc2b56..771b78d649 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -18,6 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import AutoencoderKL +from tests.utils import SkipIfBeforePyTorchVersion device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -143,12 +144,13 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape) + @SkipIfBeforePyTorchVersion((1, 10)) @parameterized.expand(CASES) def test_shape_with_convtranspose_and_checkpointing( self, input_param, input_shape, expected_shape, expected_latent_shape ): input_param = input_param.copy() - input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) @@ -202,10 +204,11 @@ def test_shape_reconstruction(self): result = net.reconstruct(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + @SkipIfBeforePyTorchVersion((1, 10)) def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): input_param, input_shape, expected_shape, _ = CASES[0] input_param = input_param.copy() - input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): result = net.reconstruct(torch.randn(input_shape).to(device)) @@ -219,10 +222,11 @@ def test_shape_encode(self): self.assertEqual(result[0].shape, expected_latent_shape) self.assertEqual(result[1].shape, expected_latent_shape) + @SkipIfBeforePyTorchVersion((1, 10)) def test_shape_encode_with_convtranspose_and_checkpointing(self): input_param, input_shape, _, expected_latent_shape = CASES[0] input_param = input_param.copy() - input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): result = net.encode(torch.randn(input_shape).to(device)) @@ -238,10 +242,11 @@ def test_shape_sampling(self): ) self.assertEqual(result.shape, expected_latent_shape) + @SkipIfBeforePyTorchVersion((1, 10)) def test_shape_sampling_convtranspose_and_checkpointing(self): input_param, _, _, expected_latent_shape = CASES[0] input_param = input_param.copy() - input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): result = net.sampling( @@ -256,10 +261,11 @@ def test_shape_decode(self): result = net.decode(torch.randn(latent_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) + @SkipIfBeforePyTorchVersion((1, 10)) def test_shape_decode_convtranspose_and_checkpointing(self): input_param, expected_input_shape, _, latent_shape = CASES[0] input_param = input_param.copy() - input_param.update({"use_checkpointing": True, "use_convtranspose": True}) + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) net = AutoencoderKL(**input_param).to(device) with eval_mode(net): result = net.decode(torch.randn(latent_shape).to(device)) From d4302d964cf5aa63372a0becb4b5b4d4af6962e1 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 15:19:24 +0000 Subject: [PATCH 11/12] Change decorator order so test skipping works Signed-off-by: Mark Graham --- tests/test_autoencoderkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 771b78d649..5978a03aa7 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -144,8 +144,8 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape) - @SkipIfBeforePyTorchVersion((1, 10)) @parameterized.expand(CASES) + @SkipIfBeforePyTorchVersion((1, 10)) def test_shape_with_convtranspose_and_checkpointing( self, input_param, input_shape, expected_shape, expected_latent_shape ): From 68171798d703b822e59676bababc9b4e9e968097 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 29 Nov 2023 15:58:49 +0000 Subject: [PATCH 12/12] Version fix Signed-off-by: Mark Graham --- tests/test_autoencoderkl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 5978a03aa7..448f1e8e9a 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -145,7 +145,7 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s self.assertEqual(result[2].shape, expected_latent_shape) @parameterized.expand(CASES) - @SkipIfBeforePyTorchVersion((1, 10)) + @SkipIfBeforePyTorchVersion((1, 11)) def test_shape_with_convtranspose_and_checkpointing( self, input_param, input_shape, expected_shape, expected_latent_shape ): @@ -204,7 +204,7 @@ def test_shape_reconstruction(self): result = net.reconstruct(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @SkipIfBeforePyTorchVersion((1, 10)) + @SkipIfBeforePyTorchVersion((1, 11)) def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): input_param, input_shape, expected_shape, _ = CASES[0] input_param = input_param.copy() @@ -222,7 +222,7 @@ def test_shape_encode(self): self.assertEqual(result[0].shape, expected_latent_shape) self.assertEqual(result[1].shape, expected_latent_shape) - @SkipIfBeforePyTorchVersion((1, 10)) + @SkipIfBeforePyTorchVersion((1, 11)) def test_shape_encode_with_convtranspose_and_checkpointing(self): input_param, input_shape, _, expected_latent_shape = CASES[0] input_param = input_param.copy() @@ -242,7 +242,7 @@ def test_shape_sampling(self): ) self.assertEqual(result.shape, expected_latent_shape) - @SkipIfBeforePyTorchVersion((1, 10)) + @SkipIfBeforePyTorchVersion((1, 11)) def test_shape_sampling_convtranspose_and_checkpointing(self): input_param, _, _, expected_latent_shape = CASES[0] input_param = input_param.copy() @@ -261,7 +261,7 @@ def test_shape_decode(self): result = net.decode(torch.randn(latent_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) - @SkipIfBeforePyTorchVersion((1, 10)) + @SkipIfBeforePyTorchVersion((1, 11)) def test_shape_decode_convtranspose_and_checkpointing(self): input_param, expected_input_shape, _, latent_shape = CASES[0] input_param = input_param.copy()