diff --git a/bin/addons/godot_whisper/audio_stream_to_text.gd b/bin/addons/godot_whisper/audio_stream_to_text.gd index 7310785..22223d9 100644 --- a/bin/addons/godot_whisper/audio_stream_to_text.gd +++ b/bin/addons/godot_whisper/audio_stream_to_text.gd @@ -3,63 +3,70 @@ class_name AudioStreamToText extends SpeechToText -# For Traditional Chinese "以下是普通話的句子。" -# For Simplified Chinese "以下是普通话的句子。" +## Initial prompt for the transcription +## For Traditional Chinese "以下是普通話的句子。" +## For Simplified Chinese "以下是普通话的句子。" @export var initial_prompt: String -func _get_configuration_warnings(): - if language_model == null: - return ["You need a language model."] - else: - return [] - -@export var audio_stream: AudioStreamWAV : +## Audio stream to be transcribed +@export var audio_stream: AudioStreamWAV: set(value): audio_stream = value text = get_text() get: return audio_stream + +## Transcribed text from the audio stream @export var text: String -## Download the model specified in the [member language_model_to_download] +## Flag to start transcription @export var start_transcribe := false: - set(x): + set(value): text = get_text() get: return false -func get_text(): +## Get the transcribed text from the audio stream +func get_text() -> String: + # Return early if audio stream is null + if audio_stream == null: + return "" + var start_time := Time.get_ticks_msec() var data := audio_stream.data - var data_float : PackedFloat32Array + var data_float: PackedFloat32Array + match audio_stream.format: AudioStreamWAV.FORMAT_8_BITS: - for i in data.size() / 2: - data_float.append((data.decode_s8(i * 2) * 1.0/128.0)) + for i in range(data.size() / 2): + data_float.append(data.decode_s8(i * 2) * 1.0 / 128.0) AudioStreamWAV.FORMAT_16_BITS: - for i in data.size() / 2: - data_float.append((data.decode_s16(i * 2) / 32768.0)) - var tokens = transcribe(data_float, initial_prompt, 0) + for i in range(data.size() / 2): + data_float.append(data.decode_s16(i * 2) / 32768.0) + + var tokens := transcribe(data_float, initial_prompt, 0) if tokens.is_empty(): return "" - var full_text : String = tokens.pop_front() + + var full_text: String = tokens.pop_front() var text := "" for token in tokens: if token["plog"] > 0: continue text += token["text"] text = full_text - print("Transcribe: " + str((Time.get_ticks_msec() - start_time)/ 1000.0)) + + print("Transcribe: " + str((Time.get_ticks_msec() - start_time) / 1000.0)) print(text) return _remove_special_characters(text) -func _remove_special_characters(message: String): - var special_characters := [ \ - { "start": "[", "end": "]" }, \ - { "start": "<", "end": ">" }] + +## Remove special characters from the transcribed text +func _remove_special_characters(message: String) -> String: + var special_characters := [{"start": "[", "end": "]"}, {"start": "<", "end": ">"}] for special_character in special_characters: - while(message.find(special_character["start"]) != -1): + while message.find(special_character["start"]) != -1: var begin_character := message.find(special_character["start"]) var end_character := message.find(special_character["end"]) if end_character != -1: @@ -67,8 +74,15 @@ func _remove_special_characters(message: String): var hallucinatory_character := [". you."] for special_character in hallucinatory_character: - while(message.find(special_character) != -1): + while message.find(special_character) != -1: var begin_character := message.find(special_character) var end_character := begin_character + len(special_character) message = message.substr(0, begin_character) + message.substr(end_character + 1) return message + + +## Get configuration warnings for the node +func _get_configuration_warnings() -> PackedStringArray: + if language_model == null: + return ["You need a language model."] + return [] diff --git a/bin/addons/godot_whisper/capture_stream_to_text.gd b/bin/addons/godot_whisper/capture_stream_to_text.gd index 0bec4aa..e3979a3 100644 --- a/bin/addons/godot_whisper/capture_stream_to_text.gd +++ b/bin/addons/godot_whisper/capture_stream_to_text.gd @@ -2,18 +2,14 @@ class_name CaptureStreamToText extends SpeechToText -signal transcribed_msg(is_partial, new_text) +signal transcribed_msg(is_partial: bool, new_text: String) -# For Traditional Chinese "以下是普通話的句子。" -# For Simplified Chinese "以下是普通话的句子。" +## Initial prompt for the transcription +## For Traditional Chinese "以下是普通話的句子。" +## For Simplified Chinese "以下是普通话的句子。" @export var initial_prompt: String -func _get_configuration_warnings(): - if language_model == null: - return ["You need a language model."] - else: - return [] - +## Flag to start/stop recording @export var recording := true: set(value): recording = value @@ -23,120 +19,146 @@ func _get_configuration_warnings(): thread.wait_to_finish() get: return recording + ## The interval at which transcribing is done. Use a value bigger than the time it takes to transcribe (eg. depends on model). @export var transcribe_interval := 0.3 + ## Using dynamic audio context speeds up transcribing but may result in mistakes. @export var use_dynamic_audio_context := true + ## How much time has to pass in seconds until we can consider a sentence. @export var minimum_sentence_time := 3 + ## Maximum time a sentence can have in seconds. @export var maximum_sentence_time := 15 -## How many tokens it's allowed to halucinate. Can provide useful info as it talks, but too much can provide useless text. -@export var halucinating_count := 1 -## The record bus has to have a AudioEffectCapture at index specified by [member audio_effect_capture_index] + +## How many tokens it's allowed to hallucinate. Can provide useful info as it talks, but too much can provide useless text. +@export var hallucinating_count := 1 + +## The record bus has to have an AudioEffectCapture at index specified by [member audio_effect_capture_index] @export var record_bus := "Record" + ## The index where the [AudioEffectCapture] is located at in the [member record_bus] @export var audio_effect_capture_index := 0 ## Character to consider when ending a sentence. @export var punctuation_characters := ".!?;。;?!" -@onready var _idx = AudioServer.get_bus_index(record_bus) -@onready var _effect_capture := AudioServer.get_bus_effect(_idx, audio_effect_capture_index) as AudioEffectCapture +var thread: Thread +var _accumulated_frames: PackedVector2Array -var thread : Thread +@onready var _idx := AudioServer.get_bus_index(record_bus) +@onready var _effect_capture := ( + AudioServer.get_bus_effect(_idx, audio_effect_capture_index) as AudioEffectCapture +) -func _ready(): + +## Ready function to initialize the thread and clear buffer +func _ready() -> void: if Engine.is_editor_hint(): return - if thread && thread.is_alive(): + if thread and thread.is_alive(): recording = false thread.wait_to_finish() thread = Thread.new() _effect_capture.clear_buffer() thread.start(transcribe_thread) -var _accumulated_frames: PackedVector2Array -func transcribe_thread(): +## Thread function to handle transcription +func transcribe_thread() -> void: var last_token_count := 0 while recording: - var start_time = Time.get_ticks_msec() - _accumulated_frames.append_array(_effect_capture.get_buffer(_effect_capture.get_frames_available())) - var resampled = resample(_accumulated_frames, SpeechToText.SRC_SINC_FASTEST) + var start_time := Time.get_ticks_msec() + _accumulated_frames.append_array( + _effect_capture.get_buffer(_effect_capture.get_frames_available()) + ) + var resampled := resample(_accumulated_frames, SpeechToText.SRC_SINC_FASTEST) if resampled.size() <= 0: OS.delay_msec(transcribe_interval * 1000) continue var no_activity := voice_activity_detection(resampled) - #if no_activity: - #print("no activity") - #continue - var total_time : float= (resampled.size() as float) / SpeechToText.SPEECH_SETTING_SAMPLE_RATE - var audio_ctx : int = total_time * 1500 / 30 + 128 - if !use_dynamic_audio_context: + var total_time: float = ( + (resampled.size() as float) / SpeechToText.SPEECH_SETTING_SAMPLE_RATE + ) + var audio_ctx: int = total_time * 1500 / 30 + 128 + if not use_dynamic_audio_context: audio_ctx = 0 var tokens := transcribe(resampled, initial_prompt, audio_ctx) if tokens.is_empty(): push_warning("No tokens generated") return - var full_text : String = tokens.pop_front() - var mix_rate : int = ProjectSettings.get_setting("audio/driver/mix_rate") - var finish_sentence = false + var full_text: String = tokens.pop_front() + var mix_rate: int = ProjectSettings.get_setting("audio/driver/mix_rate") + var finish_sentence := false if total_time > maximum_sentence_time: finish_sentence = true - var text : String + var text: String for token in tokens: text += token["text"] text = _remove_special_characters(text) - if _has_terminating_characters(text, punctuation_characters) || no_activity: + if _has_terminating_characters(text, punctuation_characters) or no_activity: finish_sentence = true - if total_time < minimum_sentence_time || abs(tokens.size() - last_token_count) > halucinating_count: + if ( + total_time < minimum_sentence_time + or abs(tokens.size() - last_token_count) > hallucinating_count + ): finish_sentence = false - var time_processing = (Time.get_ticks_msec() - start_time) + var time_processing := Time.get_ticks_msec() - start_time if no_activity: - #_accumulated_frames = [] continue if finish_sentence: - _accumulated_frames = _accumulated_frames.slice(_accumulated_frames.size() - (0.2 * mix_rate)) - #if !no_activity: + _accumulated_frames = _accumulated_frames.slice( + _accumulated_frames.size() - (0.2 * mix_rate) + ) call_deferred("emit_signal", "transcribed_msg", finish_sentence, full_text) last_token_count = tokens.size() - #print(text) print(full_text) - print("Transcribe " + str(time_processing/ 1000.0) + " s") - # Sleep remaining time - var interval_sleep = transcribe_interval * 1000 - time_processing + print("Transcribe " + str(time_processing / 1000.0) + " s") + var interval_sleep := transcribe_interval * 1000 - time_processing if interval_sleep > 0: OS.delay_msec(interval_sleep) -func _has_terminating_characters(message: String, characters: String): + +## Check if the message contains terminating characters +func _has_terminating_characters(message: String, characters: String) -> bool: for character in characters: if message.contains(character): return true return false -func _remove_special_characters(message: String): - var special_characters = [ \ - { "start": "[", "end": "]" }, \ - { "start": "<", "end": ">" }, \ - { "start": "♪", "end": "♪" }] + +## Remove special characters from the message +func _remove_special_characters(message: String) -> String: + var special_characters := [ + {"start": "[", "end": "]"}, {"start": "<", "end": ">"}, {"start": "♪", "end": "♪"} + ] for special_character in special_characters: - while(message.find(special_character["start"]) != -1): + while message.find(special_character["start"]) != -1: var begin_character := message.find(special_character["start"]) var end_character := message.find(special_character["end"]) if end_character != -1: message = message.substr(0, begin_character) + message.substr(end_character + 1) - var hallucinatory_character = [". you."] + var hallucinatory_character := [". you."] for special_character in hallucinatory_character: - while(message.find(special_character) != -1): + while message.find(special_character) != -1: var begin_character := message.find(special_character) - var end_character = begin_character + len(special_character) + var end_character := begin_character + len(special_character) message = message.substr(0, begin_character) + message.substr(end_character + 1) return message -func _notification(what): + +## Handle notifications +func _notification(what: int) -> void: if what == NOTIFICATION_WM_CLOSE_REQUEST: recording = false if thread.is_alive(): thread.wait_to_finish() + + +## Get configuration warnings for the node +func _get_configuration_warnings() -> PackedStringArray: + if language_model == null: + return ["You need a language model."] + return [] diff --git a/bin/addons/godot_whisper/label_transcribe.gd b/bin/addons/godot_whisper/label_transcribe.gd index 14d6889..d219148 100644 --- a/bin/addons/godot_whisper/label_transcribe.gd +++ b/bin/addons/godot_whisper/label_transcribe.gd @@ -1,23 +1,35 @@ extends RichTextLabel -func _ready(): +## Completed text from transcription +var completed_text := "" + +## Partial text from transcription +var partial_text := "" + + +## Ready function to initialize the label +func _ready() -> void: custom_minimum_size.x = 400 bbcode_enabled = true fit_content = true -func update_text(): + +## Update the text displayed on the label +func update_text() -> void: text = completed_text + "[color=green]" + partial_text + "[/color]" -func _process(_delta): + +## Process function to update text every frame +func _process(_delta: float) -> void: update_text() -var completed_text := "" -var partial_text := "" -func _on_speech_to_text_transcribed_msg(is_partial, new_text): - if is_partial == true: +## Handle the speech-to-text transcribed message +func _on_speech_to_text_transcribed_msg(is_partial: bool, new_text: String) -> void: + # Handle partial and complete transcriptions + if is_partial: completed_text += new_text partial_text = "" else: - if new_text!="": + if new_text != "": partial_text = new_text diff --git a/bin/addons/godot_whisper/model_downloader.gd b/bin/addons/godot_whisper/model_downloader.gd index a638b09..92e665c 100644 --- a/bin/addons/godot_whisper/model_downloader.gd +++ b/bin/addons/godot_whisper/model_downloader.gd @@ -1,29 +1,45 @@ @tool extends Node +## Option button for selecting the model @export var option_button: OptionButton -# Called when the HTTP request is completed. -func _http_request_completed(result, response_code, headers, body, file_path): - if result != HTTPRequest.RESULT_SUCCESS || response_code != 200: - push_error("Can't downloaded.") + +## Called when the HTTP request is completed. +func _http_request_completed( + result: int, + response_code: int, + _headers: PackedStringArray, + _body: PackedByteArray, + file_path: String +) -> void: + # Handle unsuccessful download + if result != HTTPRequest.RESULT_SUCCESS or response_code != 200: + push_error("Can't download.") return EditorInterface.get_resource_filesystem().scan() ResourceLoader.load(file_path, "WhisperResource", 2) print("Download successful. Check " + file_path) -func _on_button_pressed(): - var http_request = HTTPRequest.new() + +## Handle button press to start the download +func _on_button_pressed() -> void: + var http_request := HTTPRequest.new() add_child(http_request) http_request.use_threads = true DirAccess.make_dir_recursive_absolute("res://addons/godot_whisper/models") - var model = option_button.get_item_text(option_button.get_selected_id()) - var file_path : String = "res://addons/godot_whisper/models/gglm-" + model + ".bin" + var model: String = option_button.get_item_text(option_button.get_selected_id()) + var file_path: String = "res://addons/godot_whisper/models/gglm-" + model + ".bin" http_request.request_completed.connect(self._http_request_completed.bind(file_path)) http_request.download_file = file_path - var url : String = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-" + model + ".bin?download=true" + var url: String = ( + "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-" + + model + + ".bin?download=true" + ) print("Downloading file from " + url) # Perform a GET request. The URL below returns JSON as of writing. - var error = http_request.request(url) + var error: int = http_request.request(url) + # Handle HTTP request error if error != OK: push_error("An error occurred in the HTTP request.")