diff --git a/Models/STT/faster_whisper.py b/Models/STT/faster_whisper.py index 9667cfb..db2c7ec 100644 --- a/Models/STT/faster_whisper.py +++ b/Models/STT/faster_whisper.py @@ -1,3 +1,6 @@ +import gc + +import torch from faster_whisper import WhisperModel from pathlib import Path @@ -537,7 +540,7 @@ def needs_download(model: str, compute_type: str = "float32"): if compute_type not in MODEL_LINKS[model]: if compute_type == "float32": model_path = Path(model_cache_path / (model + "-ct2-fp16")) - if compute_type == "float16": + elif compute_type == "float16": model_path = Path(model_cache_path / (model + "-ct2")) pretrained_lang_model_file = Path(model_path / "model.bin") @@ -564,11 +567,10 @@ def download_model(model: str, compute_type: str = "float32"): if compute_type == "float32": compute_type = "float16" model_path = Path(model_cache_path / (model + "-ct2-fp16")) - if compute_type == "float16": + elif compute_type == "float16": compute_type = "float32" model_path = Path(model_cache_path / (model + "-ct2")) - pretrained_lang_model_file = Path(model_path / "model.bin") if not Path(model_path).exists() or not pretrained_lang_model_file.is_file(): @@ -595,24 +597,62 @@ def download_model(model: str, compute_type: str = "float32"): class FasterWhisper(metaclass=SingletonMeta): model = None loaded_model_size = "" + loaded_settings = {} + + transcription_count = 0 + reload_after_transcriptions = 0 def __init__(self, model: str, device: str = "cpu", compute_type: str = "float32", cpu_threads: int = 0, num_workers: int = 1): if self.model is None: self.load_model(model, device, compute_type, cpu_threads, num_workers) + def set_reload_after_transcriptions(self, reload_after_transcriptions: int): + self.reload_after_transcriptions = reload_after_transcriptions + + def release_model(self): + print("Reloading model...") + if self.model is not None: + if hasattr(self.model, 'model'): + del self.model.model + if hasattr(self.model, 'feature_extractor'): + del self.model.feature_extractor + if hasattr(self.model, 'hf_tokenizer'): + del self.model.hf_tokenizer + del self.model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + self.load_model( + self.loaded_settings["model"], + self.loaded_settings["device"], + self.loaded_settings["compute_type"], + self.loaded_settings["cpu_threads"], + self.loaded_settings["num_workers"], + ) + def load_model(self, model: str, device: str = "cpu", compute_type: str = "float32", cpu_threads: int = 0, num_workers: int = 1): + + self.loaded_settings = { + "model": model, + "device": device, + "compute_type": compute_type, + "cpu_threads": cpu_threads, + "num_workers": num_workers + } + model_cache_path = Path(".cache/whisper") os.makedirs(model_cache_path, exist_ok=True) model_folder_name = model + "-ct2" - if compute_type == "float16" or compute_type == "int8_float16": + if compute_type == "float16" or compute_type == "int8_float16" or compute_type == "int16" or compute_type == "int8": model_folder_name = model + "-ct2-fp16" # special case for models that are only available in one precision (as float16 vs float32 showed no difference in large-v3 and distilled versions) if compute_type not in MODEL_LINKS[model]: if compute_type == "float32": model_folder_name = model + "-ct2-fp16" - if compute_type == "float16": + elif compute_type == "float16": model_folder_name = model + "-ct2" model_path = Path(model_cache_path / model_folder_name) @@ -622,14 +662,15 @@ def load_model(self, model: str, device: str = "cpu", compute_type: str = "float # temporary fix for large-v3 loading (https://github.com/guillaumekln/faster-whisper/issues/547) # @TODO: this is a temporary fix for large-v3 - n_mels = 80 - use_tf_tokenizer = False - if model == "large-v3": - n_mels = 128 - #use_tf_tokenizer = True + #n_mels = 80 + #use_tf_tokenizer = False + #if model == "large-v3": + # n_mels = 128 + #self.model = WhisperModel(str(Path(model_path).resolve()), device=device, compute_type=compute_type, + # cpu_threads=cpu_threads, num_workers=num_workers, feature_size=n_mels, use_tf_tokenizer=use_tf_tokenizer) self.model = WhisperModel(str(Path(model_path).resolve()), device=device, compute_type=compute_type, - cpu_threads=cpu_threads, num_workers=num_workers, feature_size=n_mels, use_tf_tokenizer=use_tf_tokenizer) + cpu_threads=cpu_threads, num_workers=num_workers) def transcribe(self, audio_sample, task, language, condition_on_previous_text, initial_prompt, logprob_threshold, no_speech_threshold, temperature, beam_size, @@ -660,4 +701,8 @@ def transcribe(self, audio_sample, task, language, condition_on_previous_text, 'language': audio_info.language } + #self.transcription_count += 1 + #if self.reload_after_transcriptions > 0 and (self.transcription_count % self.reload_after_transcriptions == 0): + # self.release_model() + return result diff --git a/Models/STT/nemo_canary.py b/Models/STT/nemo_canary.py new file mode 100644 index 0000000..470056b --- /dev/null +++ b/Models/STT/nemo_canary.py @@ -0,0 +1,292 @@ +import json +import os + +import torch + +import yaml +from nemo.collections.asr.models import EncDecMultiTaskModel +from Models.Singleton import SingletonMeta + +from pathlib import Path +import downloader + +import soundfile as sf +import tempfile + +#try: +# from pytorch_quantization import nn as quant_nn +# from pytorch_quantization import quant_modules +#except ImportError: +# raise ImportError( +# "pytorch-quantization is not installed. Install from " +# "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." +# ) + +LANGUAGES = { + "en": "English", + "de": "German", + "fr": "French", + "es": "Spanish", +} + + +class NemoCanary(metaclass=SingletonMeta): + model = None + previous_model = None + processor = None + compute_type = "float32" + compute_device = "cpu" + + sample_rate = 16000 + + text_correction_model = None + + currently_downloading = False + model_cache_path = Path(".cache/nemo-canary") + MODEL_LINKS = {} + MODELS_LIST_URLS = [ + "https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/nemo-canary/models.yaml", + "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/nemo-canary/models.yaml", + "https://s3.libs.space:9000/ai-models/nemo-canary/models.yaml", + ] + _debug_skip_dl = False + + def __init__(self, compute_type="float32", device="cpu"): + os.makedirs(self.model_cache_path, exist_ok=True) + self.compute_type = compute_type + self.compute_device = device + + self.load_model_list() + + #if self._debug_skip_dl: + # # generate models.yaml + # self.generate_models_yaml(self.model_cache_path, "models.yaml") + + @staticmethod + def get_languages(): + return tuple([{"code": code, "name": language} for code, language in LANGUAGES.items()]) + + def _str_to_dtype_dict(self, dtype_str): + if dtype_str == "float16": + return {'dtype': torch.float16, '4bit': False, '8bit': False} + elif dtype_str == "float32": + return {'dtype': torch.float32, '4bit': False, '8bit': False} + elif dtype_str == "4bit": + return {'dtype': torch.float32, '4bit': True, '8bit': False} + elif dtype_str == "8bit": + return {'dtype': torch.float16, '4bit': False, '8bit': True} + else: + return {'dtype': torch.float16, '4bit': False, '8bit': False} + + def set_compute_type(self, compute_type): + self.compute_type = compute_type + + def set_compute_device(self, device): + self.compute_device = device + + def load_model_list(self): + if not self._debug_skip_dl: + if not downloader.download_extract(self.MODELS_LIST_URLS, + str(self.model_cache_path.resolve()), + '', title="Speech 2 Text (NeMo Canary Model list)", extract_format="none"): + print("Model list not downloaded. Using cached version.") + + # Load model list + if Path(self.model_cache_path / "models.yaml").exists(): + with open(self.model_cache_path / "models.yaml", "r") as file: + self.MODEL_LINKS = yaml.load(file, Loader=yaml.FullLoader) + file.close() + + def download_model(self, model_name): + model_directory = Path(self.model_cache_path / model_name) + os.makedirs(str(model_directory.resolve()), exist_ok=True) + + # if one of the files does not exist, break the loop and download the files + needs_download = False + for file in self.MODEL_LINKS[model_name]["files"]: + if not Path(model_directory / Path(file["urls"][0]).name).exists(): + needs_download = True + break + + if not needs_download: + for file in self.MODEL_LINKS[model_name]["files"]: + if Path(file["urls"][0]).name == "WS_VERSION": + checksum = downloader.sha256_checksum(str(model_directory.resolve() / Path(file["urls"][0]).name)) + if checksum != file["checksum"]: + needs_download = True + break + + # iterate over all self.MODEL_LINKS[model_name]["files"] entries and download them + if needs_download and not self.currently_downloading: + self.currently_downloading = True + for file in self.MODEL_LINKS[model_name]["files"]: + if not downloader.download_extract(file["urls"], + str(model_directory.resolve()), + file["checksum"], title="Speech 2 Text (NeMo Canary) - " + model_name, extract_format="none"): + print(f"Download failed: {file}") + + self.currently_downloading = False + + def load_model(self, model='canary-1b', compute_type="float32", device="cpu"): + #self.model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') + #if self.model is None: + + if not self._debug_skip_dl: + self.download_model(model) + + torch.set_grad_enabled(False) + + #quant_modules.initialize() + + if self.previous_model is None or self.model is None or model != self.previous_model: + print(f"Loading NeMo Canary model: {model} on {device} with {compute_type} precision...") + self.model = EncDecMultiTaskModel.restore_from(str(Path(self.model_cache_path / model / (model+".nemo")).resolve()), map_location=torch.device(device)) + #self.model.half() + #self.model.cuda() + self.model.eval() + self.previous_model = model + + def generate_models_yaml(self, directory, filename): + # Prepare the data + data = {} + + # Iterate through the directory + for root, dirs, files in os.walk(directory): + ws_version_file = None + # Get the model name from the directory name + model_name = os.path.basename(root) + for file in files: + # Calculate the SHA256 checksum + checksum = downloader.sha256_checksum(os.path.join(root, file)) + + # Initialize the model in the data dictionary if it doesn't exist + if model_name not in data: + data[model_name] = { + 'files': [] + } + + # Add the file details to the model's files list + file_data = { + 'urls': [ + f'https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/nemo-canary/{model_name}/{file}', + f'https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/nemo-canary/{model_name}/{file}', + f'https://s3.libs.space:9000/ai-models/nemo-canary/{model_name}/{file}' + ], + 'checksum': checksum + } + if file == "WS_VERSION": + ws_version_file = file_data + else: + data[model_name]['files'].append(file_data) + + if ws_version_file is not None: + data[model_name]['files'].insert(0, ws_version_file) + + # Write to YAML file + with open(os.path.join(directory, filename), 'w') as file: + yaml.dump(data, file, default_flow_style=False) + + def transcribe(self, audio_sample, task, source_lang=None, target_lang='en', + return_timestamps=False, **kwargs) -> dict: + + model = "canary-1b" + if "model" in kwargs: + model = kwargs["model"] + + self.load_model(model, self.compute_type, self.compute_device) + + beam_size = 4 + if "beam_size" in kwargs: + beam_size = kwargs["beam_size"] + length_penalty = 1.0 + if "length_penalty" in kwargs: + length_penalty = kwargs["length_penalty"] + temperature = 1.0 + if "temperature" in kwargs: + temperature = kwargs["temperature"] + + #taskname = "asr" + #if task == "transcription": + # taskname = "asr" + # source_lang = target_lang + #if task == "translation": + # taskname = "ast" + + # transcription + if source_lang == target_lang: + taskname = "asr" + # translation + else: + taskname = "s2t_translation" + + self.model.change_decoding_strategy(None) + decode_cfg = self.model.cfg.decoding + changed_cfg = False + if beam_size != decode_cfg.beam.beam_size: + decode_cfg.beam.beam_size = beam_size + changed_cfg = True + if length_penalty != decode_cfg.beam.len_pen: + decode_cfg.beam.len_pen = length_penalty + changed_cfg = True + if temperature != decode_cfg.temperature: + decode_cfg.temperature = temperature + changed_cfg = True + + if changed_cfg: + self.model.change_decoding_strategy(decode_cfg) + + # setup for buffered inference + self.model.cfg.preprocessor.dither = 0.0 + self.model.cfg.preprocessor.pad_to = 0 + + #feature_stride = self.model.cfg.preprocessor['window_stride'] + #model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer + + #transcript = self.model.transcribe([audio_sample], batch_size=8, num_workers=2, taskname=task, source_lang=source_lang, target_lang=target_lang,) + #transcript = self.model.transcribe([audio_sample], batch_size=8, num_workers=2,) + + with tempfile.TemporaryDirectory() as tmpdirname: + audio_path = os.path.join(tmpdirname, "audio.wav") + # Save the numpy array as a WAV file + sf.write(audio_path, audio_sample, self.sample_rate, 'PCM_16') + + # calculate audio duration + number_of_samples = audio_sample.shape[0] + duration_seconds = number_of_samples / self.sample_rate + + # Prepare the manifest data + manifest_data = [{ + "audio_filepath": audio_path, + "duration": duration_seconds, + "taskname": taskname, + "source_lang": source_lang, + "target_lang": target_lang, + "pnc": "yes", + #"answer": "na", + "answer": "predict", + }] + + manifest_path = os.path.join(tmpdirname, "manifest.json") + with open(manifest_path, "w") as manifest_file: + for entry in manifest_data: + manifest_file.write(json.dumps(entry) + "\n") + + compute_type = self._str_to_dtype_dict(self.compute_type).get('dtype', torch.float32) + + # Transcribe using the model + if not self.compute_device.startswith("cuda"): + with torch.no_grad(): + predicted_text = self.model.transcribe(manifest_path, batch_size=16) + else: + with torch.cuda.amp.autocast(dtype=compute_type): + with torch.no_grad(): + predicted_text = self.model.transcribe(manifest_path, batch_size=16) + + result = { + 'text': "".join(predicted_text), + 'type': "transcribe", + 'language': source_lang, + 'target_lang': target_lang + } + + return result diff --git a/Models/STT/tansformer_whisper.py b/Models/STT/tansformer_whisper.py new file mode 100644 index 0000000..56c5c13 --- /dev/null +++ b/Models/STT/tansformer_whisper.py @@ -0,0 +1,218 @@ +import os + +import torch +import gc + +import yaml +from transformers import WhisperForConditionalGeneration, WhisperProcessor +from Models.Singleton import SingletonMeta + +from pathlib import Path +import downloader + + +class TransformerWhisper(metaclass=SingletonMeta): + model = None + previous_model = None + processor = None + compute_type = "float32" + compute_device = "cpu" + + text_correction_model = None + + currently_downloading = False + model_cache_path = Path(".cache/whisper-transformer") + MODEL_LINKS = {} + MODELS_LIST_URLS = [ + "https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/whisper-transformer/models.yaml", + "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/whisper-transformer/models.yaml", + "https://s3.libs.space:9000/ai-models/whisper-transformer/models.yaml", + ] + _debug_skip_dl = False + + def __init__(self, compute_type="float32", device="cpu"): + os.makedirs(self.model_cache_path, exist_ok=True) + self.compute_type = compute_type + self.compute_device = device + self.load_model_list() + + #if self._debug_skip_dl: + # # generate models.yaml + # self.generate_models_yaml(self.model_cache_path, "models.yaml") + + def _str_to_dtype_dict(self, dtype_str): + if dtype_str == "float16": + return {'dtype': torch.float16, '4bit': False, '8bit': False} + elif dtype_str == "float32": + return {'dtype': torch.float32, '4bit': False, '8bit': False} + elif dtype_str == "4bit": + return {'dtype': torch.float32, '4bit': True, '8bit': False} + elif dtype_str == "8bit": + return {'dtype': torch.float16, '4bit': False, '8bit': True} + else: + return {'dtype': torch.float16, '4bit': False, '8bit': False} + + def set_compute_type(self, compute_type): + self.compute_type = compute_type + + def set_compute_device(self, device): + self.compute_device = device + + def load_model_list(self): + if not self._debug_skip_dl: + if not downloader.download_extract(self.MODELS_LIST_URLS, + str(self.model_cache_path.resolve()), + '', title="Speech 2 Text (Whisper-Transformer Model list)", extract_format="none"): + print("Model list not downloaded. Using cached version.") + + # Load model list + if Path(self.model_cache_path / "models.yaml").exists(): + with open(self.model_cache_path / "models.yaml", "r") as file: + self.MODEL_LINKS = yaml.load(file, Loader=yaml.FullLoader) + file.close() + + def download_model(self, model_name): + model_directory = Path(self.model_cache_path / model_name) + os.makedirs(str(model_directory.resolve()), exist_ok=True) + + # if one of the files does not exist, break the loop and download the files + needs_download = False + for file in self.MODEL_LINKS[model_name]["files"]: + if not Path(model_directory / Path(file["urls"][0]).name).exists(): + needs_download = True + break + + if not needs_download: + for file in self.MODEL_LINKS[model_name]["files"]: + if Path(file["urls"][0]).name == "WS_VERSION": + checksum = downloader.sha256_checksum(str(model_directory.resolve() / Path(file["urls"][0]).name)) + if checksum != file["checksum"]: + needs_download = True + break + + # iterate over all self.MODEL_LINKS[model_name]["files"] entries and download them + if needs_download and not self.currently_downloading: + self.currently_downloading = True + for file in self.MODEL_LINKS[model_name]["files"]: + if not downloader.download_extract(file["urls"], + str(model_directory.resolve()), + file["checksum"], title="Speech 2 Text (Whisper-Transformer) - " + model_name, extract_format="none"): + print(f"Download failed: {file}") + + self.currently_downloading = False + + def load_model(self, model='small', compute_type="float32", device="cpu"): + if self.previous_model is None or model != self.previous_model: + compute_dtype = self._str_to_dtype_dict(compute_type).get('dtype', torch.float32) + compute_4bit = self._str_to_dtype_dict(self.compute_type).get('4bit', False) + compute_8bit = self._str_to_dtype_dict(self.compute_type).get('8bit', False) + self.compute_type = compute_type + + self.compute_device = device + + if not self._debug_skip_dl: + self.download_model(model) + + if self.model is None or model != self.previous_model: + if self.model is not None: + self.release_model() + + self.previous_model = model + self.release_model() + print(f"Loading Whisper-Transformer model: {model} on {device} with {compute_type} precision...") + self.model = WhisperForConditionalGeneration.from_pretrained(str(Path(self.model_cache_path / model).resolve()), torch_dtype=compute_dtype, load_in_8bit=compute_8bit, load_in_4bit=compute_4bit) + if not compute_8bit and not compute_4bit: + self.model = self.model.to(self.compute_device) + self.processor = WhisperProcessor.from_pretrained(str(Path(self.model_cache_path / model).resolve())) + + self.model.config.forced_decoder_ids = None + + def transcribe(self, audio_sample, model, task, language, + return_timestamps=False, beam_size=4) -> dict: + self.load_model(model, self.compute_type, self.compute_device) + + compute_dtype = self._str_to_dtype_dict(self.compute_type).get('dtype', torch.float32) + + if self.model is not None and self.processor is not None: + input_features = self.processor(audio_sample, sampling_rate=16000, return_tensors="pt").to(self.compute_device).to(compute_dtype).input_features + + transcriptions = [""] + + with torch.no_grad(): + predicted_ids = self.model.generate(input_features, + task=task, language=language, num_beams=beam_size, + return_timestamps=return_timestamps, + ) + transcriptions = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) + + #result_text = self.processor.tokenizer._normalize(transcription) + + result_text = ''.join(transcriptions).strip() + + return { + 'text': result_text, + 'type': task, + 'language': language + } + else: + return { + 'text': "", + 'type': task, + 'language': language + } + + def release_model(self): + print("Releasing Whisper-Transformer model...") + if self.model is not None: + if hasattr(self.model, 'model'): + del self.model.model + if hasattr(self.model, 'feature_extractor'): + del self.model.feature_extractor + if hasattr(self.model, 'hf_tokenizer'): + del self.model.hf_tokenizer + del self.model + if self.processor is not None: + del self.processor + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + def generate_models_yaml(self, directory, filename): + # Prepare the data + data = {} + + # Iterate through the directory + for root, dirs, files in os.walk(directory): + ws_version_file = None + # Get the model name from the directory name + model_name = os.path.basename(root) + for file in files: + # Calculate the SHA256 checksum + checksum = downloader.sha256_checksum(os.path.join(root, file)) + + # Initialize the model in the data dictionary if it doesn't exist + if model_name not in data: + data[model_name] = { + 'files': [] + } + + # Add the file details to the model's files list + file_data = { + 'urls': [ + f'https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/whisper-transformer/{model_name}/{file}', + f'https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/whisper-transformer/{model_name}/{file}', + f'https://s3.libs.space:9000/ai-models/whisper-transformer/{model_name}/{file}' + ], + 'checksum': checksum + } + if file == "WS_VERSION": + ws_version_file = file_data + else: + data[model_name]['files'].append(file_data) + + if ws_version_file is not None: + data[model_name]['files'].insert(0, ws_version_file) + + # Write to YAML file + with open(os.path.join(directory, filename), 'w') as file: + yaml.dump(data, file, default_flow_style=False) diff --git a/Models/STT/wav2vec_bert.py b/Models/STT/wav2vec_bert.py new file mode 100644 index 0000000..437ebeb --- /dev/null +++ b/Models/STT/wav2vec_bert.py @@ -0,0 +1,230 @@ +import os + +import torch +import gc + +import yaml +from transformers import Wav2Vec2BertForCTC, Wav2Vec2BertProcessor +from Models.Singleton import SingletonMeta +from Models.TextCorrection import T5 + +from pathlib import Path +import downloader + + +class Wav2VecBert(metaclass=SingletonMeta): + model = None + previous_model = None + processor = None + compute_type = "float32" + compute_device = "cpu" + + text_correction_model = None + + currently_downloading = False + model_cache_path = Path(".cache/wav2vec-bert2.0") + MODEL_LINKS = {} + MODELS_LIST_URLS = [ + "https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/Wav2VecBert/models.yaml", + "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/Wav2VecBert/models.yaml", + "https://s3.libs.space:9000/ai-models/Wav2VecBert/models.yaml", + ] + _debug_skip_dl = False + + def __init__(self, compute_type="float32", device="cpu"): + os.makedirs(self.model_cache_path, exist_ok=True) + self.compute_type = compute_type + self.compute_device = device + self.load_model_list() + + #if self._debug_skip_dl: + # # generate models.yaml + # self.generate_models_yaml(self.model_cache_path, "models.yaml") + + def _str_to_dtype_dict(self, dtype_str): + if dtype_str == "float16": + return {'dtype': torch.float16, '4bit': False, '8bit': False} + elif dtype_str == "float32": + return {'dtype': torch.float32, '4bit': False, '8bit': False} + elif dtype_str == "4bit": + return {'dtype': torch.float32, '4bit': True, '8bit': False} + elif dtype_str == "8bit": + return {'dtype': torch.float32, '4bit': False, '8bit': True} + else: + return {'dtype': torch.float32, '4bit': False, '8bit': False} + + def set_compute_type(self, compute_type): + self.compute_type = compute_type + + def set_compute_device(self, device): + self.compute_device = device + + def load_model_list(self): + if not self._debug_skip_dl: + if not downloader.download_extract(self.MODELS_LIST_URLS, + str(self.model_cache_path.resolve()), + '', title="Speech 2 Text (Wav2VecBert2 Model list)", extract_format="none"): + print("Model list not downloaded. Using cached version.") + + # Load model list + if Path(self.model_cache_path / "models.yaml").exists(): + with open(self.model_cache_path / "models.yaml", "r") as file: + self.MODEL_LINKS = yaml.load(file, Loader=yaml.FullLoader) + file.close() + + def get_languages(self): + if not self.MODEL_LINKS: + # Return a default value or message. Here, we return an empty tuple as a fallback. + return () + + # Generate a list of dictionaries, each containing the language code and language name + languages = [] + for language, details in self.MODEL_LINKS.items(): + # Extract the lang_code for the current language entry + lang_name = details.get("lang_name", "") # Fallback to an empty string if not found + languages.append({"code": language, "name": lang_name}) + return tuple(languages) + + def download_model(self, model_name): + model_directory = Path(self.model_cache_path / model_name) + os.makedirs(str(model_directory.resolve()), exist_ok=True) + + # if one of the files does not exist, break the loop and download the files + needs_download = False + for file in self.MODEL_LINKS[model_name]["files"]: + if not Path(model_directory / Path(file["urls"][0]).name).exists(): + needs_download = True + break + + if not needs_download: + for file in self.MODEL_LINKS[model_name]["files"]: + if Path(file["urls"][0]).name == "WS_VERSION": + checksum = downloader.sha256_checksum(str(model_directory.resolve() / Path(file["urls"][0]).name)) + if checksum != file["checksum"]: + needs_download = True + break + + # iterate over all self.MODEL_LINKS[model_name]["files"] entries and download them + if needs_download and not self.currently_downloading: + self.currently_downloading = True + for file in self.MODEL_LINKS[model_name]["files"]: + if not downloader.download_extract(file["urls"], + str(model_directory.resolve()), + file["checksum"], title="Speech 2 Text (Wav2VecBert2) - " + model_name, extract_format="none"): + print(f"Download failed: {file}") + + self.currently_downloading = False + + def load_model(self, model='english', compute_type="float32", device="cpu"): + if self.previous_model is None or model != self.previous_model: + compute_dtype = self._str_to_dtype_dict(compute_type).get('dtype', torch.float32) + compute_4bit = self._str_to_dtype_dict(self.compute_type).get('4bit', False) + compute_8bit = self._str_to_dtype_dict(self.compute_type).get('8bit', False) + self.compute_type = compute_type + + self.compute_device = device + + if not self._debug_skip_dl: + self.download_model(model) + + if self.model is None or model != self.previous_model: + if self.model is not None: + self.release_model() + + self.previous_model = model + self.release_model() + print(f"Loading wav2vec model: {model} on {device} with {compute_type} precision...") + self.model = Wav2Vec2BertForCTC.from_pretrained(str(Path(self.model_cache_path / model).resolve()), torch_dtype=compute_dtype, load_in_8bit=compute_8bit, load_in_4bit=compute_4bit) + if not compute_8bit and not compute_4bit: + self.model = self.model.to(self.compute_device) + self.processor = Wav2Vec2BertProcessor.from_pretrained(str(Path(self.model_cache_path / model).resolve())) + + # load text correction model + self.text_correction_model = T5.TextCorrectionT5(compute_type, device) + + def transcribe(self, audio_sample, task, language) -> dict: + self.load_model(language, self.compute_type, self.compute_device) + + compute_dtype = self._str_to_dtype_dict(self.compute_type).get('dtype', torch.float32) + + if self.model is not None and self.processor is not None: + input_features = self.processor(audio=audio_sample, sampling_rate=16000, return_tensors="pt").to(self.compute_device).to(compute_dtype) + + with torch.no_grad(): + logits = self.model(**input_features).logits + + pred_ids = torch.argmax(logits, dim=-1) + + result_text = self.processor.batch_decode(pred_ids) + + if self.text_correction_model is not None and result_text[0] != "": + result_text[0] = self.text_correction_model.translate(result_text[0], language) + + return { + 'text': result_text[0], + 'type': task, + 'language': language + } + else: + return { + 'text': "", + 'type': task, + 'language': language + } + + def release_model(self): + if self.model is not None: + print("Releasing wav2vec model...") + if hasattr(self.model, 'model'): + del self.model.model + if hasattr(self.model, 'feature_extractor'): + del self.model.feature_extractor + if hasattr(self.model, 'hf_tokenizer'): + del self.model.hf_tokenizer + del self.model + if self.processor is not None: + del self.processor + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + def generate_models_yaml(self, directory, filename): + # Prepare the data + data = {} + + # Iterate through the directory + for root, dirs, files in os.walk(directory): + ws_version_file = None + # Get the model name from the directory name + model_name = os.path.basename(root) + for file in files: + # Calculate the SHA256 checksum + checksum = downloader.sha256_checksum(os.path.join(root, file)) + + # Initialize the model in the data dictionary if it doesn't exist + if model_name not in data: + data[model_name] = { + 'lang_name': model_name.capitalize(), + 'files': [] + } + + # Add the file details to the model's files list + file_data = { + 'urls': [ + f'https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/Wav2VecBert/{model_name}/{file}', + f'https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/Wav2VecBert/{model_name}/{file}', + f'https://s3.libs.space:9000/ai-models/Wav2VecBert/{model_name}/{file}' + ], + 'checksum': checksum + } + if file == "WS_VERSION": + ws_version_file = file_data + else: + data[model_name]['files'].append(file_data) + + if ws_version_file is not None: + data[model_name]['files'].insert(0, ws_version_file) + + # Write to YAML file + with open(os.path.join(directory, filename), 'w') as file: + yaml.dump(data, file, default_flow_style=False) diff --git a/Models/STT/whisper_audio_markers.py b/Models/STT/whisper_audio_markers.py index d9e4a1d..612c398 100644 --- a/Models/STT/whisper_audio_markers.py +++ b/Models/STT/whisper_audio_markers.py @@ -5,6 +5,7 @@ import audio_tools import settings +from Models.Singleton import SingletonMeta language_mapping_iso3_to_iso_1 = { "eng": "en", @@ -24,7 +25,7 @@ def iso3_to_iso_1(iso3): return iso3 -class WhisperVoiceMarker: +class WhisperVoiceMarker(metaclass=SingletonMeta): audio_sample = None audio_model = None try_count = 0 diff --git a/Models/TTS/silero.py b/Models/TTS/silero.py index 6467ee8..9cf7372 100644 --- a/Models/TTS/silero.py +++ b/Models/TTS/silero.py @@ -423,6 +423,7 @@ def tts(self, text): audio = plugin_audio['audio'] except Exception as e: + print(e) return None, None return audio, self.sample_rate diff --git a/Models/TextCorrection/T5.py b/Models/TextCorrection/T5.py new file mode 100644 index 0000000..153f534 --- /dev/null +++ b/Models/TextCorrection/T5.py @@ -0,0 +1,228 @@ +import os +import re + +import torch +import gc + +import yaml +from transformers import T5Tokenizer, T5ForConditionalGeneration +from Models.Singleton import SingletonMeta + +from pathlib import Path +import downloader + + +# https://huggingface.co/vennify/t5-base-grammar-correction +# https://huggingface.co/flexudy/t5-small-wav2vec2-grammar-fixer +# https://huggingface.co/aiassociates/t5-small-grammar-correction-german + +class TextCorrectionT5(metaclass=SingletonMeta): + model = None + previous_model = None + tokenizer = None + compute_type = "float32" + compute_device = "cpu" + prompt_template = "{text}" + capitalize_text = False + cleanup = None + + currently_downloading = False + model_cache_path = Path(".cache/text_correction_t5") + MODEL_LINKS = {} + MODELS_LIST_URLS = [ + "https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/TextCorrectionT5/models.yaml", + "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/TextCorrectionT5/models.yaml", + "https://s3.libs.space:9000/ai-models/TextCorrectionT5/models.yaml", + ] + _debug_skip_dl = False + + def __init__(self, compute_type="float32", device="cpu"): + os.makedirs(self.model_cache_path, exist_ok=True) + self.compute_type = compute_type + self.compute_device = device + self.load_model_list() + + #if self._debug_skip_dl: + # # generate models.yaml + # self.generate_models_yaml(self.model_cache_path, "models.yaml") + + def _str_to_dtype_dict(self, dtype_str): + if dtype_str == "float16": + return {'dtype': torch.float16, '4bit': False, '8bit': False} + elif dtype_str == "float32": + return {'dtype': torch.float32, '4bit': False, '8bit': False} + elif dtype_str == "4bit": + return {'dtype': torch.float32, '4bit': True, '8bit': False} + elif dtype_str == "8bit": + return {'dtype': torch.float32, '4bit': False, '8bit': True} + else: + return {'dtype': torch.float32, '4bit': False, '8bit': False} + + def set_compute_type(self, compute_type): + self.compute_type = compute_type + + def set_compute_device(self, device): + self.compute_device = device + + def load_model_list(self): + if not self._debug_skip_dl: + if not downloader.download_extract(self.MODELS_LIST_URLS, + str(self.model_cache_path.resolve()), + '', title="Text Correction (T5 Model list)", extract_format="none"): + print("Model list not downloaded. Using cached version.") + + # Load model list + if Path(self.model_cache_path / "models.yaml").exists(): + with open(self.model_cache_path / "models.yaml", "r") as file: + self.MODEL_LINKS = yaml.load(file, Loader=yaml.FullLoader) + file.close() + + def download_model(self, model_name): + model_directory = Path(self.model_cache_path / model_name) + os.makedirs(str(model_directory.resolve()), exist_ok=True) + + # if one of the files does not exist, break the loop and download the files + needs_download = False + for file in self.MODEL_LINKS[model_name]["files"]: + if not Path(model_directory / Path(file["urls"][0]).name).exists(): + needs_download = True + break + + if not needs_download: + for file in self.MODEL_LINKS[model_name]["files"]: + if Path(file["urls"][0]).name == "WS_VERSION": + checksum = downloader.sha256_checksum(str(model_directory.resolve() / Path(file["urls"][0]).name)) + if checksum != file["checksum"]: + needs_download = True + break + + # iterate over all self.MODEL_LINKS[model_name]["files"] entries and download them + if needs_download and not self.currently_downloading: + self.currently_downloading = True + for file in self.MODEL_LINKS[model_name]["files"]: + if not downloader.download_extract(file["urls"], + str(model_directory.resolve()), + file["checksum"], title="Text Correction (T5) - " + model_name, extract_format="none"): + print(f"Download failed: {file}") + + self.currently_downloading = False + + def load_model(self, model='english', compute_type="float32", device="cpu"): + if self.previous_model is None or model != self.previous_model: + compute_dtype = self._str_to_dtype_dict(compute_type).get('dtype', torch.float32) + compute_4bit = self._str_to_dtype_dict(self.compute_type).get('4bit', False) + compute_8bit = self._str_to_dtype_dict(self.compute_type).get('8bit', False) + self.compute_type = compute_type + + self.compute_device = device + + if not self._debug_skip_dl: + self.download_model(model) + + if self.model is None or model != self.previous_model: + if self.model is not None: + self.release_model() + + self.previous_model = model + self.release_model() + print(f"Loading T5 model: {model} on {device} with {compute_type} precision...") + self.model = T5ForConditionalGeneration.from_pretrained(str(Path(self.model_cache_path / model).resolve()), torch_dtype=compute_dtype, load_in_8bit=compute_8bit, load_in_4bit=compute_4bit) + if not compute_8bit and not compute_4bit: + self.model = self.model.to(self.compute_device) + self.tokenizer = T5Tokenizer.from_pretrained(str(Path(self.model_cache_path / model).resolve()), torch_dtype=compute_dtype) + self.prompt_template = self.MODEL_LINKS[model].get("prompt_template", "") + self.capitalize_text = self.MODEL_LINKS[model].get("capitalize", False) + self.cleanup = self.MODEL_LINKS[model].get("cleanup", None) + + def translate(self, text, language) -> str: + self.load_model(language, self.compute_type, self.compute_device) + + if self.model is not None and self.tokenizer is not None: + try: + input_text = self.prompt_template.format(text=text) + + input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=256, truncation=True, add_special_tokens=True).to(self.compute_device) + + with torch.no_grad(): + outputs = self.model.generate( + input_ids=input_ids, + max_length=256, + num_beams=4, + repetition_penalty=1.0, + length_penalty=1.0, + early_stopping=True + ) + + result_sentence = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True) + result_sentence_text = ' '.join(result_sentence).strip() + + if self.cleanup is not None: + for cleanup_entry in self.cleanup: + if 'pattern' in cleanup_entry and 'replace' in cleanup_entry: + result_sentence_text = re.sub(cleanup_entry['pattern'], cleanup_entry['replace'], result_sentence_text) + + return result_sentence_text + except Exception as e: + print(f"Error: {e}") + return text + else: + return text + + def release_model(self): + if self.model is not None: + print("Releasing T5 model...") + if hasattr(self.model, 'model'): + del self.model.model + if hasattr(self.model, 'feature_extractor'): + del self.model.feature_extractor + if hasattr(self.model, 'hf_tokenizer'): + del self.model.hf_tokenizer + del self.model + if self.tokenizer is not None: + del self.tokenizer + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + def generate_models_yaml(self, directory, filename): + # Prepare the data + data = {} + + # Iterate through the directory + for root, dirs, files in os.walk(directory): + ws_version_file = None + # Get the model name from the directory name + model_name = os.path.basename(root) + for file in files: + # Calculate the SHA256 checksum + checksum = downloader.sha256_checksum(os.path.join(root, file)) + + # Initialize the model in the data dictionary if it doesn't exist + if model_name not in data: + data[model_name] = { + 'prompt_template': "{text}", + 'capitalize': False, + 'files': [], + 'cleanup': {'pattern': '(\?|!)\.', 'replace': '\\1'} + } + + # Add the file details to the model's files list + file_data = { + 'urls': [ + f'https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/TextCorrectionT5/{model_name}/{file}', + f'https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/TextCorrectionT5/{model_name}/{file}', + f'https://s3.libs.space:9000/ai-models/TextCorrectionT5/{model_name}/{file}' + ], + 'checksum': checksum + } + if file == "WS_VERSION": + ws_version_file = file_data + else: + data[model_name]['files'].append(file_data) + + if ws_version_file is not None: + data[model_name]['files'].insert(0, ws_version_file) + + # Write to YAML file + with open(os.path.join(directory, filename), 'w') as file: + yaml.dump(data, file, default_flow_style=False) diff --git a/VRC_OSCLib.py b/VRC_OSCLib.py index afb85c3..2c03081 100644 --- a/VRC_OSCLib.py +++ b/VRC_OSCLib.py @@ -110,7 +110,7 @@ def set_min_time_between_messages(time_in_seconds): def _send_osc_message(): global last_message_sent_time, min_time_between_messages - print("min_time_between_messages: " + str(min_time_between_messages)) + #print("min_time_between_messages: " + str(min_time_between_messages)) while True: try: # Wait for a message to be available in the queue. This will block until a message is available. diff --git a/audioWhisper.py b/audioWhisper.py index 5b71701..c5f3abf 100644 --- a/audioWhisper.py +++ b/audioWhisper.py @@ -1,1211 +1,1231 @@ # -*- encoding: utf-8 -*- -import os -import platform -import sys -import json -import traceback +if __name__ == '__main__': + import multiprocessing + multiprocessing.freeze_support() -import Utilities -import downloader -import processmanager -import atexit + import os + import platform + import sys + import json + import traceback -from Models.TTS import silero + import Utilities + import downloader + import processmanager + import atexit -# set environment variable CT2_CUDA_ALLOW_FP16 to 1 (before ctranslate2 is imported) -# to allow using FP16 computation on GPU even if the device does not have efficient FP16 support. -os.environ["CT2_CUDA_ALLOW_FP16"] = "1" + from Models.TTS import silero + # set environment variable CT2_CUDA_ALLOW_FP16 to 1 (before ctranslate2 is imported) + # to allow using FP16 computation on GPU even if the device does not have efficient FP16 support. + os.environ["CT2_CUDA_ALLOW_FP16"] = "1" -atexit.register(processmanager.cleanup_subprocesses) + # enable fast GPU mode for safetensors (https://huggingface.co/docs/safetensors/speed) + os.environ["SAFETENSORS_FAST_GPU"] = "1" + atexit.register(processmanager.cleanup_subprocesses) -def handle_exception(exc_type, exc_value, exc_traceback): - error_msg = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - print(error_msg, file=sys.stderr) # print to standard error stream + def handle_exception(exc_type, exc_value, exc_traceback): + error_msg = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - # Format the traceback and error message as a JSON string - error_dict = { - 'type': "error", - 'message': str(exc_value), - 'traceback': traceback.format_tb(exc_traceback) - } - error_json = json.dumps(error_dict) - print(error_json, file=sys.stderr) # print to standard error stream + print(error_msg, file=sys.stderr) # print to standard error stream + # Format the traceback and error message as a JSON string + error_dict = { + 'type': "error", + 'message': str(exc_value), + 'traceback': traceback.format_tb(exc_traceback) + } + error_json = json.dumps(error_dict) + print(error_json, file=sys.stderr) # print to standard error stream -sys.excepthook = handle_exception -import io -import signal -import time -import threading + sys.excepthook = handle_exception -# import speech_recognition_patch as sr # this is a patched version of speech_recognition. (disabled for now because of freeze issues) -import speech_recognition as sr -import audioprocessor -from pathlib import Path -import click -import VRC_OSCLib -import websocket -import settings -import remote_opener -from Models.STT import faster_whisper -from Models.Multi import seamless_m4t -from Models.TextTranslation import texttranslate -from Models import languageClassification -from Models import sentence_split -import pyaudiowpatch as pyaudio -from whisper import available_models, audio as whisper_audio + import io + import signal + import time + import threading -import keyboard + # import speech_recognition_patch as sr # this is a patched version of speech_recognition. (disabled for now because of freeze issues) + import speech_recognition as sr + import audioprocessor + from pathlib import Path + import click + import VRC_OSCLib + import websocket + import settings + import remote_opener + from Models.STT import faster_whisper + from Models.Multi import seamless_m4t + from Models.TextTranslation import texttranslate + from Models import languageClassification + from Models import sentence_split + import pyaudiowpatch as pyaudio + from whisper import available_models, audio as whisper_audio -import numpy as np -import torch -import torchaudio -import audio_tools -import sounddevice as sd + import keyboard -import wave + import numpy as np + import torch -from Models.STS import DeepFilterNet + torch.backends.cudnn.benchmark = True + import audio_tools + import sounddevice as sd -def save_to_wav(data, filename, sample_rate, channels=1): - with wave.open(filename, 'wb') as wf: - wf.setnchannels(channels) - wf.setsampwidth(2) # Assuming 16-bit audio - wf.setframerate(sample_rate) - wf.writeframes(data) + import wave + from Models.STS import DeepFilterNet -#torchaudio.set_audio_backend("soundfile") -py_audio = pyaudio.PyAudio() -FORMAT = pyaudio.paInt16 -CHANNELS = 1 -SAMPLE_RATE = whisper_audio.SAMPLE_RATE -CHUNK = int(SAMPLE_RATE / 10) -cache_vad_path = Path(Path.cwd() / ".cache" / "silero-vad") -os.makedirs(cache_vad_path, exist_ok=True) + def save_to_wav(data, filename, sample_rate, channels=1): + with wave.open(filename, 'wb') as wf: + wf.setnchannels(channels) + wf.setsampwidth(2) # Assuming 16-bit audio + wf.setframerate(sample_rate) + wf.writeframes(data) -def sigterm_handler(_signo, _stack_frame): - processmanager.cleanup_subprocesses() + #torchaudio.set_audio_backend("soundfile") + py_audio = pyaudio.PyAudio() + FORMAT = pyaudio.paInt16 + CHANNELS = 1 + SAMPLE_RATE = whisper_audio.SAMPLE_RATE + CHUNK = int(SAMPLE_RATE / 10) - # reset process id - settings.SetOption("process_id", 0) + cache_vad_path = Path(Path.cwd() / ".cache" / "silero-vad") + os.makedirs(cache_vad_path, exist_ok=True) - # it raises SystemExit(0): - print('Process died') - sys.exit(0) + def sigterm_handler(_signo, _stack_frame): + processmanager.cleanup_subprocesses() -signal.signal(signal.SIGTERM, sigterm_handler) -signal.signal(signal.SIGINT, sigterm_handler) -signal.signal(signal.SIGABRT, sigterm_handler) + # reset process id + settings.SetOption("process_id", 0) + # it raises SystemExit(0): + print('Process died') + sys.exit(0) -# Taken from utils_vad.py -def validate(model, - inputs: torch.Tensor): - with torch.no_grad(): - outs = model(inputs) - return outs + signal.signal(signal.SIGTERM, sigterm_handler) + signal.signal(signal.SIGINT, sigterm_handler) + signal.signal(signal.SIGABRT, sigterm_handler) -# Provided by Alexander Veysov -def int2float(sound): - abs_max = np.abs(sound).max() - sound = sound.astype('float32') - if abs_max > 0: - sound *= 1 / abs_max - sound = sound.squeeze() # depends on the use case - return sound + # Taken from utils_vad.py + def validate(model, + inputs: torch.Tensor): + with torch.no_grad(): + outs = model(inputs) + return outs + + # Provided by Alexander Veysov + def int2float(sound): + abs_max = np.abs(sound).max() + sound = sound.astype('float32') + if abs_max > 0: + sound *= 1 / abs_max + sound = sound.squeeze() # depends on the use case + return sound + + + def call_plugin_timer(plugins): + # Call the method every x seconds + timer = threading.Timer(settings.GetOption("plugin_timer"), call_plugin_timer, args=[plugins]) + timer.start() + if not settings.GetOption("plugin_timer_stopped"): + for plugin_inst in plugins.plugins: + if plugin_inst.is_enabled(False) and hasattr(plugin_inst, 'timer'): + plugin_inst.timer() + else: + if settings.GetOption("plugin_current_timer") <= 0.0: + settings.SetOption("plugin_current_timer", settings.GetOption("plugin_timer_timeout")) + else: + settings.SetOption("plugin_current_timer", + settings.GetOption("plugin_current_timer") - settings.GetOption("plugin_timer")) + if settings.GetOption("plugin_current_timer") <= 0.0: + settings.SetOption("plugin_timer_stopped", False) + settings.SetOption("plugin_current_timer", 0.0) -def call_plugin_timer(plugins): - # Call the method every x seconds - timer = threading.Timer(settings.GetOption("plugin_timer"), call_plugin_timer, args=[plugins]) - timer.start() - if not settings.GetOption("plugin_timer_stopped"): + def call_plugin_sts(plugins, wavefiledata, sample_rate): for plugin_inst in plugins.plugins: - if plugin_inst.is_enabled(False) and hasattr(plugin_inst, 'timer'): - plugin_inst.timer() - else: - if settings.GetOption("plugin_current_timer") <= 0.0: - settings.SetOption("plugin_current_timer", settings.GetOption("plugin_timer_timeout")) + if plugin_inst.is_enabled(False) and hasattr(plugin_inst, 'sts'): + plugin_inst.sts(wavefiledata, sample_rate) + + + #def call_plugin_sts_chunk(plugins, wavefiledata, sample_rate): + # for plugin_inst in plugins.plugins: + # if plugin_inst.is_enabled(False) and hasattr(plugin_inst, 'sts_chunk'): + # plugin_inst.sts_chunk(wavefiledata, sample_rate) + + + def audio_bytes_to_wav(audio_bytes): + final_wavfile = io.BytesIO() + wavefile = wave.open(final_wavfile, 'wb') + wavefile.setnchannels(CHANNELS) + wavefile.setsampwidth(2) + wavefile.setframerate(SAMPLE_RATE) + wavefile.writeframes(audio_bytes) + + final_wavfile.seek(0) + return_data = final_wavfile.read() + wavefile.close() + return return_data + + + def typing_indicator_function(osc_ip, osc_port, send_websocket=True): + if osc_ip != "0" and settings.GetOption("osc_auto_processing_enabled") and settings.GetOption( + "osc_typing_indicator"): + VRC_OSCLib.Bool(True, "/chatbox/typing", IP=osc_ip, PORT=osc_port) + if send_websocket and settings.GetOption("websocket_ip") != "0": + threading.Thread( + target=websocket.BroadcastMessage, + args=(json.dumps({"type": "processing_start", "data": True}),) + ).start() + + + def process_audio_chunk(audio_chunk, vad_model, sample_rate): + audio_int16 = np.frombuffer(audio_chunk, np.int16) + audio_float32 = int2float(audio_int16) + if vad_model is not None: + new_confidence = vad_model(torch.from_numpy(audio_float32), sample_rate).item() else: - settings.SetOption("plugin_current_timer", - settings.GetOption("plugin_current_timer") - settings.GetOption("plugin_timer")) - if settings.GetOption("plugin_current_timer") <= 0.0: - settings.SetOption("plugin_timer_stopped", False) - settings.SetOption("plugin_current_timer", 0.0) - - -def call_plugin_sts(plugins, wavefiledata, sample_rate): - for plugin_inst in plugins.plugins: - if plugin_inst.is_enabled(False) and hasattr(plugin_inst, 'sts'): - plugin_inst.sts(wavefiledata, sample_rate) - - -def audio_bytes_to_wav(audio_bytes): - final_wavfile = io.BytesIO() - wavefile = wave.open(final_wavfile, 'wb') - wavefile.setnchannels(CHANNELS) - wavefile.setsampwidth(2) - wavefile.setframerate(SAMPLE_RATE) - wavefile.writeframes(audio_bytes) - - final_wavfile.seek(0) - return_data = final_wavfile.read() - wavefile.close() - return return_data - - -def typing_indicator_function(osc_ip, osc_port, send_websocket=True): - if osc_ip != "0" and settings.GetOption("osc_auto_processing_enabled") and settings.GetOption( - "osc_typing_indicator"): - VRC_OSCLib.Bool(True, "/chatbox/typing", IP=osc_ip, PORT=osc_port) - if send_websocket and settings.GetOption("websocket_ip") != "0": - threading.Thread( - target=websocket.BroadcastMessage, - args=(json.dumps({"type": "processing_start", "data": True}),) - ).start() - - -def process_audio_chunk(audio_chunk, vad_model, sample_rate): - audio_int16 = np.frombuffer(audio_chunk, np.int16) - audio_float32 = int2float(audio_int16) - if vad_model is not None: - new_confidence = vad_model(torch.from_numpy(audio_float32), sample_rate).item() - else: - new_confidence = 9.9 - peak_amplitude = np.max(np.abs(audio_int16)) - - # clear the variables - audio_int16 = None - del audio_int16 - audio_float32 = None - del audio_float32 - return new_confidence, peak_amplitude - - -def should_start_recording(peak_amplitude, energy, new_confidence, confidence_threshold, keyboard_key=None): - return ((keyboard_key is not None and keyboard.is_pressed( - keyboard_key)) or (0 < energy <= peak_amplitude and new_confidence >= confidence_threshold)) - - -def should_stop_recording(new_confidence, confidence_threshold, peak_amplitude, energy, pause_time, pause, - keyboard_key=None): - return (keyboard_key is not None and not keyboard.is_pressed(keyboard_key)) or ( - 0 < energy > peak_amplitude and (new_confidence < confidence_threshold or confidence_threshold == 0.0) and ( - time.time() - pause_time) > pause > 0.0) - - -def get_host_audio_api_names(): - audio = pyaudio.PyAudio() - host_api_count = audio.get_host_api_count() - host_api_names = {} - for i in range(host_api_count): - host_api_info = audio.get_host_api_info_by_index(i) - host_api_names[i] = host_api_info["name"] - return host_api_names - - -def get_default_audio_device_index_by_api(api, is_input=True): - devices = sd.query_devices() - api_info = sd.query_hostapis() - host_api_index = None - - for i, host_api in enumerate(api_info): - if api.lower() in host_api['name'].lower(): - host_api_index = i - break - - if host_api_index is None: - return None - - api_pyaudio_index, _ = get_audio_api_index_by_name(api) - - default_device_index = api_info[host_api_index]['default_input_device' if is_input else 'default_output_device'] - default_device_name = devices[default_device_index]['name'] - return get_audio_device_index_by_name_and_api(default_device_name, api_pyaudio_index, is_input) - - -def get_audio_device_index_by_name_and_api(name, api, is_input=True, default=None): - audio = pyaudio.PyAudio() - device_count = audio.get_device_count() - for i in range(device_count): - device_info = audio.get_device_info_by_index(i) - device_name = device_info["name"] - if isinstance(device_name, bytes): - device_name = Utilities.safe_decode(device_name) - if isinstance(name, bytes): - name = Utilities.safe_decode(name) - - if device_info["hostApi"] == api and device_info[ - "maxInputChannels" if is_input else "maxOutputChannels"] > 0 and name in device_name: - return i - return default - - -def get_audio_api_index_by_name(name): - audio = pyaudio.PyAudio() - host_api_count = audio.get_host_api_count() - for i in range(host_api_count): - host_api_info = audio.get_host_api_info_by_index(i) - if name.lower() in host_api_info["name"].lower(): - return i, host_api_info["name"] - return 0, "" - - -def record_highest_peak_amplitude(device_index=-1, record_time=10): - py_audio = pyaudio.PyAudio() + new_confidence = 9.9 + peak_amplitude = np.max(np.abs(audio_int16)) - default_sample_rate = SAMPLE_RATE - - stream, needs_sample_rate_conversion, recorded_sample_rate, is_mono = audio_tools.start_recording_audio_stream( - device_index, - sample_format=FORMAT, - sample_rate=SAMPLE_RATE, - channels=CHANNELS, - chunk=CHUNK, - py_audio=py_audio, - ) - - highest_peak_amplitude = 0 - start_time = time.time() - - while time.time() - start_time < record_time: - audio_chunk = stream.read(CHUNK, exception_on_overflow=False) - # special case which seems to be needed for WASAPI - if needs_sample_rate_conversion: - audio_chunk = audio_tools.resample_audio(audio_chunk, recorded_sample_rate, default_sample_rate, -1, - is_mono=is_mono).tobytes() - - _, peak_amplitude = process_audio_chunk(audio_chunk, None, default_sample_rate) - highest_peak_amplitude = max(highest_peak_amplitude, peak_amplitude) - - stream.stop_stream() - stream.close() - - return highest_peak_amplitude - - -class AudioProcessor: - last_callback_time = time.time() - - def __init__(self, - default_sample_rate=SAMPLE_RATE, - previous_audio_chunk=None, - start_rec_on_volume_threshold=None, - push_to_talk_key=None, - keyboard_rec_force_stop=None, - vad_model=None, - needs_sample_rate_conversion=False, - recorded_sample_rate=None, - is_mono=False, - - plugins=None, - audio_enhancer=None, - - osc_ip=None, - osc_port=None, - - chunk=None, - channels=None, - sample_format=None, - - verbose=False - ): - if plugins is None: - plugins = [] - self.frames = [] - self.default_sample_rate = default_sample_rate - self.previous_audio_chunk = previous_audio_chunk - self.start_rec_on_volume_threshold = start_rec_on_volume_threshold - self.push_to_talk_key = push_to_talk_key - self.keyboard_rec_force_stop = keyboard_rec_force_stop - - self.vad_model = vad_model - - self.needs_sample_rate_conversion = needs_sample_rate_conversion - self.recorded_sample_rate = recorded_sample_rate - self.is_mono = is_mono - - self.Plugins = plugins - self.audio_enhancer = audio_enhancer - - self.osc_ip = osc_ip - self.osc_port = osc_port - - self.verbose = verbose - - self.start_time = time.time() - self.pause_time = time.time() - self.intermediate_time_start = time.time() - - self.block_size_samples = int(self.default_sample_rate * 0.400) # calculate block size in samples. (0.400 is the default block size of pyloudnorm) - - self.chunk = chunk - self.channels = channels - self.sample_format = sample_format - # run callback after timeout even if no audio was detected (and such callback not called by pyAudio) - #self.timer_reset_event = threading.Event() - #self.timer_thread = threading.Thread(target=self.timer_expired) - #self.timer_thread.start() - #self.timer_reset_event.set() - #self.last_callback_time = time.time() - - # The function to call when the timer expires - #def timer_expired(self): - # while True: - # current_time = time.time() - # time_since_last_callback = current_time - self.last_callback_time - # if self.recorded_sample_rate is not None: - # # wait double the chunk size to not accidentally call callback twice - # self.timer_reset_event.wait(timeout=(self.chunk / self.recorded_sample_rate)*2) - # if time_since_last_callback >= (self.chunk / self.recorded_sample_rate)*2 and len(self.frames) > 0: - # #print("Timer expired. Triggering callback.") - # try: - # print("Timer expired. Triggering callback.") - # self.callback(None, None, None, None) - # except Exception as e: - # print(e) - # self.timer_reset_event.clear() - - def callback(self, in_data, frame_count, time_info, status): - # Reset the timer each time the callback is triggered - #self.last_callback_time = time.time() - #self.timer_reset_event.set() - - if not settings.GetOption("stt_enabled"): - return None, pyaudio.paContinue - - # disable gradient calculation - with torch.no_grad(): - phrase_time_limit = settings.GetOption("phrase_time_limit") - pause = settings.GetOption("pause") - energy = settings.GetOption("energy") - if phrase_time_limit == 0: - phrase_time_limit = None - - silence_cutting_enabled = settings.GetOption("silence_cutting_enabled") - silence_offset = settings.GetOption("silence_offset") - max_silence_length = settings.GetOption("max_silence_length") - keep_silence_length = settings.GetOption("keep_silence_length") - - normalize_enabled = settings.GetOption("normalize_enabled") - normalize_lower_threshold = settings.GetOption("normalize_lower_threshold") - normalize_upper_threshold = settings.GetOption("normalize_upper_threshold") - normalize_gain_factor = settings.GetOption("normalize_gain_factor") - - clip_duration = phrase_time_limit - fps = 0 - if clip_duration is not None: - fps = int(self.recorded_sample_rate / CHUNK * clip_duration) - - end_time = time.time() - elapsed_time = end_time - self.start_time - elapsed_intermediate_time = end_time - self.intermediate_time_start - - confidence_threshold = float(settings.GetOption("vad_confidence_threshold")) - - #if settings.GetOption("denoise_audio") and audio_enhancer is not None and num_samples < DeepFilterNet.ModelParams().hop_size: - # #print("increase num_samples for denoising") - # num_samples = DeepFilterNet.ModelParams().hop_size - - #audio_chunk = stream.read(num_samples, exception_on_overflow=False) - - # denoise audio chunk - #if settings.GetOption("denoise_audio") and audio_enhancer is not None: - # record more audio to denoise if it's too short - #if len(audio_chunk) < DeepFilterNet.ModelParams().hop_size: - #while len(audio_chunk) < DeepFilterNet.ModelParams().hop_size: - # audio_chunk += stream.read(num_samples, exception_on_overflow=False) - #audio_chunk = audio_enhancer.enhance_audio(audio_chunk, recorded_sample_rate, default_sample_rate, is_mono=is_mono) - #needs_sample_rate_conversion = False - - # denoise audio chunk - #if settings.GetOption("denoise_audio") and audio_enhancer is not None: - # #if len(audio_chunk) < DeepFilterNet.ModelParams().hop_size: - # # while len(audio_chunk) < DeepFilterNet.ModelParams().hop_size * 2: - # # audio_chunk += stream.read(num_samples, exception_on_overflow=False) - # audio_chunk = audio_enhancer.enhance_audio(audio_chunk, recorded_sample_rate, default_sample_rate, is_mono=is_mono) - # needs_sample_rate_conversion = False - # #recorded_sample_rate = audio_enhancer.df_state.sr() - - test_audio_chunk = in_data - audio_chunk = in_data - # special case which seems to be needed for WASAPI - if self.needs_sample_rate_conversion and test_audio_chunk is not None: - test_audio_chunk = audio_tools.resample_audio(test_audio_chunk, self.recorded_sample_rate, self.default_sample_rate, -1, - is_mono=self.is_mono).tobytes() - - new_confidence, peak_amplitude = 0, 0 - if test_audio_chunk is not None: - new_confidence, peak_amplitude = process_audio_chunk(test_audio_chunk, self.vad_model, self.default_sample_rate) - - # put frames with recognized speech into a list and send to whisper - if (clip_duration is not None and len(self.frames) > fps) or ( - elapsed_time > pause > 0.0 and len(self.frames) > 0) or ( - self.keyboard_rec_force_stop and self.push_to_talk_key is not None and not keyboard.is_pressed( - self.push_to_talk_key) and len(self.frames) > 0): - - clip = [] - # merge all frames to one audio clip - for i in range(0, len(self.frames)): - if self.frames[i] is not None: - clip.append(self.frames[i]) - - if len(clip) > 0: - wavefiledata = b''.join(clip) - else: - return None, pyaudio.paContinue + # clear the variables + audio_int16 = None + del audio_int16 + audio_float32 = None + del audio_float32 + return new_confidence, peak_amplitude - if self.needs_sample_rate_conversion: - wavefiledata = audio_tools.resample_audio(wavefiledata, self.recorded_sample_rate, self.default_sample_rate, -1, - is_mono=self.is_mono).tobytes() - # normalize audio (and make sure it's longer or equal the default block size by pyloudnorm) - if normalize_enabled and len(wavefiledata) >= self.block_size_samples: - wavefiledata = audio_tools.convert_audio_datatype_to_float(np.frombuffer(wavefiledata, np.int16)) - wavefiledata, lufs = audio_tools.normalize_audio_lufs( - wavefiledata, self.default_sample_rate, normalize_lower_threshold, normalize_upper_threshold, - normalize_gain_factor, verbose=self.verbose - ) - wavefiledata = audio_tools.convert_audio_datatype_to_integer(wavefiledata, np.int16) - wavefiledata = wavefiledata.tobytes() - - # remove silence from audio - if silence_cutting_enabled: - wavefiledata_np = np.frombuffer(wavefiledata, np.int16) - if len(wavefiledata_np) >= self.block_size_samples: - wavefiledata = audio_tools.remove_silence_parts( - wavefiledata_np, self.default_sample_rate, - silence_offset=silence_offset, max_silence_length=max_silence_length, keep_silence_length=keep_silence_length, - verbose=self.verbose - ) - wavefiledata = wavefiledata.tobytes() + def should_start_recording(peak_amplitude, energy, new_confidence, confidence_threshold, keyboard_key=None): + return ((keyboard_key is not None and keyboard.is_pressed( + keyboard_key)) or (0 < energy <= peak_amplitude and new_confidence >= confidence_threshold)) - # debug save of audio clip - # save_to_wav(wavefiledata, "resampled_audio_chunk.wav", self.default_sample_rate) - # check if the full audio clip is above the confidence threshold - vad_clip_test = settings.GetOption("vad_on_full_clip") - full_audio_confidence = 0. - if vad_clip_test: - audio_full_int16 = np.frombuffer(wavefiledata, np.int16) - audio_full_float32 = int2float(audio_full_int16) - full_audio_confidence = self.vad_model(torch.from_numpy(audio_full_float32), self.default_sample_rate).item() - print(full_audio_confidence) + def should_stop_recording(new_confidence, confidence_threshold, peak_amplitude, energy, pause_time, pause, + keyboard_key=None): + return (keyboard_key is not None and not keyboard.is_pressed(keyboard_key)) or ( + 0 < energy > peak_amplitude and (new_confidence < confidence_threshold or confidence_threshold == 0.0) and ( + time.time() - pause_time) > pause > 0.0) - if ((not vad_clip_test) or (vad_clip_test and full_audio_confidence >= confidence_threshold)) and len( - wavefiledata) > 0: - # denoise audio - if settings.GetOption("denoise_audio") and self.audio_enhancer is not None: - wavefiledata = self.audio_enhancer.enhance_audio(wavefiledata).tobytes() - # call sts plugin methods - call_plugin_sts(self.Plugins, wavefiledata, self.default_sample_rate) + def get_host_audio_api_names(): + audio = pyaudio.PyAudio() + host_api_count = audio.get_host_api_count() + host_api_names = {} + for i in range(host_api_count): + host_api_info = audio.get_host_api_info_by_index(i) + host_api_names[i] = host_api_info["name"] + return host_api_names + + + def get_default_audio_device_index_by_api(api, is_input=True): + devices = sd.query_devices() + api_info = sd.query_hostapis() + host_api_index = None + + for i, host_api in enumerate(api_info): + if api.lower() in host_api['name'].lower(): + host_api_index = i + break - audioprocessor.q.put( - {'time': time.time_ns(), 'data': audio_bytes_to_wav(wavefiledata), 'final': True}) - # vad_iterator.reset_states() # reset model states after each audio + if host_api_index is None: + return None - # write wav file if configured to do so - transcription_save_audio_dir = settings.GetOption("transcription_save_audio_dir") - if transcription_save_audio_dir is not None and transcription_save_audio_dir != "": - start_time_str = Utilities.ns_to_datetime(time.time_ns(), formatting='%Y-%m-%d %H_%M_%S-%f') - audio_file_name = f"audio_transcript_{start_time_str}.wav" + api_pyaudio_index, _ = get_audio_api_index_by_name(api) - transcription_save_audio_dir = Path(transcription_save_audio_dir) - audio_file_path = transcription_save_audio_dir / audio_file_name + default_device_index = api_info[host_api_index]['default_input_device' if is_input else 'default_output_device'] + default_device_name = devices[default_device_index]['name'] + return get_audio_device_index_by_name_and_api(default_device_name, api_pyaudio_index, is_input) - threading.Thread( - target=save_to_wav, - args=(wavefiledata, str(audio_file_path.resolve()), self.default_sample_rate,) - ).start() - # set typing indicator for VRChat and Websocket clients - typing_indicator_thread = threading.Thread(target=typing_indicator_function, - args=(self.osc_ip, self.osc_port, True)) - typing_indicator_thread.start() + def get_audio_device_index_by_name_and_api(name, api, is_input=True, default=None): + audio = pyaudio.PyAudio() + device_count = audio.get_device_count() + for i in range(device_count): + device_info = audio.get_device_info_by_index(i) + device_name = device_info["name"] + if isinstance(device_name, bytes): + device_name = Utilities.safe_decode(device_name) + if isinstance(name, bytes): + name = Utilities.safe_decode(name) + + if device_info["hostApi"] == api and device_info[ + "maxInputChannels" if is_input else "maxOutputChannels"] > 0 and name in device_name: + return i + return default + + + def get_audio_api_index_by_name(name): + audio = pyaudio.PyAudio() + host_api_count = audio.get_host_api_count() + for i in range(host_api_count): + host_api_info = audio.get_host_api_info_by_index(i) + if name.lower() in host_api_info["name"].lower(): + return i, host_api_info["name"] + return 0, "" + - self.frames = [] - self.start_time = time.time() - self.intermediate_time_start = time.time() - self.keyboard_rec_force_stop = False + def record_highest_peak_amplitude(device_index=-1, record_time=10): + py_audio = pyaudio.PyAudio() + + default_sample_rate = SAMPLE_RATE - if audio_chunk is None: + stream, needs_sample_rate_conversion, recorded_sample_rate, is_mono = audio_tools.start_recording_audio_stream( + device_index, + sample_format=FORMAT, + sample_rate=SAMPLE_RATE, + channels=CHANNELS, + chunk=CHUNK, + py_audio=py_audio, + ) + + highest_peak_amplitude = 0 + start_time = time.time() + + while time.time() - start_time < record_time: + audio_chunk = stream.read(CHUNK, exception_on_overflow=False) + # special case which seems to be needed for WASAPI + if needs_sample_rate_conversion: + audio_chunk = audio_tools.resample_audio(audio_chunk, recorded_sample_rate, default_sample_rate, -1, + is_mono=is_mono).tobytes() + + _, peak_amplitude = process_audio_chunk(audio_chunk, None, default_sample_rate) + highest_peak_amplitude = max(highest_peak_amplitude, peak_amplitude) + + stream.stop_stream() + stream.close() + + return highest_peak_amplitude + + + class AudioProcessor: + last_callback_time = time.time() + + def __init__(self, + default_sample_rate=SAMPLE_RATE, + previous_audio_chunk=None, + start_rec_on_volume_threshold=None, + push_to_talk_key=None, + keyboard_rec_force_stop=None, + vad_model=None, + needs_sample_rate_conversion=False, + recorded_sample_rate=None, + is_mono=False, + + plugins=None, + audio_enhancer=None, + + osc_ip=None, + osc_port=None, + + chunk=None, + channels=None, + sample_format=None, + + verbose=False + ): + if plugins is None: + plugins = [] + self.frames = [] + self.default_sample_rate = default_sample_rate + self.previous_audio_chunk = previous_audio_chunk + self.start_rec_on_volume_threshold = start_rec_on_volume_threshold + self.push_to_talk_key = push_to_talk_key + self.keyboard_rec_force_stop = keyboard_rec_force_stop + + self.vad_model = vad_model + + self.needs_sample_rate_conversion = needs_sample_rate_conversion + self.recorded_sample_rate = recorded_sample_rate + self.is_mono = is_mono + + self.Plugins = plugins + self.audio_enhancer = audio_enhancer + + self.osc_ip = osc_ip + self.osc_port = osc_port + + self.verbose = verbose + + self.start_time = time.time() + self.pause_time = time.time() + self.intermediate_time_start = time.time() + + self.block_size_samples = int(self.default_sample_rate * 0.400) # calculate block size in samples. (0.400 is the default block size of pyloudnorm) + + self.chunk = chunk + self.channels = channels + self.sample_format = sample_format + # run callback after timeout even if no audio was detected (and such callback not called by pyAudio) + #self.timer_reset_event = threading.Event() + #self.timer_thread = threading.Thread(target=self.timer_expired) + #self.timer_thread.start() + #self.timer_reset_event.set() + #self.last_callback_time = time.time() + + # The function to call when the timer expires + #def timer_expired(self): + # while True: + # current_time = time.time() + # time_since_last_callback = current_time - self.last_callback_time + # if self.recorded_sample_rate is not None: + # # wait double the chunk size to not accidentally call callback twice + # self.timer_reset_event.wait(timeout=(self.chunk / self.recorded_sample_rate)*2) + # if time_since_last_callback >= (self.chunk / self.recorded_sample_rate)*2 and len(self.frames) > 0: + # #print("Timer expired. Triggering callback.") + # try: + # print("Timer expired. Triggering callback.") + # self.callback(None, None, None, None) + # except Exception as e: + # print(e) + # self.timer_reset_event.clear() + + def callback(self, in_data, frame_count, time_info, status): + # Reset the timer each time the callback is triggered + #self.last_callback_time = time.time() + #self.timer_reset_event.set() + + if not settings.GetOption("stt_enabled"): return None, pyaudio.paContinue - # set start recording variable to true if the volume and voice confidence is above the threshold - if should_start_recording(peak_amplitude, energy, new_confidence, confidence_threshold, - keyboard_key=self.push_to_talk_key): - if self.verbose: - print("start recording - new_confidence: " + str(new_confidence) + " peak_amplitude: " + str(peak_amplitude)) - if not self.start_rec_on_volume_threshold: - # start processing_start event - typing_indicator_thread = threading.Thread(target=typing_indicator_function, - args=(self.osc_ip, self.osc_port, True)) - typing_indicator_thread.start() - if self.push_to_talk_key is not None and keyboard.is_pressed(self.push_to_talk_key): - self.keyboard_rec_force_stop = True - self.start_rec_on_volume_threshold = True - self.pause_time = time.time() - - # append audio frame to the list if the recording var is set and voice confidence is above the threshold (So it only adds the audio parts with speech) - if self.start_rec_on_volume_threshold and new_confidence >= confidence_threshold: - if self.verbose: - print("add chunk - new_confidence: " + str(new_confidence) + " peak_amplitude: " + str(peak_amplitude)) - # append previous audio chunk to improve recognition on too late audio recording starts - if self.previous_audio_chunk is not None: - self.frames.append(self.previous_audio_chunk) - - self.frames.append(audio_chunk) - self.start_time = time.time() - if settings.GetOption("realtime"): + # disable gradient calculation + with torch.no_grad(): + phrase_time_limit = settings.GetOption("phrase_time_limit") + pause = settings.GetOption("pause") + energy = settings.GetOption("energy") + if phrase_time_limit == 0: + phrase_time_limit = None + + silence_cutting_enabled = settings.GetOption("silence_cutting_enabled") + silence_offset = settings.GetOption("silence_offset") + max_silence_length = settings.GetOption("max_silence_length") + keep_silence_length = settings.GetOption("keep_silence_length") + + normalize_enabled = settings.GetOption("normalize_enabled") + normalize_lower_threshold = settings.GetOption("normalize_lower_threshold") + normalize_upper_threshold = settings.GetOption("normalize_upper_threshold") + normalize_gain_factor = settings.GetOption("normalize_gain_factor") + + clip_duration = phrase_time_limit + fps = 0 + if clip_duration is not None: + fps = int(self.recorded_sample_rate / CHUNK * clip_duration) + + end_time = time.time() + elapsed_time = end_time - self.start_time + elapsed_intermediate_time = end_time - self.intermediate_time_start + + confidence_threshold = float(settings.GetOption("vad_confidence_threshold")) + + #if settings.GetOption("denoise_audio") and audio_enhancer is not None and num_samples < DeepFilterNet.ModelParams().hop_size: + # #print("increase num_samples for denoising") + # num_samples = DeepFilterNet.ModelParams().hop_size + + #audio_chunk = stream.read(num_samples, exception_on_overflow=False) + + # denoise audio chunk + #if settings.GetOption("denoise_audio") and audio_enhancer is not None: + # record more audio to denoise if it's too short + #if len(audio_chunk) < DeepFilterNet.ModelParams().hop_size: + #while len(audio_chunk) < DeepFilterNet.ModelParams().hop_size: + # audio_chunk += stream.read(num_samples, exception_on_overflow=False) + #audio_chunk = audio_enhancer.enhance_audio(audio_chunk, recorded_sample_rate, default_sample_rate, is_mono=is_mono) + #needs_sample_rate_conversion = False + + # denoise audio chunk + #if settings.GetOption("denoise_audio") and audio_enhancer is not None: + # #if len(audio_chunk) < DeepFilterNet.ModelParams().hop_size: + # # while len(audio_chunk) < DeepFilterNet.ModelParams().hop_size * 2: + # # audio_chunk += stream.read(num_samples, exception_on_overflow=False) + # audio_chunk = audio_enhancer.enhance_audio(audio_chunk, recorded_sample_rate, default_sample_rate, is_mono=is_mono) + # needs_sample_rate_conversion = False + # #recorded_sample_rate = audio_enhancer.df_state.sr() + + test_audio_chunk = in_data + audio_chunk = in_data + # special case which seems to be needed for WASAPI + if self.needs_sample_rate_conversion and test_audio_chunk is not None: + test_audio_chunk = audio_tools.resample_audio(test_audio_chunk, self.recorded_sample_rate, self.default_sample_rate, -1, + is_mono=self.is_mono).tobytes() + + new_confidence, peak_amplitude = 0, 0 + if test_audio_chunk is not None: + new_confidence, peak_amplitude = process_audio_chunk(test_audio_chunk, self.vad_model, self.default_sample_rate) + + # put frames with recognized speech into a list and send to whisper + if (clip_duration is not None and len(self.frames) > fps) or ( + elapsed_time > pause > 0.0 and len(self.frames) > 0) or ( + self.keyboard_rec_force_stop and self.push_to_talk_key is not None and not keyboard.is_pressed( + self.push_to_talk_key) and len(self.frames) > 0): + clip = [] - frame_count = len(self.frames) - # send realtime intermediate results every x frames and every x seconds (making sure its at least x frame length) - if frame_count % settings.GetOption( - "realtime_frame_multiply") == 0 and elapsed_intermediate_time > settings.GetOption( - "realtime_frequency_time"): - # set typing indicator for VRChat but not websocket - typing_indicator_thread = threading.Thread(target=typing_indicator_function, - args=(self.osc_ip, self.osc_port, False)) - typing_indicator_thread.start() - # merge all frames to one audio clip - for i in range(0, len(self.frames)): + # merge all frames to one audio clip + for i in range(0, len(self.frames)): + if self.frames[i] is not None: clip.append(self.frames[i]) - if len(clip) > 0: - wavefiledata = b''.join(clip) - else: - return None, pyaudio.paContinue - - if self.needs_sample_rate_conversion: - wavefiledata = audio_tools.resample_audio(wavefiledata, self.recorded_sample_rate, self.default_sample_rate, -1, - is_mono=self.is_mono).tobytes() - - # normalize audio (and make sure it's longer or equal the default block size by pyloudnorm) - if normalize_enabled and len(wavefiledata) >= self.block_size_samples: - wavefiledata = audio_tools.convert_audio_datatype_to_float(np.frombuffer(wavefiledata, np.int16)) - wavefiledata, lufs = audio_tools.normalize_audio_lufs( - wavefiledata, self.default_sample_rate, normalize_lower_threshold, - normalize_upper_threshold, normalize_gain_factor, + if len(clip) > 0: + wavefiledata = b''.join(clip) + else: + return None, pyaudio.paContinue + + if self.needs_sample_rate_conversion: + wavefiledata = audio_tools.resample_audio(wavefiledata, self.recorded_sample_rate, self.default_sample_rate, -1, + is_mono=self.is_mono).tobytes() + + # normalize audio (and make sure it's longer or equal the default block size by pyloudnorm) + if normalize_enabled and len(wavefiledata) >= self.block_size_samples: + wavefiledata = audio_tools.convert_audio_datatype_to_float(np.frombuffer(wavefiledata, np.int16)) + wavefiledata, lufs = audio_tools.normalize_audio_lufs( + wavefiledata, self.default_sample_rate, normalize_lower_threshold, normalize_upper_threshold, + normalize_gain_factor, verbose=self.verbose + ) + wavefiledata = audio_tools.convert_audio_datatype_to_integer(wavefiledata, np.int16) + wavefiledata = wavefiledata.tobytes() + + # remove silence from audio + if silence_cutting_enabled: + wavefiledata_np = np.frombuffer(wavefiledata, np.int16) + if len(wavefiledata_np) >= self.block_size_samples: + wavefiledata = audio_tools.remove_silence_parts( + wavefiledata_np, self.default_sample_rate, + silence_offset=silence_offset, max_silence_length=max_silence_length, keep_silence_length=keep_silence_length, verbose=self.verbose ) - wavefiledata = audio_tools.convert_audio_datatype_to_integer(wavefiledata, np.int16) wavefiledata = wavefiledata.tobytes() - # remove silence from audio - if silence_cutting_enabled: - wavefiledata_np = np.frombuffer(wavefiledata, np.int16) - if len(wavefiledata_np) >= self.block_size_samples: - wavefiledata = audio_tools.remove_silence_parts( - wavefiledata_np, self.default_sample_rate, - silence_offset=silence_offset, max_silence_length=max_silence_length, keep_silence_length=keep_silence_length, + # debug save of audio clip + # save_to_wav(wavefiledata, "resampled_audio_chunk.wav", self.default_sample_rate) + + # check if the full audio clip is above the confidence threshold + vad_clip_test = settings.GetOption("vad_on_full_clip") + full_audio_confidence = 0. + if vad_clip_test: + audio_full_int16 = np.frombuffer(wavefiledata, np.int16) + audio_full_float32 = int2float(audio_full_int16) + full_audio_confidence = self.vad_model(torch.from_numpy(audio_full_float32), self.default_sample_rate).item() + print(full_audio_confidence) + + if ((not vad_clip_test) or (vad_clip_test and full_audio_confidence >= confidence_threshold)) and len( + wavefiledata) > 0: + # denoise audio + if settings.GetOption("denoise_audio") and self.audio_enhancer is not None: + wavefiledata = self.audio_enhancer.enhance_audio(wavefiledata).tobytes() + + # call sts plugin methods + call_plugin_sts(self.Plugins, wavefiledata, self.default_sample_rate) + + audioprocessor.q.put( + {'time': time.time_ns(), 'data': audio_bytes_to_wav(wavefiledata), 'final': True}) + # vad_iterator.reset_states() # reset model states after each audio + + # write wav file if configured to do so + transcription_save_audio_dir = settings.GetOption("transcription_save_audio_dir") + if transcription_save_audio_dir is not None and transcription_save_audio_dir != "": + start_time_str = Utilities.ns_to_datetime(time.time_ns(), formatting='%Y-%m-%d %H_%M_%S-%f') + audio_file_name = f"audio_transcript_{start_time_str}.wav" + + transcription_save_audio_dir = Path(transcription_save_audio_dir) + audio_file_path = transcription_save_audio_dir / audio_file_name + + threading.Thread( + target=save_to_wav, + args=(wavefiledata, str(audio_file_path.resolve()), self.default_sample_rate,) + ).start() + + # set typing indicator for VRChat and Websocket clients + typing_indicator_thread = threading.Thread(target=typing_indicator_function, + args=(self.osc_ip, self.osc_port, True)) + typing_indicator_thread.start() + + self.frames = [] + self.start_time = time.time() + self.intermediate_time_start = time.time() + self.keyboard_rec_force_stop = False + + if audio_chunk is None: + return None, pyaudio.paContinue + + # set start recording variable to true if the volume and voice confidence is above the threshold + if should_start_recording(peak_amplitude, energy, new_confidence, confidence_threshold, + keyboard_key=self.push_to_talk_key): + if self.verbose: + print("start recording - new_confidence: " + str(new_confidence) + " peak_amplitude: " + str(peak_amplitude)) + if not self.start_rec_on_volume_threshold: + # start processing_start event + typing_indicator_thread = threading.Thread(target=typing_indicator_function, + args=(self.osc_ip, self.osc_port, True)) + typing_indicator_thread.start() + if self.push_to_talk_key is not None and keyboard.is_pressed(self.push_to_talk_key): + self.keyboard_rec_force_stop = True + self.start_rec_on_volume_threshold = True + self.pause_time = time.time() + + # append audio frame to the list if the recording var is set and voice confidence is above the threshold (So it only adds the audio parts with speech) + if self.start_rec_on_volume_threshold and new_confidence >= confidence_threshold: + if self.verbose: + print("add chunk - new_confidence: " + str(new_confidence) + " peak_amplitude: " + str(peak_amplitude)) + # append previous audio chunk to improve recognition on too late audio recording starts + if self.previous_audio_chunk is not None: + self.frames.append(self.previous_audio_chunk) + + # TODO? send audio_chunk to plugins (RVC ) + # threading.Thread(target=call_plugin_sts_chunk, args=(self.Plugins, test_audio_chunk, self.default_sample_rate,)).start() + + self.frames.append(audio_chunk) + self.start_time = time.time() + if settings.GetOption("realtime"): + clip = [] + frame_count = len(self.frames) + # send realtime intermediate results every x frames and every x seconds (making sure its at least x frame length) + if frame_count % settings.GetOption( + "realtime_frame_multiply") == 0 and elapsed_intermediate_time > settings.GetOption( + "realtime_frequency_time"): + # set typing indicator for VRChat but not websocket + typing_indicator_thread = threading.Thread(target=typing_indicator_function, + args=(self.osc_ip, self.osc_port, False)) + typing_indicator_thread.start() + # merge all frames to one audio clip + for i in range(0, len(self.frames)): + clip.append(self.frames[i]) + + if len(clip) > 0: + wavefiledata = b''.join(clip) + else: + return None, pyaudio.paContinue + + if self.needs_sample_rate_conversion: + wavefiledata = audio_tools.resample_audio(wavefiledata, self.recorded_sample_rate, self.default_sample_rate, -1, + is_mono=self.is_mono).tobytes() + + # normalize audio (and make sure it's longer or equal the default block size by pyloudnorm) + if normalize_enabled and len(wavefiledata) >= self.block_size_samples: + wavefiledata = audio_tools.convert_audio_datatype_to_float(np.frombuffer(wavefiledata, np.int16)) + wavefiledata, lufs = audio_tools.normalize_audio_lufs( + wavefiledata, self.default_sample_rate, normalize_lower_threshold, + normalize_upper_threshold, normalize_gain_factor, verbose=self.verbose ) + wavefiledata = audio_tools.convert_audio_datatype_to_integer(wavefiledata, np.int16) wavefiledata = wavefiledata.tobytes() - if wavefiledata is not None and len(wavefiledata) > 0: - # denoise audio - if settings.GetOption("denoise_audio") and self.audio_enhancer is not None: - wavefiledata = self.audio_enhancer.enhance_audio(wavefiledata).tobytes() - - audioprocessor.q.put( - {'time': time.time_ns(), 'data': audio_bytes_to_wav(wavefiledata), 'final': False}) - else: - self.frames = [] - - self.intermediate_time_start = time.time() - - # stop recording if no speech is detected for pause seconds - if should_stop_recording(new_confidence, confidence_threshold, peak_amplitude, energy, self.pause_time, pause, - keyboard_key=self.push_to_talk_key): - self.start_rec_on_volume_threshold = False - self.intermediate_time_start = time.time() - if self.push_to_talk_key is not None and not keyboard.is_pressed( - self.push_to_talk_key) and self.keyboard_rec_force_stop: - self.keyboard_rec_force_stop = True + # remove silence from audio + if silence_cutting_enabled: + wavefiledata_np = np.frombuffer(wavefiledata, np.int16) + if len(wavefiledata_np) >= self.block_size_samples: + wavefiledata = audio_tools.remove_silence_parts( + wavefiledata_np, self.default_sample_rate, + silence_offset=silence_offset, max_silence_length=max_silence_length, keep_silence_length=keep_silence_length, + verbose=self.verbose + ) + wavefiledata = wavefiledata.tobytes() + + if wavefiledata is not None and len(wavefiledata) > 0: + # denoise audio + if settings.GetOption("denoise_audio") and self.audio_enhancer is not None: + wavefiledata = self.audio_enhancer.enhance_audio(wavefiledata).tobytes() + + audioprocessor.q.put( + {'time': time.time_ns(), 'data': audio_bytes_to_wav(wavefiledata), 'final': False}) + else: + self.frames = [] + + self.intermediate_time_start = time.time() + + # stop recording if no speech is detected for pause seconds + if should_stop_recording(new_confidence, confidence_threshold, peak_amplitude, energy, self.pause_time, pause, + keyboard_key=self.push_to_talk_key): + self.start_rec_on_volume_threshold = False + self.intermediate_time_start = time.time() + if self.push_to_talk_key is not None and not keyboard.is_pressed( + self.push_to_talk_key) and self.keyboard_rec_force_stop: + self.keyboard_rec_force_stop = True + else: + self.keyboard_rec_force_stop = False + + # save chunk as previous audio chunk to reuse later + if not self.start_rec_on_volume_threshold and ( + new_confidence < confidence_threshold or confidence_threshold == 0.0): + self.previous_audio_chunk = audio_chunk else: - self.keyboard_rec_force_stop = False + self.previous_audio_chunk = None + + #self.last_callback_time = time.time() + return in_data, pyaudio.paContinue + + + @click.command() + @click.option('--detect_energy', default=False, is_flag=True, + help='detect energy level after set time of seconds recording.', type=bool) + @click.option('--detect_energy_time', default=10, help='detect energy level time it records for.', type=int) + @click.option('--audio_input_device', default="Default", help='audio input device name. (used for detect_energy', + type=str) + @click.option('--ui_download', default=False, is_flag=True, + help='use UI application for downloads.', type=bool) + @click.option('--devices', default='False', help='print all available devices id', type=str) + @click.option('--device_index', default=-1, help='the id of the input device (-1 = default active Mic)', type=int) + @click.option('--device_out_index', default=-1, help='the id of the output device (-1 = default active Speaker)', + type=int) + @click.option('--audio_api', default='MME', help='the name of the audio API. ("MME", "DirectSound", "WASAPI")', + type=str) + @click.option('--sample_rate', default=whisper_audio.SAMPLE_RATE, help='sample rate of recording', type=int) + @click.option("--task", default="transcribe", + help="task for the model whether to only transcribe the audio or translate the audio to english", + type=click.Choice(["transcribe", "translate"])) + @click.option("--model", default="small", help="Model to use", type=click.Choice(available_models())) + @click.option("--language", default=None, + help="language spoken in the audio, specify None to perform language detection", + type=click.Choice(audioprocessor.whisper_get_languages_list_keys())) + @click.option("--condition_on_previous_text", default=False, + help="Feed it the previous result to keep it consistent across recognition windows, but makes it more prone to getting stuck in a failure loop", + is_flag=True, + type=bool) + @click.option("--energy", default=300, help="Energy level for mic to detect", type=int) + @click.option("--dynamic_energy", default=False, is_flag=True, help="Flag to enable dynamic engergy", type=bool) + @click.option("--pause", default=0.8, help="Pause time before entry ends", type=float) + @click.option("--phrase_time_limit", default=None, + help="phrase time limit before entry ends to break up long recognitions.", type=float) + @click.option("--osc_ip", default="127.0.0.1", help="IP to send OSC message to. Set to '0' to disable", type=str) + @click.option("--osc_port", default=9000, help="Port to send OSC message to. ('9000' as default for VRChat)", type=int) + @click.option("--osc_address", default="/chatbox/input", + help="The Address the OSC messages are send to. ('/chatbox/input' as default for VRChat)", type=str) + @click.option("--osc_convert_ascii", default='False', help="Convert Text to ASCII compatible when sending over OSC.", + type=str) + @click.option("--websocket_ip", default="0", help="IP where Websocket Server listens on. Set to '0' to disable", + type=str) + @click.option("--websocket_port", default=5000, help="Port where Websocket Server listens on. ('5000' as default)", + type=int) + @click.option("--ai_device", default=None, + help="The Device the AI is loaded on. can be 'cuda' or 'cpu'. default does autodetect", + type=click.Choice(["cuda", "cpu"])) + @click.option("--txt_translator", default="NLLB200", + help="The Model the AI is loading for text translations. can be 'NLLB200', 'M2M100' or 'None'. default is NLLB200", + type=click.Choice(["NLLB200", "M2M100"])) + @click.option("--txt_translator_size", default="small", + help="The Model size if M2M100 or NLLB200 text translator is used. can be 'small', 'medium' or 'large' for NLLB200 or 'small' or 'large' for M2M100. default is small.", + type=click.Choice(["small", "medium", "large"])) + @click.option("--txt_translator_device", default="auto", + help="The device used for text translation.", + type=click.Choice(["auto", "cuda", "cpu"])) + @click.option("--ocr_window_name", default="VRChat", + help="Window name of the application for OCR translations. (Default: 'VRChat')", type=str) + @click.option("--open_browser", default=False, + help="Open default Browser with websocket-remote on start. (requires --websocket_ip to be set as well)", + is_flag=True, type=bool) + @click.option("--config", default=None, + help="Use the specified config file instead of the default 'settings.yaml' (relative to the current path) [overwrites without asking!!!]", + type=str) + @click.option("--verbose", default=False, help="Whether to print verbose output", is_flag=True, type=bool) + @click.pass_context + def main(ctx, detect_energy, detect_energy_time, ui_download, devices, sample_rate, dynamic_energy, open_browser, + config, verbose, + **kwargs): + if str2bool(devices): + host_audio_api_names = get_host_audio_api_names() + audio = pyaudio.PyAudio() + # print all available host apis + print("-------------------------------------------------------------------") + print(" Host APIs ") + print("-------------------------------------------------------------------") + for i in range(audio.get_host_api_count()): + print(f"Host API {i}: {audio.get_host_api_info_by_index(i)['name']}") + print("") + print("-------------------------------------------------------------------") + print(" Input Devices ") + print(" In form of: DEVICE_NAME [Sample Rate=?] [Loopback?] (Index=INDEX) ") + print("-------------------------------------------------------------------") + for device in audio.get_device_info_generator(): + device_list_index = device["index"] + device_list_api = host_audio_api_names[device["hostApi"]] + device_list_name = device["name"] + device_list_sample_rate = int(device["defaultSampleRate"]) + device_list_max_channels = audio.get_device_info_by_index(device_list_index)['maxInputChannels'] + if device_list_max_channels >= 1: + print( + f"{device_list_name} [Sample Rate={device_list_sample_rate}, API={device_list_api}] (Index={device_list_index})") + print("") + print("-------------------------------------------------------------------") + print(" Output Devices ") + print("-------------------------------------------------------------------") + for device in audio.get_device_info_generator(): + device_list_index = device["index"] + device_list_api = host_audio_api_names[device["hostApi"]] + device_list_name = device["name"] + device_list_sample_rate = int(device["defaultSampleRate"]) + device_list_max_channels = audio.get_device_info_by_index(device_list_index)['maxOutputChannels'] + if device_list_max_channels >= 1: + print( + f"{device_list_name} [Sample Rate={device_list_sample_rate}, API={device_list_api}] (Index={device_list_index})") + return + + # is set to run energy detection + if detect_energy: + # get selected audio api + audio_api = "MME" + if settings.IsArgumentSetting(ctx, "audio_api"): + audio_api = ctx.params["audio_api"] + audio_api_index, audio_api_name = get_audio_api_index_by_name(audio_api) + + # get selected audio input device + device_index = None + if settings.IsArgumentSetting(ctx, "device_index"): + device_index = ctx.params["device_index"] + device_default_in_index = get_default_audio_device_index_by_api(audio_api, True) + + # get selected audio input device by name if possible + if settings.IsArgumentSetting(ctx, "audio_input_device"): + audio_input_device = ctx.params["audio_input_device"] + if audio_input_device is not None and audio_input_device != "": + if audio_input_device.lower() == "Default".lower(): + device_index = None + else: + device_index = get_audio_device_index_by_name_and_api(audio_input_device, audio_api_index, True, + device_index) + if device_index is None or device_index < 0: + device_index = device_default_in_index + + max_detected_energy = record_highest_peak_amplitude(device_index, detect_energy_time) + print("detected_energy: " + str(max_detected_energy)) + return + + # Load settings from file + if config is not None: + settings.SETTINGS_PATH = Path(Path.cwd() / config) + settings.LoadYaml(settings.SETTINGS_PATH) + + # set process id + settings.SetOption("process_id", os.getpid()) + + settings.SetOption("ui_download", ui_download) + + # enable stt by default + settings.SetOption("stt_enabled", True) + + # set initial settings + settings.SetOption("whisper_task", settings.GetArgumentSettingFallback(ctx, "task", "whisper_task")) + + # set audio settings + device_index = settings.GetArgumentSettingFallback(ctx, "device_index", "device_index") + settings.SetOption("device_index", + (device_index if device_index is None or device_index > -1 else None)) + device_out_index = settings.GetArgumentSettingFallback(ctx, "device_out_index", "device_out_index") + settings.SetOption("device_out_index", + (device_out_index if device_out_index is None or device_out_index > -1 else None)) + + audio_api = settings.SetOption("audio_api", settings.GetArgumentSettingFallback(ctx, "audio_api", "audio_api")) + audio_api_index, audio_api_name = get_audio_api_index_by_name(audio_api) - # save chunk as previous audio chunk to reuse later - if not self.start_rec_on_volume_threshold and ( - new_confidence < confidence_threshold or confidence_threshold == 0.0): - self.previous_audio_chunk = audio_chunk + audio_input_device = settings.GetOption("audio_input_device") + if audio_input_device is not None and audio_input_device != "": + if audio_input_device.lower() == "Default".lower(): + device_index = None else: - self.previous_audio_chunk = None - - #self.last_callback_time = time.time() - return in_data, pyaudio.paContinue - - -@click.command() -@click.option('--detect_energy', default=False, is_flag=True, - help='detect energy level after set time of seconds recording.', type=bool) -@click.option('--detect_energy_time', default=10, help='detect energy level time it records for.', type=int) -@click.option('--audio_input_device', default="Default", help='audio input device name. (used for detect_energy', - type=str) -@click.option('--ui_download', default=False, is_flag=True, - help='use UI application for downloads.', type=bool) -@click.option('--devices', default='False', help='print all available devices id', type=str) -@click.option('--device_index', default=-1, help='the id of the input device (-1 = default active Mic)', type=int) -@click.option('--device_out_index', default=-1, help='the id of the output device (-1 = default active Speaker)', - type=int) -@click.option('--audio_api', default='MME', help='the name of the audio API. ("MME", "DirectSound", "WASAPI")', - type=str) -@click.option('--sample_rate', default=whisper_audio.SAMPLE_RATE, help='sample rate of recording', type=int) -@click.option("--task", default="transcribe", - help="task for the model whether to only transcribe the audio or translate the audio to english", - type=click.Choice(["transcribe", "translate"])) -@click.option("--model", default="small", help="Model to use", type=click.Choice(available_models())) -@click.option("--language", default=None, - help="language spoken in the audio, specify None to perform language detection", - type=click.Choice(audioprocessor.whisper_get_languages_list_keys())) -@click.option("--condition_on_previous_text", default=False, - help="Feed it the previous result to keep it consistent across recognition windows, but makes it more prone to getting stuck in a failure loop", - is_flag=True, - type=bool) -@click.option("--energy", default=300, help="Energy level for mic to detect", type=int) -@click.option("--dynamic_energy", default=False, is_flag=True, help="Flag to enable dynamic engergy", type=bool) -@click.option("--pause", default=0.8, help="Pause time before entry ends", type=float) -@click.option("--phrase_time_limit", default=None, - help="phrase time limit before entry ends to break up long recognitions.", type=float) -@click.option("--osc_ip", default="127.0.0.1", help="IP to send OSC message to. Set to '0' to disable", type=str) -@click.option("--osc_port", default=9000, help="Port to send OSC message to. ('9000' as default for VRChat)", type=int) -@click.option("--osc_address", default="/chatbox/input", - help="The Address the OSC messages are send to. ('/chatbox/input' as default for VRChat)", type=str) -@click.option("--osc_convert_ascii", default='False', help="Convert Text to ASCII compatible when sending over OSC.", - type=str) -@click.option("--websocket_ip", default="0", help="IP where Websocket Server listens on. Set to '0' to disable", - type=str) -@click.option("--websocket_port", default=5000, help="Port where Websocket Server listens on. ('5000' as default)", - type=int) -@click.option("--ai_device", default=None, - help="The Device the AI is loaded on. can be 'cuda' or 'cpu'. default does autodetect", - type=click.Choice(["cuda", "cpu"])) -@click.option("--txt_translator", default="NLLB200", - help="The Model the AI is loading for text translations. can be 'NLLB200', 'M2M100' or 'None'. default is NLLB200", - type=click.Choice(["NLLB200", "M2M100"])) -@click.option("--txt_translator_size", default="small", - help="The Model size if M2M100 or NLLB200 text translator is used. can be 'small', 'medium' or 'large' for NLLB200 or 'small' or 'large' for M2M100. default is small.", - type=click.Choice(["small", "medium", "large"])) -@click.option("--txt_translator_device", default="auto", - help="The device used for text translation.", - type=click.Choice(["auto", "cuda", "cpu"])) -@click.option("--ocr_window_name", default="VRChat", - help="Window name of the application for OCR translations. (Default: 'VRChat')", type=str) -@click.option("--open_browser", default=False, - help="Open default Browser with websocket-remote on start. (requires --websocket_ip to be set as well)", - is_flag=True, type=bool) -@click.option("--config", default=None, - help="Use the specified config file instead of the default 'settings.yaml' (relative to the current path) [overwrites without asking!!!]", - type=str) -@click.option("--verbose", default=False, help="Whether to print verbose output", is_flag=True, type=bool) -@click.pass_context -def main(ctx, detect_energy, detect_energy_time, ui_download, devices, sample_rate, dynamic_energy, open_browser, - config, verbose, - **kwargs): - if str2bool(devices): - host_audio_api_names = get_host_audio_api_names() - audio = pyaudio.PyAudio() - # print all available host apis - print("-------------------------------------------------------------------") - print(" Host APIs ") - print("-------------------------------------------------------------------") - for i in range(audio.get_host_api_count()): - print(f"Host API {i}: {audio.get_host_api_info_by_index(i)['name']}") - print("") - print("-------------------------------------------------------------------") - print(" Input Devices ") - print(" In form of: DEVICE_NAME [Sample Rate=?] [Loopback?] (Index=INDEX) ") - print("-------------------------------------------------------------------") - for device in audio.get_device_info_generator(): - device_list_index = device["index"] - device_list_api = host_audio_api_names[device["hostApi"]] - device_list_name = device["name"] - device_list_sample_rate = int(device["defaultSampleRate"]) - device_list_max_channels = audio.get_device_info_by_index(device_list_index)['maxInputChannels'] - if device_list_max_channels >= 1: - print( - f"{device_list_name} [Sample Rate={device_list_sample_rate}, API={device_list_api}] (Index={device_list_index})") - print("") - print("-------------------------------------------------------------------") - print(" Output Devices ") - print("-------------------------------------------------------------------") - for device in audio.get_device_info_generator(): - device_list_index = device["index"] - device_list_api = host_audio_api_names[device["hostApi"]] - device_list_name = device["name"] - device_list_sample_rate = int(device["defaultSampleRate"]) - device_list_max_channels = audio.get_device_info_by_index(device_list_index)['maxOutputChannels'] - if device_list_max_channels >= 1: - print( - f"{device_list_name} [Sample Rate={device_list_sample_rate}, API={device_list_api}] (Index={device_list_index})") - return - - # is set to run energy detection - if detect_energy: - # get selected audio api - audio_api = "MME" - if settings.IsArgumentSetting(ctx, "audio_api"): - audio_api = ctx.params["audio_api"] - audio_api_index, audio_api_name = get_audio_api_index_by_name(audio_api) + device_index = get_audio_device_index_by_name_and_api(audio_input_device, audio_api_index, True, + device_index) + settings.SetOption("device_index", device_index) + + audio_output_device = settings.GetOption("audio_output_device") + if audio_output_device is not None and audio_output_device != "": + if audio_output_device.lower() == "Default".lower(): + device_out_index = None + else: + device_out_index = get_audio_device_index_by_name_and_api(audio_output_device, audio_api_index, False, + device_out_index) + settings.SetOption("device_out_index", device_out_index) - # get selected audio input device - device_index = None - if settings.IsArgumentSetting(ctx, "device_index"): - device_index = ctx.params["device_index"] + # set default devices: device_default_in_index = get_default_audio_device_index_by_api(audio_api, True) + device_default_out_index = get_default_audio_device_index_by_api(audio_api, False) + settings.SetOption("device_default_in_index", device_default_in_index) + settings.SetOption("device_default_out_index", device_default_out_index) - # get selected audio input device by name if possible - if settings.IsArgumentSetting(ctx, "audio_input_device"): - audio_input_device = ctx.params["audio_input_device"] - if audio_input_device is not None and audio_input_device != "": - if audio_input_device.lower() == "Default".lower(): - device_index = None - else: - device_index = get_audio_device_index_by_name_and_api(audio_input_device, audio_api_index, True, - device_index) - if device_index is None or device_index < 0: - device_index = device_default_in_index - - max_detected_energy = record_highest_peak_amplitude(device_index, detect_energy_time) - print("detected_energy: " + str(max_detected_energy)) - return + settings.SetOption("condition_on_previous_text", + settings.GetArgumentSettingFallback(ctx, "condition_on_previous_text", + "condition_on_previous_text")) + model = settings.SetOption("model", settings.GetArgumentSettingFallback(ctx, "model", "model")) - # Load settings from file - if config is not None: - settings.SETTINGS_PATH = Path(Path.cwd() / config) - settings.LoadYaml(settings.SETTINGS_PATH) + language = settings.SetOption("current_language", + settings.GetArgumentSettingFallback(ctx, "language", "current_language")) - # set process id - settings.SetOption("process_id", os.getpid()) + settings.SetOption("phrase_time_limit", settings.GetArgumentSettingFallback(ctx, "phrase_time_limit", + "phrase_time_limit")) - settings.SetOption("ui_download", ui_download) + pause = settings.SetOption("pause", settings.GetArgumentSettingFallback(ctx, "pause", "pause")) - # enable stt by default - settings.SetOption("stt_enabled", True) + energy = settings.SetOption("energy", settings.GetArgumentSettingFallback(ctx, "energy", "energy")) - # set initial settings - settings.SetOption("whisper_task", settings.GetArgumentSettingFallback(ctx, "task", "whisper_task")) + print("###################################") + print("# Whispering Tiger is starting... #") + print("###################################") - # set audio settings - device_index = settings.GetArgumentSettingFallback(ctx, "device_index", "device_index") - settings.SetOption("device_index", - (device_index if device_index is None or device_index > -1 else None)) - device_out_index = settings.GetArgumentSettingFallback(ctx, "device_out_index", "device_out_index") - settings.SetOption("device_out_index", - (device_out_index if device_out_index is None or device_out_index > -1 else None)) - - audio_api = settings.SetOption("audio_api", settings.GetArgumentSettingFallback(ctx, "audio_api", "audio_api")) - audio_api_index, audio_api_name = get_audio_api_index_by_name(audio_api) + print("running Python: " + platform.python_implementation() + " / v" + platform.python_version()) + print("using Audio API: " + audio_api_name) + print("") - audio_input_device = settings.GetOption("audio_input_device") - if audio_input_device is not None and audio_input_device != "": - if audio_input_device.lower() == "Default".lower(): - device_index = None + # check if english only model is loaded, and configure STT languages accordingly. + if model.endswith(".en") and "_whisper" in settings.GetOption("stt_type"): + if language is not None and language not in {"en", "English"}: + print(f"{model} is an English-only model but received '{language}' as language; using English instead.") + + print(f"{model} is an English-only model. only English speech is supported.") + settings.SetOption("whisper_languages", ({"code": "", "name": "Auto"}, {"code": "en", "name": "English"},)) + settings.SetOption("current_language", "en") + elif "_whisper" in settings.GetOption("stt_type") or "whisper_" in settings.GetOption("stt_type"): + settings.SetOption("whisper_languages", audioprocessor.whisper_get_languages()) + elif settings.GetOption("stt_type") == "seamless_m4t": + settings.SetOption("whisper_languages", audioprocessor.seamless_m4t_get_languages()) + elif settings.GetOption("stt_type") == "speech_t5": + # speech t5 only supports english + print(f"speechT5 is an English-only model. only English speech is supported.") + settings.SetOption("whisper_languages", ({"code": "", "name": "Auto"}, {"code": "en", "name": "English"},)) + settings.SetOption("current_language", "en") + elif settings.GetOption("stt_type") == "wav2vec_bert": + settings.SetOption("whisper_languages", audioprocessor.wav2vec_bert_get_languages()) + elif settings.GetOption("stt_type") == "nemo_canary": + settings.SetOption("whisper_languages", audioprocessor.nemo_canary_get_languages()) else: - device_index = get_audio_device_index_by_name_and_api(audio_input_device, audio_api_index, True, - device_index) - settings.SetOption("device_index", device_index) - - audio_output_device = settings.GetOption("audio_output_device") - if audio_output_device is not None and audio_output_device != "": - if audio_output_device.lower() == "Default".lower(): - device_out_index = None - else: - device_out_index = get_audio_device_index_by_name_and_api(audio_output_device, audio_api_index, False, - device_out_index) - settings.SetOption("device_out_index", device_out_index) - - # set default devices: - device_default_in_index = get_default_audio_device_index_by_api(audio_api, True) - device_default_out_index = get_default_audio_device_index_by_api(audio_api, False) - settings.SetOption("device_default_in_index", device_default_in_index) - settings.SetOption("device_default_out_index", device_default_out_index) - - settings.SetOption("condition_on_previous_text", - settings.GetArgumentSettingFallback(ctx, "condition_on_previous_text", - "condition_on_previous_text")) - model = settings.SetOption("model", settings.GetArgumentSettingFallback(ctx, "model", "model")) - - language = settings.SetOption("current_language", - settings.GetArgumentSettingFallback(ctx, "language", "current_language")) - - settings.SetOption("phrase_time_limit", settings.GetArgumentSettingFallback(ctx, "phrase_time_limit", - "phrase_time_limit")) - - pause = settings.SetOption("pause", settings.GetArgumentSettingFallback(ctx, "pause", "pause")) - - energy = settings.SetOption("energy", settings.GetArgumentSettingFallback(ctx, "energy", "energy")) - - print("###################################") - print("# Whispering Tiger is starting... #") - print("###################################") - - print("running Python: " + platform.python_implementation() + " / v" + platform.python_version()) - print("using Audio API: " + audio_api_name) - print("") - - # check if english only model is loaded, and configure STT languages accordingly. - if model.endswith(".en") and "_whisper" in settings.GetOption("stt_type"): - if language is not None and language not in {"en", "English"}: - print(f"{model} is an English-only model but received '{language}' as language; using English instead.") - - print(f"{model} is an English-only model. only English speech is supported.") - settings.SetOption("whisper_languages", ({"code": "", "name": "Auto"}, {"code": "en", "name": "English"},)) - settings.SetOption("current_language", "en") - elif "_whisper" in settings.GetOption("stt_type") or "whisper_" in settings.GetOption("stt_type"): - settings.SetOption("whisper_languages", audioprocessor.whisper_get_languages()) - elif settings.GetOption("stt_type") == "seamless_m4t": - settings.SetOption("whisper_languages", audioprocessor.seamless_m4t_get_languages()) - elif settings.GetOption("stt_type") == "speech_t5": - # speech t5 only supports english - print(f"speechT5 is an English-only model. only English speech is supported.") - settings.SetOption("whisper_languages", ({"code": "", "name": "Auto"}, {"code": "en", "name": "English"},)) - settings.SetOption("current_language", "en") - else: - # show no language if unspecified STT type - settings.SetOption("whisper_languages", ({"code": "", "name": ""},)) - - settings.SetOption("ai_device", settings.GetArgumentSettingFallback(ctx, "ai_device", "ai_device")) - settings.SetOption("verbose", verbose) - - osc_ip = settings.SetOption("osc_ip", settings.GetArgumentSettingFallback(ctx, "osc_ip", "osc_ip")) - osc_port = settings.SetOption("osc_port", settings.GetArgumentSettingFallback(ctx, "osc_port", "osc_port")) - settings.SetOption("osc_address", settings.GetArgumentSettingFallback(ctx, "osc_address", "osc_address")) - settings.SetOption("osc_convert_ascii", - str2bool(settings.GetArgumentSettingFallback(ctx, "osc_convert_ascii", "osc_convert_ascii"))) - osc_min_time_between_messages = settings.SetOption("osc_min_time_between_messages", settings.GetArgumentSettingFallback(ctx, "osc_min_time_between_messages", "osc_min_time_between_messages")) - VRC_OSCLib.set_min_time_between_messages(osc_min_time_between_messages) - - websocket_ip = settings.SetOption("websocket_ip", - settings.GetArgumentSettingFallback(ctx, "websocket_ip", "websocket_ip")) - websocket_port = settings.SetOption("websocket_port", - settings.GetArgumentSettingFallback(ctx, "websocket_port", "websocket_port")) - - txt_translator = settings.SetOption("txt_translator", - settings.GetArgumentSettingFallback(ctx, "txt_translator", "txt_translator")) - settings.SetOption("txt_translator_size", - settings.GetArgumentSettingFallback(ctx, "txt_translator_size", "txt_translator_size")) - - txt_translator_device = settings.SetOption("txt_translator_device", - settings.GetArgumentSettingFallback(ctx, "txt_translator_device", - "txt_translator_device")) - texttranslate.SetDevice(txt_translator_device) - - settings.SetOption("ocr_window_name", - settings.GetArgumentSettingFallback(ctx, "ocr_window_name", "ocr_window_name")) - - if websocket_ip != "0": - websocket.StartWebsocketServer(websocket_ip, websocket_port) - if open_browser: - open_url = 'file://' + os.getcwd() + '/websocket_clients/websocket-remote/index.html' + '?ws_server=ws://' + ( - "127.0.0.1" if websocket_ip == "0.0.0.0" else websocket_ip) + ':' + str(websocket_port) - remote_opener.openBrowser(open_url) - - if websocket_ip == "0" and open_browser: - print("--open_browser flag requres --websocket_ip to be set.") - - # initialize Silero TTS - try: - silero.init() - except Exception as e: - print(e) - - if ui_download: - # wait until ui is connected - print("waiting for ui to connect...") - max_wait = 15 # wait max 15 seconds for ui to connect - last_wait_time = time.time() - while len(websocket.WS_CLIENTS) == 0 and websocket.UI_CONNECTED["value"] is False: - time.sleep(0.1) - if time.time() - last_wait_time > max_wait: - print("timeout while waiting for ui to connect.") - ui_download = False - settings.SetOption("ui_download", ui_download) - break - if ui_download: # still true? then ui did connect - print("ui connected.") - time.sleep(0.5) - - # initialize plugins - import Plugins - print("initializing plugins...") - for plugin_inst in Plugins.plugins: - plugin_inst.init() - if plugin_inst.is_enabled(False): - print(plugin_inst.__class__.__name__ + " is enabled") - else: - print(plugin_inst.__class__.__name__ + " is disabled") - - # Load textual translation dependencies - if txt_translator.lower() != "none" and txt_translator != "": - websocket.set_loading_state("txt_transl_loading", True) + # show no language if unspecified STT type + settings.SetOption("whisper_languages", ({"code": "", "name": ""},)) + + settings.SetOption("ai_device", settings.GetArgumentSettingFallback(ctx, "ai_device", "ai_device")) + settings.SetOption("verbose", verbose) + + osc_ip = settings.SetOption("osc_ip", settings.GetArgumentSettingFallback(ctx, "osc_ip", "osc_ip")) + osc_port = settings.SetOption("osc_port", settings.GetArgumentSettingFallback(ctx, "osc_port", "osc_port")) + settings.SetOption("osc_address", settings.GetArgumentSettingFallback(ctx, "osc_address", "osc_address")) + settings.SetOption("osc_convert_ascii", + str2bool(settings.GetArgumentSettingFallback(ctx, "osc_convert_ascii", "osc_convert_ascii"))) + osc_min_time_between_messages = settings.SetOption("osc_min_time_between_messages", settings.GetArgumentSettingFallback(ctx, "osc_min_time_between_messages", "osc_min_time_between_messages")) + VRC_OSCLib.set_min_time_between_messages(osc_min_time_between_messages) + + websocket_ip = settings.SetOption("websocket_ip", + settings.GetArgumentSettingFallback(ctx, "websocket_ip", "websocket_ip")) + websocket_port = settings.SetOption("websocket_port", + settings.GetArgumentSettingFallback(ctx, "websocket_port", "websocket_port")) + + txt_translator = settings.SetOption("txt_translator", + settings.GetArgumentSettingFallback(ctx, "txt_translator", "txt_translator")) + settings.SetOption("txt_translator_size", + settings.GetArgumentSettingFallback(ctx, "txt_translator_size", "txt_translator_size")) + + txt_translator_device = settings.SetOption("txt_translator_device", + settings.GetArgumentSettingFallback(ctx, "txt_translator_device", + "txt_translator_device")) + texttranslate.SetDevice(txt_translator_device) + + settings.SetOption("ocr_window_name", + settings.GetArgumentSettingFallback(ctx, "ocr_window_name", "ocr_window_name")) + + if websocket_ip != "0": + websocket.StartWebsocketServer(websocket_ip, websocket_port) + if open_browser: + open_url = 'file://' + os.getcwd() + '/websocket_clients/websocket-remote/index.html' + '?ws_server=ws://' + ( + "127.0.0.1" if websocket_ip == "0.0.0.0" else websocket_ip) + ':' + str(websocket_port) + remote_opener.openBrowser(open_url) + + if websocket_ip == "0" and open_browser: + print("--open_browser flag requres --websocket_ip to be set.") + + # initialize Silero TTS try: - texttranslate.InstallLanguages() + silero.init() except Exception as e: print(e) - pass - websocket.set_loading_state("txt_transl_loading", False) - - # load nltk sentence splitting dependency - sentence_split.load_model() - - # Load language identification dependencies - languageClassification.download_model() - - # Download faster-whisper model - if settings.GetOption("stt_type") == "faster_whisper": - whisper_model = settings.GetOption("model") - whisper_precision = settings.GetOption("whisper_precision") - realtime_whisper_model = settings.GetOption("realtime_whisper_model") - realtime_whisper_precision = settings.GetOption("realtime_whisper_precision") - # download the model here since its only possible in the main thread - if faster_whisper.needs_download(whisper_model, whisper_precision): - websocket.set_loading_state("downloading_whisper_model", True) - faster_whisper.download_model(whisper_model, whisper_precision) - websocket.set_loading_state("downloading_whisper_model", False) - # download possibly needed realtime model - if realtime_whisper_model != "" and faster_whisper.needs_download(realtime_whisper_model, - realtime_whisper_precision): - websocket.set_loading_state("downloading_whisper_model", True) - faster_whisper.download_model(realtime_whisper_model, realtime_whisper_precision) - websocket.set_loading_state("downloading_whisper_model", False) - if settings.GetOption("stt_type") == "seamless_m4t": - stt_model_size = settings.GetOption("model") - if seamless_m4t.SeamlessM4T.needs_download(stt_model_size): - websocket.set_loading_state("downloading_whisper_model", True) - seamless_m4t.SeamlessM4T.download_model(stt_model_size) - websocket.set_loading_state("downloading_whisper_model", False) - - # load audio filter model - audio_enhancer = None - if settings.GetOption("denoise_audio"): - websocket.set_loading_state("loading_denoiser", True) - post_filter = settings.GetOption("denoise_audio_post_filter") - audio_enhancer = DeepFilterNet.DeepFilterNet(post_filter=post_filter) - websocket.set_loading_state("loading_denoiser", False) - - # prepare the plugin timer calls - call_plugin_timer(Plugins) - - vad_enabled = settings.SetOption("vad_enabled", - settings.GetArgumentSettingFallback(ctx, "vad_enabled", "vad_enabled")) - try: - vad_thread_num = int(float(settings.SetOption("vad_thread_num", - settings.GetArgumentSettingFallback(ctx, "vad_thread_num", "vad_thread_num")))) - except ValueError as e: - print("Error assigning vad_thread_num. using 1") - print(e) - vad_thread_num = int(1) - - if vad_enabled: - torch.hub.set_dir(str(Path(cache_vad_path).resolve())) - torch.set_num_threads(vad_thread_num) - try: - vad_model, vad_utils = torch.hub.load(trust_repo=True, skip_validation=True, - repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=False - ) - except: + + if ui_download: + # wait until ui is connected + print("waiting for ui to connect...") + max_wait = 15 # wait max 15 seconds for ui to connect + last_wait_time = time.time() + while len(websocket.WS_CLIENTS) == 0 and websocket.UI_CONNECTED["value"] is False: + time.sleep(0.1) + if time.time() - last_wait_time > max_wait: + print("timeout while waiting for ui to connect.") + ui_download = False + settings.SetOption("ui_download", ui_download) + break + if ui_download: # still true? then ui did connect + print("ui connected.") + time.sleep(0.5) + + # initialize plugins + import Plugins + print("initializing plugins...") + for plugin_inst in Plugins.plugins: + plugin_inst.init() + if plugin_inst.is_enabled(False): + print(plugin_inst.__class__.__name__ + " is enabled") + else: + print(plugin_inst.__class__.__name__ + " is disabled") + + # Load textual translation dependencies + if txt_translator.lower() != "none" and txt_translator != "": + websocket.set_loading_state("txt_transl_loading", True) try: - vad_model, vad_utils = torch.hub.load(trust_repo=True, skip_validation=True, - source="local", model="silero_vad", onnx=False, - repo_or_dir=str(Path( - cache_vad_path / "snakers4_silero-vad_master").resolve()) - ) + texttranslate.InstallLanguages() except Exception as e: - print("Error loading vad model trying to load from fallback server...") print(e) + pass + websocket.set_loading_state("txt_transl_loading", False) + + # load nltk sentence splitting dependency + sentence_split.load_model() + + # Load language identification dependencies + languageClassification.download_model() + + # Download faster-whisper model + if settings.GetOption("stt_type") == "faster_whisper": + whisper_model = settings.GetOption("model") + whisper_precision = settings.GetOption("whisper_precision") + realtime_whisper_model = settings.GetOption("realtime_whisper_model") + realtime_whisper_precision = settings.GetOption("realtime_whisper_precision") + # download the model here since its only possible in the main thread + if faster_whisper.needs_download(whisper_model, whisper_precision): + websocket.set_loading_state("downloading_whisper_model", True) + faster_whisper.download_model(whisper_model, whisper_precision) + websocket.set_loading_state("downloading_whisper_model", False) + # download possibly needed realtime model + if realtime_whisper_model != "" and faster_whisper.needs_download(realtime_whisper_model, + realtime_whisper_precision): + websocket.set_loading_state("downloading_whisper_model", True) + faster_whisper.download_model(realtime_whisper_model, realtime_whisper_precision) + websocket.set_loading_state("downloading_whisper_model", False) + if settings.GetOption("stt_type") == "seamless_m4t": + stt_model_size = settings.GetOption("model") + if seamless_m4t.SeamlessM4T.needs_download(stt_model_size): + websocket.set_loading_state("downloading_whisper_model", True) + seamless_m4t.SeamlessM4T.download_model(stt_model_size) + websocket.set_loading_state("downloading_whisper_model", False) + + # load audio filter model + audio_enhancer = None + if settings.GetOption("denoise_audio"): + websocket.set_loading_state("loading_denoiser", True) + post_filter = settings.GetOption("denoise_audio_post_filter") + audio_enhancer = DeepFilterNet.DeepFilterNet(post_filter=post_filter) + websocket.set_loading_state("loading_denoiser", False) + + # prepare the plugin timer calls + call_plugin_timer(Plugins) + + vad_enabled = settings.SetOption("vad_enabled", + settings.GetArgumentSettingFallback(ctx, "vad_enabled", "vad_enabled")) + try: + vad_thread_num = int(float(settings.SetOption("vad_thread_num", + settings.GetArgumentSettingFallback(ctx, "vad_thread_num", "vad_thread_num")))) + except ValueError as e: + print("Error assigning vad_thread_num. using 1") + print(e) + vad_thread_num = int(1) - vad_fallback_server = { - "urls": [ - "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/silero/silero-vad.zip", - "https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/silero/silero-vad.zip", - "https://s3.libs.space:9000/ai-models/silero/silero-vad.zip" - ], - "sha256": "097cfacdc2b2f5b09e0da1273b3e30b0e96c3588445958171a7e339cc5805683", - } - + if vad_enabled: + torch.hub.set_dir(str(Path(cache_vad_path).resolve())) + torch.set_num_threads(vad_thread_num) + try: + vad_model, vad_utils = torch.hub.load(trust_repo=True, skip_validation=True, + repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=False + ) + except: try: - downloader.download_extract(vad_fallback_server["urls"], - str(Path(cache_vad_path).resolve()), - vad_fallback_server["sha256"], - alt_fallback=True, - fallback_extract_func=downloader.extract_zip, - fallback_extract_func_args=( - str(Path(cache_vad_path / "silero-vad.zip")), - str(Path(cache_vad_path).resolve()), - ), - title="Silero VAD", extract_format="zip") - vad_model, vad_utils = torch.hub.load(trust_repo=True, skip_validation=True, source="local", model="silero_vad", onnx=False, repo_or_dir=str(Path( cache_vad_path / "snakers4_silero-vad_master").resolve()) ) - except Exception as e: - print("Error loading vad model.") + print("Error loading vad model trying to load from fallback server...") print(e) - # num_samples = 1536 - vad_frames_per_buffer = int(settings.SetOption("vad_frames_per_buffer", - settings.GetArgumentSettingFallback(ctx, "vad_frames_per_buffer", - "vad_frames_per_buffer"))) - - # set default devices if not set - if device_index is None or device_index < 0: - device_index = device_default_in_index - - #frames = [] - - default_sample_rate = SAMPLE_RATE - - previous_audio_chunk = None - - start_rec_on_volume_threshold = False - - push_to_talk_key = settings.GetOption("push_to_talk_key") - if push_to_talk_key == "": - push_to_talk_key = None - keyboard_rec_force_stop = False - - processor = AudioProcessor( - default_sample_rate=default_sample_rate, - previous_audio_chunk=previous_audio_chunk, - start_rec_on_volume_threshold=start_rec_on_volume_threshold, - push_to_talk_key=push_to_talk_key, - keyboard_rec_force_stop=keyboard_rec_force_stop, - vad_model=vad_model, - plugins=Plugins, - audio_enhancer=audio_enhancer, - osc_ip=osc_ip, - osc_port=osc_port, - chunk=vad_frames_per_buffer, - channels=CHANNELS, - sample_format=FORMAT, - verbose=verbose, - ) - - # initialize audio stream - stream, needs_sample_rate_conversion, recorded_sample_rate, is_mono = audio_tools.start_recording_audio_stream( - device_index, - sample_format=FORMAT, - sample_rate=SAMPLE_RATE, - channels=CHANNELS, - chunk=vad_frames_per_buffer, - py_audio=py_audio, - audio_processor=processor, - ) - - # Start the stream - stream.start_stream() - - #orig_recorded_sample_rate = recorded_sample_rate - - audioprocessor.start_whisper_thread() - - #continue_recording = True - - while stream.is_active(): - time.sleep(0.1) - #if not settings.GetOption("stt_enabled"): - # time.sleep(0.1) - # continue - - else: - # load the speech recognizer and set the initial energy threshold and pause threshold - r = sr.Recognizer() - r.energy_threshold = energy - r.pause_threshold = pause - r.dynamic_energy_threshold = dynamic_energy - - with sr.Microphone(sample_rate=whisper_audio.SAMPLE_RATE, - device_index=device_index) as source: + vad_fallback_server = { + "urls": [ + "https://eu2.contabostorage.com/bf1a89517e2643359087e5d8219c0c67:ai-models/silero/silero-vad.zip", + "https://usc1.contabostorage.com/8fcf133c506f4e688c7ab9ad537b5c18:ai-models/silero/silero-vad.zip", + "https://s3.libs.space:9000/ai-models/silero/silero-vad.zip" + ], + "sha256": "097cfacdc2b2f5b09e0da1273b3e30b0e96c3588445958171a7e339cc5805683", + } + + try: + downloader.download_extract(vad_fallback_server["urls"], + str(Path(cache_vad_path).resolve()), + vad_fallback_server["sha256"], + alt_fallback=True, + fallback_extract_func=downloader.extract_zip, + fallback_extract_func_args=( + str(Path(cache_vad_path / "silero-vad.zip")), + str(Path(cache_vad_path).resolve()), + ), + title="Silero VAD", extract_format="zip") + + vad_model, vad_utils = torch.hub.load(trust_repo=True, skip_validation=True, + source="local", model="silero_vad", onnx=False, + repo_or_dir=str(Path( + cache_vad_path / "snakers4_silero-vad_master").resolve()) + ) + + except Exception as e: + print("Error loading vad model.") + print(e) + + # num_samples = 1536 + vad_frames_per_buffer = int(settings.SetOption("vad_frames_per_buffer", + settings.GetArgumentSettingFallback(ctx, "vad_frames_per_buffer", + "vad_frames_per_buffer"))) + + # set default devices if not set + if device_index is None or device_index < 0: + device_index = device_default_in_index + + #frames = [] + + default_sample_rate = SAMPLE_RATE + + previous_audio_chunk = None + + start_rec_on_volume_threshold = False + + push_to_talk_key = settings.GetOption("push_to_talk_key") + if push_to_talk_key == "": + push_to_talk_key = None + keyboard_rec_force_stop = False + + processor = AudioProcessor( + default_sample_rate=default_sample_rate, + previous_audio_chunk=previous_audio_chunk, + start_rec_on_volume_threshold=start_rec_on_volume_threshold, + push_to_talk_key=push_to_talk_key, + keyboard_rec_force_stop=keyboard_rec_force_stop, + vad_model=vad_model, + plugins=Plugins, + audio_enhancer=audio_enhancer, + osc_ip=osc_ip, + osc_port=osc_port, + chunk=vad_frames_per_buffer, + channels=CHANNELS, + sample_format=FORMAT, + verbose=verbose, + ) + + # initialize audio stream + stream, needs_sample_rate_conversion, recorded_sample_rate, is_mono = audio_tools.start_recording_audio_stream( + device_index, + sample_format=FORMAT, + sample_rate=SAMPLE_RATE, + channels=CHANNELS, + chunk=vad_frames_per_buffer, + py_audio=py_audio, + audio_processor=processor, + ) + + # Start the stream + stream.start_stream() + + #orig_recorded_sample_rate = recorded_sample_rate audioprocessor.start_whisper_thread() - while True: - if not settings.GetOption("stt_enabled"): - time.sleep(0.1) - continue - - phrase_time_limit = settings.GetOption("phrase_time_limit") - if phrase_time_limit == 0: - phrase_time_limit = None - pause = settings.GetOption("pause") - energy = settings.GetOption("energy") + #continue_recording = True - r.energy_threshold = energy - r.pause_threshold = pause + while stream.is_active(): + time.sleep(0.1) + #if not settings.GetOption("stt_enabled"): + # time.sleep(0.1) + # continue - # get and save audio to wav file - audio = r.listen(source, phrase_time_limit=phrase_time_limit) - - audio_data = audio.get_wav_data() - - silence_cutting_enabled = settings.GetOption("silence_cutting_enabled") - silence_offset = settings.GetOption("silence_offset") - max_silence_length = settings.GetOption("max_silence_length") - keep_silence_length = settings.GetOption("keep_silence_length") - - normalize_enabled = settings.GetOption("normalize_enabled") - normalize_lower_threshold = settings.GetOption("normalize_lower_threshold") - normalize_upper_threshold = settings.GetOption("normalize_upper_threshold") - normalize_gain_factor = settings.GetOption("normalize_gain_factor") - block_size_samples = int(whisper_audio.SAMPLE_RATE * 0.400) - # normalize audio (and make sure it's longer or equal the default block size by pyloudnorm) - if normalize_enabled and len(audio_data) >= block_size_samples: - audio_data = audio_tools.convert_audio_datatype_to_float(np.frombuffer(audio_data, np.int16)) - audio_data, lufs = audio_tools.normalize_audio_lufs( - audio_data, whisper_audio.SAMPLE_RATE, normalize_lower_threshold, normalize_upper_threshold, - normalize_gain_factor, verbose=verbose - ) - audio_data = audio_tools.convert_audio_datatype_to_integer(audio_data, np.int16) - audio_data = audio_data.tobytes() - - # remove silence from audio - if silence_cutting_enabled: - audio_data_np = np.frombuffer(audio_data, np.int16) - if len(audio_data_np) >= block_size_samples: - audio_data = audio_tools.remove_silence_parts( - audio_data_np, whisper_audio.SAMPLE_RATE, - silence_offset=silence_offset, max_silence_length=max_silence_length, keep_silence_length=keep_silence_length, - verbose=verbose + else: + # load the speech recognizer and set the initial energy threshold and pause threshold + r = sr.Recognizer() + r.energy_threshold = energy + r.pause_threshold = pause + r.dynamic_energy_threshold = dynamic_energy + + with sr.Microphone(sample_rate=whisper_audio.SAMPLE_RATE, + device_index=device_index) as source: + + audioprocessor.start_whisper_thread() + + while True: + if not settings.GetOption("stt_enabled"): + time.sleep(0.1) + continue + + phrase_time_limit = settings.GetOption("phrase_time_limit") + if phrase_time_limit == 0: + phrase_time_limit = None + pause = settings.GetOption("pause") + energy = settings.GetOption("energy") + + r.energy_threshold = energy + r.pause_threshold = pause + + # get and save audio to wav file + audio = r.listen(source, phrase_time_limit=phrase_time_limit) + + audio_data = audio.get_wav_data() + + silence_cutting_enabled = settings.GetOption("silence_cutting_enabled") + silence_offset = settings.GetOption("silence_offset") + max_silence_length = settings.GetOption("max_silence_length") + keep_silence_length = settings.GetOption("keep_silence_length") + + normalize_enabled = settings.GetOption("normalize_enabled") + normalize_lower_threshold = settings.GetOption("normalize_lower_threshold") + normalize_upper_threshold = settings.GetOption("normalize_upper_threshold") + normalize_gain_factor = settings.GetOption("normalize_gain_factor") + block_size_samples = int(whisper_audio.SAMPLE_RATE * 0.400) + # normalize audio (and make sure it's longer or equal the default block size by pyloudnorm) + if normalize_enabled and len(audio_data) >= block_size_samples: + audio_data = audio_tools.convert_audio_datatype_to_float(np.frombuffer(audio_data, np.int16)) + audio_data, lufs = audio_tools.normalize_audio_lufs( + audio_data, whisper_audio.SAMPLE_RATE, normalize_lower_threshold, normalize_upper_threshold, + normalize_gain_factor, verbose=verbose ) + audio_data = audio_tools.convert_audio_datatype_to_integer(audio_data, np.int16) audio_data = audio_data.tobytes() - # denoise audio - if settings.GetOption("denoise_audio") and audio_enhancer is not None: - audio_data = audio_enhancer.enhance_audio(audio_data).tobytes() + # remove silence from audio + if silence_cutting_enabled: + audio_data_np = np.frombuffer(audio_data, np.int16) + if len(audio_data_np) >= block_size_samples: + audio_data = audio_tools.remove_silence_parts( + audio_data_np, whisper_audio.SAMPLE_RATE, + silence_offset=silence_offset, max_silence_length=max_silence_length, keep_silence_length=keep_silence_length, + verbose=verbose + ) + audio_data = audio_data.tobytes() - # add audio data to the queue - audioprocessor.q.put({'time': time.time_ns(), 'data': audio_bytes_to_wav(audio_data), 'final': True}) + # denoise audio + if settings.GetOption("denoise_audio") and audio_enhancer is not None: + audio_data = audio_enhancer.enhance_audio(audio_data).tobytes() - # set typing indicator for VRChat and websocket clients - typing_indicator_thread = threading.Thread(target=typing_indicator_function, - args=(osc_ip, osc_port, True)) - typing_indicator_thread.start() + # add audio data to the queue + audioprocessor.q.put({'time': time.time_ns(), 'data': audio_bytes_to_wav(audio_data), 'final': True}) + # set typing indicator for VRChat and websocket clients + typing_indicator_thread = threading.Thread(target=typing_indicator_function, + args=(osc_ip, osc_port, True)) + typing_indicator_thread.start() -def str2bool(string): - if type(string) == str: - str2val = {"true": True, "false": False} - if string.lower() in str2val: - return str2val[string.lower()] + + def str2bool(string): + if type(string) == str: + str2val = {"true": True, "false": False} + if string.lower() in str2val: + return str2val[string.lower()] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") else: - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") - else: - return bool(string) + return bool(string) -main() + #freeze_support() + main() diff --git a/audioWhisper.spec b/audioWhisper.spec index 89a0e77..0fa7ee6 100644 --- a/audioWhisper.spec +++ b/audioWhisper.spec @@ -7,7 +7,7 @@ import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) datas = [] binaries = [] -hiddenimports = ['torch', 'pytorch', 'torchaudio.lib.libtorchaudio', 'scipy.signal', 'transformers.models.nllb', 'sentencepiece', 'df.deepfilternet3', 'bitsandbytes', 'faiss-cpu', 'praat-parselmouth', 'parselmouth', 'pyworld', 'torchcrepe', 'grpcio', 'grpc'] +hiddenimports = ['torch', 'pytorch', 'torchaudio.lib.libtorchaudio', 'scipy.signal', 'transformers.models.nllb', 'sentencepiece', 'df.deepfilternet3', 'bitsandbytes', 'faiss-cpu', 'praat-parselmouth', 'parselmouth', 'pyworld', 'torchcrepe', 'grpcio', 'grpc', 'annotated_types', 'Cython', 'nemo_toolkit', 'nemo'] datas += collect_data_files('torch') datas += collect_data_files('whisper') datas += collect_data_files('pykakasi') @@ -72,6 +72,14 @@ tmp_ret = collect_all('grpcio') datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] tmp_ret = collect_all('grpc') datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] +tmp_ret = collect_all('annotated_types') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] +tmp_ret = collect_all('Cython') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] +tmp_ret = collect_all('nemo_toolkit') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] +tmp_ret = collect_all('nemo') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] workdir = os.environ.get('WORKDIR_WIN', r'\drone\src') workdir = "C:" + workdir # Now workdir = "C:\drone\src" diff --git a/audioprocessor.py b/audioprocessor.py index 5c51a43..9156571 100644 --- a/audioprocessor.py +++ b/audioprocessor.py @@ -24,6 +24,9 @@ import Models.STT.faster_whisper as faster_whisper import Models.STT.whisper_audio_markers as whisper_audio_markers import Models.STT.speecht5 as speech_t5 +import Models.STT.tansformer_whisper as transformer_whisper +import Models.STT.wav2vec_bert as wav2vec_bert +import Models.STT.nemo_canary as nemo_canary import Models.Multi.seamless_m4t as seamless_m4t import csv from datetime import datetime @@ -134,6 +137,15 @@ def seamless_m4t_get_languages(): return tuple([{"code": code, "name": language} for code, language in languages.items()]) +def wav2vec_bert_get_languages(): + wav2vec_bert_model = wav2vec_bert.Wav2VecBert() + return wav2vec_bert_model.get_languages() + + +def nemo_canary_get_languages(): + return nemo_canary.NemoCanary.get_languages() + + def remove_repetitions(text, language='english'): do_txt_translate = settings.GetOption("txt_translate") src_lang = settings.GetOption("src_lang") @@ -402,6 +414,24 @@ def load_whisper(model, ai_device): return speech_t5.SpeechT5STT(device=ai_device) except Exception as e: print("Failed to load speech t5 model. Application exits. " + str(e)) + elif stt_type == "transformer_whisper": + compute_dtype = settings.GetOption("whisper_precision") + try: + return transformer_whisper.TransformerWhisper(compute_type=compute_dtype, device=ai_device) + except Exception as e: + print("Failed to load transformer_whisper model. Application exits. " + str(e)) + elif stt_type == "wav2vec_bert": + compute_dtype = settings.GetOption("whisper_precision") + try: + return wav2vec_bert.Wav2VecBert(compute_type=compute_dtype, device=ai_device) + except Exception as e: + print("Failed to load Wav2VecBert model. Application exits. " + str(e)) + elif stt_type == "nemo_canary": + compute_dtype = settings.GetOption("whisper_precision") + try: + return nemo_canary.NemoCanary(compute_type=compute_dtype, device=ai_device) + except Exception as e: + print("Failed to load Nemo Canary model. Application exits. " + str(e)) # return None if no stt model is loaded return None @@ -422,6 +452,15 @@ def load_realtime_whisper(model, ai_device): return seamless_m4t.SeamlessM4T(model=model, compute_type=compute_dtype, device=ai_device) elif settings.GetOption("stt_type") == "speech_t5": return speech_t5.SpeechT5STT(device=ai_device) + elif settings.GetOption("stt_type") == "transformer_whisper": + compute_dtype = settings.GetOption("realtime_whisper_precision") + return transformer_whisper.TransformerWhisper(compute_type=compute_dtype, device=ai_device) + elif settings.GetOption("stt_type") == "wav2vec_bert": + compute_dtype = settings.GetOption("realtime_whisper_precision") + return wav2vec_bert.Wav2VecBert(compute_type=compute_dtype, device=ai_device) + elif settings.GetOption("stt_type") == "nemo_canary": + compute_dtype = settings.GetOption("realtime_whisper_precision") + return nemo_canary.NemoCanary(compute_type=compute_dtype, device=ai_device) def convert_audio(audio_bytes: bytes): @@ -678,11 +717,35 @@ def whisper_ai_thread(audio_data, current_audio_timestamp, audio_model, audio_mo repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size) - elif settings.GetOption("stt_type") == "speech_t5": # microsoft SpeechT5 result = audio_model.transcribe(audio_sample) + elif settings.GetOption("stt_type") == "transformer_whisper": + # Whisper Huggingface Transformer + audio_model.set_compute_type(settings.GetOption("whisper_precision")) + audio_model.set_compute_device(settings.GetOption("ai_device")) + result = audio_model.transcribe(audio_sample, model=settings.GetOption("model"), task=whisper_task, + language=whisper_language, return_timestamps=False, + beam_size=whisper_beam_size) + elif settings.GetOption("stt_type") == "wav2vec_bert": + # Wav2VecBert + audio_model.set_compute_type(settings.GetOption("whisper_precision")) + audio_model.set_compute_device(settings.GetOption("ai_device")) + result = audio_model.transcribe(audio_sample, task=whisper_task, + language=whisper_language) + + elif settings.GetOption("stt_type") == "nemo_canary": + # Nemo Canary + audio_model.set_compute_type(settings.GetOption("whisper_precision")) + audio_model.set_compute_device(settings.GetOption("ai_device")) + result = audio_model.transcribe(audio_sample, task=whisper_task, + source_lang=whisper_language, + target_lang=stt_target_language, + beam_size=whisper_beam_size, + length_penalty=whisper_faster_length_penalty, + temperature=1.0,) + if result is None or (last_whisper_result == result.get('text').strip() and not final_audio): print("skipping... result: ", result) return @@ -751,9 +814,13 @@ def whisper_worker(): # start processing audio thread if audio_model is not None if audio_model is not None: - threading.Thread(target=whisper_ai_thread, args=( - audio, audio_timestamp, audio_model, audio_model_realtime, last_whisper_result, final_audio), - daemon=True).start() + if settings.GetOption("thread_per_transcription"): + threading.Thread(target=whisper_ai_thread, args=( + audio, audio_timestamp, audio_model, audio_model_realtime, last_whisper_result, final_audio), + daemon=True).start() + else: + whisper_ai_thread(audio, audio_timestamp, audio_model, audio_model_realtime, last_whisper_result, + final_audio) def start_whisper_thread(): diff --git a/build-standalone.bat b/build-standalone.bat index eec4d4c..4c35040 100644 --- a/build-standalone.bat +++ b/build-standalone.bat @@ -19,6 +19,10 @@ rem --hidden-import=pyworld ^ rem --hidden-import=torchcrepe ^ rem --hidden-import=grpcio ^ rem --hidden-import=grpc ^ +rem --hidden-import=annotated_types ^ +rem --hidden-import=Cython ^ +rem --hidden-import=nemo_toolkit ^ +rem --hidden-import=nemo ^ rem --copy-metadata rich ^ rem --copy-metadata tqdm ^ rem --copy-metadata regex ^ @@ -57,6 +61,10 @@ rem --collect-all pyworld ^ rem --collect-all torchcrepe ^ rem --collect-all grpcio ^ rem --collect-all grpc ^ +rem --collect-all annotated_types ^ +rem --collect-all Cython ^ +rem --collect-all nemo_toolkit ^ +rem --collect-all nemo ^ rem --collect-submodules fairseq ^ rem --add-data ".cache/nltk/tokenizers/punkt;./nltk_data/tokenizers/punkt" ^ rem -i app-icon.ico diff --git a/downloader.py b/downloader.py index e02e5b8..0df5498 100644 --- a/downloader.py +++ b/downloader.py @@ -60,7 +60,7 @@ def download_extract(urls, extract_dir, checksum, title="", extract_format="", a "extract_format": extract_format}})) while True: if os.path.isfile(local_dl_file + ".finished"): - if sha256_checksum(local_dl_file + ".finished") == checksum: + if sha256_checksum(local_dl_file + ".finished") == checksum or checksum == "": success = True break else: diff --git a/ignorelist.txt b/ignorelist.txt index 4c1ec38..671e3ec 100644 --- a/ignorelist.txt +++ b/ignorelist.txt @@ -1,5 +1,6 @@ - +Thank you for watching Thanks for watching! Thanks for watching Thank you for watching! @@ -98,4 +99,5 @@ I hope you enjoyed this video, and I'll see you in the next one! Bye! and we'll see you next time. Subtitles by Stephanie Geiges Subtitles by the Amara.org community -PARROT TV \ No newline at end of file +PARROT TV +I'll see you next time. \ No newline at end of file diff --git a/requirements.nvidia.txt b/requirements.nvidia.txt index f2e0498..9f27e1c 100644 --- a/requirements.nvidia.txt +++ b/requirements.nvidia.txt @@ -1,5 +1,4 @@ -#--extra-index-url https://download.pytorch.org/whl/cu117 ---extra-index-url https://download.pytorch.org/whl/nightly/cu118 -torch==2.2.0.dev20231211+cu118 -torchaudio==2.2.0.dev20231211+cu118 -torchvision==0.17.0.dev20231211+cu118 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch==2.2.1 +torchvision==0.17.1 +torchaudio==2.2.1 diff --git a/requirements.txt b/requirements.txt index d3952d2..8bcdb74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,19 +8,24 @@ librosa==0.10.1 #transformers @ https://github.com/Sharrnah/transformers/archive/refs/heads/add_seamless-m4t.zip #transformers @ https://github.com/huggingface/transformers/archive/84724efd101af52ed3d6af878e41ff8fd651a9cc.zip #transformers==4.35.0 -transformers @ https://github.com/huggingface/transformers/archive/235e5d4991e8a0984aa78db91087b49622c7740e.zip +#transformers @ https://github.com/huggingface/transformers/archive/235e5d4991e8a0984aa78db91087b49622c7740e.zip +transformers==4.38.1 tensorboardX==2.6.2.2 -accelerate==0.21.0 +accelerate==0.26.1 +#optimum +#flash-attn #bitsandbytes==0.41.1 -bitsandbytes @ git+https://github.com/Keith-Hon/bitsandbytes-windows.git +# https://github.com/TimDettmers/bitsandbytes/actions/runs/7787696861/job/21236774833?pr=949 +#bitsandbytes @ git+https://github.com/Keith-Hon/bitsandbytes-windows.git +bitsandbytes @ https://s3.libs.space:9000/projects/wheels/bitsandbytes-0.43.0.dev0-cp311-cp311-win_amd64.whl ffmpeg-python==0.2.0 click>=8.1.3 -PyAudio==0.2.13 +PyAudio==0.2.14 PyAudioWPatch==0.2.12.6 resampy==0.4.2 sounddevice==0.4.6 -SpeechRecognition==3.10.0 +SpeechRecognition==3.10.1 pydub>=0.25.1 git+https://github.com/openai/whisper.git #triton @ https://github.com/PrashantSaikia/Triton-for-Windows/raw/84739dfcb724845b301fbde6a738e15c3ed25905/triton-2.0.0-cp310-cp310-win_amd64.whl @@ -29,15 +34,15 @@ triton @ https://s3.libs.space:9000/projects/wheels/triton-2.1.0-cp311-cp311-win soundfile==0.12.1 python-osc>=1.8.0 websockets>=10.4 -unidecode>=1.3.6 +unidecode==1.3.8 pykakasi>=2.2.1 -ctranslate2==3.22.0 -sentencepiece==0.1.99 +ctranslate2==4.0.0 +sentencepiece==0.2.0 protobuf==3.20.3 -progressbar2 +progressbar2==4.3.2 fasttext-wheel #best-download -robust-downloader @ https://github.com/Sharrnah/robust-downloader/archive/refs/heads/main.zip +robust-downloader @ https://github.com/fedebotu/robust-downloader/archive/refs/heads/main.zip # pywin32 required for easyOCR pywin32 #easyocr==1.7.0 @@ -45,8 +50,8 @@ pywin32 easyocr @ https://github.com/JaidedAI/EasyOCR/archive/refs/tags/v1.7.1.zip mss==7.0.1 scipy==1.10.1 -num2words==0.5.12 -onnxruntime==1.15.1 +num2words==0.5.13 +onnxruntime==1.17.1 requests==2.29.0 # downgradea of scikit-image to v1.19.3 to prevent https://github.com/scikit-image/scikit-image/issues/6784 scikit-image==v0.22.0 @@ -55,6 +60,11 @@ deepfilternet==0.5.6 pyloudnorm nltk +# NVIDIA Nemo (Canary) dependency +Cython +youtokentome @ https://github.com/gburlet/YouTokenToMe/archive/refs/heads/dependencies.zip +git+https://github.com/NVIDIA/NeMo.git@r1.23.0#egg=nemo_toolkit[asr] + # plugin dependencies omegaconf==2.2.3 PyYAML>=6.0 @@ -68,10 +78,12 @@ annotated_types==0.6.0 #fairseq @ https://github.com/Sharrnah/fairseq/archive/refs/heads/main.zip #fairseq @ https://github.com/Sharrnah/fairseq/releases/download/v0.12.4/fairseq-0.12.4-cp310-cp310-win_amd64.whl fairseq @ https://github.com/Sharrnah/fairseq/releases/download/v0.12.4/fairseq-0.12.4-cp311-cp311-win_amd64.whl -faiss-cpu==1.7.3 +faiss-cpu==1.7.4 praat-parselmouth>=0.4.2 pyworld==0.3.4 torchcrepe==0.0.22 #faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/refs/heads/master.zip -faster-whisper @ https://github.com/Sharrnah/faster-whisper/archive/refs/heads/large-v3.zip +faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/refs/heads/master.zip +#faster-whisper @ https://github.com/Sharrnah/faster-whisper/archive/refs/heads/large-v3.zip +#whisperx @ https://github.com/Sharrnah/whisperX/archive/refs/heads/main.zip diff --git a/settings.py b/settings.py index abd9675..9529f6b 100644 --- a/settings.py +++ b/settings.py @@ -1,9 +1,9 @@ # noinspection PyPackageRequirements +import sys import yaml import os from pathlib import Path from click import core -from whisper import available_models import threading import Utilities @@ -60,7 +60,7 @@ "repetition_penalty": 1.0, # penalize the score of previously generated tokens (set > 1 to penalize) "no_repeat_ngram_size": 0, # prevent repetitions of ngrams with this size "whisper_precision": "float32", # for original Whisper can be "float16" or "float32", for faster-whisper "default", "auto", "int8", "int8_float16", "int16", "float16", "float32". - "stt_type": "faster_whisper", # can be "faster_whisper", "original_whisper", "speech_t5" or "seamless_m4t". + "stt_type": "faster_whisper", # can be "faster_whisper", "original_whisper", "transformer_whisper", "speech_t5", "seamless_m4t" etc. "temperature_fallback": True, # Set to False to disable temperature fallback which is the reason for some slowdowns, but decreases quality. "beam_size": 5, # Beam size for beam search. (higher = more accurate, but slower) "whisper_cpu_threads": 0, # Number of threads to use when running on CPU (4 by default) @@ -90,6 +90,7 @@ "normalize_gain_factor": 2.0, "denoise_audio": False, # if enabled, audio will be de-noised before processing. "denoise_audio_post_filter": False, # Enable post filter for some minor, extra noise reduction. + "thread_per_transcription": False, # Use a separate thread for each transcription. "realtime": False, # if enabled, Whisper will process audio in realtime. "realtime_whisper_model": "", # model used for realtime transcription. (empty for using same model as model setting) @@ -226,7 +227,10 @@ def GetArgumentSettingFallback(ctx, argument_name, fallback_setting_name): def get_available_models(): - available_models_list = available_models() + available_models_list = [] + if 'whisper' not in sys.modules or 'available_models' not in dir(sys.modules['whisper']): + from whisper import available_models + available_models_list = available_models() # add custom models to list if GetOption("stt_type") == "faster_whisper": @@ -256,18 +260,18 @@ def GetAvailableSettingValues(): "ai_device": ["None", "cuda", "cpu"], "model": get_available_models(), "whisper_task": ["transcribe", "translate"], - "stt_type": ["faster_whisper", "original_whisper", "seamless_m4t", "speech_t5", ""], + "stt_type": ["faster_whisper", "original_whisper", "transformer_whisper", "seamless_m4t", "speech_t5", "wav2vec_bert", "nemo_canary", ""], "tts_ai_device": ["cuda", "cpu"], "txt_translator_device": ["cuda", "cpu"], "txt_translator": ["", "NLLB200_CT2", "NLLB200", "M2M100", "Seamless_M4T"], "txt_translator_size": ["small", "medium", "large"], - "txt_translator_precision": ["float32", "float16", "int16", "int8_float16", "int8", "bfloat16", "int8_bfloat16"], + "txt_translator_precision": ["float32", "float16", "int16", "int8_float16", "int8", "bfloat16", "int8_bfloat16", "4bit", "8bit"], "tts_prosody_rate": ["", "x-slow", "slow", "medium", "fast", "x-fast"], "tts_prosody_pitch": ["", "x-low", "low", "medium", "high", "x-high"], - "whisper_precision": ["float32", "float16", "int16", "int8_float16", "int8", "bfloat16", "int8_bfloat16"], + "whisper_precision": ["float32", "float16", "int16", "int8_float16", "int8", "bfloat16", "int8_bfloat16", "4bit", "8bit"], #"whisper_precision": ["float32", "float16", "int16", "int8_float16", "int8"], "realtime_whisper_model": [""] + get_available_models(), - "realtime_whisper_precision": ["float32", "float16", "int16", "int8_float16", "int8", "bfloat16", "int8_bfloat16"], + "realtime_whisper_precision": ["float32", "float16", "int16", "int8_float16", "int8", "bfloat16", "int8_bfloat16", "4bit", "8bit"], #"realtime_whisper_precision": ["float32", "float16", "int16", "int8_float16", "int8"], "osc_type_transfer": ["source", "translation_result", "both", "both_inverted"], "osc_send_type": ["full", "full_or_scroll", "scroll", "chunks"],