Skip to content

Commit

Permalink
Added some utility code to reconstruct a whole folder with the VAE, s…
Browse files Browse the repository at this point in the history
…ince this code is used on multiple places now I decided to move it to the utils.py script to reuse it.
  • Loading branch information
ZeroCool940711 committed Sep 16, 2023
1 parent 70c7bde commit 6800db7
Showing 1 changed file with 91 additions and 2 deletions.
93 changes: 91 additions & 2 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import print_function

import glob
import shutil
import os
import re

import PIL
import torch

import hashlib
from tqdm import tqdm
from torchvision.utils import save_image
from datetime import datetime

def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False):
"""Gets the latest checkpoint paths for both the non-ema and ema VAEs.
Expand Down Expand Up @@ -142,3 +146,88 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]):
del ema_state_dict_copy[key]
return ema_state_dict_copy

def vae_folder_validation(accelerator, vae, dataset, args=None):

# Create output directory and save input images and reconstructions as grids
output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder))
os.makedirs(output_dir, exist_ok=True)

for i in tqdm(range(len(dataset))):
retries = 0
while True:
try:
save_image(dataset[i], f"{output_dir}/input.png")

# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)

# decode
recon = vae.decode(encoded).squeeze(0)
recon = torch.clamp(recon, -1.0, 1.0)
save_image(recon, f"{output_dir}/output.png")

# Load input and output images
input_image = PIL.Image.open(f"{output_dir}/input.png")
output_image = PIL.Image.open(f"{output_dir}/output.png")

# Create horizontal grid with input and output images
grid_image = PIL.Image.new(
"RGB" if args.channels == 3 else "RGBA",
(input_image.width + output_image.width, input_image.height),
)
grid_image.paste(input_image, (0, 0))
grid_image.paste(output_image, (input_image.width, 0))

# Save grid
now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png"
grid_image.save(f"{output_dir}/{filename}", format="PNG")

if not args.save_originals:
# Remove input and output images after the grid was made.
os.remove(f"{output_dir}/input.png")
os.remove(f"{output_dir}/output.png")
else:
os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True)
shutil.move(
f"{output_dir}/input.png",
f"{os.path.join(output_dir, 'originals')}/input_{now}.png",
)
shutil.move(
f"{output_dir}/output.png",
f"{os.path.join(output_dir, 'originals')}/output_{now}.png",
)

del _
del recon

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

break # Exit the retry loop if there were no errors

except RuntimeError as e:
if "out of memory" in str(e) and retries < args.max_retries:
retries += 1
# print(f"Out of Memory. Retry #{retries}")
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
continue # Retry the loop

else:
if "out of memory" not in str(e):
print(f"\n{e}")
else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
break # Exit the retry loop after too many retries

0 comments on commit 6800db7

Please sign in to comment.