Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#5 from ZeroCool940711/adding_training_sc…
Browse files Browse the repository at this point in the history
…ript

Added support for learning rate schedulers and warmup steps.
  • Loading branch information
isamu-isozaki authored Feb 27, 2023
2 parents 9d80f56 + afe92c0 commit 7422a01
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
30 changes: 27 additions & 3 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from pathlib import Path
from shutil import rmtree
from datetime import datetime

from beartype import beartype
from PIL import Image
Expand All @@ -11,6 +12,7 @@
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid, save_image


from muse_maskgit_pytorch.vqgan_vae import VQGanVAE

from einops import rearrange
Expand All @@ -20,6 +22,8 @@
from ema_pytorch import EMA
import numpy as np
from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer
from diffusers.optimization import get_scheduler

def noop(*args, **kwargs):
pass

Expand Down Expand Up @@ -49,6 +53,8 @@ def __init__(
logging_dir="./results/logs",
apply_grad_penalty_every=4,
lr=3e-4,
lr_scheduler_type='constant',
lr_warmup_steps= 500,
discr_max_grad_norm=None,
use_ema=True,
ema_beta=0.995,
Expand All @@ -71,6 +77,20 @@ def __init__(
# optimizers
self.optim = Adam(vae_parameters, lr=lr)
self.discr_optim = Adam(discr_parameters, lr=lr)

self.lr_scheduler = get_scheduler(
lr_scheduler_type,
optimizer=self.optim,
num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps,
num_training_steps=self.num_train_steps * self.gradient_accumulation_steps,
)

self.lr_scheduler_discr = get_scheduler(
lr_scheduler_type,
optimizer=self.discr_optim,
num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps,
num_training_steps=self.num_train_steps * self.gradient_accumulation_steps,
)

self.discr_max_grad_norm = discr_max_grad_norm

Expand Down Expand Up @@ -169,6 +189,9 @@ def train_step(self):

if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

self.lr_scheduler.step()
self.lr_scheduler_discr.step()
self.optim.step()
self.optim.zero_grad()

Expand All @@ -192,9 +215,10 @@ def train_step(self):

self.discr_optim.step()

# log

# self.print(f"{steps}: vae loss: {logs['Train/vae_loss']} - discr loss: {logs['Train/discr_loss']}")
# log

self.print(f"{steps}: vae loss: {logs['Train/vae_loss']} - discr loss: {logs['Train/discr_loss']} - lr: {self.lr_scheduler.get_last_lr()[0]}")
logs['lr'] = self.lr_scheduler.get_last_lr()[0]
self.accelerator.log(logs, step=steps)

# update exponential moving averaged generator
Expand Down
15 changes: 7 additions & 8 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,6 @@ def parse_args():
)

# vae_trainer args
parser.add_argument(
"--resume_from",
type=str,
default="",
help="Path to the vae model. eg. 'results/vae.steps.pt'",
)
parser.add_argument(
"--dataset_name",
type=str,
Expand Down Expand Up @@ -137,6 +131,8 @@ def parse_args():
default=256,
help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it",
)
parser.add_argument("--lr_scheduler", type=str, default="constant", help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]')
parser.add_argument("--lr_warmup_steps", type=int, default=0, help='Number of steps for the warmup in the lr scheduler.')
parser.add_argument(
"--resume_path",
type=str,
Expand All @@ -156,6 +152,7 @@ def main():
elif args.dataset_name:
dataset = load_dataset(args.dataset_name)["train"]
vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size)

if args.resume_path:
print (f'Resuming VAE from: {args.resume_path}')
vae.load(args.resume_path)
Expand All @@ -177,9 +174,11 @@ def main():
dataloader,
validation_dataloader,
accelerator,
current_step=0,
current_step=current_step,
num_train_steps=args.num_train_steps,
lr=args.lr,
lr_scheduler = args.lr_scheduler,
lr_warmup_steps = args.lr_warmup_steps,
max_grad_norm=args.max_grad_norm,
discr_max_grad_norm=args.discr_max_grad_norm,
save_results_every=args.save_results_every,
Expand All @@ -200,4 +199,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit 7422a01

Please sign in to comment.