Skip to content

Commit

Permalink
Small QOL when using the validation_folder_at_end_of_epoch argument, …
Browse files Browse the repository at this point in the history
…we now append the epoch number to the filename so we can identify at what epoch the image was saved. (Sygil-Dev#82)

- Added experimental t5 offloading to CPU.
  • Loading branch information
ZeroCool940711 authored Dec 15, 2023
2 parents a1ca68d + 4a148a6 commit 5c7b272
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
5 changes: 4 additions & 1 deletion muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
clear_previous_experiments=False,
validation_image_scale: float = 1.0,
only_save_last_checkpoint=False,
t5_offloading=False,
args=None,
):
super().__init__(
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(
# we are going to use them later to save them to a config file.
self.args = args

self.t5_offloading = t5_offloading

# maskgit
maskgit.vae.requires_grad_(False)
maskgit.transformer.t5.requires_grad_(False)
Expand Down Expand Up @@ -161,7 +164,7 @@ def train(self):
if not text_embeds:
with torch.no_grad():
text_embeds = t5_encode_text_from_encoded(
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device if not self.t5_offloading else 'cpu'
)

with self.accelerator.accumulate(self.model), self.accelerator.autocast():
Expand Down
2 changes: 1 addition & 1 deletion muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name=
now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}{'-' + epoch if epoch else ''}-{os.path.basename(checkpoint_name)}.png"
filename = f"{str(hash)}_{str(now)}{'-'}{'E' + str(epoch) if epoch is not None else ''}-{str(os.path.basename(checkpoint_name))}.png"
grid_image.save(f"{output_dir}/{filename}", format="PNG")

if not save_originals:
Expand Down
8 changes: 8 additions & 0 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,12 @@ def decompress_pickle(file):
default="",
help="The path to save or load embeds",
)
parser.add_argument(
"--t5_offloading",
action="store_true",
default=False,
help="Wheter to offload the t5 model to cpu instad of loading it on GPU. Should help with loading bigger models but will affect performance.",
)


@dataclass
Expand Down Expand Up @@ -529,6 +535,7 @@ class Arguments:
attention_type: str = "flash"
precompute: bool = False
precompute_path: str = ""
t5_offloading = False


def main():
Expand Down Expand Up @@ -1016,6 +1023,7 @@ def main():
validation_image_scale=args.validation_image_scale,
only_save_last_checkpoint=args.only_save_last_checkpoint,
num_epochs=args.num_epochs,
t5_offloading=args.t5_offloading,
args=args,
)

Expand Down

0 comments on commit 5c7b272

Please sign in to comment.