Skip to content

Commit

Permalink
[runtime] Whisper inference support in cpp runtime (#2320)
Browse files Browse the repository at this point in the history
* Working version

* Refactor how params are passed in

* Passing in through CLI

* Fix minnor issues

* Remove extra dump for debug

* Remove unused arrays for debugging

* Remove unused header

* Change naming style of Enum

* Move init_mel_filters to it's own method

* Fix one a bug introduced in the last two commit

* Use const instead of macro

---------

Co-authored-by: hzhou245 <[email protected]>
  • Loading branch information
zhr1201 and hzhou245 authored Jan 25, 2024
1 parent 6e68e01 commit baaa27a
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 38 deletions.
14 changes: 13 additions & 1 deletion runtime/core/decoder/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DEFINE_int32(core_number, 1, "Core number of process");
// FeaturePipelineConfig flags
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
DEFINE_int32(sample_rate, 16000, "sample rate for audio");
DEFINE_string(feat_type, "kaldi", "Type of feature extraction: kaldi, whisper");

// TLG fst
DEFINE_string(fst_path, "", "TLG fst path");
Expand Down Expand Up @@ -115,9 +116,20 @@ DEFINE_int32(language_type, 0,
DEFINE_bool(lowercase, true, "lowercase final result if needed");

namespace wenet {

FeatureType StringToFeatureType(const std::string& feat_type_str) {
if (feat_type_str == "kaldi")
return FeatureType::kKaldi;
else if (feat_type_str == "whisper")
return FeatureType::kWhisper;
else
throw std::invalid_argument("Unsupported feat type!");
}

std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
FeatureType feat_type = StringToFeatureType(FLAGS_feat_type);
auto feature_config = std::make_shared<FeaturePipelineConfig>(
FLAGS_num_bins, FLAGS_sample_rate);
FLAGS_num_bins, FLAGS_sample_rate, feat_type);
return feature_config;
}

Expand Down
209 changes: 175 additions & 34 deletions runtime/core/frontend/fbank.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,39 @@ namespace wenet {

// This code is based on kaldi Fbank implementation, please see
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc

static const int kS16AbsMax = 1 << 15;

enum class WindowType {
kPovey = 0,
kHanning,
};

enum class MelType {
kHTK = 0,
kSlaney,
};

enum class NormalizationType {
kKaldi = 0,
kWhisper,
};

enum class LogBase {
kBaseE = 0,
kBase10,
};

class Fbank {
public:
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift,
float low_freq = 20, bool pre_emphasis = true,
bool scale_input_to_unit = false,
float log_floor = std::numeric_limits<float>::epsilon(),
LogBase log_base = LogBase::kBaseE,
WindowType window_type = WindowType::kPovey,
MelType mel_type = MelType::kHTK,
NormalizationType norm_type = NormalizationType::kKaldi)
: num_bins_(num_bins),
sample_rate_(sample_rate),
frame_length_(frame_length),
Expand All @@ -39,40 +69,69 @@ class Fbank {
remove_dc_offset_(true),
generator_(0),
distribution_(0, 1.0),
dither_(0.0) {
dither_(0.0),
low_freq_(low_freq),
high_freq_(sample_rate / 2),
pre_emphasis_(pre_emphasis),
scale_input_to_unit_(scale_input_to_unit),
log_floor_(log_floor),
log_base_(log_base),
norm_type_(norm_type) {
fft_points_ = UpperPowerOfTwo(frame_length_);
// generate bit reversal table and trigonometric function table
const int fft_points_4 = fft_points_ / 4;
bitrev_.resize(fft_points_);
sintbl_.resize(fft_points_ + fft_points_4);
make_sintbl(fft_points_, sintbl_.data());
make_bitrev(fft_points_, bitrev_.data());
InitMelFilters(mel_type);
InitWindow(window_type);
}

void InitMelFilters(MelType mel_type) {
int num_fft_bins = fft_points_ / 2;
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
int low_freq = 20, high_freq = sample_rate_ / 2;
float mel_low_freq = MelScale(low_freq);
float mel_high_freq = MelScale(high_freq);
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
float mel_low_freq = MelScale(low_freq_, mel_type);
float mel_high_freq = MelScale(high_freq_, mel_type);
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins_ + 1);
bins_.resize(num_bins_);
center_freqs_.resize(num_bins_);
for (int bin = 0; bin < num_bins; ++bin) {

for (int bin = 0; bin < num_bins_; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
center_freqs_[bin] = InverseMelScale(center_mel);
center_freqs_[bin] = InverseMelScale(center_mel, mel_type);
std::vector<float> this_bin(num_fft_bins);
int first_index = -1, last_index = -1;
for (int i = 0; i < num_fft_bins; ++i) {
float freq = (fft_bin_width * i); // Center frequency of this fft
// bin.
float mel = MelScale(freq);
float mel = MelScale(freq, mel_type);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else
weight = (right_mel - mel) / (right_mel - center_mel);
if (mel_type == MelType::kHTK) {
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else if (mel > center_mel)
weight = (right_mel - mel) / (right_mel - center_mel);
} else if (mel_type == MelType::kSlaney) {
if (mel <= center_mel) {
weight = (InverseMelScale(mel, mel_type) -
InverseMelScale(left_mel, mel_type)) /
(InverseMelScale(center_mel, mel_type) -
InverseMelScale(left_mel, mel_type));
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
InverseMelScale(left_mel, mel_type));
} else if (mel > center_mel) {
weight = (InverseMelScale(right_mel, mel_type) -
InverseMelScale(mel, mel_type)) /
(InverseMelScale(right_mel, mel_type) -
InverseMelScale(center_mel, mel_type));
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
InverseMelScale(left_mel, mel_type));
}
}
this_bin[i] = weight;
if (first_index == -1) first_index = i;
last_index = i;
Expand All @@ -86,12 +145,20 @@ class Fbank {
bins_[bin].second[i] = this_bin[first_index + i];
}
}
}

// povey window
povey_window_.resize(frame_length_);
double a = M_2PI / (frame_length - 1);
for (int i = 0; i < frame_length; ++i) {
povey_window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
void InitWindow(WindowType window_type) {
window_.resize(frame_length_);
if (window_type == WindowType::kPovey) {
// povey window
double a = M_2PI / (frame_length_ - 1);
for (int i = 0; i < frame_length_; ++i)
window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
} else if (window_type == WindowType::kHanning) {
// periodic hanning window
double a = M_2PI / (frame_length_);
for (int i = 0; i < frame_length_; ++i)
window_[i] = 0.5 * (1.0 - cos(i * a));
}
}

Expand All @@ -105,12 +172,45 @@ class Fbank {

int num_bins() const { return num_bins_; }

static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
static inline float InverseMelScale(float mel_freq,
MelType mel_type = MelType::kHTK) {
if (mel_type == MelType::kHTK) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
} else if (mel_type == MelType::kSlaney) {
float f_min = 0.0;
float f_sp = 200.0f / 3.0f;
float min_log_hz = 1000.0;
float freq = f_min + f_sp * mel_freq;
float min_log_mel = (min_log_hz - f_min) / f_sp;
float logstep = logf(6.4) / 27.0f;
if (mel_freq >= min_log_mel) {
return min_log_hz * expf(logstep * (mel_freq - min_log_mel));
} else {
return freq;
}
} else {
throw std::invalid_argument("Unsupported mel type!");
}
}

static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
static inline float MelScale(float freq, MelType mel_type = MelType::kHTK) {
if (mel_type == MelType::kHTK) {
return 1127.0f * logf(1.0f + freq / 700.0f);
} else if (mel_type == MelType::kSlaney) {
float f_min = 0.0;
float f_sp = 200.0f / 3.0f;
float min_log_hz = 1000.0;
float mel = (freq - f_min) / f_sp;
float min_log_mel = (min_log_hz - f_min) / f_sp;
float logstep = logf(6.4) / 27.0f;
if (freq >= min_log_hz) {
return min_log_mel + logf(freq / min_log_hz) / logstep;
} else {
return mel;
}
} else {
throw std::invalid_argument("Unsupported mel type!");
}
}

static int UpperPowerOfTwo(int n) {
Expand All @@ -125,26 +225,50 @@ class Fbank {
(*data)[0] -= coeff * (*data)[0];
}

// Apply povey window on data in place
void Povey(std::vector<float>* data) const {
CHECK_GE(data->size(), povey_window_.size());
for (size_t i = 0; i < povey_window_.size(); ++i) {
(*data)[i] *= povey_window_[i];
// Apply window on data in place
void ApplyWindow(std::vector<float>* data) const {
CHECK_GE(data->size(), window_.size());
for (size_t i = 0; i < window_.size(); ++i) {
(*data)[i] *= window_[i];
}
}

void WhisperNorm(std::vector<std::vector<float>>* feat,
float max_mel_engery) {
int num_frames = feat->size();
for (int i = 0; i < num_frames; ++i) {
for (int j = 0; j < num_bins_; ++j) {
float energy = (*feat)[i][j];
if (energy < max_mel_engery - 8) energy = max_mel_engery - 8;
energy = (energy + 4.0) / 4.0;
(*feat)[i][j] = energy;
}
}
}

// Compute fbank feat, return num frames
int Compute(const std::vector<float>& wave,
std::vector<std::vector<float>>* feat) {
int num_samples = wave.size();

if (num_samples < frame_length_) return 0;
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
feat->resize(num_frames);
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
std::vector<float> power(fft_points_ / 2);

float max_mel_engery = std::numeric_limits<float>::min();

for (int i = 0; i < num_frames; ++i) {
std::vector<float> data(wave.data() + i * frame_shift_,
wave.data() + i * frame_shift_ + frame_length_);

if (scale_input_to_unit_) {
for (int j = 0; j < frame_length_; ++j) {
data[j] = data[j] / kS16AbsMax;
}
}

// optional add noise
if (dither_ != 0.0) {
for (size_t j = 0; j < data.size(); ++j)
Expand All @@ -158,8 +282,10 @@ class Fbank {
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
}

PreEmphasis(0.97, &data);
Povey(&data);
if (pre_emphasis_) {
PreEmphasis(0.97, &data);
}
ApplyWindow(&data);
// copy data to fft_real
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
memset(fft_real.data() + frame_length_, 0,
Expand All @@ -174,6 +300,7 @@ class Fbank {

(*feat)[i].resize(num_bins_);
// cepstral coefficients, triangle filter array

for (int j = 0; j < num_bins_; ++j) {
float mel_energy = 0.0;
int s = bins_[j].first;
Expand All @@ -182,14 +309,20 @@ class Fbank {
}
// optional use log
if (use_log_) {
if (mel_energy < std::numeric_limits<float>::epsilon())
mel_energy = std::numeric_limits<float>::epsilon();
mel_energy = logf(mel_energy);
}
if (mel_energy < log_floor_) mel_energy = log_floor_;

if (log_base_ == LogBase::kBaseE)
mel_energy = logf(mel_energy);
else if (log_base_ == LogBase::kBase10)
mel_energy = log10(mel_energy);
}
if (max_mel_engery < mel_energy) max_mel_engery = mel_energy;
(*feat)[i][j] = mel_energy;
}
}
if (norm_type_ == NormalizationType::kWhisper)
WhisperNorm(feat, max_mel_engery);

return num_frames;
}

Expand All @@ -200,9 +333,17 @@ class Fbank {
int fft_points_;
bool use_log_;
bool remove_dc_offset_;
bool pre_emphasis_;
bool scale_input_to_unit_;
float low_freq_;
float log_floor_;
float high_freq_;
LogBase log_base_;
NormalizationType norm_type_;

std::vector<float> center_freqs_;
std::vector<std::pair<int, std::vector<float>>> bins_;
std::vector<float> povey_window_;
std::vector<float> window_;
std::default_random_engine generator_;
std::normal_distribution<float> distribution_;
float dither_;
Expand Down
4 changes: 3 additions & 1 deletion runtime/core/frontend/feature_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
: config_(config),
feature_dim_(config.num_bins),
fbank_(config.num_bins, config.sample_rate, config.frame_length,
config.frame_shift),
config.frame_shift, config.low_freq, config.pre_emphasis,
config.scale_input_to_unit, config.log_floor, config.log_base,
config.window_type, config.mel_type, config.norm_type),
num_frames_(0),
input_finished_(false) {}

Expand Down
Loading

0 comments on commit baaa27a

Please sign in to comment.