-
Notifications
You must be signed in to change notification settings - Fork 286
/
Copy pathimagenet_run_exp.py
172 lines (148 loc) · 6.75 KB
/
imagenet_run_exp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware
# Han Cai, Ligeng Zhu, Song Han
# International Conference on Learning Representations (ICLR), 2019.
import argparse
import numpy as np
import os
import json
import torch
from models import *
from run_manager import RunManager
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default=None)
parser.add_argument('--gpu', help='gpu available', default='0,1,2,3')
parser.add_argument('--train', action='store_true')
parser.add_argument('--manual_seed', default=0, type=int)
parser.add_argument('--resume', action='store_true')
parser.add_argument('--latency', type=str, default=None)
parser.add_argument('--n_epochs', type=int, default=300)
parser.add_argument('--init_lr', type=float, default=0.05)
parser.add_argument('--lr_schedule_type', type=str, default='cosine')
# lr_schedule_param
parser.add_argument('--dataset', type=str, default='imagenet', choices=['imagenet'])
parser.add_argument('--train_batch_size', type=int, default=256)
parser.add_argument('--test_batch_size', type=int, default=500)
parser.add_argument('--valid_size', type=int, default=None)
parser.add_argument('--opt_type', type=str, default='sgd', choices=['sgd'])
parser.add_argument('--momentum', type=float, default=0.9) # opt_param
parser.add_argument('--no_nesterov', action='store_true') # opt_param
parser.add_argument('--weight_decay', type=float, default=4e-5)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--no_decay_keys', type=str, default='bn', choices=['None', 'bn', 'bn#bias'])
parser.add_argument('--model_init', type=str, default='he_fout', choices=['he_fin', 'he_fout'])
parser.add_argument('--init_div_groups', action='store_true')
parser.add_argument('--validation_frequency', type=int, default=1)
parser.add_argument('--print_frequency', type=int, default=10)
parser.add_argument('--n_worker', type=int, default=32)
parser.add_argument('--resize_scale', type=float, default=0.08)
parser.add_argument('--distort_color', type=str, default='strong', choices=['normal', 'strong', 'None'])
""" net config """
parser.add_argument('--bn_momentum', type=float, default=0.1)
parser.add_argument('--bn_eps', type=float, default=1e-3)
parser.add_argument(
'--net', type=str, default='proxyless_mobile',
choices=['proxyless_gpu', 'proxyless_cpu', 'proxyless_mobile', 'proxyless_mobile_14']
)
parser.add_argument('--dropout', type=float, default=0)
if __name__ == '__main__':
args = parser.parse_args()
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed)
np.random.seed(args.manual_seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
os.makedirs(args.path, exist_ok=True)
# prepare run config
run_config_path = '%s/run.config' % args.path
if os.path.isfile(run_config_path):
# load run config from file
run_config = json.load(open(run_config_path, 'r'))
run_config = ImagenetRunConfig(**run_config)
if args.valid_size:
run_config.valid_size = args.valid_size
else:
# build run config from args
args.lr_schedule_param = None
args.opt_param = {
'momentum': args.momentum,
'nesterov': not args.no_nesterov,
}
if args.no_decay_keys == 'None':
args.no_decay_keys = None
run_config = ImagenetRunConfig(
**args.__dict__
)
print('Run config:')
for k, v in run_config.config.items():
print('\t%s: %s' % (k, v))
# prepare network
net_config_path = '%s/net.config' % args.path
if os.path.isfile(net_config_path):
# load net from file
from models import get_net_by_name
net_config = json.load(open(net_config_path, 'r'))
net = get_net_by_name(net_config['name']).build_from_config(net_config)
else:
# build net from args
if 'proxyless' in args.net:
from models.normal_nets.proxyless_nets import proxyless_base
net_config_url = 'https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/%s.config' % args.net
net = proxyless_base(
net_config=net_config_url, n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps), dropout_rate=args.dropout,
)
else:
raise ValueError('do not support: %s' % args.net)
# build run manager
run_manager = RunManager(args.path, net, run_config, measure_latency=args.latency)
run_manager.save_config(print_info=True)
# load checkpoints
init_path = '%s/init' % args.path
if args.resume:
run_manager.load_model()
if args.train and run_manager.best_acc == 0:
loss, acc1, acc5 = run_manager.validate(is_test=False, return_top5=True)
run_manager.best_acc = acc1
elif os.path.isfile(init_path):
if torch.cuda.is_available():
checkpoint = torch.load(init_path)
else:
checkpoint = torch.load(init_path, map_location='cpu')
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
run_manager.net.module.load_state_dict(checkpoint)
elif 'proxyless' in args.net and not args.train:
from utils.latency_estimator import download_url
pretrained_weight_url = 'https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/%s.pth' % args.net
print('Load pretrained weights from %s' % pretrained_weight_url)
init_path = download_url(pretrained_weight_url)
init = torch.load(init_path, map_location='cpu')
net.load_state_dict(init['state_dict'])
else:
print('Random initialization')
# train
if args.train:
print('Start training')
run_manager.train(print_top5=True)
run_manager.save_model()
output_dict = {}
# validate
if run_config.valid_size:
print('Test on validation set')
loss, acc1, acc5 = run_manager.validate(is_test=False, return_top5=True)
log = 'valid_loss: %f\t valid_acc1: %f\t valid_acc5: %f' % (loss, acc1, acc5)
run_manager.write_log(log, prefix='valid')
output_dict = {
**output_dict,
'valid_loss': ' % f' % loss, 'valid_acc1': ' % f' % acc1, 'valid_acc5': ' % f' % acc5,
'valid_size': run_config.valid_size
}
# test
print('Test on test set')
loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True)
log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5)
run_manager.write_log(log, prefix='test')
output_dict = {
**output_dict,
'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5
}
json.dump(output_dict, open('%s/output' % args.path, 'w'), indent=4)