diff --git a/.github/.github/FUNDING.yml b/.github/.github/FUNDING.yml new file mode 100644 index 0000000..2b7c680 --- /dev/null +++ b/.github/.github/FUNDING.yml @@ -0,0 +1 @@ +github: [ZeroCool940711] diff --git a/.gitignore b/.gitignore index 7d149a3..d382f6b 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,10 @@ dmypy.json # setuptools-scm version file muse_maskgit_pytorch/_version.py + +# wandb dir +/wandb/ + +# data, output +/data/ +/output/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f63f19..5174c96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,11 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks +ci: + autofix_prs: true + autoupdate_branch: 'dev' + autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' + autoupdate_schedule: weekly + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/infer_vae.py b/infer_vae.py index 88aea18..b6290bd 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -25,6 +25,7 @@ ImageDataset, get_dataset_from_dataroot, ) +from muse_maskgit_pytorch.vqvae import VQVAE # Create the parser parser = argparse.ArgumentParser() @@ -189,6 +190,11 @@ action="store_true", help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.", ) +parser.add_argument( + "--use_paintmind", + action="store_true", + help="Use PaintMind VAE..", +) @dataclass @@ -336,7 +342,7 @@ def main(): if args.vae_path and args.taming_model_path: raise Exception("You can't pass vae_path and taming args at the same time.") - if args.vae_path: + if args.vae_path and not args.use_paintmind: accelerator.print("Loading Muse VQGanVAE") vae = VQGanVAE( dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim @@ -390,6 +396,11 @@ def main(): vae.load(args.vae_path) + if args.use_paintmind: + # load VAE + accelerator.print("Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...") + vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4") + elif args.taming_model_path: print("Loading Taming VQGanVAE") vae = VQGanVAETaming( @@ -398,7 +409,10 @@ def main(): ) args.num_tokens = vae.codebook_size args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 + + # move vae to device vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + # then you plug the vae and transformer into your MaskGit as so dataset = ImageDataset( @@ -449,11 +463,25 @@ def main(): try: save_image(dataset[i], f"{output_dir}/input.png") - _, ids, _ = vae.encode( - dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") - ) - recon = vae.decode_from_ids(ids) - save_image(recon, f"{output_dir}/output.png") + if not args.use_paintmind: + # encode + _, ids, _ = vae.encode( + dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ) + # decode + recon = vae.decode_from_ids(ids) + # print (recon.shape) # torch.Size([1, 3, 512, 1136]) + save_image(recon, f"{output_dir}/output.png") + else: + # encode + encoded, _, _ = vae.encode( + dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ) + + # decode + recon = vae.decode(encoded).squeeze(0) + recon = torch.clamp(recon, -1.0, 1.0) + save_image(recon, f"{output_dir}/output.png") # Load input and output images input_image = PIL.Image.open(f"{output_dir}/input.png") @@ -495,7 +523,10 @@ def main(): continue # Retry the loop else: - print(f"Skipping image {i} after {retries} retries due to out of memory error") + if "out of memory" not in str(e): + print(e) + else: + print(f"Skipping image {i} after {retries} retries due to out of memory error") break # Exit the retry loop after too many retries diff --git a/muse_maskgit_pytorch/modules/__init__.py b/muse_maskgit_pytorch/modules/__init__.py new file mode 100644 index 0000000..eb619ff --- /dev/null +++ b/muse_maskgit_pytorch/modules/__init__.py @@ -0,0 +1,10 @@ +from .attention import CrossAttention, MemoryEfficientCrossAttention +from .mlp import SwiGLU, SwiGLUFFN, SwiGLUFFNFused + +__all__ = [ + "SwiGLU", + "SwiGLUFFN", + "SwiGLUFFNFused", + "CrossAttention", + "MemoryEfficientCrossAttention", +] diff --git a/muse_maskgit_pytorch/modules/attention.py b/muse_maskgit_pytorch/modules/attention.py new file mode 100644 index 0000000..c12e18d --- /dev/null +++ b/muse_maskgit_pytorch/modules/attention.py @@ -0,0 +1,112 @@ +from inspect import isfunction +from typing import Any, Callable, Optional + +from einops import rearrange +from torch import nn + +try: + from xformers.ops import memory_efficient_attention +except ImportError: + memory_efficient_attention = None + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q = q * self.scale + + sim = q @ k.transpose(-2, -1) + sim = sim.softmax(dim=-1) + + out = sim @ v + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Callable] = None + + def forward(self, x, context=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + out = memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) diff --git a/muse_maskgit_pytorch/modules/mlp.py b/muse_maskgit_pytorch/modules/mlp.py new file mode 100644 index 0000000..7dd35dd --- /dev/null +++ b/muse_maskgit_pytorch/modules/mlp.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU +except ImportError: + SwiGLU = SwiGLUFFN + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index cba2ba2..69eaeb5 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -263,17 +263,17 @@ def train(self): self.steps += 1 - if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): - if self.on_tpu: - self.accelerator.print( - f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" - f"[STOP EARLY]: Stopping training early..." - ) - else: - self.info_bar.set_description_str( - f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." - ) - break + # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: + # if self.on_tpu: + # self.accelerator.print( + # f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" + # f"[STOP EARLY]: Stopping training early..." + # ) + # else: + # self.info_bar.set_description_str( + # f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." + # ) + # break # loop complete, save final model self.accelerator.print( diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 5824ddc..f857171 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -305,16 +305,15 @@ def train(self): if (steps % self.save_results_every) == 0: self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}" + f"\n[E{epoch + 1}][{steps}] | Logging validation images to {str(self.results_dir)}" ) + self.log_validation_images(logs, steps) # save model every so often self.accelerator.wait_for_everyone() if self.is_main_process and (steps % self.save_model_every) == 0: - self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}" - ) + self.accelerator.print(f"\nStep: {steps} | Saving model to {str(self.results_dir)}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" @@ -341,11 +340,11 @@ def train(self): self.steps += 1 - if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): - self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." - ) - break + # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: + # self.accelerator.print( + # f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." + # ) + # break # Loop finished, save model self.accelerator.wait_for_everyone() diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index f44814a..56c58f2 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -407,7 +407,7 @@ def vgg(self): if exists(self._vgg): return self._vgg - vgg = torchvision.models.vgg16(pretrained=True) + vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT) vgg.classifier = nn.Sequential(*vgg.classifier[:-2]) self._vgg = vgg.to(self.device) return self._vgg diff --git a/muse_maskgit_pytorch/vqgan_vae_taming.py b/muse_maskgit_pytorch/vqgan_vae_taming.py index 2479161..1eeaf00 100644 --- a/muse_maskgit_pytorch/vqgan_vae_taming.py +++ b/muse_maskgit_pytorch/vqgan_vae_taming.py @@ -10,11 +10,10 @@ from accelerate import Accelerator from einops import rearrange from omegaconf import DictConfig, OmegaConf +from taming.models.vqgan import VQModel from torch import nn from tqdm_loggable.auto import tqdm -from taming.models.vqgan import VQModel - # constants CACHE_PATH = Path.home().joinpath(".cache/taming") diff --git a/muse_maskgit_pytorch/vqvae/__init__.py b/muse_maskgit_pytorch/vqvae/__init__.py new file mode 100644 index 0000000..9e7f389 --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/__init__.py @@ -0,0 +1,7 @@ +from .config import VQVAEConfig +from .vqvae import VQVAE + +__all__ = [ + "VQVAE", + "VQVAEConfig", +] diff --git a/muse_maskgit_pytorch/vqvae/config.py b/muse_maskgit_pytorch/vqvae/config.py new file mode 100644 index 0000000..ed2605a --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/config.py @@ -0,0 +1,62 @@ +from pydantic import BaseModel, Field + + +class EncoderConfig(BaseModel): + image_size: int = Field(...) + patch_size: int = Field(...) + dim: int = Field(...) + depth: int = Field(...) + num_head: int = Field(...) + mlp_dim: int = Field(...) + in_channels: int = Field(...) + dim_head: int = Field(...) + dropout: float = Field(...) + + +class DecoderConfig(BaseModel): + image_size: int = Field(...) + patch_size: int = Field(...) + dim: int = Field(...) + depth: int = Field(...) + num_head: int = Field(...) + mlp_dim: int = Field(...) + out_channels: int = Field(...) + dim_head: int = Field(...) + dropout: float = Field(...) + + +class VQVAEConfig(BaseModel): + n_embed: int = Field(...) + embed_dim: int = Field(...) + beta: float = Field(...) + enc: EncoderConfig = Field(...) + dec: DecoderConfig = Field(...) + + +VIT_S_CONFIG = VQVAEConfig( + n_embed=8192, + embed_dim=32, + beta=0.25, + enc=EncoderConfig( + image_size=256, + patch_size=8, + dim=512, + depth=8, + num_head=8, + mlp_dim=2048, + in_channels=3, + dim_head=64, + dropout=0.0, + ), + dec=DecoderConfig( + image_size=256, + patch_size=8, + dim=512, + depth=8, + num_head=8, + mlp_dim=2048, + out_channels=3, + dim_head=64, + dropout=0.0, + ), +) diff --git a/muse_maskgit_pytorch/vqvae/discriminator.py b/muse_maskgit_pytorch/vqvae/discriminator.py new file mode 100644 index 0000000..fcf0e2a --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/discriminator.py @@ -0,0 +1,80 @@ +import functools + +import torch.nn as nn + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super().__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + self.apply(self.init_func) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + def init_func(self, m): # define the initialization function + init_gain = 0.02 + classname = m.__class__.__name__ + if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): + nn.init.normal_(m.weight.data, 0.0, init_gain) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif ( + classname.find("BatchNorm2d") != -1 + ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + nn.init.normal_(m.weight.data, 1.0, init_gain) + nn.init.constant_(m.bias.data, 0.0) diff --git a/muse_maskgit_pytorch/vqvae/layers.py b/muse_maskgit_pytorch/vqvae/layers.py new file mode 100644 index 0000000..80a6cd5 --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/layers.py @@ -0,0 +1,179 @@ +import torch +from diffusers.utils import is_xformers_available +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +from muse_maskgit_pytorch.modules import CrossAttention, MemoryEfficientCrossAttention, SwiGLUFFNFused + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +class FeedForward(nn.Module): + def __init__(self, dim, mlp_dim, dropout=0.0): + super().__init__() + self.w_1 = nn.Linear(dim, mlp_dim) + self.act = nn.GELU() + self.dropout = nn.Dropout(p=dropout) + self.w_2 = nn.Linear(mlp_dim, dim) + + def forward(self, x): + x = self.w_1(x) + x = self.act(x) + x = self.dropout(x) + x = self.w_2(x) + + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Layer(nn.Module): + ATTENTION_MODES = {"vanilla": CrossAttention, "xformer": MemoryEfficientCrossAttention} + + def __init__(self, dim, dim_head, mlp_dim, num_head=8, dropout=0.0): + super().__init__() + attn_mode = "xformer" if is_xformers_available() else "vanilla" + attn_cls = self.ATTENTION_MODES[attn_mode] + self.norm1 = nn.LayerNorm(dim) + self.attn1 = attn_cls(query_dim=dim, heads=num_head, dim_head=dim_head, dropout=dropout) + self.norm2 = nn.LayerNorm(dim) + self.ffnet = SwiGLUFFNFused(in_features=dim, hidden_features=mlp_dim) + + def forward(self, x): + x = self.attn1(self.norm1(x)) + x + x = self.ffnet(self.norm2(x)) + x + + return x + + +class Transformer(nn.Module): + def __init__(self, dim, depth, num_head, dim_head, mlp_dim, dropout=0.0): + super().__init__() + self.layers = nn.Sequential(*[Layer(dim, dim_head, mlp_dim, num_head, dropout) for i in range(depth)]) + + def forward(self, x): + x = self.layers(x) + + return x + + +class Encoder(nn.Module): + def __init__( + self, + image_size, + patch_size, + dim, + depth, + num_head, + mlp_dim, + in_channels=3, + dim_head=64, + dropout=0.0, + ): + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + + assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size." + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=False), + Rearrange("b c h w -> b (h w) c"), + ) + + scale = dim**-0.5 + num_patches = (image_size // patch_size) ** 2 + self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale) + self.norm_pre = nn.LayerNorm(dim) + self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout) + + self.initialize_weights() + + def initialize_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.to_patch_embedding(x) + x = x + self.position_embedding + x = self.norm_pre(x) + x = self.transformer(x) + + return x + + +class Decoder(nn.Module): + def __init__( + self, + image_size, + patch_size, + dim, + depth, + num_head, + mlp_dim, + out_channels=3, + dim_head=64, + dropout=0.0, + ): + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + + assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size." + + scale = dim**-0.5 + num_patches = (image_size // patch_size) ** 2 + self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale) + self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout) + self.norm = nn.LayerNorm(dim) + self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True) + + self.initialize_weights() + + def initialize_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = x + self.position_embedding + x = self.transformer(x) + x = self.norm(x) + x = self.proj(x) + x = rearrange( + x, + "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", + h=self.image_size // self.patch_size, + p1=self.patch_size, + p2=self.patch_size, + ) + + return x diff --git a/muse_maskgit_pytorch/vqvae/quantize.py b/muse_maskgit_pytorch/vqvae/quantize.py new file mode 100644 index 0000000..fd81ea7 --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/quantize.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VectorQuantize(nn.Module): + def __init__(self, n_e, vq_embed_dim, beta=0.25): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.normal_() + + def forward(self, z): + z = F.normalize(z, p=2, dim=-1) + z_flattened = z.view(-1, self.vq_embed_dim) + embed_norm = F.normalize(self.embedding.weight, p=2, dim=-1) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embed_norm**2, dim=1) + - 2 * torch.einsum("bd,nd->bn", z_flattened, embed_norm) + ) + + encoding_indices = torch.argmin(d, dim=1).view(*z.shape[:-1]) + z_q = self.embedding(encoding_indices).view(z.shape) + z_q = F.normalize(z_q, p=2, dim=-1) + + # compute loss for embedding + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + return z_q, loss, encoding_indices + + def decode_ids(self, indices): + z_q = self.embedding(indices) + z_q = F.normalize(z_q, p=2, dim=-1) + + return z_q diff --git a/muse_maskgit_pytorch/vqvae/vqvae.py b/muse_maskgit_pytorch/vqvae/vqvae.py new file mode 100644 index 0000000..23a31c6 --- /dev/null +++ b/muse_maskgit_pytorch/vqvae/vqvae.py @@ -0,0 +1,55 @@ +import logging + +import torch +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config + +from .layers import Decoder, Encoder +from .quantize import VectorQuantize + +logger = logging.getLogger(__name__) + + +class VQVAE(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, n_embed, embed_dim, beta, enc, dec, **kwargs): + super().__init__() + self.encoder = Encoder(**enc) + self.decoder = Decoder(**dec) + + self.prev_quant = nn.Linear(enc["dim"], embed_dim) + self.quantize = VectorQuantize(n_embed, embed_dim, beta) + self.post_quant = nn.Linear(embed_dim, dec["dim"]) + + def freeze(self): + self.eval() + self.requires_grad_(False) + + def encode(self, x): + x = self.encoder(x) + x = self.prev_quant(x) + x, loss, indices = self.quantize(x) + return x, loss, indices + + def decode(self, x): + x = self.post_quant(x) + x = self.decoder(x) + return x.clamp(-1.0, 1.0) + + def forward(self, inputs: torch.FloatTensor): + z, loss, _ = self.encode(inputs) + rec = self.decode(z) + return rec, loss + + def encode_to_ids(self, inputs): + _, _, indices = self.encode(inputs) + return indices + + def decode_from_ids(self, indice): + z_q = self.quantize.decode_ids(indice) + img = self.decode(z_q) + return img + + def __call__(self, inputs: torch.FloatTensor): + return self.forward(inputs) diff --git a/scripts/vqvae_test.py b/scripts/vqvae_test.py new file mode 100644 index 0000000..1349fb6 --- /dev/null +++ b/scripts/vqvae_test.py @@ -0,0 +1,84 @@ +import logging +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms as T +from torchvision.utils import save_image + +from muse_maskgit_pytorch.vqvae import VQVAE + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# where to find the model and the test images +model_repo = "neggles/vaedump" +model_subdir = "vit-s-vqgan-f4" +test_images = ["testimg_1.png", "testimg_2.png"] + +# where to save the preprocessed and reconstructed images +image_dir = Path.cwd().joinpath("temp") +image_dir.mkdir(exist_ok=True, parents=True) + +# image transforms for the VQVAE +transform_enc = T.Compose([T.Resize(512), T.RandomCrop(256), T.ToTensor()]) +transform_dec = T.Compose([T.ConvertImageDtype(torch.uint8), T.ToPILImage()]) + + +def get_save_path(path: Path, append: str) -> Path: + # append a string to the filename before the extension + # n.b. only keeps the final suffix, e.g. "foo.xyz.png" -> "foo-prepro.png" + return path.with_name(f"{path.stem}-{append}{path.suffix}") + + +def main(): + torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + dtype = torch.float32 + + # load VAE + logger.info(f"Loading VQVAE from {model_repo}/{model_subdir}...") + vae: VQVAE = VQVAE.from_pretrained(model_repo, subfolder=model_subdir, torch_dtype=dtype) + vae = vae.to(torch_device) + logger.info(f"Loaded VQVAE from {model_repo} to {vae.device} with dtype {vae.dtype}") + + # download and process images + for image in test_images: + image_path = hf_hub_download(model_repo, subfolder="images", filename=image, local_dir=image_dir) + image_path = Path(image_path) + logger.info(f"Downloaded {image_path}, size {image_path.stat().st_size} bytes") + + # preprocess + image_obj = Image.open(image_path).convert("RGB") + image_tensor: torch.Tensor = transform_enc(image_obj) + save_path = get_save_path(image_path, "prepro") + save_image(image_tensor, save_path, normalize=True, range=(-1.0, 1.0)) + logger.info(f"Saved preprocessed image to {save_path}") + + # encode + encoded, _, _ = vae.encode(image_tensor.unsqueeze(0).to(vae.device)) + + # decode + reconstructed = vae.decode(encoded).squeeze(0) + reconstructed = torch.clamp(reconstructed, -1.0, 1.0) + + # save + save_path = get_save_path(image_path, "recon") + save_image(reconstructed, save_path, normalize=True, range=(-1.0, 1.0)) + logger.info(f"Saved reconstructed image to {save_path}") + + # compare + image_prepro = transform_dec(image_tensor) + image_recon = transform_dec(reconstructed) + canvas = Image.new("RGB", (512 + 12, 256 + 8), (248, 248, 242)) + canvas.paste(image_prepro, (4, 4)) + canvas.paste(image_recon, (256 + 8, 4)) + save_path = get_save_path(image_path, "compare") + canvas.save(save_path) + logger.info(f"Saved comparison image to {save_path}") + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 40ecdf4..b7697ff 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -10,6 +10,7 @@ import datasets import diffusers import transformers +import wandb from accelerate.utils import ProjectConfiguration from datasets import load_dataset from diffusers.optimization import SchedulerType, get_scheduler @@ -17,8 +18,6 @@ from rich import inspect from torch.optim import Optimizer -import wandb - try: import torch_xla import torch_xla.core.xla_model as xm diff --git a/train_muse_vae.py b/train_muse_vae.py index a2c53a3..d395ab7 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -5,12 +5,11 @@ from dataclasses import dataclass from typing import Optional, Union +import wandb from accelerate.utils import ProjectConfiguration from datasets import load_dataset from omegaconf import OmegaConf -import wandb - from muse_maskgit_pytorch import ( VQGanVAE, VQGanVAETaming, @@ -300,6 +299,17 @@ action="store_true", help="Whether to use the latest checkpoint", ) +parser.add_argument( + "--do_not_save_config", + action="store_true", + default=False, + help="Generate example YAML configuration file", +) +parser.add_argument( + "--use_l2_recon_loss", + action="store_true", + help="Use F.mse_loss instead of F.l1_loss.", +) @dataclass @@ -361,7 +371,6 @@ class Arguments: use_l2_recon_loss: bool = False debug: bool = False config_path: Optional[str] = None - generate_config: bool = False def preprocess_webdataset(args, image): @@ -374,13 +383,6 @@ def main(): if args.config_path: print("Using config file and ignoring CLI args") - if args.generate_config: - conf = OmegaConf.structured(args) - - # dumps to file: - with open(args.config_path, "w") as f: - OmegaConf.save(conf, f) - try: conf = OmegaConf.load(args.config_path) conf_keys = conf.keys() @@ -459,6 +461,7 @@ def main(): vq_codebook_dim=args.vq_codebook_dim, vq_codebook_size=args.vq_codebook_size, accelerator=accelerator, + l2_recon_loss=args.use_l2_recon_loss, ) if args.latest_checkpoint: @@ -582,6 +585,7 @@ def main(): num_cycles=args.num_cycles, scheduler_power=args.scheduler_power, num_epochs=args.num_epochs, + args=args, ) trainer.train()