-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ResNet implementation: set bias=False for downsample-B #5477
Conversation
Signed-off-by: Geevarghese George <[email protected]>
We hadn't implemented downloading pre-trained weights so that part isn't affected by this, but other saved instances of this network like here won't load correctly. To maintain backwards compatibility we should add a |
Yes, that makes sense, I'll make the additions in the coming days. Making this PR a draft for now. |
Signed-off-by: Geevarghese George <[email protected]>
Made the requisite changes to accept an additional kwarg as discussed. How does this look @ericspod? |
Looks good to me, we should add a test case to |
Agreed! What would the test case look like: would it be more of a sanity check to see if the expected shapes are returned with |
Yes I don't think much more is needed than that, if we had a version of an existing unit test with |
Signed-off-by: Geevarghese George <[email protected]>
745b5af
to
e5dafbe
Compare
Just added to the test_resnet case as discussed. ^ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks good to me.
/build |
/build |
1 similar comment
/build |
…#5477) Signed-off-by: Geevarghese George <[email protected]> ### Description This is a simple fix following Project-MONAI#5465. The downsampling layer is not expected to have a bias term. The previous implementation did not explicitly set `bias=False` and defaulted to PyTorch Conv3D/2D where `bias=True`. With this change, the correct number (62 for resnet18) of parameter groups are returned: ```python from torchvision import models from monai.networks import nets d2net_torch = models.resnet18() d2net_monai = nets.resnet18(spatial_dims=2) d3net_monai = nets.resnet18(spatial_dims=3) len(list(d2net_torch.parameters())), len(list(d2net_monai.parameters())), len(list(d3net_monai.parameters())) # 62, 62, 62 # before: 62 65 65 ``` Other deeper 2D ResNet architectures are also comparable to the PyTorch implementation; the `pretrained`/`weights` parameter can be allowed for these networks. Currently, it raises a `NotImplemetedError` with `pretrained=True` even for 2D ResNets. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Geevarghese George <[email protected]> Signed-off-by: Behrooz <[email protected]>
Currently, the
When manually loading MedicalNet weights, the downsample bias terms raise errors as they are not present in the loaded weights. It is also not possible to remove |
Hi @acerdur @wyli Details: AFAIU the error you are facing comes from the shortcut layer. To correctly load the pretrained weights of MedNet, you should initialize the model with the correct achitecture with from monai.networks import nets
# MONAI ResNet18
net = nets.resnet18(pretrained=False, spatial_dims=3, n_input_channels=1, num_classes=2, shortcut_type='A')
wt_path = 'resnet_18.pth' # path to weights from Google Drive of Tencent
pretrained_weights = torch.load(f=wt_path , map_location=device)
# match the keys
weights = OrderedDict()
for k, v in pretrained_weights['state_dict'].items():
weights.update({k.replace('module.', ''): v})
net.load_state_dict(weights, strict=False) # _IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[]) The only pair of incompatible keys are for the last linear layer, which is expected for finetuning, and not provided/inferable in the MedNet weights. Related: #6811 |
Signed-off-by: Geevarghese George [email protected]
Description
This is a simple fix following #5465.
The downsampling layer is not expected to have a bias term. The previous implementation did not explicitly set
bias=False
and defaulted to PyTorch Conv3D/2D wherebias=True
. With this change, the correct number (62 for resnet18) of parameter groups are returned:Other deeper 2D ResNet architectures are also comparable to the PyTorch implementation; the
pretrained
/weights
parameter can be allowed for these networks. Currently, it raises aNotImplemetedError
withpretrained=True
even for 2D ResNets.Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.