Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 9, 2023
1 parent 05d3257 commit 5b2099a
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 47 deletions.
15 changes: 11 additions & 4 deletions infer_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,11 @@ def main():
print("Loading Muse VQGanVAE")

if args.latest_checkpoint:
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema, model_type="vae")
args.vae_path, ema_model_path = get_latest_checkpoints(
args.vae_path, use_ema=args.use_ema, model_type="vae"
)
print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
print("Resuming VAE from: ", args.vae_path)
Expand Down Expand Up @@ -309,9 +311,14 @@ def main():
accelerator.print("Loading Muse MaskGit...")

if args.latest_checkpoint:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
args.resume_path, ema_model_path = get_latest_checkpoints(
args.resume_path,
use_ema=args.use_ema,
model_type="maskgit",
cond_image_size=args.cond_image_size,
)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)
Expand Down
7 changes: 4 additions & 3 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

from muse_maskgit_pytorch.vqvae import VQVAE

# Create the parser
Expand Down Expand Up @@ -380,8 +379,10 @@ def main():

if args.latest_checkpoint:
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_model_path}")
#if args.use_ema:
print(
f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_model_path}"
)
# if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")
else:
accelerator.print("Resuming VAE from: ", args.vae_path)
Expand Down
24 changes: 12 additions & 12 deletions muse_maskgit_pytorch/trainers/base_accelerated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,21 +288,21 @@ def load(self, path: Union[str, PathLike]):
return pkg

def log_validation_images(self, images, step, prompts=None):
#if self.validation_image_scale > 1:
## Calculate the new height based on the scale factor
#new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)
# if self.validation_image_scale > 1:
## Calculate the new height based on the scale factor
# new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)

## Calculate the aspect ratio of the original image
#aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]
## Calculate the aspect ratio of the original image
# aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]

## Calculate the new width based on the new height and aspect ratio
#new_width = int(new_height * aspect_ratio)
## Calculate the new width based on the new height and aspect ratio
# new_width = int(new_height * aspect_ratio)

## Resize the images using the new width and height
#output_size = (new_width, new_height)
#images_pil = [Image.fromarray(np.array(image)) for image in images]
#images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
#images = [np.array(image_pil) for image_pil in images_pil_resized]
## Resize the images using the new width and height
# output_size = (new_width, new_height)
# images_pil = [Image.fromarray(np.array(image)) for image in images]
# images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
# images = [np.array(image_pil) for image_pil in images_pil_resized]

for tracker in self.accelerator.trackers:
if tracker.name == "tensorboard":
Expand Down
16 changes: 12 additions & 4 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,19 @@ def log_validation_images(self, logs, steps):

# Scale the images
if self.validation_image_scale > 1:
grid = torch.nn.functional.interpolate(grid.unsqueeze(0), scale_factor=self.validation_image_scale, mode="bicubic", align_corners=False)
grid = torch.nn.functional.interpolate(
grid.unsqueeze(0),
scale_factor=self.validation_image_scale,
mode="bicubic",
align_corners=False,
)
if self.use_ema:
ema_grid = torch.nn.functional.interpolate(ema_grid.unsqueeze(0), scale_factor=self.validation_image_scale, mode="bicubic", align_corners=False)
ema_grid = torch.nn.functional.interpolate(
ema_grid.unsqueeze(0),
scale_factor=self.validation_image_scale,
mode="bicubic",
align_corners=False,
)

# Save grid
grid_file = f"{steps}_{i}.png"
Expand All @@ -246,8 +256,6 @@ def log_validation_images(self, logs, steps):
super().log_validation_images(log_imgs, steps, prompts=prompts)
self.model.train()



def train(self):
self.steps = self.steps + 1
device = self.device
Expand Down
50 changes: 36 additions & 14 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from __future__ import print_function
import re, glob, os, torch

import glob
import os
import re

import torch


