Skip to content

Commit

Permalink
[text] add WhisperTokenizer for test_whisper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 27, 2023
1 parent 266a4fa commit c2ecc7c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 19 deletions.
29 changes: 17 additions & 12 deletions test/wenet/whisper/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import numpy as np
import torch.nn.functional as F

from whisper.tokenizer import get_tokenizer
from whisper.audio import N_FFT, HOP_LENGTH, N_SAMPLES, N_FRAMES, pad_or_trim

from wenet.dataset.processor import compute_log_mel_spectrogram
from wenet.text.whisper_tokenizer import WhisperTokenizer
from wenet.transformer.embedding import WhisperPositionalEncoding
from wenet.whisper.convert_whisper_to_wenet_config_and_ckpt import (
convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units
Expand Down Expand Up @@ -108,19 +108,19 @@ def test_model(model, audio_path):
checkpoint = torch.load("{}/{}.pt".format(download_root, model), map_location="cpu")
multilingual = checkpoint["dims"]['n_vocab'] >= 51865
num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual)
tokenizer = get_tokenizer(multilingual, num_languages=num_languages,
language=language, task=task)
tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages,
language=language, task=task)

convert_to_wenet_state_dict(
checkpoint["model_state_dict"],
os.path.join(download_root, 'wenet_whisper.pt')
)
convert_to_wenet_units(
tokenizer,
tokenizer.tokenizer,
os.path.join(download_root, 'units.txt')
)
convert_to_wenet_yaml(
tokenizer, checkpoint["dims"],
tokenizer.tokenizer, checkpoint["dims"],
os.path.join(download_root, 'train.yaml')
)
with open("{}/train.yaml".format(download_root), 'r') as fin:
Expand All @@ -132,7 +132,7 @@ def test_model(model, audio_path):
wenet_model.eval()

with torch.no_grad():
dummy_tokens = tokenizer.encode("WeNet x OpenAI")
_, dummy_tokens = tokenizer.tokenize("WeNet x OpenAI")

# 3. Forward whisper.encoder
mel1 = whisper.log_mel_spectrogram(
Expand Down Expand Up @@ -173,8 +173,8 @@ def test_model(model, audio_path):
rtol=1e-7, atol=1e-10)

# 4. Forward whisper.decoder
whisper_tokens = torch.tensor(list(tokenizer.sot_sequence)
+ [tokenizer.no_timestamps]
whisper_tokens = torch.tensor(list(tokenizer.tokenizer.sot_sequence)
+ [tokenizer.tokenizer.no_timestamps]
+ dummy_tokens,
dtype=torch.long).unsqueeze(0) # (B=1, 9)
whisper_decoder_embed = whisper_model.decoder.token_embedding(whisper_tokens)
Expand Down Expand Up @@ -273,10 +273,15 @@ def test_model(model, audio_path):

# 6. Forward wenet.decoder
wenet_tokens, _ = add_whisper_tokens(
tokenizer, torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1,
task_id=tokenizer.transcribe if task == "transcribe" else tokenizer.translate, # noqa
no_timestamp=True, language=language, use_prev=False
)
tokenizer.tokenizer,
torch.tensor([dummy_tokens], dtype=torch.long),
ignore_id=-1,
task_id=tokenizer.tokenizer.transcribe
if task == "transcribe" else tokenizer.tokenizer.translate,
no_timestamp=True,
language=language,
use_prev=False)

L = wenet_tokens.size(1)
tgt_mask = ~make_pad_mask(
torch.tensor([L], dtype=torch.long), L).unsqueeze(1) # (B=1, 1, L)
Expand Down
2 changes: 1 addition & 1 deletion wenet/text/whisper_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from os import PathLike
from typing import List, Optional, Tuple, Union
from wenet.text.base_tokenizer import BaseTokenizer
from whisper.tokenizer import get_tokenizer

from wenet.utils.file_utils import read_non_lang_symbols

Expand All @@ -18,6 +17,7 @@ def __init__(
*args,
**kwargs,
) -> None:
from whisper.tokenizer import get_tokenizer
self.tokenizer = get_tokenizer(multilingual=multilingual,
num_languages=num_languages,
language=language,
Expand Down
5 changes: 2 additions & 3 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import torch
from torch.nn.utils.rnn import pad_sequence

from whisper.tokenizer import LANGUAGES as WhiserLanguages

WHISPER_LANGS = tuple(WhiserLanguages.keys())
IGNORE_ID = -1


Expand Down Expand Up @@ -176,6 +173,8 @@ def add_whisper_tokens(
ys_out (torch.Tensor) : (B, Lmax + ?)
"""
from whisper.tokenizer import LANGUAGES as WhiserLanguages
WHISPER_LANGS = tuple(WhiserLanguages.keys())
if use_prev:
# i.e., hotword list
_prev = [tokenizer.sot_prev]
Expand Down
8 changes: 5 additions & 3 deletions wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str):
configs['decoder_conf']['key_bias'] = False
configs['decoder_conf']['activation_type'] = "gelu"

configs['ctc_conf'] = {}
configs['ctc_conf']['ctc_blank_id'] = 50362 # <nospeech>

configs['model_conf'] = {}
configs['model_conf']['ctc_weight'] = 0.3
configs['model_conf']['lsm_weight'] = 0.1
Expand Down Expand Up @@ -208,13 +211,12 @@ def convert_to_wenet_units(tokenizer, units_txt_path):
n_vocab = tokenizer.encoding.n_vocab
with open(units_txt_path, "+w") as f:
for i in range(n_vocab):
unit = tokenizer.encoding.decode([i])
unit = str(tokenizer.encoding.decode_single_token_bytes(i))
if len(unit) == 0:
unit = str(i)
print("can not decode id {}, convert to str({})".format(i, i))
unit = unit.replace(" ", "<space>")
unit = bytes(unit, 'utf-8')
f.write("{} {}\n".format(str(unit), i))
f.write("{} {}\n".format(unit, i))
f.flush()


Expand Down

0 comments on commit c2ecc7c

Please sign in to comment.