Skip to content

Commit

Permalink
fix(transcribe): fix censor
Browse files Browse the repository at this point in the history
re has been added to imports and censor_path added to params. The goal is to allow users to create their own censor json file to use rather than have it supplied to them. A check is used to verify the file exists if the censor flag is set, and if it does not or it is not the proper file tye, the censor is disabled. Segments and full text are both censored. The returned dict was set to a variable called "data" to allow this to occur. To do so another way would be text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]) if not censor else censor_text(tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), forbidden_words).... which is much more difficult to read.

BREAKING CHANGE: I have not confirmed issues yet, however it may be possible for the censor to bug if weird formats or improper design is put in place of the json file.

Signed-off-by: matt@aero <[email protected]>
  • Loading branch information
MotoMatt5040 committed Feb 5, 2025
1 parent 517a43e commit 8547848
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import traceback
import warnings
import json
import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -52,6 +54,8 @@ def transcribe(
append_punctuations: str = "\"'.。,,!!??::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
censor: bool = False,
censor_path: str = None,
**decode_options,
):
"""
Expand Down Expand Up @@ -124,6 +128,8 @@ def transcribe(
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""


dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
Expand Down Expand Up @@ -165,6 +171,21 @@ def transcribe(
task=task,
)

forbidden_words = []
if censor:
if (
censor_path is None
or not os.path.exists(censor_path)
or not censor_path.endswith(".json")
):
warnings.warn("Please provide a valid censor directory, censoring disabled.")
censor = False
else:
with open(f'{censor_path}', 'r') as f:
censor_data = json.load(f)

forbidden_words = censor_data.get(language, [])

if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
Expand Down Expand Up @@ -243,16 +264,32 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
else:
initial_prompt_tokens = []

def censor_text(text, forbidden):

def censor_match(match):
word = match.group(0)
return '*' * len(word) if word.lower() in forbidden_words else word

censored_text = re.sub(r'\w+|[^\w\s]', censor_match, text)

return censored_text

def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]

if censor:
text = censor_text(tokenizer.decode(text_tokens), forbidden_words)
else:
text = tokenizer.decode(text_tokens)

return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"text": text,
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
Expand Down Expand Up @@ -507,12 +544,19 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)

return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
text = tokenizer.decode(all_tokens[len(initial_prompt_tokens) :])

if censor:
text = censor_text(text, forbidden_words)

data = dict(
text=text,
segments=all_segments,
language=language,
)

return data


def cli():
from . import available_models
Expand All @@ -533,6 +577,8 @@ def valid_model_name(name):
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--censor", type=str2bool, default=True, help="(requires --censor_path=\"<path>\") whether to censor out profanity or not")
parser.add_argument("--censor_path", type=str2bool, default=True, help="censored words path. Use json format - {lang: [words]}")

parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
Expand Down

0 comments on commit 8547848

Please sign in to comment.