Skip to content

Commit

Permalink
Added initial code for experimental QOL which allow us to give the tr…
Browse files Browse the repository at this point in the history
…ainer a folder with images so those images are reconstructed at the end of each epoch, this should help understand better how the training is progressing since we will be reconstructing a fixed set of images at a fixed interval.
  • Loading branch information
ZeroCool940711 committed Sep 16, 2023
1 parent 6800db7 commit 9e9a86c
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 124 deletions.
109 changes: 2 additions & 107 deletions infer_vae.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import argparse
import glob
import hashlib
import os
import random
import re
import shutil
from dataclasses import dataclass
from datetime import datetime
from typing import Optional

import accelerate
import PIL
import torch
from accelerate.utils import ProjectConfiguration
from datasets import Dataset, Image, load_dataset
from torchvision.utils import save_image
from tqdm import tqdm

from muse_maskgit_pytorch import (
VQGanVAE,
Expand All @@ -28,6 +21,7 @@
)
from muse_maskgit_pytorch.utils import (
get_latest_checkpoints,
vae_folder_validation,
)
from muse_maskgit_pytorch.vqvae import VQVAE

Expand Down Expand Up @@ -458,106 +452,7 @@ def main():
save_image(recon, f"{args.results_dir}/outputs/output.png")

if args.input_folder:
# Create output directory and save input images and reconstructions as grids
output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder))
os.makedirs(output_dir, exist_ok=True)

for i in tqdm(range(len(dataset))):
retries = 0
while True:
try:
save_image(dataset[i], f"{output_dir}/input.png")

if not args.use_paintmind:
# encode
_, ids, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)
# decode
recon = vae.decode_from_ids(ids)
# print (recon.shape) # torch.Size([1, 3, 512, 1136])
save_image(recon, f"{output_dir}/output.png")
else:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)

# decode
recon = vae.decode(encoded).squeeze(0)
recon = torch.clamp(recon, -1.0, 1.0)
save_image(recon, f"{output_dir}/output.png")

# Load input and output images
input_image = PIL.Image.open(f"{output_dir}/input.png")
output_image = PIL.Image.open(f"{output_dir}/output.png")

# Create horizontal grid with input and output images
grid_image = PIL.Image.new(
"RGB" if args.channels == 3 else "RGBA",
(input_image.width + output_image.width, input_image.height),
)
grid_image.paste(input_image, (0, 0))
grid_image.paste(output_image, (input_image.width, 0))

# Save grid
now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png"
grid_image.save(f"{output_dir}/{filename}", format="PNG")

if not args.save_originals:
# Remove input and output images after the grid was made.
os.remove(f"{output_dir}/input.png")
os.remove(f"{output_dir}/output.png")
else:
os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True)
shutil.move(
f"{output_dir}/input.png",
f"{os.path.join(output_dir, 'originals')}/input_{now}.png",
)
shutil.move(
f"{output_dir}/output.png",
f"{os.path.join(output_dir, 'originals')}/output_{now}.png",
)

del _
del ids
del recon

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

break # Exit the retry loop if there were no errors

except RuntimeError as e:
if "out of memory" in str(e) and retries < args.max_retries:
retries += 1
# print(f"Out of Memory. Retry #{retries}")
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
continue # Retry the loop

else:
if "out of memory" not in str(e):
print(f"\n{e}")
else:
print(f"Skipping image {i} after {retries} retries due to out of memory error")
break # Exit the retry loop after too many retries

vae_folder_validation(accelerator, vae, dataset, args=args, checkpoint_name=args.vae_path, save_originals=args.save_originals)

if __name__ == "__main__":
main()
26 changes: 23 additions & 3 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch
import torch, os
from accelerate import Accelerator
from diffusers.optimization import get_scheduler
from ema_pytorch import EMA
Expand All @@ -8,7 +8,10 @@
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

from typing import Optional
from muse_maskgit_pytorch.utils import (
vae_folder_validation,
)
from muse_maskgit_pytorch.trainers.base_accelerated_trainer import (
BaseAcceleratedTrainer,
get_optimizer,
Expand Down Expand Up @@ -66,6 +69,7 @@ def __init__(
use_8bit_adam=False,
num_cycles=1,
scheduler_power=1.0,
validation_folder_at_end_of_epoch: Optional[DataLoader] = None,
args=None,
):
super().__init__(
Expand All @@ -91,6 +95,8 @@ def __init__(
# we are going to use them later to save them to a config file.
self.args = args

self.validation_folder_at_end_of_epoch = validation_folder_at_end_of_epoch

self.current_step = current_step

# vae
Expand Down Expand Up @@ -266,6 +272,7 @@ def train(self):
else:
proc_label = f"[P{self.accelerator.process_index:03d}][Worker]"


for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for img in self.dl:
loss = 0.0
Expand Down Expand Up @@ -340,7 +347,11 @@ def train(self):
)

logs["lr"] = self.lr_scheduler.get_last_lr()[0]
self.accelerator.log(logs, step=steps)
try:
self.accelerator.log(logs, step=steps)
except ConnectionResetError:
print ("There was an error with the Wandb connection. Retrying...")
self.accelerator.log(logs, step=steps)

# update exponential moving averaged generator

Expand Down Expand Up @@ -386,6 +397,15 @@ def train(self):

self.steps += 1

#

if self.validation_folder_at_end_of_epoch:
vae_folder_validation(self.accelerator, self.model, self.validation_folder_at_end_of_epoch,
self.args,
checkpoint_name=os.path.join(self.results_dir, f'vae.{steps}.pt'),

)

# if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps:
# self.accelerator.print(
# f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..."
Expand Down
39 changes: 26 additions & 13 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
del ema_state_dict_copy[key]
return ema_state_dict_copy

def vae_folder_validation(accelerator, vae, dataset, args=None):
def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name="vae", save_originals=False):

# Create output directory and save input images and reconstructions as grids
output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder))
output_dir = os.path.join(args.results_dir, "outputs",
os.path.basename(args.input_folder if args.input_folder else args.validation_folder_at_end_of_epoch))
os.makedirs(output_dir, exist_ok=True)

