Skip to content

Commit

Permalink
Merge pull request #47 from rapidsai/branch-24.06
Browse files Browse the repository at this point in the history
Forward-merge branch-24.06 into branch-24.08
  • Loading branch information
GPUtester authored May 21, 2024
2 parents aabe171 + ea774a5 commit 2d862cc
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 108 deletions.
4 changes: 2 additions & 2 deletions rapids_dask_dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from .dask_loader import DaskLoader
from .dask_loader import DaskFinder

DaskLoader.install()
DaskFinder.install()
68 changes: 37 additions & 31 deletions rapids_dask_dependency/dask_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,42 @@
import sys
from contextlib import contextmanager

from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec
from rapids_dask_dependency.utils import patch_warning_stacklevel


class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader):
def __init__(self):
self._blocklist = set()
class DaskLoader(importlib.machinery.SourceFileLoader):
def __init__(self, fullname, path, finder):
super().__init__(fullname, path)
self._finder = finder

def create_module(self, spec):
if spec.name.startswith("dask") or spec.name.startswith("distributed"):
with self.disable(spec.name):
try:
# Absolute import is important here to avoid shadowing the real dask
# and distributed modules in sys.modules. Bad things will happen if
# we use relative imports here.
proxy = importlib.import_module(
f"rapids_dask_dependency.patches.{spec.name}"
)
if hasattr(proxy, "load_module"):
return proxy.load_module(spec)
except ModuleNotFoundError:
pass
with self._finder.disable(spec.name):
try:
# Absolute import is important here to avoid shadowing the real dask
# and distributed modules in sys.modules. Bad things will happen if
# we use relative imports here.
proxy = importlib.import_module(
f"rapids_dask_dependency.patches.{spec.name}"
)
if hasattr(proxy, "load_module"):
return proxy.load_module()
except ModuleNotFoundError:
pass

# Three extra stack frames: 1) DaskLoader.create_module,
# 2) importlib.import_module, and 3) the patched warnings function (not
# including the internal frames, which warnings ignores).
with patch_warning_stacklevel(3):
mod = importlib.import_module(spec.name)

update_spec(spec, mod.__spec__)
return mod
# Three extra stack frames: 1) DaskLoader.create_module,
# 2) importlib.import_module, and 3) the patched warnings function (not
# including the internal frames, which warnings ignores).
with patch_warning_stacklevel(3):
return importlib.import_module(spec.name)

def exec_module(self, _):
pass


class DaskFinder(importlib.abc.MetaPathFinder):
def __init__(self):
self._blocklist = set()

