diff --git a/CMakeLists.txt b/CMakeLists.txt index 358a772..f027de5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,7 @@ if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") cmake_policy(SET CMP0135 NEW) endif() -set(KALDI_DECODER_VERSION "0.2.2") +set(KALDI_DECODER_VERSION "0.2.3") set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") diff --git a/kaldi-decoder/csrc/decodable-ctc.cc b/kaldi-decoder/csrc/decodable-ctc.cc index c7ba25a..d09009d 100644 --- a/kaldi-decoder/csrc/decodable-ctc.cc +++ b/kaldi-decoder/csrc/decodable-ctc.cc @@ -8,7 +8,15 @@ namespace kaldi_decoder { -DecodableCtc::DecodableCtc(const FloatMatrix &feats) : feature_matrix_(feats) {} +DecodableCtc::DecodableCtc(const FloatMatrix &log_probs) + : log_probs_(log_probs) { + p_ = &log_probs_(0, 0); + num_rows_ = log_probs_.rows(); + num_cols_ = log_probs_.cols(); +} + +DecodableCtc::DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols) + : p_(p), num_rows_(num_rows), num_cols_(num_cols) {} float DecodableCtc::LogLikelihood(int32_t frame, int32_t index) { // Note: We need to use index - 1 here since @@ -16,12 +24,12 @@ float DecodableCtc::LogLikelihood(int32_t frame, int32_t index) { // construction assert(index >= 1); - return feature_matrix_(frame, index - 1); + return *(p_ + frame * num_cols_ + index - 1); } -int32_t DecodableCtc::NumFramesReady() const { return feature_matrix_.rows(); } +int32_t DecodableCtc::NumFramesReady() const { return num_rows_; } -int32_t DecodableCtc::NumIndices() const { return feature_matrix_.cols(); } +int32_t DecodableCtc::NumIndices() const { return num_cols_; } bool DecodableCtc::IsLastFrame(int32_t frame) const { assert(frame < NumFramesReady()); diff --git a/kaldi-decoder/csrc/decodable-ctc.h b/kaldi-decoder/csrc/decodable-ctc.h index 9cf044d..7b9a419 100644 --- a/kaldi-decoder/csrc/decodable-ctc.h +++ b/kaldi-decoder/csrc/decodable-ctc.h @@ -12,7 +12,15 @@ namespace kaldi_decoder { class DecodableCtc : public DecodableInterface { public: - explicit DecodableCtc(const FloatMatrix &feats); + // It copies the input log_probs + explicit DecodableCtc(const FloatMatrix &log_probs); + + // It shares the memory with the input array. + // + // @param p Pointer to a 2-d array of shape (num_rows, num_cols). + // The array should be kept alive as long as this object is still + // alive. + DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols); float LogLikelihood(int32_t frame, int32_t index) override; @@ -25,7 +33,11 @@ class DecodableCtc : public DecodableInterface { private: // it saves log_softmax output - FloatMatrix feature_matrix_; + FloatMatrix log_probs_; + + const float *p_ = nullptr; // pointer to a 2-d array + int32_t num_rows_; // number of rows in the 2-d array + int32_t num_cols_; // number of cols in the 2-d array }; } // namespace kaldi_decoder