Skip to content

Commit

Permalink
Fix the error of BLSTM forward state and force turn off use_dynamic_c…
Browse files Browse the repository at this point in the history
…hunk during bias module training
  • Loading branch information
kaixunhuang0 committed Sep 21, 2023
1 parent d009779 commit 762e199
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,16 @@ def main():
num_params = sum(p.numel() for p in model.parameters())
print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa

# Freeze other parts of the model during training context bias module
if 'context_module_conf' in configs:
# Freeze other parts of the model during training context bias module
for p in model.parameters():
p.requires_grad = False
for p in model.context_module.parameters():
p.requires_grad = True
for p in model.context_module.context_decoder_ctc_linear.parameters():
p.requires_grad = False
# Turn off dynamic chunk because it will affect the training of bias
model.encoder.use_dynamic_chunk = False

# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
Expand Down
4 changes: 2 additions & 2 deletions wenet/transformer/context_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def forward(self, sen_batch, sen_lengths):
_, last_state = self.sen_rnn(pack_seq)
laste_h = last_state[0]
laste_c = last_state[1]
state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :],
laste_c[-1, :, :], laste_c[0, :, :]], dim=-1)
state = torch.cat([laste_h[-1, :, :], laste_h[-2, :, :],
laste_c[-1, :, :], laste_c[-2, :, :]], dim=-1)
return state


Expand Down

0 comments on commit 762e199

Please sign in to comment.