diff --git a/setup.py b/setup.py index f6b71a5..7619c11 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "omegaconf", "xformers>=0.0.20", "wandb", + "bz2file", ], classifiers=[ "Development Status :: 4 - Beta", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 30118d8..b8332cb 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -1,10 +1,12 @@ import argparse import logging import os +import pickle from dataclasses import dataclass from typing import Optional, Union import accelerate +import bz2file as bz2 import datasets import diffusers import torch @@ -48,6 +50,18 @@ ) from muse_maskgit_pytorch.trainers.base_accelerated_trainer import get_optimizer + +def compressed_pickle(title, data): + with bz2.BZ2File(title, "w") as f: + pickle.dump(data, f) + + +def decompress_pickle(file): + data = bz2.BZ2File(file, "rb") + data = pickle.load(data) + return data + + # remove some unnecessary errors from transformer shown on the console. transformers.logging.set_verbosity_error() @@ -431,7 +445,13 @@ "--precompute", action="store_true", default=False, - help="whether to precompute text embeds", + help="whether to precompute text embeds (only use if we wan to compute, not load)", +) +parser.add_argument( + "--precompute_path", + type=str, + default="", + help="The path to save or load embeds", ) @@ -507,6 +527,7 @@ class Arguments: config_path: Optional[str] = None attention_type: str = "flash" precompute: bool = False + precompute_path: str = "" def main(): @@ -757,6 +778,11 @@ def main(): accelerator.print(f"Total number of parameters: {format(total_params, ',d')}") + if args.precompute_path and not args.precompute: + embeds = decompress_pickle(args.precompute_path) + else: + embeds = [] + # Create the dataset objects with accelerator.main_process_first(): if args.no_cache and args.train_data_dir: @@ -769,6 +795,7 @@ def main(): using_taming=False if not args.taming_model_path else True, random_crop=args.random_crop if args.random_crop else False, alpha_channel=False if args.channels == 3 else True, + embeds=embeds, ) elif args.link: if not args.dataset_name: @@ -783,6 +810,7 @@ def main(): center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, using_taming=False if not args.taming_model_path else True, + embeds=embeds, ) else: dataset = ImageTextDataset( @@ -795,6 +823,7 @@ def main(): flip=False if args.no_flip else True, stream=args.streaming, using_taming=False if not args.taming_model_path else True, + embeds=embeds, ) # Create the dataloaders @@ -871,6 +900,9 @@ def main(): embedding = t5_encode_text_from_encoded(input_ids, attn_mask, maskgit.transformer.t5, "cpu") embeds.append(embedding) + if args.precompute_path: + compressed_pickle(args.precompute_path, embeds) + with accelerator.main_process_first(): if args.no_cache and args.train_data_dir: dataset = LocalTextImageDataset(