From 2b4f9dc32b05c8349d662f03d982879d6c6a3a76 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Wed, 21 Jun 2023 13:48:44 -0700 Subject: [PATCH] Improved the epoch counter to also take into account the number of epochs trained before based on the current number of steps when resuming from a checkpoint. --- muse_maskgit_pytorch/trainers/maskgit_trainer.py | 3 ++- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index 69eaeb5..3f6e7bc 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -79,6 +79,7 @@ def __init__( self.save_results_every = save_results_every self.log_metrics_every = log_metrics_every self.batch_size = batch_size + self.current_step = current_step # arguments used for the training script, # we are going to use them later to save them to a config file. @@ -147,7 +148,7 @@ def train(self): proc_label = f"[P{self.accelerator.process_index}][Worker]" # logs - for epoch in range(self.num_epochs): + for epoch in range(self.current_step // len(self.dl), self.num_epochs): for imgs, input_ids, attn_mask in iter(self.dl): train_loss = 0.0 steps = int(self.steps.item()) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index f857171..fdd848f 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -91,6 +91,8 @@ def __init__( # we are going to use them later to save them to a config file. self.args = args + self.current_step = current_step + # vae self.model = vae @@ -220,7 +222,7 @@ def train(self): else: proc_label = f"[P{self.accelerator.process_index:03d}][Worker]" - for epoch in range(self.num_epochs): + for epoch in range(self.current_step // len(self.dl), self.num_epochs): for img in self.dl: loss = 0.0 steps = int(self.steps.item())