Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[text] fix whisper tokens and others #2179

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/vkw2021/s0/local/vkw_kws_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from torch.utils.data import DataLoader

from wenet.dataset.dataset import Dataset
from wenet.transformer.asr_model import init_asr_model
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer

from wenet.utils.common import get_subsample
from wenet.utils.common import remove_duplicates_and_blank
Expand Down Expand Up @@ -186,11 +186,11 @@ def get_labformat_frames(timestamp, subsample, char_dict):
cv_conf['speed_perturb'] = False
cv_conf['spec_aug'] = False

tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model)
cv_dataset = Dataset(args.data_type,
args.input_data,
symbol_table,
tokenizer,
cv_conf,
None,
partition=False)

cv_data_loader = DataLoader(cv_dataset,
Expand All @@ -205,7 +205,7 @@ def get_labformat_frames(timestamp, subsample, char_dict):
print("word_unit_list has the size of %d" % (len(word_unit_list)))

# Init asr model from configs
model = init_asr_model(configs)
model, configs = init_model(args, configs)
load_checkpoint(model, args.checkpoint)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
Expand Down
36 changes: 36 additions & 0 deletions test/wenet/text/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ def test_whisper_tokenzier_parallel():
assert all(h == r for (h, r) in zip(results, inputs))


def test_whisper_tokenzier_parallel_after_property():

inputs = ["it's ok", "wenet is simple", "test for new io"]
tokenizer = WhisperTokenizer(False)

_ = tokenizer.vocab_size
_ = tokenizer.symbol_table
partial_tokenize = partial(consistency, tokenizer)
with Pool(processes=len(inputs)) as pool:
results = pool.map(partial_tokenize, inputs)

inputs.sort()
results.sort()

assert all(h == r for (h, r) in zip(results, inputs))


def test_bpe_tokenzier_parallel():

symbol_table_path = "test/resources/librispeech.words.txt"
Expand All @@ -40,3 +57,22 @@ def test_bpe_tokenzier_parallel():
results.sort()

assert all(h == r for (h, r) in zip(results, inputs))


def test_bpe_tokenizer_parallel_after_property():
symbol_table_path = "test/resources/librispeech.words.txt"
bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel"

inputs = ["WENR IS SIMPLE", "GOOD"]
tokenizer = BpeTokenizer(bpe_model, symbol_table_path)
_ = tokenizer.vocab_size
_ = tokenizer.symbol_table

partial_tokenize = partial(consistency, tokenizer)
with Pool(processes=len(inputs)) as pool:
results = pool.map(partial_tokenize, inputs)

inputs.sort()
results.sort()

assert all(h == r for (h, r) in zip(results, inputs))
9 changes: 5 additions & 4 deletions tools/onnx2horizonbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.bin.export_onnx_cpu import to_numpy
from wenet.bin.export_onnx_bpu import export_encoder, export_ctc

Expand Down Expand Up @@ -80,9 +81,9 @@ def save_data(tensor, dirs, prefix):
def make_calibration_data(enc, args, conf):
conf['shuffle'] = True
logger.info(conf)
tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model)
dataset = Dataset(
"shard", args.cali_datalist, args.symbol_table, conf,
bpe_model=args.bpe_model, non_lang_syms=None, partition=False)
"shard", args.cali_datalist, tokenizer, conf, partition=False)
dataloader = DataLoader(dataset, batch_size=None, num_workers=0)

subsampling = enc.embed.subsampling_rate
Expand Down Expand Up @@ -148,9 +149,9 @@ def make_calibration_data(enc, args, conf):

def check_wer(enc, ctc, args, conf):
conf['shuffle'] = False
tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model)
dataset = Dataset(
"shard", args.wer_datalist, args.symbol_table, conf,
bpe_model=args.bpe_model, non_lang_syms=None, partition=False)
"shard", args.wer_datalist, tokenizer, conf, partition=False)
dataloader = DataLoader(dataset, batch_size=None, num_workers=0)
char_dict = {v: k for k, v in args.symbol_table.items()}
eos = len(char_dict) - 1
Expand Down
9 changes: 3 additions & 6 deletions wenet/bin/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
import math

from wenet.dataset.dataset import Dataset
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.ctc_utils import force_align
from wenet.utils.common import get_subsample
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer


def generator_textgrid(maxtime, lines, output):
Expand Down Expand Up @@ -183,7 +183,6 @@ def get_labformat(timestamp, subsample):
char_dict[int(arr[1])] = arr[0]
eos = len(char_dict) - 1

symbol_table = read_symbol_table(args.dict)

# Init dataset and data loader
ali_conf = copy.deepcopy(configs['dataset_conf'])
Expand All @@ -202,14 +201,12 @@ def get_labformat(timestamp, subsample):
ali_conf['fbank_conf']['dither'] = 0.0
ali_conf['batch_conf']['batch_type'] = "static"
ali_conf['batch_conf']['batch_size'] = args.batch_size
non_lang_syms = read_non_lang_symbols(args.non_lang_syms)

tokenizer = init_tokenizer(ali_conf, args.dict, args.bpe_model, args.non_lang_syms)
ali_dataset = Dataset(args.data_type,
args.input_file,
symbol_table,
tokenizer,
ali_conf,
args.bpe_model,
non_lang_syms,
partition=False)

ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0)
Expand Down
7 changes: 3 additions & 4 deletions wenet/bin/recognize_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@

from wenet.dataset.dataset import Dataset
from wenet.utils.common import IGNORE_ID
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.config import override_config
from wenet.utils.init_tokenizer import init_tokenizer

import onnxruntime as rt
import multiprocessing
Expand Down Expand Up @@ -118,7 +118,6 @@ def main():
configs = override_config(configs, args.override_config)

reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
symbol_table = read_symbol_table(args.dict)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
Expand All @@ -136,11 +135,11 @@ def main():
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = args.batch_size

tokenizer = init_tokenizer(test_conf, args.dict, args.bpe_model)
test_dataset = Dataset(args.data_type,
args.test_data,
symbol_table,
tokenizer,
test_conf,
args.bpe_model,
partition=False)

test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
Expand Down
2 changes: 1 addition & 1 deletion wenet/text/bpe_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BpeTokenizer(CharTokenizer):

def __init__(
self,
bpe_model: PathLike,
bpe_model: Union[PathLike, str],
symbol_table: Union[str, PathLike, Dict],
non_lang_syms: Optional[Union[str, PathLike, List]] = None,
split_with_space: bool = False,
Expand Down
11 changes: 11 additions & 0 deletions wenet/text/whisper_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def __init__(
# TODO(Mddct): add special tokens, like non_lang_syms
del self.non_lang_syms

def __getstate__(self):
state = self.__dict__.copy()
del state['tokenizer']
return state

def __setstate__(self, state):
self.__dict__.update(state)
recovery = {'tokenizer': None}
self.__dict__.update(recovery)

def _build_tiktoken(self):
if self.tokenizer is None:
from whisper.tokenizer import get_tokenizer
Expand Down Expand Up @@ -87,6 +97,7 @@ def vocab_size(self) -> int:
self._build_tiktoken()
return len(self.t2i)

@property
def symbol_table(self) -> Dict[str, int]:
self._build_tiktoken()
return self.t2i
Loading