-
Notifications
You must be signed in to change notification settings - Fork 279
/
train.py
109 lines (83 loc) · 3.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn.functional as F
import torch.optim as optim
from envs import create_atari_env
from model import ActorCritic
def ensure_shared_grads(model, shared_model):
for param, shared_param in zip(model.parameters(),
shared_model.parameters()):
if shared_param.grad is not None:
return
shared_param._grad = param.grad
def train(rank, args, shared_model, counter, lock, optimizer=None):
torch.manual_seed(args.seed + rank)
env = create_atari_env(args.env_name)
env.seed(args.seed + rank)
model = ActorCritic(env.observation_space.shape[0], env.action_space)
if optimizer is None:
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)
model.train()
state = env.reset()
state = torch.from_numpy(state)
done = True
episode_length = 0
while True:
# Sync with the shared model
model.load_state_dict(shared_model.state_dict())
if done:
cx = torch.zeros(1, 256)
hx = torch.zeros(1, 256)
else:
cx = cx.detach()
hx = hx.detach()
values = []
log_probs = []
rewards = []
entropies = []
for step in range(args.num_steps):
episode_length += 1
value, logit, (hx, cx) = model((state.unsqueeze(0),
(hx, cx)))
prob = F.softmax(logit, dim=-1)
log_prob = F.log_softmax(logit, dim=-1)
entropy = -(log_prob * prob).sum(1, keepdim=True)
entropies.append(entropy)
action = prob.multinomial(num_samples=1).detach()
log_prob = log_prob.gather(1, action)
state, reward, done, _ = env.step(action.numpy())
done = done or episode_length >= args.max_episode_length
reward = max(min(reward, 1), -1)
with lock:
counter.value += 1
if done:
episode_length = 0
state = env.reset()
state = torch.from_numpy(state)
values.append(value)
log_probs.append(log_prob)
rewards.append(reward)
if done:
break
R = torch.zeros(1, 1)
if not done:
value, _, _ = model((state.unsqueeze(0), (hx, cx)))
R = value.detach()
values.append(R)
policy_loss = 0
value_loss = 0
gae = torch.zeros(1, 1)
for i in reversed(range(len(rewards))):
R = args.gamma * R + rewards[i]
advantage = R - values[i]
value_loss = value_loss + 0.5 * advantage.pow(2)
# Generalized Advantage Estimation
delta_t = rewards[i] + args.gamma * \
values[i + 1] - values[i]
gae = gae * args.gamma * args.gae_lambda + delta_t
policy_loss = policy_loss - \
log_probs[i] * gae.detach() - args.entropy_coef * entropies[i]
optimizer.zero_grad()
(policy_loss + args.value_loss_coef * value_loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
ensure_shared_grads(model, shared_model)
optimizer.step()