Skip to content

Commit

Permalink
Upgrade to latest Dask and scikit-learn (#1008)
Browse files Browse the repository at this point in the history
* Remove dask-expr from CI

* Fixup

* GLM utils

* More

* Python versions

* Typo

* Use pip for numba Python 3.13

* Fix PolynomialFeatures docstring

* Avoid persisting Futures

Closes #1003

* xfail whitten test

* force_all_finite -> ensure_all_finite

* Adjust long name test

* Bump scikit-learn

* added skip

* docs, envs

* format

* get_tags

* multinomial

* check_n_features / feature_names

---------

Co-authored-by: Tom Augspurger <[email protected]>
Co-authored-by: Tom Augspurger <[email protected]>
  • Loading branch information
3 people authored Feb 7, 2025
1 parent ed8a2b7 commit 9e55d05
Show file tree
Hide file tree
Showing 27 changed files with 151 additions and 151 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ jobs:
- uses: actions/[email protected]
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.12'
- uses: pre-commit/[email protected]
7 changes: 2 additions & 5 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ jobs:
matrix:
# os: ["windows-latest", "ubuntu-latest", "macos-latest"]
os: ["ubuntu-latest"]
python-version: ["3.9", "3.10", "3.11"]
query-planning: [true, false]
python-version: ["3.10", "3.11", "3.12", "3.13"]

env:
PYTHON_VERSION: ${{ matrix.python-version }}
PARALLEL: "true"
COVERAGE: "true"
DASK_DATAFRAME__QUERY_PLANNING: ${{ matrix.query-planning }}

steps:
- name: Checkout source
Expand All @@ -26,9 +24,8 @@ jobs:
fetch-depth: 0 # Needed by codecov.io

- name: Setup Conda Environment
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
miniforge-variant: Mambaforge
miniforge-version: latest
use-mamba: true
channel-priority: strict
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/upstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ jobs:
miniforge-version: latest
use-mamba: true
channel-priority: strict
python-version: "3.9"
environment-file: ci/environment-3.9.yaml
python-version: "3.12"
environment-file: ci/environment-3.12.yaml
activate-environment: test-environment
auto-activate-base: false

Expand Down
3 changes: 1 addition & 2 deletions ci/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ dependencies:
- pytest-cov
- pytest-mock
- python=3.10.*
- scikit-learn >=1.2.0
- scikit-learn >=1.6.1
- scipy
- sparse
- toolz
- pip
- pip:
- git+https://github.com/dask-contrib/dask-expr
- git+https://github.com/dask/dask
3 changes: 1 addition & 2 deletions ci/environment-3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ dependencies:
- pytest-cov
- pytest-mock
- python=3.11.*
- scikit-learn >=1.2.0
- scikit-learn >=1.6.1
- scipy
- sparse
- toolz
- pip
- pip:
- git+https://github.com/dask-contrib/dask-expr
- git+https://github.com/dask/dask
7 changes: 3 additions & 4 deletions ci/environment-3.9.yaml → ci/environment-3.12.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: dask-ml-3.9
name: dask-ml-3.12
channels:
- conda-forge
- defaults
Expand All @@ -15,12 +15,11 @@ dependencies:
- pytest
- pytest-cov
- pytest-mock
- python=3.9.*
- scikit-learn >=1.2.0
- python=3.12.*
- scikit-learn >=1.6.1
- scipy
- sparse
- toolz
- pip
- pip:
- git+https://github.com/dask-contrib/dask-expr
- git+https://github.com/dask/dask
27 changes: 27 additions & 0 deletions ci/environment-3.13.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: dask-ml-3.13
channels:
- conda-forge
- defaults
dependencies:
- dask-glm
- multipledispatch >=0.4.9
- mypy
# - numba # Version with Python 3.13 not on conda-forge yet
- numpy
- numpydoc
- packaging
- pandas
- psutil
- pytest
- pytest-cov
- pytest-mock
- python=3.13.*
- scikit-learn >=1.6.1
- scipy
- sparse
- toolz
- pip
- pip:
- git+https://github.com/dask/dask
# Switch to conda for `numba` once a version with Python 3.13 support is on conda-forge
- numba >= 0.61.0
3 changes: 1 addition & 2 deletions ci/environment-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- psutil
- python=3.11
- sortedcontainers
- scikit-learn >=1.2.0
- scikit-learn >=1.6.1
- scipy
- sparse
- tornado
Expand Down Expand Up @@ -49,7 +49,6 @@ dependencies:
- toolz
- cloudpickle>=1.5.0
- pandas>=1.4.0
- dask-expr
- fsspec
- scipy
- pytest
Expand Down
5 changes: 3 additions & 2 deletions dask_ml/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sklearn.base
import sklearn.utils.validation

from .utils import _check_y, check_array, check_X_y

Expand Down Expand Up @@ -61,7 +62,7 @@ def _validate_data(
The validated input. A tuple is returned if both `X` and `y` are
validated.
"""
self._check_feature_names(X, reset=reset)
sklearn.utils.validation._check_feature_names(self, X, reset=reset)

if y is None and self._get_tags()["requires_y"]:
raise ValueError(
Expand Down Expand Up @@ -94,7 +95,7 @@ def _validate_data(
out = X, y

if not no_val_X and check_params.get("ensure_2d", True):
self._check_n_features(X, reset=reset)
sklearn.utils.validation._check_n_features(self, X, reset=reset)

return out

Expand Down
2 changes: 1 addition & 1 deletion dask_ml/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _check_array(self, X):
accept_dask_dataframe=True,
accept_unknown_chunks=True,
preserve_pandas_dataframe=True,
force_all_finite=False,
ensure_all_finite=False,
)

def fit(self, X, y=None):
Expand Down
77 changes: 27 additions & 50 deletions dask_ml/linear_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,33 @@
import numpy as np
from multipledispatch import dispatch

if getattr(dd, "_dask_expr_enabled", lambda: False)():
import dask_expr

@dispatch(dask_expr.FrameBase)
def exp(A):
return da.exp(A)

@dispatch(dask_expr.FrameBase)
def absolute(A):
return da.absolute(A)

@dispatch(dask_expr.FrameBase)
def sign(A):
return da.sign(A)

@dispatch(dask_expr.FrameBase)
def log1p(A):
return da.log1p(A)

@dispatch(dask_expr.FrameBase) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
raise ValueError("'intercept' column already in 'X'")
return X.assign(intercept=1)[["intercept"] + list(columns)]

else:

@dispatch(dd._Frame)
def exp(A):
return da.exp(A)

@dispatch(dd._Frame)
def absolute(A):
return da.absolute(A)

@dispatch(dd._Frame)
def sign(A):
return da.sign(A)

@dispatch(dd._Frame)
def log1p(A):
return da.log1p(A)

@dispatch(dd._Frame) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
raise ValueError("'intercept' column already in 'X'")
return X.assign(intercept=1)[["intercept"] + list(columns)]

@dispatch(dd.DataFrame)
def exp(A):
return da.exp(A)


@dispatch(dd.DataFrame)
def absolute(A):
return da.absolute(A)


@dispatch(dd.DataFrame)
def sign(A):
return da.sign(A)


@dispatch(dd.DataFrame)
def log1p(A):
return da.log1p(A)


@dispatch(dd.DataFrame) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
raise ValueError("'intercept' column already in 'X'")
return X.assign(intercept=1)[["intercept"] + list(columns)]


@dispatch(np.ndarray) # noqa: F811
Expand Down
33 changes: 31 additions & 2 deletions dask_ml/model_selection/_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,19 @@ def get_futures(partial_fit_calls):
_specs[ident] = spec

if DISTRIBUTED_2021_02_0:
_models, _scores, _specs = dask.persist(_models, _scores, _specs)
# https://github.com/dask/dask-ml/issues/1003
# We only want to persist dask collections, not Futures.
# So we build a collection without futures and bring them back later.
to_persist = {
"models": {k: v for k, v in _models.items() if not isinstance(v, Future)},
"scores": {k: v for k, v in _scores.items() if not isinstance(v, Future)},
"specs": {k: v for k, v in _specs.items() if not isinstance(v, Future)},
}
models_p, scores_p, specs_p = dask.persist(*list(to_persist.values()))
# Update with keys not present, which should just be futures
_models = {**_models, **models_p}
_scores = {**_scores, **scores_p}
_specs = {**_specs, **specs_p}
else:
_models, _scores, _specs = dask.persist(
_models, _scores, _specs, priority={tuple(_specs.values()): -1}
Expand Down Expand Up @@ -315,7 +327,24 @@ def get_futures(partial_fit_calls):
_specs[ident] = spec

if DISTRIBUTED_2021_02_0:
_models2, _scores2, _specs2 = dask.persist(_models, _scores, _specs)
# https://github.com/dask/dask-ml/issues/1003
# We only want to persist dask collections, not Futures.
# So we build a collection without futures and bring them back later.
to_persist = {
"models": {
k: v for k, v in _models.items() if not isinstance(v, Future)
},
"scores": {
k: v for k, v in _scores.items() if not isinstance(v, Future)
},
"specs": {k: v for k, v in _specs.items() if not isinstance(v, Future)},
}
models2_p, scores2_p, specs2_p = dask.persist(*list(to_persist.values()))
# Update with keys not present, which should just be futures
_models2 = {**_models, **models2_p}
_scores2 = {**_scores, **scores2_p}
_specs2 = {**_specs, **specs2_p}

else:
_models2, _scores2, _specs2 = dask.persist(
_models, _scores, _specs, priority={tuple(_specs.values()): -1}
Expand Down
4 changes: 2 additions & 2 deletions dask_ml/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_CVIterableWrapper,
)
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.utils._tags import _safe_tags
from sklearn.utils import get_tags
from sklearn.utils.metaestimators import available_if
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _num_samples, check_is_fitted
Expand Down Expand Up @@ -209,7 +209,7 @@ def build_cv_graph(
X, y, groups = to_indexable(X, y, groups)
cv = check_cv(cv, y, is_classifier(estimator))
# "pairwise" estimators require a different graph for CV splitting
is_pairwise = _safe_tags(estimator, "pairwise")
is_pairwise = get_tags(estimator).input_tags.pairwise

dsk = {}
X_name, y_name, groups_name = to_keys(dsk, X, y, groups)
Expand Down
9 changes: 5 additions & 4 deletions dask_ml/preprocessing/_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
import sklearn.preprocessing
import sklearn.utils.validation

from .._compat import SKLEARN_1_1_X
from .._typing import ArrayLike, DataFrameType, DTypeLike, SeriesType
Expand Down Expand Up @@ -167,19 +168,19 @@ def _fit(
self,
X: Union[ArrayLike, DataFrameType],
handle_unknown: str = "error",
force_all_finite: bool = True,
ensure_all_finite: bool = True,
return_counts=False,
):
X = self._validate_data(
X, accept_dask_dataframe=True, dtype=None, preserve_pandas_dataframe=True
)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
sklearn.utils.validation._check_n_features(self, X, reset=True)
sklearn.utils.validation._check_feature_names(self, X, reset=True)

if isinstance(X, np.ndarray):
kwargs = {
"handle_unknown": handle_unknown,
"force_all_finite": force_all_finite,
"ensure_all_finite": ensure_all_finite,
}

# `return_counts` expected as of scikit-learn 1.1
Expand Down
4 changes: 2 additions & 2 deletions dask_ml/preprocessing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,11 +1051,11 @@ class PolynomialFeatures(DaskMLBaseMixin, sklearn.preprocessing.PolynomialFeatur
Using False (default) returns numpy or dask arrays and mimics
sklearn's default behaviour
Examples
Attributes
"""

splitted_orig_doc = sklearn.preprocessing.PolynomialFeatures.__doc__.split(
" Examples\n"
"Attributes\n"
)
__doc__ = "".join([splitted_orig_doc[0], __doc__, splitted_orig_doc[1]])

Expand Down
Loading

0 comments on commit 9e55d05

Please sign in to comment.