Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add WebAssembly for Kws #648

Merged
merged 19 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
option(SHERPA_ONNX_ENABLE_WASM_KWS "Whether to enable WASM for KWS" OFF)
option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF)
option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON)
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON)
Expand Down Expand Up @@ -133,6 +134,10 @@ if(SHERPA_ONNX_ENABLE_WASM)
add_definitions(-DSHERPA_ONNX_ENABLE_WASM=1)
endif()

if(SHERPA_ONNX_ENABLE_WASM_KWS)
add_definitions(-DSHERPA_ONNX_ENABLE_WASM_KWS=1)
endif()

if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
endif()
Expand Down
155 changes: 155 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"

struct SherpaOnnxOnlineRecognizer {
std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
Expand Down Expand Up @@ -648,3 +649,157 @@ int32_t SherpaOnnxWriteWave(const float *samples, int32_t n,
int32_t sample_rate, const char *filename) {
return sherpa_onnx::WriteWave(filename, sample_rate, samples, n);
}

struct SherpaOnnxOnlineKws {
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
};

// ============================================================
// For KWS
// ============================================================
//
SherpaOnnxOnlineKws *CreateOnlineKws(
const SherpaOnnxOnlineKwsConfig *config) {

sherpa_onnx::KeywordSpotterConfig kws_config;

kws_config.feat_config.sampling_rate =
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);

kws_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);

kws_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");

kws_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");

kws_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");

kws_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");

kws_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);

kws_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);

kws_config.num_trailing_blanks =
SHERPA_ONNX_OR(config->num_trailing_blanks, 1);

kws_config.num_trailing_blanks =
SHERPA_ONNX_OR(config->keywords_score, 1.0);

kws_config.keywords_threshold =
SHERPA_ONNX_OR(config->keywords_threshold, 0.25);

kws_config.keywords_file = SHERPA_ONNX_OR(config->keywords, "");

SHERPA_ONNX_LOGE("%s\n", kws_config.ToString().c_str());

SherpaOnnxOnlineKws *kws_recognizer = new SherpaOnnxOnlineKws;

kws_recognizer->impl =
std::make_unique<sherpa_onnx::KeywordSpotter>(kws_config);

return kws_recognizer;
}

SherpaOnnxOnlineStream *CreateOnlineKwsStream(
const SherpaOnnxOnlineKws *kws_recognizer) {
SherpaOnnxOnlineStream *stream =
new SherpaOnnxOnlineStream(kws_recognizer->impl->CreateStream());
return stream;
}

void DestroyOnlineKwsStream(SherpaOnnxOnlineStream *stream) { delete stream; }

void DestroyOnlineKws(SherpaOnnxOnlineKws *recognizer) {
delete recognizer;
}

int32_t IsOnlineKwsStreamReady(SherpaOnnxOnlineKws *recognizer,
SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsReady(stream->impl.get());
}

void DecodeOnlineKwsStream(SherpaOnnxOnlineKws *recognizer,
SherpaOnnxOnlineStream *stream) {
recognizer->impl->DecodeStream(stream->impl.get());
}

const SherpaOnnxOnlineKwsResult *GetOnlineKwsStreamResult(
SherpaOnnxOnlineKws *recognizer, SherpaOnnxOnlineStream *stream) {
sherpa_onnx::KeywordResult result =
recognizer->impl->GetResult(stream->impl.get());
const auto &text = result.keyword;

auto r = new SherpaOnnxOnlineKwsResult;
memset(r, 0, sizeof(SherpaOnnxOnlineKwsResult));

// copy text
r->keyword = new char[text.size() + 1];
std::copy(text.begin(), text.end(), const_cast<char *>(r->keyword));
const_cast<char *>(r->keyword)[text.size()] = 0;

// copy json
const auto &json = result.AsJsonString();
r->json = new char[json.size() + 1];
std::copy(json.begin(), json.end(), const_cast<char *>(r->json));
const_cast<char *>(r->json)[json.size()] = 0;

// copy tokens
auto count = result.tokens.size();
if (count > 0) {
size_t total_length = 0;
for (const auto &token : result.tokens) {
// +1 for the null character at the end of each token
total_length += token.size() + 1;
}

// Each word ends with nullptr
r->tokens = new char[total_length];
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
char **tokens_temp = new char *[count];
int32_t pos = 0;
for (int32_t i = 0; i < count; ++i) {
tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
pos += result.tokens[i].size() + 1;
}
r->tokens_arr = tokens_temp;

if (!result.timestamps.empty()) {
r->timestamps = new float[count];
std::copy(result.timestamps.begin(), result.timestamps.end(),
r->timestamps);
} else {
r->timestamps = nullptr;
}

} else {
r->timestamps = nullptr;
r->tokens = nullptr;
r->tokens_arr = nullptr;
}

return r;
}

