From 25f0a104688d1291179ec8387c8fb904d4ca0abf Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 18 Jul 2024 22:54:18 +0800 Subject: [PATCH] Add C++ runtime for SenseVoice models (#1148) --- .github/scripts/test-offline-ctc.sh | 25 +- .github/scripts/test-python.sh | 12 + .../workflows/export-sense-voice-to-onnx.yaml | 2 +- .gitignore | 1 + CHANGELOG.md | 4 + CMakeLists.txt | 2 +- .../offline-sense-voice-ctc-decode-files.py | 67 ++++ scripts/sense-voice/export-onnx.py | 9 +- sherpa-onnx/c-api/c-api.cc | 10 + sherpa-onnx/c-api/c-api.h | 7 + sherpa-onnx/csrc/CMakeLists.txt | 2 + .../offline-ct-transformer-model-meta-data.h | 2 +- sherpa-onnx/csrc/offline-ctc-model.cc | 2 + sherpa-onnx/csrc/offline-model-config.cc | 14 +- sherpa-onnx/csrc/offline-model-config.h | 4 + .../csrc/offline-recognizer-ctc-impl.h | 5 +- sherpa-onnx/csrc/offline-recognizer-impl.cc | 45 +++ .../csrc/offline-recognizer-paraformer-impl.h | 21 +- .../offline-recognizer-sense-voice-impl.h | 363 ++++++++++++++++++ .../csrc/offline-sense-voice-model-config.cc | 55 +++ .../csrc/offline-sense-voice-model-config.h | 39 ++ .../offline-sense-voice-model-meta-data.h | 50 +++ sherpa-onnx/csrc/offline-sense-voice-model.cc | 156 ++++++++ sherpa-onnx/csrc/offline-sense-voice-model.h | 61 +++ sherpa-onnx/csrc/onnx-utils.cc | 57 ++- sherpa-onnx/csrc/onnx-utils.h | 13 +- sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../python/csrc/offline-model-config.cc | 7 +- .../csrc/offline-paraformer-model-config.cc | 1 + .../csrc/offline-sense-voice-model-config.cc | 26 ++ .../csrc/offline-sense-voice-model-config.h | 16 + .../python/sherpa_onnx/offline_recognizer.py | 83 ++++ swift-api-examples/SherpaOnnx.swift | 18 +- .../decode-file-non-streaming.swift | 19 +- 34 files changed, 1160 insertions(+), 39 deletions(-) create mode 100644 python-api-examples/offline-sense-voice-ctc-decode-files.py create mode 100644 sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h create mode 100644 sherpa-onnx/csrc/offline-sense-voice-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-sense-voice-model-config.h create mode 100644 sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h create mode 100644 sherpa-onnx/csrc/offline-sense-voice-model.cc create mode 100644 sherpa-onnx/csrc/offline-sense-voice-model.h create mode 100644 sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-sense-voice-model-config.h diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index 9ac0ace49..dea791b54 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -15,7 +15,30 @@ echo "PATH: $PATH" which $EXE -if false; then +log "------------------------------------------------------------" +log "Run SenseVoice models" +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +repo=sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 + +for m in model.onnx model.int8.onnx; do + for w in zh en yue ja ko; do + for use_itn in 0 1; do + echo "$m $w $use_itn" + time $EXE \ + --tokens=$repo/tokens.txt \ + --sense-voice-model=$repo/$m \ + --sense-voice-use-itn=$use_itn \ + $repo/test_wavs/$w.wav + done + done +done + +rm -rf $repo + +if true; then # It has problems with onnxruntime 1.18 log "------------------------------------------------------------" log "Run Wenet models" diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index e8685d936..bdfd6d370 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -10,6 +10,18 @@ log() { export GIT_CLONE_PROTECTION_ACTIVE=false +log "test offline SenseVoice CTC" +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +name=$(basename $url) +repo=$(basename -s .tar.bz2 $name) + +curl -SL -O $url +tar xvf $name +rm $name +ls -lh $repo +python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py +rm -rf $repo + log "test offline TeleSpeech CTC" url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 name=$(basename $url) diff --git a/.github/workflows/export-sense-voice-to-onnx.yaml b/.github/workflows/export-sense-voice-to-onnx.yaml index 8303dec41..41a9a31a6 100644 --- a/.github/workflows/export-sense-voice-to-onnx.yaml +++ b/.github/workflows/export-sense-voice-to-onnx.yaml @@ -73,7 +73,7 @@ jobs: echo "pwd: $PWD" ls -lh ../scripts/sense-voice - rm -rf ./ + rm -rf ./* cp -v ../scripts/sense-voice/*.onnx . cp -v ../scripts/sense-voice/tokens.txt . diff --git a/.gitignore b/.gitignore index a39684cee..5486ad51a 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,4 @@ sherpa-onnx-telespeech-ctc-* *.fst .ccache lib*.a +sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 diff --git a/CHANGELOG.md b/CHANGELOG.md index b685f7b66..b432e375c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 1.10.17 + +* Support SenseVoice CTC models. + ## 1.10.16 * Support zh-en TTS model from MeloTTS. diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e2d1b2b3..661647020 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ project(sherpa-onnx) # ./nodejs-addon-examples # ./dart-api-examples/ # ./CHANGELOG.md -set(SHERPA_ONNX_VERSION "1.10.16") +set(SHERPA_ONNX_VERSION "1.10.17") # Disable warning about # diff --git a/python-api-examples/offline-sense-voice-ctc-decode-files.py b/python-api-examples/offline-sense-voice-ctc-decode-files.py new file mode 100644 index 000000000..b406288fc --- /dev/null +++ b/python-api-examples/offline-sense-voice-ctc-decode-files.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming SenseVoice CTC model from +https://github.com/FunAudioLLM/SenseVoice +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +""" + +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx" + tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt" + test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ja.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ko.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/yue.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_sense_voice( + model=model, + tokens=tokens, + use_itn=True, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/scripts/sense-voice/export-onnx.py b/scripts/sense-voice/export-onnx.py index b3923b981..97d9a506e 100755 --- a/scripts/sense-voice/export-onnx.py +++ b/scripts/sense-voice/export-onnx.py @@ -162,7 +162,9 @@ def main(): "neg_mean": neg_mean, "inv_stddev": inv_stddev, "model_type": "sense_voice_ctc", - "version": "1", + # version 1: Use QInt8 + # version 2: Use QUInt8 + "version": "2", "model_author": "iic", "maintainer": "k2-fsa", "vocab_size": vocab_size, @@ -185,7 +187,10 @@ def main(): model_input=filename, model_output=filename_int8, op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, + # Note that we have to use QUInt8 here. + # + # When QInt8 is used, C++ onnxruntime produces incorrect results + weight_type=QuantType.QUInt8, ) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index c63fad900..b91668746 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -310,6 +310,7 @@ struct SherpaOnnxOfflineStream { static sherpa_onnx::OfflineRecognizerConfig convertConfig( const SherpaOnnxOfflineRecognizerConfig *config); + SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( const SherpaOnnxOfflineRecognizerConfig *config) { sherpa_onnx::OfflineRecognizerConfig recognizer_config = @@ -391,6 +392,15 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( recognizer_config.model_config.telespeech_ctc = SHERPA_ONNX_OR(config->model_config.telespeech_ctc, ""); + recognizer_config.model_config.sense_voice.model = + SHERPA_ONNX_OR(config->model_config.sense_voice.model, ""); + + recognizer_config.model_config.sense_voice.language = + SHERPA_ONNX_OR(config->model_config.sense_voice.language, ""); + + recognizer_config.model_config.sense_voice.use_itn = + config->model_config.sense_voice.use_itn; + recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, ""); recognizer_config.lm_config.scale = diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 13ea4f5c4..36b286972 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -379,6 +379,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig { float scale; } SherpaOnnxOfflineLMConfig; +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSenseVoiceModelConfig { + const char *model; + const char *language; + int32_t use_itn; +} SherpaOnnxOfflineSenseVoiceModelConfig; + SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { SherpaOnnxOfflineTransducerModelConfig transducer; SherpaOnnxOfflineParaformerModelConfig paraformer; @@ -398,6 +404,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { const char *modeling_unit; const char *bpe_vocab; const char *telespeech_ctc; + SherpaOnnxOfflineSenseVoiceModelConfig sense_voice; } SherpaOnnxOfflineModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 89d3d278e..a176c701d 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -36,6 +36,8 @@ set(sources offline-recognizer-impl.cc offline-recognizer.cc offline-rnn-lm.cc + offline-sense-voice-model-config.cc + offline-sense-voice-model.cc offline-stream.cc offline-tdnn-ctc-model.cc offline-tdnn-model-config.cc diff --git a/sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h b/sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h index eea37d73e..d94e24a35 100644 --- a/sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h +++ b/sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h @@ -1,4 +1,4 @@ -// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h +// sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h // // Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index bd646ece3..6331a650b 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -93,6 +93,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, std::unique_ptr OfflineCtcModel::Create( const OfflineModelConfig &config) { + // TODO(fangjun): Refactor it. We don't need to use model_type here ModelType model_type = ModelType::kUnknown; std::string filename; @@ -148,6 +149,7 @@ std::unique_ptr OfflineCtcModel::Create( std::unique_ptr OfflineCtcModel::Create( AAssetManager *mgr, const OfflineModelConfig &config) { + // TODO(fangjun): Refactor it. We don't need to use model_type here ModelType model_type = ModelType::kUnknown; std::string filename; diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 24a5a2141..862e4a60c 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { tdnn.Register(po); zipformer_ctc.Register(po); wenet_ctc.Register(po); + sense_voice.Register(po); po->Register("telespeech-ctc", &telespeech_ctc, "Path to model.onnx for telespeech ctc"); @@ -94,15 +95,21 @@ bool OfflineModelConfig::Validate() const { return wenet_ctc.Validate(); } + if (!sense_voice.model.empty()) { + return sense_voice.Validate(); + } + if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", telespeech_ctc.c_str()); return false; - } else { - return true; } - return transducer.Validate(); + if (!transducer.encoder_filename.empty()) { + return transducer.Validate(); + } + + return true; } std::string OfflineModelConfig::ToString() const { @@ -116,6 +123,7 @@ std::string OfflineModelConfig::ToString() const { os << "tdnn=" << tdnn.ToString() << ", "; os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; + os << "sense_voice=" << sense_voice.ToString() << ", "; os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 856a6f35d..8eb725e4e 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -8,6 +8,7 @@ #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" #include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" @@ -24,6 +25,7 @@ struct OfflineModelConfig { OfflineTdnnModelConfig tdnn; OfflineZipformerCtcModelConfig zipformer_ctc; OfflineWenetCtcModelConfig wenet_ctc; + OfflineSenseVoiceModelConfig sense_voice; std::string telespeech_ctc; std::string tokens; @@ -53,6 +55,7 @@ struct OfflineModelConfig { const OfflineTdnnModelConfig &tdnn, const OfflineZipformerCtcModelConfig &zipformer_ctc, const OfflineWenetCtcModelConfig &wenet_ctc, + const OfflineSenseVoiceModelConfig &sense_voice, const std::string &telespeech_ctc, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type, @@ -65,6 +68,7 @@ struct OfflineModelConfig { tdnn(tdnn), zipformer_ctc(zipformer_ctc), wenet_ctc(wenet_ctc), + sense_voice(sense_voice), telespeech_ctc(telespeech_ctc), tokens(tokens), num_threads(num_threads), diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 9c7252a06..05c1b7981 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -212,10 +212,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } } - OfflineRecognizerConfig GetConfig() const override { - return config_; - } - + OfflineRecognizerConfig GetConfig() const override { return config_; } private: // Decode a single stream. diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index dd96f2b8a..319e104d0 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -21,6 +21,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" @@ -31,6 +32,28 @@ namespace sherpa_onnx { std::unique_ptr OfflineRecognizerImpl::Create( const OfflineRecognizerConfig &config) { + if (!config.model_config.sense_voice.model.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.paraformer.model.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.nemo_ctc.model.empty() || + !config.model_config.zipformer_ctc.model.empty() || + !config.model_config.tdnn.model.empty() || + !config.model_config.wenet_ctc.model.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.whisper.encoder.empty()) { + return std::make_unique(config); + } + + // TODO(fangjun): Refactor it. We only need to use model type for the + // following models: + // 1. transducer and nemo_transducer if (!config.model_config.model_type.empty()) { const auto &model_type = config.model_config.model_type; if (model_type == "transducer") { @@ -180,6 +203,28 @@ std::unique_ptr OfflineRecognizerImpl::Create( #if __ANDROID_API__ >= 9 std::unique_ptr OfflineRecognizerImpl::Create( AAssetManager *mgr, const OfflineRecognizerConfig &config) { + if (!config.model_config.sense_voice.model.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.paraformer.model.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.nemo_ctc.model.empty() || + !config.model_config.zipformer_ctc.model.empty() || + !config.model_config.tdnn.model.empty() || + !config.model_type.wenet_ctc.model.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.whisper.encoder.empty()) { + return std::make_unique(mgr, config); + } + + // TODO(fangjun): Refactor it. We only need to use model type for the + // following models: + // 1. transducer and nemo_transducer if (!config.model_config.model_type.empty()) { const auto &model_type = config.model_config.model_type; if (model_type == "transducer") { diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index 13240cc01..525c92cc2 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -102,9 +102,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { exit(-1); } - // Paraformer models assume input samples are in the range - // [-32768, 32767], so we set normalize_samples to false - config_.feat_config.normalize_samples = false; + InitFeatConfig(); } #if __ANDROID_API__ >= 9 @@ -124,9 +122,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { exit(-1); } - // Paraformer models assume input samples are in the range - // [-32768, 32767], so we set normalize_samples to false - config_.feat_config.normalize_samples = false; + InitFeatConfig(); } #endif @@ -211,11 +207,18 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { } } - OfflineRecognizerConfig GetConfig() const override { - return config_; - } + OfflineRecognizerConfig GetConfig() const override { return config_; } private: + void InitFeatConfig() { + // Paraformer models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + config_.feat_config.window_type = "hamming"; + config_.feat_config.high_freq = 0; + config_.feat_config.snip_edges = true; + } + std::vector ApplyLFR(const std::vector &in) const { int32_t lfr_window_size = model_->LfrWindowSize(); int32_t lfr_window_shift = model_->LfrWindowShift(); diff --git a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h new file mode 100644 index 000000000..6d7397dea --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h @@ -0,0 +1,363 @@ +// sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ + +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/offline-sense-voice-model.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult ConvertSenseVoiceResult( + const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, + int32_t frame_shift_ms, int32_t subsampling_factor) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.timestamps.size()); + + std::string text; + + for (int32_t i = 4; i < src.tokens.size(); ++i) { + auto sym = sym_table[src.tokens[i]]; + text.append(sym); + + r.tokens.push_back(std::move(sym)); + } + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + + for (int32_t i = 4; i < src.timestamps.size(); ++i) { + float time = frame_shift_s * (src.timestamps[i] - 4); + r.timestamps.push_back(time); + } + + r.words = std::move(src.words); + + return r; +} + +class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerSenseVoiceImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + const auto &meta_data = model_->GetModelMetadata(); + if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(meta_data.blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + + InitFeatConfig(); + } + +#if __ANDROID_API__ >= 9 + OfflineRecognizerSenseVoiceImpl(AAssetManager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique(mgr, + config.model_config)) { + const auto &meta_data = model_->GetModelMetadata(); + if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(meta_data.blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + + InitFeatConfig(); + } +#endif + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + if (n == 1) { + DecodeOneStream(ss[0]); + return; + } + + const auto &meta_data = model_->GetModelMetadata(); + // 1. Apply LFR + // 2. Apply CMVN + // + // Please refer to + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf + // for what LFR means + // + // "Lower Frame Rate Neural Network Acoustic Models" + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector features; + features.reserve(n); + + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size; + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + std::vector f = ss[i]->GetFrames(); + + f = ApplyLFR(f); + ApplyCMVN(&f); + + int32_t num_frames = f.size() / feat_dim; + features_vec[i] = std::move(f); + + features_length_vec[i] = num_frames; + + std::array shape = {num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = &features[i]; + } + + std::array features_length_shape = {n}; + Ort::Value x_length = Ort::Value::CreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + // Caution(fangjun): We cannot pad it with log(eps), + // i.e., -23.025850929940457f + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); + + int32_t language = 0; + if (config_.model_config.sense_voice.language.empty()) { + language = 0; + } else if (meta_data.lang2id.count( + config_.model_config.sense_voice.language)) { + language = + meta_data.lang2id.at(config_.model_config.sense_voice.language); + } else { + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", + config_.model_config.sense_voice.language.c_str()); + } + + std::vector language_array(n); + std::fill(language_array.begin(), language_array.end(), language); + + std::vector text_norm_array(n); + std::fill(text_norm_array.begin(), text_norm_array.end(), + config_.model_config.sense_voice.use_itn + ? meta_data.with_itn_id + : meta_data.without_itn_id); + + Ort::Value language_tensor = Ort::Value::CreateTensor( + memory_info, language_array.data(), n, features_length_shape.data(), + features_length_shape.size()); + + Ort::Value text_norm_tensor = Ort::Value::CreateTensor( + memory_info, text_norm_array.data(), n, features_length_shape.data(), + features_length_shape.size()); + + Ort::Value logits{nullptr}; + try { + logits = model_->Forward(std::move(x), std::move(x_length), + std::move(language_tensor), + std::move(text_norm_tensor)); + } catch (const Ort::Exception &ex) { + SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result", + ex.what()); + return; + } + + // decoder_->Decode() requires that logits_length is of dtype int64 + std::vector features_length_vec_64; + features_length_vec_64.reserve(n); + for (auto i : features_length_vec) { + i += 4; + features_length_vec_64.push_back(i); + } + + Ort::Value logits_length = Ort::Value::CreateTensor( + memory_info, features_length_vec_64.data(), n, + features_length_shape.data(), features_length_shape.size()); + + auto results = + decoder_->Decode(std::move(logits), std::move(logits_length)); + + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = meta_data.window_shift; + for (int32_t i = 0; i != n; ++i) { + auto r = ConvertSenseVoiceResult(results[i], symbol_table_, + frame_shift_ms, subsampling_factor); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + ss[i]->SetResult(r); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void DecodeOneStream(OfflineStream *s) const { + const auto &meta_data = model_->GetModelMetadata(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size; + std::vector f = s->GetFrames(); + f = ApplyLFR(f); + ApplyCMVN(&f); + int32_t num_frames = f.size() / feat_dim; + std::array shape = {1, num_frames, feat_dim}; + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int64_t scale_shape = 1; + + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &num_frames, 1, &scale_shape, 1); + + int32_t language = 0; + if (config_.model_config.sense_voice.language.empty()) { + language = 0; + } else if (meta_data.lang2id.count( + config_.model_config.sense_voice.language)) { + language = + meta_data.lang2id.at(config_.model_config.sense_voice.language); + } else { + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", + config_.model_config.sense_voice.language.c_str()); + } + + int32_t text_norm = config_.model_config.sense_voice.use_itn + ? meta_data.with_itn_id + : meta_data.without_itn_id; + + Ort::Value language_tensor = + Ort::Value::CreateTensor(memory_info, &language, 1, &scale_shape, 1); + + Ort::Value text_norm_tensor = + Ort::Value::CreateTensor(memory_info, &text_norm, 1, &scale_shape, 1); + + Ort::Value logits{nullptr}; + try { + logits = model_->Forward(std::move(x), std::move(x_length), + std::move(language_tensor), + std::move(text_norm_tensor)); + } catch (const Ort::Exception &ex) { + SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result", + ex.what()); + return; + } + + int64_t new_num_frames = num_frames + 4; + Ort::Value logits_length = Ort::Value::CreateTensor( + memory_info, &new_num_frames, 1, &scale_shape, 1); + + auto results = + decoder_->Decode(std::move(logits), std::move(logits_length)); + + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = meta_data.window_shift; + auto r = ConvertSenseVoiceResult(results[0], symbol_table_, frame_shift_ms, + subsampling_factor); + + r.text = ApplyInverseTextNormalization(std::move(r.text)); + s->SetResult(r); + } + + void InitFeatConfig() { + const auto &meta_data = model_->GetModelMetadata(); + + config_.feat_config.normalize_samples = meta_data.normalize_samples; + config_.feat_config.window_type = "hamming"; + config_.feat_config.high_freq = 0; + config_.feat_config.snip_edges = true; + } + std::vector ApplyLFR(const std::vector &in) const { + const auto &meta_data = model_->GetModelMetadata(); + + int32_t lfr_window_size = meta_data.window_size; + int32_t lfr_window_shift = meta_data.window_shift; + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t in_num_frames = in.size() / in_feat_dim; + int32_t out_num_frames = + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; + int32_t out_feat_dim = in_feat_dim * lfr_window_size; + + std::vector out(out_num_frames * out_feat_dim); + + const float *p_in = in.data(); + float *p_out = out.data(); + + for (int32_t i = 0; i != out_num_frames; ++i) { + std::copy(p_in, p_in + out_feat_dim, p_out); + + p_out += out_feat_dim; + p_in += lfr_window_shift * in_feat_dim; + } + + return out; + } + + void ApplyCMVN(std::vector *v) const { + const auto &meta_data = model_->GetModelMetadata(); + + const std::vector &neg_mean = meta_data.neg_mean; + const std::vector &inv_stddev = meta_data.inv_stddev; + + int32_t dim = neg_mean.size(); + int32_t num_frames = v->size() / dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != dim; ++k) { + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; + } + + p += dim; + } + } + + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-sense-voice-model-config.cc b/sherpa-onnx/csrc/offline-sense-voice-model-config.cc new file mode 100644 index 000000000..fc9884dd5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-sense-voice-model-config.cc @@ -0,0 +1,55 @@ +// sherpa-onnx/csrc/offline-sense-voice-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSenseVoiceModelConfig::Register(ParseOptions *po) { + po->Register("sense-voice-model", &model, + "Path to model.onnx of SenseVoice."); + po->Register( + "sense-voice-language", &language, + "Valid values: auto, zh, en, ja, ko, yue. If left empty, auto is used"); + po->Register( + "sense-voice-use-itn", &use_itn, + "True to enable inverse text normalization. False to disable it."); +} + +bool OfflineSenseVoiceModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("SenseVoice model '%s' does not exist", model.c_str()); + return false; + } + + if (!language.empty()) { + if (language != "auto" && language != "zh" && language != "en" && + language != "ja" && language != "ko" && language != "yue") { + SHERPA_ONNX_LOGE( + "Invalid sense-voice-language: '%s'. Valid values are: auto, zh, en, " + "ja, ko, yue. Or you can leave it empty to use 'auto'", + language.c_str()); + + return false; + } + } + + return true; +} + +std::string OfflineSenseVoiceModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSenseVoiceModelConfig("; + os << "model=\"" << model << "\", "; + os << "language=\"" << language << "\", "; + os << "use_itn=" << (use_itn ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-sense-voice-model-config.h b/sherpa-onnx/csrc/offline-sense-voice-model-config.h new file mode 100644 index 000000000..2f724e446 --- /dev/null +++ b/sherpa-onnx/csrc/offline-sense-voice-model-config.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/offline-sense-voice-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSenseVoiceModelConfig { + std::string model; + + // "" or "auto" to let the model recognize the language + // valid values: + // zh, en, ja, ko, yue, auto + std::string language = "auto"; + + // true to use inverse text normalization + // false to not use inverse text normalization + bool use_itn = false; + + OfflineSenseVoiceModelConfig() = default; + explicit OfflineSenseVoiceModelConfig(const std::string &model, + const std::string &language, + bool use_itn) + : model(model), language(language), use_itn(use_itn) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h b/sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h new file mode 100644 index 000000000..6065c93cc --- /dev/null +++ b/sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h @@ -0,0 +1,50 @@ +// sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +struct OfflineSenseVoiceModelMetaData { + // ID for using inverse text normalization + int32_t with_itn_id; + + // ID for not using inverse text normalization + int32_t without_itn_id; + + int32_t window_size; // lfr_m + int32_t window_shift; // lfr_n + int32_t vocab_size; + + int32_t subsampling_factor = 1; + + // Usually 0 for SenseVoice models. + // 0 means samples are scaled to [-32768, 32767] before are sent to the + // feature extractor + int32_t normalize_samples = 0; + + int32_t blank_id = 0; + + // possible values: + // zh, en, ja, ko, yue, auto + // where + // zh is Chinese (Mandarin) + // en is English + // ja is Japanese + // ko is Korean + // yue is Cantonese + // auto is to let the model recognize the language + std::unordered_map lang2id; + + std::vector neg_mean; + std::vector inv_stddev; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.cc b/sherpa-onnx/csrc/offline-sense-voice-model.cc new file mode 100644 index 000000000..2280a9be5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-sense-voice-model.cc @@ -0,0 +1,156 @@ +// sherpa-onnx/csrc/offline-sense-voice-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-sense-voice-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineSenseVoiceModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.sense_voice.model); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.sense_voice.model); + Init(buf.data(), buf.size()); + } +#endif + + Ort::Value Forward(Ort::Value features, Ort::Value features_length, + Ort::Value language, Ort::Value text_norm) { + std::array inputs = { + std::move(features), + std::move(features_length), + std::move(language), + std::move(text_norm), + }; + + auto ans = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + return std::move(ans[0]); + } + + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const { + return meta_data_; + } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.vocab_size, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "lfr_window_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_shift, "lfr_window_shift"); + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples, + "normalize_samples"); + + SHERPA_ONNX_READ_META_DATA(meta_data_.with_itn_id, "with_itn"); + + SHERPA_ONNX_READ_META_DATA(meta_data_.without_itn_id, "without_itn"); + + int32_t lang_auto = 0; + int32_t lang_zh = 0; + int32_t lang_en = 0; + int32_t lang_ja = 0; + int32_t lang_ko = 0; + int32_t lang_yue = 0; + + SHERPA_ONNX_READ_META_DATA(lang_auto, "lang_auto"); + SHERPA_ONNX_READ_META_DATA(lang_zh, "lang_zh"); + SHERPA_ONNX_READ_META_DATA(lang_en, "lang_en"); + SHERPA_ONNX_READ_META_DATA(lang_ja, "lang_ja"); + SHERPA_ONNX_READ_META_DATA(lang_ko, "lang_ko"); + SHERPA_ONNX_READ_META_DATA(lang_yue, "lang_yue"); + + meta_data_.lang2id = { + {"auto", lang_auto}, {"zh", lang_zh}, {"ja", lang_ja}, + {"ko", lang_ko}, {"yue", lang_yue}, + }; + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.neg_mean, "neg_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, "inv_stddev"); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineSenseVoiceModelMetaData meta_data_; +}; + +OfflineSenseVoiceModel::OfflineSenseVoiceModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineSenseVoiceModel::OfflineSenseVoiceModel(AAssetManager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineSenseVoiceModel::~OfflineSenseVoiceModel() = default; + +Ort::Value OfflineSenseVoiceModel::Forward(Ort::Value features, + Ort::Value features_length, + Ort::Value language, + Ort::Value text_norm) const { + return impl_->Forward(std::move(features), std::move(features_length), + std::move(language), std::move(text_norm)); +} + +const OfflineSenseVoiceModelMetaData &OfflineSenseVoiceModel::GetModelMetadata() + const { + return impl_->GetModelMetadata(); +} + +OrtAllocator *OfflineSenseVoiceModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.h b/sherpa-onnx/csrc/offline-sense-voice-model.h new file mode 100644 index 000000000..29d31b286 --- /dev/null +++ b/sherpa-onnx/csrc/offline-sense-voice-model.h @@ -0,0 +1,61 @@ +// sherpa-onnx/csrc/offline-sense-voice-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ + +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h" + +namespace sherpa_onnx { + +class OfflineSenseVoiceModel { + public: + explicit OfflineSenseVoiceModel(const OfflineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineSenseVoiceModel(AAssetManager *mgr, const OfflineModelConfig &config); +#endif + + ~OfflineSenseVoiceModel(); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int32_t. + * @param language A 1-D tensor of shape (N,) with dtype int32_t + * @param text_norm A 1-D tensor of shape (N,) with dtype int32_t + * + * @return Return logits of shape (N, T, C) with dtype float + * + * Note: The subsampling factor is 1 for SenseVoice, so there is + * no need to output logits_length. + */ + Ort::Value Forward(Ort::Value features, Ort::Value features_length, + Ort::Value language, Ort::Value text_norm) const; + + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index ec0475b20..1a6bf704b 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -6,6 +6,8 @@ #include #include +#include +#include #include #include @@ -153,23 +155,60 @@ Ort::Value View(Ort::Value *v) { } } +float ComputeSum(const Ort::Value *v, int32_t n /*= -1*/) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + auto size = static_cast(std::accumulate( + shape.begin(), shape.end(), 1, std::multiplies())); + if (n != -1 && n < size && n > 0) { + size = n; + } + + const float *p = v->GetTensorData(); + + return std::accumulate(p, p + size, 1.0f); +} + +float ComputeMean(const Ort::Value *v, int32_t n /*= -1*/) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + auto size = static_cast(std::accumulate( + shape.begin(), shape.end(), 1, std::multiplies())); + + if (n != -1 && n < size && n > 0) { + size = n; + } + + auto sum = ComputeSum(v, n); + return sum / size; +} + +void PrintShape(const Ort::Value *v) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + std::ostringstream os; + for (auto i : shape) { + os << i << ", "; + } + os << "\n"; + fprintf(stderr, "%s", os.str().c_str()); +} + template -void Print1D(Ort::Value *v) { +void Print1D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const T *d = v->GetTensorData(); std::ostringstream os; for (int32_t i = 0; i != static_cast(shape[0]); ++i) { - os << *d << " "; + os << d[i] << " "; } os << "\n"; fprintf(stderr, "%s\n", os.str().c_str()); } -template void Print1D(Ort::Value *v); -template void Print1D(Ort::Value *v); +template void Print1D(const Ort::Value *v); +template void Print1D(const Ort::Value *v); +template void Print1D(const Ort::Value *v); template -void Print2D(Ort::Value *v) { +void Print2D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const T *d = v->GetTensorData(); @@ -183,10 +222,10 @@ void Print2D(Ort::Value *v) { fprintf(stderr, "%s\n", os.str().c_str()); } -template void Print2D(Ort::Value *v); -template void Print2D(Ort::Value *v); +template void Print2D(const Ort::Value *v); +template void Print2D(const Ort::Value *v); -void Print3D(Ort::Value *v) { +void Print3D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); @@ -202,7 +241,7 @@ void Print3D(Ort::Value *v) { fprintf(stderr, "\n"); } -void Print4D(Ort::Value *v) { +void Print4D(const Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index da0abab82..98eb25137 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -68,19 +68,24 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); // Return a shallow copy Ort::Value View(Ort::Value *v); +float ComputeSum(const Ort::Value *v, int32_t n = -1); +float ComputeMean(const Ort::Value *v, int32_t n = -1); + // Print a 1-D tensor to stderr template -void Print1D(Ort::Value *v); +void Print1D(const Ort::Value *v); // Print a 2-D tensor to stderr template -void Print2D(Ort::Value *v); +void Print2D(const Ort::Value *v); // Print a 3-D tensor to stderr -void Print3D(Ort::Value *v); +void Print3D(const Ort::Value *v); // Print a 4-D tensor to stderr -void Print4D(Ort::Value *v); +void Print4D(const Ort::Value *v); + +void PrintShape(const Ort::Value *v); template void Fill(Ort::Value *tensor, T value) { diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 5e74a1721..fa5d32aff 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -15,6 +15,7 @@ set(srcs offline-paraformer-model-config.cc offline-punctuation.cc offline-recognizer.cc + offline-sense-voice-model-config.cc offline-stream.cc offline-tdnn-model-config.cc offline-transducer-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index a72c182ea..f498bd7e2 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -10,6 +10,7 @@ #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" #include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" @@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineTdnnModelConfig(m); PybindOfflineZipformerCtcModelConfig(m); PybindOfflineWenetCtcModelConfig(m); + PybindOfflineSenseVoiceModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") @@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) { const OfflineNemoEncDecCtcModelConfig &, const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, const OfflineZipformerCtcModelConfig &, - const OfflineWenetCtcModelConfig &, const std::string &, + const OfflineWenetCtcModelConfig &, + const OfflineSenseVoiceModelConfig &, const std::string &, const std::string &, int32_t, bool, const std::string &, const std::string &, const std::string &, const std::string &>(), py::arg("transducer") = OfflineTransducerModelConfig(), @@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) { py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), + py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), py::arg("telespeech_ctc") = "", py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "", @@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) { .def_readwrite("tdnn", &PyClass::tdnn) .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) + .def_readwrite("sense_voice", &PyClass::sense_voice) .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) diff --git a/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc b/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc index 4b0ca491e..5162c932f 100644 --- a/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc @@ -14,6 +14,7 @@ namespace sherpa_onnx { void PybindOfflineParaformerModelConfig(py::module *m) { using PyClass = OfflineParaformerModelConfig; py::class_(*m, "OfflineParaformerModelConfig") + .def(py::init<>()) .def(py::init(), py::arg("model")) .def_readwrite("model", &PyClass::model) .def("__str__", &PyClass::ToString); diff --git a/sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc b/sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc new file mode 100644 index 000000000..e9e9d0d05 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc @@ -0,0 +1,26 @@ +// sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineSenseVoiceModelConfig(py::module *m) { + using PyClass = OfflineSenseVoiceModelConfig; + py::class_(*m, "OfflineSenseVoiceModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("language"), py::arg("use_itn")) + .def_readwrite("model", &PyClass::model) + .def_readwrite("language", &PyClass::language) + .def_readwrite("use_itn", &PyClass::use_itn) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-sense-voice-model-config.h b/sherpa-onnx/python/csrc/offline-sense-voice-model-config.h new file mode 100644 index 000000000..eda8336f9 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-sense-voice-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-sense-voice-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineSenseVoiceModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index f0e9a45f2..dc1fcdf12 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -10,6 +10,7 @@ OfflineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, + OfflineSenseVoiceModelConfig, ) from _sherpa_onnx import OfflineRecognizer as _Recognizer from _sherpa_onnx import ( @@ -173,6 +174,88 @@ def from_transducer( self.config = recognizer_config return self + @classmethod + def from_sense_voice( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + language: str = "", + use_itn: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + language: + If not empty, then valid values are: auto, zh, en, ja, ko, yue + use_itn: + True to enable inverse text normalization; False to disable it. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + sense_voice=OfflineSenseVoiceModelConfig( + model=model, + language=language, + use_itn=use_itn, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + @classmethod def from_paraformer( cls, diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index eba8d8916..c6f8b51eb 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -355,6 +355,18 @@ func sherpaOnnxOfflineTdnnModelConfig( ) } +func sherpaOnnxOfflineSenseVoiceModelConfig( + model: String = "", + language: String = "", + useInverseTextNormalization: Bool = false +) -> SherpaOnnxOfflineSenseVoiceModelConfig { + return SherpaOnnxOfflineSenseVoiceModelConfig( + model: toCPointer(model), + language: toCPointer(language), + use_itn: useInverseTextNormalization ? 1 : 0 + ) +} + func sherpaOnnxOfflineLMConfig( model: String = "", scale: Float = 1.0 @@ -378,7 +390,8 @@ func sherpaOnnxOfflineModelConfig( modelType: String = "", modelingUnit: String = "cjkchar", bpeVocab: String = "", - teleSpeechCtc: String = "" + teleSpeechCtc: String = "", + senseVoice: SherpaOnnxOfflineSenseVoiceModelConfig = sherpaOnnxOfflineSenseVoiceModelConfig() ) -> SherpaOnnxOfflineModelConfig { return SherpaOnnxOfflineModelConfig( transducer: transducer, @@ -393,7 +406,8 @@ func sherpaOnnxOfflineModelConfig( model_type: toCPointer(modelType), modeling_unit: toCPointer(modelingUnit), bpe_vocab: toCPointer(bpeVocab), - telespeech_ctc: toCPointer(teleSpeechCtc) + telespeech_ctc: toCPointer(teleSpeechCtc), + sense_voice: senseVoice ) } diff --git a/swift-api-examples/decode-file-non-streaming.swift b/swift-api-examples/decode-file-non-streaming.swift index ca9d9475e..a60777832 100644 --- a/swift-api-examples/decode-file-non-streaming.swift +++ b/swift-api-examples/decode-file-non-streaming.swift @@ -17,6 +17,7 @@ func run() { var modelConfig: SherpaOnnxOfflineModelConfig var modelType = "whisper" // modelType = "paraformer" + // modelType = "sense_voice" if modelType == "whisper" { let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" @@ -47,6 +48,19 @@ func run() { debug: 0, modelType: "paraformer" ) + } else if modelType == "sense_voice" { + let model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx" + let tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt" + let senseVoiceConfig = sherpaOnnxOfflineSenseVoiceModelConfig( + model: model, + useInverseTextNormalization: true + ) + + modelConfig = sherpaOnnxOfflineModelConfig( + tokens: tokens, + debug: 0, + senseVoice: senseVoiceConfig + ) } else { print("Please specify a supported modelType \(modelType)") return @@ -63,7 +77,10 @@ func run() { recognizer = SherpaOnnxOfflineRecognizer(config: &config) - let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" + var filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" + if modelType == "sense_voice" { + filePath = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav" + } let fileURL: NSURL = NSURL(fileURLWithPath: filePath) let audioFile = try! AVAudioFile(forReading: fileURL as URL)