-
Notifications
You must be signed in to change notification settings - Fork 3
/
ensemble.py
56 lines (41 loc) · 1.31 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import pickle
import os
import numpy as np
from tqdm import tqdm
'''
Dataset split wise alpha values
NTU60 Xsub : [2, 1.8, 0.7]
NTU60 XView : [2.1, 1.6, 0.7]
NTU120 XSub : [2.2, 1.5, 0.7]
NTU120 XSet : [2.1, 2, 0.8]
NTU60x : [3.4, 2, 1]
NTU120x : [2.4, 2.2, 1]
'''
if __name__ == "__main__":
body_scores_path = "path_body_scores.pkl"
hand_scores_path = "path_body_scores.pkl"
leg_scores_path = "path_body_scores.pkl"
with open(body_scores_path, 'rb') as r1:
r1 = list(pickle.load(r1).items())
with open(hand_scores_path, 'rb') as r2:
r2 = list(pickle.load(r2).items())
with open(leg_scores_path, 'rb') as r2:
r2 = list(pickle.load(r2).items())
right_num = total_num = right_num_5 = 0
arg.alpha = [1.8, 1.5, 0.5]
for i in tqdm(range(len(label))):
l = label[i]
_, r11 = r1[i]
_, r22 = r2[i]
_, r33 = r3[i]
r = r11 * arg.alpha[0] + r22 * arg.alpha[1] + r33 * arg.alpha[2]
rank_5 = r.argsort()[-5:]
right_num_5 += int(int(l) in rank_5)
r = np.argmax(r)
right_num += int(r == int(l))
total_num += 1
acc = right_num / total_num
acc5 = right_num_5 / total_num
print('Top1 Acc: {:.4f}%'.format(acc * 100))
print('Top5 Acc: {:.4f}%'.format(acc5 * 100))