Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Whisper model validation. #241

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions software/source/server/services/stt/local-whisper/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from datetime import datetime
import os
import contextlib
import platform
import tempfile
import shutil
import ffmpeg
import subprocess

import urllib.request


Expand Down Expand Up @@ -56,21 +56,92 @@ def install(service_dir):
print("Whisper Rust executable already exists. Skipping build.")

WHISPER_MODEL_PATH = os.path.join(service_dir, "model")

WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin")
WHISPER_MODEL_URL = os.getenv(
"WHISPER_MODEL_URL",
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/",
)

if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)):
while not valid_model(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME):
print(f"Downloading Whisper model '{WHISPER_MODEL_NAME}'.")
WHISPER_MODEL_URL = os.getenv(
"WHISPER_MODEL_URL",
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/",
)
os.makedirs(WHISPER_MODEL_PATH, exist_ok=True)
urllib.request.urlretrieve(
f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}",
os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME),
)
else:
print("Whisper model already exists. Skipping download.")
print(f"Whisper model '{WHISPER_MODEL_NAME}' installed.")


def valid_model(model_path: str, model_file: str) -> bool:
# Try to validate model through cryptographic hash comparison

model_file_path = os.path.join(model_path, model_file)
if not os.path.isfile(model_file_path):
return False

# Download details file and get hash
details_file = f"https://huggingface.co/ggerganov/whisper.cpp/raw/main/{model_file}"
try:
with urllib.request.urlopen(details_file) as response:
body_bytes = response.read()
except:
print("Internet connection not detected. Skipping validation.")
return True

lines = body_bytes.splitlines()
colon_index = lines[1].find(b':')
details_hash = lines[1][colon_index + 1:].decode()

# Generate model hash using native commands
model_hash = None
system = platform.system()
if system == 'Darwin':
shasum_path = shutil.which('shasum')
model_hash = subprocess.check_output(
f"{shasum_path} -a 256 {model_file_path} | cut -d' ' -f1",
text=True,
shell=True
)
elif system == 'Linux':
sha256sum_path = shutil.which('sha256sum')
model_hash = subprocess.check_output(
f"{sha256sum_path} {model_file_path} | cut -d' ' -f1",
text=True,
shell=True
)
elif system == 'Windows':
comspec = os.getenv("COMSPEC")
if comspec.endswith('cmd.exe'): # Most likely
certutil_path = shutil.which('certutil')
first_op = f"{certutil_path} -hashfile {model_file_path} sha256"
second_op = 'findstr /v "SHA256 CertUtil"' # Prints only lines that do not contain a match.
model_hash = subprocess.check_output(f"{first_op} | {second_op}", text=True, shell=True)
else:
first_op = f"Get-FileHash -LiteralPath {model_file_path} -Algorithm SHA256"
subsequent_ops = "Select-Object Hash | Format-Table -HideTableHeaders | Out-String"
model_hash = subprocess.check_output([
'pwsh',
'-Command',
f"({first_op} | {subsequent_ops}).trim().toLower()"
],
text=True
)
else:
print(f"System '{system}' not supported. Skipping validation.")
return True

if details_hash == model_hash.strip():
print(f"Whisper model '{model_file}' file is valid.")
else:
msg = f'''
The model '{model_file}' did not validate. STT may not function correctly.
The model path is '{model_path}'.
Manually download and verify the model's hash to get better functionality.
Continuing.
'''
print(msg)

return True


def convert_mime_type_to_format(mime_type: str) -> str:
Expand Down