Skip to content

Commit

Permalink
Added experimental t5 offloading to CPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroCool940711 committed Dec 15, 2023
1 parent 6c7d258 commit bbcf752
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 4 additions & 1 deletion muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
clear_previous_experiments=False,
validation_image_scale: float = 1.0,
only_save_last_checkpoint=False,
t5_offloading=False,
args=None,
):
super().__init__(
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(
# we are going to use them later to save them to a config file.
self.args = args

self.t5_offloading = t5_offloading

# maskgit
maskgit.vae.requires_grad_(False)
maskgit.transformer.t5.requires_grad_(False)
Expand Down Expand Up @@ -161,7 +164,7 @@ def train(self):
if not text_embeds:
with torch.no_grad():
text_embeds = t5_encode_text_from_encoded(
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device if not self.t5_offloading else 'cpu'
)

with self.accelerator.accumulate(self.model), self.accelerator.autocast():
Expand Down
8 changes: 8 additions & 0 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ def decompress_pickle(file):
default="",
help="The path to save or load embeds",
)
parser.add_argument(
"--t5_offloading",
action="store_true",
default=False,
help="Wheter to offload the t5 model to cpu instad of loading it on GPU. Should help with loading bigger models but will affect performance.",
)


@dataclass
Expand Down Expand Up @@ -528,6 +534,7 @@ class Arguments:
attention_type: str = "flash"
precompute: bool = False
precompute_path: str = ""
t5_offloading = False


def main():
Expand Down Expand Up @@ -984,6 +991,7 @@ def main():
validation_image_scale=args.validation_image_scale,
only_save_last_checkpoint=args.only_save_last_checkpoint,
num_epochs=args.num_epochs,
t5_offloading=args.t5_offloading,
args=args,
)

Expand Down

0 comments on commit bbcf752

Please sign in to comment.