Skip to content

Commit

Permalink
Created basic pipeline. Starting cleaning script
Browse files Browse the repository at this point in the history
  • Loading branch information
isamu-isozaki committed Feb 19, 2023
1 parent e3ad84e commit ef02381
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 311 deletions.
2 changes: 1 addition & 1 deletion muse_maskgit_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic

from muse_maskgit_pytorch.trainers import VQGanVAETrainer
from muse_maskgit_pytorch.trainers import VQGanVAETrainer, MaskGitTrainer
17 changes: 15 additions & 2 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from PIL import Image, ImageFile
from pathlib import Path
from muse_maskgit_pytorch.t5 import MAX_LENGTH

import datasets
import random
ImageFile.LOAD_TRUNCATED_IMAGES = True


Expand Down Expand Up @@ -72,4 +73,16 @@ def __getitem__(self, index):

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
return self.transform(image), input_ids, attn_mask
return self.transform(image), input_ids, attn_mask

def get_dataset_from_dataroot(data_root, args):
image_paths = list(Path(data_root).rglob("*.[jJ][pP][gG]"))
random.shuffle(image_paths)
data_dict = {args.image_column: [], args.caption_column: []}
for image_path in image_paths:
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
data_dict[args.image_column].append(image)
data_dict[args.caption_column].append(None)
return datasets.Dataset.from_dict(data_dict)
43 changes: 22 additions & 21 deletions muse_maskgit_pytorch/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,15 @@ def get_encoded_dim(name):

# encoding text

@beartype
def t5_encode_text(
texts: Union[str, List[str]],
tokenizer,
t5,
output_device = None
):
if isinstance(texts, str):
texts = [texts]

def t5_encode_text_from_encoded(input_ids,
attn_mask,
t5,
output_device):
if torch.cuda.is_available():
t5 = t5.cuda()

device = next(t5.parameters()).device

encoded = tokenizer.batch_encode_plus(
texts,
return_tensors = "pt",
padding = 'longest',
max_length = MAX_LENGTH,
truncation = True
)

input_ids = encoded.input_ids.to(device)
attn_mask = encoded.attention_mask.to(device)

t5.eval()

with torch.no_grad():
Expand All @@ -96,3 +79,21 @@ def t5_encode_text(

encoded_text.to(output_device)
return encoded_text
@beartype
def t5_encode_text(
texts: Union[str, List[str]],
tokenizer,
t5,
output_device = None
):
if isinstance(texts, str):
texts = [texts]

encoded = tokenizer.batch_encode_plus(
texts,
return_tensors = "pt",
padding = 'longest',
max_length = MAX_LENGTH,
truncation = True
)
return t5_encode_text_from_encoded(encoded.input_ids, encoded.attn_mask, t5, output_device)
9 changes: 1 addition & 8 deletions muse_maskgit_pytorch/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
"""
Author: Isamu Isozaki ([email protected])
Description: description
Created: 2023-02-18T19:28:19.819Z
Modified: !date!
Modified By: modifier
"""

from muse_maskgit_pytorch.trainers.vqvae_trainers import VQGanVAETrainer, MaskGitTrainer
7 changes: 4 additions & 3 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from muse_maskgit_pytorch.diffusers_optimization import get_scheduler
from muse_maskgit_pytorch.muse_maskgit_pytorch import MaskGit
from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer
from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded
import torch.nn.functional as F
def noop(*args, **kwargs):
pass
Expand Down Expand Up @@ -108,12 +109,12 @@ def train_step(self):
# logs
train_loss = 0
with self.accelerator.accumulate(self.model):
imgs, token_ids, attention_mask = next(self.dl_iter)
imgs, input_ids, attn_mask = next(self.dl_iter)
text_embeds = t5_encode_text_from_encoded(input_ids, attn_mask, self.model.t5, device)
imgs = imgs.to(device)
loss = self.model(
imgs,
token_ids=token_ids,
attentioN_mask=attention_mask
text_embeds=text_embeds,
add_gradient_penalty = apply_grad_penalty,
return_loss = True
)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
],
install_requires=[
'accelerate',
'datasets',
'beartype',
'einops>=0.6',
'ema-pytorch',
Expand Down
247 changes: 0 additions & 247 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,253 +113,6 @@ def parse_args():
args = parser.parse_args()
return args


def vae_trainer(args):
vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size)

