From ea774a5b831e048b48b52b056f88d4171cf85458 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Tue, 21 May 2024 14:18:19 -0700 Subject: [PATCH] Fix CLI invocations (#46) Module loads via `python -m *` take a different code path than direct usage of the dask/distributed CLI. The finder is still run to get the spec, but instead of the loader's `[create|exec]_module` methods being called next the loader's `get_code` method is called to simply load the `__main__` module as a text file to be executed directly. To support this use case, this PR splits up the finder and the loader. The loader is now a SourceFileLoader that is instantiated once per file to be loaded, allowing it to support loading the text of that file on demand by command line invocations. In order for this to work, the spec needs to be properly configured before the finder's `find_spec` call returns. The `update_spec` function I had in place before was a bit of a hack. This PR replaces that with a call to `importlib.util.find_spec` on the real module to get the full spec right up front. This allowed me to do quite a bit of other simplification; since the spec is now correct up front, I removed `update_spec` entirely and converted the loaders from stateful classes that updated the spec on load to simple factories since the spec is already guaranteed to be correct. Finally, I added some new tests of the CLI, module usage, and the accessor module. Closes #45 Closes https://github.com/rapidsai/docker/issues/668 Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Richard (Rick) Zamora (https://github.com/rjzamora) URL: https://github.com/rapidsai/rapids-dask-dependency/pull/46 --- rapids_dask_dependency/__init__.py | 4 +- rapids_dask_dependency/dask_loader.py | 68 ++++++++++--------- rapids_dask_dependency/importer.py | 55 --------------- rapids_dask_dependency/loaders.py | 39 +++++++++++ .../patches/dask/__init__.py | 5 +- .../patches/dask/dataframe/accessor.py | 7 +- .../patches/distributed/__init__.py | 5 +- rapids_dask_dependency/utils.py | 10 --- tests/test_patch.py | 53 +++++++++++++++ 9 files changed, 138 insertions(+), 108 deletions(-) delete mode 100644 rapids_dask_dependency/importer.py create mode 100644 rapids_dask_dependency/loaders.py diff --git a/rapids_dask_dependency/__init__.py b/rapids_dask_dependency/__init__.py index 07ae32d..7155d05 100644 --- a/rapids_dask_dependency/__init__.py +++ b/rapids_dask_dependency/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from .dask_loader import DaskLoader +from .dask_loader import DaskFinder -DaskLoader.install() +DaskFinder.install() diff --git a/rapids_dask_dependency/dask_loader.py b/rapids_dask_dependency/dask_loader.py index 0a4aba4..da61d3e 100644 --- a/rapids_dask_dependency/dask_loader.py +++ b/rapids_dask_dependency/dask_loader.py @@ -7,40 +7,42 @@ import sys from contextlib import contextmanager -from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec +from rapids_dask_dependency.utils import patch_warning_stacklevel -class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader): - def __init__(self): - self._blocklist = set() +class DaskLoader(importlib.machinery.SourceFileLoader): + def __init__(self, fullname, path, finder): + super().__init__(fullname, path) + self._finder = finder def create_module(self, spec): - if spec.name.startswith("dask") or spec.name.startswith("distributed"): - with self.disable(spec.name): - try: - # Absolute import is important here to avoid shadowing the real dask - # and distributed modules in sys.modules. Bad things will happen if - # we use relative imports here. - proxy = importlib.import_module( - f"rapids_dask_dependency.patches.{spec.name}" - ) - if hasattr(proxy, "load_module"): - return proxy.load_module(spec) - except ModuleNotFoundError: - pass + with self._finder.disable(spec.name): + try: + # Absolute import is important here to avoid shadowing the real dask + # and distributed modules in sys.modules. Bad things will happen if + # we use relative imports here. + proxy = importlib.import_module( + f"rapids_dask_dependency.patches.{spec.name}" + ) + if hasattr(proxy, "load_module"): + return proxy.load_module() + except ModuleNotFoundError: + pass - # Three extra stack frames: 1) DaskLoader.create_module, - # 2) importlib.import_module, and 3) the patched warnings function (not - # including the internal frames, which warnings ignores). - with patch_warning_stacklevel(3): - mod = importlib.import_module(spec.name) - - update_spec(spec, mod.__spec__) - return mod + # Three extra stack frames: 1) DaskLoader.create_module, + # 2) importlib.import_module, and 3) the patched warnings function (not + # including the internal frames, which warnings ignores). + with patch_warning_stacklevel(3): + return importlib.import_module(spec.name) def exec_module(self, _): pass + +class DaskFinder(importlib.abc.MetaPathFinder): + def __init__(self): + self._blocklist = set() + @contextmanager def disable(self, name): # This is a context manager that prevents this finder from intercepting calls to @@ -62,14 +64,18 @@ def find_spec(self, fullname: str, _, __=None): or fullname.startswith("dask.") or fullname.startswith("distributed.") ): - return importlib.machinery.ModuleSpec( + with self.disable(fullname): + if (real_spec := importlib.util.find_spec(fullname)) is None: + return None + spec = importlib.machinery.ModuleSpec( name=fullname, - loader=self, - # Set these parameters dynamically in create_module - origin=None, - loader_state=None, - is_package=True, + loader=DaskLoader(fullname, real_spec.origin, self), + origin=real_spec.origin, + loader_state=real_spec.loader_state, + is_package=real_spec.submodule_search_locations is not None, ) + spec.submodule_search_locations = real_spec.submodule_search_locations + return spec return None @classmethod diff --git a/rapids_dask_dependency/importer.py b/rapids_dask_dependency/importer.py deleted file mode 100644 index 613a808..0000000 --- a/rapids_dask_dependency/importer.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. - -import importlib -import importlib.util -from abc import abstractmethod - -from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec - - -class BaseImporter: - @abstractmethod - def load_module(self, spec): - pass - - -class MonkeyPatchImporter(BaseImporter): - """The base importer for modules that are monkey-patched.""" - - def __init__(self, name, patch_func): - self.name = name.replace("rapids_dask_dependency.patches.", "") - self.patch_func = patch_func - - def load_module(self, spec): - # Four extra stack frames: 1) DaskLoader.create_module, 2) - # MonkeyPatchImporter.load_module, 3) importlib.import_module, and 4) the - # patched warnings function (not including the internal frames, which warnings - # ignores). - with patch_warning_stacklevel(4): - mod = importlib.import_module(self.name) - self.patch_func(mod) - update_spec(spec, mod.__spec__) - mod._rapids_patched = True - return mod - - -class VendoredImporter(BaseImporter): - """The base importer for vendored modules.""" - - # Vendored files use a standard prefix to avoid name collisions. - default_prefix = "__rdd_patch_" - - def __init__(self, module): - self.real_module_name = module.replace("rapids_dask_dependency.patches.", "") - module_parts = module.split(".") - module_parts[-1] = self.default_prefix + module_parts[-1] - self.vendored_module_name = ".".join(module_parts) - - def load_module(self, spec): - vendored_module = importlib.import_module(self.vendored_module_name) - # At this stage the module loader must have been disabled for this module, so we - # can access the original module. We don't want to actually import it, we just - # want enough information on it to update the spec. - original_spec = importlib.util.find_spec(self.real_module_name) - update_spec(spec, original_spec) - return vendored_module diff --git a/rapids_dask_dependency/loaders.py b/rapids_dask_dependency/loaders.py new file mode 100644 index 0000000..6cf0c97 --- /dev/null +++ b/rapids_dask_dependency/loaders.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import importlib +import importlib.util + +from rapids_dask_dependency.utils import patch_warning_stacklevel + +# Vendored files use a standard prefix to avoid name collisions. +DEFAULT_VENDORED_PREFIX = "__rdd_patch_" + + +def make_monkey_patch_loader(name, patch_func): + """Create a loader for monkey-patched modules.""" + + def load_module(): + # Four extra stack frames: 1) DaskLoader.create_module, 2) + # load_module, 3) importlib.import_module, and 4) the patched warnings function. + with patch_warning_stacklevel(4): + mod = importlib.import_module( + name.replace("rapids_dask_dependency.patches.", "") + ) + patch_func(mod) + mod._rapids_patched = True + return mod + + return load_module + + +def make_vendored_loader(name): + """Create a loader for vendored modules.""" + + def load_module(): + parts = name.split(".") + parts[-1] = DEFAULT_VENDORED_PREFIX + parts[-1] + mod = importlib.import_module(".".join(parts)) + mod._rapids_vendored = True + return mod + + return load_module diff --git a/rapids_dask_dependency/patches/dask/__init__.py b/rapids_dask_dependency/patches/dask/__init__.py index 5f7dc38..52fc400 100644 --- a/rapids_dask_dependency/patches/dask/__init__.py +++ b/rapids_dask_dependency/patches/dask/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from rapids_dask_dependency.importer import MonkeyPatchImporter +from rapids_dask_dependency.loaders import make_monkey_patch_loader -_importer = MonkeyPatchImporter(__name__, lambda _: None) -load_module = _importer.load_module +load_module = make_monkey_patch_loader(__name__, lambda _: None) diff --git a/rapids_dask_dependency/patches/dask/dataframe/accessor.py b/rapids_dask_dependency/patches/dask/dataframe/accessor.py index 1968f97..5c4864b 100644 --- a/rapids_dask_dependency/patches/dask/dataframe/accessor.py +++ b/rapids_dask_dependency/patches/dask/dataframe/accessor.py @@ -1,13 +1,12 @@ # Copyright (c) 2024, NVIDIA CORPORATION. import sys +# Currently vendoring this module due to https://github.com/dask/dask/pull/11035 if sys.version_info >= (3, 11, 9): from dask import __version__ from packaging.version import Version if Version(__version__) < Version("2024.4.1"): - from rapids_dask_dependency.importer import VendoredImporter + from rapids_dask_dependency.loaders import make_vendored_loader - # Currently vendoring this module due to https://github.com/dask/dask/pull/11035 - _importer = VendoredImporter(__name__) - load_module = _importer.load_module + load_module = make_vendored_loader(__name__) diff --git a/rapids_dask_dependency/patches/distributed/__init__.py b/rapids_dask_dependency/patches/distributed/__init__.py index 5f7dc38..52fc400 100644 --- a/rapids_dask_dependency/patches/distributed/__init__.py +++ b/rapids_dask_dependency/patches/distributed/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from rapids_dask_dependency.importer import MonkeyPatchImporter +from rapids_dask_dependency.loaders import make_monkey_patch_loader -_importer = MonkeyPatchImporter(__name__, lambda _: None) -load_module = _importer.load_module +load_module = make_monkey_patch_loader(__name__, lambda _: None) diff --git a/rapids_dask_dependency/utils.py b/rapids_dask_dependency/utils.py index 36d35bb..3adc184 100644 --- a/rapids_dask_dependency/utils.py +++ b/rapids_dask_dependency/utils.py @@ -24,13 +24,3 @@ def patch_warning_stacklevel(level): warnings.warn = _make_warning_func(level) yield warnings.warn = previous_warn - - -# Note: The Python documentation does not make it clear whether we're guaranteed that -# spec is not a copy of the original spec, but that is the case for now. We need to -# assign this because the spec is used to update module attributes after it is -# initialized by create_module. -def update_spec(spec, original_spec): - spec.origin = original_spec.origin - spec.submodule_search_locations = original_spec.submodule_search_locations - return spec diff --git a/tests/test_patch.py b/tests/test_patch.py index 4c611fd..586cdce 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -1,8 +1,11 @@ import contextlib +import subprocess import tempfile from functools import wraps from multiprocessing import Process +import pytest + def run_test_in_subprocess(func): def redirect_stdout_stderr(func, stdout, stderr, *args, **kwargs): @@ -49,3 +52,53 @@ def test_distributed(): import distributed assert hasattr(distributed, "_rapids_patched") + + +@pytest.mark.parametrize("python_version", [(3, 11, 9), (3, 11, 8)]) +@run_test_in_subprocess +def test_dask_accessor(python_version): + import sys + + import dask + + # Simulate the version of Python and Dask needed to trigger vendoring of the + # accessor module. + sys.version_info = python_version + dask.__version__ = "2023.4.1" + + from dask.dataframe import accessor + + assert (hasattr(accessor, "_rapids_vendored")) == (python_version >= (3, 11, 9)) + + +def test_dask_cli(): + try: + subprocess.run(["dask", "--help"], capture_output=True, check=True) + except subprocess.CalledProcessError as e: + print(e.stdout.decode()) + print(e.stderr.decode()) + raise + + +def test_dask_as_module(): + try: + subprocess.run( + ["python", "-m", "dask", "--help"], capture_output=True, check=True + ) + except subprocess.CalledProcessError as e: + print(e.stdout.decode()) + print(e.stderr.decode()) + raise + + +def test_distributed_cli_dask_spec_as_module(): + try: + subprocess.run( + ["python", "-m", "distributed.cli.dask_spec", "--help"], + capture_output=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print(e.stdout.decode()) + print(e.stderr.decode()) + raise