Skip to content

Commit

Permalink
6387 update_kwargs for merging multiple configs (#7109)
Browse files Browse the repository at this point in the history
Fixes #6387
Fixes #5899

### Description
- add api for update_kwargs
- add support of merging multiple configs files and dictionaries
- remove warning message of directory in `runtests.sh`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Oct 10, 2023
1 parent fc1350a commit 8d730cd
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ Model Bundle
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
.. autofunction:: init_bundle
.. autofunction:: update_kwargs
1 change: 1 addition & 0 deletions monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
run,
run_workflow,
trt_export,
update_kwargs,
verify_metadata,
verify_net_in_out,
)
Expand Down
5 changes: 4 additions & 1 deletion monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,16 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs
Args:
files: path of target files to load, supported postfixes: `.json`, `.yml`, `.yaml`.
if providing a list of files, wil merge the content of them.
if providing a list of files, will merge the content of them.
if providing a string with comma separated file paths, will merge the content of them.
if providing a dictionary, return it directly.
kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format.
"""
if isinstance(files, dict): # already a config dict
return files
parser = ConfigParser(config={})
if isinstance(files, str) and not Path(files).is_file() and "," in files:
files = files.split(",")
for i in ensure_tuple(files):
for k, v in (cls.load_config_file(i, **kwargs)).items():
parser[k] = v
Expand Down
39 changes: 26 additions & 13 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,46 @@
PPRINT_CONFIG_N = 5


def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
"""
Update the `args` with the input `kwargs`.
Update the `args` dictionary with the input `kwargs`.
For dict data, recursively update the content based on the keys.
Example::
from monai.bundle import update_kwargs
update_kwargs({'exist': 1}, exist=2, new_arg=3)
# return {'exist': 2, 'new_arg': 3}
Args:
args: source args to update.
args: source `args` dictionary (or a json/yaml filename to read as dictionary) to update.
ignore_none: whether to ignore input args with None value, default to `True`.
kwargs: destination args to update.
kwargs: key=value pairs to be merged into `args`.
"""
args_: dict = args if isinstance(args, dict) else {}
if isinstance(args, str):
# args are defined in a structured file
args_ = ConfigParser.load_config_file(args)
if isinstance(args, (tuple, list)) and all(isinstance(x, str) for x in args):
primary, overrides = args
args_ = update_kwargs(primary, ignore_none, **update_kwargs(overrides, ignore_none, **kwargs))
if not isinstance(args_, dict):
return args_
# recursively update the default args with new args
for k, v in kwargs.items():
print(k, v)
if ignore_none and v is None:
continue
if isinstance(v, dict) and isinstance(args_.get(k), dict):
args_[k] = _update_args(args_[k], ignore_none, **v)
args_[k] = update_kwargs(args_[k], ignore_none, **v)
else:
args_[k] = v
return args_


_update_args = update_kwargs # backward compatibility


def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple:
"""
Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`.
Expand Down Expand Up @@ -318,7 +331,7 @@ def download(
so that the command line inputs can be simplified.
"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
name=name,
version=version,
Expand Down Expand Up @@ -834,7 +847,7 @@ def verify_metadata(
"""

_args = _update_args(
_args = update_kwargs(
args=args_file,
meta_file=meta_file,
filepath=filepath,
Expand Down Expand Up @@ -958,7 +971,7 @@ def verify_net_in_out(
"""

_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
meta_file=meta_file,
Expand Down Expand Up @@ -1127,7 +1140,7 @@ def onnx_export(
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.
"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
filepath=filepath,
Expand Down Expand Up @@ -1242,7 +1255,7 @@ def ckpt_export(
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.
"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
filepath=filepath,
Expand Down Expand Up @@ -1401,7 +1414,7 @@ def trt_export(
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.
"""
_args = _update_args(
_args = update_kwargs(
args=args_file,
net_id=net_id,
filepath=filepath,
Expand Down Expand Up @@ -1614,7 +1627,7 @@ def create_workflow(
kwargs: arguments to instantiate the workflow class.
"""
_args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
_log_input_summary(tag="run", args=_args)
(workflow_name, config_file) = _pop_args(
_args, workflow_name=ConfigWorkflow, config_file=None
Expand Down
18 changes: 10 additions & 8 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
import sys
import time
import warnings
from abc import ABC, abstractmethod
from copy import copy
from logging.config import fileConfig
Expand Down Expand Up @@ -158,7 +157,7 @@ def add_property(self, name: str, required: str, desc: str | None = None) -> Non
if self.properties is None:
self.properties = {}
if name in self.properties:
warnings.warn(f"property '{name}' already exists in the properties list, overriding it.")
logger.warn(f"property '{name}' already exists in the properties list, overriding it.")
self.properties[name] = {BundleProperty.DESC: desc, BundleProperty.REQUIRED: required}

def check_properties(self) -> list[str] | None:
Expand Down Expand Up @@ -241,7 +240,7 @@ def __init__(
for _config_file in _config_files:
_config_file = Path(_config_file)
if _config_file.parent != self.config_root_path:
warnings.warn(
logger.warn(
f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are"
f"not specified, {self.config_root_path} will be used as the default config root directory."
)
Expand All @@ -254,7 +253,7 @@ def __init__(
if logging_file is not None:
if not os.path.exists(logging_file):
if logging_file == str(self.config_root_path / "logging.conf"):
warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
else:
Expand All @@ -265,7 +264,10 @@ def __init__(
self.parser.read_config(f=config_file)
meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file
if isinstance(meta_file, str) and not os.path.exists(meta_file):
raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
logger.error(
f"Cannot find the metadata config file: {meta_file}. "
"Please see: https://docs.monai.io/en/stable/mb_specification.html"
)
else:
self.parser.read_meta(f=meta_file)

Expand Down Expand Up @@ -323,17 +325,17 @@ def check_properties(self) -> list[str] | None:
"""
ret = super().check_properties()
if self.properties is None:
warnings.warn("No available properties had been set, skipping check.")
logger.warn("No available properties had been set, skipping check.")
return None
if ret:
warnings.warn(f"Loaded bundle does not contain the following required properties: {ret}")
logger.warn(f"Loaded bundle does not contain the following required properties: {ret}")
# also check whether the optional properties use correct ID name if existing
wrong_props = []
for n, p in self.properties.items():
if not p.get(BundleProperty.REQUIRED, False) and not self._check_optional_id(name=n, property=p):
wrong_props.append(n)
if wrong_props:
warnings.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}")
logger.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}")
if ret is not None:
ret.extend(wrong_props)
return ret
Expand Down
2 changes: 1 addition & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def astype(self, dtype, device=None, *_args, **_kwargs):
@property
def affine(self) -> torch.Tensor:
"""Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``"""
return self.meta.get(MetaKeys.AFFINE, self.get_default_affine())
return self.meta.get(MetaKeys.AFFINE, self.get_default_affine()) # type: ignore

@affine.setter
def affine(self, d: NdarrayTensor) -> None:
Expand Down
5 changes: 2 additions & 3 deletions runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function print_usage {
echo " -c, --clean : clean temporary files from tests and exit"
echo " -h, --help : show this help message and exit"
echo " -v, --version : show MONAI and system version information and exit"
echo " -p, --path : specify the path used for formatting"
echo " -p, --path : specify the path used for formatting, default is the current dir if unspecified"
echo " --formatfix : format code using \"isort\" and \"black\" for user specified directories"
echo ""
echo "${separator}For bug reports and feature requests, please file an issue at:"
Expand Down Expand Up @@ -359,10 +359,9 @@ if [ -e "$testdir" ]
then
homedir=$testdir
else
print_error_msg "Incorrect path: $testdir provided, run under $currentdir"
homedir=$currentdir
fi
echo "run tests under $homedir"
echo "Run tests under $homedir"
cd "$homedir"

# python path
Expand Down
2 changes: 2 additions & 0 deletions tests/test_bundle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from monai.bundle import update_kwargs
from monai.bundle.utils import load_bundle_config
from monai.networks.nets import UNet
from monai.utils import pprint_edges
Expand Down Expand Up @@ -141,6 +142,7 @@ def test_str(self):
"[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]",
)
self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3))
self.assertEqual(update_kwargs({"a": 1}, a=2, b=3), {"a": 2, "b": 3})


if __name__ == "__main__":
Expand Down
10 changes: 4 additions & 6 deletions tests/test_integration_bundle_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ def test_tiny(self):
with self.assertRaises(RuntimeError):
# test wrong run_id="run"
command_line_tests(cmd + ["run", "run", "--config_file", config_file])
with self.assertRaises(RuntimeError):
# test missing meta file
command_line_tests(cmd + ["run", "training", "--config_file", config_file])
# test missing meta file
self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file]))

def test_scripts_fold(self):
# test scripts directory has been added to Python search directories automatically
Expand Down Expand Up @@ -150,9 +149,8 @@ def test_scripts_fold(self):
print(output)
self.assertTrue(expected_condition in output)

with self.assertRaises(RuntimeError):
# test missing meta file
command_line_tests(cmd + ["run", "training", "--config_file", config_file])
# test missing meta file
self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file]))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, config_file, expected_shape):
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def command_line_tests(cmd, copy_env=True):
try:
normal_out = subprocess.run(cmd, env=test_env, check=True, capture_output=True)
print(repr(normal_out).replace("\\n", "\n").replace("\\t", "\t"))
return repr(normal_out)
except subprocess.CalledProcessError as e:
output = repr(e.stdout).replace("\\n", "\n").replace("\\t", "\t")
errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t")
Expand Down

0 comments on commit 8d730cd

Please sign in to comment.