Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResNet implementation: set bias=False for downsample-B #5477

Merged
merged 5 commits into from
Nov 15, 2022

Conversation

thatgeeman
Copy link
Contributor

@thatgeeman thatgeeman commented Nov 5, 2022

Signed-off-by: Geevarghese George [email protected]

Description

This is a simple fix following #5465.
The downsampling layer is not expected to have a bias term. The previous implementation did not explicitly set bias=False and defaulted to PyTorch Conv3D/2D where bias=True. With this change, the correct number (62 for resnet18) of parameter groups are returned:

from torchvision import models 
from monai.networks import nets
d2net_torch = models.resnet18()
d2net_monai = nets.resnet18(spatial_dims=2)
d3net_monai = nets.resnet18(spatial_dims=3)
len(list(d2net_torch.parameters())), len(list(d2net_monai.parameters())), len(list(d3net_monai.parameters())) 
# 62, 62, 62
# before: 62 65 65 

Other deeper 2D ResNet architectures are also comparable to the PyTorch implementation; the pretrained/weights parameter can be allowed for these networks. Currently, it raises a NotImplemetedError with pretrained=True even for 2D ResNets.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: Geevarghese George <[email protected]>
@thatgeeman thatgeeman changed the title ResNet implementaion: set bias=False for downsample-B ResNet implementation: set bias=False for downsample-B Nov 5, 2022
@ericspod
Copy link
Member

ericspod commented Nov 6, 2022

We hadn't implemented downloading pre-trained weights so that part isn't affected by this, but other saved instances of this network like here won't load correctly. To maintain backwards compatibility we should add a bias argument at the end of the constructor's arguments whose default is True which sets the bias argument, when a standard ResNet compatible with other pretrained weights is requested this would be set to False then.

@thatgeeman thatgeeman marked this pull request as draft November 6, 2022 16:38
@thatgeeman
Copy link
Contributor Author

Yes, that makes sense, I'll make the additions in the coming days. Making this PR a draft for now.

Signed-off-by: Geevarghese George <[email protected]>
@thatgeeman
Copy link
Contributor Author

Made the requisite changes to accept an additional kwarg as discussed. How does this look @ericspod?

@thatgeeman thatgeeman marked this pull request as ready for review November 12, 2022 12:01
@ericspod
Copy link
Member

Looks good to me, we should add a test case to test_resnet.py to cover this change as well.

@thatgeeman
Copy link
Contributor Author

Agreed! What would the test case look like: would it be more of a sanity check to see if the expected shapes are returned with pretrained=True?

@ericspod
Copy link
Member

Yes I don't think much more is needed than that, if we had a version of an existing unit test with bias_downsample=False that would be enough to show that the network still is correct.

@thatgeeman
Copy link
Contributor Author

thatgeeman commented Nov 14, 2022

Just added to the test_resnet case as discussed. ^

Copy link
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks good to me.

@wyli
Copy link
Contributor

wyli commented Nov 14, 2022

/build

@wyli wyli enabled auto-merge (squash) November 14, 2022 18:50
@wyli
Copy link
Contributor

wyli commented Nov 15, 2022

/build

1 similar comment
@wyli
Copy link
Contributor

wyli commented Nov 15, 2022

/build

@wyli wyli merged commit 5e6f105 into Project-MONAI:dev Nov 15, 2022
@thatgeeman thatgeeman deleted the 5465-resnet-downsample branch November 15, 2022 11:52
drbeh pushed a commit to drbeh/MONAI that referenced this pull request Nov 23, 2022
…#5477)

Signed-off-by: Geevarghese George <[email protected]>

### Description

This is a simple fix following Project-MONAI#5465. 
The downsampling layer is not expected to have a bias term. The previous
implementation did not explicitly set `bias=False` and defaulted to
PyTorch Conv3D/2D where `bias=True`. With this change, the correct
number (62 for resnet18) of parameter groups are returned:
```python
from torchvision import models 
from monai.networks import nets
d2net_torch = models.resnet18()
d2net_monai = nets.resnet18(spatial_dims=2)
d3net_monai = nets.resnet18(spatial_dims=3)
len(list(d2net_torch.parameters())), len(list(d2net_monai.parameters())), len(list(d3net_monai.parameters())) 
# 62, 62, 62
# before: 62 65 65 
```
Other deeper 2D ResNet architectures are also comparable to the PyTorch
implementation; the `pretrained`/`weights` parameter can be allowed for
these networks. Currently, it raises a `NotImplemetedError` with
`pretrained=True` even for 2D ResNets.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Geevarghese George <[email protected]>
Signed-off-by: Behrooz <[email protected]>
@acerdur
Copy link

acerdur commented Aug 1, 2023

Currently, the bias_downsample=False argument is contradicting with pretraining, as it is hard coded to be not pretrained in ResNet constructor:

model: ResNet = ResNet(block, layers, block_inplanes, bias_downsample=not pretrained, **kwargs)
    if pretrained:
        # Author of paper zipped the state_dict on googledrive,
        # so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
        # Would like to load dict from url but need somewhere to save the state dicts.
        raise NotImplementedError(
            "Currently not implemented. You need to manually download weights provided by the paper's author"
            " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
        )
    return model

When manually loading MedicalNet weights, the downsample bias terms raise errors as they are not present in the loaded weights. It is also not possible to remove bias_downsample by setting pretrained=True, this raises NotImplementedError.
So, can you please remove the hard coding from the model constructor in the source code?
Thanks

@thatgeeman
Copy link
Contributor Author

thatgeeman commented Aug 7, 2023

Hi @acerdur @wyli
Since there were no error logs provided in the above comments, I'm assuming the issue comes only from the shortcut_type used.

Details:
The sole purpose of passing bias_downsample=False is to match with the MedNet and official PyTorch implementation of ResNet which sets the bias=False in the downsampling layer. As for the downsampling layer, there are two variants in MONAI:
shortcut_type='B' # uses a conv1x1 as downsampling layer
shortcut_type='A' # uses a avgpool1x1 as downsampling layer

AFAIU the error you are facing comes from the shortcut layer. To correctly load the pretrained weights of MedNet, you should initialize the model with the correct achitecture with shortcut_type='A' for MedNet:

from monai.networks import nets
# MONAI ResNet18
net = nets.resnet18(pretrained=False, spatial_dims=3, n_input_channels=1, num_classes=2, shortcut_type='A')
wt_path = 'resnet_18.pth'  # path to weights from Google Drive of Tencent
pretrained_weights = torch.load(f=wt_path , map_location=device)

# match the keys 
weights = OrderedDict()
for k, v in pretrained_weights['state_dict'].items():
    weights.update({k.replace('module.', ''): v})

net.load_state_dict(weights, strict=False)  # _IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

The only pair of incompatible keys are for the last linear layer, which is expected for finetuning, and not provided/inferable in the MedNet weights.

Related: #6811

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants