Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#36 from Sygil-Dev/dev
Browse files Browse the repository at this point in the history
Merge dev to main.
  • Loading branch information
ZeroCool940711 authored Jun 17, 2023
2 parents c14f970 + 4585883 commit a529f2e
Show file tree
Hide file tree
Showing 20 changed files with 777 additions and 42 deletions.
1 change: 1 addition & 0 deletions .github/.github/FUNDING.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
github: [ZeroCool940711]
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,10 @@ dmypy.json

# setuptools-scm version file
muse_maskgit_pytorch/_version.py

# wandb dir
/wandb/

# data, output
/data/
/output/
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
ci:
autofix_prs: true
autoupdate_branch: 'dev'
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand Down
45 changes: 38 additions & 7 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ImageDataset,
get_dataset_from_dataroot,
)
from muse_maskgit_pytorch.vqvae import VQVAE

# Create the parser
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -189,6 +190,11 @@
action="store_true",
help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.",
)
parser.add_argument(
"--use_paintmind",
action="store_true",
help="Use PaintMind VAE..",
)


@dataclass
Expand Down Expand Up @@ -336,7 +342,7 @@ def main():
if args.vae_path and args.taming_model_path:
raise Exception("You can't pass vae_path and taming args at the same time.")

if args.vae_path:
if args.vae_path and not args.use_paintmind:
accelerator.print("Loading Muse VQGanVAE")
vae = VQGanVAE(
dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim
Expand Down Expand Up @@ -390,6 +396,11 @@ def main():

vae.load(args.vae_path)

if args.use_paintmind:
# load VAE
accelerator.print("Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...")
vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4")

elif args.taming_model_path:
print("Loading Taming VQGanVAE")
vae = VQGanVAETaming(
Expand All @@ -398,7 +409,10 @@ def main():
)
args.num_tokens = vae.codebook_size
args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2

# move vae to device
vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

# then you plug the vae and transformer into your MaskGit as so

dataset = ImageDataset(
Expand Down Expand Up @@ -449,11 +463,25 @@ def main():
try:
save_image(dataset[i], f"{output_dir}/input.png")

_, ids, _ = vae.encode(
dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
)
recon = vae.decode_from_ids(ids)
save_image(recon, f"{output_dir}/output.png")
if not args.use_paintmind:
# encode
_, ids, _ = vae.encode(
dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
)
# decode
recon = vae.decode_from_ids(ids)
# print (recon.shape) # torch.Size([1, 3, 512, 1136])
save_image(recon, f"{output_dir}/output.png")
else:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
)

# decode
recon = vae.decode(encoded).squeeze(0)
recon = torch.clamp(recon, -1.0, 1.0)
save_image(recon, f"{output_dir}/output.png")

# Load input and output images
input_image = PIL.Image.open(f"{output_dir}/input.png")
Expand Down Expand Up @@ -495,7 +523,10 @@ def main():
continue # Retry the loop

else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
if "out of memory" not in str(e):
print(e)
else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
break # Exit the retry loop after too many retries


Expand Down
10 changes: 10 additions & 0 deletions muse_maskgit_pytorch/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .attention import CrossAttention, MemoryEfficientCrossAttention
from .mlp import SwiGLU, SwiGLUFFN, SwiGLUFFNFused

__all__ = [
"SwiGLU",
"SwiGLUFFN",
"SwiGLUFFNFused",
"CrossAttention",
"MemoryEfficientCrossAttention",
]
112 changes: 112 additions & 0 deletions muse_maskgit_pytorch/modules/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from inspect import isfunction
from typing import Any, Callable, Optional

from einops import rearrange
from torch import nn

try:
from xformers.ops import memory_efficient_attention
except ImportError:
memory_efficient_attention = None


def exists(x):
return x is not None


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head**-0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))

def forward(self, x, context=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = q * self.scale

sim = q @ k.transpose(-2, -1)
sim = sim.softmax(dim=-1)

out = sim @ v
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)


class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.heads = heads
self.dim_head = dim_head

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Callable] = None

def forward(self, x, context=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)

out = memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)

out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
56 changes: 56 additions & 0 deletions muse_maskgit_pytorch/modules/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch.nn.functional as F
from torch import Tensor, nn


class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)

def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)


try:
from xformers.ops import SwiGLU
except ImportError:
SwiGLU = SwiGLUFFN


class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)
22 changes: 11 additions & 11 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,17 +263,17 @@ def train(self):

self.steps += 1

if self.num_train_steps > 0 and self.steps >= int(self.steps.item()):
if self.on_tpu:
self.accelerator.print(
f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}"
f"[STOP EARLY]: Stopping training early..."
)
else:
self.info_bar.set_description_str(
f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..."
)
break
# if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps:
# if self.on_tpu:
# self.accelerator.print(
# f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}"
# f"[STOP EARLY]: Stopping training early..."
# )
# else:
# self.info_bar.set_description_str(
# f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..."
# )
# break

# loop complete, save final model
self.accelerator.print(
Expand Down
17 changes: 8 additions & 9 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,16 +305,15 @@ def train(self):

if (steps % self.save_results_every) == 0:
self.accelerator.print(
f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}"
f"\n[E{epoch + 1}][{steps}] | Logging validation images to {str(self.results_dir)}"
)

self.log_validation_images(logs, steps)

# save model every so often
self.accelerator.wait_for_everyone()
if self.is_main_process and (steps % self.save_model_every) == 0:
self.accelerator.print(
f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}"
)
self.accelerator.print(f"\nStep: {steps} | Saving model to {str(self.results_dir)}")

state_dict = self.accelerator.unwrap_model(self.model).state_dict()
file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt"
Expand All @@ -341,11 +340,11 @@ def train(self):

self.steps += 1

if self.num_train_steps > 0 and self.steps >= int(self.steps.item()):
self.accelerator.print(
f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"[STOP EARLY]: Stopping training early..."
)
break
# if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps:
# self.accelerator.print(
# f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..."
# )
# break

# Loop finished, save model
self.accelerator.wait_for_everyone()
Expand Down
2 changes: 1 addition & 1 deletion muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def vgg(self):
if exists(self._vgg):
return self._vgg

vgg = torchvision.models.vgg16(pretrained=True)
vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
self._vgg = vgg.to(self.device)
return self._vgg
Expand Down
3 changes: 1 addition & 2 deletions muse_maskgit_pytorch/vqgan_vae_taming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from accelerate import Accelerator
from einops import rearrange
from omegaconf import DictConfig, OmegaConf
from taming.models.vqgan import VQModel
from torch import nn
from tqdm_loggable.auto import tqdm

from taming.models.vqgan import VQModel

# constants
CACHE_PATH = Path.home().joinpath(".cache/taming")

Expand Down
7 changes: 7 additions & 0 deletions muse_maskgit_pytorch/vqvae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .config import VQVAEConfig
from .vqvae import VQVAE

__all__ = [
"VQVAE",
"VQVAEConfig",
]
Loading

0 comments on commit a529f2e

Please sign in to comment.