-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(whisper): support whisper arch #2141
Conversation
Looking forward to the |
Confirm that import torchaudio
import numpy as np
from subprocess import CalledProcessError, run
wav_file = "BAC009S0724W0121.wav"
# 1. torchaudio
waveform_torchaudio, sample_rate = torchaudio.load(wav_file)
waveform_torchaudio = waveform_torchaudio.numpy().flatten().astype(np.float32)
# 2. whisper
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
wavform_whisper = load_audio(wav_file)
# 3. compare
print(waveform_torchaudio.shape, waveform_torchaudio[:10])
print(wavform_whisper.shape, wavform_whisper[:10])
print(np.allclose(waveform_torchaudio, wavform_whisper)) |
fbank这一块需要改动吗 |
正在验证的就是这个 |
Compared to import torch
import torchaudio
import librosa
import numpy as np
import torch.nn.functional as F
import torchaudio.transforms as T
from functools import lru_cache
from typing import Optional, Union
from subprocess import CalledProcessError, run
wav_file = "BAC009S0724W0121.wav"
N_MEL = 128
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
WINDOW_LENGTH = N_FFT
# 1. torchaudio
waveform_torchaudio, sample_rate = torchaudio.load(wav_file)
def torchaudio_log_mel_spectrogram(
audio: torch.Tensor,
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_transform = T.MelScale(
n_mels=N_MEL,
sample_rate=SAMPLE_RATE,
n_stft=(N_FFT // 2) + 1,
norm="slaney",
mel_scale="slaney"
)
mel_spec = mel_transform(magnitudes).squeeze(0)
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
mat_torchaudio = torchaudio_log_mel_spectrogram(waveform_torchaudio).numpy()
# 2. librosa
def librosa_log_mel_spectrogram(
audio: torch.Tensor,
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = torch.from_numpy(
librosa.filters.mel(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MEL)
).to(magnitudes.device)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10().squeeze(0)
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
mat_librosa = librosa_log_mel_spectrogram(waveform_torchaudio).numpy()
# 3. whisper
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = "mel_filters.npz"
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor
containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
mat_whisper = log_mel_spectrogram(
wav_file, n_mels=N_MEL, padding=0, device="cpu").numpy()
# 3. compare
print("torchaudio\n", mat_torchaudio.shape, mat_torchaudio[:10])
print("librosa\n", mat_librosa.shape, mat_librosa[:10])
print("whisper\n", mat_whisper.shape, mat_whisper[:10])
print("=================== librosa v.s. whisper =====================")
print("librosa v.s. whisper", np.allclose(mat_librosa, mat_whisper, atol=1e-06))
np.testing.assert_allclose(mat_librosa, mat_whisper, atol=1e-06)
print("=================== torchaudio v.s. whisper =====================")
print("torchaudio v.s. whisper", np.allclose(mat_torchaudio, mat_whisper, atol=1e-04))
np.testing.assert_allclose(mat_torchaudio, mat_whisper, atol=1e-04) |
confirm that #!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-11-15] <[email protected], Xingchen Song>
import torch
import math
import numpy as np
def wenet_sinusoids(length, channels):
"""Returns sinusoids for positional embedding"""
d_model = channels
xscale = math.sqrt(d_model)
max_len = length
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len,
dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
def whisper_sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
wenet_pe = wenet_sinusoids(100, 512).numpy()
whisper_pe = whisper_sinusoids(100, 512).numpy()
print(wenet_pe.shape)
print(whisper_pe.shape)
np.testing.assert_allclose(wenet_pe, whisper_pe, atol=1e-8) |
有个提议,直接把 whisper的requirments和我们的基本不冲突 |
引入whisper的库只用他的log_mel_spec和tokenizer,引入这个库,那啥mel_filter.npz和xx.tiktoken也不用手动下载了。cli也可以复用他的download相关函数来下载ckpt |
whisper-style decoder input: def add_whisper_tokens(
tokenizer, ys_pad: torch.Tensor,
ignore_id: int, task_id: int, no_timestamp: bool,
language: str, use_prev: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add whisper-style tokens.
([PREV] -> [previous text tokens or hotwords]).optional --
┌------------------------------------------------------↲
↓
[sot] -> [language id] -> [transcribe] -> [begin time] -> [text tokens] -> [end time] -> ... -> [eot] # noqa
| | |-------> [no timestamps] -> [text tokens] ----------------------↑ # noqa
| | | # noqa
| |--------> [translate] -> [begin time] -> [text tokens] -> [end time] -> ... --->| # noqa
| |-------> [no timestamps] -> [text tokens] --------------------->| # noqa
| | # noqa
|--> [no speech(VAD)] ---------------------------------------------------------------------->| # noqa
Args:
tokenizer: get IDs of special tokens
ignore_id (int): index of padding
no_timestamp (bool): whether to add timestamps tokens
language (str): language tag
Returns:
ys_in (torch.Tensor) : (B, Lmax + ?)
ys_out (torch.Tensor) : (B, Lmax + ?)
"""
if use_prev:
# i.e., hotword list
_prev = [tokenizer.sot_prev]
# append hotword list to _prev
# ...
raise NotImplementedError
else:
_prev = []
language_id = tokenizer.sot + 1 + WHISPER_LANGS.index(language)
_sot = _prev + [tokenizer.sot, language_id, task_id]
_eot = torch.tensor([tokenizer.eot],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
if task_id == tokenizer.transcribe or task_id == tokenizer.translate:
if no_timestamp:
_sot.append(tokenizer.no_timestamps)
else:
_sot.append(tokenizer.timestamp_begin)
# add subsequent tokens
# ...
raise NotImplementedError
elif task_id == tokenizer.no_speech:
_sot.append(tokenizer.no_speech)
else:
raise NotImplementedError
_sot = torch.tensor(_sot, dtype=torch.long,
requires_grad=False, device=ys_pad.device)
ys_in = [torch.cat([_sot, y], dim=0) for y in ys]
ys_out = [torch.cat([_sot[1:], y, _eot], dim=0) for y in ys]
return pad_list(ys_in, tokenizer.eot), pad_list(ys_out, ignore_id) |
增加了unit test (单元测试包含
从 softmax 后的数值可以看出,概率归一化之后,应该不影响解码结果 |
pass unit test, ready for final review @Mddct @robin1001 @whiteshirt0429 p.s. 现在单元测试大概需要7~8min |
From binbin: |
|
一些开源工具在集成 whisper 时的相关实现(供参考):
|
指定 |
whisper 潜在问题,whisper的encoder和decoder均采用固定长度emb(encoder 30s, decoder 448字符),因此当开启speed_perturb,变速1.1时 30s的audio可能会造成assertion error #2171 |
whisper的tokenizer 和 中文常用的char tokenizer,效率对比如下: #!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-11-30] <[email protected], Xingchen Song>
import statistics as s
from whisper.tokenizer import get_tokenizer
tokenizer = get_tokenizer(multilingual=True, num_languages=100)
char_lens, token_lens = [], []
with open("data/train/text", "r") as f:
lines = f.readlines()
for l in lines:
l = l.strip().split()[1]
char_len = len(l)
token_len = len(tokenizer.encoding.encode(l))
char_lens.append(char_len)
token_lens.append(token_len)
print("{} {} {}".format(char_len, token_len, l))
print("Mean: CharTokenizer {}, WhisperTokenizer {}".format(s.mean(char_lens),
s.mean(token_lens)))
print("Var : CharTokenizer {}, WhisperTokenizer {}".format(s.variance(char_lens),
s.variance(token_lens))) ...
16 21 他还增设了一种叫消防费的收费项目
10 11 总共收上来一四万馀元
8 9 都揣进了自己腰包
20 32 改装小货车撞上出租车钢筋将出租车后座顶出
12 14 八月二零日晚将近一零点钟
15 23 一辆改装的小货车撞上一辆出租车
15 25 货车装载的钢筋将出租车后座对穿
9 12 幸亏车后座没有乘客
13 20 钢筋穿过后排座位并顶出车外
11 11 案件性质关系到国家安全
Mean: CharTokenizer 14.405843561091775, WhisperTokenizer 17.837832436843243
Var : CharTokenizer 18.59714879629654, WhisperTokenizer 35.86092595601854 |
whisper的tokenizer和英文常用的bpe tokenizer,效率对比如下: #!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-11-30] <[email protected], Xingchen Song>
import statistics as s
import sentencepiece as spm
from whisper.tokenizer import get_tokenizer
from wenet.text.tokenize_utils import tokenize_by_bpe_model
tokenizer = get_tokenizer(multilingual=True, num_languages=100)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load("../../../test/resources/librispeech.train_960_unigram5000.bpemodel")
bpe_lens, token_lens = [], []
with open("data/test_clean/text", "r") as f:
lines = f.readlines()
for l in lines:
l = " ".join(l.strip().split()[1:])
bpe_len = len(tokenize_by_bpe_model(bpe_model, l))
token_len = len(tokenizer.encoding.encode(l))
bpe_lens.append(bpe_len)
token_lens.append(token_len)
print("{} {} {}".format(bpe_len, token_len, l))
print("Mean: BpeTokenizer {}, WhisperTokenizer {}".format(s.mean(bpe_lens),
s.mean(token_lens)))
print("Var : BpeTokenizer {}, WhisperTokenizer {}".format(s.variance(bpe_lens),
s.variance(token_lens))) ...
18 25 I THANK ALL WHO HAVE LOVED ME IN THEIR HEARTS WITH THANKS AND LOVE FROM MINE
36 62 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE COMFORT FAST WHILE BUDDING AT THY SIGHT MY PILGRIM'S STAFF GAVE OUT GREEN LEAVES WITH MORNING DEWS IMPEARLED
24 32 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I LOVE THEE PURELY AS THEY TURN FROM PRAISE
22 34 I LOVE THEE WITH THE PASSION PUT TO USE IN MY OLD GRIEFS AND WITH MY CHILDHOOD'S FAITH
43 58 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH MY LOST SAINTS I LOVE THEE WITH THE BREATH SMILES TEARS OF ALL MY LIFE AND IF GOD CHOOSE I SHALL BUT LOVE THEE BETTER AFTER DEATH
Mean: BpeTokenizer 25.176335877862595, WhisperTokenizer 40.09809160305343
Var : BpeTokenizer 314.1025325790101, WhisperTokenizer 814.1969417556378
... |
The training of a 1.5 billion parameter model places high demands on CPU memory, and issues like the following are likely to occur on systems with memory equal to or less than 160GB. |
whisper基本结构约等于wenet.transformer,复用大部分代码+修改ckpt命名 而不是直接将 openai-whisper的模型定义copy过来,是出于以下几点考虑:
TODO (This PR)
TODO (Next PR)