From 18339162fc869fad271b79f1ecbcecb82019255a Mon Sep 17 00:00:00 2001 From: Reda Oulbacha Date: Wed, 16 Aug 2023 22:47:37 -0400 Subject: [PATCH 1/3] refactor and simplify sys.path and mp context managers --- flojoy/flojoy_node_venv.py | 81 +++++++++++++++----------------------- 1 file changed, 31 insertions(+), 50 deletions(-) diff --git a/flojoy/flojoy_node_venv.py b/flojoy/flojoy_node_venv.py index 84c5cf5..51ae9e1 100644 --- a/flojoy/flojoy_node_venv.py +++ b/flojoy/flojoy_node_venv.py @@ -22,6 +22,7 @@ def TORCH_NODE(default: Matrix) -> Matrix: from typing import Callable import hashlib +from contextlib import contextmanager import importlib.metadata import inspect import logging @@ -41,41 +42,19 @@ def TORCH_NODE(default: Matrix) -> Matrix: __all__ = ["run_in_venv"] -class MultiprocessingExecutableContextManager: - """Temporarily change the executable used by multiprocessing.""" - def __init__(self, executable_path): - self.original_executable_path = sys.executable - self.executable_path = executable_path - # We need to save the original start method - # because it is set to "fork" by default on Linux while we ALWAYS want spawn - self.original_start_method = multiprocessing.get_start_method() - - def __enter__(self): - if self.original_start_method != "spawn": - multiprocessing.set_start_method("spawn", force=True) - multiprocessing.set_executable(self.executable_path) - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.original_start_method != "spawn": - multiprocessing.set_start_method(self.original_start_method, force=True) - multiprocessing.set_executable(self.original_executable_path) - - -class SwapSysPath: +@contextmanager +def swap_sys_path(venv_executable: os.PathLike, extra_sys_path: list[str] = None): """Temporarily swap the sys.path of the child process with the sys.path of the parent process.""" + old_path = sys.path + try: + new_path = _get_venv_syspath(venv_executable) + extra_sys_path = [] if extra_sys_path is None else extra_sys_path + sys.path = new_path + extra_sys_path + yield + finally: + sys.path = old_path - def __init__(self, venv_executable, extra_sys_path): - self.new_path = _get_venv_syspath(venv_executable) - self.extra_sys_path = [] if extra_sys_path is None else extra_sys_path - self.old_path = None - - def __enter__(self): - self.old_path = sys.path - sys.path = self.new_path + self.extra_sys_path - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.path = self.old_path def _install_pip_dependencies( @@ -122,9 +101,7 @@ def __init__( self._venv_executable = venv_executable def __call__(self, *args_serialized, **kwargs_serialized): - with SwapSysPath( - venv_executable=self._venv_executable, extra_sys_path=self._extra_sys_path - ): + with swap_sys_path(venv_executable=self._venv_executable, extra_sys_path=self._extra_sys_path): try: fn = cloudpickle.loads(self._func_serialized) args = [cloudpickle.loads(arg) for arg in args_serialized] @@ -225,8 +202,10 @@ def TORCH_NODE(default: Matrix) -> Matrix: def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + # Generate a new multiprocessing context for the parent process in "spawn" mode + parent_mp_context = multiprocessing.get_context("spawn") + parent_conn, child_conn = parent_mp_context.Pipe() # Serialize function arguments using cloudpickle - parent_conn, child_conn = multiprocessing.Pipe() args_serialized = [cloudpickle.dumps(arg) for arg in args] kwargs_serialized = { key: cloudpickle.dumps(value) for key, value in kwargs.items() @@ -234,20 +213,22 @@ def wrapper(*args, **kwargs): pickleable_func_with_pipe = PickleableFunctionWithPipeIO( func, child_conn, venv_executable ) - # Start the context manager that will change the executable used by multiprocessing - with MultiprocessingExecutableContextManager(venv_executable): - # Create a new process that will run the Python code - process = multiprocessing.Process( - target=pickleable_func_with_pipe, - args=args_serialized, - kwargs=kwargs_serialized, - ) - # Start the process - process.start() - # Fetch result from the child process - serialized_result = parent_conn.recv_bytes() - # Wait for the process to finish - process.join() + # Create a new multiprocessing context for the child process in "spawn" mode + # while setting its executable to the virtual environment python executable + child_mp_context = multiprocessing.get_context("spawn") + child_mp_context.set_executable(venv_executable) + # Create a new process that will run the Python code + process = child_mp_context.Process( + target=pickleable_func_with_pipe, + args=args_serialized, + kwargs=kwargs_serialized, + ) + # Start the process + process.start() + # Fetch result from the child process + serialized_result = parent_conn.recv_bytes() + # Wait for the process to finish + process.join() # Check if the process sent an exception with a traceback result = cloudpickle.loads(serialized_result) if isinstance(result, tuple) and isinstance(result[0], Exception): From ced63af27b7e7ad81050f7fb7953ed1735b51018 Mon Sep 17 00:00:00 2001 From: Reda Oulbacha Date: Wed, 16 Aug 2023 22:48:03 -0400 Subject: [PATCH 2/3] run black formatter --- flojoy/flojoy_node_venv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flojoy/flojoy_node_venv.py b/flojoy/flojoy_node_venv.py index 51ae9e1..2a60b67 100644 --- a/flojoy/flojoy_node_venv.py +++ b/flojoy/flojoy_node_venv.py @@ -42,7 +42,6 @@ def TORCH_NODE(default: Matrix) -> Matrix: __all__ = ["run_in_venv"] - @contextmanager def swap_sys_path(venv_executable: os.PathLike, extra_sys_path: list[str] = None): """Temporarily swap the sys.path of the child process with the sys.path of the parent process.""" @@ -56,7 +55,6 @@ def swap_sys_path(venv_executable: os.PathLike, extra_sys_path: list[str] = None sys.path = old_path - def _install_pip_dependencies( venv_executable: os.PathLike, pip_dependencies: tuple[str], verbose: bool = False ): @@ -101,7 +99,9 @@ def __init__( self._venv_executable = venv_executable def __call__(self, *args_serialized, **kwargs_serialized): - with swap_sys_path(venv_executable=self._venv_executable, extra_sys_path=self._extra_sys_path): + with swap_sys_path( + venv_executable=self._venv_executable, extra_sys_path=self._extra_sys_path + ): try: fn = cloudpickle.loads(self._func_serialized) args = [cloudpickle.loads(arg) for arg in args_serialized] @@ -204,7 +204,7 @@ def decorator(func): def wrapper(*args, **kwargs): # Generate a new multiprocessing context for the parent process in "spawn" mode parent_mp_context = multiprocessing.get_context("spawn") - parent_conn, child_conn = parent_mp_context.Pipe() + parent_conn, child_conn = parent_mp_context.Pipe() # Serialize function arguments using cloudpickle args_serialized = [cloudpickle.dumps(arg) for arg in args] kwargs_serialized = { From 85cbdf24aefd6ef8306b2b426dce2ab993dfee82 Mon Sep 17 00:00:00 2001 From: Reda Oulbacha Date: Wed, 16 Aug 2023 23:58:27 -0400 Subject: [PATCH 3/3] update extra_sys_path type hint --- flojoy/flojoy_node_venv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flojoy/flojoy_node_venv.py b/flojoy/flojoy_node_venv.py index 2a60b67..fc06d25 100644 --- a/flojoy/flojoy_node_venv.py +++ b/flojoy/flojoy_node_venv.py @@ -43,7 +43,7 @@ def TORCH_NODE(default: Matrix) -> Matrix: @contextmanager -def swap_sys_path(venv_executable: os.PathLike, extra_sys_path: list[str] = None): +def swap_sys_path(venv_executable: os.PathLike, extra_sys_path: list[str] | None = None): """Temporarily swap the sys.path of the child process with the sys.path of the parent process.""" old_path = sys.path try: