Skip to content

Commit

Permalink
style: 💄 Black format.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Aug 12, 2024
1 parent 125894e commit e6d0277
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/cellmap_models/pytorch/untrained_models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

## Models

***ResNet***: Parameterizable 2D and 3D ResNet models with a variable number of layers and channels. This model is based on the original ResNet architecture with the addition of a decoding path, which mirrors the encoder, after the bottleneck, to produce an image output.
- **ResNet**: Parameterizable 2D and 3D ResNet models with a variable number of layers and channels. This model is based on the original ResNet architecture with the addition of a decoding path, which mirrors the encoder, after the bottleneck, to produce an image output.

***UNet2D***: A simple 2D UNet model with a variable number of output channels.
- **UNet2D**: A simple 2D UNet model with a variable number of output channels.

***UNet3D***: A simple 3D UNet model with a variable number of output channels.
- **UNet3D**: A simple 3D UNet model with a variable number of output channels.

***ViTVNet***: A 3D VNet model with a Vision Transformer (ViT) encoder. This model is based on the original VNet architecture with the addition of a ViT encoder in place of the original convolutional encoder.
- **ViTVNet**: A 3D VNet model with a Vision Transformer (ViT) encoder. This model is based on the original VNet architecture with the addition of a ViT encoder in place of the original convolutional encoder.

## Usage

Expand Down
3 changes: 3 additions & 0 deletions src/cellmap_models/pytorch/untrained_models/unet_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Original source code from:
# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py


class UNet2D(nn.Module):
def __init__(self, n_channels, n_classes, trilinear=False):
super(UNet2D, self).__init__()
Expand Down Expand Up @@ -38,10 +39,12 @@ def forward(self, x):
logits = self.outc(x)
return logits


""" Parts of the 2D U-Net model """
# Original source code from:
# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py


class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

Expand Down
4 changes: 3 additions & 1 deletion src/cellmap_models/pytorch/untrained_models/unet_3D.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -7,6 +6,7 @@
# Original source code from:
# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py


class UNet3D(nn.Module):
def __init__(self, n_channels, n_classes, trilinear=False):
super(UNet3D, self).__init__()
Expand Down Expand Up @@ -39,11 +39,13 @@ def forward(self, x):
logits = self.outc(x)
return logits


""" Parts of the U-Net model """
# Adapted from:
# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
# By Emma Avetissian, @aemmav


class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

Expand Down

0 comments on commit e6d0277

Please sign in to comment.