for i in tqdm(range(len(dataset))):
Expand All @@ -159,16 +160,26 @@ def vae_folder_validation(accelerator, vae, dataset, args=None):
try:
save_image(dataset[i], f"{output_dir}/input.png")

# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
try:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
"cpu"
if args.cpu
else accelerator.device
if args.gpu == 0
else f"cuda:{args.gpu}"
)
)
except AttributeError:
# encode
encoded, _, _ = vae.encode(
dataset[i][None].to(
accelerator.device
if accelerator.device
else f"cuda:{args.gpu}"
)
)
)

# decode
recon = vae.decode(encoded).squeeze(0)
Expand All @@ -191,10 +202,10 @@ def vae_folder_validation(accelerator, vae, dataset, args=None):
now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png"
filename = f"{hash}_{now}-{os.path.basename(checkpoint_name)}.png"
grid_image.save(f"{output_dir}/{filename}", format="PNG")

if not args.save_originals:
if not save_originals:
# Remove input and output images after the grid was made.
os.remove(f"{output_dir}/input.png")
os.remove(f"{output_dir}/output.png")
Expand All @@ -215,6 +226,8 @@ def vae_folder_validation(accelerator, vae, dataset, args=None):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

dataset[i][None].to("cpu")

break # Exit the retry loop if there were no errors

except RuntimeError as e:
Expand Down
56 changes: 55 additions & 1 deletion train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import wandb
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from datasets import load_dataset, Dataset, Image
from omegaconf import OmegaConf

from muse_maskgit_pytorch import (
Expand Down Expand Up @@ -319,6 +319,19 @@
action="store_true",
help="Use F.mse_loss instead of F.l1_loss.",
)
parser.add_argument(
"--validation_folder_at_end_of_epoch",
type=str,
default=None,
help="Path to a folder containing images that will be used for validation/reconstruction."
" At the end of each Epoch this folder will be used for validation and reconstructions will be saved to a subfolder called 'outputs/validation'.",
)
parser.add_argument(
"--exclude_folders",
type=str,
default=None,
help="List of folders we want to exclude when doing reconstructions from an input folder.",
)


@dataclass
Expand Down Expand Up @@ -383,6 +396,9 @@ class Arguments:
use_l2_recon_loss: bool = False
debug: bool = False
config_path: Optional[str] = None
validation_folder_at_end_of_epoch: Optional[str] = None
input_folder = None
exclude_folders: Optional[str] = None


def preprocess_webdataset(args, image):
Expand Down Expand Up @@ -575,6 +591,43 @@ def main():
dataloader, validation_dataloader = split_dataset_into_dataloaders(
dataset, args.valid_frac, args.seed, args.batch_size
)

if args.validation_folder_at_end_of_epoch:
# Create dataset from input folder
extensions = ["jpg", "jpeg", "png", "webp"]
exclude_folders = args.exclude_folders.split(",") if args.exclude_folders else []

filepaths = []
for root, dirs, files in os.walk(args.validation_folder_at_end_of_epoch, followlinks=True):
# Resolve symbolic link to actual path and exclude based on actual path
resolved_root = os.path.realpath(root)
for exclude_folder in exclude_folders:
if exclude_folder in resolved_root:
dirs[:] = []
break
for file in files:
if file.lower().endswith(tuple(extensions)):
filepaths.append(os.path.join(root, file))

if not filepaths:
print(f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}.")
exit(1)

epoch_validation_dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image())

epoch_validation_dataset = ImageDataset(
epoch_validation_dataset,
image_size=512,
image_column=args.image_column,
center_crop=False,
flip=False,
random_crop=False,
alpha_channel=False if args.channels == 3 else True,
)

else:
epoch_validation_dataset = None

trainer = VQGanVAETrainer(
vae,
dataloader,
Expand Down Expand Up @@ -606,6 +659,7 @@ def main():
num_cycles=args.num_cycles,
scheduler_power=args.scheduler_power,
num_epochs=args.num_epochs,
validation_folder_at_end_of_epoch=epoch_validation_dataset,
args=args,
)

Expand Down

0 comments on commit 9e9a86c

Please sign in to comment.