-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
171 lines (137 loc) · 6.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import math
import random
import matplotlib.pyplot as plt
import numpy as np
from collections import deque, namedtuple
from PIL import Image
from th10.game import TH10
from config import *
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
game = TH10()
memory = deque()
class DQN(nn.Module):
def __init__(self):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(IMG_CHANNELS, 32, kernel_size=8, stride=2)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=8, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2)
self.bn2 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2)
self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=1)
self.fc1 = nn.Linear(2048, 512)
self.head = nn.Linear(512, NUM_OF_ACTIONS)
def forward(self, x):
x = self.bn1(F.relu(self.conv1(x)))
x = F.relu(self.conv2(x))
x = self.bn2(F.relu(self.conv3(x)))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = x.view(x.size(0), -1) # <-- flatten
x = F.relu(self.fc1(x))
x = self.head(x)
return x
class ToTensorWithoutScaling(object):
def __call__(self, picture):
return torch.ByteTensor(np.array(picture)).unsqueeze(0)
transform = T.Compose([T.Grayscale(), T.Resize((TRANSFORM_HEIGHT, TRANSFORM_WIDTH)), ToTensorWithoutScaling()])
def transform_state(single_state):
# PIL -> Grayscale -> Resize -> ToTensor
single_state = transform(single_state)
single_state = single_state.unsqueeze(0)
single_state = single_state.to(device, dtype=torch.float)
return single_state
if __name__ == '__main__':
policy_net = DQN().to(device)
policy_net.load_state_dict(torch.load(f'./weights_{NUM_OF_ACTIONS}'))
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
"""
test_input = torch.rand(1, 4, 128, 128).to(device, dtype=torch.float)
policy_net(test_input)
"""
optimizer = optim.Adam(policy_net.parameters())
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'is_terminal'))
state, _, _ = game.play(-1)
state = transform_state(state)
state = torch.cat((state, state, state, state), 1)
steps = 0
while True:
loss = 0
train_q = torch.tensor([0])
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps / EPS_DECAY)
if random.random() <= eps_threshold:
# choose a random action
action = torch.tensor([[random.randrange(NUM_OF_ACTIONS)]], device=device, dtype=torch.long)
q_val = 0
else:
# input a stack of 4 images, get the prediction
q = policy_net(state).max(1)
action = q[1].view(1, 1)
q_val = q[0].item()
next_state, reward, is_terminal = game.play(action.item())
if next_state is None:
continue
next_state = transform_state(next_state)
next_state = torch.cat((next_state, state[:, :3]), 1)
reward = torch.tensor([reward], device=device, dtype=torch.float)
is_terminal = torch.tensor([is_terminal], device=device, dtype=torch.uint8)
'''
We need enough states in our experience replay deque so that we can take a random sample from it of the size we declared.
Therefore we wait until a certain number and observe the environment until we're ready.
'''
memory.append((state, action, next_state, reward, is_terminal))
if len(memory) > EXP_REPLAY_MEMORY:
memory.popleft()
# Optimize
if len(memory) > BATCH_SIZE:
# Batches
transitions = random.sample(memory, BATCH_SIZE)
batch = Transition(*zip(*transitions))
# Current results
state_batch = torch.cat(batch.state)
next_state_batch = torch.cat(batch.next_state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
is_terminal_batch = torch.cat(batch.is_terminal)
state_action_values = policy_net(state_batch).gather(1, action_batch)
# Non-final next state
non_final_mask = is_terminal_batch == 0
non_final_next_states = next_state_batch[non_final_mask]
# Non-final next state reward
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
# Compute the expected Q values
# (current state reward) + (next state reward) * gamma
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
train_q = expected_state_action_values
# Optimize with mean squared error
# loss = F.mse_loss(state_action_values, expected_state_action_values.unsqueeze(1))
# Optimize with Huber loss
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
state = next_state
steps += 1
if steps % STEPS_SAVE == 0:
torch.save(policy_net.state_dict(), f'./weights_{NUM_OF_ACTIONS}')
if steps % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())
'''
img = state[0, 0:3]
img = img.data.cpu().numpy()
img = img.transpose((1, 2, 0))
plt.imshow(img)
plt.savefig(f'steps/{steps}.png')
'''
print("Timestep: %d, Action: %d, Reward: %.2f, q: %.2f, train_q_min: %.2f, train_q_max: %.2f, Loss: %.2f" %
(steps, action.item(), reward.item(), q_val, torch.min(train_q), torch.max(train_q), loss))