Skip to content

Commit

Permalink
surface dithering constant, 0.0 disables dithering
Browse files Browse the repository at this point in the history
- currently, dithering is not yet implemented in https://github.com/csukuangfj/kaldi-native-fbank
- i can port it there from kaldi
  • Loading branch information
KarelVesely84 committed Mar 14, 2024
1 parent 20f7aca commit 4f04fb8
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 9 deletions.
1 change: 1 addition & 0 deletions cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function(download_kaldi_native_fbank)
include(FetchContent)

# TODO: update is required, so that dithering works... (it was missing)
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.7.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.7.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=e78fd9d481d83d7d6d1be0012752e6531cb614e030558a3491e3c033cb8e0e4e")
Expand Down
14 changes: 11 additions & 3 deletions sherpa-onnx/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
"Low cutoff frequency for mel bins");

po->Register("high-freq", &high_freq,
"High cutoff frequency for mel bins (if <= 0, offset from Nyquist)");
"High cutoff frequency for mel bins "
"(if <= 0, offset from Nyquist)");

po->Register("dither", &dither,
"Dithering constant (0.0 means no dither). "
"By default the audio samples are in range [-1,+1], "
"so 0.00003 is a good value, "
"equivalent to the default 1.0 from kaldi");
}

std::string FeatureExtractorConfig::ToString() const {
Expand All @@ -40,15 +47,16 @@ std::string FeatureExtractorConfig::ToString() const {
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ", ";
os << "low_freq=" << low_freq << ", ";
os << "high_freq=" << high_freq << ")";
os << "high_freq=" << high_freq << ", ";
os << "dither=" << dither << ")";

return os.str();
}

class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ struct FeatureExtractorConfig {
// https://github.com/k2-fsa/sherpa-onnx/issues/514
float high_freq = -400.0f;

// dithering constant, useful for signals with hard-zeroes in non-speech parts
// this prevents large negative values in log-mel filterbanks
//
// In k2, audio samples are in range [-1..+1], in kaldi the range was [-32k..+32k],
// so the value 0.00003 is equivalent to kaldi default 1.0
//
float dither = 0.0f; // dithering disabled by default

// Set internally by some models, e.g., paraformer sets it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

InitKeywords();

decoder_ = std::make_unique<TransducerKeywordDecoder>(
Expand All @@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

InitKeywords(mgr);

decoder_ = std::make_unique<TransducerKeywordDecoder>(
Expand Down
7 changes: 4 additions & 3 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {

model_->SetFeatureDim(config.feat_config.feature_dim);

if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

if (config.decoding_method == "modified_beam_search") {
if (!config_.hotwords_file.empty()) {
InitHotwords();
Expand Down Expand Up @@ -126,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

if (config.decoding_method == "modified_beam_search") {
#if 0
// TODO(fangjun): Implement it
Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/online-zipformer2-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {

std::vector<Ort::Value> GetEncoderInitStates() override;

void SetFeatureDim(int32_t feature_dim) override { feature_dim_ = feature_dim; }
void SetFeatureDim(int32_t feature_dim) override {
feature_dim_ = feature_dim;
}

std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states,
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/python/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<int32_t, int32_t, float, float>(),
.def(py::init<int32_t, int32_t, float, float, float>(),
py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80,
py::arg("low_freq") = 20.0f,
py::arg("high_freq") = -400.0f)
py::arg("high_freq") = -400.0f,
py::arg("dither") = 0.0f)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("low_freq", &PyClass::low_freq)
.def_readwrite("high_freq", &PyClass::high_freq)
.def_readwrite("dither", &PyClass::high_freq)
.def("__str__", &PyClass::ToString);
}

Expand Down
7 changes: 7 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def from_transducer(
feature_dim: int = 80,
low_freq: float = 20.0,
high_freq: float = -400.0,
dither: float = 0.0,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
Expand Down Expand Up @@ -87,6 +88,11 @@ def from_transducer(
high_freq:
High cutoff frequency for mel bins in feature extraction
(if <= 0, offset from Nyquist)
dither:
Dithering constant (0.0 means no dither).
By default the audio samples are in range [-1,+1],
so dithering constant 0.00003 is a good value,
equivalent to the default 1.0 from kaldi
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
Expand Down Expand Up @@ -149,6 +155,7 @@ def from_transducer(
feature_dim=feature_dim,
low_freq=low_freq,
high_freq=high_freq,
dither=dither,
)

endpoint_config = EndpointConfig(
Expand Down

0 comments on commit 4f04fb8

Please sign in to comment.