Skip to content

Commit

Permalink
[fix] forgot resnets
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 30, 2024
1 parent 951b2bc commit 0cd6f8f
Showing 1 changed file with 341 additions and 0 deletions.
341 changes: 341 additions & 0 deletions bioimage_embed/models/bolts/resnets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
# https://github.com/Lightning-Universe/lightning-bolts
# Credit

import torch
from torch import nn
from torch.nn import functional as F # noqa: N812



class Interpolate(nn.Module):
"""nn.Module wrapper for F.interpolate."""

def __init__(self, size=None, scale_factor=None) -> None:
super().__init__()
self.size, self.scale_factor = size, scale_factor

def forward(self, x):
return F.interpolate(x, size=self.size, scale_factor=self.scale_factor)


def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding."""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)



def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution."""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)



def resize_conv3x3(in_planes, out_planes, scale=1):
"""Upsample + 3x3 convolution with padding to avoid checkerboard artifact."""
if scale == 1:
return conv3x3(in_planes, out_planes)
return nn.Sequential(Interpolate(scale_factor=scale), conv3x3(in_planes, out_planes))



def resize_conv1x1(in_planes, out_planes, scale=1):
"""Upsample + 1x1 convolution with padding to avoid checkerboard artifact."""
if scale == 1:
return conv1x1(in_planes, out_planes)
return nn.Sequential(Interpolate(scale_factor=scale), conv1x1(in_planes, out_planes))



class EncoderBlock(nn.Module):
"""ResNet block, copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L35."""

expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None) -> None:
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
return self.relu(out)



class EncoderBottleneck(nn.Module):
"""ResNet bottleneck, copied from
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L75."""

expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None) -> None:
super().__init__()
width = planes # this needs to change if we want wide resnets
self.conv1 = conv1x1(inplanes, width)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = conv3x3(width, width, stride)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
return self.relu(out)



class DecoderBlock(nn.Module):
"""ResNet block, but convs replaced with resize convs, and channel increase is in second conv, not first."""

expansion = 1

def __init__(self, inplanes, planes, scale=1, upsample=None) -> None:
super().__init__()
self.conv1 = resize_conv3x3(inplanes, inplanes)
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = resize_conv3x3(inplanes, planes, scale)
self.bn2 = nn.BatchNorm2d(planes)
self.upsample = upsample

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.upsample is not None:
identity = self.upsample(x)

out += identity
return self.relu(out)



class DecoderBottleneck(nn.Module):
"""ResNet bottleneck, but convs replaced with resize convs."""

expansion = 4

def __init__(self, inplanes, planes, scale=1, upsample=None) -> None:
super().__init__()
width = planes # this needs to change if we want wide resnets
self.conv1 = resize_conv1x1(inplanes, width)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = resize_conv3x3(width, width, scale)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.upsample = upsample
self.scale = scale

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.upsample is not None:
identity = self.upsample(x)

out += identity
return self.relu(out)



class ResNetEncoder(nn.Module):
def __init__(self, block, layers, first_conv=False, maxpool1=False) -> None:
super().__init__()

self.inplanes = 64
self.first_conv = first_conv
self.maxpool1 = maxpool1

if self.first_conv:
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
else:
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)

self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)

if self.maxpool1:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)

self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
return torch.flatten(x, 1)



class ResNetDecoder(nn.Module):
"""Resnet in reverse order."""

def __init__(self, block, layers, latent_dim, input_height, first_conv=False, maxpool1=False) -> None:
super().__init__()

self.expansion = block.expansion
self.inplanes = 512 * block.expansion
self.first_conv = first_conv
self.maxpool1 = maxpool1
self.input_height = input_height

self.upscale_factor = 8

self.linear = nn.Linear(latent_dim, self.inplanes * 4 * 4)

self.layer1 = self._make_layer(block, 256, layers[0], scale=2)
self.layer2 = self._make_layer(block, 128, layers[1], scale=2)
self.layer3 = self._make_layer(block, 64, layers[2], scale=2)

if self.maxpool1:
self.layer4 = self._make_layer(block, 64, layers[3], scale=2)
self.upscale_factor *= 2
else:
self.layer4 = self._make_layer(block, 64, layers[3])

if self.first_conv:
self.upscale = Interpolate(scale_factor=2)
self.upscale_factor *= 2
else:
self.upscale = Interpolate(scale_factor=1)

# interpolate after linear layer using scale factor
self.upscale1 = Interpolate(size=input_height // self.upscale_factor)

self.conv1 = nn.Conv2d(64 * block.expansion, 3, kernel_size=3, stride=1, padding=1, bias=False)

def _make_layer(self, block, planes, blocks, scale=1):
upsample = None
if scale != 1 or self.inplanes != planes * block.expansion:
upsample = nn.Sequential(
resize_conv1x1(self.inplanes, planes * block.expansion, scale),
nn.BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, scale, upsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
x = self.linear(x)

# NOTE: replaced this by Linear(in_channels, 514 * 4 * 4)
# x = F.interpolate(x, scale_factor=4)

x = x.view(x.size(0), 512 * self.expansion, 4, 4)
x = self.upscale1(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.upscale(x)

return self.conv1(x)



def resnet18_encoder(first_conv, maxpool1):
return ResNetEncoder(EncoderBlock, [2, 2, 2, 2], first_conv, maxpool1)



def resnet18_decoder(latent_dim, input_height, first_conv, maxpool1):
return ResNetDecoder(DecoderBlock, [2, 2, 2, 2], latent_dim, input_height, first_conv, maxpool1)



def resnet50_encoder(first_conv, maxpool1):
return ResNetEncoder(EncoderBottleneck, [3, 4, 6, 3], first_conv, maxpool1)



def resnet50_decoder(latent_dim, input_height, first_conv, maxpool1):
return ResNetDecoder(DecoderBottleneck, [3, 4, 6, 3], latent_dim, input_height, first_conv, maxpool1)

0 comments on commit 0cd6f8f

Please sign in to comment.