From f2bd0de684bf9a872ddd1f9aa71838642dcbd57d Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 11 Jun 2023 20:33:12 -0700 Subject: [PATCH 1/2] Added some extra text for clarification when trying to find the latest check for the maskgit on a folder. --- .../trainers/maskgit_trainer.py | 28 +++++++++---------- train_muse_maskgit.py | 10 +++---- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index d985233..9219b16 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -129,7 +129,7 @@ def save_validation_images( save_dir = self.results_dir.joinpath("MaskGit") save_dir.mkdir(exist_ok=True, parents=True) - save_file = save_dir.joinpath(f"maskgit_{step:04d}.png") + save_file = save_dir.joinpath(f"maskgit_{step}.png") if self.accelerator.is_main_process: save_image(images, save_file, "png") @@ -141,9 +141,9 @@ def train(self): self.model.train() if self.accelerator.is_main_process: - proc_label = f"[P{self.accelerator.process_index:03d}][Master]" + proc_label = f"[P{self.accelerator.process_index}][Master]" else: - proc_label = f"[P{self.accelerator.process_index:03d}][Worker]" + proc_label = f"[P{self.accelerator.process_index}][Worker]" # logs for epoch in range(self.num_epochs): @@ -174,7 +174,7 @@ def train(self): logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]} if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}") else: self.training_bar.update() @@ -184,7 +184,7 @@ def train(self): self.accelerator.log(logs, step=steps) if not (steps % self.save_model_every): - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() @@ -206,7 +206,7 @@ def train(self): if self.use_ema: self.accelerator.print( - f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: " + f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"saving EMA model to {self.results_dir}") ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() @@ -241,14 +241,14 @@ def train(self): self.validation_prompts, steps, cond_image=cond_image ) if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: saved to {saved_image}") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: metrics:") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:") @@ -256,7 +256,7 @@ def train(self): if self.num_train_steps > 0 and self.steps >= int(self.steps.item()): if self.on_tpu: - self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item()):05d}]{proc_label}" + self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" f"[STOP EARLY]: Stopping training early...") else: self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}" @@ -264,7 +264,7 @@ def train(self): break # loop complete, save final model - self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}[FINAL]: saving model to {self.results_dir}") + self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}[FINAL]: saving model to {self.results_dir}") state_dict = self.accelerator.unwrap_model(self.model).state_dict() maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" file_name = ( @@ -284,7 +284,7 @@ def train(self): if self.use_ema: self.accelerator.print( - f"\n[{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" + f"\n[{steps}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}" ) ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() file_name = ( @@ -309,9 +309,9 @@ def train(self): cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest") steps = int(self.steps.item()) + 1 # get the final step count, plus one - self.accelerator.print(f"\n[{steps:05d}]{proc_label}: Logging validation images") + self.accelerator.print(f"\n[{steps}]{proc_label}: Logging validation images") saved_image = self.save_validation_images(self.validation_prompts, steps, cond_image=cond_image) - self.accelerator.print(f"\n[{steps:05d}]{proc_label}: saved to {saved_image}") + self.accelerator.print(f"\n[{steps}]{proc_label}: saved to {saved_image}") if met is not None and not (steps % self.log_metrics_every): - self.accelerator.print(f"\n[{steps:05d}]{proc_label}: metrics:") + self.accelerator.print(f"\n[{steps}]{proc_label}: metrics:") diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 805c3b5..276b036 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -655,21 +655,21 @@ def main(): # Check if latest checkpoint is empty or unreadable if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): - accelerator.print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + accelerator.print(f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable.") if len(checkpoint_files) > 1: # Use the second last checkpoint as a fallback latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1) - accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) + accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file) else: - accelerator.print("No usable checkpoint found.") + accelerator.print("No usable MaskGit checkpoint found.") elif latest_checkpoint_file != orig_vae_path: accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file) else: - accelerator.print("Using checkpoint specified in resume_path: ", orig_vae_path) + accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path) args.resume_path = latest_checkpoint_file else: - accelerator.print("No checkpoints found in directory: ", args.resume_path) + accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path) else: accelerator.print("Resuming MaskGit from: ", args.resume_path) From b5070ef5ebdf377219e08b290f63c78fabfeb6e3 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Mon, 12 Jun 2023 03:26:04 -0700 Subject: [PATCH 2/2] Added option to specify the project name, run name and user/organization to use for trackers such as wandb. --- train_muse_maskgit.py | 32 +++++++++++++++++++++++++++++++- train_muse_vae.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index d98aa50..11bb4b2 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -17,6 +17,7 @@ import os import glob import re +import wandb from omegaconf import OmegaConf from accelerate.utils import ProjectConfiguration @@ -156,6 +157,25 @@ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) +parser.add_argument( + "--project_name", + type=str, + default="muse_maskgit", + help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."), +) +parser.add_argument( + "--run_name", + type=str, + default=None, + help=("Name to use for the run to identify it when saved to a tracker such" + " as wandb or tensorboard. If not specified a random one will be generated."), +) +parser.add_argument( + "--wandb_user", + type=str, + default=None, + help=("Specify the name for the user or the organization in which the project will be saved when using wand."), +) parser.add_argument( "--mixed_precision", type=str, @@ -878,7 +898,17 @@ def main(): ) if accelerator.is_main_process: - accelerator.init_trackers("muse_maskgit", config=vars(args)) + accelerator.init_trackers( + args.project_name, + config=vars(args), + init_kwargs={ + "wandb":{ + "entity": f"{args.wandb_user or wandb.api.default_entity}", + "name": args.run_name, + }, + } + + ) # Create the trainer accelerator.wait_for_everyone() diff --git a/train_muse_vae.py b/train_muse_vae.py index 912b84c..2380e62 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -19,6 +19,7 @@ import os import glob import re +import wandb from omegaconf import OmegaConf from accelerate.utils import ProjectConfiguration @@ -110,6 +111,25 @@ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) +parser.add_argument( + "--project_name", + type=str, + default="muse_vae", + help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."), +) +parser.add_argument( + "--run_name", + type=str, + default=None, + help=("Name to use for the run to identify it when saved to a tracker such" + " as wandb or tensorboard. If not specified a random one will be generated."), +) +parser.add_argument( + "--wandb_user", + type=str, + default=None, + help=("Specify the name for the user or the organization in which the project will be saved when using wand."), +) parser.add_argument( "--mixed_precision", type=str, @@ -389,7 +409,18 @@ def main(): even_batches=True ) if accelerator.is_main_process: - accelerator.init_trackers("muse_vae", config=vars(args)) + accelerator.init_trackers( + args.project_name, + config=vars(args), + init_kwargs={ + "wandb":{ + "entity": f"{args.wandb_user or wandb.api.default_entity}", + "name": args.run_name, + }, + } + + ) + if args.webdataset is not None: import webdataset as wds