Skip to content

Commit

Permalink
Support TDNN models from the yesno recipe from icefall
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Aug 12, 2023
1 parent 8d2870a commit 056c39d
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 9 deletions.
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ set(sources
offline-recognizer.cc
offline-rnn-lm.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-model-config.cc
Expand Down
19 changes: 18 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

namespace {

enum class ModelType {
kEncDecCTCModelBPE,
kTdnn,
kUnkown,
};

Expand Down Expand Up @@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,

if (model_type.get() == std::string("EncDecCTCModelBPE")) {
return ModelType::kEncDecCTCModelBPE;
} else if (model_type.get() == std::string("tdnn_lstm")) {
return ModelType::kTdnn;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
Expand All @@ -65,8 +69,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;

std::string filename;
if (!config.nemo_ctc.model.empty()) {
filename = config.nemo_ctc.model;
} else if (!config.tdnn.model.empty()) {
filename = config.tdnn.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
}

{
auto buffer = ReadFile(config.nemo_ctc.model);
auto buffer = ReadFile(filename);

model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
Expand All @@ -75,6 +89,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
break;
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
Expand Down
6 changes: 3 additions & 3 deletions sherpa-onnx/csrc/offline-ctc-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class OfflineCtcModel {

/** SubsamplingFactor of the model
*
* For Citrinet, the subsampling factor is usually 4.
* For Conformer CTC, the subsampling factor is usually 8.
* For NeMo Citrinet, the subsampling factor is usually 4.
* For NeMo Conformer CTC, the subsampling factor is usually 8.
*/
virtual int32_t SubsamplingFactor() const = 0;
virtual int32_t SubsamplingFactor() const { return 1; }

/** Return an allocator for allocating memory
*/
Expand Down
17 changes: 14 additions & 3 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
std::string text;

for (int32_t i = 0; i != src.tokens.size(); ++i) {
if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
// TDNN-LSTM from yesno has a SIL token, we should remove it.
continue;
}
auto sym = sym_table[src.tokens[i]];
text.append(sym);
r.tokens.push_back(std::move(sym));
Expand All @@ -46,14 +50,21 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
model_->FeatureNormalizationMethod();

if (config.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>")) {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> and its ID.");
"the symbol <blk> or <eps> and its ID.");
exit(-1);
}

int32_t blank_id = symbol_table_["<blk>"];
int32_t blank_id = 0;
if (symbol_table_.contains("<blk>")) {
blank_id = symbol_table_["<blk>"];
} else if (symbol_table_.contains("<eps>")) {
blank_id = symbol_table_["<eps>"];
}

decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
Expand Down
16 changes: 15 additions & 1 deletion sherpa-onnx/csrc/offline-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "tdnn_lstm") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
} else {
Expand All @@ -46,6 +48,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
Expand Down Expand Up @@ -84,6 +88,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
"\n "
"(3) Whisper"
"\n "
"(4) Tdnn models"
"\n "
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
"\n"
"\n");
exit(-1);
}
Expand All @@ -102,6 +111,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}

if (model_type == "tdnn_lstm") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}

if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
}
Expand All @@ -112,7 +125,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n",
" - Whisper models\n"
" - Tdnn models\n",
model_type.c_str());

exit(-1);
Expand Down
106 changes: 106 additions & 0 deletions sherpa-onnx/csrc/offline-tdnn-ctc-model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// sherpa-onnx/csrc/offline-tdnn-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"

namespace sherpa_onnx {

class OfflineTdnnCtcModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
}

std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
auto nnet_out =
sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());

std::vector<int64_t> nnet_out_shape =
nnet_out[0].GetTensorTypeAndShapeInfo().GetShape();

std::vector<int64_t> out_length_vec(nnet_out_shape[0], nnet_out_shape[1]);
std::vector<int64_t> out_length_shape(1, 1);

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

Ort::Value nnet_out_length = Ort::Value::CreateTensor(
memory_info, out_length_vec.data(), out_length_vec.size(),
out_length_shape.data(), out_length_shape.size());

return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)};
}

int32_t VocabSize() const { return vocab_size_; }

OrtAllocator *Allocator() const { return allocator_; }

private:
void Init() {
auto buf = ReadFile(config_.tdnn.model);

sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);

// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
}

private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;

std::unique_ptr<Ort::Session> sess_;

std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;

std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;

int32_t vocab_size_ = 0;
};

OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;

std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(
Ort::Value features, Ort::Value /*features_length*/) {
return impl_->Forward(std::move(features));
}

int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); }

OrtAllocator *OfflineTdnnCtcModel::Allocator() const {
return impl_->Allocator();
}

} // namespace sherpa_onnx
56 changes: 56 additions & 0 deletions sherpa-onnx/csrc/offline-tdnn-ctc-model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// sherpa-onnx/csrc/offline-tdnn-ctc-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"

namespace sherpa_onnx {

/** This class implements the tdnn model of the yesno recipe from icefall.
*
* See
* https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn
*/
class OfflineTdnnCtcModel : public OfflineCtcModel {
public:
explicit OfflineTdnnCtcModel(const OfflineModelConfig &config);
~OfflineTdnnCtcModel() override;

/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a pair containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::pair<Ort::Value, Ort::Value> Forward(
Ort::Value features, Ort::Value /*features_length*/) override;

/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;

/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-tdnn-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace sherpa_onnx {

void OfflineTdnnModelConfig::Register(ParseOptions *po) {
po->Register("--tdnn-model", &model, "Path to onnx model");
po->Register("tdnn-model", &model, "Path to onnx model");
}

bool OfflineTdnnModelConfig::Validate() const {
Expand Down

0 comments on commit 056c39d

Please sign in to comment.