Skip to content

Commit

Permalink
Added options to the infer_vae.py for using the paintmind vae, this i…
Browse files Browse the repository at this point in the history
…s a WIP and do not work properly.
  • Loading branch information
ZeroCool940711 committed Jun 17, 2023
1 parent 0a8847f commit 5048d15
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ImageDataset,
get_dataset_from_dataroot,
)
from muse_maskgit_pytorch.vqvae import VQVAE

# Create the parser
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -189,6 +190,11 @@
action="store_true",
help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.",
)
parser.add_argument(
"--use_paintmind",
action="store_true",
help="Use PaintMind VAE..",
)


@dataclass
Expand Down Expand Up @@ -336,7 +342,7 @@ def main():
if args.vae_path and args.taming_model_path:
raise Exception("You can't pass vae_path and taming args at the same time.")

if args.vae_path:
if args.vae_path and not args.use_paintmind:
accelerator.print("Loading Muse VQGanVAE")
vae = VQGanVAE(
dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim
Expand Down Expand Up @@ -390,6 +396,11 @@ def main():

vae.load(args.vae_path)

if args.use_paintmind:
# load VAE
accelerator.print(f"Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...")
vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4")

elif args.taming_model_path:
print("Loading Taming VQGanVAE")
vae = VQGanVAETaming(
Expand All @@ -398,7 +409,10 @@ def main():
)
args.num_tokens = vae.codebook_size
args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2

# move vae to device
vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

# then you plug the vae and transformer into your MaskGit as so

dataset = ImageDataset(
Expand Down Expand Up @@ -449,11 +463,21 @@ def main():
try:
save_image(dataset[i], f"{output_dir}/input.png")

_, ids, _ = vae.encode(
dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
)
recon = vae.decode_from_ids(ids)
save_image(recon, f"{output_dir}/output.png")
if not args.use_paintmind:
# encode
_, ids, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}"))
# decode
recon = vae.decode_from_ids(ids)
#print (recon.shape) # torch.Size([1, 3, 512, 1136])
save_image(recon, f"{output_dir}/output.png")
else:
# encode
encoded, _, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}"))

# decode
recon = vae.decode(encoded).squeeze(0)
recon = torch.clamp(recon, -1.0, 1.0)
save_image(recon, f"{output_dir}/output.png")

# Load input and output images
input_image = PIL.Image.open(f"{output_dir}/input.png")
Expand Down Expand Up @@ -495,7 +519,10 @@ def main():
continue # Retry the loop

else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
if"out of memory" not in str(e):
print(e)
else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
break # Exit the retry loop after too many retries


Expand Down

0 comments on commit 5048d15

Please sign in to comment.