Skip to content

Commit

Permalink
refactor(model): Add special token to model (#2209)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Dec 8, 2023
1 parent da93958 commit b517fb5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
14 changes: 11 additions & 3 deletions wenet/ctl_model/asr_model_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,19 @@ def __init__(
logit_temp: float = 0.1,
n_negatives: int = 0,
ctl_weight: float = 1,
special_tokens: dict = None,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight,
ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
super().__init__(vocab_size,
encoder,
decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
special_tokens=special_tokens)

# For CTL Loss
self.n_negatives = n_negatives
Expand Down
14 changes: 11 additions & 3 deletions wenet/k2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,18 @@ def __init__(
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
lfmmi_dir: str = '',
special_tokens: dict = None,
):
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight,
ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
super().__init__(vocab_size,
encoder,
decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
special_tokens=special_tokens)
self.lfmmi_dir = lfmmi_dir
if self.lfmmi_dir != '':
self.load_lfmmi_resource()
Expand Down
14 changes: 11 additions & 3 deletions wenet/transducer/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,19 @@ def __init__(
warmup_steps: float = 25000,
lm_only_scale: float = 0.25,
am_only_scale: float = 0.0,
special_tokens: dict = None,
) -> None:
assert attention_weight + ctc_weight + transducer_weight == 1.0
super().__init__(vocab_size, encoder, attention_decoder, ctc,
ctc_weight, ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
super().__init__(vocab_size,
encoder,
attention_decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
special_tokens=special_tokens)

self.blank = blank
self.transducer_weight = transducer_weight
Expand Down

0 comments on commit b517fb5

Please sign in to comment.