-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnon_local1D.py
72 lines (56 loc) · 2.53 KB
/
non_local1D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
class Rs_GCN(nn.Module):
def __init__(self, in_channels, inter_channels, bn_layer=True):
super(Rs_GCN, self).__init__()
self.in_channels = in_channels
self.inter_channels = inter_channels
# 1D卷积
self.g = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
# 是否使用BN层
if bn_layer:
self.W = nn.Sequential(
# 1D卷积升维度
nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm1d(self.in_channels), )
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
# 1D卷积
self.theta = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
# 1D卷积
self.phi = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
def forward(self, v):
'''
:param v: (B, D, N)
:return:
'''
batch_size = v.size(0)
# 1D卷积降维
g_v = self.g(v).view(batch_size, self.inter_channels, -1) # [16, 512, 80]
g_v = g_v.permute(0, 2, 1) # [16, 80, 512]
# 1D卷积降维
theta_v = self.theta(v).view(batch_size, self.inter_channels, -1) # [16, 512, 80]
theta_v = theta_v.permute(0, 2, 1) # [16, 80, 512]
# 1D卷积降维
phi_v = self.phi(v).view(batch_size, self.inter_channels, -1) # [16, 512, 80]
# 矩阵相乘
R = torch.matmul(theta_v, phi_v) # [16, 80, 512] * [16, 512, 80] ==> [16, 80, 80]
N = R.size(-1)
R_div_C = R / N
# 矩阵相乘
y = torch.matmul(R_div_C, g_v) # [16, 80, 80] * [16, 80, 512] ==> [16, 80, 512]
y = y.permute(0, 2, 1).contiguous() # [16, 512, 80]
y = y.view(batch_size, self.inter_channels, *v.size()[2:])
# 1D卷积升维度
W_y = self.W(y)
# 残差连接
v_star = W_y + v
return v_star, R_div_C