Skip to content

Commit

Permalink
add generator and discriminator log info
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 committed Jan 9, 2024
1 parent 25c60ba commit 89a4a5d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 0 additions & 1 deletion wenet/tts/vits/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,6 @@ def forward(self, batch: dict, device: torch.device):
'loss_dur': loss_dur,
'loss_kl': loss_kl,
}
print('optimizer_idx', optimizer_idx)
return losses

def infer(self, text: torch.Tensor):
Expand Down
3 changes: 3 additions & 0 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def train(self, model, optimizer, scheduler, train_data_loader,
else:
context = nullcontext
num_opt = len(optimizer) if isinstance(optimizer, list) else 1
loss_dict = {}
for opt_idx in range(num_opt):
batch_dict['optimizer_idx'] = opt_idx
with context():
Expand All @@ -84,6 +85,8 @@ def train(self, model, optimizer, scheduler, train_data_loader,
scaler,
info_dict,
)
loss_dict.update(info_dict['loss_dict'])
info_dict['loss_dict'] = loss_dict
save_interval = info_dict.get('save_interval', 10000)
if self.step % save_interval == 0 and self.step != 0 \
and (batch_idx + 1) % info_dict["accum_grad"] == 0:
Expand Down

0 comments on commit 89a4a5d

Please sign in to comment.