Skip to content
This repository has been archived by the owner on Oct 10, 2023. It is now read-only.

Refactor of sys.path swapping context manager, simplify multiprocessing contexts for parent and child process #75

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 31 additions & 50 deletions flojoy/flojoy_node_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def TORCH_NODE(default: Matrix) -> Matrix:
from typing import Callable

import hashlib
from contextlib import contextmanager
import importlib.metadata
import inspect
import logging
Expand All @@ -41,41 +42,17 @@ def TORCH_NODE(default: Matrix) -> Matrix:
__all__ = ["run_in_venv"]


class MultiprocessingExecutableContextManager:
"""Temporarily change the executable used by multiprocessing."""

def __init__(self, executable_path):
self.original_executable_path = sys.executable
self.executable_path = executable_path
# We need to save the original start method
# because it is set to "fork" by default on Linux while we ALWAYS want spawn
self.original_start_method = multiprocessing.get_start_method()

def __enter__(self):
if self.original_start_method != "spawn":
multiprocessing.set_start_method("spawn", force=True)
multiprocessing.set_executable(self.executable_path)

def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_start_method != "spawn":
multiprocessing.set_start_method(self.original_start_method, force=True)
multiprocessing.set_executable(self.original_executable_path)


class SwapSysPath:
@contextmanager
def swap_sys_path(venv_executable: os.PathLike, extra_sys_path: list[str] | None = None):
"""Temporarily swap the sys.path of the child process with the sys.path of the parent process."""

def __init__(self, venv_executable, extra_sys_path):
self.new_path = _get_venv_syspath(venv_executable)
self.extra_sys_path = [] if extra_sys_path is None else extra_sys_path
self.old_path = None

def __enter__(self):
self.old_path = sys.path
sys.path = self.new_path + self.extra_sys_path

def __exit__(self, exc_type, exc_val, exc_tb):
sys.path = self.old_path
old_path = sys.path
try:
new_path = _get_venv_syspath(venv_executable)
extra_sys_path = [] if extra_sys_path is None else extra_sys_path
sys.path = new_path + extra_sys_path
yield
finally:
sys.path = old_path


def _install_pip_dependencies(
Expand Down Expand Up @@ -122,7 +99,7 @@ def __init__(
self._venv_executable = venv_executable

def __call__(self, *args_serialized, **kwargs_serialized):
with SwapSysPath(
with swap_sys_path(
venv_executable=self._venv_executable, extra_sys_path=self._extra_sys_path
):
try:
Expand Down Expand Up @@ -225,29 +202,33 @@ def TORCH_NODE(default: Matrix) -> Matrix:
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Generate a new multiprocessing context for the parent process in "spawn" mode
parent_mp_context = multiprocessing.get_context("spawn")
parent_conn, child_conn = parent_mp_context.Pipe()
# Serialize function arguments using cloudpickle
parent_conn, child_conn = multiprocessing.Pipe()
args_serialized = [cloudpickle.dumps(arg) for arg in args]
kwargs_serialized = {
key: cloudpickle.dumps(value) for key, value in kwargs.items()
}
pickleable_func_with_pipe = PickleableFunctionWithPipeIO(
func, child_conn, venv_executable
)
# Start the context manager that will change the executable used by multiprocessing
with MultiprocessingExecutableContextManager(venv_executable):
# Create a new process that will run the Python code
process = multiprocessing.Process(
target=pickleable_func_with_pipe,
args=args_serialized,
kwargs=kwargs_serialized,
)
# Start the process
process.start()
# Fetch result from the child process
serialized_result = parent_conn.recv_bytes()
# Wait for the process to finish
process.join()
# Create a new multiprocessing context for the child process in "spawn" mode
# while setting its executable to the virtual environment python executable
child_mp_context = multiprocessing.get_context("spawn")
child_mp_context.set_executable(venv_executable)
# Create a new process that will run the Python code
process = child_mp_context.Process(
target=pickleable_func_with_pipe,
args=args_serialized,
kwargs=kwargs_serialized,
)
# Start the process
process.start()
# Fetch result from the child process
serialized_result = parent_conn.recv_bytes()
# Wait for the process to finish
process.join()
# Check if the process sent an exception with a traceback
result = cloudpickle.loads(serialized_result)
if isinstance(result, tuple) and isinstance(result[0], Exception):
Expand Down
Loading