Skip to content

Commit

Permalink
Continued to add support for the EMA model.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroCool940711 committed Sep 9, 2023
1 parent 1e63779 commit 05d3257
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 40 deletions.
25 changes: 12 additions & 13 deletions muse_maskgit_pytorch/trainers/base_accelerated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from beartype import beartype
from datasets import Dataset
from lion_pytorch import Lion
from PIL import Image
from torch import nn
from torch.optim import Adam, AdamW, Optimizer
from torch.utils.data import DataLoader, random_split
Expand Down Expand Up @@ -289,21 +288,21 @@ def load(self, path: Union[str, PathLike]):
return pkg

def log_validation_images(self, images, step, prompts=None):
if self.validation_image_scale != 1:
# Calculate the new height based on the scale factor
new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)
#if self.validation_image_scale > 1:
## Calculate the new height based on the scale factor
#new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)

# Calculate the aspect ratio of the original image
aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]
## Calculate the aspect ratio of the original image
#aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]

# Calculate the new width based on the new height and aspect ratio
new_width = int(new_height * aspect_ratio)
## Calculate the new width based on the new height and aspect ratio
#new_width = int(new_height * aspect_ratio)

# Resize the images using the new width and height
output_size = (new_width, new_height)
images_pil = [Image.fromarray(np.array(image)) for image in images]
images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
images = [np.array(image_pil) for image_pil in images_pil_resized]
## Resize the images using the new width and height
#output_size = (new_width, new_height)
#images_pil = [Image.fromarray(np.array(image)) for image in images]
#images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
#images = [np.array(image_pil) for image_pil in images_pil_resized]

for tracker in self.accelerator.trackers:
if tracker.name == "tensorboard":
Expand Down
9 changes: 8 additions & 1 deletion muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from accelerate import Accelerator
from diffusers.optimization import get_scheduler
from einops import rearrange
from ema_pytorch import EMA
from omegaconf import OmegaConf
from PIL import Image
Expand Down Expand Up @@ -219,6 +218,12 @@ def log_validation_images(self, logs, steps):
if self.use_ema:
ema_grid = make_grid([sample, ema_recon], nrow=2)

# Scale the images
if self.validation_image_scale > 1:
grid = torch.nn.functional.interpolate(grid.unsqueeze(0), scale_factor=self.validation_image_scale, mode="bicubic", align_corners=False)
if self.use_ema:
ema_grid = torch.nn.functional.interpolate(ema_grid.unsqueeze(0), scale_factor=self.validation_image_scale, mode="bicubic", align_corners=False)

# Save grid
grid_file = f"{steps}_{i}.png"
if self.use_ema:
Expand All @@ -241,6 +246,8 @@ def log_validation_images(self, logs, steps):
super().log_validation_images(log_imgs, steps, prompts=prompts)
self.model.train()



def train(self):
self.steps = self.steps + 1
device = self.device
Expand Down
13 changes: 9 additions & 4 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,15 @@ def main():
accelerator.print("Loading Muse MaskGit...")

if args.latest_checkpoint:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")
try:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")

except ValueError:
load = False

else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)

Expand Down
48 changes: 26 additions & 22 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,28 +480,30 @@ def main():
)

if args.latest_checkpoint:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="vae")
if ema_model_path:
ema_vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
vq_codebook_size=args.vq_codebook_size,
l2_recon_loss=args.use_l2_recon_loss,
channels=args.channels,
layers=args.layers,
discr_layers=args.discr_layers,
accelerator=accelerator,
)
print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")

ema_vae.load(ema_model_path, map="cpu")
else:
ema_vae = None

print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
else:
accelerator.print("Resuming VAE from: ", args.resume_path)
ema_vae = None
try:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="vae")

if ema_model_path:
ema_vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
vq_codebook_size=args.vq_codebook_size,
l2_recon_loss=args.use_l2_recon_loss,
channels=args.channels,
layers=args.layers,
discr_layers=args.discr_layers,
accelerator=accelerator,
)
print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")

ema_vae.load(ema_model_path, map="cpu")
else:
ema_vae = None

print(f"Resuming VAE from latest checkpoint: {args.resume_path}")

except ValueError:
load = False

if load:
#vae.load(args.resume_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu")
Expand All @@ -516,6 +518,8 @@ def main():
if current_step == 0:
accelerator.print("No step found for the VAE model.")
else:
#accelerator.print("Resuming VAE from: ", args.resume_path)
ema_vae = None
accelerator.print("No step found for the VAE model.")
current_step = 0

Expand Down

0 comments on commit 05d3257

Please sign in to comment.