diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index dd7a6ad9b..3e61fb2cc 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -273,9 +273,9 @@ def test_model(model, audio_path): # 6. Forward wenet.decoder wenet_tokens, _ = add_whisper_tokens( - tokenizer, torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, - task_id=tokenizer.transcribe if task == "transcribe" else tokenizer.translate, # noqa - no_timestamp=True, language=language, use_prev=False + configs['model_conf']['special_tokens'], + torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, + task=task, no_timestamp=True, language=language, use_prev=False ) L = wenet_tokens.size(1) tgt_mask = ~make_pad_mask( diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 824e77581..5da1fb341 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -148,8 +148,8 @@ def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, def add_whisper_tokens( - tokenizer, ys_pad: torch.Tensor, - ignore_id: int, task_id: int, no_timestamp: bool, + special_tokens, ys_pad: torch.Tensor, + ignore_id: int, task: str, no_timestamp: bool, language: str, use_prev: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """Add whisper-style tokens. @@ -166,7 +166,7 @@ def add_whisper_tokens( |--> [no speech(VAD)] ---------------------------------------------------------------------->| # noqa Args: - tokenizer: get IDs of special tokens + special_tokens: get IDs of special tokens ignore_id (int): index of padding no_timestamp (bool): whether to add timestamps tokens language (str): language tag @@ -178,31 +178,39 @@ def add_whisper_tokens( """ if use_prev: # i.e., hotword list - _prev = [tokenizer.sot_prev] + _prev = [special_tokens["sot_prev"]] # append hotword list to _prev # ... raise NotImplementedError else: _prev = [] - language_id = tokenizer.sot + 1 + WHISPER_LANGS.index(language) - _sot = _prev + [tokenizer.sot, language_id, task_id] - _eot = torch.tensor([tokenizer.eot], + language_id = special_tokens["sot"] + 1 + WHISPER_LANGS.index(language) + if task == "transcribe": + task_id = special_tokens["transcribe"] + elif task == "translate": + task_id = special_tokens["translate"] + elif task == "vad": + task_id = special_tokens["no_speech"] + else: + raise NotImplementedError("unsupported task {}".format(task)) + _sot = _prev + [special_tokens["sot"], language_id, task_id] + _eot = torch.tensor([special_tokens["eot"]], dtype=torch.long, requires_grad=False, device=ys_pad.device) ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys - if task_id == tokenizer.transcribe or task_id == tokenizer.translate: + if task == "transcribe" or task == "translate": if no_timestamp: - _sot.append(tokenizer.no_timestamps) + _sot.append(special_tokens["no_timestamps"]) else: - _sot.append(tokenizer.timestamp_begin) + _sot.append(special_tokens["timestamp_begin"]) # add subsequent tokens # ... raise NotImplementedError - elif task_id == tokenizer.no_speech: - _sot.append(tokenizer.no_speech) + elif task == "vad": + _sot.append(special_tokens["no_speech"]) else: raise NotImplementedError @@ -210,7 +218,7 @@ def add_whisper_tokens( requires_grad=False, device=ys_pad.device) ys_in = [torch.cat([_sot, y], dim=0) for y in ys] ys_out = [torch.cat([_sot[1:], y, _eot], dim=0) for y in ys] - return pad_list(ys_in, tokenizer.eot), pad_list(ys_out, ignore_id) + return pad_list(ys_in, special_tokens["eot"]), pad_list(ys_out, ignore_id) def reverse_pad_list(ys_pad: torch.Tensor, diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 443aa07b2..2dced14fd 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -90,12 +90,22 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['decoder_conf']['activation_type'] = "gelu" configs['ctc_conf'] = {} - configs['ctc_conf']['ctc_blank_id'] = 50362 # + configs['ctc_conf']['ctc_blank_id'] = tokenizer.no_speech configs['model_conf'] = {} configs['model_conf']['ctc_weight'] = 0.3 configs['model_conf']['lsm_weight'] = 0.1 configs['model_conf']['length_normalized_loss'] = False + configs['model_conf']['special_tokens'] = {} + configs['model_conf']['special_tokens']['sot'] = tokenizer.sot + configs['model_conf']['special_tokens']['eot'] = tokenizer.sot + configs['model_conf']['special_tokens']['sot_prev'] = tokenizer.sot_prev + configs['model_conf']['special_tokens']['transcribe'] = tokenizer.transcribe + configs['model_conf']['special_tokens']['translate'] = tokenizer.translate + configs['model_conf']['special_tokens']['no_timestamps'] = tokenizer.no_timestamps + configs['model_conf']['special_tokens']['no_speech'] = tokenizer.no_speech + configs['model_conf']['special_tokens']['timestamp_begin'] = \ + tokenizer.timestamp_begin configs['dataset_conf'] = {} configs['dataset_conf']['filter_conf'] = {} diff --git a/wenet/whisper/whisper.py b/wenet/whisper/whisper.py index 2c83a346d..457b636ed 100644 --- a/wenet/whisper/whisper.py +++ b/wenet/whisper/whisper.py @@ -17,7 +17,6 @@ import torch from typing import Tuple -from whisper.tokenizer import get_tokenizer from wenet.transformer.asr_model import ASRModel from wenet.transformer.ctc import CTC @@ -38,16 +37,14 @@ def __init__( reverse_weight: float = 0.0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, + special_tokens: dict = None, ): super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, ignore_id, reverse_weight, lsm_weight, length_normalized_loss) - self.tokenizer = get_tokenizer(multilingual=self.is_multilingual, - num_languages=self.num_languages) - assert vocab_size == self.tokenizer.encoding.n_vocab, "{} v.s. {}".format( - vocab_size, self.tokenizer.encoding.n_vocab) assert reverse_weight == 0.0 - self.sos = self.tokenizer.sot - self.eos = self.tokenizer.eot + self.sos = special_tokens["sot"] + self.eos = special_tokens["eot"] + self.special_tokens = special_tokens # TODO(xcsong): time align def set_alignment_heads(self, dump: bytes): @@ -75,7 +72,7 @@ def _calc_att_loss( # TODO(xcsong): add args for no_timestamp, language, etc prev_len = ys_pad.size(1) ys_in_pad, ys_out_pad = add_whisper_tokens( - self.tokenizer, ys_pad, self.ignore_id, task_id=self.tokenizer.transcribe, + self.special_tokens, ys_pad, self.ignore_id, task="transcribe", no_timestamp=True, language="zh", use_prev=False ) cur_len = ys_in_pad.size(1)