-
Notifications
You must be signed in to change notification settings - Fork 424
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support TDNN models from the yesno recipe from icefall
- Loading branch information
1 parent
8d2870a
commit 056c39d
Showing
8 changed files
with
214 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// sherpa-onnx/csrc/offline-tdnn-ctc-model.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "sherpa-onnx/csrc/onnx-utils.h" | ||
#include "sherpa-onnx/csrc/session.h" | ||
#include "sherpa-onnx/csrc/text-utils.h" | ||
#include "sherpa-onnx/csrc/transpose.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class OfflineTdnnCtcModel::Impl { | ||
public: | ||
explicit Impl(const OfflineModelConfig &config) | ||
: config_(config), | ||
env_(ORT_LOGGING_LEVEL_ERROR), | ||
sess_opts_(GetSessionOptions(config)), | ||
allocator_{} { | ||
Init(); | ||
} | ||
|
||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) { | ||
auto nnet_out = | ||
sess_->Run({}, input_names_ptr_.data(), &features, 1, | ||
output_names_ptr_.data(), output_names_ptr_.size()); | ||
|
||
std::vector<int64_t> nnet_out_shape = | ||
nnet_out[0].GetTensorTypeAndShapeInfo().GetShape(); | ||
|
||
std::vector<int64_t> out_length_vec(nnet_out_shape[0], nnet_out_shape[1]); | ||
std::vector<int64_t> out_length_shape(1, 1); | ||
|
||
auto memory_info = | ||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
|
||
Ort::Value nnet_out_length = Ort::Value::CreateTensor( | ||
memory_info, out_length_vec.data(), out_length_vec.size(), | ||
out_length_shape.data(), out_length_shape.size()); | ||
|
||
return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)}; | ||
} | ||
|
||
int32_t VocabSize() const { return vocab_size_; } | ||
|
||
OrtAllocator *Allocator() const { return allocator_; } | ||
|
||
private: | ||
void Init() { | ||
auto buf = ReadFile(config_.tdnn.model); | ||
|
||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(), | ||
sess_opts_); | ||
|
||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
|
||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
|
||
// get meta data | ||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
if (config_.debug) { | ||
std::ostringstream os; | ||
PrintModelMetadata(os, meta_data); | ||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
} | ||
|
||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
} | ||
|
||
private: | ||
OfflineModelConfig config_; | ||
Ort::Env env_; | ||
Ort::SessionOptions sess_opts_; | ||
Ort::AllocatorWithDefaultOptions allocator_; | ||
|
||
std::unique_ptr<Ort::Session> sess_; | ||
|
||
std::vector<std::string> input_names_; | ||
std::vector<const char *> input_names_ptr_; | ||
|
||
std::vector<std::string> output_names_; | ||
std::vector<const char *> output_names_ptr_; | ||
|
||
int32_t vocab_size_ = 0; | ||
}; | ||
|
||
OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config) | ||
: impl_(std::make_unique<Impl>(config)) {} | ||
|
||
OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; | ||
|
||
std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward( | ||
Ort::Value features, Ort::Value /*features_length*/) { | ||
return impl_->Forward(std::move(features)); | ||
} | ||
|
||
int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); } | ||
|
||
OrtAllocator *OfflineTdnnCtcModel::Allocator() const { | ||
return impl_->Allocator(); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// sherpa-onnx/csrc/offline-tdnn-ctc-model.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ | ||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "onnxruntime_cxx_api.h" // NOLINT | ||
#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
#include "sherpa-onnx/csrc/offline-model-config.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
/** This class implements the tdnn model of the yesno recipe from icefall. | ||
* | ||
* See | ||
* https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn | ||
*/ | ||
class OfflineTdnnCtcModel : public OfflineCtcModel { | ||
public: | ||
explicit OfflineTdnnCtcModel(const OfflineModelConfig &config); | ||
~OfflineTdnnCtcModel() override; | ||
|
||
/** Run the forward method of the model. | ||
* | ||
* @param features A tensor of shape (N, T, C). It is changed in-place. | ||
* @param features_length A 1-D tensor of shape (N,) containing number of | ||
* valid frames in `features` before padding. | ||
* Its dtype is int64_t. | ||
* | ||
* @return Return a pair containing: | ||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size). | ||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | ||
*/ | ||
std::pair<Ort::Value, Ort::Value> Forward( | ||
Ort::Value features, Ort::Value /*features_length*/) override; | ||
|
||
/** Return the vocabulary size of the model | ||
*/ | ||
int32_t VocabSize() const override; | ||
|
||
/** Return an allocator for allocating memory | ||
*/ | ||
OrtAllocator *Allocator() const override; | ||
|
||
private: | ||
class Impl; | ||
std::unique_ptr<Impl> impl_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters