diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 52eadcc60..cb4953c56 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -32,6 +32,7 @@ set(sources offline-recognizer.cc offline-rnn-lm.cc offline-stream.cc + offline-tdnn-ctc-model.cc offline-tdnn-model-config.cc offline-transducer-greedy-search-decoder.cc offline-transducer-model-config.cc diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index 1d19253da..a7b0565c5 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -11,12 +11,14 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace { enum class ModelType { kEncDecCTCModelBPE, + kTdnn, kUnkown, }; @@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, if (model_type.get() == std::string("EncDecCTCModelBPE")) { return ModelType::kEncDecCTCModelBPE; + } else if (model_type.get() == std::string("tdnn_lstm")) { + return ModelType::kTdnn; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; @@ -65,8 +69,18 @@ std::unique_ptr OfflineCtcModel::Create( const OfflineModelConfig &config) { ModelType model_type = ModelType::kUnkown; + std::string filename; + if (!config.nemo_ctc.model.empty()) { + filename = config.nemo_ctc.model; + } else if (!config.tdnn.model.empty()) { + filename = config.tdnn.model; + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } + { - auto buffer = ReadFile(config.nemo_ctc.model); + auto buffer = ReadFile(filename); model_type = GetModelType(buffer.data(), buffer.size(), config.debug); } @@ -75,6 +89,9 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kEncDecCTCModelBPE: return std::make_unique(config); break; + case ModelType::kTdnn: + return std::make_unique(config); + break; case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; diff --git a/sherpa-onnx/csrc/offline-ctc-model.h b/sherpa-onnx/csrc/offline-ctc-model.h index 8be7f99b5..8ef43d554 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.h +++ b/sherpa-onnx/csrc/offline-ctc-model.h @@ -39,10 +39,10 @@ class OfflineCtcModel { /** SubsamplingFactor of the model * - * For Citrinet, the subsampling factor is usually 4. - * For Conformer CTC, the subsampling factor is usually 8. + * For NeMo Citrinet, the subsampling factor is usually 4. + * For NeMo Conformer CTC, the subsampling factor is usually 8. */ - virtual int32_t SubsamplingFactor() const = 0; + virtual int32_t SubsamplingFactor() const { return 1; } /** Return an allocator for allocating memory */ diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 80c946c7d..d99d83163 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, std::string text; for (int32_t i = 0; i != src.tokens.size(); ++i) { + if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) { + // TDNN-LSTM from yesno has a SIL token, we should remove it. + continue; + } auto sym = sym_table[src.tokens[i]]; text.append(sym); r.tokens.push_back(std::move(sym)); @@ -46,14 +50,21 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { model_->FeatureNormalizationMethod(); if (config.decoding_method == "greedy_search") { - if (!symbol_table_.contains("")) { + if (!symbol_table_.contains("") && + !symbol_table_.contains("")) { SHERPA_ONNX_LOGE( "We expect that tokens.txt contains " - "the symbol and its ID."); + "the symbol or and its ID."); exit(-1); } - int32_t blank_id = symbol_table_[""]; + int32_t blank_id = 0; + if (symbol_table_.contains("")) { + blank_id = symbol_table_[""]; + } else if (symbol_table_.contains("")) { + blank_id = symbol_table_[""]; + } + decoder_ = std::make_unique(blank_id); } else { SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 5058a8ce2..d73c47198 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -27,6 +27,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } else if (model_type == "nemo_ctc") { return std::make_unique(config); + } else if (model_type == "tdnn_lstm") { + return std::make_unique(config); } else if (model_type == "whisper") { return std::make_unique(config); } else { @@ -46,6 +48,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.paraformer.model; } else if (!config.model_config.nemo_ctc.model.empty()) { model_filename = config.model_config.nemo_ctc.model; + } else if (!config.model_config.tdnn.model.empty()) { + model_filename = config.model_config.tdnn.model; } else if (!config.model_config.whisper.encoder.empty()) { model_filename = config.model_config.whisper.encoder; } else { @@ -84,6 +88,11 @@ std::unique_ptr OfflineRecognizerImpl::Create( "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" "\n " "(3) Whisper" + "\n " + "(4) Tdnn models" + "\n " + "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" + "\n" "\n"); exit(-1); } @@ -102,6 +111,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (model_type == "tdnn_lstm") { + return std::make_unique(config); + } + if (strncmp(model_type.c_str(), "whisper", 7) == 0) { return std::make_unique(config); } @@ -112,7 +125,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - Non-streaming transducer models from icefall\n" " - Non-streaming Paraformer models from FunASR\n" " - EncDecCTCModelBPE models from NeMo\n" - " - Whisper models\n", + " - Whisper models\n" + " - Tdnn models\n", model_type.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc new file mode 100644 index 000000000..15695e2e1 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc @@ -0,0 +1,106 @@ +// sherpa-onnx/csrc/offline-tdnn-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class OfflineTdnnCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + Init(); + } + + std::pair Forward(Ort::Value features) { + auto nnet_out = + sess_->Run({}, input_names_ptr_.data(), &features, 1, + output_names_ptr_.data(), output_names_ptr_.size()); + + std::vector nnet_out_shape = + nnet_out[0].GetTensorTypeAndShapeInfo().GetShape(); + + std::vector out_length_vec(nnet_out_shape[0], nnet_out_shape[1]); + std::vector out_length_shape(1, 1); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + Ort::Value nnet_out_length = Ort::Value::CreateTensor( + memory_info, out_length_vec.data(), out_length_vec.size(), + out_length_shape.data(), out_length_shape.size()); + + return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)}; + } + + int32_t VocabSize() const { return vocab_size_; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init() { + auto buf = ReadFile(config_.tdnn.model); + + sess_ = std::make_unique(env_, buf.data(), buf.size(), + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; +}; + +OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; + +std::pair OfflineTdnnCtcModel::Forward( + Ort::Value features, Ort::Value /*features_length*/) { + return impl_->Forward(std::move(features)); +} + +int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); } + +OrtAllocator *OfflineTdnnCtcModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.h b/sherpa-onnx/csrc/offline-tdnn-ctc-model.h new file mode 100644 index 000000000..882e6e577 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.h @@ -0,0 +1,56 @@ +// sherpa-onnx/csrc/offline-tdnn-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +/** This class implements the tdnn model of the yesno recipe from icefall. + * + * See + * https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn + */ +class OfflineTdnnCtcModel : public OfflineCtcModel { + public: + explicit OfflineTdnnCtcModel(const OfflineModelConfig &config); + ~OfflineTdnnCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a pair containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + */ + std::pair Forward( + Ort::Value features, Ort::Value /*features_length*/) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-tdnn-model-config.cc b/sherpa-onnx/csrc/offline-tdnn-model-config.cc index e36f8956e..be1b11cd8 100644 --- a/sherpa-onnx/csrc/offline-tdnn-model-config.cc +++ b/sherpa-onnx/csrc/offline-tdnn-model-config.cc @@ -10,7 +10,7 @@ namespace sherpa_onnx { void OfflineTdnnModelConfig::Register(ParseOptions *po) { - po->Register("--tdnn-model", &model, "Path to onnx model"); + po->Register("tdnn-model", &model, "Path to onnx model"); } bool OfflineTdnnModelConfig::Validate() const {