diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index b8332cb..4d11423 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -12,7 +12,7 @@ import torch import transformers from accelerate.utils import ProjectConfiguration -from datasets import load_dataset +from datasets import concatenate_datasets, load_dataset from diffusers.optimization import SchedulerType, get_scheduler from omegaconf import OmegaConf from rich import inspect @@ -237,7 +237,8 @@ def decompress_pickle(file): "--dataset_name", type=str, default=None, - help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir)", + help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir, use multiple by splitting with '|', " + "they must have the same image column and text column)", ) parser.add_argument( "--hf_split_name", @@ -605,18 +606,49 @@ def main(): save_path=args.dataset_save_path, ) elif args.dataset_name is not None: - dataset = load_dataset( - args.dataset_name, - streaming=args.streaming, - cache_dir=args.cache_path, - save_infos=True, - split="train", - ) - if args.streaming: - if args.cache_path: - dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name] - else: - dataset = load_dataset(args.dataset_name)[args.hf_split_name] + if "|" in args.dataset_name: + loaded_datasets = [] + for name in args.dataset_name.split("|"): + accelerator.print(f"Loading {name}") + data_to_add = load_dataset( + name, + streaming=args.streaming, + cache_dir=args.cache_path, + save_infos=True, + split="train", + ) + + data_to_add.remove_columns( + [ + col + for col in data_to_add.column_names + if col != args.caption_column or col != args.image_column + ] + ) + + loaded_datasets.append(data_to_add) + + try: + dataset = concatenate_datasets(loaded_datasets) + except ValueError: + raise UserWarning("Failed concatenating dataset... Make sure they use the same columns!") + + else: + dataset = load_dataset( + args.dataset_name, + streaming=args.streaming, + cache_dir=args.cache_path, + save_infos=True, + split="train", + ) + + if args.streaming: + if args.cache_path: + dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[ + args.hf_split_name + ] + else: + dataset = load_dataset(args.dataset_name)[args.hf_split_name] else: raise ValueError("You must pass either train_data_dir or dataset_name (but not both)") diff --git a/train_muse_vae.py b/train_muse_vae.py index c336ccf..ff280b0 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -3,11 +3,11 @@ from dataclasses import dataclass from typing import Optional, Union -import wandb from accelerate.utils import ProjectConfiguration -from datasets import load_dataset, Dataset, Image +from datasets import Dataset, Image, concatenate_datasets, load_dataset from omegaconf import OmegaConf +import wandb from muse_maskgit_pytorch import ( VQGanVAE, VQGanVAETaming, @@ -163,7 +163,8 @@ "--dataset_name", type=str, default=None, - help="Name of the huggingface dataset used.", + help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir, use multiple by splitting with '|', " + "they must have the same image column and text column)", ) parser.add_argument( "--hf_split_name", @@ -409,7 +410,7 @@ def main(): args = parser.parse_args(namespace=Arguments()) if args.config_path: - accelerator.print("Using config file and ignoring CLI args") + print("Using config file and ignoring CLI args") try: conf = OmegaConf.load(args.config_path) @@ -420,10 +421,10 @@ def main(): try: args_to_convert[key] = conf[key] except KeyError: - accelerator.print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed") + print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed") except FileNotFoundError: - accelerator.print("Could not find config, using default and parsed values...") + print("Could not find config, using default and parsed values...") project_config = ProjectConfiguration( project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"), @@ -464,18 +465,43 @@ def main(): save=not args.no_cache, ) elif args.dataset_name: - if args.cache_path: - dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[ - "train" - ] + if "|" in args.dataset_name: + loaded_datasets = [] + for name in args.dataset_name.split("|"): + accelerator.print(f"Loading {name}") + data_to_add = load_dataset( + name, + streaming=args.streaming, + cache_dir=args.cache_path, + save_infos=True, + split="train", + ) + + data_to_add.remove_columns( + [ + col + for col in data_to_add.column_names + if col != args.caption_column or col != args.image_column + ] + ) + + loaded_datasets.append(data_to_add) + + try: + dataset = concatenate_datasets(loaded_datasets) + except ValueError: + raise UserWarning("Failed concatenating dataset... Make sure they use the same columns!") + else: - dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[ - "train" - ] - if args.streaming: - if dataset.info.dataset_size is None: - accelerator.print("Dataset doesn't support streaming, disabling streaming") - args.streaming = False + dataset = load_dataset( + args.dataset_name, + streaming=args.streaming, + cache_dir=args.cache_path, + save_infos=True, + split="train", + ) + + if args.streaming: if args.cache_path: dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name] else: @@ -610,7 +636,9 @@ def main(): filepaths.append(os.path.join(root, file)) if not filepaths: - print(f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}.") + print( + f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}." + ) exit(1) epoch_validation_dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image())