From 7fa2dcf8cfd9db99b767aad6ed1cb403b83cff70 Mon Sep 17 00:00:00 2001 From: Geevarghese George Date: Sat, 5 Nov 2022 16:49:55 +0100 Subject: [PATCH 1/3] set bias=False for downsample-B Signed-off-by: Geevarghese George --- monai/networks/nets/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index e923c1bb7d..0b0984c874 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -277,7 +277,7 @@ def _make_layer( ) else: downsample = nn.Sequential( - conv_type(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride), + conv_type(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), norm_type(planes * block.expansion), ) From b27a6a3397ac35535648ce7f4fa189ea2bfa838d Mon Sep 17 00:00:00 2001 From: Geevarghese George Date: Fri, 11 Nov 2022 21:50:23 +0100 Subject: [PATCH 2/3] add bias_downsample kwarg to ResNet Signed-off-by: Geevarghese George --- monai/networks/nets/resnet.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 0b0984c874..fca975d40e 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -175,6 +175,7 @@ class ResNet(nn.Module): widen_factor: widen output for each layer. num_classes: number of output (classifications). feed_forward: whether to add the FC layer for the output, default to `True`. + bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`. """ @@ -192,6 +193,7 @@ def __init__( widen_factor: float = 1.0, num_classes: int = 400, feed_forward: bool = True, + bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) ) -> None: super().__init__() @@ -216,6 +218,7 @@ def __init__( self.in_planes = block_inplanes[0] self.no_max_pool = no_max_pool + self.bias_downsample = bias_downsample conv1_kernel_size = ensure_tuple_rep(conv1_t_size, spatial_dims) conv1_stride = ensure_tuple_rep(conv1_t_stride, spatial_dims) @@ -277,7 +280,13 @@ def _make_layer( ) else: downsample = nn.Sequential( - conv_type(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + conv_type( + self.in_planes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=self.bias_downsample, + ), norm_type(planes * block.expansion), ) @@ -323,7 +332,7 @@ def _resnet( progress: bool, **kwargs: Any, ) -> ResNet: - model: ResNet = ResNet(block, layers, block_inplanes, **kwargs) + model: ResNet = ResNet(block, layers, block_inplanes, bias_downsample=not pretrained, **kwargs) if pretrained: # Author of paper zipped the state_dict on googledrive, # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). From a203493e706566bc83fe3c3e9a2ffcc74007d5e0 Mon Sep 17 00:00:00 2001 From: Geevarghese George Date: Mon, 14 Nov 2022 19:00:07 +0100 Subject: [PATCH 3/3] add to test_resnet: TEST_CASE_7 Signed-off-by: Geevarghese George --- tests/test_resnet.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index ae05f36210..b09b97a450 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -140,11 +140,27 @@ (1, 3), ] +TEST_CASE_7 = [ # 1D, batch 1, 2 input channels, bias_downsample + { + "block": "bottleneck", + "layers": [3, 4, 6, 3], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + "bias_downsample": False, # set to False if pretrained=True (PR #5477) + }, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) -for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6]: +for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]: TEST_CASES.append([ResNet, *case]) TEST_SCRIPT_CASES = [