-
-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathautopilot_model.py
35 lines (27 loc) · 1.06 KB
/
autopilot_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torchvision
from efficientnet_pytorch import EfficientNet
from torch2trt import TRTModule
from torch2trt import torch2trt
OUTPUT_SIZE = 2
DROPOUT_PROB = 0.5
class AutopilotModel(torch.nn.Module):
def __init__(self, pretrained):
super(AutopilotModel, self).__init__()
self.network = torchvision.models.resnet18(pretrained=pretrained)
self.network.fc = torch.nn.Sequential(
torch.nn.Dropout(p=DROPOUT_PROB),
torch.nn.Linear(in_features=self.network.fc.in_features, out_features=128),
torch.nn.Dropout(p=DROPOUT_PROB),
torch.nn.Linear(in_features=128, out_features=64),
torch.nn.Dropout(p=DROPOUT_PROB),
torch.nn.Linear(in_features=64, out_features=OUTPUT_SIZE)
)
self.network.cuda()
def forward(self, x):
y = self.network(x)
return y
def save_to_path(self, path):
torch.save(self.state_dict(), path)
def load_from_path(self, path):
self.load_state_dict(torch.load(path))