forked from jiwoon-ahn/psa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmanage_weights.py
68 lines (56 loc) · 1.96 KB
/
manage_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# continuously check for new weights, run inference
# check miou, and save
import os
import time
import pickle
import subprocess
from compute_miou import compute_miou
run_num = '0514'
weights_dir = 'caff_psa'
while 1:
files = os.listdir(weights_dir)
ep_2_fname = {}
for fi in files:
if '.pth' in fi:
if fi.split('_')[2] == run_num:
try:
ep = fi.split('_')[4].split('.')[0]
ep_2_fname[ep] = fi
except:
pass
try:
with open('{}_miou.p'.format(run_num), 'rb') as f:
ep_2_miou = pickle.load(f)
except FileNotFoundError:
ep_2_miou = {}
eps = list(ep_2_fname.keys())
eps.sort(key=float, reverse=True)
myep = -1
for ep in eps:
if ep not in ep_2_miou:
ep_2_miou[ep] = -1
with open('{}_miou.p'.format(run_num), 'wb') as f:
pickle.dump(ep_2_miou, f)
myep = ep
break
if myep == -1:
continue # continue the while 1, let's check all the weights again
# otherwise, run inference
start= time.time()
print(start)
cmd = 'CUDA_VISBILE_DEVICES=1,2 python3 infer_cls.py --infer_list voc12/val.txt --voc12_root /media/ssd1/austin/datasets/VOC/VOCdevkit/VOC2012 --network network.resnet38_cls --weights {} --out_cam_pred out_cam_pred'.format( os.path.join(weights_dir, ep_2_fname[myep]) )
print(cmd)
val_out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
end = time.time()
print(end)
print('elapsed: {}'.format(end - start))
# then compute miou
gt_path = '/media/ssd1/austin/datasets/VOC/VOCdevkit/VOC2012/AugSegClass'
pred_path = 'out_cam_pred'
miou = compute_miou(gt_path, pred_path)
print(miou, myep)
# save miou to pickle
ep_2_miou[myep] = miou
print(ep_2_miou)
with open('{}_miou.p'.format(run_num), 'wb') as f:
ep_2_miou = pickle.dump(ep_2_miou, f)