Skip to content

Commit

Permalink
Implement ZeroCool QoL
Browse files Browse the repository at this point in the history
  • Loading branch information
korakoe committed Jun 3, 2023
1 parent 373febb commit 47c1be0
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 16 deletions.
30 changes: 25 additions & 5 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
flip=True,
center_crop=True,
stream=False,
using_taming=False
):
super().__init__()
self.dataset = dataset
Expand All @@ -46,6 +47,7 @@ def __init__(
transform_list.append(T.CenterCrop(image_size))
transform_list.append(T.ToTensor())
self.transform = T.Compose(transform_list)
self.using_taming = using_taming

def __len__(self):
if not self.stream:
Expand All @@ -55,7 +57,10 @@ def __len__(self):

def __getitem__(self, index):
image = self.dataset[index][self.image_column]
return self.transform(image) - 0.5
if self.using_taming:
return self.transform(image) - 0.5
else:
return self.transform(image)


class ImageTextDataset(ImageDataset):
Expand All @@ -69,6 +74,7 @@ def __init__(
flip=True,
center_crop=True,
stream=False,
using_taming=False
):
super().__init__(
dataset,
Expand All @@ -77,6 +83,7 @@ def __init__(
flip=flip,
center_crop=center_crop,
stream=stream,
using_taming=using_taming
)
self.caption_column: str = caption_column
self.tokenizer: T5Tokenizer = tokenizer
Expand Down Expand Up @@ -104,7 +111,11 @@ def __getitem__(self, index):

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
return self.transform(image), input_ids[0], attn_mask[0]

if self.using_taming:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
else:
return self.transform(image), input_ids[0], attn_mask[0]


class URLTextDataset(ImageDataset):
Expand All @@ -117,13 +128,15 @@ def __init__(
caption_column="caption",
flip=True,
center_crop=True,
using_taming=True
):
super().__init__(
dataset,
image_size=image_size,
image_column=image_column,
flip=flip,
center_crop=center_crop,
using_taming=using_taming
)
self.caption_column: str = caption_column
self.tokenizer: T5Tokenizer = tokenizer
Expand Down Expand Up @@ -161,13 +174,17 @@ def __getitem__(self, index):

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
return self.transform(image), input_ids[0], attn_mask[0]
if self.using_taming:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
else:
return self.transform(image), input_ids[0], attn_mask[0]


class LocalTextImageDataset(Dataset):
def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True):
def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False):
super().__init__()
self.tokenizer = tokenizer
self.using_taming = using_taming

print("Building dataset...")

Expand Down Expand Up @@ -226,7 +243,10 @@ def __getitem__(self, index):

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
return self.transform(image), input_ids[0], attn_mask[0]
if self.using_taming:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
else:
return self.transform(image), input_ids[0], attn_mask[0]


def get_directory_size(path):
Expand Down
36 changes: 36 additions & 0 deletions muse_maskgit_pytorch/trainers/base_accelerated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from lion_pytorch import Lion
from torch import nn
from torch.optim import Adam, AdamW, Optimizer
from torch_optimizer import AdaBound, AdaMod, AccSGD, AdamP, AggMo, DiffGrad, \
Lamb, NovoGrad, PID, QHAdam, QHM, RAdam, SGDP, SGDW, Shampoo, SWATS, Yogi
from transformers.optimization import Adafactor
from torch.utils.data import DataLoader, random_split

