Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#21 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Added option to specify the project name, run name and user/organization to use for trackers such as wandb.
  • Loading branch information
ZeroCool940711 authored Jun 12, 2023
2 parents b1d4e5a + b5070ef commit ddca0f9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 21 deletions.
28 changes: 14 additions & 14 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def save_validation_images(

save_dir = self.results_dir.joinpath("MaskGit")
save_dir.mkdir(exist_ok=True, parents=True)
save_file = save_dir.joinpath(f"maskgit_{step:04d}.png")
save_file = save_dir.joinpath(f"maskgit_{step}.png")

if self.accelerator.is_main_process:
save_image(images, save_file, "png")
Expand All @@ -141,9 +141,9 @@ def train(self):
self.model.train()

if self.accelerator.is_main_process:
proc_label = f"[P{self.accelerator.process_index:03d}][Master]"
proc_label = f"[P{self.accelerator.process_index}][Master]"
else:
proc_label = f"[P{self.accelerator.process_index:03d}][Worker]"
proc_label = f"[P{self.accelerator.process_index}][Worker]"

# logs
for epoch in range(self.num_epochs):
Expand Down Expand Up @@ -174,7 +174,7 @@ def train(self):
logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]}

if self.on_tpu:
self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: "
self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: "
f"maskgit loss: {logs['loss']} - lr: {logs['lr']}")
else:
self.training_bar.update()
Expand All @@ -184,7 +184,7 @@ def train(self):
self.accelerator.log(logs, step=steps)

if not (steps % self.save_model_every):
self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: "
self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: "
f"saving model to {self.results_dir}")

state_dict = self.accelerator.unwrap_model(self.model).state_dict()
Expand All @@ -206,7 +206,7 @@ def train(self):

if self.use_ema:
self.accelerator.print(
f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: "
f"\n[E{epoch + 1}][{steps}]{proc_label}: "
f"saving EMA model to {self.results_dir}")

ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict()
Expand Down Expand Up @@ -241,30 +241,30 @@ def train(self):
self.validation_prompts, steps, cond_image=cond_image
)
if self.on_tpu:
self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: saved to {saved_image}")
self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: saved to {saved_image}")
else:
self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: "
f"saved to {saved_image}")

if met is not None and not (steps % self.log_metrics_every):
if self.on_tpu:
self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}: metrics:")
self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: metrics:")
else:
self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:")

self.steps += 1

if self.num_train_steps > 0 and self.steps >= int(self.steps.item()):
if self.on_tpu:
self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item()):05d}]{proc_label}"
self.accelerator.print(f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}"
f"[STOP EARLY]: Stopping training early...")
else:
self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}"
f"[STOP EARLY]: Stopping training early...")
break

# loop complete, save final model
self.accelerator.print(f"\n[E{epoch + 1}][{steps:05d}]{proc_label}[FINAL]: saving model to {self.results_dir}")
self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}[FINAL]: saving model to {self.results_dir}")
state_dict = self.accelerator.unwrap_model(self.model).state_dict()
maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit"
file_name = (
Expand All @@ -284,7 +284,7 @@ def train(self):

if self.use_ema:
self.accelerator.print(
f"\n[{steps:05d}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}"
f"\n[{steps}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}"
)
ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict()
file_name = (
Expand All @@ -309,9 +309,9 @@ def train(self):
cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest")

steps = int(self.steps.item()) + 1 # get the final step count, plus one
self.accelerator.print(f"\n[{steps:05d}]{proc_label}: Logging validation images")
self.accelerator.print(f"\n[{steps}]{proc_label}: Logging validation images")
saved_image = self.save_validation_images(self.validation_prompts, steps, cond_image=cond_image)
self.accelerator.print(f"\n[{steps:05d}]{proc_label}: saved to {saved_image}")
self.accelerator.print(f"\n[{steps}]{proc_label}: saved to {saved_image}")

if met is not None and not (steps % self.log_metrics_every):
self.accelerator.print(f"\n[{steps:05d}]{proc_label}: metrics:")
self.accelerator.print(f"\n[{steps}]{proc_label}: metrics:")
42 changes: 36 additions & 6 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import glob
import re
import wandb

from omegaconf import OmegaConf
from accelerate.utils import ProjectConfiguration
Expand Down Expand Up @@ -156,6 +157,25 @@
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--project_name",
type=str,
default="muse_maskgit",
help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."),
)
parser.add_argument(
"--run_name",
type=str,
default=None,
help=("Name to use for the run to identify it when saved to a tracker such"
" as wandb or tensorboard. If not specified a random one will be generated."),
)
parser.add_argument(
"--wandb_user",
type=str,
default=None,
help=("Specify the name for the user or the organization in which the project will be saved when using wand."),
)
parser.add_argument(
"--mixed_precision",
type=str,
Expand Down Expand Up @@ -673,21 +693,21 @@ def main():

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK):
accelerator.print(f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.")
accelerator.print(f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt$', x).group(1)) if not x.endswith('ema.pt') else -1)
accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable checkpoint found.")
accelerator.print("No usable MaskGit checkpoint found.")
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using checkpoint specified in resume_path: ", orig_vae_path)
accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path)

args.resume_path = latest_checkpoint_file
else:
accelerator.print("No checkpoints found in directory: ", args.resume_path)
accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path)
else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)

Expand Down Expand Up @@ -878,7 +898,17 @@ def main():
)

if accelerator.is_main_process:
accelerator.init_trackers("muse_maskgit", config=vars(args))
accelerator.init_trackers(
args.project_name,
config=vars(args),
init_kwargs={
"wandb":{
"entity": f"{args.wandb_user or wandb.api.default_entity}",
"name": args.run_name,
},
}

)

# Create the trainer
accelerator.wait_for_everyone()
Expand Down
33 changes: 32 additions & 1 deletion train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import glob
import re
import wandb

from omegaconf import OmegaConf
from accelerate.utils import ProjectConfiguration
Expand Down Expand Up @@ -110,6 +111,25 @@
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--project_name",
type=str,
default="muse_vae",
help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."),
)
parser.add_argument(
"--run_name",
type=str,
default=None,
help=("Name to use for the run to identify it when saved to a tracker such"
" as wandb or tensorboard. If not specified a random one will be generated."),
)
parser.add_argument(
"--wandb_user",
type=str,
default=None,
help=("Specify the name for the user or the organization in which the project will be saved when using wand."),
)
parser.add_argument(
"--mixed_precision",
type=str,
Expand Down Expand Up @@ -389,7 +409,18 @@ def main():
even_batches=True
)
if accelerator.is_main_process:
accelerator.init_trackers("muse_vae", config=vars(args))
accelerator.init_trackers(
args.project_name,
config=vars(args),
init_kwargs={
"wandb":{
"entity": f"{args.wandb_user or wandb.api.default_entity}",
"name": args.run_name,
},
}

)

if args.webdataset is not None:
import webdataset as wds

Expand Down

0 comments on commit ddca0f9

Please sign in to comment.