Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kotlin API for speaker diarization #1415

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kotlin-api-examples/OfflineSpeakerDiarization.kt
31 changes: 31 additions & 0 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,37 @@ function testPunctuation() {
java -Djava.library.path=../build/lib -jar $out_filename
}

function testOfflineSpeakerDiarization() {
if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
fi

if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
fi

if [ ! -f ./0-four-speakers-zh.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
fi

out_filename=test_offline_speaker_diarization.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_offline_speaker_diarization.kt \
OfflineSpeakerDiarization.kt \
Speaker.kt \
OnlineStream.kt \
WaveReader.kt \
faked-asset-manager.kt \
faked-log.kt

ls -lh $out_filename

java -Djava.library.path=../build/lib -jar $out_filename
}

testOfflineSpeakerDiarization
testSpeakerEmbeddingExtractor
testOnlineAsr
testTts
Expand Down
53 changes: 53 additions & 0 deletions kotlin-api-examples/test_offline_speaker_diarization.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.k2fsa.sherpa.onnx

fun main() {
testOfflineSpeakerDiarization()
}

fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int {
val progress = numProcessedChunks.toFloat() / numTotalChunks * 100
val s = "%.2f".format(progress)
println("Progress: ${s}%");

return 0
}

fun testOfflineSpeakerDiarization() {
var config = OfflineSpeakerDiarizationConfig(
segmentation=OfflineSpeakerSegmentationModelConfig(
pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"),
),
embedding=SpeakerEmbeddingExtractorConfig(
model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx",
),

// The test wave file ./0-four-speakers-zh.wav contains four speakers, so
// we use numClusters=4 here. If you don't know the number of speakers
// in the test wave file, please set the threshold like below.
//
// clustering=FastClusteringConfig(threshold=0.5),
//
// WARNING: You need to tune threshold by yourself.
// A larger threshold leads to fewer clusters, i.e., few speakers.
// A smaller threshold leads to more clusters, i.e., more speakers.
//
clustering=FastClusteringConfig(numClusters=4),
)

val sd = OfflineSpeakerDiarization(config=config)

val waveData = WaveReader.readWave(
filename = "./0-four-speakers-zh.wav",
)

if (sd.sampleRate() != waveData.sampleRate) {
println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}")
return
}

// val segments = sd.process(waveData.samples) // this one is also ok
val segments = sd.processWithCallback(waveData.samples, callback=::callback)
for (segment in segments) {
println("${segment.start} -- ${segment.end} speaker_${segment.speaker}")
}
}
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-speaker-diarization-result.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class OfflineSpeakerDiarizationResult {
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker()
const;

public:
private:
std::vector<OfflineSpeakerDiarizationSegment> segments_;
};

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ if(SHERPA_ONNX_ENABLE_TTS)
)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sources
offline-speaker-diarization.cc
)
endif()

add_library(sherpa-onnx-jni ${sources})

target_compile_definitions(sherpa-onnx-jni PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1)
Expand Down
219 changes: 219 additions & 0 deletions sherpa-onnx/jni/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// sherpa-onnx/jni/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-speaker-diarization.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"

namespace sherpa_onnx {

static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig(
JNIEnv *env, jobject config) {
OfflineSpeakerDiarizationConfig ans;

jclass cls = env->GetObjectClass(config);
jfieldID fid;

//---------- segmentation ----------
fid = env->GetFieldID(
cls, "segmentation",
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;");
jobject segmentation_config = env->GetObjectField(config, fid);
jclass segmentation_config_cls = env->GetObjectClass(segmentation_config);

fid = env->GetFieldID(
segmentation_config_cls, "pyannote",
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;");
jobject pyannote_config = env->GetObjectField(segmentation_config, fid);
jclass pyannote_config_cls = env->GetObjectClass(pyannote_config);

fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(pyannote_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.segmentation.pyannote.model = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I");
ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid);

fid = env->GetFieldID(segmentation_config_cls, "debug", "Z");
ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid);

fid = env->GetFieldID(segmentation_config_cls, "provider",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(segmentation_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.segmentation.provider = p;
env->ReleaseStringUTFChars(s, p);

//---------- embedding ----------
fid = env->GetFieldID(
cls, "embedding",
"Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;");
jobject embedding_config = env->GetObjectField(config, fid);
jclass embedding_config_cls = env->GetObjectClass(embedding_config);

fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(embedding_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.embedding.model = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(embedding_config_cls, "numThreads", "I");
ans.embedding.num_threads = env->GetIntField(embedding_config, fid);

fid = env->GetFieldID(embedding_config_cls, "debug", "Z");
ans.embedding.debug = env->GetBooleanField(embedding_config, fid);

fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(embedding_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.embedding.provider = p;
env->ReleaseStringUTFChars(s, p);

//---------- clustering ----------
fid = env->GetFieldID(cls, "clustering",
"Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;");
jobject clustering_config = env->GetObjectField(config, fid);
jclass clustering_config_cls = env->GetObjectClass(clustering_config);

fid = env->GetFieldID(clustering_config_cls, "numClusters", "I");
ans.clustering.num_clusters = env->GetIntField(clustering_config, fid);

fid = env->GetFieldID(clustering_config_cls, "threshold", "F");
ans.clustering.threshold = env->GetFloatField(clustering_config, fid);

// its own fields
fid = env->GetFieldID(cls, "minDurationOn", "F");
ans.min_duration_on = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "minDurationOff", "F");
ans.min_duration_off = env->GetFloatField(config, fid);

return ans;
}

} // namespace sherpa_onnx

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
return 0;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}

auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config);

return (jlong)sd;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig(
JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
sd->SetConfig(config);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
}

static jobjectArray ProcessImpl(
JNIEnv *env,
const std::vector<sherpa_onnx::OfflineSpeakerDiarizationSegment>
&segments) {
jclass cls =
env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment");

jobjectArray obj_arr =
(jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr);

jmethodID constructor = env->GetMethodID(cls, "<init>", "(FFI)V");

for (int32_t i = 0; i != segments.size(); ++i) {
const auto &s = segments[i];
jobject segment =
env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker());
env->SetObjectArrayElement(obj_arr, i, segment);
}

return obj_arr;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);

jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto segments = sd->Process(p, n).SortByStartTime();
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);

return ProcessImpl(env, segments);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jobject callback, jlong arg) {
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
[env, callback](int32_t num_processed_chunks, int32_t num_total_chunks,
void *data) -> int {
jclass cls = env->GetObjectClass(callback);

jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;");
if (mid == nullptr) {
SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
return 0;
}

jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks,
num_total_chunks, (jlong)data);
jclass jklass = env->GetObjectClass(ret);
jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
return env->CallIntMethod(ret, int_value_mid);
};

auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);

jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto segments =
sd->Process(p, n, callback_wrapper, (void *)arg).SortByStartTime();
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);

return ProcessImpl(env, segments);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr)
->SampleRate();
}
Loading
Loading