From ccf100dcf0fc00de540b71f0d6f902c9ed27b7b9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Aug 2023 19:33:02 +0800 Subject: [PATCH] Fix hotword bugs --- CMakeLists.txt | 2 +- sherpa-ncnn/csrc/modified-beam-search-decoder.cc | 2 +- sherpa-ncnn/csrc/recognizer.cc | 14 ++++++-------- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d59a019e..08082d8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-ncnn) -set(SHERPA_NCNN_VERSION "2.1.0") +set(SHERPA_NCNN_VERSION "2.1.1") # Disable warning about # diff --git a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc index e72a6791..5e4a4ccc 100644 --- a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc +++ b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc @@ -251,7 +251,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, new_hyp.ys.push_back(new_token); new_hyp.num_trailing_blanks = 0; new_hyp.timestamps.push_back(t + frame_offset); - if (s != nullptr && s->GetContextGraph() != nullptr) { + if (s && s->GetContextGraph()) { auto context_res = s->GetContextGraph()->ForwardOneStep(context_state, new_token); context_score = context_res.first; diff --git a/sherpa-ncnn/csrc/recognizer.cc b/sherpa-ncnn/csrc/recognizer.cc index 236a757e..12b373e6 100644 --- a/sherpa-ncnn/csrc/recognizer.cc +++ b/sherpa-ncnn/csrc/recognizer.cc @@ -172,11 +172,7 @@ class Recognizer::Impl { void DecodeStream(Stream *s) const { int32_t segment = model_->Segment(); int32_t offset = model_->Offset(); - bool has_context_graph = false; - if (!has_context_graph && s->GetContextGraph()) { - has_context_graph = true; - } ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment); s->GetNumProcessedFrames() += offset; std::vector states = s->GetStates(); @@ -184,7 +180,7 @@ class Recognizer::Impl { ncnn::Mat encoder_out; std::tie(encoder_out, states) = model_->RunEncoder(features, states); - if (has_context_graph) { + if (s->GetContextGraph()) { decoder_->Decode(encoder_out, s, &s->GetResult()); } else { decoder_->Decode(encoder_out, &s->GetResult()); @@ -216,7 +212,7 @@ class Recognizer::Impl { } // Caution: We need to keep the decoder output state ncnn::Mat decoder_out = s->GetResult().decoder_out; - s->SetResult(decoder_->GetEmptyResult()); + s->SetResult(r); s->GetResult().decoder_out = decoder_out; // don't reset encoder state @@ -284,8 +280,10 @@ class Recognizer::Impl { int32_t number = sym_[word]; tmp.push_back(number); } else { - NCNN_LOGE("Cannot find ID for hotword %s at line: %s", word.c_str(), - line.c_str()); + NCNN_LOGE( + "Cannot find ID for hotword %s at line: %s. (Hint: words on the " + "same line are separated by spaces)", + word.c_str(), line.c_str()); exit(-1); } }