Skip to content

Commit

Permalink
Merge pull request #141 from JSchmie/fix-inference-device-setting
Browse files Browse the repository at this point in the history
Fix: Now uses device if set in cli parameters
JSchmie authored Dec 2, 2024
2 parents d00ec2d + 1e40e4c commit 931f50c
Showing 2 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions scraibe/cli.py
Original file line number Diff line number Diff line change
@@ -100,6 +100,7 @@ def str2bool(string):
'whisper_type':arg_dict.pop("whisper_type"),
'dia_model': arg_dict.pop("diarization_directory"),
'use_auth_token': arg_dict.pop("hf_token"),
'device': arg_dict.pop('inference_device')
}

if arg_dict["whisper_model_directory"]:
20 changes: 10 additions & 10 deletions scraibe/transcriber.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@
from abc import abstractmethod
import warnings

from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS
import scraibe.misc
whisper = TypeVar('whisper')


@@ -122,8 +122,8 @@ def save_transcript(transcript: str, save_path: str) -> None:
def load_model(cls,
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> None:
@@ -204,8 +204,8 @@ def transcribe(self, audio: Union[str, Tensor, ndarray],
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> 'WhisperTranscriber':
@@ -303,8 +303,8 @@ def transcribe(self, audio: Union[str, Tensor, ndarray],
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> 'FasterWhisperModel':
"""
@@ -349,7 +349,7 @@ def load_model(cls,
compute_type = 'int8'
_model = FasterWhisperModel(model, download_root=download_root,
device=device, compute_type=compute_type,
cpu_threads=SCRAIBE_NUM_THREADS)
cpu_threads=scraibe.misc.SCRAIBE_NUM_THREADS)

return cls(_model, model_name=model)

@@ -411,8 +411,8 @@ def __repr__(self) -> str:

def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:

0 comments on commit 931f50c

Please sign in to comment.