Skip to content

Commit

Permalink
Added lr schedulers
Browse files Browse the repository at this point in the history
  • Loading branch information
isamu-isozaki committed Feb 27, 2023
1 parent 7ebb04f commit 7739c45
Show file tree
Hide file tree
Showing 12 changed files with 1,001 additions and 597 deletions.
14 changes: 12 additions & 2 deletions muse_maskgit_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic
from muse_maskgit_pytorch.muse_maskgit_pytorch import (
Transformer,
MaskGit,
Muse,
MaskGitTransformer,
TokenCritic,
)

from muse_maskgit_pytorch.trainers import VQGanVAETrainer, MaskGitTrainer, get_accelerator
from muse_maskgit_pytorch.trainers import (
VQGanVAETrainer,
MaskGitTrainer,
get_accelerator,
)
46 changes: 31 additions & 15 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from torch.utils.data import Dataset, DataLoader, random_split
import os
from tqdm import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True


class ImageDataset(Dataset):
def __init__(self, dataset, image_size, image_column="image"):
super().__init__()
Expand All @@ -31,11 +33,19 @@ def __len__(self):
return len(self.dataset)

def __getitem__(self, index):
image= self.dataset[index][self.image_column]
image = self.dataset[index][self.image_column]
return self.transform(image)


class ImageTextDataset(ImageDataset):
def __init__(self, dataset, image_size, tokenizer, image_column="image", caption_column="caption"):
def __init__(
self,
dataset,
image_size,
tokenizer,
image_column="image",
caption_column="caption",
):
super().__init__(dataset, image_size=image_size, image_column=image_column)
self.caption_column = caption_column
self.tokenizer = tokenizer
Expand Down Expand Up @@ -65,7 +75,10 @@ def __getitem__(self, index):
attn_mask = encoded.attention_mask
return self.transform(image), input_ids[0], attn_mask[0]

def get_dataset_from_dataroot(data_root, image_column="image", caption_column="caption", save_path="dataset"):

def get_dataset_from_dataroot(
data_root, image_column="image", caption_column="caption", save_path="dataset"
):
if os.path.exists(save_path):
return load_from_disk(save_path)
image_paths = list(Path(data_root).rglob("*.[jJ][pP][gG]"))
Expand All @@ -74,7 +87,7 @@ def get_dataset_from_dataroot(data_root, image_column="image", caption_column="c
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(".txt")
if os.path.exists(str(caption_path)):
captions = caption_path.read_text(encoding="utf-8").split('\n')
captions = caption_path.read_text(encoding="utf-8").split("\n")
captions = list(filter(lambda t: len(t) > 0, captions))
else:
captions = []
Expand All @@ -86,24 +99,27 @@ def get_dataset_from_dataroot(data_root, image_column="image", caption_column="c
dataset.save_to_disk(save_path)
return dataset


def split_dataset_into_dataloaders(dataset, valid_frac=0.05, seed=42, batch_size=1):
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(dataset))
valid_size = len(dataset) - train_size
dataset, validation_dataset = random_split(dataset, [train_size, valid_size], generator = torch.Generator().manual_seed(seed))
print(f'training with dataset of {len(dataset)} samples and validating with randomly splitted {len(validation_dataset)} samples')
dataset, validation_dataset = random_split(
dataset,
[train_size, valid_size],
generator=torch.Generator().manual_seed(seed),
)
print(
f"training with dataset of {len(dataset)} samples and validating with randomly splitted {len(validation_dataset)} samples"
)
else:
validation_dataset = dataset
print(f'training with shared training and valid dataset of {len(dataset)} samples')
dataloader = DataLoader(
dataset,
batch_size = batch_size,
shuffle = True
)
print(
f"training with shared training and valid dataset of {len(dataset)} samples"
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

validation_dataloader = DataLoader(
validation_dataset,
batch_size = batch_size,
shuffle = True
validation_dataset, batch_size=batch_size, shuffle=True
)
return dataloader, validation_dataloader
Loading

0 comments on commit 7739c45

Please sign in to comment.