-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_transfer.py
102 lines (91 loc) · 2.54 KB
/
run_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import sys
import argparse
from bisim_transfer.bisimulation import *
import matplotlib.pyplot as plt
if __name__ == '__main__':
argparser = argparse.ArgumentParser(description=__doc__)
argparser.add_argument(
'--transfer',
default='lax',
type=str
)
argparser.add_argument(
'--src-env',
default='FourSmallRooms_11',
type=str
)
argparser.add_argument(
'--tgt-env',
default=None,
type=str
)
argparser.add_argument(
'--solver',
default='pyemd',
type=str
)
argparser.add_argument(
'--lfp-iters',
default=5,
type=int
)
argparser.add_argument(
'-th',
'--threshold',
default=0.01,
type=float
)
argparser.add_argument(
'-dfk',
'--discount-kd',
default=0.9,
type=float
)
argparser.add_argument(
'-dfr',
'--discount-r',
default=0.1,
type=float
)
argparser.add_argument(
'--policy-dir',
default='saved_qvalues/optimal_qvalues/',
type=str
)
argparser.add_argument(
'-l',
'--log-dir',
default='logs/',
type=str
)
argparser.add_argument(
'--save-dir',
default='saved_qvalues/transferred_qvalues/',
type=str
)
argparser.add_argument(
'--render',
action='store_true')
argparser.add_argument(
'-v', '--verbose',
action='store_true',
dest='debug',
help='print debug information')
args = argparser.parse_args()
if args.transfer == 'basic':
bisimulation = LaxBisimulation(args)
elif args.transfer == 'lax':
bisimulation = LaxBisimulation(args)
elif args.transfer == 'pess':
bisimulation = PessBisimulation(args)
elif args.transfer == 'optimistic':
bisimulation = OptBisimulation(args)
else:
raise ValueError("Provide a valid transfer metric")
bisimulation.execute_transfer()
if args.render:
bisimulation.render()
bisimulation.tgt_env.save_transfer_image(bisimulation.transferred_agent.qvalues, os.path.join('transfer_logs', args.src_env + '_' + args.tgt_env + '-' + str(args.discount_r) + '-' + str(args.discount_kd) + '.png'))
np.save('transfer_logs/Dist-sa_' + args.src_env + '_' + args.tgt_env + '.npy', bisimulation.d_sa_final)
np.save('transfer_logs/Dist-matrix_' + args.src_env + '_' + args.tgt_env + '.npy', bisimulation.dist_matrix_final)