Skip to content

Commit

Permalink
7047 simplify resnet pretrained (#7095)
Browse files Browse the repository at this point in the history
Fixes #7047

### Description

Resnet did not support `True` value (not implemented ) for its
pretrained flag.
2 implemented behavior: 
- When pretrained is True, download weights from
https://huggingface.co/TencentMedicalNet
- When pretrained is a string, loads weights from the path defined by
the string

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 authored Oct 18, 2023
1 parent 2c9f44c commit 9f1168f
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ opencv-python-headless
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
huggingface_hub
2 changes: 2 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
ResNet,
ResNetBlock,
ResNetBottleneck,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
resnet18,
resnet34,
Expand Down
134 changes: 123 additions & 11 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

from __future__ import annotations

import logging
import re
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import Any

import torch
Expand All @@ -21,7 +24,13 @@
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option
from monai.utils.module import look_up_option, optional_import

hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download")
EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError")

MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet"
MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_"

__all__ = [
"ResNet",
Expand All @@ -36,6 +45,8 @@
"resnet200",
]

logger = logging.getLogger(__name__)


def get_inplanes():
return [64, 128, 256, 512]
Expand Down Expand Up @@ -329,21 +340,54 @@ def _resnet(
block: type[ResNetBlock | ResNetBottleneck],
layers: list[int],
block_inplanes: list[int],
pretrained: bool,
pretrained: bool | str,
progress: bool,
**kwargs: Any,
) -> ResNet:
model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)
if pretrained:
# Author of paper zipped the state_dict on googledrive,
# so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
# Would like to load dict from url but need somewhere to save the state dicts.
raise NotImplementedError(
"Currently not implemented. You need to manually download weights provided by the paper's author"
" and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
"Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified"
"here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(pretrained, str):
if Path(pretrained).exists():
logger.info(f"Loading weights from {pretrained}...")
model_state_dict = torch.load(pretrained, map_location=device)
else:
# Throw error
raise FileNotFoundError("The pretrained checkpoint file is not found")
else:
# Also check bias downsample and shortcut.
if kwargs.get("spatial_dims", 3) == 3:
if kwargs.get("n_input_channels", 3) == 1 and kwargs.get("feed_forward", True) is False:
search_res = re.search(r"resnet(\d+)", arch)
if search_res:
resnet_depth = int(search_res.group(1))
else:
raise ValueError("arch argument should be as 'resnet_{resnet_depth}")

# Check model bias_downsample and shortcut_type
bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)
if shortcut_type == kwargs.get("shortcut_type", "B") and (
bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True
):
# Download the MedicalNet pretrained model
model_state_dict = get_pretrained_resnet_medicalnet(
resnet_depth, device=device, datasets23=True
)
else:
raise NotImplementedError(
f"Please set shortcut_type to {shortcut_type} and bias_downsample to"
f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}"
f"when using pretrained MedicalNet resnet{resnet_depth}"
)
else:
raise NotImplementedError(
"Please set n_input_channels to 1"
"and feed_forward to False in order to use MedicalNet pretrained weights"
)
else:
raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models")
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict, strict=True)
return model


Expand Down Expand Up @@ -429,3 +473,71 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet200", ResNetBottleneck, [3, 24, 36, 3], get_inplanes(), pretrained, progress, **kwargs)


def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True):
"""
Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet
Args:
resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example.
datasets23: if True, get the weights trained on more datasets (23).
Not all depths are available. If not, standard weights are returned.
Returns:
Pretrained state dict
Raises:
huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub
NotImplementedError: if `resnet_depth` is not supported
"""

medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet"
medicalnet_huggingface_files_basename = "resnet_"
supported_depth = [10, 18, 34, 50, 101, 152, 200]

logger.info(
f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}"
)

if resnet_depth in supported_depth:
filename = (
f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth"
if not datasets23
else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth"
)
try:
pretrained_path = hf_hub_download(
repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename
)
except Exception:
if datasets23:
logger.info(f"{filename} not available for resnet{resnet_depth}")
filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth"
logger.info(f"Trying with {filename}")
pretrained_path = hf_hub_download(
repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename
)
else:
raise EntryNotFoundError(
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
) from None
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
else:
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
logger.info(f"{filename} downloaded")
return checkpoint.get("state_dict")


