Skip to content

Commit

Permalink
Merge pull request #23 from Sharrnah/wav2vec_bert2
Browse files Browse the repository at this point in the history
updates, additional ai models
  • Loading branch information
Sharrnah authored Feb 28, 2024
2 parents fd47214 + 0329162 commit a5c59a9
Show file tree
Hide file tree
Showing 17 changed files with 2,283 additions and 1,148 deletions.
67 changes: 56 additions & 11 deletions Models/STT/faster_whisper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import gc

import torch
from faster_whisper import WhisperModel

from pathlib import Path
Expand Down Expand Up @@ -537,7 +540,7 @@ def needs_download(model: str, compute_type: str = "float32"):
if compute_type not in MODEL_LINKS[model]:
if compute_type == "float32":
model_path = Path(model_cache_path / (model + "-ct2-fp16"))
if compute_type == "float16":
elif compute_type == "float16":
model_path = Path(model_cache_path / (model + "-ct2"))

pretrained_lang_model_file = Path(model_path / "model.bin")
Expand All @@ -564,11 +567,10 @@ def download_model(model: str, compute_type: str = "float32"):
if compute_type == "float32":
compute_type = "float16"
model_path = Path(model_cache_path / (model + "-ct2-fp16"))
if compute_type == "float16":
elif compute_type == "float16":
compute_type = "float32"
model_path = Path(model_cache_path / (model + "-ct2"))


pretrained_lang_model_file = Path(model_path / "model.bin")

if not Path(model_path).exists() or not pretrained_lang_model_file.is_file():
Expand All @@ -595,24 +597,62 @@ def download_model(model: str, compute_type: str = "float32"):
class FasterWhisper(metaclass=SingletonMeta):
model = None
loaded_model_size = ""
loaded_settings = {}

transcription_count = 0
reload_after_transcriptions = 0

def __init__(self, model: str, device: str = "cpu", compute_type: str = "float32", cpu_threads: int = 0,
num_workers: int = 1):
if self.model is None:
self.load_model(model, device, compute_type, cpu_threads, num_workers)

def set_reload_after_transcriptions(self, reload_after_transcriptions: int):
self.reload_after_transcriptions = reload_after_transcriptions

def release_model(self):
print("Reloading model...")
if self.model is not None:
if hasattr(self.model, 'model'):
del self.model.model
if hasattr(self.model, 'feature_extractor'):
del self.model.feature_extractor
if hasattr(self.model, 'hf_tokenizer'):
del self.model.hf_tokenizer
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

self.load_model(
self.loaded_settings["model"],
self.loaded_settings["device"],
self.loaded_settings["compute_type"],
self.loaded_settings["cpu_threads"],
self.loaded_settings["num_workers"],
)

def load_model(self, model: str, device: str = "cpu", compute_type: str = "float32", cpu_threads: int = 0,
num_workers: int = 1):

self.loaded_settings = {
"model": model,
"device": device,
"compute_type": compute_type,
"cpu_threads": cpu_threads,
"num_workers": num_workers
}

model_cache_path = Path(".cache/whisper")
os.makedirs(model_cache_path, exist_ok=True)
model_folder_name = model + "-ct2"
if compute_type == "float16" or compute_type == "int8_float16":
if compute_type == "float16" or compute_type == "int8_float16" or compute_type == "int16" or compute_type == "int8":
model_folder_name = model + "-ct2-fp16"
# special case for models that are only available in one precision (as float16 vs float32 showed no difference in large-v3 and distilled versions)
if compute_type not in MODEL_LINKS[model]:
if compute_type == "float32":
model_folder_name = model + "-ct2-fp16"
if compute_type == "float16":
elif compute_type == "float16":
model_folder_name = model + "-ct2"
model_path = Path(model_cache_path / model_folder_name)

Expand All @@ -622,14 +662,15 @@ def load_model(self, model: str, device: str = "cpu", compute_type: str = "float

# temporary fix for large-v3 loading (https://github.com/guillaumekln/faster-whisper/issues/547)
# @TODO: this is a temporary fix for large-v3
n_mels = 80
use_tf_tokenizer = False
if model == "large-v3":
n_mels = 128
#use_tf_tokenizer = True
#n_mels = 80
#use_tf_tokenizer = False
#if model == "large-v3":
# n_mels = 128

#self.model = WhisperModel(str(Path(model_path).resolve()), device=device, compute_type=compute_type,
# cpu_threads=cpu_threads, num_workers=num_workers, feature_size=n_mels, use_tf_tokenizer=use_tf_tokenizer)
self.model = WhisperModel(str(Path(model_path).resolve()), device=device, compute_type=compute_type,
cpu_threads=cpu_threads, num_workers=num_workers, feature_size=n_mels, use_tf_tokenizer=use_tf_tokenizer)
cpu_threads=cpu_threads, num_workers=num_workers)

def transcribe(self, audio_sample, task, language, condition_on_previous_text,
initial_prompt, logprob_threshold, no_speech_threshold, temperature, beam_size,
Expand Down Expand Up @@ -660,4 +701,8 @@ def transcribe(self, audio_sample, task, language, condition_on_previous_text,
'language': audio_info.language
}

#self.transcription_count += 1
#if self.reload_after_transcriptions > 0 and (self.transcription_count % self.reload_after_transcriptions == 0):
# self.release_model()

return result
Loading

0 comments on commit a5c59a9

Please sign in to comment.