Skip to content

Commit

Permalink
Avoid breaking change in creating BundleWorkflow (#6950)
Browse files Browse the repository at this point in the history
Fixes # .

### Description
Avoid breaking changes introduced by
#6835
- when creating `BundleWorkflow`
- when using `load` API, add `return_state_dict` when `model` and
`net_name` are both `None`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
KumoLiu and ericspod authored Sep 7, 2023
1 parent 8ccde11 commit 6f13b8d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 27 deletions.
77 changes: 55 additions & 22 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from monai.apps.mmars.mmars import _get_all_ngc_models
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
Expand Down Expand Up @@ -247,7 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path:
return Path(bundle_dir)


@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.4")
@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.5")
def download(
name: str | None = None,
version: str | None = None,
Expand Down Expand Up @@ -375,8 +376,9 @@ def download(
)


@deprecated_arg("net_name", since="1.3", removed="1.4", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.3", removed="1.3", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_name", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("return_state_dict", since="1.3", removed="1.5")
def load(
name: str,
model: torch.nn.Module | None = None,
Expand All @@ -395,8 +397,10 @@ def load(
workflow_name: str | BundleWorkflow | None = None,
args_file: str | None = None,
copy_model_args: dict | None = None,
return_state_dict: bool = True,
net_override: dict | None = None,
net_name: str | None = None,
**net_override: Any,
**net_kwargs: Any,
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
"""
Load model weights or TorchScript module of a bundle.
Expand Down Expand Up @@ -441,7 +445,12 @@ def load(
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
net_override: id-value pairs to override the parameters in the network of the bundle.
return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network
from `_workflow.network_def` will be instantiated and load the achieved weights.
net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
This argument only works when loading weights.
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.
Returns:
1. If `load_ts_module` is `False` and `model` is `None`,
Expand All @@ -452,9 +461,15 @@ def load(
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
the corresponding metadata dict, and extra files dict.
please check `monai.data.load_net_with_metadata` for more details.
4. If `return_state_dict` is True, return model weights, only used for compatibility
when `model` and `net_name` are all `None`.
"""
if return_state_dict and (model is not None or net_name is not None):
warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.")

bundle_dir_ = _process_bundle_dir(bundle_dir)
net_override = {} if net_override is None else net_override
copy_model_args = {} if copy_model_args is None else copy_model_args

if device is None:
Expand All @@ -466,7 +481,7 @@ def load(
if remove_prefix:
name = _remove_ngc_prefix(name, prefix=remove_prefix)
full_path = os.path.join(bundle_dir_, name, model_file)
if not os.path.exists(full_path) or model is None:
if not os.path.exists(full_path):
download(
name=name,
version=version,
Expand All @@ -477,34 +492,52 @@ def load(
progress=progress,
args_file=args_file,
)
train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
if train_config_file.is_file():
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
_workflow = create_workflow(
workflow_name=workflow_name,
args_file=args_file,
config_file=str(train_config_file),
workflow_type=workflow_type,
**_net_override,
)
else:
_workflow = None

# loading with `torch.jit.load`
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))

if not isinstance(model_dict, Mapping):
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
model_dict = get_state_dict(model_dict)

if model is None and _workflow is None:
if return_state_dict:
return model_dict
model = _workflow.network_def if model is None else model
model.to(device)

copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args)
_workflow = None
if model is None and net_name is None:
bundle_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
if bundle_config_file.is_file():
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
_workflow = create_workflow(
workflow_name=workflow_name,
args_file=args_file,
config_file=str(bundle_config_file),
workflow_type=workflow_type,
**_net_override,
)
else:
warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.")
return model_dict
if _workflow is not None:
if not hasattr(_workflow, "network_def"):
warnings.warn("No available network definition in the bundle, return state dict instead.")
return model_dict
else:
model = _workflow.network_def
elif net_name is not None:
net_kwargs["_target_"] = net_name
configer = ConfigComponent(config=net_kwargs)
model = configer.instantiate() # type: ignore

model.to(device) # type: ignore

copy_model_state(
dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args # type: ignore
)

return model


Expand Down
13 changes: 12 additions & 1 deletion monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class BundleWorkflow(ABC):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
"""

Expand All @@ -56,7 +60,8 @@ class BundleWorkflow(ABC):
new_name="workflow_type",
msg_suffix="please use `workflow_type` instead.",
)
def __init__(self, workflow_type: str | None = None):
def __init__(self, workflow_type: str | None = None, workflow: str | None = None):
workflow_type = workflow if workflow is not None else workflow_type
if workflow_type is None:
self.properties = copy(MetaProperties)
self.workflow_type = None
Expand Down Expand Up @@ -198,6 +203,10 @@ class ConfigWorkflow(BundleWorkflow):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
Expand All @@ -221,8 +230,10 @@ def __init__(
final_id: str = "finalize",
tracking: str | dict | None = None,
workflow_type: str | None = None,
workflow: str | None = None,
**override: Any,
) -> None:
workflow_type = workflow if workflow is not None else workflow_type
super().__init__(workflow_type=workflow_type)
if config_file is not None:
_config_files = ensure_tuple(config_file)
Expand Down
7 changes: 6 additions & 1 deletion tests/ngc_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))

model = load(
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
name=bundle_name,
source="ngc",
version=version,
bundle_dir=tempdir,
remove_prefix=remove_prefix,
return_state_dict=False,
)
assert_allclose(
model.state_dict()[TESTCASE_WEIGHTS["key"]],
Expand Down
31 changes: 28 additions & 3 deletions tests/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
source="github",
progress=False,
device=device,
return_state_dict=True,
)

# prepare network
Expand Down Expand Up @@ -174,21 +175,44 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
bundle_dir=tempdir,
progress=False,
device=device,
net_name=model_name,
source="github",
return_state_dict=False,
)
model_2.eval()
output_2 = model_2.forward(input_tensor)
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

# test compatibility with return_state_dict=True.
model_3 = load(
name=bundle_name,
model_file=model_file,
bundle_dir=tempdir,
progress=False,
device=device,
net_name=model_name,
source="github",
return_state_dict=False,
**net_args,
)
model_3.eval()
output_3 = model_3.forward(input_tensor)
assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand([TEST_CASE_7])
@skip_if_quick
def test_load_weights_with_net_override(self, bundle_name, device, net_override):
with skip_if_downloading_fails():
# download bundle, and load weights from the downloaded path
with tempfile.TemporaryDirectory() as tempdir:
# load weights
model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device)
model = load(
name=bundle_name,
bundle_dir=tempdir,
source="monaihosting",
progress=False,
device=device,
return_state_dict=False,
)

# prepare data and test
input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
Expand All @@ -209,7 +233,8 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
source="monaihosting",
progress=False,
device=device,
**net_override,
return_state_dict=False,
net_override=net_override,
)

# prepare data and test
Expand Down

0 comments on commit 6f13b8d

Please sign in to comment.