diff --git a/runtime/binding/python/py/decoder.py b/runtime/binding/python/py/decoder.py index 47055d042..6ed782cf5 100644 --- a/runtime/binding/python/py/decoder.py +++ b/runtime/binding/python/py/decoder.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional import _wenet +from .hub import Hub + class Decoder: + def __init__(self, - model_dir: str, + model_dir: Optional[str] = None, lang: str = 'chs', nbest: int = 1, enable_timestamp: bool = False, - context: List[str] = None, + context: Optional[List[str]] = None, context_score: float = 3.0): """ Init WeNet decoder Args: @@ -34,7 +37,11 @@ def __init__(self, context: context words context_score: bonus score when the context is matched """ + if model_dir is None: + model_dir = Hub.get_model_by_lang(lang) + self.d = _wenet.wenet_init(model_dir) + self.set_language(lang) self.set_nbest(nbest) self.enable_timestamp(enable_timestamp) diff --git a/runtime/binding/python/py/hub.py b/runtime/binding/python/py/hub.py new file mode 100644 index 000000000..03e5b9f68 --- /dev/null +++ b/runtime/binding/python/py/hub.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022 Mddct(hamddct@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tarfile +from pathlib import Path +from urllib.request import urlretrieve + +import tqdm + + +def download(url: str, dest: str, only_child=True): + """ download from url to dest + """ + assert os.path.exists(dest) + + def progress_hook(t): + last_b = [0] + + def update_to(b=1, bsize=1, tsize=None): + if tsize not in (None, -1): + t.total = tsize + displayed = t.update((b - last_b[0]) * bsize) + last_b[0] = b + return displayed + return update_to + + # *.tar.gz + name = url.split("/")[-1] + tar_path = os.path.join(dest, name) + with tqdm.tqdm(unit='B', + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=(name)) as t: + urlretrieve(url, + filename=tar_path, + reporthook=progress_hook(t), + data=None) + t.total = t.n + + with tarfile.open(tar_path) as f: + if not only_child: + f.extractall(dest) + else: + for tarinfo in f: + if "/" not in tarinfo.name: + continue + name = os.path.basename(tarinfo.name) + f.extract(tarinfo, os.path.join(dest, name)) + + +class Hub(object): + """Hub for wenet pretrain runtime model + """ + # TODO(Mddct): make assets class to support other language + Assets = { + # wenetspeech + "chs": + "https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/wenetspeech/20220506_u2pp_conformer_libtorch.tar.gz", + # gigaspeech + "en": + "https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/gigaspeech/20210728_u2pp_conformer_libtorch.tar.gz" + } + + def __init__(self) -> None: + pass + + @staticmethod + def get_model_by_lang(lang: str) -> str: + assert lang in Hub.Assets.keys() + # NOTE(Mddct): model_dir structure + # Path.Home()/.went + # - chs + # - units.txt + # - final.zip + # - en + # - units.txt + # - final.zip + model_url = Hub.Assets[lang] + model_dir = os.path.join(Path.home(), ".wenet", lang) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + # TODO(Mddct): model metadata + if set(["final.zip", + "units.txt"]).issubset(set(os.listdir(model_dir))): + return model_dir + download(model_url, model_dir, only_child=True) + return model_dir diff --git a/runtime/binding/python/setup.py b/runtime/binding/python/setup.py index 92b02521a..744e76f95 100644 --- a/runtime/binding/python/setup.py +++ b/runtime/binding/python/setup.py @@ -48,11 +48,13 @@ def build_extension(self, ext: setuptools.extension.Extension): libs = [] torch_lib = 'fc_base/libtorch-src/lib' for ext in ['so', 'pyd']: - libs.extend(glob.glob( - f"{self.build_temp}/**/_wenet*.{ext}", recursive=True)) + libs.extend( + glob.glob(f"{self.build_temp}/**/_wenet*.{ext}", + recursive=True)) for ext in ['so', 'dylib', 'dll']: - libs.extend(glob.glob( - f"{self.build_temp}/**/*wenet_api.{ext}", recursive=True)) + libs.extend( + glob.glob(f"{self.build_temp}/**/*wenet_api.{ext}", + recursive=True)) libs.extend(glob.glob(f'{src_dir}/{torch_lib}/*c10.{ext}')) libs.extend(glob.glob(f'{src_dir}/{torch_lib}/*torch_cpu.{ext}')) @@ -95,6 +97,8 @@ def read_long_description(): ext_modules=[cmake_extension("_wenet")], cmdclass={"build_ext": BuildExtension}, zip_safe=False, + setup_requires=["tqdm"], + install_requires=["tqdm"], classifiers=[ "Programming Language :: C++", "Programming Language :: Python",