-
Notifications
You must be signed in to change notification settings - Fork 457
/
init_large_model.py
68 lines (54 loc) · 2.61 KB
/
init_large_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.")
parser.add_argument("--text_model", type=str, help="Repository id or path to the text encoder.")
parser.add_argument("--audio_model", type=str, help="Repository id or path to the audio encoder.")
args = parser.parse_args()
text_model = args.text_model
encodec_version = args.audio_model
t5 = AutoConfig.from_pretrained(text_model)
encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
num_codebooks = encodec.num_codebooks
print("num_codebooks", num_codebooks)
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size + 64, # + 64 instead of +1 to have a multiple of 64
max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=30,
ffn_dim=6144,
num_attention_heads=24,
num_key_value_heads=24,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
hidden_size=1536,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"),
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-larger/"))