Skip to content

Commit

Permalink
Support reading multi-channel wave files with 8/16/32-bit encoded sam…
Browse files Browse the repository at this point in the history
…ples (#1258)
  • Loading branch information
csukuangfj authored Aug 15, 2024
1 parent 62c4d4a commit ca729fa
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 42 deletions.
28 changes: 21 additions & 7 deletions .github/scripts/test-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,28 @@ done


# test wav reader for non-standard wav files
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/naudio.wav
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/junk-padding.wav
waves=(
naudio.wav
junk-padding.wav
int8-1-channel-zh.wav
int8-2-channel-zh.wav
int8-4-channel-zh.wav
int16-1-channel-zh.wav
int16-2-channel-zh.wav
int32-1-channel-zh.wav
int32-2-channel-zh.wav
float32-1-channel-zh.wav
float32-2-channel-zh.wav
)
for w in ${waves[@]}; do
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/$w

time $EXE \
--tokens=$repo/tokens.txt \
--sense-voice-model=$repo/model.int8.onnx \
./naudio.wav \
./junk-padding.wav
time $EXE \
--tokens=$repo/tokens.txt \
--sense-voice-model=$repo/model.int8.onnx \
$w
rm -v $w
done

rm -rf $repo

Expand Down
17 changes: 8 additions & 9 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,35 +143,34 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test online punctuation
- name: Test offline CTC
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-online-punctuation
export EXE=sherpa-onnx-offline
.github/scripts/test-online-punctuation.sh
.github/scripts/test-offline-ctc.sh
du -h -d1 .
- name: Test offline transducer
- name: Test online punctuation
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
export EXE=sherpa-onnx-online-punctuation
.github/scripts/test-offline-transducer.sh
.github/scripts/test-online-punctuation.sh
du -h -d1 .
- name: Test offline CTC
- name: Test offline transducer
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
.github/scripts/test-offline-transducer.sh
du -h -d1 .
- name: Test online transducer
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-tts-frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_
#include <cstdint>
#include <string>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"
Expand Down
127 changes: 110 additions & 17 deletions sherpa-onnx/csrc/wave-reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ struct WaveHeader {
};
static_assert(sizeof(WaveHeader) == 44);

/*
sox int16-1-channel-zh.wav -b 8 int8-1-channel-zh.wav
sox int16-1-channel-zh.wav -c 2 int16-2-channel-zh.wav
we use audacity to generate int32-1-channel-zh.wav and float32-1-channel-zh.wav
because sox uses WAVE_FORMAT_EXTENSIBLE, which is not easy to support
in sherpa-onnx.
*/

// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
Expand Down Expand Up @@ -114,9 +124,18 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
is.read(reinterpret_cast<char *>(&header.audio_format),
sizeof(header.audio_format));

if (header.audio_format != 1) { // 1 for PCM
if (header.audio_format != 1 && header.audio_format != 3) {
// 1 for integer PCM
// 3 for floating point PCM
// see https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
// and https://github.com/microsoft/DirectXTK/wiki/Wave-Formats
SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n",
header.audio_format);

if (header.audio_format == static_cast<int16_t>(0xfffe)) {
SHERPA_ONNX_LOGE("We don't support WAVE_FORMAT_EXTENSIBLE files.");
}

*is_ok = false;
return {};
}
Expand All @@ -125,10 +144,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
sizeof(header.num_channels));

if (header.num_channels != 1) { // we support only single channel for now
SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n",
header.num_channels);
*is_ok = false;
return {};
SHERPA_ONNX_LOGE(
"Warning: %d channels are found. We only use the first channel.\n",
header.num_channels);
}

is.read(reinterpret_cast<char *>(&header.sample_rate),
Expand Down Expand Up @@ -161,8 +179,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return {};
}

if (header.bits_per_sample != 16) { // we support only 16 bits per sample
SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n",
if (header.bits_per_sample != 8 && header.bits_per_sample != 16 &&
header.bits_per_sample != 32) {
SHERPA_ONNX_LOGE("Expected bits_per_sample 8, 16 or 32. Given: %d\n",
header.bits_per_sample);
*is_ok = false;
return {};
Expand Down Expand Up @@ -199,21 +218,95 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,

*sampling_rate = header.sample_rate;

// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);
std::vector<float> ans;

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
if (header.bits_per_sample == 16 && header.audio_format == 1) {
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);
SHERPA_ONNX_LOGE("%d samples, bytes: %d", (int)samples.size(),
header.subchunk2_size);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);

// samples are interleaved
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i * header.num_channels] / 32768.;
}
} else if (header.bits_per_sample == 8 && header.audio_format == 1) {
// number of samples == number of bytes for 8-bit encoded samples
//
// For 8-bit encoded samples, they are unsigned!
std::vector<uint8_t> samples(header.subchunk2_size);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
// Note(fangjun): We want to normalize each sample into the range [-1, 1]
// Since each original sample is in the range [0, 256], dividing
// them by 128 converts them to the range [0, 2];
// so after subtracting 1, we get the range [-1, 1]
//
ans[i] = samples[i * header.num_channels] / 128. - 1;
}
} else if (header.bits_per_sample == 32 && header.audio_format == 1) {
// 32 here is for int32
//
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains 4 bytes, so it is divided by 4 here
std::vector<int32_t> samples(header.subchunk2_size / 4);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31);
}
} else if (header.bits_per_sample == 32 && header.audio_format == 3) {
// 32 here is for float32
//
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains 4 bytes, so it is divided by 4 here
std::vector<float> samples(header.subchunk2_size / 4);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i * header.num_channels];
}
} else {
SHERPA_ONNX_LOGE(
"Unsupported %d bits per sample and audio format: %d. Supported values "
"are: 8, 16, 32.",
header.bits_per_sample, header.audio_format);
*is_ok = false;
return {};
}

std::vector<float> ans(samples.size());
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i] / 32768.;
}

*is_ok = true;
return ans;
}
Expand Down
17 changes: 8 additions & 9 deletions sherpa-onnx/jni/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env,
return (jlong)model;
}


SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jobject _config) {
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig(
JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

Expand Down Expand Up @@ -350,9 +346,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
// [3]: lang, jstring
// [4]: emotion, jstring
// [5]: event, jstring
env->SetObjectArrayElement(obj_arr, 3, env->NewStringUTF(result.lang.c_str()));
env->SetObjectArrayElement(obj_arr, 4, env->NewStringUTF(result.emotion.c_str()));
env->SetObjectArrayElement(obj_arr, 5, env->NewStringUTF(result.event.c_str()));
env->SetObjectArrayElement(obj_arr, 3,
env->NewStringUTF(result.lang.c_str()));
env->SetObjectArrayElement(obj_arr, 4,
env->NewStringUTF(result.emotion.c_str()));
env->SetObjectArrayElement(obj_arr, 5,
env->NewStringUTF(result.event.c_str()));

return obj_arr;
}

0 comments on commit ca729fa

Please sign in to comment.