Skip to content

Commit

Permalink
Add timm ConvNeXt 'atto' weights, change test resolution for FB ConvN…
Browse files Browse the repository at this point in the history
…eXt 224x224 weights, add support for different dw kernel_size
  • Loading branch information
rwightman committed Aug 16, 2022
1 parent 7c4682d commit 1d8ada3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 26 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ Thanks to the following for hardware support:
And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face.

## What's New

### Aug 15, 2022
* ConvNeXt atto weights added
* `convnext_atto` - 75.7 @ 224, 77.0 @ 288
* `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288

### Aug 5, 2022
* More custom ConvNeXt smaller model defs with weights
* `convnext_femto` - 77.5 @ 224, 78.7 @ 288
Expand Down
77 changes: 52 additions & 25 deletions timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
create_conv2d, make_divisible
create_conv2d, get_act_layer, make_divisible, to_ntuple
from .registry import register_model


Expand All @@ -40,14 +39,13 @@ def _cfg(url='', **kwargs):


default_cfgs = dict(
convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"),
convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"),
convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),

# timm specific variants
convnext_atto=_cfg(url=''),
convnext_atto_ols=_cfg(url=''),
convnext_atto=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
convnext_atto_ols=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
convnext_femto=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
Expand All @@ -70,16 +68,34 @@ def _cfg(url='', **kwargs):
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),

convnext_tiny=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_small=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_base=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_large=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),

convnext_tiny_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'),
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_small_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'),
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_base_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_large_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'),
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
convnext_xlarge_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
test_input_size=(3, 288, 288), test_crop_pct=1.0),

convnext_tiny_384_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
Expand Down Expand Up @@ -121,37 +137,39 @@ class ConvNeXtBlock(nn.Module):
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
Args:
dim (int): Number of input channels.
in_chs (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""

def __init__(
self,
dim,
dim_out=None,
in_chs,
out_chs=None,
kernel_size=7,
stride=1,
dilation=1,
mlp_ratio=4,
conv_mlp=False,
conv_bias=True,
ls_init_value=1e-6,
act_layer='gelu',
norm_layer=None,
act_layer=nn.GELU,
drop_path=0.,
):
super().__init__()
dim_out = dim_out or dim
out_chs = out_chs or in_chs
act_layer = get_act_layer(act_layer)
if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp

self.conv_dw = create_conv2d(
dim, dim_out, kernel_size=7, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
self.norm = norm_layer(dim_out)
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
Expand All @@ -178,13 +196,15 @@ def __init__(
self,
in_chs,
out_chs,
kernel_size=7,
stride=2,
depth=2,
dilation=(1, 1),
drop_path_rates=None,
ls_init_value=1.0,
conv_mlp=False,
conv_bias=True,
act_layer='gelu',
norm_layer=None,
norm_layer_cl=None
):
Expand All @@ -208,13 +228,15 @@ def __init__(
stage_blocks = []
for i in range(depth):
stage_blocks.append(ConvNeXtBlock(
dim=in_chs,
dim_out=out_chs,
in_chs=in_chs,
out_chs=out_chs,
kernel_size=kernel_size,
dilation=dilation[1],
drop_path=drop_path_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
act_layer=act_layer,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
))
in_chs = out_chs
Expand Down Expand Up @@ -252,19 +274,22 @@ def __init__(
output_stride=32,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
kernel_sizes=7,
ls_init_value=1e-6,
stem_type='patch',
patch_size=4,
head_init_scale=1.,
head_norm_first=False,
conv_mlp=False,
conv_bias=True,
act_layer='gelu',
norm_layer=None,
drop_rate=0.,
drop_path_rate=0.,
):
super().__init__()
assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes)
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
Expand Down Expand Up @@ -312,13 +337,15 @@ def __init__(
stages.append(ConvNeXtStage(
prev_chs,
out_chs,
kernel_size=kernel_sizes[i],
stride=stride,
dilation=(first_dilation, dilation),
depth=depths[i],
drop_path_rates=dp_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
))
Expand Down
2 changes: 1 addition & 1 deletion timm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.6.8'
__version__ = '0.6.9'

0 comments on commit 1d8ada3

Please sign in to comment.