forked from Jaraxxus-Me/LogiCity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransfer.py
143 lines (131 loc) · 5.72 KB
/
transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import copy
import time
import yaml
import torch
import argparse
import importlib
import numpy as np
import pickle as pkl
from tqdm import trange
from logicity.core.config import *
from logicity.utils.load import CityLoader
from logicity.utils.logger import setup_logger
from logicity.utils.vis import visualize_city
# RL
from logicity.rl_agent.alg import *
from logicity.utils.gym_wrapper import GymCityWrapper
from stable_baselines3.common.vec_env import SubprocVecEnv
from logicity.utils.gym_callback import EvalCheckpointCallback, DreamerEvalCheckpointCallback
def parse_arguments():
parser = argparse.ArgumentParser(description='Logic-based city simulation.')
# logger
parser.add_argument('--log_dir', type=str, default="./log_rl")
parser.add_argument('--exp', type=str, default="bc_transfer_train")
# seed
parser.add_argument('--seed', type=int, default=2)
# RL
parser.add_argument('--config', default='config/tasks/Nav/transfer/medium/algo/bc/bc_10.yaml', help='Configure file for this RL exp.')
return parser.parse_args()
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
def dynamic_import(module_name, class_name):
module = importlib.import_module(module_name)
return getattr(module, class_name)
def make_env(simulation_config, episode_cache=None, return_cache=False):
# Unpack arguments from simulation_config and pass them to CityLoader
city, cached_observation = CityLoader.from_yaml(**simulation_config, episode_cache=episode_cache)
env = GymCityWrapper(city)
if return_cache:
return env, cached_observation
else:
return env
def make_envs(simulation_config, rank):
"""
Utility function for multiprocessed env.
:param simulation_config: The configuration for the simulation.
:param rank: Unique index for each environment to ensure different seeds.
:return: A function that creates a single environment.
"""
def _init():
env = make_env(simulation_config)
env.seed(rank + 1000) # Optional: set a unique seed for each environment
return env
return _init
def main(args, logger):
torch.manual_seed(args.seed)
np.random.seed(args.seed)
config = load_config(args.config)
# simulation config
simulation_config = config["simulation"]
logger.info("Simulation config: {}".format(simulation_config))
# RL config
rl_config = config['stable_baselines']
logger.info("RL config: {}".format(rl_config))
# Dynamic import of the features extractor class
if "features_extractor_module" in rl_config["policy_kwargs"]:
features_extractor_class = dynamic_import(
rl_config["policy_kwargs"]["features_extractor_module"],
rl_config["policy_kwargs"]["features_extractor_class"]
)
# Prepare policy_kwargs with the dynamically imported class
policy_kwargs = {
"features_extractor_class": features_extractor_class,
"features_extractor_kwargs": rl_config["policy_kwargs"]["features_extractor_kwargs"]
}
else:
policy_kwargs = rl_config["policy_kwargs"]
# Dynamic import of the RL algorithm
algorithm_class = dynamic_import(
"logicity.rl_agent.alg", # Adjust the module path as needed
rl_config["algorithm"]
)
# Load the entire eval_checkpoint configuration as a dictionary
eval_checkpoint_config = config.get('eval_checkpoint', {})
# Hyperparameters
hyperparameters = rl_config["hyperparameters"]
train = rl_config["train"]
# model training
assert train
num_envs = rl_config["num_envs"]
total_timesteps = rl_config["total_timesteps"]
if num_envs > 1:
logger.info("Running in RL mode with {} parallel environments.".format(num_envs))
train_env = SubprocVecEnv([make_envs(simulation_config, i) for i in range(num_envs)])
else:
train_env = make_env(simulation_config)
train_env.reset()
assert os.path.isfile(rl_config["checkpoint_path"]), "Checkpoint path not found."
logger.info("Continue Learning")
logger.info("Loading the model from checkpoint: {}".format(rl_config["checkpoint_path"]))
policy_kwargs_use = copy.deepcopy(policy_kwargs)
model = algorithm_class.load(rl_config["checkpoint_path"], \
train_env, **hyperparameters, policy_kwargs=policy_kwargs_use)
# RL training mode
# Create the custom checkpoint and evaluation callback
eval_checkpoint_callback = EvalCheckpointCallback(exp_name=args.exp, **eval_checkpoint_config)
# Train the model
model.learn(total_timesteps=total_timesteps, callback=eval_checkpoint_callback\
, tb_log_name=args.exp)
# Save the model
model.save(eval_checkpoint_config["name_prefix"])
return
def cal_step_metric(decision_step, succ_decision):
mean_decision_succ = {}
total_decision = sum(decision_step.values())
total_decision = max(total_decision, 1)
total_succ = sum(succ_decision.values())
for action, num in decision_step.items():
num = max(num, 1)
mean_decision_succ[action] = succ_decision[action]/num
average_decision_succ = sum(mean_decision_succ.values())/len(mean_decision_succ)
# mean decision succ (over all steps), average decision succ (over all actions), decision succ for each action
return total_succ/total_decision, average_decision_succ, mean_decision_succ
if __name__ == '__main__':
args = parse_arguments()
logger = setup_logger(log_dir=args.log_dir, log_name=args.exp)
logger.info("Running in RL mode. ***Transfer Learning***")
logger.info("Loading RL config from {}.".format(args.config))
# RL mode, will use gym wrapper to learn and test an agent
main(args, logger)