diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 21dca82..eb60094 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -69,6 +69,7 @@ def __init__( validation_image_scale=1, only_save_last_checkpoint=False, optimizer="Lion", + weight_decay=0.0, ): super().__init__( dataloader, @@ -103,11 +104,11 @@ def __init__( # optimizers if optimizer == "Adam": - self.optim = Adam(transformer_parameters, lr=lr) + self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay) elif optimizer == "AdamW": - self.optim = Adam(transformer_parameters, lr=lr) + self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay) elif optimizer == "Lion": - self.optim = Lion(transformer_parameters, lr=lr) + self.optim = Lion(transformer_parameters, lr=lr, weight_decay=weight_decay) else: print(f"{optimizer} optimizer not supported yet.") diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index b2dfc26..6cc681d 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -72,7 +72,8 @@ def __init__( clear_previous_experiments=False, validation_image_scale=1, only_save_last_checkpoint=False, - optimizer="Adam", + optimizer='Adam', + weight_decay=0.0, ): super().__init__( dataloader, @@ -100,15 +101,15 @@ def __init__( vae_parameters = all_parameters - discr_parameters # optimizers - if optimizer == "Adam": - self.optim = Adam(vae_parameters, lr=lr) - self.discr_optim = Adam(discr_parameters, lr=lr) - elif optimizer == "AdamW": - self.optim = AdamW(vae_parameters, lr=lr) - self.discr_optim = AdamW(discr_parameters, lr=lr) - elif optimizer == "Lion": - self.optim = Lion(vae_parameters, lr=lr) - self.discr_optim = Lion(discr_parameters, lr=lr) + if optimizer == 'Adam': + self.optim = Adam(vae_parameters, lr=lr, weight_decay=weight_decay) + self.discr_optim = Adam(discr_parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == 'AdamW': + self.optim = AdamW(vae_parameters, lr=lr, weight_decay=weight_decay) + self.discr_optim = AdamW(discr_parameters, lr=lr) + elif optimizer == 'Lion': + self.optim = Lion(vae_parameters, lr=lr, weight_decay=weight_decay) + self.discr_optim = Lion(discr_parameters, lr=lr, weight_decay=weight_decay) else: print(f"{optimizer} optimizer not supported yet.") diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index ac96b2a..5bd488a 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -249,6 +249,10 @@ def parse_args(): default="Lion", help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Lion", ) + parser.add_argument("--weight_decay", type=float, + default=0.0, + help="Optimizer weight_decay to use. Default: 0.0", + ) # Parse the argument return parser.parse_args() @@ -379,6 +383,7 @@ def main(): validation_image_scale=args.validation_image_scale, only_save_last_checkpoint=args.only_save_last_checkpoint, optimizer=args.optimizer, + weight_decay=args.weight_decay, ) trainer.train() diff --git a/train_muse_vae.py b/train_muse_vae.py index cbbe80e..207ee7e 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -188,13 +188,15 @@ def parse_args(): default=None, help="Path to the last saved checkpoint. 'results/vae.steps.pt'", ) - parser.add_argument( - "--optimizer", - type=str, - default="Lion", - help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Lion", - ) - + parser.add_argument("--optimizer",type=str, + default='Lion', + help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Adam", + ) + parser.add_argument("--weight_decay", type=float, + default=0.0, + help="Optimizer weight_decay to use. Default: 0.0", + ) + # Parse the argument return parser.parse_args()