-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathtrain-parallel.py
78 lines (67 loc) · 2.79 KB
/
train-parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : train-duration.py
@Date : 2021/01/05, Tue
@Author : Atomicoo
@Version : 1.0
@Contact : [email protected]
@License : (C)Copyright 2020-2021, ShiGroup-NLP-XMU
@Desc : Training parallel model.
'''
__author__ = 'Atomicoo'
import argparse
import os
import os.path as osp
import torch
from utils.hparams import HParam
from helpers.logger import Logger
from helpers.trainer import ParallelTrainer
from utils.utils import select_device
try:
from helpers.manager import GPUManager
except ImportError as err:
print(err); gm = None
else:
gm = GPUManager()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--batch_size", default=64, type=int, help="Batch size")
parser.add_argument("--epochs", default=300, type=int, help="Training epochs")
parser.add_argument("--adam_lr", default=0.002, type=int, help="Initial learning rate for adam")
# parser.add_argument("--standardize", default=True, type=bool, help="Standardize spectrograms")
parser.add_argument("--ground_truth", action='store_true', help='Ground-truth melspectrogram')
parser.add_argument("--checkpoint", default=None, type=str, help="Checkpoint file path")
parser.add_argument("--device", default=None, type=str, help='cuda device or cpu')
parser.add_argument("--name", default="parallel", type=str, help="Append to logdir name")
parser.add_argument("--enable_wandb", action='store_true', help="Enable wandb or not")
parser.add_argument("--project", default="parallel-speech", type=str, help="Project for wandb")
parser.add_argument("--entity", default="atomicoo", type=str, help="Entity for wandb")
parser.add_argument('--config', default=None, type=str, help='Config file path')
args = parser.parse_args()
if torch.cuda.is_available():
index = args.device if args.device else str(0 if gm is None else gm.auto_choice())
else:
index = 'cpu'
device = select_device(index)
hparams = HParam(args.config) \
if args.config else HParam(osp.join(osp.abspath(os.getcwd()), 'config', 'default.yaml'))
loggers = Logger(
hparams.trainer.logdir,
hparams.data.dataset, args.name,
wandb_info={"project": args.project, "entity": args.entity} if args.enable_wandb else None
)
ground_truth = True if args.ground_truth else hparams.parallel.ground_truth
trainer = ParallelTrainer(
hparams=hparams,
adam_lr=args.adam_lr,
ground_truth=ground_truth,
device=device
)
trainer.fit(
batch_size=args.batch_size,
epochs=args.epochs,
chkpt_every=10,
checkpoint=args.checkpoint,
loggers=loggers
)