Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[binding] add chunk size interface and use non-streaming decoding by … #1970

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runtime/binding/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ with wave.open(test_wav, 'rb') as fin:
assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes())

decoder = wenet.Decoder(lang='chs')
decoder = wenet.Decoder(lang='chs', streaming=True)
# We suppose the wav is 16k, 16bits, and decode every 0.5 seconds
interval = int(0.5 * 16000) * 2
for i in range(0, len(wav), interval):
Expand Down
3 changes: 2 additions & 1 deletion runtime/binding/python/cpp/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

namespace py = pybind11;


PYBIND11_MODULE(_wenet, m) {
m.doc() = "wenet pybind11 plugin"; // optional module docstring
m.def("wenet_init", &wenet_init, py::return_value_policy::reference,
Expand All @@ -36,4 +35,6 @@ PYBIND11_MODULE(_wenet, m) {
m.def("wenet_set_language", &wenet_set_language, "set language");
m.def("wenet_set_continuous_decoding", &wenet_set_continuous_decoding,
"enable continuous decoding or not");
m.def("wenet_set_chunk_size", &wenet_set_chunk_size,
"set decoding chunk size");
}
9 changes: 8 additions & 1 deletion runtime/binding/python/wenetruntime/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self,
enable_timestamp: bool = False,
context: Optional[List[str]] = None,
context_score: float = 3.0,
continuous_decoding: bool = False):
continuous_decoding: bool = False,
streaming: bool = False):
""" Init WeNet decoder
Args:
lang: language type of the model
Expand All @@ -44,6 +45,7 @@ def __init__(self,
context: context words
context_score: bonus score when the context is matched
continuous_decoding: enable countinous decoding or not
streaming: streaming mode
"""
if model_dir is None:
model_dir = Hub.get_model_by_lang(lang)
Expand All @@ -57,6 +59,8 @@ def __init__(self,
self.add_context(context)
self.set_context_score(context_score)
self.set_continuous_decoding(continuous_decoding)
chunk_size = 16 if streaming else -1
self.set_chunk_size(chunk_size)

def __del__(self):
_wenet.wenet_free(self.d)
Expand Down Expand Up @@ -90,6 +94,9 @@ def set_continuous_decoding(self, continuous_decoding: bool):
flag = 1 if continuous_decoding else 0
_wenet.wenet_set_continuous_decoding(self.d, flag)

def set_chunk_size(self, chunk_size: int):
_wenet.wenet_set_chunk_size(self.d, chunk_size)

def decode(self,
audio: Union[str, bytes, np.ndarray],
last: bool = True) -> str:
Expand Down
13 changes: 13 additions & 0 deletions runtime/binding/python/wenetruntime/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse

from wenetruntime.decoder import Decoder
from _wenet import wenet_set_log_level as set_log_level # noqa


def get_args():
Expand All @@ -23,14 +24,26 @@ def get_args():
default='chs',
choices=['chs', 'en'],
help='select language')
parser.add_argument('-c',
'--chunk_size',
default=-1,
type=int,
help='set decoding chunk size')
parser.add_argument('-v',
'--verbose',
default=0,
type=int,
help='set log(glog backend) level')
parser.add_argument('audio', help='input audio file')
args = parser.parse_args()
return args


def main():
args = get_args()
set_log_level(args.verbose)
decoder = Decoder(lang=args.language)
decoder.set_chunk_size(args.chunk_size)
result = decoder.decode(args.audio)
print(result)

Expand Down
9 changes: 9 additions & 0 deletions runtime/core/api/wenet_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class Recognizer {
}
resource_->post_processor =
std::make_shared<wenet::PostProcessor>(*post_process_opts_);
// Init decode options
decode_options_->chunk_size = chunk_size_;
// Init decoder
decoder_ = std::make_shared<wenet::AsrDecoder>(feature_pipeline_, resource_,
*decode_options_);
Expand Down Expand Up @@ -180,6 +182,7 @@ class Recognizer {
void set_context_score(float score) { context_score_ = score; }
void set_language(const char* lang) { language_ = lang; }
void set_continuous_decoding(bool flag) { continuous_decoding_ = flag; }
void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; }

private:
// NOTE(Binbin Zhang): All use shared_ptr for clone in the future
Expand All @@ -197,6 +200,7 @@ class Recognizer {
float context_score_;
std::string language_ = "chs";
bool continuous_decoding_ = false;
int chunk_size_ = 16;
};

void* wenet_init(const char* model_dir) {
Expand Down Expand Up @@ -255,3 +259,8 @@ void wenet_set_continuous_decoding(void* decoder, int flag) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_continuous_decoding(flag > 0);
}

void wenet_set_chunk_size(void* decoder, int chunk_size) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_chunk_size(chunk_size);
}
4 changes: 4 additions & 0 deletions runtime/core/api/wenet_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ void wenet_set_log_level(int level);
*/
void wenet_set_continuous_decoding(void* decoder, int flag);

/** Set chunk size for decoding, -1 for non-streaming decoding
*/
void wenet_set_chunk_size(void* decoder, int chunk_size);

#ifdef __cplusplus
}
#endif
Expand Down
Loading