def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
"""
Return correct shortcut_type and bias_downsample
for pretrained MedicalNet weights according to resnet depth
"""
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type
1 change: 1 addition & 0 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
onnxreference, _ = optional_import("onnx.reference")
onnxruntime, _ = optional_import("onnxruntime")


__all__ = [
"one_hot",
"predict_segmentation",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ all =
zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
nibabel =
nibabel
ninja =
Expand Down
85 changes: 83 additions & 2 deletions tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,32 @@

from __future__ import annotations

import copy
import os
import re
import sys
import unittest
from typing import TYPE_CHECKING

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
from monai.networks.nets import (
ResNet,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
resnet200,
)
from monai.networks.nets.resnet import ResNetBlock
from monai.utils import optional_import
from tests.utils import test_script_save
from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save

if TYPE_CHECKING:
import torchvision
Expand All @@ -30,6 +45,10 @@
else:
torchvision, has_torchvision = optional_import("torchvision")

has_hf_modules = "huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules

# from torchvision.models import ResNet50_Weights, resnet50

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_1 = [ # 3D, batch 3, 2 input channel
Expand Down Expand Up @@ -159,9 +178,11 @@
]

TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]:
TEST_CASES.append([ResNet, *case])

Expand All @@ -171,6 +192,16 @@


class TestResNet(unittest.TestCase):
def setUp(self):
self.tmp_ckpt_filename = os.path.join("tests", "monai_unittest_tmp_ckpt.pth")

def tearDown(self):
if os.path.exists(self.tmp_ckpt_filename):
try:
os.remove(self.tmp_ckpt_filename)
except BaseException:
pass

@parameterized.expand(TEST_CASES)
def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
Expand All @@ -181,6 +212,56 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
else:
self.assertTrue(result.shape in expected_shape)

@parameterized.expand(PRETRAINED_TEST_CASES)
@skip_if_quick
@skip_if_no_cuda
def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
# Save ckpt
torch.save(net.state_dict(), self.tmp_ckpt_filename)

cp_input_param = copy.copy(input_param)
# Custom pretrained weights
cp_input_param["pretrained"] = self.tmp_ckpt_filename
pretrained_net = model(**cp_input_param)
self.assertTrue(equal_state_dict(net.state_dict(), pretrained_net.state_dict()))

if has_hf_modules:
# True flag
cp_input_param["pretrained"] = True
resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1))

bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)

# With orig. test cases
if (
input_param.get("spatial_dims", 3) == 3
and input_param.get("n_input_channels", 3) == 1
and input_param.get("feed_forward", True) is False
and input_param.get("shortcut_type", "B") == shortcut_type
and (
input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True
)
):
model(**cp_input_param)
else:
with self.assertRaises(NotImplementedError):
model(**cp_input_param)

# forcing MedicalNet pretrained download for 3D tests cases
cp_input_param["n_input_channels"] = 1
cp_input_param["feed_forward"] = False
cp_input_param["shortcut_type"] = shortcut_type
cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True
if cp_input_param.get("spatial_dims", 3) == 3:
with skip_if_downloading_fails():
pretrained_net = model(**cp_input_param).to(device)
medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device)
medicalnet_state_dict = {
key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()
}
self.assertTrue(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict))

@parameterized.expand(TEST_SCRIPT_CASES)
def test_script(self, model, input_param, input_shape, expected_shape):
net = model(**input_param)
Expand Down
17 changes: 17 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,23 @@ def command_line_tests(cmd, copy_env=True):
raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e


def equal_state_dict(st_1, st_2):
"""
Compare 2 torch state dicts.
"""
r = True
for key_st_1, val_st_1 in st_1.items():
if key_st_1 in st_2:
val_st_2 = st_2.get(key_st_1)
if not torch.equal(val_st_1, val_st_2):
r = False
break
else:
r = False
break
return r


TEST_TORCH_TENSORS: tuple = (torch.as_tensor,)
if torch.cuda.is_available():
gpu_tensor: Callable = partial(torch.as_tensor, device="cuda")
Expand Down

0 comments on commit 9f1168f

Please sign in to comment.