diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index dd7a6ad9b..06b046859 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -12,10 +12,10 @@ import numpy as np import torch.nn.functional as F -from whisper.tokenizer import get_tokenizer from whisper.audio import N_FFT, HOP_LENGTH, N_SAMPLES, N_FRAMES, pad_or_trim from wenet.dataset.processor import compute_log_mel_spectrogram +from wenet.text.whisper_tokenizer import WhisperTokenizer from wenet.transformer.embedding import WhisperPositionalEncoding from wenet.whisper.convert_whisper_to_wenet_config_and_ckpt import ( convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units @@ -108,19 +108,19 @@ def test_model(model, audio_path): checkpoint = torch.load("{}/{}.pt".format(download_root, model), map_location="cpu") multilingual = checkpoint["dims"]['n_vocab'] >= 51865 num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) - tokenizer = get_tokenizer(multilingual, num_languages=num_languages, - language=language, task=task) + tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages, + language=language, task=task) convert_to_wenet_state_dict( checkpoint["model_state_dict"], os.path.join(download_root, 'wenet_whisper.pt') ) convert_to_wenet_units( - tokenizer, + tokenizer.tokenizer, os.path.join(download_root, 'units.txt') ) convert_to_wenet_yaml( - tokenizer, checkpoint["dims"], + tokenizer.tokenizer, checkpoint["dims"], os.path.join(download_root, 'train.yaml') ) with open("{}/train.yaml".format(download_root), 'r') as fin: @@ -132,7 +132,7 @@ def test_model(model, audio_path): wenet_model.eval() with torch.no_grad(): - dummy_tokens = tokenizer.encode("WeNet x OpenAI") + _, dummy_tokens = tokenizer.tokenize("WeNet x OpenAI") # 3. Forward whisper.encoder mel1 = whisper.log_mel_spectrogram( @@ -173,8 +173,8 @@ def test_model(model, audio_path): rtol=1e-7, atol=1e-10) # 4. Forward whisper.decoder - whisper_tokens = torch.tensor(list(tokenizer.sot_sequence) - + [tokenizer.no_timestamps] + whisper_tokens = torch.tensor(list(tokenizer.tokenizer.sot_sequence) + + [tokenizer.tokenizer.no_timestamps] + dummy_tokens, dtype=torch.long).unsqueeze(0) # (B=1, 9) whisper_decoder_embed = whisper_model.decoder.token_embedding(whisper_tokens) @@ -273,10 +273,15 @@ def test_model(model, audio_path): # 6. Forward wenet.decoder wenet_tokens, _ = add_whisper_tokens( - tokenizer, torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, - task_id=tokenizer.transcribe if task == "transcribe" else tokenizer.translate, # noqa - no_timestamp=True, language=language, use_prev=False - ) + tokenizer.tokenizer, + torch.tensor([dummy_tokens], dtype=torch.long), + ignore_id=-1, + task_id=tokenizer.tokenizer.transcribe + if task == "transcribe" else tokenizer.tokenizer.translate, + no_timestamp=True, + language=language, + use_prev=False) + L = wenet_tokens.size(1) tgt_mask = ~make_pad_mask( torch.tensor([L], dtype=torch.long), L).unsqueeze(1) # (B=1, 1, L) diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index f1adc41a1..61056a0e7 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -1,7 +1,6 @@ from os import PathLike from typing import List, Optional, Tuple, Union from wenet.text.base_tokenizer import BaseTokenizer -from whisper.tokenizer import get_tokenizer from wenet.utils.file_utils import read_non_lang_symbols @@ -18,6 +17,7 @@ def __init__( *args, **kwargs, ) -> None: + from whisper.tokenizer import get_tokenizer self.tokenizer = get_tokenizer(multilingual=multilingual, num_languages=num_languages, language=language, diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 824e77581..c73741a5b 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -20,9 +20,6 @@ import torch from torch.nn.utils.rnn import pad_sequence -from whisper.tokenizer import LANGUAGES as WhiserLanguages - -WHISPER_LANGS = tuple(WhiserLanguages.keys()) IGNORE_ID = -1 @@ -176,6 +173,8 @@ def add_whisper_tokens( ys_out (torch.Tensor) : (B, Lmax + ?) """ + from whisper.tokenizer import LANGUAGES as WhiserLanguages + WHISPER_LANGS = tuple(WhiserLanguages.keys()) if use_prev: # i.e., hotword list _prev = [tokenizer.sot_prev] diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 45c36d970..443aa07b2 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -89,6 +89,9 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['decoder_conf']['key_bias'] = False configs['decoder_conf']['activation_type'] = "gelu" + configs['ctc_conf'] = {} + configs['ctc_conf']['ctc_blank_id'] = 50362 # + configs['model_conf'] = {} configs['model_conf']['ctc_weight'] = 0.3 configs['model_conf']['lsm_weight'] = 0.1 @@ -208,13 +211,12 @@ def convert_to_wenet_units(tokenizer, units_txt_path): n_vocab = tokenizer.encoding.n_vocab with open(units_txt_path, "+w") as f: for i in range(n_vocab): - unit = tokenizer.encoding.decode([i]) + unit = str(tokenizer.encoding.decode_single_token_bytes(i)) if len(unit) == 0: unit = str(i) print("can not decode id {}, convert to str({})".format(i, i)) unit = unit.replace(" ", "") - unit = bytes(unit, 'utf-8') - f.write("{} {}\n".format(str(unit), i)) + f.write("{} {}\n".format(unit, i)) f.flush()