diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 9759a16e9b4..5cb66c8bb99 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -361,7 +361,7 @@ def _resnet( raise KeyError( "The checkpoint should contain the pretrained model state dict with the following key: 'state_dict'" ) - + model.load_state_dict(model_state_dict, strict=True) return model diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 6151b13c453..f02da422c80 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -196,18 +196,18 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape "state_dict": net.state_dict() }, tmp_ckpt_filename) - + cp_input_param = copy.copy(input_param) cp_input_param["pretrained"] = tmp_ckpt_filename pretrained_net = model(**cp_input_param) assert str(net.state_dict()) == str(pretrained_net.state_dict()) - + with self.assertRaises(NotImplementedError): cp_input_param["pretrained"] = True - bool_pretrained_net = model(**cp_input_param) - + model(**cp_input_param) + os.remove(tmp_ckpt_filename) - + @parameterized.expand(TEST_SCRIPT_CASES) def test_script(self, model, input_param, input_shape, expected_shape): net = model(**input_param)