Skip to content

Commit

Permalink
export per-token scores also for greedy-search (online-transducer)
Browse files Browse the repository at this point in the history
- export un-scaled lm_probs (modified-beam search, online-transducer)
- polishing
  • Loading branch information
KarelVesely84 committed Feb 15, 2024
1 parent 3d4f212 commit a4bc688
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 10 deletions.
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct Hypothesis {

// The acoustic probability for each token in ys.
// Used for keyword spotting task.
// For transducer mofified beam-search, this is filled with log_posterior scores.
// For transducer mofified beam-search and greedy-search,
// this is filled with log_posterior scores.
std::vector<float> ys_probs;

// lm_probs[i] contains the lm score for each token in ys.
Expand Down
12 changes: 7 additions & 5 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ namespace sherpa_onnx {
/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<typename T>
const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6) {
std::ostringstrean oss;
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ " <<
oss << "[ ";
std::string sep = "";
for (auto item : vec) {
oss << sep << item;
Expand All @@ -35,9 +35,11 @@ const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6)

/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<> // explicit specialization for T = std::string
const std::string& VecToString<std::string>(const std::vector<T>& vec, int32_t) { // ignore 2nd arg
std::ostringstrean oss;
oss << "[ " <<
const std::string& VecToString<std::string>(const std::vector<std::string>& vec,
int32_t) // ignore 2nd arg
{
std::ostringstream oss;
oss << "[ ";
std::string sep = "";
for (auto item : vec) {
oss << sep << "\"" << item << "\"";
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/online-transducer-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
frame_offset = other.frame_offset;
timestamps = other.timestamps;

ys_probs = other.ys_probs;
lm_probs = other.lm_probs;
context_scores = other.context_scores;

return *this;
}

Expand All @@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
frame_offset = other.frame_offset;
timestamps = std::move(other.timestamps);

ys_probs = std::move(other.ys_probs);
lm_probs = std::move(other.lm_probs);
context_scores = std::move(other.context_scores);

return *this;
}

Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
r->tokens = std::vector<int64_t>(start, end);
}


void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {

std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();

Expand All @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
break;
}
}

if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape =
Expand Down Expand Up @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
Expand All @@ -138,6 +142,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
} else {
++r.num_trailing_blanks;
}

// export the per-token log scores
if (y != 0 && y != unk_id_) {
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
// emitted symbols
float *p_logprob = p_logit; // rename p_logit as p_logprob,
// now it contains normalized
// probability
r.ys_probs.push_back(p_logprob[y]);
}

}
if (emitted) {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);


// export per-token scores
r->ys_probs = std::move(hyp.ys_probs);
r->lm_probs = std::move(hyp.lm_probs);
Expand Down Expand Up @@ -149,8 +148,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
p_logprob = p_logit; // we changed p_logprob in the above for loop

// KarelVesely: Sholud the context score be added already before taking topk tokens ?

for (int32_t b = 0; b != batch_size; ++b) {
int32_t frame_offset = (*result)[b].frame_offset;
int32_t start = hyps_row_splits[b];
Expand Down Expand Up @@ -190,14 +187,17 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
// export the per-token log scores
{
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis& prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// for getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
new_hyp.ys_probs.push_back(y_prob);

float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
new_hyp.lm_probs.push_back(lm_prob);

new_hyp.context_scores.push_back(context_score);
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ def is_ready(self, s: OnlineStream) -> bool:
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text.strip()

def get_result_as_json_string(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).as_json_string()

def tokens(self, s: OnlineStream) -> List[str]:
return self.recognizer.get_result(s).tokens

Expand Down

0 comments on commit a4bc688

Please sign in to comment.