diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 94ca8813f4..7942e4d349 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -48,3 +48,4 @@ Model Bundle .. autofunction:: verify_metadata .. autofunction:: verify_net_in_out .. autofunction:: init_bundle +.. autofunction:: update_kwargs diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 711f1d2875..bd5db3cbea 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -29,6 +29,7 @@ run, run_workflow, trt_export, + update_kwargs, verify_metadata, verify_net_in_out, ) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 990ff00a42..829036af6f 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -412,13 +412,16 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs Args: files: path of target files to load, supported postfixes: `.json`, `.yml`, `.yaml`. - if providing a list of files, wil merge the content of them. + if providing a list of files, will merge the content of them. + if providing a string with comma separated file paths, will merge the content of them. if providing a dictionary, return it directly. kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format. """ if isinstance(files, dict): # already a config dict return files parser = ConfigParser(config={}) + if isinstance(files, str) and not Path(files).is_file() and "," in files: + files = files.split(",") for i in ensure_tuple(files): for k, v in (cls.load_config_file(i, **kwargs)).items(): parser[k] = v diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index c3aec09b5e..4607ef65b7 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -66,33 +66,46 @@ PPRINT_CONFIG_N = 5 -def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: +def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ - Update the `args` with the input `kwargs`. + Update the `args` dictionary with the input `kwargs`. For dict data, recursively update the content based on the keys. + Example:: + + from monai.bundle import update_kwargs + update_kwargs({'exist': 1}, exist=2, new_arg=3) + # return {'exist': 2, 'new_arg': 3} + Args: - args: source args to update. + args: source `args` dictionary (or a json/yaml filename to read as dictionary) to update. ignore_none: whether to ignore input args with None value, default to `True`. - kwargs: destination args to update. + kwargs: key=value pairs to be merged into `args`. """ args_: dict = args if isinstance(args, dict) else {} if isinstance(args, str): # args are defined in a structured file args_ = ConfigParser.load_config_file(args) + if isinstance(args, (tuple, list)) and all(isinstance(x, str) for x in args): + primary, overrides = args + args_ = update_kwargs(primary, ignore_none, **update_kwargs(overrides, ignore_none, **kwargs)) + if not isinstance(args_, dict): + return args_ # recursively update the default args with new args for k, v in kwargs.items(): - print(k, v) if ignore_none and v is None: continue if isinstance(v, dict) and isinstance(args_.get(k), dict): - args_[k] = _update_args(args_[k], ignore_none, **v) + args_[k] = update_kwargs(args_[k], ignore_none, **v) else: args_[k] = v return args_ +_update_args = update_kwargs # backward compatibility + + def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: """ Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`. @@ -318,7 +331,7 @@ def download( so that the command line inputs can be simplified. """ - _args = _update_args( + _args = update_kwargs( args=args_file, name=name, version=version, @@ -834,7 +847,7 @@ def verify_metadata( """ - _args = _update_args( + _args = update_kwargs( args=args_file, meta_file=meta_file, filepath=filepath, @@ -958,7 +971,7 @@ def verify_net_in_out( """ - _args = _update_args( + _args = update_kwargs( args=args_file, net_id=net_id, meta_file=meta_file, @@ -1127,7 +1140,7 @@ def onnx_export( e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. """ - _args = _update_args( + _args = update_kwargs( args=args_file, net_id=net_id, filepath=filepath, @@ -1242,7 +1255,7 @@ def ckpt_export( e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. """ - _args = _update_args( + _args = update_kwargs( args=args_file, net_id=net_id, filepath=filepath, @@ -1401,7 +1414,7 @@ def trt_export( e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. """ - _args = _update_args( + _args = update_kwargs( args=args_file, net_id=net_id, filepath=filepath, @@ -1614,7 +1627,7 @@ def create_workflow( kwargs: arguments to instantiate the workflow class. """ - _args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs) + _args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs) _log_input_summary(tag="run", args=_args) (workflow_name, config_file) = _pop_args( _args, workflow_name=ConfigWorkflow, config_file=None diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 585553a806..da3aa30141 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -14,7 +14,6 @@ import os import sys import time -import warnings from abc import ABC, abstractmethod from copy import copy from logging.config import fileConfig @@ -158,7 +157,7 @@ def add_property(self, name: str, required: str, desc: str | None = None) -> Non if self.properties is None: self.properties = {} if name in self.properties: - warnings.warn(f"property '{name}' already exists in the properties list, overriding it.") + logger.warn(f"property '{name}' already exists in the properties list, overriding it.") self.properties[name] = {BundleProperty.DESC: desc, BundleProperty.REQUIRED: required} def check_properties(self) -> list[str] | None: @@ -241,7 +240,7 @@ def __init__( for _config_file in _config_files: _config_file = Path(_config_file) if _config_file.parent != self.config_root_path: - warnings.warn( + logger.warn( f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are" f"not specified, {self.config_root_path} will be used as the default config root directory." ) @@ -254,7 +253,7 @@ def __init__( if logging_file is not None: if not os.path.exists(logging_file): if logging_file == str(self.config_root_path / "logging.conf"): - warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") + logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") else: raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") else: @@ -265,7 +264,10 @@ def __init__( self.parser.read_config(f=config_file) meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file if isinstance(meta_file, str) and not os.path.exists(meta_file): - raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") + logger.error( + f"Cannot find the metadata config file: {meta_file}. " + "Please see: https://docs.monai.io/en/stable/mb_specification.html" + ) else: self.parser.read_meta(f=meta_file) @@ -323,17 +325,17 @@ def check_properties(self) -> list[str] | None: """ ret = super().check_properties() if self.properties is None: - warnings.warn("No available properties had been set, skipping check.") + logger.warn("No available properties had been set, skipping check.") return None if ret: - warnings.warn(f"Loaded bundle does not contain the following required properties: {ret}") + logger.warn(f"Loaded bundle does not contain the following required properties: {ret}") # also check whether the optional properties use correct ID name if existing wrong_props = [] for n, p in self.properties.items(): if not p.get(BundleProperty.REQUIRED, False) and not self._check_optional_id(name=n, property=p): wrong_props.append(n) if wrong_props: - warnings.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}") + logger.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}") if ret is not None: ret.extend(wrong_props) return ret diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 9e93b3935e..cad0851a8e 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -462,7 +462,7 @@ def astype(self, dtype, device=None, *_args, **_kwargs): @property def affine(self) -> torch.Tensor: """Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``""" - return self.meta.get(MetaKeys.AFFINE, self.get_default_affine()) + return self.meta.get(MetaKeys.AFFINE, self.get_default_affine()) # type: ignore @affine.setter def affine(self, d: NdarrayTensor) -> None: diff --git a/runtests.sh b/runtests.sh index bdd56f8de3..cfceb6976a 100755 --- a/runtests.sh +++ b/runtests.sh @@ -108,7 +108,7 @@ function print_usage { echo " -c, --clean : clean temporary files from tests and exit" echo " -h, --help : show this help message and exit" echo " -v, --version : show MONAI and system version information and exit" - echo " -p, --path : specify the path used for formatting" + echo " -p, --path : specify the path used for formatting, default is the current dir if unspecified" echo " --formatfix : format code using \"isort\" and \"black\" for user specified directories" echo "" echo "${separator}For bug reports and feature requests, please file an issue at:" @@ -359,10 +359,9 @@ if [ -e "$testdir" ] then homedir=$testdir else - print_error_msg "Incorrect path: $testdir provided, run under $currentdir" homedir=$currentdir fi -echo "run tests under $homedir" +echo "Run tests under $homedir" cd "$homedir" # python path diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py index 391a56bc3c..181c08475c 100644 --- a/tests/test_bundle_utils.py +++ b/tests/test_bundle_utils.py @@ -18,6 +18,7 @@ import torch +from monai.bundle import update_kwargs from monai.bundle.utils import load_bundle_config from monai.networks.nets import UNet from monai.utils import pprint_edges @@ -141,6 +142,7 @@ def test_str(self): "[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]", ) self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3)) + self.assertEqual(update_kwargs({"a": 1}, a=2, b=3), {"a": 2, "b": 3}) if __name__ == "__main__": diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index fb9e0e75c6..bd96f50c55 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -86,9 +86,8 @@ def test_tiny(self): with self.assertRaises(RuntimeError): # test wrong run_id="run" command_line_tests(cmd + ["run", "run", "--config_file", config_file]) - with self.assertRaises(RuntimeError): - # test missing meta file - command_line_tests(cmd + ["run", "training", "--config_file", config_file]) + # test missing meta file + self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file])) def test_scripts_fold(self): # test scripts directory has been added to Python search directories automatically @@ -150,9 +149,8 @@ def test_scripts_fold(self): print(output) self.assertTrue(expected_condition in output) - with self.assertRaises(RuntimeError): - # test missing meta file - command_line_tests(cmd + ["run", "training", "--config_file", config_file]) + # test missing meta file + self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file])) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, config_file, expected_shape): diff --git a/tests/utils.py b/tests/utils.py index 9f2b41adb7..a8efbe081e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -818,6 +818,7 @@ def command_line_tests(cmd, copy_env=True): try: normal_out = subprocess.run(cmd, env=test_env, check=True, capture_output=True) print(repr(normal_out).replace("\\n", "\n").replace("\\t", "\t")) + return repr(normal_out) except subprocess.CalledProcessError as e: output = repr(e.stdout).replace("\\n", "\n").replace("\\t", "\t") errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t")