Skip to content

Commit

Permalink
refactor(yaml): try to pass unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Dec 8, 2023
1 parent 1e21195 commit 6e2941e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 138 deletions.
14 changes: 11 additions & 3 deletions wenet/ctl_model/asr_model_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,19 @@ def __init__(
logit_temp: float = 0.1,
n_negatives: int = 0,
ctl_weight: float = 1,
special_tokens: dict = None,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight,
ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
super().__init__(vocab_size,
encoder,
decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
special_tokens=special_tokens)

# For CTL Loss
self.n_negatives = n_negatives
Expand Down
24 changes: 17 additions & 7 deletions wenet/k2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,29 @@ def __init__(
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
lfmmi_dir: str = '',
special_tokens: dict = None,
):
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight,
ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
super().__init__(vocab_size,
encoder,
decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
special_tokens=special_tokens)
self.lfmmi_dir = lfmmi_dir
if self.lfmmi_dir != '':
self.load_lfmmi_resource()

@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, text)
def _forward_ctc(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask,
text)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
Expand Down
14 changes: 11 additions & 3 deletions wenet/transducer/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,19 @@ def __init__(
warmup_steps: float = 25000,
lm_only_scale: float = 0.25,
am_only_scale: float = 0.0,
special_tokens: dict = None,
) -> None:
assert attention_weight + ctc_weight + transducer_weight == 1.0
super().__init__(vocab_size, encoder, attention_decoder, ctc,
ctc_weight, ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
super().__init__(vocab_size,
encoder,
attention_decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
special_tokens=special_tokens)

self.blank = blank
self.transducer_weight = transducer_weight
Expand Down
13 changes: 8 additions & 5 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def init_model(args, configs):
attention_decoder=decoder,
joint=joint,
ctc=ctc,
special_tokens=configs['tokenizer_conf']['special_tokens'],
**configs['model_conf'])
elif configs['model'] == 'paraformer':
""" NOTE(Mddct): support fintune paraformer, if there is a need for
Expand All @@ -134,11 +135,13 @@ def init_model(args, configs):
print(configs)
return model, configs
else:
model = WENET_MODEL_CLASSES[configs['model']](vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
model = WENET_MODEL_CLASSES[configs['model']](
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
special_tokens=configs['tokenizer_conf']['special_tokens'],
**configs['model_conf'])

# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:
Expand Down
120 changes: 0 additions & 120 deletions wenet/utils/module_utils.py

This file was deleted.

0 comments on commit 6e2941e

Please sign in to comment.