diff --git a/pyutilib/misc/import_file.py b/pyutilib/misc/import_file.py index 90db6c0a..ab9ea393 100644 --- a/pyutilib/misc/import_file.py +++ b/pyutilib/misc/import_file.py @@ -92,11 +92,11 @@ def import_file(filename, context=None, name=None, clear_cache=False): else: if clear_cache and modulename in sys.modules: del sys.modules[modulename] - if dirname is not None: - sys.path.insert(0, dirname) - else: - sys.path.insert(0, implied_dirname) try: + if dirname is not None: + sys.path.insert(0, dirname) + else: + sys.path.insert(0, implied_dirname) module = __import__(modulename) except ImportError: pass @@ -129,8 +129,8 @@ def import_file(filename, context=None, name=None, clear_cache=False): [dirname]) fp.close() else: - sys.path.insert(0, implied_dirname) try: + sys.path.insert(0, implied_dirname) # find_module will return the .py file # (never .pyc) fp, pathname, description = imp.find_module(modulename) @@ -175,7 +175,7 @@ def run_file(filename, logfile=None, execdir=None): # # Open logfile # - if not logfile is None: + if logfile is not None: sys.stderr.flush() sys.stdout.flush() save_stdout = sys.stdout @@ -187,44 +187,44 @@ def run_file(filename, logfile=None, execdir=None): # Add the file directory to the system path # currdir_ = '' - if '/' in filename: - currdir_ = "/".join((filename).split("/")[:-1]) - tmp_import = (filename).split("/")[-1] - sys.path.append(currdir_) - elif '\\' in filename: - currdir_ = "\\".join((filename).split("\\")[:-1]) - tmp_import = (filename).split("\\")[-1] - sys.path.append(currdir_) + norm_file = os.path.normpath(filename) + assert norm_file[-1] not in '\\/' + split_path = [] + while norm_file: + norm_file, tail = os.path.split(norm_file) + split_path.append(tail) + # Absolute paths can get stuck returning ('/', '') + if not tail: + split_path.append(norm_file) + norm_file = '' + split_path.reverse() + if len(split_path) > 1: + currdir_ = os.path.join(*tuple(split_path[:-1])) else: - tmp_import = filename + currdir_ = '.' + tmp_import = split_path[-1] + name = ".".join((tmp_import).split(".")[:-1]) # # Run the module # try: - if not execdir is None: + tmp_path = list(sys.path) + if execdir is not None: tmp = os.getcwd() os.chdir(execdir) - tmp_path = sys.path sys.path = [execdir] + sys.path + # [JDS 191130] I am not sure why we put the target file's + # directory at the end of sys.path, but I am preserving that + # decision. + sys.path.append(currdir_) runpy.run_module(name, None, "__main__") - if not execdir is None: + finally: + # Mandatory cleanup + sys.path = tmp_path + if execdir is not None: os.chdir(tmp) - sys.path = tmp_path - except Exception: #pragma:nocover - if not logfile is None: + if logfile is not None: OUTPUT.close() sys.stdout = save_stdout sys.stderr = save_stderr - raise - if currdir_ in sys.path: - sys.path.remove(currdir_) - if execdir in sys.path: - sys.path.remove(execdir) - # - # Close logfile - # - if not logfile is None: - OUTPUT.close() - sys.stdout = save_stdout - sys.stderr = save_stderr diff --git a/pyutilib/misc/tests/import_exception.py b/pyutilib/misc/tests/import_exception.py new file mode 100644 index 00000000..5480bd95 --- /dev/null +++ b/pyutilib/misc/tests/import_exception.py @@ -0,0 +1,3 @@ +import sys + +raise RuntimeError("raised during import") diff --git a/pyutilib/misc/tests/import_main_exception.py b/pyutilib/misc/tests/import_main_exception.py new file mode 100644 index 00000000..a3ef7612 --- /dev/null +++ b/pyutilib/misc/tests/import_main_exception.py @@ -0,0 +1,5 @@ +import sys + +if __name__ == "__main__": + print("import_main_exception - main") + raise RuntimeError("raised from __main__") diff --git a/pyutilib/misc/tests/import_main_exception.txt b/pyutilib/misc/tests/import_main_exception.txt new file mode 100644 index 00000000..588c98e3 --- /dev/null +++ b/pyutilib/misc/tests/import_main_exception.txt @@ -0,0 +1 @@ +import_main_exception - main diff --git a/pyutilib/misc/tests/test_import.py b/pyutilib/misc/tests/test_import.py index a5a57b07..af523ad2 100644 --- a/pyutilib/misc/tests/test_import.py +++ b/pyutilib/misc/tests/test_import.py @@ -54,6 +54,21 @@ def test_run_file3(self): currdir + "import2.txt")[0]) os.remove(currdir + "import2.log") + def test_run_file_exception(self): + orig_path = list(sys.path) + with self.assertRaisesRegexp(RuntimeError, "raised from __main__"): + pyutilib.misc.run_file( + "import_main_exception.py", + logfile=currdir + "import_main_exception.log", execdir=currdir) + + self.assertFalse( + pyutilib.misc.comparison.compare_file( + currdir + "import_main_exception.log", + currdir + "import_main_exception.txt")[0]) + os.remove(currdir + "import_main_exception.log") + self.assertIsNot(orig_path, sys.path) + self.assertEqual(orig_path, sys.path) + class TestImportFile(unittest.TestCase): @@ -83,6 +98,14 @@ def test_import_file_context3(self): if not "import1" in globals(): self.fail("test_import_file - failed to import the import1.py file") + def test_import_exception(self): + orig_path = list(sys.path) + with self.assertRaisesRegexp(RuntimeError, "raised during import"): + pyutilib.misc.run_file( + "import_exception.py", execdir=currdir) + self.assertIsNot(orig_path, sys.path) + self.assertEqual(orig_path, sys.path) + def test1(self): try: pyutilib.misc.import_file('tfile.py')