diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 9cd2aff4f..722807e4f 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -54,9 +54,9 @@ def __init__( super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = (vocab_size - 1 if special_tokens is None else - special_tokens.get("sos", vocab_size - 1)) + special_tokens.get("", vocab_size - 1)) self.eos = (vocab_size - 1 if special_tokens is None else - special_tokens.get("eos", vocab_size - 1)) + special_tokens.get("", vocab_size - 1)) self.vocab_size = vocab_size self.special_tokens = special_tokens self.ignore_id = ignore_id