Skip to content

Commit

Permalink
Merge pull request #2 from mateuszbuda/hubconf
Browse files Browse the repository at this point in the history
torch hub entrypoint config
  • Loading branch information
mateuszbuda authored Jun 11, 2019
2 parents 217620b + bd57174 commit df5a0cd
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
dependencies = ["torch"]

import torch

from unet import UNet


def unet(pretrained=False, **kwargs):
"""
U-Net segmentation model with batch normalization for biomedical image segmentation
pretrained (bool): load pretrained weights into the model
in_channels (int): number of input channels
out_channels (int): number of output channels
init_features (int): number of feature-maps in the first encoder layer
"""
model = UNet(**kwargs)

if pretrained:
state_dict = torch.load("weights/unet.pt")
model.load_state_dict(state_dict)

return model

0 comments on commit df5a0cd

Please sign in to comment.