Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#70 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Partially restored the ability to use the EMA model for training on the VAE trainer.
  • Loading branch information
ZeroCool940711 authored Sep 4, 2023
2 parents a6d0113 + 18c0125 commit aaddf6e
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ def save(self, path):
def log_validation_images(self, logs, steps):
log_imgs = []
prompts = []

self.model.eval()
if self.use_ema:
self.ema_model.eval()

try:
valid_data = next(self.valid_dl_iter)
Expand All @@ -197,25 +200,41 @@ def log_validation_images(self, logs, steps):
valid_data = valid_data.to(self.device)

recons = self.model(valid_data, return_recons=True)
if self.use_ema:
ema_recons = self.ema_model(valid_data, return_recons=True)

# else save a grid of images

for i in range(valid_data.shape[0]):
# Get sample and reconstruction
sample = valid_data[i]
recon = recons[i]
if self.use_ema:
ema_recon = ema_recons[i]

# Create grid for this sample
grid = make_grid([sample, recon], nrow=2)
if self.use_ema:
ema_grid = make_grid([sample, ema_recon], nrow=2)

# Save grid
grid_file = f"{steps}_{i}.png"
if self.use_ema:
ema_grid_file = f"{steps}_{i}.ema.png"

save_path = self.results_dir / grid_file
save_image(grid, str(save_path))

if self.use_ema:
ema_save_path = self.results_dir / ema_grid_file
save_image(ema_grid, str(ema_save_path))

# Log each saved grid image
log_imgs.append(Image.open(save_path))
prompts.append("vae")
if self.use_ema:
log_imgs.append(Image.open(ema_save_path))
prompts.append("ema")

super().log_validation_images(log_imgs, steps, prompts=prompts)
self.model.train()
Expand Down

0 comments on commit aaddf6e

Please sign in to comment.