Skip to content

Commit

Permalink
support multiple datasets at once
Browse files Browse the repository at this point in the history
they can be split like the validation image prompts

the text and image columns **MUST** match
  • Loading branch information
korakoe committed Oct 2, 2023
1 parent 80f8a10 commit 8f5f1e3
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 32 deletions.
60 changes: 46 additions & 14 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import transformers
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from datasets import concatenate_datasets, load_dataset
from diffusers.optimization import SchedulerType, get_scheduler
from omegaconf import OmegaConf
from rich import inspect
Expand Down Expand Up @@ -237,7 +237,8 @@ def decompress_pickle(file):
"--dataset_name",
type=str,
default=None,
help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir)",
help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir, use multiple by splitting with '|', "
"they must have the same image column and text column)",
)
parser.add_argument(
"--hf_split_name",
Expand Down Expand Up @@ -605,18 +606,49 @@ def main():
save_path=args.dataset_save_path,
)
elif args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)
if args.streaming:
if args.cache_path:
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name]
else:
dataset = load_dataset(args.dataset_name)[args.hf_split_name]
if "|" in args.dataset_name:
loaded_datasets = []
for name in args.dataset_name.split("|"):
accelerator.print(f"Loading {name}")
data_to_add = load_dataset(
name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

data_to_add.remove_columns(
[
col
for col in data_to_add.column_names
if col != args.caption_column or col != args.image_column
]
)

loaded_datasets.append(data_to_add)

try:
dataset = concatenate_datasets(loaded_datasets)
except ValueError:
raise UserWarning("Failed concatenating dataset... Make sure they use the same columns!")

else:
dataset = load_dataset(
args.dataset_name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

if args.streaming:
if args.cache_path:
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[
args.hf_split_name
]
else:
dataset = load_dataset(args.dataset_name)[args.hf_split_name]
else:
raise ValueError("You must pass either train_data_dir or dataset_name (but not both)")

Expand Down
64 changes: 46 additions & 18 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from dataclasses import dataclass
from typing import Optional, Union

import wandb
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset, Dataset, Image
from datasets import Dataset, Image, concatenate_datasets, load_dataset
from omegaconf import OmegaConf

import wandb
from muse_maskgit_pytorch import (
VQGanVAE,
VQGanVAETaming,
Expand Down Expand Up @@ -163,7 +163,8 @@
"--dataset_name",
type=str,
default=None,
help="Name of the huggingface dataset used.",
help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir, use multiple by splitting with '|', "
"they must have the same image column and text column)",
)
parser.add_argument(
"--hf_split_name",
Expand Down Expand Up @@ -409,7 +410,7 @@ def main():
args = parser.parse_args(namespace=Arguments())

if args.config_path:
accelerator.print("Using config file and ignoring CLI args")
print("Using config file and ignoring CLI args")

try:
conf = OmegaConf.load(args.config_path)
Expand All @@ -420,10 +421,10 @@ def main():
try:
args_to_convert[key] = conf[key]
except KeyError:
accelerator.print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed")
print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed")

except FileNotFoundError:
accelerator.print("Could not find config, using default and parsed values...")
print("Could not find config, using default and parsed values...")

project_config = ProjectConfiguration(
project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
Expand Down Expand Up @@ -464,18 +465,43 @@ def main():
save=not args.no_cache,
)
elif args.dataset_name:
if args.cache_path:
dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[
"train"
]
if "|" in args.dataset_name:
loaded_datasets = []
for name in args.dataset_name.split("|"):
accelerator.print(f"Loading {name}")
data_to_add = load_dataset(
name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

data_to_add.remove_columns(
[
col
for col in data_to_add.column_names
if col != args.caption_column or col != args.image_column
]
)

loaded_datasets.append(data_to_add)

try:
dataset = concatenate_datasets(loaded_datasets)
except ValueError:
raise UserWarning("Failed concatenating dataset... Make sure they use the same columns!")

else:
dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[
"train"
]
if args.streaming:
if dataset.info.dataset_size is None:
accelerator.print("Dataset doesn't support streaming, disabling streaming")
args.streaming = False
dataset = load_dataset(
args.dataset_name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

if args.streaming:
if args.cache_path:
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name]
else:
Expand Down Expand Up @@ -610,7 +636,9 @@ def main():
filepaths.append(os.path.join(root, file))

if not filepaths:
print(f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}.")
print(
f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}."
)
exit(1)

epoch_validation_dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image())
Expand Down

0 comments on commit 8f5f1e3

Please sign in to comment.