Skip to content

Commit

Permalink
Fix hotword bugs (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Aug 31, 2023
1 parent 401de81 commit 8be3e08
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-ncnn)

set(SHERPA_NCNN_VERSION "2.1.0")
set(SHERPA_NCNN_VERSION "2.1.1")

# Disable warning about
#
Expand Down
2 changes: 1 addition & 1 deletion sherpa-ncnn/csrc/modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
new_hyp.ys.push_back(new_token);
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
if (s != nullptr && s->GetContextGraph() != nullptr) {
if (s && s->GetContextGraph()) {
auto context_res =
s->GetContextGraph()->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
Expand Down
14 changes: 6 additions & 8 deletions sherpa-ncnn/csrc/recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,15 @@ class Recognizer::Impl {
void DecodeStream(Stream *s) const {
int32_t segment = model_->Segment();
int32_t offset = model_->Offset();
bool has_context_graph = false;

if (!has_context_graph && s->GetContextGraph()) {
has_context_graph = true;
}
ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment);
s->GetNumProcessedFrames() += offset;
std::vector<ncnn::Mat> states = s->GetStates();

ncnn::Mat encoder_out;
std::tie(encoder_out, states) = model_->RunEncoder(features, states);

if (has_context_graph) {
if (s->GetContextGraph()) {
decoder_->Decode(encoder_out, s, &s->GetResult());
} else {
decoder_->Decode(encoder_out, &s->GetResult());
Expand Down Expand Up @@ -216,7 +212,7 @@ class Recognizer::Impl {
}
// Caution: We need to keep the decoder output state
ncnn::Mat decoder_out = s->GetResult().decoder_out;
s->SetResult(decoder_->GetEmptyResult());
s->SetResult(r);
s->GetResult().decoder_out = decoder_out;

// don't reset encoder state
Expand Down Expand Up @@ -284,8 +280,10 @@ class Recognizer::Impl {
int32_t number = sym_[word];
tmp.push_back(number);
} else {
NCNN_LOGE("Cannot find ID for hotword %s at line: %s", word.c_str(),
line.c_str());
NCNN_LOGE(
"Cannot find ID for hotword %s at line: %s. (Hint: words on the "
"same line are separated by spaces)",
word.c_str(), line.c_str());
exit(-1);
}
}
Expand Down

0 comments on commit 8be3e08

Please sign in to comment.