diff --git a/runtime/binding/python/README.md b/runtime/binding/python/README.md index 30d9c84b4..238d54f1f 100644 --- a/runtime/binding/python/README.md +++ b/runtime/binding/python/README.md @@ -91,7 +91,7 @@ with wave.open(test_wav, 'rb') as fin: assert fin.getnchannels() == 1 wav = fin.readframes(fin.getnframes()) -decoder = wenet.Decoder(lang='chs') +decoder = wenet.Decoder(lang='chs', streaming=True) # We suppose the wav is 16k, 16bits, and decode every 0.5 seconds interval = int(0.5 * 16000) * 2 for i in range(0, len(wav), interval): diff --git a/runtime/binding/python/cpp/binding.cc b/runtime/binding/python/cpp/binding.cc index 42578f211..e3dbb3e02 100644 --- a/runtime/binding/python/cpp/binding.cc +++ b/runtime/binding/python/cpp/binding.cc @@ -18,7 +18,6 @@ namespace py = pybind11; - PYBIND11_MODULE(_wenet, m) { m.doc() = "wenet pybind11 plugin"; // optional module docstring m.def("wenet_init", &wenet_init, py::return_value_policy::reference, @@ -36,4 +35,6 @@ PYBIND11_MODULE(_wenet, m) { m.def("wenet_set_language", &wenet_set_language, "set language"); m.def("wenet_set_continuous_decoding", &wenet_set_continuous_decoding, "enable continuous decoding or not"); + m.def("wenet_set_chunk_size", &wenet_set_chunk_size, + "set decoding chunk size"); } diff --git a/runtime/binding/python/wenetruntime/decoder.py b/runtime/binding/python/wenetruntime/decoder.py index c0e835f33..20c1f7ca0 100644 --- a/runtime/binding/python/wenetruntime/decoder.py +++ b/runtime/binding/python/wenetruntime/decoder.py @@ -34,7 +34,8 @@ def __init__(self, enable_timestamp: bool = False, context: Optional[List[str]] = None, context_score: float = 3.0, - continuous_decoding: bool = False): + continuous_decoding: bool = False, + streaming: bool = False): """ Init WeNet decoder Args: lang: language type of the model @@ -44,6 +45,7 @@ def __init__(self, context: context words context_score: bonus score when the context is matched continuous_decoding: enable countinous decoding or not + streaming: streaming mode """ if model_dir is None: model_dir = Hub.get_model_by_lang(lang) @@ -57,6 +59,8 @@ def __init__(self, self.add_context(context) self.set_context_score(context_score) self.set_continuous_decoding(continuous_decoding) + chunk_size = 16 if streaming else -1 + self.set_chunk_size(chunk_size) def __del__(self): _wenet.wenet_free(self.d) @@ -90,6 +94,9 @@ def set_continuous_decoding(self, continuous_decoding: bool): flag = 1 if continuous_decoding else 0 _wenet.wenet_set_continuous_decoding(self.d, flag) + def set_chunk_size(self, chunk_size: int): + _wenet.wenet_set_chunk_size(self.d, chunk_size) + def decode(self, audio: Union[str, bytes, np.ndarray], last: bool = True) -> str: diff --git a/runtime/binding/python/wenetruntime/main.py b/runtime/binding/python/wenetruntime/main.py index 26fbf16c5..5c49d149b 100644 --- a/runtime/binding/python/wenetruntime/main.py +++ b/runtime/binding/python/wenetruntime/main.py @@ -15,6 +15,7 @@ import argparse from wenetruntime.decoder import Decoder +from _wenet import wenet_set_log_level as set_log_level # noqa def get_args(): @@ -23,6 +24,16 @@ def get_args(): default='chs', choices=['chs', 'en'], help='select language') + parser.add_argument('-c', + '--chunk_size', + default=-1, + type=int, + help='set decoding chunk size') + parser.add_argument('-v', + '--verbose', + default=0, + type=int, + help='set log(glog backend) level') parser.add_argument('audio', help='input audio file') args = parser.parse_args() return args @@ -30,7 +41,9 @@ def get_args(): def main(): args = get_args() + set_log_level(args.verbose) decoder = Decoder(lang=args.language) + decoder.set_chunk_size(args.chunk_size) result = decoder.decode(args.audio) print(result) diff --git a/runtime/core/api/wenet_api.cc b/runtime/core/api/wenet_api.cc index 0852a079b..ff5f98111 100644 --- a/runtime/core/api/wenet_api.cc +++ b/runtime/core/api/wenet_api.cc @@ -109,6 +109,8 @@ class Recognizer { } resource_->post_processor = std::make_shared(*post_process_opts_); + // Init decode options + decode_options_->chunk_size = chunk_size_; // Init decoder decoder_ = std::make_shared(feature_pipeline_, resource_, *decode_options_); @@ -180,6 +182,7 @@ class Recognizer { void set_context_score(float score) { context_score_ = score; } void set_language(const char* lang) { language_ = lang; } void set_continuous_decoding(bool flag) { continuous_decoding_ = flag; } + void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; } private: // NOTE(Binbin Zhang): All use shared_ptr for clone in the future @@ -197,6 +200,7 @@ class Recognizer { float context_score_; std::string language_ = "chs"; bool continuous_decoding_ = false; + int chunk_size_ = 16; }; void* wenet_init(const char* model_dir) { @@ -255,3 +259,8 @@ void wenet_set_continuous_decoding(void* decoder, int flag) { Recognizer* recognizer = reinterpret_cast(decoder); recognizer->set_continuous_decoding(flag > 0); } + +void wenet_set_chunk_size(void* decoder, int chunk_size) { + Recognizer* recognizer = reinterpret_cast(decoder); + recognizer->set_chunk_size(chunk_size); +} diff --git a/runtime/core/api/wenet_api.h b/runtime/core/api/wenet_api.h index fe524fc51..edeb9dd24 100644 --- a/runtime/core/api/wenet_api.h +++ b/runtime/core/api/wenet_api.h @@ -103,6 +103,10 @@ void wenet_set_log_level(int level); */ void wenet_set_continuous_decoding(void* decoder, int flag); +/** Set chunk size for decoding, -1 for non-streaming decoding + */ +void wenet_set_chunk_size(void* decoder, int chunk_size); + #ifdef __cplusplus } #endif