From 281cb0119c01eaa8e6c841880b91f92f45e8d7f7 Mon Sep 17 00:00:00 2001 From: NoTody <88493484+NoTody@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:05:52 -0700 Subject: [PATCH] 6973 sincos pos embed (#6986) Fixes #6973 ### Description Adding support for sincos positional embedding for monai.networks.blocks.patchembedding.PatchEmbedding class. This pull request corresponds to this opened issue https://github.com/Project-MONAI/MONAI/issues/6973 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NoTody --- monai/networks/blocks/mlp.py | 2 +- monai/networks/blocks/patchembedding.py | 51 ++++++++--- monai/networks/blocks/pos_embed_utils.py | 103 ++++++++++++++++++++++ monai/networks/blocks/selfattention.py | 2 +- monai/networks/blocks/transformerblock.py | 2 +- monai/networks/nets/transchex.py | 2 +- monai/networks/nets/unetr.py | 13 ++- monai/networks/nets/vit.py | 22 +++-- monai/networks/nets/vitautoenc.py | 17 ++-- tests/test_patchembedding.py | 54 ++++++++---- 10 files changed, 215 insertions(+), 53 deletions(-) create mode 100644 monai/networks/blocks/pos_embed_utils.py diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index e3ab94b32a..d3510b64d3 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -32,7 +32,7 @@ def __init__( Args: hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. - dropout_rate: faction of the input units to drop. + dropout_rate: fraction of the input units to drop. act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others. dropout_mode: dropout mode, can be "vit" or "swin". "vit" mode uses two dropout instances as implemented in diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 57c0c5ee02..f6d390692e 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -19,12 +19,14 @@ import torch.nn.functional as F from torch.nn import LayerNorm +from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding from monai.networks.layers import Conv, trunc_normal_ -from monai.utils import ensure_tuple_rep, optional_import +from monai.utils import deprecated_arg, ensure_tuple_rep, optional_import from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") -SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"} +SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"} +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} class PatchEmbeddingBlock(nn.Module): @@ -35,10 +37,12 @@ class PatchEmbeddingBlock(nn.Module): Example:: >>> from monai.networks.blocks import PatchEmbeddingBlock - >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv") + >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, + >>> proj_type="conv", pos_embed_type="sincos") """ + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, @@ -46,7 +50,9 @@ def __init__( patch_size: Sequence[int] | int, hidden_size: int, num_heads: int, - pos_embed: str, + pos_embed: str = "conv", + proj_type: str = "conv", + pos_embed_type: str = "learnable", dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: @@ -57,11 +63,12 @@ def __init__( patch_size: dimension of patch size. hidden_size: dimension of hidden layer. num_heads: number of attention heads. - pos_embed: position embedding layer type. - dropout_rate: faction of the input units to drop. + proj_type: patch embedding layer type. + pos_embed_type: position embedding layer type. + dropout_rate: fraction of the input units to drop. spatial_dims: number of spatial dimensions. - - + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. """ super().__init__() @@ -72,24 +79,25 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.") - self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES) + self.proj_type = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES) + self.pos_embed_type = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(patch_size, spatial_dims) for m, p in zip(img_size, patch_size): if m < p: raise ValueError("patch_size should be smaller than img_size.") - if self.pos_embed == "perceptron" and m % p != 0: + if self.proj_type == "perceptron" and m % p != 0: raise ValueError("patch_size should be divisible by img_size for perceptron.") self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) self.patch_dim = int(in_channels * np.prod(patch_size)) self.patch_embeddings: nn.Module - if self.pos_embed == "conv": + if self.proj_type == "conv": self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) - elif self.pos_embed == "perceptron": + elif self.proj_type == "perceptron": # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) @@ -100,7 +108,22 @@ def __init__( ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) - trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + + if self.pos_embed_type == "none": + pass + elif self.pos_embed_type == "learnable": + trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + elif self.pos_embed_type == "sincos": + grid_size = [] + for in_size, pa_size in zip(img_size, patch_size): + grid_size.append(in_size // pa_size) + + with torch.no_grad(): + pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) + self.position_embeddings.data.copy_(pos_embeddings.float()) + else: + raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") + self.apply(self._init_weights) def _init_weights(self, m): @@ -114,7 +137,7 @@ def _init_weights(self, m): def forward(self, x): x = self.patch_embeddings(x) - if self.pos_embed == "conv": + if self.proj_type == "conv": x = x.flatten(2).transpose(-1, -2) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py new file mode 100644 index 0000000000..e1f47cd7e9 --- /dev/null +++ b/monai/networks/blocks/pos_embed_utils.py @@ -0,0 +1,103 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections.abc +from itertools import repeat +from typing import List, Union + +import torch +import torch.nn as nn + +__all__ = ["build_sincos_position_embedding"] + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def build_sincos_position_embedding( + grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 +) -> torch.nn.Parameter: + """ + Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature. + Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py + + Args: + grid_size (List[int]): The size of the grid in each spatial dimension. + embed_dim (int): The dimension of the embedding. + spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). + temperature (float): The temperature for the sin-cos position embedding. + + Returns: + pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter. + """ + + if spatial_dims == 2: + to_2tuple = _ntuple(2) + grid_size_t = to_2tuple(grid_size) + h, w = grid_size_t + grid_h = torch.arange(h, dtype=torch.float32) + grid_w = torch.arange(w, dtype=torch.float32) + + grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij") + + assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega]) + out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) + pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] + elif spatial_dims == 3: + to_3tuple = _ntuple(3) + grid_size_t = to_3tuple(grid_size) + h, w, d = grid_size_t + grid_h = torch.arange(h, dtype=torch.float32) + grid_w = torch.arange(w, dtype=torch.float32) + grid_d = torch.arange(d, dtype=torch.float32) + + grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij") + + assert embed_dim % 6 == 0, "Embed dimension must be divisible by 6 for 3D sin-cos position embedding" + + pos_dim = embed_dim // 6 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega]) + out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) + out_d = torch.einsum("m,d->md", [grid_d.flatten(), omega]) + pos_emb = torch.cat( + [ + torch.sin(out_w), + torch.cos(out_w), + torch.sin(out_h), + torch.cos(out_h), + torch.sin(out_d), + torch.cos(out_d), + ], + dim=1, + )[None, :, :] + else: + raise NotImplementedError("Spatial Dimension Size {spatial_dims} Not Implemented!") + + pos_embed = nn.Parameter(pos_emb) + pos_embed.requires_grad = False + + return pos_embed diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 71fb549db8..7c81c1704f 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -37,7 +37,7 @@ def __init__( Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. - dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index f7d4e0e130..ddf959dad2 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -37,7 +37,7 @@ def __init__( hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. num_heads (int): number of attention heads. - dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index 31e27ffbf2..c73415b63e 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -314,7 +314,7 @@ def __init__( num_language_layers: number of language transformer layers. num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. - drop_out: faction of the input units to drop. + drop_out: fraction of the input units to drop. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 4cdcd73c4d..bfcd6e7d47 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -18,7 +18,7 @@ from monai.networks.blocks.dynunet_block import UnetOutBlock from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock from monai.networks.nets.vit import ViT -from monai.utils import ensure_tuple_rep +from monai.utils import deprecated_arg, ensure_tuple_rep class UNETR(nn.Module): @@ -27,6 +27,7 @@ class UNETR(nn.Module): UNETR: Transformers for 3D Medical Image Segmentation " """ + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, @@ -37,6 +38,7 @@ def __init__( mlp_dim: int = 3072, num_heads: int = 12, pos_embed: str = "conv", + proj_type: str = "conv", norm_name: tuple | str = "instance", conv_block: bool = True, res_block: bool = True, @@ -54,7 +56,7 @@ def __init__( hidden_size: dimension of hidden layer. Defaults to 768. mlp_dim: dimension of feedforward layer. Defaults to 3072. num_heads: number of attention heads. Defaults to 12. - pos_embed: position embedding layer type. Defaults to "conv". + proj_type: patch embedding layer type. Defaults to "conv". norm_name: feature normalization type and arguments. Defaults to "instance". conv_block: if convolutional block is used. Defaults to True. res_block: if residual block is used. Defaults to True. @@ -63,6 +65,9 @@ def __init__( qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. + Examples:: # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm @@ -72,7 +77,7 @@ def __init__( >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2) # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm - >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') + >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance') """ @@ -98,7 +103,7 @@ def __init__( mlp_dim=mlp_dim, num_layers=self.num_layers, num_heads=num_heads, - pos_embed=pos_embed, + proj_type=proj_type, classification=self.classification, dropout_rate=dropout_rate, spatial_dims=spatial_dims, diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 8cd42b54b1..f033d7ff4a 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -18,6 +18,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock +from monai.utils import deprecated_arg __all__ = ["ViT"] @@ -30,6 +31,7 @@ class ViT(nn.Module): ViT supports Torchscript but only works for Pytorch after 1.8. """ + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, @@ -40,6 +42,8 @@ def __init__( num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", + proj_type: str = "conv", + pos_embed_type: str = "learnable", classification: bool = False, num_classes: int = 2, dropout_rate: float = 0.0, @@ -57,10 +61,11 @@ def __init__( mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072. num_layers (int, optional): number of transformer blocks. Defaults to 12. num_heads (int, optional): number of attention heads. Defaults to 12. - pos_embed (str, optional): position embedding layer type. Defaults to "conv". + proj_type (str, optional): patch embedding layer type. Defaults to "conv". + pos_embed_type (str, optional): position embedding type. Defaults to "learnable". classification (bool, optional): bool argument to determine if classification is used. Defaults to False. num_classes (int, optional): number of classes if classification is used. Defaults to 2. - dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. spatial_dims (int, optional): number of spatial dimensions. Defaults to 3. post_activation (str, optional): add a final acivation function to the classification head when `classification` is True. Default to "Tanh" for `nn.Tanh()`. @@ -68,16 +73,20 @@ def __init__( qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. + Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone - >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') + >>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos') # for 3-channel with image size of (128,128,128), 24 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) + >>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True) # for 3-channel with image size of (224,224), 12 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2) + >>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True, + >>> spatial_dims=2) """ @@ -96,7 +105,8 @@ def __init__( patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, - pos_embed=pos_embed, + proj_type=proj_type, + pos_embed_type=pos_embed_type, dropout_rate=dropout_rate, spatial_dims=spatial_dims, ) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 12d7d4e376..59aae2d54a 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -20,7 +20,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.layers import Conv -from monai.utils import ensure_tuple_rep, is_sqrt +from monai.utils import deprecated_arg, ensure_tuple_rep, is_sqrt __all__ = ["ViTAutoEnc"] @@ -33,6 +33,7 @@ class ViTAutoEnc(nn.Module): Modified to also give same dimension outputs as the input size of the image """ + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, @@ -45,6 +46,7 @@ def __init__( num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", + proj_type: str = "conv", dropout_rate: float = 0.0, spatial_dims: int = 3, qkv_bias: bool = False, @@ -61,20 +63,23 @@ def __init__( mlp_dim: dimension of feedforward layer. Defaults to 3072. num_layers: number of transformer blocks. Defaults to 12. num_heads: number of attention heads. Defaults to 12. - pos_embed: position embedding layer type. Defaults to "conv". - dropout_rate: faction of the input units to drop. Defaults to 0.0. + proj_type: position embedding layer type. Defaults to "conv". + dropout_rate: fraction of the input units to drop. Defaults to 0.0. spatial_dims: number of spatial dimensions. Defaults to 3. qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. + Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone # It will provide an output of same size as that of the input - >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), proj_type='conv') # for 3-channel with image size of (128,128,128), output will be same size as of input - >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), proj_type='conv') """ @@ -94,7 +99,7 @@ def __init__( patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, - pos_embed=pos_embed, + proj_type=proj_type, dropout_rate=dropout_rate, spatial_dims=self.spatial_dims, ) diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index ae7fd14401..77ade984eb 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -21,6 +21,7 @@ from monai.networks import eval_mode from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -31,25 +32,27 @@ for img_size in [32, 64]: for patch_size in [8, 16]: for num_heads in [8, 12]: - for pos_embed in ["conv", "perceptron"]: - # for classification in (False, True): # TODO: add classification tests - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": hidden_size, - "num_heads": num_heads, - "pos_embed": pos_embed, - "dropout_rate": dropout_rate, - }, - (2, in_channels, *([img_size] * nd)), - (2, (img_size // patch_size) ** nd, hidden_size), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) + for proj_type in ["conv", "perceptron"]: + for pos_embed_type in ["none", "learnable", "sincos"]: + # for classification in (False, True): # TODO: add classification tests + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": hidden_size, + "num_heads": num_heads, + "pos_embed": proj_type, + "pos_embed_type": pos_embed_type, + "dropout_rate": dropout_rate, + }, + (2, in_channels, *([img_size] * nd)), + (2, (img_size // patch_size) ** nd, hidden_size), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) TEST_CASE_PATCHEMBED = [] for patch_size in [2]: @@ -72,6 +75,7 @@ TEST_CASE_PATCHEMBED.append(test_case) +@SkipIfBeforePyTorchVersion((1, 11, 1)) class TestPatchEmbeddingBlock(unittest.TestCase): def setUp(self): self.threads = torch.get_num_threads() @@ -97,6 +101,7 @@ def test_ill_arg(self): hidden_size=128, num_heads=12, pos_embed="conv", + pos_embed_type="sincos", dropout_rate=5.0, ) @@ -108,6 +113,7 @@ def test_ill_arg(self): hidden_size=512, num_heads=8, pos_embed="perceptron", + pos_embed_type="sincos", dropout_rate=0.3, ) @@ -132,6 +138,16 @@ def test_ill_arg(self): pos_embed="perceptron", dropout_rate=0.3, ) + with self.assertRaises(ValueError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + num_heads=8, + proj_type="perceptron", + dropout_rate=0.3, + ) with self.assertRaises(ValueError): PatchEmbeddingBlock(