diff --git a/infer_vae.py b/infer_vae.py index 88aea18..cbf84cf 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -25,6 +25,7 @@ ImageDataset, get_dataset_from_dataroot, ) +from muse_maskgit_pytorch.vqvae import VQVAE # Create the parser parser = argparse.ArgumentParser() @@ -189,6 +190,11 @@ action="store_true", help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.", ) +parser.add_argument( + "--use_paintmind", + action="store_true", + help="Use PaintMind VAE..", +) @dataclass @@ -336,7 +342,7 @@ def main(): if args.vae_path and args.taming_model_path: raise Exception("You can't pass vae_path and taming args at the same time.") - if args.vae_path: + if args.vae_path and not args.use_paintmind: accelerator.print("Loading Muse VQGanVAE") vae = VQGanVAE( dim=args.dim, vq_codebook_size=args.vq_codebook_size, vq_codebook_dim=args.vq_codebook_dim @@ -390,6 +396,11 @@ def main(): vae.load(args.vae_path) + if args.use_paintmind: + # load VAE + accelerator.print(f"Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...") + vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4") + elif args.taming_model_path: print("Loading Taming VQGanVAE") vae = VQGanVAETaming( @@ -398,7 +409,10 @@ def main(): ) args.num_tokens = vae.codebook_size args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 + + # move vae to device vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") + # then you plug the vae and transformer into your MaskGit as so dataset = ImageDataset( @@ -449,11 +463,21 @@ def main(): try: save_image(dataset[i], f"{output_dir}/input.png") - _, ids, _ = vae.encode( - dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") - ) - recon = vae.decode_from_ids(ids) - save_image(recon, f"{output_dir}/output.png") + if not args.use_paintmind: + # encode + _, ids, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + # decode + recon = vae.decode_from_ids(ids) + #print (recon.shape) # torch.Size([1, 3, 512, 1136]) + save_image(recon, f"{output_dir}/output.png") + else: + # encode + encoded, _, _ = vae.encode(dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")) + + # decode + recon = vae.decode(encoded).squeeze(0) + recon = torch.clamp(recon, -1.0, 1.0) + save_image(recon, f"{output_dir}/output.png") # Load input and output images input_image = PIL.Image.open(f"{output_dir}/input.png") @@ -495,7 +519,10 @@ def main(): continue # Retry the loop else: - print(f"Skipping image {i} after {retries} retries due to out of memory error") + if"out of memory" not in str(e): + print(e) + else: + print(f"Skipping image {i} after {retries} retries due to out of memory error") break # Exit the retry loop after too many retries