You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def __init__(self, in_channel, out_channel, mid_channel, is_fuse=True):
super(ABF, self).__init__()
self.conv_first = nn.Sequential(
nn.Conv2d(in_channel, mid_channel, kernel_size=(1, 1), bias=False),
nn.BatchNorm2d(mid_channel)
)
self.conv_last = nn.Sequential(
nn.Conv2d(mid_channel, out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(out_channel)
)
self.att_conv = None if not is_fuse else nn.Sequential(
nn.Conv2d(mid_channel * 2, 2, kernel_size=(1, 1)),
nn.Sigmoid()
)
self.__init_weights()
def __init_weights(self):
nn.init.kaiming_uniform_(self.conv_first[0].weight, a=1)
nn.init.kaiming_uniform_(self.conv_last[0].weight, a=1)
def forward(self, x, y=None, shape=None):
assert len(x.shape) == 4
N, _, H, W = x.shape[:4]
x = self.conv_first(x)
if self.att_conv is not None:
# up sample residual features
y = F.interpolate(y, shape, mode="nearest")
# fusion
z = torch.cat([x, y], dim=1)
z = self.att_conv(z)
x = (x * z[:, 0].view(N, 1, H, W) + y * z[:, 1].view(N, 1, H, W))
y = self.conv_last(x)
return y, x
In the 'forward' function, only the channel of y seems must be equal to mid_channel if self.att_conv could work.But the input y is res_features, the channel's number of res_features seem can't be guaranteed to be equal to mid_channel.
The text was updated successfully, but these errors were encountered:
class ABF(nn.Module):
In the 'forward' function, only the channel of y seems must be equal to mid_channel if self.att_conv could work.But the input y is res_features, the channel's number of res_features seem can't be guaranteed to be equal to mid_channel.
The text was updated successfully, but these errors were encountered: