From faa1281d2765b630ecdf371f5ad569a2893a5ef6 Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sat, 11 Mar 2023 14:24:13 +0800 Subject: [PATCH] 8bit Optimisers and resolved potential Env issues The commit to setup.py was to resolve dependency issues on my machine, as certain features are deprecated in later versions of lightning and torchmetrics, these versions still have these features but present deprecation warnings Also, regarding 8bit, replacing nn.Linear() with it's 8bit equivalent seems to break tensors, so I've opted for merely implementing the optimisers --- .../trainers/maskgit_trainer.py | 14 +++++++++-- .../trainers/vqvae_trainers.py | 21 ++++++++++++---- setup.py | 2 ++ train_muse_maskgit.py | 6 +++++ train_muse_vae.py | 25 +++++++++++++------ 5 files changed, 54 insertions(+), 14 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index eb60094..54655cf 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -14,6 +14,7 @@ from torchvision.utils import make_grid, save_image from PIL import Image from muse_maskgit_pytorch.vqgan_vae import VQGanVAE +import bitsandbytes as bnb from einops import rearrange @@ -70,6 +71,7 @@ def __init__( only_save_last_checkpoint=False, optimizer="Lion", weight_decay=0.0, + use_8bit_adam=False ): super().__init__( dataloader, @@ -104,11 +106,19 @@ def __init__( # optimizers if optimizer == "Adam": - self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + self.optim = bnb.optim.Adam8bit(transformer_parameters, lr=lr) + else: + self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay) elif optimizer == "AdamW": - self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + self.optim = bnb.optim.AdamW8bit(transformer_parameters, lr=lr) + else: + self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay) elif optimizer == "Lion": self.optim = Lion(transformer_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + print("8bit is not supported with the Lion optimiser, Using standard Lion instead.") else: print(f"{optimizer} optimizer not supported yet.") diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 6cc681d..b4c46c5 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, random_split from torch.utils.tensorboard import SummaryWriter from torchvision.utils import make_grid, save_image - +import bitsandbytes as bnb from muse_maskgit_pytorch.vqgan_vae import VQGanVAE @@ -74,6 +74,7 @@ def __init__( only_save_last_checkpoint=False, optimizer='Adam', weight_decay=0.0, + use_8bit_adam=False ): super().__init__( dataloader, @@ -102,14 +103,24 @@ def __init__( # optimizers if optimizer == 'Adam': - self.optim = Adam(vae_parameters, lr=lr, weight_decay=weight_decay) - self.discr_optim = Adam(discr_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + self.optim = bnb.optim.Adam8bit(vae_parameters, lr=lr) + self.discr_optim = bnb.optim.Adam8bit(discr_parameters, lr=lr) + else: + self.optim = Adam(vae_parameters, lr=lr) + self.discr_optim = Adam(discr_parameters, lr=lr) elif optimizer == 'AdamW': - self.optim = AdamW(vae_parameters, lr=lr, weight_decay=weight_decay) - self.discr_optim = AdamW(discr_parameters, lr=lr) + if use_8bit_adam: + self.optim = bnb.optim.AdamW8bit(vae_parameters, lr=lr) + self.discr_optim = bnb.optim.AdamW8bit(discr_parameters, lr=lr) + else: + self.optim = AdamW(vae_parameters, lr=lr) + self.discr_optim = AdamW(discr_parameters, lr=lr) elif optimizer == 'Lion': self.optim = Lion(vae_parameters, lr=lr, weight_decay=weight_decay) self.discr_optim = Lion(discr_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + print("8bit is not supported with the Lion optimiser, Using standard Lion instead.") else: print(f"{optimizer} optimizer not supported yet.") diff --git a/setup.py b/setup.py index 632e784..92e81e3 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,8 @@ "pillow", "sentencepiece", "torch>=1.6", + "torchmetrics<0.8.0", + "pytorch-lightning<=1.7.7", "taming-transformers>=0.0.1", "transformers", "torch>=1.6", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 5bd488a..1c947c3 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -140,6 +140,11 @@ def parse_args(): choices=["no", "fp16", "bf16"], help="Precision to train on.", ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether to use the 8bit adam optimiser", + ) parser.add_argument( "--results_dir", type=str, @@ -384,6 +389,7 @@ def main(): only_save_last_checkpoint=args.only_save_last_checkpoint, optimizer=args.optimizer, weight_decay=args.weight_decay, + use_8bit_adam=args.use_8bit_adam ) trainer.train() diff --git a/train_muse_vae.py b/train_muse_vae.py index 207ee7e..050f1bf 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -13,6 +13,9 @@ import argparse +import torch.nn as nn +import bitsandbytes as bnb +from accelerate import init_empty_weights def parse_args(): # Create the parser @@ -110,6 +113,11 @@ def parse_args(): choices=["no", "fp16", "bf16"], help="Precision to train on.", ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether to use the 8bit adam optimiser", + ) parser.add_argument( "--results_dir", type=str, @@ -188,15 +196,17 @@ def parse_args(): default=None, help="Path to the last saved checkpoint. 'results/vae.steps.pt'", ) - parser.add_argument("--optimizer",type=str, + parser.add_argument( + "--optimizer",type=str, default='Lion', help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Adam", - ) - parser.add_argument("--weight_decay", type=float, - default=0.0, - help="Optimizer weight_decay to use. Default: 0.0", - ) - + ) + parser.add_argument( + "--weight_decay", type=float, + default=0.0, + help="Optimizer weight_decay to use. Default: 0.0", + ) + # Parse the argument return parser.parse_args() @@ -288,6 +298,7 @@ def main(): validation_image_scale=args.validation_image_scale, only_save_last_checkpoint=args.only_save_last_checkpoint, optimizer=args.optimizer, + use_8bit_adam=args.use_8bit_adam ) trainer.train()