Skip to content

Commit

Permalink
Fix computing features for CED audio tagging models.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 12, 2024
1 parent fa20ae1 commit 3395085
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion sherpa-onnx/csrc/offline-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cassert>
#include <cmath>
#include <iomanip>
#include <limits>
#include <utility>

#include "kaldi-native-fbank/csrc/online-feature.h"
Expand Down Expand Up @@ -110,7 +111,7 @@ class OfflineStream::Impl {
config_.sampling_rate = opts_.frame_opts.samp_freq;
}

explicit Impl(CEDTag /*tag*/) {
explicit Impl(CEDTag /*tag*/) : is_ced_(true) {
// see
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py

Expand All @@ -123,7 +124,9 @@ class OfflineStream::Impl {

opts_.frame_opts.samp_freq = 16000; // fixed to 16000
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.low_freq = 0;
opts_.mel_opts.high_freq = 8000;
opts_.use_log_fbank = false;

config_.sampling_rate = opts_.frame_opts.samp_freq;

Expand Down Expand Up @@ -216,6 +219,10 @@ class OfflineStream::Impl {

NemoNormalizeFeatures(features.data(), n, feature_dim);

if (is_ced_) {
AmplitudeToDB(features.data(), features.size());
}

return features;
}

Expand All @@ -226,6 +233,32 @@ class OfflineStream::Impl {
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }

private:
// see
// https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L359
void AmplitudeToDB(float *p, int32_t n) const {
float multiplier = 10;
float top_db = 120;
float amin = 1e-10;

float max_x = std::numeric_limits<float>::min();

for (int32_t i = 0; i != n; ++i) {
float x = p[i];
x = (x > amin) ? x : amin;
x = std::log10f(x) * multiplier;

max_x = (x > max_x) ? x : max_x;
p[i] = x;
}

float d = max_x - top_db;
for (int32_t i = 0; i != n; ++i) {
float x = p[i];
x = (x > d) ? x : d;
p[i] = x;
}
}

void NemoNormalizeFeatures(float *p, int32_t num_frames,
int32_t feature_dim) const {
if (config_.nemo_normalize_type.empty()) {
Expand Down Expand Up @@ -266,6 +299,7 @@ class OfflineStream::Impl {
knf::MfccOptions mfcc_opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
bool is_ced_ = false;
};

OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
Expand Down

0 comments on commit 3395085

Please sign in to comment.