Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumstomized score for hotwords & add Finalize to stream #281

Merged
merged 10 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions sherpa-ncnn/csrc/context-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,27 @@
#include <utility>

namespace sherpa_ncnn {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
void ContextGraph::Build(const std::vector<ContextItem> &token_ids) const {
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
auto &ids = std::get<0>(token_ids[i]);
float score = std::get<1>(token_ids[i]);
float token_score = score == 0.0 ? context_score_ : score / ids.size();
for (int32_t j = 0; j < ids.size(); ++j) {
int32_t token = ids[j];
bool is_end = j == ids.size() - 1;
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
token, token_score, node->node_score + token_score,
is_end ? node->node_score + token_score : 0, is_end);
} else {
ContextState *current_node = node->next[token].get();
current_node->is_end = is_end || current_node->is_end;
current_node->token_score =
std::max(token_score, current_node->token_score);
current_node->node_score = node->node_score + current_node->token_score;
current_node->output_score =
current_node->is_end ? current_node->node_score : 0;
}
node = node->next[token].get();
}
Expand Down
7 changes: 3 additions & 4 deletions sherpa-ncnn/csrc/context-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
#include <utility>
#include <vector>


namespace sherpa_ncnn {

class ContextGraph;
using ContextGraphPtr = std::shared_ptr<ContextGraph>;
using ContextItem = std::pair<std::vector<int32_t>, float>;

struct ContextState {
int32_t token;
Expand All @@ -39,8 +39,7 @@ struct ContextState {
class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float hotwords_score)
ContextGraph(const std::vector<ContextItem> &token_ids, float hotwords_score)
: context_score_(hotwords_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
root_->fail = root_.get();
Expand All @@ -57,7 +56,7 @@ class ContextGraph {
private:
float context_score_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void Build(const std::vector<ContextItem> &token_ids) const;
void FillFailOutput() const;
};

Expand Down
77 changes: 1 addition & 76 deletions sherpa-ncnn/csrc/modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,82 +117,7 @@ ncnn::Mat ModifiedBeamSearchDecoder::BuildDecoderInput(

void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
DecoderResult *result) {
int32_t context_size = model_->ContextSize();
Hypotheses cur = std::move(result->hyps);
/* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */
for (int32_t t = 0; t != encoder_out.h; ++t) {
std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
cur.Clear();

ncnn::Mat decoder_input = BuildDecoderInput(prev);
ncnn::Mat decoder_out;
if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
!result->decoder_out.empty()) {
// When an endpoint is detected, we keep the decoder_out
decoder_out = result->decoder_out;
} else {
decoder_out = RunDecoder2D(model_, decoder_input);
}

// decoder_out.w == decoder_dim
// decoder_out.h == num_active_paths
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
// Note: encoder_out_t.h == 1, we rely on the binary op broadcasting
// in ncnn
// See https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
// broadcast B for outer axis, type 14
ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);

// joiner_out.w == vocab_size
// joiner_out.h == num_active_paths
LogSoftmax(&joiner_out);

float *p_joiner_out = joiner_out;

for (int32_t i = 0; i != joiner_out.h; ++i) {
float prev_log_prob = prev[i].log_prob;
for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
*p_joiner_out += prev_log_prob;
}
}

auto topk = TopkIndex(static_cast<float *>(joiner_out),
joiner_out.w * joiner_out.h, num_active_paths_);

int32_t frame_offset = result->frame_offset;
for (auto i : topk) {
int32_t hyp_index = i / joiner_out.w;
int32_t new_token = i % joiner_out.w;

const float *p = joiner_out.row(hyp_index);

Hypothesis new_hyp = prev[hyp_index];

// blank id is fixed to 0
if (new_token != 0) {
new_hyp.ys.push_back(new_token);
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
} else {
++new_hyp.num_trailing_blanks;
}
// We have already added prev[hyp_index].log_prob to p[new_token]
new_hyp.log_prob = p[new_token];

cur.Add(std::move(new_hyp));
}
}

result->hyps = std::move(cur);
result->frame_offset += encoder_out.h;
auto hyp = result->hyps.GetMostProbable(true);

// set decoder_out in case of endpointing
ncnn::Mat decoder_input = BuildDecoderInput({hyp});
result->decoder_out = model_->RunDecoder(decoder_input);

result->tokens = std::move(hyp.ys);
result->num_trailing_blanks = hyp.num_trailing_blanks;
Decode(encoder_out, nullptr, result);
}