Expand Down Expand Up @@ -136,6 +138,40 @@ def get_optimizer(
)
elif optimizer == "Adafactor":
return Adafactor(parameters, lr=lr, weight_decay=weight_decay, relative_step=False, scale_parameter=False, **optimizer_kwargs)
elif optimizer == "AccSGD":
return AccSGD(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "AdaBound":
return AdaBound(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "AdaMod":
return AdaMod(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "AdamP":
return AdamP(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "AggMo":
return AggMo(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "DiffGrad":
return DiffGrad(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "Lamb":
return Lamb(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "NovoGrad":
return NovoGrad(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "PID":
return PID(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "QHAdam":
return QHAdam(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "QHM":
return QHM(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "RAdam":
return RAdam(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "SGDP":
return SGDP(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "SGDW":
return SGDW(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "Shampoo":
return Shampoo(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "SWATS":
return SWATS(parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "Yogi":
return Yogi(parameters, lr=lr, weight_decay=weight_decay)
else:
raise NotImplementedError(f"{optimizer} optimizer not supported yet.")

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"taming-transformers @ git+https://github.com/neggles/[email protected]",
"transformers",
"torchvision",
"torch_optimizer",
"tqdm",
"vector-quantize-pytorch>=0.10.14",
"lion-pytorch",
Expand Down
120 changes: 109 additions & 11 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from diffusers.optimization import SchedulerType, get_scheduler
from torch.optim import Optimizer

import os
import glob
import re

try:
import torch_xla
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -287,7 +291,10 @@
"--optimizer",
type=str,
default="Adafactor",
help="Optimizer to use. Choose between: ['Adam', 'AdamW', 'Lion', 'Adafactor']. Default: Adafactor (paper recommended)",
help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion', 'Adafactor', "
"'AdaBound', 'AdaMod', 'AccSGD', 'AdamP', 'AggMo', 'DiffGrad', 'Lamb', "
"'NovoGrad', 'PID', 'QHAdam', 'QHM', 'RAdam', 'SGDP', 'SGDW', 'Shampoo', "
"'SWATS', 'Yogi']. Default: Lion",
)
parser.add_argument(
"--weight_decay",
Expand All @@ -311,6 +318,11 @@
action="store_true",
help="whether to load a dataset with links instead of image (image column becomes URL column)",
)
parser.add_argument(
"--latest_checkpoint",
action="store_true",
help="Automatically find and use the latest checkpoint in the folder.",
)
parser.add_argument(
"--debug",
action="store_true",
Expand Down Expand Up @@ -374,6 +386,7 @@ class Arguments:
cache_path: Optional[str] = None
skip_arrow: bool = False
link: bool = True
latest_checkpoint: bool = False
debug: bool = False


Expand Down Expand Up @@ -438,13 +451,54 @@ def main():
# Load the VAE
with accelerator.main_process_first():
if args.vae_path is not None:
load = True
accelerator.print(f"Using Muse VQGanVAE, loading from {args.vae_path}")
vae = VQGanVAE(
dim=args.dim,
vq_codebook_size=args.vq_codebook_size,
accelerator=accelerator,
)
vae.load(args.vae_path, map="cpu")

if args.latest_checkpoint:
accelerator.print("Finding latest checkpoint...")
orig_vae_path = args.vae_path

if os.path.isfile(args.vae_path) or '.pt' in args.vae_path:
# If args.vae_path is a file, split it into directory and filename
args.vae_path, _ = os.path.split(args.vae_path)

checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(checkpoint_files,
key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1)))

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK):
accelerator.print(
f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(checkpoint_files[:-1],
key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1)))
accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable checkpoint found.")
load = False
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path)

args.vae_path = latest_checkpoint_file
else:
accelerator.print("No checkpoints found in directory: ", args.vae_path)
load = False
else:
accelerator.print("Resuming VAE from: ", args.vae_path)

if load:
vae.load(args.vae_path, map="cpu")

elif args.taming_model_path is not None and args.taming_config_path is not None:
print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
vae = VQGanVAETaming(
Expand Down Expand Up @@ -490,16 +544,57 @@ def main():
# load the maskgit transformer from disk if we have previously trained one
with accelerator.main_process_first():
if args.resume_path:
load = True
accelerator.print(f"Resuming MaskGit from: {args.resume_path}")
maskgit.load(args.resume_path)
resume_from_parts = args.resume_path.split(".")
for i in range(len(resume_from_parts) - 1, -1, -1):
if resume_from_parts[i].isdigit():
current_step = int(resume_from_parts[i])
accelerator.print(f"Found step {current_step} for the MaskGit model.")
break
if current_step == 0:
accelerator.print("No step found for the MaskGit model.")

if args.latest_checkpoint:
accelerator.print("Finding latest checkpoint...")
orig_vae_path = args.resume_path

if os.path.isfile(args.resume_path) or '.pt' in args.resume_path:
# If args.resume_path is a file, split it into directory and filename
args.resume_path, _ = os.path.split(args.resume_path)

checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(checkpoint_files,
key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt', x).group(1)))

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK):
accelerator.print(
f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(
re.search(r'maskgit\.(\d+)\.pt', x).group(1)))
accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable checkpoint found.")
load = False
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using checkpoint specified in resume_path: ", orig_vae_path)

args.resume_path = latest_checkpoint_file
else:
accelerator.print("No checkpoints found in directory: ", args.resume_path)
load = False
else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)

if load:
maskgit.load(args.resume_path)

resume_from_parts = args.resume_path.split(".")
for i in range(len(resume_from_parts) - 1, -1, -1):
if resume_from_parts[i].isdigit():
current_step = int(resume_from_parts[i])
accelerator.print(f"Found step {current_step} for the MaskGit model.")
break
if current_step == 0:
accelerator.print("No step found for the MaskGit model.")
else:
accelerator.print("Initialized new empty MaskGit model.")
current_step = 0
Expand All @@ -513,6 +608,7 @@ def main():
tokenizer=transformer.tokenizer,
center_crop=False if args.no_center_crop else True,
flip=False if args.no_flip else True,
using_taming=False if not args.taming_model_path else True
)
elif args.link:
if not args.dataset_name:
Expand All @@ -526,6 +622,7 @@ def main():
caption_column=args.caption_column,
center_crop=False if args.no_center_crop else True,
flip=False if args.no_flip else True,
using_taming=False if not args.taming_model_path else True
)
else:
dataset = ImageTextDataset(
Expand All @@ -537,6 +634,7 @@ def main():
center_crop=False if args.no_center_crop else True,
flip=False if args.no_flip else True,
stream=args.streaming,
using_taming=False if not args.taming_model_path else True
)

# Create the dataloaders
Expand Down

0 comments on commit 47c1be0

Please sign in to comment.