-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathBiFPN.py
155 lines (121 loc) · 8.06 KB
/
BiFPN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch.nn as nn
import torch
import math
class BiFPN(nn.Module):
def __init__(self, fpn_sizes):
super(BiFPN, self).__init__()
P3_channels, P4_channels, P5_channels, P6_channels, P7_channels = fpn_sizes
self.W_bifpn = 64
#self.p6_td_conv = nn.Conv2d(P6_channels, self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p6_td_conv = nn.Conv2d(P6_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p6_td_conv_2 = nn.Conv2d(self.W_bifpn, self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p6_td_act = nn.ReLU()
self.p6_td_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p6_td_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p6_td_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p5_td_conv = nn.Conv2d(P5_channels,self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p5_td_conv_2 = nn.Conv2d(self.W_bifpn,self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p5_td_act = nn.ReLU()
self.p5_td_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p5_td_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p5_td_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p4_td_conv = nn.Conv2d(P4_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p4_td_conv_2 = nn.Conv2d(self.W_bifpn, self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p4_td_act = nn.ReLU()
self.p4_td_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p4_td_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p4_td_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p3_out_conv = nn.Conv2d(P3_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p3_out_conv_2 = nn.Conv2d(self.W_bifpn, self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p3_out_act = nn.ReLU()
self.p3_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p3_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p3_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
#self.p4_out_conv = nn.Conv2d(P4_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p4_out_conv = nn.Conv2d(self.W_bifpn, self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p4_out_act = nn.ReLU()
self.p4_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p4_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p4_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p4_out_w3 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p3_downsample= nn.MaxPool2d(kernel_size=2)
#self.p5_out_conv = nn.Conv2d(P5_channels,self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p5_out_conv = nn.Conv2d(self.W_bifpn,self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p5_out_act = nn.ReLU()
self.p5_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p5_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p5_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p5_out_w3 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p4_downsample= nn.MaxPool2d(kernel_size=2)
#self.p6_out_conv = nn.Conv2d(P6_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p6_out_conv = nn.Conv2d(self.W_bifpn, self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p6_out_act = nn.ReLU()
self.p6_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p6_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p6_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p6_out_w3 = torch.tensor(1, dtype=torch.float, requires_grad=True)
#self.p4_downsample= nn.MaxPool2d(kernel_size=2)
self.p7_out_conv = nn.Conv2d(P7_channels,self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1)
self.p7_out_conv_2 = nn.Conv2d(self.W_bifpn,self.W_bifpn, kernel_size=3, stride=1, groups=self.W_bifpn, bias=True, padding=1)
self.p7_out_act = nn.ReLU()
self.p7_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
self.p7_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
self.p7_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
def forward(self, inputs):
epsilon = 0.0001
P3, P4, P5, P6, P7 = inputs
#print ("Input::", P3.shape, P4.shape, P5.shape, P6.shape, P7.shape)
#P6_td = self.p6_td_conv((self.p6_td_w1 * P6 ) /
# (self.p6_td_w1 + epsilon))
P7_td = self.p7_out_conv(P7)
P6_td_inp = self.p6_td_conv(P6)
P6_td = self.p6_td_conv_2((self.p6_td_w1 * P6_td_inp + self.p6_td_w2 * P7_td) /
(self.p6_td_w1 + self.p6_td_w2 + epsilon))
#P6_td = self.p6_td_conv_2(P6_td_inp)
P6_td = self.p6_td_act(P6_td)
P6_td = self.p6_td_conv_bn(P6_td)
P5_td_inp = self.p5_td_conv(P5)
#print (P5_td_inp.shape, P6_td.shape)
P5_td = self.p5_td_conv_2((self.p5_td_w1 * P5_td_inp + self.p5_td_w2 * P6_td) /
(self.p5_td_w1 + self.p5_td_w2 + epsilon))
P5_td = self.p5_td_act(P5_td)
P5_td = self.p5_td_conv_bn(P5_td)
#print (P4.shape, P5_td.shape)
P4_td_inp = self.p4_td_conv(P4)
P4_td = self.p4_td_conv_2((self.p4_td_w1 * P4_td_inp + self.p4_td_w2 * self.p5_upsample(P5_td)) /
(self.p4_td_w1 + self.p4_td_w2 + epsilon))
P4_td = self.p4_td_act(P4_td)
P4_td = self.p4_td_conv_bn(P4_td)
P3_td = self.p3_out_conv(P3)
P3_out = self.p3_out_conv_2((self.p3_out_w1 * P3_td + self.p3_out_w2 * self.p4_upsample(P4_td)) /
(self.p3_out_w1 + self.p3_out_w2 + epsilon))
P3_out = self.p3_out_act(P3_out)
P3_out = self.p3_out_conv_bn(P3_out)
#print (P4_td.shape, P3_out.shape)
P4_out = self.p4_out_conv((self.p4_out_w1 * P4_td_inp + self.p4_out_w2 * P4_td + self.p4_out_w3 * self.p3_downsample(P3_out) )
/ (self.p4_out_w1 + self.p4_out_w2 + self.p4_out_w3 + epsilon))
P4_out = self.p4_out_act(P4_out)
P4_out = self.p4_out_conv_bn(P4_out)
P5_out = self.p5_out_conv(( self.p5_out_w1 * P5_td_inp + self.p5_out_w2 * P5_td + self.p5_out_w3 * self.p4_downsample(P4_out) )
/ (self.p5_out_w2 + self.p5_out_w3 + epsilon))
P5_out = self.p5_out_act(P5_out)
P5_out = self.p5_out_conv_bn(P5_out)
P6_out = self.p6_out_conv((self.p6_out_w1 * P6_td_inp + self.p6_out_w2 * P6_td + self.p6_out_w3 * (P5_out) )
/ (self.p6_out_w1 + self.p6_out_w2 + self.p6_out_w3 + epsilon))
P6_out = self.p6_out_act(P6_out)
P6_out = self.p6_out_conv_bn(P6_out)
P7_out = self.p7_out_conv_2((self.p7_out_w1 * P7_td + self.p7_out_w2 * P6_out) /
(self.p7_out_w1 + self.p7_out_w2 + epsilon))
P7_out = self.p7_out_act(P7_out)
P7_out = self.p7_out_conv_bn(P7_out)
return [P3_out, P4_out, P5_out, P6_out, P7_out]
fpn = BiFPN([40, 112, 192, 192, 1280])
c1 = torch.randn([1, 40, 64, 64])
c2 = torch.randn([1, 112, 32, 32])
c3 = torch.randn([1, 192, 16, 16])
c4 = torch.randn([1, 192, 16, 16])
c5 = torch.randn([1, 1280, 16, 16])
feats = [c1, c2, c3, c4, c5]
output = fpn.forward(feats)