From 669f5ef44105138116d9d8dd5fe224ec3f6354a6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 26 Oct 2024 14:34:07 +0800 Subject: [PATCH] Add C++ runtime and Python APIs for Moonshine models (#1473) --- .github/scripts/test-offline-moonshine.sh | 50 ++++ .github/scripts/test-python.sh | 10 + .github/workflows/linux.yaml | 13 + .github/workflows/macos.yaml | 11 +- .github/workflows/windows-x64.yaml | 8 + .github/workflows/windows-x86.yaml | 8 + python-api-examples/generate-subtitles.py | 117 +++++++- python-api-examples/non_streaming_server.py | 108 ++++++- .../offline-moonshine-decode-files.py | 82 +++++ .../offline-whisper-decode-files.py | 77 +++++ .../vad-with-non-streaming-asr.py | 83 +++++- sherpa-onnx/csrc/CMakeLists.txt | 3 + sherpa-onnx/csrc/offline-model-config.cc | 6 + sherpa-onnx/csrc/offline-model-config.h | 4 + sherpa-onnx/csrc/offline-moonshine-decoder.h | 34 +++ ...offline-moonshine-greedy-search-decoder.cc | 87 ++++++ .../offline-moonshine-greedy-search-decoder.h | 29 ++ .../csrc/offline-moonshine-model-config.cc | 88 ++++++ .../csrc/offline-moonshine-model-config.h | 37 +++ sherpa-onnx/csrc/offline-moonshine-model.cc | 282 ++++++++++++++++++ sherpa-onnx/csrc/offline-moonshine-model.h | 93 ++++++ sherpa-onnx/csrc/offline-recognizer-impl.cc | 15 + .../csrc/offline-recognizer-moonshine-impl.h | 150 ++++++++++ sherpa-onnx/csrc/offline-stream.cc | 27 +- sherpa-onnx/csrc/offline-stream.h | 12 +- sherpa-onnx/csrc/offline-whisper-model.cc | 9 +- sherpa-onnx/csrc/sherpa-onnx-offline.cc | 19 +- sherpa-onnx/csrc/symbol-table.cc | 2 + sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../python/csrc/offline-model-config.cc | 7 +- .../csrc/offline-moonshine-model-config.cc | 28 ++ .../csrc/offline-moonshine-model-config.h | 16 + .../python/sherpa_onnx/offline_recognizer.py | 92 +++++- 33 files changed, 1572 insertions(+), 36 deletions(-) create mode 100755 .github/scripts/test-offline-moonshine.sh create mode 100644 python-api-examples/offline-moonshine-decode-files.py create mode 100644 python-api-examples/offline-whisper-decode-files.py create mode 100644 sherpa-onnx/csrc/offline-moonshine-decoder.h create mode 100644 sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc create mode 100644 sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h create mode 100644 sherpa-onnx/csrc/offline-moonshine-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-moonshine-model-config.h create mode 100644 sherpa-onnx/csrc/offline-moonshine-model.cc create mode 100644 sherpa-onnx/csrc/offline-moonshine-model.h create mode 100644 sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h create mode 100644 sherpa-onnx/python/csrc/offline-moonshine-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-moonshine-model-config.h diff --git a/.github/scripts/test-offline-moonshine.sh b/.github/scripts/test-offline-moonshine.sh new file mode 100755 index 000000000..1768e82ec --- /dev/null +++ b/.github/scripts/test-offline-moonshine.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +export GIT_CLONE_PROTECTION_ACTIVE=false + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +names=( +tiny +base +) + +for name in ${names[@]}; do + log "------------------------------------------------------------" + log "Run $name" + log "------------------------------------------------------------" + + repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2 + repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2 + curl -SL -O $repo_url + tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2 + rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2 + repo=sherpa-onnx-moonshine-$name-en-int8 + log "Start testing ${repo_url}" + + log "test int8 onnx" + + time $EXE \ + --moonshine-preprocessor=$repo/preprocess.onnx \ + --moonshine-encoder=$repo/encode.int8.onnx \ + --moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=$repo/cached_decode.int8.onnx \ + --tokens=$repo/tokens.txt \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + + rm -rf $repo +done diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 8c9d303b0..91f6f66bc 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,16 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test offline Moonshine" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 + +python3 ./python-api-examples/offline-moonshine-decode-files.py + +rm -rf sherpa-onnx-moonshine-tiny-en-int8 + log "test offline speaker diarization" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 2ad40e215..98c88e589 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -149,6 +149,19 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: install/* + - name: Test offline Moonshine + if: matrix.build_type != 'Debug' + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + readelf -d build/bin/sherpa-onnx-offline + + .github/scripts/test-offline-moonshine.sh + du -h -d1 . + - name: Test offline CTC shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 849631015..9e17491bb 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -121,6 +121,15 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline Moonshine + if: matrix.build_type != 'Debug' + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-moonshine.sh + - name: Test C++ API shell: bash run: | @@ -243,8 +252,6 @@ jobs: .github/scripts/test-offline-whisper.sh - - - name: Test online transducer shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 9435dcefd..50bf014d4 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -93,6 +93,14 @@ jobs: name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} path: build/install/* + - name: Test offline Moonshine for windows x64 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline.exe + + .github/scripts/test-offline-moonshine.sh + - name: Test C++ API shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 36089b2dd..8a1370959 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -93,6 +93,14 @@ jobs: name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} path: build/install/* + - name: Test offline Moonshine for windows x86 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline.exe + + .github/scripts/test-offline-moonshine.sh + - name: Test C++ API shell: bash run: | diff --git a/python-api-examples/generate-subtitles.py b/python-api-examples/generate-subtitles.py index 85871016a..bf51f7627 100755 --- a/python-api-examples/generate-subtitles.py +++ b/python-api-examples/generate-subtitles.py @@ -47,7 +47,19 @@ --feature-dim=80 \ /path/to/test.mp4 -(3) For Whisper models +(3) For Moonshine models + +./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \ + --num-threads=2 \ + /path/to/test.mp4 + +(4) For Whisper models ./python-api-examples/generate-subtitles.py \ --silero-vad-model=/path/to/silero_vad.onnx \ @@ -58,7 +70,7 @@ --num-threads=2 \ /path/to/test.mp4 -(4) For SenseVoice CTC models +(5) For SenseVoice CTC models ./python-api-examples/generate-subtitles.py \ --silero-vad-model=/path/to/silero_vad.onnx \ @@ -68,7 +80,7 @@ /path/to/test.mp4 -(5) For WeNet CTC models +(6) For WeNet CTC models ./python-api-examples/generate-subtitles.py \ --silero-vad-model=/path/to/silero_vad.onnx \ @@ -83,6 +95,7 @@ used in this file. """ import argparse +import datetime as dt import shutil import subprocess import sys @@ -157,7 +170,7 @@ def get_args(): parser.add_argument( "--num-threads", type=int, - default=1, + default=2, help="Number of threads for neural network computation", ) @@ -208,6 +221,34 @@ def get_args(): """, ) + parser.add_argument( + "--moonshine-preprocessor", + default="", + type=str, + help="Path to moonshine preprocessor model", + ) + + parser.add_argument( + "--moonshine-encoder", + default="", + type=str, + help="Path to moonshine encoder model", + ) + + parser.add_argument( + "--moonshine-uncached-decoder", + default="", + type=str, + help="Path to moonshine uncached decoder model", + ) + + parser.add_argument( + "--moonshine-cached-decoder", + default="", + type=str, + help="Path to moonshine cached decoder model", + ) + parser.add_argument( "--decoding-method", type=str, @@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.encoder) assert_file_exists(args.decoder) @@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.paraformer) @@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.sense_voice) recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( @@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: elif args.wenet_ctc: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.wenet_ctc) @@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: elif args.whisper_encoder: assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_decoder) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( encoder=args.whisper_encoder, @@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: task=args.whisper_task, tail_paddings=args.whisper_tail_paddings, ) + elif args.moonshine_preprocessor: + assert_file_exists(args.moonshine_preprocessor) + assert_file_exists(args.moonshine_encoder) + assert_file_exists(args.moonshine_uncached_decoder) + assert_file_exists(args.moonshine_cached_decoder) + + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( + preprocessor=args.moonshine_preprocessor, + encoder=args.moonshine_encoder, + uncached_decoder=args.moonshine_uncached_decoder, + cached_decoder=args.moonshine_cached_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + ) else: raise ValueError("Please specify at least one model") @@ -424,28 +511,32 @@ def main(): segment_list = [] print("Started!") + start_t = dt.datetime.now() + num_processed_samples = 0 - is_silence = False + is_eof = False # TODO(fangjun): Support multithreads while True: # *2 because int16_t has two bytes data = process.stdout.read(frames_per_read * 2) if not data: - if is_silence: + if is_eof: break - is_silence = True - # The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data + is_eof = True + # pad 1 second at the end of the file for the VAD data = np.zeros(1 * args.sample_rate, dtype=np.int16) samples = np.frombuffer(data, dtype=np.int16) samples = samples.astype(np.float32) / 32768 + num_processed_samples += samples.shape[0] + buffer = np.concatenate([buffer, samples]) while len(buffer) > window_size: vad.accept_waveform(buffer[:window_size]) buffer = buffer[window_size:] - if is_silence: + if is_eof: vad.flush() streams = [] @@ -471,6 +562,11 @@ def main(): seg.text = stream.result.text segment_list.append(seg) + end_t = dt.datetime.now() + elapsed_seconds = (end_t - start_t).total_seconds() + duration = num_processed_samples / 16000 + rtf = elapsed_seconds / duration + srt_filename = Path(args.sound_file).with_suffix(".srt") with open(srt_filename, "w", encoding="utf-8") as f: for i, seg in enumerate(segment_list): @@ -479,6 +575,9 @@ def main(): print("", file=f) print(f"Saved to {srt_filename}") + print(f"Audio duration:\t{duration:.3f} s") + print(f"Elapsed:\t{elapsed_seconds:.3f} s") + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") print("Done!") diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 2194d6f54..3dd12564e 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -66,7 +66,21 @@ --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt -(5) Use a Whisper model +(5) Use a Moonshine model + +cd /path/to/sherpa-onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt + +(6) Use a Whisper model cd /path/to/sherpa-onnx curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 @@ -78,7 +92,7 @@ --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt -(5) Use a tdnn model of the yesno recipe from icefall +(7) Use a tdnn model of the yesno recipe from icefall cd /path/to/sherpa-onnx @@ -92,7 +106,7 @@ --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt -(6) Use a Non-streaming SenseVoice model +(8) Use a Non-streaming SenseVoice model curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 @@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): ) +def add_moonshine_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--moonshine-preprocessor", + default="", + type=str, + help="Path to moonshine preprocessor model", + ) + + parser.add_argument( + "--moonshine-encoder", + default="", + type=str, + help="Path to moonshine encoder model", + ) + + parser.add_argument( + "--moonshine-uncached-decoder", + default="", + type=str, + help="Path to moonshine uncached decoder model", + ) + + parser.add_argument( + "--moonshine-cached-decoder", + default="", + type=str, + help="Path to moonshine cached decoder model", + ) + + def add_whisper_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--whisper-encoder", @@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser): add_wenet_ctc_model_args(parser) add_tdnn_ctc_model_args(parser) add_whisper_model_args(parser) + add_moonshine_model_args(parser) parser.add_argument( "--tokens", @@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.encoder) assert_file_exists(args.decoder) @@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.paraformer) @@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.sense_voice) recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( @@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.nemo_ctc) @@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.wenet_ctc) @@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_decoder) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( encoder=args.whisper_encoder, @@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: ) elif args.tdnn_model: assert_file_exists(args.tdnn_model) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( model=args.tdnn_model, @@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: decoding_method=args.decoding_method, provider=args.provider, ) + elif args.moonshine_preprocessor: + assert_file_exists(args.moonshine_preprocessor) + assert_file_exists(args.moonshine_encoder) + assert_file_exists(args.moonshine_uncached_decoder) + assert_file_exists(args.moonshine_cached_decoder) + + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( + preprocessor=args.moonshine_preprocessor, + encoder=args.moonshine_encoder, + uncached_decoder=args.moonshine_uncached_decoder, + cached_decoder=args.moonshine_cached_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + ) else: raise ValueError("Please specify at least one model") diff --git a/python-api-examples/offline-moonshine-decode-files.py b/python-api-examples/offline-moonshine-decode-files.py new file mode 100644 index 000000000..f4d153d87 --- /dev/null +++ b/python-api-examples/offline-moonshine-decode-files.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming Moonshine model from +https://github.com/usefulsensors/moonshine +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +""" + +import datetime as dt +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx" + encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx" + uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx" + cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx" + + tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt" + test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav" + + if not Path(preprocessor).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_moonshine( + preprocessor=preprocessor, + encoder=encoder, + uncached_decoder=uncached_decoder, + cached_decoder=cached_decoder, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + start_t = dt.datetime.now() + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + + end_t = dt.datetime.now() + elapsed_seconds = (end_t - start_t).total_seconds() + duration = audio.shape[-1] / sample_rate + rtf = elapsed_seconds / duration + + print(stream.result) + print(wave_filename) + print("Text:", stream.result.text) + print(f"Audio duration:\t{duration:.3f} s") + print(f"Elapsed:\t{elapsed_seconds:.3f} s") + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + + +if __name__ == "__main__": + main() diff --git a/python-api-examples/offline-whisper-decode-files.py b/python-api-examples/offline-whisper-decode-files.py new file mode 100644 index 000000000..aa85ccdf9 --- /dev/null +++ b/python-api-examples/offline-whisper-decode-files.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming whisper model from +https://github.com/openai/whisper +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 +tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 +rm sherpa-onnx-whisper-tiny.en.tar.bz2 +""" + +import datetime as dt +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" + decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" + tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" + test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" + + if not Path(encoder).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_whisper( + encoder=encoder, + decoder=decoder, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + start_t = dt.datetime.now() + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + + end_t = dt.datetime.now() + elapsed_seconds = (end_t - start_t).total_seconds() + duration = audio.shape[-1] / sample_rate + rtf = elapsed_seconds / duration + + print(stream.result) + print(wave_filename) + print("Text:", stream.result.text) + print(f"Audio duration:\t{duration:.3f} s") + print(f"Elapsed:\t{elapsed_seconds:.3f} s") + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + + +if __name__ == "__main__": + main() diff --git a/python-api-examples/vad-with-non-streaming-asr.py b/python-api-examples/vad-with-non-streaming-asr.py index 7bb125d1a..f5bde30c6 100755 --- a/python-api-examples/vad-with-non-streaming-asr.py +++ b/python-api-examples/vad-with-non-streaming-asr.py @@ -35,7 +35,18 @@ --sample-rate=16000 \ --feature-dim=80 -(3) For Whisper models +(3) For Moonshine models + +./python-api-examples/vad-with-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \ + --num-threads=2 + +(4) For Whisper models ./python-api-examples/vad-with-non-streaming-asr.py \ --silero-vad-model=/path/to/silero_vad.onnx \ @@ -45,7 +56,7 @@ --whisper-task=transcribe \ --num-threads=2 -(4) For SenseVoice CTC models +(5) For SenseVoice CTC models ./python-api-examples/vad-with-non-streaming-asr.py \ --silero-vad-model=/path/to/silero_vad.onnx \ @@ -192,6 +203,34 @@ def get_args(): """, ) + parser.add_argument( + "--moonshine-preprocessor", + default="", + type=str, + help="Path to moonshine preprocessor model", + ) + + parser.add_argument( + "--moonshine-encoder", + default="", + type=str, + help="Path to moonshine encoder model", + ) + + parser.add_argument( + "--moonshine-uncached-decoder", + default="", + type=str, + help="Path to moonshine uncached decoder model", + ) + + parser.add_argument( + "--moonshine-cached-decoder", + default="", + type=str, + help="Path to moonshine cached decoder model", + ) + parser.add_argument( "--blank-penalty", type=float, @@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.sense_voice) == 0, args.sense_voice assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.encoder) assert_file_exists(args.decoder) @@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.sense_voice) == 0, args.sense_voice assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.paraformer) @@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: elif args.sense_voice: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder assert_file_exists(args.sense_voice) recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( @@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: elif args.whisper_encoder: assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_decoder) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( encoder=args.whisper_encoder, @@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: task=args.whisper_task, tail_paddings=args.whisper_tail_paddings, ) + elif args.moonshine_preprocessor: + assert_file_exists(args.moonshine_preprocessor) + assert_file_exists(args.moonshine_encoder) + assert_file_exists(args.moonshine_uncached_decoder) + assert_file_exists(args.moonshine_cached_decoder) + + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( + preprocessor=args.moonshine_preprocessor, + encoder=args.moonshine_encoder, + uncached_decoder=args.moonshine_uncached_decoder, + cached_decoder=args.moonshine_cached_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + ) else: raise ValueError("Please specify at least one model") diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index fafe5de96..f34c78609 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -29,6 +29,9 @@ set(sources offline-lm-config.cc offline-lm.cc offline-model-config.cc + offline-moonshine-greedy-search-decoder.cc + offline-moonshine-model-config.cc + offline-moonshine-model.cc offline-nemo-enc-dec-ctc-model-config.cc offline-nemo-enc-dec-ctc-model.cc offline-paraformer-greedy-search-decoder.cc diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 862e4a60c..787290327 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { zipformer_ctc.Register(po); wenet_ctc.Register(po); sense_voice.Register(po); + moonshine.Register(po); po->Register("telespeech-ctc", &telespeech_ctc, "Path to model.onnx for telespeech ctc"); @@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const { return sense_voice.Validate(); } + if (!moonshine.preprocessor.empty()) { + return moonshine.Validate(); + } + if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", telespeech_ctc.c_str()); @@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const { os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "sense_voice=" << sense_voice.ToString() << ", "; + os << "moonshine=" << moonshine.ToString() << ", "; os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 8eb725e4e..cfff5eed2 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -6,6 +6,7 @@ #include +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" @@ -26,6 +27,7 @@ struct OfflineModelConfig { OfflineZipformerCtcModelConfig zipformer_ctc; OfflineWenetCtcModelConfig wenet_ctc; OfflineSenseVoiceModelConfig sense_voice; + OfflineMoonshineModelConfig moonshine; std::string telespeech_ctc; std::string tokens; @@ -56,6 +58,7 @@ struct OfflineModelConfig { const OfflineZipformerCtcModelConfig &zipformer_ctc, const OfflineWenetCtcModelConfig &wenet_ctc, const OfflineSenseVoiceModelConfig &sense_voice, + const OfflineMoonshineModelConfig &moonshine, const std::string &telespeech_ctc, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type, @@ -69,6 +72,7 @@ struct OfflineModelConfig { zipformer_ctc(zipformer_ctc), wenet_ctc(wenet_ctc), sense_voice(sense_voice), + moonshine(moonshine), telespeech_ctc(telespeech_ctc), tokens(tokens), num_threads(num_threads), diff --git a/sherpa-onnx/csrc/offline-moonshine-decoder.h b/sherpa-onnx/csrc/offline-moonshine-decoder.h new file mode 100644 index 000000000..4d0b9ac93 --- /dev/null +++ b/sherpa-onnx/csrc/offline-moonshine-decoder.h @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/offline-moonshine-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OfflineMoonshineDecoderResult { + /// The decoded token IDs + std::vector tokens; +}; + +class OfflineMoonshineDecoder { + public: + virtual ~OfflineMoonshineDecoder() = default; + + /** Run beam search given the output from the moonshine encoder model. + * + * @param encoder_out A 3-D tensor of shape (batch_size, T, dim) + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + Ort::Value encoder_out) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc new file mode 100644 index 000000000..603bdd0cf --- /dev/null +++ b/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc @@ -0,0 +1,87 @@ +// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h" + +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +std::vector +OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) { + auto encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + if (encoder_out_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n", + static_cast(encoder_out_shape[0])); + return {}; + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // encoder_out_shape[1] * 384 is the number of audio samples + // 16000 is the sample rate + // + // + // 384 is from the moonshine paper + int32_t max_len = + static_cast(encoder_out_shape[1] * 384 / 16000.0 * 6); + + int32_t sos = 1; + int32_t eos = 2; + int32_t seq_len = 1; + + std::vector tokens; + + std::array token_shape = {1, 1}; + int64_t seq_len_shape = 1; + + Ort::Value token_tensor = Ort::Value::CreateTensor( + memory_info, &sos, 1, token_shape.data(), token_shape.size()); + + Ort::Value seq_len_tensor = + Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1); + + Ort::Value logits{nullptr}; + std::vector states; + + std::tie(logits, states) = model_->ForwardUnCachedDecoder( + std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out)); + + int32_t vocab_size = logits.GetTensorTypeAndShapeInfo().GetShape()[2]; + + for (int32_t i = 0; i != max_len; ++i) { + const float *p = logits.GetTensorData(); + + int32_t max_token_id = static_cast( + std::distance(p, std::max_element(p, p + vocab_size))); + if (max_token_id == eos) { + break; + } + tokens.push_back(max_token_id); + + seq_len += 1; + + token_tensor = Ort::Value::CreateTensor( + memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size()); + + seq_len_tensor = + Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1); + + std::tie(logits, states) = model_->ForwardCachedDecoder( + std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out), + std::move(states)); + } + + OfflineMoonshineDecoderResult ans; + ans.tokens = std::move(tokens); + + return {ans}; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h new file mode 100644 index 000000000..b215405db --- /dev/null +++ b/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-moonshine-decoder.h" +#include "sherpa-onnx/csrc/offline-moonshine-model.h" + +namespace sherpa_onnx { + +class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder { + public: + explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model) + : model_(model) {} + + std::vector Decode( + Ort::Value encoder_out) override; + + private: + OfflineMoonshineModel *model_; // not owned +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-moonshine-model-config.cc b/sherpa-onnx/csrc/offline-moonshine-model-config.cc new file mode 100644 index 000000000..c687507e3 --- /dev/null +++ b/sherpa-onnx/csrc/offline-moonshine-model-config.cc @@ -0,0 +1,88 @@ +// sherpa-onnx/csrc/offline-moonshine-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineMoonshineModelConfig::Register(ParseOptions *po) { + po->Register("moonshine-preprocessor", &preprocessor, + "Path to onnx preprocessor of moonshine, e.g., preprocess.onnx"); + + po->Register("moonshine-encoder", &encoder, + "Path to onnx encoder of moonshine, e.g., encode.onnx"); + + po->Register( + "moonshine-uncached-decoder", &uncached_decoder, + "Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx"); + + po->Register( + "moonshine-cached-decoder", &cached_decoder, + "Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx"); +} + +bool OfflineMoonshineModelConfig::Validate() const { + if (preprocessor.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor"); + return false; + } + + if (!FileExists(preprocessor)) { + SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist", + preprocessor.c_str()); + return false; + } + + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist", + encoder.c_str()); + return false; + } + + if (uncached_decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder"); + return false; + } + + if (!FileExists(uncached_decoder)) { + SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist", + uncached_decoder.c_str()); + return false; + } + + if (cached_decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder"); + return false; + } + + if (!FileExists(cached_decoder)) { + SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist", + cached_decoder.c_str()); + return false; + } + + return true; +} + +std::string OfflineMoonshineModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineMoonshineModelConfig("; + os << "preprocessor=\"" << preprocessor << "\", "; + os << "encoder=\"" << encoder << "\", "; + os << "uncached_decoder=\"" << uncached_decoder << "\", "; + os << "cached_decoder=\"" << cached_decoder << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-moonshine-model-config.h b/sherpa-onnx/csrc/offline-moonshine-model-config.h new file mode 100644 index 000000000..829ca520d --- /dev/null +++ b/sherpa-onnx/csrc/offline-moonshine-model-config.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/offline-moonshine-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineMoonshineModelConfig { + std::string preprocessor; + std::string encoder; + std::string uncached_decoder; + std::string cached_decoder; + + OfflineMoonshineModelConfig() = default; + OfflineMoonshineModelConfig(const std::string &preprocessor, + const std::string &encoder, + const std::string &uncached_decoder, + const std::string &cached_decoder) + : preprocessor(preprocessor), + encoder(encoder), + uncached_decoder(uncached_decoder), + cached_decoder(cached_decoder) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-moonshine-model.cc b/sherpa-onnx/csrc/offline-moonshine-model.cc new file mode 100644 index 000000000..ab71d000f --- /dev/null +++ b/sherpa-onnx/csrc/offline-moonshine-model.cc @@ -0,0 +1,282 @@ +// sherpa-onnx/csrc/offline-moonshine-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-moonshine-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineMoonshineModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.moonshine.preprocessor); + InitPreprocessor(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.moonshine.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.moonshine.uncached_decoder); + InitUnCachedDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.moonshine.cached_decoder); + InitCachedDecoder(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.moonshine.preprocessor); + InitPreprocessor(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.moonshine.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.moonshine.uncached_decoder); + InitUnCachedDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.moonshine.cached_decoder); + InitCachedDecoder(buf.data(), buf.size()); + } + } +#endif + + Ort::Value ForwardPreprocessor(Ort::Value audio) { + auto features = preprocessor_sess_->Run( + {}, preprocessor_input_names_ptr_.data(), &audio, 1, + preprocessor_output_names_ptr_.data(), + preprocessor_output_names_ptr_.size()); + + return std::move(features[0]); + } + + Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) { + std::array encoder_inputs{std::move(features), + std::move(features_len)}; + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + return std::move(encoder_out[0]); + } + + std::pair> ForwardUnCachedDecoder( + Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out) { + std::array uncached_decoder_input = { + std::move(tokens), + std::move(encoder_out), + std::move(seq_len), + }; + + auto uncached_decoder_out = uncached_decoder_sess_->Run( + {}, uncached_decoder_input_names_ptr_.data(), + uncached_decoder_input.data(), uncached_decoder_input.size(), + uncached_decoder_output_names_ptr_.data(), + uncached_decoder_output_names_ptr_.size()); + + std::vector states; + states.reserve(uncached_decoder_out.size() - 1); + + int32_t i = -1; + for (auto &s : uncached_decoder_out) { + ++i; + if (i == 0) { + continue; + } + + states.push_back(std::move(s)); + } + + return {std::move(uncached_decoder_out[0]), std::move(states)}; + } + + std::pair> ForwardCachedDecoder( + Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out, + std::vector states) { + std::vector cached_decoder_input; + cached_decoder_input.reserve(3 + states.size()); + cached_decoder_input.push_back(std::move(tokens)); + cached_decoder_input.push_back(std::move(encoder_out)); + cached_decoder_input.push_back(std::move(seq_len)); + + for (auto &s : states) { + cached_decoder_input.push_back(std::move(s)); + } + + auto cached_decoder_out = cached_decoder_sess_->Run( + {}, cached_decoder_input_names_ptr_.data(), cached_decoder_input.data(), + cached_decoder_input.size(), cached_decoder_output_names_ptr_.data(), + cached_decoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(cached_decoder_out.size() - 1); + + int32_t i = -1; + for (auto &s : cached_decoder_out) { + ++i; + if (i == 0) { + continue; + } + + next_states.push_back(std::move(s)); + } + + return {std::move(cached_decoder_out[0]), std::move(next_states)}; + } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void InitPreprocessor(void *model_data, size_t model_data_length) { + preprocessor_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_, + &preprocessor_input_names_ptr_); + + GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_, + &preprocessor_output_names_ptr_); + } + + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + } + + void InitUnCachedDecoder(void *model_data, size_t model_data_length) { + uncached_decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_, + &uncached_decoder_input_names_ptr_); + + GetOutputNames(uncached_decoder_sess_.get(), + &uncached_decoder_output_names_, + &uncached_decoder_output_names_ptr_); + } + + void InitCachedDecoder(void *model_data, size_t model_data_length) { + cached_decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_, + &cached_decoder_input_names_ptr_); + + GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_, + &cached_decoder_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr preprocessor_sess_; + std::unique_ptr encoder_sess_; + std::unique_ptr uncached_decoder_sess_; + std::unique_ptr cached_decoder_sess_; + + std::vector preprocessor_input_names_; + std::vector preprocessor_input_names_ptr_; + + std::vector preprocessor_output_names_; + std::vector preprocessor_output_names_ptr_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector uncached_decoder_input_names_; + std::vector uncached_decoder_input_names_ptr_; + + std::vector uncached_decoder_output_names_; + std::vector uncached_decoder_output_names_ptr_; + + std::vector cached_decoder_input_names_; + std::vector cached_decoder_input_names_ptr_; + + std::vector cached_decoder_output_names_; + std::vector cached_decoder_output_names_ptr_; +}; + +OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineMoonshineModel::OfflineMoonshineModel(AAssetManager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineMoonshineModel::~OfflineMoonshineModel() = default; + +Ort::Value OfflineMoonshineModel::ForwardPreprocessor(Ort::Value audio) const { + return impl_->ForwardPreprocessor(std::move(audio)); +} + +Ort::Value OfflineMoonshineModel::ForwardEncoder( + Ort::Value features, Ort::Value features_len) const { + return impl_->ForwardEncoder(std::move(features), std::move(features_len)); +} + +std::pair> +OfflineMoonshineModel::ForwardUnCachedDecoder(Ort::Value token, + Ort::Value seq_len, + Ort::Value encoder_out) const { + return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len), + std::move(encoder_out)); +} + +std::pair> +OfflineMoonshineModel::ForwardCachedDecoder( + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out, + std::vector states) const { + return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len), + std::move(encoder_out), std::move(states)); +} + +OrtAllocator *OfflineMoonshineModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-moonshine-model.h b/sherpa-onnx/csrc/offline-moonshine-model.h new file mode 100644 index 000000000..7065b1445 --- /dev/null +++ b/sherpa-onnx/csrc/offline-moonshine-model.h @@ -0,0 +1,93 @@ +// sherpa-onnx/csrc/offline-moonshine-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +// please see +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/moonshine/test.py +class OfflineMoonshineModel { + public: + explicit OfflineMoonshineModel(const OfflineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineMoonshineModel(AAssetManager *mgr, const OfflineModelConfig &config); +#endif + + ~OfflineMoonshineModel(); + + /** Run the preprocessor model. + * + * @param audio A float32 tensor of shape (batch_size, num_samples) + * + * @return Return a float32 tensor of shape (batch_size, T, dim) that + * can be used as the input of ForwardEncoder() + */ + Ort::Value ForwardPreprocessor(Ort::Value audio) const; + + /** Run the encoder model. + * + * @param features A float32 tensor of shape (batch_size, T, dim) + * @param features_len A int32 tensor of shape (batch_size,) + * @returns A float32 tensor of shape (batch_size, T, dim). + */ + Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) const; + + /** Run the uncached decoder. + * + * @param token A int32 tensor of shape (batch_size, num_tokens) + * @param seq_len A int32 tensor of shape (batch_size,) containing number + * of predicted tokens so far + * @param encoder_out A float32 tensor of shape (batch_size, T, dim) + * + * @returns Return a pair: + * + * - logits, a float32 tensor of shape (batch_size, 1, dim) + * - states, a list of states + */ + std::pair> ForwardUnCachedDecoder( + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out) const; + + /** Run the cached decoder. + * + * @param token A int32 tensor of shape (batch_size, num_tokens) + * @param seq_len A int32 tensor of shape (batch_size,) containing number + * of predicted tokens so far + * @param encoder_out A float32 tensor of shape (batch_size, T, dim) + * @param states A list of previous states + * + * @returns Return a pair: + * - logits, a float32 tensor of shape (batch_size, 1, dim) + * - states, a list of new states + */ + std::pair> ForwardCachedDecoder( + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out, + std::vector states) const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index f6c6e247c..07887df60 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -20,6 +20,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" @@ -51,6 +52,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (!config.model_config.moonshine.preprocessor.empty()) { + return std::make_unique(config); + } + // TODO(fangjun): Refactor it. We only need to use model type for the // following models: // 1. transducer and nemo_transducer @@ -67,7 +72,11 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_type == "telespeech_ctc") { return std::make_unique(config); } else if (model_type == "whisper") { + // unreachable return std::make_unique(config); + } else if (model_type == "moonshine") { + // unreachable + return std::make_unique(config); } else { SHERPA_ONNX_LOGE( "Invalid model_type: %s. Trying to load the model to get its type", @@ -225,6 +234,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } + if (!config.model_config.moonshine.preprocessor.empty()) { + return std::make_unique(mgr, config); + } + // TODO(fangjun): Refactor it. We only need to use model type for the // following models: // 1. transducer and nemo_transducer @@ -242,6 +255,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } else if (model_type == "whisper") { return std::make_unique(mgr, config); + } else if (model_type == "moonshine") { + return std::make_unique(mgr, config); } else { SHERPA_ONNX_LOGE( "Invalid model_type: %s. Trying to load the model to get its type", diff --git a/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h b/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h new file mode 100644 index 000000000..7d52a41b2 --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h @@ -0,0 +1,150 @@ +// sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-moonshine-decoder.h" +#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-moonshine-model.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult Convert( + const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + if (!sym_table.Contains(i)) { + continue; + } + + const auto &s = sym_table[i]; + text += s; + r.tokens.push_back(s); + } + + r.text = text; + + return r; +} + +class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + Init(); + } + +#if __ANDROID_API__ >= 9 + OfflineRecognizerMoonshineImpl(AAssetManager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_( + std::make_unique(mgr, config.model_config)) { + Init(); + } + +#endif + + void Init() { + if (config_.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + SHERPA_ONNX_LOGE( + "Only greedy_search is supported at present for moonshine. Given %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + MoonshineTag tag; + return std::make_unique(tag); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // batch decoding is not implemented yet + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void DecodeStream(OfflineStream *s) const { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector audio = s->GetFrames(); + + try { + std::array shape{1, static_cast(audio.size())}; + + Ort::Value audio_tensor = Ort::Value::CreateTensor( + memory_info, audio.data(), audio.size(), shape.data(), shape.size()); + + Ort::Value features = + model_->ForwardPreprocessor(std::move(audio_tensor)); + + int32_t features_len = features.GetTensorTypeAndShapeInfo().GetShape()[1]; + + int64_t features_shape = 1; + + Ort::Value features_len_tensor = Ort::Value::CreateTensor( + memory_info, &features_len, 1, &features_shape, 1); + + Ort::Value encoder_out = model_->ForwardEncoder( + std::move(features), std::move(features_len_tensor)); + + auto results = decoder_->Decode(std::move(encoder_out)); + + auto r = Convert(results[0], symbol_table_); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + s->SetResult(r); + } catch (const Ort::Exception &ex) { + SHERPA_ONNX_LOGE( + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of " + "audio samples: %d", + ex.what(), static_cast(audio.size())); + return; + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 0f83807dc..6fa0ca6c8 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -133,6 +133,10 @@ class OfflineStream::Impl { fbank_ = std::make_unique(opts_); } + explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) { + config_.sampling_rate = 16000; + } + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { if (config_.normalize_samples) { AcceptWaveformImpl(sampling_rate, waveform, n); @@ -164,7 +168,9 @@ class OfflineStream::Impl { std::vector samples; resampler->Resample(waveform, n, true, &samples); - if (fbank_) { + if (is_moonshine_) { + samples_.insert(samples_.end(), samples.begin(), samples.end()); + } else if (fbank_) { fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), samples.size()); fbank_->InputFinished(); @@ -181,7 +187,9 @@ class OfflineStream::Impl { return; } // if (sampling_rate != config_.sampling_rate) - if (fbank_) { + if (is_moonshine_) { + samples_.insert(samples_.end(), waveform, waveform + n); + } else if (fbank_) { fbank_->AcceptWaveform(sampling_rate, waveform, n); fbank_->InputFinished(); } else if (mfcc_) { @@ -194,10 +202,18 @@ class OfflineStream::Impl { } int32_t FeatureDim() const { + if (is_moonshine_) { + return samples_.size(); + } + return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; } std::vector GetFrames() const { + if (is_moonshine_) { + return samples_; + } + int32_t n = fbank_ ? fbank_->NumFramesReady() : mfcc_ ? mfcc_->NumFramesReady() : whisper_fbank_->NumFramesReady(); @@ -300,6 +316,10 @@ class OfflineStream::Impl { OfflineRecognitionResult r_; ContextGraphPtr context_graph_; bool is_ced_ = false; + bool is_moonshine_ = false; + + // used only when is_moonshine_== true + std::vector samples_; }; OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, @@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag) OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique(tag)) {} +OfflineStream::OfflineStream(MoonshineTag tag) + : impl_(std::make_unique(tag)) {} + OfflineStream::~OfflineStream() = default; void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 95bc80e83..e4bed1115 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -34,7 +34,7 @@ struct OfflineRecognitionResult { // event target of the audio. std::string event; - /// timestamps.size() == tokens.size() + /// timestamps.size() == tokens.size() /// timestamps[i] records the time in seconds when tokens[i] is decoded. std::vector timestamps; @@ -49,6 +49,10 @@ struct WhisperTag { struct CEDTag {}; +// It uses a neural network model, a preprocessor, to convert +// audio samples to features +struct MoonshineTag {}; + class OfflineStream { public: explicit OfflineStream(const FeatureExtractorConfig &config = {}, @@ -56,6 +60,7 @@ class OfflineStream { explicit OfflineStream(WhisperTag tag); explicit OfflineStream(CEDTag tag); + explicit OfflineStream(MoonshineTag tag); ~OfflineStream(); /** @@ -72,7 +77,10 @@ class OfflineStream { void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) const; - /// Return feature dim of this extractor + /// Return feature dim of this extractor. + /// + /// Note: if it is Moonshine, then it returns the number of audio samples + /// currently received. int32_t FeatureDim() const; // Get all the feature frames of this stream in a 1-D array, which is diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc index 6327e5347..485eaf93c 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.cc +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl { explicit Impl(const OfflineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), - debug_(config.debug), sess_opts_(GetSessionOptions(config)), allocator_{} { { @@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl { explicit Impl(const SpokenLanguageIdentificationConfig &config) : lid_config_(config), env_(ORT_LOGGING_LEVEL_ERROR), - debug_(config_.debug), sess_opts_(GetSessionOptions(config)), allocator_{} { { @@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - debug_ = config_.debug; { auto buf = ReadFile(mgr, config.whisper.encoder); InitEncoder(buf.data(), buf.size()); @@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - debug_ = config_.debug; { auto buf = ReadFile(mgr, config.whisper.encoder); InitEncoder(buf.data(), buf.size()); @@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl { } } - if (debug_) { + if (config_.debug) { SHERPA_ONNX_LOGE("Detected language: %s", GetID2Lang().at(lang_id).c_str()); } @@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl { // get meta data Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); - if (debug_) { + if (config_.debug) { std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); @@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl { private: OfflineModelConfig config_; SpokenLanguageIdentificationConfig lid_config_; - bool debug_ = false; Ort::Env env_; Ort::SessionOptions sess_opts_; Ort::AllocatorWithDefaultOptions allocator_; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index 73e77299a..022f7569b 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in --decoding-method=greedy_search \ /path/to/foo.wav [bar.wav foobar.wav ...] -(3) Whisper models +(3) Moonshine models + +See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html + + ./bin/sherpa-onnx-offline \ + --moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \ + --moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \ + --moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \ + --tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(4) Whisper models See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html @@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html --num-threads=1 \ /path/to/foo.wav [bar.wav foobar.wav ...] -(4) NeMo CTC models +(5) NeMo CTC models See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html @@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav -(5) TDNN CTC model for the yesno recipe from icefall +(6) TDNN CTC model for the yesno recipe from icefall See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html // diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index a71225c38..eed7a1e53 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const { // for byte-level BPE // id 0 is blank, id 1 is sos/eos, id 2 is unk + // + // Note: For moonshine models, 0 is , 1, is , 2 is if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { std::ostringstream os; diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 2e971581a..21f77f29d 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -11,6 +11,7 @@ set(srcs offline-ctc-fst-decoder-config.cc offline-lm-config.cc offline-model-config.cc + offline-moonshine-model-config.cc offline-nemo-enc-dec-ctc-model-config.cc offline-paraformer-model-config.cc offline-punctuation.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index f498bd7e2..d999486bc 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -8,6 +8,7 @@ #include #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" @@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineZipformerCtcModelConfig(m); PybindOfflineWenetCtcModelConfig(m); PybindOfflineSenseVoiceModelConfig(m); + PybindOfflineMoonshineModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") @@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) { const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, const OfflineZipformerCtcModelConfig &, const OfflineWenetCtcModelConfig &, - const OfflineSenseVoiceModelConfig &, const std::string &, + const OfflineSenseVoiceModelConfig &, + const OfflineMoonshineModelConfig &, const std::string &, const std::string &, int32_t, bool, const std::string &, const std::string &, const std::string &, const std::string &>(), py::arg("transducer") = OfflineTransducerModelConfig(), @@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) { py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), + py::arg("moonshine") = OfflineMoonshineModelConfig(), py::arg("telespeech_ctc") = "", py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "", @@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) { .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("sense_voice", &PyClass::sense_voice) + .def_readwrite("moonshine", &PyClass::moonshine) .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) diff --git a/sherpa-onnx/python/csrc/offline-moonshine-model-config.cc b/sherpa-onnx/python/csrc/offline-moonshine-model-config.cc new file mode 100644 index 000000000..14bea382b --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-moonshine-model-config.cc @@ -0,0 +1,28 @@ +// sherpa-onnx/python/csrc/offline-moonshine-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineMoonshineModelConfig(py::module *m) { + using PyClass = OfflineMoonshineModelConfig; + py::class_(*m, "OfflineMoonshineModelConfig") + .def(py::init(), + py::arg("preprocessor"), py::arg("encoder"), + py::arg("uncached_decoder"), py::arg("cached_decoder")) + .def_readwrite("preprocessor", &PyClass::preprocessor) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("uncached_decoder", &PyClass::uncached_decoder) + .def_readwrite("cached_decoder", &PyClass::cached_decoder) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-moonshine-model-config.h b/sherpa-onnx/python/csrc/offline-moonshine-model-config.h new file mode 100644 index 000000000..1b30f9f94 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-moonshine-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-moonshine-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineMoonshineModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index e96271a58..391666005 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -8,13 +8,14 @@ OfflineCtcFstDecoderConfig, OfflineLMConfig, OfflineModelConfig, + OfflineMoonshineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, - OfflineSenseVoiceModelConfig, ) from _sherpa_onnx import OfflineRecognizer as _Recognizer from _sherpa_onnx import ( OfflineRecognizerConfig, + OfflineSenseVoiceModelConfig, OfflineStream, OfflineTdnnModelConfig, OfflineTransducerModelConfig, @@ -503,12 +504,12 @@ def from_whisper( e.g., tiny, tiny.en, base, base.en, etc. Args: - encoder_model: - Path to the encoder model, e.g., tiny-encoder.onnx, - tiny-encoder.int8.onnx, tiny-encoder.ort, etc. - decoder_model: + encoder: Path to the encoder model, e.g., tiny-encoder.onnx, tiny-encoder.int8.onnx, tiny-encoder.ort, etc. + decoder: + Path to the decoder model, e.g., tiny-decoder.onnx, + tiny-decoder.int8.onnx, tiny-decoder.ort, etc. tokens: Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two columns:: @@ -570,6 +571,87 @@ def from_whisper( self.config = recognizer_config return self + @classmethod + def from_moonshine( + cls, + preprocessor: str, + encoder: str, + uncached_decoder: str, + cached_decoder: str, + tokens: str, + num_threads: int = 1, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different kinds of moonshine models, + e.g., tiny, base, etc. + + Args: + preprocessor: + Path to the preprocessor model, e.g., preprocess.onnx + encoder: + Path to the encoder model, e.g., encode.int8.onnx + uncached_decoder: + Path to the uncached decoder model, e.g., uncached_decode.int8.onnx, + cached_decoder: + Path to the cached decoder model, e.g., cached_decode.int8.onnx, + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + decoding_method: + Valid values: greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + moonshine=OfflineMoonshineModelConfig( + preprocessor=preprocessor, + encoder=encoder, + uncached_decoder=uncached_decoder, + cached_decoder=cached_decoder, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + ) + + unused_feat_config = FeatureExtractorConfig( + sampling_rate=16000, + feature_dim=80, + ) + + recognizer_config = OfflineRecognizerConfig( + model_config=model_config, + feat_config=unused_feat_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + @classmethod def from_tdnn_ctc( cls,