From 0006c2dba433c8fbaa4c1add67ade2928802bf29 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Tue, 13 Jun 2023 05:03:14 -0700 Subject: [PATCH 01/16] Fixed missing arguments on the train_muse_vae.py script. --- train_muse_vae.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/train_muse_vae.py b/train_muse_vae.py index a2c53a3..0eaeab5 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -300,6 +300,17 @@ action="store_true", help="Whether to use the latest checkpoint", ) +parser.add_argument( + "--do_not_save_config", + action="store_true", + default=False, + help="Generate example YAML configuration file", +) +parser.add_argument( + "--use_l2_recon_loss", + action="store_true", + help="Use F.mse_loss instead of F.l1_loss.", +) @dataclass @@ -361,7 +372,6 @@ class Arguments: use_l2_recon_loss: bool = False debug: bool = False config_path: Optional[str] = None - generate_config: bool = False def preprocess_webdataset(args, image): @@ -374,13 +384,6 @@ def main(): if args.config_path: print("Using config file and ignoring CLI args") - if args.generate_config: - conf = OmegaConf.structured(args) - - # dumps to file: - with open(args.config_path, "w") as f: - OmegaConf.save(conf, f) - try: conf = OmegaConf.load(args.config_path) conf_keys = conf.keys() @@ -459,6 +462,7 @@ def main(): vq_codebook_dim=args.vq_codebook_dim, vq_codebook_size=args.vq_codebook_size, accelerator=accelerator, + l2_recon_loss=args.use_l2_recon_loss, ) if args.latest_checkpoint: @@ -582,6 +586,7 @@ def main(): num_cycles=args.num_cycles, scheduler_power=args.scheduler_power, num_epochs=args.num_epochs, + args=args, ) trainer.train() From 2c0f27ff57c87da9ea30558334f0aae7b4a9832e Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Tue, 13 Jun 2023 05:42:07 -0700 Subject: [PATCH 02/16] Fixed missing arguments on the train_muse_vae.py script. --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 5824ddc..0930e6e 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -305,16 +305,15 @@ def train(self): if (steps % self.save_results_every) == 0: self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}" + f"\n[E{epoch + 1}][{steps}] | Logging validation images to {str(self.results_dir)}\n" ) + self.log_validation_images(logs, steps) # save model every so often self.accelerator.wait_for_everyone() if self.is_main_process and (steps % self.save_model_every) == 0: - self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}" - ) + self.accelerator.print(f"\nStep: {steps - 1} | Saving model to {str(self.results_dir)}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" From bd06d9dd6453eba08fe43ff03e0426a63eb31ed8 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Tue, 13 Jun 2023 21:12:48 -0700 Subject: [PATCH 03/16] Small fix to the step counter on the vae trainer when saving the model to disk, the print was showing the counter as having 1 step less than it actually had. --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 0930e6e..7768622 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -313,7 +313,7 @@ def train(self): # save model every so often self.accelerator.wait_for_everyone() if self.is_main_process and (steps % self.save_model_every) == 0: - self.accelerator.print(f"\nStep: {steps - 1} | Saving model to {str(self.results_dir)}") + self.accelerator.print(f"\nStep: {steps} | Saving model to {str(self.results_dir)}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" From 764a771ddbcd795bcca08bde50880ccbe1ffd436 Mon Sep 17 00:00:00 2001 From: Andi Powers Holmes Date: Wed, 14 Jun 2023 22:24:57 +1000 Subject: [PATCH 04/16] Update .pre-commit-config.yaml add config for https://pre-commit.ci --- .pre-commit-config.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f63f19..5174c96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,11 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks +ci: + autofix_prs: true + autoupdate_branch: 'dev' + autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' + autoupdate_schedule: weekly + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 From e832c1267f2124c7e9bafb40a834a887ea568bba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jun 2023 12:26:20 +0000 Subject: [PATCH 05/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- muse_maskgit_pytorch/vqgan_vae_taming.py | 3 +-- train_muse_maskgit.py | 3 +-- train_muse_vae.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/muse_maskgit_pytorch/vqgan_vae_taming.py b/muse_maskgit_pytorch/vqgan_vae_taming.py index 2479161..1eeaf00 100644 --- a/muse_maskgit_pytorch/vqgan_vae_taming.py +++ b/muse_maskgit_pytorch/vqgan_vae_taming.py @@ -10,11 +10,10 @@ from accelerate import Accelerator from einops import rearrange from omegaconf import DictConfig, OmegaConf +from taming.models.vqgan import VQModel from torch import nn from tqdm_loggable.auto import tqdm -from taming.models.vqgan import VQModel - # constants CACHE_PATH = Path.home().joinpath(".cache/taming") diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 40ecdf4..b7697ff 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -10,6 +10,7 @@ import datasets import diffusers import transformers +import wandb from accelerate.utils import ProjectConfiguration from datasets import load_dataset from diffusers.optimization import SchedulerType, get_scheduler @@ -17,8 +18,6 @@ from rich import inspect from torch.optim import Optimizer -import wandb - try: import torch_xla import torch_xla.core.xla_model as xm diff --git a/train_muse_vae.py b/train_muse_vae.py index 0eaeab5..d395ab7 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -5,12 +5,11 @@ from dataclasses import dataclass from typing import Optional, Union +import wandb from accelerate.utils import ProjectConfiguration from datasets import load_dataset from omegaconf import OmegaConf -import wandb - from muse_maskgit_pytorch import ( VQGanVAE, VQGanVAETaming, From cc9c5a598865ba32ac55a42c927b409b572327b3 Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Wed, 14 Jun 2023 20:40:14 +1000 Subject: [PATCH 06/16] add wandb, data, output dirs to gitignore --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index 7d149a3..d382f6b 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,10 @@ dmypy.json # setuptools-scm version file muse_maskgit_pytorch/_version.py + +# wandb dir +/wandb/ + +# data, output +/data/ +/output/ From 769fe2f807ad33b0af884a51eb255f04b36c005b Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Wed, 14 Jun 2023 20:42:52 +1000 Subject: [PATCH 07/16] add shiny VQVAE impl and test script --- muse_maskgit_pytorch/modules/__init__.py | 10 ++ muse_maskgit_pytorch/modules/attention.py | 98 +++++++++++ muse_maskgit_pytorch/modules/mlp.py | 56 ++++++ muse_maskgit_pytorch/vqvae/__init__.py | 7 + muse_maskgit_pytorch/vqvae/config.py | 64 +++++++ muse_maskgit_pytorch/vqvae/discriminator.py | 80 +++++++++ muse_maskgit_pytorch/vqvae/layers.py | 181 ++++++++++++++++++++ muse_maskgit_pytorch/vqvae/quantize.py | 44 +++++ muse_maskgit_pytorch/vqvae/vqvae.py | 48 ++++++ scripts/vqvae_test.py | 78 +++++++++ 10 files changed, 666 insertions(+) create mode 100644 muse_maskgit_pytorch/modules/__init__.py create mode 100644 muse_maskgit_pytorch/modules/attention.py create mode 100644 muse_maskgit_pytorch/modules/mlp.py create mode 100644 muse_maskgit_pytorch/vqvae/__init__.py create mode 100644 muse_maskgit_pytorch/vqvae/config.py create mode 100644 muse_maskgit_pytorch/vqvae/discriminator.py create mode 100644 muse_maskgit_pytorch/vqvae/layers.py create mode 100644 muse_maskgit_pytorch/vqvae/quantize.py create mode 100644 muse_maskgit_pytorch/vqvae/vqvae.py create mode 100644 scripts/vqvae_test.py diff --git a/muse_maskgit_pytorch/modules/__init__.py b/muse_maskgit_pytorch/modules/__init__.py new file mode 100644 index 0000000..eb619ff --- /dev/null +++ b/muse_maskgit_pytorch/modules/__init__.py @@ -0,0 +1,10 @@ +from .attention import CrossAttention, MemoryEfficientCrossAttention +from .mlp import SwiGLU, SwiGLUFFN, SwiGLUFFNFused + +__all__ = [ + "SwiGLU", + "SwiGLUFFN", + "SwiGLUFFNFused", + "CrossAttention", + "MemoryEfficientCrossAttention", +] diff --git a/muse_maskgit_pytorch/modules/attention.py b/muse_maskgit_pytorch/modules/attention.py new file mode 100644 index 0000000..5945278 --- /dev/null +++ b/muse_maskgit_pytorch/modules/attention.py @@ -0,0 +1,98 @@ +from inspect import isfunction +from typing import Any, Optional + +from einops import rearrange +from torch import nn + +try: + from xformers.ops import memory_efficient_attention +except ImportError: + memory_efficient_attention = None + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q = q * self.scale + + sim = q @ k.transpose(-2, -1) + sim = sim.softmax(dim=-1) + + out = sim @ v + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + out = memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) diff --git a/muse_maskgit_pytorch/modules/mlp.py b/muse_maskgit_pytorch/modules/mlp.py new file mode 100644 index 0000000..7dd35dd --- /dev/null +++ b/muse_maskgit_pytorch/modules/mlp.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU +except ImportError: + SwiGLU = SwiGLUFFN + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/muse_maskgit_pytorch/vqvae/__init__.py b/muse_maskgit_pytorch/vqvae/__init__.py new file mode 100644 index 0000000..9e7f389 --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/__init__.py @@ -0,0 +1,7 @@ +from .config import VQVAEConfig +from .vqvae import VQVAE + +__all__ = [ + "VQVAE", + "VQVAEConfig", +] diff --git a/muse_maskgit_pytorch/vqvae/config.py b/muse_maskgit_pytorch/vqvae/config.py new file mode 100644 index 0000000..dadd8ba --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/config.py @@ -0,0 +1,64 @@ +from typing import Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Field + + +class EncoderConfig(BaseModel): + image_size: int = Field(...) + patch_size: int = Field(...) + dim: int = Field(...) + depth: int = Field(...) + num_head: int = Field(...) + mlp_dim: int = Field(...) + in_channels: int = Field(...) + dim_head: int = Field(...) + dropout: float = Field(...) + + +class DecoderConfig(BaseModel): + image_size: int = Field(...) + patch_size: int = Field(...) + dim: int = Field(...) + depth: int = Field(...) + num_head: int = Field(...) + mlp_dim: int = Field(...) + out_channels: int = Field(...) + dim_head: int = Field(...) + dropout: float = Field(...) + + +class VQVAEConfig(BaseModel): + n_embed: int = Field(...) + embed_dim: int = Field(...) + beta: float = Field(...) + enc: EncoderConfig = Field(...) + dec: DecoderConfig = Field(...) + + +VIT_S_CONFIG = VQVAEConfig( + n_embed=8192, + embed_dim=32, + beta=0.25, + enc=EncoderConfig( + image_size=256, + patch_size=8, + dim=512, + depth=8, + num_head=8, + mlp_dim=2048, + in_channels=3, + dim_head=64, + dropout=0.0, + ), + dec=DecoderConfig( + image_size=256, + patch_size=8, + dim=512, + depth=8, + num_head=8, + mlp_dim=2048, + out_channels=3, + dim_head=64, + dropout=0.0, + ), +) diff --git a/muse_maskgit_pytorch/vqvae/discriminator.py b/muse_maskgit_pytorch/vqvae/discriminator.py new file mode 100644 index 0000000..fcf0e2a --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/discriminator.py @@ -0,0 +1,80 @@ +import functools + +import torch.nn as nn + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super().__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + self.apply(self.init_func) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + def init_func(self, m): # define the initialization function + init_gain = 0.02 + classname = m.__class__.__name__ + if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): + nn.init.normal_(m.weight.data, 0.0, init_gain) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif ( + classname.find("BatchNorm2d") != -1 + ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + nn.init.normal_(m.weight.data, 1.0, init_gain) + nn.init.constant_(m.bias.data, 0.0) diff --git a/muse_maskgit_pytorch/vqvae/layers.py b/muse_maskgit_pytorch/vqvae/layers.py new file mode 100644 index 0000000..c0278fb --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/layers.py @@ -0,0 +1,181 @@ +import torch +from diffusers.utils import is_xformers_available +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +from muse_maskgit_pytorch.modules import CrossAttention, MemoryEfficientCrossAttention, SwiGLUFFNFused + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +class FeedForward(nn.Module): + def __init__(self, dim, mlp_dim, dropout=0.0): + super().__init__() + self.w_1 = nn.Linear(dim, mlp_dim) + self.act = nn.GELU() + self.dropout = nn.Dropout(p=dropout) + self.w_2 = nn.Linear(mlp_dim, dim) + + def forward(self, x): + x = self.w_1(x) + x = self.act(x) + x = self.dropout(x) + x = self.w_2(x) + + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Layer(nn.Module): + ATTENTION_MODES = {"vanilla": CrossAttention, "xformer": MemoryEfficientCrossAttention} + + def __init__(self, dim, dim_head, mlp_dim, num_head=8, dropout=0.0): + super().__init__() + attn_mode = "xformer" if is_xformers_available() else "vanilla" + attn_cls = self.ATTENTION_MODES[attn_mode] + self.norm1 = nn.LayerNorm(dim) + self.attn1 = attn_cls(query_dim=dim, heads=num_head, dim_head=dim_head, dropout=dropout) + self.norm2 = nn.LayerNorm(dim) + self.ffnet = SwiGLUFFNFused(in_features=dim, hidden_features=mlp_dim) + + def forward(self, x): + x = self.attn1(self.norm1(x)) + x + x = self.ffnet(self.norm2(x)) + x + + return x + + +class Transformer(nn.Module): + def __init__(self, dim, depth, num_head, dim_head, mlp_dim, dropout=0.0): + super().__init__() + self.layers = nn.Sequential(*[Layer(dim, dim_head, mlp_dim, num_head, dropout) for i in range(depth)]) + + def forward(self, x): + x = self.layers(x) + + return x + + +class Encoder(nn.Module): + def __init__( + self, + image_size, + patch_size, + dim, + depth, + num_head, + mlp_dim, + in_channels=3, + out_channels=3, + dim_head=64, + dropout=0.0, + ): + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + + assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size." + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=False), + Rearrange("b c h w -> b (h w) c"), + ) + + scale = dim**-0.5 + num_patches = (image_size // patch_size) ** 2 + self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale) + self.norm_pre = nn.LayerNorm(dim) + self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout) + + self.initialize_weights() + + def initialize_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.to_patch_embedding(x) + x = x + self.position_embedding + x = self.norm_pre(x) + x = self.transformer(x) + + return x + + +class Decoder(nn.Module): + def __init__( + self, + image_size, + patch_size, + dim, + depth, + num_head, + mlp_dim, + in_channels=3, + out_channels=3, + dim_head=64, + dropout=0.0, + ): + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + + assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size." + + scale = dim**-0.5 + num_patches = (image_size // patch_size) ** 2 + self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale) + self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout) + self.norm = nn.LayerNorm(dim) + self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True) + + self.initialize_weights() + + def initialize_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = x + self.position_embedding + x = self.transformer(x) + x = self.norm(x) + x = self.proj(x) + x = rearrange( + x, + "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", + h=self.image_size // self.patch_size, + p1=self.patch_size, + p2=self.patch_size, + ) + + return x diff --git a/muse_maskgit_pytorch/vqvae/quantize.py b/muse_maskgit_pytorch/vqvae/quantize.py new file mode 100644 index 0000000..97453ad --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/quantize.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VectorQuantize(nn.Module): + def __init__(self, n_e, vq_embed_dim, beta=0.25): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.normal_() + + def forward(self, z): + z = F.normalize(z, p=2, dim=-1) + z_flattened = z.view(-1, self.vq_embed_dim) + embed_norm = F.normalize(self.embedding.weight, p=2, dim=-1) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embed_norm**2, dim=1) + - 2 * torch.einsum("bd,nd->bn", z_flattened, embed_norm) + ) + + encoding_indices = torch.argmin(d, dim=1).view(*z.shape[:-1]) + z_q = self.embedding(encoding_indices).view(z.shape) + z_q = F.normalize(z_q, p=2, dim=-1) + + # compute loss for embedding + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + return z_q, loss, encoding_indices + + def get_codebook_entry(self, indices): + z_q = self.embedding(indices) + z_q = F.normalize(z_q, p=2, dim=-1) + + return z_q diff --git a/muse_maskgit_pytorch/vqvae/vqvae.py b/muse_maskgit_pytorch/vqvae/vqvae.py new file mode 100644 index 0000000..155bc7e --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/vqvae.py @@ -0,0 +1,48 @@ +import logging + +import torch +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config + +from .layers import Decoder, Encoder +from .quantize import VectorQuantize + +logger = logging.getLogger(__name__) + + +class VQVAE(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, n_embed, embed_dim, beta, enc, dec, **kwargs): + super().__init__() + self.encoder = Encoder(**enc) + self.decoder = Decoder(**dec) + + self.prev_quant = nn.Linear(enc["dim"], embed_dim) + self.quantize = VectorQuantize(n_embed, embed_dim, beta) + self.post_quant = nn.Linear(embed_dim, dec["dim"]) + + def freeze(self): + self.eval() + self.requires_grad_(False) + + def encode(self, x): + x = self.encoder(x) + x = self.prev_quant(x) + x, loss, indices = self.quantize(x) + return x, loss, indices + + def decode(self, x): + x = self.post_quant(x) + x = self.decoder(x) + return x.clamp(-1.0, 1.0) + + def forward(self, img: torch.FloatTensor): + z, loss, indices = self.encode(img) + rec = self.decode(z) + return rec, loss + + def decode_from_ids(self, indice): + z_q = self.quantize.get_codebook_entry(indice) + img = self.decode(z_q) + return img diff --git a/scripts/vqvae_test.py b/scripts/vqvae_test.py new file mode 100644 index 0000000..36f23d1 --- /dev/null +++ b/scripts/vqvae_test.py @@ -0,0 +1,78 @@ +import logging +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms as T +from torchvision.utils import save_image + +from muse_maskgit_pytorch.vqvae import VQVAE + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# where to find the model and the test images +model_repo = "neggles/vaedump" +model_subdir = "vit-s-vqgan-f4" +test_images = ["testimg_1.png", "testimg_2.png"] + +# where to save the preprocessed and reconstructed images +image_dir = Path.cwd().joinpath("temp") +image_dir.mkdir(exist_ok=True, parents=True) + +# image transforms for the VQVAE +transform_enc = T.Compose([T.Resize(512), T.RandomCrop(256), T.ToTensor()]) +transform_dec = T.Compose([T.ConvertImageDtype(torch.uint8), T.ToPILImage()]) + + +def get_save_path(path: Path, append: str) -> Path: + return path.with_stem(f"{path.stem}-{append}") + + +def main(): + torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + dtype = torch.float32 + + # load VAE + logger.info(f"Loading VQVAE from {model_repo}/{model_subdir}...") + vae: VQVAE = VQVAE.from_pretrained(model_repo, subfolder=model_subdir, torch_dtype=dtype) + vae = vae.to(torch_device) + logger.info(f"Loaded VQVAE from {model_repo} to {vae.device} with dtype {vae.dtype}") + + # download and process images + for image in test_images: + image_path = hf_hub_download(model_repo, subfolder="images", filename=image, output_dir=image_dir) + image_path = Path(image_path) + logger.info(f"Downloaded {image_path}, size {image_path.stat().st_size} bytes") + + # preprocess + image_obj = Image.open(image_path).convert("RGB") + image_tensor: torch.Tensor = transform_enc(image_obj) + save_path = get_save_path(image_path, "prepro") + save_image(image_tensor, save_path, normalize=True, range=(-1.0, 1.0)) + logger.info(f"Saved preprocessed image to {save_path}") + + # encode + encoded, _, _ = vae.encode(image_tensor.unsqueeze(0).to(vae.device)) + + # decode + reconstructed = vae.decode(encoded).squeeze(0) + reconstructed = torch.clamp(reconstructed, -1.0, 1.0) + + # save + save_path = get_save_path(image_path, "recon") + save_image(reconstructed, save_path, normalize=True, range=(-1.0, 1.0)) + logger.info(f"Saved reconstructed image to {save_path}") + + # compare + image_prepro = transform_dec(image_tensor) + image_recon = transform_dec(reconstructed) + canvas = Image.new("RGB", (512 + 12, 256 + 8), (248, 248, 242)) + canvas.paste(image_prepro, (4, 4)) + canvas.paste(image_recon, (256 + 8, 4)) + save_path = get_save_path(image_path, "compare") + canvas.save(save_path) + logger.info(f"Saved comparison image to {save_path}") + + logger.info("Done!") From 2361789dc0d47766017138aebb68b88051aca62e Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Thu, 15 Jun 2023 12:06:04 +1000 Subject: [PATCH 08/16] fix scripts/vqvae_test.py --- scripts/vqvae_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/vqvae_test.py b/scripts/vqvae_test.py index 36f23d1..1349fb6 100644 --- a/scripts/vqvae_test.py +++ b/scripts/vqvae_test.py @@ -27,7 +27,9 @@ def get_save_path(path: Path, append: str) -> Path: - return path.with_stem(f"{path.stem}-{append}") + # append a string to the filename before the extension + # n.b. only keeps the final suffix, e.g. "foo.xyz.png" -> "foo-prepro.png" + return path.with_name(f"{path.stem}-{append}{path.suffix}") def main(): @@ -42,7 +44,7 @@ def main(): # download and process images for image in test_images: - image_path = hf_hub_download(model_repo, subfolder="images", filename=image, output_dir=image_dir) + image_path = hf_hub_download(model_repo, subfolder="images", filename=image, local_dir=image_dir) image_path = Path(image_path) logger.info(f"Downloaded {image_path}, size {image_path.stat().st_size} bytes") @@ -76,3 +78,7 @@ def main(): logger.info(f"Saved comparison image to {save_path}") logger.info("Done!") + + +if __name__ == "__main__": + main() From 4a3c5aec6525a6409757a7a950c80d4aac759f7f Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Thu, 15 Jun 2023 12:21:25 +1000 Subject: [PATCH 09/16] vqvae: rename some components and adjust some formatting --- muse_maskgit_pytorch/modules/attention.py | 22 ++++++++++++++++++---- muse_maskgit_pytorch/vqvae/config.py | 2 -- muse_maskgit_pytorch/vqvae/layers.py | 2 -- muse_maskgit_pytorch/vqvae/quantize.py | 2 +- muse_maskgit_pytorch/vqvae/vqvae.py | 17 ++++++++++++----- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/muse_maskgit_pytorch/modules/attention.py b/muse_maskgit_pytorch/modules/attention.py index 5945278..c12e18d 100644 --- a/muse_maskgit_pytorch/modules/attention.py +++ b/muse_maskgit_pytorch/modules/attention.py @@ -1,5 +1,5 @@ from inspect import isfunction -from typing import Any, Optional +from typing import Any, Callable, Optional from einops import rearrange from torch import nn @@ -21,7 +21,14 @@ def default(val, d): class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -56,7 +63,14 @@ def forward(self, x, context=None): class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -69,7 +83,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None + self.attention_op: Optional[Callable] = None def forward(self, x, context=None): q = self.to_q(x) diff --git a/muse_maskgit_pytorch/vqvae/config.py b/muse_maskgit_pytorch/vqvae/config.py index dadd8ba..ed2605a 100644 --- a/muse_maskgit_pytorch/vqvae/config.py +++ b/muse_maskgit_pytorch/vqvae/config.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Optional, Tuple, Union - from pydantic import BaseModel, Field diff --git a/muse_maskgit_pytorch/vqvae/layers.py b/muse_maskgit_pytorch/vqvae/layers.py index c0278fb..80a6cd5 100644 --- a/muse_maskgit_pytorch/vqvae/layers.py +++ b/muse_maskgit_pytorch/vqvae/layers.py @@ -78,7 +78,6 @@ def __init__( num_head, mlp_dim, in_channels=3, - out_channels=3, dim_head=64, dropout=0.0, ): @@ -132,7 +131,6 @@ def __init__( depth, num_head, mlp_dim, - in_channels=3, out_channels=3, dim_head=64, dropout=0.0, diff --git a/muse_maskgit_pytorch/vqvae/quantize.py b/muse_maskgit_pytorch/vqvae/quantize.py index 97453ad..fd81ea7 100644 --- a/muse_maskgit_pytorch/vqvae/quantize.py +++ b/muse_maskgit_pytorch/vqvae/quantize.py @@ -37,7 +37,7 @@ def forward(self, z): return z_q, loss, encoding_indices - def get_codebook_entry(self, indices): + def decode_ids(self, indices): z_q = self.embedding(indices) z_q = F.normalize(z_q, p=2, dim=-1) diff --git a/muse_maskgit_pytorch/vqvae/vqvae.py b/muse_maskgit_pytorch/vqvae/vqvae.py index 155bc7e..68320df 100644 --- a/muse_maskgit_pytorch/vqvae/vqvae.py +++ b/muse_maskgit_pytorch/vqvae/vqvae.py @@ -19,7 +19,7 @@ def __init__(self, n_embed, embed_dim, beta, enc, dec, **kwargs): self.decoder = Decoder(**dec) self.prev_quant = nn.Linear(enc["dim"], embed_dim) - self.quantize = VectorQuantize(n_embed, embed_dim, beta) + self.quantizer = VectorQuantize(n_embed, embed_dim, beta) self.post_quant = nn.Linear(embed_dim, dec["dim"]) def freeze(self): @@ -29,7 +29,7 @@ def freeze(self): def encode(self, x): x = self.encoder(x) x = self.prev_quant(x) - x, loss, indices = self.quantize(x) + x, loss, indices = self.quantizer(x) return x, loss, indices def decode(self, x): @@ -37,12 +37,19 @@ def decode(self, x): x = self.decoder(x) return x.clamp(-1.0, 1.0) - def forward(self, img: torch.FloatTensor): - z, loss, indices = self.encode(img) + def forward(self, inputs: torch.FloatTensor): + z, loss, _ = self.encode(inputs) rec = self.decode(z) return rec, loss + def encode_to_ids(self, inputs): + _, _, indices = self.encode(inputs) + return indices + def decode_from_ids(self, indice): - z_q = self.quantize.get_codebook_entry(indice) + z_q = self.quantizer.decode_ids(indice) img = self.decode(z_q) return img + + def __call__(self, inputs: torch.FloatTensor): + return self.forward(inputs) From 0e07e175aeccaa1034bf41d4417bc5271e0a8490 Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Thu, 15 Jun 2023 12:51:15 +1000 Subject: [PATCH 10/16] fix vqvae weight load --- muse_maskgit_pytorch/vqvae/vqvae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/muse_maskgit_pytorch/vqvae/vqvae.py b/muse_maskgit_pytorch/vqvae/vqvae.py index 68320df..23a31c6 100644 --- a/muse_maskgit_pytorch/vqvae/vqvae.py +++ b/muse_maskgit_pytorch/vqvae/vqvae.py @@ -19,7 +19,7 @@ def __init__(self, n_embed, embed_dim, beta, enc, dec, **kwargs): self.decoder = Decoder(**dec) self.prev_quant = nn.Linear(enc["dim"], embed_dim) - self.quantizer = VectorQuantize(n_embed, embed_dim, beta) + self.quantize = VectorQuantize(n_embed, embed_dim, beta) self.post_quant = nn.Linear(embed_dim, dec["dim"]) def freeze(self): @@ -29,7 +29,7 @@ def freeze(self): def encode(self, x): x = self.encoder(x) x = self.prev_quant(x) - x, loss, indices = self.quantizer(x) + x, loss, indices = self.quantize(x) return x, loss, indices def decode(self, x): @@ -47,7 +47,7 @@ def encode_to_ids(self, inputs): return indices def decode_from_ids(self, indice): - z_q = self.quantizer.decode_ids(indice) + z_q = self.quantize.decode_ids(indice) img = self.decode(z_q) return img From e51716b10e4c4f5a519d1c050eca00fc73ab7831 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Thu, 15 Jun 2023 01:41:43 -0700 Subject: [PATCH 11/16] Fixed issue where the training would stop early than intended. --- muse_maskgit_pytorch/trainers/maskgit_trainer.py | 2 +- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index cba2ba2..d1f7939 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -263,7 +263,7 @@ def train(self): self.steps += 1 - if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): + if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: if self.on_tpu: self.accelerator.print( f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 7768622..f082232 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -340,9 +340,9 @@ def train(self): self.steps += 1 - if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): + if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." + f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." ) break From 6daee002da246ec76aef81f75f230f713b4d31a6 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Fri, 16 Jun 2023 04:47:05 -0700 Subject: [PATCH 12/16] Changed the pretrained "pretrained=True" to be "weights=torchvision.models.VGG16_Weights.DEFAULT" so it is compatible with latest version of torchvision and also so we always get the latest version of the weights. --- muse_maskgit_pytorch/vqgan_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index f44814a..56c58f2 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -407,7 +407,7 @@ def vgg(self): if exists(self._vgg): return self._vgg - vgg = torchvision.models.vgg16(pretrained=True) + vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT) vgg.classifier = nn.Sequential(*vgg.classifier[:-2]) self._vgg = vgg.to(self.device) return self._vgg From 26f38e57d9e4dc3006cf7c7441bac1be93c81764 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sat, 17 Jun 2023 10:07:02 -0700 Subject: [PATCH 13/16] Create FUNDING.yml --- .github/.github/FUNDING.yml | 1 + 1 file changed, 1 insertion(+) create mode 100644 .github/.github/FUNDING.yml diff --git a/.github/.github/FUNDING.yml b/.github/.github/FUNDING.yml new file mode 100644 index 0000000..2b7c680 --- /dev/null +++ b/.github/.github/FUNDING.yml @@ -0,0 +1 @@ +github: [ZeroCool940711] From 0a8847fe2c61ad53ea9de4740076f07c2c2c1574 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sat, 17 Jun 2023 10:08:53 -0700 Subject: [PATCH 14/16] Removed some code that was causing the training to stop early with some multi processes scenarios. --- .../trainers/maskgit_trainer.py | 22 +++++++++---------- .../trainers/vqvae_trainers.py | 12 +++++----- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index d1f7939..b4c83c8 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -263,17 +263,17 @@ def train(self): self.steps += 1 - if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: - if self.on_tpu: - self.accelerator.print( - f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" - f"[STOP EARLY]: Stopping training early..." - ) - else: - self.info_bar.set_description_str( - f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." - ) - break + #if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: + #if self.on_tpu: + #self.accelerator.print( + #f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" + #f"[STOP EARLY]: Stopping training early..." + #) + #else: + #self.info_bar.set_description_str( + #f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." + #) + #break # loop complete, save final model self.accelerator.print( diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index f082232..3d931bd 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -305,7 +305,7 @@ def train(self): if (steps % self.save_results_every) == 0: self.accelerator.print( - f"\n[E{epoch + 1}][{steps}] | Logging validation images to {str(self.results_dir)}\n" + f"\n[E{epoch + 1}][{steps}] | Logging validation images to {str(self.results_dir)}" ) self.log_validation_images(logs, steps) @@ -340,11 +340,11 @@ def train(self): self.steps += 1 - if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: - self.accelerator.print( - f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." - ) - break + #if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: + #self.accelerator.print( + #f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." + #) + #break # Loop finished, save model self.accelerator.wait_for_everyone() From 5048d1507d15c3ffe0f2c9b5d56cce26a6050f34 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sat, 17 Jun 2023 10:09:44 -0700 Subject: [PATCH 15/16] Added options to the infer_vae.py for using the paintmind vae, this is a WIP and do not work properly. --- infer_vae.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/infer_vae.py b/infer_vae.py index 88aea18..cbf84cf 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -25,6 +25,7 @@ ImageDataset, get_dataset_from_dataroot, ) +from muse_maskgit_pytorch.vqvae import VQVAE # Create the parser parser = argparse.ArgumentParser() @@ -189,6 +190,11 @@ action="store_true", help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.", ) +parser.add_argument( + "--use_paintmind", + action="store_true", + help="Use PaintMind VAE..", +) @dataclass @@ -336,7 +342,7 @@ def main(): if args.vae_path and args.taming_model_path: raise Exception("You can't pass vae_path and taming args at the same time.") - if args.vae_path: + if args.vae_path and not args.use_paintmind: accelerator.print("Loading Muse VQGanVAE") vae = VQGanVAE( dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim @@ -390,6 +396,11 @@ def main(): vae.load(args.vae_path) + if args.use_paintmind: + # load VAE + accelerator.print(f"Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...") + vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4") + elif args.taming_model_path: print("Loading Taming VQGanVAE") vae = VQGanVAETaming( @@ -398,7 +409,10 @@ def main(): ) args.num_tokens = vae.codebook_size args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 + + # move vae to device vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + # then you plug the vae and transformer into your MaskGit as so dataset = ImageDataset( @@ -449,11 +463,21 @@ def main(): try: save_image(dataset[i], f"{output_dir}/input.png") - _, ids, _ = vae.encode( - dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") - ) - recon = vae.decode_from_ids(ids) - save_image(recon, f"{output_dir}/output.png") + if not args.use_paintmind: + # encode + _, ids, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + # decode + recon = vae.decode_from_ids(ids) + #print (recon.shape) # torch.Size([1, 3, 512, 1136]) + save_image(recon, f"{output_dir}/output.png") + else: + # encode + encoded, _, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + + # decode + recon = vae.decode(encoded).squeeze(0) + recon = torch.clamp(recon, -1.0, 1.0) + save_image(recon, f"{output_dir}/output.png") # Load input and output images input_image = PIL.Image.open(f"{output_dir}/input.png") @@ -495,7 +519,10 @@ def main(): continue # Retry the loop else: - print(f"Skipping image {i} after {retries} retries due to out of memory error") + if"out of memory" not in str(e): + print(e) + else: + print(f"Skipping image {i} after {retries} retries due to out of memory error") break # Exit the retry loop after too many retries From 3b871d747a974cb0a033e9cc3e3d71812e662157 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Jun 2023 17:13:16 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- infer_vae.py | 14 +++++++----- .../trainers/maskgit_trainer.py | 22 +++++++++---------- .../trainers/vqvae_trainers.py | 10 ++++----- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/infer_vae.py b/infer_vae.py index cbf84cf..b6290bd 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -398,7 +398,7 @@ def main(): if args.use_paintmind: # load VAE - accelerator.print(f"Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...") + accelerator.print("Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...") vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4") elif args.taming_model_path: @@ -465,14 +465,18 @@ def main(): if not args.use_paintmind: # encode - _, ids, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + _, ids, _ = vae.encode( + dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ) # decode recon = vae.decode_from_ids(ids) - #print (recon.shape) # torch.Size([1, 3, 512, 1136]) + # print (recon.shape) # torch.Size([1, 3, 512, 1136]) save_image(recon, f"{output_dir}/output.png") else: # encode - encoded, _, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + encoded, _, _ = vae.encode( + dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ) # decode recon = vae.decode(encoded).squeeze(0) @@ -519,7 +523,7 @@ def main(): continue # Retry the loop else: - if"out of memory" not in str(e): + if "out of memory" not in str(e): print(e) else: print(f"Skipping image {i} after {retries} retries due to out of memory error") diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index b4c83c8..69eaeb5 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -263,17 +263,17 @@ def train(self): self.steps += 1 - #if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: - #if self.on_tpu: - #self.accelerator.print( - #f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" - #f"[STOP EARLY]: Stopping training early..." - #) - #else: - #self.info_bar.set_description_str( - #f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." - #) - #break + # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: + # if self.on_tpu: + # self.accelerator.print( + # f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" + # f"[STOP EARLY]: Stopping training early..." + # ) + # else: + # self.info_bar.set_description_str( + # f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." + # ) + # break # loop complete, save final model self.accelerator.print( diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 3d931bd..f857171 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -340,11 +340,11 @@ def train(self): self.steps += 1 - #if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: - #self.accelerator.print( - #f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." - #) - #break + # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: + # self.accelerator.print( + # f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." + # ) + # break # Loop finished, save model self.accelerator.wait_for_everyone()