Skip to content

Commit

Permalink
Init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
isamu-isozaki committed Feb 19, 2023
1 parent a5476c6 commit e3ad84e
Show file tree
Hide file tree
Showing 9 changed files with 1,396 additions and 9 deletions.
75 changes: 75 additions & 0 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Author: Isamu Isozaki ([email protected])
Description: description
Created: 2023-02-18T03:59:45.810Z
Modified: !date!
Modified By: modifier
"""
from torch.utils.data import Dataset
import torchvision.transforms as T
from PIL import Image, ImageFile
from pathlib import Path
from muse_maskgit_pytorch.t5 import MAX_LENGTH

ImageFile.LOAD_TRUNCATED_IMAGES = True


class ImageDataset(Dataset):
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]

print(f"{len(self.paths)} training samples found at {folder}")

self.transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor(),
]
)

def __len__(self):
return len(self.paths)

def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
class ImageTextDataset(Dataset):
def __init__(self, dataset, image_size, tokenizer, image_column="image", caption_column="caption"):
super().__init__()
self.image_column = image_column
self.caption_column = caption_column
self.tokenizer = tokenizer
self.dataset = dataset
self.transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor(),
]
)
def __getitem__(self, index):
image= self.dataset[index][self.image_column]
if self.caption_column == None:
text = ""
else:
text = self.dataset[index][self.caption_column]
encoded = self.tokenizer.batch_encode_plus(
[text],
return_tensors="pt",
padding="longest",
max_length=MAX_LENGTH,
truncation=True,
)

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
return self.transform(image), input_ids, attn_mask
12 changes: 6 additions & 6 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from beartype import beartype

from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

from muse_maskgit_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME, get_model_and_tokenizer
from pathlib import Path
from tqdm.auto import tqdm

# helpers

def exists(val):
Expand Down Expand Up @@ -220,8 +219,8 @@ def __init__(

# text conditioning

self.encode_text = partial(t5_encode_text, name = t5_name)

self.tokenizer, self.t5 = get_model_and_tokenizer(t5_name)
self.encode_text = partial(t5_encode_text, tokenizer = self.tokenizer, t5=self.t5)
text_embed_dim = get_encoded_dim(t5_name)

self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity()
Expand Down Expand Up @@ -254,6 +253,7 @@ def forward_with_cond_scale(

def forward_with_neg_prompt(
self,
*args,
text_embed: torch.Tensor,
neg_text_embed: torch.Tensor,
cond_scale = 3.,
Expand All @@ -263,7 +263,7 @@ def forward_with_neg_prompt(
neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs)
pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs)

logits = neg_logits + (pos_logits - neg_logits) * cond_scale
scaled_logits = neg_logits + (pos_logits - neg_logits) * cond_scale

if return_embed:
return scaled_logits, embed
Expand Down
5 changes: 2 additions & 3 deletions muse_maskgit_pytorch/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,13 @@ def get_encoded_dim(name):
@beartype
def t5_encode_text(
texts: Union[str, List[str]],
name = DEFAULT_T5_NAME,
tokenizer,
t5,
output_device = None
):
if isinstance(texts, str):
texts = [texts]

t5, tokenizer = get_model_and_tokenizer(name)

if torch.cuda.is_available():
t5 = t5.cuda()

Expand Down
8 changes: 8 additions & 0 deletions muse_maskgit_pytorch/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
Author: Isamu Isozaki ([email protected])
Description: description
Created: 2023-02-18T19:28:19.819Z
Modified: !date!
Modified By: modifier
"""

214 changes: 214 additions & 0 deletions muse_maskgit_pytorch/trainers/base_accelerated_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@

from pathlib import Path
from shutil import rmtree

from beartype import beartype

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision.utils import make_grid, save_image

from einops import rearrange

from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs

from ema_pytorch import EMA


import numpy as np
try:
import wandb
except:
None
def noop(*args, **kwargs):
pass
# helper functions

def identity(t, *args, **kwargs):
return t

def cycle(dl):
while True:
for data in dl:
yield data


def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)


def yes_or_no(question):
answer = input(f"{question} (y/n) ")
return answer.lower() in ("yes", "y")


