diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index c03b95426..0dbb8b99d 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -256,7 +256,18 @@ if [[ x$OS != x'windows-latest' ]]; then $repo/test_wavs/3.wav \ $repo/test_wavs/8k.wav + ln -s $repo $PWD/ + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + + python3 ./python-api-examples/inverse-text-normalization-online-asr.py + python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose + + rm -rfv sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 + + rm -rf $repo fi log "Test non-streaming transducer models" diff --git a/python-api-examples/inverse-text-normalization-online-asr.py b/python-api-examples/inverse-text-normalization-online-asr.py new file mode 100755 index 000000000..8524c20f3 --- /dev/null +++ b/python-api-examples/inverse-text-normalization-online-asr.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2024 Xiaomi Corporation + +""" +This script shows how to use inverse text normalization with streaming ASR. + +Usage: + +(1) Download the test model + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + +(2) Download rule fst + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + +Please refer to +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb +for how itn_zh_number.fst is generated. + +(3) Download test wave + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + +(4) Run this script + +python3 ./python-api-examples/inverse-text-normalization-online-asr.py +""" +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" + decoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" + tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + rule_fsts = "./itn_zh_number.fst" + + if ( + not Path(encoder).is_file() + or not Path(decoder).is_file() + or not Path(joiner).is_file() + or not Path(tokens).is_file() + or not Path(rule_fsts).is_file() + ): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return sherpa_onnx.OnlineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + debug=True, + rule_fsts=rule_fsts, + ) + + +def main(): + recognizer = create_recognizer() + wave_filename = "./itn-zh-number.wav" + if not Path(wave_filename).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + + tail_padding = [0] * int(0.3 * sample_rate) + stream.accept_waveform(sample_rate, tail_padding) + + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + print(wave_filename) + print(recognizer.get_result_all(stream)) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 4d8ce2961..7dd9d8b18 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { public: explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(config), + config_(config), model_(OnlineCtcModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { @@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { #if __ANDROID_API__ >= 9 explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(mgr, config), + config_(config), model_(OnlineCtcModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { @@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; int32_t subsampling_factor = 4; - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(r.text); + return r; } bool IsEndpoint(OnlineStream *s) const override { diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 2de905772..89d172f97 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -4,11 +4,22 @@ #include "sherpa-onnx/csrc/online-recognizer-impl.h" +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -78,4 +89,110 @@ std::unique_ptr OnlineRecognizerImpl::Create( } #endif +OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + itn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } +} + +#if __ANDROID_API__ >= 9 +OnlineRecognizerImpl::OnlineRecognizerImpl(AAssetManager *mgr, + const OnlineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + itn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) +} +#endif + +std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( + std::string text) const { + if (!itn_list_.empty()) { + for (const auto &tn : itn_list_) { + text = tn->Normalize(text); + if (config_.model_config.debug) { + SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str()); + } + } + } + + return text; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h index 72efedec7..8b569f3af 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -9,6 +9,12 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "kaldifst/csrc/text-normalizer.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-stream.h" @@ -17,10 +23,15 @@ namespace sherpa_onnx { class OnlineRecognizerImpl { public: + explicit OnlineRecognizerImpl(const OnlineRecognizerConfig &config); + static std::unique_ptr Create( const OnlineRecognizerConfig &config); #if __ANDROID_API__ >= 9 + OnlineRecognizerImpl(AAssetManager *mgr, + const OnlineRecognizerConfig &config); + static std::unique_ptr Create( AAssetManager *mgr, const OnlineRecognizerConfig &config); #endif @@ -50,6 +61,15 @@ class OnlineRecognizerImpl { virtual bool IsEndpoint(OnlineStream *s) const = 0; virtual void Reset(OnlineStream *s) const = 0; + + std::string ApplyInverseTextNormalization(std::string text) const; + + private: + OnlineRecognizerConfig config_; + // for inverse text normalization. Used only if + // config.rule_fsts is not empty or + // config.rule_fars is not empty + std::vector> itn_list_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h index 8303af5e3..26fdb08c3 100644 --- a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h @@ -96,7 +96,8 @@ static void Scale(const float *x, int32_t n, float scale, float *y) { class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { public: explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(config), + config_(config), model_(config.model_config), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { @@ -116,7 +117,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { #if __ANDROID_API__ >= 9 explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(mgr, config), + config_(config), model_(mgr, config.model_config), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { @@ -160,7 +162,9 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { OnlineRecognizerResult GetResult(OnlineStream *s) const override { auto decoder_result = s->GetParaformerResult(); - return Convert(decoder_result, sym_); + auto r = Convert(decoder_result, sym_); + r.text = ApplyInverseTextNormalization(r.text); + return r; } bool IsEndpoint(OnlineStream *s) const override { diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index a2531b10c..2bea765cb 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { public: explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(config), + config_(config), model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { @@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { #if __ANDROID_API__ >= 9 explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr, const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(mgr, config), + config_(config), model_(OnlineTransducerModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { @@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; int32_t subsampling_factor = 4; - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + return r; } bool IsEndpoint(OnlineStream *s) const override { diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 2391efb1f..700054dc2 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { public: explicit OnlineRecognizerTransducerNeMoImpl( const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(config), + config_(config), symbol_table_(config.model_config.tokens), endpoint_(config_.endpoint_config), model_( @@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { #if __ANDROID_API__ >= 9 explicit OnlineRecognizerTransducerNeMoImpl( AAssetManager *mgr, const OnlineRecognizerConfig &config) - : config_(config), + : OnlineRecognizerImpl(mgr, config), + config_(config), symbol_table_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config), model_(std::make_unique( @@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; int32_t subsampling_factor = model_->SubsamplingFactor(); - return Convert(s->GetResult(), symbol_table_, frame_shift_ms, - subsampling_factor, s->GetCurrentSegment(), - s->GetNumFramesSinceStart()); + auto r = Convert(s->GetResult(), symbol_table_, frame_shift_ms, + subsampling_factor, s->GetCurrentSegment(), + s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + return r; } bool IsEndpoint(OnlineStream *s) const override { diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index fcb9169ef..a49a62f6a 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -14,7 +14,9 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -100,6 +102,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "now support greedy_search and modified_beam_search."); po->Register("temperature-scale", &temperature_scale, "Temperature scale for confidence computation in decoding."); + po->Register( + "rule-fsts", &rule_fsts, + "If not empty, it specifies fsts for inverse text normalization. " + "If there are multiple fsts, they are separated by a comma."); + + po->Register( + "rule-fars", &rule_fars, + "If not empty, it specifies fst archives for inverse text normalization. " + "If there are multiple archives, they are separated by a comma."); } bool OnlineRecognizerConfig::Validate() const { @@ -129,6 +140,34 @@ bool OnlineRecognizerConfig::Validate() const { return false; } + if (!hotwords_file.empty() && !FileExists(hotwords_file)) { + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist", + hotwords_file.c_str()); + return false; + } + + if (!rule_fsts.empty()) { + std::vector files; + SplitStringToVector(rule_fsts, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + if (!rule_fars.empty()) { + std::vector files; + SplitStringToVector(rule_fars, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); + return false; + } + } + } + return model_config.Validate(); } @@ -147,7 +186,9 @@ std::string OnlineRecognizerConfig::ToString() const { os << "hotwords_file=\"" << hotwords_file << "\", "; os << "decoding_method=\"" << decoding_method << "\", "; os << "blank_penalty=" << blank_penalty << ", "; - os << "temperature_scale=" << temperature_scale << ")"; + os << "temperature_scale=" << temperature_scale << ", "; + os << "rule_fsts=\"" << rule_fsts << "\", "; + os << "rule_fars=\"" << rule_fars << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index f7fcf2f21..7fde367fb 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -100,6 +100,12 @@ struct OnlineRecognizerConfig { float temperature_scale = 2.0; + // If there are multiple rules, they are applied from left to right. + std::string rule_fsts; + + // If there are multiple FST archives, they are applied from left to right. + std::string rule_fars; + OnlineRecognizerConfig() = default; OnlineRecognizerConfig( @@ -109,7 +115,8 @@ struct OnlineRecognizerConfig { const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, bool enable_endpoint, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, - float hotwords_score, float blank_penalty, float temperature_scale) + float hotwords_score, float blank_penalty, float temperature_scale, + const std::string &rule_fsts, const std::string &rule_fars) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -121,7 +128,9 @@ struct OnlineRecognizerConfig { hotwords_file(hotwords_file), hotwords_score(hotwords_score), blank_penalty(blank_penalty), - temperature_scale(temperature_scale) {} + temperature_scale(temperature_scale), + rule_fsts(rule_fsts), + rule_fars(rule_fars) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 148f73ee5..fe6cd454a 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -54,19 +54,20 @@ static void PybindOnlineRecognizerResult(py::module *m) { static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") - .def( - py::init(), - py::arg("feat_config"), py::arg("model_config"), - py::arg("lm_config") = OnlineLMConfig(), - py::arg("endpoint_config") = EndpointConfig(), - py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), - py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, - py::arg("temperature_scale") = 2.0) + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("lm_config") = OnlineLMConfig(), + py::arg("endpoint_config") = EndpointConfig(), + py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, + py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", + py::arg("rule_fars") = "") .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -79,6 +80,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("hotwords_score", &PyClass::hotwords_score) .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def_readwrite("temperature_scale", &PyClass::temperature_scale) + .def_readwrite("rule_fsts", &PyClass::rule_fsts) + .def_readwrite("rule_fars", &PyClass::rule_fars) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 97f7472b4..82b2e3b42 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -64,6 +64,8 @@ def from_transducer( lm_scale: float = 0.1, temperature_scale: float = 2.0, debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -148,6 +150,12 @@ def from_transducer( the log probability, you can get it from the directory where your bpe model is generated. Only used when hotwords provided and the modeling unit is bpe or cjkchar+bpe. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -217,6 +225,8 @@ def from_transducer( hotwords_file=hotwords_file, blank_penalty=blank_penalty, temperature_scale=temperature_scale, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) @@ -239,6 +249,8 @@ def from_paraformer( decoding_method: str = "greedy_search", provider: str = "cpu", debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -283,6 +295,12 @@ def from_paraformer( The only valid value is greedy_search. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -322,6 +340,8 @@ def from_paraformer( endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) @@ -345,6 +365,8 @@ def from_zipformer2_ctc( ctc_max_active: int = 3000, provider: str = "cpu", debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -393,6 +415,12 @@ def from_zipformer2_ctc( active paths at a time. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -433,6 +461,8 @@ def from_zipformer2_ctc( ctc_fst_decoder_config=ctc_fst_decoder_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) @@ -454,6 +484,8 @@ def from_nemo_ctc( decoding_method: str = "greedy_search", provider: str = "cpu", debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -497,6 +529,12 @@ def from_nemo_ctc( onnxruntime execution providers. Valid values are: cpu, cuda, coreml. debug: True to show meta data in the model. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -533,6 +571,8 @@ def from_nemo_ctc( endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config) @@ -556,6 +596,8 @@ def from_wenet_ctc( decoding_method: str = "greedy_search", provider: str = "cpu", debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", ): """ Please refer to @@ -602,6 +644,12 @@ def from_wenet_ctc( The only valid value is greedy_search. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -640,6 +688,8 @@ def from_wenet_ctc( endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, ) self.recognizer = _Recognizer(recognizer_config)