Skip to content

Commit

Permalink
refactor multilingual option (#1148)
Browse files Browse the repository at this point in the history
* Added test for `multilingual` option with english-german audio
* removed `output_language` argument as it is redundant, you can get the same functionality with `task="translate"`
* use the correct `encoder_output` for language detection in sequential transcription
* enabled `multilingual` functionality for batched inference
  • Loading branch information
MahmoudAshraf97 authored Nov 19, 2024
1 parent be9fb36 commit bcd8ce0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 39 deletions.
70 changes: 31 additions & 39 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class TranscriptionOptions:
prepend_punctuations: str
append_punctuations: str
multilingual: bool
output_language: Optional[str]
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]
Expand Down Expand Up @@ -210,10 +209,21 @@ def generate_segment_batched(
)

encoder_output = self.model.encode(features)
prompts = [prompt.copy() for _ in range(batch_size)]

if options.multilingual:
language_tokens = [
tokenizer.tokenizer.token_to_id(segment_langs[0][0])
for segment_langs in self.model.model.detect_language(encoder_output)
]
language_token_index = prompt.index(tokenizer.language)

for i, language_token in enumerate(language_tokens):
prompts[i][language_token_index] = language_token

results = self.model.model.generate(
encoder_output,
[prompt] * batch_size,
prompts,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
Expand Down Expand Up @@ -279,7 +289,6 @@ def transcribe(
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
Expand Down Expand Up @@ -322,6 +331,7 @@ def transcribe(
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
multilingual: Perform language detection on every segment.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
Expand Down Expand Up @@ -360,10 +370,6 @@ def transcribe(
Arg has effect only if condition_on_previous_text is True. Set at 0.5
prefix: Optional text to provide as a prefix at the beginning of each window.
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
multilingual: If True, perform transcription on multilingual videos. Set as False.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription). set as None.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
Expand All @@ -376,6 +382,13 @@ def transcribe(

sampling_rate = self.model.feature_extractor.sampling_rate

if multilingual and not self.model.model.is_multilingual:
self.model.logger.warning(
"The current model is English-only but the multilingual parameter is set to"
"True; setting to False instead."
)
multilingual = False

if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
Expand Down Expand Up @@ -498,8 +511,7 @@ def transcribe(
condition_on_previous_text=False,
clip_timestamps=clip_timestamps,
prompt_reset_on_temperature=0.5,
multilingual=False,
output_language=None,
multilingual=multilingual,
without_timestamps=without_timestamps,
max_initial_timestamp=0.0,
)
Expand Down Expand Up @@ -721,7 +733,6 @@ def transcribe(
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
Expand Down Expand Up @@ -781,12 +792,7 @@ def transcribe(
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
multilingual: If True, perform transcription on multilingual videos
and return the transcript based
on the 'output_language' flag.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription).
multilingual: Perform language detection on every segment.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
Expand Down Expand Up @@ -817,6 +823,13 @@ def transcribe(

sampling_rate = self.feature_extractor.sampling_rate

if multilingual and not self.model.is_multilingual:
self.logger.warning(
"The current model is English-only but the multilingual parameter is set to"
"True; setting to False instead."
)
multilingual = False

if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)

Expand Down Expand Up @@ -863,13 +876,6 @@ def transcribe(
encoder_output = None
all_language_probs = None

# setting output_language for multilingual videos
if multilingual:
if output_language is None:
output_language = "en"
elif output_language not in ["en", "hybrid"]:
raise ValueError("Output language needs to be one of 'en'/'hybrid'.")

# detecting the language if not provided
if language is None:
if not self.model.is_multilingual:
Expand Down Expand Up @@ -949,7 +955,6 @@ def transcribe(
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
multilingual=multilingual,
output_language=output_language,
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
Expand Down Expand Up @@ -1139,27 +1144,17 @@ def generate_segments(

previous_tokens = all_tokens[prompt_reset_since:]

if encoder_output is None:
if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment)

# Perform language detection at every segment to update task based on output language,
# if the language is english, task is transcribe,
# else the task is translate to english (default)
# or transcribe if 'output_language' is 'hybrid'.
if options.multilingual:
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
if options.output_language == "en" and language != "en":
task = "translate"
else:
task = "transcribe"

# Update tokenizer based on task and language
tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>")
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
tokenizer.language_code = language
# Update prompt based on task and language

prompt = self.get_prompt(
tokenizer,
previous_tokens,
Expand All @@ -1168,9 +1163,6 @@ def generate_segments(
hotwords=options.hotwords,
)

if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment)

(
result,
avg_logprob,
Expand Down
Binary file added tests/data/multilingual.mp3
Binary file not shown.
57 changes: 57 additions & 0 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,63 @@ def test_stereo_diarization(data_dir):
assert transcription == "The horizon seems extremely distant."


def test_multilingual_transcription(data_dir):
model = WhisperModel("tiny")
pipeline = BatchedInferencePipeline(model)

audio_path = os.path.join(data_dir, "multilingual.mp3")
audio = decode_audio(audio_path)

segments, info = model.transcribe(
audio,
multilingual=True,
without_timestamps=True,
condition_on_previous_text=False,
)
segments = list(segments)

assert (
segments[0].text
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
" software and associated documentation files to deal in the software without restriction,"
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
", sublicence, and or cell copies of the software, and to permit persons to whom the "
"software is furnished to do so, subject to the following conditions. The above copyright"
" notice and this permission notice, shall be included in all copies or substantial "
"portions of the software."
)

assert (
segments[1].text
== " Jedem, der dieses Software und die dazu gehöregen Dokumentationsdatein erhält, wird "
"hiermit unengeltlich die Genehmigung erteilt, wird der Software und eingeschränkt zu "
"verfahren. Dies umfasst insbesondere das Recht, die Software zu verwenden, zu "
"vervielfältigen, zu modifizieren, zu Samenzofügen, zu veröffentlichen, zu verteilen, "
"unterzulizenzieren und oder kopieren der Software zu verkaufen und diese Rechte "
"unterfolgen den Bedingungen anderen zu übertragen."
)

segments, info = pipeline.transcribe(audio, multilingual=True)
segments = list(segments)

assert (
segments[0].text
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
" software and associated documentation files to deal in the software without restriction,"
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
", sublicence, and or cell copies of the software, and to permit persons to whom the "
"software is furnished to do so, subject to the following conditions. The above copyright"
" notice and this permission notice, shall be included in all copies or substantial "
"portions of the software."
)
assert (
"Dokumentationsdatein erhält, wird hiermit unengeltlich die Genehmigung erteilt,"
" wird der Software und eingeschränkt zu verfahren. Dies umfasst insbesondere das Recht,"
" die Software zu verwenden, zu vervielfältigen, zu modifizieren"
in segments[1].text
)


def test_suppressed_tokens_minus_1():
model = WhisperModel("tiny.en")

Expand Down

0 comments on commit bcd8ce0

Please sign in to comment.