Skip to content

Commit

Permalink
feat: Added Rexnet 2.0x and pretrained weights (#60)
Browse files Browse the repository at this point in the history
* feat: Added ReXNet-2.0x

* test: Updated unittest

* docs: Updated documentation

* feat: Added pretrained weights from clovaai/rexnet

Added pretrained weights from https://github.com/clovaai/rexnet

* style: Fixed lint

* fix: Fixed weight loading
  • Loading branch information
frgfm authored Jul 21, 2020
1 parent 1e45838 commit 59c3124
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ ResNet

.. autofunction:: rexnet1_5x

.. autofunction:: rexnet2_0x

.. autofunction:: rexnet2_2x


Expand Down
29 changes: 24 additions & 5 deletions holocron/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from math import ceil
from collections import OrderedDict
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from holocron.nn import SiLU, init


__all__ = ['BasicBlock', 'Bottleneck', 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d',
'SEBlock', 'ReXBlock', 'ReXNet', 'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_2x']
'SEBlock', 'ReXBlock', 'ReXNet', 'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_0x', 'rexnet2_2x']


default_cfgs = {
Expand All @@ -33,11 +34,13 @@
'resnext101_32x8d': {'block': 'Bottleneck', 'num_blocks': [3, 4, 23, 3],
'url': None},
'rexnet1_0x': {'width_mult': 1.0, 'depth_mult': 1.0,
'url': None},
'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_0_224-a120bf73.pth'},
'rexnet1_3x': {'width_mult': 1.3, 'depth_mult': 1.0,
'url': None},
'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_3_224-191b60f1.pth'},
'rexnet1_5x': {'width_mult': 1.5, 'depth_mult': 1.0,
'url': None},
'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_5_224-30ce6260.pth'},
'rexnet2_0x': {'width_mult': 2.0, 'depth_mult': 1.0,
'url': 'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet2_0_224-e5243878.pth'},
'rexnet2_2x': {'width_mult': 2.2, 'depth_mult': 1.0,
'url': None},
}
Expand Down Expand Up @@ -241,7 +244,7 @@ def forward(self, x):


class ReXNet(nn.Sequential):
def __init__(self, width_mult=1.0, depth_mult=1.0, num_classes=10, in_channels=3, in_planes=16, final_planes=180,
def __init__(self, width_mult=1.0, depth_mult=1.0, num_classes=1000, in_channels=3, in_planes=16, final_planes=180,
use_se=True, se_ratio=12, dropout_ratio=0.2, bn_momentum=0.9,
act_layer=None, norm_layer=None, drop_layer=None):
"""Mostly adapted from https://github.com/clovaai/rexnet/blob/master/rexnetv1.py"""
Expand Down Expand Up @@ -478,6 +481,22 @@ def rexnet1_5x(pretrained=False, progress=True, **kwargs):
return _rexnet('rexnet1_5x', pretrained, progress, **kwargs)


def rexnet2_0x(pretrained=False, progress=True, **kwargs):
"""ReXNet-2.0x from
`"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network"
<https://arxiv.org/pdf/2007.00992.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""

return _rexnet('rexnet2_0x', pretrained, progress, **kwargs)


def rexnet2_2x(pretrained=False, progress=True, **kwargs):
"""ReXNet-2.2x from
`"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network"
Expand Down
2 changes: 1 addition & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def do_test(self, model_name=model_name):
for model_name in ['darknet24', 'darknet19', 'darknet53',
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d',
'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_2x']:
'rexnet1_0x', 'rexnet1_3x', 'rexnet1_5x', 'rexnet2_0x', 'rexnet2_2x']:
def do_test(self, model_name=model_name):
self._test_classification_model(model_name)

Expand Down

0 comments on commit 59c3124

Please sign in to comment.