Skip to content

Commit

Permalink
Dataset Streaming, remove duplicate args and spelling
Browse files Browse the repository at this point in the history
  • Loading branch information
korakoe committed Mar 26, 2023
1 parent f955fc9 commit 6892924
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
4 changes: 2 additions & 2 deletions muse_maskgit_pytorch/trainers/base_accelerated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ema_pytorch import EMA
from torch.optim import Adam, AdamW
from lion_pytorch import Lion
from torch_optimizer import Adafactor

import numpy as np

Expand Down Expand Up @@ -116,11 +117,10 @@ def get_optimizer(use_8bit_adam, optimizer, parameters, lr, weight_decay):
optim = bnb.optim.AdamW8bit(parameters, lr=lr, weight_decay=weight_decay)
else:
optim = AdamW(parameters, lr=lr, weight_decay=weight_decay)

elif optimizer == "Lion":
optim = Lion(parameters, lr=lr, weight_decay=weight_decay)
if use_8bit_adam:
print("8bit is not supported by the Lion optimiser, Using standard Lion instead.")
print("8bit is not supported by the Lion optimizer, Using standard Lion instead.")
else:
raise NotImplementedError(f"{optimizer} optimizer not supported yet.")
return optim
Expand Down
7 changes: 6 additions & 1 deletion train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ def parse_args():
default=None,
help="Name of the huggingface dataset used.",
)
parser.add_argument(
"--streaming",
action="store_true",
help="Whether to stream the huggingface dataset",
)
parser.add_argument(
"--train_data_dir",
type=str,
Expand Down Expand Up @@ -285,7 +290,7 @@ def main():
save_path=args.dataset_save_path,
)
elif args.dataset_name:
dataset = load_dataset(args.dataset_name)["train"]
dataset = load_dataset(args.dataset_name, streaming=args.streaming)["train"]
if args.vae_path and args.taming_model_path:
raise Exception("You can't pass vae_path and taming args at the same time.")

Expand Down
12 changes: 6 additions & 6 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def parse_args():
default=None,
help="Name of the huggingface dataset used.",
)
parser.add_argument(
"--streaming",
action="store_true",
help="Whether to stream the huggingface dataset",
)
parser.add_argument(
"--train_data_dir",
type=str,
Expand Down Expand Up @@ -200,11 +205,6 @@ def parse_args():
default=None,
help="Path to the last saved checkpoint. 'results/vae.steps.pt'",
)
parser.add_argument(
"--optimizer",type=str,
default='Lion',
help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Adam",
)
parser.add_argument(
"--weight_decay", type=float,
default=0.0,
Expand Down Expand Up @@ -267,7 +267,7 @@ def main():
save_path=args.dataset_save_path,
)
elif args.dataset_name:
dataset = load_dataset(args.dataset_name)["train"]
dataset = load_dataset(args.dataset_name, streaming=args.streaming)["train"]

vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size)
if args.taming_model_path:
Expand Down

0 comments on commit 6892924

Please sign in to comment.