-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
52 lines (44 loc) · 1.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from voc_classes import get_class_name
CHECKPOINT_PATH = "checkpoint.tar"
def save_checkpoint(model, optimizer, epoch, loss):
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"loss": loss
}
torch.save(checkpoint, CHECKPOINT_PATH)
def load_checkpoint(model):
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer_state_dict = checkpoint['optimizer_state_dict']
epoch = checkpoint['epoch']
loss = checkpoint['loss']
return epoch, optimizer_state_dict, loss
def write_result_to_disk(file_name, bboxes, type="gt"):
"""
Write gt or pred result to disk with respect to
https://github.com/rafaelpadilla/Object-Detection-Metrics#how-to-use-this-project
Args:
file_name: name of the file where to store the result
bboxes: list of bounding boxes where each bbox consist of
x1,y1,x2,y2,class,confidence
type: gt or pred for ground truth or prediction, respectively
"""
text = ""
for bbox in bboxes:
text += get_class_name(bbox[4]) # class name
if type == "pred":
text += " " + str(bbox[5])
# from xyxy to xywh format
xmin, ymin, xmax, ymax = bbox[:4]
x = (xmin + xmax) / 2
y = (ymin + ymax) / 2
w = xmax - xmin
h = ymax - ymin
for i in [x,y,w,h]:
text += " " + str(i)
text += "\n"
with open(file_name, "w") as f:
f.write(text)