From 4467fbe132d517e21e549289afb22ceedae4c544 Mon Sep 17 00:00:00 2001 From: vsd-vector Date: Tue, 10 Sep 2024 18:02:32 +0300 Subject: [PATCH] Preserve previous result as context for next segment --- .../csrc/online-recognizer-transducer-impl.h | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index ab1e165f3..a61eafa28 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -360,11 +360,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } void Reset(OnlineStream *s) const override { + int32_t context_size = model_->ContextSize(); + { // segment is incremented only when the last - // result is not empty + // result is not empty, contains non-blanks and longer than context_size) const auto &r = s->GetResult(); - if (!r.tokens.empty() && r.tokens.back() != 0) { + if (!r.tokens.empty() && r.tokens.back() != 0 && r.tokens.size() > context_size) { s->GetCurrentSegment() += 1; } } @@ -372,10 +374,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { // reset encoder states // s->SetStates(model_->GetEncoderInitStates()); - // we keep the decoder_out - decoder_->UpdateDecoderOut(&s->GetResult()); - Ort::Value decoder_out = std::move(s->GetResult().decoder_out); - auto r = decoder_->GetEmptyResult(); if (config_.decoding_method == "modified_beam_search" && nullptr != s->GetContextGraph()) { @@ -383,8 +381,19 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { it->second.context_state = s->GetContextGraph()->Root(); } } + + auto last_result = s->GetResult(); + // if last result is not empty, then + // preserve last tokens as the context for next result + if (static_cast(last_result.tokens.size()) > context_size) { + std::vector context(last_result.tokens.end() - context_size, last_result.tokens.end()); + + Hypotheses context_hyp({{context, 0}}); + r.hyps = std::move(context_hyp); + r.tokens = std::move(context); + } + s->SetResult(r); - s->GetResult().decoder_out = std::move(decoder_out); // Note: We only update counters. The underlying audio samples // are not discarded.