Skip to content

Commit

Permalink
fix style issues
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Aug 31, 2023
1 parent 5662310 commit 51d9b8c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
9 changes: 3 additions & 6 deletions sherpa-ncnn/csrc/modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
result->num_trailing_blanks = hyp.num_trailing_blanks;
}


void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
DecoderResult *result) {
int32_t context_size = model_->ContextSize();
Expand All @@ -205,7 +204,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
cur.Clear();


ncnn::Mat decoder_input = BuildDecoderInput(prev);
ncnn::Mat decoder_out;
if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
Expand All @@ -218,14 +216,13 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,

// decoder_out.w == decoder_dim
// decoder_out.h == num_active_paths
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));

ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
// joiner_out.w == vocab_size
// joiner_out.h == num_active_paths
LogSoftmax(&joiner_out);


float *p_joiner_out = joiner_out;

for (int32_t i = 0; i != joiner_out.h; ++i) {
Expand Down Expand Up @@ -255,8 +252,8 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
if (s != nullptr && s->GetContextGraph() != nullptr) {
auto context_res = s->GetContextGraph()->ForwardOneStep(
context_state, new_token);
auto context_res =
s->GetContextGraph()->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
Expand Down
13 changes: 8 additions & 5 deletions sherpa-ncnn/csrc/sherpa-ncnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
*/

#include <stdio.h>
#include <fstream>

#include <algorithm>
#include <chrono> // NOLINT
#include <fstream>
#include <iostream>

#include "net.h" // NOLINT
Expand Down Expand Up @@ -72,17 +73,19 @@ for a list of pre-trained models to download.
config.decoder_config.method = method;
}
}
std::cout<<"decode method:"<<config.decoder_config.method<<std::endl;
if(argc >= 12) {
config.hotwords_file = argv[11];

if (argc >= 12) {
config.hotwords_file = argv[11];
} else {
config.hotwords_file = "";
}
if(argc == 13) {

if (argc == 13) {
config.hotwords_score = atof(argv[12]);
} else {
config.hotwords_file = 1.5;
}

config.feat_config.sampling_rate = expected_sampling_rate;
config.feat_config.feature_dim = 80;

Expand Down
8 changes: 5 additions & 3 deletions sherpa-ncnn/csrc/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace sherpa_ncnn {

class Stream::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config,ContextGraphPtr context_graph)
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: feat_extractor_(config), context_graph_(context_graph) {}

void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
Expand Down Expand Up @@ -73,7 +74,8 @@ class Stream::Impl {
std::vector<ncnn::Mat> states_;
};

Stream::Stream(const FeatureExtractorConfig &config, ContextGraphPtr context_graph)
Stream::Stream(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: impl_(std::make_unique<Impl>(config, context_graph)) {}

Stream::~Stream() = default;
Expand Down Expand Up @@ -113,5 +115,5 @@ std::vector<ncnn::Mat> &Stream::GetStates() { return impl_->GetStates(); }

const ContextGraphPtr &Stream::GetContextGraph() const {
return impl_->GetContextGraph();
}
}
} // namespace sherpa_ncnn
4 changes: 2 additions & 2 deletions sherpa-ncnn/csrc/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
#include <memory>
#include <vector>

#include "sherpa-ncnn/csrc/context-graph.h"
#include "sherpa-ncnn/csrc/decoder.h"
#include "sherpa-ncnn/csrc/features.h"
#include "sherpa-ncnn/csrc/context-graph.h"

namespace sherpa_ncnn {
class Stream {
public:
explicit Stream(const FeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
ContextGraphPtr context_graph = nullptr);
~Stream();

/**
Expand Down

0 comments on commit 51d9b8c

Please sign in to comment.