diff --git a/muse_maskgit_pytorch/__init__.py b/muse_maskgit_pytorch/__init__.py index 192453b..1c68988 100644 --- a/muse_maskgit_pytorch/__init__.py +++ b/muse_maskgit_pytorch/__init__.py @@ -1,4 +1,4 @@ from muse_maskgit_pytorch.vqgan_vae import VQGanVAE from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic -from muse_maskgit_pytorch.trainers import VQGanVAETrainer +from muse_maskgit_pytorch.trainers import VQGanVAETrainer, MaskGitTrainer diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index b76d1c3..658377d 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -10,7 +10,8 @@ from PIL import Image, ImageFile from pathlib import Path from muse_maskgit_pytorch.t5 import MAX_LENGTH - +import datasets +import random ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -72,4 +73,16 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask - return self.transform(image), input_ids, attn_mask \ No newline at end of file + return self.transform(image), input_ids, attn_mask + +def get_dataset_from_dataroot(data_root, args): + image_paths = list(Path(data_root).rglob("*.[jJ][pP][gG]")) + random.shuffle(image_paths) + data_dict = {args.image_column: [], args.caption_column: []} + for image_path in image_paths: + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + data_dict[args.image_column].append(image) + data_dict[args.caption_column].append(None) + return datasets.Dataset.from_dict(data_dict) \ No newline at end of file diff --git a/muse_maskgit_pytorch/t5.py b/muse_maskgit_pytorch/t5.py index ea40e01..2aa98fe 100644 --- a/muse_maskgit_pytorch/t5.py +++ b/muse_maskgit_pytorch/t5.py @@ -56,32 +56,15 @@ def get_encoded_dim(name): # encoding text -@beartype -def t5_encode_text( - texts: Union[str, List[str]], - tokenizer, - t5, - output_device = None -): - if isinstance(texts, str): - texts = [texts] - +def t5_encode_text_from_encoded(input_ids, + attn_mask, + t5, + output_device): if torch.cuda.is_available(): t5 = t5.cuda() device = next(t5.parameters()).device - encoded = tokenizer.batch_encode_plus( - texts, - return_tensors = "pt", - padding = 'longest', - max_length = MAX_LENGTH, - truncation = True - ) - - input_ids = encoded.input_ids.to(device) - attn_mask = encoded.attention_mask.to(device) - t5.eval() with torch.no_grad(): @@ -96,3 +79,21 @@ def t5_encode_text( encoded_text.to(output_device) return encoded_text +@beartype +def t5_encode_text( + texts: Union[str, List[str]], + tokenizer, + t5, + output_device = None +): + if isinstance(texts, str): + texts = [texts] + + encoded = tokenizer.batch_encode_plus( + texts, + return_tensors = "pt", + padding = 'longest', + max_length = MAX_LENGTH, + truncation = True + ) + return t5_encode_text_from_encoded(encoded.input_ids, encoded.attn_mask, t5, output_device) diff --git a/muse_maskgit_pytorch/trainers/__init__.py b/muse_maskgit_pytorch/trainers/__init__.py index fd0b1b3..22de99a 100644 --- a/muse_maskgit_pytorch/trainers/__init__.py +++ b/muse_maskgit_pytorch/trainers/__init__.py @@ -1,8 +1 @@ -""" -Author: Isamu Isozaki (isamu.website@gmail.com) -Description: description -Created: 2023-02-18T19:28:19.819Z -Modified: !date! -Modified By: modifier -""" - +from muse_maskgit_pytorch.trainers.vqvae_trainers import VQGanVAETrainer, MaskGitTrainer \ No newline at end of file diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 9f43ead..e0a8f5d 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -22,6 +22,7 @@ from muse_maskgit_pytorch.diffusers_optimization import get_scheduler from muse_maskgit_pytorch.muse_maskgit_pytorch import MaskGit from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer +from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded import torch.nn.functional as F def noop(*args, **kwargs): pass @@ -108,12 +109,12 @@ def train_step(self): # logs train_loss = 0 with self.accelerator.accumulate(self.model): - imgs, token_ids, attention_mask = next(self.dl_iter) + imgs, input_ids, attn_mask = next(self.dl_iter) + text_embeds = t5_encode_text_from_encoded(input_ids, attn_mask, self.model.t5, device) imgs = imgs.to(device) loss = self.model( imgs, - token_ids=token_ids, - attentioN_mask=attention_mask + text_embeds=text_embeds, add_gradient_penalty = apply_grad_penalty, return_loss = True ) diff --git a/setup.py b/setup.py index 91ba65f..8e36652 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ ], install_requires=[ 'accelerate', + 'datasets', 'beartype', 'einops>=0.6', 'ema-pytorch', diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index e7b01f7..a25bc20 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -113,253 +113,6 @@ def parse_args(): args = parser.parse_args() return args - -def vae_trainer(args): - vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size) - - current_step = 0 - resume_from = args.resume_from - # load the vae from disk if we have previously trained one - if resume_from: - print("Resuming VAE from: ", resume_from) - vae.load(resume_from) - - resume_from_parts = resume_from.split(".") - for i in range(len(resume_from_parts) - 1, -1, -1): - if resume_from_parts[i].isdigit(): - current_step = int(resume_from_parts[i]) - print("Found step " + str(current_step)) - break - if current_step == 0: - print("No step found") - - trainer = VQGanVAETrainer( - vae, - folder=args.data_folder, - current_step=current_step, - num_train_steps=args.num_train_steps, - batch_size=args.batch_size, - image_size=args.image_size, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it - lr=args.lr, - lr_scheduler=args.lr_scheduler, - lr_warmup_steps=args.lr_warmup_steps, - gradient_accumulation_steps=args.gradient_accumulation_steps, - max_grad_norm=None, - discr_max_grad_norm=None, - save_results_every=args.save_results_every, - save_model_every=args.save_model_every, - results_dir=args.results_dir, - logging_dir=args.logging_dir, - valid_frac=0.05, - random_split_seed=42, - use_ema=True, - ema_beta=0.995, - ema_update_after_step=1, - ema_update_every=1, - apply_grad_penalty_every=4, - accelerate_kwargs={ - 'mixed_precision': args.mixed_precisionWW - }, - ) - - trainer.train() - - -def base_maskgit_trainer( - args -): - # first instantiate your vae - - vae = VQGanVAE(dim=base_dim, vq_codebook_size=base_vq_codebook_size).cuda() - - print("Resuming VAE from: ", args.resume_from) - vae.load( - args.resume_from - ) # you will want to load the exponentially moving averaged VAE - - # then you plug the vae and transformer into your MaskGit as so - - # (1) create your transformer / attention network - - transformer = MaskGitTransformer( - num_tokens=base_num_tokens, # must be same as codebook size above - seq_len=base_seq_len, # must be equivalent to fmap_size ** 2 in vae - dim=base_dim, # model dimension - depth=base_depth, # depth - dim_head=base_dim_head, # attention head dimension - heads=base_heads, # attention heads, - ff_mult=base_ff_mult, # feedforward expansion factor - t5_name=base_t5_name, # name of your T5 - ) - - # (2) pass your trained VAE and the base transformer to MaskGit - - base_maskgit = MaskGit( - vae=vae, # vqgan vae - transformer=transformer, # transformer - image_size=base_image_size, # image size - cond_drop_prob=base_cond_drop_prob, # conditional dropout, for classifier free guidance - ).cuda() - - # ready your training text and images - images = torch.randn(4, 3, base_image_size, base_image_size).cuda() - - # feed it into your maskgit instance, with return_loss set to True - - loss = base_maskgit(images, texts=base_texts) - - loss.backward() - - # do this for a long time on much data - - # then... - images = base_maskgit.generate( - texts=[ - "a whale breaching from afar", - "young girl blowing out candles on her birthday cake", - "fireworks with blue and green sparkles", - ], - cond_scale=base_cond_scale, # conditioning scale for classifier free guidance - timesteps=base_timesteps, - ) - - # save the base vae - base_maskgit.save(args.resume_from.replace(".pt", ".base.pt")) - - # print(images.shape) # (3, 3, 256, 256) - - # print(images) # List[PIL.Image.Image] - - img1 = images[0] - - save_image(img1, f"{results_dir}/outputs/base_result.png") - # img.save(f'{results_dir}/outputs/base_result.png') - - # for count in len(images): - # for image in images: - # image.save(f'{results_dir}/outputs/base_{count}.png') - - -# -def superres_maskgit_trainer( - superres_texts=args.superres_texts, - superres_resume_from=args.superres_resume_from, - superres_dim=args.superres_dim, - superres_vq_codebook_size=args.superres_vq_codebook_size, - superres_num_tokens=args.superres_num_tokens, - superres_seq_len=args.superres_seq_len, - superres_depth=args.superres_depth, - superres_dim_head=args.superres_dim_head, - superres_heads=args.superres_heads, - superres_ff_mult=args.superres_ff_mult, - superres_t5_name=args.superres_t5_name, - superres_image_size=args.superres_image_size, -): - # first instantiate your ViT VQGan VAE - # a VQGan VAE made of transformers - - vae = VQGanVAE(dim=superres_dim, vq_codebook_size=superres_vq_codebook_size).cuda() - - vae.load( - args.resume_from - ) # you will want to load the exponentially moving averaged VAE - - # then you plug the VqGan VAE into your MaskGit as so - - # (1) create your transformer / attention network - - transformer = MaskGitTransformer( - num_tokens=superres_num_tokens, # must be same as codebook size above - seq_len=superres_seq_len, # must be equivalent to fmap_size ** 2 in vae - dim=superres_dim, # model dimension - depth=superres_depth, # depth - dim_head=superres_dim_head, # attention head dimension - heads=superres_heads, # attention heads, - ff_mult=superres_ff_mult, # feedforward expansion factor - t5_name=superres_t5_name, # name of your T5 - ) - - # (2) pass your trained VAE and the base transformer to MaskGit - - superres_maskgit = MaskGit( - vae=vae, - transformer=transformer, - cond_drop_prob=0.25, - image_size=superres_image_size, # larger image size - cond_image_size=256, # conditioning image size <- this must be set - ).cuda() - - # ready your training text and images - images = torch.randn(4, 3, superres_image_size, superres_image_size).cuda() - - # feed it into your maskgit instance, with return_loss set to True - - loss = superres_maskgit(images, texts=superres_texts) - - loss.backward() - - # do this for a long time on much data - # then... - - images = superres_maskgit.generate( - texts=[ - "a whale breaching from afar", - "young girl blowing out candles on her birthday cake", - "fireworks with blue and green sparkles", - "waking up to a psychedelic landscape", - ], - cond_images=F.interpolate( - images, 256 - ), # conditioning images must be passed in for generating from superres - cond_scale=3.0, - timesteps=args.superres_timesteps, - ) - - # save the superres vae - superres_maskgit.save(args.resume_from.replace(".pt", ".superres.pt")) - - # print(images.shape) # (4, 3, 512, 512) - # print(images) # List[PIL.Image.Image] - - img1 = images[0] - - save_image(img1, f"{results_dir}/outputs/superres_result.png") - - # for count in len(images): - # for image in images: - # image.save(f'{results_dir}/outputs/superres_{count}.png') - - -def generate( - prompt=args.prompt, - base_model_path=args.base_model_path, - superres_maskgit=args.superres_maskgit, - dim=args.dim, - vq_codebook_size=args.vq_codebook_size, - timesteps=args.generate_timesteps, - cond_scale=args.generate_cond_scale, -): - base_maskgit = VQGanVAE(dim=dim, vq_codebook_size=vq_codebook_size).cuda() - - superres_maskgit = VQGanVAE(dim=dim, vq_codebook_size=vq_codebook_size).cuda() - - # vae.load(model_path) - - base_maskgit.load(args.resume_from.replace(".pt", ".base.pt")) - superres_maskgit.load(args.resume_from.replace(".pt", ".superres.pt")) - - # pass in the trained base_maskgit and superres_maskgit from above - - muse = Muse(base=base_maskgit, superres=superres_maskgit) - - images = muse(texts=prompt, timesteps=timesteps, cond_scale=cond_scale) - - print(images) # List[PIL.Image.Image] - - img1 = images[0] - - save_image(img1, f"{results_dir}/outputs/result.png") - def main(): args = parse_args() accelerator = Accelerator( diff --git a/train_muse_vae.py b/train_muse_vae.py index 89f1928..8aac876 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -2,7 +2,7 @@ import torch.nn.functional as F from torchvision.utils import save_image from pathlib import Path - +from datasets import load_dataset import os from muse_maskgit_pytorch import ( VQGanVAE, @@ -264,26 +264,17 @@ def parse_args(): args = parser.parse_args() return args - -def vae_trainer(args): +def main(): + args = parse_args() + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=args.logging_dir, + ) + dataset = ImageDataset(args.folder, args.image_size) vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size) - current_step = 0 - resume_from = args.resume_from - # load the vae from disk if we have previously trained one - if resume_from: - print("Resuming VAE from: ", resume_from) - vae.load(resume_from) - - resume_from_parts = resume_from.split(".") - for i in range(len(resume_from_parts) - 1, -1, -1): - if resume_from_parts[i].isdigit(): - current_step = int(resume_from_parts[i]) - print("Found step " + str(current_step)) - break - if current_step == 0: - print("No step found") - trainer = VQGanVAETrainer( vae, folder=args.data_folder, @@ -315,16 +306,6 @@ def vae_trainer(args): trainer.train() -def main(): - args = parse_args() - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - logging_dir=args.logging_dir, - ) - dataset = ImageDataset(args.folder, args.image_size) - if __name__ == "__main__":