Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#14 from ZeroCool940711/adding_training_s…
Browse files Browse the repository at this point in the history
…cript

Added support for weight_decay on the different optimizers.
  • Loading branch information
isamu-isozaki authored Mar 10, 2023
2 parents d846fd9 + daf6842 commit de94178
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 20 deletions.
7 changes: 4 additions & 3 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
validation_image_scale=1,
only_save_last_checkpoint=False,
optimizer="Lion",
weight_decay=0.0,
):
super().__init__(
dataloader,
Expand Down Expand Up @@ -103,11 +104,11 @@ def __init__(

# optimizers
if optimizer == "Adam":
self.optim = Adam(transformer_parameters, lr=lr)
self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "AdamW":
self.optim = Adam(transformer_parameters, lr=lr)
self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "Lion":
self.optim = Lion(transformer_parameters, lr=lr)
self.optim = Lion(transformer_parameters, lr=lr, weight_decay=weight_decay)
else:
print(f"{optimizer} optimizer not supported yet.")

Expand Down
21 changes: 11 additions & 10 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __init__(
clear_previous_experiments=False,
validation_image_scale=1,
only_save_last_checkpoint=False,
optimizer="Adam",
optimizer='Adam',
weight_decay=0.0,
):
super().__init__(
dataloader,
Expand Down Expand Up @@ -100,15 +101,15 @@ def __init__(
vae_parameters = all_parameters - discr_parameters

# optimizers
if optimizer == "Adam":
self.optim = Adam(vae_parameters, lr=lr)
self.discr_optim = Adam(discr_parameters, lr=lr)
elif optimizer == "AdamW":
self.optim = AdamW(vae_parameters, lr=lr)
self.discr_optim = AdamW(discr_parameters, lr=lr)
elif optimizer == "Lion":
self.optim = Lion(vae_parameters, lr=lr)
self.discr_optim = Lion(discr_parameters, lr=lr)
if optimizer == 'Adam':
self.optim = Adam(vae_parameters, lr=lr, weight_decay=weight_decay)
self.discr_optim = Adam(discr_parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == 'AdamW':
self.optim = AdamW(vae_parameters, lr=lr, weight_decay=weight_decay)
self.discr_optim = AdamW(discr_parameters, lr=lr)
elif optimizer == 'Lion':
self.optim = Lion(vae_parameters, lr=lr, weight_decay=weight_decay)
self.discr_optim = Lion(discr_parameters, lr=lr, weight_decay=weight_decay)
else:
print(f"{optimizer} optimizer not supported yet.")

Expand Down
5 changes: 5 additions & 0 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ def parse_args():
default="Lion",
help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Lion",
)
parser.add_argument("--weight_decay", type=float,
default=0.0,
help="Optimizer weight_decay to use. Default: 0.0",
)
# Parse the argument
return parser.parse_args()

Expand Down Expand Up @@ -379,6 +383,7 @@ def main():
validation_image_scale=args.validation_image_scale,
only_save_last_checkpoint=args.only_save_last_checkpoint,
optimizer=args.optimizer,
weight_decay=args.weight_decay,
)

trainer.train()
Expand Down
16 changes: 9 additions & 7 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ def parse_args():
default=None,
help="Path to the last saved checkpoint. 'results/vae.steps.pt'",
)
parser.add_argument(
"--optimizer",
type=str,
default="Lion",
help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Lion",
)

parser.add_argument("--optimizer",type=str,
default='Lion',
help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Adam",
)
parser.add_argument("--weight_decay", type=float,
default=0.0,
help="Optimizer weight_decay to use. Default: 0.0",
)

# Parse the argument
return parser.parse_args()

Expand Down

0 comments on commit de94178

Please sign in to comment.