Skip to content

Commit

Permalink
Support sharing memory with log_probs in DecodableCtc. (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 7, 2023
1 parent 5d3ef5b commit 8085869
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif()

set(KALDI_DECODER_VERSION "0.2.2")
set(KALDI_DECODER_VERSION "0.2.3")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
Expand Down
16 changes: 12 additions & 4 deletions kaldi-decoder/csrc/decodable-ctc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,28 @@

namespace kaldi_decoder {

DecodableCtc::DecodableCtc(const FloatMatrix &feats) : feature_matrix_(feats) {}
DecodableCtc::DecodableCtc(const FloatMatrix &log_probs)
: log_probs_(log_probs) {
p_ = &log_probs_(0, 0);
num_rows_ = log_probs_.rows();
num_cols_ = log_probs_.cols();
}

DecodableCtc::DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols)
: p_(p), num_rows_(num_rows), num_cols_(num_cols) {}

float DecodableCtc::LogLikelihood(int32_t frame, int32_t index) {
// Note: We need to use index - 1 here since
// all the input labels of the H are incremented during graph
// construction
assert(index >= 1);

return feature_matrix_(frame, index - 1);
return *(p_ + frame * num_cols_ + index - 1);
}

int32_t DecodableCtc::NumFramesReady() const { return feature_matrix_.rows(); }
int32_t DecodableCtc::NumFramesReady() const { return num_rows_; }

int32_t DecodableCtc::NumIndices() const { return feature_matrix_.cols(); }
int32_t DecodableCtc::NumIndices() const { return num_cols_; }

bool DecodableCtc::IsLastFrame(int32_t frame) const {
assert(frame < NumFramesReady());
Expand Down
16 changes: 14 additions & 2 deletions kaldi-decoder/csrc/decodable-ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ namespace kaldi_decoder {

class DecodableCtc : public DecodableInterface {
public:
explicit DecodableCtc(const FloatMatrix &feats);
// It copies the input log_probs
explicit DecodableCtc(const FloatMatrix &log_probs);

// It shares the memory with the input array.
//
// @param p Pointer to a 2-d array of shape (num_rows, num_cols).
// The array should be kept alive as long as this object is still
// alive.
DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols);

float LogLikelihood(int32_t frame, int32_t index) override;

Expand All @@ -25,7 +33,11 @@ class DecodableCtc : public DecodableInterface {

private:
// it saves log_softmax output
FloatMatrix feature_matrix_;
FloatMatrix log_probs_;

const float *p_ = nullptr; // pointer to a 2-d array
int32_t num_rows_; // number of rows in the 2-d array
int32_t num_cols_; // number of cols in the 2-d array
};

} // namespace kaldi_decoder
Expand Down

0 comments on commit 8085869

Please sign in to comment.