From 28d9b5cb185738a1fbd480c39fa6bd345d7f0664 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Wed, 12 Jul 2023 14:41:46 -0700 Subject: [PATCH 1/2] Added support for both 3 and 4 channels images. This change makes it possible to use both RGB and RGBA images effectively making the VAE and the Maskgit capable of generating transparency. --- infer_vae.py | 3 ++- muse_maskgit_pytorch/dataset.py | 8 +++++--- muse_maskgit_pytorch/vqgan_vae.py | 10 ++++++---- train_muse_maskgit.py | 3 ++- train_muse_vae.py | 3 ++- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/infer_vae.py b/infer_vae.py index be1e18b..01022a1 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -112,7 +112,7 @@ parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.") parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.") +parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument( @@ -435,6 +435,7 @@ def main(): center_crop=True if not args.no_center_crop and not args.random_crop else False, flip=not args.no_flip, random_crop=args.random_crop if args.random_crop else False, + alpha_channel=False if args.channels == 3 else True, ) if args.input_image and not args.input_folder: diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 055425a..f8cba1c 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -45,15 +45,17 @@ def __init__( stream=False, using_taming=False, random_crop=False, + alpha_channel=True ): super().__init__() self.dataset = dataset self.image_column = image_column self.stream = stream transform_list = [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")), T.Resize(image_size), ] + if flip: transform_list.append(T.RandomHorizontalFlip()) if center_crop and not random_crop: @@ -199,7 +201,7 @@ def __getitem__(self, index): class LocalTextImageDataset(Dataset): def __init__( - self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False + self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False, alpha_channel=False ): super().__init__() self.tokenizer = tokenizer @@ -229,7 +231,7 @@ def __init__( self.caption_pair.append(captions) transform_list = [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")), T.Resize(image_size), ] if flip: diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index 56c58f2..0887fe6 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -156,7 +156,7 @@ def forward(self, x): # discriminator class Discriminator(nn.Module): - def __init__(self, dims, channels=3, groups=16, init_kernel_size=5): + def __init__(self, dims, channels=4, groups=16, init_kernel_size=5): super().__init__() dim_pairs = zip(dims[:-1], dims[1:]) @@ -194,7 +194,7 @@ def __init__( self, dim: int, *, - channels=3, + channels=4, layers=4, layer_mults=None, num_resnet_blocks=1, @@ -337,7 +337,7 @@ def __init__( *, dim: int, accelerator: Accelerator = None, - channels=3, + channels=4, layers=4, l2_recon_loss=False, use_hinge_loss=True, @@ -407,11 +407,13 @@ def vgg(self): if exists(self._vgg): return self._vgg - vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT) + vgg = torchvision.models.vgg16(pretrained=True) + vgg.features[0] = nn.Conv2d(self.channels, 64, kernel_size=3, stride=1, padding=1) vgg.classifier = nn.Sequential(*vgg.classifier[:-2]) self._vgg = vgg.to(self.device) return self._vgg + @property def encoded_dim(self): return self.enc_dec.encoded_dim diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 802418d..e711cff 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -298,7 +298,7 @@ help="Image Size.", ) parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.") +parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument( @@ -812,6 +812,7 @@ def main(): flip=False if args.no_flip else True, using_taming=False if not args.taming_model_path else True, random_crop=args.random_crop if args.random_crop else False, + alpha_channel=False if args.channels == 3 else True, ) elif args.link: if not args.dataset_name: diff --git a/train_muse_vae.py b/train_muse_vae.py index 32cfeab..639edaf 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -222,7 +222,7 @@ ) parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.") +parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument( @@ -563,6 +563,7 @@ def main(): flip=not args.no_flip, stream=args.streaming, random_crop=args.random_crop, + alpha_channel=False if args.channels == 3 else True, ) # dataloader From 113af92b73b05012079e7385fb74fe8128fdf377 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jul 2023 21:42:54 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- infer_vae.py | 4 +++- muse_maskgit_pytorch/dataset.py | 28 ++++++++++++++++++++++++---- muse_maskgit_pytorch/vqgan_vae.py | 1 - train_muse_maskgit.py | 4 +++- train_muse_vae.py | 4 +++- 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/infer_vae.py b/infer_vae.py index 01022a1..f5d6efd 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -112,7 +112,9 @@ parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.") parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") +parser.add_argument( + "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." +) parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument( diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index f8cba1c..7659e38 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -45,14 +45,20 @@ def __init__( stream=False, using_taming=False, random_crop=False, - alpha_channel=True + alpha_channel=True, ): super().__init__() self.dataset = dataset self.image_column = image_column self.stream = stream transform_list = [ - T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")), + T.Lambda( + lambda img: img.convert("RGBA") + if img.mode != "RGBA" and alpha_channel + else img + if img.mode == "RGB" and not alpha_channel + else img.convert("RGB") + ), T.Resize(image_size), ] @@ -201,7 +207,15 @@ def __getitem__(self, index): class LocalTextImageDataset(Dataset): def __init__( - self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False, alpha_channel=False + self, + path, + image_size, + tokenizer, + flip=True, + center_crop=True, + using_taming=False, + random_crop=False, + alpha_channel=False, ): super().__init__() self.tokenizer = tokenizer @@ -231,7 +245,13 @@ def __init__( self.caption_pair.append(captions) transform_list = [ - T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")), + T.Lambda( + lambda img: img.convert("RGBA") + if img.mode != "RGBA" and alpha_channel + else img + if img.mode == "RGB" and not alpha_channel + else img.convert("RGB") + ), T.Resize(image_size), ] if flip: diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index 0887fe6..5715e65 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -413,7 +413,6 @@ def vgg(self): self._vgg = vgg.to(self.device) return self._vgg - @property def encoded_dim(self): return self.enc_dec.encoded_dim diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index e711cff..4ecd529 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -298,7 +298,9 @@ help="Image Size.", ) parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") +parser.add_argument( + "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." +) parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument( diff --git a/train_muse_vae.py b/train_muse_vae.py index 639edaf..655eb2a 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -222,7 +222,9 @@ ) parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") +parser.add_argument( + "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." +) parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument(