-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrain.py
55 lines (40 loc) · 1.91 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import logging
from agent.trainer import Trainer
from util.parser import get_parser
from util.config import Config
from util.mytorch import same_seeds
logging.basicConfig(level=logging.DEBUG, format='[%(levelname)s] %(asctime)s | %(filename)s | %(message)s',\
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
def get_args():
parser = get_parser(description='Train')
# config
parser.add_argument('--config', '-c', default='./config/train_again-c4s.yaml', help='config yaml file')
# dryrun
parser.add_argument('--dry', action='store_true', help='whether to dry run')
# debugging mode
parser.add_argument('--debug', action='store_true', help='debugging mode')
# seed
parser.add_argument('--seed', type=int, help='random seed', default=961998)
#
parser.add_argument('--load', '-l', type=str, help='Load a checkpoint.', default='')
parser.add_argument('--njobs', '-p', type=int, help='', default=4)
parser.add_argument('--total-steps', type=int, help='Total training steps.', default=100000)
parser.add_argument('--verbose-steps', type=int, help='The steps to update tqdm message.', default=10)
parser.add_argument('--log-steps', type=int, help='The steps to log data for the customed logger (wandb, tensorboard, etc.).', default=500)
parser.add_argument('--save-steps', type=int, help='The steps to save a checkpoint.', default=5000)
parser.add_argument('--eval-steps', type=int, help='The steps to evaluate.', default=5000)
return parser.parse_args()
if __name__ == '__main__':
# config
args = get_args()
config = Config(args.config)
same_seeds(args.seed)
# build trainer
trainer = Trainer(config, args)
# train
trainer.train(total_steps=args.total_steps,
verbose_steps=args.verbose_steps,
log_steps=args.log_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps)