Skip to content

Commit

Permalink
refactor(whisper): remove tokenizer in WhisperModel (#2172)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Nov 27, 2023
1 parent d9fb33c commit 16daf5d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 25 deletions.
6 changes: 3 additions & 3 deletions test/wenet/whisper/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def test_model(model, audio_path):

# 6. Forward wenet.decoder
wenet_tokens, _ = add_whisper_tokens(
tokenizer, torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1,
task_id=tokenizer.transcribe if task == "transcribe" else tokenizer.translate, # noqa
no_timestamp=True, language=language, use_prev=False
configs['model_conf']['special_tokens'],
torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1,
task=task, no_timestamp=True, language=language, use_prev=False
)
L = wenet_tokens.size(1)
tgt_mask = ~make_pad_mask(
Expand Down
34 changes: 21 additions & 13 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,


def add_whisper_tokens(
tokenizer, ys_pad: torch.Tensor,
ignore_id: int, task_id: int, no_timestamp: bool,
special_tokens, ys_pad: torch.Tensor,
ignore_id: int, task: str, no_timestamp: bool,
language: str, use_prev: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add whisper-style tokens.
Expand All @@ -166,7 +166,7 @@ def add_whisper_tokens(
|--> [no speech(VAD)] ---------------------------------------------------------------------->| # noqa
Args:
tokenizer: get IDs of special tokens
special_tokens: get IDs of special tokens
ignore_id (int): index of padding
no_timestamp (bool): whether to add timestamps tokens
language (str): language tag
Expand All @@ -178,39 +178,47 @@ def add_whisper_tokens(
"""
if use_prev:
# i.e., hotword list
_prev = [tokenizer.sot_prev]
_prev = [special_tokens["sot_prev"]]
# append hotword list to _prev
# ...
raise NotImplementedError
else:
_prev = []

language_id = tokenizer.sot + 1 + WHISPER_LANGS.index(language)
_sot = _prev + [tokenizer.sot, language_id, task_id]
_eot = torch.tensor([tokenizer.eot],
language_id = special_tokens["sot"] + 1 + WHISPER_LANGS.index(language)
if task == "transcribe":
task_id = special_tokens["transcribe"]
elif task == "translate":
task_id = special_tokens["translate"]
elif task == "vad":
task_id = special_tokens["no_speech"]
else:
raise NotImplementedError("unsupported task {}".format(task))
_sot = _prev + [special_tokens["sot"], language_id, task_id]
_eot = torch.tensor([special_tokens["eot"]],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys

if task_id == tokenizer.transcribe or task_id == tokenizer.translate:
if task == "transcribe" or task == "translate":
if no_timestamp:
_sot.append(tokenizer.no_timestamps)
_sot.append(special_tokens["no_timestamps"])
else:
_sot.append(tokenizer.timestamp_begin)
_sot.append(special_tokens["timestamp_begin"])
# add subsequent tokens
# ...
raise NotImplementedError
elif task_id == tokenizer.no_speech:
_sot.append(tokenizer.no_speech)
elif task == "vad":
_sot.append(special_tokens["no_speech"])
else:
raise NotImplementedError

_sot = torch.tensor(_sot, dtype=torch.long,
requires_grad=False, device=ys_pad.device)
ys_in = [torch.cat([_sot, y], dim=0) for y in ys]
ys_out = [torch.cat([_sot[1:], y, _eot], dim=0) for y in ys]
return pad_list(ys_in, tokenizer.eot), pad_list(ys_out, ignore_id)
return pad_list(ys_in, special_tokens["eot"]), pad_list(ys_out, ignore_id)


def reverse_pad_list(ys_pad: torch.Tensor,
Expand Down
12 changes: 11 additions & 1 deletion wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,22 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str):
configs['decoder_conf']['activation_type'] = "gelu"

configs['ctc_conf'] = {}
configs['ctc_conf']['ctc_blank_id'] = 50362 # <nospeech>
configs['ctc_conf']['ctc_blank_id'] = tokenizer.no_speech

configs['model_conf'] = {}
configs['model_conf']['ctc_weight'] = 0.3
configs['model_conf']['lsm_weight'] = 0.1
configs['model_conf']['length_normalized_loss'] = False
configs['model_conf']['special_tokens'] = {}
configs['model_conf']['special_tokens']['sot'] = tokenizer.sot
configs['model_conf']['special_tokens']['eot'] = tokenizer.sot
configs['model_conf']['special_tokens']['sot_prev'] = tokenizer.sot_prev
configs['model_conf']['special_tokens']['transcribe'] = tokenizer.transcribe
configs['model_conf']['special_tokens']['translate'] = tokenizer.translate
configs['model_conf']['special_tokens']['no_timestamps'] = tokenizer.no_timestamps
configs['model_conf']['special_tokens']['no_speech'] = tokenizer.no_speech
configs['model_conf']['special_tokens']['timestamp_begin'] = \
tokenizer.timestamp_begin

configs['dataset_conf'] = {}
configs['dataset_conf']['filter_conf'] = {}
Expand Down
13 changes: 5 additions & 8 deletions wenet/whisper/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch

from typing import Tuple
from whisper.tokenizer import get_tokenizer

from wenet.transformer.asr_model import ASRModel
from wenet.transformer.ctc import CTC
Expand All @@ -38,16 +37,14 @@ def __init__(
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
special_tokens: dict = None,
):
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, ignore_id,
reverse_weight, lsm_weight, length_normalized_loss)
self.tokenizer = get_tokenizer(multilingual=self.is_multilingual,
num_languages=self.num_languages)
assert vocab_size == self.tokenizer.encoding.n_vocab, "{} v.s. {}".format(
vocab_size, self.tokenizer.encoding.n_vocab)
assert reverse_weight == 0.0
self.sos = self.tokenizer.sot
self.eos = self.tokenizer.eot
self.sos = special_tokens["sot"]
self.eos = special_tokens["eot"]
self.special_tokens = special_tokens

# TODO(xcsong): time align
def set_alignment_heads(self, dump: bytes):
Expand Down Expand Up @@ -75,7 +72,7 @@ def _calc_att_loss(
# TODO(xcsong): add args for no_timestamp, language, etc
prev_len = ys_pad.size(1)
ys_in_pad, ys_out_pad = add_whisper_tokens(
self.tokenizer, ys_pad, self.ignore_id, task_id=self.tokenizer.transcribe,
self.special_tokens, ys_pad, self.ignore_id, task="transcribe",
no_timestamp=True, language="zh", use_prev=False
)
cur_len = ys_in_pad.size(1)
Expand Down

0 comments on commit 16daf5d

Please sign in to comment.