Skip to content

Commit

Permalink
Add bundle root directory to Python search directories automatically (#…
Browse files Browse the repository at this point in the history
…6910)

Fixes #6722 .

### Description
Add scripts directory to Python search directories automatically in the
`run` function in `ConfigWorkflow`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Aug 31, 2023
1 parent a4e4894 commit be4e1f5
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 8 deletions.
23 changes: 15 additions & 8 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import os
import sys
import time
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -170,6 +171,7 @@ class ConfigWorkflow(BundleWorkflow):
"""
Specification for the config-based bundle workflow.
Standardized the `initialize`, `run`, `finalize` behavior in a config-based training, evaluation, or inference.
Before `run`, we add bundle root directory to Python search directories automatically.
For more information: https://docs.monai.io/en/latest/mb_specification.html.
Args:
Expand Down Expand Up @@ -224,23 +226,23 @@ def __init__(
super().__init__(workflow_type=workflow_type)
if config_file is not None:
_config_files = ensure_tuple(config_file)
config_root_path = Path(_config_files[0]).parent
self.config_root_path = Path(_config_files[0]).parent
for _config_file in _config_files:
_config_file = Path(_config_file)
if _config_file.parent != config_root_path:
if _config_file.parent != self.config_root_path:
warnings.warn(
f"Not all config files are in {config_root_path}. If logging_file and meta_file are"
f"not specified, {config_root_path} will be used as the default config root directory."
f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are"
f"not specified, {self.config_root_path} will be used as the default config root directory."
)
if not _config_file.is_file():
raise FileNotFoundError(f"Cannot find the config file: {_config_file}.")
else:
config_root_path = Path("configs")
self.config_root_path = Path("configs")

logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
if logging_file is not None:
if not os.path.exists(logging_file):
if logging_file == str(config_root_path / "logging.conf"):
if logging_file == str(self.config_root_path / "logging.conf"):
warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
Expand All @@ -250,7 +252,7 @@ def __init__(

self.parser = ConfigParser()
self.parser.read_config(f=config_file)
meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file
meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file
if isinstance(meta_file, str) and not os.path.exists(meta_file):
raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
else:
Expand Down Expand Up @@ -283,8 +285,13 @@ def initialize(self) -> Any:
def run(self) -> Any:
"""
Run the bundle workflow, it can be a training, evaluation or inference.
Before run, we add bundle root directory to Python search directories automatically.
"""
_bundle_root_path = (
self.config_root_path.parent if self.config_root_path.name == "configs" else self.config_root_path
)
sys.path.insert(1, str(_bundle_root_path))
if self.run_id not in self.parser:
raise ValueError(f"run ID '{self.run_id}' doesn't exist in the config file.")
return self._run_expr(id=self.run_id)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_integration_bundle_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import os
import shutil
import subprocess
import sys
import tempfile
import unittest
Expand Down Expand Up @@ -44,6 +45,14 @@ def run(self):
return self.val


class _Runnable43:
def __init__(self, func):
self.func = func

def run(self):
self.func()


class TestBundleRun(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
Expand Down Expand Up @@ -77,6 +86,69 @@ def test_tiny(self):
with self.assertRaises(RuntimeError):
# test wrong run_id="run"
command_line_tests(cmd + ["run", "run", "--config_file", config_file])
with self.assertRaises(RuntimeError):
# test missing meta file
command_line_tests(cmd + ["run", "training", "--config_file", config_file])

def test_scripts_fold(self):
# test scripts directory has been added to Python search directories automatically
config_file = os.path.join(self.data_dir, "tiny_config.json")
meta_file = os.path.join(self.data_dir, "tiny_meta.json")
scripts_dir = os.path.join(self.data_dir, "scripts")
script_file = os.path.join(scripts_dir, "test_scripts_fold.py")
init_file = os.path.join(scripts_dir, "__init__.py")

with open(config_file, "w") as f:
json.dump(
{
"imports": ["$import scripts"],
"trainer": {
"_target_": "tests.test_integration_bundle_run._Runnable43",
"func": "$scripts.tiny_test",
},
# keep this test case to cover the "runner_id" arg
"training": "[email protected]()",
},
f,
)
with open(meta_file, "w") as f:
json.dump(
{"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"},
f,
)

os.mkdir(scripts_dir)
script_file_lines = ["def tiny_test():\n", " print('successfully added scripts fold!') \n"]
init_file_line = "from .test_scripts_fold import tiny_test\n"
with open(script_file, "w") as f:
f.writelines(script_file_lines)
f.close()
with open(init_file, "w") as f:
f.write(init_file_line)
f.close()

cmd = ["coverage", "run", "-m", "monai.bundle"]
# test both CLI entry "run" and "run_workflow"
expected_condition = "successfully added scripts fold!"
command_run = cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file]
completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True)
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
print(output)

self.assertTrue(expected_condition in output)
command_run_workflow = cmd + [
"run_workflow",
"--run_id",
"training",
"--config_file",
config_file,
"--meta_file",
meta_file,
]
completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True)
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
print(output)
self.assertTrue(expected_condition in output)

with self.assertRaises(RuntimeError):
# test missing meta file
Expand Down

0 comments on commit be4e1f5

Please sign in to comment.