Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroCool940711 committed Jun 21, 2023
2 parents 2b4f9dc + 3c3f12b commit 06b4e9b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion muse_maskgit_pytorch/t5.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from dataclasses import dataclass, field
from functools import cached_property
from os import PathLike
Expand All @@ -7,11 +8,11 @@
from beartype import beartype
from torch import Tensor
from transformers import T5Config, T5EncoderModel, T5Tokenizer
import warnings

# disable t5 warnings and a few others to keep the console clean and nice.
warnings.filterwarnings("ignore")


# dataclass for T5 model info
@dataclass
class T5ModelInfo:
Expand Down
12 changes: 9 additions & 3 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,9 @@ def main():
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
print(f"Warning: latest VAE checkpoint {latest_checkpoint_file} is empty or unreadable.")
print(
f"Warning: latest VAE checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(
Expand Down Expand Up @@ -744,14 +746,18 @@ def main():
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)),
)
accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file)
accelerator.print(
"Using second last MaskGit checkpoint: ", latest_checkpoint_file
)
else:
accelerator.print("No usable MaskGit checkpoint found.")
load = False
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path)
accelerator.print(
"Using MaskGit checkpoint specified in resume_path: ", orig_vae_path
)

args.resume_path = latest_checkpoint_file
else:
Expand Down

0 comments on commit 06b4e9b

Please sign in to comment.