Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#10 from wnakano/tamig-vae-vqgan
Browse files Browse the repository at this point in the history
Tamig vae vqgan
  • Loading branch information
isamu-isozaki authored Mar 4, 2023
2 parents f9ff752 + fa4d40c commit 2505734
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
wandb
results
models
dataset
taming

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 1 addition & 0 deletions muse_maskgit_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.vqgan_vae_taming import VQGanVAETaming
from muse_maskgit_pytorch.muse_maskgit_pytorch import (
Transformer,
MaskGit,
Expand Down
38 changes: 38 additions & 0 deletions muse_maskgit_pytorch/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Utility functions for optional distributed execution.
To use,
1. set the `BACKENDS` to the ones you want to make available,
2. in the script, wrap the argument parser with `wrap_arg_parser`,
3. in the script, set and use the backend by calling
`set_backend_from_args`.
You can check whether a backend is in use with the `using_backend`
function.
"""



is_distributed = None
"""Whether we are distributed."""
backend = None
"""Backend in usage."""

def require_set_backend():
"""Raise an `AssertionError` when the backend has not been set."""
assert backend is not None, (
'distributed backend is not set. Please call '
'`distributed_utils.set_backend_from_args` at the start of your script'
)


def using_backend(test_backend):
"""Return whether the backend is set to `test_backend`.
`test_backend` may be a string of the name of the backend or
its class.
"""
require_set_backend()
if isinstance(test_backend, str):
return backend.BACKEND_NAME == test_backend
return isinstance(backend, test_backend)
7 changes: 4 additions & 3 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

import torchvision.transforms as T

from typing import Callable, Optional, List
from typing import Callable, Optional, List, Union

from einops import rearrange, repeat

from beartype import beartype

from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.vqgan_vae_taming import VQGanVAETaming
from muse_maskgit_pytorch.t5 import (
t5_encode_text,
get_encoded_dim,
Expand Down Expand Up @@ -476,8 +477,8 @@ def __init__(
noise_schedule: Callable = cosine_schedule,
token_critic: Optional[TokenCritic] = None,
self_token_critic=False,
vae: Optional[VQGanVAE] = None,
cond_vae: Optional[VQGanVAE] = None,
vae: Optional[Union[VQGanVAE, VQGanVAETaming]] = None,
cond_vae: Optional[Union[VQGanVAE, VQGanVAETaming]] = None,
cond_image_size=None,
cond_drop_prob=0.5,
self_cond_prob=0.9,
Expand Down
3 changes: 3 additions & 0 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,14 @@ def __init__(
def log_validation_images(
self, validation_prompts, step, cond_image=None, cond_scale=3, temperature=1
):

images = self.model.generate(
validation_prompts,
cond_images=cond_image,
cond_scale=cond_scale,
temperature=temperature,
)
step = int(step.item())
save_file = str(self.results_dir / f"MaskGit" / f"maskgit_{step}.png")
os.makedirs(str(self.results_dir / f"MaskGit"), exist_ok = True)

Expand Down Expand Up @@ -211,6 +213,7 @@ def train_step(self):
if self.model.cond_image_size:
self.print("With conditional image training, we recommend keeping the validation prompts to empty strings")
cond_image = F.interpolate(imgs[0], 256)

self.log_validation_images(
self.validation_prompts, self.steps, cond_image=cond_image
)
Expand Down
168 changes: 168 additions & 0 deletions muse_maskgit_pytorch/vqgan_vae_taming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os
import copy
import urllib
from pathlib import Path
from tqdm import tqdm
from math import sqrt, log

from omegaconf import OmegaConf
import importlib

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from taming.models.vqgan import VQModel #, GumbelVQ
import muse_maskgit_pytorch.distributed_utils as distributed_utils

# constants
CACHE_PATH = Path("~/.cache/taming")
CACHE_PATH.mkdir(parents=True, exist_ok=True)

VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1'
VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1'

# helpers methods

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d


def download(url, filename = None, root = CACHE_PATH, is_distributed=None, backend=None):
filename = default(filename, os.path.basename(url))
download_target = os.path.join(root, filename)
download_target_tmp = os.path.join(root, f'tmp.{filename}')

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")


if os.path.isfile(download_target):
return download_target

with urllib.request.urlopen(url) as source, open(download_target_tmp, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
loop.update(len(buffer))

os.rename(download_target_tmp, download_target)
if (
distributed_utils.is_distributed
and distributed_utils.backend.is_local_root_worker()
):
distributed_utils.backend.local_barrier()
return download_target


# VQGAN from Taming Transformers paper
# https://arxiv.org/abs/2012.09841

def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)

def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))

class VQGanVAETaming(nn.Module):
def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
super().__init__()

if vqgan_model_path is None:
model_filename = 'vqgan.1024.model.ckpt'
config_filename = 'vqgan.1024.config.yml'
download(VQGAN_VAE_CONFIG_PATH, config_filename)
download(VQGAN_VAE_PATH, model_filename)
config_path = str(Path(CACHE_PATH) / config_filename)
model_path = str(Path(CACHE_PATH) / model_filename)
else:
model_path = vqgan_model_path
config_path = vqgan_config_path

config = OmegaConf.load(config_path)

model = instantiate_from_config(config["model"])

state = torch.load(model_path, map_location = 'cpu')['state_dict']
model.load_state_dict(state, strict = False)

print(f"Loaded VQGAN from {model_path} and {config_path}")

self.model = model

# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]

self.num_layers = int(log(f)/log(2))
self.channels = 3
self.image_size = 256
self.num_tokens = config.model.params.n_embed
self.is_gumbel = False # isinstance(self.model, GumbelVQ)
self.codebook_size = config["model"]["params"]["n_embed"]


@torch.no_grad()
def get_codebook_indices(self, img):
b = img.shape[0]
img = (2 * img) - 1
_, _, [_, _, indices] = self.model.encode(img)
if self.is_gumbel:
return rearrange(indices, 'b h w -> b (h w)', b=b)
return rearrange(indices, '(b n) -> b n', b = b)

def get_encoded_fmap_size(self, image_size):
return image_size // (2**self.num_layers)


def decode_from_ids(self, img_seq):
print(img_seq.shape)
img_seq = rearrange(img_seq, "b h w -> b (h w)")
b, n = img_seq.shape
one_hot_indices = F.one_hot(img_seq, num_classes = self.num_tokens).float()
z = one_hot_indices @ self.model.quantize.embed.weight if self.is_gumbel \
else (one_hot_indices @ self.model.quantize.embedding.weight)

z = rearrange(z, 'b (h w) c -> b c h w', h = int(sqrt(n)))
img = self.model.decode(z)

img = (img.clamp(-1., 1.) + 1) * 0.5
print(img)
return img

def encode(self, im_seq):
# encode output
# fmap, loss, (perplexity, min_encodings, min_encodings_indices) = self.model.encode(im_seq)
fmap, loss, (_, _, min_encodings_indices) = self.model.encode(im_seq)

b, _, h, w = fmap.shape
min_encodings_indices = rearrange(min_encodings_indices, "(b h w) 1 -> b h w", h=h, w=w, b=b)
return fmap, min_encodings_indices, loss

def decode_ids(self, ids):
return self.model.decode_code(ids)


def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())

vae_copy.eval()
return vae_copy.to(device)

def forward(self, img):
raise NotImplemented

2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
"beartype",
"einops>=0.6",
"ema-pytorch",
"omegaconf>=2.3.0",
"pillow",
"sentencepiece",
"torch>=1.6",
"taming-transformers>=0.0.1"
"transformers",
"torch>=1.6",
"torchvision",
Expand Down
36 changes: 25 additions & 11 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from muse_maskgit_pytorch import (
VQGanVAE,
VQGanVAETaming,
MaskGitTrainer,
MaskGit,
MaskGitTransformer,
Expand Down Expand Up @@ -156,7 +157,7 @@ def parse_args():
parser.add_argument(
"--vae_path",
type=str,
default="",
default=None,
help="Path to the vae model. eg. 'results/vae.steps.pt'",
)
parser.add_argument(
Expand Down Expand Up @@ -229,6 +230,12 @@ def parse_args():
default=None,
help="Path to the last saved checkpoint. 'results/maskgit.steps.pt'",
)
parser.add_argument('--taming', dest='taming', action='store_true', default=None)
parser.add_argument('--taming_model_path', type=str, default = None,
help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

parser.add_argument('--taming_config_path', type=str, default = None,
help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')
# Parse the argument
return parser.parse_args()

Expand All @@ -252,23 +259,30 @@ def main():
)
elif args.dataset_name:
dataset = load_dataset(args.dataset_name)["train"]
if all([bool(args.vae_path), bool(args.taming)]):
raise Exception("You can't pass vae_path and taming args at the same time.")

if args.vae_path:
print("Loading Muse VQGanVAE")
vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size).to(
accelerator.device
)

vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size).to(
accelerator.device
)

print("Resuming VAE from: ", args.vae_path)
vae.load(
args.vae_path
) # you will want to load the exponentially moving averaged VAE
print("Resuming VAE from: ", args.vae_path)
vae.load(
args.vae_path
) # you will want to load the exponentially moving averaged VAE
elif args.taming:
print("Loading Taming VQGanVAE")
vae = VQGanVAETaming(vqgan_model_path=args.taming_model_path, vqgan_config_path=args.taming_config_path)

# then you plug the vae and transformer into your MaskGit as so

# (1) create your transformer / attention network

transformer = MaskGitTransformer(
num_tokens=args.num_tokens, # must be same as codebook size above
seq_len=args.seq_len, # must be equivalent to fmap_size ** 2 in vae
num_tokens=vae.codebook_size, # must be same as codebook size above
seq_len=vae.get_encoded_fmap_size(args.image_size) ** 2, # must be equivalent to fmap_size ** 2 in vae
dim=args.dim, # model dimension
depth=args.depth, # depth
dim_head=args.dim_head, # attention head dimension
Expand Down
7 changes: 6 additions & 1 deletion train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def parse_args():
default=None,
help="Path to the last saved checkpoint. 'results/vae.steps.pt'",
)

# Parse the argument
return parser.parse_args()

Expand All @@ -209,7 +210,11 @@ def main():
)
elif args.dataset_name:
dataset = load_dataset(args.dataset_name)["train"]
vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size)

if args.taming:
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
else:
vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size)

if args.resume_path:
print(f"Resuming VAE from: {args.resume_path}")
Expand Down

0 comments on commit 2505734

Please sign in to comment.