diff --git a/rapids_dask_dependency/__init__.py b/rapids_dask_dependency/__init__.py index 07ae32d..7155d05 100644 --- a/rapids_dask_dependency/__init__.py +++ b/rapids_dask_dependency/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from .dask_loader import DaskLoader +from .dask_loader import DaskFinder -DaskLoader.install() +DaskFinder.install() diff --git a/rapids_dask_dependency/dask_loader.py b/rapids_dask_dependency/dask_loader.py index 0a4aba4..da61d3e 100644 --- a/rapids_dask_dependency/dask_loader.py +++ b/rapids_dask_dependency/dask_loader.py @@ -7,40 +7,42 @@ import sys from contextlib import contextmanager -from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec +from rapids_dask_dependency.utils import patch_warning_stacklevel -class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader): - def __init__(self): - self._blocklist = set() +class DaskLoader(importlib.machinery.SourceFileLoader): + def __init__(self, fullname, path, finder): + super().__init__(fullname, path) + self._finder = finder def create_module(self, spec): - if spec.name.startswith("dask") or spec.name.startswith("distributed"): - with self.disable(spec.name): - try: - # Absolute import is important here to avoid shadowing the real dask - # and distributed modules in sys.modules. Bad things will happen if - # we use relative imports here. - proxy = importlib.import_module( - f"rapids_dask_dependency.patches.{spec.name}" - ) - if hasattr(proxy, "load_module"): - return proxy.load_module(spec) - except ModuleNotFoundError: - pass + with self._finder.disable(spec.name): + try: + # Absolute import is important here to avoid shadowing the real dask + # and distributed modules in sys.modules. Bad things will happen if + # we use relative imports here. + proxy = importlib.import_module( + f"rapids_dask_dependency.patches.{spec.name}" + ) + if hasattr(proxy, "load_module"): + return proxy.load_module() + except ModuleNotFoundError: + pass - # Three extra stack frames: 1) DaskLoader.create_module, - # 2) importlib.import_module, and 3) the patched warnings function (not - # including the internal frames, which warnings ignores). - with patch_warning_stacklevel(3): - mod = importlib.import_module(spec.name) - - update_spec(spec, mod.__spec__) - return mod + # Three extra stack frames: 1) DaskLoader.create_module, + # 2) importlib.import_module, and 3) the patched warnings function (not + # including the internal frames, which warnings ignores). + with patch_warning_stacklevel(3): + return importlib.import_module(spec.name) def exec_module(self, _): pass + +class DaskFinder(importlib.abc.MetaPathFinder): + def __init__(self): + self._blocklist = set() + @contextmanager def disable(self, name): # This is a context manager that prevents this finder from intercepting calls to @@ -62,14 +64,18 @@ def find_spec(self, fullname: str, _, __=None): or fullname.startswith("dask.") or fullname.startswith("distributed.") ): - return importlib.machinery.ModuleSpec( + with self.disable(fullname): + if (real_spec := importlib.util.find_spec(fullname)) is None: + return None + spec = importlib.machinery.ModuleSpec( name=fullname, - loader=self, - # Set these parameters dynamically in create_module - origin=None, - loader_state=None, - is_package=True, + loader=DaskLoader(fullname, real_spec.origin, self), + origin=real_spec.origin, + loader_state=real_spec.loader_state, + is_package=real_spec.submodule_search_locations is not None, ) + spec.submodule_search_locations = real_spec.submodule_search_locations + return spec return None @classmethod diff --git a/rapids_dask_dependency/importer.py b/rapids_dask_dependency/importer.py deleted file mode 100644 index 613a808..0000000 --- a/rapids_dask_dependency/importer.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. - -import importlib -import importlib.util -from abc import abstractmethod - -from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec - - -class BaseImporter: - @abstractmethod - def load_module(self, spec): - pass - - -class MonkeyPatchImporter(BaseImporter): - """The base importer for modules that are monkey-patched.""" - - def __init__(self, name, patch_func): - self.name = name.replace("rapids_dask_dependency.patches.", "") - self.patch_func = patch_func - - def load_module(self, spec): - # Four extra stack frames: 1) DaskLoader.create_module, 2) - # MonkeyPatchImporter.load_module, 3) importlib.import_module, and 4) the - # patched warnings function (not including the internal frames, which warnings - # ignores). - with patch_warning_stacklevel(4): - mod = importlib.import_module(self.name) - self.patch_func(mod) - update_spec(spec, mod.__spec__) - mod._rapids_patched = True - return mod - - -class VendoredImporter(BaseImporter): - """The base importer for vendored modules.""" - - # Vendored files use a standard prefix to avoid name collisions. - default_prefix = "__rdd_patch_" - - def __init__(self, module): - self.real_module_name = module.replace("rapids_dask_dependency.patches.", "") - module_parts = module.split(".") - module_parts[-1] = self.default_prefix + module_parts[-1] - self.vendored_module_name = ".".join(module_parts) - - def load_module(self, spec): - vendored_module = importlib.import_module(self.vendored_module_name) - # At this stage the module loader must have been disabled for this module, so we - # can access the original module. We don't want to actually import it, we just - # want enough information on it to update the spec. - original_spec = importlib.util.find_spec(self.real_module_name) - update_spec(spec, original_spec) - return vendored_module diff --git a/rapids_dask_dependency/loaders.py b/rapids_dask_dependency/loaders.py new file mode 100644 index 0000000..6cf0c97 --- /dev/null +++ b/rapids_dask_dependency/loaders.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import importlib +import importlib.util + +from rapids_dask_dependency.utils import patch_warning_stacklevel + +# Vendored files use a standard prefix to avoid name collisions. +DEFAULT_VENDORED_PREFIX = "__rdd_patch_" + + +def make_monkey_patch_loader(name, patch_func): + """Create a loader for monkey-patched modules.""" + + def load_module(): + # Four extra stack frames: 1) DaskLoader.create_module, 2) + # load_module, 3) importlib.import_module, and 4) the patched warnings function. + with patch_warning_stacklevel(4): + mod = importlib.import_module( + name.replace("rapids_dask_dependency.patches.", "") + ) + patch_func(mod) + mod._rapids_patched = True + return mod + + return load_module + + +def make_vendored_loader(name): + """Create a loader for vendored modules.""" + + def load_module(): + parts = name.split(".") + parts[-1] = DEFAULT_VENDORED_PREFIX + parts[-1] + mod = importlib.import_module(".".join(parts)) + mod._rapids_vendored = True + return mod + + return load_module diff --git a/rapids_dask_dependency/patches/dask/__init__.py b/rapids_dask_dependency/patches/dask/__init__.py index 5f7dc38..52fc400 100644 --- a/rapids_dask_dependency/patches/dask/__init__.py +++ b/rapids_dask_dependency/patches/dask/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from rapids_dask_dependency.importer import MonkeyPatchImporter +from rapids_dask_dependency.loaders import make_monkey_patch_loader -_importer = MonkeyPatchImporter(__name__, lambda _: None) -load_module = _importer.load_module +load_module = make_monkey_patch_loader(__name__, lambda _: None) diff --git a/rapids_dask_dependency/patches/dask/dataframe/accessor.py b/rapids_dask_dependency/patches/dask/dataframe/accessor.py index 1968f97..5c4864b 100644 --- a/rapids_dask_dependency/patches/dask/dataframe/accessor.py +++ b/rapids_dask_dependency/patches/dask/dataframe/accessor.py @@ -1,13 +1,12 @@ # Copyright (c) 2024, NVIDIA CORPORATION. import sys +# Currently vendoring this module due to https://github.com/dask/dask/pull/11035 if sys.version_info >= (3, 11, 9): from dask import __version__ from packaging.version import Version if Version(__version__) < Version("2024.4.1"): - from rapids_dask_dependency.importer import VendoredImporter + from rapids_dask_dependency.loaders import make_vendored_loader - # Currently vendoring this module due to https://github.com/dask/dask/pull/11035 - _importer = VendoredImporter(__name__) - load_module = _importer.load_module + load_module = make_vendored_loader(__name__) diff --git a/rapids_dask_dependency/patches/distributed/__init__.py b/rapids_dask_dependency/patches/distributed/__init__.py index 5f7dc38..52fc400 100644 --- a/rapids_dask_dependency/patches/distributed/__init__.py +++ b/rapids_dask_dependency/patches/distributed/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from rapids_dask_dependency.importer import MonkeyPatchImporter +from rapids_dask_dependency.loaders import make_monkey_patch_loader -_importer = MonkeyPatchImporter(__name__, lambda _: None) -load_module = _importer.load_module +load_module = make_monkey_patch_loader(__name__, lambda _: None) diff --git a/rapids_dask_dependency/utils.py b/rapids_dask_dependency/utils.py index 36d35bb..3adc184 100644 --- a/rapids_dask_dependency/utils.py +++ b/rapids_dask_dependency/utils.py @@ -24,13 +24,3 @@ def patch_warning_stacklevel(level): warnings.warn = _make_warning_func(level) yield warnings.warn = previous_warn - - -# Note: The Python documentation does not make it clear whether we're guaranteed that -# spec is not a copy of the original spec, but that is the case for now. We need to -# assign this because the spec is used to update module attributes after it is -# initialized by create_module. -def update_spec(spec, original_spec): - spec.origin = original_spec.origin - spec.submodule_search_locations = original_spec.submodule_search_locations - return spec diff --git a/tests/test_patch.py b/tests/test_patch.py index 4c611fd..586cdce 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -1,8 +1,11 @@ import contextlib +import subprocess import tempfile from functools import wraps from multiprocessing import Process +import pytest + def run_test_in_subprocess(func): def redirect_stdout_stderr(func, stdout, stderr, *args, **kwargs): @@ -49,3 +52,53 @@ def test_distributed(): import distributed assert hasattr(distributed, "_rapids_patched") + + +@pytest.mark.parametrize("python_version", [(3, 11, 9), (3, 11, 8)]) +@run_test_in_subprocess +def test_dask_accessor(python_version): + import sys + + import dask + + # Simulate the version of Python and Dask needed to trigger vendoring of the + # accessor module. + sys.version_info = python_version + dask.__version__ = "2023.4.1" + + from dask.dataframe import accessor + + assert (hasattr(accessor, "_rapids_vendored")) == (python_version >= (3, 11, 9)) + + +def test_dask_cli(): + try: + subprocess.run(["dask", "--help"], capture_output=True, check=True) + except subprocess.CalledProcessError as e: + print(e.stdout.decode()) + print(e.stderr.decode()) + raise + + +def test_dask_as_module(): + try: + subprocess.run( + ["python", "-m", "dask", "--help"], capture_output=True, check=True + ) + except subprocess.CalledProcessError as e: + print(e.stdout.decode()) + print(e.stderr.decode()) + raise + + +def test_distributed_cli_dask_spec_as_module(): + try: + subprocess.run( + ["python", "-m", "distributed.cli.dask_spec", "--help"], + capture_output=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print(e.stdout.decode()) + print(e.stderr.decode()) + raise