Skip to content

Commit

Permalink
[TASK] Add direct-ml support
Browse files Browse the repository at this point in the history
  • Loading branch information
Sharrnah committed Jul 20, 2024
1 parent 593447d commit f5cef2d
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 32 deletions.
26 changes: 16 additions & 10 deletions Models/Multi/mms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ class Mms(metaclass=SingletonMeta):
device = None
compute_type = "float32"
compute_device = "cpu"
compute_device_str = "cpu"
precision = None
load_in_8bit = False

Expand All @@ -1146,9 +1147,6 @@ class Mms(metaclass=SingletonMeta):
language_identification = None

def __init__(self, model='mms-1b-fl102', compute_type="float32", device="cpu"):
self.compute_type = compute_type
self.compute_device = device

self.load_model(model_size=model, compute_type=compute_type, device=device)

@staticmethod
Expand All @@ -1167,13 +1165,21 @@ def _str_to_dtype_dict(self, dtype_str):
else:
return {'dtype': torch.float32, '4bit': False, '8bit': False}

def set_device(self, device: str):
if device == "cuda" or device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "direct-ml":
def set_device(self, device: str | None):
self.compute_device_str = device
if device is None or device == "cuda" or device == "auto" or device == "":
self.compute_device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(self.compute_device_str)
elif device == "cpu":
device = torch.device("cpu")
elif device.startswith("direct-ml"):
device_id = 0
device_id_split = device.split(":")
if len(device_id_split) > 1:
device_id = int(device_id_split[1])
import torch_directml
device = torch_directml.device()
self.device = device
device = torch_directml.device(device_id)
self.compute_device = device