void DestroyOnlineKwsResult(const SherpaOnnxOnlineKwsResult *r) {
if (r) {
delete[] r->keyword;
delete[] r->json;
delete[] r->tokens;
delete[] r->tokens_arr;
delete[] r->timestamps;
delete r;
}
}

109 changes: 109 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,115 @@ SHERPA_ONNX_API int32_t SherpaOnnxWriteWave(const float *samples, int32_t n,
int32_t sample_rate,
const char *filename);

// ============================================================
// For online KWS
// ============================================================

SHERPA_ONNX_API typedef struct SherpaOnnxOnlineKwsModelConfig {
SherpaOnnxOnlineTransducerModelConfig transducer;
const char *tokens;
int32_t num_threads;
} SherpaOnnxOnlineKwsModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlineKwsConfig {
SherpaOnnxFeatureConfig feat_config;
SherpaOnnxOnlineKwsModelConfig model_config;

/// Used only when decoding_method is modified_beam_search
/// Example value: 4
int32_t max_active_paths;
int32_t num_trailing_blanks;
float keywords_score;
float keywords_threshold;
/// Path to the keywords.
const char *keywords;
} SherpaOnnxOnlineKwsConfig;


SHERPA_ONNX_API typedef struct SherpaOnnxOnlineKwsResult {
// Recognized text
const char *keyword;

// Pointer to continuous memory which holds string based tokens
// which are separated by \0
const char *tokens;

// a pointer array containing the address of the first item in tokens
const char *const *tokens_arr;

// Pointer to continuous memory which holds timestamps
float *timestamps;

/** Return a json string.
*
* The returned string contains:
* {
* "keyword": "The kws keyword result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* }
*/
const char *json;
} SherpaOnnxOnlineKwsResult;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlineKws SherpaOnnxOnlineKws;

/// @param config Config for the kws recognizer.
/// @return Return a pointer to the kws recognizer. The user has to invoke
// DestroyOnlineKws() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxOnlineKws *CreateOnlineKws(
const SherpaOnnxOnlineKwsConfig *config);

SHERPA_ONNX_API SherpaOnnxOnlineStream *CreateOnlineKwsStream(
const SherpaOnnxOnlineKws *kws_recognizer);

/// Free a pointer returned by CreateOnlineKws()
/// @param recognizer A pointer returned by CreateOnlineKws()
SHERPA_ONNX_API void DestroyOnlineKws(
SherpaOnnxOnlineKws *recognizer);

/// Destroy an online stream.
/// @param stream A pointer returned by CreateOnlineStream()
SHERPA_ONNX_API void DestroyOnlineKwsStream(SherpaOnnxOnlineStream *stream);

/// Get the decoding results so far for an OnlineKwsStream.
///
/// @param recognizer A pointer returned by CreateOnlineKws().
/// @param stream A pointer returned by CreateOnlineKwsStream().
/// @return A pointer containing the result. The user has to invoke
/// DestroyOnlineKwsResult() to free the returned pointer to
/// avoid memory leak.
SHERPA_ONNX_API const SherpaOnnxOnlineKwsResult *GetOnlineKwsStreamResult(
SherpaOnnxOnlineKws *recognizer, SherpaOnnxOnlineStream *stream);

/// Destroy the pointer returned by GetOnlineKwsStreamResult().
///
/// @param r A pointer returned by GetOnlineKwsStreamResult()
SHERPA_ONNX_API void DestroyOnlineKwsResult(
const SherpaOnnxOnlineKwsResult *r);

/// Return 1 if there are enough number of feature frames for decoding.
/// Return 0 otherwise.
///
/// @param kws_recognizer A pointer returned by CreateOnlineKws
/// @param stream A pointer returned by CreateOnlineKwsStream
SHERPA_ONNX_API int32_t IsOnlineKwsStreamReady(
SherpaOnnxOnlineKws *kws_recognizer, SherpaOnnxOnlineStream *stream);


/// Call this function to run the neural network model and decoding.
//
/// Precondition for this function: IsOnlineStreamReady() MUST return 1.
///
/// Usage example:
///
/// while (IsOnlineKwsStreamReady(recognizer, stream)) {
/// DecodeOnlineKwsStream(recognizer, stream);
/// }
///
SHERPA_ONNX_API void DecodeOnlineKwsStream(SherpaOnnxOnlineKws *kws_recognizer,
SherpaOnnxOnlineStream *stream);

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
Expand Down
12 changes: 11 additions & 1 deletion sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,25 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
}

