diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index eb60094..54655cf 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -14,6 +14,7 @@ from torchvision.utils import make_grid, save_image from PIL import Image from muse_maskgit_pytorch.vqgan_vae import VQGanVAE +import bitsandbytes as bnb from einops import rearrange @@ -70,6 +71,7 @@ def __init__( only_save_last_checkpoint=False, optimizer="Lion", weight_decay=0.0, + use_8bit_adam=False ): super().__init__( dataloader, @@ -104,11 +106,19 @@ def __init__( # optimizers if optimizer == "Adam": - self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + self.optim = bnb.optim.Adam8bit(transformer_parameters, lr=lr) + else: + self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay) elif optimizer == "AdamW": - self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + self.optim = bnb.optim.AdamW8bit(transformer_parameters, lr=lr) + else: + self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay) elif optimizer == "Lion": self.optim = Lion(transformer_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + print("8bit is not supported with the Lion optimiser, Using standard Lion instead.") else: print(f"{optimizer} optimizer not supported yet.") diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 6cc681d..b4c46c5 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, random_split from torch.utils.tensorboard import SummaryWriter from torchvision.utils import make_grid, save_image - +import bitsandbytes as bnb from muse_maskgit_pytorch.vqgan_vae import VQGanVAE @@ -74,6 +74,7 @@ def __init__( only_save_last_checkpoint=False, optimizer='Adam', weight_decay=0.0, + use_8bit_adam=False ): super().__init__( dataloader, @@ -102,14 +103,24 @@ def __init__( # optimizers if optimizer == 'Adam': - self.optim = Adam(vae_parameters, lr=lr, weight_decay=weight_decay) - self.discr_optim = Adam(discr_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + self.optim = bnb.optim.Adam8bit(vae_parameters, lr=lr) + self.discr_optim = bnb.optim.Adam8bit(discr_parameters, lr=lr) + else: + self.optim = Adam(vae_parameters, lr=lr) + self.discr_optim = Adam(discr_parameters, lr=lr) elif optimizer == 'AdamW': - self.optim = AdamW(vae_parameters, lr=lr, weight_decay=weight_decay) - self.discr_optim = AdamW(discr_parameters, lr=lr) + if use_8bit_adam: + self.optim = bnb.optim.AdamW8bit(vae_parameters, lr=lr) + self.discr_optim = bnb.optim.AdamW8bit(discr_parameters, lr=lr) + else: + self.optim = AdamW(vae_parameters, lr=lr) + self.discr_optim = AdamW(discr_parameters, lr=lr) elif optimizer == 'Lion': self.optim = Lion(vae_parameters, lr=lr, weight_decay=weight_decay) self.discr_optim = Lion(discr_parameters, lr=lr, weight_decay=weight_decay) + if use_8bit_adam: + print("8bit is not supported with the Lion optimiser, Using standard Lion instead.") else: print(f"{optimizer} optimizer not supported yet.") diff --git a/setup.py b/setup.py index 632e784..92e81e3 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,8 @@ "pillow", "sentencepiece", "torch>=1.6", + "torchmetrics<0.8.0", + "pytorch-lightning<=1.7.7", "taming-transformers>=0.0.1", "transformers", "torch>=1.6", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 5bd488a..1c947c3 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -140,6 +140,11 @@ def parse_args(): choices=["no", "fp16", "bf16"], help="Precision to train on.", ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether to use the 8bit adam optimiser", + ) parser.add_argument( "--results_dir", type=str, @@ -384,6 +389,7 @@ def main(): only_save_last_checkpoint=args.only_save_last_checkpoint, optimizer=args.optimizer, weight_decay=args.weight_decay, + use_8bit_adam=args.use_8bit_adam ) trainer.train() diff --git a/train_muse_vae.py b/train_muse_vae.py index 207ee7e..050f1bf 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -13,6 +13,9 @@ import argparse +import torch.nn as nn +import bitsandbytes as bnb +from accelerate import init_empty_weights def parse_args(): # Create the parser @@ -110,6 +113,11 @@ def parse_args(): choices=["no", "fp16", "bf16"], help="Precision to train on.", ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether to use the 8bit adam optimiser", + ) parser.add_argument( "--results_dir", type=str, @@ -188,15 +196,17 @@ def parse_args(): default=None, help="Path to the last saved checkpoint. 'results/vae.steps.pt'", ) - parser.add_argument("--optimizer",type=str, + parser.add_argument( + "--optimizer",type=str, default='Lion', help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Adam", - ) - parser.add_argument("--weight_decay", type=float, - default=0.0, - help="Optimizer weight_decay to use. Default: 0.0", - ) - + ) + parser.add_argument( + "--weight_decay", type=float, + default=0.0, + help="Optimizer weight_decay to use. Default: 0.0", + ) + # Parse the argument return parser.parse_args() @@ -288,6 +298,7 @@ def main(): validation_image_scale=args.validation_image_scale, only_save_last_checkpoint=args.only_save_last_checkpoint, optimizer=args.optimizer, + use_8bit_adam=args.use_8bit_adam ) trainer.train()