-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrender_agent.py
93 lines (80 loc) · 3.35 KB
/
render_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from pathlib import Path
from stable_baselines3 import PPO
from stable_baselines3.common.atari_wrappers import WarpFrame
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv
from scobi import Environment
from utils.parser.parser import render_parser, get_highest_version
from utils.renderer import Renderer
from viper_extract import DTClassifierModel
from joblib import load
def flist(l):
return ["%.2f" % e for e in l]
# Helper function to load from a dt and not a checkpoint directly
def _load_viper(exp_name, path_provided):
if path_provided:
viper_path = Path(exp_name)
model = load(sorted(viper_path.glob("*_best.viper"))[0])
else:
viper_path = Path("resources/viper_extracts/extract_output", exp_name + "-extraction")
model = load(sorted(viper_path.glob("*_best.viper"))[0])
wrapped = DTClassifierModel(model)
return wrapped
# Helper function ensuring that a checkpoint has completed training
def _ensure_completeness(path):
checkpoint = path / "best_model.zip"
return checkpoint.is_file()
def main():
flag_dictionary = render_parser()
version = int(flag_dictionary["version"])
exp_name = flag_dictionary["exp_name"]
variant = flag_dictionary["variant"]
env_str = flag_dictionary["env_str"]
pruned_ff_name = flag_dictionary["pruned_ff_name"]
hide_properties = flag_dictionary["hide_properties"]
viper = flag_dictionary["viper"]
record = flag_dictionary["record"]
nb_frames = flag_dictionary["nb_frames"]
print_reward = flag_dictionary["print_reward"]
if version == -1:
version = get_highest_version(exp_name)
elif version == 0:
version = ""
exp_name += str(version)
checkpoint_str = "best_model" # "model_5000000_steps" #"best_model"
vecnorm_str = "best_vecnormalize.pkl"
model_path = Path("resources/checkpoints", exp_name, checkpoint_str)
vecnorm_path = Path("resources/checkpoints", exp_name, vecnorm_str)
ff_file_path = Path("resources/checkpoints", exp_name)
EVAL_ENV_SEED = 84
if not _ensure_completeness(ff_file_path):
print("The folder " + str(ff_file_path) + " does not contain a completed training checkpoint.")
return
if variant == "rgb":
env = make_vec_env(env_str, seed=EVAL_ENV_SEED, wrapper_class=WarpFrame)
else:
env = Environment(env_str,
focus_dir=ff_file_path,
focus_file=pruned_ff_name,
hide_properties=hide_properties,
draw_features=True, # implement feature attribution
reward=0) #env reward only for evaluation
_, _ = env.reset(seed=EVAL_ENV_SEED)
dummy_vecenv = DummyVecEnv([lambda : env])
env = VecNormalize.load(vecnorm_path, dummy_vecenv)
env.training = False
env.norm_reward = False
if viper:
print("loading viper tree of " + exp_name)
if isinstance(viper, str):
model = _load_viper(viper, True)
else:
model = _load_viper(exp_name, False)
else:
model = PPO.load(model_path)
obs = env.reset()
renderer = Renderer(env, model, ff_file_path, record, nb_frames)
renderer.print_reward = print_reward
renderer.run()
if __name__ == '__main__':
main()