Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#48 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Improved the epoch counter to also take into account the number of epochs trained before based on the current number of steps when resuming from a checkpoint.
  • Loading branch information
ZeroCool940711 authored Jun 21, 2023
2 parents 9d0d07e + f05affa commit 5a77082
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
self.save_results_every = save_results_every
self.log_metrics_every = log_metrics_every
self.batch_size = batch_size
self.current_step = current_step

# arguments used for the training script,
# we are going to use them later to save them to a config file.
Expand Down Expand Up @@ -147,7 +148,7 @@ def train(self):
proc_label = f"[P{self.accelerator.process_index}][Worker]"

# logs
for epoch in range(self.num_epochs):
for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for imgs, input_ids, attn_mask in iter(self.dl):
train_loss = 0.0
steps = int(self.steps.item())
Expand Down
4 changes: 3 additions & 1 deletion muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
# we are going to use them later to save them to a config file.
self.args = args

self.current_step = current_step

# vae
self.model = vae

Expand Down Expand Up @@ -220,7 +222,7 @@ def train(self):
else:
proc_label = f"[P{self.accelerator.process_index:03d}][Worker]"

for epoch in range(self.num_epochs):
for epoch in range(self.current_step // len(self.dl), self.num_epochs):
for img in self.dl:
loss = 0.0
steps = int(self.steps.item())
Expand Down

0 comments on commit 5a77082

Please sign in to comment.