-
Notifications
You must be signed in to change notification settings - Fork 269
/
train.py
89 lines (76 loc) · 2.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
for epoch in range(epochs):
training_loss = 0.0
valid_loss = 0.0
model.train()
for batch in train_loader:
optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
output = model(inputs)
loss = loss_fn(output, targets)
loss.backward()
optimizer.step()
training_loss += loss.data.item() * inputs.size(0)
training_loss /= len(train_loader.dataset)
model.eval()
num_correct = 0
num_examples = 0
for batch in val_loader:
inputs, targets = batch
inputs = inputs.to(device)
output = model(inputs)
targets = targets.to(device)
loss = loss_fn(output,targets)
valid_loss += loss.data.item() * inputs.size(0)
correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
num_correct += torch.sum(correct).item()
num_examples += correct.shape[0]
valid_loss /= len(val_loader.dataset)
print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
valid_loss, num_correct / num_examples))
def find_lr(model, loss_fn, optimizer, train_loader, init_value=1e-8, final_value=10.0, device="cpu"):
number_in_epoch = len(train_loader) - 1
update_step = (final_value / init_value) ** (1 / number_in_epoch)
lr = init_value
optimizer.param_groups[0]["lr"] = lr
best_loss = 0.0
batch_num = 0
losses = []
log_lrs = []
for data in train_loader:
batch_num += 1
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# Crash out if loss explodes
if batch_num > 1 and loss > 4 * best_loss:
if(len(log_lrs) > 20):
return log_lrs[10:-5], losses[10:-5]
else:
return log_lrs, losses
# Record the best loss
if loss < best_loss or batch_num == 1:
best_loss = loss
# Store the values
losses.append(loss.item())
log_lrs.append((lr))
# Do the backward pass and optimize
loss.backward()
optimizer.step()
# Update the lr for the next step and store
lr *= update_step
optimizer.param_groups[0]["lr"] = lr
if(len(log_lrs) > 20):
return log_lrs[10:-5], losses[10:-5]
else:
return log_lrs, losses