From fb489a7be3db688fb72f341f9e143d800f5b1d07 Mon Sep 17 00:00:00 2001 From: ButterCream <56580073+korakoe@users.noreply.github.com> Date: Fri, 9 Jun 2023 19:38:45 +0800 Subject: [PATCH 01/62] Resolve StopIteration by seeking to start --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index cb8e13f..ea10d52 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -177,7 +177,10 @@ def log_validation_images(self, models_to_evaluate, logs, steps): for model, filename in models_to_evaluate: model.eval() - valid_data = next(self.valid_dl_iter) + try: + valid_data = next(self.valid_dl_iter) + except StopIteration: + valid_data = self.valif_dl_iter.seek(0) valid_data = valid_data.to(self.device) recons = model(valid_data, return_recons=True) From a2e11a887aceb625c3adda494e2a1a9bf2328315 Mon Sep 17 00:00:00 2001 From: ButterCream <56580073+korakoe@users.noreply.github.com> Date: Fri, 9 Jun 2023 19:42:10 +0800 Subject: [PATCH 02/62] Seek is only for files Instead I'm reinstantiating the valid_dl --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index ea10d52..19dad42 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -180,7 +180,9 @@ def log_validation_images(self, models_to_evaluate, logs, steps): try: valid_data = next(self.valid_dl_iter) except StopIteration: - valid_data = self.valif_dl_iter.seek(0) + self.valid_dl_iter = iter(self.valid_dl) + valid_data = next(self.valid_dl_iter) + valid_data = valid_data.to(self.device) recons = model(valid_data, return_recons=True) From 0169f746237cdaf846bd0499c261fd4006a9ece8 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sat, 10 Jun 2023 11:11:54 +0800 Subject: [PATCH 03/62] implement Random cropping will detriment maskgit, but VAE will benefit from this --- muse_maskgit_pytorch/dataset.py | 7 +++++-- train_muse_vae.py | 7 +++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index b42a23f..3863fc5 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -39,7 +39,8 @@ def __init__( flip=True, center_crop=True, stream=False, - using_taming=False + using_taming=False, + random_crop = False, ): super().__init__() self.dataset = dataset @@ -51,8 +52,10 @@ def __init__( ] if flip: transform_list.append(T.RandomHorizontalFlip()) - if center_crop: + if center_crop and not random_crop: transform_list.append(T.CenterCrop(image_size)) + if random_crop: + transform_list.append(T.RandomCrop(image_size, pad_if_needed=True)) transform_list.append(T.ToTensor()) self.transform = T.Compose(transform_list) self.using_taming = using_taming diff --git a/train_muse_vae.py b/train_muse_vae.py index 23edd4d..24c15d8 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -45,6 +45,11 @@ action="store_true", help="Don't flip image.", ) +parser.add_argument( + "--random_crop", + action="store_true", + help="Crop the images at random locations instead of cropping from the center.", + ) parser.add_argument( "--dataset_save_path", type=str, @@ -267,6 +272,7 @@ class Arguments: validation_image_scale: float = 1.0 no_center_crop: bool = False no_flip: bool = False + random_crop: bool = False dataset_save_path: Optional[str] = None clear_previous_experiments: bool = False max_grad_norm: Optional[float] = None @@ -480,6 +486,7 @@ def main(): center_crop=not args.no_center_crop, flip=not args.no_flip, stream=args.streaming, + random_crop=args.random_crop ) # dataloader From 5311ad853c8d9d0e8df9af286ef90dc95f177d1f Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sat, 10 Jun 2023 11:11:54 +0800 Subject: [PATCH 04/62] implement Random cropping will detriment maskgit, but VAE will benefit from this --- muse_maskgit_pytorch/dataset.py | 7 +++++-- train_muse_vae.py | 7 +++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index b42a23f..3863fc5 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -39,7 +39,8 @@ def __init__( flip=True, center_crop=True, stream=False, - using_taming=False + using_taming=False, + random_crop = False, ): super().__init__() self.dataset = dataset @@ -51,8 +52,10 @@ def __init__( ] if flip: transform_list.append(T.RandomHorizontalFlip()) - if center_crop: + if center_crop and not random_crop: transform_list.append(T.CenterCrop(image_size)) + if random_crop: + transform_list.append(T.RandomCrop(image_size, pad_if_needed=True)) transform_list.append(T.ToTensor()) self.transform = T.Compose(transform_list) self.using_taming = using_taming diff --git a/train_muse_vae.py b/train_muse_vae.py index 23edd4d..24c15d8 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -45,6 +45,11 @@ action="store_true", help="Don't flip image.", ) +parser.add_argument( + "--random_crop", + action="store_true", + help="Crop the images at random locations instead of cropping from the center.", + ) parser.add_argument( "--dataset_save_path", type=str, @@ -267,6 +272,7 @@ class Arguments: validation_image_scale: float = 1.0 no_center_crop: bool = False no_flip: bool = False + random_crop: bool = False dataset_save_path: Optional[str] = None clear_previous_experiments: bool = False max_grad_norm: Optional[float] = None @@ -480,6 +486,7 @@ def main(): center_crop=not args.no_center_crop, flip=not args.no_flip, stream=args.streaming, + random_crop=args.random_crop ) # dataloader From 28315eefaa07d8b27a457fed91e0c14075fdef0c Mon Sep 17 00:00:00 2001 From: ButterCream <56580073+korakoe@users.noreply.github.com> Date: Sat, 10 Jun 2023 21:13:17 +0800 Subject: [PATCH 05/62] Use LR steps counter instead of num_steps --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 19dad42..15c8f0e 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -105,11 +105,16 @@ def __init__( self.optim = get_optimizer(use_8bit_adam, optimizer, vae_parameters, lr, weight_decay) self.discr_optim = get_optimizer(use_8bit_adam, optimizer, discr_parameters, lr, weight_decay) + if self.num_train_steps <= 0: + self.num_lr_steps = self.num_train_steps * self.gradient_accumulation_steps + else: + self.num_lr_steps = self.num_epochs * len(self.dl) + self.lr_scheduler: LRScheduler = get_scheduler( lr_scheduler_type, optimizer=self.optim, num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, - num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_lr_steps, num_cycles=num_cycles, power=scheduler_power, ) @@ -118,7 +123,7 @@ def __init__( lr_scheduler_type, optimizer=self.discr_optim, num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, - num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_lr_steps, num_cycles=num_cycles, power=scheduler_power, ) From ae8609a7baa6b5f7d1ff22ce9c3f4e48f0a911ec Mon Sep 17 00:00:00 2001 From: ButterCream <56580073+korakoe@users.noreply.github.com> Date: Sat, 10 Jun 2023 21:14:05 +0800 Subject: [PATCH 06/62] Other way around --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 15c8f0e..3bffe3a 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -105,7 +105,7 @@ def __init__( self.optim = get_optimizer(use_8bit_adam, optimizer, vae_parameters, lr, weight_decay) self.discr_optim = get_optimizer(use_8bit_adam, optimizer, discr_parameters, lr, weight_decay) - if self.num_train_steps <= 0: + if self.num_train_steps > 0: self.num_lr_steps = self.num_train_steps * self.gradient_accumulation_steps else: self.num_lr_steps = self.num_epochs * len(self.dl) From 61e4325903d817c61b3b035abe8d0523c97dc2f7 Mon Sep 17 00:00:00 2001 From: ButterCream <56580073+korakoe@users.noreply.github.com> Date: Sat, 10 Jun 2023 23:28:18 +0800 Subject: [PATCH 07/62] Implement proper scheduling with epoch system (#3) * implement Random cropping will detriment maskgit, but VAE will benefit from this * Use LR steps counter instead of num_steps * Other way around --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 19dad42..3bffe3a 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -105,11 +105,16 @@ def __init__( self.optim = get_optimizer(use_8bit_adam, optimizer, vae_parameters, lr, weight_decay) self.discr_optim = get_optimizer(use_8bit_adam, optimizer, discr_parameters, lr, weight_decay) + if self.num_train_steps > 0: + self.num_lr_steps = self.num_train_steps * self.gradient_accumulation_steps + else: + self.num_lr_steps = self.num_epochs * len(self.dl) + self.lr_scheduler: LRScheduler = get_scheduler( lr_scheduler_type, optimizer=self.optim, num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, - num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_lr_steps, num_cycles=num_cycles, power=scheduler_power, ) @@ -118,7 +123,7 @@ def __init__( lr_scheduler_type, optimizer=self.discr_optim, num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, - num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_lr_steps, num_cycles=num_cycles, power=scheduler_power, ) From 91c6dcfd03a2f7303e33525d7c0a256071d8979e Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sun, 11 Jun 2023 11:16:11 +0800 Subject: [PATCH 08/62] implement epoch system for muse scheduler --- train_muse_maskgit.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 5d19559..cbb8666 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -735,11 +735,16 @@ def main(): args.use_8bit_adam, args.optimizer, set(transformer.parameters()), args.lr, args.weight_decay ) + if args.num_train_steps > 0: + num_lr_steps = args.num_train_steps * args.gradient_accumulation_steps + else: + num_lr_steps = args.num_epochs * len(dataloader) + scheduler: SchedulerType = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.num_train_steps * args.gradient_accumulation_steps, + num_training_steps=num_lr_steps, num_cycles=args.num_cycles, power=args.scheduler_power, ) From 60e6f77504f2b8fd0033403aefca048eb8cb1d7b Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Sat, 10 Jun 2023 13:57:08 +0000 Subject: [PATCH 09/62] update tpu-vm.env --- tpu-vm.env | 62 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/tpu-vm.env b/tpu-vm.env index 2dd926e..a2045b7 100644 --- a/tpu-vm.env +++ b/tpu-vm.env @@ -2,23 +2,61 @@ # dot-source this, ok? export PYTHONUNBUFFERED='1' -export ACCELERATE_LOG_LEVEL='DEBUG' -# you may need to clear LD_PRELOAD so tcmalloc doesn't mess us up -#unset LD_PRELOAD +## General log level opts for Accelerate/Transformers +#export ACCELERATE_LOG_LEVEL='INFO' +#export TRANSFORMERS_LOG_LEVEL='INFO' -# set LD_LIBRARY_PATH to point at fresh libtpu -export LD_LIBRARY_PATH=/usr/local/libtpu.so:${LD_LIBRARY_PATH} +# tcmalloc breaks things and google enable it by default, so that's gotta go +unset LD_PRELOAD -# Set these if you want some fun debug info -#export TF_CPP_MIN_LOG_LEVEL=0 -#export TF_CPP_LOG_THREAD_ID=1 -#export TF_CPP_VMODULE='tensor=4,computation_client=5,xrt_computation_client=5,aten_xla_type=5' -#export PT_XLA_DEBUG=1 +# add the dir where `libtpu-nightly` puts the library to LD_LIBRARY_PATH +export LD_LIBRARY_PATH="/usr/local/lib/python3.8/dist-packages/libtpu/:${LD_LIBRARY_PATH}" -# we can't use PJRT with DistributedDataParallel because it uses multithreading and the RNG sync fails. -# thanks google. +# PJRT doesn't work with Accelerate yet so we deconfigure it and go back to old XRT unset PJRT_DEVICE export XRT_TPU_CONFIG='localservice;0;localhost:51011' export MASTER_ADDR='localhost' export MASTER_PORT='12355' + +## see https://github.com/pytorch/xla/issues/4914 +export XLA_IR_SHAPE_CACHE_SIZE=12288 + +## useful options for debug +#export PT_XLA_DEBUG=1 +# Enables the Python stack trace to be captured where creating IR nodes, hence allowing to understand which PyTorch operation was responsible for generating the IR. +#export XLA_IR_DEBUG=1 +# Path to save the IR graphs generated during execution. +#export XLA_SAVE_TENSORS_FILE='' +# File type for above. can be text, dot (GraphViz), or hlo (native) +#export XLA_SAVE_TENSORS_FMT='text' +# Path to save metrics after every op +#export XLA_METRICS_FILE= +# In case of compilation/execution error, the offending HLO graph will be saved here. +#export XLA_SAVE_HLO_FILE= + +# Enable OpByOp dispatch for "get tensors" +#export XLA_GET_TENSORS_OPBYOP=1 +# Enable OpByOp dispatch for "sync tensors" +#export XLA_SYNC_TENSORS_OPBYOP=1 +# Force XLA tensor sync before moving to next step +#export XLA_SYNC_WAIT=1 + +# Force downcasting of fp32 to bf16 +#export XLA_USE_BF16=1 +# Force downcasting of fp32 to fp16 +#export XLA_USE_F16=1 +# Force downcasting of fp64 to fp32 +#export XLA_USE_32BIT_LONG=1 + +## TPU runtime / compilation debug logging +# All XLA log messages are INFO level so this is required +#export TF_CPP_MIN_LOG_LEVEL=0 +# Print the thread ID in log messages +#export TF_CPP_LOG_THREAD_ID=1 +# What modules to print from at what level +#export TF_CPP_VMODULE='tensor=4,computation_client=5,xrt_computation_client=5,aten_xla_type=5' + +## Limit to single TPU chip/core, can be useful for testing +# export TPU_PROCESS_BOUNDS='1,1,1' +# export TPU_VISIBLE_CHIPS=0 From 4bdc669fe76cd2697516d15f57f7c57323300ed7 Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Sun, 11 Jun 2023 05:48:18 +0000 Subject: [PATCH 10/62] vqgan_vae_taming: fix broken einops map --- muse_maskgit_pytorch/vqgan_vae_taming.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/muse_maskgit_pytorch/vqgan_vae_taming.py b/muse_maskgit_pytorch/vqgan_vae_taming.py index c9fa12b..3602b03 100644 --- a/muse_maskgit_pytorch/vqgan_vae_taming.py +++ b/muse_maskgit_pytorch/vqgan_vae_taming.py @@ -11,7 +11,7 @@ from einops import rearrange from omegaconf import OmegaConf, DictConfig -from taming.models.vqgan import VQModel # , GumbelVQ +from taming.models.vqgan import VQModel from torch import nn from tqdm_loggable.auto import tqdm @@ -104,9 +104,9 @@ def __init__(self, vqgan_model_path=None, vqgan_config_path=None, accelerator: A model.load_state_dict(state, strict=False) print(f"Loaded VQGAN from {model_path} and {config_path}") - self.model = model - # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models + self.model: VQModel = model + # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0] self.num_layers = int(log(f) / log(2)) self.channels = 3 @@ -149,7 +149,7 @@ def encode(self, im_seq): fmap, loss, (_, _, min_encodings_indices) = self.model.encode(im_seq) b, _, h, w = fmap.shape - min_encodings_indices = rearrange(min_encodings_indices, "(b h w) 1 -> b h w", h=h, w=w, b=b) + min_encodings_indices = rearrange(min_encodings_indices, "(b h w) -> b h w", h=h, w=w, b=b) return fmap, min_encodings_indices, loss def decode_ids(self, ids): From 4baf63f997bcee1ba208383729a7aaf446f3ce2b Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 00:13:57 -0700 Subject: [PATCH 11/62] Added hf_split_name to set the split or subset to use for huggingface datasets. - Renamed `skip_arrow` to `no_cache.` --- train_muse_maskgit.py | 40 ++++++++++++++++++++++------------------ train_muse_vae.py | 20 +++++++++++++------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index cbb8666..df5c440 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -188,6 +188,12 @@ default=None, help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir)", ) +parser.add_argument( + "--hf_split_name", + type=str, + default="train", + help="Subset or split to use from the dataset when using a dataset form HuggingFace.", +) parser.add_argument( "--streaming", action="store_true", @@ -254,9 +260,9 @@ help="Image Size.", ) parser.add_argument( - "--vq_codebook_dim", - type=int, - default=256, + "--vq_codebook_dim", + type=int, + default=256, help="VQ Codebook dimensions.") parser.add_argument( "--cond_drop_prob", @@ -335,9 +341,9 @@ help="The path to cache huggingface models", ) parser.add_argument( - "--skip_arrow", + "--no_cache", action="store_true", - help="whether to skip converting the dataset to arrow, and to directly fetch data", + help="Do not save the dataset pyarrow cache/files to disk to save disk space and reduce the time it takes to launch the training.", ) parser.add_argument( "--link", @@ -425,7 +431,7 @@ class Arguments: optimizer: str = "Lion" weight_decay: float = 0.0 cache_path: Optional[str] = None - skip_arrow: bool = False + no_cache: bool = False link: bool = False latest_checkpoint: bool = False debug: bool = False @@ -491,15 +497,13 @@ def main(): # Load the dataset (main process first to download, rest will load from cache) with accelerator.main_process_first(): if args.train_data_dir is not None: - if args.skip_arrow: - pass - else: - dataset = get_dataset_from_dataroot( - args.train_data_dir, - image_column=args.image_column, - caption_column=args.caption_column, - save_path=args.dataset_save_path, - ) + dataset = get_dataset_from_dataroot( + args.train_data_dir, + image_column=args.image_column, + caption_column=args.caption_column, + save_path=args.dataset_save_path, + save=not args.no_cache, + ) elif args.dataset_name is not None: dataset = load_dataset( args.dataset_name, @@ -510,9 +514,9 @@ def main(): ) if args.streaming: if args.cache_path: - dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)["train"] + dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name] else: - dataset = load_dataset(args.dataset_name)["train"] + dataset = load_dataset(args.dataset_name)[args.hf_split_name] else: raise ValueError("You must pass either train_data_dir or dataset_name (but not both)") @@ -686,7 +690,7 @@ def main(): # Create the dataset objects with accelerator.main_process_first(): - if args.skip_arrow and args.train_data_dir: + if args.no_cache and args.train_data_dir: dataset = LocalTextImageDataset( args.train_data_dir, args.image_size, diff --git a/train_muse_vae.py b/train_muse_vae.py index 24c15d8..928a366 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -138,6 +138,12 @@ default=None, help="Name of the huggingface dataset used.", ) +parser.add_argument( + "--hf_split_name", + type=str, + default="train", + help="Subset or split to use from the dataset when using a dataset form HuggingFace.", +) parser.add_argument( "--streaming", action="store_true", @@ -256,9 +262,9 @@ help="The path to cache huggingface models", ) parser.add_argument( - "--skip_arrow", + "--no_cache", action="store_true", - help="Whether to skip saving the dataset to Arrow files", + help="Do not save the dataset pyarrow cache/files to disk to save disk space and reduce the time it takes to launch the training.", ) parser.add_argument( "--latest_checkpoint", @@ -318,7 +324,7 @@ class Arguments: optimizer: str = "Lion" weight_decay: float = 0.0 cache_path: Optional[str] = None - skip_arrow: bool = False + no_cache: bool = False latest_checkpoint: bool = False debug: bool = False config_path: Optional[str] = None @@ -376,7 +382,7 @@ def main(): image_column=args.image_column, caption_column=args.caption_column, save_path=args.dataset_save_path, - save=not args.skip_arrow + save=not args.no_cache ) elif args.dataset_name: if args.cache_path: @@ -392,9 +398,9 @@ def main(): print("Dataset doesn't support streaming, disabling streaming") args.streaming = False if args.cache_path: - dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)["train"] + dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name] else: - dataset = load_dataset(args.dataset_name)["train"] + dataset = load_dataset(args.dataset_name)[args.hf_split_name] if args.resume_path is not None: load = True @@ -476,7 +482,7 @@ def main(): accelerator=accelerator, ) - + current_step = 0 dataset = ImageDataset( From 7a108f727b119c5af4d38b84c2a16e258e92107e Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 00:17:41 -0700 Subject: [PATCH 12/62] Added a check so when the cache exist but the folder where the original data for the dataset is modified the cache is removed and recreated, this will help keeping up the cache up to date with the dataset directory. --- muse_maskgit_pytorch/dataset.py | 33 +++++++++++++++++++++++++++++---- train_muse_maskgit.py | 1 + 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 3863fc5..53e2278 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -1,5 +1,5 @@ import os -import random +import random, shutil import sys import time from pathlib import Path @@ -7,7 +7,7 @@ import datasets import torch -from datasets import Image +from datasets import Image, load_from_disk from PIL import Image as pImage from PIL import ImageFile from torch.utils.data import DataLoader, Dataset, random_split @@ -287,11 +287,35 @@ def save_dataset_with_progress(dataset, save_path): time.sleep(1) -def get_dataset_from_dataroot(data_root, image_column="image", caption_column="caption", save_path="dataset", save=True): +def get_dataset_from_dataroot( + data_root, image_column="image", caption_column="caption", save_path="dataset", save=True, + ): # Check if data_root is a symlink and resolve it to its target location if it is if os.path.islink(data_root): data_root = os.path.realpath(data_root) + if os.path.exists(save_path): + # Get the modified time of save_path + save_path_mtime = os.stat(save_path).st_mtime + + if save: + # Traverse the directory tree of data_root and get the modified time of all files and subdirectories + print("Checking modified date of all the files and subdirectories in the dataset folder.") + data_root_mtime = max( + os.stat(os.path.join(root, f)).st_mtime + for root, dirs, files in os.walk(data_root) + for f in files + dirs + ) + + # Check if data_root is newer than save_path + if data_root_mtime > save_path_mtime: + print("The data_root folder has being updated recently. Removing previously saved dataset and updating it.") + shutil.rmtree(save_path, ignore_errors=True) + else: + print("The dataset is up-to-date. Loading...") + # Load the dataset from save_path if it is up-to-date + return load_from_disk(save_path) + extensions = ["jpg", "jpeg", "png", "webp"] image_paths = [] @@ -315,9 +339,10 @@ def get_dataset_from_dataroot(data_root, image_column="image", caption_column="c data_dict[caption_column].append(captions) dataset = datasets.Dataset.from_dict(data_dict) dataset = dataset.cast_column(image_column, Image()) - # dataset.save_to_disk(save_path) + if save: save_dataset_with_progress(dataset, save_path) + return dataset diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index df5c440..0ad7035 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -407,6 +407,7 @@ class Arguments: logging_dir: str = "results/logs" vae_path: Optional[str] = None dataset_name: Optional[str] = None + hf_split_name: Optional[str] = None streaming: bool = False train_data_dir: Optional[str] = None num_train_steps: int = -1 From 74ce3466802ce3e35d26d39eb216ad1538578933 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 00:34:45 -0700 Subject: [PATCH 13/62] Added random_crop argument for the maskgit training. --- muse_maskgit_pytorch/dataset.py | 14 +++++++++----- train_muse_maskgit.py | 8 +++++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 53e2278..21409eb 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -85,16 +85,18 @@ def __init__( flip=True, center_crop=True, stream=False, - using_taming=False + using_taming=False, + random_crop=False, ): super().__init__( dataset, image_size=image_size, image_column=image_column, flip=flip, - center_crop=center_crop, stream=stream, - using_taming=using_taming + center_crop=center_crop, + using_taming=using_taming, + random_crop=random_crop, ) self.caption_column: str = caption_column self.tokenizer: T5Tokenizer = tokenizer @@ -192,7 +194,7 @@ def __getitem__(self, index): class LocalTextImageDataset(Dataset): - def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False): + def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False): super().__init__() self.tokenizer = tokenizer self.using_taming = using_taming @@ -226,8 +228,10 @@ def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, usi ] if flip: transform_list.append(T.RandomHorizontalFlip()) - if center_crop: + if center_crop and not random_crop: transform_list.append(T.CenterCrop(image_size)) + if random_crop: + transform_list.append(T.RandomCrop(image_size, pad_if_needed=True)) transform_list.append(T.ToTensor()) self.transform = T.Compose(transform_list) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 0ad7035..1b88de2 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -69,6 +69,11 @@ action="store_true", help="Don't do center crop.", ) +parser.add_argument( + "--random_crop", + action="store_true", + help="Crop the images at random locations instead of cropping from the center.", + ) parser.add_argument( "--no_flip", action="store_true", @@ -698,7 +703,8 @@ def main(): tokenizer=transformer.tokenizer, center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, - using_taming=False if not args.taming_model_path else True + using_taming=False if not args.taming_model_path else True, + random_crop=args.random_crop if args.random_crop else False, ) elif args.link: if not args.dataset_name: From e1a3ba1ade2a864c84f53839c048833528df1e7b Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sun, 11 Jun 2023 16:14:18 +0800 Subject: [PATCH 14/62] Prevent possible reinstantiation also dont validate ema, inconsistent with other trainers --- .../trainers/vqvae_trainers.py | 50 ++++++++----------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 3bffe3a..b0be794 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -176,35 +176,34 @@ def save(self, path): ) self.accelerator.save(pkg, path) - def log_validation_images(self, models_to_evaluate, logs, steps): + def log_validation_images(self, logs, steps): log_imgs = [] - prompts = ["vae"] if len(models_to_evaluate) == 1 else ["vae", "ema"] - for model, filename in models_to_evaluate: - model.eval() + self.model.eval() - try: - valid_data = next(self.valid_dl_iter) - except StopIteration: - self.valid_dl_iter = iter(self.valid_dl) - valid_data = next(self.valid_dl_iter) - - valid_data = valid_data.to(self.device) + try: + valid_data = next(self.valid_dl_iter) + except StopIteration: + self.valid_dl_iter = iter(self.valid_dl) + valid_data = next(self.valid_dl_iter) - recons = model(valid_data, return_recons=True) + valid_data = valid_data.to(self.device) - # else save a grid of images + recons = self.model(valid_data, return_recons=True) - imgs_and_recons = torch.stack((valid_data, recons), dim=0) - imgs_and_recons = rearrange(imgs_and_recons, "r b ... -> (b r) ...") + # else save a grid of images - imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0.0, 1.0) - grid = make_grid(imgs_and_recons, nrow=2, normalize=True, value_range=(0, 1)) + imgs_and_recons = torch.stack((valid_data, recons), dim=0) + imgs_and_recons = rearrange(imgs_and_recons, "r b ... -> (b r) ...") - logs["reconstructions"] = grid - save_file = str(self.results_dir / f"{filename}.png") - save_image(grid, save_file) - log_imgs.append(Image.open(save_file)) - super().log_validation_images(log_imgs, steps, prompts=prompts) + imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0.0, 1.0) + grid = make_grid(imgs_and_recons, nrow=2, normalize=True, value_range=(0, 1)) + + logs["reconstructions"] = grid + save_file = str(self.results_dir / f"{steps}.png") + save_image(grid, save_file) + log_imgs.append(Image.open(save_file)) + super().log_validation_images(log_imgs, steps, prompts=["vae"]) + self.model.train() def train(self): self.steps = self.steps + 1 @@ -289,12 +288,7 @@ def train(self): # sample results every so often if (steps % self.save_results_every) == 0: - vaes_to_evaluate = ((self.model, str(steps)),) - - if self.use_ema: - vaes_to_evaluate = ((ema_model.ema_model, f"{steps}.ema"),) + vaes_to_evaluate - - self.log_validation_images(vaes_to_evaluate, logs, steps) + self.log_validation_images(logs, steps) self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") # save model every so often From 1ff8bd78d8d180f839276579715b44e70b6c8fe8 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 01:19:01 -0700 Subject: [PATCH 15/62] Fixed double spacing to comply with codefactor. --- train_muse_maskgit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 1b88de2..df3d52d 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -412,7 +412,7 @@ class Arguments: logging_dir: str = "results/logs" vae_path: Optional[str] = None dataset_name: Optional[str] = None - hf_split_name: Optional[str] = None + hf_split_name: Optional[str] = None streaming: bool = False train_data_dir: Optional[str] = None num_train_steps: int = -1 From 574a61c86c8d3b99ed5b813f88a84b823993c4ee Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 00:13:57 -0700 Subject: [PATCH 16/62] Added hf_split_name to set the split or subset to use for huggingface datasets. - Renamed `skip_arrow` to `no_cache.` --- train_muse_maskgit.py | 40 ++++++++++++++++++++++------------------ train_muse_vae.py | 20 +++++++++++++------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index cbb8666..df5c440 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -188,6 +188,12 @@ default=None, help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir)", ) +parser.add_argument( + "--hf_split_name", + type=str, + default="train", + help="Subset or split to use from the dataset when using a dataset form HuggingFace.", +) parser.add_argument( "--streaming", action="store_true", @@ -254,9 +260,9 @@ help="Image Size.", ) parser.add_argument( - "--vq_codebook_dim", - type=int, - default=256, + "--vq_codebook_dim", + type=int, + default=256, help="VQ Codebook dimensions.") parser.add_argument( "--cond_drop_prob", @@ -335,9 +341,9 @@ help="The path to cache huggingface models", ) parser.add_argument( - "--skip_arrow", + "--no_cache", action="store_true", - help="whether to skip converting the dataset to arrow, and to directly fetch data", + help="Do not save the dataset pyarrow cache/files to disk to save disk space and reduce the time it takes to launch the training.", ) parser.add_argument( "--link", @@ -425,7 +431,7 @@ class Arguments: optimizer: str = "Lion" weight_decay: float = 0.0 cache_path: Optional[str] = None - skip_arrow: bool = False + no_cache: bool = False link: bool = False latest_checkpoint: bool = False debug: bool = False @@ -491,15 +497,13 @@ def main(): # Load the dataset (main process first to download, rest will load from cache) with accelerator.main_process_first(): if args.train_data_dir is not None: - if args.skip_arrow: - pass - else: - dataset = get_dataset_from_dataroot( - args.train_data_dir, - image_column=args.image_column, - caption_column=args.caption_column, - save_path=args.dataset_save_path, - ) + dataset = get_dataset_from_dataroot( + args.train_data_dir, + image_column=args.image_column, + caption_column=args.caption_column, + save_path=args.dataset_save_path, + save=not args.no_cache, + ) elif args.dataset_name is not None: dataset = load_dataset( args.dataset_name, @@ -510,9 +514,9 @@ def main(): ) if args.streaming: if args.cache_path: - dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)["train"] + dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name] else: - dataset = load_dataset(args.dataset_name)["train"] + dataset = load_dataset(args.dataset_name)[args.hf_split_name] else: raise ValueError("You must pass either train_data_dir or dataset_name (but not both)") @@ -686,7 +690,7 @@ def main(): # Create the dataset objects with accelerator.main_process_first(): - if args.skip_arrow and args.train_data_dir: + if args.no_cache and args.train_data_dir: dataset = LocalTextImageDataset( args.train_data_dir, args.image_size, diff --git a/train_muse_vae.py b/train_muse_vae.py index 24c15d8..928a366 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -138,6 +138,12 @@ default=None, help="Name of the huggingface dataset used.", ) +parser.add_argument( + "--hf_split_name", + type=str, + default="train", + help="Subset or split to use from the dataset when using a dataset form HuggingFace.", +) parser.add_argument( "--streaming", action="store_true", @@ -256,9 +262,9 @@ help="The path to cache huggingface models", ) parser.add_argument( - "--skip_arrow", + "--no_cache", action="store_true", - help="Whether to skip saving the dataset to Arrow files", + help="Do not save the dataset pyarrow cache/files to disk to save disk space and reduce the time it takes to launch the training.", ) parser.add_argument( "--latest_checkpoint", @@ -318,7 +324,7 @@ class Arguments: optimizer: str = "Lion" weight_decay: float = 0.0 cache_path: Optional[str] = None - skip_arrow: bool = False + no_cache: bool = False latest_checkpoint: bool = False debug: bool = False config_path: Optional[str] = None @@ -376,7 +382,7 @@ def main(): image_column=args.image_column, caption_column=args.caption_column, save_path=args.dataset_save_path, - save=not args.skip_arrow + save=not args.no_cache ) elif args.dataset_name: if args.cache_path: @@ -392,9 +398,9 @@ def main(): print("Dataset doesn't support streaming, disabling streaming") args.streaming = False if args.cache_path: - dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)["train"] + dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name] else: - dataset = load_dataset(args.dataset_name)["train"] + dataset = load_dataset(args.dataset_name)[args.hf_split_name] if args.resume_path is not None: load = True @@ -476,7 +482,7 @@ def main(): accelerator=accelerator, ) - + current_step = 0 dataset = ImageDataset( From 4effe4ac8c5aad1fbe2e3611a166a945c0643bee Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 00:17:41 -0700 Subject: [PATCH 17/62] Added a check so when the cache exist but the folder where the original data for the dataset is modified the cache is removed and recreated, this will help keeping up the cache up to date with the dataset directory. --- muse_maskgit_pytorch/dataset.py | 33 +++++++++++++++++++++++++++++---- train_muse_maskgit.py | 1 + 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 3863fc5..53e2278 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -1,5 +1,5 @@ import os -import random +import random, shutil import sys import time from pathlib import Path @@ -7,7 +7,7 @@ import datasets import torch -from datasets import Image +from datasets import Image, load_from_disk from PIL import Image as pImage from PIL import ImageFile from torch.utils.data import DataLoader, Dataset, random_split @@ -287,11 +287,35 @@ def save_dataset_with_progress(dataset, save_path): time.sleep(1) -def get_dataset_from_dataroot(data_root, image_column="image", caption_column="caption", save_path="dataset", save=True): +def get_dataset_from_dataroot( + data_root, image_column="image", caption_column="caption", save_path="dataset", save=True, + ): # Check if data_root is a symlink and resolve it to its target location if it is if os.path.islink(data_root): data_root = os.path.realpath(data_root) + if os.path.exists(save_path): + # Get the modified time of save_path + save_path_mtime = os.stat(save_path).st_mtime + + if save: + # Traverse the directory tree of data_root and get the modified time of all files and subdirectories + print("Checking modified date of all the files and subdirectories in the dataset folder.") + data_root_mtime = max( + os.stat(os.path.join(root, f)).st_mtime + for root, dirs, files in os.walk(data_root) + for f in files + dirs + ) + + # Check if data_root is newer than save_path + if data_root_mtime > save_path_mtime: + print("The data_root folder has being updated recently. Removing previously saved dataset and updating it.") + shutil.rmtree(save_path, ignore_errors=True) + else: + print("The dataset is up-to-date. Loading...") + # Load the dataset from save_path if it is up-to-date + return load_from_disk(save_path) + extensions = ["jpg", "jpeg", "png", "webp"] image_paths = [] @@ -315,9 +339,10 @@ def get_dataset_from_dataroot(data_root, image_column="image", caption_column="c data_dict[caption_column].append(captions) dataset = datasets.Dataset.from_dict(data_dict) dataset = dataset.cast_column(image_column, Image()) - # dataset.save_to_disk(save_path) + if save: save_dataset_with_progress(dataset, save_path) + return dataset diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index df5c440..0ad7035 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -407,6 +407,7 @@ class Arguments: logging_dir: str = "results/logs" vae_path: Optional[str] = None dataset_name: Optional[str] = None + hf_split_name: Optional[str] = None streaming: bool = False train_data_dir: Optional[str] = None num_train_steps: int = -1 From f624b62999bcb666f388d360a39d936229980bbb Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 00:34:45 -0700 Subject: [PATCH 18/62] Added random_crop argument for the maskgit training. --- muse_maskgit_pytorch/dataset.py | 14 +++++++++----- train_muse_maskgit.py | 8 +++++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 53e2278..21409eb 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -85,16 +85,18 @@ def __init__( flip=True, center_crop=True, stream=False, - using_taming=False + using_taming=False, + random_crop=False, ): super().__init__( dataset, image_size=image_size, image_column=image_column, flip=flip, - center_crop=center_crop, stream=stream, - using_taming=using_taming + center_crop=center_crop, + using_taming=using_taming, + random_crop=random_crop, ) self.caption_column: str = caption_column self.tokenizer: T5Tokenizer = tokenizer @@ -192,7 +194,7 @@ def __getitem__(self, index): class LocalTextImageDataset(Dataset): - def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False): + def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False): super().__init__() self.tokenizer = tokenizer self.using_taming = using_taming @@ -226,8 +228,10 @@ def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, usi ] if flip: transform_list.append(T.RandomHorizontalFlip()) - if center_crop: + if center_crop and not random_crop: transform_list.append(T.CenterCrop(image_size)) + if random_crop: + transform_list.append(T.RandomCrop(image_size, pad_if_needed=True)) transform_list.append(T.ToTensor()) self.transform = T.Compose(transform_list) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 0ad7035..1b88de2 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -69,6 +69,11 @@ action="store_true", help="Don't do center crop.", ) +parser.add_argument( + "--random_crop", + action="store_true", + help="Crop the images at random locations instead of cropping from the center.", + ) parser.add_argument( "--no_flip", action="store_true", @@ -698,7 +703,8 @@ def main(): tokenizer=transformer.tokenizer, center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, - using_taming=False if not args.taming_model_path else True + using_taming=False if not args.taming_model_path else True, + random_crop=args.random_crop if args.random_crop else False, ) elif args.link: if not args.dataset_name: From beae6189fd9d801dcf7077f59b416727bb5701e9 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sun, 11 Jun 2023 16:24:03 +0800 Subject: [PATCH 19/62] for whatever reason my commits arent going through so... --- .../trainers/vqvae_trainers.py | 133 +++++++++--------- 1 file changed, 69 insertions(+), 64 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index cb8e13f..661c345 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -42,38 +42,38 @@ def exists(val): class VQGanVAETrainer(BaseAcceleratedTrainer): def __init__( - self, - vae: VQGanVAE, - dataloader: DataLoader, - valid_dataloader: DataLoader, - accelerator: Accelerator, - *, - current_step, - num_train_steps, - num_epochs: int = 5, - gradient_accumulation_steps=1, - max_grad_norm=None, - save_results_every=100, - save_model_every=1000, - results_dir="./results", - logging_dir="./results/logs", - apply_grad_penalty_every=4, - lr=3e-4, - lr_scheduler_type="constant", - lr_warmup_steps=500, - discr_max_grad_norm=None, - use_ema=True, - ema_beta=0.995, - ema_update_after_step=0, - ema_update_every=1, - clear_previous_experiments=False, - validation_image_scale: float = 1.0, - only_save_last_checkpoint=False, - optimizer="Adam", - weight_decay=0.0, - use_8bit_adam=False, - num_cycles=1, - scheduler_power=1.0 + self, + vae: VQGanVAE, + dataloader: DataLoader, + valid_dataloader: DataLoader, + accelerator: Accelerator, + *, + current_step, + num_train_steps, + num_epochs: int = 5, + gradient_accumulation_steps=1, + max_grad_norm=None, + save_results_every=100, + save_model_every=1000, + results_dir="./results", + logging_dir="./results/logs", + apply_grad_penalty_every=4, + lr=3e-4, + lr_scheduler_type="constant", + lr_warmup_steps=500, + discr_max_grad_norm=None, + use_ema=True, + ema_beta=0.995, + ema_update_after_step=0, + ema_update_every=1, + clear_previous_experiments=False, + validation_image_scale: float = 1.0, + only_save_last_checkpoint=False, + optimizer="Adam", + weight_decay=0.0, + use_8bit_adam=False, + num_cycles=1, + scheduler_power=1.0 ): super().__init__( dataloader, @@ -104,12 +104,17 @@ def __init__( # optimizers self.optim = get_optimizer(use_8bit_adam, optimizer, vae_parameters, lr, weight_decay) self.discr_optim = get_optimizer(use_8bit_adam, optimizer, discr_parameters, lr, weight_decay) - + + if self.num_train_steps > 0: + self.num_lr_steps = self.num_train_steps * self.gradient_accumulation_steps + else: + self.num_lr_steps = self.num_epochs * len(self.dl) + self.lr_scheduler: LRScheduler = get_scheduler( lr_scheduler_type, optimizer=self.optim, num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, - num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_lr_steps, num_cycles=num_cycles, power=scheduler_power, ) @@ -118,7 +123,7 @@ def __init__( lr_scheduler_type, optimizer=self.discr_optim, num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, - num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_lr_steps, num_cycles=num_cycles, power=scheduler_power, ) @@ -171,30 +176,34 @@ def save(self, path): ) self.accelerator.save(pkg, path) - def log_validation_images(self, models_to_evaluate, logs, steps): + def log_validation_images(self, logs, steps): log_imgs = [] - prompts = ["vae"] if len(models_to_evaluate) == 1 else ["vae", "ema"] - for model, filename in models_to_evaluate: - model.eval() + self.model.eval() + try: + valid_data = next(self.valid_dl_iter) + except StopIteration: + self.valid_dl_iter = iter(self.valid_dl) valid_data = next(self.valid_dl_iter) - valid_data = valid_data.to(self.device) - recons = model(valid_data, return_recons=True) + valid_data = valid_data.to(self.device) - # else save a grid of images + recons = self.model(valid_data, return_recons=True) - imgs_and_recons = torch.stack((valid_data, recons), dim=0) - imgs_and_recons = rearrange(imgs_and_recons, "r b ... -> (b r) ...") + # else save a grid of images - imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0.0, 1.0) - grid = make_grid(imgs_and_recons, nrow=2, normalize=True, value_range=(0, 1)) + imgs_and_recons = torch.stack((valid_data, recons), dim=0) + imgs_and_recons = rearrange(imgs_and_recons, "r b ... -> (b r) ...") - logs["reconstructions"] = grid - save_file = str(self.results_dir / f"{filename}.png") - save_image(grid, save_file) - log_imgs.append(Image.open(save_file)) - super().log_validation_images(log_imgs, steps, prompts=prompts) + imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0.0, 1.0) + grid = make_grid(imgs_and_recons, nrow=2, normalize=True, value_range=(0, 1)) + + logs["reconstructions"] = grid + save_file = str(self.results_dir / f"{steps}.png") + save_image(grid, save_file) + log_imgs.append(Image.open(save_file)) + super().log_validation_images(log_imgs, steps, prompts=["vae"]) + self.model.train() def train(self): self.steps = self.steps + 1 @@ -205,7 +214,7 @@ def train(self): proc_label = f"[P{self.accelerator.process_index:03d}][Master]" else: proc_label = f"[P{self.accelerator.process_index:03d}][Worker]" - + for epoch in range(self.num_epochs): for img in self.dl: loss = 0.0 @@ -234,7 +243,6 @@ def train(self): accum_log(logs, {"Train/vae_loss": loss.item() / self.gradient_accumulation_steps}) - self.lr_scheduler.step() self.lr_scheduler_discr.step() self.optim.step() @@ -279,13 +287,9 @@ def train(self): # sample results every so often if (steps % self.save_results_every) == 0: - vaes_to_evaluate = ((self.model, str(steps)),) - - if self.use_ema: - vaes_to_evaluate = ((ema_model.ema_model, f"{steps}.ema"),) + vaes_to_evaluate - - self.log_validation_images(vaes_to_evaluate, logs, steps) - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") + self.log_validation_images(logs, steps) + self.accelerator.print( + f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") # save model every so often self.accelerator.wait_for_everyone() @@ -301,7 +305,8 @@ def train(self): model_path = str(self.results_dir / file_name) self.accelerator.save(ema_state_dict, model_path) - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") + self.accelerator.print( + f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") self.steps += 1 @@ -309,7 +314,7 @@ def train(self): self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: " f"[STOP EARLY]: Stopping training early...") break - + # Loop finished, save model self.accelerator.wait_for_everyone() if self.is_main_process: @@ -324,5 +329,5 @@ def train(self): model_path = str(self.results_dir / file_name) self.accelerator.save(ema_state_dict, model_path) - self.accelerator.print(f"[E{self.num_epochs}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") - + self.accelerator.print( + f"[E{self.num_epochs}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") From 0d7ee042a9296a297028d3b62183175140a7a2b5 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sun, 11 Jun 2023 16:28:43 +0800 Subject: [PATCH 20/62] fix dataset loading twice --- train_muse_maskgit.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 1b88de2..08b1335 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -503,13 +503,16 @@ def main(): # Load the dataset (main process first to download, rest will load from cache) with accelerator.main_process_first(): if args.train_data_dir is not None: - dataset = get_dataset_from_dataroot( - args.train_data_dir, - image_column=args.image_column, - caption_column=args.caption_column, - save_path=args.dataset_save_path, - save=not args.no_cache, - ) + if args.no_cache: + pass + else: + dataset = get_dataset_from_dataroot( + args.train_data_dir, + image_column=args.image_column, + caption_column=args.caption_column, + save_path=args.dataset_save_path, + save=, + ) elif args.dataset_name is not None: dataset = load_dataset( args.dataset_name, From 29905965c9ab2463f885ee0020b82a4bc601dc89 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sun, 11 Jun 2023 16:28:43 +0800 Subject: [PATCH 21/62] fix dataset loading twice --- train_muse_maskgit.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 1b88de2..08b1335 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -503,13 +503,16 @@ def main(): # Load the dataset (main process first to download, rest will load from cache) with accelerator.main_process_first(): if args.train_data_dir is not None: - dataset = get_dataset_from_dataroot( - args.train_data_dir, - image_column=args.image_column, - caption_column=args.caption_column, - save_path=args.dataset_save_path, - save=not args.no_cache, - ) + if args.no_cache: + pass + else: + dataset = get_dataset_from_dataroot( + args.train_data_dir, + image_column=args.image_column, + caption_column=args.caption_column, + save_path=args.dataset_save_path, + save=, + ) elif args.dataset_name is not None: dataset = load_dataset( args.dataset_name, From fd0a12debecf7bb43b992eafd6303c73d6e567fc Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 01:47:52 -0700 Subject: [PATCH 22/62] Fixed unnecessary clutter on the console caused by the transformers library internal logging. --- train_muse_maskgit.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index a7e46e8..96fb1a2 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -51,6 +51,9 @@ traceback_install(show_locals=True, width=120, word_wrap=True) +# remove some unnecessary errors from transformer shown on the console. +transformers.logging.set_verbosity_error() + # Create the parser parser = argparse.ArgumentParser() parser.add_argument( @@ -497,7 +500,9 @@ def main(): if accelerator.is_main_process: accelerator.print(f"Preparing MaskGit for training on {accelerator.device.type}") - inspect(args, docs=False) + if args.debug: + inspect(args, docs=False) + accelerate.utils.set_seed(args.seed) # Load the dataset (main process first to download, rest will load from cache) @@ -511,7 +516,6 @@ def main(): image_column=args.image_column, caption_column=args.caption_column, save_path=args.dataset_save_path, - save=, ) elif args.dataset_name is not None: dataset = load_dataset( From b42fbdd796762c3083583da7eb2a8437e72d96db Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 01:50:04 -0700 Subject: [PATCH 23/62] Changed the prompt separator for when using multiple prompts for validation to be the pipe `|` character as its easier to distinguish it from the rest of the text on the prompt. --- muse_maskgit_pytorch/trainers/maskgit_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 2736f21..cfa4dfb 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -121,7 +121,7 @@ def save_validation_images( if self.accelerator.is_main_process: save_image(images, save_file, "png") - self.log_validation_images([Image.open(save_file)], step, ["\n---\n".join(validation_prompts)]) + self.log_validation_images([Image.open(save_file)], step, ["|".join(validation_prompts)]) return save_file def train(self): From cca687d4ec18a93e03cfa2eb88cb430edea95f7d Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 01:55:56 -0700 Subject: [PATCH 24/62] Improved validation image logging. --- .../trainers/base_accelerated_trainer.py | 55 +++++++++++-------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index cf6aefe..966b8a3 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -265,29 +265,40 @@ def load(self, path: Union[str, PathLike]): def log_validation_images(self, images, step, prompts=None): if prompts: - self.print(f"Logging with prompts: {prompts}") + self.print(f"\nStep: {step} | Logging with prompts: {prompts}") if self.validation_image_scale != 1: - # Feel free to make pr for better solution! - output_size = ( - int(images[0].size[0] * self.validation_image_scale), - int(images[0].size[1] * self.validation_image_scale), - ) - for i in range(len(images)): - images[i] = images[i].resize(output_size) - if self.accelerator.is_main_process: - for tracker in self.accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, step, dataformats="NHWC") - elif tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption="" if not prompts else prompts[i]) - for i, image in enumerate(images) - ] - } - ) + # Calculate the new height based on the scale factor + new_height = int(images[0].shape[0] * self.validation_image_scale) + + # Calculate the aspect ratio of the original image + aspect_ratio = images[0].shape[1] / images[0].shape[0] + + # Calculate the new width based on the new height and aspect ratio + new_width = int(new_height * aspect_ratio) + + # Resize the images using the new width and height + output_size = (new_width, new_height) + images_pil = [Image.fromarray(image) for image in images] + images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil] + images = [np.array(image_pil) for image_pil in images_pil_resized] + + for tracker in self.accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images( + "validation", np_images, step, dataformats="NHWC" + ) + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image( + image, caption="" if not prompts else prompts[i] + ) + for i, image in enumerate(images) + ] + } + ) @property def device(self): From b14df9c0620e35a6d082061785148b1810e86feb Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 02:43:08 -0700 Subject: [PATCH 25/62] Fixed latest checkpoint not working. --- train_muse_maskgit.py | 33 +++++++++++++++++++++++---------- train_muse_vae.py | 11 +++++++---- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 96fb1a2..af29d5c 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -8,7 +8,6 @@ import diffusers from rich import inspect -import torch # noqa: F401 import transformers from datasets import load_dataset from diffusers.optimization import SchedulerType, get_scheduler @@ -54,6 +53,9 @@ # remove some unnecessary errors from transformer shown on the console. transformers.logging.set_verbosity_error() +# disable bitsandbytes welcome message. +os.environ['BITSANDBYTES_NOWELCOME'] = 1 + # Create the parser parser = argparse.ArgumentParser() parser.add_argument( @@ -359,9 +361,20 @@ help="whether to load a dataset with links instead of image (image column becomes URL column)", ) parser.add_argument( - "--latest_checkpoint", - action="store_true", - help="Automatically find and use the latest checkpoint in the folder.", + "--latest_checkpoint", + action="store_true", + help="Automatically find and use the latest checkpoint in the folder.", +) +parser.add_argument( + "--do_not_save_config", + action="store_true", + default=False, + help="Generate example YAML configuration file", +) +parser.add_argument( + "--use_l2_recon_loss", + action="store_true", + help="Use F.mse_loss instead of F.l1_loss.", ) parser.add_argument( "--debug", @@ -443,11 +456,12 @@ class Arguments: no_cache: bool = False link: bool = False latest_checkpoint: bool = False + do_not_save_config: bool = False + use_l2_recon_loss: bool = False debug: bool = False config_path: Optional[str] = None generate_config: bool = False - def main(): args = parser.parse_args(namespace=Arguments()) @@ -555,8 +569,7 @@ def main(): checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) if checkpoint_files: - latest_checkpoint_file = max(checkpoint_files, - key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max(checkpoint_files,key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) # Check if latest checkpoint is empty or unreadable if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): @@ -564,8 +577,7 @@ def main(): f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback - latest_checkpoint_file = max(checkpoint_files[:-1], - key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) else: accelerator.print("No usable checkpoint found.") @@ -822,7 +834,8 @@ def main(): clear_previous_experiments=args.clear_previous_experiments, validation_image_scale=args.validation_image_scale, only_save_last_checkpoint=args.only_save_last_checkpoint, - num_epochs=args.num_epochs + num_epochs=args.num_epochs, + args=args, ) # Prepare the trainer for distributed training diff --git a/train_muse_vae.py b/train_muse_vae.py index 928a366..d14a8bb 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -22,6 +22,9 @@ from omegaconf import OmegaConf, ValidationError +# disable bitsandbytes welcome message. +os.environ['BITSANDBYTES_NOWELCOME'] = 1 + parser = argparse.ArgumentParser() parser.add_argument("--webdataset", type=str, default=None, help="Path to webdataset if using one.") parser.add_argument( @@ -326,6 +329,8 @@ class Arguments: cache_path: Optional[str] = None no_cache: bool = False latest_checkpoint: bool = False + do_not_save_config: bool = False + use_l2_recon_loss: bool = False debug: bool = False config_path: Optional[str] = None generate_config: bool = False @@ -422,8 +427,7 @@ def main(): checkpoint_files = glob.glob(os.path.join(args.resume_path, "vae.*.pt")) if checkpoint_files: - latest_checkpoint_file = max(checkpoint_files, - key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max(checkpoint_files, key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) # Check if latest checkpoint is empty or unreadable if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): @@ -431,8 +435,7 @@ def main(): f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback - latest_checkpoint_file = max(checkpoint_files[:-1], - key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) else: accelerator.print("No usable checkpoint found.") From a976fbe1ede94bc1da646e179c2e6698e5a2eb01 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 02:44:28 -0700 Subject: [PATCH 26/62] Added option to save a config file next to the model checkpoint so its easier to later know what parameters were used for that specific checkpoint. --- .../trainers/maskgit_trainer.py | 33 ++++++++++++- .../trainers/vqvae_trainers.py | 47 ++++++++++++------- 2 files changed, 62 insertions(+), 18 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index cfa4dfb..8e9e63e 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -14,6 +14,8 @@ from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer +from omegaconf import OmegaConf + try: import torch_xla import torch_xla.core.xla_model as xm @@ -55,6 +57,7 @@ def __init__( clear_previous_experiments=False, validation_image_scale: float = 1.0, only_save_last_checkpoint=False, + args=None, ): super().__init__( dataloader=dataloader, @@ -77,6 +80,11 @@ def __init__( self.save_results_every = save_results_every self.log_metrics_every = log_metrics_every self.batch_size = batch_size + + # arguments used for the training script, + # we are going to use them later to save them to a config file. + self.args = args + # maskgit maskgit.vae.requires_grad_(False) maskgit.transformer.t5.requires_grad_(False) @@ -113,7 +121,7 @@ def save_validation_images( cond_images=cond_image, cond_scale=cond_scale, temperature=temperature, - ).to("cpu") + ).to(self.accelerator.device) save_dir = self.results_dir.joinpath("MaskGit") save_dir.mkdir(exist_ok=True, parents=True) @@ -121,6 +129,7 @@ def save_validation_images( if self.accelerator.is_main_process: save_image(images, save_file, "png") + self.accelerator.print(f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}") self.log_validation_images([Image.open(save_file)], step, ["|".join(validation_prompts)]) return save_file @@ -176,7 +185,7 @@ def train(self): self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: " f"saving model to {self.results_dir}") else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " + self.accelerator.print(f"[E{epoch + 1}]{proc_label}: " f"saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() @@ -191,6 +200,11 @@ def train(self): self.accelerator.wait_for_everyone() self.accelerator.save(state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + if self.use_ema: if self.on_tpu: self.accelerator.print( @@ -210,6 +224,11 @@ def train(self): self.accelerator.wait_for_everyone() self.accelerator.save(ema_state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + if not (steps % self.save_results_every): cond_image = None if self.model.cond_image_size: @@ -263,6 +282,11 @@ def train(self): self.accelerator.wait_for_everyone() self.accelerator.save(state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + if self.use_ema: self.accelerator.print( f"[S{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" @@ -277,6 +301,11 @@ def train(self): self.accelerator.wait_for_everyone() self.accelerator.save(ema_state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + cond_image = None if self.model.cond_image_size: self.accelerator.print( diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index a76d660..afaac22 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -1,22 +1,13 @@ -from datetime import datetime -from pathlib import Path -from shutil import rmtree - -import numpy as np import torch -from accelerate import Accelerator, DistributedDataParallelKwargs, DistributedType -from beartype import beartype +from accelerate import Accelerator from diffusers.optimization import get_scheduler from einops import rearrange from ema_pytorch import EMA -from lion_pytorch import Lion from PIL import Image -from torch import nn -from torch.optim import Adam, AdamW from torch.optim.lr_scheduler import LRScheduler -from torch.utils.data import DataLoader, random_split -from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader from torchvision.utils import make_grid, save_image +from omegaconf import OmegaConf from muse_maskgit_pytorch.trainers.base_accelerated_trainer import ( BaseAcceleratedTrainer, @@ -24,7 +15,6 @@ ) from muse_maskgit_pytorch.vqgan_vae import VQGanVAE - def noop(*args, **kwargs): pass @@ -73,7 +63,8 @@ def __init__( weight_decay=0.0, use_8bit_adam=False, num_cycles=1, - scheduler_power=1.0 + scheduler_power=1.0, + args=None, ): super().__init__( dataloader, @@ -94,6 +85,10 @@ def __init__( only_save_last_checkpoint=only_save_last_checkpoint, ) + # arguments used for the training script, + # we are going to use them later to save them to a config file. + self.args = args + # vae self.model = vae @@ -104,12 +99,12 @@ def __init__( # optimizers self.optim = get_optimizer(use_8bit_adam, optimizer, vae_parameters, lr, weight_decay) self.discr_optim = get_optimizer(use_8bit_adam, optimizer, discr_parameters, lr, weight_decay) - + if self.num_train_steps > 0: self.num_lr_steps = self.num_train_steps * self.gradient_accumulation_steps else: self.num_lr_steps = self.num_epochs * len(self.dl) - + self.lr_scheduler: LRScheduler = get_scheduler( lr_scheduler_type, optimizer=self.optim, @@ -298,12 +293,22 @@ def train(self): model_path = str(self.results_dir / file_name) self.accelerator.save(state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + if self.use_ema: ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt" model_path = str(self.results_dir / file_name) self.accelerator.save(ema_state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + self.accelerator.print( f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") @@ -322,11 +327,21 @@ def train(self): model_path = str(self.results_dir / file_name) self.accelerator.save(state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + if self.use_ema: ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt" model_path = str(self.results_dir / file_name) self.accelerator.save(ema_state_dict, model_path) + if self.args and not self.args.do_not_save_config: + # save config file next to the model file. + conf = OmegaConf.create(vars(self.args)) + OmegaConf.save(conf, f"{model_path}.yaml") + self.accelerator.print( f"[E{self.num_epochs}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") From 99d77b41c5d76a5f6f8f027fe198e9a71c2c47b4 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 02:45:27 -0700 Subject: [PATCH 27/62] Move logging prompt message to the maskgit_trainer.py so it shows before the progress bar during image validation and for clarity. --- muse_maskgit_pytorch/trainers/base_accelerated_trainer.py | 2 -- train_muse_maskgit.py | 2 +- train_muse_vae.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index 966b8a3..906d88a 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -264,8 +264,6 @@ def load(self, path: Union[str, PathLike]): return pkg def log_validation_images(self, images, step, prompts=None): - if prompts: - self.print(f"\nStep: {step} | Logging with prompts: {prompts}") if self.validation_image_scale != 1: # Calculate the new height based on the scale factor new_height = int(images[0].shape[0] * self.validation_image_scale) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index af29d5c..63cb884 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -54,7 +54,7 @@ transformers.logging.set_verbosity_error() # disable bitsandbytes welcome message. -os.environ['BITSANDBYTES_NOWELCOME'] = 1 +os.environ['BITSANDBYTES_NOWELCOME'] = '1' # Create the parser parser = argparse.ArgumentParser() diff --git a/train_muse_vae.py b/train_muse_vae.py index d14a8bb..8b8c358 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -20,10 +20,10 @@ import glob import re -from omegaconf import OmegaConf, ValidationError +from omegaconf import OmegaConf # disable bitsandbytes welcome message. -os.environ['BITSANDBYTES_NOWELCOME'] = 1 +os.environ['BITSANDBYTES_NOWELCOME'] = '1' parser = argparse.ArgumentParser() parser.add_argument("--webdataset", type=str, default=None, help="Path to webdataset if using one.") From e15a2a63197011899fde99ad2e2ff4c912dd4832 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 03:18:11 -0700 Subject: [PATCH 28/62] Moved message for the validation images so its shown before the progress bar for readability. --- muse_maskgit_pytorch/trainers/maskgit_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 8e9e63e..bc9ea96 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -116,6 +116,10 @@ def __init__( def save_validation_images( self, validation_prompts, step: int, cond_image=None, cond_scale=3, temperature=1 ): + # moved the print to the top of the function so it shows before the progress bar for reability. + if validation_prompts: + self.accelerator.print(f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}") + images = self.model.generate( validation_prompts, cond_images=cond_image, @@ -129,7 +133,6 @@ def save_validation_images( if self.accelerator.is_main_process: save_image(images, save_file, "png") - self.accelerator.print(f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}") self.log_validation_images([Image.open(save_file)], step, ["|".join(validation_prompts)]) return save_file From 626e819385d1c64f528c1155ef0ed12d79e75968 Mon Sep 17 00:00:00 2001 From: ButterCream <56580073+korakoe@users.noreply.github.com> Date: Sun, 11 Jun 2023 19:31:01 +0800 Subject: [PATCH 29/62] Freeze the VAE before passing to the transformer --- train_muse_maskgit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 63cb884..728d47d 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -610,6 +610,10 @@ def main(): raise ValueError( "You must pass either vae_path or taming_model_path + taming_config_path (but not both)" ) + + + # freeze VAE before parsing to transformer + vae.requires_grad_(False) # then you plug the vae and transformer into your MaskGit like so: From a6583085fea4446e41175641c3a071832b10c582 Mon Sep 17 00:00:00 2001 From: ButterCream <56580073+korakoe@users.noreply.github.com> Date: Sun, 11 Jun 2023 19:31:01 +0800 Subject: [PATCH 30/62] Freeze the VAE before passing to the transformer --- train_muse_maskgit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 63cb884..728d47d 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -610,6 +610,10 @@ def main(): raise ValueError( "You must pass either vae_path or taming_model_path + taming_config_path (but not both)" ) + + + # freeze VAE before parsing to transformer + vae.requires_grad_(False) # then you plug the vae and transformer into your MaskGit like so: From 3669068e883ca6314391791d84e87944fddac511 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 06:33:06 -0700 Subject: [PATCH 31/62] Cleaned some prints a bit and added proper handle for `cond_image_size` on maskgit_trainer.py --- .../trainers/maskgit_trainer.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index bc9ea96..b058a32 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -129,7 +129,7 @@ def save_validation_images( save_dir = self.results_dir.joinpath("MaskGit") save_dir.mkdir(exist_ok=True, parents=True) - save_file = save_dir.joinpath(f"maskgit_S{step:04d}.png") + save_file = save_dir.joinpath(f"maskgit_{step:04d}.png") if self.accelerator.is_main_process: save_image(images, save_file, "png") @@ -174,7 +174,7 @@ def train(self): logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]} if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: " + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") else: self.training_bar.update() @@ -185,7 +185,7 @@ def train(self): if not (steps % self.save_model_every): if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: " + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " f"saving model to {self.results_dir}") else: self.accelerator.print(f"[E{epoch + 1}]{proc_label}: " @@ -211,7 +211,7 @@ def train(self): if self.use_ema: if self.on_tpu: self.accelerator.print( - f"[E{epoch + 1}][S{steps:05d}]{proc_label}: " + f"[E{epoch + 1}][{steps:05d}]{proc_label}: " f"saving EMA model to {self.results_dir}") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " @@ -249,14 +249,14 @@ def train(self): self.validation_prompts, steps, cond_image=cond_image ) if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: saved to {saved_image}") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: metrics:") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:") @@ -264,7 +264,7 @@ def train(self): if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][S{int(self.steps.item()):05d}]{proc_label}" + self.accelerator.print(f"[E{epoch + 1}][{int(self.steps.item()):05d}]{proc_label}" f"[STOP EARLY]: Stopping training early...") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}" @@ -272,7 +272,7 @@ def train(self): break # loop complete, save final model - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}[FINAL]: saving model to {self.results_dir}") + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}[FINAL]: saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" file_name = ( @@ -292,7 +292,7 @@ def train(self): if self.use_ema: self.accelerator.print( - f"[S{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" + f"[{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" ) ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = ( @@ -314,11 +314,12 @@ def train(self): self.accelerator.print( "With conditional image training, we recommend keeping the validation prompts to empty strings" ) - cond_image = F.interpolate(imgs[0], 256) + cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest") + steps = int(self.steps.item()) + 1 # get the final step count, plus one - self.accelerator.print(f"[S{steps:05d}]{proc_label}: Logging validation images") + self.accelerator.print(f"[{steps:05d}]{proc_label}: Logging validation images") saved_image = self.save_validation_images(self.validation_prompts, steps, cond_image=cond_image) - self.accelerator.print(f"[S{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"[{steps:05d}]{proc_label}: saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): - self.accelerator.print(f"[S{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"[{steps:05d}]{proc_label}: metrics:") From a9d6c49bf5561e892d479ee109bcc7dd6f8e1f14 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 06:42:22 -0700 Subject: [PATCH 32/62] - Changed `show_locals` in `traceback_install` to be true only when using `--debug`. - Added support for ProjectConfiguration for accelerate which we will later use for proper handling the checkpoints and config files saved with them. - Added `checkpoint_limit` arg we will need for ProjectConfiguration and accelerate. --- .../trainers/vqvae_trainers.py | 10 +- train_muse_maskgit.py | 146 ++++++++++++------ train_muse_vae.py | 16 +- 3 files changed, 118 insertions(+), 54 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index afaac22..4c37027 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -266,7 +266,7 @@ def train(self): # log - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: " + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " f"vae loss: {logs['Train/vae_loss']} - " f"discr loss: {logs['Train/discr_loss']} - " f"lr: {self.lr_scheduler.get_last_lr()[0]}") @@ -283,7 +283,7 @@ def train(self): if (steps % self.save_results_every) == 0: self.log_validation_images(logs, steps) - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") # save model every so often self.accelerator.wait_for_everyone() @@ -310,12 +310,12 @@ def train(self): OmegaConf.save(conf, f"{model_path}.yaml") self.accelerator.print( - f"[E{epoch + 1}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") + f"[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") self.steps += 1 if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): - self.accelerator.print(f"[E{epoch + 1}][S{steps:05d}]{proc_label}: " + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " f"[STOP EARLY]: Stopping training early...") break @@ -344,4 +344,4 @@ def train(self): OmegaConf.save(conf, f"{model_path}.yaml") self.accelerator.print( - f"[E{self.num_epochs}][S{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") + f"[E{self.num_epochs}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 728d47d..805c3b5 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -1,7 +1,8 @@ import argparse import logging from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union + import accelerate import datasets @@ -18,6 +19,7 @@ import re from omegaconf import OmegaConf +from accelerate.utils import ProjectConfiguration try: import torch_xla @@ -44,12 +46,6 @@ ) from muse_maskgit_pytorch.trainers.base_accelerated_trainer import get_optimizer -if accelerate.utils.is_rich_available(): - from rich import print - from rich.traceback import install as traceback_install - - traceback_install(show_locals=True, width=120, word_wrap=True) - # remove some unnecessary errors from transformer shown on the console. transformers.logging.set_verbosity_error() @@ -263,6 +259,12 @@ default=500, help="Save the model every N steps.", ) +parser.add_argument( + "--checkpoint_limit", + type=int, + default=None, + help="Keep only X number of checkpoints and delete the older ones.", +) parser.add_argument( "--vq_codebook_size", type=int, @@ -387,12 +389,6 @@ default=None, help="debug logging on", ) -parser.add_argument( - "--generate_config", - action="store_true", - help="whether to generate a model config (Recommended for training later)", -) - @dataclass class Arguments: @@ -439,6 +435,7 @@ class Arguments: gradient_accumulation_steps: int = 1 save_results_every: int = 100 save_model_every: int = 500 + checkpoint_limit: Union[int, str] = None vq_codebook_size: int = 256 vq_codebook_dim: int = 256 cond_drop_prob: float = 0.5 @@ -460,20 +457,18 @@ class Arguments: use_l2_recon_loss: bool = False debug: bool = False config_path: Optional[str] = None - generate_config: bool = False def main(): args = parser.parse_args(namespace=Arguments()) - if args.config_path: - print("Using config file and ignoring CLI args") + if accelerate.utils.is_rich_available(): + from rich import print + from rich.traceback import install as traceback_install - if args.generate_config: - conf = OmegaConf.structured(args) + traceback_install(show_locals=args.debug, width=120, word_wrap=True) - # dumps to file: - with open(args.config_path, "w") as f: - OmegaConf.save(conf, f) + if args.config_path: + print("Using config file and ignoring CLI args") try: conf = OmegaConf.load(args.config_path) @@ -498,11 +493,17 @@ def main(): else: logging.basicConfig(level=logging.INFO) + project_config = ProjectConfiguration( + project_dir=args.logging_dir, + total_limit=args.checkpoint_limit, + automatic_checkpoint_naming=True, + ) + accelerator: accelerate.Accelerator = get_accelerator( log_with=args.log_with, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - project_dir=args.logging_dir, + project_config=project_config, even_batches=True ) @@ -549,53 +550,56 @@ def main(): # Load the VAE with accelerator.main_process_first(): - if args.vae_path is not None: - load = True - accelerator.print(f"Using Muse VQGanVAE, loading from {args.vae_path}") - vae = VQGanVAE( - dim=args.dim, - vq_codebook_dim=args.vq_codebook_dim, - vq_codebook_size=args.vq_codebook_size, - accelerator=accelerator, - ) + if args.vae_path: + print("Loading Muse VQGanVAE") if args.latest_checkpoint: - accelerator.print("Finding latest checkpoint...") + print("Finding latest checkpoint...") orig_vae_path = args.vae_path + if os.path.isfile(args.vae_path) or '.pt' in args.vae_path: # If args.vae_path is a file, split it into directory and filename args.vae_path, _ = os.path.split(args.vae_path) checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) if checkpoint_files: - latest_checkpoint_file = max(checkpoint_files,key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + latest_checkpoint_file = max(checkpoint_files, key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) # Check if latest checkpoint is empty or unreadable if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): - accelerator.print( - f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback - latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) - accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) + latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + print("Using second last checkpoint: ", latest_checkpoint_file) else: - accelerator.print("No usable checkpoint found.") - load = False + print("No usable checkpoint found.") elif latest_checkpoint_file != orig_vae_path: - accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) + print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) else: - accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path) + print("Using checkpoint specified in vae_path: ", orig_vae_path) args.vae_path = latest_checkpoint_file else: - accelerator.print("No checkpoints found in directory: ", args.vae_path) - load = False + print("No checkpoints found in directory: ", args.vae_path) else: - accelerator.print("Resuming VAE from: ", args.vae_path) + print("Resuming VAE from: ", args.vae_path) + + # use config next to checkpoint if there is one and merge the cli arguments to it + # the cli arguments will take priority so we can use it to override any value we want. + #if os.path.exists(f"{args.vae_path}.yaml"): + #print("Config file found, reusing config from it. Use cli arguments to override any desired value.") + #conf = OmegaConf.load(f"{args.vae_path}.yaml") + #cli_conf = OmegaConf.from_cli() + ## merge the config file and the cli arguments. + #conf = OmegaConf.merge(conf, cli_conf) + + vae = VQGanVAE(dim=args.dim, vq_codebook_dim=args.vq_codebook_dim, vq_codebook_size=args.vq_codebook_size, l2_recon_loss=args.use_l2_recon_loss).to( + accelerator.device + ) + vae.load(args.vae_path) - if load: - vae.load(args.vae_path, map="cpu") elif args.taming_model_path is not None and args.taming_config_path is not None: print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}") @@ -610,8 +614,8 @@ def main(): raise ValueError( "You must pass either vae_path or taming_model_path + taming_config_path (but not both)" ) - - + + # freeze VAE before parsing to transformer vae.requires_grad_(False) @@ -633,6 +637,52 @@ def main(): t5_name=args.t5_name, cache_path=args.cache_path, ) + + # load the maskgit transformer from disk if we have previously trained one + if args.resume_path: + if args.latest_checkpoint: + accelerator.print("Finding latest checkpoint...") + orig_vae_path = args.resume_path + + + if os.path.isfile(args.resume_path) or '.pt' in args.resume_path: + # If args.resume_path is a file, split it into directory and filename + args.resume_path, _ = os.path.split(args.resume_path) + + checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt")) + if checkpoint_files: + latest_checkpoint_file = max(checkpoint_files,key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + + # Check if latest checkpoint is empty or unreadable + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): + accelerator.print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + if len(checkpoint_files) > 1: + # Use the second last checkpoint as a fallback + latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("No usable checkpoint found.") + elif latest_checkpoint_file != orig_vae_path: + accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("Using checkpoint specified in resume_path: ", orig_vae_path) + + args.resume_path = latest_checkpoint_file + else: + accelerator.print("No checkpoints found in directory: ", args.resume_path) + else: + accelerator.print("Resuming MaskGit from: ", args.resume_path) + + # use config next to checkpoint if there is one and merge the cli arguments to it + # the cli arguments will take priority so we can use it to override any value we want. + if os.path.exists(f"{args.resume_path}.yaml"): + accelerator.print("Config file found, reusing config from it. Use cli arguments to override any desired value.") + conf = OmegaConf.load(f"{args.resume_path}.yaml") + cli_conf = OmegaConf.from_cli() + # merge the config file and the cli arguments. + conf = OmegaConf.merge(conf, cli_conf) + + # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit( vae=vae, # vqgan vae diff --git a/train_muse_vae.py b/train_muse_vae.py index 8b8c358..49ca8f0 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -21,6 +21,7 @@ import re from omegaconf import OmegaConf +from accelerate.utils import ProjectConfiguration # disable bitsandbytes welcome message. os.environ['BITSANDBYTES_NOWELCOME'] = '1' @@ -191,6 +192,12 @@ default=500, help="Save the model every this number of steps.", ) +parser.add_argument( + "--checkpoint_limit", + type=int, + default=None, + help="Keep only X number of checkpoints and delete the older ones.", +) parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") parser.add_argument( "--vq_codebook_dim", @@ -314,6 +321,7 @@ class Arguments: gradient_accumulation_steps: int = 1 save_results_every: int = 100 save_model_every: int = 500 + checkpoint_limit: Union[int, str] = None vq_codebook_size: int = 256 vq_codebook_dim: int = 256 cond_drop_prob: float = 0.5 @@ -367,11 +375,17 @@ def main(): except FileNotFoundError: print("Could not find config, using default and parsed values...") + project_config = ProjectConfiguration( + project_dir=args.logging_dir, + total_limit=args.checkpoint_limit, + automatic_checkpoint_naming=True, + ) + accelerator = get_accelerator( log_with=args.log_with, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - project_dir=args.logging_dir, + project_config=project_config, even_batches=True ) if accelerator.is_main_process: From 7b0210ca13ad3a2b452452f3d580409b88b992c9 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 07:03:02 -0700 Subject: [PATCH 33/62] Removed unused imports. --- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 1 - muse_maskgit_pytorch/trainers/base_accelerated_trainer.py | 4 ++-- muse_maskgit_pytorch/vqgan_vae.py | 4 +--- train_muse_maskgit.py | 2 -- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 4c7aa02..4ce391a 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -11,7 +11,6 @@ from accelerate import Accelerator from beartype import beartype from einops import rearrange, repeat -from rich import inspect from torch import einsum, nn, isnan from tqdm.auto import tqdm from transformers import T5EncoderModel, T5Tokenizer diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index 906d88a..0698bf2 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -1,9 +1,9 @@ from os import PathLike from pathlib import Path from shutil import rmtree -from typing import Dict, Optional, Union +from typing import Optional, Union import accelerate - +from PIL import Image import numpy as np import torch from accelerate import Accelerator, DistributedDataParallelKwargs, DistributedType diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index 3f75ab4..f891730 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -1,8 +1,7 @@ import copy from functools import partial, wraps from pathlib import Path -from typing import List - +from torch import nn import timm import torch import torch.nn.functional as F @@ -10,7 +9,6 @@ from accelerate import Accelerator from beartype import beartype from einops import rearrange, repeat -from torch import nn, Tensor from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize as VQ diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 805c3b5..d0bf285 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -608,8 +608,6 @@ def main(): vqgan_config_path=args.taming_config_path, accelerator=accelerator, ) - args.num_tokens = vae.codebook_size - args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 else: raise ValueError( "You must pass either vae_path or taming_model_path + taming_config_path (but not both)" From 363219065e68b974460db8bb8d5f6686bcce7095 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 07:15:15 -0700 Subject: [PATCH 34/62] Reverted removed args that are needed for taming to work. --- train_muse_maskgit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index d0bf285..805c3b5 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -608,6 +608,8 @@ def main(): vqgan_config_path=args.taming_config_path, accelerator=accelerator, ) + args.num_tokens = vae.codebook_size + args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 else: raise ValueError( "You must pass either vae_path or taming_model_path + taming_config_path (but not both)" From 1f22d22f0cff0a10a04f9969c0847848be089a3a Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 07:40:03 -0700 Subject: [PATCH 35/62] Cleaned some prints and progress bar descriptions. --- .../trainers/maskgit_trainer.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index b058a32..483f67d 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -174,21 +174,21 @@ def train(self): logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]} if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") else: self.training_bar.update() - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") self.accelerator.log(logs, step=steps) if not (steps % self.save_model_every): if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"saving model to {self.results_dir}") else: - self.accelerator.print(f"[E{epoch + 1}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " f"saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() @@ -211,10 +211,10 @@ def train(self): if self.use_ema: if self.on_tpu: self.accelerator.print( - f"[E{epoch + 1}][{steps:05d}]{proc_label}: " + f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"saving EMA model to {self.results_dir}") else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " f"saving EMA model to {self.results_dir}") ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() @@ -239,40 +239,40 @@ def train(self): self.validation_prompts = [""] * self.batch_size if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " f"Logging validation images") else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " f"Logging validation images") saved_image = self.save_validation_images( self.validation_prompts, steps, cond_image=cond_image ) if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saved to {saved_image}") else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: metrics:") else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:") + self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: metrics:") self.steps += 1 if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][{int(self.steps.item()):05d}]{proc_label}" + self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item()):05d}]{proc_label}" f"[STOP EARLY]: Stopping training early...") else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}" + self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early...") break # loop complete, save final model - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}[FINAL]: saving model to {self.results_dir}") + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}[FINAL]: saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" file_name = ( @@ -292,7 +292,7 @@ def train(self): if self.use_ema: self.accelerator.print( - f"[{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" + f"\n[{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" ) ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = ( @@ -317,9 +317,9 @@ def train(self): cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest") steps = int(self.steps.item()) + 1 # get the final step count, plus one - self.accelerator.print(f"[{steps:05d}]{proc_label}: Logging validation images") + self.accelerator.print(f"\n[{steps:05d}]{proc_label}: Logging validation images") saved_image = self.save_validation_images(self.validation_prompts, steps, cond_image=cond_image) - self.accelerator.print(f"[{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"\n[{steps:05d}]{proc_label}: saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): - self.accelerator.print(f"[{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"\n[{steps:05d}]{proc_label}: metrics:") From 92a4abb51cc95d4e6af8263b6484779b2a73f2d2 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 07:57:52 -0700 Subject: [PATCH 36/62] Reverted changes to the progress bar description which made it so it was making a new line on each steps which was not intended to happen on the progress bar. --- muse_maskgit_pytorch/trainers/maskgit_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 483f67d..10752a1 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -178,7 +178,7 @@ def train(self): f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") else: self.training_bar.update() - self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") self.accelerator.log(logs, step=steps) @@ -214,7 +214,7 @@ def train(self): f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"saving EMA model to {self.results_dir}") else: - self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " f"saving EMA model to {self.results_dir}") ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() @@ -242,7 +242,7 @@ def train(self): self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " f"Logging validation images") else: - self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " f"Logging validation images") saved_image = self.save_validation_images( @@ -251,14 +251,14 @@ def train(self): if self.on_tpu: self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saved to {saved_image}") else: - self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: " + self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): if self.on_tpu: self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: metrics:") else: - self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}: metrics:") + self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:") self.steps += 1 @@ -267,7 +267,7 @@ def train(self): self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item()):05d}]{proc_label}" f"[STOP EARLY]: Stopping training early...") else: - self.info_bar.set_description_str(f"\n[E{epoch + 1}]{proc_label}" + self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early...") break From f721c5b43712f0007617bb130ded635018ed81d9 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 19:09:58 -0700 Subject: [PATCH 37/62] Added missing Union import. --- train_muse_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_muse_vae.py b/train_muse_vae.py index 49ca8f0..912b84c 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -1,6 +1,6 @@ import argparse from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union from datasets import load_dataset From 4abad83b871205aa20038c4cf8828db2c85e77ba Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 19:37:53 -0700 Subject: [PATCH 38/62] Added tqdm progress bar to the vae training. --- .../trainers/vqvae_trainers.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 4c37027..bab411b 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from torchvision.utils import make_grid, save_image +from tqdm import tqdm from omegaconf import OmegaConf from muse_maskgit_pytorch.trainers.base_accelerated_trainer import ( @@ -156,6 +157,15 @@ def __init__( ) self.ema_model = accelerator.prepare(self.ema_model) + + if not self.on_tpu: + if self.num_train_steps <= 0: + self.training_bar = tqdm(initial=int(self.steps.item()), total=len(self.dl) * self.num_epochs) + else: + self.training_bar = tqdm(initial=int(self.steps.item()), total=self.num_train_steps) + + self.info_bar = tqdm(total=0, bar_format='{desc}') + def load(self, path): pkg = super().load(path) self.discr_optim.load_state_dict(pkg["discr_optim"]) @@ -265,11 +275,17 @@ def train(self): self.discr_optim.step() # log - - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " - f"vae loss: {logs['Train/vae_loss']} - " - f"discr loss: {logs['Train/discr_loss']} - " - f"lr: {self.lr_scheduler.get_last_lr()[0]}") + if self.on_tpu: + self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " + f"vae loss: {logs['Train/vae_loss']} - " + f"discr loss: {logs['Train/discr_loss']} - " + f"lr: {self.lr_scheduler.get_last_lr()[0]}") + else: + self.training_bar.update() + self.info_bar.set_description_str(f"[E{epoch + 1}][{steps:05d}]: " + f"vae loss: {logs['Train/vae_loss']} - " + f"discr loss: {logs['Train/discr_loss']} - " + f"lr: {self.lr_scheduler.get_last_lr()[0]}") logs["lr"] = self.lr_scheduler.get_last_lr()[0] self.accelerator.log(logs, step=steps) From 2bbd45ec54951df3e57ec4702716c2527b0184c8 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 19:42:49 -0700 Subject: [PATCH 39/62] Changed position of the print statements in some part of the code so they are shown before the action they are supposed to show info for like when saving the validation images and saving the model, this was previously done after the action and not before. --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index bab411b..69ec2e5 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -282,6 +282,8 @@ def train(self): f"lr: {self.lr_scheduler.get_last_lr()[0]}") else: self.training_bar.update() + # Note: we had to remove {proc_label} from the description + # to short it so it doenst go beyond one line on each step. self.info_bar.set_description_str(f"[E{epoch + 1}][{steps:05d}]: " f"vae loss: {logs['Train/vae_loss']} - " f"discr loss: {logs['Train/discr_loss']} - " @@ -298,12 +300,15 @@ def train(self): # sample results every so often if (steps % self.save_results_every) == 0: + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") self.log_validation_images(logs, steps) - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") # save model every so often self.accelerator.wait_for_everyone() if self.is_main_process and (steps % self.save_model_every) == 0: + self.accelerator.print( + f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") + state_dict = self.accelerator.unwrap_model(self.model).state_dict() file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" model_path = str(self.results_dir / file_name) @@ -325,19 +330,19 @@ def train(self): conf = OmegaConf.create(vars(self.args)) OmegaConf.save(conf, f"{model_path}.yaml") - self.accelerator.print( - f"[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") - self.steps += 1 if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"[STOP EARLY]: Stopping training early...") break # Loop finished, save model self.accelerator.wait_for_everyone() if self.is_main_process: + self.accelerator.print( + f"[E{self.num_epochs}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") + state_dict = self.accelerator.unwrap_model(self.model).state_dict() file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" model_path = str(self.results_dir / file_name) @@ -359,5 +364,3 @@ def train(self): conf = OmegaConf.create(vars(self.args)) OmegaConf.save(conf, f"{model_path}.yaml") - self.accelerator.print( - f"[E{self.num_epochs}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") From 67d09eec03318c12572654bc025b0d8c9d036960 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 19:43:58 -0700 Subject: [PATCH 40/62] Removed some conditions that were unnecessary as we are doing the exact same action whether we are on a tpu or not. --- .../trainers/maskgit_trainer.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 10752a1..d985233 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -184,12 +184,8 @@ def train(self): self.accelerator.log(logs, step=steps) if not (steps % self.save_model_every): - if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " - f"saving model to {self.results_dir}") - else: - self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " - f"saving model to {self.results_dir}") + self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + f"saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" @@ -209,13 +205,9 @@ def train(self): OmegaConf.save(conf, f"{model_path}.yaml") if self.use_ema: - if self.on_tpu: - self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " - f"saving EMA model to {self.results_dir}") - else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " - f"saving EMA model to {self.results_dir}") + self.accelerator.print( + f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + f"saving EMA model to {self.results_dir}") ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = ( From d979fc5294a10bbc8b8edd54da4569371637f725 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 19:44:27 -0700 Subject: [PATCH 41/62] Improved the infer_vae.py script. --- infer_vae.py | 513 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 387 insertions(+), 126 deletions(-) diff --git a/infer_vae.py b/infer_vae.py index f197649..564b68f 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -1,7 +1,10 @@ import torch +import accelerate +from dataclasses import dataclass from torchvision.utils import save_image from datasets import load_dataset, Dataset, Image -import os, random +import os, random, hashlib +from datetime import datetime from muse_maskgit_pytorch import ( VQGanVAE, VQGanVAETaming, @@ -11,118 +14,238 @@ get_dataset_from_dataroot, ImageDataset, ) - +from tqdm import tqdm import argparse +import PIL +import glob, re +from accelerate.utils import ProjectConfiguration +# Create the parser +parser = argparse.ArgumentParser() +parser.add_argument( + "--no_center_crop", + action="store_true", + help="Don't do center crop.", +) +parser.add_argument( + "--random_crop", + action="store_true", + help="Crop the images at random locations instead of cropping from the center.", +) +parser.add_argument( + "--no_flip", + action="store_true", + help="Don't flip image.", +) +parser.add_argument( + "--random_image", + action="store_true", + help="Get a random image from the dataset to use for the reconstruction.", +) +parser.add_argument( + "--dataset_save_path", + type=str, + default="dataset", + help="Path to save the dataset if you are making one from a directory", +) +parser.add_argument( + "--seed", + type=int, + default=42, + help="Seed for reproducibility. If set to -1 a random seed will be generated.", +) +parser.add_argument( + "--valid_frac", type=float, default=0.05, help="validation fraction." +) +parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", +) +parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="Precision to train on.", +) +parser.add_argument( + "--results_dir", + type=str, + default="results", + help="Path to save the training samples and checkpoints", +) +parser.add_argument( + "--logging_dir", + type=str, + default="results/logs", + help="Path to log the losses and LR", +) -def parse_args(): - # Create the parser - parser = argparse.ArgumentParser() - parser.add_argument( - "--no_center_crop", - action="store_true", - help="Don't do center crop.", - ) - parser.add_argument( - "--no_flip", - action="store_true", - help="Don't flip image.", - ) - parser.add_argument( - "--random_image", - action="store_true", - help="Get a random image from the dataset to use for the reconstruction.", - ) - parser.add_argument( - "--dataset_save_path", - type=str, - default="dataset", - help="Path to save the dataset if you are making one from a directory", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="Seed for reproducibility. If set to -1 a random seed will be generated.", - ) - parser.add_argument("--valid_frac", type=float, default=0.05, help="validation fraction.") - parser.add_argument( - "--image_column", - type=str, - default="image", - help="The column of the dataset containing an image.", - ) - parser.add_argument( - "--mixed_precision", - type=str, - default="no", - choices=["no", "fp16", "bf16"], - help="Precision to train on.", - ) - parser.add_argument( - "--results_dir", - type=str, - default="results", - help="Path to save the training samples and checkpoints", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="results/logs", - help="Path to log the losses and LR", - ) - - # vae_trainer args - parser.add_argument( - "--vae_path", - type=str, - default=None, - help="Path to the vae model. eg. 'results/vae.steps.pt'", - ) - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help="Name of the huggingface dataset used.", - ) - parser.add_argument( - "--train_data_dir", - type=str, - default=None, - help="Dataset folder where your input images for training are.", - ) - parser.add_argument("--dim", type=int, default=128, help="Model dimension.") - parser.add_argument("--batch_size", type=int, default=512, help="Batch Size.") - parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.") - parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") - parser.add_argument( - "--image_size", - type=int, - default=256, - help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it", - ) - parser.add_argument( - "--taming_model_path", - type=str, - default=None, - help="path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)", - ) +# vae_trainer args +parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path to the vae model. eg. 'results/vae.steps.pt'", +) +parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="Name of the huggingface dataset used.", +) +parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help="Dataset folder where your input images for training are.", +) +parser.add_argument("--dim", type=int, default=128, help="Model dimension.") +parser.add_argument("--batch_size", type=int, default=512, help="Batch Size.") +parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.") +parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") +parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") +parser.add_argument( + "--image_size", + type=int, + default=256, + help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it", +) +parser.add_argument( + "--chunk_size", + type=int, + default=256, + help="This is used to split big images into smaller chunks so we can still reconstruct them no matter the size.", +) +parser.add_argument( + "--min_chunk_size", + type=int, + default=8, + help="We use a minimum chunk size to ensure that the image is always reconstructed correctly.", +) +parser.add_argument( + "--overlap_size", + type=int, + default=256, + help="The overlap size used with --chunk_size to overlap the chunks and make sure the whole image is reconstructe as well as make sure we remove artifacts caused by doing the reconstrucion in chunks.", +) +parser.add_argument( + "--min_overlap_size", + type=int, + default=1, + help="We use a minimum overlap size to ensure that the image is always reconstructed correctly.", +) +parser.add_argument( + "--taming_model_path", + type=str, + default=None, + help="path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)", +) - parser.add_argument( - "--taming_config_path", - type=str, - default=None, - help="path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)", - ) - parser.add_argument( - "--input_image", - type=str, - default=None, - help="Path to an image to use as input for reconstruction instead of using one from the dataset.", - ) +parser.add_argument( + "--taming_config_path", + type=str, + default=None, + help="path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)", +) +parser.add_argument( + "--input_image", + type=str, + default=None, + help="Path to an image to use as input for reconstruction instead of using one from the dataset.", +) +parser.add_argument( + "--input_folder", + type=str, + default=None, + help="Path to a folder with images to use as input for creating a dataset for reconstructing all the imgaes in it instead of just one image.", +) +parser.add_argument( + "--exclude_folders", + type=str, + default=None, + help="List of folders we want to exclude when doing reconstructions from an input folder.", +) +parser.add_argument( + "--gpu", + type=int, + default=0, + help="GPU to use in case we want to use a specific GPU for inference.", +) +parser.add_argument( + "--max_retries", + type=int, + default=30, + help="Max number of times to retry in case the reconstruction fails due to OOM or any other error.", +) +parser.add_argument( + "--latest_checkpoint", + action="store_true", + help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.", +) - # Parse the argument - return parser.parse_args() +@dataclass +class Arguments: + only_save_last_checkpoint: bool = False + validation_image_scale: float = 1.0 + no_center_crop: bool = False + no_flip: bool = False + random_crop: bool = False + random_image: bool = False + dataset_save_path: Optional[str] = None + clear_previous_experiments: bool = False + max_grad_norm: Optional[float] = None + discr_max_grad_norm: Optional[float] = None + num_tokens: int = 256 + seq_len: int = 1024 + seed: int = 42 + valid_frac: float = 0.05 + use_ema: bool = False + ema_beta: float = 0.995 + ema_update_after_step: int = 1 + ema_update_every: int = 1 + apply_grad_penalty_every: int = 4 + image_column: str = "image" + caption_column: str = "caption" + log_with: str = "wandb" + mixed_precision: str = "no" + use_8bit_adam: bool = False + results_dir: str = "results" + logging_dir: str = "results/logs" + resume_path: Optional[str] = None + dataset_name: Optional[str] = None + streaming: bool = False + train_data_dir: Optional[str] = None + num_train_steps: int = -1 + num_epochs: int = 5 + dim: int = 128 + batch_size: int = 512 + lr: float = 1e-5 + gradient_accumulation_steps: int = 1 + save_results_every: int = 100 + save_model_every: int = 500 + vq_codebook_size: int = 256 + vq_codebook_dim: int = 256 + cond_drop_prob: float = 0.5 + image_size: int = 256 + lr_scheduler: str = "constant" + scheduler_power: float = 1.0 + lr_warmup_steps: int = 0 + num_cycles: int = 1 + taming_model_path: Optional[str] = None + taming_config_path: Optional[str] = None + optimizer: str = "Lion" + weight_decay: float = 0.0 + cache_path: Optional[str] = None + no_cache: bool = False + latest_checkpoint: bool = False + do_not_save_config: bool = False + use_l2_recon_loss: bool = False + debug: bool = False + config_path: Optional[str] = None + generate_config: bool = False def seed_to_int(s): @@ -151,36 +274,103 @@ def seed_to_int(s): def main(): - args = parse_args() - accelerator = get_accelerator( + args = parser.parse_args(namespace=Arguments()) + + project_config = ProjectConfiguration( + project_dir=args.logging_dir, + automatic_checkpoint_naming=True, + ) + + accelerator: accelerate.Accelerator = get_accelerator( + log_with=args.log_with, + gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - project_dir=args.logging_dir, + project_config=project_config, + even_batches=True ) # set pytorch seed for reproducibility torch.manual_seed(seed_to_int(args.seed)) - if args.train_data_dir and not args.input_image: + if args.train_data_dir and not args.input_image and not args.input_folder: dataset = get_dataset_from_dataroot( args.train_data_dir, image_column=args.image_column, save_path=args.dataset_save_path, ) - elif args.dataset_name and not args.input_image: + elif args.dataset_name and not args.input_image and not args.input_folder: dataset = load_dataset(args.dataset_name)["train"] - elif args.input_image: + elif args.input_image and not args.input_folder: + # Create dataset from single input image dataset = Dataset.from_dict({"image": [args.input_image]}).cast_column("image", Image()) + if args.input_folder: + # Create dataset from input folder + extensions = ["jpg", "jpeg", "png", "webp"] + exclude_folders = args.exclude_folders.split(',') if args.exclude_folders else [] + + filepaths = [] + for root, dirs, files in os.walk(args.input_folder, followlinks=True): + # Resolve symbolic link to actual path and exclude based on actual path + resolved_root = os.path.realpath(root) + for exclude_folder in exclude_folders: + if exclude_folder in resolved_root: + dirs[:] = [] + break + for file in files: + if file.lower().endswith(tuple(extensions)): + filepaths.append(os.path.join(root, file)) + + if not filepaths: + print(f"No images with extensions {extensions} found in {args.input_folder}.") + sys.exit(1) + + dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image()) + if args.vae_path and args.taming_model_path: raise Exception("You can't pass vae_path and taming args at the same time.") if args.vae_path: accelerator.print("Loading Muse VQGanVAE") - vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size).to(accelerator.device) + vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim).to( + accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}" + ) + + if args.latest_checkpoint: + accelerator.print("Finding latest checkpoint...") + orig_vae_path = args.vae_path + + + if os.path.isfile(args.vae_path) or '.pt' in args.vae_path: + # If args.vae_path is a file, split it into directory and filename + args.vae_path, _ = os.path.split(args.vae_path) + + checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) + if checkpoint_files: + latest_checkpoint_file = max(checkpoint_files,key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + + # Check if latest checkpoint is empty or unreadable + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): + accelerator.print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + if len(checkpoint_files) > 1: + # Use the second last checkpoint as a fallback + latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("No usable checkpoint found.") + elif latest_checkpoint_file != orig_vae_path: + accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path) + + args.vae_path = latest_checkpoint_file + else: + accelerator.print("No checkpoints found in directory: ", args.vae_path) + else: + accelerator.print("Resuming VAE from: ", args.vae_path) - accelerator.print("Resuming VAE from: ", args.vae_path) - vae.load(args.vae_path) # you will want to load the exponentially moving averaged VAE + vae.load(args.vae_path) elif args.taming_model_path: print("Loading Taming VQGanVAE") @@ -190,27 +380,98 @@ def main(): ) args.num_tokens = vae.codebook_size args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 - vae = vae.to(accelerator.device) + vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") # then you plug the vae and transformer into your MaskGit as so dataset = ImageDataset( dataset, args.image_size, image_column=args.image_column, - center_crop=not args.no_center_crop, + center_crop=True if not args.no_center_crop and not args.random_crop else False, flip=not args.no_flip, + random_crop=args.random_crop if args.random_crop else False ) - image_id = 0 if not args.random_image else random.randint(0, len(dataset)) + if args.input_image and not args.input_folder: + image_id = 0 if not args.random_image else random.randint(0, len(dataset)) + + os.makedirs(f"{args.results_dir}/outputs", exist_ok=True) + + save_image(dataset[image_id], f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}") + + _, ids, _ = vae.encode(dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + recon = vae.decode_from_ids(ids) + save_image(recon, f"{args.results_dir}/outputs/output.{str(args.input_image).split('.')[-1]}") + + if not args.input_image and not args.input_folder: + image_id = 0 if not args.random_image else random.randint(0, len(dataset)) + + os.makedirs(f"{args.results_dir}/outputs", exist_ok=True) + + save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png") + + _, ids, _ = vae.encode(dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + recon = vae.decode_from_ids(ids) + save_image(recon, f"{args.results_dir}/outputs/output.png") + + + if args.input_folder: + # Create output directory and save input images and reconstructions as grids + output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder)) + os.makedirs(output_dir, exist_ok=True) + + for i in tqdm(range(len(dataset))): + retries = 0 + while True: + try: + save_image(dataset[i], f"{output_dir}/input.png") + + _, ids, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + recon = vae.decode_from_ids(ids) + save_image(recon, f"{output_dir}/output.png") + + # Load input and output images + input_image = PIL.Image.open(f"{output_dir}/input.png") + output_image = PIL.Image.open(f"{output_dir}/output.png") + + # Create horizontal grid with input and output images + grid_image = PIL.Image.new('RGB', (input_image.width + output_image.width, input_image.height)) + grid_image.paste(input_image, (0, 0)) + grid_image.paste(output_image, (input_image.width, 0)) + + # Save grid + now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") + hash = hashlib.sha1(input_image.tobytes()).hexdigest() + + filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png" + grid_image.save(f"{output_dir}/{filename}") + + # Remove input and output images after the grid was made. + os.remove(f"{output_dir}/input.png") + os.remove(f"{output_dir}/output.png") + + del _ + del ids + del recon + + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + break # Exit the retry loop if there were no errors - os.makedirs(f"{args.results_dir}/outputs", exist_ok=True) + except RuntimeError as e: + if "out of memory" in str(e) and retries < args.max_retries: + retries += 1 + #print(f"Out of Memory. Retry #{retries}") + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + continue # Retry the loop - save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png") + else: + print(f"Skipping image {i} after {retries} retries due to out of memory error") + break # Exit the retry loop after too many retries - _, ids, _ = vae.encode(dataset[image_id][None].to(accelerator.device)) - recon = vae.decode_from_ids(ids) - save_image(recon, f"{args.results_dir}/outputs/output.png") if __name__ == "__main__": - main() + main() \ No newline at end of file From c3ef3fa233ab8f4e382620538f3317f966cc85cf Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 19:46:14 -0700 Subject: [PATCH 42/62] Fixed bug with PIL that would sometimes just crash because of the pixel explosion bomb error with some images on the dataset during training. This can be just ignored and bypassed by configuring "MAX_MIAGE_PIXELS = None" on PIL. --- muse_maskgit_pytorch/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 21409eb..b185e05 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -28,6 +28,7 @@ from io import BytesIO ImageFile.LOAD_TRUNCATED_IMAGES = True +pImage.MAX_IMAGE_PIXELS = None class ImageDataset(Dataset): From f2bd0de684bf9a872ddd1f9aa71838642dcbd57d Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 20:33:12 -0700 Subject: [PATCH 43/62] Added some extra text for clarification when trying to find the latest check for the maskgit on a folder. --- .../trainers/maskgit_trainer.py | 28 +++++++++---------- train_muse_maskgit.py | 10 +++---- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index d985233..9219b16 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -129,7 +129,7 @@ def save_validation_images( save_dir = self.results_dir.joinpath("MaskGit") save_dir.mkdir(exist_ok=True, parents=True) - save_file = save_dir.joinpath(f"maskgit_{step:04d}.png") + save_file = save_dir.joinpath(f"maskgit_{step}.png") if self.accelerator.is_main_process: save_image(images, save_file, "png") @@ -141,9 +141,9 @@ def train(self): self.model.train() if self.accelerator.is_main_process: - proc_label = f"[P{self.accelerator.process_index:03d}][Master]" + proc_label = f"[P{self.accelerator.process_index}][Master]" else: - proc_label = f"[P{self.accelerator.process_index:03d}][Worker]" + proc_label = f"[P{self.accelerator.process_index}][Worker]" # logs for epoch in range(self.num_epochs): @@ -174,7 +174,7 @@ def train(self): logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]} if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") else: self.training_bar.update() @@ -184,7 +184,7 @@ def train(self): self.accelerator.log(logs, step=steps) if not (steps % self.save_model_every): - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() @@ -206,7 +206,7 @@ def train(self): if self.use_ema: self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"saving EMA model to {self.results_dir}") ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() @@ -241,14 +241,14 @@ def train(self): self.validation_prompts, steps, cond_image=cond_image ) if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: saved to {saved_image}") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: metrics:") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:") @@ -256,7 +256,7 @@ def train(self): if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item()):05d}]{proc_label}" + self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" f"[STOP EARLY]: Stopping training early...") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}" @@ -264,7 +264,7 @@ def train(self): break # loop complete, save final model - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}[FINAL]: saving model to {self.results_dir}") + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}[FINAL]: saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" file_name = ( @@ -284,7 +284,7 @@ def train(self): if self.use_ema: self.accelerator.print( - f"\n[{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" + f"\n[{steps}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" ) ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = ( @@ -309,9 +309,9 @@ def train(self): cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest") steps = int(self.steps.item()) + 1 # get the final step count, plus one - self.accelerator.print(f"\n[{steps:05d}]{proc_label}: Logging validation images") + self.accelerator.print(f"\n[{steps}]{proc_label}: Logging validation images") saved_image = self.save_validation_images(self.validation_prompts, steps, cond_image=cond_image) - self.accelerator.print(f"\n[{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"\n[{steps}]{proc_label}: saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): - self.accelerator.print(f"\n[{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"\n[{steps}]{proc_label}: metrics:") diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 805c3b5..276b036 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -655,21 +655,21 @@ def main(): # Check if latest checkpoint is empty or unreadable if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): - accelerator.print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + accelerator.print(f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable.") if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) - accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) + accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file) else: - accelerator.print("No usable checkpoint found.") + accelerator.print("No usable MaskGit checkpoint found.") elif latest_checkpoint_file != orig_vae_path: accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file) else: - accelerator.print("Using checkpoint specified in resume_path: ", orig_vae_path) + accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path) args.resume_path = latest_checkpoint_file else: - accelerator.print("No checkpoints found in directory: ", args.resume_path) + accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path) else: accelerator.print("Resuming MaskGit from: ", args.resume_path) From 51d6632364af6c08a8fca2ee8851b9ed66af8319 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 12:14:26 +0100 Subject: [PATCH 44/62] start making a way to compare attn implementations --- .vscode/launch.json | 11 +++++++ attn/ein_attn.py | 78 +++++++++++++++++++++++++++++++++++++++++++++ attn/sdp_attn.py | 78 +++++++++++++++++++++++++++++++++++++++++++++ attn_test.py | 54 +++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+) create mode 100644 .vscode/launch.json create mode 100644 attn/ein_attn.py create mode 100644 attn/sdp_attn.py create mode 100644 attn_test.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..c7264c4 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,11 @@ +{ + "configurations": [ + { + "name": "Python: Verify attn impl equivalence", + "type": "python", + "request": "launch", + "module": "attn_test", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/attn/ein_attn.py b/attn/ein_attn.py new file mode 100644 index 0000000..4ffa6fb --- /dev/null +++ b/attn/ein_attn.py @@ -0,0 +1,78 @@ +from torch import einsum, nn +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +# helpers +def exists(val): + return val is not None + +def l2norm(t): + return F.normalize(t, dim=-1) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + +# TODO: make faster +class Attention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): + super().__init__() + self.scale = scale + self.heads = heads + inner_dim = dim_head * heads + + self.cross_attend = cross_attend + self.norm = LayerNorm(dim) + + self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + + self.q_scale = nn.Parameter(torch.ones(dim_head)) + self.k_scale = nn.Parameter(torch.ones(dim_head)) + + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, context=None, context_mask=None): + assert not (exists(context) ^ self.cross_attend) + + h = self.heads + x = self.norm(x) + + kv_input = context if self.cross_attend else x + + q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + nk, nv = self.null_kv + nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv)) + + k = torch.cat((nk, k), dim=-2) + v = torch.cat((nv, v), dim=-2) + + q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + + sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + + if exists(context_mask): + context_mask = rearrange(context_mask, "b j -> b 1 1 j") + context_mask = F.pad(context_mask, (1, 0), value=True) + + mask_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~context_mask, mask_value) + + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) \ No newline at end of file diff --git a/attn/sdp_attn.py b/attn/sdp_attn.py new file mode 100644 index 0000000..58a1af5 --- /dev/null +++ b/attn/sdp_attn.py @@ -0,0 +1,78 @@ +from torch import einsum, nn +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +# helpers +def exists(val): + return val is not None + +def l2norm(t): + return F.normalize(t, dim=-1) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + +# TODO: change this to use torch sdp attn +class Attention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): + super().__init__() + self.scale = scale + self.heads = heads + inner_dim = dim_head * heads + + self.cross_attend = cross_attend + self.norm = LayerNorm(dim) + + self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + + self.q_scale = nn.Parameter(torch.ones(dim_head)) + self.k_scale = nn.Parameter(torch.ones(dim_head)) + + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, context=None, context_mask=None): + assert not (exists(context) ^ self.cross_attend) + + h = self.heads + x = self.norm(x) + + kv_input = context if self.cross_attend else x + + q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + nk, nv = self.null_kv + nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv)) + + k = torch.cat((nk, k), dim=-2) + v = torch.cat((nv, v), dim=-2) + + q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + + sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + + if exists(context_mask): + context_mask = rearrange(context_mask, "b j -> b 1 1 j") + context_mask = F.pad(context_mask, (1, 0), value=True) + + mask_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~context_mask, mask_value) + + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) \ No newline at end of file diff --git a/attn_test.py b/attn_test.py new file mode 100644 index 0000000..ea95c21 --- /dev/null +++ b/attn_test.py @@ -0,0 +1,54 @@ +from attn.ein_attn import Attention as EinAttn +from attn.sdp_attn import Attention as SDPAttn +import torch +from torch import FloatTensor, BoolTensor, manual_seed, randn, arange, allclose, no_grad + +device = torch.device('cuda') +dtype = torch.float32 +seed = 42 + +# realistically this would be 320 in stable-diffusion, but I'm going smaller during testing +vision_dim = 64 + +attn_init_params = { + 'dim': vision_dim, + 'dim_head': 64, + # realistically this would be at least 5 + 'heads': 2, + 'cross_attend': True, + 'scale': 8, +} + +with no_grad(): + # seed RNG before we initialize any layers, so that both will end up with same params + manual_seed(seed) + ein_attn = EinAttn(**attn_init_params).to(device, dtype).eval() + manual_seed(seed) + sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval() + + batch_size = 2 + + # realistically this would be 64**2 in stable-diffusion + vision_tokens = 32**2 # 1024 + + # generate rand on-CPU for cross-platform determinism of results + x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device) + + text_tokens = 16 # CLIP would be 77 + # there's no reason why these would **have** to be the same (in stable-diffusion text_dim is 768) + # but lucid didn't expose any separate param for customizing the cross attention input dim. + # easily fixed, but whatever I'll work with what's there. + text_dim = vision_dim + context: FloatTensor = randn(batch_size, text_tokens, text_dim, dtype=dtype).to(device) + + # attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens) + context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1) + + ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) + sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) + + # default relative and absolute tolerance + rtol=1e-5 + atol=1e-8 + assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" + print(f'attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}') \ No newline at end of file From dfef69a2ac3db05b946639ecab4ab72dedf3d11f Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 13:56:57 +0100 Subject: [PATCH 45/62] compute attn similarities using typical scale factor, by fusing the extra scaling into q_scale --- attn/sdp_attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/attn/sdp_attn.py b/attn/sdp_attn.py index 58a1af5..6577cab 100644 --- a/attn/sdp_attn.py +++ b/attn/sdp_attn.py @@ -23,7 +23,6 @@ def forward(self, x): class Attention(nn.Module): def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() - self.scale = scale self.heads = heads inner_dim = dim_head * heads @@ -35,7 +34,9 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.q_scale = nn.Parameter(torch.ones(dim_head)) + self.typical_scale = dim_head ** -.5 + scale_ratio = scale/self.typical_scale + self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Linear(inner_dim, dim, bias=False) @@ -62,7 +63,7 @@ def forward(self, x, context=None, context_mask=None): q = q * self.q_scale k = k * self.k_scale - sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.typical_scale if exists(context_mask): context_mask = rearrange(context_mask, "b j -> b 1 1 j") From 9a288b782366979d0817c431e04569ea50e7d090 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 13:57:57 +0100 Subject: [PATCH 46/62] prefer matmul over einsum --- attn/sdp_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/attn/sdp_attn.py b/attn/sdp_attn.py index 6577cab..eb56e3e 100644 --- a/attn/sdp_attn.py +++ b/attn/sdp_attn.py @@ -63,7 +63,7 @@ def forward(self, x, context=None, context_mask=None): q = q * self.q_scale k = k * self.k_scale - sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.typical_scale + sim = q @ k.transpose(-2, -1) * self.typical_scale if exists(context_mask): context_mask = rearrange(context_mask, "b j -> b 1 1 j") @@ -73,7 +73,7 @@ def forward(self, x, context=None, context_mask=None): sim = sim.masked_fill(~context_mask, mask_value) attn = sim.softmax(dim=-1) - out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = attn @ v out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) \ No newline at end of file From 5246f215e384013765838f1e250ab0c1ddbeadf2 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 14:16:56 +0100 Subject: [PATCH 47/62] sdp attn --- attn/sdp_attn.py | 26 +++++++++----------------- attn_test.py | 3 ++- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/attn/sdp_attn.py b/attn/sdp_attn.py index eb56e3e..97d76c1 100644 --- a/attn/sdp_attn.py +++ b/attn/sdp_attn.py @@ -1,11 +1,9 @@ -from torch import einsum, nn +from torch import einsum, nn, FloatTensor import torch import torch.nn.functional as F +from torch.nn.functional import scaled_dot_product_attention from einops import rearrange, repeat - -# helpers -def exists(val): - return val is not None +from typing import Optional def l2norm(t): return F.normalize(t, dim=-1) @@ -34,15 +32,15 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.typical_scale = dim_head ** -.5 - scale_ratio = scale/self.typical_scale + typical_scale = dim_head ** -.5 + scale_ratio = scale/typical_scale self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x, context=None, context_mask=None): - assert not (exists(context) ^ self.cross_attend) + def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask=None): + assert (context is None) != self.cross_attend h = self.heads x = self.norm(x) @@ -63,17 +61,11 @@ def forward(self, x, context=None, context_mask=None): q = q * self.q_scale k = k * self.k_scale - sim = q @ k.transpose(-2, -1) * self.typical_scale - - if exists(context_mask): + if context_mask is not None: context_mask = rearrange(context_mask, "b j -> b 1 1 j") context_mask = F.pad(context_mask, (1, 0), value=True) - mask_value = -torch.finfo(sim.dtype).max - sim = sim.masked_fill(~context_mask, mask_value) - - attn = sim.softmax(dim=-1) - out = attn @ v + out: FloatTensor = scaled_dot_product_attention(q, k, v, context_mask) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) \ No newline at end of file diff --git a/attn_test.py b/attn_test.py index ea95c21..e99177e 100644 --- a/attn_test.py +++ b/attn_test.py @@ -43,12 +43,13 @@ # attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens) context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1) + # context_mask = None ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) # default relative and absolute tolerance rtol=1e-5 - atol=1e-8 + atol=5e-7 assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" print(f'attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}') \ No newline at end of file From 000c2da73d5de62c7a8859ee56354b169c1c88ca Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 15:19:37 +0100 Subject: [PATCH 48/62] xformers attn working, so long as length of mask plus null token is a multiple of 8 --- .vscode/launch.json | 2 +- attn/sdp_attn.py | 1 - attn/xformers_attn.py | 74 +++++++++++++++++++++++++++++++++++++++++++ attn_test.py | 18 +++++++---- setup.py | 3 +- 5 files changed, 89 insertions(+), 9 deletions(-) create mode 100644 attn/xformers_attn.py diff --git a/.vscode/launch.json b/.vscode/launch.json index c7264c4..2537580 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,7 +5,7 @@ "type": "python", "request": "launch", "module": "attn_test", - "justMyCode": true + "justMyCode": false } ] } \ No newline at end of file diff --git a/attn/sdp_attn.py b/attn/sdp_attn.py index 97d76c1..da65da6 100644 --- a/attn/sdp_attn.py +++ b/attn/sdp_attn.py @@ -17,7 +17,6 @@ def __init__(self, dim): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) -# TODO: change this to use torch sdp attn class Attention(nn.Module): def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() diff --git a/attn/xformers_attn.py b/attn/xformers_attn.py new file mode 100644 index 0000000..c7f1422 --- /dev/null +++ b/attn/xformers_attn.py @@ -0,0 +1,74 @@ +from torch import nn, FloatTensor +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from typing import Optional +from xformers.ops import memory_efficient_attention + +def l2norm(t): + return F.normalize(t, dim=-1) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + +class Attention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): + super().__init__() + self.heads = heads + inner_dim = dim_head * heads + + self.cross_attend = cross_attend + self.norm = LayerNorm(dim) + + self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + + typical_scale = dim_head ** -.5 + scale_ratio = scale/typical_scale + self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) + self.k_scale = nn.Parameter(torch.ones(dim_head)) + + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask=None): + assert (context is None) != self.cross_attend + + h = self.heads + x = self.norm(x) + + kv_input = context if self.cross_attend else x + + q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q, k, v)) + + nk, nv = self.null_kv + nk, nv = map(lambda t: repeat(t, "h 1 d -> b 1 h d", b=x.shape[0]), (nk, nv)) + + k = torch.cat((nk, k), dim=-3) + v = torch.cat((nv, v), dim=-3) + + q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + + if context_mask is None: + attn_bias = None + else: + context_mask = F.pad(context_mask, (1, 0), value=True) + context_mask = rearrange(context_mask, "b j -> b 1 1 j") + attn_bias = torch.where(context_mask == True, 0., -10000.) + attn_bias = attn_bias.expand(-1, h, q.size(1), -1) + + out: FloatTensor = memory_efficient_attention(q, k, v, attn_bias) + + out = rearrange(out, "b n h d -> b n (h d)") + return self.to_out(out) \ No newline at end of file diff --git a/attn_test.py b/attn_test.py index e99177e..14b2475 100644 --- a/attn_test.py +++ b/attn_test.py @@ -1,7 +1,9 @@ from attn.ein_attn import Attention as EinAttn from attn.sdp_attn import Attention as SDPAttn +from attn.xformers_attn import Attention as XformersAttn import torch from torch import FloatTensor, BoolTensor, manual_seed, randn, arange, allclose, no_grad +from torch.backends.cuda import sdp_kernel device = torch.device('cuda') dtype = torch.float32 @@ -23,8 +25,10 @@ # seed RNG before we initialize any layers, so that both will end up with same params manual_seed(seed) ein_attn = EinAttn(**attn_init_params).to(device, dtype).eval() + # manual_seed(seed) + # sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval() manual_seed(seed) - sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval() + xfo_attn = XformersAttn(**attn_init_params).to(device, dtype).eval() batch_size = 2 @@ -34,7 +38,7 @@ # generate rand on-CPU for cross-platform determinism of results x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device) - text_tokens = 16 # CLIP would be 77 + text_tokens = 15 # CLIP would be 77 # there's no reason why these would **have** to be the same (in stable-diffusion text_dim is 768) # but lucid didn't expose any separate param for customizing the cross attention input dim. # easily fixed, but whatever I'll work with what's there. @@ -42,14 +46,16 @@ context: FloatTensor = randn(batch_size, text_tokens, text_dim, dtype=dtype).to(device) # attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens) - context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1) - # context_mask = None + context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1).contiguous() ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) - sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) + # with sdp_kernel(enable_math=False): + # sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) + xfo_attn: FloatTensor = xfo_attn.forward(x, context, context_mask) # default relative and absolute tolerance rtol=1e-5 atol=5e-7 - assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" + # assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" + assert allclose(ein_result, xfo_attn, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" print(f'attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}') \ No newline at end of file diff --git a/setup.py b/setup.py index f55dfee..c89a3d6 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,8 @@ "tqdm-loggable", "vector-quantize-pytorch>=0.10.14", "lion-pytorch", - "omegaconf" + "omegaconf", + "xformers>=0.0.20", ], classifiers=[ "Development Status :: 4 - Beta", From ba502c3949ed0c98b2e77f935ebb560eeddaec9e Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 15:23:28 +0100 Subject: [PATCH 49/62] support arbitrary context lengths --- attn_test.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/attn_test.py b/attn_test.py index 14b2475..46f92a8 100644 --- a/attn_test.py +++ b/attn_test.py @@ -4,6 +4,7 @@ import torch from torch import FloatTensor, BoolTensor, manual_seed, randn, arange, allclose, no_grad from torch.backends.cuda import sdp_kernel +from torch.nn.functional import pad device = torch.device('cuda') dtype = torch.float32 @@ -38,7 +39,7 @@ # generate rand on-CPU for cross-platform determinism of results x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device) - text_tokens = 15 # CLIP would be 77 + text_tokens = 16 # CLIP would be 77 # there's no reason why these would **have** to be the same (in stable-diffusion text_dim is 768) # but lucid didn't expose any separate param for customizing the cross attention input dim. # easily fixed, but whatever I'll work with what's there. @@ -48,10 +49,22 @@ # attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens) context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1).contiguous() + # for xformers cutlassF kernel: masks are only supported for keys whose lengths are multiples of 8: + # https://gist.github.com/Birch-san/0c36d228e1d4b881a06d1c6e5289d569 + # so, we add whatever we feel like to the end of the key to extend it to a multiple of 8, + # and add "discard" tokens to the mask to get rid of the excess + # note: muse will add an extra "null" token to our context, so we'll account for that in advance + mask_length = context_mask.shape[-1] + 1 + extra_tokens_needed = 8 - (mask_length % 8) + # 0-pad mask to multiple of 8 tokens + xfo_context_mask = pad(context_mask, (0, extra_tokens_needed)) + # replicate-pad embedding to multiple of 8 tokens (mask will hide the extra tokens) + xfo_context = pad(context, (0, 0, 0, extra_tokens_needed,), 'replicate') + ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) # with sdp_kernel(enable_math=False): # sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) - xfo_attn: FloatTensor = xfo_attn.forward(x, context, context_mask) + xfo_attn: FloatTensor = xfo_attn.forward(x, xfo_context, xfo_context_mask) # default relative and absolute tolerance rtol=1e-5 From 8d7e020d46ef92c74753cddcbc18844919aa6420 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 17:51:56 +0100 Subject: [PATCH 50/62] tidy / clarify --- attn/ein_attn.py | 1 - attn/sdp_attn.py | 6 ++++-- attn/xformers_attn.py | 6 ++++-- attn_test.py | 6 +++++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/attn/ein_attn.py b/attn/ein_attn.py index 4ffa6fb..e6dfd50 100644 --- a/attn/ein_attn.py +++ b/attn/ein_attn.py @@ -19,7 +19,6 @@ def __init__(self, dim): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) -# TODO: make faster class Attention(nn.Module): def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() diff --git a/attn/sdp_attn.py b/attn/sdp_attn.py index da65da6..100f841 100644 --- a/attn/sdp_attn.py +++ b/attn/sdp_attn.py @@ -1,4 +1,4 @@ -from torch import einsum, nn, FloatTensor +from torch import nn, FloatTensor, BoolTensor import torch import torch.nn.functional as F from torch.nn.functional import scaled_dot_product_attention @@ -38,14 +38,16 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask=None): + def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask: Optional[BoolTensor]=None): assert (context is None) != self.cross_attend h = self.heads + # TODO: you could fuse this layernorm with the linear that follows it, e.g. via TransformerEngine x = self.norm(x) kv_input = context if self.cross_attend else x + # TODO: to_q and to_kvs could be combined into one to_qkv q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) diff --git a/attn/xformers_attn.py b/attn/xformers_attn.py index c7f1422..54f32c6 100644 --- a/attn/xformers_attn.py +++ b/attn/xformers_attn.py @@ -1,4 +1,4 @@ -from torch import nn, FloatTensor +from torch import nn, FloatTensor, BoolTensor import torch import torch.nn.functional as F from einops import rearrange, repeat @@ -38,14 +38,16 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask=None): + def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask: Optional[BoolTensor]=None): assert (context is None) != self.cross_attend h = self.heads + # TODO: you could fuse this layernorm with the linear that follows it, e.g. via TransformerEngine x = self.norm(x) kv_input = context if self.cross_attend else x + # TODO: to_q and to_kvs could be combined into one to_qkv q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q, k, v)) diff --git a/attn_test.py b/attn_test.py index 46f92a8..f2bdd47 100644 --- a/attn_test.py +++ b/attn_test.py @@ -26,6 +26,7 @@ # seed RNG before we initialize any layers, so that both will end up with same params manual_seed(seed) ein_attn = EinAttn(**attn_init_params).to(device, dtype).eval() + # commented-out scaled dot product attention because it didn't support flash attn, so we'll try with xformers instead. # manual_seed(seed) # sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval() manual_seed(seed) @@ -39,6 +40,7 @@ # generate rand on-CPU for cross-platform determinism of results x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device) + # I've said text here simply as an example of something you could cross-attend to text_tokens = 16 # CLIP would be 77 # there's no reason why these would **have** to be the same (in stable-diffusion text_dim is 768) # but lucid didn't expose any separate param for customizing the cross attention input dim. @@ -62,12 +64,14 @@ xfo_context = pad(context, (0, 0, 0, extra_tokens_needed,), 'replicate') ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) + # sdp attn works, but only supports flash attn when context_mask is None. # with sdp_kernel(enable_math=False): # sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) xfo_attn: FloatTensor = xfo_attn.forward(x, xfo_context, xfo_context_mask) - # default relative and absolute tolerance + # default rtol rtol=1e-5 + # atol would normally be 1e-8 atol=5e-7 # assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" assert allclose(ein_result, xfo_attn, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" From 451221f2aa489fb3a45755e99cebe92cec84f256 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 17:54:57 +0100 Subject: [PATCH 51/62] more clarification --- attn_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/attn_test.py b/attn_test.py index f2bdd47..2c41672 100644 --- a/attn_test.py +++ b/attn_test.py @@ -42,9 +42,8 @@ # I've said text here simply as an example of something you could cross-attend to text_tokens = 16 # CLIP would be 77 - # there's no reason why these would **have** to be the same (in stable-diffusion text_dim is 768) - # but lucid didn't expose any separate param for customizing the cross attention input dim. - # easily fixed, but whatever I'll work with what's there. + # for a *general* cross-attention Module: + # kv_in_dim could differ from q_in_dim, but this attention Module requires x and context to have same dim. text_dim = vision_dim context: FloatTensor = randn(batch_size, text_tokens, text_dim, dtype=dtype).to(device) From dfaeb14880d546a220c507a68242a08e860b2d31 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Mon, 12 Jun 2023 14:51:52 +0800 Subject: [PATCH 52/62] Implement the attention types saves a surprising amount of vram --- muse_maskgit_pytorch/attn/__init__.py | 0 .../attn/attn_test.py | 6 +- .../attn}/ein_attn.py | 0 .../attn}/sdp_attn.py | 0 .../attn}/xformers_attn.py | 0 muse_maskgit_pytorch/muse_maskgit_pytorch.py | 106 +++++++----------- train_muse_maskgit.py | 40 ++++--- 7 files changed, 68 insertions(+), 84 deletions(-) create mode 100644 muse_maskgit_pytorch/attn/__init__.py rename attn_test.py => muse_maskgit_pytorch/attn/attn_test.py (94%) rename {attn => muse_maskgit_pytorch/attn}/ein_attn.py (100%) rename {attn => muse_maskgit_pytorch/attn}/sdp_attn.py (100%) rename {attn => muse_maskgit_pytorch/attn}/xformers_attn.py (100%) diff --git a/muse_maskgit_pytorch/attn/__init__.py b/muse_maskgit_pytorch/attn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/attn_test.py b/muse_maskgit_pytorch/attn/attn_test.py similarity index 94% rename from attn_test.py rename to muse_maskgit_pytorch/attn/attn_test.py index 2c41672..052be45 100644 --- a/attn_test.py +++ b/muse_maskgit_pytorch/attn/attn_test.py @@ -1,9 +1,7 @@ -from attn.ein_attn import Attention as EinAttn -from attn.sdp_attn import Attention as SDPAttn -from attn.xformers_attn import Attention as XformersAttn +from muse_maskgit_pytorch.attn.ein_attn import Attention as EinAttn +from muse_maskgit_pytorch.attn.xformers_attn import Attention as XformersAttn import torch from torch import FloatTensor, BoolTensor, manual_seed, randn, arange, allclose, no_grad -from torch.backends.cuda import sdp_kernel from torch.nn.functional import pad device = torch.device('cuda') diff --git a/attn/ein_attn.py b/muse_maskgit_pytorch/attn/ein_attn.py similarity index 100% rename from attn/ein_attn.py rename to muse_maskgit_pytorch/attn/ein_attn.py diff --git a/attn/sdp_attn.py b/muse_maskgit_pytorch/attn/sdp_attn.py similarity index 100% rename from attn/sdp_attn.py rename to muse_maskgit_pytorch/attn/sdp_attn.py diff --git a/attn/xformers_attn.py b/muse_maskgit_pytorch/attn/xformers_attn.py similarity index 100% rename from attn/xformers_attn.py rename to muse_maskgit_pytorch/attn/xformers_attn.py diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 4ce391a..6bedb36 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -19,6 +19,14 @@ from .vqgan_vae import VQGanVAE from .vqgan_vae_taming import VQGanVAETaming +from .attn import ein_attn, sdp_attn + +try: + from .attn import xformers_attn + xformer_attn = True +except ImportError: + xformer_attn = False + # helpers def exists(val): @@ -93,79 +101,43 @@ def FeedForward(dim, mult=4): ) -class Attention(nn.Module): - def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): - super().__init__() - self.scale = scale - self.heads = heads - inner_dim = dim_head * heads - - self.cross_attend = cross_attend - self.norm = LayerNorm(dim) - - self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - - self.q_scale = nn.Parameter(torch.ones(dim_head)) - self.k_scale = nn.Parameter(torch.ones(dim_head)) - - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - def forward(self, x, context=None, context_mask=None): - assert not (exists(context) ^ self.cross_attend) - - h = self.heads - x = self.norm(x) - - kv_input = context if self.cross_attend else x - - q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - nk, nv = self.null_kv - nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv)) - - k = torch.cat((nk, k), dim=-2) - v = torch.cat((nv, v), dim=-2) - - q, k = map(l2norm, (q, k)) - q = q * self.q_scale - k = k * self.k_scale - - sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale - - if exists(context_mask): - context_mask = rearrange(context_mask, "b j -> b 1 1 j") - context_mask = F.pad(context_mask, (1, 0), value=True) - - mask_value = -torch.finfo(sim.dtype).max - sim = sim.masked_fill(~context_mask, mask_value) - - attn = sim.softmax(dim=-1) - out = einsum("b h i j, b h j d -> b h i d", attn, v) - - out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) - - class TransformerBlocks(nn.Module): - def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4): + def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4, flash=True, xformers=False): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - Attention(dim=dim, dim_head=dim_head, heads=heads), - Attention(dim=dim, dim_head=dim_head, heads=heads, cross_attend=True), - FeedForward(dim=dim, mult=ff_mult), - ] + if flash: + if xformers and xformer_attn: + self.layers.append( + nn.ModuleList( + [ + xformers_attn.Attention(dim=dim, dim_head=dim_head, heads=heads), + xformers_attn.Attention(dim=dim, dim_head=dim_head, heads=heads, cross_attend=True), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + else: + self.layers.append( + nn.ModuleList( + [ + sdp_attn.Attention(dim=dim, dim_head=dim_head, heads=heads), + sdp_attn.Attention(dim=dim, dim_head=dim_head, heads=heads, cross_attend=True), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + else: + self.layers.append( + nn.ModuleList( + [ + ein_attn.Attention(dim=dim, dim_head=dim_head, heads=heads), + ein_attn.Attention(dim=dim, dim_head=dim_head, heads=heads, cross_attend=True), + FeedForward(dim=dim, mult=ff_mult), + ] + ) ) - ) self.norm = LayerNorm(dim) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index fa354f5..036338d 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -389,6 +389,12 @@ default=None, help="debug logging on", ) +parser.add_argument( + "--attention_type", + type=str, + default="flash", + help="what type of attention to use [ein, flash, xformers] | Default: flash", +) @dataclass class Arguments: @@ -457,6 +463,8 @@ class Arguments: use_l2_recon_loss: bool = False debug: bool = False config_path: Optional[str] = None + attention_type: str = "flash" + def main(): args = parser.parse_args(namespace=Arguments()) @@ -557,7 +565,6 @@ def main(): print("Finding latest checkpoint...") orig_vae_path = args.vae_path - if os.path.isfile(args.vae_path) or '.pt' in args.vae_path: # If args.vae_path is a file, split it into directory and filename args.vae_path, _ = os.path.split(args.vae_path) @@ -589,18 +596,17 @@ def main(): # use config next to checkpoint if there is one and merge the cli arguments to it # the cli arguments will take priority so we can use it to override any value we want. #if os.path.exists(f"{args.vae_path}.yaml"): - #print("Config file found, reusing config from it. Use cli arguments to override any desired value.") - #conf = OmegaConf.load(f"{args.vae_path}.yaml") - #cli_conf = OmegaConf.from_cli() - ## merge the config file and the cli arguments. - #conf = OmegaConf.merge(conf, cli_conf) + #print("Config file found, reusing config from it. Use cli arguments to override any desired value.") + #conf = OmegaConf.load(f"{args.vae_path}.yaml") + #cli_conf = OmegaConf.from_cli() + ## merge the config file and the cli arguments. + #conf = OmegaConf.merge(conf, cli_conf) vae = VQGanVAE(dim=args.dim, vq_codebook_dim=args.vq_codebook_dim, vq_codebook_size=args.vq_codebook_size, l2_recon_loss=args.use_l2_recon_loss).to( accelerator.device ) vae.load(args.vae_path) - elif args.taming_model_path is not None and args.taming_config_path is not None: print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}") vae = VQGanVAETaming( @@ -614,11 +620,6 @@ def main(): raise ValueError( "You must pass either vae_path or taming_model_path + taming_config_path (but not both)" ) - - - # freeze VAE before parsing to transformer - vae.requires_grad_(False) - # freeze VAE before parsing to transformer vae.requires_grad_(False) @@ -626,6 +627,18 @@ def main(): # then you plug the vae and transformer into your MaskGit like so: # (1) create your transformer / attention network + if args.attention_type == "flash": + xformers = False + flash = True + elif args.attention_type == "xformers": + xformers = True + flash = True + elif args.attention_type == "ein": + xformers = False + flash = False + else: + raise NotImplementedError(f"Attention of type \"{args.attention_type}\" does not exist") + transformer: MaskGitTransformer = MaskGitTransformer( # num_tokens must be same as codebook size above num_tokens=args.num_tokens if args.num_tokens else args.vq_codebook_size, @@ -640,6 +653,8 @@ def main(): # name of your T5 model configuration t5_name=args.t5_name, cache_path=args.cache_path, + flash=flash, + xformers=xformers ) # load the maskgit transformer from disk if we have previously trained one @@ -648,7 +663,6 @@ def main(): accelerator.print("Finding latest checkpoint...") orig_vae_path = args.resume_path - if os.path.isfile(args.resume_path) or '.pt' in args.resume_path: # If args.resume_path is a file, split it into directory and filename args.resume_path, _ = os.path.split(args.resume_path) From b5070ef5ebdf377219e08b290f63c78fabfeb6e3 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Mon, 12 Jun 2023 03:26:04 -0700 Subject: [PATCH 53/62] Added option to specify the project name, run name and user/organization to use for trackers such as wandb. --- train_muse_maskgit.py | 32 +++++++++++++++++++++++++++++++- train_muse_vae.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index d98aa50..11bb4b2 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -17,6 +17,7 @@ import os import glob import re +import wandb from omegaconf import OmegaConf from accelerate.utils import ProjectConfiguration @@ -156,6 +157,25 @@ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) +parser.add_argument( + "--project_name", + type=str, + default="muse_maskgit", + help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."), +) +parser.add_argument( + "--run_name", + type=str, + default=None, + help=("Name to use for the run to identify it when saved to a tracker such" + " as wandb or tensorboard. If not specified a random one will be generated."), +) +parser.add_argument( + "--wandb_user", + type=str, + default=None, + help=("Specify the name for the user or the organization in which the project will be saved when using wand."), +) parser.add_argument( "--mixed_precision", type=str, @@ -878,7 +898,17 @@ def main(): ) if accelerator.is_main_process: - accelerator.init_trackers("muse_maskgit", config=vars(args)) + accelerator.init_trackers( + args.project_name, + config=vars(args), + init_kwargs={ + "wandb":{ + "entity": f"{args.wandb_user or wandb.api.default_entity}", + "name": args.run_name, + }, + } + + ) # Create the trainer accelerator.wait_for_everyone() diff --git a/train_muse_vae.py b/train_muse_vae.py index 912b84c..2380e62 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -19,6 +19,7 @@ import os import glob import re +import wandb from omegaconf import OmegaConf from accelerate.utils import ProjectConfiguration @@ -110,6 +111,25 @@ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) +parser.add_argument( + "--project_name", + type=str, + default="muse_vae", + help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."), +) +parser.add_argument( + "--run_name", + type=str, + default=None, + help=("Name to use for the run to identify it when saved to a tracker such" + " as wandb or tensorboard. If not specified a random one will be generated."), +) +parser.add_argument( + "--wandb_user", + type=str, + default=None, + help=("Specify the name for the user or the organization in which the project will be saved when using wand."), +) parser.add_argument( "--mixed_precision", type=str, @@ -389,7 +409,18 @@ def main(): even_batches=True ) if accelerator.is_main_process: - accelerator.init_trackers("muse_vae", config=vars(args)) + accelerator.init_trackers( + args.project_name, + config=vars(args), + init_kwargs={ + "wandb":{ + "entity": f"{args.wandb_user or wandb.api.default_entity}", + "name": args.run_name, + }, + } + + ) + if args.webdataset is not None: import webdataset as wds From c64fc2b8119618d01eae9360301e517aa5533f35 Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Mon, 12 Jun 2023 22:23:49 +1000 Subject: [PATCH 54/62] add pre-commit setup --- .pre-commit-config.yaml | 21 +++++++++++++++++++++ pyproject.toml | 11 +++++++---- setup.py | 9 ++++++++- 3 files changed, 36 insertions(+), 5 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..7f63f19 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: "v0.0.272" + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black diff --git a/pyproject.toml b/pyproject.toml index 70d39be..b019588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,6 @@ [build-system] build-backend = "setuptools.build_meta" -requires = [ - "setuptools>=61.0.0", - "wheel", -] +requires = ["setuptools>=61.0.0", "wheel", "setuptools_scm[toml]>=6.2"] [tool.setuptools_scm] write_to = "muse_maskgit_pytorch/_version.py" @@ -15,7 +12,13 @@ target-version = ['py38', 'py39', 'py310'] [tool.ruff] line-length = 110 target-version = 'py38' +format = "grouped" +ignore-init-module-imports = true +select = ["E", "F", "I"] +ignore = ['F841', 'F401', 'E501'] [tool.ruff.isort] combine-as-imports = true force-wrap-aliases = true +known-local-folder = ["muse_maskgit_pytorch"] +known-first-party = ["muse_maskgit_pytorch"] diff --git a/setup.py b/setup.py index c89a3d6..826d3d9 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name="muse-maskgit-pytorch", @@ -17,6 +17,13 @@ "attention mechanism", "text-to-image", ], + extras_require={ + "dev": [ + "pre-commit>=3.3.2", + "black>=23.3.0", + "ruff>=0.0.272", + ] + }, install_requires=[ "accelerate", "diffusers", From c459bd3a48ef3d9d33f96b600f121ea96851dbad Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Mon, 12 Jun 2023 22:24:02 +1000 Subject: [PATCH 55/62] add vscode settings and extension recommendations --- .vscode/extensions.json | 9 +++++++++ .vscode/settings.json | 44 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 .vscode/extensions.json create mode 100644 .vscode/settings.json diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..4e30ad8 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,9 @@ +{ + "recommendations": [ + "ms-python.python", + "charliermarsh.ruff", + "redhat.vscode-yaml", + "codezombiech.gitignore", + "ms-python.black-formatter" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..5c2ddf2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,44 @@ +{ + "editor.formatOnSaveMode": "file", + + "files.associations": { + ".config": "shellscript", + ".gitignore": "gitignore", + ".vscode/*.json": "jsonc", + "*.txt": "plaintext", + "requirements*.txt": "pip-requirements", + "setup.cfg": "ini", + }, + + "[json]": { + "editor.codeActionsOnSave": { + "source.fixAll.sortJSON": false + }, + "editor.defaultFormatter": "vscode.json-language-features", + "editor.formatOnSave": true, + "editor.tabSize": 4 + }, + "[jsonc]": { + "editor.codeActionsOnSave": { + "source.fixAll.sortJSON": false + }, + "editor.defaultFormatter": "vscode.json-language-features", + "editor.formatOnSave": true, + "editor.tabSize": 4 + }, + "json.format.keepLines": true, + + "[python]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.codeActionsOnSave": { + "source.organizeImports": true + } + }, + "python.formatting.provider": "none", + "ruff.organizeImports": true, + "ruff.args": [ "--line-length=110", "--extend-ignore=F401,F841" ], + "black-formatter.args": [ "--line-length", "110" ], + "python.linting.flake8Enabled": false, + "python.linting.mypyEnabled": false, +} From f5fd2c54b6c6f9d21d6fb0ef64a28c164fb454ca Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Mon, 12 Jun 2023 22:24:57 +1000 Subject: [PATCH 56/62] actions: add pre-commit action --- .github/workflows/pre-commit.yml | 40 ++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .github/workflows/pre-commit.yml diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..0826906 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,40 @@ +name: pre-commit + +on: + pull_request: + push: + branches: + - main + - dev + +jobs: + pre-commit: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: [3.8, 3.10] + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout + id: checkout + uses: actions/checkout@v3 + with: + submodules: "recursive" + + - name: Set up Python + id: setup-python + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + + - name: Install package + id: install-package + run: | + python -m pip install --upgrade pip setuptools wheel + pip install -e '.[dev]' + + - name: Run pre-commit + uses: pre-commit/action@v3.0.0 From 8ea7da74997f0ea98ba1d6eb956ec0eb3aba81b0 Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Mon, 12 Jun 2023 22:40:16 +1000 Subject: [PATCH 57/62] run pre-commit on everything --- .github/workflows/python-publish.yml | 2 +- .vscode/launch.json | 2 +- infer_vae.py | 104 ++++++---- muse_maskgit_pytorch/attn/attn_test.py | 44 ++-- muse_maskgit_pytorch/attn/ein_attn.py | 8 +- muse_maskgit_pytorch/attn/sdp_attn.py | 20 +- muse_maskgit_pytorch/attn/xformers_attn.py | 20 +- muse_maskgit_pytorch/dataset.py | 37 ++-- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 14 +- muse_maskgit_pytorch/t5.py | 2 +- .../trainers/base_accelerated_trainer.py | 43 +++- .../trainers/maskgit_trainer.py | 64 +++--- .../trainers/vqvae_trainers.py | 112 +++++----- muse_maskgit_pytorch/vqgan_vae.py | 12 +- muse_maskgit_pytorch/vqgan_vae_taming.py | 5 +- train_muse_maskgit.py | 191 ++++++++++-------- train_muse_vae.py | 76 +++---- 17 files changed, 453 insertions(+), 303 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 5f38eed..6c961fc 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,5 +1,5 @@ - + # This workflow will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries diff --git a/.vscode/launch.json b/.vscode/launch.json index 2537580..d291071 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,4 +8,4 @@ "justMyCode": false } ] -} \ No newline at end of file +} diff --git a/infer_vae.py b/infer_vae.py index 564b68f..88aea18 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -1,24 +1,30 @@ -import torch -import accelerate +import argparse +import glob +import hashlib +import os +import random +import re from dataclasses import dataclass -from torchvision.utils import save_image -from datasets import load_dataset, Dataset, Image -import os, random, hashlib from datetime import datetime +from typing import Optional + +import accelerate +import PIL +import torch +from accelerate.utils import ProjectConfiguration +from datasets import Dataset, Image, load_dataset +from torchvision.utils import save_image +from tqdm import tqdm + from muse_maskgit_pytorch import ( VQGanVAE, VQGanVAETaming, get_accelerator, ) from muse_maskgit_pytorch.dataset import ( - get_dataset_from_dataroot, ImageDataset, + get_dataset_from_dataroot, ) -from tqdm import tqdm -import argparse -import PIL -import glob, re -from accelerate.utils import ProjectConfiguration # Create the parser parser = argparse.ArgumentParser() @@ -54,9 +60,7 @@ default=42, help="Seed for reproducibility. If set to -1 a random seed will be generated.", ) -parser.add_argument( - "--valid_frac", type=float, default=0.05, help="validation fraction." -) +parser.add_argument("--valid_frac", type=float, default=0.05, help="validation fraction.") parser.add_argument( "--image_column", type=str, @@ -186,6 +190,7 @@ help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.", ) + @dataclass class Arguments: only_save_last_checkpoint: bool = False @@ -277,16 +282,16 @@ def main(): args = parser.parse_args(namespace=Arguments()) project_config = ProjectConfiguration( - project_dir=args.logging_dir, - automatic_checkpoint_naming=True, - ) + project_dir=args.logging_dir, + automatic_checkpoint_naming=True, + ) accelerator: accelerate.Accelerator = get_accelerator( log_with=args.log_with, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, project_config=project_config, - even_batches=True + even_batches=True, ) # set pytorch seed for reproducibility @@ -308,7 +313,7 @@ def main(): if args.input_folder: # Create dataset from input folder extensions = ["jpg", "jpeg", "png", "webp"] - exclude_folders = args.exclude_folders.split(',') if args.exclude_folders else [] + exclude_folders = args.exclude_folders.split(",") if args.exclude_folders else [] filepaths = [] for root, dirs, files in os.walk(args.input_folder, followlinks=True): @@ -324,7 +329,7 @@ def main(): if not filepaths: print(f"No images with extensions {extensions} found in {args.input_folder}.") - sys.exit(1) + exit(1) dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image()) @@ -333,29 +338,42 @@ def main(): if args.vae_path: accelerator.print("Loading Muse VQGanVAE") - vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim).to( - accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}" - ) + vae = VQGanVAE( + dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim + ).to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") if args.latest_checkpoint: accelerator.print("Finding latest checkpoint...") orig_vae_path = args.vae_path - - if os.path.isfile(args.vae_path) or '.pt' in args.vae_path: + if os.path.isfile(args.vae_path) or ".pt" in args.vae_path: # If args.vae_path is a file, split it into directory and filename args.vae_path, _ = os.path.split(args.vae_path) checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) if checkpoint_files: - latest_checkpoint_file = max(checkpoint_files,key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + latest_checkpoint_file = max( + checkpoint_files, + key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1)) + if not x.endswith("ema.pt") + else -1, + ) # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): - accelerator.print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( + latest_checkpoint_file, os.R_OK + ): + accelerator.print( + f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable." + ) if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback - latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + latest_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1)) + if not x.endswith("ema.pt") + else -1, + ) accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) else: accelerator.print("No usable checkpoint found.") @@ -389,7 +407,7 @@ def main(): image_column=args.image_column, center_crop=True if not args.no_center_crop and not args.random_crop else False, flip=not args.no_flip, - random_crop=args.random_crop if args.random_crop else False + random_crop=args.random_crop if args.random_crop else False, ) if args.input_image and not args.input_folder: @@ -397,9 +415,13 @@ def main(): os.makedirs(f"{args.results_dir}/outputs", exist_ok=True) - save_image(dataset[image_id], f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}") + save_image( + dataset[image_id], f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}" + ) - _, ids, _ = vae.encode(dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + _, ids, _ = vae.encode( + dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ) recon = vae.decode_from_ids(ids) save_image(recon, f"{args.results_dir}/outputs/output.{str(args.input_image).split('.')[-1]}") @@ -410,11 +432,12 @@ def main(): save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png") - _, ids, _ = vae.encode(dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + _, ids, _ = vae.encode( + dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ) recon = vae.decode_from_ids(ids) save_image(recon, f"{args.results_dir}/outputs/output.png") - if args.input_folder: # Create output directory and save input images and reconstructions as grids output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder)) @@ -426,7 +449,9 @@ def main(): try: save_image(dataset[i], f"{output_dir}/input.png") - _, ids, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + _, ids, _ = vae.encode( + dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + ) recon = vae.decode_from_ids(ids) save_image(recon, f"{output_dir}/output.png") @@ -435,7 +460,9 @@ def main(): output_image = PIL.Image.open(f"{output_dir}/output.png") # Create horizontal grid with input and output images - grid_image = PIL.Image.new('RGB', (input_image.width + output_image.width, input_image.height)) + grid_image = PIL.Image.new( + "RGB", (input_image.width + output_image.width, input_image.height) + ) grid_image.paste(input_image, (0, 0)) grid_image.paste(output_image, (input_image.width, 0)) @@ -462,7 +489,7 @@ def main(): except RuntimeError as e: if "out of memory" in str(e) and retries < args.max_retries: retries += 1 - #print(f"Out of Memory. Retry #{retries}") + # print(f"Out of Memory. Retry #{retries}") torch.cuda.empty_cache() torch.cuda.ipc_collect() continue # Retry the loop @@ -472,6 +499,5 @@ def main(): break # Exit the retry loop after too many retries - if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/muse_maskgit_pytorch/attn/attn_test.py b/muse_maskgit_pytorch/attn/attn_test.py index 052be45..ba313ae 100644 --- a/muse_maskgit_pytorch/attn/attn_test.py +++ b/muse_maskgit_pytorch/attn/attn_test.py @@ -1,10 +1,11 @@ -from muse_maskgit_pytorch.attn.ein_attn import Attention as EinAttn -from muse_maskgit_pytorch.attn.xformers_attn import Attention as XformersAttn import torch -from torch import FloatTensor, BoolTensor, manual_seed, randn, arange, allclose, no_grad +from torch import BoolTensor, FloatTensor, allclose, arange, manual_seed, no_grad, randn from torch.nn.functional import pad -device = torch.device('cuda') +from muse_maskgit_pytorch.attn.ein_attn import Attention as EinAttn +from muse_maskgit_pytorch.attn.xformers_attn import Attention as XformersAttn + +device = torch.device("cuda") dtype = torch.float32 seed = 42 @@ -12,12 +13,12 @@ vision_dim = 64 attn_init_params = { - 'dim': vision_dim, - 'dim_head': 64, + "dim": vision_dim, + "dim_head": 64, # realistically this would be at least 5 - 'heads': 2, - 'cross_attend': True, - 'scale': 8, + "heads": 2, + "cross_attend": True, + "scale": 8, } with no_grad(): @@ -33,13 +34,13 @@ batch_size = 2 # realistically this would be 64**2 in stable-diffusion - vision_tokens = 32**2 # 1024 + vision_tokens = 32**2 # 1024 # generate rand on-CPU for cross-platform determinism of results x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device) # I've said text here simply as an example of something you could cross-attend to - text_tokens = 16 # CLIP would be 77 + text_tokens = 16 # CLIP would be 77 # for a *general* cross-attention Module: # kv_in_dim could differ from q_in_dim, but this attention Module requires x and context to have same dim. text_dim = vision_dim @@ -58,7 +59,16 @@ # 0-pad mask to multiple of 8 tokens xfo_context_mask = pad(context_mask, (0, extra_tokens_needed)) # replicate-pad embedding to multiple of 8 tokens (mask will hide the extra tokens) - xfo_context = pad(context, (0, 0, 0, extra_tokens_needed,), 'replicate') + xfo_context = pad( + context, + ( + 0, + 0, + 0, + extra_tokens_needed, + ), + "replicate", + ) ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) # sdp attn works, but only supports flash attn when context_mask is None. @@ -67,9 +77,11 @@ xfo_attn: FloatTensor = xfo_attn.forward(x, xfo_context, xfo_context_mask) # default rtol - rtol=1e-5 + rtol = 1e-5 # atol would normally be 1e-8 - atol=5e-7 + atol = 5e-7 # assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" - assert allclose(ein_result, xfo_attn, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" - print(f'attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}') \ No newline at end of file + assert allclose( + ein_result, xfo_attn, rtol=rtol, atol=atol + ), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" + print(f"attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}") diff --git a/muse_maskgit_pytorch/attn/ein_attn.py b/muse_maskgit_pytorch/attn/ein_attn.py index e6dfd50..3c91639 100644 --- a/muse_maskgit_pytorch/attn/ein_attn.py +++ b/muse_maskgit_pytorch/attn/ein_attn.py @@ -1,15 +1,18 @@ -from torch import einsum, nn import torch import torch.nn.functional as F from einops import rearrange, repeat +from torch import einsum, nn + # helpers def exists(val): return val is not None + def l2norm(t): return F.normalize(t, dim=-1) + class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() @@ -19,6 +22,7 @@ def __init__(self, dim): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + class Attention(nn.Module): def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() @@ -74,4 +78,4 @@ def forward(self, x, context=None, context_mask=None): out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) \ No newline at end of file + return self.to_out(out) diff --git a/muse_maskgit_pytorch/attn/sdp_attn.py b/muse_maskgit_pytorch/attn/sdp_attn.py index 100f841..e21f859 100644 --- a/muse_maskgit_pytorch/attn/sdp_attn.py +++ b/muse_maskgit_pytorch/attn/sdp_attn.py @@ -1,13 +1,16 @@ -from torch import nn, FloatTensor, BoolTensor +from typing import Optional + import torch import torch.nn.functional as F -from torch.nn.functional import scaled_dot_product_attention from einops import rearrange, repeat -from typing import Optional +from torch import BoolTensor, FloatTensor, nn +from torch.nn.functional import scaled_dot_product_attention + def l2norm(t): return F.normalize(t, dim=-1) + class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() @@ -17,6 +20,7 @@ def __init__(self, dim): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + class Attention(nn.Module): def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() @@ -31,14 +35,16 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - typical_scale = dim_head ** -.5 - scale_ratio = scale/typical_scale + typical_scale = dim_head**-0.5 + scale_ratio = scale / typical_scale self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask: Optional[BoolTensor]=None): + def forward( + self, x: FloatTensor, context: Optional[FloatTensor] = None, context_mask: Optional[BoolTensor] = None + ): assert (context is None) != self.cross_attend h = self.heads @@ -69,4 +75,4 @@ def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_m out: FloatTensor = scaled_dot_product_attention(q, k, v, context_mask) out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) \ No newline at end of file + return self.to_out(out) diff --git a/muse_maskgit_pytorch/attn/xformers_attn.py b/muse_maskgit_pytorch/attn/xformers_attn.py index 54f32c6..c432ee0 100644 --- a/muse_maskgit_pytorch/attn/xformers_attn.py +++ b/muse_maskgit_pytorch/attn/xformers_attn.py @@ -1,13 +1,16 @@ -from torch import nn, FloatTensor, BoolTensor +from typing import Optional + import torch import torch.nn.functional as F from einops import rearrange, repeat -from typing import Optional +from torch import BoolTensor, FloatTensor, nn from xformers.ops import memory_efficient_attention + def l2norm(t): return F.normalize(t, dim=-1) + class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() @@ -17,6 +20,7 @@ def __init__(self, dim): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + class Attention(nn.Module): def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() @@ -31,14 +35,16 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - typical_scale = dim_head ** -.5 - scale_ratio = scale/typical_scale + typical_scale = dim_head**-0.5 + scale_ratio = scale / typical_scale self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask: Optional[BoolTensor]=None): + def forward( + self, x: FloatTensor, context: Optional[FloatTensor] = None, context_mask: Optional[BoolTensor] = None + ): assert (context is None) != self.cross_attend h = self.heads @@ -67,10 +73,10 @@ def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_m else: context_mask = F.pad(context_mask, (1, 0), value=True) context_mask = rearrange(context_mask, "b j -> b 1 1 j") - attn_bias = torch.where(context_mask == True, 0., -10000.) + attn_bias = torch.where(context_mask is True, 0.0, -10000.0) attn_bias = attn_bias.expand(-1, h, q.size(1), -1) out: FloatTensor = memory_efficient_attention(q, k, v, attn_bias) out = rearrange(out, "b n h d -> b n (h d)") - return self.to_out(out) \ No newline at end of file + return self.to_out(out) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index b185e05..055425a 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -1,5 +1,6 @@ import os -import random, shutil +import random +import shutil import sys import time from pathlib import Path @@ -8,24 +9,26 @@ import datasets import torch from datasets import Image, load_from_disk -from PIL import Image as pImage -from PIL import ImageFile +from PIL import ( + Image as pImage, + ImageFile, +) from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms as T try: import torch_xla import torch_xla.core.xla_model as xm - from tqdm_loggable.auto import tqdm except ImportError: from tqdm import tqdm +from io import BytesIO + +import requests from transformers import T5Tokenizer from muse_maskgit_pytorch.t5 import MAX_LENGTH -import requests -from io import BytesIO ImageFile.LOAD_TRUNCATED_IMAGES = True pImage.MAX_IMAGE_PIXELS = None @@ -41,7 +44,7 @@ def __init__( center_crop=True, stream=False, using_taming=False, - random_crop = False, + random_crop=False, ): super().__init__() self.dataset = dataset @@ -142,7 +145,7 @@ def __init__( caption_column="caption", flip=True, center_crop=True, - using_taming=True + using_taming=True, ): super().__init__( dataset, @@ -150,7 +153,7 @@ def __init__( image_column=image_column, flip=flip, center_crop=center_crop, - using_taming=using_taming + using_taming=using_taming, ) self.caption_column: str = caption_column self.tokenizer: T5Tokenizer = tokenizer @@ -195,7 +198,9 @@ def __getitem__(self, index): class LocalTextImageDataset(Dataset): - def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False): + def __init__( + self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False + ): super().__init__() self.tokenizer = tokenizer self.using_taming = using_taming @@ -293,8 +298,12 @@ def save_dataset_with_progress(dataset, save_path): def get_dataset_from_dataroot( - data_root, image_column="image", caption_column="caption", save_path="dataset", save=True, - ): + data_root, + image_column="image", + caption_column="caption", + save_path="dataset", + save=True, +): # Check if data_root is a symlink and resolve it to its target location if it is if os.path.islink(data_root): data_root = os.path.realpath(data_root) @@ -314,7 +323,9 @@ def get_dataset_from_dataroot( # Check if data_root is newer than save_path if data_root_mtime > save_path_mtime: - print("The data_root folder has being updated recently. Removing previously saved dataset and updating it.") + print( + "The data_root folder has being updated recently. Removing previously saved dataset and updating it." + ) shutil.rmtree(save_path, ignore_errors=True) else: print("The dataset is up-to-date. Loading...") diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 6bedb36..4af9686 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -11,18 +11,18 @@ from accelerate import Accelerator from beartype import beartype from einops import rearrange, repeat -from torch import einsum, nn, isnan +from torch import einsum, isnan, nn from tqdm.auto import tqdm from transformers import T5EncoderModel, T5Tokenizer +from .attn import ein_attn, sdp_attn from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text from .vqgan_vae import VQGanVAE from .vqgan_vae_taming import VQGanVAETaming -from .attn import ein_attn, sdp_attn - try: from .attn import xformers_attn + xformer_attn = True except ImportError: xformer_attn = False @@ -113,7 +113,9 @@ def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4, flash=True, x nn.ModuleList( [ xformers_attn.Attention(dim=dim, dim_head=dim_head, heads=heads), - xformers_attn.Attention(dim=dim, dim_head=dim_head, heads=heads, cross_attend=True), + xformers_attn.Attention( + dim=dim, dim_head=dim_head, heads=heads, cross_attend=True + ), FeedForward(dim=dim, mult=ff_mult), ] ) @@ -123,7 +125,9 @@ def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4, flash=True, x nn.ModuleList( [ sdp_attn.Attention(dim=dim, dim_head=dim_head, heads=heads), - sdp_attn.Attention(dim=dim, dim_head=dim_head, heads=heads, cross_attend=True), + sdp_attn.Attention( + dim=dim, dim_head=dim_head, heads=heads, cross_attend=True + ), FeedForward(dim=dim, mult=ff_mult), ] ) diff --git a/muse_maskgit_pytorch/t5.py b/muse_maskgit_pytorch/t5.py index 490c9f5..7e50b65 100644 --- a/muse_maskgit_pytorch/t5.py +++ b/muse_maskgit_pytorch/t5.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from functools import cached_property from os import PathLike -from typing import List, Optional, Tuple, Union, Dict +from typing import Dict, List, Optional, Tuple, Union import torch from beartype import beartype diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index 0698bf2..253e6d6 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -2,20 +2,38 @@ from pathlib import Path from shutil import rmtree from typing import Optional, Union + import accelerate -from PIL import Image import numpy as np import torch from accelerate import Accelerator, DistributedDataParallelKwargs, DistributedType from beartype import beartype from datasets import Dataset from lion_pytorch import Lion +from PIL import Image from torch import nn from torch.optim import Adam, AdamW, Optimizer -from torch_optimizer import AdaBound, AdaMod, AccSGD, AdamP, AggMo, DiffGrad, \ - Lamb, NovoGrad, PID, QHAdam, QHM, RAdam, SGDP, SGDW, Shampoo, SWATS, Yogi -from transformers.optimization import Adafactor from torch.utils.data import DataLoader, random_split +from torch_optimizer import ( + PID, + QHM, + SGDP, + SGDW, + SWATS, + AccSGD, + AdaBound, + AdaMod, + AdamP, + AggMo, + DiffGrad, + Lamb, + NovoGrad, + QHAdam, + RAdam, + Shampoo, + Yogi, +) +from transformers.optimization import Adafactor try: from accelerate.data_loader import MpDeviceLoaderWrapper @@ -138,7 +156,14 @@ def get_optimizer( else Lion(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs) ) elif optimizer == "Adafactor": - return Adafactor(parameters, lr=lr, weight_decay=weight_decay, relative_step=False, scale_parameter=False, **optimizer_kwargs) + return Adafactor( + parameters, + lr=lr, + weight_decay=weight_decay, + relative_step=False, + scale_parameter=False, + **optimizer_kwargs, + ) elif optimizer == "AccSGD": return AccSGD(parameters, lr=lr, weight_decay=weight_decay) elif optimizer == "AdaBound": @@ -283,16 +308,12 @@ def log_validation_images(self, images, step, prompts=None): for tracker in self.accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images( - "validation", np_images, step, dataformats="NHWC" - ) + tracker.writer.add_images("validation", np_images, step, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { "validation": [ - wandb.Image( - image, caption="" if not prompts else prompts[i] - ) + wandb.Image(image, caption="" if not prompts else prompts[i]) for i, image in enumerate(images) ] } diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 9219b16..cba2ba2 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -5,6 +5,7 @@ from accelerate import Accelerator from diffusers.optimization import SchedulerType from ema_pytorch import EMA +from omegaconf import OmegaConf from PIL import Image from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -14,8 +15,6 @@ from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer -from omegaconf import OmegaConf - try: import torch_xla import torch_xla.core.xla_model as xm @@ -111,14 +110,16 @@ def __init__( else: self.training_bar = tqdm(initial=int(self.steps.item()), total=self.num_train_steps) - self.info_bar = tqdm(total=0, bar_format='{desc}') + self.info_bar = tqdm(total=0, bar_format="{desc}") def save_validation_images( self, validation_prompts, step: int, cond_image=None, cond_scale=3, temperature=1 ): # moved the print to the top of the function so it shows before the progress bar for reability. if validation_prompts: - self.accelerator.print(f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}") + self.accelerator.print( + f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}" + ) images = self.model.generate( validation_prompts, @@ -174,18 +175,22 @@ def train(self): logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]} if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: " - f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") + self.accelerator.print( + f"\n[E{epoch + 1}][{steps}]{proc_label}: " + f"maskgit loss: {logs['loss']} - lr: {logs['lr']}" + ) else: self.training_bar.update() - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " - f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") + self.info_bar.set_description_str( + f"[E{epoch + 1}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}" + ) self.accelerator.log(logs, step=steps) if not (steps % self.save_model_every): - self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: " - f"saving model to {self.results_dir}") + self.accelerator.print( + f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"saving model to {self.results_dir}" + ) state_dict = self.accelerator.unwrap_model(self.model).state_dict() maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" @@ -207,7 +212,8 @@ def train(self): if self.use_ema: self.accelerator.print( f"\n[E{epoch + 1}][{steps}]{proc_label}: " - f"saving EMA model to {self.results_dir}") + f"saving EMA model to {self.results_dir}" + ) ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = ( @@ -231,20 +237,23 @@ def train(self): self.validation_prompts = [""] * self.batch_size if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " - f"Logging validation images") + self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " f"Logging validation images") else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " - f"Logging validation images") + self.info_bar.set_description_str( + f"[E{epoch + 1}]{proc_label}: " f"Logging validation images" + ) saved_image = self.save_validation_images( self.validation_prompts, steps, cond_image=cond_image ) if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: saved to {saved_image}") + self.accelerator.print( + f"\n[E{epoch + 1}][{steps}]{proc_label}: saved to {saved_image}" + ) else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " - f"saved to {saved_image}") + self.info_bar.set_description_str( + f"[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}" + ) if met is not None and not (steps % self.log_metrics_every): if self.on_tpu: @@ -256,15 +265,20 @@ def train(self): if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" - f"[STOP EARLY]: Stopping training early...") + self.accelerator.print( + f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" + f"[STOP EARLY]: Stopping training early..." + ) else: - self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}" - f"[STOP EARLY]: Stopping training early...") + self.info_bar.set_description_str( + f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." + ) break # loop complete, save final model - self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}[FINAL]: saving model to {self.results_dir}") + self.accelerator.print( + f"\n[E{epoch + 1}][{steps}]{proc_label}[FINAL]: saving model to {self.results_dir}" + ) state_dict = self.accelerator.unwrap_model(self.model).state_dict() maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" file_name = ( @@ -283,9 +297,7 @@ def train(self): OmegaConf.save(conf, f"{model_path}.yaml") if self.use_ema: - self.accelerator.print( - f"\n[{steps}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" - ) + self.accelerator.print(f"\n[{steps}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}") ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = ( f"{maskgit_save_name}.{steps}.ema.pt" diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 69ec2e5..5824ddc 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -3,12 +3,12 @@ from diffusers.optimization import get_scheduler from einops import rearrange from ema_pytorch import EMA +from omegaconf import OmegaConf from PIL import Image from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from torchvision.utils import make_grid, save_image from tqdm import tqdm -from omegaconf import OmegaConf from muse_maskgit_pytorch.trainers.base_accelerated_trainer import ( BaseAcceleratedTrainer, @@ -16,6 +16,7 @@ ) from muse_maskgit_pytorch.vqgan_vae import VQGanVAE + def noop(*args, **kwargs): pass @@ -33,39 +34,39 @@ def exists(val): class VQGanVAETrainer(BaseAcceleratedTrainer): def __init__( - self, - vae: VQGanVAE, - dataloader: DataLoader, - valid_dataloader: DataLoader, - accelerator: Accelerator, - *, - current_step, - num_train_steps, - num_epochs: int = 5, - gradient_accumulation_steps=1, - max_grad_norm=None, - save_results_every=100, - save_model_every=1000, - results_dir="./results", - logging_dir="./results/logs", - apply_grad_penalty_every=4, - lr=3e-4, - lr_scheduler_type="constant", - lr_warmup_steps=500, - discr_max_grad_norm=None, - use_ema=True, - ema_beta=0.995, - ema_update_after_step=0, - ema_update_every=1, - clear_previous_experiments=False, - validation_image_scale: float = 1.0, - only_save_last_checkpoint=False, - optimizer="Adam", - weight_decay=0.0, - use_8bit_adam=False, - num_cycles=1, - scheduler_power=1.0, - args=None, + self, + vae: VQGanVAE, + dataloader: DataLoader, + valid_dataloader: DataLoader, + accelerator: Accelerator, + *, + current_step, + num_train_steps, + num_epochs: int = 5, + gradient_accumulation_steps=1, + max_grad_norm=None, + save_results_every=100, + save_model_every=1000, + results_dir="./results", + logging_dir="./results/logs", + apply_grad_penalty_every=4, + lr=3e-4, + lr_scheduler_type="constant", + lr_warmup_steps=500, + discr_max_grad_norm=None, + use_ema=True, + ema_beta=0.995, + ema_update_after_step=0, + ema_update_every=1, + clear_previous_experiments=False, + validation_image_scale: float = 1.0, + only_save_last_checkpoint=False, + optimizer="Adam", + weight_decay=0.0, + use_8bit_adam=False, + num_cycles=1, + scheduler_power=1.0, + args=None, ): super().__init__( dataloader, @@ -157,14 +158,13 @@ def __init__( ) self.ema_model = accelerator.prepare(self.ema_model) - if not self.on_tpu: if self.num_train_steps <= 0: self.training_bar = tqdm(initial=int(self.steps.item()), total=len(self.dl) * self.num_epochs) else: self.training_bar = tqdm(initial=int(self.steps.item()), total=self.num_train_steps) - self.info_bar = tqdm(total=0, bar_format='{desc}') + self.info_bar = tqdm(total=0, bar_format="{desc}") def load(self, path): pkg = super().load(path) @@ -276,18 +276,22 @@ def train(self): # log if self.on_tpu: - self.accelerator.print(f"[E{epoch + 1}][{steps:05d}]{proc_label}: " - f"vae loss: {logs['Train/vae_loss']} - " - f"discr loss: {logs['Train/discr_loss']} - " - f"lr: {self.lr_scheduler.get_last_lr()[0]}") + self.accelerator.print( + f"[E{epoch + 1}][{steps:05d}]{proc_label}: " + f"vae loss: {logs['Train/vae_loss']} - " + f"discr loss: {logs['Train/discr_loss']} - " + f"lr: {self.lr_scheduler.get_last_lr()[0]}" + ) else: self.training_bar.update() # Note: we had to remove {proc_label} from the description # to short it so it doenst go beyond one line on each step. - self.info_bar.set_description_str(f"[E{epoch + 1}][{steps:05d}]: " - f"vae loss: {logs['Train/vae_loss']} - " - f"discr loss: {logs['Train/discr_loss']} - " - f"lr: {self.lr_scheduler.get_last_lr()[0]}") + self.info_bar.set_description_str( + f"[E{epoch + 1}][{steps:05d}]: " + f"vae loss: {logs['Train/vae_loss']} - " + f"discr loss: {logs['Train/discr_loss']} - " + f"lr: {self.lr_scheduler.get_last_lr()[0]}" + ) logs["lr"] = self.lr_scheduler.get_last_lr()[0] self.accelerator.log(logs, step=steps) @@ -300,14 +304,17 @@ def train(self): # sample results every so often if (steps % self.save_results_every) == 0: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}") + self.accelerator.print( + f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving to {str(self.results_dir)}" + ) self.log_validation_images(logs, steps) # save model every so often self.accelerator.wait_for_everyone() if self.is_main_process and (steps % self.save_model_every) == 0: self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") + f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}" + ) state_dict = self.accelerator.unwrap_model(self.model).state_dict() file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" @@ -321,7 +328,9 @@ def train(self): if self.use_ema: ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() - file_name = f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt" + file_name = ( + f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt" + ) model_path = str(self.results_dir / file_name) self.accelerator.save(ema_state_dict, model_path) @@ -333,15 +342,17 @@ def train(self): self.steps += 1 if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " - f"[STOP EARLY]: Stopping training early...") + self.accelerator.print( + f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." + ) break # Loop finished, save model self.accelerator.wait_for_everyone() if self.is_main_process: self.accelerator.print( - f"[E{self.num_epochs}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}") + f"[E{self.num_epochs}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}" + ) state_dict = self.accelerator.unwrap_model(self.model).state_dict() file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" @@ -363,4 +374,3 @@ def train(self): # save config file next to the model file. conf = OmegaConf.create(vars(self.args)) OmegaConf.save(conf, f"{model_path}.yaml") - diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index f891730..f44814a 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -1,7 +1,7 @@ import copy from functools import partial, wraps from pathlib import Path -from torch import nn + import timm import torch import torch.nn.functional as F @@ -9,6 +9,7 @@ from accelerate import Accelerator from beartype import beartype from einops import rearrange, repeat +from torch import nn from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize as VQ @@ -474,7 +475,7 @@ def forward( return_discr_loss=False, return_recons=False, add_gradient_penalty=True, - relu_loss=True + relu_loss=True, ): batch, channels, height, width, device = *img.shape, img.device @@ -581,7 +582,12 @@ def forward( # commit loss is loss in quanitizing in vq mse # gan loss is if relu_loss: - loss = F.relu(recon_loss) + F.relu(perceptual_loss) + F.relu(commit_loss) + F.relu(adaptive_weight) * F.relu(gen_loss) + loss = ( + F.relu(recon_loss) + + F.relu(perceptual_loss) + + F.relu(commit_loss) + + F.relu(adaptive_weight) * F.relu(gen_loss) + ) else: loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss diff --git a/muse_maskgit_pytorch/vqgan_vae_taming.py b/muse_maskgit_pytorch/vqgan_vae_taming.py index 3602b03..1eeaf00 100644 --- a/muse_maskgit_pytorch/vqgan_vae_taming.py +++ b/muse_maskgit_pytorch/vqgan_vae_taming.py @@ -1,16 +1,15 @@ import copy import importlib -from urllib.parse import urlparse from math import log, sqrt from pathlib import Path +from urllib.parse import urlparse import requests import torch import torch.nn.functional as F from accelerate import Accelerator from einops import rearrange -from omegaconf import OmegaConf, DictConfig - +from omegaconf import DictConfig, OmegaConf from taming.models.vqgan import VQModel from torch import nn from tqdm_loggable.auto import tqdm diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 11bb4b2..b7697ff 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -1,26 +1,22 @@ import argparse +import glob import logging +import os +import re from dataclasses import dataclass from typing import Optional, Union - import accelerate import datasets import diffusers -from rich import inspect - import transformers +import wandb +from accelerate.utils import ProjectConfiguration from datasets import load_dataset from diffusers.optimization import SchedulerType, get_scheduler -from torch.optim import Optimizer - -import os -import glob -import re -import wandb - from omegaconf import OmegaConf -from accelerate.utils import ProjectConfiguration +from rich import inspect +from torch.optim import Optimizer try: import torch_xla @@ -41,9 +37,9 @@ from muse_maskgit_pytorch.dataset import ( ImageTextDataset, LocalTextImageDataset, + URLTextDataset, get_dataset_from_dataroot, split_dataset_into_dataloaders, - URLTextDataset ) from muse_maskgit_pytorch.trainers.base_accelerated_trainer import get_optimizer @@ -51,7 +47,7 @@ transformers.logging.set_verbosity_error() # disable bitsandbytes welcome message. -os.environ['BITSANDBYTES_NOWELCOME'] = '1' +os.environ["BITSANDBYTES_NOWELCOME"] = "1" # Create the parser parser = argparse.ArgumentParser() @@ -72,10 +68,10 @@ help="Don't do center crop.", ) parser.add_argument( - "--random_crop", - action="store_true", - help="Crop the images at random locations instead of cropping from the center.", - ) + "--random_crop", + action="store_true", + help="Crop the images at random locations instead of cropping from the center.", +) parser.add_argument( "--no_flip", action="store_true", @@ -167,14 +163,18 @@ "--run_name", type=str, default=None, - help=("Name to use for the run to identify it when saved to a tracker such" - " as wandb or tensorboard. If not specified a random one will be generated."), + help=( + "Name to use for the run to identify it when saved to a tracker such" + " as wandb or tensorboard. If not specified a random one will be generated." + ), ) parser.add_argument( "--wandb_user", type=str, default=None, - help=("Specify the name for the user or the organization in which the project will be saved when using wand."), + help=( + "Specify the name for the user or the organization in which the project will be saved when using wand." + ), ) parser.add_argument( "--mixed_precision", @@ -291,11 +291,7 @@ default=256, help="Image Size.", ) -parser.add_argument( - "--vq_codebook_dim", - type=int, - default=256, - help="VQ Codebook dimensions.") +parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") parser.add_argument( "--cond_drop_prob", type=float, @@ -319,7 +315,7 @@ type=float, default=1.0, help="Controls the power of the polynomial decay schedule used by the CosineScheduleWithWarmup scheduler. " - "It determines the rate at which the learning rate decreases during the schedule.", + "It determines the rate at which the learning rate decreases during the schedule.", ) parser.add_argument( "--lr_warmup_steps", @@ -328,11 +324,11 @@ help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( - "--num_cycles", - type=int, - default=1, - help="Number of cycles for the lr scheduler.", - ) + "--num_cycles", + type=int, + default=1, + help="Number of cycles for the lr scheduler.", +) parser.add_argument( "--resume_path", type=str, @@ -356,9 +352,9 @@ type=str, default="Adafactor", help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion', 'Adafactor', " - "'AdaBound', 'AdaMod', 'AccSGD', 'AdamP', 'AggMo', 'DiffGrad', 'Lamb', " - "'NovoGrad', 'PID', 'QHAdam', 'QHM', 'RAdam', 'SGDP', 'SGDW', 'Shampoo', " - "'SWATS', 'Yogi']. Default: Lion", + "'AdaBound', 'AdaMod', 'AccSGD', 'AdamP', 'AggMo', 'DiffGrad', 'Lamb', " + "'NovoGrad', 'PID', 'QHAdam', 'QHM', 'RAdam', 'SGDP', 'SGDW', 'Shampoo', " + "'SWATS', 'Yogi']. Default: Lion", ) parser.add_argument( "--weight_decay", @@ -397,7 +393,7 @@ "--use_l2_recon_loss", action="store_true", help="Use F.mse_loss instead of F.l1_loss.", - ) +) parser.add_argument( "--debug", action="store_true", @@ -416,6 +412,7 @@ help="what type of attention to use [ein, flash, xformers] | Default: flash", ) + @dataclass class Arguments: only_save_last_checkpoint: bool = False @@ -522,17 +519,17 @@ def main(): logging.basicConfig(level=logging.INFO) project_config = ProjectConfiguration( - project_dir=args.logging_dir, - total_limit=args.checkpoint_limit, - automatic_checkpoint_naming=True, - ) + project_dir=args.logging_dir, + total_limit=args.checkpoint_limit, + automatic_checkpoint_naming=True, + ) accelerator: accelerate.Accelerator = get_accelerator( log_with=args.log_with, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, project_config=project_config, - even_batches=True + even_batches=True, ) # Get these errors out of the way early @@ -585,20 +582,27 @@ def main(): print("Finding latest checkpoint...") orig_vae_path = args.vae_path - if os.path.isfile(args.vae_path) or '.pt' in args.vae_path: + if os.path.isfile(args.vae_path) or ".pt" in args.vae_path: # If args.vae_path is a file, split it into directory and filename args.vae_path, _ = os.path.split(args.vae_path) checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) if checkpoint_files: - latest_checkpoint_file = max(checkpoint_files, key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) + ) # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( + latest_checkpoint_file, os.R_OK + ): print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback - latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)), + ) print("Using second last checkpoint: ", latest_checkpoint_file) else: print("No usable checkpoint found.") @@ -615,16 +619,19 @@ def main(): # use config next to checkpoint if there is one and merge the cli arguments to it # the cli arguments will take priority so we can use it to override any value we want. - #if os.path.exists(f"{args.vae_path}.yaml"): - #print("Config file found, reusing config from it. Use cli arguments to override any desired value.") - #conf = OmegaConf.load(f"{args.vae_path}.yaml") - #cli_conf = OmegaConf.from_cli() + # if os.path.exists(f"{args.vae_path}.yaml"): + # print("Config file found, reusing config from it. Use cli arguments to override any desired value.") + # conf = OmegaConf.load(f"{args.vae_path}.yaml") + # cli_conf = OmegaConf.from_cli() ## merge the config file and the cli arguments. - #conf = OmegaConf.merge(conf, cli_conf) - - vae = VQGanVAE(dim=args.dim, vq_codebook_dim=args.vq_codebook_dim, vq_codebook_size=args.vq_codebook_size, l2_recon_loss=args.use_l2_recon_loss).to( - accelerator.device - ) + # conf = OmegaConf.merge(conf, cli_conf) + + vae = VQGanVAE( + dim=args.dim, + vq_codebook_dim=args.vq_codebook_dim, + vq_codebook_size=args.vq_codebook_size, + l2_recon_loss=args.use_l2_recon_loss, + ).to(accelerator.device) vae.load(args.vae_path) elif args.taming_model_path is not None and args.taming_config_path is not None: @@ -657,7 +664,7 @@ def main(): xformers = False flash = False else: - raise NotImplementedError(f"Attention of type \"{args.attention_type}\" does not exist") + raise NotImplementedError(f'Attention of type "{args.attention_type}" does not exist') transformer: MaskGitTransformer = MaskGitTransformer( # num_tokens must be same as codebook size above @@ -674,7 +681,7 @@ def main(): t5_name=args.t5_name, cache_path=args.cache_path, flash=flash, - xformers=xformers + xformers=xformers, ) # load the maskgit transformer from disk if we have previously trained one @@ -683,20 +690,34 @@ def main(): accelerator.print("Finding latest checkpoint...") orig_vae_path = args.resume_path - if os.path.isfile(args.resume_path) or '.pt' in args.resume_path: + if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: # If args.resume_path is a file, split it into directory and filename args.resume_path, _ = os.path.split(args.resume_path) checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt")) if checkpoint_files: - latest_checkpoint_file = max(checkpoint_files,key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + latest_checkpoint_file = max( + checkpoint_files, + key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt$", x).group(1)) + if not x.endswith("ema.pt") + else -1, + ) # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): - accelerator.print(f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable.") + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( + latest_checkpoint_file, os.R_OK + ): + accelerator.print( + f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable." + ) if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback - latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) + latest_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt$", x).group(1)) + if not x.endswith("ema.pt") + else -1, + ) accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file) else: accelerator.print("No usable MaskGit checkpoint found.") @@ -714,13 +735,14 @@ def main(): # use config next to checkpoint if there is one and merge the cli arguments to it # the cli arguments will take priority so we can use it to override any value we want. if os.path.exists(f"{args.resume_path}.yaml"): - accelerator.print("Config file found, reusing config from it. Use cli arguments to override any desired value.") + accelerator.print( + "Config file found, reusing config from it. Use cli arguments to override any desired value." + ) conf = OmegaConf.load(f"{args.resume_path}.yaml") cli_conf = OmegaConf.from_cli() # merge the config file and the cli arguments. conf = OmegaConf.merge(conf, cli_conf) - # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit( vae=vae, # vqgan vae @@ -741,7 +763,7 @@ def main(): accelerator.print("Finding latest checkpoint...") orig_vae_path = args.resume_path - if os.path.isfile(args.resume_path) or '.pt' in args.resume_path: + if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: # If args.resume_path is a file, split it into directory and filename args.resume_path, _ = os.path.split(args.resume_path) @@ -752,26 +774,34 @@ def main(): if checkpoint_files: if args.cond_image_size: - latest_checkpoint_file = max(checkpoint_files, - key=lambda x: int(re.search(r'maskgit_superres\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files, + key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), + ) else: - latest_checkpoint_file = max(checkpoint_files, - key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files, key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)) + ) # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( + latest_checkpoint_file, os.R_OK + ): accelerator.print( - f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable." + ) if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback if args.cond_image_size: - latest_checkpoint_file = max(checkpoint_files[:-1], - key=lambda x: int( - re.search(r'maskgit_superres\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), + ) else: - latest_checkpoint_file = max(checkpoint_files[:-1], - key=lambda x: int( - re.search(r'maskgit\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)), + ) accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) else: accelerator.print("No usable checkpoint found.") @@ -829,7 +859,7 @@ def main(): caption_column=args.caption_column, center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, - using_taming=False if not args.taming_model_path else True + using_taming=False if not args.taming_model_path else True, ) else: dataset = ImageTextDataset( @@ -841,7 +871,7 @@ def main(): center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, stream=args.streaming, - using_taming=False if not args.taming_model_path else True + using_taming=False if not args.taming_model_path else True, ) # Create the dataloaders @@ -902,12 +932,11 @@ def main(): args.project_name, config=vars(args), init_kwargs={ - "wandb":{ + "wandb": { "entity": f"{args.wandb_user or wandb.api.default_entity}", "name": args.run_name, - }, - } - + }, + }, ) # Create the trainer diff --git a/train_muse_vae.py b/train_muse_vae.py index 2380e62..659e21d 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -1,8 +1,14 @@ import argparse +import glob +import os +import re from dataclasses import dataclass from typing import Optional, Union +import wandb +from accelerate.utils import ProjectConfiguration from datasets import load_dataset +from omegaconf import OmegaConf from muse_maskgit_pytorch import ( VQGanVAE, @@ -16,16 +22,8 @@ split_dataset_into_dataloaders, ) -import os -import glob -import re -import wandb - -from omegaconf import OmegaConf -from accelerate.utils import ProjectConfiguration - # disable bitsandbytes welcome message. -os.environ['BITSANDBYTES_NOWELCOME'] = '1' +os.environ["BITSANDBYTES_NOWELCOME"] = "1" parser = argparse.ArgumentParser() parser.add_argument("--webdataset", type=str, default=None, help="Path to webdataset if using one.") @@ -51,10 +49,10 @@ help="Don't flip image.", ) parser.add_argument( - "--random_crop", - action="store_true", - help="Crop the images at random locations instead of cropping from the center.", - ) + "--random_crop", + action="store_true", + help="Crop the images at random locations instead of cropping from the center.", +) parser.add_argument( "--dataset_save_path", type=str, @@ -121,14 +119,18 @@ "--run_name", type=str, default=None, - help=("Name to use for the run to identify it when saved to a tracker such" - " as wandb or tensorboard. If not specified a random one will be generated."), + help=( + "Name to use for the run to identify it when saved to a tracker such" + " as wandb or tensorboard. If not specified a random one will be generated." + ), ) parser.add_argument( "--wandb_user", type=str, default=None, - help=("Specify the name for the user or the organization in which the project will be saved when using wand."), + help=( + "Specify the name for the user or the organization in which the project will be saved when using wand." + ), ) parser.add_argument( "--mixed_precision", @@ -219,11 +221,7 @@ help="Keep only X number of checkpoints and delete the older ones.", ) parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") -parser.add_argument( - "--vq_codebook_dim", - type=int, - default=256, - help="VQ Codebook dimensions.") +parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") parser.add_argument( "--image_size", type=int, @@ -241,7 +239,7 @@ type=float, default=1.0, help="Controls the power of the polynomial decay schedule used by the CosineScheduleWithWarmup scheduler. " - "It determines the rate at which the learning rate decreases during the schedule.", + "It determines the rate at which the learning rate decreases during the schedule.", ) parser.add_argument( "--lr_warmup_steps", @@ -302,6 +300,7 @@ help="Whether to use the latest checkpoint", ) + @dataclass class Arguments: only_save_last_checkpoint: bool = False @@ -406,19 +405,18 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, project_config=project_config, - even_batches=True + even_batches=True, ) if accelerator.is_main_process: accelerator.init_trackers( args.project_name, config=vars(args), init_kwargs={ - "wandb":{ + "wandb": { "entity": f"{args.wandb_user or wandb.api.default_entity}", "name": args.run_name, - }, - } - + }, + }, ) if args.webdataset is not None: @@ -432,7 +430,7 @@ def main(): image_column=args.image_column, caption_column=args.caption_column, save_path=args.dataset_save_path, - save=not args.no_cache + save=not args.no_cache, ) elif args.dataset_name: if args.cache_path: @@ -466,21 +464,28 @@ def main(): accelerator.print("Finding latest checkpoint...") orig_vae_path = args.resume_path - if os.path.isfile(args.resume_path) or '.pt' in args.resume_path: + if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: # If args.vae_path is a file, split it into directory and filename args.resume_path, _ = os.path.split(args.resume_path) checkpoint_files = glob.glob(os.path.join(args.resume_path, "vae.*.pt")) if checkpoint_files: - latest_checkpoint_file = max(checkpoint_files, key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) + ) # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( + latest_checkpoint_file, os.R_OK + ): accelerator.print( - f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable." + ) if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback - latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + latest_checkpoint_file = max( + checkpoint_files[:-1], key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) + ) accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) else: accelerator.print("No usable checkpoint found.") @@ -528,7 +533,6 @@ def main(): vq_codebook_dim=args.vq_codebook_dim, vq_codebook_size=args.vq_codebook_size, accelerator=accelerator, - ) current_step = 0 @@ -540,7 +544,7 @@ def main(): center_crop=not args.no_center_crop, flip=not args.no_flip, stream=args.streaming, - random_crop=args.random_crop + random_crop=args.random_crop, ) # dataloader @@ -576,7 +580,7 @@ def main(): use_8bit_adam=args.use_8bit_adam, num_cycles=args.num_cycles, scheduler_power=args.scheduler_power, - num_epochs=args.num_epochs + num_epochs=args.num_epochs, ) trainer.train() From 71b2c6b698dd44cedff1db90174625b860d88fba Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Mon, 12 Jun 2023 22:46:54 +1000 Subject: [PATCH 58/62] attn: fixup attn_test.py formatting, raise() instead of assert() --- muse_maskgit_pytorch/attn/attn_test.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/muse_maskgit_pytorch/attn/attn_test.py b/muse_maskgit_pytorch/attn/attn_test.py index ba313ae..022bcf8 100644 --- a/muse_maskgit_pytorch/attn/attn_test.py +++ b/muse_maskgit_pytorch/attn/attn_test.py @@ -59,16 +59,7 @@ # 0-pad mask to multiple of 8 tokens xfo_context_mask = pad(context_mask, (0, extra_tokens_needed)) # replicate-pad embedding to multiple of 8 tokens (mask will hide the extra tokens) - xfo_context = pad( - context, - ( - 0, - 0, - 0, - extra_tokens_needed, - ), - "replicate", - ) + xfo_context = pad(context, (0, 0, 0, extra_tokens_needed), "replicate") ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) # sdp attn works, but only supports flash attn when context_mask is None. @@ -81,7 +72,7 @@ # atol would normally be 1e-8 atol = 5e-7 # assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" - assert allclose( - ein_result, xfo_attn, rtol=rtol, atol=atol - ), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" - print(f"attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}") + if not allclose(ein_result, xfo_attn, rtol=rtol, atol=atol): + raise RuntimeError( + f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" + ) From 99aba6028d91ca274bbf7382db31ea936d60a4ef Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Mon, 12 Jun 2023 22:53:25 +1000 Subject: [PATCH 59/62] add setuptools-scm version file to gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 19d673c..7d149a3 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,6 @@ dmypy.json # Pyre type checker .pyre/ + +# setuptools-scm version file +muse_maskgit_pytorch/_version.py From e08f7069419284b5b5f73c35230b1022c2146e7e Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Mon, 12 Jun 2023 06:36:20 -0700 Subject: [PATCH 60/62] Small fixes made by the pre-commit hook. --- muse_maskgit_pytorch/vqgan_vae_taming.py | 3 ++- train_muse_maskgit.py | 3 ++- train_muse_vae.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/muse_maskgit_pytorch/vqgan_vae_taming.py b/muse_maskgit_pytorch/vqgan_vae_taming.py index 1eeaf00..2479161 100644 --- a/muse_maskgit_pytorch/vqgan_vae_taming.py +++ b/muse_maskgit_pytorch/vqgan_vae_taming.py @@ -10,10 +10,11 @@ from accelerate import Accelerator from einops import rearrange from omegaconf import DictConfig, OmegaConf -from taming.models.vqgan import VQModel from torch import nn from tqdm_loggable.auto import tqdm +from taming.models.vqgan import VQModel + # constants CACHE_PATH = Path.home().joinpath(".cache/taming") diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index b7697ff..40ecdf4 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -10,7 +10,6 @@ import datasets import diffusers import transformers -import wandb from accelerate.utils import ProjectConfiguration from datasets import load_dataset from diffusers.optimization import SchedulerType, get_scheduler @@ -18,6 +17,8 @@ from rich import inspect from torch.optim import Optimizer +import wandb + try: import torch_xla import torch_xla.core.xla_model as xm diff --git a/train_muse_vae.py b/train_muse_vae.py index 659e21d..a2c53a3 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -5,11 +5,12 @@ from dataclasses import dataclass from typing import Optional, Union -import wandb from accelerate.utils import ProjectConfiguration from datasets import load_dataset from omegaconf import OmegaConf +import wandb + from muse_maskgit_pytorch import ( VQGanVAE, VQGanVAETaming, From 757082d143ab25fc691b3a5fc0d99f572064ad27 Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Tue, 13 Jun 2023 00:47:55 +1000 Subject: [PATCH 61/62] actions: fix pre-commit --- .github/workflows/pre-commit.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 0826906..87febc1 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -28,7 +28,6 @@ jobs: uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - cache: "pip" - name: Install package id: install-package From 8ddfd3001d60ba53be181be95ac2bf5e72f2882e Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Tue, 13 Jun 2023 00:50:46 +1000 Subject: [PATCH 62/62] actions: fix pre-commit again --- .github/workflows/pre-commit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 87febc1..1bfa9c3 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -13,7 +13,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: [3.8, 3.10] + python-version: ["3.8", "3.10"] runs-on: ${{ matrix.os }} steps: