From 05d3257337368d8b6067f790927263bb8ad555ca Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sat, 9 Sep 2023 01:44:03 -0700 Subject: [PATCH] Continued to add support for the EMA model. --- .../trainers/base_accelerated_trainer.py | 25 +++++----- .../trainers/vqvae_trainers.py | 9 +++- train_muse_maskgit.py | 13 +++-- train_muse_vae.py | 48 ++++++++++--------- 4 files changed, 55 insertions(+), 40 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index 9bdce76..09d07bb 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -10,7 +10,6 @@ from beartype import beartype from datasets import Dataset from lion_pytorch import Lion -from PIL import Image from torch import nn from torch.optim import Adam, AdamW, Optimizer from torch.utils.data import DataLoader, random_split @@ -289,21 +288,21 @@ def load(self, path: Union[str, PathLike]): return pkg def log_validation_images(self, images, step, prompts=None): - if self.validation_image_scale != 1: - # Calculate the new height based on the scale factor - new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale) + #if self.validation_image_scale > 1: + ## Calculate the new height based on the scale factor + #new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale) - # Calculate the aspect ratio of the original image - aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0] + ## Calculate the aspect ratio of the original image + #aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0] - # Calculate the new width based on the new height and aspect ratio - new_width = int(new_height * aspect_ratio) + ## Calculate the new width based on the new height and aspect ratio + #new_width = int(new_height * aspect_ratio) - # Resize the images using the new width and height - output_size = (new_width, new_height) - images_pil = [Image.fromarray(np.array(image)) for image in images] - images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil] - images = [np.array(image_pil) for image_pil in images_pil_resized] + ## Resize the images using the new width and height + #output_size = (new_width, new_height) + #images_pil = [Image.fromarray(np.array(image)) for image in images] + #images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil] + #images = [np.array(image_pil) for image_pil in images_pil_resized] for tracker in self.accelerator.trackers: if tracker.name == "tensorboard": diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index ea939f0..45a3aa6 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -1,7 +1,6 @@ import torch from accelerate import Accelerator from diffusers.optimization import get_scheduler -from einops import rearrange from ema_pytorch import EMA from omegaconf import OmegaConf from PIL import Image @@ -219,6 +218,12 @@ def log_validation_images(self, logs, steps): if self.use_ema: ema_grid = make_grid([sample, ema_recon], nrow=2) + # Scale the images + if self.validation_image_scale > 1: + grid = torch.nn.functional.interpolate(grid.unsqueeze(0), scale_factor=self.validation_image_scale, mode="bicubic", align_corners=False) + if self.use_ema: + ema_grid = torch.nn.functional.interpolate(ema_grid.unsqueeze(0), scale_factor=self.validation_image_scale, mode="bicubic", align_corners=False) + # Save grid grid_file = f"{steps}_{i}.png" if self.use_ema: @@ -241,6 +246,8 @@ def log_validation_images(self, logs, steps): super().log_validation_images(log_imgs, steps, prompts=prompts) self.model.train() + + def train(self): self.steps = self.steps + 1 device = self.device diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index ab5e1b0..0a48b4a 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -706,10 +706,15 @@ def main(): accelerator.print("Loading Muse MaskGit...") if args.latest_checkpoint: - args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size) - print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}") - #if args.use_ema: - # print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}") + try: + args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size) + print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}") + #if args.use_ema: + # print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}") + + except ValueError: + load = False + else: accelerator.print("Resuming MaskGit from: ", args.resume_path) diff --git a/train_muse_vae.py b/train_muse_vae.py index 7c3598b..515836b 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -480,28 +480,30 @@ def main(): ) if args.latest_checkpoint: - args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="vae") - if ema_model_path: - ema_vae = VQGanVAE( - dim=args.dim, - vq_codebook_dim=args.vq_codebook_dim, - vq_codebook_size=args.vq_codebook_size, - l2_recon_loss=args.use_l2_recon_loss, - channels=args.channels, - layers=args.layers, - discr_layers=args.discr_layers, - accelerator=accelerator, - ) - print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}") - - ema_vae.load(ema_model_path, map="cpu") - else: - ema_vae = None - - print(f"Resuming VAE from latest checkpoint: {args.resume_path}") - else: - accelerator.print("Resuming VAE from: ", args.resume_path) - ema_vae = None + try: + args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="vae") + + if ema_model_path: + ema_vae = VQGanVAE( + dim=args.dim, + vq_codebook_dim=args.vq_codebook_dim, + vq_codebook_size=args.vq_codebook_size, + l2_recon_loss=args.use_l2_recon_loss, + channels=args.channels, + layers=args.layers, + discr_layers=args.discr_layers, + accelerator=accelerator, + ) + print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}") + + ema_vae.load(ema_model_path, map="cpu") + else: + ema_vae = None + + print(f"Resuming VAE from latest checkpoint: {args.resume_path}") + + except ValueError: + load = False if load: #vae.load(args.resume_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu") @@ -516,6 +518,8 @@ def main(): if current_step == 0: accelerator.print("No step found for the VAE model.") else: + #accelerator.print("Resuming VAE from: ", args.resume_path) + ema_vae = None accelerator.print("No step found for the VAE model.") current_step = 0