Skip to content

Commit

Permalink
[FEATURE] Added download function that works in threads
Browse files Browse the repository at this point in the history
[TASK] Added download arguments to the plugin tts method
  • Loading branch information
Sharrnah committed Apr 5, 2023
1 parent 7d8ba14 commit 5f38da2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def stt(self, text, result_obj):
pass

@abstractmethod
def tts(self, text, device_index):
def tts(self, text, device_index, websocket_connection=None, download=False):
pass

@abstractmethod
Expand Down
54 changes: 53 additions & 1 deletion downloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import threading
import zipfile
import os
from best_download import download_file

from contextlib import closing
import urllib.request
from urllib.parse import urlparse
import hashlib

# import logging
# logging.basicConfig(filename="download.log", level=logging.INFO)
Expand All @@ -18,3 +22,51 @@ def download_extract(urls, extract_dir, checksum):
os.remove(local_dl_file)

return success


def sha256_checksum(file_path):
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as file:
for chunk in iter(lambda: file.read(4096), b''):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()


def download_file_simple(url, target_path, expected_sha256=None):
progress_lock = threading.Lock()

def show_progress(count, total_size):
with progress_lock:
percentage = int(count * 100 / total_size)
print(f'\rDownloading {url}: {percentage}%', end='')

if os.path.isdir(target_path):
file_name = os.path.basename(urlparse(url).path)
target_path = os.path.join(target_path, file_name)

with closing(urllib.request.urlopen(url)) as remote_file:
headers = remote_file.info()
total_size = int(headers.get('Content-Length', -1))

with open(target_path, 'wb') as local_file:
block_size = 8192
downloaded_size = 0
for block in iter(lambda: remote_file.read(block_size), b''):
local_file.write(block)
downloaded_size += len(block)
show_progress(downloaded_size, total_size)
print()

if expected_sha256:
actual_sha256 = sha256_checksum(target_path)
if actual_sha256.lower() != expected_sha256.lower():
os.remove(target_path)
raise ValueError(f"Downloaded file has incorrect SHA256 hash. Expected {expected_sha256}, but got {actual_sha256}.")
else:
print("SHA256 hash verified.")


def download_thread(url, extract_dir, checksum):
dl_thread = threading.Thread(target=download_file_simple, args=(url, extract_dir, checksum))
dl_thread.start()
dl_thread.join()
1 change: 1 addition & 0 deletions ignorelist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Thanks for watching!
Thank you for watching!
Thanks for watching.
Thank you for watching.
Thanks for watching! Please like and subscribe!
Please subscribe!
Please subscribe to my channel!
Please subscribe to my channel.
Expand Down
12 changes: 9 additions & 3 deletions websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def tts_request(msgObj, websocket):
print("TTS failed")


def tts_plugin_process(msgObj, websocket):
def tts_plugin_process(msgObj, websocket, download=False):
text = msgObj["value"]["text"]
device = None
if msgObj["value"]["to_device"]:
Expand All @@ -49,7 +49,7 @@ def tts_plugin_process(msgObj, websocket):
device = settings.GetOption("device_out_index")

for plugin_inst in Plugins.plugins:
plugin_inst.tts(text, device)
plugin_inst.tts(text, device, websocket, download)


def ocr_req(msgObj, websocket):
Expand Down Expand Up @@ -114,12 +114,18 @@ def websocketMessageHandler(msgObj, websocket):
tts_thread = threading.Thread(target=tts_request, args=(msgObj, websocket))
tts_thread.start()
else:
tts_thread = threading.Thread(target=tts_plugin_process, args=(msgObj, websocket))
download = False
if not msgObj["value"]["to_device"]:
download = True
tts_thread = threading.Thread(target=tts_plugin_process, args=(msgObj, websocket, download))
tts_thread.start()

if msgObj["type"] == "tts_voice_save_req":
if silero.init():
silero.tts.save_voice()
else:
tts_thread = threading.Thread(target=tts_plugin_process, args=(msgObj, websocket, True))
tts_thread.start()

if msgObj["type"] == "get_windows_list":
windows_list = WindowCapture.list_window_names()
Expand Down

0 comments on commit 5f38da2

Please sign in to comment.