current_step = 0
resume_from = args.resume_from
# load the vae from disk if we have previously trained one
if resume_from:
print("Resuming VAE from: ", resume_from)
vae.load(resume_from)

resume_from_parts = resume_from.split(".")
for i in range(len(resume_from_parts) - 1, -1, -1):
if resume_from_parts[i].isdigit():
current_step = int(resume_from_parts[i])
print("Found step " + str(current_step))
break
if current_step == 0:
print("No step found")

trainer = VQGanVAETrainer(
vae,
folder=args.data_folder,
current_step=current_step,
num_train_steps=args.num_train_steps,
batch_size=args.batch_size,
image_size=args.image_size, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
lr=args.lr,
lr_scheduler=args.lr_scheduler,
lr_warmup_steps=args.lr_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
max_grad_norm=None,
discr_max_grad_norm=None,
save_results_every=args.save_results_every,
save_model_every=args.save_model_every,
results_dir=args.results_dir,
logging_dir=args.logging_dir,
valid_frac=0.05,
random_split_seed=42,
use_ema=True,
ema_beta=0.995,
ema_update_after_step=1,
ema_update_every=1,
apply_grad_penalty_every=4,
accelerate_kwargs={
'mixed_precision': args.mixed_precisionWW
},
)

trainer.train()


def base_maskgit_trainer(
args
):
# first instantiate your vae

vae = VQGanVAE(dim=base_dim, vq_codebook_size=base_vq_codebook_size).cuda()

print("Resuming VAE from: ", args.resume_from)
vae.load(
args.resume_from
) # you will want to load the exponentially moving averaged VAE

# then you plug the vae and transformer into your MaskGit as so

# (1) create your transformer / attention network

transformer = MaskGitTransformer(
num_tokens=base_num_tokens, # must be same as codebook size above
seq_len=base_seq_len, # must be equivalent to fmap_size ** 2 in vae
dim=base_dim, # model dimension
depth=base_depth, # depth
dim_head=base_dim_head, # attention head dimension
heads=base_heads, # attention heads,
ff_mult=base_ff_mult, # feedforward expansion factor
t5_name=base_t5_name, # name of your T5
)

# (2) pass your trained VAE and the base transformer to MaskGit

base_maskgit = MaskGit(
vae=vae, # vqgan vae
transformer=transformer, # transformer
image_size=base_image_size, # image size
cond_drop_prob=base_cond_drop_prob, # conditional dropout, for classifier free guidance
).cuda()

# ready your training text and images
images = torch.randn(4, 3, base_image_size, base_image_size).cuda()

# feed it into your maskgit instance, with return_loss set to True

loss = base_maskgit(images, texts=base_texts)

loss.backward()

# do this for a long time on much data

# then...
images = base_maskgit.generate(
texts=[
"a whale breaching from afar",
"young girl blowing out candles on her birthday cake",
"fireworks with blue and green sparkles",
],
cond_scale=base_cond_scale, # conditioning scale for classifier free guidance
timesteps=base_timesteps,
)

# save the base vae
base_maskgit.save(args.resume_from.replace(".pt", ".base.pt"))

# print(images.shape) # (3, 3, 256, 256)

# print(images) # List[PIL.Image.Image]

img1 = images[0]

save_image(img1, f"{results_dir}/outputs/base_result.png")
# img.save(f'{results_dir}/outputs/base_result.png')

# for count in len(images):
# for image in images:
# image.save(f'{results_dir}/outputs/base_{count}.png')


