Skip to content

Commit

Permalink
Merge pull request #70 from PyUtilib/import-fixes
Browse files Browse the repository at this point in the history
Improvements to import_file() and run_file()
  • Loading branch information
blnicho authored Dec 13, 2019
2 parents 12fddaf + d04eace commit 5e8ec7d
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 33 deletions.
66 changes: 33 additions & 33 deletions pyutilib/misc/import_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def import_file(filename, context=None, name=None, clear_cache=False):
else:
if clear_cache and modulename in sys.modules:
del sys.modules[modulename]
if dirname is not None:
sys.path.insert(0, dirname)
else:
sys.path.insert(0, implied_dirname)
try:
if dirname is not None:
sys.path.insert(0, dirname)
else:
sys.path.insert(0, implied_dirname)
module = __import__(modulename)
except ImportError:
pass
Expand Down Expand Up @@ -129,8 +129,8 @@ def import_file(filename, context=None, name=None, clear_cache=False):
[dirname])
fp.close()
else:
sys.path.insert(0, implied_dirname)
try:
sys.path.insert(0, implied_dirname)
# find_module will return the .py file
# (never .pyc)
fp, pathname, description = imp.find_module(modulename)
Expand Down Expand Up @@ -175,7 +175,7 @@ def run_file(filename, logfile=None, execdir=None):
#
# Open logfile
#
if not logfile is None:
if logfile is not None:
sys.stderr.flush()
sys.stdout.flush()
save_stdout = sys.stdout
Expand All @@ -187,44 +187,44 @@ def run_file(filename, logfile=None, execdir=None):
# Add the file directory to the system path
#
currdir_ = ''
if '/' in filename:
currdir_ = "/".join((filename).split("/")[:-1])
tmp_import = (filename).split("/")[-1]
sys.path.append(currdir_)
elif '\\' in filename:
currdir_ = "\\".join((filename).split("\\")[:-1])
tmp_import = (filename).split("\\")[-1]
sys.path.append(currdir_)
norm_file = os.path.normpath(filename)
assert norm_file[-1] not in '\\/'
split_path = []
while norm_file:
norm_file, tail = os.path.split(norm_file)
split_path.append(tail)
# Absolute paths can get stuck returning ('/', '')
if not tail:
split_path.append(norm_file)
norm_file = ''
split_path.reverse()
if len(split_path) > 1:
currdir_ = os.path.join(*tuple(split_path[:-1]))
else:
tmp_import = filename
currdir_ = '.'
tmp_import = split_path[-1]

name = ".".join((tmp_import).split(".")[:-1])
#
# Run the module
#
try:
if not execdir is None:
tmp_path = list(sys.path)
if execdir is not None:
tmp = os.getcwd()
os.chdir(execdir)
tmp_path = sys.path
sys.path = [execdir] + sys.path
# [JDS 191130] I am not sure why we put the target file's
# directory at the end of sys.path, but I am preserving that
# decision.
sys.path.append(currdir_)
runpy.run_module(name, None, "__main__")
if not execdir is None:
finally:
# Mandatory cleanup
sys.path = tmp_path
if execdir is not None:
os.chdir(tmp)
sys.path = tmp_path
except Exception: #pragma:nocover
if not logfile is None:
if logfile is not None:
OUTPUT.close()
sys.stdout = save_stdout
sys.stderr = save_stderr
raise
if currdir_ in sys.path:
sys.path.remove(currdir_)
if execdir in sys.path:
sys.path.remove(execdir)
#
# Close logfile
#
if not logfile is None:
OUTPUT.close()
sys.stdout = save_stdout
sys.stderr = save_stderr
3 changes: 3 additions & 0 deletions pyutilib/misc/tests/import_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import sys

raise RuntimeError("raised during import")
5 changes: 5 additions & 0 deletions pyutilib/misc/tests/import_main_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys

if __name__ == "__main__":
print("import_main_exception - main")
raise RuntimeError("raised from __main__")
1 change: 1 addition & 0 deletions pyutilib/misc/tests/import_main_exception.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import_main_exception - main
23 changes: 23 additions & 0 deletions pyutilib/misc/tests/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ def test_run_file3(self):
currdir + "import2.txt")[0])
os.remove(currdir + "import2.log")

def test_run_file_exception(self):
orig_path = list(sys.path)
with self.assertRaisesRegexp(RuntimeError, "raised from __main__"):
pyutilib.misc.run_file(
"import_main_exception.py",
logfile=currdir + "import_main_exception.log", execdir=currdir)

self.assertFalse(
pyutilib.misc.comparison.compare_file(
currdir + "import_main_exception.log",
currdir + "import_main_exception.txt")[0])
os.remove(currdir + "import_main_exception.log")
self.assertIsNot(orig_path, sys.path)
self.assertEqual(orig_path, sys.path)


class TestImportFile(unittest.TestCase):

Expand Down Expand Up @@ -83,6 +98,14 @@ def test_import_file_context3(self):
if not "import1" in globals():
self.fail("test_import_file - failed to import the import1.py file")

def test_import_exception(self):
orig_path = list(sys.path)
with self.assertRaisesRegexp(RuntimeError, "raised during import"):
pyutilib.misc.run_file(
"import_exception.py", execdir=currdir)
self.assertIsNot(orig_path, sys.path)
self.assertEqual(orig_path, sys.path)

def test1(self):
try:
pyutilib.misc.import_file('tfile.py')
Expand Down

0 comments on commit 5e8ec7d

Please sign in to comment.