-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
206 lines (175 loc) · 6.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import carla
import time
from Utils.utils import *
from Utils.HUD import HUD as HUD
from World import World
import argparse
import logging
import os
from stable_baselines3.common.callbacks import CallbackList
from callbacks import *
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from sb3_contrib import RecurrentPPO
logdir = f"logs/{int(time.time())}/"
if not os.path.exists(logdir):
os.makedirs(logdir)
def game_loop(args):
world=None
try:
client = carla.Client(args.host, args.port)
client.set_timeout(100.0)
hud = HUD()
carla_world = client.load_world(args.map)
carla_world = client.get_world()
carla_world.apply_settings(carla.WorldSettings(
no_rendering_mode=False,
synchronous_mode=True,
fixed_delta_seconds=1/args.FPS))
world = World(client, carla_world, hud, args)
world = Monitor(world, logdir)
world.reset()
model = RecurrentPPO('MlpLstmPolicy', world, verbose=2, learning_rate=0.0003, n_steps=1280, n_epochs=20, batch_size=128, ent_coef=0.01,
tensorboard_log=logdir) # tensorboard_log=logdir
# Create Callback
save_callback = SaveOnBestTrainingRewardCallback(check_freq=500, log_dir=logdir, verbose=1)
tensor = TensorboardCallback()
checkpoint = CheckpointCallback(save_freq=500, save_path=logdir, verbose=1)
TIMESTEPS = 50000 # how long is each training iteration - individual steps
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO", progress_bar=True,
callback = CallbackList([tensor, save_callback, checkpoint]))
finally:
if world is not None:
world.destroy()
# ==============================================================================
# -- main() --------------------------------------------------------------------
# ==============================================================================
def main():
argparser = argparse.ArgumentParser(
description='CARLA Manual Control Client')
argparser.add_argument(
'-v', '--verbose',
action='store_true',
dest='debug',
help='print debug information')
argparser.add_argument(
'--host',
metavar='H',
default='127.0.0.1',
help='IP of the host server (default: 127.0.0.1)')
argparser.add_argument(
'-p', '--port',
metavar='P',
default=2000,
type=int,
help='TCP port to listen to (default: 2000)')
argparser.add_argument(
'-a', '--autopilot',
action='store_true',
help='enable autopilot')
argparser.add_argument(
'--res',
metavar='WIDTHxHEIGHT',
default='1280x720',
help='window resolution (default: 1280x720)')
argparser.add_argument(
'--filter',
metavar='PATTERN',
default='vehicle.*',
help='actor filter (default: "vehicle.*")')
argparser.add_argument(
'--rolename',
metavar='NAME',
default='hero',
help='actor role name (default: "hero")')
argparser.add_argument(
'--gamma',
default=2.2,
type=float,
help='Gamma correction of the camera (default: 2.2)')
argparser.add_argument(
'--map',
metavar='NAME',
default='Town04',
help='simulation map (default: "Town04")')
argparser.add_argument(
'--spawn_x',
metavar='x',
default='-16.75', #town04 = -16.75
help='x position to spawn the agent')
argparser.add_argument(
'--spawn_y',
metavar='y',
default='-223.55', #town04 = -223.55
help='y position to spawn the agent')
argparser.add_argument(
'--random_spawn',
metavar='RS',
default='0',
type=int,
help='Random spawn agent')
argparser.add_argument(
'--vehicle_id',
metavar='NAME',
# default='vehicle.jeep.wrangler_rubicon',
default='vehicle.tesla.model3',
help='vehicle to spawn, available options : vehicle.audi.a2 vehicle.audi.tt vehicle.carlamotors.carlacola vehicle.citroen.c3 vehicle.dodge_charger.police vehicle.jeep.wrangler_rubicon vehicle.yamaha.yzf vehicle.nissan.patrol vehicle.gazelle.omafiets vehicle.bh.crossbike vehicle.ford.mustang vehicle.bmw.isetta vehicle.audi.etron vehicle.harley-davidson.low rider vehicle.mercedes-benz.coupe vehicle.bmw.grandtourer vehicle.toyota.prius vehicle.diamondback.century vehicle.tesla.model3 vehicle.seat.leon vehicle.lincoln.mkz2017 vehicle.kawasaki.ninja vehicle.volkswagen.t2 vehicle.nissan.micra vehicle.chevrolet.impala vehicle.mini.cooperst')
argparser.add_argument(
'--vehicle_wheelbase',
metavar='NAME',
type=float,
default='2.89',
help='vehicle wheelbase used for model predict control')
argparser.add_argument(
'--waypoint_resolution',
metavar='WR',
default='1',
type=float,
help='waypoint resulution for control')
argparser.add_argument(
'--waypoint_lookahead_distance',
metavar='WLD',
default='5.0',
type=float,
help='waypoint look ahead distance for control')
argparser.add_argument(
'--desired_speed',
metavar='SPEED',
default='15',
type=float,
help='desired speed for highway driving')
argparser.add_argument(
'--control_mode',
metavar='CONT',
default='MPC',
help='Controller')
argparser.add_argument(
'--planning_horizon',
metavar='HORIZON',
type=int,
default='5',
help='Planning horizon for MPC')
argparser.add_argument(
'--time_step',
metavar='DT',
default='0.2',
type=float,
help='Planning time step for MPC')
argparser.add_argument(
'--FPS',
metavar='FPS',
default='15',
type=int,
help='Frame per second for simulation')
args = argparser.parse_args()
args.width, args.height = [int(x) for x in args.res.split('x')]
log_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format='%(levelname)s: %(message)s', level=log_level)
logging.info('listening to server %s:%s', args.host, args.port)
print(__doc__)
try:
game_loop(args)
except KeyboardInterrupt:
print('\nCancelled by user. Bye!')
if __name__ == '__main__':
main()