#
def superres_maskgit_trainer(
superres_texts=args.superres_texts,
superres_resume_from=args.superres_resume_from,
superres_dim=args.superres_dim,
superres_vq_codebook_size=args.superres_vq_codebook_size,
superres_num_tokens=args.superres_num_tokens,
superres_seq_len=args.superres_seq_len,
superres_depth=args.superres_depth,
superres_dim_head=args.superres_dim_head,
superres_heads=args.superres_heads,
superres_ff_mult=args.superres_ff_mult,
superres_t5_name=args.superres_t5_name,
superres_image_size=args.superres_image_size,
):
# first instantiate your ViT VQGan VAE
# a VQGan VAE made of transformers

vae = VQGanVAE(dim=superres_dim, vq_codebook_size=superres_vq_codebook_size).cuda()

vae.load(
args.resume_from
) # you will want to load the exponentially moving averaged VAE

# then you plug the VqGan VAE into your MaskGit as so

# (1) create your transformer / attention network

transformer = MaskGitTransformer(
num_tokens=superres_num_tokens, # must be same as codebook size above
seq_len=superres_seq_len, # must be equivalent to fmap_size ** 2 in vae
dim=superres_dim, # model dimension
depth=superres_depth, # depth
dim_head=superres_dim_head, # attention head dimension
heads=superres_heads, # attention heads,
ff_mult=superres_ff_mult, # feedforward expansion factor
t5_name=superres_t5_name, # name of your T5
)

# (2) pass your trained VAE and the base transformer to MaskGit

superres_maskgit = MaskGit(
vae=vae,
transformer=transformer,
cond_drop_prob=0.25,
image_size=superres_image_size, # larger image size
cond_image_size=256, # conditioning image size <- this must be set
).cuda()

# ready your training text and images
images = torch.randn(4, 3, superres_image_size, superres_image_size).cuda()

# feed it into your maskgit instance, with return_loss set to True

loss = superres_maskgit(images, texts=superres_texts)

loss.backward()

# do this for a long time on much data
# then...

images = superres_maskgit.generate(
texts=[
"a whale breaching from afar",
"young girl blowing out candles on her birthday cake",
"fireworks with blue and green sparkles",
"waking up to a psychedelic landscape",
],
cond_images=F.interpolate(
images, 256
), # conditioning images must be passed in for generating from superres
cond_scale=3.0,
timesteps=args.superres_timesteps,
)

# save the superres vae
superres_maskgit.save(args.resume_from.replace(".pt", ".superres.pt"))

# print(images.shape) # (4, 3, 512, 512)
# print(images) # List[PIL.Image.Image]

img1 = images[0]

save_image(img1, f"{results_dir}/outputs/superres_result.png")

# for count in len(images):
# for image in images:
# image.save(f'{results_dir}/outputs/superres_{count}.png')


def generate(
prompt=args.prompt,
base_model_path=args.base_model_path,
superres_maskgit=args.superres_maskgit,
dim=args.dim,
vq_codebook_size=args.vq_codebook_size,
timesteps=args.generate_timesteps,
cond_scale=args.generate_cond_scale,
):
base_maskgit = VQGanVAE(dim=dim, vq_codebook_size=vq_codebook_size).cuda()

superres_maskgit = VQGanVAE(dim=dim, vq_codebook_size=vq_codebook_size).cuda()

# vae.load(model_path)

base_maskgit.load(args.resume_from.replace(".pt", ".base.pt"))
superres_maskgit.load(args.resume_from.replace(".pt", ".superres.pt"))

# pass in the trained base_maskgit and superres_maskgit from above

muse = Muse(base=base_maskgit, superres=superres_maskgit)

images = muse(texts=prompt, timesteps=timesteps, cond_scale=cond_scale)

print(images) # List[PIL.Image.Image]

img1 = images[0]

save_image(img1, f"{results_dir}/outputs/result.png")

def main():
args = parse_args()
accelerator = Accelerator(
Expand Down
Loading

0 comments on commit ef02381

Please sign in to comment.