def pair(val):
return val if isinstance(val, tuple) else (val, val)


def convert_image_to_fn(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image


# image related helpers fnuctions and dataset


def get_accelerator(**accelerate_kwargs):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

kwargs_handlers = accelerate_kwargs.get("kwargs_handlers", [])
kwargs_handlers.append(ddp_kwargs)
accelerate_kwargs.update(kwargs_handlers=kwargs_handlers)

accelerator = Accelerator(**accelerate_kwargs)
return accelerator
def split_dataset(dataset, valid_frac, accelerator, seed=42):
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(dataset))
valid_size = len(dataset) - train_size
ds, valid_ds = random_split(
ds,
[train_size, valid_size],
generator=torch.Generator().manual_seed(seed),
)
accelerator.print(
f"training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples"
)
else:
valid_ds = ds
accelerator.print(
f"training with shared training and valid dataset of {len(ds)} samples"
)
return ds, valid_ds

# main trainer class

@beartype
class BaseAcceleratedTrainer(nn.Module):
def __init__(
self,
dataloader,
valid_dataloader,
*,
current_step,
num_train_steps,
batch_size,
max_grad_norm=None,
save_results_every=100,
save_model_every=1000,
results_dir="./results",
logging_dir="./results/logs",
apply_grad_penalty_every=4,
**accelerate_kwargs,
):
super().__init__()
self.model=None
# instantiate accelerator
self.gradient_accumulation_steps = accelerate_kwargs.gradient_accumulation_steps
self.accelerator = get_accelerator(**accelerate_kwargs)
self.results_dir = Path(results_dir)
if len([*self.results_dir.glob("**/*")]) > 0 and yes_or_no(
"do you want to clear previous experiment checkpoints and results?"
):
rmtree(str(self.results_dir))
self.results_dir.mkdir(parents=True, exist_ok=True)

self.logging_dir = Path(logging_dir)
self.logging_dir.mkdir(parents=True, exist_ok=True)

# training params

self.register_buffer("steps", torch.Tensor([current_step]))
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.max_grad_norm = max_grad_norm

self.dl = dataloader
self.valid_dl = valid_dataloader
self.dl_iter = cycle(self.dl)
self.valid_dl_iter = cycle(self.valid_dl)

self.save_model_every = save_model_every
self.save_results_every = save_results_every

self.apply_grad_penalty_every = apply_grad_penalty_every

def save(self, path):
if not self.is_local_main_process:
return

pkg = dict(
model=self.get_state_dict(self.model),
optim=self.optim.state_dict(),
)
torch.save(pkg, path)
def load(self, path):
path = Path(path)
assert path.exists()
pkg = torch.load(path)

model = self.accelerator.unwrap_model(self.model)
model.load_state_dict(pkg["model"])

self.optim.load_state_dict(pkg["optim"])
return pkg
def log_validation_images(self, images, step, prompt=None):
for tracker in self.accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, step, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}"+"" if prompt else f": {prompt}")
for i, image in enumerate(images)
]
}
)
def print(self, msg):
self.accelerator.print(msg)
def log(self, log_dict):
self.accelerator.log(log_dict)
def prepare(self, *args):
return self.accelerator.prepare(*args)
def get_state_dict(self, model):
return self.accelerator.get_state_dict(model)
def unwrap_model(self, model):
return self.accelerator.unwrap_model(model)
@property
def device(self):
return self.accelerator.device

@property
def is_distributed(self):
return not (
self.accelerator.distributed_type == DistributedType.NO
and self.accelerator.num_processes == 1
)

@property
def is_main(self):
return self.accelerator.is_main_process

@property
def is_local_main(self):
return self.accelerator.is_local_main_process

def train_step(self):
raise NotImplementedError("You are calling train_step on the base trainer with no models")


def train(self, log_fn=noop):
self.model.train()
while self.steps < self.num_train_steps:
with self.accelerator.autocast():
logs = self.train_step()
log_fn(logs)
self.writer.close()
self.print("training complete")

Loading

0 comments on commit e3ad84e

Please sign in to comment.