@contextmanager
def disable(self, name):
# This is a context manager that prevents this finder from intercepting calls to
Expand All @@ -62,14 +64,18 @@ def find_spec(self, fullname: str, _, __=None):
or fullname.startswith("dask.")
or fullname.startswith("distributed.")
):
return importlib.machinery.ModuleSpec(
with self.disable(fullname):
if (real_spec := importlib.util.find_spec(fullname)) is None:
return None
spec = importlib.machinery.ModuleSpec(
name=fullname,
loader=self,
# Set these parameters dynamically in create_module
origin=None,
loader_state=None,
is_package=True,
loader=DaskLoader(fullname, real_spec.origin, self),
origin=real_spec.origin,
loader_state=real_spec.loader_state,
is_package=real_spec.submodule_search_locations is not None,
)
spec.submodule_search_locations = real_spec.submodule_search_locations
return spec
return None

@classmethod
Expand Down
55 changes: 0 additions & 55 deletions rapids_dask_dependency/importer.py

This file was deleted.

39 changes: 39 additions & 0 deletions rapids_dask_dependency/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import importlib
import importlib.util

from rapids_dask_dependency.utils import patch_warning_stacklevel

# Vendored files use a standard prefix to avoid name collisions.
DEFAULT_VENDORED_PREFIX = "__rdd_patch_"


def make_monkey_patch_loader(name, patch_func):
"""Create a loader for monkey-patched modules."""

def load_module():
# Four extra stack frames: 1) DaskLoader.create_module, 2)
# load_module, 3) importlib.import_module, and 4) the patched warnings function.
with patch_warning_stacklevel(4):
mod = importlib.import_module(
name.replace("rapids_dask_dependency.patches.", "")
)
patch_func(mod)
mod._rapids_patched = True
return mod

return load_module


def make_vendored_loader(name):
"""Create a loader for vendored modules."""

def load_module():
parts = name.split(".")
parts[-1] = DEFAULT_VENDORED_PREFIX + parts[-1]
mod = importlib.import_module(".".join(parts))
mod._rapids_vendored = True
return mod

return load_module
5 changes: 2 additions & 3 deletions rapids_dask_dependency/patches/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from rapids_dask_dependency.importer import MonkeyPatchImporter
from rapids_dask_dependency.loaders import make_monkey_patch_loader

_importer = MonkeyPatchImporter(__name__, lambda _: None)
load_module = _importer.load_module
load_module = make_monkey_patch_loader(__name__, lambda _: None)
7 changes: 3 additions & 4 deletions rapids_dask_dependency/patches/dask/dataframe/accessor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import sys

# Currently vendoring this module due to https://github.com/dask/dask/pull/11035
if sys.version_info >= (3, 11, 9):
from dask import __version__
from packaging.version import Version

if Version(__version__) < Version("2024.4.1"):
from rapids_dask_dependency.importer import VendoredImporter
from rapids_dask_dependency.loaders import make_vendored_loader

# Currently vendoring this module due to https://github.com/dask/dask/pull/11035
_importer = VendoredImporter(__name__)
load_module = _importer.load_module
load_module = make_vendored_loader(__name__)
5 changes: 2 additions & 3 deletions rapids_dask_dependency/patches/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from rapids_dask_dependency.importer import MonkeyPatchImporter
from rapids_dask_dependency.loaders import make_monkey_patch_loader

_importer = MonkeyPatchImporter(__name__, lambda _: None)
load_module = _importer.load_module
load_module = make_monkey_patch_loader(__name__, lambda _: None)
10 changes: 0 additions & 10 deletions rapids_dask_dependency/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,3 @@ def patch_warning_stacklevel(level):
warnings.warn = _make_warning_func(level)
yield
warnings.warn = previous_warn


# Note: The Python documentation does not make it clear whether we're guaranteed that
# spec is not a copy of the original spec, but that is the case for now. We need to
# assign this because the spec is used to update module attributes after it is
# initialized by create_module.
def update_spec(spec, original_spec):
spec.origin = original_spec.origin
spec.submodule_search_locations = original_spec.submodule_search_locations
return spec
53 changes: 53 additions & 0 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextlib
import subprocess
import tempfile
from functools import wraps
from multiprocessing import Process

import pytest


def run_test_in_subprocess(func):
def redirect_stdout_stderr(func, stdout, stderr, *args, **kwargs):
Expand Down Expand Up @@ -49,3 +52,53 @@ def test_distributed():
import distributed

assert hasattr(distributed, "_rapids_patched")


@pytest.mark.parametrize("python_version", [(3, 11, 9), (3, 11, 8)])
@run_test_in_subprocess
def test_dask_accessor(python_version):
import sys

import dask

# Simulate the version of Python and Dask needed to trigger vendoring of the
# accessor module.
sys.version_info = python_version
dask.__version__ = "2023.4.1"

from dask.dataframe import accessor

assert (hasattr(accessor, "_rapids_vendored")) == (python_version >= (3, 11, 9))


def test_dask_cli():
try:
subprocess.run(["dask", "--help"], capture_output=True, check=True)
except subprocess.CalledProcessError as e:
print(e.stdout.decode())
print(e.stderr.decode())
raise


def test_dask_as_module():
try:
subprocess.run(
["python", "-m", "dask", "--help"], capture_output=True, check=True
)
except subprocess.CalledProcessError as e:
print(e.stdout.decode())
print(e.stderr.decode())
raise


def test_distributed_cli_dask_spec_as_module():
try:
subprocess.run(
["python", "-m", "distributed.cli.dask_spec", "--help"],
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e:
print(e.stdout.decode())
print(e.stderr.decode())
raise

0 comments on commit 2d862cc

Please sign in to comment.