diff --git a/docs/source/models.rst b/docs/source/models.rst index 06d6c8b86..4961abf19 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -49,6 +49,8 @@ ResNet .. autofunction:: rexnet1_5x +.. autofunction:: rexnet2_0x + .. autofunction:: rexnet2_2x diff --git a/holocron/models/resnet.py b/holocron/models/resnet.py index f78c99bd6..d1b475b2e 100644 --- a/holocron/models/resnet.py +++ b/holocron/models/resnet.py @@ -9,12 +9,13 @@ from math import ceil from collections import OrderedDict import torch.nn as nn +from torchvision.models.utils import load_state_dict_from_url from holocron.nn import SiLU, init __all__ = ['BasicBlock', 'Bottleneck', 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'SEBlock', 'ReXBlock', 'ReXNet', 'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_2x'] + 'SEBlock', 'ReXBlock', 'ReXNet', 'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_0x', 'rexnet2_2x'] default_cfgs = { @@ -33,11 +34,13 @@ 'resnext101_32x8d': {'block': 'Bottleneck', 'num_blocks': [3, 4, 23, 3], 'url': None}, 'rexnet1_0x': {'width_mult': 1.0, 'depth_mult': 1.0, - 'url': None}, + 'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_0_224-a120bf73.pth'}, 'rexnet1_3x': {'width_mult': 1.3, 'depth_mult': 1.0, - 'url': None}, + 'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_3_224-191b60f1.pth'}, 'rexnet1_5x': {'width_mult': 1.5, 'depth_mult': 1.0, - 'url': None}, + 'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_5_224-30ce6260.pth'}, + 'rexnet2_0x': {'width_mult': 2.0, 'depth_mult': 1.0, + 'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet2_0_224-e5243878.pth'}, 'rexnet2_2x': {'width_mult': 2.2, 'depth_mult': 1.0, 'url': None}, } @@ -241,7 +244,7 @@ def forward(self, x): class ReXNet(nn.Sequential): - def __init__(self, width_mult=1.0, depth_mult=1.0, num_classes=10, in_channels=3, in_planes=16, final_planes=180, + def __init__(self, width_mult=1.0, depth_mult=1.0, num_classes=1000, in_channels=3, in_planes=16, final_planes=180, use_se=True, se_ratio=12, dropout_ratio=0.2, bn_momentum=0.9, act_layer=None, norm_layer=None, drop_layer=None): """Mostly adapted from https://github.com/clovaai/rexnet/blob/master/rexnetv1.py""" @@ -478,6 +481,22 @@ def rexnet1_5x(pretrained=False, progress=True, **kwargs): return _rexnet('rexnet1_5x', pretrained, progress, **kwargs) +def rexnet2_0x(pretrained=False, progress=True, **kwargs): + """ReXNet-2.0x from + `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" + `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + torch.nn.Module: classification model + """ + + return _rexnet('rexnet2_0x', pretrained, progress, **kwargs) + + def rexnet2_2x(pretrained=False, progress=True, **kwargs): """ReXNet-2.2x from `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" diff --git a/test/test_models.py b/test/test_models.py index 150f663f7..8f9801ed7 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -104,7 +104,7 @@ def do_test(self, model_name=model_name): for model_name in ['darknet24', 'darknet19', 'darknet53', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_2x']: + 'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_0x', 'rexnet2_2x']: def do_test(self, model_name=model_name): self._test_classification_model(model_name)