Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#71 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Fixed the VAE validation images being sent too many times to wandb during validation causing network issues because of it hitting the api rate limit.
  • Loading branch information
ZeroCool940711 authored Sep 5, 2023
2 parents aaddf6e + d087179 commit 09e5f92
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 280 deletions.
105 changes: 11 additions & 94 deletions infer_maskgit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import argparse
import glob
import os
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
Expand All @@ -18,6 +16,9 @@
VQGanVAETaming,
get_accelerator,
)
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

# Create the parser
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -211,43 +212,10 @@ def main():
print("Loading Muse VQGanVAE")

if args.latest_checkpoint:
print("Finding latest VAE checkpoint...")
orig_vae_path = args.vae_path

if os.path.isfile(args.vae_path) or ".pt" in args.vae_path:
# If args.vae_path is a file, split it into directory and filename
args.vae_path, _ = os.path.split(args.vae_path)

checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(
checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1))
)

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
print(
f"Warning: latest VAE checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)),
)
print("Using second last VAE checkpoint: ", latest_checkpoint_file)
else:
print("No usable checkpoint found.")
elif latest_checkpoint_file != orig_vae_path:
print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
else:
print("Using VAE checkpoint specified in vae_path: ", orig_vae_path)

args.vae_path = latest_checkpoint_file
else:
print("No VAE checkpoints found in directory: ", args.vae_path)
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema, model_type="vae")
print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
print("Resuming VAE from: ", args.vae_path)

Expand Down Expand Up @@ -341,61 +309,10 @@ def main():
accelerator.print("Loading Muse MaskGit...")

if args.latest_checkpoint:
accelerator.print("Finding latest MaskGit checkpoint...")
orig_vae_path = args.resume_path

if os.path.isfile(args.resume_path) or ".pt" in args.resume_path:
# If args.resume_path is a file, split it into directory and filename
args.resume_path, _ = os.path.split(args.resume_path)

if args.cond_image_size:
checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit_superres.*.pt"))
else:
checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt"))

if checkpoint_files:
if args.cond_image_size:
latest_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_checkpoint_file = max(
checkpoint_files, key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1))
)

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
accelerator.print(
f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if args.cond_image_size:
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)),
)
accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable MaskGit checkpoint found.")
load = False
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path)

args.resume_path = latest_checkpoint_file
else:
accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path)
load = False
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)

Expand Down
54 changes: 11 additions & 43 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
ImageDataset,
get_dataset_from_dataroot,
)
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

from muse_maskgit_pytorch.vqvae import VQVAE

# Create the parser
Expand Down Expand Up @@ -211,6 +215,8 @@
action="store_true",
help="Save the original input.png and output.png images to a subfolder instead of deleting them after the grid is made.",
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use ema.")
parser.add_argument("--ema_beta", type=float, default=0.995, help="Ema beta.")


@dataclass
Expand Down Expand Up @@ -373,52 +379,14 @@ def main():
).to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

if args.latest_checkpoint:
accelerator.print("Finding latest checkpoint...")
orig_vae_path = args.vae_path

if os.path.isfile(args.vae_path) or ".pt" in args.vae_path:
# If args.vae_path is a file, split it into directory and filename
args.vae_path, _ = os.path.split(args.vae_path)

checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
accelerator.print(
f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)
accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable checkpoint found.")
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path)

args.vae_path = latest_checkpoint_file
else:
accelerator.print("No checkpoints found in directory: ", args.vae_path)
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_model_path}")
#if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming VAE from: ", args.vae_path)

vae.load(args.vae_path)
vae.load(args.vae_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu")

if args.use_paintmind:
# load VAE
Expand Down
2 changes: 1 addition & 1 deletion muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def log_validation_images(self, logs, steps):
log_imgs.append(Image.open(ema_save_path))
prompts.append("ema")

super().log_validation_images(log_imgs, steps, prompts=prompts)
super().log_validation_images(log_imgs, steps, prompts=prompts)
self.model.train()

def train(self):
Expand Down
122 changes: 122 additions & 0 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import print_function
import re, glob, os, torch

def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False):
"""Gets the latest checkpoint paths for both the non-ema and ema VAEs.
Args:
resume_path: The path to the directory containing the VAE checkpoints.
Returns:
A tuple containing the paths to the latest non-ema and ema VAE checkpoints, respectively.
"""

vae_path, _ = os.path.split(resume_path)
if cond_image_size:
checkpoint_files = glob.glob(os.path.join(vae_path, "maskgit_superres.*.pt"))
else:
checkpoint_files = glob.glob(os.path.join(vae_path, "vae.*.pt" if model_type == "vae" else "maskgit.*.pt"))
#print(checkpoint_files)

print(f"Finding latest {'VAE' if model_type == 'vae' else 'MaskGit'} checkpoint...")

# Get the latest non-ema VAE checkpoint path
if cond_image_size:
latest_non_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)

# Check if the latest checkpoints are empty or unreadable
if os.path.getsize(latest_non_ema_checkpoint_file) == 0 or not os.access(
latest_non_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest checkpoint {latest_non_ema_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
latest_non_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)
print("Using second last checkpoint: ", latest_non_ema_checkpoint_file)
else:
print("No usable checkpoint found.")

if use_ema:
# Get the latest ema VAE checkpoint path
if cond_image_size:
latest_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)
else:
latest_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)

if os.path.getsize(latest_ema_checkpoint_file) == 0 or not os.access(
latest_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest EMA checkpoint {latest_ema_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
latest_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)
else:
latest_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)
print("Using second last EMA checkpoint: ", latest_ema_checkpoint_file)
else:
latest_ema_checkpoint_file = None

return latest_non_ema_checkpoint_file, latest_ema_checkpoint_file

def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
"""Removes duplicate weights from the ema state dictionary.
Args:
ema_state_dict: The state dictionary of the ema model.
non_ema_state_dict: The state dictionary of the non-ema model.
Returns:
The ema state dictionary with duplicate weights removed.
"""

ema_state_dict_copy = ema_state_dict.copy()
for key, value in ema_state_dict.items():
if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]):
del ema_state_dict_copy[key]
return ema_state_dict_copy
6 changes: 5 additions & 1 deletion muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,11 @@ def state_dict(self, *args, **kwargs):

@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
try:
return super().load_state_dict(*args, **kwargs)
except RuntimeError:
return super().load_state_dict(*args, **kwargs, strict=False)


def save(self, path):
if self.accelerator is not None:
Expand Down
Loading

0 comments on commit 09e5f92

Please sign in to comment.