diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index a8a020a..9447306 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -17,6 +17,7 @@ from ema_pytorch import EMA from torch.optim import Adam, AdamW from lion_pytorch import Lion +from torch_optimizer import Adafactor import numpy as np @@ -116,11 +117,10 @@ def get_optimizer(use_8bit_adam, optimizer, parameters, lr, weight_decay): optim = bnb.optim.AdamW8bit(parameters, lr=lr, weight_decay=weight_decay) else: optim = AdamW(parameters, lr=lr, weight_decay=weight_decay) - elif optimizer == "Lion": optim = Lion(parameters, lr=lr, weight_decay=weight_decay) if use_8bit_adam: - print("8bit is not supported by the Lion optimiser, Using standard Lion instead.") + print("8bit is not supported by the Lion optimizer, Using standard Lion instead.") else: raise NotImplementedError(f"{optimizer} optimizer not supported yet.") return optim diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 2a00390..f245d2f 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -171,6 +171,11 @@ def parse_args(): default=None, help="Name of the huggingface dataset used.", ) + parser.add_argument( + "--streaming", + action="store_true", + help="Whether to stream the huggingface dataset", + ) parser.add_argument( "--train_data_dir", type=str, @@ -285,7 +290,7 @@ def main(): save_path=args.dataset_save_path, ) elif args.dataset_name: - dataset = load_dataset(args.dataset_name)["train"] + dataset = load_dataset(args.dataset_name, streaming=args.streaming)["train"] if args.vae_path and args.taming_model_path: raise Exception("You can't pass vae_path and taming args at the same time.") diff --git a/train_muse_vae.py b/train_muse_vae.py index 631ef80..1924c8c 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -142,6 +142,11 @@ def parse_args(): default=None, help="Name of the huggingface dataset used.", ) + parser.add_argument( + "--streaming", + action="store_true", + help="Whether to stream the huggingface dataset", + ) parser.add_argument( "--train_data_dir", type=str, @@ -200,11 +205,6 @@ def parse_args(): default=None, help="Path to the last saved checkpoint. 'results/vae.steps.pt'", ) - parser.add_argument( - "--optimizer",type=str, - default='Lion', - help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Adam", - ) parser.add_argument( "--weight_decay", type=float, default=0.0, @@ -267,7 +267,7 @@ def main(): save_path=args.dataset_save_path, ) elif args.dataset_name: - dataset = load_dataset(args.dataset_name)["train"] + dataset = load_dataset(args.dataset_name, streaming=args.streaming)["train"] vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size) if args.taming_model_path: