Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-merge branch-24.06 into branch-24.08 #47

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rapids_dask_dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from .dask_loader import DaskLoader
from .dask_loader import DaskFinder

DaskLoader.install()
DaskFinder.install()
68 changes: 37 additions & 31 deletions rapids_dask_dependency/dask_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,42 @@
import sys
from contextlib import contextmanager

from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec
from rapids_dask_dependency.utils import patch_warning_stacklevel


class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader):
def __init__(self):
self._blocklist = set()
class DaskLoader(importlib.machinery.SourceFileLoader):
def __init__(self, fullname, path, finder):
super().__init__(fullname, path)
self._finder = finder

def create_module(self, spec):
if spec.name.startswith("dask") or spec.name.startswith("distributed"):
with self.disable(spec.name):
try:
# Absolute import is important here to avoid shadowing the real dask
# and distributed modules in sys.modules. Bad things will happen if
# we use relative imports here.
proxy = importlib.import_module(
f"rapids_dask_dependency.patches.{spec.name}"
)
if hasattr(proxy, "load_module"):
return proxy.load_module(spec)
except ModuleNotFoundError:
pass
with self._finder.disable(spec.name):
try:
# Absolute import is important here to avoid shadowing the real dask
# and distributed modules in sys.modules. Bad things will happen if
# we use relative imports here.
proxy = importlib.import_module(
f"rapids_dask_dependency.patches.{spec.name}"
)
if hasattr(proxy, "load_module"):
return proxy.load_module()
except ModuleNotFoundError:
pass

# Three extra stack frames: 1) DaskLoader.create_module,
# 2) importlib.import_module, and 3) the patched warnings function (not
# including the internal frames, which warnings ignores).
with patch_warning_stacklevel(3):
mod = importlib.import_module(spec.name)

update_spec(spec, mod.__spec__)
return mod
# Three extra stack frames: 1) DaskLoader.create_module,
# 2) importlib.import_module, and 3) the patched warnings function (not
# including the internal frames, which warnings ignores).
with patch_warning_stacklevel(3):
return importlib.import_module(spec.name)

def exec_module(self, _):
pass


class DaskFinder(importlib.abc.MetaPathFinder):
def __init__(self):
self._blocklist = set()

@contextmanager
def disable(self, name):
# This is a context manager that prevents this finder from intercepting calls to
Expand All @@ -62,14 +64,18 @@ def find_spec(self, fullname: str, _, __=None):
or fullname.startswith("dask.")
or fullname.startswith("distributed.")
):
return importlib.machinery.ModuleSpec(
with self.disable(fullname):
if (real_spec := importlib.util.find_spec(fullname)) is None:
return None
spec = importlib.machinery.ModuleSpec(
name=fullname,
loader=self,
# Set these parameters dynamically in create_module
origin=None,
loader_state=None,
is_package=True,
loader=DaskLoader(fullname, real_spec.origin, self),
origin=real_spec.origin,
loader_state=real_spec.loader_state,
is_package=real_spec.submodule_search_locations is not None,
)
spec.submodule_search_locations = real_spec.submodule_search_locations
return spec
return None

@classmethod
Expand Down
55 changes: 0 additions & 55 deletions rapids_dask_dependency/importer.py

This file was deleted.

39 changes: 39 additions & 0 deletions rapids_dask_dependency/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import importlib
import importlib.util

from rapids_dask_dependency.utils import patch_warning_stacklevel

# Vendored files use a standard prefix to avoid name collisions.
DEFAULT_VENDORED_PREFIX = "__rdd_patch_"


def make_monkey_patch_loader(name, patch_func):
"""Create a loader for monkey-patched modules."""

def load_module():
# Four extra stack frames: 1) DaskLoader.create_module, 2)
# load_module, 3) importlib.import_module, and 4) the patched warnings function.
with patch_warning_stacklevel(4):
mod = importlib.import_module(
name.replace("rapids_dask_dependency.patches.", "")
)
patch_func(mod)
mod._rapids_patched = True
return mod

return load_module


def make_vendored_loader(name):
"""Create a loader for vendored modules."""

def load_module():
parts = name.split(".")
parts[-1] = DEFAULT_VENDORED_PREFIX + parts[-1]
mod = importlib.import_module(".".join(parts))
mod._rapids_vendored = True
return mod

return load_module
5 changes: 2 additions & 3 deletions rapids_dask_dependency/patches/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from rapids_dask_dependency.importer import MonkeyPatchImporter
from rapids_dask_dependency.loaders import make_monkey_patch_loader

_importer = MonkeyPatchImporter(__name__, lambda _: None)
load_module = _importer.load_module
load_module = make_monkey_patch_loader(__name__, lambda _: None)
7 changes: 3 additions & 4 deletions rapids_dask_dependency/patches/dask/dataframe/accessor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import sys

# Currently vendoring this module due to https://github.com/dask/dask/pull/11035
if sys.version_info >= (3, 11, 9):
from dask import __version__
from packaging.version import Version

if Version(__version__) < Version("2024.4.1"):
from rapids_dask_dependency.importer import VendoredImporter
from rapids_dask_dependency.loaders import make_vendored_loader

# Currently vendoring this module due to https://github.com/dask/dask/pull/11035
_importer = VendoredImporter(__name__)
load_module = _importer.load_module
load_module = make_vendored_loader(__name__)
5 changes: 2 additions & 3 deletions rapids_dask_dependency/patches/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from rapids_dask_dependency.importer import MonkeyPatchImporter
from rapids_dask_dependency.loaders import make_monkey_patch_loader

_importer = MonkeyPatchImporter(__name__, lambda _: None)
load_module = _importer.load_module
load_module = make_monkey_patch_loader(__name__, lambda _: None)
10 changes: 0 additions & 10 deletions rapids_dask_dependency/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,3 @@ def patch_warning_stacklevel(level):
warnings.warn = _make_warning_func(level)
yield
warnings.warn = previous_warn


# Note: The Python documentation does not make it clear whether we're guaranteed that
# spec is not a copy of the original spec, but that is the case for now. We need to
# assign this because the spec is used to update module attributes after it is
# initialized by create_module.
def update_spec(spec, original_spec):
spec.origin = original_spec.origin
spec.submodule_search_locations = original_spec.submodule_search_locations
return spec
53 changes: 53 additions & 0 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextlib
import subprocess
import tempfile
from functools import wraps
from multiprocessing import Process

import pytest


def run_test_in_subprocess(func):
def redirect_stdout_stderr(func, stdout, stderr, *args, **kwargs):
Expand Down Expand Up @@ -49,3 +52,53 @@ def test_distributed():
import distributed

assert hasattr(distributed, "_rapids_patched")


@pytest.mark.parametrize("python_version", [(3, 11, 9), (3, 11, 8)])
@run_test_in_subprocess
def test_dask_accessor(python_version):
import sys

import dask

# Simulate the version of Python and Dask needed to trigger vendoring of the
# accessor module.
sys.version_info = python_version
dask.__version__ = "2023.4.1"

from dask.dataframe import accessor

assert (hasattr(accessor, "_rapids_vendored")) == (python_version >= (3, 11, 9))


def test_dask_cli():
try:
subprocess.run(["dask", "--help"], capture_output=True, check=True)
except subprocess.CalledProcessError as e:
print(e.stdout.decode())
print(e.stderr.decode())
raise


def test_dask_as_module():
try:
subprocess.run(
["python", "-m", "dask", "--help"], capture_output=True, check=True
)
except subprocess.CalledProcessError as e:
print(e.stdout.decode())
print(e.stderr.decode())
raise


def test_distributed_cli_dask_spec_as_module():
try:
subprocess.run(
["python", "-m", "distributed.cli.dask_spec", "--help"],
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e:
print(e.stdout.decode())
print(e.stderr.decode())
raise