Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7047 simplify resnet pretrained #7095

Merged
merged 39 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0613c2f
Simplify resnet pretrained flag
vgrau98 Oct 4, 2023
1005fe8
add tests + typos
vgrau98 Oct 6, 2023
8b75095
add MedicalNet resnet 3D pretrained models support
vgrau98 Oct 6, 2023
00ec022
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2023
f5e09b1
add optional import
vgrau98 Oct 7, 2023
c7a827b
simplify user pretrained weights loading
vgrau98 Oct 7, 2023
fa60fad
Manage MedicalNet resnet model validation with pretrained flag
vgrau98 Oct 7, 2023
e9ba99d
update resnet tests
vgrau98 Oct 7, 2023
0ddfb04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2023
06ed8b0
update resnet unit tests
vgrau98 Oct 8, 2023
955dcf1
fix incorrect optional import
vgrau98 Oct 8, 2023
ff7f6d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2023
a707d1c
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 8, 2023
bb6830f
Line shortening
vgrau98 Oct 8, 2023
d21a022
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2023
5ac3627
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 9, 2023
3735733
update resnet tests and deployment files
vgrau98 Oct 9, 2023
8b6782a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2023
801edf6
Merge branch 'dev' into 7047-simplify-resnet-pretrained
wyli Oct 10, 2023
02a360a
[MONAI] code formatting
monai-bot Oct 10, 2023
1993f04
Update utils.py
vgrau98 Oct 10, 2023
e344771
Update utils.py
vgrau98 Oct 10, 2023
0ea682b
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 10, 2023
be516c0
Update resnet.py
vgrau98 Oct 10, 2023
3dd89de
Update utils.py
vgrau98 Oct 10, 2023
b74d48a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2023
46e7acc
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 11, 2023
6896216
fix lint error
vgrau98 Oct 16, 2023
b30ea73
minor refactos
vgrau98 Oct 16, 2023
9ee9e45
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 16, 2023
2945704
fix lint error
vgrau98 Oct 16, 2023
7a01bb5
fix typo
vgrau98 Oct 16, 2023
1a29bb2
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 16, 2023
8198203
fix mypy error
vgrau98 Oct 17, 2023
2930f97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2023
7a0a2c3
fix lint error
vgrau98 Oct 17, 2023
5b279b6
Merge branch 'dev' into 7047-simplify-resnet-pretrained
wyli Oct 18, 2023
741970d
update unit test
wyli Oct 18, 2023
89eda2c
local torch.cuda check
wyli Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from collections.abc import Callable
from functools import partial
from typing import Any
from pathlib import Path
import logging
import re
from huggingface_hub import hf_hub_download
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

import torch
import torch.nn as nn
Expand All @@ -23,6 +27,9 @@
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option

MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet"
MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_"

__all__ = [
"ResNet",
"ResNetBlock",
Expand All @@ -36,6 +43,9 @@
"resnet200",
]

logger = logging.getLogger(__name__)

device = "cuda" if torch.cuda.is_available() else "cpu"

def get_inplanes():
return [64, 128, 256, 512]
Expand Down Expand Up @@ -329,21 +339,40 @@ def _resnet(
block: type[ResNetBlock | ResNetBottleneck],
layers: list[int],
block_inplanes: list[int],
pretrained: bool,
pretrained: bool | str,
progress: bool,
**kwargs: Any,
) -> ResNet:
model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)
if pretrained:
# Author of paper zipped the state_dict on googledrive,
# so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
# Would like to load dict from url but need somewhere to save the state dicts.
raise NotImplementedError(
"Currently not implemented. You need to manually download weights provided by the paper's author"
" and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
"Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified"
"here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730"
)
if isinstance(pretrained, str):
if Path(pretrained).exists():
logger.info(f"Loading weights from {pretrained}...")
checkpoint = torch.load(pretrained, map_location=device)
else:
### Throw error
raise FileNotFoundError("The pretrained checkpoint file is not found")
else:
resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1))
# Download the MedicalNet pretrained model
logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}")
pretrained_path = hf_hub_download(
repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}",
filename=f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth")

checkpoint = torch.load(pretrained_path, map_location=device)

if "state_dict" in checkpoint:
model_state_dict = checkpoint["state_dict"]
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
else:
### Throw error
raise KeyError(
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
"The checkpoint should contain the pretrained model state dict with the following key: 'state_dict'"
)

model.load_state_dict(model_state_dict, strict=True)

return model


Expand Down
27 changes: 27 additions & 0 deletions tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from __future__ import annotations

import unittest
import copy
import os
from typing import TYPE_CHECKING

import torch
Expand All @@ -30,6 +32,8 @@
else:
torchvision, has_torchvision = optional_import("torchvision")

# from torchvision.models import ResNet50_Weights, resnet50

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_1 = [ # 3D, batch 3, 2 input channel
Expand Down Expand Up @@ -159,9 +163,11 @@
]

TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]:
TEST_CASES.append([ResNet, *case])

Expand All @@ -181,6 +187,27 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
else:
self.assertTrue(result.shape in expected_shape)

@parameterized.expand(PRETRAINED_TEST_CASES)
def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
tmp_ckpt_filename = "monai_unittest_tmp_ckpt.pth"
# Save ckpt
torch.save({
"state_dict": net.state_dict()
},
tmp_ckpt_filename)

cp_input_param = copy.copy(input_param)
cp_input_param["pretrained"] = tmp_ckpt_filename
pretrained_net = model(**cp_input_param)
assert str(net.state_dict()) == str(pretrained_net.state_dict())

with self.assertRaises(NotImplementedError):
cp_input_param["pretrained"] = True
model(**cp_input_param)

os.remove(tmp_ckpt_filename)

@parameterized.expand(TEST_SCRIPT_CASES)
def test_script(self, model, input_param, input_shape, expected_shape):
net = model(**input_param)
Expand Down