def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False):
"""Gets the latest checkpoint paths for both the non-ema and ema VAEs.
Expand All @@ -15,8 +21,10 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
if cond_image_size:
checkpoint_files = glob.glob(os.path.join(vae_path, "maskgit_superres.*.pt"))
else:
checkpoint_files = glob.glob(os.path.join(vae_path, "vae.*.pt" if model_type == "vae" else "maskgit.*.pt"))
#print(checkpoint_files)
checkpoint_files = glob.glob(
os.path.join(vae_path, "vae.*.pt" if model_type == "vae" else "maskgit.*.pt")
)
# print(checkpoint_files)

print(f"Finding latest {'VAE' if model_type == 'vae' else 'MaskGit'} checkpoint...")

Expand All @@ -29,7 +37,9 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
key=lambda x: int(
re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1)
)
if not x.endswith("ema.pt")
else -1,
)
Expand All @@ -38,9 +48,7 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
if os.path.getsize(latest_non_ema_checkpoint_file) == 0 or not os.access(
latest_non_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest checkpoint {latest_non_ema_checkpoint_file} is empty or unreadable."
)
print(f"Warning: latest checkpoint {latest_non_ema_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
Expand All @@ -51,7 +59,11 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_non_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1))
key=lambda x: int(
re.search(
r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x
).group(1)
)
if not x.endswith("ema.pt")
else -1,
)
Expand All @@ -71,17 +83,19 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_ema_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
key=lambda x: int(
re.search(
r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x
).group(1)
)
if x.endswith("ema.pt")
else -1,
)

if os.path.getsize(latest_ema_checkpoint_file) == 0 or not os.access(
latest_ema_checkpoint_file, os.R_OK
):
print(
f"Warning: latest EMA checkpoint {latest_ema_checkpoint_file} is empty or unreadable."
)
print(f"Warning: latest EMA checkpoint {latest_ema_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
if cond_image_size:
Expand All @@ -94,7 +108,14 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im
else:
latest_ema_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1))
key=lambda x: int(
re.search(
r"vae\.(\d+)\.ema\.pt$"
if model_type == "vae"
else r"maskgit\.(\d+)\.ema\.pt$",
x,
).group(1)
)
if x.endswith("ema.pt")
else -1,
)
Expand All @@ -104,6 +125,7 @@ def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_im

return latest_non_ema_checkpoint_file, latest_ema_checkpoint_file


def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
"""Removes duplicate weights from the ema state dictionary.
Expand All @@ -119,4 +141,4 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
for key, value in ema_state_dict.items():
if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]):
del ema_state_dict_copy[key]
return ema_state_dict_copy
return ema_state_dict_copy
1 change: 0 additions & 1 deletion muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def load_state_dict(self, *args, **kwargs):
except RuntimeError:
return super().load_state_dict(*args, **kwargs, strict=False)


def save(self, path):
if self.accelerator is not None:
self.accelerator.save(self.state_dict(), path)
Expand Down
10 changes: 8 additions & 2 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from omegaconf import OmegaConf
from rich import inspect
from torch.optim import Optimizer

from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)
Expand Down Expand Up @@ -707,9 +708,14 @@ def main():

if args.latest_checkpoint:
try:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size)
args.resume_path, ema_model_path = get_latest_checkpoints(
args.resume_path,
use_ema=args.use_ema,
model_type="maskgit",
cond_image_size=args.cond_image_size,
)
print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}")
#if args.use_ema:
# if args.use_ema:
# print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}")

except ValueError:
Expand Down
17 changes: 10 additions & 7 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from omegaconf import OmegaConf
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

from muse_maskgit_pytorch import (
VQGanVAE,
VQGanVAETaming,
Expand All @@ -21,6 +19,9 @@
get_dataset_from_dataroot,
split_dataset_into_dataloaders,
)
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
)

# disable bitsandbytes welcome message.
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
Expand Down Expand Up @@ -467,7 +468,7 @@ def main():
if args.resume_path is not None and len(args.resume_path) > 1:
load = True

accelerator.print(f"Loading Muse VQGanVAE...")
accelerator.print("Loading Muse VQGanVAE...")
vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
Expand All @@ -481,7 +482,9 @@ def main():

if args.latest_checkpoint:
try:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="vae")
args.resume_path, ema_model_path = get_latest_checkpoints(
args.resume_path, use_ema=args.use_ema, model_type="vae"
)

if ema_model_path:
ema_vae = VQGanVAE(
Expand All @@ -506,7 +509,7 @@ def main():
load = False

if load:
#vae.load(args.resume_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu")
# vae.load(args.resume_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu")
vae.load(args.resume_path, map="cpu")

resume_from_parts = args.resume_path.split(".")
Expand All @@ -518,7 +521,7 @@ def main():
if current_step == 0:
accelerator.print("No step found for the VAE model.")
else:
#accelerator.print("Resuming VAE from: ", args.resume_path)
# accelerator.print("Resuming VAE from: ", args.resume_path)
ema_vae = None
accelerator.print("No step found for the VAE model.")
current_step = 0
Expand Down

0 comments on commit 5b2099a

Please sign in to comment.