Skip to content

Commit

Permalink
Refactored the code for finding the latest checkpoint and made it so …
Browse files Browse the repository at this point in the history
…we can also get the ema model, I partially also added support for loading them but for now it seems to be broken so it needs more work.
  • Loading branch information
ZeroCool940711 committed Sep 5, 2023
1 parent 3f4684d commit bc998c1
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 278 deletions.
105 changes: 11 additions & 94 deletions infer_maskgit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import argparse
import glob
import os
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
Expand All @@ -18,6 +16,9 @@
VQGanVAETaming,
get_accelerator,
)
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

# Create the parser
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -211,43 +212,10 @@ def main():
print("Loading Muse VQGanVAE")

if args.latest_checkpoint:
print("Finding latest VAE checkpoint...")
orig_vae_path = args.vae_path

if os.path.isfile(args.vae_path) or ".pt" in args.vae_path:
# If args.vae_path is a file, split it into directory and filename
args.vae_path, _ = os.path.split(args.vae_path)

checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(
checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1))
)

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
print(
f"Warning: latest VAE checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)),
)
print("Using second last VAE checkpoint: ", latest_checkpoint_file)
else:
print("No usable checkpoint found.")
elif latest_checkpoint_file != orig_vae_path:
print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
else:
print("Using VAE checkpoint specified in vae_path: ", orig_vae_path)

args.vae_path = latest_checkpoint_file
else:
print("No VAE checkpoints found in directory: ", args.vae_path)
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema, model_type="vae")
print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
print("Resuming VAE from: ", args.vae_path)

Expand Down Expand Up @@ -341,61 +309,10 @@ def main():
accelerator.print("Loading Muse MaskGit...")

if args.latest_checkpoint:
accelerator.print("Finding latest MaskGit checkpoint...")
orig_vae_path = args.resume_path

if os.path.isfile(args.resume_path) or ".pt" in args.resume_path:
# If args.resume_path is a file, split it into directory and filename
args.resume_path, _ = os.path.split(args.resume_path)

if args.cond_image_size:
checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit_superres.*.pt"))
else:
checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt"))

if checkpoint_files:
if args.cond_image_size:
latest_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_checkpoint_file = max(
checkpoint_files, key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1))
)

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
accelerator.print(
f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if args.cond_image_size:
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)),
)
accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable MaskGit checkpoint found.")
load = False
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path)

args.resume_path = latest_checkpoint_file
else:
accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path)
load = False
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)

Expand Down
54 changes: 11 additions & 43 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
ImageDataset,
get_dataset_from_dataroot,
)
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

from muse_maskgit_pytorch.vqvae import VQVAE

# Create the parser
Expand Down Expand Up @@ -211,6 +215,8 @@
action="store_true",
help="Save the original input.png and output.png images to a subfolder instead of deleting them after the grid is made.",
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use ema.")
parser.add_argument("--ema_beta", type=float, default=0.995, help="Ema beta.")


@dataclass
Expand Down Expand Up @@ -373,52 +379,14 @@ def main():
).to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")

if args.latest_checkpoint:
accelerator.print("Finding latest checkpoint...")
orig_vae_path = args.vae_path

if os.path.isfile(args.vae_path) or ".pt" in args.vae_path:
# If args.vae_path is a file, split it into directory and filename
args.vae_path, _ = os.path.split(args.vae_path)

checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
accelerator.print(
f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)
accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable checkpoint found.")
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path)

args.vae_path = latest_checkpoint_file
else:
accelerator.print("No checkpoints found in directory: ", args.vae_path)
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_model_path}")
#if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming VAE from: ", args.vae_path)

vae.load(args.vae_path)
vae.load(args.vae_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu")

if args.use_paintmind:
# load VAE
Expand Down
122 changes: 122 additions & 0 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import print_function
import re, glob, os, torch

def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False):
"""Gets the latest checkpoint paths for both the non-ema and ema VAEs.
Args:
resume_path: The path to the directory containing the VAE checkpoints.
Returns:
A tuple containing the paths to the latest non-ema and ema VAE checkpoints, respectively.
"""

vae_path, _ = os.path.split(resume_path)
if cond_image_size:
checkpoint_files = glob.glob(os.path.join(vae_path, "maskgit_superres.*.pt"))
else:
checkpoint_files = glob.glob(os.path.join(vae_path, "vae.*.pt" if model_type == "vae" else "maskgit.*.pt"))
#print(checkpoint_files)

print(f"Finding latest {'VAE' if model_type == 'vae' else 'MaskGit'} checkpoint...")

# Get the latest non-ema VAE checkpoint path
if cond_image_size:
latest_non_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)

# Check if the latest checkpoints are empty or unreadable
if os.path.getsize(latest_non_ema_checkpoint_file) == 0 or not os.access(
latest_non_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest checkpoint {latest_non_ema_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
latest_non_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)),
)
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)
print("Using second last checkpoint: ", latest_non_ema_checkpoint_file)
else:
print("No usable checkpoint found.")

if use_ema:
# Get the latest ema VAE checkpoint path
if cond_image_size:
latest_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)
else:
latest_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)

if os.path.getsize(latest_ema_checkpoint_file) == 0 or not os.access(
latest_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest EMA checkpoint {latest_ema_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
latest_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)
else:
latest_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
if x.endswith("ema.pt")
else -1,
)
print("Using second last EMA checkpoint: ", latest_ema_checkpoint_file)
else:
latest_ema_checkpoint_file = None

return latest_non_ema_checkpoint_file, latest_ema_checkpoint_file

def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
"""Removes duplicate weights from the ema state dictionary.
Args:
ema_state_dict: The state dictionary of the ema model.
non_ema_state_dict: The state dictionary of the non-ema model.
Returns:
The ema state dictionary with duplicate weights removed.
"""

ema_state_dict_copy = ema_state_dict.copy()
for key, value in ema_state_dict.items():
if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]):
del ema_state_dict_copy[key]
return ema_state_dict_copy
Loading

0 comments on commit bc998c1

Please sign in to comment.