void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
Expand Down
26 changes: 18 additions & 8 deletions sherpa-ncnn/csrc/recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <utility>
#include <vector>

#include "sherpa-ncnn/csrc/context-graph.h"
#include "sherpa-ncnn/csrc/decoder.h"
#include "sherpa-ncnn/csrc/greedy-search-decoder.h"
#include "sherpa-ncnn/csrc/modified-beam-search-decoder.h"
Expand Down Expand Up @@ -225,7 +226,11 @@ class Recognizer::Impl {
}

RecognitionResult GetResult(Stream *s) const {
if (IsEndpoint(s)) {
s->Finalize();
}
DecoderResult decoder_result = s->GetResult();

decoder_->StripLeadingBlanks(&decoder_result);

// Those 2 parameters are figured out from sherpa source code
Expand Down Expand Up @@ -275,20 +280,25 @@ class Recognizer::Impl {

while (std::getline(is, line)) {
std::istringstream iss(line);
float tmp_score = 0.0;
while (iss >> word) {
if (sym_.contains(word)) {
int32_t number = sym_[word];
tmp.push_back(number);
} else {
NCNN_LOGE(
"Cannot find ID for hotword %s at line: %s. (Hint: words on the "
"same line are separated by spaces)",
word.c_str(), line.c_str());
exit(-1);
if (word[0] == ':') {
tmp_score = std::stof(word.substr(1));
} else {
NCNN_LOGE(
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
"the "
pkufool marked this conversation as resolved.
Show resolved Hide resolved
"same line are separated by spaces)",
word.c_str(), line.c_str());
exit(-1);
}
}
}

hotwords_.push_back(std::move(tmp));
hotwords_.push_back(ContextItem(std::move(tmp), tmp_score));
}
}

Expand All @@ -298,7 +308,7 @@ class Recognizer::Impl {
std::unique_ptr<Decoder> decoder_;
Endpoint endpoint_;
SymbolTable sym_;
std::vector<std::vector<int32_t>> hotwords_;
std::vector<ContextItem> hotwords_;
};

Recognizer::Recognizer(const RecognizerConfig &config)
Expand Down
2 changes: 1 addition & 1 deletion sherpa-ncnn/csrc/sherpa-ncnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ for a list of pre-trained models to download.
static_cast<int>(0.3 * expected_sampling_rate));
stream->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
tail_paddings.size());

while (recognizer.IsReady(stream.get())) {
recognizer.DecodeStream(stream.get());
}
stream->Finalize();
auto result = recognizer.GetResult(stream.get());
std::cout << "Done!\n";

Expand Down
16 changes: 16 additions & 0 deletions sherpa-ncnn/csrc/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "sherpa-ncnn/csrc/stream.h"

#include <iostream>

namespace sherpa_ncnn {

class Stream::Impl {
Expand Down Expand Up @@ -49,6 +51,18 @@ class Stream::Impl {
num_processed_frames_ = 0;
}

void Finalize() {
if (!context_graph_) return;
auto &cur = result_.hyps;
for (auto iter = cur.begin(); iter != cur.end(); ++iter) {
auto context_res = context_graph_->Finalize(iter->second.context_state);
iter->second.log_prob += context_res.first;
iter->second.context_state = context_res.second;
}
auto hyp = result_.hyps.GetMostProbable(true);
result_.tokens = std::move(hyp.ys);
}

int32_t &GetNumProcessedFrames() { return num_processed_frames_; }

void SetResult(const DecoderResult &r) {
Expand Down Expand Up @@ -99,6 +113,8 @@ ncnn::Mat Stream::GetFrames(int32_t frame_index, int32_t n) const {

void Stream::Reset() { impl_->Reset(); }

void Stream::Finalize() { impl_->Finalize(); }

int32_t &Stream::GetNumProcessedFrames() {
return impl_->GetNumProcessedFrames();
}
Expand Down
9 changes: 9 additions & 0 deletions sherpa-ncnn/csrc/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class Stream {

void Reset();

/**
* Finalize the decoding result. This is mainly for decoding with hotwords
* (i.e. providing context_graph). It will cancel the boosting score of the
* partial matching paths. For example, the hotword is "BCD", the path "ABC"
* gets boosting score of "BC" but it fails to match the whole hotword "BCD",
* so we have to cancel the scores of "BC" at the end.
*/
void Finalize();

// Return a reference to the number of processed frames so far
// before subsampling..
// Initially, it is 0. It is always less than NumFramesReady().
Expand Down
Loading