From bd571746f7d5967b3ee676dc298e522d0278c47d Mon Sep 17 00:00:00 2001 From: Mateusz Buda Date: Mon, 10 Jun 2019 23:10:25 -0400 Subject: [PATCH] torch hub entrypoint config --- hubconf.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 hubconf.py diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..a033ba7 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,22 @@ +dependencies = ["torch"] + +import torch + +from unet import UNet + + +def unet(pretrained=False, **kwargs): + """ + U-Net segmentation model with batch normalization for biomedical image segmentation + pretrained (bool): load pretrained weights into the model + in_channels (int): number of input channels + out_channels (int): number of output channels + init_features (int): number of feature-maps in the first encoder layer + """ + model = UNet(**kwargs) + + if pretrained: + state_dict = torch.load("weights/unet.pt") + model.load_state_dict(state_dict) + + return model