From 06f4b718f824e9d074a2f7c1d01f466c0dadf37e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 22 Jan 2024 23:19:54 +0800 Subject: [PATCH] First working version --- .../speaker/identification/MainActivity.kt | 2 + .../onnx/speaker/identification/Speaker.kt | 61 ++++- .../speaker/identification/screens/Home.kt | 218 +++++++++++++++++- .../identification/screens/Register.kt | 116 ++++++---- .../app/src/main/res/values/strings.xml | 1 + sherpa-onnx/jni/jni.cc | 16 +- 6 files changed, 350 insertions(+), 64 deletions(-) diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt index 10dff4a88..262f1973c 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt @@ -32,6 +32,7 @@ import androidx.navigation.compose.NavHost import androidx.navigation.compose.composable import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.rememberNavController +import com.k2fsa.sherpa.onnx.SpeakerRecognition import com.k2fsa.sherpa.onnx.speaker.identification.screens.HelpScreen import com.k2fsa.sherpa.onnx.speaker.identification.screens.HomeScreen import com.k2fsa.sherpa.onnx.speaker.identification.screens.RegisterScreen @@ -59,6 +60,7 @@ class MainActivity : ComponentActivity() { ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) + SpeakerRecognition.initExtractor(this.assets) } @Deprecated("Deprecated in Java") diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt index fcad730b0..0f9337510 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt @@ -1,6 +1,8 @@ package com.k2fsa.sherpa.onnx import android.content.res.AssetManager +import android.util.Log +import com.k2fsa.sherpa.onnx.speaker.identification.TAG data class SpeakerEmbeddingExtractorConfig( @@ -11,7 +13,8 @@ data class SpeakerEmbeddingExtractorConfig( ) class SpeakerEmbeddingExtractorStream(var ptr: Long) { - fun acceptWaveform(samples: FloatArray, sampleRate: Int) = acceptWaveform(ptr, samples, sampleRate) + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = + acceptWaveform(ptr, samples, sampleRate) fun inputFinished() = inputFinished(ptr) @@ -28,6 +31,7 @@ class SpeakerEmbeddingExtractorStream(var ptr: Long) { private external fun inputFinished(ptr: Long) private external fun delete(ptr: Long) + companion object { init { System.loadLibrary("sherpa-onnx-jni") @@ -108,7 +112,9 @@ class SpeakerEmbeddingManager(val dim: Int) { fun add(name: String, embedding: Array) = addList(ptr, name, embedding) fun remove(name: String) = remove(ptr, name) fun search(embedding: FloatArray, threshold: Float) = search(ptr, embedding, threshold) - fun verify(name: String, embedding: FloatArray, threshold: Float) = verify(ptr, name, embedding, threshold) + fun verify(name: String, embedding: FloatArray, threshold: Float) = + verify(ptr, name, embedding, threshold) + fun contains(name: String) = contains(ptr, name) fun numSpeakers() = numSpeakers(ptr) @@ -118,7 +124,13 @@ class SpeakerEmbeddingManager(val dim: Int) { private external fun addList(ptr: Long, name: String, embedding: Array): Boolean private external fun remove(ptr: Long, name: String): Boolean private external fun search(ptr: Long, embedding: FloatArray, threshold: Float): String - private external fun verify(ptr: Long, name: String, embedding: FloatArray, threshold: Float): Boolean + private external fun verify( + ptr: Long, + name: String, + embedding: FloatArray, + threshold: Float + ): Boolean + private external fun contains(ptr: Long, name: String): Boolean private external fun numSpeakers(ptr: Long): Int @@ -128,3 +140,46 @@ class SpeakerEmbeddingManager(val dim: Int) { } } } + +// Please download the model file from +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +// and put it inside the assets directory. +// +// Please don't put it in a subdirectory of assets +private val modelName = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + +object SpeakerRecognition { + var _extractor: SpeakerEmbeddingExtractor? = null + var _manager: SpeakerEmbeddingManager? = null + + val extractor: SpeakerEmbeddingExtractor + get() { + return _extractor!! + } + + val manager: SpeakerEmbeddingManager + get() { + return _manager!! + } + + fun initExtractor(assetManager: AssetManager? = null) { + synchronized(this) { + if (_extractor != null) { + return + } + Log.i(TAG, "Initializing speaker embedding extractor") + + _extractor = SpeakerEmbeddingExtractor( + assetManager = assetManager, + config = SpeakerEmbeddingExtractorConfig( + model = modelName, + numThreads = 2, + debug = false, + provider = "cpu", + ) + ) + + _manager = SpeakerEmbeddingManager(dim=_extractor!!.dim()) + } + } +} diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Home.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Home.kt index 1f2cf15c5..ddaaa0e3a 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Home.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Home.kt @@ -1,14 +1,228 @@ package com.k2fsa.sherpa.onnx.speaker.identification.screens +import android.Manifest +import android.annotation.SuppressLint +import android.app.Activity +import android.content.pm.PackageManager +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder +import android.util.Log +import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.material3.Button +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Slider import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import androidx.core.app.ActivityCompat +import com.k2fsa.sherpa.onnx.SpeakerRecognition +import com.k2fsa.sherpa.onnx.speaker.identification.R +import com.k2fsa.sherpa.onnx.speaker.identification.TAG +import kotlin.concurrent.thread +private var audioRecord: AudioRecord? = null +private var sampleList: MutableList? = null + +private val clearedResult = "-cleared-" @Composable fun HomeScreen() { - Box(modifier= Modifier.fillMaxSize()) { - Text("Home") + val activity = LocalContext.current as Activity + var threshold by remember { + mutableStateOf(0.5F) + } + + var detectedName by remember { + mutableStateOf(clearedResult) + } + + var isStarted by remember { mutableStateOf(false) } + val onRecordingButtonClick: () -> Unit = { + isStarted = !isStarted + + if (isStarted) { + if (ActivityCompat.checkSelfPermission( + activity, + Manifest.permission.RECORD_AUDIO + ) != PackageManager.PERMISSION_GRANTED + ) { + Log.i(TAG, "Recording is not allowed") + } else { + // recording is allowed + val audioSource = MediaRecorder.AudioSource.MIC + val channelConfig = AudioFormat.CHANNEL_IN_MONO + val audioFormat = AudioFormat.ENCODING_PCM_16BIT + val numBytes = + AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat) + + audioRecord = AudioRecord( + audioSource, + sampleRateInHz, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + numBytes * 2 // a sample has two bytes as we are using 16-bit PCM + ) + + sampleList = null + detectedName = clearedResult + + // recording is started here + thread(true) { + Log.i(TAG, "processing samples") + + val interval = 0.1 // i.e., 100 ms + val bufferSize = (interval * sampleRateInHz).toInt() // in samples + val buffer = ShortArray(bufferSize) + audioRecord?.let { + it.startRecording() + + while (isStarted) { + val ret = audioRecord?.read(buffer, 0, buffer.size) + ret?.let { n -> + val samples = FloatArray(n) { buffer[it] / 32768.0f } + if (sampleList == null) { + sampleList = mutableListOf(samples) + } else { + sampleList?.add(samples) + } + } + } + } + + Log.i(TAG, "Home: Recording is stopped. ${sampleList?.count()}") + } + } + } else { + // recording is stopped here + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + + sampleList?.let { + val stream = SpeakerRecognition.extractor.createStream() + for (samples in it) { + stream.acceptWaveform(samples = samples, sampleRate = sampleRateInHz) + } + stream.inputFinished() + if (SpeakerRecognition.extractor.isReady(stream)) { + val embedding = SpeakerRecognition.extractor.compute(stream) + detectedName = SpeakerRecognition.manager.search( + embedding = embedding, + threshold = threshold, + ) + } + } + } + } + + val onThresholdChange = { newValue: Float -> + threshold = newValue + } + + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.TopCenter, + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + ) { + HomeThresholdRow( + threshold = threshold, + onValueChange = onThresholdChange, + ) + HomeButtonRow( + isStarted = isStarted, + onRecordingButtonClick = onRecordingButtonClick, + onClearButtonClick = { + detectedName = clearedResult + }, + ) + + Spacer(modifier = Modifier.height(48.dp)) + + if(detectedName == clearedResult) { + // do nothing + } else if (detectedName.length > 0) { + Text( + text = "Speaker: ${detectedName}", + style = MaterialTheme.typography.headlineLarge, + fontWeight = FontWeight.Bold, + ) + } else { + Text( + text = "Unknown speaker", + style = MaterialTheme.typography.headlineLarge, + fontWeight = FontWeight.Bold, + ) + } + } + } +} + +@SuppressLint("UnrememberedMutableState") +@Composable +private fun HomeButtonRow( + modifier: Modifier = Modifier, + isStarted: Boolean, + onRecordingButtonClick: () -> Unit, + onClearButtonClick: () -> Unit, +) { + val numSpeakers: Int by mutableStateOf(SpeakerRecognition.manager.numSpeakers()) + Row( + modifier = modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.Center, + ) { + Button( + enabled = numSpeakers > 0, + onClick = onRecordingButtonClick + ) { + Text(text = stringResource(if (isStarted) R.string.stop else R.string.start)) + } + + Spacer(modifier = Modifier.width(24.dp)) + + Button(onClick = onClearButtonClick) { + Text(text = stringResource(id = R.string.clear)) + } + } +} + +@Composable +fun HomeThresholdRow( + modifier: Modifier = Modifier, + threshold: Float, + onValueChange: (Float) -> Unit, +) { + Column(modifier = Modifier) { + Text( + text = "Threshold: " + String.format("%.2f", threshold), + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold, + modifier = modifier.padding(bottom = 8.dp, top = 8.dp), + ) + Slider( + value = threshold, + onValueChange = onValueChange, + valueRange = 0.1F..1.0F, + modifier = modifier.fillMaxWidth(), + ) } } \ No newline at end of file diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt index 15638d37b..7ac895d17 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt @@ -4,13 +4,11 @@ import android.Manifest import android.annotation.SuppressLint import android.app.Activity import android.content.pm.PackageManager -import android.media.AudioAttributes import android.media.AudioFormat -import android.media.AudioManager import android.media.AudioRecord -import android.media.AudioTrack import android.media.MediaRecorder import android.util.Log +import android.widget.Toast import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -37,6 +35,7 @@ import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.core.app.ActivityCompat +import com.k2fsa.sherpa.onnx.SpeakerRecognition import com.k2fsa.sherpa.onnx.speaker.identification.R import com.k2fsa.sherpa.onnx.speaker.identification.TAG import kotlin.concurrent.thread @@ -45,9 +44,9 @@ private var audioRecord: AudioRecord? = null private var sampleList: MutableList? = null -private var allSampleList: MutableList>? = null +private var embeddingList: MutableList? = null -private var number = 0 +val sampleRateInHz = 16000 @SuppressLint("UnrememberedMutableState") @Preview @@ -59,11 +58,10 @@ fun RegisterScreen(modifier: Modifier = Modifier) { if (firstTime) { firstTime = false // clear states - - number = 0 + embeddingList = null } - var numberAudio by mutableStateOf(number) + val numberAudio: Int by mutableStateOf(embeddingList?.count() ?: 0) Box( modifier = Modifier.fillMaxSize(), @@ -73,7 +71,7 @@ fun RegisterScreen(modifier: Modifier = Modifier) { val onSpeakerNameChange = { newName: String -> speakerName = newName } var isStarted by remember { mutableStateOf(false) } - val onRecordButtonClick: () -> Unit = { + val onRecordingButtonClick: () -> Unit = { isStarted = !isStarted if (isStarted) { @@ -86,7 +84,6 @@ fun RegisterScreen(modifier: Modifier = Modifier) { } else { // recording is allowed val audioSource = MediaRecorder.AudioSource.MIC - val sampleRateInHz = 16000 val channelConfig = AudioFormat.CHANNEL_IN_MONO val audioFormat = AudioFormat.ENCODING_PCM_16BIT val numBytes = @@ -127,7 +124,6 @@ fun RegisterScreen(modifier: Modifier = Modifier) { Log.i(TAG, "Recording is stopped. ${sampleList?.count()}") - ++number } } } else { @@ -136,40 +132,59 @@ fun RegisterScreen(modifier: Modifier = Modifier) { audioRecord?.release() audioRecord = null - Log.i(TAG, "Start to play the recorded samples") - val sampleRate = 16000 - val bufLength = AudioTrack.getMinBufferSize( - sampleRate, - AudioFormat.CHANNEL_OUT_MONO, - AudioFormat.ENCODING_PCM_FLOAT - ) - Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}") - - val attr = - AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) - .setUsage(AudioAttributes.USAGE_MEDIA) - .build() - - val format = AudioFormat.Builder() - .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) - .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) - .setSampleRate(sampleRate) - .build() - - val track = AudioTrack( - attr, format, bufLength, AudioTrack.MODE_STREAM, - AudioManager.AUDIO_SESSION_ID_GENERATE - ) - track.play() - - for (samples in sampleList!!) { - track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING) + sampleList?.let { + val stream = SpeakerRecognition.extractor.createStream() + for (samples in it) { + stream.acceptWaveform(samples=samples, sampleRate=sampleRateInHz) + } + stream.inputFinished() + if(SpeakerRecognition.extractor.isReady(stream)) { + val embedding = SpeakerRecognition.extractor.compute(stream) + if(embeddingList == null) { + embeddingList = mutableListOf(embedding) + } else { + embeddingList?.add(embedding) + } + } } - track.stop() - track.release() - Log.i(TAG, "released audio track") + } + } + + val onAddButtonClick: () -> Unit = { + if(speakerName.isEmpty() || speakerName.isBlank()) { + Toast.makeText( + activity, + "please input a speaker name", + Toast.LENGTH_SHORT + ).show() + } else if(SpeakerRecognition.manager.contains(speakerName.trim())) { + Toast.makeText( + activity, + "A speaker with $speakerName already exists. Please choose a new name", + Toast.LENGTH_SHORT + ).show() + } else { + val ok = SpeakerRecognition.manager.add(speakerName.trim(), embedding = embeddingList!!.toTypedArray()) + if(ok) { + Log.i(TAG, "Added ${speakerName.trim()} successfully") + Toast.makeText( + activity, + "Added ${speakerName.trim()}", + Toast.LENGTH_SHORT + ).show() - // play the recorded audio to check that the recording is working + embeddingList = null + sampleList = null + speakerName = "" + firstTime = true + } else { + Log.i(TAG, "Failed to add ${speakerName.trim()}") + Toast.makeText( + activity, + "Failed to add ${speakerName.trim()}", + Toast.LENGTH_SHORT + ).show() + } } } @@ -184,7 +199,8 @@ fun RegisterScreen(modifier: Modifier = Modifier) { RegisterSpeakerButtonRow( modifier, isStarted = isStarted, - onButtonClick = onRecordButtonClick, + onRecordingButtonClick = onRecordingButtonClick, + onAddButtonClick = onAddButtonClick, ) } } @@ -209,23 +225,29 @@ fun SpeakerNameRow( ) } +@SuppressLint("UnrememberedMutableState") @Composable fun RegisterSpeakerButtonRow( modifier: Modifier = Modifier, isStarted: Boolean, - onButtonClick: () -> Unit, + onRecordingButtonClick: () -> Unit, + onAddButtonClick: () -> Unit, ) { + val numberAudio: Int by mutableStateOf(embeddingList?.count() ?: 0) Row( modifier = modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center, ) { - Button(onClick = onButtonClick) { + Button(onClick = onRecordingButtonClick) { Text(text = stringResource(if (isStarted) R.string.stop else R.string.start)) } Spacer(modifier = Modifier.width(24.dp)) - Button(onClick = {}) { + Button( + enabled = numberAudio > 0, + onClick = onAddButtonClick, + ) { Text(text = stringResource(id = R.string.add)) } } diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/res/values/strings.xml b/android/SherpaOnnxSpeakerIdentification/app/src/main/res/values/strings.xml index 349b44723..0766efd7d 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/res/values/strings.xml +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/res/values/strings.xml @@ -3,4 +3,5 @@ Start recording Stop recording Add speaker + Clear result \ No newline at end of file diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index f974f56e6..4773da4f4 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -234,7 +234,7 @@ class SherpaOnnxSpeakerEmbeddingExtractor { #if __ANDROID_API__ >= 9 SherpaOnnxSpeakerEmbeddingExtractor( AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) - : extractor_(mgr, config), stream_(extractor_.CreateStream()) {} + : extractor_(mgr, config) {} #endif explicit SherpaOnnxSpeakerEmbeddingExtractor( @@ -865,11 +865,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_new(JNIEnv *env, } #endif auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - - if (!config.Validate()) { - SHERPA_ONNX_LOGE("Errors found in config!"); - } + SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str()); auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor( #if __ANDROID_API__ >= 9 @@ -885,7 +881,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile( JNIEnv *env, jobject /*obj*/, jobject _config) { auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str()); if (!config.Validate()) { SHERPA_ONNX_LOGE("Errors found in config!"); @@ -1160,7 +1156,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env, jobject /*obj*/, jlong ptr) { auto manager = reinterpret_cast(ptr); - return manager->Dim(); + return manager->NumSpeakers(); } SHERPA_ONNX_EXTERN_C @@ -1175,10 +1171,6 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new( auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - if (!config.Validate()) { - SHERPA_ONNX_LOGE("Errors found in config!"); - } - auto tts = new sherpa_onnx::SherpaOnnxOfflineTts( #if __ANDROID_API__ >= 9 mgr,