-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
42 lines (36 loc) · 1.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from utils.data_creat import load_data
from model import *
from utils.loss import *
batch_size = 32
train_iter = load_data(batch_size)
net = TinySSD(num_classes=1)
net = net.to('cuda'if torch.cuda.is_available() else 'cpu')
trainer = torch.optim.SGD(net.parameters(), lr=0.2, weight_decay=5e-4)
num_epochs = 50
for epoch in range(num_epochs):
print('epoch: ', epoch)
# 训练精确度的和,训练精确度的和中的示例数
# 绝对误差的和,绝对误差的和中的示例数
metric = Accumulator(4)
net.train()
for features, target in train_iter:
trainer.zero_grad()
X, Y = features.to('cuda'), target.to('cuda')
# 生成多尺度的锚框,为每个锚框预测类别和偏移量
anchors, cls_preds, bbox_preds = net(X)
# 为每个锚框标注类别和偏移量
bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)
# 根据类别和偏移量的预测和标注值计算损失函数
l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,
bbox_masks)
l.mean().backward()
trainer.step()
metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),
bbox_eval(bbox_preds, bbox_labels, bbox_masks),
bbox_labels.numel())
cls_err, bbox_mae = 1 - metric.data[0] / metric.data[1], metric.data[2] / metric.data[3]
print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')
# 保存模型参数
if epoch % 10 == 0:
torch.save(net.state_dict(), 'weights/net_' + str(epoch) + '.pkl')