diff --git a/scraibe/cli.py b/scraibe/cli.py index e4eeaad..323a1b1 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -100,6 +100,7 @@ def str2bool(string): 'whisper_type':arg_dict.pop("whisper_type"), 'dia_model': arg_dict.pop("diarization_directory"), 'use_auth_token': arg_dict.pop("hf_token"), + 'device': arg_dict.pop('inference_device') } if arg_dict["whisper_model_directory"]: diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index bc341dc..a5d6f37 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -36,7 +36,7 @@ from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS +import scraibe.misc whisper = TypeVar('whisper') @@ -122,8 +122,8 @@ def save_transcript(transcript: str, save_path: str) -> None: def load_model(cls, model: str = "medium", whisper_type: str = 'whisper', - download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, + download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> None: @@ -204,8 +204,8 @@ def transcribe(self, audio: Union[str, Tensor, ndarray], @classmethod def load_model(cls, model: str = "medium", - download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, + download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> 'WhisperTranscriber': @@ -303,8 +303,8 @@ def transcribe(self, audio: Union[str, Tensor, ndarray], @classmethod def load_model(cls, model: str = "medium", - download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, + download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE, *args, **kwargs ) -> 'FasterWhisperModel': """ @@ -349,7 +349,7 @@ def load_model(cls, compute_type = 'int8' _model = FasterWhisperModel(model, download_root=download_root, device=device, compute_type=compute_type, - cpu_threads=SCRAIBE_NUM_THREADS) + cpu_threads=scraibe.misc.SCRAIBE_NUM_THREADS) return cls(_model, model_name=model) @@ -411,8 +411,8 @@ def __repr__(self) -> str: def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', - download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, + download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: