From 734bbd91dcafe5210a29773a1c1cb3ac52fb2199 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Fri, 1 Mar 2024 09:31:11 +0800 Subject: [PATCH] Add Python API for keyword spotting (#576) * Add alsa & microphone support for keyword spotting * Add python wrapper --- .github/scripts/test-python.sh | 58 +++++ .../keyword-spotter-from-microphone.py | 191 ++++++++++++++ python-api-examples/keyword-spotter.py | 242 ++++++++++++++++++ sherpa-onnx/csrc/CMakeLists.txt | 8 + .../csrc/sherpa-onnx-keyword-spotter-alsa.cc | 124 +++++++++ .../sherpa-onnx-keyword-spotter-microphone.cc | 148 +++++++++++ .../csrc/sherpa-onnx-keyword-spotter.cc | 1 - sherpa-onnx/python/csrc/CMakeLists.txt | 1 + sherpa-onnx/python/csrc/keyword-spotter.cc | 82 ++++++ sherpa-onnx/python/csrc/keyword-spotter.h | 16 ++ sherpa-onnx/python/csrc/sherpa-onnx.cc | 2 + sherpa-onnx/python/sherpa_onnx/__init__.py | 1 + .../python/sherpa_onnx/keyword_spotter.py | 147 +++++++++++ sherpa-onnx/python/tests/CMakeLists.txt | 1 + .../python/tests/test_keyword_spotter.py | 170 ++++++++++++ 15 files changed, 1191 insertions(+), 1 deletion(-) create mode 100755 python-api-examples/keyword-spotter-from-microphone.py create mode 100755 python-api-examples/keyword-spotter.py create mode 100644 sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc create mode 100644 sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc create mode 100644 sherpa-onnx/python/csrc/keyword-spotter.cc create mode 100644 sherpa-onnx/python/csrc/keyword-spotter.h create mode 100644 sherpa-onnx/python/sherpa_onnx/keyword_spotter.py create mode 100755 sherpa-onnx/python/tests/test_keyword_spotter.py diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index f63c2de66..befaab483 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data python3 sherpa-onnx/python/tests/test_text2token.py --verbose rm -rf /tmp/sherpa-test-data + +mkdir -p /tmp/onnx-models +dir=/tmp/onnx-models + +log "Test keyword spotting models" + +python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" +sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)") + +echo "sherpa_onnx version: $sherpa_onnx_version" + +pwd +ls -lh + +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 +log "Start testing ${repo}" + +pushd $dir +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz +tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz +popd + +repo=$dir/$repo +ls -lh $repo + +python3 ./python-api-examples/keyword-spotter.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ + --keywords-file=$repo/test_wavs/test_keywords.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav + +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 +log "Start testing ${repo}" + +pushd $dir +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz +tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz +popd + +repo=$dir/$repo +ls -lh $repo + +python3 ./python-api-examples/keyword-spotter.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ + --keywords-file=$repo/test_wavs/test_keywords.txt \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav \ + $repo/test_wavs/5.wav + +python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose + +rm -r $dir diff --git a/python-api-examples/keyword-spotter-from-microphone.py b/python-api-examples/keyword-spotter-from-microphone.py new file mode 100755 index 000000000..5a0ebafe7 --- /dev/null +++ b/python-api-examples/keyword-spotter-from-microphone.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 + +# Real-time keyword spotting from a microphone with sherpa-onnx Python API +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +# to download pre-trained models + +import argparse +import sys +from pathlib import Path + +from typing import List + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_onnx + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help=""" + It specifies number of active paths to keep during decoding. + """, + ) + + parser.add_argument( + "--num-trailing-blanks", + type=int, + default=1, + help="""The number of trailing blanks a keyword should be followed. Setting + to a larger value (e.g. 8) when your keywords has overlapping tokens + between each other. + """, + ) + + parser.add_argument( + "--keywords-file", + type=str, + help=""" + The file containing keywords, one words/phrases per line, and for each + phrase the bpe/cjkchar/pinyin are separated by a space. For example: + + ▁HE LL O ▁WORLD + x iǎo ài t óng x ué + """, + ) + + parser.add_argument( + "--keywords-score", + type=float, + default=1.0, + help=""" + The boosting score of each token for keywords. The larger the easier to + survive beam search. + """, + ) + + parser.add_argument( + "--keywords-threshold", + type=float, + default=0.25, + help=""" + The trigger threshold (i.e. probability) of the keyword. The larger the + harder to trigger. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + assert_file_exists(args.tokens) + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + assert Path( + args.keywords_file + ).is_file(), ( + f"keywords_file : {args.keywords_file} not exist, please provide a valid path." + ) + + keyword_spotter = sherpa_onnx.KeywordSpotter( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=args.num_threads, + max_active_paths=args.max_active_paths, + keywords_file=args.keywords_file, + keywords_score=args.keywords_score, + keywords_threshold=args.keywords_threshold, + num_tailing_blanks=args.rnum_tailing_blanks, + provider=args.provider, + ) + + print("Started! Please speak") + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + stream = keyword_spotter.create_stream() + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + while keyword_spotter.is_ready(stream): + keyword_spotter.decode_stream(stream) + result = keyword_spotter.get_result(stream) + if result: + print("\r{}".format(result), end="", flush=True) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/python-api-examples/keyword-spotter.py b/python-api-examples/keyword-spotter.py new file mode 100755 index 000000000..1b1de77e3 --- /dev/null +++ b/python-api-examples/keyword-spotter.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 + +""" +This file demonstrates how to use sherpa-onnx Python API to do keyword spotting +from wave file(s). + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +to download pre-trained models. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_onnx + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help=""" + It specifies number of active paths to keep during decoding. + """, + ) + + parser.add_argument( + "--num-trailing-blanks", + type=int, + default=1, + help="""The number of trailing blanks a keyword should be followed. Setting + to a larger value (e.g. 8) when your keywords has overlapping tokens + between each other. + """, + ) + + parser.add_argument( + "--keywords-file", + type=str, + help=""" + The file containing keywords, one words/phrases per line, and for each + phrase the bpe/cjkchar/pinyin are separated by a space. For example: + + ▁HE LL O ▁WORLD + x iǎo ài t óng x ué + """, + ) + + parser.add_argument( + "--keywords-score", + type=float, + default=1.0, + help=""" + The boosting score of each token for keywords. The larger the easier to + survive beam search. + """, + ) + + parser.add_argument( + "--keywords-threshold", + type=float, + default=0.25, + help=""" + The trigger threshold (i.e. probability) of the keyword. The larger the + harder to trigger. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + assert Path( + args.keywords_file + ).is_file(), ( + f"keywords_file : {args.keywords_file} not exist, please provide a valid path." + ) + + keyword_spotter = sherpa_onnx.KeywordSpotter( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=args.num_threads, + max_active_paths=args.max_active_paths, + keywords_file=args.keywords_file, + keywords_score=args.keywords_score, + keywords_threshold=args.keywords_threshold, + num_trailing_blanks=args.num_trailing_blanks, + provider=args.provider, + ) + + print("Started!") + start_time = time.time() + + streams = [] + total_duration = 0 + for wave_filename in args.sound_files: + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + + s = keyword_spotter.create_stream() + + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + + s.input_finished() + + streams.append(s) + + results = [""] * len(streams) + while True: + ready_list = [] + for i, s in enumerate(streams): + if keyword_spotter.is_ready(s): + ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + results[i] += f"{r}/" + print(f"{r} is detected.") + if len(ready_list) == 0: + break + keyword_spotter.decode_streams(ready_list) + end_time = time.time() + print("Done!") + + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 17fc047ff..198e8fa7c 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -230,12 +230,14 @@ endif() if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) + add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc) add_executable(sherpa-onnx-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc) add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) set(exes sherpa-onnx-alsa + sherpa-onnx-keyword-spotter-alsa sherpa-onnx-alsa-offline sherpa-onnx-offline-tts-play-alsa sherpa-onnx-alsa-offline-speaker-identification @@ -278,6 +280,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) microphone.cc ) + add_executable(sherpa-onnx-keyword-spotter-microphone + sherpa-onnx-keyword-spotter-microphone.cc + microphone.cc + ) + add_executable(sherpa-onnx-microphone sherpa-onnx-microphone.cc microphone.cc @@ -311,6 +318,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) set(exes sherpa-onnx-microphone + sherpa-onnx-keyword-spotter-microphone sherpa-onnx-microphone-offline sherpa-onnx-microphone-offline-speaker-identification sherpa-onnx-offline-tts-play diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc new file mode 100644 index 000000000..ab61eb87c --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc @@ -0,0 +1,124 @@ +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include +#include +#include + +#include +#include + +#include "sherpa-onnx/csrc/alsa.h" +#include "sherpa-onnx/csrc/display.h" +#include "sherpa-onnx/csrc/keyword-spotter.h" +#include "sherpa-onnx/csrc/parse-options.h" + +bool stop = false; + +static void Handler(int sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +Usage: + ./bin/sherpa-onnx-keyword-spotter-alsa \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=2 \ + --keywords-file=keywords.txt \ + device_name + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +for a list of pre-trained models to download. + +The device name specifies which microphone to use in case there are several +on you system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and the device 0 on that card, please use: + + hw:3,0 + +or + + plughw:3,0 + +as the device_name. +)usage"; + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::KeywordSpotterConfig config; + + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + sherpa_onnx::KeywordSpotter spotter(config); + + int32_t expected_sample_rate = config.feat_config.sampling_rate; + + std::string device_name = po.GetArg(1); + sherpa_onnx::Alsa alsa(device_name.c_str()); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + + std::string last_text; + + auto stream = spotter.CreateStream(); + + sherpa_onnx::Display display; + + int32_t keyword_index = 0; + while (!stop) { + const std::vector &samples = alsa.Read(chunk); + + stream->AcceptWaveform(expected_sample_rate, samples.data(), + samples.size()); + + while (spotter.IsReady(stream.get())) { + spotter.DecodeStream(stream.get()); + } + + const auto r = spotter.GetResult(stream.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + } + } + + return 0; +} diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc new file mode 100644 index 000000000..1f42da40a --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc @@ -0,0 +1,148 @@ +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include + +#include "portaudio.h" // NOLINT +#include "sherpa-onnx/csrc/display.h" +#include "sherpa-onnx/csrc/microphone.h" +#include "sherpa-onnx/csrc/keyword-spotter.h" + +bool stop = false; + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void *user_data) { + auto stream = reinterpret_cast(user_data); + + stream->AcceptWaveform(16000, reinterpret_cast(input_buffer), + frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program uses streaming models with microphone for keyword spotting. +Usage: + + ./bin/sherpa-onnx-keyword-spotter-microphone \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=1 \ + --keywords-file=keywords.txt + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::KeywordSpotterConfig config; + + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 0) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_onnx::KeywordSpotter spotter(config); + auto s = spotter.CreateStream(); + + sherpa_onnx::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + PaStreamParameters param; + + param.device = Pa_GetDefaultInputDevice(); + if (param.device == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + exit(EXIT_FAILURE); + } + fprintf(stderr, "Use default device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float sample_rate = 16000; + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, s.get()); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + int32_t keyword_index = 0; + sherpa_onnx::Display display; + while (!stop) { + while (spotter.IsReady(s.get())) { + spotter.DecodeStream(s.get()); + } + + const auto r = spotter.GetResult(s.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + } + + Pa_Sleep(20); // sleep for 20ms + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc index d7ef2bd68..72053744d 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc @@ -12,7 +12,6 @@ #include "sherpa-onnx/csrc/keyword-spotter.h" #include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/parse-options.h" -#include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/wave-reader.h" typedef struct { diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index a94a1ff67..30f646216 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx display.cc endpoint.cc features.cc + keyword-spotter.cc offline-ctc-fst-decoder-config.cc offline-lm-config.cc offline-model-config.cc diff --git a/sherpa-onnx/python/csrc/keyword-spotter.cc b/sherpa-onnx/python/csrc/keyword-spotter.cc new file mode 100644 index 000000000..144992605 --- /dev/null +++ b/sherpa-onnx/python/csrc/keyword-spotter.cc @@ -0,0 +1,82 @@ +// sherpa-onnx/python/csrc/keyword-spotter.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/keyword-spotter.h" + +#include +#include + +#include "sherpa-onnx/csrc/keyword-spotter.h" + +namespace sherpa_onnx { + +static void PybindKeywordResult(py::module *m) { + using PyClass = KeywordResult; + py::class_(*m, "KeywordResult") + .def_property_readonly( + "keyword", + [](PyClass &self) -> py::str { + return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(), + self.keyword.size(), "ignore")); + }) + .def_property_readonly( + "tokens", + [](PyClass &self) -> std::vector { return self.tokens; }) + .def_property_readonly( + "timestamps", + [](PyClass &self) -> std::vector { return self.timestamps; }); +} + +static void PybindKeywordSpotterConfig(py::module *m) { + using PyClass = KeywordSpotterConfig; + py::class_(*m, "KeywordSpotterConfig") + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1, + py::arg("keywords_score") = 1.0, + py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "") + .def_readwrite("feat_config", &PyClass::feat_config) + .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("max_active_paths", &PyClass::max_active_paths) + .def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks) + .def_readwrite("keywords_score", &PyClass::keywords_score) + .def_readwrite("keywords_threshold", &PyClass::keywords_threshold) + .def_readwrite("keywords_file", &PyClass::keywords_file) + .def("__str__", &PyClass::ToString); +} + +void PybindKeywordSpotter(py::module *m) { + PybindKeywordResult(m); + PybindKeywordSpotterConfig(m); + + using PyClass = KeywordSpotter; + py::class_(*m, "KeywordSpotter") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def( + "create_stream", + [](const PyClass &self) { return self.CreateStream(); }, + py::call_guard()) + .def( + "create_stream", + [](PyClass &self, const std::string &keywords) { + return self.CreateStream(keywords); + }, + py::arg("keywords"), py::call_guard()) + .def("is_ready", &PyClass::IsReady, + py::call_guard()) + .def("decode_stream", &PyClass::DecodeStream, + py::call_guard()) + .def( + "decode_streams", + [](PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }, + py::call_guard()) + .def("get_result", &PyClass::GetResult, + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/keyword-spotter.h b/sherpa-onnx/python/csrc/keyword-spotter.h new file mode 100644 index 000000000..dce0bae02 --- /dev/null +++ b/sherpa-onnx/python/csrc/keyword-spotter.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/keyword-spotter.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindKeywordSpotter(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 37728426d..bdc38bbe9 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -8,6 +8,7 @@ #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" +#include "sherpa-onnx/python/csrc/keyword-spotter.h" #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/offline-lm-config.h" #include "sherpa-onnx/python/csrc/offline-model-config.h" @@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOnlineStream(&m); PybindEndpoint(&m); PybindOnlineRecognizer(&m); + PybindKeywordSpotter(&m); PybindDisplay(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 0f13f38c4..926edbb8f 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -17,6 +17,7 @@ VoiceActivityDetector, ) +from .keyword_spotter import KeywordSpotter from .offline_recognizer import OfflineRecognizer from .online_recognizer import OnlineRecognizer from .utils import text2token diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py new file mode 100644 index 000000000..218628ea9 --- /dev/null +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -0,0 +1,147 @@ +# Copyright (c) 2023 Xiaomi Corporation + +from pathlib import Path +from typing import List, Optional + +from _sherpa_onnx import ( + FeatureExtractorConfig, + KeywordSpotterConfig, + OnlineModelConfig, + OnlineTransducerModelConfig, + OnlineStream, +) + +from _sherpa_onnx import KeywordSpotter as _KeywordSpotter + + +def _assert_file_exists(f: str): + assert Path(f).is_file(), f"{f} does not exist" + + +class KeywordSpotter(object): + """A class for keyword spotting. + + Please refer to the following files for usages + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter.py + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter-from-microphone.py + """ + + def __init__( + self, + tokens: str, + encoder: str, + decoder: str, + joiner: str, + keywords_file: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + max_active_paths: int = 4, + keywords_score: float = 1.0, + keywords_threshold: float = 0.25, + num_trailing_blanks: int = 1, + provider: str = "cpu", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + joiner: + Path to ``joiner.onnx``. + keywords_file: + The file containing keywords, one word/phrase per line, and for each + phrase the bpe/cjkchar/pinyin are separated by a space. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + max_active_paths: + Use only when decoding_method is modified_beam_search. It specifies + the maximum number of active paths during beam search. + keywords_score: + The boosting score of each token for keywords. The larger the easier to + survive beam search. + keywords_threshold: + The trigger threshold (i.e. probability) of the keyword. The larger the + harder to trigger. + num_trailing_blanks: + The number of trailing blanks a keyword should be followed. Setting + to a larger value (e.g. 8) when your keywords has overlapping tokens + between each other. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + """ + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + _assert_file_exists(joiner) + + assert num_threads > 0, num_threads + + transducer_config = OnlineTransducerModelConfig( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + + model_config = OnlineModelConfig( + transducer=transducer_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + keywords_spotter_config = KeywordSpotterConfig( + feat_config=feat_config, + model_config=model_config, + max_active_paths=max_active_paths, + num_trailing_blanks=num_trailing_blanks, + keywords_score=keywords_score, + keywords_threshold=keywords_threshold, + keywords_file=keywords_file, + ) + self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) + + def create_stream(self, keywords: Optional[str] = None): + if keywords is None: + return self.keyword_spotter.create_stream() + else: + return self.keyword_spotter.create_stream(keywords) + + def decode_stream(self, s: OnlineStream): + self.keyword_spotter.decode_stream(s) + + def decode_streams(self, ss: List[OnlineStream]): + self.keyword_spotter.decode_streams(ss) + + def is_ready(self, s: OnlineStream) -> bool: + return self.keyword_spotter.is_ready(s) + + def get_result(self, s: OnlineStream) -> str: + return self.keyword_spotter.get_result(s).keyword.strip() + + def tokens(self, s: OnlineStream) -> List[str]: + return self.keyword_spotter.get_result(s).tokens + + def timestamps(self, s: OnlineStream) -> List[float]: + return self.keyword_spotter.get_result(s).timestamps diff --git a/sherpa-onnx/python/tests/CMakeLists.txt b/sherpa-onnx/python/tests/CMakeLists.txt index e99636e2b..c82edc612 100644 --- a/sherpa-onnx/python/tests/CMakeLists.txt +++ b/sherpa-onnx/python/tests/CMakeLists.txt @@ -20,6 +20,7 @@ endfunction() # please sort the files in alphabetic order set(py_test_files test_feature_extractor_config.py + test_keyword_spotter.py test_offline_recognizer.py test_online_recognizer.py test_online_transducer_model_config.py diff --git a/sherpa-onnx/python/tests/test_keyword_spotter.py b/sherpa-onnx/python/tests/test_keyword_spotter.py new file mode 100755 index 000000000..bdefa5d10 --- /dev/null +++ b/sherpa-onnx/python/tests/test_keyword_spotter.py @@ -0,0 +1,170 @@ +# sherpa-onnx/python/tests/test_keyword_spotter.py +# +# Copyright (c) 2024 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_keyword_spotter_py + +import unittest +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_onnx + +d = "/tmp/onnx-models" +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +# to download pre-trained models for testing + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +class TestKeywordSpotter(unittest.TestCase): + def test_zipformer_transducer_en(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + + tokens = ( + f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt" + ) + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" + wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav" + wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav" + + if not Path(encoder).is_file(): + print("skipping test_zipformer_transducer_en()") + return + keyword_spotter = sherpa_onnx.KeywordSpotter( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + keywords_file=keywords_file, + provider="cpu", + ) + streams = [] + waves = [wave0, wave1] + for wave in waves: + s = keyword_spotter.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + results = [""] * len(streams) + while True: + ready_list = [] + for i, s in enumerate(streams): + if keyword_spotter.is_ready(s): + ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + print(f"{r} is detected.") + results[i] += f"{r}/" + if len(ready_list) == 0: + break + keyword_spotter.decode_streams(ready_list) + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result[0:-1]}") + print("-" * 10) + + def test_zipformer_transducer_cn(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + + tokens = ( + f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt" + ) + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" + wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" + wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav" + wave2 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav" + + if not Path(encoder).is_file(): + print("skipping test_zipformer_transducer_cn()") + return + keyword_spotter = sherpa_onnx.KeywordSpotter( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + keywords_file=keywords_file, + provider="cpu", + ) + streams = [] + waves = [wave0, wave1, wave2] + for wave in waves: + s = keyword_spotter.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + results = [""] * len(streams) + while True: + ready_list = [] + for i, s in enumerate(streams): + if keyword_spotter.is_ready(s): + ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + print(f"{r} is detected.") + results[i] += f"{r}/" + if len(ready_list) == 0: + break + keyword_spotter.decode_streams(ready_list) + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result[0:-1]}") + print("-" * 10) + + +if __name__ == "__main__": + unittest.main()