def load_model(self, model_size='mms-1b-fl102', compute_type="float32", device="cpu"):
model_path = Path(model_cache_path / model_size)
Expand All @@ -1183,7 +1189,7 @@ def load_model(self, model_size='mms-1b-fl102', compute_type="float32", device="
compute_4bit = self._str_to_dtype_dict(self.compute_type).get('4bit', False)
compute_8bit = self._str_to_dtype_dict(self.compute_type).get('8bit', False)
self.compute_type = compute_type
self.compute_device = device
self.set_device(device)

if self.model is None or model_size != self.previous_model:
if self.model is not None:
Expand Down
17 changes: 16 additions & 1 deletion Models/Multi/seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,29 @@ def __init__(self, model='medium', compute_type="float32", device="cpu"):

if self.device is None:
self.device = device
if device == "cuda":
if device == "cuda" or device == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif device.startswith("direct-ml"):
device_id = 0
device_id_split = device.split(":")
if len(device_id_split) > 1:
device_id = int(device_id_split[1])
import torch_directml
self.device = torch_directml.device(device_id)
if self.model is None or self.processor is None:
self.load_model(model_size=model)

def set_device(self, device: str):
self.device_str = device
if device == "cuda" or device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device.startswith("direct-ml"):
device_id = 0
device_id_split = device.split(":")
if len(device_id_split) > 1:
device_id = int(device_id_split[1])
import torch_directml
device = torch_directml.device(device_id)
self.device = device

@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions Models/STT/speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ def __init__(self, device="cpu"):
self.device = device
if device == "cuda":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif device.startswith("direct-ml"):
device_id = 0
device_id_split = device.split(":")
if len(device_id_split) > 1:
device_id = int(device_id_split[1])
import torch_directml
self.device = torch_directml.device(device_id)
if self.model is None:
self.load_model()

Expand Down
22 changes: 18 additions & 4 deletions Models/STT/tansformer_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TransformerWhisper(metaclass=SingletonMeta):
processor = None
compute_type = "float32"
compute_device = "cpu"
compute_device_str = "cpu"

text_correction_model = None

Expand All @@ -33,7 +34,7 @@ class TransformerWhisper(metaclass=SingletonMeta):
def __init__(self, compute_type="float32", device="cpu"):
os.makedirs(self.model_cache_path, exist_ok=True)
self.compute_type = compute_type
self.compute_device = device
self.set_compute_device(device)
self.load_model_list()

#if self._debug_skip_dl:
Expand All @@ -56,6 +57,19 @@ def set_compute_type(self, compute_type):
self.compute_type = compute_type

def set_compute_device(self, device):
self.compute_device_str = device
if device is None or device == "cuda" or device == "auto" or device == "":
self.compute_device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(self.compute_device_str)
elif device == "cpu":
device = torch.device("cpu")
elif device.startswith("direct-ml"):
device_id = 0
device_id_split = device.split(":")
if len(device_id_split) > 1:
device_id = int(device_id_split[1])
import torch_directml
device = torch_directml.device(device_id)
self.compute_device = device

def load_model_list(self):
Expand Down Expand Up @@ -108,7 +122,7 @@ def load_model(self, model='small', compute_type="float32", device="cpu"):
compute_8bit = self._str_to_dtype_dict(self.compute_type).get('8bit', False)
self.compute_type = compute_type

self.compute_device = device
self.set_compute_device(device)

if not self._debug_skip_dl:
self.download_model(model)
Expand All @@ -120,7 +134,7 @@ def load_model(self, model='small', compute_type="float32", device="cpu"):
self.previous_model = model
self.release_model()
print(f"Loading Whisper-Transformer model: {model} on {device} with {compute_type} precision...")
self.model = WhisperForConditionalGeneration.from_pretrained(str(Path(self.model_cache_path / model).resolve()), torch_dtype=compute_dtype, load_in_8bit=compute_8bit, load_in_4bit=compute_4bit)
self.model = WhisperForConditionalGeneration.from_pretrained(str(Path(self.model_cache_path / model).resolve()), torch_dtype=compute_dtype, load_in_8bit=compute_8bit, load_in_4bit=compute_4bit, device_map=self.compute_device)
if not compute_8bit and not compute_4bit:
self.model = self.model.to(self.compute_device)
self.processor = WhisperProcessor.from_pretrained(str(Path(self.model_cache_path / model).resolve()))
Expand All @@ -129,7 +143,7 @@ def load_model(self, model='small', compute_type="float32", device="cpu"):

def transcribe(self, audio_sample, model, task, language,
return_timestamps=False, beam_size=4) -> dict:
self.load_model(model, self.compute_type, self.compute_device)
self.load_model(model, self.compute_type, self.compute_device_str)

compute_dtype = self._str_to_dtype_dict(self.compute_type).get('dtype', torch.float32)

Expand Down
12 changes: 10 additions & 2 deletions Models/TTS/silero.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class Silero:
sample_rate = 48000
speaker = 'random'
models = []
device = "cpu" # cpu or cuda
device = "cpu" # cpu, cuda or direct-ml
rate = ""
pitch = ""

Expand Down Expand Up @@ -339,7 +339,15 @@ def load(self):
self.set_language(settings.GetOption('tts_model')[0])
self.set_model(settings.GetOption('tts_model')[1])

device = torch.device(self.device)
if self.device.startswith("direct-ml"):
device_id = 0
device_id_split = self.device.split(":")
if len(device_id_split) > 1:
device_id = int(device_id_split[1])
import torch_directml
device = torch_directml.device(device_id)
else:
device = torch.device(self.device)

# set cache path
torch.hub.set_dir(str(Path(cache_path).resolve()))
Expand Down
3 changes: 2 additions & 1 deletion audioWhisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def main(ctx, detect_energy, detect_energy_time, ui_download, devices, sample_ra
print(e)
vad_thread_num = int(1)

vad_model = None
if vad_enabled:
vad_model = VAD.VAD(vad_thread_num)

Expand All @@ -567,7 +568,7 @@ def main(ctx, detect_energy, detect_energy_time, ui_download, devices, sample_ra
# prepare the plugin timer calls
call_plugin_timer(Plugins)

if vad_enabled:
if vad_enabled and vad_model is not None:
# num_samples = 1536
vad_frames_per_buffer = int(settings.SETTINGS.SetOption("vad_frames_per_buffer",
settings.SETTINGS.get_argument_setting_fallback(ctx, "vad_frames_per_buffer",
Expand Down
4 changes: 3 additions & 1 deletion audioWhisper.spec
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ binaries = []
# Collect dynamic libraries from onnxruntime
binaries= collect_dynamic_libs('onnxruntime', destdir='onnxruntime/capi')

hiddenimports = ['torch', 'pytorch', 'torchaudio.lib.libtorchaudio', 'scipy.signal', 'transformers.models.nllb', 'sentencepiece', 'df.deepfilternet3', 'bitsandbytes', 'faiss', 'faiss-cpu', 'praat-parselmouth', 'parselmouth', 'pyworld', 'torchcrepe', 'grpcio', 'grpc', 'annotated_types', 'Cython', 'nemo_toolkit', 'nemo', 'speechbrain', 'pyannote', 'pyannote.audio', 'pyannote.pipeline', 'noisereduce', 'frozendict']
hiddenimports = ['torch', 'pytorch', 'torchaudio.lib.libtorchaudio', 'scipy.signal', 'transformers.models.nllb', 'sentencepiece', 'df.deepfilternet3', 'bitsandbytes', 'faiss', 'faiss-cpu', 'praat-parselmouth', 'parselmouth', 'pyworld', 'torchcrepe', 'grpcio', 'grpc', 'annotated_types', 'Cython', 'nemo_toolkit', 'nemo', 'speechbrain', 'pyannote', 'pyannote.audio', 'pyannote.pipeline', 'noisereduce', 'frozendict', 'torch_directml']
datas += collect_data_files('torch')
datas += collect_data_files('whisper')
datas += collect_data_files('pykakasi')
Expand Down Expand Up @@ -110,6 +110,8 @@ tmp_ret = collect_all('noisereduce')
datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2]
tmp_ret = collect_all('frozendict')
datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2]
tmp_ret = collect_all('torch_directml')
datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2]

workdir = os.environ.get('WORKDIR_WIN', r'\drone\src')
workdir = "C:" + workdir # Now workdir = "C:\drone\src"
Expand Down
9 changes: 7 additions & 2 deletions audioprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,14 @@ def load_whisper(model, ai_device):
if stt_type == "original_whisper":
try:
set_ai_device = ai_device
if ai_device == "direct-ml":

if ai_device.startswith("direct-ml"):
device_id = 0
device_id_split = ai_device.split(":")
if len(device_id_split) > 1:
device_id = int(device_id_split[1])
import torch_directml
set_ai_device = torch_directml.device()
set_ai_device = torch_directml.device(device_id)
return whisper.load_model(model, download_root=".cache/whisper", device=set_ai_device)
except Exception as e:
print("Failed to load whisper model. Application exits. " + str(e))
Expand Down
1 change: 0 additions & 1 deletion requirements.amd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@
torch==2.2.0.dev20231211+cpu
torchaudio==2.2.0.dev20231211+cpu
torchvision==0.17.0.dev20231211+cpu
torch-directml
6 changes: 3 additions & 3 deletions requirements.nvidia.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.2.1
torchvision==0.17.1
torchaudio==2.2.1
torch==2.3.1
torchvision==0.18.1
torchaudio==2.3.1
13 changes: 7 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# numpy v1.23.4 required for whisper
numpy==1.24.2
tqdm
tqdm==4.66.4
rich==12.6.0
more-itertools
librosa==0.10.1
more-itertools==10.3.0
librosa==0.10.2.post1
#transformers==4.33.2
#transformers @ https://github.com/Sharrnah/transformers/archive/refs/heads/add_seamless-m4t.zip
#transformers @ https://github.com/huggingface/transformers/archive/84724efd101af52ed3d6af878e41ff8fd651a9cc.zip
#transformers==4.35.0
#transformers @ https://github.com/huggingface/transformers/archive/235e5d4991e8a0984aa78db91087b49622c7740e.zip
transformers==4.42.3
transformers==4.42.4

torch-directml
tensorboardX==2.6.2.2
accelerate==0.30.1
accelerate==0.32.1
#optimum
#flash-attn
#bitsandbytes==0.41.1
Expand All @@ -25,7 +26,7 @@ PyAudio==0.2.14
PyAudioWPatch==0.2.12.6
resampy==0.4.3
sounddevice==0.4.7
SpeechRecognition==3.10.1
SpeechRecognition==3.10.4
pydub>=0.25.1
git+https://github.com/openai/whisper.git
#triton @ https://github.com/PrashantSaikia/Triton-for-Windows/raw/84739dfcb724845b301fbde6a738e15c3ed25905/triton-2.0.0-cp310-cp310-win_amd64.whl
Expand Down
2 changes: 1 addition & 1 deletion settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def get_available_models(self):

def get_available_setting_values(self):
possible_settings = {
"ai_device": ["None", "cuda", "cpu"],
"ai_device": ["None", "cuda", "cpu", "direct-ml:0", "direct-ml:1"],
"model": self.get_available_models(),
"whisper_task": ["transcribe", "translate"],
"stt_type": ["faster_whisper", "original_whisper", "transformer_whisper", "seamless_m4t", "mms", "speech_t5", "wav2vec_bert", "nemo_canary", ""],
Expand Down

0 comments on commit f5cef2d

Please sign in to comment.