Skip to content

Commit

Permalink
Automatically download converted models from the Hugging Face Hub (SY…
Browse files Browse the repository at this point in the history
…STRAN#70)

* Automatically download converted models from the Hugging Face Hub

* Remove unused import

* Remove non needed requirements in dev mode

* Remove extra index URL when pip install in CI

* Allow downloading to a specific directory

* Update docstring

* Add argument to disable the progess bars

* Fix typo in docstring
  • Loading branch information
guillaumekln authored Mar 24, 2023
1 parent 523ae21 commit de7682a
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 53 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install module
run: |
pip install wheel
pip install .[dev] --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .[dev]
- name: Check code format with Black
run: |
Expand Down Expand Up @@ -55,11 +55,11 @@ jobs:
- name: Install module
run: |
pip install wheel
pip install .[dev] --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .[dev]
- name: Run pytest
run: |
pytest -v tests/test.py
pytest -v tests/
build-and-push-package:
Expand Down
49 changes: 24 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/
pip install faster-whisper
```

The model conversion script requires the modules `transformers` and `torch` which can be installed with the `[conversion]` extra requirement:

```bash
pip install faster-whisper[conversion]
```

**Other installation methods:**

```bash
Expand All @@ -70,35 +64,20 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst

## Usage

### Model conversion

A Whisper model should be first converted into the CTranslate2 format. We provide a script to download and convert models from the [Hugging Face model repository](https://huggingface.co/models?sort=downloads&search=whisper).

For example the command below converts the "large-v2" Whisper model and saves the weights in FP16:

```bash
ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \
--copy_files tokenizer.json --quantization float16
```

If the option `--copy_files tokenizer.json` is not used, the tokenizer configuration is automatically downloaded when the model is loaded later.

Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).

### Transcription

```python
from faster_whisper import WhisperModel

model_path = "whisper-large-v2-ct2/"
model_size = "large-v2"

# Run on GPU with FP16
model = WhisperModel(model_path, device="cuda", compute_type="float16")
model = WhisperModel(model_size, device="cuda", compute_type="float16")

# or run on GPU with INT8
# model = WhisperModel(model_path, device="cuda", compute_type="int8_float16")
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# model = WhisperModel(model_path, device="cpu", compute_type="int8")
# model = WhisperModel(model_size, device="cpu", compute_type="int8")

segments, info = model.transcribe("audio.mp3", beam_size=5)

Expand All @@ -120,6 +99,26 @@ for segment in segments:

See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation.

## Model conversion

When loading a model from its size such as `WhisperModel("large-v2")`, the correspondig CTranslate2 model is automatically downloaded from the [Hugging Face Hub](https://huggingface.co/guillaumekln).

We also provide a script to convert any Whisper models compatible with the Transformers library. They could be the original OpenAI models or user fine-tuned models.

For example the command below converts the [original "large-v2" Whisper model](https://huggingface.co/openai/whisper-large-v2) and saves the weights in FP16:

```bash
pip install transformers[torch]>=4.23

ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \
--copy_files tokenizer.json --quantization float16
```

* The option `--model` accepts a model name on the Hub or a path to a model directory.
* If the option `--copy_files tokenizer.json` is not used, the tokenizer configuration is automatically downloaded when the model is loaded later.

Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).

## Comparing performance against other implementations

If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular:
Expand Down
3 changes: 2 additions & 1 deletion faster_whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from faster_whisper.audio import decode_audio
from faster_whisper.transcribe import WhisperModel
from faster_whisper.utils import format_timestamp
from faster_whisper.utils import download_model, format_timestamp

__all__ = [
"decode_audio",
"WhisperModel",
"download_model",
"format_timestamp",
]
12 changes: 10 additions & 2 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model


class Word(NamedTuple):
Expand Down Expand Up @@ -57,7 +58,7 @@ class TranscriptionOptions(NamedTuple):
class WhisperModel:
def __init__(
self,
model_path: str,
model_size_or_path: str,
device: str = "auto",
device_index: Union[int, List[int]] = 0,
compute_type: str = "default",
Expand All @@ -67,7 +68,9 @@ def __init__(
"""Initializes the Whisper model.
Args:
model_path: Path to the converted model.
model_size_or_path: Size of the model to use (e.g. "large-v2", "small", "tiny.en", etc.)
or a path to a converted model directory. When a size is configured, the converted
model is downloaded from the Hugging Face Hub.
device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use.
The model can also be loaded on multiple GPUs by passing a list of IDs
Expand All @@ -82,6 +85,11 @@ def __init__(
(concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage.
"""
if os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
model_path = download_model(model_size_or_path)

self.model = ctranslate2.models.Whisper(
model_path,
device=device,
Expand Down
45 changes: 45 additions & 0 deletions faster_whisper/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
from typing import Optional

import huggingface_hub

from tqdm.auto import tqdm


def download_model(
size: str,
output_dir: Optional[str] = None,
show_progress_bars: bool = True,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
The model is downloaded from https://huggingface.co/guillaumekln.
Args:
size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
medium, medium.en, or large-v2).
output_dir: Directory where the model should be saved. If not set, the model is saved in
the standard Hugging Face cache directory.
show_progress_bars: Show the tqdm progress bars during the download.
Returns:
The path to the downloaded model.
"""
repo_id = "guillaumekln/faster-whisper-%s" % size
kwargs = {}

if output_dir is not None:
kwargs["local_dir"] = output_dir
kwargs["local_dir_use_symlinks"] = False

if not show_progress_bars:
kwargs["tqdm_class"] = disabled_tqdm

return huggingface_hub.snapshot_download(repo_id, **kwargs)


def format_timestamp(
seconds: float,
always_include_hours: bool = False,
Expand All @@ -19,3 +58,9 @@ def format_timestamp(
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)


class disabled_tqdm(tqdm):
def __init__(self, *args, **kwargs):
kwargs["disable"] = True
super().__init__(*args, **kwargs)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
av==10.*
ctranslate2>=3.10,<4
huggingface_hub>=0.13
tokenizers==0.13.*
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def get_requirements(path):
install_requires=install_requires,
extras_require={
"conversion": conversion_requires,
"dev": conversion_requires
+ [
"dev": [
"black==23.*",
"flake8==6.*",
"isort==5.*",
Expand Down
18 changes: 0 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

import ctranslate2
import pytest


Expand All @@ -12,20 +11,3 @@ def data_dir():
@pytest.fixture
def jfk_path(data_dir):
return os.path.join(data_dir, "jfk.flac")


@pytest.fixture(scope="session")
def tiny_model_dir(tmp_path_factory):
model_path = str(tmp_path_factory.mktemp("data") / "model")
convert_model("tiny", model_path)
return model_path


def convert_model(size, output_dir):
name = "openai/whisper-%s" % size

ctranslate2.converters.TransformersConverter(
name,
copy_files=["tokenizer.json"],
load_as_float16=True,
).convert(output_dir, quantization="float16")
4 changes: 2 additions & 2 deletions tests/test.py → tests/test_transcribe.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from faster_whisper import WhisperModel


def test_transcribe(tiny_model_dir, jfk_path):
model = WhisperModel(tiny_model_dir)
def test_transcribe(jfk_path):
model = WhisperModel("tiny")
segments, info = model.transcribe(jfk_path, word_timestamps=True)

assert info.language == "en"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os

from faster_whisper import download_model


def test_download_model(tmpdir):
output_dir = str(tmpdir.join("model"))

model_dir = download_model("tiny", output_dir=output_dir)

assert model_dir == output_dir
assert os.path.isdir(model_dir)
assert not os.path.islink(model_dir)

for filename in os.listdir(model_dir):
path = os.path.join(model_dir, filename)
assert not os.path.islink(path)

0 comments on commit de7682a

Please sign in to comment.