void InitKeywords() {
// each line in keywords_file contains space-separated words

#ifdef SHERPA_ONNX_ENABLE_WASM_KWS
// Due to the limitations of the wasm file system,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pass a keyword file from wasm?

We have been doing this for model files, such as tokens.txt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but it needs to be recompiled when the keywords is modified, which is very inconvenient. Because token.txt and keywords.txt are both packaged into the sherpa-onnx-wasm-kws-main.data file, they cannot be modified and can only be recompiled to generate. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Then please keep using your current approach.

// the keyword_file variable is directly parsed as a string of keywords
// if WASM KWS on
std::istringstream is(config_.keywords_file);
InitKeywords(is);
#else
// each line in keywords_file contains space-separated words
std::ifstream is(config_.keywords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
config_.keywords_file.c_str());
exit(-1);
}
InitKeywords(is);
#endif


}

#if __ANDROID_API__ >= 9
Expand Down
4 changes: 4 additions & 0 deletions wasm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ if(SHERPA_ONNX_ENABLE_WASM_ASR)
add_subdirectory(asr)
endif()

if(SHERPA_ONNX_ENABLE_WASM_KWS)
add_subdirectory(kws)
endif()

if(SHERPA_ONNX_ENABLE_WASM_NODEJS)
add_subdirectory(nodejs)
endif()
55 changes: 55 additions & 0 deletions wasm/kws/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
if(NOT $ENV{SHERPA_ONNX_IS_USING_BUILD_WASM_SH})
message(FATAL_ERROR "Please use ./build-wasm-kws.sh to build for wasm KWS")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
message(FATAL_ERROR "Please use ./build-wasm-kws.sh to build for wasm KWS")
message(FATAL_ERROR "Please use ./build-wasm-simd-kws.sh to build for wasm KWS")

endif()

if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/assets/decoder-epoch-12-avg-2-chunk-16-left-64.onnx")
message(WARNING "${CMAKE_CURRENT_SOURCE_DIR}/assets/decoder-epoch-12-avg-2-chunk-16-left-64.onnx does not exist")
message(FATAL_ERROR "Please read ${CMAKE_CURRENT_SOURCE_DIR}/assets/README.md before you continue")
endif()

set(exported_functions
AcceptWaveform
CreateOnlineKws
CreateOnlineKwsStream
GetOnlineKwsStreamResult
DecodeOnlineKwsStream
DestroyOnlineKws
DestroyOnlineKwsResult
DestroyOnlineKwsStream
IsOnlineKwsStreamReady
InputFinished
)
set(mangled_exported_functions)
foreach(x IN LISTS exported_functions)
list(APPEND mangled_exported_functions "_${x}")
endforeach()

list(JOIN mangled_exported_functions "," all_exported_functions)

include_directories(${CMAKE_SOURCE_DIR})
set(MY_FLAGS "-s FORCE_FILESYSTEM=1 -s INITIAL_MEMORY=512MB -s ALLOW_MEMORY_GROWTH=1")
string(APPEND MY_FLAGS " -sSTACK_SIZE=10485760 ")
string(APPEND MY_FLAGS " -sEXPORTED_FUNCTIONS=[_CopyHeap,_malloc,_free,${all_exported_functions}] ")
string(APPEND MY_FLAGS "--preload-file ${CMAKE_CURRENT_SOURCE_DIR}/assets@. ")
string(APPEND MY_FLAGS " -sEXPORTED_RUNTIME_METHODS=['ccall','stringToUTF8','setValue','getValue','lengthBytesUTF8','UTF8ToString'] ")
message(STATUS "MY_FLAGS: ${MY_FLAGS}")

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${MY_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MY_FLAGS}")
set(CMAKE_EXECUTBLE_LINKER_FLAGS "${CMAKE_EXECUTBLE_LINKER_FLAGS} ${MY_FLAGS}")

add_executable(sherpa-onnx-wasm-kws-main sherpa-onnx-wasm-main-kws.cc)
target_link_libraries(sherpa-onnx-wasm-kws-main sherpa-onnx-c-api)
install(TARGETS sherpa-onnx-wasm-kws-main DESTINATION bin/wasm)

install(
FILES
"sherpa-onnx-kws.js"
"app.js"
"index.html"
"$<TARGET_FILE_DIR:sherpa-onnx-wasm-kws-main>/sherpa-onnx-wasm-kws-main.js"
"$<TARGET_FILE_DIR:sherpa-onnx-wasm-kws-main>/sherpa-onnx-wasm-kws-main.wasm"
"$<TARGET_FILE_DIR:sherpa-onnx-wasm-kws-main>/sherpa-onnx-wasm-kws-main.data"
DESTINATION
bin/wasm
)
Loading
Loading