-
Notifications
You must be signed in to change notification settings - Fork 424
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e342705
commit 2c4b952
Showing
3 changed files
with
208 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
206 changes: 206 additions & 0 deletions
206
python-api-examples/speech-recognition-from-microphone-with-endpoint-detection-alsa.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Real-time speech recognition from a microphone with sherpa-onnx Python API | ||
# with endpoint detection. | ||
# | ||
# Note: This script uses ALSA and works only on Linux systems. | ||
# | ||
# Please refer to | ||
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
# to download pre-trained models | ||
|
||
import argparse | ||
import sys | ||
from pathlib import Path | ||
import sherpa_onnx | ||
|
||
|
||
def assert_file_exists(filename: str): | ||
assert Path(filename).is_file(), ( | ||
f"{filename} does not exist!\n" | ||
"Please refer to " | ||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | ||
) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
|
||
parser.add_argument( | ||
"--tokens", | ||
type=str, | ||
required=True, | ||
help="Path to tokens.txt", | ||
) | ||
|
||
parser.add_argument( | ||
"--encoder", | ||
type=str, | ||
required=True, | ||
help="Path to the encoder model", | ||
) | ||
|
||
parser.add_argument( | ||
"--decoder", | ||
type=str, | ||
required=True, | ||
help="Path to the decoder model", | ||
) | ||
|
||
parser.add_argument( | ||
"--joiner", | ||
type=str, | ||
required=True, | ||
help="Path to the joiner model", | ||
) | ||
|
||
parser.add_argument( | ||
"--decoding-method", | ||
type=str, | ||
default="greedy_search", | ||
help="Valid values are greedy_search and modified_beam_search", | ||
) | ||
|
||
parser.add_argument( | ||
"--provider", | ||
type=str, | ||
default="cpu", | ||
help="Valid values: cpu, cuda, coreml", | ||
) | ||
|
||
parser.add_argument( | ||
"--hotwords-file", | ||
type=str, | ||
default="", | ||
help=""" | ||
The file containing hotwords, one words/phrases per line, and for each | ||
phrase the bpe/cjkchar are separated by a space. For example: | ||
▁HE LL O ▁WORLD | ||
你 好 世 界 | ||
""", | ||
) | ||
|
||
parser.add_argument( | ||
"--hotwords-score", | ||
type=float, | ||
default=1.5, | ||
help=""" | ||
The hotword score of each token for biasing word/phrase. Used only if | ||
--hotwords-file is given. | ||
""", | ||
) | ||
|
||
parser.add_argument( | ||
"--blank-penalty", | ||
type=float, | ||
default=0.0, | ||
help=""" | ||
The penalty applied on blank symbol during decoding. | ||
Note: It is a positive value that would be applied to logits like | ||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
[batch_size, vocab] and blank id is 0). | ||
""", | ||
) | ||
|
||
parser.add_argument( | ||
"--device-name", | ||
type=str, | ||
required=True, | ||
help=""" | ||
The device name specifies which microphone to use in case there are several | ||
on your system. You can use | ||
arecord -l | ||
to find all available microphones on your computer. For instance, if it outputs | ||
**** List of CAPTURE Hardware Devices **** | ||
card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] | ||
Subdevices: 1/1 | ||
Subdevice #0: subdevice #0 | ||
and if you want to select card 3 and the device 0 on that card, please use: | ||
plughw:3,0 | ||
as the device_name. | ||
""", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def create_recognizer(args): | ||
assert_file_exists(args.encoder) | ||
assert_file_exists(args.decoder) | ||
assert_file_exists(args.joiner) | ||
assert_file_exists(args.tokens) | ||
# Please replace the model files if needed. | ||
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
# for download links. | ||
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( | ||
tokens=args.tokens, | ||
encoder=args.encoder, | ||
decoder=args.decoder, | ||
joiner=args.joiner, | ||
num_threads=1, | ||
sample_rate=16000, | ||
feature_dim=80, | ||
enable_endpoint_detection=True, | ||
rule1_min_trailing_silence=2.4, | ||
rule2_min_trailing_silence=1.2, | ||
rule3_min_utterance_length=300, # it essentially disables this rule | ||
decoding_method=args.decoding_method, | ||
provider=args.provider, | ||
hotwords_file=args.hotwords_file, | ||
hotwords_score=args.hotwords_score, | ||
blank_penalty=args.blank_penalty, | ||
) | ||
return recognizer | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
device_name = args.device_name | ||
print(f"device_name: {device_name}") | ||
alsa = sherpa_onnx.Alsa(device_name) | ||
|
||
print("Creating recognizer") | ||
recognizer = create_recognizer(args) | ||
print("Started! Please speak") | ||
|
||
sample_rate = 16000 | ||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | ||
|
||
stream = recognizer.create_stream() | ||
|
||
last_result = "" | ||
segment_id = 0 | ||
while True: | ||
samples = alsa.read(samples_per_read) # a blocking read | ||
stream.accept_waveform(sample_rate, samples) | ||
while recognizer.is_ready(stream): | ||
recognizer.decode_stream(stream) | ||
|
||
is_endpoint = recognizer.is_endpoint(stream) | ||
|
||
result = recognizer.get_result(stream) | ||
|
||
if result and (last_result != result): | ||
last_result = result | ||
print("\r{}:{}".format(segment_id, result), end="", flush=True) | ||
if is_endpoint: | ||
if result: | ||
print("\r{}:{}".format(segment_id, result), flush=True) | ||
segment_id += 1 | ||
recognizer.reset(stream) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
main() | ||
except KeyboardInterrupt: | ||
print("\nCaught Ctrl + C. Exiting") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters