Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#5 from korakoe/adding_training_script
Browse files Browse the repository at this point in the history
Merge to main
  • Loading branch information
korakoe authored Jun 3, 2023
2 parents a5476c6 + 00baf98 commit 373febb
Show file tree
Hide file tree
Showing 20 changed files with 3,319 additions and 967 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
*.ipynb
wandb
results
models
dataset
taming
~
input.png
output.png
muse_maskgit_pytorch/wt.py

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,26 @@ images = muse([

images # List[PIL.Image.Image]
```
## Training

Training should be done in 4 stages.
1. Training base VAE(swap out the dataset_name with your huggingface dataset)
```
accelerate launch train_muse_vae.py --dataset_name="Isamu136/big-animal-dataset"
```
2. Once you trained enough in the base VAE, move the checkpoint of your latest version to a new location. Then, do
```
accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --vae_path=path_to_vae_checkpoint
```
Alternatively, if you want to use a pretrained autoencoder, download one from [here](https://github.com/CompVis/taming-transformers) and then extract it. In the below code, we are using vqgan_imagenet_f16_1024. Change the paths accordingly
```
accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="elephant"
```
or if you want to train on cifar10, try

```
accelerate launch train_muse_maskgit.py --dataset_name="cifar10" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="0" --image_column="img" --caption_column="label"
```
## Appreciation

- <a href="https://stability.ai/">StabilityAI</a> for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
Expand All @@ -232,7 +251,8 @@ images # List[PIL.Image.Image]
- [x] add optional self-conditioning on embeddings
- [x] combine with token critic paper, already implemented at <a href="https://github.com/lucidrains/phenaki-pytorch">Phenaki</a>

- [ ] hook up accelerate training code for maskgit
- [x] hook up accelerate training code for maskgit
- [ ] train a base model

## Citations

Expand Down
216 changes: 216 additions & 0 deletions infer_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import torch
from torchvision.utils import save_image
from datasets import load_dataset, Dataset, Image
import os, random
from muse_maskgit_pytorch import (
VQGanVAE,
VQGanVAETaming,
get_accelerator,
)
from muse_maskgit_pytorch.dataset import (
get_dataset_from_dataroot,
ImageDataset,
)

import argparse


def parse_args():
# Create the parser
parser = argparse.ArgumentParser()
parser.add_argument(
"--no_center_crop",
action="store_true",
help="Don't do center crop.",
)
parser.add_argument(
"--no_flip",
action="store_true",
help="Don't flip image.",
)
parser.add_argument(
"--random_image",
action="store_true",
help="Get a random image from the dataset to use for the reconstruction.",
)
parser.add_argument(
"--dataset_save_path",
type=str,
default="dataset",
help="Path to save the dataset if you are making one from a directory",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Seed for reproducibility. If set to -1 a random seed will be generated.",
)
parser.add_argument("--valid_frac", type=float, default=0.05, help="validation fraction.")
parser.add_argument(
"--image_column",
type=str,
default="image",
help="The column of the dataset containing an image.",
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="Precision to train on.",
)
parser.add_argument(
"--results_dir",
type=str,
default="results",
help="Path to save the training samples and checkpoints",
)
parser.add_argument(
"--logging_dir",
type=str,
default="results/logs",
help="Path to log the losses and LR",
)

# vae_trainer args
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path to the vae model. eg. 'results/vae.steps.pt'",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="Name of the huggingface dataset used.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help="Dataset folder where your input images for training are.",
)
parser.add_argument("--dim", type=int, default=128, help="Model dimension.")
parser.add_argument("--batch_size", type=int, default=512, help="Batch Size.")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.")
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument(
"--image_size",
type=int,
default=256,
help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it",
)
parser.add_argument(
"--taming_model_path",
type=str,
default=None,
help="path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)",
)

parser.add_argument(
"--taming_config_path",
type=str,
default=None,
help="path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)",
)
parser.add_argument(
"--input_image",
type=str,
default=None,
help="Path to an image to use as input for reconstruction instead of using one from the dataset.",
)

# Parse the argument
return parser.parse_args()


def seed_to_int(s):
if type(s) is int:
return s
if s is None or s == "":
return random.randint(0, 2**32 - 1)

if "," in s:
s = s.split(",")

if type(s) is list:
seed_list = []
for seed in s:
if seed is None or seed == "":
seed_list.append(random.randint(0, 2**32 - 1))
else:
seed_list = s

return seed_list

n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
while n >= 2**32:
n = n >> 32
return n


def main():
args = parse_args()
accelerator = get_accelerator(
mixed_precision=args.mixed_precision,
logging_dir=args.logging_dir,
)

# set pytorch seed for reproducibility
torch.manual_seed(seed_to_int(args.seed))

if args.train_data_dir and not args.input_image:
dataset = get_dataset_from_dataroot(
args.train_data_dir,
image_column=args.image_column,
save_path=args.dataset_save_path,
)
elif args.dataset_name and not args.input_image:
dataset = load_dataset(args.dataset_name)["train"]

elif args.input_image:
dataset = Dataset.from_dict({"image": [args.input_image]}).cast_column("image", Image())

if args.vae_path and args.taming_model_path:
raise Exception("You can't pass vae_path and taming args at the same time.")

if args.vae_path:
accelerator.print("Loading Muse VQGanVAE")
vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size).to(accelerator.device)

accelerator.print("Resuming VAE from: ", args.vae_path)
vae.load(args.vae_path) # you will want to load the exponentially moving averaged VAE

elif args.taming_model_path:
print("Loading Taming VQGanVAE")
vae = VQGanVAETaming(
vqgan_model_path=args.taming_model_path,
vqgan_config_path=args.taming_config_path,
)
args.num_tokens = vae.codebook_size
args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2
vae = vae.to(accelerator.device)
# then you plug the vae and transformer into your MaskGit as so

dataset = ImageDataset(
dataset,
args.image_size,
image_column=args.image_column,
center_crop=not args.no_center_crop,
flip=not args.no_flip,
)

image_id = 0 if not args.random_image else random.randint(0, len(dataset))

os.makedirs(f"{args.results_dir}/outputs", exist_ok=True)

save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png")

_, ids, _ = vae.encode(dataset[image_id][None].to(accelerator.device))
recon = vae.decode_from_ids(ids)
save_image(recon, f"{args.results_dir}/outputs/output.png")


if __name__ == "__main__":
main()
19 changes: 16 additions & 3 deletions muse_maskgit_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic
from .muse_maskgit_pytorch import MaskGit, MaskGitTransformer, Muse, TokenCritic, Transformer
from .trainers import MaskGitTrainer, VQGanVAETrainer, get_accelerator
from .vqgan_vae import VQGanVAE
from .vqgan_vae_taming import VQGanVAETaming

from muse_maskgit_pytorch.trainers import VQGanVAETrainer
__all__ = [
"VQGanVAE",
"VQGanVAETaming",
"Transformer",
"MaskGit",
"Muse",
"MaskGitTransformer",
"TokenCritic",
"VQGanVAETrainer",
"MaskGitTrainer",
"get_accelerator",
]
Loading

0 comments on commit 373febb

Please sign in to comment.