diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 582595263ab..4862a6235e8 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -821,8 +821,9 @@ jobs:
if: matrix.engine == 'python' || matrix.test_task == 'group_4'
- run: python -m pytest modin/test/interchange/dataframe_protocol/test_general.py
if: matrix.engine == 'python' || matrix.test_task == 'group_4'
- - run: python -m pytest modin/test/interchange/dataframe_protocol/pandas/test_protocol.py
+ - run: python -m pytest modin/pandas/test/interoperability
if: matrix.engine == 'python' || matrix.test_task == 'group_4'
+ - run: python -m pytest modin/test/interchange/dataframe_protocol/pandas/test_protocol.py
- uses: ./.github/workflows/upload-coverage
test-experimental:
diff --git a/.github/workflows/push-to-master.yml b/.github/workflows/push-to-master.yml
index 863b2dfaf11..304655d757f 100644
--- a/.github/workflows/push-to-master.yml
+++ b/.github/workflows/push-to-master.yml
@@ -86,6 +86,7 @@ jobs:
python -m pytest modin/pandas/test/test_general.py
python -m pytest modin/pandas/test/test_io.py
python -m pytest modin/experimental/pandas/test/test_io_exp.py
+ python -m pytest modin/pandas/test/interoperability
test-docs:
runs-on: ubuntu-latest
diff --git a/modin/pandas/test/interoperability/matplotlib/test_axes.py b/modin/pandas/test/interoperability/matplotlib/test_axes.py
new file mode 100644
index 00000000000..dc8a8c9c1cc
--- /dev/null
+++ b/modin/pandas/test/interoperability/matplotlib/test_axes.py
@@ -0,0 +1,154 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import numpy as np
+import matplotlib.pyplot as plt
+import modin.pandas as pd
+from matplotlib.testing.decorators import (
+ check_figures_equal,
+)
+
+# Note: Some test cases are run twice: once normally and once with labeled data
+# These two must be defined in the same test function or need to have
+# different baseline images to prevent race conditions when pytest runs
+# the tests with multiple threads.
+
+
+@check_figures_equal(extensions=["png"])
+def test_invisible_axes(fig_test, fig_ref):
+ ax = fig_test.subplots()
+ ax.set_visible(False)
+
+
+def test_boxplot_dates_pandas():
+ # import modin.pandas as pd
+
+ # smoke test for boxplot and dates in pandas
+ data = np.random.rand(5, 2)
+ years = pd.date_range("1/1/2000", periods=2, freq=pd.DateOffset(years=1)).year
+ plt.figure()
+ plt.boxplot(data, positions=years)
+
+
+def test_bar_pandas():
+ # Smoke test for pandas
+ df = pd.DataFrame(
+ {
+ "year": [2018, 2018, 2018],
+ "month": [1, 1, 1],
+ "day": [1, 2, 3],
+ "value": [1, 2, 3],
+ }
+ )
+ df["date"] = pd.to_datetime(df[["year", "month", "day"]])
+
+ monthly = df[["date", "value"]].groupby(["date"]).sum()
+ dates = monthly.index
+ forecast = monthly["value"]
+ baseline = monthly["value"]
+
+ fig, ax = plt.subplots()
+ ax.bar(dates, forecast, width=10, align="center")
+ ax.plot(dates, baseline, color="orange", lw=4)
+
+
+def test_bar_pandas_indexed():
+ # Smoke test for indexed pandas
+ df = pd.DataFrame({"x": [1.0, 2.0, 3.0], "width": [0.2, 0.4, 0.6]}, index=[1, 2, 3])
+ fig, ax = plt.subplots()
+ ax.bar(df.x, 1.0, width=df.width)
+
+
+def test_pandas_minimal_plot():
+ # smoke test that series and index objects do not warn
+ for x in [pd.Series([1, 2], dtype="float64"), pd.Series([1, 2], dtype="Float64")]:
+ plt.plot(x, x)
+ plt.plot(x.index, x)
+ plt.plot(x)
+ plt.plot(x.index)
+ df = pd.DataFrame({"col": [1, 2, 3]})
+ plt.plot(df)
+ plt.plot(df, df)
+
+
+@check_figures_equal(extensions=["png"])
+def test_violinplot_pandas_series(fig_test, fig_ref):
+ np.random.seed(110433579)
+ s1 = pd.Series(np.random.normal(size=7), index=[9, 8, 7, 6, 5, 4, 3])
+ s2 = pd.Series(np.random.normal(size=9), index=list("ABCDEFGHI"))
+ s3 = pd.Series(np.random.normal(size=11))
+ fig_test.subplots().violinplot([s1, s2, s3])
+ fig_ref.subplots().violinplot([s1.values, s2.values, s3.values])
+
+
+def test_pandas_pcolormesh():
+ time = pd.date_range("2000-01-01", periods=10)
+ depth = np.arange(20)
+ data = np.random.rand(19, 9)
+
+ fig, ax = plt.subplots()
+ ax.pcolormesh(time, depth, data)
+
+
+def test_pandas_indexing_dates():
+ dates = np.arange("2005-02", "2005-03", dtype="datetime64[D]")
+ values = np.sin(range(len(dates)))
+ df = pd.DataFrame({"dates": dates, "values": values})
+
+ ax = plt.gca()
+
+ without_zero_index = df[np.array(df.index) % 2 == 1].copy()
+ ax.plot("dates", "values", data=without_zero_index)
+
+
+def test_pandas_errorbar_indexing():
+ df = pd.DataFrame(
+ np.random.uniform(size=(5, 4)),
+ columns=["x", "y", "xe", "ye"],
+ index=[1, 2, 3, 4, 5],
+ )
+ fig, ax = plt.subplots()
+ ax.errorbar("x", "y", xerr="xe", yerr="ye", data=df)
+
+
+def test_pandas_index_shape():
+ df = pd.DataFrame({"XX": [4, 5, 6], "YY": [7, 1, 2]})
+ fig, ax = plt.subplots()
+ ax.plot(df.index, df["YY"])
+
+
+def test_pandas_indexing_hist():
+ ser_1 = pd.Series(data=[1, 2, 2, 3, 3, 4, 4, 4, 4, 5])
+ ser_2 = ser_1.iloc[1:]
+ fig, ax = plt.subplots()
+ ax.hist(ser_2)
+
+
+def test_pandas_bar_align_center():
+ # Tests fix for issue 8767
+ df = pd.DataFrame({"a": range(2), "b": range(2)})
+
+ fig, ax = plt.subplots(1)
+
+ ax.bar(df.loc[df["a"] == 1, "b"], df.loc[df["a"] == 1, "b"], align="center")
+
+ fig.canvas.draw()
+
+
+def test_scatter_series_non_zero_index():
+ # create non-zero index
+ ids = range(10, 18)
+ x = pd.Series(np.random.uniform(size=8), index=ids)
+ y = pd.Series(np.random.uniform(size=8), index=ids)
+ c = pd.Series([1, 1, 1, 1, 1, 0, 0, 0], index=ids)
+ plt.scatter(x, y, c)
diff --git a/modin/pandas/test/interoperability/matplotlib/test_cbook.py b/modin/pandas/test/interoperability/matplotlib/test_cbook.py
new file mode 100644
index 00000000000..788e36df54b
--- /dev/null
+++ b/modin/pandas/test/interoperability/matplotlib/test_cbook.py
@@ -0,0 +1,43 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import numpy as np
+import modin.pandas as pd
+from matplotlib import cbook
+
+
+def test_reshape2d_pandas():
+ # separate to allow the rest of the tests to run if no pandas...
+ X = np.arange(30).reshape(10, 3)
+ x = pd.DataFrame(X, columns=["a", "b", "c"])
+ Xnew = cbook._reshape_2D(x, "x")
+ # Need to check each row because _reshape_2D returns a list of arrays:
+ for x, xnew in zip(X.T, Xnew):
+ np.testing.assert_array_equal(x, xnew)
+
+
+def test_index_of_pandas():
+ # separate to allow the rest of the tests to run if no pandas...
+ X = np.arange(30).reshape(10, 3)
+ x = pd.DataFrame(X, columns=["a", "b", "c"])
+ Idx, Xnew = cbook.index_of(x)
+ np.testing.assert_array_equal(X, Xnew)
+ IdxRef = np.arange(10)
+ np.testing.assert_array_equal(Idx, IdxRef)
+
+
+def test_safe_first_element_pandas_series():
+ # deliberately create a pandas series with index not starting from 0
+ s = pd.Series(range(5), index=range(10, 15))
+ actual = cbook._safe_first_finite(s)
+ assert actual == 0
diff --git a/modin/pandas/test/interoperability/matplotlib/test_collections.py b/modin/pandas/test/interoperability/matplotlib/test_collections.py
new file mode 100644
index 00000000000..f89118213cd
--- /dev/null
+++ b/modin/pandas/test/interoperability/matplotlib/test_collections.py
@@ -0,0 +1,31 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import modin.pandas as pd
+from matplotlib.collections import Collection
+
+
+def test_pandas_indexing():
+ # Should not fail break when faced with a
+ # non-zero indexed series
+ index = [11, 12, 13]
+ ec = fc = pd.Series(["red", "blue", "green"], index=index)
+ lw = pd.Series([1, 2, 3], index=index)
+ ls = pd.Series(["solid", "dashed", "dashdot"], index=index)
+ aa = pd.Series([True, False, True], index=index)
+
+ Collection(edgecolors=ec)
+ Collection(facecolors=fc)
+ Collection(linewidths=lw)
+ Collection(linestyles=ls)
+ Collection(antialiaseds=aa)
diff --git a/modin/pandas/test/interoperability/matplotlib/test_colors.py b/modin/pandas/test/interoperability/matplotlib/test_colors.py
new file mode 100644
index 00000000000..57174217a20
--- /dev/null
+++ b/modin/pandas/test/interoperability/matplotlib/test_colors.py
@@ -0,0 +1,27 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+from numpy.testing import assert_array_equal
+import matplotlib.colors as mcolors
+import modin.pandas as pd
+
+
+def test_pandas_iterable():
+ # Using a list or series yields equivalent
+ # colormaps, i.e the series isn't seen as
+ # a single color
+ lst = ["red", "blue", "green"]
+ s = pd.Series(lst)
+ cm1 = mcolors.ListedColormap(lst, N=5)
+ cm2 = mcolors.ListedColormap(s, N=5)
+ assert_array_equal(cm1.colors, cm2.colors)
diff --git a/modin/pandas/test/interoperability/matplotlib/test_dates.py b/modin/pandas/test/interoperability/matplotlib/test_dates.py
new file mode 100644
index 00000000000..0a18e900a0c
--- /dev/null
+++ b/modin/pandas/test/interoperability/matplotlib/test_dates.py
@@ -0,0 +1,1881 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+import datetime
+import dateutil.tz
+import dateutil.rrule
+import functools
+import numpy as np
+import pytest
+
+from matplotlib import _api, rc_context, style
+import matplotlib.dates as mdates
+import matplotlib.pyplot as plt
+from matplotlib.testing.decorators import image_comparison
+import matplotlib.ticker as mticker
+import modin.pandas as pd
+
+
+def test_date_numpyx():
+ # test that numpy dates work properly...
+ base = datetime.datetime(2017, 1, 1)
+ time = [base + datetime.timedelta(days=x) for x in range(0, 3)]
+ timenp = np.array(time, dtype="datetime64[ns]")
+ data = np.array([0.0, 2.0, 1.0])
+ fig = plt.figure(figsize=(10, 2))
+ ax = fig.add_subplot(1, 1, 1)
+ (h,) = ax.plot(time, data)
+ (hnp,) = ax.plot(timenp, data)
+ np.testing.assert_equal(h.get_xdata(orig=False), hnp.get_xdata(orig=False))
+ fig = plt.figure(figsize=(10, 2))
+ ax = fig.add_subplot(1, 1, 1)
+ (h,) = ax.plot(data, time)
+ (hnp,) = ax.plot(data, timenp)
+ np.testing.assert_equal(h.get_ydata(orig=False), hnp.get_ydata(orig=False))
+
+
+@pytest.mark.parametrize(
+ "t0",
+ [
+ datetime.datetime(2017, 1, 1, 0, 1, 1),
+ [
+ datetime.datetime(2017, 1, 1, 0, 1, 1),
+ datetime.datetime(2017, 1, 1, 1, 1, 1),
+ ],
+ [
+ [
+ datetime.datetime(2017, 1, 1, 0, 1, 1),
+ datetime.datetime(2017, 1, 1, 1, 1, 1),
+ ],
+ [
+ datetime.datetime(2017, 1, 1, 2, 1, 1),
+ datetime.datetime(2017, 1, 1, 3, 1, 1),
+ ],
+ ],
+ ],
+)
+@pytest.mark.parametrize(
+ "dtype", ["datetime64[s]", "datetime64[us]", "datetime64[ms]", "datetime64[ns]"]
+)
+def test_date_date2num_numpy(t0, dtype):
+ time = mdates.date2num(t0)
+ tnp = np.array(t0, dtype=dtype)
+ nptime = mdates.date2num(tnp)
+ np.testing.assert_equal(time, nptime)
+
+
+@pytest.mark.parametrize(
+ "dtype", ["datetime64[s]", "datetime64[us]", "datetime64[ms]", "datetime64[ns]"]
+)
+def test_date2num_NaT(dtype):
+ t0 = datetime.datetime(2017, 1, 1, 0, 1, 1)
+ tmpl = [mdates.date2num(t0), np.nan]
+ tnp = np.array([t0, "NaT"], dtype=dtype)
+ nptime = mdates.date2num(tnp)
+ np.testing.assert_array_equal(tmpl, nptime)
+
+
+@pytest.mark.parametrize("units", ["s", "ms", "us", "ns"])
+def test_date2num_NaT_scalar(units):
+ tmpl = mdates.date2num(np.datetime64("NaT", units))
+ assert np.isnan(tmpl)
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_date2num_masked():
+ # Without tzinfo
+ base = datetime.datetime(2022, 12, 15)
+ dates = np.ma.array(
+ [base + datetime.timedelta(days=(2 * i)) for i in range(7)],
+ mask=[0, 1, 1, 0, 0, 0, 1],
+ )
+ npdates = mdates.date2num(dates)
+ np.testing.assert_array_equal(
+ np.ma.getmask(npdates), (False, True, True, False, False, False, True)
+ )
+
+ # With tzinfo
+ base = datetime.datetime(2022, 12, 15, tzinfo=mdates.UTC)
+ dates = np.ma.array(
+ [base + datetime.timedelta(days=(2 * i)) for i in range(7)],
+ mask=[0, 1, 1, 0, 0, 0, 1],
+ )
+ npdates = mdates.date2num(dates)
+ np.testing.assert_array_equal(
+ np.ma.getmask(npdates), (False, True, True, False, False, False, True)
+ )
+
+
+def test_date_empty():
+ # make sure we do the right thing when told to plot dates even
+ # if no date data has been presented, cf
+ # http://sourceforge.net/tracker/?func=detail&aid=2850075&group_id=80706&atid=560720
+ fig, ax = plt.subplots()
+ ax.xaxis_date()
+ fig.draw_without_rendering()
+ np.testing.assert_allclose(
+ ax.get_xlim(),
+ [
+ mdates.date2num(np.datetime64("1970-01-01")),
+ mdates.date2num(np.datetime64("1970-01-02")),
+ ],
+ )
+
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("0000-12-31")
+ fig, ax = plt.subplots()
+ ax.xaxis_date()
+ fig.draw_without_rendering()
+ np.testing.assert_allclose(
+ ax.get_xlim(),
+ [
+ mdates.date2num(np.datetime64("1970-01-01")),
+ mdates.date2num(np.datetime64("1970-01-02")),
+ ],
+ )
+ mdates._reset_epoch_test_example()
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_date_not_empty():
+ fig = plt.figure()
+ ax = fig.add_subplot()
+
+ ax.plot([50, 70], [1, 2])
+ ax.xaxis.axis_date()
+ np.testing.assert_allclose(ax.get_xlim(), [50, 70])
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_axhline():
+ # make sure that axhline doesn't set the xlimits...
+ fig, ax = plt.subplots()
+ ax.axhline(1.5)
+ ax.plot([np.datetime64("2016-01-01"), np.datetime64("2016-01-02")], [1, 2])
+ np.testing.assert_allclose(
+ ax.get_xlim(),
+ [
+ mdates.date2num(np.datetime64("2016-01-01")),
+ mdates.date2num(np.datetime64("2016-01-02")),
+ ],
+ )
+
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("0000-12-31")
+ fig, ax = plt.subplots()
+ ax.axhline(1.5)
+ ax.plot([np.datetime64("2016-01-01"), np.datetime64("2016-01-02")], [1, 2])
+ np.testing.assert_allclose(
+ ax.get_xlim(),
+ [
+ mdates.date2num(np.datetime64("2016-01-01")),
+ mdates.date2num(np.datetime64("2016-01-02")),
+ ],
+ )
+ mdates._reset_epoch_test_example()
+
+
+@pytest.mark.skip(reason="Failing test")
+@image_comparison(["date_axhspan.png"])
+def test_date_axhspan():
+ # test axhspan with date inputs
+ t0 = datetime.datetime(2009, 1, 20)
+ tf = datetime.datetime(2009, 1, 21)
+ fig, ax = plt.subplots()
+ ax.axhspan(t0, tf, facecolor="blue", alpha=0.25)
+ ax.set_ylim(t0 - datetime.timedelta(days=5), tf + datetime.timedelta(days=5))
+ fig.subplots_adjust(left=0.25)
+
+
+@pytest.mark.skip(reason="Failing test")
+@image_comparison(["date_axvspan.png"])
+def test_date_axvspan():
+ # test axvspan with date inputs
+ t0 = datetime.datetime(2000, 1, 20)
+ tf = datetime.datetime(2010, 1, 21)
+ fig, ax = plt.subplots()
+ ax.axvspan(t0, tf, facecolor="blue", alpha=0.25)
+ ax.set_xlim(t0 - datetime.timedelta(days=720), tf + datetime.timedelta(days=720))
+ fig.autofmt_xdate()
+
+
+@pytest.mark.skip(reason="Failing test")
+@image_comparison(["date_axhline.png"])
+def test_date_axhline():
+ # test axhline with date inputs
+ t0 = datetime.datetime(2009, 1, 20)
+ tf = datetime.datetime(2009, 1, 31)
+ fig, ax = plt.subplots()
+ ax.axhline(t0, color="blue", lw=3)
+ ax.set_ylim(t0 - datetime.timedelta(days=5), tf + datetime.timedelta(days=5))
+ fig.subplots_adjust(left=0.25)
+
+
+@pytest.mark.skip(reason="Failing test")
+@image_comparison(["date_axvline.png"])
+def test_date_axvline():
+ # test axvline with date inputs
+ t0 = datetime.datetime(2000, 1, 20)
+ tf = datetime.datetime(2000, 1, 21)
+ fig, ax = plt.subplots()
+ ax.axvline(t0, color="red", lw=3)
+ ax.set_xlim(t0 - datetime.timedelta(days=5), tf + datetime.timedelta(days=5))
+ fig.autofmt_xdate()
+
+
+def test_too_many_date_ticks(caplog):
+ # Attempt to test SF 2715172, see
+ # https://sourceforge.net/tracker/?func=detail&aid=2715172&group_id=80706&atid=560720
+ # setting equal datetimes triggers and expander call in
+ # transforms.nonsingular which results in too many ticks in the
+ # DayLocator. This should emit a log at WARNING level.
+ caplog.set_level("WARNING")
+ t0 = datetime.datetime(2000, 1, 20)
+ tf = datetime.datetime(2000, 1, 20)
+ fig, ax = plt.subplots()
+ with pytest.warns(UserWarning) as rec:
+ ax.set_xlim((t0, tf), auto=True)
+ assert len(rec) == 1
+ assert "Attempting to set identical low and high xlims" in str(rec[0].message)
+ ax.plot([], [])
+ ax.xaxis.set_major_locator(mdates.DayLocator())
+ v = ax.xaxis.get_major_locator()()
+ assert len(v) > 1000
+ # The warning is emitted multiple times because the major locator is also
+ # called both when placing the minor ticks (for overstriking detection) and
+ # during tick label positioning.
+ assert caplog.records and all(
+ record.name == "matplotlib.ticker" and record.levelname == "WARNING"
+ for record in caplog.records
+ )
+ assert len(caplog.records) > 0
+
+
+def _new_epoch_decorator(thefunc):
+ @functools.wraps(thefunc)
+ def wrapper():
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("2000-01-01")
+ thefunc()
+ mdates._reset_epoch_test_example()
+
+ return wrapper
+
+
+@pytest.mark.skip(reason="Failing test")
+@image_comparison(["RRuleLocator_bounds.png"])
+def test_RRuleLocator():
+ import matplotlib.testing.jpl_units as units
+
+ units.register()
+ # This will cause the RRuleLocator to go out of bounds when it tries
+ # to add padding to the limits, so we make sure it caps at the correct
+ # boundary values.
+ t0 = datetime.datetime(1000, 1, 1)
+ tf = datetime.datetime(6000, 1, 1)
+
+ fig = plt.figure()
+ ax = plt.subplot()
+ ax.set_autoscale_on(True)
+ ax.plot([t0, tf], [0.0, 1.0], marker="o")
+
+ rrule = mdates.rrulewrapper(dateutil.rrule.YEARLY, interval=500)
+ locator = mdates.RRuleLocator(rrule)
+ ax.xaxis.set_major_locator(locator)
+ ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(locator))
+
+ ax.autoscale_view()
+ fig.autofmt_xdate()
+
+
+def test_RRuleLocator_dayrange():
+ loc = mdates.DayLocator()
+ x1 = datetime.datetime(year=1, month=1, day=1, tzinfo=mdates.UTC)
+ y1 = datetime.datetime(year=1, month=1, day=16, tzinfo=mdates.UTC)
+ loc.tick_values(x1, y1)
+ # On success, no overflow error shall be thrown
+
+
+def test_RRuleLocator_close_minmax():
+ # if d1 and d2 are very close together, rrule cannot create
+ # reasonable tick intervals; ensure that this is handled properly
+ rrule = mdates.rrulewrapper(dateutil.rrule.SECONDLY, interval=5)
+ loc = mdates.RRuleLocator(rrule)
+ d1 = datetime.datetime(year=2020, month=1, day=1)
+ d2 = datetime.datetime(year=2020, month=1, day=1, microsecond=1)
+ expected = ["2020-01-01 00:00:00+00:00", "2020-01-01 00:00:00.000001+00:00"]
+ assert list(map(str, mdates.num2date(loc.tick_values(d1, d2)))) == expected
+
+
+@pytest.mark.skip(reason="Failing test")
+@image_comparison(["DateFormatter_fractionalSeconds.png"])
+def test_DateFormatter():
+ import matplotlib.testing.jpl_units as units
+
+ units.register()
+
+ # Lets make sure that DateFormatter will allow us to have tick marks
+ # at intervals of fractional seconds.
+
+ t0 = datetime.datetime(2001, 1, 1, 0, 0, 0)
+ tf = datetime.datetime(2001, 1, 1, 0, 0, 1)
+
+ fig = plt.figure()
+ ax = plt.subplot()
+ ax.set_autoscale_on(True)
+ ax.plot([t0, tf], [0.0, 1.0], marker="o")
+
+ # rrule = mpldates.rrulewrapper( dateutil.rrule.YEARLY, interval=500 )
+ # locator = mpldates.RRuleLocator( rrule )
+ # ax.xaxis.set_major_locator( locator )
+ # ax.xaxis.set_major_formatter( mpldates.AutoDateFormatter(locator) )
+
+ ax.autoscale_view()
+ fig.autofmt_xdate()
+
+
+def test_locator_set_formatter():
+ """
+ Test if setting the locator only will update the AutoDateFormatter to use
+ the new locator.
+ """
+ plt.rcParams["date.autoformatter.minute"] = "%d %H:%M"
+ t = [
+ datetime.datetime(2018, 9, 30, 8, 0),
+ datetime.datetime(2018, 9, 30, 8, 59),
+ datetime.datetime(2018, 9, 30, 10, 30),
+ ]
+ x = [2, 3, 1]
+
+ fig, ax = plt.subplots()
+ ax.plot(t, x)
+ ax.xaxis.set_major_locator(mdates.MinuteLocator((0, 30)))
+ fig.canvas.draw()
+ ticklabels = [tl.get_text() for tl in ax.get_xticklabels()]
+ expected = ["30 08:00", "30 08:30", "30 09:00", "30 09:30", "30 10:00", "30 10:30"]
+ assert ticklabels == expected
+
+ ax.xaxis.set_major_locator(mticker.NullLocator())
+ ax.xaxis.set_minor_locator(mdates.MinuteLocator((5, 55)))
+ decoy_loc = mdates.MinuteLocator((12, 27))
+ ax.xaxis.set_minor_formatter(mdates.AutoDateFormatter(decoy_loc))
+
+ ax.xaxis.set_minor_locator(mdates.MinuteLocator((15, 45)))
+ fig.canvas.draw()
+ ticklabels = [tl.get_text() for tl in ax.get_xticklabels(which="minor")]
+ expected = ["30 08:15", "30 08:45", "30 09:15", "30 09:45", "30 10:15"]
+ assert ticklabels == expected
+
+
+def test_date_formatter_callable():
+ class _Locator:
+ def _get_unit(self):
+ return -11
+
+ def callable_formatting_function(dates, _):
+ return [dt.strftime("%d-%m//%Y") for dt in dates]
+
+ formatter = mdates.AutoDateFormatter(_Locator())
+ formatter.scaled[-10] = callable_formatting_function
+ assert formatter([datetime.datetime(2014, 12, 25)]) == ["25-12//2014"]
+
+
+@pytest.mark.parametrize(
+ "delta, expected",
+ [
+ (
+ datetime.timedelta(weeks=52 * 200),
+ [r"$\mathdefault{%d}$" % year for year in range(1990, 2171, 20)],
+ ),
+ (
+ datetime.timedelta(days=30),
+ [r"$\mathdefault{1990{-}01{-}%02d}$" % day for day in range(1, 32, 3)],
+ ),
+ (
+ datetime.timedelta(hours=20),
+ [r"$\mathdefault{01{-}01\;%02d}$" % hour for hour in range(0, 21, 2)],
+ ),
+ (
+ datetime.timedelta(minutes=10),
+ [r"$\mathdefault{01\;00{:}%02d}$" % minu for minu in range(0, 11)],
+ ),
+ ],
+)
+def test_date_formatter_usetex(delta, expected):
+ style.use("default")
+
+ d1 = datetime.datetime(1990, 1, 1)
+ d2 = d1 + delta
+
+ locator = mdates.AutoDateLocator(interval_multiples=False)
+ locator.create_dummy_axis()
+ locator.axis.set_view_interval(mdates.date2num(d1), mdates.date2num(d2))
+
+ formatter = mdates.AutoDateFormatter(locator, usetex=True)
+ assert [formatter(loc) for loc in locator()] == expected
+
+
+def test_drange():
+ """
+ This test should check if drange works as expected, and if all the
+ rounding errors are fixed
+ """
+ start = datetime.datetime(2011, 1, 1, tzinfo=mdates.UTC)
+ end = datetime.datetime(2011, 1, 2, tzinfo=mdates.UTC)
+ delta = datetime.timedelta(hours=1)
+ # We expect 24 values in drange(start, end, delta), because drange returns
+ # dates from an half open interval [start, end)
+ assert len(mdates.drange(start, end, delta)) == 24
+
+ # Same if interval ends slightly earlier
+ end = end - datetime.timedelta(microseconds=1)
+ assert len(mdates.drange(start, end, delta)) == 24
+
+ # if end is a little bit later, we expect the range to contain one element
+ # more
+ end = end + datetime.timedelta(microseconds=2)
+ assert len(mdates.drange(start, end, delta)) == 25
+
+ # reset end
+ end = datetime.datetime(2011, 1, 2, tzinfo=mdates.UTC)
+
+ # and tst drange with "complicated" floats:
+ # 4 hours = 1/6 day, this is an "dangerous" float
+ delta = datetime.timedelta(hours=4)
+ daterange = mdates.drange(start, end, delta)
+ assert len(daterange) == 6
+ assert mdates.num2date(daterange[-1]) == (end - delta)
+
+
+@_new_epoch_decorator
+def test_auto_date_locator():
+ def _create_auto_date_locator(date1, date2):
+ locator = mdates.AutoDateLocator(interval_multiples=False)
+ locator.create_dummy_axis()
+ locator.axis.set_view_interval(*mdates.date2num([date1, date2]))
+ return locator
+
+ d1 = datetime.datetime(1990, 1, 1)
+ results = (
+ [
+ datetime.timedelta(weeks=52 * 200),
+ [
+ "1990-01-01 00:00:00+00:00",
+ "2010-01-01 00:00:00+00:00",
+ "2030-01-01 00:00:00+00:00",
+ "2050-01-01 00:00:00+00:00",
+ "2070-01-01 00:00:00+00:00",
+ "2090-01-01 00:00:00+00:00",
+ "2110-01-01 00:00:00+00:00",
+ "2130-01-01 00:00:00+00:00",
+ "2150-01-01 00:00:00+00:00",
+ "2170-01-01 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(weeks=52),
+ [
+ "1990-01-01 00:00:00+00:00",
+ "1990-02-01 00:00:00+00:00",
+ "1990-03-01 00:00:00+00:00",
+ "1990-04-01 00:00:00+00:00",
+ "1990-05-01 00:00:00+00:00",
+ "1990-06-01 00:00:00+00:00",
+ "1990-07-01 00:00:00+00:00",
+ "1990-08-01 00:00:00+00:00",
+ "1990-09-01 00:00:00+00:00",
+ "1990-10-01 00:00:00+00:00",
+ "1990-11-01 00:00:00+00:00",
+ "1990-12-01 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(days=141),
+ [
+ "1990-01-05 00:00:00+00:00",
+ "1990-01-26 00:00:00+00:00",
+ "1990-02-16 00:00:00+00:00",
+ "1990-03-09 00:00:00+00:00",
+ "1990-03-30 00:00:00+00:00",
+ "1990-04-20 00:00:00+00:00",
+ "1990-05-11 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(days=40),
+ [
+ "1990-01-03 00:00:00+00:00",
+ "1990-01-10 00:00:00+00:00",
+ "1990-01-17 00:00:00+00:00",
+ "1990-01-24 00:00:00+00:00",
+ "1990-01-31 00:00:00+00:00",
+ "1990-02-07 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(hours=40),
+ [
+ "1990-01-01 00:00:00+00:00",
+ "1990-01-01 04:00:00+00:00",
+ "1990-01-01 08:00:00+00:00",
+ "1990-01-01 12:00:00+00:00",
+ "1990-01-01 16:00:00+00:00",
+ "1990-01-01 20:00:00+00:00",
+ "1990-01-02 00:00:00+00:00",
+ "1990-01-02 04:00:00+00:00",
+ "1990-01-02 08:00:00+00:00",
+ "1990-01-02 12:00:00+00:00",
+ "1990-01-02 16:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(minutes=20),
+ [
+ "1990-01-01 00:00:00+00:00",
+ "1990-01-01 00:05:00+00:00",
+ "1990-01-01 00:10:00+00:00",
+ "1990-01-01 00:15:00+00:00",
+ "1990-01-01 00:20:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(seconds=40),
+ [
+ "1990-01-01 00:00:00+00:00",
+ "1990-01-01 00:00:05+00:00",
+ "1990-01-01 00:00:10+00:00",
+ "1990-01-01 00:00:15+00:00",
+ "1990-01-01 00:00:20+00:00",
+ "1990-01-01 00:00:25+00:00",
+ "1990-01-01 00:00:30+00:00",
+ "1990-01-01 00:00:35+00:00",
+ "1990-01-01 00:00:40+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(microseconds=1500),
+ [
+ "1989-12-31 23:59:59.999500+00:00",
+ "1990-01-01 00:00:00+00:00",
+ "1990-01-01 00:00:00.000500+00:00",
+ "1990-01-01 00:00:00.001000+00:00",
+ "1990-01-01 00:00:00.001500+00:00",
+ "1990-01-01 00:00:00.002000+00:00",
+ ],
+ ],
+ )
+
+ for t_delta, expected in results:
+ d2 = d1 + t_delta
+ locator = _create_auto_date_locator(d1, d2)
+ assert list(map(str, mdates.num2date(locator()))) == expected
+
+ locator = mdates.AutoDateLocator(interval_multiples=False)
+ assert locator.maxticks == {0: 11, 1: 12, 3: 11, 4: 12, 5: 11, 6: 11, 7: 8}
+
+ locator = mdates.AutoDateLocator(maxticks={dateutil.rrule.MONTHLY: 5})
+ assert locator.maxticks == {0: 11, 1: 5, 3: 11, 4: 12, 5: 11, 6: 11, 7: 8}
+
+ locator = mdates.AutoDateLocator(maxticks=5)
+ assert locator.maxticks == {0: 5, 1: 5, 3: 5, 4: 5, 5: 5, 6: 5, 7: 5}
+
+
+@_new_epoch_decorator
+def test_auto_date_locator_intmult():
+ def _create_auto_date_locator(date1, date2):
+ locator = mdates.AutoDateLocator(interval_multiples=True)
+ locator.create_dummy_axis()
+ locator.axis.set_view_interval(*mdates.date2num([date1, date2]))
+ return locator
+
+ results = (
+ [
+ datetime.timedelta(weeks=52 * 200),
+ [
+ "1980-01-01 00:00:00+00:00",
+ "2000-01-01 00:00:00+00:00",
+ "2020-01-01 00:00:00+00:00",
+ "2040-01-01 00:00:00+00:00",
+ "2060-01-01 00:00:00+00:00",
+ "2080-01-01 00:00:00+00:00",
+ "2100-01-01 00:00:00+00:00",
+ "2120-01-01 00:00:00+00:00",
+ "2140-01-01 00:00:00+00:00",
+ "2160-01-01 00:00:00+00:00",
+ "2180-01-01 00:00:00+00:00",
+ "2200-01-01 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(weeks=52),
+ [
+ "1997-01-01 00:00:00+00:00",
+ "1997-02-01 00:00:00+00:00",
+ "1997-03-01 00:00:00+00:00",
+ "1997-04-01 00:00:00+00:00",
+ "1997-05-01 00:00:00+00:00",
+ "1997-06-01 00:00:00+00:00",
+ "1997-07-01 00:00:00+00:00",
+ "1997-08-01 00:00:00+00:00",
+ "1997-09-01 00:00:00+00:00",
+ "1997-10-01 00:00:00+00:00",
+ "1997-11-01 00:00:00+00:00",
+ "1997-12-01 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(days=141),
+ [
+ "1997-01-01 00:00:00+00:00",
+ "1997-01-15 00:00:00+00:00",
+ "1997-02-01 00:00:00+00:00",
+ "1997-02-15 00:00:00+00:00",
+ "1997-03-01 00:00:00+00:00",
+ "1997-03-15 00:00:00+00:00",
+ "1997-04-01 00:00:00+00:00",
+ "1997-04-15 00:00:00+00:00",
+ "1997-05-01 00:00:00+00:00",
+ "1997-05-15 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(days=40),
+ [
+ "1997-01-01 00:00:00+00:00",
+ "1997-01-05 00:00:00+00:00",
+ "1997-01-09 00:00:00+00:00",
+ "1997-01-13 00:00:00+00:00",
+ "1997-01-17 00:00:00+00:00",
+ "1997-01-21 00:00:00+00:00",
+ "1997-01-25 00:00:00+00:00",
+ "1997-01-29 00:00:00+00:00",
+ "1997-02-01 00:00:00+00:00",
+ "1997-02-05 00:00:00+00:00",
+ "1997-02-09 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(hours=40),
+ [
+ "1997-01-01 00:00:00+00:00",
+ "1997-01-01 04:00:00+00:00",
+ "1997-01-01 08:00:00+00:00",
+ "1997-01-01 12:00:00+00:00",
+ "1997-01-01 16:00:00+00:00",
+ "1997-01-01 20:00:00+00:00",
+ "1997-01-02 00:00:00+00:00",
+ "1997-01-02 04:00:00+00:00",
+ "1997-01-02 08:00:00+00:00",
+ "1997-01-02 12:00:00+00:00",
+ "1997-01-02 16:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(minutes=20),
+ [
+ "1997-01-01 00:00:00+00:00",
+ "1997-01-01 00:05:00+00:00",
+ "1997-01-01 00:10:00+00:00",
+ "1997-01-01 00:15:00+00:00",
+ "1997-01-01 00:20:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(seconds=40),
+ [
+ "1997-01-01 00:00:00+00:00",
+ "1997-01-01 00:00:05+00:00",
+ "1997-01-01 00:00:10+00:00",
+ "1997-01-01 00:00:15+00:00",
+ "1997-01-01 00:00:20+00:00",
+ "1997-01-01 00:00:25+00:00",
+ "1997-01-01 00:00:30+00:00",
+ "1997-01-01 00:00:35+00:00",
+ "1997-01-01 00:00:40+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(microseconds=1500),
+ [
+ "1996-12-31 23:59:59.999500+00:00",
+ "1997-01-01 00:00:00+00:00",
+ "1997-01-01 00:00:00.000500+00:00",
+ "1997-01-01 00:00:00.001000+00:00",
+ "1997-01-01 00:00:00.001500+00:00",
+ "1997-01-01 00:00:00.002000+00:00",
+ ],
+ ],
+ )
+
+ d1 = datetime.datetime(1997, 1, 1)
+ for t_delta, expected in results:
+ d2 = d1 + t_delta
+ locator = _create_auto_date_locator(d1, d2)
+ assert list(map(str, mdates.num2date(locator()))) == expected
+
+
+def test_concise_formatter_subsecond():
+ locator = mdates.AutoDateLocator(interval_multiples=True)
+ formatter = mdates.ConciseDateFormatter(locator)
+ year_1996 = 9861.0
+ strings = formatter.format_ticks(
+ [
+ year_1996,
+ year_1996 + 500 / mdates.MUSECONDS_PER_DAY,
+ year_1996 + 900 / mdates.MUSECONDS_PER_DAY,
+ ]
+ )
+ assert strings == ["00:00", "00.0005", "00.0009"]
+
+
+def test_concise_formatter():
+ def _create_auto_date_locator(date1, date2):
+ fig, ax = plt.subplots()
+
+ locator = mdates.AutoDateLocator(interval_multiples=True)
+ formatter = mdates.ConciseDateFormatter(locator)
+ ax.yaxis.set_major_locator(locator)
+ ax.yaxis.set_major_formatter(formatter)
+ ax.set_ylim(date1, date2)
+ fig.canvas.draw()
+ sts = [st.get_text() for st in ax.get_yticklabels()]
+ return sts
+
+ d1 = datetime.datetime(1997, 1, 1)
+ results = (
+ [datetime.timedelta(weeks=52 * 200), [str(t) for t in range(1980, 2201, 20)]],
+ [
+ datetime.timedelta(weeks=52),
+ [
+ "1997",
+ "Feb",
+ "Mar",
+ "Apr",
+ "May",
+ "Jun",
+ "Jul",
+ "Aug",
+ "Sep",
+ "Oct",
+ "Nov",
+ "Dec",
+ ],
+ ],
+ [
+ datetime.timedelta(days=141),
+ ["Jan", "15", "Feb", "15", "Mar", "15", "Apr", "15", "May", "15"],
+ ],
+ [
+ datetime.timedelta(days=40),
+ ["Jan", "05", "09", "13", "17", "21", "25", "29", "Feb", "05", "09"],
+ ],
+ [
+ datetime.timedelta(hours=40),
+ [
+ "Jan-01",
+ "04:00",
+ "08:00",
+ "12:00",
+ "16:00",
+ "20:00",
+ "Jan-02",
+ "04:00",
+ "08:00",
+ "12:00",
+ "16:00",
+ ],
+ ],
+ [datetime.timedelta(minutes=20), ["00:00", "00:05", "00:10", "00:15", "00:20"]],
+ [
+ datetime.timedelta(seconds=40),
+ ["00:00", "05", "10", "15", "20", "25", "30", "35", "40"],
+ ],
+ [
+ datetime.timedelta(seconds=2),
+ ["59.5", "00:00", "00.5", "01.0", "01.5", "02.0", "02.5"],
+ ],
+ )
+ for t_delta, expected in results:
+ d2 = d1 + t_delta
+ strings = _create_auto_date_locator(d1, d2)
+ assert strings == expected
+
+
+@pytest.mark.parametrize(
+ "t_delta, expected",
+ [
+ (datetime.timedelta(seconds=0.01), "1997-Jan-01 00:00"),
+ (datetime.timedelta(minutes=1), "1997-Jan-01 00:01"),
+ (datetime.timedelta(hours=1), "1997-Jan-01"),
+ (datetime.timedelta(days=1), "1997-Jan-02"),
+ (datetime.timedelta(weeks=1), "1997-Jan"),
+ (datetime.timedelta(weeks=26), ""),
+ (datetime.timedelta(weeks=520), ""),
+ ],
+)
+def test_concise_formatter_show_offset(t_delta, expected):
+ d1 = datetime.datetime(1997, 1, 1)
+ d2 = d1 + t_delta
+
+ fig, ax = plt.subplots()
+ locator = mdates.AutoDateLocator()
+ formatter = mdates.ConciseDateFormatter(locator)
+ ax.xaxis.set_major_locator(locator)
+ ax.xaxis.set_major_formatter(formatter)
+
+ ax.plot([d1, d2], [0, 0])
+ fig.canvas.draw()
+ assert formatter.get_offset() == expected
+
+
+def test_offset_changes():
+ fig, ax = plt.subplots()
+
+ d1 = datetime.datetime(1997, 1, 1)
+ d2 = d1 + datetime.timedelta(weeks=520)
+
+ locator = mdates.AutoDateLocator()
+ formatter = mdates.ConciseDateFormatter(locator)
+ ax.xaxis.set_major_locator(locator)
+ ax.xaxis.set_major_formatter(formatter)
+
+ ax.plot([d1, d2], [0, 0])
+ fig.draw_without_rendering()
+ assert formatter.get_offset() == ""
+ ax.set_xlim(d1, d1 + datetime.timedelta(weeks=3))
+ fig.draw_without_rendering()
+ assert formatter.get_offset() == "1997-Jan"
+ ax.set_xlim(d1 + datetime.timedelta(weeks=7), d1 + datetime.timedelta(weeks=30))
+ fig.draw_without_rendering()
+ assert formatter.get_offset() == "1997"
+ ax.set_xlim(d1, d1 + datetime.timedelta(weeks=520))
+ fig.draw_without_rendering()
+ assert formatter.get_offset() == ""
+
+
+@pytest.mark.parametrize(
+ "t_delta, expected",
+ [
+ (
+ datetime.timedelta(weeks=52 * 200),
+ ["$\\mathdefault{%d}$" % (t,) for t in range(1980, 2201, 20)],
+ ),
+ (
+ datetime.timedelta(days=40),
+ [
+ "Jan",
+ "$\\mathdefault{05}$",
+ "$\\mathdefault{09}$",
+ "$\\mathdefault{13}$",
+ "$\\mathdefault{17}$",
+ "$\\mathdefault{21}$",
+ "$\\mathdefault{25}$",
+ "$\\mathdefault{29}$",
+ "Feb",
+ "$\\mathdefault{05}$",
+ "$\\mathdefault{09}$",
+ ],
+ ),
+ (
+ datetime.timedelta(hours=40),
+ [
+ "Jan$\\mathdefault{{-}01}$",
+ "$\\mathdefault{04{:}00}$",
+ "$\\mathdefault{08{:}00}$",
+ "$\\mathdefault{12{:}00}$",
+ "$\\mathdefault{16{:}00}$",
+ "$\\mathdefault{20{:}00}$",
+ "Jan$\\mathdefault{{-}02}$",
+ "$\\mathdefault{04{:}00}$",
+ "$\\mathdefault{08{:}00}$",
+ "$\\mathdefault{12{:}00}$",
+ "$\\mathdefault{16{:}00}$",
+ ],
+ ),
+ (
+ datetime.timedelta(seconds=2),
+ [
+ "$\\mathdefault{59.5}$",
+ "$\\mathdefault{00{:}00}$",
+ "$\\mathdefault{00.5}$",
+ "$\\mathdefault{01.0}$",
+ "$\\mathdefault{01.5}$",
+ "$\\mathdefault{02.0}$",
+ "$\\mathdefault{02.5}$",
+ ],
+ ),
+ ],
+)
+def test_concise_formatter_usetex(t_delta, expected):
+ d1 = datetime.datetime(1997, 1, 1)
+ d2 = d1 + t_delta
+
+ locator = mdates.AutoDateLocator(interval_multiples=True)
+ locator.create_dummy_axis()
+ locator.axis.set_view_interval(mdates.date2num(d1), mdates.date2num(d2))
+
+ formatter = mdates.ConciseDateFormatter(locator, usetex=True)
+ assert formatter.format_ticks(locator()) == expected
+
+
+def test_concise_formatter_formats():
+ formats = ["%Y", "%m/%Y", "day: %d", "%H hr %M min", "%H hr %M min", "%S.%f sec"]
+
+ def _create_auto_date_locator(date1, date2):
+ fig, ax = plt.subplots()
+
+ locator = mdates.AutoDateLocator(interval_multiples=True)
+ formatter = mdates.ConciseDateFormatter(locator, formats=formats)
+ ax.yaxis.set_major_locator(locator)
+ ax.yaxis.set_major_formatter(formatter)
+ ax.set_ylim(date1, date2)
+ fig.canvas.draw()
+ sts = [st.get_text() for st in ax.get_yticklabels()]
+ return sts
+
+ d1 = datetime.datetime(1997, 1, 1)
+ results = (
+ [datetime.timedelta(weeks=52 * 200), [str(t) for t in range(1980, 2201, 20)]],
+ [
+ datetime.timedelta(weeks=52),
+ [
+ "1997",
+ "02/1997",
+ "03/1997",
+ "04/1997",
+ "05/1997",
+ "06/1997",
+ "07/1997",
+ "08/1997",
+ "09/1997",
+ "10/1997",
+ "11/1997",
+ "12/1997",
+ ],
+ ],
+ [
+ datetime.timedelta(days=141),
+ [
+ "01/1997",
+ "day: 15",
+ "02/1997",
+ "day: 15",
+ "03/1997",
+ "day: 15",
+ "04/1997",
+ "day: 15",
+ "05/1997",
+ "day: 15",
+ ],
+ ],
+ [
+ datetime.timedelta(days=40),
+ [
+ "01/1997",
+ "day: 05",
+ "day: 09",
+ "day: 13",
+ "day: 17",
+ "day: 21",
+ "day: 25",
+ "day: 29",
+ "02/1997",
+ "day: 05",
+ "day: 09",
+ ],
+ ],
+ [
+ datetime.timedelta(hours=40),
+ [
+ "day: 01",
+ "04 hr 00 min",
+ "08 hr 00 min",
+ "12 hr 00 min",
+ "16 hr 00 min",
+ "20 hr 00 min",
+ "day: 02",
+ "04 hr 00 min",
+ "08 hr 00 min",
+ "12 hr 00 min",
+ "16 hr 00 min",
+ ],
+ ],
+ [
+ datetime.timedelta(minutes=20),
+ [
+ "00 hr 00 min",
+ "00 hr 05 min",
+ "00 hr 10 min",
+ "00 hr 15 min",
+ "00 hr 20 min",
+ ],
+ ],
+ [
+ datetime.timedelta(seconds=40),
+ [
+ "00 hr 00 min",
+ "05.000000 sec",
+ "10.000000 sec",
+ "15.000000 sec",
+ "20.000000 sec",
+ "25.000000 sec",
+ "30.000000 sec",
+ "35.000000 sec",
+ "40.000000 sec",
+ ],
+ ],
+ [
+ datetime.timedelta(seconds=2),
+ [
+ "59.500000 sec",
+ "00 hr 00 min",
+ "00.500000 sec",
+ "01.000000 sec",
+ "01.500000 sec",
+ "02.000000 sec",
+ "02.500000 sec",
+ ],
+ ],
+ )
+ for t_delta, expected in results:
+ d2 = d1 + t_delta
+ strings = _create_auto_date_locator(d1, d2)
+ assert strings == expected
+
+
+def test_concise_formatter_zformats():
+ zero_formats = ["", "'%y", "%B", "%m-%d", "%S", "%S.%f"]
+
+ def _create_auto_date_locator(date1, date2):
+ fig, ax = plt.subplots()
+
+ locator = mdates.AutoDateLocator(interval_multiples=True)
+ formatter = mdates.ConciseDateFormatter(locator, zero_formats=zero_formats)
+ ax.yaxis.set_major_locator(locator)
+ ax.yaxis.set_major_formatter(formatter)
+ ax.set_ylim(date1, date2)
+ fig.canvas.draw()
+ sts = [st.get_text() for st in ax.get_yticklabels()]
+ return sts
+
+ d1 = datetime.datetime(1997, 1, 1)
+ results = (
+ [datetime.timedelta(weeks=52 * 200), [str(t) for t in range(1980, 2201, 20)]],
+ [
+ datetime.timedelta(weeks=52),
+ [
+ "'97",
+ "Feb",
+ "Mar",
+ "Apr",
+ "May",
+ "Jun",
+ "Jul",
+ "Aug",
+ "Sep",
+ "Oct",
+ "Nov",
+ "Dec",
+ ],
+ ],
+ [
+ datetime.timedelta(days=141),
+ [
+ "January",
+ "15",
+ "February",
+ "15",
+ "March",
+ "15",
+ "April",
+ "15",
+ "May",
+ "15",
+ ],
+ ],
+ [
+ datetime.timedelta(days=40),
+ [
+ "January",
+ "05",
+ "09",
+ "13",
+ "17",
+ "21",
+ "25",
+ "29",
+ "February",
+ "05",
+ "09",
+ ],
+ ],
+ [
+ datetime.timedelta(hours=40),
+ [
+ "01-01",
+ "04:00",
+ "08:00",
+ "12:00",
+ "16:00",
+ "20:00",
+ "01-02",
+ "04:00",
+ "08:00",
+ "12:00",
+ "16:00",
+ ],
+ ],
+ [datetime.timedelta(minutes=20), ["00", "00:05", "00:10", "00:15", "00:20"]],
+ [
+ datetime.timedelta(seconds=40),
+ ["00", "05", "10", "15", "20", "25", "30", "35", "40"],
+ ],
+ [
+ datetime.timedelta(seconds=2),
+ ["59.5", "00.0", "00.5", "01.0", "01.5", "02.0", "02.5"],
+ ],
+ )
+ for t_delta, expected in results:
+ d2 = d1 + t_delta
+ strings = _create_auto_date_locator(d1, d2)
+ assert strings == expected
+
+
+def test_concise_formatter_tz():
+ def _create_auto_date_locator(date1, date2, tz):
+ fig, ax = plt.subplots()
+
+ locator = mdates.AutoDateLocator(interval_multiples=True)
+ formatter = mdates.ConciseDateFormatter(locator, tz=tz)
+ ax.yaxis.set_major_locator(locator)
+ ax.yaxis.set_major_formatter(formatter)
+ ax.set_ylim(date1, date2)
+ fig.canvas.draw()
+ sts = [st.get_text() for st in ax.get_yticklabels()]
+ return sts, ax.yaxis.get_offset_text().get_text()
+
+ d1 = datetime.datetime(1997, 1, 1).replace(tzinfo=datetime.timezone.utc)
+ results = (
+ [
+ datetime.timedelta(hours=40),
+ [
+ "03:00",
+ "07:00",
+ "11:00",
+ "15:00",
+ "19:00",
+ "23:00",
+ "03:00",
+ "07:00",
+ "11:00",
+ "15:00",
+ "19:00",
+ ],
+ "1997-Jan-02",
+ ],
+ [
+ datetime.timedelta(minutes=20),
+ ["03:00", "03:05", "03:10", "03:15", "03:20"],
+ "1997-Jan-01",
+ ],
+ [
+ datetime.timedelta(seconds=40),
+ ["03:00", "05", "10", "15", "20", "25", "30", "35", "40"],
+ "1997-Jan-01 03:00",
+ ],
+ [
+ datetime.timedelta(seconds=2),
+ ["59.5", "03:00", "00.5", "01.0", "01.5", "02.0", "02.5"],
+ "1997-Jan-01 03:00",
+ ],
+ )
+
+ new_tz = datetime.timezone(datetime.timedelta(hours=3))
+ for t_delta, expected_strings, expected_offset in results:
+ d2 = d1 + t_delta
+ strings, offset = _create_auto_date_locator(d1, d2, new_tz)
+ assert strings == expected_strings
+ assert offset == expected_offset
+
+
+def test_auto_date_locator_intmult_tz():
+ def _create_auto_date_locator(date1, date2, tz):
+ locator = mdates.AutoDateLocator(interval_multiples=True, tz=tz)
+ locator.create_dummy_axis()
+ locator.axis.set_view_interval(*mdates.date2num([date1, date2]))
+ return locator
+
+ results = (
+ [
+ datetime.timedelta(weeks=52 * 200),
+ [
+ "1980-01-01 00:00:00-08:00",
+ "2000-01-01 00:00:00-08:00",
+ "2020-01-01 00:00:00-08:00",
+ "2040-01-01 00:00:00-08:00",
+ "2060-01-01 00:00:00-08:00",
+ "2080-01-01 00:00:00-08:00",
+ "2100-01-01 00:00:00-08:00",
+ "2120-01-01 00:00:00-08:00",
+ "2140-01-01 00:00:00-08:00",
+ "2160-01-01 00:00:00-08:00",
+ "2180-01-01 00:00:00-08:00",
+ "2200-01-01 00:00:00-08:00",
+ ],
+ ],
+ [
+ datetime.timedelta(weeks=52),
+ [
+ "1997-01-01 00:00:00-08:00",
+ "1997-02-01 00:00:00-08:00",
+ "1997-03-01 00:00:00-08:00",
+ "1997-04-01 00:00:00-08:00",
+ "1997-05-01 00:00:00-07:00",
+ "1997-06-01 00:00:00-07:00",
+ "1997-07-01 00:00:00-07:00",
+ "1997-08-01 00:00:00-07:00",
+ "1997-09-01 00:00:00-07:00",
+ "1997-10-01 00:00:00-07:00",
+ "1997-11-01 00:00:00-08:00",
+ "1997-12-01 00:00:00-08:00",
+ ],
+ ],
+ [
+ datetime.timedelta(days=141),
+ [
+ "1997-01-01 00:00:00-08:00",
+ "1997-01-15 00:00:00-08:00",
+ "1997-02-01 00:00:00-08:00",
+ "1997-02-15 00:00:00-08:00",
+ "1997-03-01 00:00:00-08:00",
+ "1997-03-15 00:00:00-08:00",
+ "1997-04-01 00:00:00-08:00",
+ "1997-04-15 00:00:00-07:00",
+ "1997-05-01 00:00:00-07:00",
+ "1997-05-15 00:00:00-07:00",
+ ],
+ ],
+ [
+ datetime.timedelta(days=40),
+ [
+ "1997-01-01 00:00:00-08:00",
+ "1997-01-05 00:00:00-08:00",
+ "1997-01-09 00:00:00-08:00",
+ "1997-01-13 00:00:00-08:00",
+ "1997-01-17 00:00:00-08:00",
+ "1997-01-21 00:00:00-08:00",
+ "1997-01-25 00:00:00-08:00",
+ "1997-01-29 00:00:00-08:00",
+ "1997-02-01 00:00:00-08:00",
+ "1997-02-05 00:00:00-08:00",
+ "1997-02-09 00:00:00-08:00",
+ ],
+ ],
+ [
+ datetime.timedelta(hours=40),
+ [
+ "1997-01-01 00:00:00-08:00",
+ "1997-01-01 04:00:00-08:00",
+ "1997-01-01 08:00:00-08:00",
+ "1997-01-01 12:00:00-08:00",
+ "1997-01-01 16:00:00-08:00",
+ "1997-01-01 20:00:00-08:00",
+ "1997-01-02 00:00:00-08:00",
+ "1997-01-02 04:00:00-08:00",
+ "1997-01-02 08:00:00-08:00",
+ "1997-01-02 12:00:00-08:00",
+ "1997-01-02 16:00:00-08:00",
+ ],
+ ],
+ [
+ datetime.timedelta(minutes=20),
+ [
+ "1997-01-01 00:00:00-08:00",
+ "1997-01-01 00:05:00-08:00",
+ "1997-01-01 00:10:00-08:00",
+ "1997-01-01 00:15:00-08:00",
+ "1997-01-01 00:20:00-08:00",
+ ],
+ ],
+ [
+ datetime.timedelta(seconds=40),
+ [
+ "1997-01-01 00:00:00-08:00",
+ "1997-01-01 00:00:05-08:00",
+ "1997-01-01 00:00:10-08:00",
+ "1997-01-01 00:00:15-08:00",
+ "1997-01-01 00:00:20-08:00",
+ "1997-01-01 00:00:25-08:00",
+ "1997-01-01 00:00:30-08:00",
+ "1997-01-01 00:00:35-08:00",
+ "1997-01-01 00:00:40-08:00",
+ ],
+ ],
+ )
+
+ tz = dateutil.tz.gettz("Canada/Pacific")
+ d1 = datetime.datetime(1997, 1, 1, tzinfo=tz)
+ for t_delta, expected in results:
+ with rc_context({"_internal.classic_mode": False}):
+ d2 = d1 + t_delta
+ locator = _create_auto_date_locator(d1, d2, tz)
+ st = list(map(str, mdates.num2date(locator(), tz=tz)))
+ assert st == expected
+
+
+@pytest.mark.skip(reason="Failing test")
+@image_comparison(["date_inverted_limit.png"])
+def test_date_inverted_limit():
+ # test ax hline with date inputs
+ t0 = datetime.datetime(2009, 1, 20)
+ tf = datetime.datetime(2009, 1, 31)
+ fig, ax = plt.subplots()
+ ax.axhline(t0, color="blue", lw=3)
+ ax.set_ylim(t0 - datetime.timedelta(days=5), tf + datetime.timedelta(days=5))
+ ax.invert_yaxis()
+ fig.subplots_adjust(left=0.25)
+
+
+def _test_date2num_dst(date_range, tz_convert):
+ # Timezones
+
+ BRUSSELS = dateutil.tz.gettz("Europe/Brussels")
+ UTC = mdates.UTC
+
+ # Create a list of timezone-aware datetime objects in UTC
+ # Interval is 0b0.0000011 days, to prevent float rounding issues
+ dtstart = datetime.datetime(2014, 3, 30, 0, 0, tzinfo=UTC)
+ interval = datetime.timedelta(minutes=33, seconds=45)
+ interval_days = interval.seconds / 86400
+ N = 8
+
+ dt_utc = date_range(start=dtstart, freq=interval, periods=N)
+ dt_bxl = tz_convert(dt_utc, BRUSSELS)
+ t0 = 735322.0 + mdates.date2num(np.datetime64("0000-12-31"))
+ expected_ordinalf = [t0 + (i * interval_days) for i in range(N)]
+ actual_ordinalf = list(mdates.date2num(dt_bxl))
+
+ assert actual_ordinalf == expected_ordinalf
+
+
+def test_date2num_dst():
+ # Test for github issue #3896, but in date2num around DST transitions
+ # with a timezone-aware pandas date_range object.
+
+ class dt_tzaware(datetime.datetime):
+ """
+ This bug specifically occurs because of the normalization behavior of
+ pandas Timestamp objects, so in order to replicate it, we need a
+ datetime-like object that applies timezone normalization after
+ subtraction.
+ """
+
+ def __sub__(self, other):
+ r = super().__sub__(other)
+ tzinfo = getattr(r, "tzinfo", None)
+
+ if tzinfo is not None:
+ localizer = getattr(tzinfo, "normalize", None)
+ if localizer is not None:
+ r = tzinfo.normalize(r)
+
+ if isinstance(r, datetime.datetime):
+ r = self.mk_tzaware(r)
+
+ return r
+
+ def __add__(self, other):
+ return self.mk_tzaware(super().__add__(other))
+
+ def astimezone(self, tzinfo):
+ dt = super().astimezone(tzinfo)
+ return self.mk_tzaware(dt)
+
+ @classmethod
+ def mk_tzaware(cls, datetime_obj):
+ kwargs = {}
+ attrs = (
+ "year",
+ "month",
+ "day",
+ "hour",
+ "minute",
+ "second",
+ "microsecond",
+ "tzinfo",
+ )
+
+ for attr in attrs:
+ val = getattr(datetime_obj, attr, None)
+ if val is not None:
+ kwargs[attr] = val
+
+ return cls(**kwargs)
+
+ # Define a date_range function similar to pandas.date_range
+ def date_range(start, freq, periods):
+ dtstart = dt_tzaware.mk_tzaware(start)
+
+ return [dtstart + (i * freq) for i in range(periods)]
+
+ # Define a tz_convert function that converts a list to a new timezone.
+ def tz_convert(dt_list, tzinfo):
+ return [d.astimezone(tzinfo) for d in dt_list]
+
+ _test_date2num_dst(date_range, tz_convert)
+
+
+def test_date2num_dst_pandas():
+ # Test for github issue #3896, but in date2num around DST transitions
+ # with a timezone-aware pandas date_range object.
+
+ def tz_convert(*args):
+ return pd.DatetimeIndex.tz_convert(*args).astype(object)
+
+ _test_date2num_dst(pd.date_range, tz_convert)
+
+
+def _test_rrulewrapper(attach_tz, get_tz):
+ SYD = get_tz("Australia/Sydney")
+
+ dtstart = attach_tz(datetime.datetime(2017, 4, 1, 0), SYD)
+ dtend = attach_tz(datetime.datetime(2017, 4, 4, 0), SYD)
+
+ rule = mdates.rrulewrapper(freq=dateutil.rrule.DAILY, dtstart=dtstart)
+
+ act = rule.between(dtstart, dtend)
+ exp = [
+ datetime.datetime(2017, 4, 1, 13, tzinfo=dateutil.tz.tzutc()),
+ datetime.datetime(2017, 4, 2, 14, tzinfo=dateutil.tz.tzutc()),
+ ]
+
+ assert act == exp
+
+
+def test_rrulewrapper():
+ def attach_tz(dt, zi):
+ return dt.replace(tzinfo=zi)
+
+ _test_rrulewrapper(attach_tz, dateutil.tz.gettz)
+
+ SYD = dateutil.tz.gettz("Australia/Sydney")
+ dtstart = datetime.datetime(2017, 4, 1, 0)
+ dtend = datetime.datetime(2017, 4, 4, 0)
+ rule = mdates.rrulewrapper(
+ freq=dateutil.rrule.DAILY, dtstart=dtstart, tzinfo=SYD, until=dtend
+ )
+ assert rule.after(dtstart) == datetime.datetime(2017, 4, 2, 0, 0, tzinfo=SYD)
+ assert rule.before(dtend) == datetime.datetime(2017, 4, 3, 0, 0, tzinfo=SYD)
+
+ # Test parts of __getattr__
+ assert rule._base_tzinfo == SYD
+ assert rule._interval == 1
+
+
+@pytest.mark.pytz
+def test_rrulewrapper_pytz():
+ # Test to make sure pytz zones are supported in rrules
+ pytz = pytest.importorskip("pytz")
+
+ def attach_tz(dt, zi):
+ return zi.localize(dt)
+
+ _test_rrulewrapper(attach_tz, pytz.timezone)
+
+
+@pytest.mark.pytz
+def test_yearlocator_pytz():
+ pytz = pytest.importorskip("pytz")
+
+ tz = pytz.timezone("America/New_York")
+ x = [
+ tz.localize(datetime.datetime(2010, 1, 1)) + datetime.timedelta(i)
+ for i in range(2000)
+ ]
+ locator = mdates.AutoDateLocator(interval_multiples=True, tz=tz)
+ locator.create_dummy_axis()
+ locator.axis.set_view_interval(
+ mdates.date2num(x[0]) - 1.0, mdates.date2num(x[-1]) + 1.0
+ )
+ t = np.array(
+ [
+ 733408.208333,
+ 733773.208333,
+ 734138.208333,
+ 734503.208333,
+ 734869.208333,
+ 735234.208333,
+ 735599.208333,
+ ]
+ )
+ # convert to new epoch from old...
+ t = t + mdates.date2num(np.datetime64("0000-12-31"))
+ np.testing.assert_allclose(t, locator())
+ expected = [
+ "2009-01-01 00:00:00-05:00",
+ "2010-01-01 00:00:00-05:00",
+ "2011-01-01 00:00:00-05:00",
+ "2012-01-01 00:00:00-05:00",
+ "2013-01-01 00:00:00-05:00",
+ "2014-01-01 00:00:00-05:00",
+ "2015-01-01 00:00:00-05:00",
+ ]
+ st = list(map(str, mdates.num2date(locator(), tz=tz)))
+ assert st == expected
+ assert np.allclose(
+ locator.tick_values(x[0], x[1]),
+ np.array(
+ [
+ 14610.20833333,
+ 14610.33333333,
+ 14610.45833333,
+ 14610.58333333,
+ 14610.70833333,
+ 14610.83333333,
+ 14610.95833333,
+ 14611.08333333,
+ 14611.20833333,
+ ]
+ ),
+ )
+ assert np.allclose(
+ locator.get_locator(x[1], x[0]).tick_values(x[0], x[1]),
+ np.array(
+ [
+ 14610.20833333,
+ 14610.33333333,
+ 14610.45833333,
+ 14610.58333333,
+ 14610.70833333,
+ 14610.83333333,
+ 14610.95833333,
+ 14611.08333333,
+ 14611.20833333,
+ ]
+ ),
+ )
+
+
+def test_YearLocator():
+ def _create_year_locator(date1, date2, **kwargs):
+ locator = mdates.YearLocator(**kwargs)
+ locator.create_dummy_axis()
+ locator.axis.set_view_interval(mdates.date2num(date1), mdates.date2num(date2))
+ return locator
+
+ d1 = datetime.datetime(1990, 1, 1)
+ results = (
+ [
+ datetime.timedelta(weeks=52 * 200),
+ {"base": 20, "month": 1, "day": 1},
+ [
+ "1980-01-01 00:00:00+00:00",
+ "2000-01-01 00:00:00+00:00",
+ "2020-01-01 00:00:00+00:00",
+ "2040-01-01 00:00:00+00:00",
+ "2060-01-01 00:00:00+00:00",
+ "2080-01-01 00:00:00+00:00",
+ "2100-01-01 00:00:00+00:00",
+ "2120-01-01 00:00:00+00:00",
+ "2140-01-01 00:00:00+00:00",
+ "2160-01-01 00:00:00+00:00",
+ "2180-01-01 00:00:00+00:00",
+ "2200-01-01 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(weeks=52 * 200),
+ {"base": 20, "month": 5, "day": 16},
+ [
+ "1980-05-16 00:00:00+00:00",
+ "2000-05-16 00:00:00+00:00",
+ "2020-05-16 00:00:00+00:00",
+ "2040-05-16 00:00:00+00:00",
+ "2060-05-16 00:00:00+00:00",
+ "2080-05-16 00:00:00+00:00",
+ "2100-05-16 00:00:00+00:00",
+ "2120-05-16 00:00:00+00:00",
+ "2140-05-16 00:00:00+00:00",
+ "2160-05-16 00:00:00+00:00",
+ "2180-05-16 00:00:00+00:00",
+ "2200-05-16 00:00:00+00:00",
+ ],
+ ],
+ [
+ datetime.timedelta(weeks=52 * 5),
+ {"base": 20, "month": 9, "day": 25},
+ ["1980-09-25 00:00:00+00:00", "2000-09-25 00:00:00+00:00"],
+ ],
+ )
+
+ for delta, arguments, expected in results:
+ d2 = d1 + delta
+ locator = _create_year_locator(d1, d2, **arguments)
+ assert list(map(str, mdates.num2date(locator()))) == expected
+
+
+def test_DayLocator():
+ with pytest.raises(ValueError):
+ mdates.DayLocator(interval=-1)
+ with pytest.raises(ValueError):
+ mdates.DayLocator(interval=-1.5)
+ with pytest.raises(ValueError):
+ mdates.DayLocator(interval=0)
+ with pytest.raises(ValueError):
+ mdates.DayLocator(interval=1.3)
+ mdates.DayLocator(interval=1.0)
+
+
+def test_tz_utc():
+ dt = datetime.datetime(1970, 1, 1, tzinfo=mdates.UTC)
+ assert dt.tzname() == "UTC"
+
+
+@pytest.mark.parametrize(
+ "x, tdelta",
+ [
+ (1, datetime.timedelta(days=1)),
+ ([1, 1.5], [datetime.timedelta(days=1), datetime.timedelta(days=1.5)]),
+ ],
+)
+def test_num2timedelta(x, tdelta):
+ dt = mdates.num2timedelta(x)
+ assert dt == tdelta
+
+
+def test_datetime64_in_list():
+ dt = [np.datetime64("2000-01-01"), np.datetime64("2001-01-01")]
+ dn = mdates.date2num(dt)
+ # convert fixed values from old to new epoch
+ t = np.array([730120.0, 730486.0]) + mdates.date2num(np.datetime64("0000-12-31"))
+ np.testing.assert_equal(dn, t)
+
+
+def test_change_epoch():
+ date = np.datetime64("2000-01-01")
+
+ # use private method to clear the epoch and allow it to be set...
+ mdates._reset_epoch_test_example()
+ mdates.get_epoch() # Set default.
+
+ with pytest.raises(RuntimeError):
+ # this should fail here because there is a sentinel on the epoch
+ # if the epoch has been used then it cannot be set.
+ mdates.set_epoch("0000-01-01")
+
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("1970-01-01")
+ dt = (date - np.datetime64("1970-01-01")).astype("datetime64[D]")
+ dt = dt.astype("int")
+ np.testing.assert_equal(mdates.date2num(date), float(dt))
+
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("0000-12-31")
+ np.testing.assert_equal(mdates.date2num(date), 730120.0)
+
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("1970-01-01T01:00:00")
+ np.testing.assert_allclose(mdates.date2num(date), dt - 1.0 / 24.0)
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("1970-01-01T00:00:00")
+ np.testing.assert_allclose(
+ mdates.date2num(np.datetime64("1970-01-01T12:00:00")), 0.5
+ )
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_change_converter():
+ plt.rcParams["date.converter"] = "concise"
+ dates = np.arange("2020-01-01", "2020-05-01", dtype="datetime64[D]")
+ fig, ax = plt.subplots()
+
+ ax.plot(dates, np.arange(len(dates)))
+ fig.canvas.draw()
+ assert ax.get_xticklabels()[0].get_text() == "Jan"
+ assert ax.get_xticklabels()[1].get_text() == "15"
+
+ plt.rcParams["date.converter"] = "auto"
+ fig, ax = plt.subplots()
+
+ ax.plot(dates, np.arange(len(dates)))
+ fig.canvas.draw()
+ assert ax.get_xticklabels()[0].get_text() == "Jan 01 2020"
+ assert ax.get_xticklabels()[1].get_text() == "Jan 15 2020"
+ with pytest.raises(ValueError):
+ plt.rcParams["date.converter"] = "boo"
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_change_interval_multiples():
+ plt.rcParams["date.interval_multiples"] = False
+ dates = np.arange("2020-01-10", "2020-05-01", dtype="datetime64[D]")
+ fig, ax = plt.subplots()
+
+ ax.plot(dates, np.arange(len(dates)))
+ fig.canvas.draw()
+ assert ax.get_xticklabels()[0].get_text() == "Jan 10 2020"
+ assert ax.get_xticklabels()[1].get_text() == "Jan 24 2020"
+
+ plt.rcParams["date.interval_multiples"] = "True"
+ fig, ax = plt.subplots()
+
+ ax.plot(dates, np.arange(len(dates)))
+ fig.canvas.draw()
+ assert ax.get_xticklabels()[0].get_text() == "Jan 15 2020"
+ assert ax.get_xticklabels()[1].get_text() == "Feb 01 2020"
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_julian2num():
+ with pytest.warns(_api.MatplotlibDeprecationWarning):
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("0000-12-31")
+ # 2440587.5 is julian date for 1970-01-01T00:00:00
+ # https://en.wikipedia.org/wiki/Julian_day
+ assert mdates.julian2num(2440588.5) == 719164.0
+ assert mdates.num2julian(719165.0) == 2440589.5
+ # set back to the default
+ mdates._reset_epoch_test_example()
+ mdates.set_epoch("1970-01-01T00:00:00")
+ assert mdates.julian2num(2440588.5) == 1.0
+ assert mdates.num2julian(2.0) == 2440589.5
+
+
+def test_DateLocator():
+ locator = mdates.DateLocator()
+ # Test nonsingular
+ assert locator.nonsingular(0, np.inf) == (0, 1)
+ assert locator.nonsingular(0, 1) == (0, 1)
+ assert locator.nonsingular(1, 0) == (0, 1)
+ assert locator.nonsingular(0, 0) == (-2, 2)
+ locator.create_dummy_axis()
+ # default values
+ assert locator.datalim_to_dt() == (
+ datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc),
+ datetime.datetime(1970, 1, 2, 0, 0, tzinfo=datetime.timezone.utc),
+ )
+
+ # Check default is UTC
+ assert locator.tz == mdates.UTC
+ tz_str = "Iceland"
+ iceland_tz = dateutil.tz.gettz(tz_str)
+ # Check not Iceland
+ assert locator.tz != iceland_tz
+ # Set it to to Iceland
+ locator.set_tzinfo("Iceland")
+ # Check now it is Iceland
+ assert locator.tz == iceland_tz
+ locator.create_dummy_axis()
+ locator.axis.set_data_interval(*mdates.date2num(["2022-01-10", "2022-01-08"]))
+ assert locator.datalim_to_dt() == (
+ datetime.datetime(2022, 1, 8, 0, 0, tzinfo=iceland_tz),
+ datetime.datetime(2022, 1, 10, 0, 0, tzinfo=iceland_tz),
+ )
+
+ # Set rcParam
+ plt.rcParams["timezone"] = tz_str
+
+ # Create a new one in a similar way
+ locator = mdates.DateLocator()
+ # Check now it is Iceland
+ assert locator.tz == iceland_tz
+
+ # Test invalid tz values
+ with pytest.raises(ValueError, match="Aiceland is not a valid timezone"):
+ mdates.DateLocator(tz="Aiceland")
+ with pytest.raises(TypeError, match="tz must be string or tzinfo subclass."):
+ mdates.DateLocator(tz=1)
+
+
+def test_datestr2num():
+ assert mdates.datestr2num("2022-01-10") == 19002.0
+ dt = datetime.date(year=2022, month=1, day=10)
+ assert mdates.datestr2num("2022-01", default=dt) == 19002.0
+ assert np.all(
+ mdates.datestr2num(["2022-01", "2022-02"], default=dt)
+ == np.array([19002.0, 19033.0])
+ )
+ assert mdates.datestr2num([]).size == 0
+ assert mdates.datestr2num([], datetime.date(year=2022, month=1, day=10)).size == 0
+
+
+@pytest.mark.parametrize("kwarg", ("formats", "zero_formats", "offset_formats"))
+def test_concise_formatter_exceptions(kwarg):
+ locator = mdates.AutoDateLocator()
+ kwargs = {kwarg: ["", "%Y"]}
+ match = f"{kwarg} argument must be a list"
+ with pytest.raises(ValueError, match=match):
+ mdates.ConciseDateFormatter(locator, **kwargs)
+
+
+def test_concise_formatter_call():
+ locator = mdates.AutoDateLocator()
+ formatter = mdates.ConciseDateFormatter(locator)
+ assert formatter(19002.0) == "2022"
+ assert formatter.format_data_short(19002.0) == "2022-01-10 00:00:00"
+
+
+@pytest.mark.parametrize(
+ "span, expected_locator",
+ (
+ (0.02, mdates.MinuteLocator),
+ (1, mdates.HourLocator),
+ (19, mdates.DayLocator),
+ (40, mdates.WeekdayLocator),
+ (200, mdates.MonthLocator),
+ (2000, mdates.YearLocator),
+ ),
+)
+def test_date_ticker_factory(span, expected_locator):
+ with pytest.warns(_api.MatplotlibDeprecationWarning):
+ locator, _ = mdates.date_ticker_factory(span)
+ assert isinstance(locator, expected_locator)
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_datetime_masked():
+ # make sure that all-masked data falls back to the viewlim
+ # set in convert.axisinfo....
+ x = np.array([datetime.datetime(2017, 1, n) for n in range(1, 6)])
+ y = np.array([1, 2, 3, 4, 5])
+ m = np.ma.masked_greater(y, 0)
+
+ fig, ax = plt.subplots()
+ ax.plot(x, m)
+ assert ax.get_xlim() == (0, 1)
+
+
+@pytest.mark.parametrize("val", (-1000000, 10000000))
+def test_num2date_error(val):
+ with pytest.raises(ValueError, match=f"Date ordinal {val} converts"):
+ mdates.num2date(val)
+
+
+def test_num2date_roundoff():
+ assert mdates.num2date(100000.0000578702) == datetime.datetime(
+ 2243, 10, 17, 0, 0, 4, 999980, tzinfo=datetime.timezone.utc
+ )
+ # Slightly larger, steps of 20 microseconds
+ assert mdates.num2date(100000.0000578703) == datetime.datetime(
+ 2243, 10, 17, 0, 0, 5, tzinfo=datetime.timezone.utc
+ )
+
+
+def test_DateFormatter_settz():
+ time = mdates.date2num(datetime.datetime(2011, 1, 1, 0, 0, tzinfo=mdates.UTC))
+ formatter = mdates.DateFormatter("%Y-%b-%d %H:%M")
+ # Default UTC
+ assert formatter(time) == "2011-Jan-01 00:00"
+
+ # Set tzinfo
+ formatter.set_tzinfo("Pacific/Kiritimati")
+ assert formatter(time) == "2011-Jan-01 14:00"
diff --git a/modin/pandas/test/interoperability/matplotlib/test_preprocess_data.py b/modin/pandas/test/interoperability/matplotlib/test_preprocess_data.py
new file mode 100644
index 00000000000..56c3d65899a
--- /dev/null
+++ b/modin/pandas/test/interoperability/matplotlib/test_preprocess_data.py
@@ -0,0 +1,69 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import numpy as np
+import pytest
+from matplotlib import _preprocess_data
+import modin.pandas as pd
+
+# Notes on testing the plotting functions itself
+# * the individual decorated plotting functions are tested in 'test_axes.py'
+# * that pyplot functions accept a data kwarg is only tested in
+# test_axes.test_pie_linewidth_0
+
+
+# this gets used in multiple tests, so define it here
+@_preprocess_data(replace_names=["x", "y"], label_namer="y")
+def plot_func(ax, x, y, ls="x", label=None, w="xyz"):
+ return "x: %s, y: %s, ls: %s, w: %s, label: %s" % (list(x), list(y), ls, w, label)
+
+
+all_funcs = [plot_func]
+all_func_ids = ["plot_func"]
+
+
+@pytest.mark.parametrize("func", all_funcs, ids=all_func_ids)
+def test_function_call_with_pandas_data(func):
+ """Test with pandas dataframe -> label comes from ``data["col"].name``."""
+ data = pd.DataFrame(
+ {
+ "a": np.array([1, 2], dtype=np.int32),
+ "b": np.array([8, 9], dtype=np.int32),
+ "w": ["NOT", "NOT"],
+ }
+ )
+
+ assert (
+ func(None, "a", "b", data=data)
+ == "x: [1, 2], y: [8, 9], ls: x, w: xyz, label: b"
+ )
+ assert (
+ func(None, x="a", y="b", data=data)
+ == "x: [1, 2], y: [8, 9], ls: x, w: xyz, label: b"
+ )
+ assert (
+ func(None, "a", "b", label="", data=data)
+ == "x: [1, 2], y: [8, 9], ls: x, w: xyz, label: "
+ )
+ assert (
+ func(None, "a", "b", label="text", data=data)
+ == "x: [1, 2], y: [8, 9], ls: x, w: xyz, label: text"
+ )
+ assert (
+ func(None, x="a", y="b", label="", data=data)
+ == "x: [1, 2], y: [8, 9], ls: x, w: xyz, label: "
+ )
+ assert (
+ func(None, x="a", y="b", label="text", data=data)
+ == "x: [1, 2], y: [8, 9], ls: x, w: xyz, label: text"
+ )
diff --git a/modin/pandas/test/interoperability/plotly/test_io/test_to_from_plotly_json.py b/modin/pandas/test/interoperability/plotly/test_io/test_to_from_plotly_json.py
new file mode 100644
index 00000000000..4c6e13bc7f8
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_io/test_to_from_plotly_json.py
@@ -0,0 +1,238 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import pytest
+import plotly.io.json as pio
+import plotly.graph_objects as go
+import plotly.express as px
+import numpy as np
+import modin.pandas as pd
+import json
+import datetime
+import sys
+from pytz import timezone
+from _plotly_utils.optional_imports import get_module
+
+orjson = get_module("orjson")
+
+eastern = timezone("US/Eastern")
+
+
+# Testing helper
+def build_json_opts(pretty=False):
+ opts = {"sort_keys": True}
+ if pretty:
+ opts["indent"] = 2
+ else:
+ opts["separators"] = (",", ":")
+ return opts
+
+
+def to_json_test(value, pretty=False):
+ return json.dumps(value, **build_json_opts(pretty=pretty))
+
+
+def isoformat_test(dt_value):
+ if isinstance(dt_value, np.datetime64):
+ return str(dt_value)
+ elif isinstance(dt_value, datetime.datetime):
+ return dt_value.isoformat()
+ else:
+ raise ValueError("Unsupported date type: {}".format(type(dt_value)))
+
+
+def build_test_dict(value):
+ return dict(a=value, b=[3, value], c=dict(Z=value))
+
+
+def build_test_dict_string(value_string, pretty=False):
+ if pretty:
+ non_pretty_str = build_test_dict_string(value_string, pretty=False)
+ return to_json_test(json.loads(non_pretty_str), pretty=True)
+ else:
+ value_string = str(value_string).replace(" ", "")
+ return """{"a":%s,"b":[3,%s],"c":{"Z":%s}}""" % tuple([value_string] * 3)
+
+
+def check_roundtrip(value, engine, pretty):
+ encoded = pio.to_json_plotly(value, engine=engine, pretty=pretty)
+ decoded = pio.from_json_plotly(encoded, engine=engine)
+ reencoded = pio.to_json_plotly(decoded, engine=engine, pretty=pretty)
+ assert encoded == reencoded
+
+ # Check from_plotly_json with bytes on Python 3
+ if sys.version_info.major == 3:
+ encoded_bytes = encoded.encode("utf8")
+ decoded_from_bytes = pio.from_json_plotly(encoded_bytes, engine=engine)
+ assert decoded == decoded_from_bytes
+
+
+# Fixtures
+if orjson is not None:
+ engines = ["json", "orjson", "auto"]
+else:
+ engines = ["json", "auto"]
+
+
+@pytest.fixture(scope="module", params=engines)
+def engine(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[False])
+def pretty(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["float64", "int32", "uint32"])
+def graph_object(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["float64", "int32", "uint32"])
+def numeric_numpy_array(request):
+ dtype = request.param
+ return np.linspace(-5, 5, 4, dtype=dtype)
+
+
+@pytest.fixture(scope="module")
+def object_numpy_array(request):
+ return np.array(["a", 1, [2, 3]])
+
+
+@pytest.fixture(scope="module")
+def numpy_unicode_array(request):
+ return np.array(["A", "BB", "CCC"], dtype="U")
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ datetime.datetime(2003, 7, 12, 8, 34, 22),
+ datetime.datetime.now(),
+ np.datetime64(datetime.datetime.utcnow()),
+ pd.Timestamp(datetime.datetime.now()),
+ eastern.localize(datetime.datetime(2003, 7, 12, 8, 34, 22)),
+ eastern.localize(datetime.datetime.now()),
+ pd.Timestamp(datetime.datetime.now(), tzinfo=eastern),
+ ],
+)
+def datetime_value(request):
+ return request.param
+
+
+@pytest.fixture(
+ params=[
+ list, # plain list of datetime values
+ lambda a: pd.DatetimeIndex(a), # Pandas DatetimeIndex
+ lambda a: pd.Series(pd.DatetimeIndex(a)), # Pandas Datetime Series
+ lambda a: pd.DatetimeIndex(a).values, # Numpy datetime64 array
+ lambda a: np.array(a, dtype="object"), # Numpy object array of datetime
+ ]
+)
+def datetime_array(request, datetime_value):
+ return request.param([datetime_value] * 3)
+
+
+# Encoding tests
+def test_graph_object_input(engine, pretty):
+ scatter = go.Scatter(x=[1, 2, 3], y=np.array([4, 5, 6]))
+ result = pio.to_json_plotly(scatter, engine=engine)
+ expected = """{"x":[1,2,3],"y":[4,5,6],"type":"scatter"}"""
+ assert result == expected
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+def test_numeric_numpy_encoding(numeric_numpy_array, engine, pretty):
+ value = build_test_dict(numeric_numpy_array)
+ result = pio.to_json_plotly(value, engine=engine, pretty=pretty)
+
+ array_str = to_json_test(numeric_numpy_array.tolist())
+ expected = build_test_dict_string(array_str, pretty=pretty)
+ assert result == expected
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+def test_numpy_unicode_encoding(numpy_unicode_array, engine, pretty):
+ value = build_test_dict(numpy_unicode_array)
+ result = pio.to_json_plotly(value, engine=engine, pretty=pretty)
+
+ array_str = to_json_test(numpy_unicode_array.tolist())
+ expected = build_test_dict_string(array_str)
+ assert result == expected
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+@pytest.mark.skip(reason="fails in plotly")
+def test_object_numpy_encoding(object_numpy_array, engine, pretty):
+ value = build_test_dict(object_numpy_array)
+ result = pio.to_json_plotly(value, engine=engine, pretty=pretty)
+
+ array_str = to_json_test(object_numpy_array.tolist())
+ expected = build_test_dict_string(array_str)
+ assert result == expected
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+def test_datetime(datetime_value, engine, pretty):
+ value = build_test_dict(datetime_value)
+ result = pio.to_json_plotly(value, engine=engine, pretty=pretty)
+ expected = build_test_dict_string('"{}"'.format(isoformat_test(datetime_value)))
+ assert result == expected
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+@pytest.mark.skip(reason="Failing Test")
+def test_datetime_arrays(datetime_array, engine, pretty):
+ value = build_test_dict(datetime_array)
+ result = pio.to_json_plotly(value, engine=engine)
+
+ def to_str(v):
+ try:
+ v = v.isoformat(sep="T")
+ except (TypeError, AttributeError):
+ pass
+
+ return str(v)
+
+ if isinstance(datetime_array, list):
+ dt_values = [to_str(d) for d in datetime_array]
+ elif isinstance(datetime_array, pd.Series):
+ dt_values = [to_str(d) for d in datetime_array.dt.to_pydatetime().tolist()]
+ elif isinstance(datetime_array, pd.DatetimeIndex):
+ dt_values = [to_str(d) for d in datetime_array.to_pydatetime().tolist()]
+ else: # numpy datetime64 array
+ dt_values = [to_str(d) for d in datetime_array]
+
+ array_str = to_json_test(dt_values)
+ expected = build_test_dict_string(array_str)
+ assert result == expected
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+def test_object_array(engine, pretty):
+ fig = px.scatter(px.data.tips(), x="total_bill", y="tip", custom_data=["sex"])
+ result = fig.to_plotly_json()
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+def test_nonstring_key(engine, pretty):
+ value = build_test_dict({0: 1})
+ result = pio.to_json_plotly(value, engine=engine)
+ check_roundtrip(result, engine=engine, pretty=pretty)
+
+
+def test_mixed_string_nonstring_key(engine, pretty):
+ value = build_test_dict({0: 1, "a": 2})
+ result = pio.to_json_plotly(value, engine=engine)
+ check_roundtrip(result, engine=engine, pretty=pretty)
diff --git a/modin/pandas/test/interoperability/plotly/test_optional/test_matplotlylib/test_date_times.py b/modin/pandas/test/interoperability/plotly/test_optional/test_matplotlylib/test_date_times.py
new file mode 100644
index 00000000000..657739a8706
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_optional/test_matplotlylib/test_date_times.py
@@ -0,0 +1,82 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+from __future__ import absolute_import
+import datetime
+import random
+from unittest import TestCase
+import pytest
+import modin.pandas as pd
+import plotly.tools as tls
+from plotly import optional_imports
+
+matplotlylib = optional_imports.get_module("plotly.matplotlylib")
+if matplotlylib:
+ from matplotlib.dates import date2num
+ import matplotlib.pyplot as plt
+
+
+@pytest.mark.skip
+class TestDateTimes(TestCase):
+ def test_normal_mpl_dates(self):
+ datetime_format = "%Y-%m-%d %H:%M:%S"
+ y = [1, 2, 3, 4]
+ date_strings = [
+ "2010-01-04 00:00:00",
+ "2010-01-04 10:00:00",
+ "2010-01-04 23:00:59",
+ "2010-01-05 00:00:00",
+ ]
+
+ # 1. create datetimes from the strings
+ dates = [
+ datetime.datetime.strptime(date_string, datetime_format)
+ for date_string in date_strings
+ ]
+
+ # 2. create the mpl_dates from these datetimes
+ mpl_dates = date2num(dates)
+
+ # make a figure in mpl
+ fig, ax = plt.subplots()
+ ax.plot_date(mpl_dates, y)
+
+ # convert this figure to plotly's graph_objs
+ pfig = tls.mpl_to_plotly(fig)
+
+ # we use the same format here, so we expect equality here
+ self.assertEqual(fig.axes[0].lines[0].get_xydata()[0][0], 7.33776000e05)
+ self.assertEqual(tuple(pfig["data"][0]["x"]), tuple(date_strings))
+
+ def test_pandas_time_series_date_formatter(self):
+ ndays = 3
+ x = pd.date_range("1/1/2001", periods=ndays, freq="D")
+ y = [random.randint(0, 10) for i in range(ndays)]
+ s = pd.DataFrame(y, columns=["a"])
+
+ s["Date"] = x
+ s.plot(x="Date")
+
+ fig = plt.gcf()
+ pfig = tls.mpl_to_plotly(fig)
+
+ expected_x = (
+ "2001-01-01 00:00:00",
+ "2001-01-02 00:00:00",
+ "2001-01-03 00:00:00",
+ )
+ expected_x0 = 11323.0 # this is floating point days since epoch
+
+ x0 = fig.axes[0].lines[0].get_xydata()[0][0]
+ self.assertEqual(x0, expected_x0)
+ self.assertEqual(pfig["data"][0]["x"], expected_x)
diff --git a/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_facets.py b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_facets.py
new file mode 100644
index 00000000000..07f29840d43
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_facets.py
@@ -0,0 +1,145 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import modin.pandas as pd
+import plotly.express as px
+from pytest import approx
+import pytest
+import random
+
+
+def test_facets():
+ df = px.data.tips()
+ fig = px.scatter(df, x="total_bill", y="tip")
+ assert "xaxis2" not in fig.layout
+ assert "yaxis2" not in fig.layout
+ assert fig.layout.xaxis.domain == (0.0, 1.0)
+ assert fig.layout.yaxis.domain == (0.0, 1.0)
+
+ fig = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker")
+ assert fig.layout.xaxis4.domain[0] - fig.layout.xaxis.domain[1] == approx(0.02)
+ assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.03)
+
+ fig = px.scatter(df, x="total_bill", y="tip", facet_col="day", facet_col_wrap=2)
+ assert fig.layout.xaxis4.domain[0] - fig.layout.xaxis.domain[1] == approx(0.02)
+ assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.07)
+
+ fig = px.scatter(
+ df,
+ x="total_bill",
+ y="tip",
+ facet_row="sex",
+ facet_col="smoker",
+ facet_col_spacing=0.09,
+ facet_row_spacing=0.08,
+ )
+ assert fig.layout.xaxis4.domain[0] - fig.layout.xaxis.domain[1] == approx(0.09)
+ assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.08)
+
+ fig = px.scatter(
+ df,
+ x="total_bill",
+ y="tip",
+ facet_col="day",
+ facet_col_wrap=2,
+ facet_col_spacing=0.09,
+ facet_row_spacing=0.08,
+ )
+ assert fig.layout.xaxis4.domain[0] - fig.layout.xaxis.domain[1] == approx(0.09)
+ assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.08)
+
+
+def test_facets_with_marginals():
+ df = px.data.tips()
+
+ fig = px.histogram(df, x="total_bill", facet_col="sex", marginal="rug")
+ assert len(fig.data) == 4
+ fig = px.histogram(df, x="total_bill", facet_row="sex", marginal="rug")
+ assert len(fig.data) == 2
+
+ fig = px.histogram(df, y="total_bill", facet_col="sex", marginal="rug")
+ assert len(fig.data) == 2
+ fig = px.histogram(df, y="total_bill", facet_row="sex", marginal="rug")
+ assert len(fig.data) == 4
+
+ fig = px.scatter(df, x="total_bill", y="tip", facet_col="sex", marginal_x="rug")
+ assert len(fig.data) == 4
+ fig = px.scatter(
+ df, x="total_bill", y="tip", facet_col="day", facet_col_wrap=2, marginal_x="rug"
+ )
+ assert len(fig.data) == 8 # ignore the wrap when marginal is used
+ fig = px.scatter(df, x="total_bill", y="tip", facet_col="sex", marginal_y="rug")
+ assert len(fig.data) == 2 # ignore the marginal in the facet direction
+
+ fig = px.scatter(df, x="total_bill", y="tip", facet_row="sex", marginal_x="rug")
+ assert len(fig.data) == 2 # ignore the marginal in the facet direction
+ fig = px.scatter(df, x="total_bill", y="tip", facet_row="sex", marginal_y="rug")
+ assert len(fig.data) == 4
+
+ fig = px.scatter(
+ df, x="total_bill", y="tip", facet_row="sex", marginal_y="rug", marginal_x="rug"
+ )
+ assert len(fig.data) == 4 # ignore the marginal in the facet direction
+ fig = px.scatter(
+ df, x="total_bill", y="tip", facet_col="sex", marginal_y="rug", marginal_x="rug"
+ )
+ assert len(fig.data) == 4 # ignore the marginal in the facet direction
+ fig = px.scatter(
+ df,
+ x="total_bill",
+ y="tip",
+ facet_row="sex",
+ facet_col="sex",
+ marginal_y="rug",
+ marginal_x="rug",
+ )
+ assert len(fig.data) == 2 # ignore all marginals
+
+
+@pytest.fixture
+def bad_facet_spacing_df():
+ NROWS = 101
+ NDATA = 1000
+ categories = [n % NROWS for n in range(NDATA)]
+ df = pd.DataFrame(
+ {
+ "x": [random.random() for _ in range(NDATA)],
+ "y": [random.random() for _ in range(NDATA)],
+ "category": categories,
+ }
+ )
+ return df
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_bad_facet_spacing_eror(bad_facet_spacing_df):
+ df = bad_facet_spacing_df
+ with pytest.raises(
+ ValueError, match="Use the facet_row_spacing argument to adjust this spacing."
+ ):
+ px.scatter(df, x="x", y="y", facet_row="category", facet_row_spacing=0.01001)
+ with pytest.raises(
+ ValueError, match="Use the facet_col_spacing argument to adjust this spacing."
+ ):
+ px.scatter(df, x="x", y="y", facet_col="category", facet_col_spacing=0.01001)
+ # Check error is not raised when the spacing is OK
+ try:
+ px.scatter(df, x="x", y="y", facet_row="category", facet_row_spacing=0.01)
+ except ValueError:
+ # Error shouldn't be raised, so fail if it is
+ assert False
+ try:
+ px.scatter(df, x="x", y="y", facet_col="category", facet_col_spacing=0.01)
+ except ValueError:
+ # Error shouldn't be raised, so fail if it is
+ assert False
diff --git a/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_functions.py b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_functions.py
new file mode 100644
index 00000000000..bf11b0be0e1
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_functions.py
@@ -0,0 +1,251 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import plotly.express as px
+import plotly.graph_objects as go
+from numpy.testing import assert_array_equal
+import numpy as np
+import modin.pandas as pd
+import pytest
+
+
+def _compare_figures(go_trace, px_fig):
+ """Compare a figure created with a go trace and a figure created with
+ a px function call. Check that all values inside the go Figure are the
+ same in the px figure (which sets more parameters).
+ """
+ go_fig = go.Figure(go_trace)
+ go_fig = go_fig.to_plotly_json()
+ px_fig = px_fig.to_plotly_json()
+ del go_fig["layout"]["template"]
+ del px_fig["layout"]["template"]
+ for key in go_fig["data"][0]:
+ assert_array_equal(go_fig["data"][0][key], px_fig["data"][0][key])
+ for key in go_fig["layout"]:
+ assert go_fig["layout"][key] == px_fig["layout"][key]
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_sunburst_treemap_with_path():
+ vendors = ["A", "B", "C", "D", "E", "F", "G", "H"]
+ sectors = [
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ ]
+ regions = ["North", "North", "North", "North", "South", "South", "South", "South"]
+ values = [1, 3, 2, 4, 2, 2, 1, 4]
+ total = ["total"] * 8
+ df = pd.DataFrame(
+ dict(
+ vendors=vendors,
+ sectors=sectors,
+ regions=regions,
+ values=values,
+ total=total,
+ )
+ )
+ path = ["total", "regions", "sectors", "vendors"]
+ # No values
+ fig = px.sunburst(df, path=path)
+ assert fig.data[0].branchvalues == "total"
+ # Values passed
+ fig = px.sunburst(df, path=path, values="values")
+ assert fig.data[0].branchvalues == "total"
+ assert fig.data[0].values[-1] == np.sum(values)
+ # Values passed
+ fig = px.sunburst(df, path=path, values="values")
+ assert fig.data[0].branchvalues == "total"
+ assert fig.data[0].values[-1] == np.sum(values)
+ # Error when values cannot be converted to numerical data type
+ df["values"] = ["1 000", "3 000", "2", "4", "2", "2", "1 000", "4 000"]
+ msg = "Column `values` of `df` could not be converted to a numerical data type."
+ with pytest.raises(ValueError, match=msg):
+ fig = px.sunburst(df, path=path, values="values")
+ # path is a mixture of column names and array-like
+ path = [df.total, "regions", df.sectors, "vendors"]
+ fig = px.sunburst(df, path=path)
+ assert fig.data[0].branchvalues == "total"
+ # Continuous colorscale
+ df["values"] = 1
+ fig = px.sunburst(df, path=path, values="values", color="values")
+ assert "coloraxis" in fig.data[0].marker
+ assert np.all(np.array(fig.data[0].marker.colors) == 1)
+ assert fig.data[0].values[-1] == 8
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_sunburst_treemap_with_path_color():
+ vendors = ["A", "B", "C", "D", "E", "F", "G", "H"]
+ sectors = [
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ ]
+ regions = ["North", "North", "North", "North", "South", "South", "South", "South"]
+ values = [1, 3, 2, 4, 2, 2, 1, 4]
+ calls = [8, 2, 1, 3, 2, 2, 4, 1]
+ total = ["total"] * 8
+ df = pd.DataFrame(
+ dict(
+ vendors=vendors,
+ sectors=sectors,
+ regions=regions,
+ values=values,
+ total=total,
+ calls=calls,
+ )
+ )
+ path = ["total", "regions", "sectors", "vendors"]
+ fig = px.sunburst(df, path=path, values="values", color="calls")
+ colors = fig.data[0].marker.colors
+ assert np.all(np.array(colors[:8]) == np.array(calls))
+ fig = px.sunburst(df, path=path, color="calls")
+ colors = fig.data[0].marker.colors
+ assert np.all(np.array(colors[:8]) == np.array(calls))
+
+ # Hover info
+ df["hover"] = [el.lower() for el in vendors]
+ fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"])
+ custom = fig.data[0].customdata
+ assert np.all(custom[:8, 0] == df["hover"])
+ assert np.all(custom[8:, 0] == "(?)")
+ assert np.all(custom[:8, 1] == df["calls"])
+
+ # Discrete color
+ fig = px.sunburst(df, path=path, color="vendors")
+ assert len(np.unique(fig.data[0].marker.colors)) == 9
+
+ # Discrete color and color_discrete_map
+ cmap = {"Tech": "yellow", "Finance": "magenta", "(?)": "black"}
+ fig = px.sunburst(df, path=path, color="sectors", color_discrete_map=cmap)
+ assert np.all(np.in1d(fig.data[0].marker.colors, list(cmap.values())))
+
+ # Numerical column in path
+ df["regions"] = df["regions"].map({"North": 1, "South": 2})
+ path = ["total", "regions", "sectors", "vendors"]
+ fig = px.sunburst(df, path=path, values="values", color="calls")
+ colors = fig.data[0].marker.colors
+ assert np.all(np.array(colors[:8]) == np.array(calls))
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_sunburst_treemap_column_parent():
+ vendors = ["A", "B", "C", "D", "E", "F", "G", "H"]
+ sectors = [
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ ]
+ regions = ["North", "North", "North", "North", "South", "South", "South", "South"]
+ values = [1, 3, 2, 4, 2, 2, 1, 4]
+ df = pd.DataFrame(
+ dict(
+ id=vendors,
+ sectors=sectors,
+ parent=regions,
+ values=values,
+ )
+ )
+ path = ["parent", "sectors", "id"]
+ # One column of the path is a reserved name - this is ok and should not raise
+ px.sunburst(df, path=path, values="values")
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_sunburst_treemap_with_path_non_rectangular():
+ vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]
+ sectors = [
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ None,
+ "Tech",
+ "Tech",
+ "Finance",
+ "Finance",
+ "Finance",
+ ]
+ regions = [
+ "North",
+ "North",
+ "North",
+ "North",
+ "North",
+ "South",
+ "South",
+ "South",
+ "South",
+ "South",
+ ]
+ values = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1]
+ total = ["total"] * 10
+ df = pd.DataFrame(
+ dict(
+ vendors=vendors,
+ sectors=sectors,
+ regions=regions,
+ values=values,
+ total=total,
+ )
+ )
+ path = ["total", "regions", "sectors", "vendors"]
+ msg = "Non-leaves rows are not permitted in the dataframe"
+ with pytest.raises(ValueError, match=msg):
+ fig = px.sunburst(df, path=path, values="values")
+ df.loc[df["vendors"].isnull(), "sectors"] = "Other"
+ fig = px.sunburst(df, path=path, values="values")
+ assert fig.data[0].values[-1] == np.sum(values)
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_timeline():
+ df = pd.DataFrame(
+ [
+ dict(Task="Job A", Start="2009-01-01", Finish="2009-02-28"),
+ dict(Task="Job B", Start="2009-03-05", Finish="2009-04-15"),
+ dict(Task="Job C", Start="2009-02-20", Finish="2009-05-30"),
+ ]
+ )
+ fig = px.timeline(df, x_start="Start", x_end="Finish", y="Task", color="Task")
+ assert len(fig.data) == 3
+ assert fig.layout.xaxis.type == "date"
+ assert fig.layout.xaxis.title.text is None
+ fig = px.timeline(df, x_start="Start", x_end="Finish", y="Task", facet_row="Task")
+ assert len(fig.data) == 3
+ assert fig.data[1].xaxis == "x2"
+ assert fig.layout.xaxis.type == "date"
+
+ msg = "Both x_start and x_end are required"
+ with pytest.raises(ValueError, match=msg):
+ px.timeline(df, x_start="Start", y="Task", color="Task")
+
+ msg = "Both x_start and x_end must refer to data convertible to datetimes."
+ with pytest.raises(TypeError, match=msg):
+ px.timeline(df, x_start="Start", x_end=["a", "b", "c"], y="Task", color="Task")
diff --git a/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_hover.py b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_hover.py
new file mode 100644
index 00000000000..bb524925421
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_hover.py
@@ -0,0 +1,65 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import plotly.express as px
+import numpy as np
+import modin.pandas as pd
+import pytest
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_newdatain_hover_data():
+ hover_dicts = [
+ {"comment": ["a", "b", "c"]},
+ {"comment": (1.234, 45.3455, 5666.234)},
+ {"comment": [1.234, 45.3455, 5666.234]},
+ {"comment": np.array([1.234, 45.3455, 5666.234])},
+ {"comment": pd.Series([1.234, 45.3455, 5666.234])},
+ ]
+ for hover_dict in hover_dicts:
+ fig = px.scatter(x=[1, 2, 3], y=[3, 4, 5], hover_data=hover_dict)
+ assert (
+ fig.data[0].hovertemplate
+ == "x=%{x}
y=%{y}
comment=%{customdata[0]}"
+ )
+ fig = px.scatter(
+ x=[1, 2, 3], y=[3, 4, 5], hover_data={"comment": (True, ["a", "b", "c"])}
+ )
+ assert (
+ fig.data[0].hovertemplate
+ == "x=%{x}
y=%{y}
comment=%{customdata[0]}"
+ )
+ hover_dicts = [
+ {"comment": (":.1f", (1.234, 45.3455, 5666.234))},
+ {"comment": (":.1f", [1.234, 45.3455, 5666.234])},
+ {"comment": (":.1f", np.array([1.234, 45.3455, 5666.234]))},
+ {"comment": (":.1f", pd.Series([1.234, 45.3455, 5666.234]))},
+ ]
+ for hover_dict in hover_dicts:
+ fig = px.scatter(
+ x=[1, 2, 3],
+ y=[3, 4, 5],
+ hover_data=hover_dict,
+ )
+ assert (
+ fig.data[0].hovertemplate
+ == "x=%{x}
y=%{y}
comment=%{customdata[0]:.1f}"
+ )
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_date_in_hover():
+ df = pd.DataFrame({"date": ["2015-04-04 19:31:30+1:00"], "value": [3]})
+ df["date"] = pd.to_datetime(df["date"])
+ fig = px.scatter(df, x="value", y="value", hover_data=["date"])
+ assert str(fig.data[0].customdata[0][0]) == str(df["date"][0])
diff --git a/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_input.py b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_input.py
new file mode 100644
index 00000000000..11d2be2c4a1
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_input.py
@@ -0,0 +1,298 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import plotly.express as px
+import plotly.graph_objects as go
+import numpy as np
+import modin.pandas as pd
+import pytest
+from plotly.express._core import build_dataframe
+from pandas.testing import assert_frame_equal
+
+
+def test_pandas_series():
+ tips = px.data.tips()
+ before_tip = tips.total_bill - tips.tip
+ fig = px.bar(tips, x="day", y=before_tip)
+ assert fig.data[0].hovertemplate == "day=%{x}
y=%{y}"
+ fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"})
+ assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}"
+ # lock down that we can pass df.col to facet_*
+ fig = px.bar(tips, x="day", y="tip", facet_row=tips.day, facet_col=tips.day)
+ assert fig.data[0].hovertemplate == "day=%{x}
tip=%{y}"
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_several_dataframes():
+ df = pd.DataFrame(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8]))
+ df2 = pd.DataFrame(dict(time=[23, 26], money=[100, 200]))
+ fig = px.scatter(df, x="z", y=df2.money, size="x")
+ assert (
+ fig.data[0].hovertemplate
+ == "z=%{x}
y=%{y}
x=%{marker.size}"
+ )
+ fig = px.scatter(df2, x=df.z, y=df2.money, size=df.z)
+ assert (
+ fig.data[0].hovertemplate
+ == "x=%{x}
money=%{y}
size=%{marker.size}"
+ )
+ # Name conflict
+ with pytest.raises(NameError) as err_msg:
+ fig = px.scatter(df, x="z", y=df2.money, size="y")
+ assert "A name conflict was encountered for argument 'y'" in str(err_msg.value)
+ with pytest.raises(NameError) as err_msg:
+ fig = px.scatter(df, x="z", y=df2.money, size=df.y)
+ assert "A name conflict was encountered for argument 'y'" in str(err_msg.value)
+
+ # No conflict when the dataframe is not given, fields are used
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))
+ df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24]))
+ fig = px.scatter(x=df.y, y=df2.y)
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([23, 24]))
+ assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}"
+
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))
+ df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24]))
+ df3 = pd.DataFrame(dict(y=[0.1, 0.2]))
+ fig = px.scatter(x=df.y, y=df2.y, size=df3.y)
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([23, 24]))
+ assert (
+ fig.data[0].hovertemplate
+ == "x=%{x}
y=%{y}
size=%{marker.size}"
+ )
+
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))
+ df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24]))
+ df3 = pd.DataFrame(dict(y=[0.1, 0.2]))
+ fig = px.scatter(x=df.y, y=df2.y, hover_data=[df3.y])
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([23, 24]))
+ assert (
+ fig.data[0].hovertemplate
+ == "x=%{x}
y=%{y}
hover_data_0=%{customdata[0]}"
+ )
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_name_heuristics():
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2]))
+ fig = px.scatter(df, x=df.y, y=df.x, size=df.y)
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([0, 1]))
+ assert fig.data[0].hovertemplate == "y=%{marker.size}
x=%{y}"
+
+
+def test_arrayattrable_numpy():
+ tips = px.data.tips()
+ fig = px.scatter(
+ tips, x="total_bill", y="tip", hover_data=[np.random.random(tips.shape[0])]
+ )
+ assert (
+ fig.data[0]["hovertemplate"]
+ == "total_bill=%{x}
tip=%{y}
hover_data_0=%{customdata[0]}"
+ )
+ tips = px.data.tips()
+ fig = px.scatter(
+ tips,
+ x="total_bill",
+ y="tip",
+ hover_data=[np.random.random(tips.shape[0])],
+ labels={"hover_data_0": "suppl"},
+ )
+ assert (
+ fig.data[0]["hovertemplate"]
+ == "total_bill=%{x}
tip=%{y}
suppl=%{customdata[0]}"
+ )
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_wrong_dimensions_mixed_case():
+ with pytest.raises(ValueError) as err_msg:
+ df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25]))
+ px.scatter(df, x="time", y="temperature", color=[1, 3, 9, 5])
+ assert "All arguments should have the same length." in str(err_msg.value)
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_multiindex_raise_error():
+ index = pd.MultiIndex.from_product(
+ [[1, 2, 3], ["a", "b"]], names=["first", "second"]
+ )
+ df = pd.DataFrame(np.random.random((6, 3)), index=index, columns=["A", "B", "C"])
+ # This is ok
+ px.scatter(df, x="A", y="B")
+ with pytest.raises(TypeError) as err_msg:
+ px.scatter(df, x=df.index, y="B")
+ assert "pandas MultiIndex is not supported by plotly express" in str(err_msg.value)
+
+
+def test_build_df_from_lists():
+ # Just lists
+ args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9])
+ output = {key: key for key in args}
+ df = pd.DataFrame(args)
+ args["data_frame"] = None
+ out = build_dataframe(args, go.Scatter)
+ assert_frame_equal(
+ df.sort_index(axis=1)._to_pandas(), out["data_frame"].sort_index(axis=1)
+ )
+ out.pop("data_frame")
+ assert out == output
+
+ # Arrays
+ args = dict(x=np.array([1, 2, 3]), y=np.array([2, 3, 4]), color=[1, 3, 9])
+ output = {key: key for key in args}
+ df = pd.DataFrame(args)
+ args["data_frame"] = None
+ out = build_dataframe(args, go.Scatter)
+ assert_frame_equal(
+ df.sort_index(axis=1)._to_pandas(), out["data_frame"].sort_index(axis=1)
+ )
+ out.pop("data_frame")
+ assert out == output
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_timezones():
+ df = pd.DataFrame({"date": ["2015-04-04 19:31:30+1:00"], "value": [3]})
+ df["date"] = pd.to_datetime(df["date"])
+ args = dict(data_frame=df, x="date", y="value")
+ out = build_dataframe(args, go.Scatter)
+ assert str(out["data_frame"]["date"][0]) == str(df["date"][0])
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_non_matching_index():
+ df = pd.DataFrame(dict(y=[1, 2, 3]), index=["a", "b", "c"])
+
+ expected = pd.DataFrame(dict(index=["a", "b", "c"], y=[1, 2, 3]))
+
+ args = dict(data_frame=df, x=df.index, y="y")
+ out = build_dataframe(args, go.Scatter)
+ assert_frame_equal(expected._to_pandas(), out["data_frame"])
+
+ expected = pd.DataFrame(dict(x=["a", "b", "c"], y=[1, 2, 3]))
+
+ args = dict(data_frame=None, x=df.index, y=df.y)
+
+ # args = dict(data_frame=None, x=df.index, y=[1, 2, 3])
+ out = build_dataframe(args, go.Scatter)
+ assert_frame_equal(expected._to_pandas(), out["data_frame"])
+
+ # args = dict(data_frame=None, x=["a", "b", "c"], y=df.y)
+ # out = build_dataframe(args, go.Scatter)
+ # assert_frame_equal(expected._to_pandas(), out["data_frame"])
+
+
+def test_int_col_names():
+ # DataFrame with int column names
+ lengths = pd.DataFrame(np.random.random(100))
+ fig = px.histogram(lengths, x=0)
+ assert np.all(np.array(lengths).flatten() == fig.data[0].x)
+ # Numpy array
+ ar = np.arange(100).reshape((10, 10))
+ fig = px.scatter(ar, x=2, y=8)
+ assert np.all(fig.data[0].x == ar[:, 2])
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize(
+ "fn,mode", [(px.violin, "violinmode"), (px.box, "boxmode"), (px.strip, "boxmode")]
+)
+@pytest.mark.parametrize(
+ "x,y,color,result",
+ [
+ ("categorical1", "numerical", None, "group"),
+ ("categorical1", "numerical", "categorical2", "group"),
+ ("categorical1", "numerical", "categorical1", "overlay"),
+ ("numerical", "categorical1", None, "group"),
+ ("numerical", "categorical1", "categorical2", "group"),
+ ("numerical", "categorical1", "categorical1", "overlay"),
+ ],
+)
+def test_auto_boxlike_overlay(fn, mode, x, y, color, result):
+ df = pd.DataFrame(
+ dict(
+ categorical1=["a", "a", "b", "b"],
+ categorical2=["a", "a", "b", "b"],
+ numerical=[1, 2, 3, 4],
+ )
+ )
+ assert fn(df, x=x, y=y, color=color).layout[mode] == result
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("fn", [px.scatter, px.line, px.area, px.bar])
+def test_x_or_y(fn):
+ categorical = ["a", "a", "b", "b"]
+ numerical = [1, 2, 3, 4]
+ constant = [1, 1, 1, 1]
+ range_4 = [0, 1, 2, 3]
+ index = [11, 12, 13, 14]
+ numerical_df = pd.DataFrame(dict(col=numerical), index=index)
+ categorical_df = pd.DataFrame(dict(col=categorical), index=index)
+
+ fig = fn(x=numerical)
+ assert list(fig.data[0].x) == numerical
+ assert list(fig.data[0].y) == range_4
+ assert fig.data[0].orientation == "h"
+ fig = fn(y=numerical)
+ assert list(fig.data[0].x) == range_4
+ assert list(fig.data[0].y) == numerical
+ assert fig.data[0].orientation == "v"
+ fig = fn(numerical_df, x="col")
+ assert list(fig.data[0].x) == numerical
+ assert list(fig.data[0].y) == index
+ assert fig.data[0].orientation == "h"
+ fig = fn(numerical_df, y="col")
+ assert list(fig.data[0].x) == index
+ assert list(fig.data[0].y) == numerical
+ assert fig.data[0].orientation == "v"
+
+ if fn != px.bar:
+ fig = fn(x=categorical)
+ assert list(fig.data[0].x) == categorical
+ assert list(fig.data[0].y) == range_4
+ assert fig.data[0].orientation == "h"
+ fig = fn(y=categorical)
+ assert list(fig.data[0].x) == range_4
+ assert list(fig.data[0].y) == categorical
+ assert fig.data[0].orientation == "v"
+ fig = fn(categorical_df, x="col")
+ assert list(fig.data[0].x) == categorical
+ assert list(fig.data[0].y) == index
+ assert fig.data[0].orientation == "h"
+ fig = fn(categorical_df, y="col")
+ assert list(fig.data[0].x) == index
+ assert list(fig.data[0].y) == categorical
+ assert fig.data[0].orientation == "v"
+
+ else:
+ fig = fn(x=categorical)
+ assert list(fig.data[0].x) == categorical
+ assert list(fig.data[0].y) == constant
+ assert fig.data[0].orientation == "v"
+ fig = fn(y=categorical)
+ assert list(fig.data[0].x) == constant
+ assert list(fig.data[0].y) == categorical
+ assert fig.data[0].orientation == "h"
+ fig = fn(categorical_df, x="col")
+ assert list(fig.data[0].x) == categorical
+ assert list(fig.data[0].y) == constant
+ assert fig.data[0].orientation == "v"
+ fig = fn(categorical_df, y="col")
+ assert list(fig.data[0].x) == constant
+ assert list(fig.data[0].y) == categorical
+ assert fig.data[0].orientation == "h"
diff --git a/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_wide.py b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_wide.py
new file mode 100644
index 00000000000..ca29f7a052d
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_px_wide.py
@@ -0,0 +1,812 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import plotly.express as px
+import plotly.graph_objects as go
+import modin.pandas as pd
+from plotly.express._core import build_dataframe, _is_col_list
+from pandas.testing import assert_frame_equal
+import pytest
+
+
+def test_is_col_list():
+ df_input = pd.DataFrame(dict(a=[1, 2], b=[1, 2]))
+ assert _is_col_list(df_input, ["a"])
+ assert _is_col_list(df_input, ["a", "b"])
+ assert _is_col_list(df_input, [[3, 4]])
+ assert _is_col_list(df_input, [[3, 4], [3, 4]])
+ assert not _is_col_list(df_input, pytest)
+ assert not _is_col_list(df_input, False)
+ assert not _is_col_list(df_input, ["a", 1])
+ assert not _is_col_list(df_input, "a")
+ assert not _is_col_list(df_input, 1)
+ assert not _is_col_list(df_input, ["a", "b", "c"])
+ assert not _is_col_list(df_input, [1, 2])
+ df_input = pd.DataFrame([[1, 2], [1, 2]])
+ assert _is_col_list(df_input, [0])
+ assert _is_col_list(df_input, [0, 1])
+ assert _is_col_list(df_input, [[3, 4]])
+ assert _is_col_list(df_input, [[3, 4], [3, 4]])
+ assert not _is_col_list(df_input, pytest)
+ assert not _is_col_list(df_input, False)
+ assert not _is_col_list(df_input, ["a", 1])
+ assert not _is_col_list(df_input, "a")
+ assert not _is_col_list(df_input, 1)
+ assert not _is_col_list(df_input, [0, 1, 2])
+ assert not _is_col_list(df_input, ["a", "b"])
+ df_input = None
+ assert _is_col_list(df_input, [[3, 4]])
+ assert _is_col_list(df_input, [[3, 4], [3, 4]])
+ assert not _is_col_list(df_input, [0])
+ assert not _is_col_list(df_input, [0, 1])
+ assert not _is_col_list(df_input, pytest)
+ assert not _is_col_list(df_input, False)
+ assert not _is_col_list(df_input, ["a", 1])
+ assert not _is_col_list(df_input, "a")
+ assert not _is_col_list(df_input, 1)
+ assert not _is_col_list(df_input, [0, 1, 2])
+ assert not _is_col_list(df_input, ["a", "b"])
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_wide_mode_labels_external():
+ # here we prove that the _uglylabels_ can be renamed using the usual labels kwarg
+ df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6], c=[7, 8, 9]), index=[11, 12, 13])
+ fig = px.bar(df)
+ assert fig.layout.xaxis.title.text == "index"
+ assert fig.layout.yaxis.title.text == "value"
+ assert fig.layout.legend.title.text == "variable"
+ labels = dict(index="my index", value="my value", variable="my column")
+ fig = px.bar(df, labels=labels)
+ assert fig.layout.xaxis.title.text == "my index"
+ assert fig.layout.yaxis.title.text == "my value"
+ assert fig.layout.legend.title.text == "my column"
+ df.index.name = "my index"
+ df.columns.name = "my column"
+ fig = px.bar(df)
+ assert fig.layout.xaxis.title.text == "my index"
+ assert fig.layout.yaxis.title.text == "value"
+ assert fig.layout.legend.title.text == "my column"
+
+
+@pytest.mark.skip(reason="Failing test")
+# here we do basic exhaustive testing of the various graph_object permutations
+# via build_dataframe directly, which leads to more compact test code:
+# we pass in args (which includes df) and look at how build_dataframe mutates
+# both args and the df, and assume that since the rest of the downstream PX
+# machinery has no wide-mode-specific code, and the tests above pass, that this is
+# enough to prove things work
+@pytest.mark.parametrize(
+ "trace_type,x,y,color",
+ [
+ (go.Scatter, "index", "value", "variable"),
+ (go.Histogram2dContour, "index", "value", "variable"),
+ (go.Histogram2d, "index", "value", None),
+ (go.Bar, "index", "value", "variable"),
+ (go.Funnel, "index", "value", "variable"),
+ (go.Box, "variable", "value", None),
+ (go.Violin, "variable", "value", None),
+ (go.Histogram, "value", None, "variable"),
+ ],
+)
+@pytest.mark.parametrize("orientation", [None, "v", "h"])
+def test_wide_mode_internal(trace_type, x, y, color, orientation):
+ df_in = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]), index=[11, 12, 13])
+ args_in = dict(data_frame=df_in, color=None, orientation=orientation)
+ args_out = build_dataframe(args_in, trace_type)
+ df_out = args_out.pop("data_frame")
+ expected = dict(
+ variable=["a", "a", "a", "b", "b", "b"],
+ value=[1, 2, 3, 4, 5, 6],
+ )
+ if x == "index":
+ expected["index"] = [11, 12, 13, 11, 12, 13]
+ assert_frame_equal(
+ df_out.sort_index(axis=1),
+ pd.DataFrame(expected).sort_index(axis=1)._to_pandas(),
+ )
+ if trace_type in [go.Histogram2dContour, go.Histogram2d]:
+ if orientation is None or orientation == "v":
+ assert args_out == dict(x=x, y=y, color=color)
+ else:
+ assert args_out == dict(x=y, y=x, color=color)
+ else:
+ if (orientation is None and trace_type != go.Funnel) or orientation == "v":
+ assert args_out == dict(x=x, y=y, color=color, orientation="v")
+ else:
+ assert args_out == dict(x=y, y=x, color=color, orientation="h")
+
+
+cases = []
+for transpose in [True, False]:
+ for tt in [go.Scatter, go.Bar, go.Funnel, go.Histogram2dContour, go.Histogram2d]:
+ color = None if tt == go.Histogram2d else "variable"
+ df_in = dict(a=[1, 2], b=[3, 4])
+ args = dict(x=None, y=["a", "b"], color=None, orientation=None)
+ df_exp = dict(
+ variable=["a", "a", "b", "b"],
+ value=[1, 2, 3, 4],
+ index=[0, 1, 0, 1],
+ )
+ cases.append((tt, df_in, args, "index", "value", color, df_exp, transpose))
+
+ df_in = dict(a=[1, 2], b=[3, 4], c=[5, 6])
+ args = dict(x="c", y=["a", "b"], color=None, orientation=None)
+ df_exp = dict(
+ variable=["a", "a", "b", "b"],
+ value=[1, 2, 3, 4],
+ c=[5, 6, 5, 6],
+ )
+ cases.append((tt, df_in, args, "c", "value", color, df_exp, transpose))
+
+ args = dict(x=None, y=[[1, 2], [3, 4]], color=None, orientation=None)
+ df_exp = dict(
+ variable=[
+ "wide_variable_0",
+ "wide_variable_0",
+ "wide_variable_1",
+ "wide_variable_1",
+ ],
+ value=[1, 2, 3, 4],
+ index=[0, 1, 0, 1],
+ )
+ cases.append((tt, None, args, "index", "value", color, df_exp, transpose))
+
+ for tt in [go.Bar]: # bar categorical exception
+ df_in = dict(a=["q", "r"], b=["s", "t"])
+ args = dict(x=None, y=["a", "b"], color=None, orientation=None)
+ df_exp = dict(
+ variable=["a", "a", "b", "b"],
+ value=["q", "r", "s", "t"],
+ index=[0, 1, 0, 1],
+ count=[1, 1, 1, 1],
+ )
+ cases.append((tt, df_in, args, "value", "count", "variable", df_exp, transpose))
+
+ for tt in [go.Violin, go.Box]:
+ df_in = dict(a=[1, 2], b=[3, 4])
+ args = dict(x=None, y=["a", "b"], color=None, orientation=None)
+ df_exp = dict(
+ variable=["a", "a", "b", "b"],
+ value=[1, 2, 3, 4],
+ )
+ cases.append((tt, df_in, args, "variable", "value", None, df_exp, transpose))
+
+ df_in = dict(a=[1, 2], b=[3, 4], c=[5, 6])
+ args = dict(x="c", y=["a", "b"], color=None, orientation=None)
+ df_exp = dict(
+ variable=["a", "a", "b", "b"],
+ value=[1, 2, 3, 4],
+ c=[5, 6, 5, 6],
+ )
+ cases.append((tt, df_in, args, "c", "value", None, df_exp, transpose))
+
+ args = dict(x=None, y=[[1, 2], [3, 4]], color=None, orientation=None)
+ df_exp = dict(
+ variable=[
+ "wide_variable_0",
+ "wide_variable_0",
+ "wide_variable_1",
+ "wide_variable_1",
+ ],
+ value=[1, 2, 3, 4],
+ )
+ cases.append((tt, None, args, "variable", "value", None, df_exp, transpose))
+
+ for tt in [go.Histogram]:
+ df_in = dict(a=[1, 2], b=[3, 4])
+ args = dict(x=None, y=["a", "b"], color=None, orientation=None)
+ df_exp = dict(
+ variable=["a", "a", "b", "b"],
+ value=[1, 2, 3, 4],
+ )
+ cases.append((tt, df_in, args, None, "value", "variable", df_exp, transpose))
+
+ df_in = dict(a=[1, 2], b=[3, 4], c=[5, 6])
+ args = dict(x="c", y=["a", "b"], color=None, orientation=None)
+ df_exp = dict(
+ variable=["a", "a", "b", "b"],
+ value=[1, 2, 3, 4],
+ c=[5, 6, 5, 6],
+ )
+ cases.append((tt, df_in, args, "c", "value", "variable", df_exp, transpose))
+
+ args = dict(x=None, y=[[1, 2], [3, 4]], color=None, orientation=None)
+ df_exp = dict(
+ variable=[
+ "wide_variable_0",
+ "wide_variable_0",
+ "wide_variable_1",
+ "wide_variable_1",
+ ],
+ value=[1, 2, 3, 4],
+ )
+ cases.append((tt, None, args, None, "value", "variable", df_exp, transpose))
+
+
+@pytest.mark.parametrize("tt,df_in,args_in,x,y,color,df_out_exp,transpose", cases)
+def test_wide_x_or_y(tt, df_in, args_in, x, y, color, df_out_exp, transpose):
+ if transpose:
+ args_in["y"], args_in["x"] = args_in["x"], args_in["y"]
+ args_in["data_frame"] = df_in
+ args_out = build_dataframe(args_in, tt)
+ df_out = args_out.pop("data_frame").sort_index(axis=1)
+ assert_frame_equal(df_out, pd.DataFrame(df_out_exp).sort_index(axis=1)._to_pandas())
+ if transpose:
+ args_exp = dict(x=y, y=x, color=color)
+ else:
+ args_exp = dict(x=x, y=y, color=color)
+ if tt not in [go.Histogram2dContour, go.Histogram2d]:
+ orientation_exp = args_in["orientation"]
+ if (args_in["x"] is None) != (args_in["y"] is None) and tt != go.Histogram:
+ orientation_exp = "h" if transpose else "v"
+ args_exp["orientation"] = orientation_exp
+ assert args_out == args_exp
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("orientation", [None, "v", "h"])
+def test_wide_mode_internal_bar_exception(orientation):
+ df_in = pd.DataFrame(dict(a=["q", "r", "s"], b=["t", "u", "v"]), index=[11, 12, 13])
+ args_in = dict(data_frame=df_in, color=None, orientation=orientation)
+ args_out = build_dataframe(args_in, go.Bar)
+ df_out = args_out.pop("data_frame")
+ assert_frame_equal(
+ df_out.sort_index(axis=1),
+ pd.DataFrame(
+ dict(
+ index=[11, 12, 13, 11, 12, 13],
+ variable=["a", "a", "a", "b", "b", "b"],
+ value=["q", "r", "s", "t", "u", "v"],
+ count=[1, 1, 1, 1, 1, 1],
+ )
+ )
+ .sort_index(axis=1)
+ ._to_pandas(),
+ )
+ if orientation is None or orientation == "v":
+ assert args_out == dict(x="value", y="count", color="variable", orientation="v")
+ else:
+ assert args_out == dict(x="count", y="value", color="variable", orientation="h")
+
+
+# given all of the above tests, and given that the melt() code is not sensitive
+# to the trace type, we can do all sorts of special-case testing just by focusing
+# on build_dataframe(args, go.Scatter) for various values of args, and looking at
+# how args and df get mutated
+special_cases = []
+
+
+def append_special_case(df_in, args_in, args_expect, df_expect):
+ special_cases.append((df_in, args_in, args_expect, df_expect))
+
+
+# input is single bare array: column comes out as string "0"
+append_special_case(
+ df_in=[1, 2, 3],
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(x="index", y="value", color="variable", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(index=[0, 1, 2], value=[1, 2, 3], variable=["0", "0", "0"])
+ ),
+)
+
+# input is single bare Series: column comes out as string "0"
+append_special_case(
+ df_in=pd.Series([1, 2, 3]),
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(x="index", y="value", color="variable", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(index=[0, 1, 2], value=[1, 2, 3], variable=["0", "0", "0"])
+ ),
+)
+
+# input is a Series from a DF: we pick up the name and index values automatically
+df = pd.DataFrame(dict(my_col=[1, 2, 3]), index=["a", "b", "c"])
+append_special_case(
+ df_in=df["my_col"],
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(x="index", y="value", color="variable", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(
+ index=["a", "b", "c"],
+ value=[1, 2, 3],
+ variable=["my_col", "my_col", "my_col"],
+ )
+ ),
+)
+
+# input is an index from a DF: treated like a Series basically
+df = pd.DataFrame(dict(my_col=[1, 2, 3]), index=["a", "b", "c"])
+df.index.name = "my_index"
+append_special_case(
+ df_in=df.index,
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(x="index", y="value", color="variable", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 2],
+ value=["a", "b", "c"],
+ variable=["my_index", "my_index", "my_index"],
+ )
+ ),
+)
+
+# input is a data frame with named row and col indices: we grab those
+df = pd.DataFrame(dict(my_col=[1, 2, 3]), index=["a", "b", "c"])
+df.index.name = "my_index"
+df.columns.name = "my_col_name"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(x="my_index", y="value", color="my_col_name", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(
+ my_index=["a", "b", "c"],
+ value=[1, 2, 3],
+ my_col_name=["my_col", "my_col", "my_col"],
+ )
+ ),
+)
+
+# input is array of arrays: treated as rows, columns come out as string "0", "1"
+append_special_case(
+ df_in=[[1, 2], [4, 5]],
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(x="index", y="value", color="variable", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 4, 2, 5],
+ variable=["0", "0", "1", "1"],
+ )
+ ),
+)
+
+# partial-melting by assigning symbol: we pick up that column and don't melt it
+append_special_case(
+ df_in=pd.DataFrame(dict(a=[1, 2], b=[3, 4], symbol_col=["q", "r"])),
+ args_in=dict(x=None, y=None, color=None, symbol="symbol_col"),
+ args_expect=dict(
+ x="index",
+ y="value",
+ color="variable",
+ symbol="symbol_col",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ symbol_col=["q", "r", "q", "r"],
+ )
+ ),
+)
+
+# partial-melting by assigning the same column twice: we pick it up once
+append_special_case(
+ df_in=pd.DataFrame(dict(a=[1, 2], b=[3, 4], symbol_col=["q", "r"])),
+ args_in=dict(
+ x=None,
+ y=None,
+ color=None,
+ symbol="symbol_col",
+ custom_data=["symbol_col"],
+ ),
+ args_expect=dict(
+ x="index",
+ y="value",
+ color="variable",
+ symbol="symbol_col",
+ custom_data=["symbol_col"],
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ symbol_col=["q", "r", "q", "r"],
+ )
+ ),
+)
+
+# partial-melting by assigning more than one column: we pick them both up
+append_special_case(
+ df_in=pd.DataFrame(
+ dict(a=[1, 2], b=[3, 4], symbol_col=["q", "r"], data_col=["i", "j"])
+ ),
+ args_in=dict(
+ x=None,
+ y=None,
+ color=None,
+ symbol="symbol_col",
+ custom_data=["data_col"],
+ ),
+ args_expect=dict(
+ x="index",
+ y="value",
+ color="variable",
+ symbol="symbol_col",
+ custom_data=["data_col"],
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ symbol_col=["q", "r", "q", "r"],
+ data_col=["i", "j", "i", "j"],
+ )
+ ),
+)
+
+# partial-melting by assigning symbol to a bare array: we pick it up with the attr name
+append_special_case(
+ df_in=pd.DataFrame(dict(a=[1, 2], b=[3, 4])),
+ args_in=dict(x=None, y=None, color=None, symbol=["q", "r"]),
+ args_expect=dict(
+ x="index", y="value", color="variable", symbol="symbol", orientation="v"
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ symbol=["q", "r", "q", "r"],
+ )
+ ),
+)
+
+# assigning color to variable explicitly: just works
+append_special_case(
+ df_in=pd.DataFrame(dict(a=[1, 2], b=[3, 4])),
+ args_in=dict(x=None, y=None, color="variable"),
+ args_expect=dict(x="index", y="value", color="variable", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(index=[0, 1, 0, 1], value=[1, 2, 3, 4], variable=["a", "a", "b", "b"])
+ ),
+)
+
+# assigning color to a different column: variable drops out of args
+append_special_case(
+ df_in=pd.DataFrame(dict(a=[1, 2], b=[3, 4], color_col=["q", "r"])),
+ args_in=dict(x=None, y=None, color="color_col"),
+ args_expect=dict(x="index", y="value", color="color_col", orientation="v"),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ color_col=["q", "r", "q", "r"],
+ )
+ ),
+)
+
+# assigning variable to something else: just works
+append_special_case(
+ df_in=pd.DataFrame(dict(a=[1, 2], b=[3, 4])),
+ args_in=dict(x=None, y=None, color=None, symbol="variable"),
+ args_expect=dict(
+ x="index", y="value", color="variable", symbol="variable", orientation="v"
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ )
+ ),
+)
+
+# swapping symbol and color: just works
+append_special_case(
+ df_in=pd.DataFrame(dict(a=[1, 2], b=[3, 4], color_col=["q", "r"])),
+ args_in=dict(x=None, y=None, color="color_col", symbol="variable"),
+ args_expect=dict(
+ x="index",
+ y="value",
+ color="color_col",
+ symbol="variable",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ color_col=["q", "r", "q", "r"],
+ )
+ ),
+)
+
+# a DF with a named column index: have to use that instead of variable
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
+df.columns.name = "my_col_name"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None, facet_row="my_col_name"),
+ args_expect=dict(
+ x="index",
+ y="value",
+ color="my_col_name",
+ facet_row="my_col_name",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ my_col_name=["a", "a", "b", "b"],
+ )
+ ),
+)
+
+# passing the DF index into some other attr: works
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
+df.columns.name = "my_col_name"
+df.index.name = "my_index_name"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None, hover_name=df.index),
+ args_expect=dict(
+ x="my_index_name",
+ y="value",
+ color="my_col_name",
+ hover_name="my_index_name",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ my_index_name=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ my_col_name=["a", "a", "b", "b"],
+ )
+ ),
+)
+
+# assigning value to something: works
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
+df.columns.name = "my_col_name"
+df.index.name = "my_index_name"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None, hover_name="value"),
+ args_expect=dict(
+ x="my_index_name",
+ y="value",
+ color="my_col_name",
+ hover_name="value",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ my_index_name=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ my_col_name=["a", "a", "b", "b"],
+ )
+ ),
+)
+
+# assigning a px.Constant: works
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
+df.columns.name = "my_col_name"
+df.index.name = "my_index_name"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None, symbol=px.Constant(1)),
+ args_expect=dict(
+ x="my_index_name",
+ y="value",
+ color="my_col_name",
+ symbol="symbol",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ my_index_name=[0, 1, 0, 1],
+ value=[1, 2, 3, 4],
+ my_col_name=["a", "a", "b", "b"],
+ symbol=[1, 1, 1, 1],
+ )
+ ),
+)
+
+# df has columns named after every special string
+df = pd.DataFrame(dict(index=[1, 2], value=[3, 4], variable=[5, 6]), index=[7, 8])
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(
+ x="_index",
+ y="_value",
+ color="_variable",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ _index=[7, 8, 7, 8, 7, 8],
+ _value=[1, 2, 3, 4, 5, 6],
+ _variable=["index", "index", "value", "value", "variable", "variable"],
+ )
+ ),
+)
+
+# df has columns with name collisions with indexes
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=[7, 8])
+df.index.name = "a"
+df.columns.name = "b"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(
+ x="index",
+ y="value",
+ color="variable",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[7, 8, 7, 8],
+ value=[1, 2, 3, 4],
+ variable=["a", "a", "b", "b"],
+ )
+ ),
+)
+
+# everything is called value, OMG
+df = pd.DataFrame(dict(b=[1, 2], value=[3, 4]), index=[7, 8])
+df.index.name = "value"
+df.columns.name = "value"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None),
+ args_expect=dict(
+ x="index",
+ y="_value",
+ color="variable",
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ index=[7, 8, 7, 8],
+ _value=[1, 2, 3, 4],
+ variable=["b", "b", "value", "value"],
+ )
+ ),
+)
+
+# y = columns
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=[7, 8])
+df.index.name = "c"
+df.columns.name = "d"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=df.index, y=df.columns, color=None),
+ args_expect=dict(x="c", y="value", color="d"),
+ df_expect=pd.DataFrame(
+ dict(c=[7, 8, 7, 8], d=["a", "a", "b", "b"], value=[1, 2, 3, 4])
+ ),
+)
+
+# y = columns subset
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=[7, 8])
+df.index.name = "c"
+df.columns.name = "d"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=df.index, y=df.columns[:1], color=None),
+ args_expect=dict(x="c", y="value", color="variable"),
+ df_expect=pd.DataFrame(dict(c=[7, 8], variable=["a", "a"], value=[1, 2])),
+)
+
+# list-like hover_data
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=[7, 8])
+df.index.name = "c"
+df.columns.name = "d"
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=None, hover_data=dict(new=[5, 6])),
+ args_expect=dict(
+ x="c",
+ y="value",
+ color="d",
+ orientation="v",
+ hover_data=dict(new=(True, [5, 6])),
+ ),
+ df_expect=pd.DataFrame(
+ dict(
+ c=[7, 8, 7, 8], d=["a", "a", "b", "b"], new=[5, 6, 5, 6], value=[1, 2, 3, 4]
+ )
+ ),
+)
+
+# NO_COLOR
+df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
+append_special_case(
+ df_in=df,
+ args_in=dict(x=None, y=None, color=px.NO_COLOR),
+ args_expect=dict(
+ x="index",
+ y="value",
+ color=None,
+ orientation="v",
+ ),
+ df_expect=pd.DataFrame(
+ dict(variable=["a", "a", "b", "b"], index=[0, 1, 0, 1], value=[1, 2, 3, 4])
+ ),
+)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("df_in, args_in, args_expect, df_expect", special_cases)
+def test_wide_mode_internal_special_cases(df_in, args_in, args_expect, df_expect):
+ args_in["data_frame"] = df_in
+ args_out = build_dataframe(args_in, go.Scatter)
+ df_out = args_out.pop("data_frame")
+ assert args_out == args_expect
+ assert_frame_equal(
+ df_out.sort_index(axis=1),
+ df_expect.sort_index(axis=1)._to_pandas(),
+ )
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_multi_index():
+ df = pd.DataFrame([[1, 2, 3, 4], [3, 4, 5, 6], [1, 2, 3, 4], [3, 4, 5, 6]])
+ df.index = [["a", "a", "b", "b"], ["c", "d", "c", "d"]]
+ with pytest.raises(TypeError) as err_msg:
+ px.scatter(df)
+ assert "pandas MultiIndex is not supported by plotly express" in str(err_msg.value)
+
+ df = pd.DataFrame([[1, 2, 3, 4], [3, 4, 5, 6], [1, 2, 3, 4], [3, 4, 5, 6]])
+ df.columns = [["e", "e", "f", "f"], ["g", "h", "g", "h"]]
+ with pytest.raises(TypeError) as err_msg:
+ px.scatter(df)
+ assert "pandas MultiIndex is not supported by plotly express" in str(err_msg.value)
+
+
+@pytest.mark.parametrize("df", [px.data.stocks(), dict(a=[1, 2], b=["1", "2"])])
+def test_mixed_input_error(df):
+ with pytest.raises(ValueError) as err_msg:
+ px.line(df)
+ assert (
+ "Plotly Express cannot process wide-form data with columns of different type"
+ in str(err_msg.value)
+ )
+
+
+def test_mixed_number_input():
+ df = pd.DataFrame(dict(a=[1, 2], b=[1.1, 2.1]))
+ fig = px.line(df)
+ assert len(fig.data) == 2
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_line_group():
+ df = pd.DataFrame(
+ data={
+ "who": ["a", "a", "b", "b"],
+ "x": [0, 1, 0, 1],
+ "score": [1.0, 2, 3, 4],
+ "miss": [3.2, 2.5, 1.3, 1.5],
+ }
+ )
+ fig = px.line(df, x="x", y=["miss", "score"])
+ assert len(fig.data) == 2
+ fig = px.line(df, x="x", y=["miss", "score"], color="who")
+ assert len(fig.data) == 4
+ fig = px.scatter(df, x="x", y=["miss", "score"], color="who")
+ assert len(fig.data) == 2
diff --git a/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_trendline.py b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_trendline.py
new file mode 100644
index 00000000000..c400990815a
--- /dev/null
+++ b/modin/pandas/test/interoperability/plotly/test_optional/test_px/test_trendline.py
@@ -0,0 +1,254 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import plotly.express as px
+import numpy as np
+import modin.pandas as pd
+import pytest
+from datetime import datetime
+
+
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_results_passthrough(mode, options):
+ df = px.data.gapminder().query("continent == 'Oceania'")
+ fig = px.scatter(
+ df,
+ x="year",
+ y="pop",
+ color="country",
+ trendline=mode,
+ trendline_options=options,
+ )
+ assert len(fig.data) == 4
+ for trace in fig["data"][0::2]:
+ assert "trendline" not in trace.hovertemplate
+ for trendline in fig["data"][1::2]:
+ assert "trendline" in trendline.hovertemplate
+ if mode == "ols":
+ assert "R2" in trendline.hovertemplate
+ results = px.get_trendline_results(fig)
+ if mode == "ols":
+ assert len(results) == 2
+ assert results["country"].values[0] == "Australia"
+ au_result = results["px_fit_results"].values[0]
+ assert len(au_result.params) == 2
+ else:
+ assert len(results) == 0
+
+
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_enough_values(mode, options):
+ fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode, trendline_options=options)
+ assert len(fig.data) == 2
+ assert len(fig.data[1].x) == 2
+ fig = px.scatter(x=[0], y=[0], trendline=mode, trendline_options=options)
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode, trendline_options=options)
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(
+ x=[0, 1], y=np.array([0, np.nan]), trendline=mode, trendline_options=options
+ )
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(
+ x=[0, 1, None], y=[0, None, 1], trendline=mode, trendline_options=options
+ )
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(
+ x=np.array([0, 1, np.nan]),
+ y=np.array([0, np.nan, 1]),
+ trendline=mode,
+ trendline_options=options,
+ )
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(
+ x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode, trendline_options=options
+ )
+ assert len(fig.data) == 2
+ assert len(fig.data[1].x) == 2
+ fig = px.scatter(
+ x=np.array([0, 1, np.nan, 2]),
+ y=np.array([1, np.nan, 1, 2]),
+ trendline=mode,
+ trendline_options=options,
+ )
+ assert len(fig.data) == 2
+ assert len(fig.data[1].x) == 2
+
+
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("ols", dict(add_constant=False, log_x=True, log_y=True)),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_nan_values(mode, options):
+ df = px.data.gapminder().query("continent == 'Oceania'")
+ start_date = 1970
+ df["pop"][df["year"] < start_date] = np.nan
+ fig = px.scatter(
+ df,
+ x="year",
+ y="pop",
+ color="country",
+ trendline=mode,
+ trendline_options=options,
+ )
+ for trendline in fig["data"][1::2]:
+ assert trendline.x[0] >= start_date
+ assert len(trendline.x) == len(trendline.y)
+
+
+def test_ols_trendline_slopes():
+ fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
+ # should be "y = 1 * x + 0" but sometimes is some tiny number instead
+ assert "y = 1 * x + " in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [0, 1]))
+
+ fig = px.scatter(x=[0, 1], y=[1, 2], trendline="ols")
+ assert "y = 1 * x + 1
" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [1, 1]))
+
+ fig = px.scatter(
+ x=[0, 1], y=[1, 2], trendline="ols", trendline_options=dict(add_constant=False)
+ )
+ assert "y = 2 * x
" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [2]))
+
+ fig = px.scatter(
+ x=[1, 1], y=[0, 0], trendline="ols", trendline_options=dict(add_constant=False)
+ )
+ assert "y = 0 * x
" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [0]))
+
+ fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
+ assert "y = 0
" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [0]))
+
+ fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
+ assert "y = 0 * x + 0
" in fig.data[1].hovertemplate
+ fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
+ assert "y = 0 * x + 1
" in fig.data[1].hovertemplate
+ fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
+ assert "y = 0 * x + 1.5
" in fig.data[1].hovertemplate
+
+
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("rolling", dict(window="10d")),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_on_timeseries(mode, options):
+ df = px.data.stocks()
+
+ with pytest.raises(ValueError) as err_msg:
+ px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options)
+ assert "Could not convert value of 'x' ('date') into a numeric type." in str(
+ err_msg.value
+ )
+
+ df["date"] = pd.to_datetime(df["date"])
+ df["date"] = df["date"].dt.tz_localize("CET") # force a timezone
+ fig = px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options)
+ assert len(fig.data) == 2
+ assert len(fig.data[0].x) == len(fig.data[1].x)
+ assert type(fig.data[0].x[0]) == datetime
+ assert type(fig.data[1].x[0]) == datetime
+ assert np.all(fig.data[0].x == fig.data[1].x)
+ assert str(fig.data[0].x[0]) == str(fig.data[1].x[0])
+
+
+def test_overall_trendline():
+ df = px.data.tips()
+ fig1 = px.scatter(df, x="total_bill", y="tip", trendline="ols")
+ assert len(fig1.data) == 2
+ assert "trendline" in fig1.data[1].hovertemplate
+ results1 = px.get_trendline_results(fig1)
+ params1 = results1["px_fit_results"].iloc[0].params
+
+ fig2 = px.scatter(
+ df,
+ x="total_bill",
+ y="tip",
+ color="sex",
+ trendline="ols",
+ trendline_scope="overall",
+ )
+ assert len(fig2.data) == 3
+ assert "trendline" in fig2.data[2].hovertemplate
+ results2 = px.get_trendline_results(fig2)
+ params2 = results2["px_fit_results"].iloc[0].params
+
+ assert np.all(np.array_equal(params1, params2))
+
+ fig3 = px.scatter(
+ df,
+ x="total_bill",
+ y="tip",
+ facet_row="sex",
+ trendline="ols",
+ trendline_scope="overall",
+ )
+ assert len(fig3.data) == 4
+ assert "trendline" in fig3.data[3].hovertemplate
+ results3 = px.get_trendline_results(fig3)
+ params3 = results3["px_fit_results"].iloc[0].params
+
+ assert np.all(np.array_equal(params1, params3))
diff --git a/modin/pandas/test/interoperability/sklearn/compose/test_column_transformer.py b/modin/pandas/test/interoperability/sklearn/compose/test_column_transformer.py
new file mode 100644
index 00000000000..63da4d0d90f
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/compose/test_column_transformer.py
@@ -0,0 +1,2147 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+"""
+Test the ColumnTransformer.
+"""
+import re
+import pickle
+import numpy as np
+from scipy import sparse
+import pytest
+
+from numpy.testing import assert_allclose
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_allclose_dense_sparse
+from sklearn.utils._testing import assert_almost_equal
+
+from sklearn.base import BaseEstimator, TransformerMixin
+from sklearn.compose import (
+ ColumnTransformer,
+ make_column_transformer,
+ make_column_selector,
+)
+from sklearn.exceptions import NotFittedError
+from sklearn.preprocessing import FunctionTransformer
+from sklearn.preprocessing import StandardScaler, Normalizer, OneHotEncoder
+
+
+class Trans(TransformerMixin, BaseEstimator):
+ def fit(self, X, y=None):
+ return self
+
+ def transform(self, X, y=None):
+ # 1D Series -> 2D DataFrame
+ if hasattr(X, "to_frame"):
+ return X.to_frame()
+ # 1D array -> 2D array
+ if X.ndim == 1:
+ return np.atleast_2d(X).T
+ return X
+
+
+class DoubleTrans(BaseEstimator):
+ def fit(self, X, y=None):
+ return self
+
+ def transform(self, X):
+ return 2 * X
+
+
+class SparseMatrixTrans(BaseEstimator):
+ def fit(self, X, y=None):
+ return self
+
+ def transform(self, X, y=None):
+ n_samples = len(X)
+ return sparse.eye(n_samples, n_samples).tocsr()
+
+
+class TransNo2D(BaseEstimator):
+ def fit(self, X, y=None):
+ return self
+
+ def transform(self, X, y=None):
+ return X
+
+
+class TransRaise(BaseEstimator):
+ def fit(self, X, y=None):
+ raise ValueError("specific message")
+
+ def transform(self, X, y=None):
+ raise ValueError("specific message")
+
+
+def test_column_transformer():
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+
+ X_res_first1D = np.array([0, 1, 2])
+ X_res_second1D = np.array([2, 4, 6])
+ X_res_first = X_res_first1D.reshape(-1, 1)
+ X_res_both = X_array
+
+ cases = [
+ # single column 1D / 2D
+ (0, X_res_first),
+ ([0], X_res_first),
+ # list-like
+ ([0, 1], X_res_both),
+ (np.array([0, 1]), X_res_both),
+ # slice
+ (slice(0, 1), X_res_first),
+ (slice(0, 2), X_res_both),
+ # boolean mask
+ (np.array([True, False]), X_res_first),
+ ([True, False], X_res_first),
+ (np.array([True, True]), X_res_both),
+ ([True, True], X_res_both),
+ ]
+
+ for selection, res in cases:
+ ct = ColumnTransformer([("trans", Trans(), selection)], remainder="drop")
+ assert_array_equal(ct.fit_transform(X_array), res)
+ assert_array_equal(ct.fit(X_array).transform(X_array), res)
+
+ # callable that returns any of the allowed specifiers
+ ct = ColumnTransformer(
+ [("trans", Trans(), lambda x: selection)], remainder="drop"
+ )
+ assert_array_equal(ct.fit_transform(X_array), res)
+ assert_array_equal(ct.fit(X_array).transform(X_array), res)
+
+ ct = ColumnTransformer([("trans1", Trans(), [0]), ("trans2", Trans(), [1])])
+ assert_array_equal(ct.fit_transform(X_array), X_res_both)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_both)
+ assert len(ct.transformers_) == 2
+
+ # test with transformer_weights
+ transformer_weights = {"trans1": 0.1, "trans2": 10}
+ both = ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", Trans(), [1])],
+ transformer_weights=transformer_weights,
+ )
+ res = np.vstack(
+ [
+ transformer_weights["trans1"] * X_res_first1D,
+ transformer_weights["trans2"] * X_res_second1D,
+ ]
+ ).T
+ assert_array_equal(both.fit_transform(X_array), res)
+ assert_array_equal(both.fit(X_array).transform(X_array), res)
+ assert len(both.transformers_) == 2
+
+ both = ColumnTransformer(
+ [("trans", Trans(), [0, 1])], transformer_weights={"trans": 0.1}
+ )
+ assert_array_equal(both.fit_transform(X_array), 0.1 * X_res_both)
+ assert_array_equal(both.fit(X_array).transform(X_array), 0.1 * X_res_both)
+ assert len(both.transformers_) == 1
+
+
+def test_column_transformer_tuple_transformers_parameter():
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+
+ transformers = [("trans1", Trans(), [0]), ("trans2", Trans(), [1])]
+
+ ct_with_list = ColumnTransformer(transformers)
+ ct_with_tuple = ColumnTransformer(tuple(transformers))
+
+ assert_array_equal(
+ ct_with_list.fit_transform(X_array), ct_with_tuple.fit_transform(X_array)
+ )
+ assert_array_equal(
+ ct_with_list.fit(X_array).transform(X_array),
+ ct_with_tuple.fit(X_array).transform(X_array),
+ )
+
+
+def test_column_transformer_dataframe():
+ pd = pytest.importorskip("modin.pandas")
+
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_df = pd.DataFrame(X_array, columns=["first", "second"])
+
+ X_res_first = np.array([0, 1, 2]).reshape(-1, 1)
+ X_res_both = X_array
+
+ cases = [
+ # String keys: label based
+ # scalar
+ ("first", X_res_first),
+ # list
+ (["first"], X_res_first),
+ (["first", "second"], X_res_both),
+ # slice
+ (slice("first", "second"), X_res_both),
+ # int keys: positional
+ # scalar
+ (0, X_res_first),
+ # list
+ ([0], X_res_first),
+ ([0, 1], X_res_both),
+ (np.array([0, 1]), X_res_both),
+ # slice
+ (slice(0, 1), X_res_first),
+ (slice(0, 2), X_res_both),
+ # boolean mask
+ (np.array([True, False]), X_res_first),
+ (pd.Series([True, False], index=["first", "second"]), X_res_first),
+ ([True, False], X_res_first),
+ ]
+
+ for selection, res in cases:
+ ct = ColumnTransformer([("trans", Trans(), selection)], remainder="drop")
+ assert_array_equal(ct.fit_transform(X_df), res)
+ assert_array_equal(ct.fit(X_df).transform(X_df), res)
+
+ # callable that returns any of the allowed specifiers
+ ct = ColumnTransformer(
+ [("trans", Trans(), lambda X: selection)], remainder="drop"
+ )
+ assert_array_equal(ct.fit_transform(X_df), res)
+ assert_array_equal(ct.fit(X_df).transform(X_df), res)
+
+ ct = ColumnTransformer(
+ [("trans1", Trans(), ["first"]), ("trans2", Trans(), ["second"])]
+ )
+ assert_array_equal(ct.fit_transform(X_df), X_res_both)
+ assert_array_equal(ct.fit(X_df).transform(X_df), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] != "remainder"
+
+ ct = ColumnTransformer([("trans1", Trans(), [0]), ("trans2", Trans(), [1])])
+ assert_array_equal(ct.fit_transform(X_df), X_res_both)
+ assert_array_equal(ct.fit(X_df).transform(X_df), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] != "remainder"
+
+ # test with transformer_weights
+ transformer_weights = {"trans1": 0.1, "trans2": 10}
+ both = ColumnTransformer(
+ [("trans1", Trans(), ["first"]), ("trans2", Trans(), ["second"])],
+ transformer_weights=transformer_weights,
+ )
+ res = np.vstack(
+ [
+ transformer_weights["trans1"] * X_df["first"],
+ transformer_weights["trans2"] * X_df["second"],
+ ]
+ ).T
+ assert_array_equal(both.fit_transform(X_df), res)
+ assert_array_equal(both.fit(X_df).transform(X_df), res)
+ assert len(both.transformers_) == 2
+ assert both.transformers_[-1][0] != "remainder"
+
+ # test multiple columns
+ both = ColumnTransformer(
+ [("trans", Trans(), ["first", "second"])], transformer_weights={"trans": 0.1}
+ )
+ assert_array_equal(both.fit_transform(X_df), 0.1 * X_res_both)
+ assert_array_equal(both.fit(X_df).transform(X_df), 0.1 * X_res_both)
+ assert len(both.transformers_) == 1
+ assert both.transformers_[-1][0] != "remainder"
+
+ both = ColumnTransformer(
+ [("trans", Trans(), [0, 1])], transformer_weights={"trans": 0.1}
+ )
+ assert_array_equal(both.fit_transform(X_df), 0.1 * X_res_both)
+ assert_array_equal(both.fit(X_df).transform(X_df), 0.1 * X_res_both)
+ assert len(both.transformers_) == 1
+ assert both.transformers_[-1][0] != "remainder"
+
+ # ensure pandas object is passed through
+
+ class TransAssert(BaseEstimator):
+ def fit(self, X, y=None):
+ return self
+
+ def transform(self, X, y=None):
+ assert isinstance(X, (pd.DataFrame, pd.Series))
+ if isinstance(X, pd.Series):
+ X = X.to_frame()
+ return X
+
+ ct = ColumnTransformer([("trans", TransAssert(), "first")], remainder="drop")
+ ct.fit_transform(X_df)
+ ct = ColumnTransformer([("trans", TransAssert(), ["first", "second"])])
+ ct.fit_transform(X_df)
+
+ # integer column spec + integer column names -> still use positional
+ X_df2 = X_df.copy()
+ X_df2.columns = [1, 0]
+ ct = ColumnTransformer([("trans", Trans(), 0)], remainder="drop")
+ assert_array_equal(ct.fit_transform(X_df2), X_res_first)
+ assert_array_equal(ct.fit(X_df2).transform(X_df2), X_res_first)
+
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][1] == "drop"
+ assert_array_equal(ct.transformers_[-1][2], [1])
+
+
+@pytest.mark.parametrize("pandas", [True, False], ids=["pandas", "numpy"])
+@pytest.mark.parametrize(
+ "column_selection",
+ [[], np.array([False, False]), [False, False]],
+ ids=["list", "bool", "bool_int"],
+)
+@pytest.mark.parametrize("callable_column", [False, True])
+def test_column_transformer_empty_columns(pandas, column_selection, callable_column):
+ # test case that ensures that the column transformer does also work when
+ # a given transformer doesn't have any columns to work on
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_res_both = X_array
+
+ if pandas:
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame(X_array, columns=["first", "second"])
+ else:
+ X = X_array
+
+ if callable_column:
+ column = lambda X: column_selection # noqa
+ else:
+ column = column_selection
+
+ ct = ColumnTransformer(
+ [("trans1", Trans(), [0, 1]), ("trans2", TransRaise(), column)]
+ )
+ assert_array_equal(ct.fit_transform(X), X_res_both)
+ assert_array_equal(ct.fit(X).transform(X), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert isinstance(ct.transformers_[1][1], TransRaise)
+
+ ct = ColumnTransformer(
+ [("trans1", TransRaise(), column), ("trans2", Trans(), [0, 1])]
+ )
+ assert_array_equal(ct.fit_transform(X), X_res_both)
+ assert_array_equal(ct.fit(X).transform(X), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert isinstance(ct.transformers_[0][1], TransRaise)
+
+ ct = ColumnTransformer([("trans", TransRaise(), column)], remainder="passthrough")
+ assert_array_equal(ct.fit_transform(X), X_res_both)
+ assert_array_equal(ct.fit(X).transform(X), X_res_both)
+ assert len(ct.transformers_) == 2 # including remainder
+ assert isinstance(ct.transformers_[0][1], TransRaise)
+
+ fixture = np.array([[], [], []])
+ ct = ColumnTransformer([("trans", TransRaise(), column)], remainder="drop")
+ assert_array_equal(ct.fit_transform(X), fixture)
+ assert_array_equal(ct.fit(X).transform(X), fixture)
+ assert len(ct.transformers_) == 2 # including remainder
+ assert isinstance(ct.transformers_[0][1], TransRaise)
+
+
+def test_column_transformer_output_indices():
+ # Checks for the output_indices_ attribute
+ X_array = np.arange(6).reshape(3, 2)
+
+ ct = ColumnTransformer([("trans1", Trans(), [0]), ("trans2", Trans(), [1])])
+ X_trans = ct.fit_transform(X_array)
+ assert ct.output_indices_ == {
+ "trans1": slice(0, 1),
+ "trans2": slice(1, 2),
+ "remainder": slice(0, 0),
+ }
+ assert_array_equal(X_trans[:, [0]], X_trans[:, ct.output_indices_["trans1"]])
+ assert_array_equal(X_trans[:, [1]], X_trans[:, ct.output_indices_["trans2"]])
+
+ # test with transformer_weights and multiple columns
+ ct = ColumnTransformer(
+ [("trans", Trans(), [0, 1])], transformer_weights={"trans": 0.1}
+ )
+ X_trans = ct.fit_transform(X_array)
+ assert ct.output_indices_ == {"trans": slice(0, 2), "remainder": slice(0, 0)}
+ assert_array_equal(X_trans[:, [0, 1]], X_trans[:, ct.output_indices_["trans"]])
+ assert_array_equal(X_trans[:, []], X_trans[:, ct.output_indices_["remainder"]])
+
+ # test case that ensures that the attribute does also work when
+ # a given transformer doesn't have any columns to work on
+ ct = ColumnTransformer([("trans1", Trans(), [0, 1]), ("trans2", TransRaise(), [])])
+ X_trans = ct.fit_transform(X_array)
+ assert ct.output_indices_ == {
+ "trans1": slice(0, 2),
+ "trans2": slice(0, 0),
+ "remainder": slice(0, 0),
+ }
+ assert_array_equal(X_trans[:, [0, 1]], X_trans[:, ct.output_indices_["trans1"]])
+ assert_array_equal(X_trans[:, []], X_trans[:, ct.output_indices_["trans2"]])
+ assert_array_equal(X_trans[:, []], X_trans[:, ct.output_indices_["remainder"]])
+
+ ct = ColumnTransformer([("trans", TransRaise(), [])], remainder="passthrough")
+ X_trans = ct.fit_transform(X_array)
+ assert ct.output_indices_ == {"trans": slice(0, 0), "remainder": slice(0, 2)}
+ assert_array_equal(X_trans[:, []], X_trans[:, ct.output_indices_["trans"]])
+ assert_array_equal(X_trans[:, [0, 1]], X_trans[:, ct.output_indices_["remainder"]])
+
+
+def test_column_transformer_output_indices_df():
+ # Checks for the output_indices_ attribute with data frames
+ pd = pytest.importorskip("modin.pandas")
+
+ X_df = pd.DataFrame(np.arange(6).reshape(3, 2), columns=["first", "second"])
+
+ ct = ColumnTransformer(
+ [("trans1", Trans(), ["first"]), ("trans2", Trans(), ["second"])]
+ )
+ X_trans = ct.fit_transform(X_df)
+ assert ct.output_indices_ == {
+ "trans1": slice(0, 1),
+ "trans2": slice(1, 2),
+ "remainder": slice(0, 0),
+ }
+ assert_array_equal(X_trans[:, [0]], X_trans[:, ct.output_indices_["trans1"]])
+ assert_array_equal(X_trans[:, [1]], X_trans[:, ct.output_indices_["trans2"]])
+ assert_array_equal(X_trans[:, []], X_trans[:, ct.output_indices_["remainder"]])
+
+ ct = ColumnTransformer([("trans1", Trans(), [0]), ("trans2", Trans(), [1])])
+ X_trans = ct.fit_transform(X_df)
+ assert ct.output_indices_ == {
+ "trans1": slice(0, 1),
+ "trans2": slice(1, 2),
+ "remainder": slice(0, 0),
+ }
+ assert_array_equal(X_trans[:, [0]], X_trans[:, ct.output_indices_["trans1"]])
+ assert_array_equal(X_trans[:, [1]], X_trans[:, ct.output_indices_["trans2"]])
+ assert_array_equal(X_trans[:, []], X_trans[:, ct.output_indices_["remainder"]])
+
+
+def test_column_transformer_sparse_array():
+ X_sparse = sparse.eye(3, 2).tocsr()
+
+ # no distinction between 1D and 2D
+ X_res_first = X_sparse[:, 0]
+ X_res_both = X_sparse
+
+ for col in [0, [0], slice(0, 1)]:
+ for remainder, res in [("drop", X_res_first), ("passthrough", X_res_both)]:
+ ct = ColumnTransformer(
+ [("trans", Trans(), col)], remainder=remainder, sparse_threshold=0.8
+ )
+ assert sparse.issparse(ct.fit_transform(X_sparse))
+ assert_allclose_dense_sparse(ct.fit_transform(X_sparse), res)
+ assert_allclose_dense_sparse(ct.fit(X_sparse).transform(X_sparse), res)
+
+ for col in [[0, 1], slice(0, 2)]:
+ ct = ColumnTransformer([("trans", Trans(), col)], sparse_threshold=0.8)
+ assert sparse.issparse(ct.fit_transform(X_sparse))
+ assert_allclose_dense_sparse(ct.fit_transform(X_sparse), X_res_both)
+ assert_allclose_dense_sparse(ct.fit(X_sparse).transform(X_sparse), X_res_both)
+
+
+def test_column_transformer_list():
+ X_list = [[1, float("nan"), "a"], [0, 0, "b"]]
+ expected_result = np.array(
+ [
+ [1, float("nan"), 1, 0],
+ [-1, 0, 0, 1],
+ ]
+ )
+
+ ct = ColumnTransformer(
+ [
+ ("numerical", StandardScaler(), [0, 1]),
+ ("categorical", OneHotEncoder(), [2]),
+ ]
+ )
+
+ assert_array_equal(ct.fit_transform(X_list), expected_result)
+ assert_array_equal(ct.fit(X_list).transform(X_list), expected_result)
+
+
+def test_column_transformer_sparse_stacking():
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ col_trans = ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", SparseMatrixTrans(), 1)],
+ sparse_threshold=0.8,
+ )
+ col_trans.fit(X_array)
+ X_trans = col_trans.transform(X_array)
+ assert sparse.issparse(X_trans)
+ assert X_trans.shape == (X_trans.shape[0], X_trans.shape[0] + 1)
+ assert_array_equal(X_trans.toarray()[:, 1:], np.eye(X_trans.shape[0]))
+ assert len(col_trans.transformers_) == 2
+ assert col_trans.transformers_[-1][0] != "remainder"
+
+ col_trans = ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", SparseMatrixTrans(), 1)],
+ sparse_threshold=0.1,
+ )
+ col_trans.fit(X_array)
+ X_trans = col_trans.transform(X_array)
+ assert not sparse.issparse(X_trans)
+ assert X_trans.shape == (X_trans.shape[0], X_trans.shape[0] + 1)
+ assert_array_equal(X_trans[:, 1:], np.eye(X_trans.shape[0]))
+
+
+def test_column_transformer_mixed_cols_sparse():
+ df = np.array([["a", 1, True], ["b", 2, False]], dtype="O")
+
+ ct = make_column_transformer(
+ (OneHotEncoder(), [0]), ("passthrough", [1, 2]), sparse_threshold=1.0
+ )
+
+ # this shouldn't fail, since boolean can be coerced into a numeric
+ # See: https://github.com/scikit-learn/scikit-learn/issues/11912
+ X_trans = ct.fit_transform(df)
+ assert X_trans.getformat() == "csr"
+ assert_array_equal(X_trans.toarray(), np.array([[1, 0, 1, 1], [0, 1, 2, 0]]))
+
+ ct = make_column_transformer(
+ (OneHotEncoder(), [0]), ("passthrough", [0]), sparse_threshold=1.0
+ )
+ with pytest.raises(ValueError, match="For a sparse output, all columns should"):
+ # this fails since strings `a` and `b` cannot be
+ # coerced into a numeric.
+ ct.fit_transform(df)
+
+
+def test_column_transformer_sparse_threshold():
+ X_array = np.array([["a", "b"], ["A", "B"]], dtype=object).T
+ # above data has sparsity of 4 / 8 = 0.5
+
+ # apply threshold even if all sparse
+ col_trans = ColumnTransformer(
+ [("trans1", OneHotEncoder(), [0]), ("trans2", OneHotEncoder(), [1])],
+ sparse_threshold=0.2,
+ )
+ res = col_trans.fit_transform(X_array)
+ assert not sparse.issparse(res)
+ assert not col_trans.sparse_output_
+
+ # mixed -> sparsity of (4 + 2) / 8 = 0.75
+ for thres in [0.75001, 1]:
+ col_trans = ColumnTransformer(
+ [
+ ("trans1", OneHotEncoder(sparse_output=True), [0]),
+ ("trans2", OneHotEncoder(sparse_output=False), [1]),
+ ],
+ sparse_threshold=thres,
+ )
+ res = col_trans.fit_transform(X_array)
+ assert sparse.issparse(res)
+ assert col_trans.sparse_output_
+
+ for thres in [0.75, 0]:
+ col_trans = ColumnTransformer(
+ [
+ ("trans1", OneHotEncoder(sparse_output=True), [0]),
+ ("trans2", OneHotEncoder(sparse_output=False), [1]),
+ ],
+ sparse_threshold=thres,
+ )
+ res = col_trans.fit_transform(X_array)
+ assert not sparse.issparse(res)
+ assert not col_trans.sparse_output_
+
+ # if nothing is sparse -> no sparse
+ for thres in [0.33, 0, 1]:
+ col_trans = ColumnTransformer(
+ [
+ ("trans1", OneHotEncoder(sparse_output=False), [0]),
+ ("trans2", OneHotEncoder(sparse_output=False), [1]),
+ ],
+ sparse_threshold=thres,
+ )
+ res = col_trans.fit_transform(X_array)
+ assert not sparse.issparse(res)
+ assert not col_trans.sparse_output_
+
+
+def test_column_transformer_error_msg_1D():
+ X_array = np.array([[0.0, 1.0, 2.0], [2.0, 4.0, 6.0]]).T
+
+ col_trans = ColumnTransformer([("trans", StandardScaler(), 0)])
+ msg = "1D data passed to a transformer"
+ with pytest.raises(ValueError, match=msg):
+ col_trans.fit(X_array)
+
+ with pytest.raises(ValueError, match=msg):
+ col_trans.fit_transform(X_array)
+
+ col_trans = ColumnTransformer([("trans", TransRaise(), 0)])
+ for func in [col_trans.fit, col_trans.fit_transform]:
+ with pytest.raises(ValueError, match="specific message"):
+ func(X_array)
+
+
+def test_2D_transformer_output():
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+
+ # if one transformer is dropped, test that name is still correct
+ ct = ColumnTransformer([("trans1", "drop", 0), ("trans2", TransNo2D(), 1)])
+
+ msg = "the 'trans2' transformer should be 2D"
+ with pytest.raises(ValueError, match=msg):
+ ct.fit_transform(X_array)
+ # because fit is also doing transform, this raises already on fit
+ with pytest.raises(ValueError, match=msg):
+ ct.fit(X_array)
+
+
+def test_2D_transformer_output_pandas():
+ pd = pytest.importorskip("modin.pandas")
+
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_df = pd.DataFrame(X_array, columns=["col1", "col2"])
+
+ # if one transformer is dropped, test that name is still correct
+ ct = ColumnTransformer([("trans1", TransNo2D(), "col1")])
+ msg = "the 'trans1' transformer should be 2D"
+ with pytest.raises(ValueError, match=msg):
+ ct.fit_transform(X_df)
+ # because fit is also doing transform, this raises already on fit
+ with pytest.raises(ValueError, match=msg):
+ ct.fit(X_df)
+
+
+@pytest.mark.parametrize("remainder", ["drop", "passthrough"])
+def test_column_transformer_invalid_columns(remainder):
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+
+ # general invalid
+ for col in [1.5, ["string", 1], slice(1, "s"), np.array([1.0])]:
+ ct = ColumnTransformer([("trans", Trans(), col)], remainder=remainder)
+ with pytest.raises(ValueError, match="No valid specification"):
+ ct.fit(X_array)
+
+ # invalid for arrays
+ for col in ["string", ["string", "other"], slice("a", "b")]:
+ ct = ColumnTransformer([("trans", Trans(), col)], remainder=remainder)
+ with pytest.raises(ValueError, match="Specifying the columns"):
+ ct.fit(X_array)
+
+ # transformed n_features does not match fitted n_features
+ col = [0, 1]
+ ct = ColumnTransformer([("trans", Trans(), col)], remainder=remainder)
+ ct.fit(X_array)
+ X_array_more = np.array([[0, 1, 2], [2, 4, 6], [3, 6, 9]]).T
+ msg = "X has 3 features, but ColumnTransformer is expecting 2 features as input."
+ with pytest.raises(ValueError, match=msg):
+ ct.transform(X_array_more)
+ X_array_fewer = np.array(
+ [
+ [0, 1, 2],
+ ]
+ ).T
+ err_msg = (
+ "X has 1 features, but ColumnTransformer is expecting 2 features as input."
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ ct.transform(X_array_fewer)
+
+
+def test_column_transformer_invalid_transformer():
+ class NoTrans(BaseEstimator):
+ def fit(self, X, y=None):
+ return self
+
+ def predict(self, X):
+ return X
+
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ ct = ColumnTransformer([("trans", NoTrans(), [0])])
+ msg = "All estimators should implement fit and transform"
+ with pytest.raises(TypeError, match=msg):
+ ct.fit(X_array)
+
+
+def test_make_column_transformer():
+ scaler = StandardScaler()
+ norm = Normalizer()
+ ct = make_column_transformer((scaler, "first"), (norm, ["second"]))
+ names, transformers, columns = zip(*ct.transformers)
+ assert names == ("standardscaler", "normalizer")
+ assert transformers == (scaler, norm)
+ assert columns == ("first", ["second"])
+
+
+def test_make_column_transformer_pandas():
+ pd = pytest.importorskip("modin.pandas")
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_df = pd.DataFrame(X_array, columns=["first", "second"])
+ norm = Normalizer()
+ ct1 = ColumnTransformer([("norm", Normalizer(), X_df.columns)])
+ ct2 = make_column_transformer((norm, X_df.columns))
+ assert_almost_equal(ct1.fit_transform(X_df), ct2.fit_transform(X_df))
+
+
+def test_make_column_transformer_kwargs():
+ scaler = StandardScaler()
+ norm = Normalizer()
+ ct = make_column_transformer(
+ (scaler, "first"),
+ (norm, ["second"]),
+ n_jobs=3,
+ remainder="drop",
+ sparse_threshold=0.5,
+ )
+ assert (
+ ct.transformers
+ == make_column_transformer((scaler, "first"), (norm, ["second"])).transformers
+ )
+ assert ct.n_jobs == 3
+ assert ct.remainder == "drop"
+ assert ct.sparse_threshold == 0.5
+ # invalid keyword parameters should raise an error message
+ msg = re.escape(
+ "make_column_transformer() got an unexpected keyword argument 'transformer_weights'"
+ )
+ with pytest.raises(TypeError, match=msg):
+ make_column_transformer(
+ (scaler, "first"),
+ (norm, ["second"]),
+ transformer_weights={"pca": 10, "Transf": 1},
+ )
+
+
+def test_make_column_transformer_remainder_transformer():
+ scaler = StandardScaler()
+ norm = Normalizer()
+ remainder = StandardScaler()
+ ct = make_column_transformer(
+ (scaler, "first"), (norm, ["second"]), remainder=remainder
+ )
+ assert ct.remainder == remainder
+
+
+def test_column_transformer_get_set_params():
+ ct = ColumnTransformer(
+ [("trans1", StandardScaler(), [0]), ("trans2", StandardScaler(), [1])]
+ )
+
+ exp = {
+ "n_jobs": None,
+ "remainder": "drop",
+ "sparse_threshold": 0.3,
+ "trans1": ct.transformers[0][1],
+ "trans1__copy": True,
+ "trans1__with_mean": True,
+ "trans1__with_std": True,
+ "trans2": ct.transformers[1][1],
+ "trans2__copy": True,
+ "trans2__with_mean": True,
+ "trans2__with_std": True,
+ "transformers": ct.transformers,
+ "transformer_weights": None,
+ "verbose_feature_names_out": True,
+ "verbose": False,
+ }
+
+ assert ct.get_params() == exp
+
+ ct.set_params(trans1__with_mean=False)
+ assert not ct.get_params()["trans1__with_mean"]
+
+ ct.set_params(trans1="passthrough")
+ exp = {
+ "n_jobs": None,
+ "remainder": "drop",
+ "sparse_threshold": 0.3,
+ "trans1": "passthrough",
+ "trans2": ct.transformers[1][1],
+ "trans2__copy": True,
+ "trans2__with_mean": True,
+ "trans2__with_std": True,
+ "transformers": ct.transformers,
+ "transformer_weights": None,
+ "verbose_feature_names_out": True,
+ "verbose": False,
+ }
+
+ assert ct.get_params() == exp
+
+
+def test_column_transformer_named_estimators():
+ X_array = np.array([[0.0, 1.0, 2.0], [2.0, 4.0, 6.0]]).T
+ ct = ColumnTransformer(
+ [
+ ("trans1", StandardScaler(), [0]),
+ ("trans2", StandardScaler(with_std=False), [1]),
+ ]
+ )
+ assert not hasattr(ct, "transformers_")
+ ct.fit(X_array)
+ assert hasattr(ct, "transformers_")
+ assert isinstance(ct.named_transformers_["trans1"], StandardScaler)
+ assert isinstance(ct.named_transformers_.trans1, StandardScaler)
+ assert isinstance(ct.named_transformers_["trans2"], StandardScaler)
+ assert isinstance(ct.named_transformers_.trans2, StandardScaler)
+ assert not ct.named_transformers_.trans2.with_std
+ # check it are fitted transformers
+ assert ct.named_transformers_.trans1.mean_ == 1.0
+
+
+def test_column_transformer_cloning():
+ X_array = np.array([[0.0, 1.0, 2.0], [2.0, 4.0, 6.0]]).T
+
+ ct = ColumnTransformer([("trans", StandardScaler(), [0])])
+ ct.fit(X_array)
+ assert not hasattr(ct.transformers[0][1], "mean_")
+ assert hasattr(ct.transformers_[0][1], "mean_")
+
+ ct = ColumnTransformer([("trans", StandardScaler(), [0])])
+ ct.fit_transform(X_array)
+ assert not hasattr(ct.transformers[0][1], "mean_")
+ assert hasattr(ct.transformers_[0][1], "mean_")
+
+
+def test_column_transformer_get_feature_names():
+ X_array = np.array([[0.0, 1.0, 2.0], [2.0, 4.0, 6.0]]).T
+ ct = ColumnTransformer([("trans", Trans(), [0, 1])])
+ # raise correct error when not fitted
+ with pytest.raises(NotFittedError):
+ ct.get_feature_names_out()
+ # raise correct error when no feature names are available
+ ct.fit(X_array)
+ msg = re.escape(
+ "Transformer trans (type Trans) does not provide get_feature_names_out"
+ )
+ with pytest.raises(AttributeError, match=msg):
+ ct.get_feature_names_out()
+
+
+def test_column_transformer_special_strings():
+ # one 'drop' -> ignore
+ X_array = np.array([[0.0, 1.0, 2.0], [2.0, 4.0, 6.0]]).T
+ ct = ColumnTransformer([("trans1", Trans(), [0]), ("trans2", "drop", [1])])
+ exp = np.array([[0.0], [1.0], [2.0]])
+ assert_array_equal(ct.fit_transform(X_array), exp)
+ assert_array_equal(ct.fit(X_array).transform(X_array), exp)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] != "remainder"
+
+ # all 'drop' -> return shape 0 array
+ ct = ColumnTransformer([("trans1", "drop", [0]), ("trans2", "drop", [1])])
+ assert_array_equal(ct.fit(X_array).transform(X_array).shape, (3, 0))
+ assert_array_equal(ct.fit_transform(X_array).shape, (3, 0))
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] != "remainder"
+
+ # 'passthrough'
+ X_array = np.array([[0.0, 1.0, 2.0], [2.0, 4.0, 6.0]]).T
+ ct = ColumnTransformer([("trans1", Trans(), [0]), ("trans2", "passthrough", [1])])
+ exp = X_array
+ assert_array_equal(ct.fit_transform(X_array), exp)
+ assert_array_equal(ct.fit(X_array).transform(X_array), exp)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] != "remainder"
+
+
+def test_column_transformer_remainder():
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+
+ X_res_first = np.array([0, 1, 2]).reshape(-1, 1)
+ X_res_second = np.array([2, 4, 6]).reshape(-1, 1)
+ X_res_both = X_array
+
+ # default drop
+ ct = ColumnTransformer([("trans1", Trans(), [0])])
+ assert_array_equal(ct.fit_transform(X_array), X_res_first)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_first)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][1] == "drop"
+ assert_array_equal(ct.transformers_[-1][2], [1])
+
+ # specify passthrough
+ ct = ColumnTransformer([("trans", Trans(), [0])], remainder="passthrough")
+ assert_array_equal(ct.fit_transform(X_array), X_res_both)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][1] == "passthrough"
+ assert_array_equal(ct.transformers_[-1][2], [1])
+
+ # column order is not preserved (passed through added to end)
+ ct = ColumnTransformer([("trans1", Trans(), [1])], remainder="passthrough")
+ assert_array_equal(ct.fit_transform(X_array), X_res_both[:, ::-1])
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_both[:, ::-1])
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][1] == "passthrough"
+ assert_array_equal(ct.transformers_[-1][2], [0])
+
+ # passthrough when all actual transformers are skipped
+ ct = ColumnTransformer([("trans1", "drop", [0])], remainder="passthrough")
+ assert_array_equal(ct.fit_transform(X_array), X_res_second)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_second)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][1] == "passthrough"
+ assert_array_equal(ct.transformers_[-1][2], [1])
+
+ # check default for make_column_transformer
+ ct = make_column_transformer((Trans(), [0]))
+ assert ct.remainder == "drop"
+
+
+@pytest.mark.parametrize(
+ "key", [[0], np.array([0]), slice(0, 1), np.array([True, False])]
+)
+def test_column_transformer_remainder_numpy(key):
+ # test different ways that columns are specified with passthrough
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_res_both = X_array
+
+ ct = ColumnTransformer([("trans1", Trans(), key)], remainder="passthrough")
+ assert_array_equal(ct.fit_transform(X_array), X_res_both)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][1] == "passthrough"
+ assert_array_equal(ct.transformers_[-1][2], [1])
+
+
+@pytest.mark.parametrize(
+ "key",
+ [
+ [0],
+ slice(0, 1),
+ np.array([True, False]),
+ ["first"],
+ "pd-index",
+ np.array(["first"]),
+ np.array(["first"], dtype=object),
+ slice(None, "first"),
+ slice("first", "first"),
+ ],
+)
+def test_column_transformer_remainder_pandas(key):
+ # test different ways that columns are specified with passthrough
+ pd = pytest.importorskip("modin.pandas")
+ if isinstance(key, str) and key == "pd-index":
+ key = pd.Index(["first"])
+
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_df = pd.DataFrame(X_array, columns=["first", "second"])
+ X_res_both = X_array
+
+ ct = ColumnTransformer([("trans1", Trans(), key)], remainder="passthrough")
+ assert_array_equal(ct.fit_transform(X_df), X_res_both)
+ assert_array_equal(ct.fit(X_df).transform(X_df), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][1] == "passthrough"
+ assert_array_equal(ct.transformers_[-1][2], [1])
+
+
+@pytest.mark.parametrize(
+ "key", [[0], np.array([0]), slice(0, 1), np.array([True, False, False])]
+)
+def test_column_transformer_remainder_transformer(key):
+ X_array = np.array([[0, 1, 2], [2, 4, 6], [8, 6, 4]]).T
+ X_res_both = X_array.copy()
+
+ # second and third columns are doubled when remainder = DoubleTrans
+ X_res_both[:, 1:3] *= 2
+
+ ct = ColumnTransformer([("trans1", Trans(), key)], remainder=DoubleTrans())
+
+ assert_array_equal(ct.fit_transform(X_array), X_res_both)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert isinstance(ct.transformers_[-1][1], DoubleTrans)
+ assert_array_equal(ct.transformers_[-1][2], [1, 2])
+
+
+def test_column_transformer_no_remaining_remainder_transformer():
+ X_array = np.array([[0, 1, 2], [2, 4, 6], [8, 6, 4]]).T
+
+ ct = ColumnTransformer([("trans1", Trans(), [0, 1, 2])], remainder=DoubleTrans())
+
+ assert_array_equal(ct.fit_transform(X_array), X_array)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_array)
+ assert len(ct.transformers_) == 1
+ assert ct.transformers_[-1][0] != "remainder"
+
+
+def test_column_transformer_drops_all_remainder_transformer():
+ X_array = np.array([[0, 1, 2], [2, 4, 6], [8, 6, 4]]).T
+
+ # columns are doubled when remainder = DoubleTrans
+ X_res_both = 2 * X_array.copy()[:, 1:3]
+
+ ct = ColumnTransformer([("trans1", "drop", [0])], remainder=DoubleTrans())
+
+ assert_array_equal(ct.fit_transform(X_array), X_res_both)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_both)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert isinstance(ct.transformers_[-1][1], DoubleTrans)
+ assert_array_equal(ct.transformers_[-1][2], [1, 2])
+
+
+def test_column_transformer_sparse_remainder_transformer():
+ X_array = np.array([[0, 1, 2], [2, 4, 6], [8, 6, 4]]).T
+
+ ct = ColumnTransformer(
+ [("trans1", Trans(), [0])], remainder=SparseMatrixTrans(), sparse_threshold=0.8
+ )
+
+ X_trans = ct.fit_transform(X_array)
+ assert sparse.issparse(X_trans)
+ # SparseMatrixTrans creates 3 features for each column. There is
+ # one column in ``transformers``, thus:
+ assert X_trans.shape == (3, 3 + 1)
+
+ exp_array = np.hstack((X_array[:, 0].reshape(-1, 1), np.eye(3)))
+ assert_array_equal(X_trans.toarray(), exp_array)
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert isinstance(ct.transformers_[-1][1], SparseMatrixTrans)
+ assert_array_equal(ct.transformers_[-1][2], [1, 2])
+
+
+def test_column_transformer_drop_all_sparse_remainder_transformer():
+ X_array = np.array([[0, 1, 2], [2, 4, 6], [8, 6, 4]]).T
+ ct = ColumnTransformer(
+ [("trans1", "drop", [0])], remainder=SparseMatrixTrans(), sparse_threshold=0.8
+ )
+
+ X_trans = ct.fit_transform(X_array)
+ assert sparse.issparse(X_trans)
+
+ # SparseMatrixTrans creates 3 features for each column, thus:
+ assert X_trans.shape == (3, 3)
+ assert_array_equal(X_trans.toarray(), np.eye(3))
+ assert len(ct.transformers_) == 2
+ assert ct.transformers_[-1][0] == "remainder"
+ assert isinstance(ct.transformers_[-1][1], SparseMatrixTrans)
+ assert_array_equal(ct.transformers_[-1][2], [1, 2])
+
+
+def test_column_transformer_get_set_params_with_remainder():
+ ct = ColumnTransformer(
+ [("trans1", StandardScaler(), [0])], remainder=StandardScaler()
+ )
+
+ exp = {
+ "n_jobs": None,
+ "remainder": ct.remainder,
+ "remainder__copy": True,
+ "remainder__with_mean": True,
+ "remainder__with_std": True,
+ "sparse_threshold": 0.3,
+ "trans1": ct.transformers[0][1],
+ "trans1__copy": True,
+ "trans1__with_mean": True,
+ "trans1__with_std": True,
+ "transformers": ct.transformers,
+ "transformer_weights": None,
+ "verbose_feature_names_out": True,
+ "verbose": False,
+ }
+
+ assert ct.get_params() == exp
+
+ ct.set_params(remainder__with_std=False)
+ assert not ct.get_params()["remainder__with_std"]
+
+ ct.set_params(trans1="passthrough")
+ exp = {
+ "n_jobs": None,
+ "remainder": ct.remainder,
+ "remainder__copy": True,
+ "remainder__with_mean": True,
+ "remainder__with_std": False,
+ "sparse_threshold": 0.3,
+ "trans1": "passthrough",
+ "transformers": ct.transformers,
+ "transformer_weights": None,
+ "verbose_feature_names_out": True,
+ "verbose": False,
+ }
+ assert ct.get_params() == exp
+
+
+def test_column_transformer_no_estimators():
+ X_array = np.array([[0, 1, 2], [2, 4, 6], [8, 6, 4]]).astype("float").T
+ ct = ColumnTransformer([], remainder=StandardScaler())
+
+ params = ct.get_params()
+ assert params["remainder__with_mean"]
+
+ X_trans = ct.fit_transform(X_array)
+ assert X_trans.shape == X_array.shape
+ assert len(ct.transformers_) == 1
+ assert ct.transformers_[-1][0] == "remainder"
+ assert ct.transformers_[-1][2] == [0, 1, 2]
+
+
+@pytest.mark.parametrize(
+ ["est", "pattern"],
+ [
+ (
+ ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", Trans(), [1])],
+ remainder=DoubleTrans(),
+ ),
+ (
+ r"\[ColumnTransformer\].*\(1 of 3\) Processing trans1.* total=.*\n"
+ r"\[ColumnTransformer\].*\(2 of 3\) Processing trans2.* total=.*\n"
+ r"\[ColumnTransformer\].*\(3 of 3\) Processing remainder.* total=.*\n$"
+ ),
+ ),
+ (
+ ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", Trans(), [1])],
+ remainder="passthrough",
+ ),
+ (
+ r"\[ColumnTransformer\].*\(1 of 3\) Processing trans1.* total=.*\n"
+ r"\[ColumnTransformer\].*\(2 of 3\) Processing trans2.* total=.*\n"
+ r"\[ColumnTransformer\].*\(3 of 3\) Processing remainder.* total=.*\n$"
+ ),
+ ),
+ (
+ ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", "drop", [1])],
+ remainder="passthrough",
+ ),
+ (
+ r"\[ColumnTransformer\].*\(1 of 2\) Processing trans1.* total=.*\n"
+ r"\[ColumnTransformer\].*\(2 of 2\) Processing remainder.* total=.*\n$"
+ ),
+ ),
+ (
+ ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", "passthrough", [1])],
+ remainder="passthrough",
+ ),
+ (
+ r"\[ColumnTransformer\].*\(1 of 3\) Processing trans1.* total=.*\n"
+ r"\[ColumnTransformer\].*\(2 of 3\) Processing trans2.* total=.*\n"
+ r"\[ColumnTransformer\].*\(3 of 3\) Processing remainder.* total=.*\n$"
+ ),
+ ),
+ (
+ ColumnTransformer([("trans1", Trans(), [0])], remainder="passthrough"),
+ (
+ r"\[ColumnTransformer\].*\(1 of 2\) Processing trans1.* total=.*\n"
+ r"\[ColumnTransformer\].*\(2 of 2\) Processing remainder.* total=.*\n$"
+ ),
+ ),
+ (
+ ColumnTransformer(
+ [("trans1", Trans(), [0]), ("trans2", Trans(), [1])], remainder="drop"
+ ),
+ (
+ r"\[ColumnTransformer\].*\(1 of 2\) Processing trans1.* total=.*\n"
+ r"\[ColumnTransformer\].*\(2 of 2\) Processing trans2.* total=.*\n$"
+ ),
+ ),
+ (
+ ColumnTransformer([("trans1", Trans(), [0])], remainder="drop"),
+ r"\[ColumnTransformer\].*\(1 of 1\) Processing trans1.* total=.*\n$",
+ ),
+ ],
+)
+@pytest.mark.parametrize("method", ["fit", "fit_transform"])
+def test_column_transformer_verbose(est, pattern, method, capsys):
+ X_array = np.array([[0, 1, 2], [2, 4, 6], [8, 6, 4]]).T
+
+ func = getattr(est, method)
+ est.set_params(verbose=False)
+ func(X_array)
+ assert not capsys.readouterr().out, "Got output for verbose=False"
+
+ est.set_params(verbose=True)
+ func(X_array)
+ assert re.match(pattern, capsys.readouterr()[0])
+
+
+def test_column_transformer_no_estimators_set_params():
+ ct = ColumnTransformer([]).set_params(n_jobs=2)
+ assert ct.n_jobs == 2
+
+
+def test_column_transformer_callable_specifier():
+ # assert that function gets the full array
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_res_first = np.array([[0, 1, 2]]).T
+
+ def func(X):
+ assert_array_equal(X, X_array)
+ return [0]
+
+ ct = ColumnTransformer([("trans", Trans(), func)], remainder="drop")
+ assert_array_equal(ct.fit_transform(X_array), X_res_first)
+ assert_array_equal(ct.fit(X_array).transform(X_array), X_res_first)
+ assert callable(ct.transformers[0][2])
+ assert ct.transformers_[0][2] == [0]
+
+
+def test_column_transformer_callable_specifier_dataframe():
+ # assert that function gets the full dataframe
+ pd = pytest.importorskip("modin.pandas")
+ X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_res_first = np.array([[0, 1, 2]]).T
+
+ X_df = pd.DataFrame(X_array, columns=["first", "second"])
+
+ def func(X):
+ assert_array_equal(X.columns, X_df.columns)
+ assert_array_equal(X.values, X_df.values)
+ return ["first"]
+
+ ct = ColumnTransformer([("trans", Trans(), func)], remainder="drop")
+ assert_array_equal(ct.fit_transform(X_df), X_res_first)
+ assert_array_equal(ct.fit(X_df).transform(X_df), X_res_first)
+ assert callable(ct.transformers[0][2])
+ assert ct.transformers_[0][2] == ["first"]
+
+
+def test_column_transformer_negative_column_indexes():
+ X = np.random.randn(2, 2)
+ X_categories = np.array([[1], [2]])
+ X = np.concatenate([X, X_categories], axis=1)
+
+ ohe = OneHotEncoder()
+
+ tf_1 = ColumnTransformer([("ohe", ohe, [-1])], remainder="passthrough")
+ tf_2 = ColumnTransformer([("ohe", ohe, [2])], remainder="passthrough")
+ assert_array_equal(tf_1.fit_transform(X), tf_2.fit_transform(X))
+
+
+@pytest.mark.parametrize("array_type", [np.asarray, sparse.csr_matrix])
+def test_column_transformer_mask_indexing(array_type):
+ # Regression test for #14510
+ # Boolean array-like does not behave as boolean array with sparse matrices.
+ X = np.transpose([[1, 2, 3], [4, 5, 6], [5, 6, 7], [8, 9, 10]])
+ X = array_type(X)
+ column_transformer = ColumnTransformer(
+ [("identity", FunctionTransformer(), [False, True, False, True])]
+ )
+ X_trans = column_transformer.fit_transform(X)
+ assert X_trans.shape == (3, 2)
+
+
+def test_n_features_in():
+ # make sure n_features_in is what is passed as input to the column
+ # transformer.
+
+ X = [[1, 2], [3, 4], [5, 6]]
+ ct = ColumnTransformer([("a", DoubleTrans(), [0]), ("b", DoubleTrans(), [1])])
+ assert not hasattr(ct, "n_features_in_")
+ ct.fit(X)
+ assert ct.n_features_in_ == 2
+
+
+@pytest.mark.parametrize(
+ "cols, pattern, include, exclude",
+ [
+ (["col_int", "col_float"], None, np.number, None),
+ (["col_int", "col_float"], None, None, object),
+ (["col_int", "col_float"], None, [int, float], None),
+ (["col_str"], None, [object], None),
+ (["col_str"], None, object, None),
+ (["col_float"], None, float, None),
+ (["col_float"], "at$", [np.number], None),
+ (["col_int"], None, [int], None),
+ (["col_int"], "^col_int", [np.number], None),
+ (["col_float", "col_str"], "float|str", None, None),
+ (["col_str"], "^col_s", None, [int]),
+ ([], "str$", float, None),
+ (["col_int", "col_float", "col_str"], None, [np.number, object], None),
+ ],
+)
+def test_make_column_selector_with_select_dtypes(cols, pattern, include, exclude):
+ pd = pytest.importorskip("modin.pandas")
+
+ X_df = pd.DataFrame(
+ {
+ "col_int": np.array([0, 1, 2], dtype=int),
+ "col_float": np.array([0.0, 1.0, 2.0], dtype=float),
+ "col_str": ["one", "two", "three"],
+ },
+ columns=["col_int", "col_float", "col_str"],
+ )
+
+ selector = make_column_selector(
+ dtype_include=include, dtype_exclude=exclude, pattern=pattern
+ )
+
+ assert_array_equal(selector(X_df), cols)
+
+
+def test_column_transformer_with_make_column_selector():
+ # Functional test for column transformer + column selector
+ pd = pytest.importorskip("modin.pandas")
+ X_df = pd.DataFrame(
+ {
+ "col_int": np.array([0, 1, 2], dtype=int),
+ "col_float": np.array([0.0, 1.0, 2.0], dtype=float),
+ "col_cat": ["one", "two", "one"],
+ "col_str": ["low", "middle", "high"],
+ },
+ columns=["col_int", "col_float", "col_cat", "col_str"],
+ )
+ X_df["col_str"] = X_df["col_str"].astype("category")
+
+ cat_selector = make_column_selector(dtype_include=["category", object])
+ num_selector = make_column_selector(dtype_include=np.number)
+
+ ohe = OneHotEncoder()
+ scaler = StandardScaler()
+
+ ct_selector = make_column_transformer((ohe, cat_selector), (scaler, num_selector))
+ ct_direct = make_column_transformer(
+ (ohe, ["col_cat", "col_str"]), (scaler, ["col_float", "col_int"])
+ )
+
+ X_selector = ct_selector.fit_transform(X_df)
+ X_direct = ct_direct.fit_transform(X_df)
+
+ assert_allclose(X_selector, X_direct)
+
+
+def test_make_column_selector_error():
+ selector = make_column_selector(dtype_include=np.number)
+ X = np.array([[0.1, 0.2]])
+ msg = "make_column_selector can only be applied to pandas dataframes"
+ with pytest.raises(ValueError, match=msg):
+ selector(X)
+
+
+def test_make_column_selector_pickle():
+ pd = pytest.importorskip("modin.pandas")
+
+ X_df = pd.DataFrame(
+ {
+ "col_int": np.array([0, 1, 2], dtype=int),
+ "col_float": np.array([0.0, 1.0, 2.0], dtype=float),
+ "col_str": ["one", "two", "three"],
+ },
+ columns=["col_int", "col_float", "col_str"],
+ )
+
+ selector = make_column_selector(dtype_include=[object])
+ selector_picked = pickle.loads(pickle.dumps(selector))
+
+ assert_array_equal(selector(X_df), selector_picked(X_df))
+
+
+@pytest.mark.parametrize(
+ "empty_col",
+ [[], np.array([], dtype=int), lambda x: []],
+ ids=["list", "array", "callable"],
+)
+def test_feature_names_empty_columns(empty_col):
+ pd = pytest.importorskip("modin.pandas")
+
+ df = pd.DataFrame({"col1": ["a", "a", "b"], "col2": ["z", "z", "z"]})
+
+ ct = ColumnTransformer(
+ transformers=[
+ ("ohe", OneHotEncoder(), ["col1", "col2"]),
+ ("empty_features", OneHotEncoder(), empty_col),
+ ],
+ )
+
+ ct.fit(df)
+ assert_array_equal(
+ ct.get_feature_names_out(), ["ohe__col1_a", "ohe__col1_b", "ohe__col2_z"]
+ )
+
+
+@pytest.mark.parametrize(
+ "selector",
+ [
+ [1],
+ lambda x: [1],
+ ["col2"],
+ lambda x: ["col2"],
+ [False, True],
+ lambda x: [False, True],
+ ],
+)
+def test_feature_names_out_pandas(selector):
+ """Checks name when selecting only the second column"""
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame({"col1": ["a", "a", "b"], "col2": ["z", "z", "z"]})
+ ct = ColumnTransformer([("ohe", OneHotEncoder(), selector)])
+ ct.fit(df)
+
+ assert_array_equal(ct.get_feature_names_out(), ["ohe__col2_z"])
+
+
+@pytest.mark.parametrize(
+ "selector", [[1], lambda x: [1], [False, True], lambda x: [False, True]]
+)
+def test_feature_names_out_non_pandas(selector):
+ """Checks name when selecting the second column with numpy array"""
+ X = [["a", "z"], ["a", "z"], ["b", "z"]]
+ ct = ColumnTransformer([("ohe", OneHotEncoder(), selector)])
+ ct.fit(X)
+
+ assert_array_equal(ct.get_feature_names_out(), ["ohe__x1_z"])
+
+
+@pytest.mark.parametrize("remainder", ["passthrough", StandardScaler()])
+def test_sk_visual_block_remainder(remainder):
+ # remainder='passthrough' or an estimator will be shown in repr_html
+ ohe = OneHotEncoder()
+ ct = ColumnTransformer(
+ transformers=[("ohe", ohe, ["col1", "col2"])], remainder=remainder
+ )
+ visual_block = ct._sk_visual_block_()
+ assert visual_block.names == ("ohe", "remainder")
+ assert visual_block.name_details == (["col1", "col2"], "")
+ assert visual_block.estimators == (ohe, remainder)
+
+
+def test_sk_visual_block_remainder_drop():
+ # remainder='drop' is not shown in repr_html
+ ohe = OneHotEncoder()
+ ct = ColumnTransformer(transformers=[("ohe", ohe, ["col1", "col2"])])
+ visual_block = ct._sk_visual_block_()
+ assert visual_block.names == ("ohe",)
+ assert visual_block.name_details == (["col1", "col2"],)
+ assert visual_block.estimators == (ohe,)
+
+
+@pytest.mark.parametrize("remainder", ["passthrough", StandardScaler()])
+def test_sk_visual_block_remainder_fitted_pandas(remainder):
+ # Remainder shows the columns after fitting
+ pd = pytest.importorskip("modin.pandas")
+ ohe = OneHotEncoder()
+ ct = ColumnTransformer(
+ transformers=[("ohe", ohe, ["col1", "col2"])], remainder=remainder
+ )
+ df = pd.DataFrame(
+ {
+ "col1": ["a", "b", "c"],
+ "col2": ["z", "z", "z"],
+ "col3": [1, 2, 3],
+ "col4": [3, 4, 5],
+ }
+ )
+ ct.fit(df)
+ visual_block = ct._sk_visual_block_()
+ assert visual_block.names == ("ohe", "remainder")
+ assert visual_block.name_details == (["col1", "col2"], ["col3", "col4"])
+ assert visual_block.estimators == (ohe, remainder)
+
+
+@pytest.mark.parametrize("remainder", ["passthrough", StandardScaler()])
+def test_sk_visual_block_remainder_fitted_numpy(remainder):
+ # Remainder shows the indices after fitting
+ X = np.array([[1, 2, 3], [4, 5, 6]], dtype=float)
+ scaler = StandardScaler()
+ ct = ColumnTransformer(
+ transformers=[("scale", scaler, [0, 2])], remainder=remainder
+ )
+ ct.fit(X)
+ visual_block = ct._sk_visual_block_()
+ assert visual_block.names == ("scale", "remainder")
+ assert visual_block.name_details == ([0, 2], [1])
+ assert visual_block.estimators == (scaler, remainder)
+
+
+@pytest.mark.parametrize("explicit_colname", ["first", "second", 0, 1])
+@pytest.mark.parametrize("remainder", [Trans(), "passthrough", "drop"])
+def test_column_transformer_reordered_column_names_remainder(
+ explicit_colname, remainder
+):
+ """Test the interaction between remainder and column transformer"""
+ pd = pytest.importorskip("modin.pandas")
+
+ X_fit_array = np.array([[0, 1, 2], [2, 4, 6]]).T
+ X_fit_df = pd.DataFrame(X_fit_array, columns=["first", "second"])
+
+ X_trans_array = np.array([[2, 4, 6], [0, 1, 2]]).T
+ X_trans_df = pd.DataFrame(X_trans_array, columns=["second", "first"])
+
+ tf = ColumnTransformer([("bycol", Trans(), explicit_colname)], remainder=remainder)
+
+ tf.fit(X_fit_df)
+ X_fit_trans = tf.transform(X_fit_df)
+
+ # Changing the order still works
+ X_trans = tf.transform(X_trans_df)
+ assert_allclose(X_trans, X_fit_trans)
+
+ # extra columns are ignored
+ X_extended_df = X_fit_df.copy()
+ X_extended_df["third"] = [3, 6, 9]
+ X_trans = tf.transform(X_extended_df)
+ assert_allclose(X_trans, X_fit_trans)
+
+ if isinstance(explicit_colname, str):
+ # Raise error if columns are specified by names but input only allows
+ # to specify by position, e.g. numpy array instead of a pandas df.
+ X_array = X_fit_array.copy()
+ err_msg = "Specifying the columns"
+ with pytest.raises(ValueError, match=err_msg):
+ tf.transform(X_array)
+
+
+def test_feature_name_validation_missing_columns_drop_passthough():
+ """Test the interaction between {'drop', 'passthrough'} and
+ missing column names."""
+ pd = pytest.importorskip("modin.pandas")
+
+ X = np.ones(shape=(3, 4))
+ df = pd.DataFrame(X, columns=["a", "b", "c", "d"])
+
+ df_dropped = df.drop("c", axis=1)
+
+ # with remainder='passthrough', all columns seen during `fit` must be
+ # present
+ tf = ColumnTransformer([("bycol", Trans(), [1])], remainder="passthrough")
+ tf.fit(df)
+ msg = r"columns are missing: {'c'}"
+ with pytest.raises(ValueError, match=msg):
+ tf.transform(df_dropped)
+
+ # with remainder='drop', it is allowed to have column 'c' missing
+ tf = ColumnTransformer([("bycol", Trans(), [1])], remainder="drop")
+ tf.fit(df)
+
+ df_dropped_trans = tf.transform(df_dropped)
+ df_fit_trans = tf.transform(df)
+ assert_allclose(df_dropped_trans, df_fit_trans)
+
+ # bycol drops 'c', thus it is allowed for 'c' to be missing
+ tf = ColumnTransformer([("bycol", "drop", ["c"])], remainder="passthrough")
+ tf.fit(df)
+ df_dropped_trans = tf.transform(df_dropped)
+ df_fit_trans = tf.transform(df)
+ assert_allclose(df_dropped_trans, df_fit_trans)
+
+
+def test_feature_names_in_():
+ """Feature names are stored in column transformer.
+
+ Column transformer deliberately does not check for column name consistency.
+ It only checks that the non-dropped names seen in `fit` are seen
+ in `transform`. This behavior is already tested in
+ `test_feature_name_validation_missing_columns_drop_passthough`"""
+
+ pd = pytest.importorskip("modin.pandas")
+
+ feature_names = ["a", "c", "d"]
+ df = pd.DataFrame([[1, 2, 3]], columns=feature_names)
+ ct = ColumnTransformer([("bycol", Trans(), ["a", "d"])], remainder="passthrough")
+
+ ct.fit(df)
+ assert_array_equal(ct.feature_names_in_, feature_names)
+ assert isinstance(ct.feature_names_in_, np.ndarray)
+ assert ct.feature_names_in_.dtype == object
+
+
+class TransWithNames(Trans):
+ def __init__(self, feature_names_out=None):
+ self.feature_names_out = feature_names_out
+
+ def get_feature_names_out(self, input_features=None):
+ if self.feature_names_out is not None:
+ return np.asarray(self.feature_names_out, dtype=object)
+ return input_features
+
+
+@pytest.mark.parametrize(
+ "transformers, remainder, expected_names",
+ [
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", ["d"]),
+ ],
+ "passthrough",
+ ["bycol1__d", "bycol1__c", "bycol2__d", "remainder__a", "remainder__b"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", ["d"]),
+ ],
+ "drop",
+ ["bycol1__d", "bycol1__c", "bycol2__d"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["b"]),
+ ("bycol2", "drop", ["d"]),
+ ],
+ "passthrough",
+ ["bycol1__b", "remainder__a", "remainder__c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["pca1", "pca2"]), ["a", "b", "d"]),
+ ],
+ "passthrough",
+ ["bycol1__pca1", "bycol1__pca2", "remainder__c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a", "b"]), ["d"]),
+ ("bycol2", "passthrough", ["b"]),
+ ],
+ "drop",
+ ["bycol1__a", "bycol1__b", "bycol2__b"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames([f"pca{i}" for i in range(2)]), ["b"]),
+ ("bycol2", TransWithNames([f"pca{i}" for i in range(2)]), ["b"]),
+ ],
+ "passthrough",
+ [
+ "bycol1__pca0",
+ "bycol1__pca1",
+ "bycol2__pca0",
+ "bycol2__pca1",
+ "remainder__a",
+ "remainder__c",
+ "remainder__d",
+ ],
+ ),
+ (
+ [
+ ("bycol1", "drop", ["d"]),
+ ],
+ "drop",
+ [],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), slice(1, 3)),
+ ],
+ "drop",
+ ["bycol1__b", "bycol1__c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["b"]),
+ ("bycol2", "drop", slice(3, 4)),
+ ],
+ "passthrough",
+ ["bycol1__b", "remainder__a", "remainder__c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", slice(3, 4)),
+ ],
+ "passthrough",
+ ["bycol1__d", "bycol1__c", "bycol2__d", "remainder__a", "remainder__b"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), slice("b", "c")),
+ ],
+ "drop",
+ ["bycol1__b", "bycol1__c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["b"]),
+ ("bycol2", "drop", slice("c", "d")),
+ ],
+ "passthrough",
+ ["bycol1__b", "remainder__a"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", slice("c", "d")),
+ ],
+ "passthrough",
+ [
+ "bycol1__d",
+ "bycol1__c",
+ "bycol2__c",
+ "bycol2__d",
+ "remainder__a",
+ "remainder__b",
+ ],
+ ),
+ ],
+)
+def test_verbose_feature_names_out_true(transformers, remainder, expected_names):
+ """Check feature_names_out for verbose_feature_names_out=True (default)"""
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame([[1, 2, 3, 4]], columns=["a", "b", "c", "d"])
+ ct = ColumnTransformer(
+ transformers,
+ remainder=remainder,
+ )
+ ct.fit(df)
+
+ names = ct.get_feature_names_out()
+ assert isinstance(names, np.ndarray)
+ assert names.dtype == object
+ assert_array_equal(names, expected_names)
+
+
+@pytest.mark.parametrize(
+ "transformers, remainder, expected_names",
+ [
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", ["a"]),
+ ],
+ "passthrough",
+ ["d", "c", "a", "b"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a"]), ["d", "c"]),
+ ("bycol2", "passthrough", ["d"]),
+ ],
+ "drop",
+ ["a", "d"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["b"]),
+ ("bycol2", "drop", ["d"]),
+ ],
+ "passthrough",
+ ["b", "a", "c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["pca1", "pca2"]), ["a", "b", "d"]),
+ ],
+ "passthrough",
+ ["pca1", "pca2", "c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a", "c"]), ["d"]),
+ ("bycol2", "passthrough", ["d"]),
+ ],
+ "drop",
+ ["a", "c", "d"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames([f"pca{i}" for i in range(2)]), ["b"]),
+ ("bycol2", TransWithNames([f"kpca{i}" for i in range(2)]), ["b"]),
+ ],
+ "passthrough",
+ ["pca0", "pca1", "kpca0", "kpca1", "a", "c", "d"],
+ ),
+ (
+ [
+ ("bycol1", "drop", ["d"]),
+ ],
+ "drop",
+ [],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), slice(1, 2)),
+ ("bycol2", "drop", ["d"]),
+ ],
+ "passthrough",
+ ["b", "a", "c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["b"]),
+ ("bycol2", "drop", slice(3, 4)),
+ ],
+ "passthrough",
+ ["b", "a", "c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", slice(0, 2)),
+ ],
+ "drop",
+ ["d", "c", "a", "b"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), slice("a", "b")),
+ ("bycol2", "drop", ["d"]),
+ ],
+ "passthrough",
+ ["a", "b", "c"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["b"]),
+ ("bycol2", "drop", slice("c", "d")),
+ ],
+ "passthrough",
+ ["b", "a"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", slice("a", "b")),
+ ],
+ "drop",
+ ["d", "c", "a", "b"],
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(), ["d", "c"]),
+ ("bycol2", "passthrough", slice("b", "b")),
+ ],
+ "drop",
+ ["d", "c", "b"],
+ ),
+ ],
+)
+def test_verbose_feature_names_out_false(transformers, remainder, expected_names):
+ """Check feature_names_out for verbose_feature_names_out=False"""
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame([[1, 2, 3, 4]], columns=["a", "b", "c", "d"])
+ ct = ColumnTransformer(
+ transformers,
+ remainder=remainder,
+ verbose_feature_names_out=False,
+ )
+ ct.fit(df)
+
+ names = ct.get_feature_names_out()
+ assert isinstance(names, np.ndarray)
+ assert names.dtype == object
+ assert_array_equal(names, expected_names)
+
+
+@pytest.mark.parametrize(
+ "transformers, remainder, colliding_columns",
+ [
+ (
+ [
+ ("bycol1", TransWithNames(), ["b"]),
+ ("bycol2", "passthrough", ["b"]),
+ ],
+ "drop",
+ "['b']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["c", "d"]), ["c"]),
+ ("bycol2", "passthrough", ["c"]),
+ ],
+ "drop",
+ "['c']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a"]), ["b"]),
+ ("bycol2", "passthrough", ["b"]),
+ ],
+ "passthrough",
+ "['a']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a"]), ["b"]),
+ ("bycol2", "drop", ["b"]),
+ ],
+ "passthrough",
+ "['a']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["c", "b"]), ["b"]),
+ ("bycol2", "passthrough", ["c", "b"]),
+ ],
+ "drop",
+ "['b', 'c']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a"]), ["b"]),
+ ("bycol2", "passthrough", ["a"]),
+ ("bycol3", TransWithNames(["a"]), ["b"]),
+ ],
+ "passthrough",
+ "['a']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a", "b"]), ["b"]),
+ ("bycol2", "passthrough", ["a"]),
+ ("bycol3", TransWithNames(["b"]), ["c"]),
+ ],
+ "passthrough",
+ "['a', 'b']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames([f"pca{i}" for i in range(6)]), ["b"]),
+ ("bycol2", TransWithNames([f"pca{i}" for i in range(6)]), ["b"]),
+ ],
+ "passthrough",
+ "['pca0', 'pca1', 'pca2', 'pca3', 'pca4', ...]",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a", "b"]), slice(1, 2)),
+ ("bycol2", "passthrough", ["a"]),
+ ("bycol3", TransWithNames(["b"]), ["c"]),
+ ],
+ "passthrough",
+ "['a', 'b']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a", "b"]), ["b"]),
+ ("bycol2", "passthrough", slice(0, 1)),
+ ("bycol3", TransWithNames(["b"]), ["c"]),
+ ],
+ "passthrough",
+ "['a', 'b']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a", "b"]), slice("b", "c")),
+ ("bycol2", "passthrough", ["a"]),
+ ("bycol3", TransWithNames(["b"]), ["c"]),
+ ],
+ "passthrough",
+ "['a', 'b']",
+ ),
+ (
+ [
+ ("bycol1", TransWithNames(["a", "b"]), ["b"]),
+ ("bycol2", "passthrough", slice("a", "a")),
+ ("bycol3", TransWithNames(["b"]), ["c"]),
+ ],
+ "passthrough",
+ "['a', 'b']",
+ ),
+ ],
+)
+def test_verbose_feature_names_out_false_errors(
+ transformers, remainder, colliding_columns
+):
+ """Check feature_names_out for verbose_feature_names_out=False"""
+
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame([[1, 2, 3, 4]], columns=["a", "b", "c", "d"])
+ ct = ColumnTransformer(
+ transformers,
+ remainder=remainder,
+ verbose_feature_names_out=False,
+ )
+ ct.fit(df)
+
+ msg = re.escape(
+ f"Output feature names: {colliding_columns} are not unique. Please set "
+ "verbose_feature_names_out=True to add prefixes to feature names"
+ )
+ with pytest.raises(ValueError, match=msg):
+ ct.get_feature_names_out()
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("verbose_feature_names_out", [True, False])
+@pytest.mark.parametrize("remainder", ["drop", "passthrough"])
+def test_column_transformer_set_output(verbose_feature_names_out, remainder):
+ """Check column transformer behavior with set_output."""
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame([[1, 2, 3, 4]], columns=["a", "b", "c", "d"], index=[10])
+ ct = ColumnTransformer(
+ [("first", TransWithNames(), ["a", "c"]), ("second", TransWithNames(), ["d"])],
+ remainder=remainder,
+ verbose_feature_names_out=verbose_feature_names_out,
+ )
+ X_trans = ct.fit_transform(df)
+ assert isinstance(X_trans, np.ndarray)
+
+ ct.set_output(transform="modin.pandas")
+
+ df_test = pd.DataFrame([[1, 2, 3, 4]], columns=df.columns, index=[20])
+ X_trans = ct.transform(df_test)
+ assert isinstance(X_trans, pd.DataFrame)
+
+ feature_names_out = ct.get_feature_names_out()
+ assert_array_equal(X_trans.columns, feature_names_out)
+ assert_array_equal(X_trans.index, df_test.index)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("remainder", ["drop", "passthrough"])
+@pytest.mark.parametrize("fit_transform", [True, False])
+def test_column_transform_set_output_mixed(remainder, fit_transform):
+ """Check ColumnTransformer outputs mixed types correctly."""
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame(
+ {
+ "pet": pd.Series(["dog", "cat", "snake"], dtype="category"),
+ "color": pd.Series(["green", "blue", "red"], dtype="object"),
+ "age": [1.4, 2.1, 4.4],
+ "height": [20, 40, 10],
+ "distance": pd.Series([20, pd.NA, 100], dtype="Int32"),
+ }
+ )
+ ct = ColumnTransformer(
+ [
+ (
+ "color_encode",
+ OneHotEncoder(sparse_output=False, dtype="int8"),
+ ["color"],
+ ),
+ ("age", StandardScaler(), ["age"]),
+ ],
+ remainder=remainder,
+ verbose_feature_names_out=False,
+ ).set_output(transform="pandas")
+ if fit_transform:
+ X_trans = ct.fit_transform(df)
+ else:
+ X_trans = ct.fit(df).transform(df)
+
+ assert isinstance(X_trans, pd.DataFrame)
+ assert_array_equal(X_trans.columns, ct.get_feature_names_out())
+
+ expected_dtypes = {
+ "color_blue": "int8",
+ "color_green": "int8",
+ "color_red": "int8",
+ "age": "float64",
+ "pet": "category",
+ "height": "int64",
+ "distance": "Int32",
+ }
+ for col, dtype in X_trans.dtypes.items():
+ assert dtype == expected_dtypes[col]
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("remainder", ["drop", "passthrough"])
+def test_column_transform_set_output_after_fitting(remainder):
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame(
+ {
+ "pet": pd.Series(["dog", "cat", "snake"], dtype="category"),
+ "age": [1.4, 2.1, 4.4],
+ "height": [20, 40, 10],
+ }
+ )
+ ct = ColumnTransformer(
+ [
+ (
+ "color_encode",
+ OneHotEncoder(sparse_output=False, dtype="int16"),
+ ["pet"],
+ ),
+ ("age", StandardScaler(), ["age"]),
+ ],
+ remainder=remainder,
+ verbose_feature_names_out=False,
+ )
+
+ # fit without calling set_output
+ X_trans = ct.fit_transform(df)
+ assert isinstance(X_trans, np.ndarray)
+ assert X_trans.dtype == "float64"
+
+ ct.set_output(transform="modin.pandas")
+ X_trans_df = ct.transform(df)
+ expected_dtypes = {
+ "pet_cat": "int16",
+ "pet_dog": "int16",
+ "pet_snake": "int16",
+ "height": "int64",
+ "age": "float64",
+ }
+ for col, dtype in X_trans_df.dtypes.items():
+ assert dtype == expected_dtypes[col]
+
+
+# PandasOutTransformer that does not define get_feature_names_out and always expects
+# the input to be a DataFrame.
+class PandasOutTransformer(BaseEstimator):
+ def __init__(self, offset=1.0):
+ self.offset = offset
+
+ def fit(self, X, y=None):
+ pd = pytest.importorskip("modin.pandas")
+ assert isinstance(X, pd.DataFrame)
+ return self
+
+ def transform(self, X, y=None):
+ pd = pytest.importorskip("modin.pandas")
+ assert isinstance(X, pd.DataFrame)
+ return X - self.offset
+
+ def set_output(self, transform=None):
+ # This transformer will always output a DataFrame regardless of the
+ # configuration.
+ return self
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize(
+ "trans_1, expected_verbose_names, expected_non_verbose_names",
+ [
+ (
+ PandasOutTransformer(offset=2.0),
+ ["trans_0__feat1", "trans_1__feat0"],
+ ["feat1", "feat0"],
+ ),
+ (
+ "drop",
+ ["trans_0__feat1"],
+ ["feat1"],
+ ),
+ (
+ "passthrough",
+ ["trans_0__feat1", "trans_1__feat0"],
+ ["feat1", "feat0"],
+ ),
+ ],
+)
+def test_transformers_with_pandas_out_but_not_feature_names_out(
+ trans_1, expected_verbose_names, expected_non_verbose_names
+):
+ """Check that set_config(transform="pandas") is compatible with more transformers.
+
+ Specifically, if transformers returns a DataFrame, but does not define
+ `get_feature_names_out`.
+ """
+ pd = pytest.importorskip("modin.pandas")
+
+ X_df = pd.DataFrame({"feat0": [1.0, 2.0, 3.0], "feat1": [2.0, 3.0, 4.0]})
+ ct = ColumnTransformer(
+ [
+ ("trans_0", PandasOutTransformer(offset=3.0), ["feat1"]),
+ ("trans_1", trans_1, ["feat0"]),
+ ]
+ )
+ X_trans_np = ct.fit_transform(X_df)
+ assert isinstance(X_trans_np, np.ndarray)
+
+ # `ct` does not have `get_feature_names_out` because `PandasOutTransformer` does
+ # not define the method.
+ with pytest.raises(AttributeError, match="not provide get_feature_names_out"):
+ ct.get_feature_names_out()
+
+ # The feature names are prefixed because verbose_feature_names_out=True is default
+ ct.set_output(transform="modin.pandas")
+ X_trans_df0 = ct.fit_transform(X_df)
+ assert_array_equal(X_trans_df0.columns, expected_verbose_names)
+
+ ct.set_params(verbose_feature_names_out=False)
+ X_trans_df1 = ct.fit_transform(X_df)
+ assert_array_equal(X_trans_df1.columns, expected_non_verbose_names)
diff --git a/modin/pandas/test/interoperability/sklearn/conftest.py b/modin/pandas/test/interoperability/sklearn/conftest.py
new file mode 100644
index 00000000000..42f036e01ef
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/conftest.py
@@ -0,0 +1,252 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+from os import environ
+from functools import wraps
+import platform
+import sys
+
+import pytest
+import numpy as np
+from threadpoolctl import threadpool_limits
+from _pytest.doctest import DoctestItem
+
+from sklearn.utils import _IS_32BIT
+from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
+from sklearn._min_dependencies import PYTEST_MIN_VERSION
+from sklearn.utils.fixes import parse_version
+from sklearn.datasets import fetch_20newsgroups
+from sklearn.datasets import fetch_20newsgroups_vectorized
+from sklearn.datasets import fetch_california_housing
+from sklearn.datasets import fetch_covtype
+from sklearn.datasets import fetch_kddcup99
+from sklearn.datasets import fetch_olivetti_faces
+from sklearn.datasets import fetch_rcv1
+from sklearn.tests import random_seed
+
+
+if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION):
+ raise ImportError(
+ "Your version of pytest is too old, you should have "
+ "at least pytest >= {} installed.".format(PYTEST_MIN_VERSION)
+ )
+
+dataset_fetchers = {
+ "fetch_20newsgroups_fxt": fetch_20newsgroups,
+ "fetch_20newsgroups_vectorized_fxt": fetch_20newsgroups_vectorized,
+ "fetch_california_housing_fxt": fetch_california_housing,
+ "fetch_covtype_fxt": fetch_covtype,
+ "fetch_kddcup99_fxt": fetch_kddcup99,
+ "fetch_olivetti_faces_fxt": fetch_olivetti_faces,
+ "fetch_rcv1_fxt": fetch_rcv1,
+}
+
+_SKIP32_MARK = pytest.mark.skipif(
+ environ.get("SKLEARN_RUN_FLOAT32_TESTS", "0") != "1",
+ reason="Set SKLEARN_RUN_FLOAT32_TESTS=1 to run float32 dtype tests",
+)
+
+
+# Global fixtures
+@pytest.fixture(params=[pytest.param(np.float32, marks=_SKIP32_MARK), np.float64])
+def global_dtype(request):
+ yield request.param
+
+
+def _fetch_fixture(f):
+ """Fetch dataset (download if missing and requested by environment)."""
+ download_if_missing = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
+
+ @wraps(f)
+ def wrapped(*args, **kwargs):
+ kwargs["download_if_missing"] = download_if_missing
+ try:
+ return f(*args, **kwargs)
+ except IOError as e:
+ if str(e) != "Data not found and `download_if_missing` is False":
+ raise
+ pytest.skip("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0")
+
+ return pytest.fixture(lambda: wrapped)
+
+
+# Adds fixtures for fetching data
+fetch_20newsgroups_fxt = _fetch_fixture(fetch_20newsgroups)
+fetch_20newsgroups_vectorized_fxt = _fetch_fixture(fetch_20newsgroups_vectorized)
+fetch_california_housing_fxt = _fetch_fixture(fetch_california_housing)
+fetch_covtype_fxt = _fetch_fixture(fetch_covtype)
+fetch_kddcup99_fxt = _fetch_fixture(fetch_kddcup99)
+fetch_olivetti_faces_fxt = _fetch_fixture(fetch_olivetti_faces)
+fetch_rcv1_fxt = _fetch_fixture(fetch_rcv1)
+
+
+def pytest_collection_modifyitems(config, items):
+ """Called after collect is completed.
+
+ Parameters
+ ----------
+ config : pytest config
+ items : list of collected items
+ """
+ run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
+ skip_network = pytest.mark.skip(
+ reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0"
+ )
+
+ # download datasets during collection to avoid thread unsafe behavior
+ # when running pytest in parallel with pytest-xdist
+ dataset_features_set = set(dataset_fetchers)
+ datasets_to_download = set()
+
+ for item in items:
+ if not hasattr(item, "fixturenames"):
+ continue
+ item_fixtures = set(item.fixturenames)
+ dataset_to_fetch = item_fixtures & dataset_features_set
+ if not dataset_to_fetch:
+ continue
+
+ if run_network_tests:
+ datasets_to_download |= dataset_to_fetch
+ else:
+ # network tests are skipped
+ item.add_marker(skip_network)
+
+ # Only download datasets on the first worker spawned by pytest-xdist
+ # to avoid thread unsafe behavior. If pytest-xdist is not used, we still
+ # download before tests run.
+ worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0")
+ if worker_id == "gw0" and run_network_tests:
+ for name in datasets_to_download:
+ dataset_fetchers[name]()
+
+ for item in items:
+ # Known failure on with GradientBoostingClassifier on ARM64
+ if (
+ item.name.endswith("GradientBoostingClassifier")
+ and platform.machine() == "aarch64"
+ ):
+ marker = pytest.mark.xfail(
+ reason=(
+ "know failure. See "
+ "https://github.com/scikit-learn/scikit-learn/issues/17797" # noqa
+ )
+ )
+ item.add_marker(marker)
+
+ skip_doctests = False
+ try:
+ import matplotlib # noqa
+ except ImportError:
+ skip_doctests = True
+ reason = "matplotlib is required to run the doctests"
+
+ if _IS_32BIT:
+ reason = "doctest are only run when the default numpy int is 64 bits."
+ skip_doctests = True
+ elif sys.platform.startswith("win32"):
+ reason = (
+ "doctests are not run for Windows because numpy arrays "
+ "repr is inconsistent across platforms."
+ )
+ skip_doctests = True
+
+ # Normally doctest has the entire module's scope. Here we set globs to an empty dict
+ # to remove the module's scope:
+ # https://docs.python.org/3/library/doctest.html#what-s-the-execution-context
+ for item in items:
+ if isinstance(item, DoctestItem):
+ item.dtest.globs = {}
+
+ if skip_doctests:
+ skip_marker = pytest.mark.skip(reason=reason)
+
+ for item in items:
+ if isinstance(item, DoctestItem):
+ # work-around an internal error with pytest if adding a skip
+ # mark to a doctest in a contextmanager, see
+ # https://github.com/pytest-dev/pytest/issues/8796 for more
+ # details.
+ if item.name != "sklearn._config.config_context":
+ item.add_marker(skip_marker)
+ try:
+ import PIL # noqa
+
+ pillow_installed = True
+ except ImportError:
+ pillow_installed = False
+
+ if not pillow_installed:
+ skip_marker = pytest.mark.skip(reason="pillow (or PIL) not installed!")
+ for item in items:
+ if item.name in [
+ "sklearn.feature_extraction.image.PatchExtractor",
+ "sklearn.feature_extraction.image.extract_patches_2d",
+ ]:
+ item.add_marker(skip_marker)
+
+
+@pytest.fixture(scope="function")
+def pyplot():
+ """Setup and teardown fixture for matplotlib.
+
+ This fixture checks if we can import matplotlib. If not, the tests will be
+ skipped. Otherwise, we close the figures before and after running the
+ functions.
+
+ Returns
+ -------
+ pyplot : module
+ The ``matplotlib.pyplot`` module.
+ """
+ pyplot = pytest.importorskip("matplotlib.pyplot")
+ pyplot.close("all")
+ yield pyplot
+ pyplot.close("all")
+
+
+def pytest_runtest_setup(item):
+ """Set the number of openmp threads based on the number of workers
+ xdist is using to prevent oversubscription.
+
+ Parameters
+ ----------
+ item : pytest item
+ item to be processed
+ """
+ xdist_worker_count = environ.get("PYTEST_XDIST_WORKER_COUNT")
+ if xdist_worker_count is None:
+ # returns if pytest-xdist is not installed
+ return
+ else:
+ xdist_worker_count = int(xdist_worker_count)
+
+ openmp_threads = _openmp_effective_n_threads()
+ threads_per_worker = max(openmp_threads // xdist_worker_count, 1)
+ threadpool_limits(threads_per_worker, user_api="openmp")
+
+
+def pytest_configure(config):
+ # Use matplotlib agg backend during the tests including doctests
+ try:
+ import matplotlib
+
+ matplotlib.use("agg")
+ except ImportError:
+ pass
+
+ # Register global_random_seed plugin if it is not already registered
+ if not config.pluginmanager.hasplugin("sklearn.tests.random_seed"):
+ config.pluginmanager.register(random_seed)
diff --git a/modin/pandas/test/interoperability/sklearn/cross_decomposition/test_pls.py b/modin/pandas/test/interoperability/sklearn/cross_decomposition/test_pls.py
new file mode 100644
index 00000000000..50e2be83ace
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/cross_decomposition/test_pls.py
@@ -0,0 +1,650 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import pytest
+import warnings
+import numpy as np
+from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_allclose
+from sklearn.datasets import load_linnerud
+from sklearn.cross_decomposition._pls import (
+ _center_scale_xy,
+ _get_first_singular_vectors_power_method,
+ _get_first_singular_vectors_svd,
+ _svd_flip_1d,
+)
+from sklearn.cross_decomposition import CCA
+from sklearn.cross_decomposition import PLSSVD, PLSRegression, PLSCanonical
+from sklearn.datasets import make_regression
+from sklearn.utils import check_random_state
+from sklearn.utils.extmath import svd_flip
+from sklearn.exceptions import ConvergenceWarning
+
+
+def assert_matrix_orthogonal(M):
+ K = np.dot(M.T, M)
+ assert_array_almost_equal(K, np.diag(np.diag(K)))
+
+
+def test_pls_canonical_basics():
+ # Basic checks for PLSCanonical
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+
+ pls = PLSCanonical(n_components=X.shape[1])
+ pls.fit(X, Y)
+
+ assert_matrix_orthogonal(pls.x_weights_)
+ assert_matrix_orthogonal(pls.y_weights_)
+ assert_matrix_orthogonal(pls._x_scores)
+ assert_matrix_orthogonal(pls._y_scores)
+
+ # Check X = TP' and Y = UQ'
+ T = pls._x_scores
+ P = pls.x_loadings_
+ U = pls._y_scores
+ Q = pls.y_loadings_
+ # Need to scale first
+ Xc, Yc, x_mean, y_mean, x_std, y_std = _center_scale_xy(
+ X.copy(), Y.copy(), scale=True
+ )
+ assert_array_almost_equal(Xc, np.dot(T, P.T))
+ assert_array_almost_equal(Yc, np.dot(U, Q.T))
+
+ # Check that rotations on training data lead to scores
+ Xt = pls.transform(X)
+ assert_array_almost_equal(Xt, pls._x_scores)
+ Xt, Yt = pls.transform(X, Y)
+ assert_array_almost_equal(Xt, pls._x_scores)
+ assert_array_almost_equal(Yt, pls._y_scores)
+
+ # Check that inverse_transform works
+ X_back = pls.inverse_transform(Xt)
+ assert_array_almost_equal(X_back, X)
+ _, Y_back = pls.inverse_transform(Xt, Yt)
+ assert_array_almost_equal(Y_back, Y)
+
+
+def test_sanity_check_pls_regression():
+ # Sanity check for PLSRegression
+ # The results were checked against the R-packages plspm, misOmics and pls
+
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+
+ pls = PLSRegression(n_components=X.shape[1])
+ X_trans, _ = pls.fit_transform(X, Y)
+
+ # FIXME: one would expect y_trans == pls.y_scores_ but this is not
+ # the case.
+ # xref: https://github.com/scikit-learn/scikit-learn/issues/22420
+ assert_allclose(X_trans, pls.x_scores_)
+
+ expected_x_weights = np.array(
+ [
+ [-0.61330704, -0.00443647, 0.78983213],
+ [-0.74697144, -0.32172099, -0.58183269],
+ [-0.25668686, 0.94682413, -0.19399983],
+ ]
+ )
+
+ expected_x_loadings = np.array(
+ [
+ [-0.61470416, -0.24574278, 0.78983213],
+ [-0.65625755, -0.14396183, -0.58183269],
+ [-0.51733059, 1.00609417, -0.19399983],
+ ]
+ )
+
+ expected_y_weights = np.array(
+ [
+ [+0.32456184, 0.29892183, 0.20316322],
+ [+0.42439636, 0.61970543, 0.19320542],
+ [-0.13143144, -0.26348971, -0.17092916],
+ ]
+ )
+
+ expected_y_loadings = np.array(
+ [
+ [+0.32456184, 0.29892183, 0.20316322],
+ [+0.42439636, 0.61970543, 0.19320542],
+ [-0.13143144, -0.26348971, -0.17092916],
+ ]
+ )
+
+ assert_array_almost_equal(np.abs(pls.x_loadings_), np.abs(expected_x_loadings))
+ assert_array_almost_equal(np.abs(pls.x_weights_), np.abs(expected_x_weights))
+ assert_array_almost_equal(np.abs(pls.y_loadings_), np.abs(expected_y_loadings))
+ assert_array_almost_equal(np.abs(pls.y_weights_), np.abs(expected_y_weights))
+
+ # The R / Python difference in the signs should be consistent across
+ # loadings, weights, etc.
+ x_loadings_sign_flip = np.sign(pls.x_loadings_ / expected_x_loadings)
+ x_weights_sign_flip = np.sign(pls.x_weights_ / expected_x_weights)
+ y_weights_sign_flip = np.sign(pls.y_weights_ / expected_y_weights)
+ y_loadings_sign_flip = np.sign(pls.y_loadings_ / expected_y_loadings)
+ assert_array_almost_equal(x_loadings_sign_flip, x_weights_sign_flip)
+ assert_array_almost_equal(y_loadings_sign_flip, y_weights_sign_flip)
+
+
+def test_sanity_check_pls_regression_constant_column_Y():
+ # Check behavior when the first column of Y is constant
+ # The results are checked against a modified version of plsreg2
+ # from the R-package plsdepot
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+ Y[:, 0] = 1
+ pls = PLSRegression(n_components=X.shape[1])
+ pls.fit(X, Y)
+
+ expected_x_weights = np.array(
+ [
+ [-0.6273573, 0.007081799, 0.7786994],
+ [-0.7493417, -0.277612681, -0.6011807],
+ [-0.2119194, 0.960666981, -0.1794690],
+ ]
+ )
+
+ expected_x_loadings = np.array(
+ [
+ [-0.6273512, -0.22464538, 0.7786994],
+ [-0.6643156, -0.09871193, -0.6011807],
+ [-0.5125877, 1.01407380, -0.1794690],
+ ]
+ )
+
+ expected_y_loadings = np.array(
+ [
+ [0.0000000, 0.0000000, 0.0000000],
+ [0.4357300, 0.5828479, 0.2174802],
+ [-0.1353739, -0.2486423, -0.1810386],
+ ]
+ )
+
+ assert_array_almost_equal(np.abs(expected_x_weights), np.abs(pls.x_weights_))
+ assert_array_almost_equal(np.abs(expected_x_loadings), np.abs(pls.x_loadings_))
+ # For the PLSRegression with default parameters, y_loadings == y_weights
+ assert_array_almost_equal(np.abs(pls.y_loadings_), np.abs(expected_y_loadings))
+ assert_array_almost_equal(np.abs(pls.y_weights_), np.abs(expected_y_loadings))
+
+ x_loadings_sign_flip = np.sign(expected_x_loadings / pls.x_loadings_)
+ x_weights_sign_flip = np.sign(expected_x_weights / pls.x_weights_)
+ # we ignore the first full-zeros row for y
+ y_loadings_sign_flip = np.sign(expected_y_loadings[1:] / pls.y_loadings_[1:])
+
+ assert_array_equal(x_loadings_sign_flip, x_weights_sign_flip)
+ assert_array_equal(x_loadings_sign_flip[1:], y_loadings_sign_flip)
+
+
+def test_sanity_check_pls_canonical():
+ # Sanity check for PLSCanonical
+ # The results were checked against the R-package plspm
+
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+
+ pls = PLSCanonical(n_components=X.shape[1])
+ pls.fit(X, Y)
+
+ expected_x_weights = np.array(
+ [
+ [-0.61330704, 0.25616119, -0.74715187],
+ [-0.74697144, 0.11930791, 0.65406368],
+ [-0.25668686, -0.95924297, -0.11817271],
+ ]
+ )
+
+ expected_x_rotations = np.array(
+ [
+ [-0.61330704, 0.41591889, -0.62297525],
+ [-0.74697144, 0.31388326, 0.77368233],
+ [-0.25668686, -0.89237972, -0.24121788],
+ ]
+ )
+
+ expected_y_weights = np.array(
+ [
+ [+0.58989127, 0.7890047, 0.1717553],
+ [+0.77134053, -0.61351791, 0.16920272],
+ [-0.23887670, -0.03267062, 0.97050016],
+ ]
+ )
+
+ expected_y_rotations = np.array(
+ [
+ [+0.58989127, 0.7168115, 0.30665872],
+ [+0.77134053, -0.70791757, 0.19786539],
+ [-0.23887670, -0.00343595, 0.94162826],
+ ]
+ )
+
+ assert_array_almost_equal(np.abs(pls.x_rotations_), np.abs(expected_x_rotations))
+ assert_array_almost_equal(np.abs(pls.x_weights_), np.abs(expected_x_weights))
+ assert_array_almost_equal(np.abs(pls.y_rotations_), np.abs(expected_y_rotations))
+ assert_array_almost_equal(np.abs(pls.y_weights_), np.abs(expected_y_weights))
+
+ x_rotations_sign_flip = np.sign(pls.x_rotations_ / expected_x_rotations)
+ x_weights_sign_flip = np.sign(pls.x_weights_ / expected_x_weights)
+ y_rotations_sign_flip = np.sign(pls.y_rotations_ / expected_y_rotations)
+ y_weights_sign_flip = np.sign(pls.y_weights_ / expected_y_weights)
+ assert_array_almost_equal(x_rotations_sign_flip, x_weights_sign_flip)
+ assert_array_almost_equal(y_rotations_sign_flip, y_weights_sign_flip)
+
+ assert_matrix_orthogonal(pls.x_weights_)
+ assert_matrix_orthogonal(pls.y_weights_)
+
+ assert_matrix_orthogonal(pls._x_scores)
+ assert_matrix_orthogonal(pls._y_scores)
+
+
+def test_sanity_check_pls_canonical_random():
+ # Sanity check for PLSCanonical on random data
+ # The results were checked against the R-package plspm
+ n = 500
+ p_noise = 10
+ q_noise = 5
+ # 2 latents vars:
+ rng = check_random_state(11)
+ l1 = rng.normal(size=n)
+ l2 = rng.normal(size=n)
+ latents = np.array([l1, l1, l2, l2]).T
+ X = latents + rng.normal(size=4 * n).reshape((n, 4))
+ Y = latents + rng.normal(size=4 * n).reshape((n, 4))
+ X = np.concatenate((X, rng.normal(size=p_noise * n).reshape(n, p_noise)), axis=1)
+ Y = np.concatenate((Y, rng.normal(size=q_noise * n).reshape(n, q_noise)), axis=1)
+
+ pls = PLSCanonical(n_components=3)
+ pls.fit(X, Y)
+
+ expected_x_weights = np.array(
+ [
+ [0.65803719, 0.19197924, 0.21769083],
+ [0.7009113, 0.13303969, -0.15376699],
+ [0.13528197, -0.68636408, 0.13856546],
+ [0.16854574, -0.66788088, -0.12485304],
+ [-0.03232333, -0.04189855, 0.40690153],
+ [0.1148816, -0.09643158, 0.1613305],
+ [0.04792138, -0.02384992, 0.17175319],
+ [-0.06781, -0.01666137, -0.18556747],
+ [-0.00266945, -0.00160224, 0.11893098],
+ [-0.00849528, -0.07706095, 0.1570547],
+ [-0.00949471, -0.02964127, 0.34657036],
+ [-0.03572177, 0.0945091, 0.3414855],
+ [0.05584937, -0.02028961, -0.57682568],
+ [0.05744254, -0.01482333, -0.17431274],
+ ]
+ )
+
+ expected_x_loadings = np.array(
+ [
+ [0.65649254, 0.1847647, 0.15270699],
+ [0.67554234, 0.15237508, -0.09182247],
+ [0.19219925, -0.67750975, 0.08673128],
+ [0.2133631, -0.67034809, -0.08835483],
+ [-0.03178912, -0.06668336, 0.43395268],
+ [0.15684588, -0.13350241, 0.20578984],
+ [0.03337736, -0.03807306, 0.09871553],
+ [-0.06199844, 0.01559854, -0.1881785],
+ [0.00406146, -0.00587025, 0.16413253],
+ [-0.00374239, -0.05848466, 0.19140336],
+ [0.00139214, -0.01033161, 0.32239136],
+ [-0.05292828, 0.0953533, 0.31916881],
+ [0.04031924, -0.01961045, -0.65174036],
+ [0.06172484, -0.06597366, -0.1244497],
+ ]
+ )
+
+ expected_y_weights = np.array(
+ [
+ [0.66101097, 0.18672553, 0.22826092],
+ [0.69347861, 0.18463471, -0.23995597],
+ [0.14462724, -0.66504085, 0.17082434],
+ [0.22247955, -0.6932605, -0.09832993],
+ [0.07035859, 0.00714283, 0.67810124],
+ [0.07765351, -0.0105204, -0.44108074],
+ [-0.00917056, 0.04322147, 0.10062478],
+ [-0.01909512, 0.06182718, 0.28830475],
+ [0.01756709, 0.04797666, 0.32225745],
+ ]
+ )
+
+ expected_y_loadings = np.array(
+ [
+ [0.68568625, 0.1674376, 0.0969508],
+ [0.68782064, 0.20375837, -0.1164448],
+ [0.11712173, -0.68046903, 0.12001505],
+ [0.17860457, -0.6798319, -0.05089681],
+ [0.06265739, -0.0277703, 0.74729584],
+ [0.0914178, 0.00403751, -0.5135078],
+ [-0.02196918, -0.01377169, 0.09564505],
+ [-0.03288952, 0.09039729, 0.31858973],
+ [0.04287624, 0.05254676, 0.27836841],
+ ]
+ )
+
+ assert_array_almost_equal(np.abs(pls.x_loadings_), np.abs(expected_x_loadings))
+ assert_array_almost_equal(np.abs(pls.x_weights_), np.abs(expected_x_weights))
+ assert_array_almost_equal(np.abs(pls.y_loadings_), np.abs(expected_y_loadings))
+ assert_array_almost_equal(np.abs(pls.y_weights_), np.abs(expected_y_weights))
+
+ x_loadings_sign_flip = np.sign(pls.x_loadings_ / expected_x_loadings)
+ x_weights_sign_flip = np.sign(pls.x_weights_ / expected_x_weights)
+ y_weights_sign_flip = np.sign(pls.y_weights_ / expected_y_weights)
+ y_loadings_sign_flip = np.sign(pls.y_loadings_ / expected_y_loadings)
+ assert_array_almost_equal(x_loadings_sign_flip, x_weights_sign_flip)
+ assert_array_almost_equal(y_loadings_sign_flip, y_weights_sign_flip)
+
+ assert_matrix_orthogonal(pls.x_weights_)
+ assert_matrix_orthogonal(pls.y_weights_)
+
+ assert_matrix_orthogonal(pls._x_scores)
+ assert_matrix_orthogonal(pls._y_scores)
+
+
+def test_convergence_fail():
+ # Make sure ConvergenceWarning is raised if max_iter is too small
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+ pls_nipals = PLSCanonical(n_components=X.shape[1], max_iter=2)
+ with pytest.warns(ConvergenceWarning):
+ pls_nipals.fit(X, Y)
+
+
+@pytest.mark.parametrize("Est", (PLSSVD, PLSRegression, PLSCanonical))
+def test_attibutes_shapes(Est):
+ # Make sure attributes are of the correct shape depending on n_components
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+ n_components = 2
+ pls = Est(n_components=n_components)
+ pls.fit(X, Y)
+ assert all(
+ attr.shape[1] == n_components for attr in (pls.x_weights_, pls.y_weights_)
+ )
+
+
+# TODO(1.3): remove the warning filter
+@pytest.mark.filterwarnings(
+ "ignore:The attribute `coef_` will be transposed in version 1.3"
+)
+@pytest.mark.parametrize("Est", (PLSRegression, PLSCanonical, CCA))
+def test_univariate_equivalence(Est):
+ # Ensure 2D Y with 1 column is equivalent to 1D Y
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+
+ est = Est(n_components=1)
+ one_d_coeff = est.fit(X, Y[:, 0]).coef_
+ two_d_coeff = est.fit(X, Y[:, :1]).coef_
+
+ assert one_d_coeff.shape == two_d_coeff.shape
+ assert_array_almost_equal(one_d_coeff, two_d_coeff)
+
+
+@pytest.mark.parametrize("Est", (PLSRegression, PLSCanonical, CCA, PLSSVD))
+def test_copy(Est):
+ # check that the "copy" keyword works
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+ X_orig = X.copy()
+
+ # copy=True won't modify inplace
+ pls = Est(copy=True).fit(X, Y)
+ assert_array_equal(X, X_orig)
+
+ # copy=False will modify inplace
+ with pytest.raises(AssertionError):
+ Est(copy=False).fit(X, Y)
+ assert_array_almost_equal(X, X_orig)
+
+ if Est is PLSSVD:
+ return # PLSSVD does not support copy param in predict or transform
+
+ X_orig = X.copy()
+ with pytest.raises(AssertionError):
+ pls.transform(X, Y, copy=False),
+ assert_array_almost_equal(X, X_orig)
+
+ X_orig = X.copy()
+ with pytest.raises(AssertionError):
+ pls.predict(X, copy=False),
+ assert_array_almost_equal(X, X_orig)
+
+ # Make sure copy=True gives same transform and predictions as predict=False
+ assert_array_almost_equal(
+ pls.transform(X, Y, copy=True), pls.transform(X.copy(), Y.copy(), copy=False)
+ )
+ assert_array_almost_equal(
+ pls.predict(X, copy=True), pls.predict(X.copy(), copy=False)
+ )
+
+
+def _generate_test_scale_and_stability_datasets():
+ """Generate dataset for test_scale_and_stability"""
+ # dataset for non-regression 7818
+ rng = np.random.RandomState(0)
+ n_samples = 1000
+ n_targets = 5
+ n_features = 10
+ Q = rng.randn(n_targets, n_features)
+ Y = rng.randn(n_samples, n_targets)
+ X = np.dot(Y, Q) + 2 * rng.randn(n_samples, n_features) + 1
+ X *= 1000
+ yield X, Y
+
+ # Data set where one of the features is constraint
+ X, Y = load_linnerud(return_X_y=True)
+ # causes X[:, -1].std() to be zero
+ X[:, -1] = 1.0
+ yield X, Y
+
+ X = np.array([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [2.0, 2.0, 2.0], [3.0, 5.0, 4.0]])
+ Y = np.array([[0.1, -0.2], [0.9, 1.1], [6.2, 5.9], [11.9, 12.3]])
+ yield X, Y
+
+ # Seeds that provide a non-regression test for #18746, where CCA fails
+ seeds = [530, 741]
+ for seed in seeds:
+ rng = np.random.RandomState(seed)
+ X = rng.randn(4, 3)
+ Y = rng.randn(4, 2)
+ yield X, Y
+
+
+@pytest.mark.parametrize("Est", (CCA, PLSCanonical, PLSRegression, PLSSVD))
+@pytest.mark.parametrize("X, Y", _generate_test_scale_and_stability_datasets())
+def test_scale_and_stability(Est, X, Y):
+ """scale=True is equivalent to scale=False on centered/scaled data
+ This allows to check numerical stability over platforms as well"""
+
+ X_s, Y_s, *_ = _center_scale_xy(X, Y)
+
+ X_score, Y_score = Est(scale=True).fit_transform(X, Y)
+ X_s_score, Y_s_score = Est(scale=False).fit_transform(X_s, Y_s)
+
+ assert_allclose(X_s_score, X_score, atol=1e-4)
+ assert_allclose(Y_s_score, Y_score, atol=1e-4)
+
+
+@pytest.mark.parametrize("Estimator", (PLSSVD, PLSRegression, PLSCanonical, CCA))
+def test_n_components_upper_bounds(Estimator):
+ """Check the validation of `n_components` upper bounds for `PLS` regressors."""
+ rng = np.random.RandomState(0)
+ X = rng.randn(10, 5)
+ Y = rng.randn(10, 3)
+ est = Estimator(n_components=10)
+ err_msg = "`n_components` upper bound is .*. Got 10 instead. Reduce `n_components`."
+ with pytest.raises(ValueError, match=err_msg):
+ est.fit(X, Y)
+
+
+@pytest.mark.parametrize("n_samples, n_features", [(100, 10), (100, 200)])
+@pytest.mark.parametrize("seed", range(10))
+def test_singular_value_helpers(n_samples, n_features, seed):
+ # Make sure SVD and power method give approximately the same results
+ X, Y = make_regression(n_samples, n_features, n_targets=5, random_state=seed)
+ u1, v1, _ = _get_first_singular_vectors_power_method(X, Y, norm_y_weights=True)
+ u2, v2 = _get_first_singular_vectors_svd(X, Y)
+
+ _svd_flip_1d(u1, v1)
+ _svd_flip_1d(u2, v2)
+
+ rtol = 1e-1
+ assert_allclose(u1, u2, rtol=rtol)
+ assert_allclose(v1, v2, rtol=rtol)
+
+
+def test_one_component_equivalence():
+ # PLSSVD, PLSRegression and PLSCanonical should all be equivalent when
+ # n_components is 1
+ X, Y = make_regression(100, 10, n_targets=5, random_state=0)
+ svd = PLSSVD(n_components=1).fit(X, Y).transform(X)
+ reg = PLSRegression(n_components=1).fit(X, Y).transform(X)
+ canonical = PLSCanonical(n_components=1).fit(X, Y).transform(X)
+
+ assert_allclose(svd, reg, rtol=1e-2)
+ assert_allclose(svd, canonical, rtol=1e-2)
+
+
+def test_svd_flip_1d():
+ # Make sure svd_flip_1d is equivalent to svd_flip
+ u = np.array([1, -4, 2])
+ v = np.array([1, 2, 3])
+
+ u_expected, v_expected = svd_flip(u.reshape(-1, 1), v.reshape(1, -1))
+ _svd_flip_1d(u, v) # inplace
+
+ assert_allclose(u, u_expected.ravel())
+ assert_allclose(u, [-1, 4, -2])
+
+ assert_allclose(v, v_expected.ravel())
+ assert_allclose(v, [-1, -2, -3])
+
+
+def test_loadings_converges():
+ """Test that CCA converges. Non-regression test for #19549."""
+ X, y = make_regression(n_samples=200, n_features=20, n_targets=20, random_state=20)
+
+ cca = CCA(n_components=10, max_iter=500)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", ConvergenceWarning)
+
+ cca.fit(X, y)
+
+ # Loadings converges to reasonable values
+ assert np.all(np.abs(cca.x_loadings_) < 1)
+
+
+def test_pls_constant_y():
+ """Checks warning when y is constant. Non-regression test for #19831"""
+ rng = np.random.RandomState(42)
+ x = rng.rand(100, 3)
+ y = np.zeros(100)
+
+ pls = PLSRegression()
+
+ msg = "Y residual is constant at iteration"
+ with pytest.warns(UserWarning, match=msg):
+ pls.fit(x, y)
+
+ assert_allclose(pls.x_rotations_, 0)
+
+
+@pytest.mark.parametrize("PLSEstimator", [PLSRegression, PLSCanonical, CCA])
+def test_pls_coef_shape(PLSEstimator):
+ """Check the shape of `coef_` attribute.
+
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/12410
+ """
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+
+ pls = PLSEstimator(copy=True).fit(X, Y)
+
+ # TODO(1.3): remove the warning check
+ warning_msg = "The attribute `coef_` will be transposed in version 1.3"
+ with pytest.warns(FutureWarning, match=warning_msg):
+ assert pls.coef_.shape == (X.shape[1], Y.shape[1])
+
+ # Next accesses do not warn
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", FutureWarning)
+ pls.coef_
+
+ # TODO(1.3): rename `_coef_` to `coef_`
+ assert pls._coef_.shape == (Y.shape[1], X.shape[1])
+
+
+# TODO (1.3): remove the filterwarnings and adapt the dot product between `X_trans` and
+# `pls.coef_`
+@pytest.mark.filterwarnings("ignore:The attribute `coef_` will be transposed")
+@pytest.mark.parametrize("scale", [True, False])
+@pytest.mark.parametrize("PLSEstimator", [PLSRegression, PLSCanonical, CCA])
+def test_pls_prediction(PLSEstimator, scale):
+ """Check the behaviour of the prediction function."""
+ d = load_linnerud()
+ X = d.data
+ Y = d.target
+
+ pls = PLSEstimator(copy=True, scale=scale).fit(X, Y)
+ Y_pred = pls.predict(X, copy=True)
+
+ y_mean = Y.mean(axis=0)
+ X_trans = X - X.mean(axis=0)
+ if scale:
+ X_trans /= X.std(axis=0, ddof=1)
+
+ assert_allclose(pls.intercept_, y_mean)
+ assert_allclose(Y_pred, X_trans @ pls.coef_ + pls.intercept_)
+
+
+@pytest.mark.parametrize("Klass", [CCA, PLSSVD, PLSRegression, PLSCanonical])
+def test_pls_feature_names_out(Klass):
+ """Check `get_feature_names_out` cross_decomposition module."""
+ X, Y = load_linnerud(return_X_y=True)
+
+ est = Klass().fit(X, Y)
+ names_out = est.get_feature_names_out()
+
+ class_name_lower = Klass.__name__.lower()
+ expected_names_out = np.array(
+ [f"{class_name_lower}{i}" for i in range(est.x_weights_.shape[1])],
+ dtype=object,
+ )
+ assert_array_equal(names_out, expected_names_out)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("Klass", [CCA, PLSSVD, PLSRegression, PLSCanonical])
+def test_pls_set_output(Klass):
+ """Check `set_output` in cross_decomposition module."""
+ pd = pytest.importorskip("modin.pandas")
+ X, Y = load_linnerud(return_X_y=True, as_frame=True)
+
+ est = Klass().set_output(transform="pandas").fit(X, Y)
+ X_trans, y_trans = est.transform(X, Y)
+ assert isinstance(y_trans, np.ndarray)
+ assert isinstance(X_trans, pd.DataFrame)
+ assert_array_equal(X_trans.columns, est.get_feature_names_out())
diff --git a/modin/pandas/test/interoperability/sklearn/datasets/test_20news.py b/modin/pandas/test/interoperability/sklearn/datasets/test_20news.py
new file mode 100644
index 00000000000..701f751ed3a
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/datasets/test_20news.py
@@ -0,0 +1,151 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+"""Test the 20news downloader, if the data is available,
+or if specifically requested via environment variable
+(e.g. for travis cron job)."""
+from functools import partial
+from unittest.mock import patch
+import pytest
+import numpy as np
+import scipy.sparse as sp
+from sklearn.datasets.tests.test_common import check_as_frame
+from sklearn.datasets.tests.test_common import check_pandas_dependency_message
+from sklearn.datasets.tests.test_common import check_return_X_y
+from sklearn.utils._testing import assert_allclose_dense_sparse
+from sklearn.preprocessing import normalize
+
+
+def test_20news(fetch_20newsgroups_fxt):
+ data = fetch_20newsgroups_fxt(subset="all", shuffle=False)
+ assert data.DESCR.startswith(".. _20newsgroups_dataset:")
+
+ # Extract a reduced dataset
+ data2cats = fetch_20newsgroups_fxt(
+ subset="all", categories=data.target_names[-1:-3:-1], shuffle=False
+ )
+ # Check that the ordering of the target_names is the same
+ # as the ordering in the full dataset
+ assert data2cats.target_names == data.target_names[-2:]
+ # Assert that we have only 0 and 1 as labels
+ assert np.unique(data2cats.target).tolist() == [0, 1]
+
+ # Check that the number of filenames is consistent with data/target
+ assert len(data2cats.filenames) == len(data2cats.target)
+ assert len(data2cats.filenames) == len(data2cats.data)
+
+ # Check that the first entry of the reduced dataset corresponds to
+ # the first entry of the corresponding category in the full dataset
+ entry1 = data2cats.data[0]
+ category = data2cats.target_names[data2cats.target[0]]
+ label = data.target_names.index(category)
+ entry2 = data.data[np.where(data.target == label)[0][0]]
+ assert entry1 == entry2
+
+ # check that return_X_y option
+ X, y = fetch_20newsgroups_fxt(subset="all", shuffle=False, return_X_y=True)
+ assert len(X) == len(data.data)
+ assert y.shape == data.target.shape
+
+
+def test_20news_length_consistency(fetch_20newsgroups_fxt):
+ """Checks the length consistencies within the bunch
+
+ This is a non-regression test for a bug present in 0.16.1.
+ """
+ # Extract the full dataset
+ data = fetch_20newsgroups_fxt(subset="all")
+ assert len(data["data"]) == len(data.data)
+ assert len(data["target"]) == len(data.target)
+ assert len(data["filenames"]) == len(data.filenames)
+
+
+def test_20news_vectorized(fetch_20newsgroups_vectorized_fxt):
+ # test subset = train
+ bunch = fetch_20newsgroups_vectorized_fxt(subset="train")
+ assert sp.isspmatrix_csr(bunch.data)
+ assert bunch.data.shape == (11314, 130107)
+ assert bunch.target.shape[0] == 11314
+ assert bunch.data.dtype == np.float64
+ assert bunch.DESCR.startswith(".. _20newsgroups_dataset:")
+
+ # test subset = test
+ bunch = fetch_20newsgroups_vectorized_fxt(subset="test")
+ assert sp.isspmatrix_csr(bunch.data)
+ assert bunch.data.shape == (7532, 130107)
+ assert bunch.target.shape[0] == 7532
+ assert bunch.data.dtype == np.float64
+ assert bunch.DESCR.startswith(".. _20newsgroups_dataset:")
+
+ # test return_X_y option
+ fetch_func = partial(fetch_20newsgroups_vectorized_fxt, subset="test")
+ check_return_X_y(bunch, fetch_func)
+
+ # test subset = all
+ bunch = fetch_20newsgroups_vectorized_fxt(subset="all")
+ assert sp.isspmatrix_csr(bunch.data)
+ assert bunch.data.shape == (11314 + 7532, 130107)
+ assert bunch.target.shape[0] == 11314 + 7532
+ assert bunch.data.dtype == np.float64
+ assert bunch.DESCR.startswith(".. _20newsgroups_dataset:")
+
+
+def test_20news_normalization(fetch_20newsgroups_vectorized_fxt):
+ X = fetch_20newsgroups_vectorized_fxt(normalize=False)
+ X_ = fetch_20newsgroups_vectorized_fxt(normalize=True)
+ X_norm = X_["data"][:100]
+ X = X["data"][:100]
+
+ assert_allclose_dense_sparse(X_norm, normalize(X))
+ assert np.allclose(np.linalg.norm(X_norm.todense(), axis=1), 1)
+
+
+def test_20news_as_frame(fetch_20newsgroups_vectorized_fxt):
+ pd = pytest.importorskip("modin.pandas")
+
+ bunch = fetch_20newsgroups_vectorized_fxt(as_frame=True)
+ check_as_frame(bunch, fetch_20newsgroups_vectorized_fxt)
+
+ frame = bunch.frame
+ assert frame.shape == (11314, 130108)
+ assert all([isinstance(col, pd.SparseDtype) for col in bunch.data.dtypes])
+
+ # Check a small subset of features
+ for expected_feature in [
+ "beginner",
+ "beginners",
+ "beginning",
+ "beginnings",
+ "begins",
+ "begley",
+ "begone",
+ ]:
+ assert expected_feature in frame.keys()
+ assert "category_class" in frame.keys()
+ assert bunch.target.name == "category_class"
+
+
+def test_as_frame_no_pandas(fetch_20newsgroups_vectorized_fxt, hide_available_pandas):
+ check_pandas_dependency_message(fetch_20newsgroups_vectorized_fxt)
+
+
+def test_outdated_pickle(fetch_20newsgroups_vectorized_fxt):
+ with patch("os.path.exists") as mock_is_exist:
+ with patch("joblib.load") as mock_load:
+ # mock that the dataset was cached
+ mock_is_exist.return_value = True
+ # mock that we have an outdated pickle with only X and y returned
+ mock_load.return_value = ("X", "y")
+ err_msg = "The cached dataset located in"
+ with pytest.raises(ValueError, match=err_msg):
+ fetch_20newsgroups_vectorized_fxt(as_frame=True)
diff --git a/modin/pandas/test/interoperability/sklearn/datasets/test_arff_parser.py b/modin/pandas/test/interoperability/sklearn/datasets/test_arff_parser.py
new file mode 100644
index 00000000000..23c8706ae4f
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/datasets/test_arff_parser.py
@@ -0,0 +1,298 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+from io import BytesIO
+import textwrap
+import pytest
+from sklearn.datasets._arff_parser import (
+ _liac_arff_parser,
+ _pandas_arff_parser,
+ _post_process_frame,
+ load_arff_from_gzip_file,
+)
+
+
+@pytest.mark.parametrize(
+ "feature_names, target_names",
+ [
+ (
+ [
+ "col_int_as_integer",
+ "col_int_as_numeric",
+ "col_float_as_real",
+ "col_float_as_numeric",
+ ],
+ ["col_categorical", "col_string"],
+ ),
+ (
+ [
+ "col_int_as_integer",
+ "col_int_as_numeric",
+ "col_float_as_real",
+ "col_float_as_numeric",
+ ],
+ ["col_categorical"],
+ ),
+ (
+ [
+ "col_int_as_integer",
+ "col_int_as_numeric",
+ "col_float_as_real",
+ "col_float_as_numeric",
+ ],
+ [],
+ ),
+ ],
+)
+def test_post_process_frame(feature_names, target_names):
+ """Check the behaviour of the post-processing function for splitting a dataframe."""
+ pd = pytest.importorskip("modin.pandas")
+
+ X_original = pd.DataFrame(
+ {
+ "col_int_as_integer": [1, 2, 3],
+ "col_int_as_numeric": [1, 2, 3],
+ "col_float_as_real": [1.0, 2.0, 3.0],
+ "col_float_as_numeric": [1.0, 2.0, 3.0],
+ "col_categorical": ["a", "b", "c"],
+ "col_string": ["a", "b", "c"],
+ }
+ )
+
+ X, y = _post_process_frame(X_original, feature_names, target_names)
+ assert isinstance(X, pd.DataFrame)
+ if len(target_names) >= 2:
+ assert isinstance(y, pd.DataFrame)
+ elif len(target_names) == 1:
+ assert isinstance(y, pd.Series)
+ else:
+ assert y is None
+
+
+def test_load_arff_from_gzip_file_error_parser():
+ """An error will be raised if the parser is not known."""
+ # None of the input parameters are required to be accurate since the check
+ # of the parser will be carried out first.
+
+ err_msg = "Unknown parser: 'xxx'. Should be 'liac-arff' or 'pandas'"
+ with pytest.raises(ValueError, match=err_msg):
+ load_arff_from_gzip_file("xxx", "xxx", "xxx", "xxx", "xxx", "xxx")
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser])
+def test_pandas_arff_parser_strip_single_quotes(parser_func):
+ """Check that we properly strip single quotes from the data."""
+ pd = pytest.importorskip("modin.pandas")
+
+ arff_file = BytesIO(
+ textwrap.dedent(
+ """
+ @relation 'toy'
+ @attribute 'cat_single_quote' {'A', 'B', 'C'}
+ @attribute 'str_single_quote' string
+ @attribute 'str_nested_quote' string
+ @attribute 'class' numeric
+ @data
+ 'A','some text','\"expect double quotes\"',0
+ """
+ ).encode("utf-8")
+ )
+
+ columns_info = {
+ "cat_single_quote": {
+ "data_type": "nominal",
+ "name": "cat_single_quote",
+ },
+ "str_single_quote": {
+ "data_type": "string",
+ "name": "str_single_quote",
+ },
+ "str_nested_quote": {
+ "data_type": "string",
+ "name": "str_nested_quote",
+ },
+ "class": {
+ "data_type": "numeric",
+ "name": "class",
+ },
+ }
+
+ feature_names = [
+ "cat_single_quote",
+ "str_single_quote",
+ "str_nested_quote",
+ ]
+ target_names = ["class"]
+
+ # We don't strip single quotes for string columns with the pandas parser.
+ expected_values = {
+ "cat_single_quote": "A",
+ "str_single_quote": (
+ "some text" if parser_func is _liac_arff_parser else "'some text'"
+ ),
+ "str_nested_quote": (
+ '"expect double quotes"'
+ if parser_func is _liac_arff_parser
+ else "'\"expect double quotes\"'"
+ ),
+ "class": 0,
+ }
+
+ _, _, frame, _ = parser_func(
+ arff_file,
+ output_arrays_type="pandas",
+ openml_columns_info=columns_info,
+ feature_names_to_select=feature_names,
+ target_names_to_select=target_names,
+ )
+
+ assert frame.columns.tolist() == feature_names + target_names
+ pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser])
+def test_pandas_arff_parser_strip_double_quotes(parser_func):
+ """Check that we properly strip double quotes from the data."""
+ pd = pytest.importorskip("modin.pandas")
+
+ arff_file = BytesIO(
+ textwrap.dedent(
+ """
+ @relation 'toy'
+ @attribute 'cat_double_quote' {"A", "B", "C"}
+ @attribute 'str_double_quote' string
+ @attribute 'str_nested_quote' string
+ @attribute 'class' numeric
+ @data
+ "A","some text","\'expect double quotes\'",0
+ """
+ ).encode("utf-8")
+ )
+
+ columns_info = {
+ "cat_double_quote": {
+ "data_type": "nominal",
+ "name": "cat_double_quote",
+ },
+ "str_double_quote": {
+ "data_type": "string",
+ "name": "str_double_quote",
+ },
+ "str_nested_quote": {
+ "data_type": "string",
+ "name": "str_nested_quote",
+ },
+ "class": {
+ "data_type": "numeric",
+ "name": "class",
+ },
+ }
+
+ feature_names = [
+ "cat_double_quote",
+ "str_double_quote",
+ "str_nested_quote",
+ ]
+ target_names = ["class"]
+
+ expected_values = {
+ "cat_double_quote": "A",
+ "str_double_quote": "some text",
+ "str_nested_quote": "'expect double quotes'",
+ "class": 0,
+ }
+
+ _, _, frame, _ = parser_func(
+ arff_file,
+ output_arrays_type="pandas",
+ openml_columns_info=columns_info,
+ feature_names_to_select=feature_names,
+ target_names_to_select=target_names,
+ )
+
+ assert frame.columns.tolist() == feature_names + target_names
+ pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize(
+ "parser_func",
+ [
+ # internal quotes are not considered to follow the ARFF spec in LIAC ARFF
+ pytest.param(_liac_arff_parser, marks=pytest.mark.xfail),
+ _pandas_arff_parser,
+ ],
+)
+def test_pandas_arff_parser_strip_no_quotes(parser_func):
+ """Check that we properly parse with no quotes characters."""
+ pd = pytest.importorskip("modin.pandas")
+
+ arff_file = BytesIO(
+ textwrap.dedent(
+ """
+ @relation 'toy'
+ @attribute 'cat_without_quote' {A, B, C}
+ @attribute 'str_without_quote' string
+ @attribute 'str_internal_quote' string
+ @attribute 'class' numeric
+ @data
+ A,some text,'internal' quote,0
+ """
+ ).encode("utf-8")
+ )
+
+ columns_info = {
+ "cat_without_quote": {
+ "data_type": "nominal",
+ "name": "cat_without_quote",
+ },
+ "str_without_quote": {
+ "data_type": "string",
+ "name": "str_without_quote",
+ },
+ "str_internal_quote": {
+ "data_type": "string",
+ "name": "str_internal_quote",
+ },
+ "class": {
+ "data_type": "numeric",
+ "name": "class",
+ },
+ }
+
+ feature_names = [
+ "cat_without_quote",
+ "str_without_quote",
+ "str_internal_quote",
+ ]
+ target_names = ["class"]
+
+ expected_values = {
+ "cat_without_quote": "A",
+ "str_without_quote": "some text",
+ "str_internal_quote": "'internal' quote",
+ "class": 0,
+ }
+
+ _, _, frame, _ = parser_func(
+ arff_file,
+ output_arrays_type="pandas",
+ openml_columns_info=columns_info,
+ feature_names_to_select=feature_names,
+ target_names_to_select=target_names,
+ )
+
+ assert frame.columns.tolist() == feature_names + target_names
+ pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
diff --git a/modin/pandas/test/interoperability/sklearn/datasets/test_california_housing.py b/modin/pandas/test/interoperability/sklearn/datasets/test_california_housing.py
new file mode 100644
index 00000000000..383c3ef4598
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/datasets/test_california_housing.py
@@ -0,0 +1,48 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+"""Test the california_housing loader, if the data is available,
+or if specifically requested via environment variable
+(e.g. for travis cron job)."""
+import pytest
+from sklearn.datasets.tests.test_common import check_return_X_y
+from functools import partial
+
+
+def test_fetch(fetch_california_housing_fxt):
+ data = fetch_california_housing_fxt()
+ assert (20640, 8) == data.data.shape
+ assert (20640,) == data.target.shape
+ assert data.DESCR.startswith(".. _california_housing_dataset:")
+
+ # test return_X_y option
+ fetch_func = partial(fetch_california_housing_fxt)
+ check_return_X_y(data, fetch_func)
+
+
+def test_fetch_asframe(fetch_california_housing_fxt):
+ pd = pytest.importorskip("modin.pandas")
+ bunch = fetch_california_housing_fxt(as_frame=True)
+ frame = bunch.frame
+ assert hasattr(bunch, "frame") is True
+ assert frame.shape == (20640, 9)
+ assert isinstance(bunch.data, pd.DataFrame)
+ assert isinstance(bunch.target, pd.Series)
+
+
+def test_pandas_dependency_message(fetch_california_housing_fxt, hide_available_pandas):
+ # Check that pandas is imported lazily and that an informative error
+ # message is raised when pandas is missing:
+ expected_msg = "fetch_california_housing with as_frame=True requires pandas"
+ with pytest.raises(ImportError, match=expected_msg):
+ fetch_california_housing_fxt(as_frame=True)
diff --git a/modin/pandas/test/interoperability/sklearn/datasets/test_common_dataset.py b/modin/pandas/test/interoperability/sklearn/datasets/test_common_dataset.py
new file mode 100644
index 00000000000..2847b892699
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/datasets/test_common_dataset.py
@@ -0,0 +1,147 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+"""Test loaders for common functionality."""
+import inspect
+import os
+import pytest
+import numpy as np
+import sklearn.datasets
+
+
+def is_pillow_installed():
+ try:
+ import PIL # noqa
+
+ return True
+ except ImportError:
+ return False
+
+
+FETCH_PYTEST_MARKERS = {
+ "return_X_y": {
+ "fetch_20newsgroups": pytest.mark.xfail(
+ reason="X is a list and does not have a shape argument"
+ ),
+ "fetch_openml": pytest.mark.xfail(
+ reason="fetch_opeml requires a dataset name or id"
+ ),
+ "fetch_lfw_people": pytest.mark.skipif(
+ not is_pillow_installed(), reason="pillow is not installed"
+ ),
+ },
+ "as_frame": {
+ "fetch_openml": pytest.mark.xfail(
+ reason="fetch_opeml requires a dataset name or id"
+ ),
+ },
+}
+
+
+def check_pandas_dependency_message(fetch_func):
+ try:
+ import modin.pandas # noqa
+
+ pytest.skip("This test requires pandas to not be installed")
+ except ImportError:
+ # Check that pandas is imported lazily and that an informative error
+ # message is raised when pandas is missing:
+ name = fetch_func.__name__
+ expected_msg = f"{name} with as_frame=True requires pandas"
+ with pytest.raises(ImportError, match=expected_msg):
+ fetch_func(as_frame=True)
+
+
+def check_return_X_y(bunch, dataset_func):
+ X_y_tuple = dataset_func(return_X_y=True)
+ assert isinstance(X_y_tuple, tuple)
+ assert X_y_tuple[0].shape == bunch.data.shape
+ assert X_y_tuple[1].shape == bunch.target.shape
+
+
+def check_as_frame(
+ bunch, dataset_func, expected_data_dtype=None, expected_target_dtype=None
+):
+ pd = pytest.importorskip("modin.pandas")
+ frame_bunch = dataset_func(as_frame=True)
+ assert hasattr(frame_bunch, "frame")
+ assert isinstance(frame_bunch.frame, pd.DataFrame)
+ assert isinstance(frame_bunch.data, pd.DataFrame)
+ assert frame_bunch.data.shape == bunch.data.shape
+ if frame_bunch.target.ndim > 1:
+ assert isinstance(frame_bunch.target, pd.DataFrame)
+ else:
+ assert isinstance(frame_bunch.target, pd.Series)
+ assert frame_bunch.target.shape[0] == bunch.target.shape[0]
+ if expected_data_dtype is not None:
+ assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
+ if expected_target_dtype is not None:
+ assert np.all(frame_bunch.target.dtypes == expected_target_dtype)
+
+ # Test for return_X_y and as_frame=True
+ frame_X, frame_y = dataset_func(as_frame=True, return_X_y=True)
+ assert isinstance(frame_X, pd.DataFrame)
+ if frame_y.ndim > 1:
+ assert isinstance(frame_X, pd.DataFrame)
+ else:
+ assert isinstance(frame_y, pd.Series)
+
+
+def _skip_network_tests():
+ return os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1"
+
+
+def _generate_func_supporting_param(param, dataset_type=("load", "fetch")):
+ markers_fetch = FETCH_PYTEST_MARKERS.get(param, {})
+ for name, obj in inspect.getmembers(sklearn.datasets):
+ if not inspect.isfunction(obj):
+ continue
+
+ is_dataset_type = any([name.startswith(t) for t in dataset_type])
+ is_support_param = param in inspect.signature(obj).parameters
+ if is_dataset_type and is_support_param:
+ # check if we should skip if we don't have network support
+ marks = [
+ pytest.mark.skipif(
+ condition=name.startswith("fetch") and _skip_network_tests(),
+ reason="Skip because fetcher requires internet network",
+ )
+ ]
+ if name in markers_fetch:
+ marks.append(markers_fetch[name])
+
+ yield pytest.param(name, obj, marks=marks)
+
+
+@pytest.mark.parametrize(
+ "name, dataset_func", _generate_func_supporting_param("return_X_y")
+)
+def test_common_check_return_X_y(name, dataset_func):
+ bunch = dataset_func()
+ check_return_X_y(bunch, dataset_func)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize(
+ "name, dataset_func", _generate_func_supporting_param("as_frame")
+)
+def test_common_check_as_frame(name, dataset_func):
+ bunch = dataset_func()
+ check_as_frame(bunch, dataset_func)
+
+
+@pytest.mark.parametrize(
+ "name, dataset_func", _generate_func_supporting_param("as_frame")
+)
+def test_common_check_pandas_dependency(name, dataset_func):
+ check_pandas_dependency_message(dataset_func)
diff --git a/modin/pandas/test/interoperability/sklearn/datasets/test_covtype.py b/modin/pandas/test/interoperability/sklearn/datasets/test_covtype.py
new file mode 100644
index 00000000000..50fe2629c4d
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/datasets/test_covtype.py
@@ -0,0 +1,65 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+"""Test the covtype loader, if the data is available,
+or if specifically requested via environment variable
+(e.g. for travis cron job)."""
+from functools import partial
+import pytest
+from sklearn.datasets.tests.test_common import check_return_X_y
+
+
+def test_fetch(fetch_covtype_fxt):
+ data1 = fetch_covtype_fxt(shuffle=True, random_state=42)
+ data2 = fetch_covtype_fxt(shuffle=True, random_state=37)
+
+ X1, X2 = data1["data"], data2["data"]
+ assert (581012, 54) == X1.shape
+ assert X1.shape == X2.shape
+
+ assert X1.sum() == X2.sum()
+
+ y1, y2 = data1["target"], data2["target"]
+ assert (X1.shape[0],) == y1.shape
+ assert (X1.shape[0],) == y2.shape
+
+ descr_prefix = ".. _covtype_dataset:"
+ assert data1.DESCR.startswith(descr_prefix)
+ assert data2.DESCR.startswith(descr_prefix)
+
+ # test return_X_y option
+ fetch_func = partial(fetch_covtype_fxt)
+ check_return_X_y(data1, fetch_func)
+
+
+def test_fetch_asframe(fetch_covtype_fxt):
+ pytest.importorskip("modin.pandas")
+
+ bunch = fetch_covtype_fxt(as_frame=True)
+ assert hasattr(bunch, "frame")
+ frame = bunch.frame
+ assert frame.shape == (581012, 55)
+ assert bunch.data.shape == (581012, 54)
+ assert bunch.target.shape == (581012,)
+
+ column_names = set(frame.columns)
+
+ # enumerated names are added correctly
+ assert set(f"Wilderness_Area_{i}" for i in range(4)) < column_names
+ assert set(f"Soil_Type_{i}" for i in range(40)) < column_names
+
+
+def test_pandas_dependency_message(fetch_covtype_fxt, hide_available_pandas):
+ expected_msg = "fetch_covtype with as_frame=True requires pandas"
+ with pytest.raises(ImportError, match=expected_msg):
+ fetch_covtype_fxt(as_frame=True)
diff --git a/modin/pandas/test/interoperability/sklearn/datasets/test_openml.py b/modin/pandas/test/interoperability/sklearn/datasets/test_openml.py
new file mode 100644
index 00000000000..37073656e15
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/datasets/test_openml.py
@@ -0,0 +1,1656 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+"""Test the openml loader."""
+import gzip
+import json
+import os
+import re
+from functools import partial
+from io import BytesIO
+from urllib.error import HTTPError
+
+import numpy as np
+import scipy.sparse
+import pytest
+
+import sklearn
+from sklearn import config_context
+from sklearn.utils import Bunch, check_pandas_support
+from sklearn.utils.fixes import _open_binary
+from sklearn.utils._testing import (
+ SkipTest,
+ assert_allclose,
+ assert_array_equal,
+ fails_if_pypy,
+)
+
+from sklearn.datasets import fetch_openml as fetch_openml_orig
+from sklearn.datasets._openml import (
+ _OPENML_PREFIX,
+ _open_openml_url,
+ _get_local_path,
+ _retry_with_clean_cache,
+)
+
+
+OPENML_TEST_DATA_MODULE = "sklearn.datasets.tests.data.openml"
+# if True, urlopen will be monkey patched to only use local files
+test_offline = True
+
+
+class _MockHTTPResponse:
+ def __init__(self, data, is_gzip):
+ self.data = data
+ self.is_gzip = is_gzip
+
+ def read(self, amt=-1):
+ return self.data.read(amt)
+
+ def close(self):
+ self.data.close()
+
+ def info(self):
+ if self.is_gzip:
+ return {"Content-Encoding": "gzip"}
+ return {}
+
+ def __iter__(self):
+ return iter(self.data)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return False
+
+
+# Disable the disk-based cache when testing `fetch_openml`:
+# the mock data in sklearn/datasets/tests/data/openml/ is not always consistent
+# with the version on openml.org. If one were to load the dataset outside of
+# the tests, it may result in data that does not represent openml.org.
+fetch_openml = partial(fetch_openml_orig, data_home=None)
+
+
+def _monkey_patch_webbased_functions(context, data_id, gzip_response):
+ # monkey patches the urlopen function. Important note: Do NOT use this
+ # in combination with a regular cache directory, as the files that are
+ # stored as cache should not be mixed up with real openml datasets
+ url_prefix_data_description = "https://openml.org/api/v1/json/data/"
+ url_prefix_data_features = "https://openml.org/api/v1/json/data/features/"
+ url_prefix_download_data = "https://openml.org/data/v1/"
+ url_prefix_data_list = "https://openml.org/api/v1/json/data/list/"
+
+ path_suffix = ".gz"
+ read_fn = gzip.open
+
+ data_module = OPENML_TEST_DATA_MODULE + "." + f"id_{data_id}"
+
+ def _file_name(url, suffix):
+ output = (
+ re.sub(r"\W", "-", url[len("https://openml.org/") :]) + suffix + path_suffix
+ )
+ # Shorten the filenames to have better compatibility with windows 10
+ # and filenames > 260 characters
+ return (
+ output.replace("-json-data-list", "-jdl")
+ .replace("-json-data-features", "-jdf")
+ .replace("-json-data-qualities", "-jdq")
+ .replace("-json-data", "-jd")
+ .replace("-data_name", "-dn")
+ .replace("-download", "-dl")
+ .replace("-limit", "-l")
+ .replace("-data_version", "-dv")
+ .replace("-status", "-s")
+ .replace("-deactivated", "-dact")
+ .replace("-active", "-act")
+ )
+
+ def _mock_urlopen_shared(url, has_gzip_header, expected_prefix, suffix):
+ assert url.startswith(expected_prefix)
+
+ data_file_name = _file_name(url, suffix)
+
+ with _open_binary(data_module, data_file_name) as f:
+ if has_gzip_header and gzip_response:
+ fp = BytesIO(f.read())
+ return _MockHTTPResponse(fp, True)
+ else:
+ decompressed_f = read_fn(f, "rb")
+ fp = BytesIO(decompressed_f.read())
+ return _MockHTTPResponse(fp, False)
+
+ def _mock_urlopen_data_description(url, has_gzip_header):
+ return _mock_urlopen_shared(
+ url=url,
+ has_gzip_header=has_gzip_header,
+ expected_prefix=url_prefix_data_description,
+ suffix=".json",
+ )
+
+ def _mock_urlopen_data_features(url, has_gzip_header):
+ return _mock_urlopen_shared(
+ url=url,
+ has_gzip_header=has_gzip_header,
+ expected_prefix=url_prefix_data_features,
+ suffix=".json",
+ )
+
+ def _mock_urlopen_download_data(url, has_gzip_header):
+ return _mock_urlopen_shared(
+ url=url,
+ has_gzip_header=has_gzip_header,
+ expected_prefix=url_prefix_download_data,
+ suffix=".arff",
+ )
+
+ def _mock_urlopen_data_list(url, has_gzip_header):
+ assert url.startswith(url_prefix_data_list)
+
+ data_file_name = _file_name(url, ".json")
+
+ # load the file itself, to simulate a http error
+ with _open_binary(data_module, data_file_name) as f:
+ decompressed_f = read_fn(f, "rb")
+ decoded_s = decompressed_f.read().decode("utf-8")
+ json_data = json.loads(decoded_s)
+ if "error" in json_data:
+ raise HTTPError(
+ url=None, code=412, msg="Simulated mock error", hdrs=None, fp=None
+ )
+
+ with _open_binary(data_module, data_file_name) as f:
+ if has_gzip_header:
+ fp = BytesIO(f.read())
+ return _MockHTTPResponse(fp, True)
+ else:
+ decompressed_f = read_fn(f, "rb")
+ fp = BytesIO(decompressed_f.read())
+ return _MockHTTPResponse(fp, False)
+
+ def _mock_urlopen(request, *args, **kwargs):
+ url = request.get_full_url()
+ has_gzip_header = request.get_header("Accept-encoding") == "gzip"
+ if url.startswith(url_prefix_data_list):
+ return _mock_urlopen_data_list(url, has_gzip_header)
+ elif url.startswith(url_prefix_data_features):
+ return _mock_urlopen_data_features(url, has_gzip_header)
+ elif url.startswith(url_prefix_download_data):
+ return _mock_urlopen_download_data(url, has_gzip_header)
+ elif url.startswith(url_prefix_data_description):
+ return _mock_urlopen_data_description(url, has_gzip_header)
+ else:
+ raise ValueError("Unknown mocking URL pattern: %s" % url)
+
+ # XXX: Global variable
+ if test_offline:
+ context.setattr(sklearn.datasets._openml, "urlopen", _mock_urlopen)
+
+
+###############################################################################
+# Test the behaviour of `fetch_openml` depending of the input parameters.
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize(
+ "data_id, dataset_params, n_samples, n_features, n_targets",
+ [
+ # iris
+ (61, {"data_id": 61}, 150, 4, 1),
+ (61, {"name": "iris", "version": 1}, 150, 4, 1),
+ # anneal
+ (2, {"data_id": 2}, 11, 38, 1),
+ (2, {"name": "anneal", "version": 1}, 11, 38, 1),
+ # cpu
+ (561, {"data_id": 561}, 209, 7, 1),
+ (561, {"name": "cpu", "version": 1}, 209, 7, 1),
+ # emotions
+ (40589, {"data_id": 40589}, 13, 72, 6),
+ # adult-census
+ (1119, {"data_id": 1119}, 10, 14, 1),
+ (1119, {"name": "adult-census"}, 10, 14, 1),
+ # miceprotein
+ (40966, {"data_id": 40966}, 7, 77, 1),
+ (40966, {"name": "MiceProtein"}, 7, 77, 1),
+ # titanic
+ (40945, {"data_id": 40945}, 1309, 13, 1),
+ ],
+)
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_fetch_openml_as_frame_true(
+ monkeypatch,
+ data_id,
+ dataset_params,
+ n_samples,
+ n_features,
+ n_targets,
+ parser,
+ gzip_response,
+):
+ """Check the behaviour of `fetch_openml` with `as_frame=True`.
+
+ Fetch by ID and/or name (depending if the file was previously cached).
+ """
+ pd = pytest.importorskip("modin.pandas")
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=gzip_response)
+ bunch = fetch_openml(
+ as_frame=True,
+ cache=False,
+ parser=parser,
+ **dataset_params,
+ )
+
+ assert int(bunch.details["id"]) == data_id
+ assert isinstance(bunch, Bunch)
+
+ assert isinstance(bunch.frame, pd.DataFrame)
+ assert bunch.frame.shape == (n_samples, n_features + n_targets)
+
+ assert isinstance(bunch.data, pd.DataFrame)
+ assert bunch.data.shape == (n_samples, n_features)
+
+ if n_targets == 1:
+ assert isinstance(bunch.target, pd.Series)
+ assert bunch.target.shape == (n_samples,)
+ else:
+ assert isinstance(bunch.target, pd.DataFrame)
+ assert bunch.target.shape == (n_samples, n_targets)
+
+ assert bunch.categories is None
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.parametrize(
+ "data_id, dataset_params, n_samples, n_features, n_targets",
+ [
+ # iris
+ (61, {"data_id": 61}, 150, 4, 1),
+ (61, {"name": "iris", "version": 1}, 150, 4, 1),
+ # anneal
+ (2, {"data_id": 2}, 11, 38, 1),
+ (2, {"name": "anneal", "version": 1}, 11, 38, 1),
+ # cpu
+ (561, {"data_id": 561}, 209, 7, 1),
+ (561, {"name": "cpu", "version": 1}, 209, 7, 1),
+ # emotions
+ (40589, {"data_id": 40589}, 13, 72, 6),
+ # adult-census
+ (1119, {"data_id": 1119}, 10, 14, 1),
+ (1119, {"name": "adult-census"}, 10, 14, 1),
+ # miceprotein
+ (40966, {"data_id": 40966}, 7, 77, 1),
+ (40966, {"name": "MiceProtein"}, 7, 77, 1),
+ ],
+)
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+def test_fetch_openml_as_frame_false(
+ monkeypatch,
+ data_id,
+ dataset_params,
+ n_samples,
+ n_features,
+ n_targets,
+ parser,
+):
+ """Check the behaviour of `fetch_openml` with `as_frame=False`.
+
+ Fetch both by ID and/or name + version.
+ """
+ pytest.importorskip("modin.pandas")
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
+ bunch = fetch_openml(
+ as_frame=False,
+ cache=False,
+ parser=parser,
+ **dataset_params,
+ )
+ assert int(bunch.details["id"]) == data_id
+ assert isinstance(bunch, Bunch)
+
+ assert bunch.frame is None
+
+ assert isinstance(bunch.data, np.ndarray)
+ assert bunch.data.shape == (n_samples, n_features)
+
+ assert isinstance(bunch.target, np.ndarray)
+ if n_targets == 1:
+ assert bunch.target.shape == (n_samples,)
+ else:
+ assert bunch.target.shape == (n_samples, n_targets)
+
+ assert isinstance(bunch.categories, dict)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("data_id", [61, 1119, 40945])
+def test_fetch_openml_consistency_parser(monkeypatch, data_id):
+ """Check the consistency of the LIAC-ARFF and pandas parsers."""
+ pd = pytest.importorskip("modin.pandas")
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
+ bunch_liac = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ parser="liac-arff",
+ )
+ bunch_pandas = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ parser="pandas",
+ )
+
+ # The data frames for the input features should match up to some numerical
+ # dtype conversions (e.g. float64 <=> Int64) due to limitations of the
+ # LIAC-ARFF parser.
+ data_liac, data_pandas = bunch_liac.data, bunch_pandas.data
+
+ def convert_numerical_dtypes(series):
+ pandas_series = data_pandas[series.name]
+ if pd.api.types.is_numeric_dtype(pandas_series):
+ return series.astype(pandas_series.dtype)
+ else:
+ return series
+
+ data_liac_with_fixed_dtypes = data_liac.apply(convert_numerical_dtypes)
+ pd.testing.assert_frame_equal(data_liac_with_fixed_dtypes, data_pandas)
+
+ # Let's also check that the .frame attributes also match
+ frame_liac, frame_pandas = bunch_liac.frame, bunch_pandas.frame
+
+ # Note that the .frame attribute is a superset of the .data attribute:
+ pd.testing.assert_frame_equal(frame_pandas[bunch_pandas.feature_names], data_pandas)
+
+ # However the remaining columns, typically the target(s), are not necessarily
+ # dtyped similarly by both parsers due to limitations of the LIAC-ARFF parser.
+ # Therefore, extra dtype conversions are required for those columns:
+
+ def convert_numerical_and_categorical_dtypes(series):
+ pandas_series = frame_pandas[series.name]
+ if pd.api.types.is_numeric_dtype(pandas_series):
+ return series.astype(pandas_series.dtype)
+ elif pd.api.types.is_categorical_dtype(pandas_series):
+ # Compare categorical features by converting categorical liac uses
+ # strings to denote the categories, we rename the categories to make
+ # them comparable to the pandas parser. Fixing this behavior in
+ # LIAC-ARFF would allow to check the consistency in the future but
+ # we do not plan to maintain the LIAC-ARFF on the long term.
+ return series.cat.rename_categories(pandas_series.cat.categories)
+ else:
+ return series
+
+ frame_liac_with_fixed_dtypes = frame_liac.apply(
+ convert_numerical_and_categorical_dtypes
+ )
+ pd.testing.assert_frame_equal(frame_liac_with_fixed_dtypes, frame_pandas)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+def test_fetch_openml_equivalence_array_dataframe(monkeypatch, parser):
+ """Check the equivalence of the dataset when using `as_frame=False` and
+ `as_frame=True`.
+ """
+ pytest.importorskip("modin.pandas")
+
+ data_id = 61
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
+ bunch_as_frame_true = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ parser=parser,
+ )
+
+ bunch_as_frame_false = fetch_openml(
+ data_id=data_id,
+ as_frame=False,
+ cache=False,
+ parser=parser,
+ )
+
+ assert_allclose(bunch_as_frame_false.data, bunch_as_frame_true.data)
+ assert_array_equal(bunch_as_frame_false.target, bunch_as_frame_true.target)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+def test_fetch_openml_iris_pandas(monkeypatch, parser):
+ """Check fetching on a numerical only dataset with string labels."""
+ pd = pytest.importorskip("modin.pandas")
+ CategoricalDtype = pd.api.types.CategoricalDtype
+ data_id = 61
+ data_shape = (150, 4)
+ target_shape = (150,)
+ frame_shape = (150, 5)
+
+ target_dtype = CategoricalDtype(
+ ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
+ )
+ data_dtypes = [np.float64] * 4
+ data_names = ["sepallength", "sepalwidth", "petallength", "petalwidth"]
+ target_name = "class"
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+
+ bunch = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ parser=parser,
+ )
+ data = bunch.data
+ target = bunch.target
+ frame = bunch.frame
+
+ assert isinstance(data, pd.DataFrame)
+ assert np.all(data.dtypes == data_dtypes)
+ assert data.shape == data_shape
+ assert np.all(data.columns == data_names)
+ assert np.all(bunch.feature_names == data_names)
+ assert bunch.target_names == [target_name]
+
+ assert isinstance(target, pd.Series)
+ assert target.dtype == target_dtype
+ assert target.shape == target_shape
+ assert target.name == target_name
+ assert target.index.is_unique
+
+ assert isinstance(frame, pd.DataFrame)
+ assert frame.shape == frame_shape
+ assert np.all(frame.dtypes == data_dtypes + [target_dtype])
+ assert frame.index.is_unique
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+@pytest.mark.parametrize("target_column", ["petalwidth", ["petalwidth", "petallength"]])
+def test_fetch_openml_forcing_targets(monkeypatch, parser, target_column):
+ """Check that we can force the target to not be the default target."""
+ pd = pytest.importorskip("modin.pandas")
+
+ data_id = 61
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+ bunch_forcing_target = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ target_column=target_column,
+ parser=parser,
+ )
+ bunch_default = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ parser=parser,
+ )
+
+ pd.testing.assert_frame_equal(bunch_forcing_target.frame, bunch_default.frame)
+ if isinstance(target_column, list):
+ pd.testing.assert_index_equal(
+ bunch_forcing_target.target.columns, pd.Index(target_column)
+ )
+ assert bunch_forcing_target.data.shape == (150, 3)
+ else:
+ assert bunch_forcing_target.target.name == target_column
+ assert bunch_forcing_target.data.shape == (150, 4)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("data_id", [61, 2, 561, 40589, 1119])
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+def test_fetch_openml_equivalence_frame_return_X_y(monkeypatch, data_id, parser):
+ """Check the behaviour of `return_X_y=True` when `as_frame=True`."""
+ pd = pytest.importorskip("modin.pandas")
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
+ bunch = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ return_X_y=False,
+ parser=parser,
+ )
+ X, y = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ return_X_y=True,
+ parser=parser,
+ )
+
+ pd.testing.assert_frame_equal(bunch.data, X)
+ if isinstance(y, pd.Series):
+ pd.testing.assert_series_equal(bunch.target, y)
+ else:
+ pd.testing.assert_frame_equal(bunch.target, y)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.parametrize("data_id", [61, 561, 40589, 1119])
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+def test_fetch_openml_equivalence_array_return_X_y(monkeypatch, data_id, parser):
+ """Check the behaviour of `return_X_y=True` when `as_frame=False`."""
+ pytest.importorskip("modin.pandas")
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
+ bunch = fetch_openml(
+ data_id=data_id,
+ as_frame=False,
+ cache=False,
+ return_X_y=False,
+ parser=parser,
+ )
+ X, y = fetch_openml(
+ data_id=data_id,
+ as_frame=False,
+ cache=False,
+ return_X_y=True,
+ parser=parser,
+ )
+
+ assert_array_equal(bunch.data, X)
+ assert_array_equal(bunch.target, y)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+def test_fetch_openml_difference_parsers(monkeypatch):
+ """Check the difference between liac-arff and pandas parser."""
+ pytest.importorskip("modin.pandas")
+
+ data_id = 1119
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
+ # When `as_frame=False`, the categories will be ordinally encoded with
+ # liac-arff parser while this is not the case with pandas parser.
+ as_frame = False
+ bunch_liac_arff = fetch_openml(
+ data_id=data_id,
+ as_frame=as_frame,
+ cache=False,
+ parser="liac-arff",
+ )
+ bunch_pandas = fetch_openml(
+ data_id=data_id,
+ as_frame=as_frame,
+ cache=False,
+ parser="pandas",
+ )
+
+ assert bunch_liac_arff.data.dtype.kind == "f"
+ assert bunch_pandas.data.dtype == "O"
+
+
+###############################################################################
+# Test the ARFF parsing on several dataset to check if detect the correct
+# types (categories, intgers, floats).
+
+
+@pytest.fixture(scope="module")
+def datasets_column_names():
+ """Returns the columns names for each dataset."""
+ return {
+ 61: ["sepallength", "sepalwidth", "petallength", "petalwidth", "class"],
+ 2: [
+ "family",
+ "product-type",
+ "steel",
+ "carbon",
+ "hardness",
+ "temper_rolling",
+ "condition",
+ "formability",
+ "strength",
+ "non-ageing",
+ "surface-finish",
+ "surface-quality",
+ "enamelability",
+ "bc",
+ "bf",
+ "bt",
+ "bw%2Fme",
+ "bl",
+ "m",
+ "chrom",
+ "phos",
+ "cbond",
+ "marvi",
+ "exptl",
+ "ferro",
+ "corr",
+ "blue%2Fbright%2Fvarn%2Fclean",
+ "lustre",
+ "jurofm",
+ "s",
+ "p",
+ "shape",
+ "thick",
+ "width",
+ "len",
+ "oil",
+ "bore",
+ "packing",
+ "class",
+ ],
+ 561: ["vendor", "MYCT", "MMIN", "MMAX", "CACH", "CHMIN", "CHMAX", "class"],
+ 40589: [
+ "Mean_Acc1298_Mean_Mem40_Centroid",
+ "Mean_Acc1298_Mean_Mem40_Rolloff",
+ "Mean_Acc1298_Mean_Mem40_Flux",
+ "Mean_Acc1298_Mean_Mem40_MFCC_0",
+ "Mean_Acc1298_Mean_Mem40_MFCC_1",
+ "Mean_Acc1298_Mean_Mem40_MFCC_2",
+ "Mean_Acc1298_Mean_Mem40_MFCC_3",
+ "Mean_Acc1298_Mean_Mem40_MFCC_4",
+ "Mean_Acc1298_Mean_Mem40_MFCC_5",
+ "Mean_Acc1298_Mean_Mem40_MFCC_6",
+ "Mean_Acc1298_Mean_Mem40_MFCC_7",
+ "Mean_Acc1298_Mean_Mem40_MFCC_8",
+ "Mean_Acc1298_Mean_Mem40_MFCC_9",
+ "Mean_Acc1298_Mean_Mem40_MFCC_10",
+ "Mean_Acc1298_Mean_Mem40_MFCC_11",
+ "Mean_Acc1298_Mean_Mem40_MFCC_12",
+ "Mean_Acc1298_Std_Mem40_Centroid",
+ "Mean_Acc1298_Std_Mem40_Rolloff",
+ "Mean_Acc1298_Std_Mem40_Flux",
+ "Mean_Acc1298_Std_Mem40_MFCC_0",
+ "Mean_Acc1298_Std_Mem40_MFCC_1",
+ "Mean_Acc1298_Std_Mem40_MFCC_2",
+ "Mean_Acc1298_Std_Mem40_MFCC_3",
+ "Mean_Acc1298_Std_Mem40_MFCC_4",
+ "Mean_Acc1298_Std_Mem40_MFCC_5",
+ "Mean_Acc1298_Std_Mem40_MFCC_6",
+ "Mean_Acc1298_Std_Mem40_MFCC_7",
+ "Mean_Acc1298_Std_Mem40_MFCC_8",
+ "Mean_Acc1298_Std_Mem40_MFCC_9",
+ "Mean_Acc1298_Std_Mem40_MFCC_10",
+ "Mean_Acc1298_Std_Mem40_MFCC_11",
+ "Mean_Acc1298_Std_Mem40_MFCC_12",
+ "Std_Acc1298_Mean_Mem40_Centroid",
+ "Std_Acc1298_Mean_Mem40_Rolloff",
+ "Std_Acc1298_Mean_Mem40_Flux",
+ "Std_Acc1298_Mean_Mem40_MFCC_0",
+ "Std_Acc1298_Mean_Mem40_MFCC_1",
+ "Std_Acc1298_Mean_Mem40_MFCC_2",
+ "Std_Acc1298_Mean_Mem40_MFCC_3",
+ "Std_Acc1298_Mean_Mem40_MFCC_4",
+ "Std_Acc1298_Mean_Mem40_MFCC_5",
+ "Std_Acc1298_Mean_Mem40_MFCC_6",
+ "Std_Acc1298_Mean_Mem40_MFCC_7",
+ "Std_Acc1298_Mean_Mem40_MFCC_8",
+ "Std_Acc1298_Mean_Mem40_MFCC_9",
+ "Std_Acc1298_Mean_Mem40_MFCC_10",
+ "Std_Acc1298_Mean_Mem40_MFCC_11",
+ "Std_Acc1298_Mean_Mem40_MFCC_12",
+ "Std_Acc1298_Std_Mem40_Centroid",
+ "Std_Acc1298_Std_Mem40_Rolloff",
+ "Std_Acc1298_Std_Mem40_Flux",
+ "Std_Acc1298_Std_Mem40_MFCC_0",
+ "Std_Acc1298_Std_Mem40_MFCC_1",
+ "Std_Acc1298_Std_Mem40_MFCC_2",
+ "Std_Acc1298_Std_Mem40_MFCC_3",
+ "Std_Acc1298_Std_Mem40_MFCC_4",
+ "Std_Acc1298_Std_Mem40_MFCC_5",
+ "Std_Acc1298_Std_Mem40_MFCC_6",
+ "Std_Acc1298_Std_Mem40_MFCC_7",
+ "Std_Acc1298_Std_Mem40_MFCC_8",
+ "Std_Acc1298_Std_Mem40_MFCC_9",
+ "Std_Acc1298_Std_Mem40_MFCC_10",
+ "Std_Acc1298_Std_Mem40_MFCC_11",
+ "Std_Acc1298_Std_Mem40_MFCC_12",
+ "BH_LowPeakAmp",
+ "BH_LowPeakBPM",
+ "BH_HighPeakAmp",
+ "BH_HighPeakBPM",
+ "BH_HighLowRatio",
+ "BHSUM1",
+ "BHSUM2",
+ "BHSUM3",
+ "amazed.suprised",
+ "happy.pleased",
+ "relaxing.calm",
+ "quiet.still",
+ "sad.lonely",
+ "angry.aggresive",
+ ],
+ 1119: [
+ "age",
+ "workclass",
+ "fnlwgt:",
+ "education:",
+ "education-num:",
+ "marital-status:",
+ "occupation:",
+ "relationship:",
+ "race:",
+ "sex:",
+ "capital-gain:",
+ "capital-loss:",
+ "hours-per-week:",
+ "native-country:",
+ "class",
+ ],
+ 40966: [
+ "DYRK1A_N",
+ "ITSN1_N",
+ "BDNF_N",
+ "NR1_N",
+ "NR2A_N",
+ "pAKT_N",
+ "pBRAF_N",
+ "pCAMKII_N",
+ "pCREB_N",
+ "pELK_N",
+ "pERK_N",
+ "pJNK_N",
+ "PKCA_N",
+ "pMEK_N",
+ "pNR1_N",
+ "pNR2A_N",
+ "pNR2B_N",
+ "pPKCAB_N",
+ "pRSK_N",
+ "AKT_N",
+ "BRAF_N",
+ "CAMKII_N",
+ "CREB_N",
+ "ELK_N",
+ "ERK_N",
+ "GSK3B_N",
+ "JNK_N",
+ "MEK_N",
+ "TRKA_N",
+ "RSK_N",
+ "APP_N",
+ "Bcatenin_N",
+ "SOD1_N",
+ "MTOR_N",
+ "P38_N",
+ "pMTOR_N",
+ "DSCR1_N",
+ "AMPKA_N",
+ "NR2B_N",
+ "pNUMB_N",
+ "RAPTOR_N",
+ "TIAM1_N",
+ "pP70S6_N",
+ "NUMB_N",
+ "P70S6_N",
+ "pGSK3B_N",
+ "pPKCG_N",
+ "CDK5_N",
+ "S6_N",
+ "ADARB1_N",
+ "AcetylH3K9_N",
+ "RRP1_N",
+ "BAX_N",
+ "ARC_N",
+ "ERBB4_N",
+ "nNOS_N",
+ "Tau_N",
+ "GFAP_N",
+ "GluR3_N",
+ "GluR4_N",
+ "IL1B_N",
+ "P3525_N",
+ "pCASP9_N",
+ "PSD95_N",
+ "SNCA_N",
+ "Ubiquitin_N",
+ "pGSK3B_Tyr216_N",
+ "SHH_N",
+ "BAD_N",
+ "BCL2_N",
+ "pS6_N",
+ "pCFOS_N",
+ "SYP_N",
+ "H3AcK18_N",
+ "EGR1_N",
+ "H3MeK4_N",
+ "CaNA_N",
+ "class",
+ ],
+ 40945: [
+ "pclass",
+ "survived",
+ "name",
+ "sex",
+ "age",
+ "sibsp",
+ "parch",
+ "ticket",
+ "fare",
+ "cabin",
+ "embarked",
+ "boat",
+ "body",
+ "home.dest",
+ ],
+ }
+
+
+@pytest.fixture(scope="module")
+def datasets_missing_values():
+ return {
+ 61: {},
+ 2: {
+ "family": 11,
+ "temper_rolling": 9,
+ "condition": 2,
+ "formability": 4,
+ "non-ageing": 10,
+ "surface-finish": 11,
+ "enamelability": 11,
+ "bc": 11,
+ "bf": 10,
+ "bt": 11,
+ "bw%2Fme": 8,
+ "bl": 9,
+ "m": 11,
+ "chrom": 11,
+ "phos": 11,
+ "cbond": 10,
+ "marvi": 11,
+ "exptl": 11,
+ "ferro": 11,
+ "corr": 11,
+ "blue%2Fbright%2Fvarn%2Fclean": 11,
+ "lustre": 8,
+ "jurofm": 11,
+ "s": 11,
+ "p": 11,
+ "oil": 10,
+ "packing": 11,
+ },
+ 561: {},
+ 40589: {},
+ 1119: {},
+ 40966: {"BCL2_N": 7},
+ 40945: {
+ "age": 263,
+ "fare": 1,
+ "cabin": 1014,
+ "embarked": 2,
+ "boat": 823,
+ "body": 1188,
+ "home.dest": 564,
+ },
+ }
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.parametrize(
+ "data_id, parser, expected_n_categories, expected_n_floats, expected_n_ints",
+ [
+ # iris dataset
+ (61, "liac-arff", 1, 4, 0),
+ (61, "pandas", 1, 4, 0),
+ # anneal dataset
+ (2, "liac-arff", 33, 6, 0),
+ (2, "pandas", 33, 2, 4),
+ # cpu dataset
+ (561, "liac-arff", 1, 7, 0),
+ (561, "pandas", 1, 0, 7),
+ # emotions dataset
+ (40589, "liac-arff", 6, 72, 0),
+ (40589, "pandas", 6, 69, 3),
+ # adult-census dataset
+ (1119, "liac-arff", 9, 6, 0),
+ (1119, "pandas", 9, 0, 6),
+ # miceprotein
+ # 1 column has only missing values with object dtype
+ (40966, "liac-arff", 1, 76, 0),
+ # with casting it will be transformed to either float or Int64
+ (40966, "pandas", 1, 77, 0),
+ # titanic
+ (40945, "liac-arff", 3, 5, 0),
+ (40945, "pandas", 3, 3, 3),
+ ],
+)
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_fetch_openml_types_inference(
+ monkeypatch,
+ data_id,
+ parser,
+ expected_n_categories,
+ expected_n_floats,
+ expected_n_ints,
+ gzip_response,
+ datasets_column_names,
+ datasets_missing_values,
+):
+ """Check that `fetch_openml` infer the right number of categories, integers, and
+ floats."""
+ pd = pytest.importorskip("modin.pandas")
+ CategoricalDtype = pd.api.types.CategoricalDtype
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=gzip_response)
+
+ bunch = fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ parser=parser,
+ )
+ frame = bunch.frame
+
+ n_categories = len(
+ [dtype for dtype in frame.dtypes if isinstance(dtype, CategoricalDtype)]
+ )
+ n_floats = len([dtype for dtype in frame.dtypes if dtype.kind == "f"])
+ n_ints = len([dtype for dtype in frame.dtypes if dtype.kind == "i"])
+
+ assert n_categories == expected_n_categories
+ assert n_floats == expected_n_floats
+ assert n_ints == expected_n_ints
+
+ assert frame.columns.tolist() == datasets_column_names[data_id]
+
+ frame_feature_to_n_nan = frame.isna().sum().to_dict()
+ for name, n_missing in frame_feature_to_n_nan.items():
+ expected_missing = datasets_missing_values[data_id].get(name, 0)
+ assert n_missing == expected_missing
+
+
+###############################################################################
+# Test some more specific behaviour
+
+
+# TODO(1.4): remove this filterwarning decorator
+@pytest.mark.filterwarnings("ignore:The default value of `parser` will change")
+@pytest.mark.parametrize(
+ "params, err_msg",
+ [
+ ({"parser": "unknown"}, "`parser` must be one of"),
+ ({"as_frame": "unknown"}, "`as_frame` must be one of"),
+ ],
+)
+def test_fetch_openml_validation_parameter(monkeypatch, params, err_msg):
+ data_id = 1119
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+ with pytest.raises(ValueError, match=err_msg):
+ fetch_openml(data_id=data_id, **params)
+
+
+@pytest.mark.parametrize(
+ "params",
+ [
+ {"as_frame": True, "parser": "auto"},
+ {"as_frame": "auto", "parser": "auto"},
+ {"as_frame": False, "parser": "pandas"},
+ ],
+)
+def test_fetch_openml_requires_pandas_error(monkeypatch, params):
+ """Check that we raise the proper errors when we require pandas."""
+ data_id = 1119
+ try:
+ check_pandas_support("test_fetch_openml_requires_pandas")
+ except ImportError:
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+ err_msg = "requires pandas to be installed. Alternatively, explicitely"
+ with pytest.raises(ImportError, match=err_msg):
+ fetch_openml(data_id=data_id, **params)
+ else:
+ raise SkipTest("This test requires pandas to not be installed.")
+
+
+# TODO(1.4): move this parameter option in`test_fetch_openml_requires_pandas_error`
+def test_fetch_openml_requires_pandas_in_future(monkeypatch):
+ """Check that we raise a warning that pandas will be required in the future."""
+ params = {"as_frame": False, "parser": "auto"}
+ data_id = 1119
+ try:
+ check_pandas_support("test_fetch_openml_requires_pandas")
+ except ImportError:
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+ warn_msg = (
+ "From version 1.4, `parser='auto'` with `as_frame=False` will use pandas"
+ )
+ with pytest.warns(FutureWarning, match=warn_msg):
+ fetch_openml(data_id=data_id, **params)
+ else:
+ raise SkipTest("This test requires pandas to not be installed.")
+
+
+@pytest.mark.filterwarnings("ignore:Version 1 of dataset Australian is inactive")
+# TODO(1.4): remove this filterwarning decorator for `parser`
+@pytest.mark.filterwarnings("ignore:The default value of `parser` will change")
+@pytest.mark.parametrize(
+ "params, err_msg",
+ [
+ (
+ {"parser": "pandas"},
+ "Sparse ARFF datasets cannot be loaded with parser='pandas'",
+ ),
+ (
+ {"as_frame": True},
+ "Sparse ARFF datasets cannot be loaded with as_frame=True.",
+ ),
+ (
+ {"parser": "pandas", "as_frame": True},
+ "Sparse ARFF datasets cannot be loaded with as_frame=True.",
+ ),
+ ],
+)
+def test_fetch_openml_sparse_arff_error(monkeypatch, params, err_msg):
+ """Check that we raise the expected error for sparse ARFF datasets and
+ a wrong set of incompatible parameters.
+ """
+ pytest.importorskip("modin.pandas")
+ data_id = 292
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+ with pytest.raises(ValueError, match=err_msg):
+ fetch_openml(
+ data_id=data_id,
+ cache=False,
+ **params,
+ )
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.filterwarnings("ignore:Version 1 of dataset Australian is inactive")
+@pytest.mark.parametrize(
+ "data_id, data_type",
+ [
+ (61, "dataframe"), # iris dataset version 1
+ (292, "sparse"), # Australian dataset version 1
+ ],
+)
+def test_fetch_openml_auto_mode(monkeypatch, data_id, data_type):
+ """Check the auto mode of `fetch_openml`."""
+ pd = pytest.importorskip("modin.pandas")
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+ data = fetch_openml(data_id=data_id, as_frame="auto", parser="auto", cache=False)
+ klass = pd.DataFrame if data_type == "dataframe" else scipy.sparse.csr_matrix
+ assert isinstance(data.data, klass)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+def test_convert_arff_data_dataframe_warning_low_memory_pandas(monkeypatch):
+ """Check that we raise a warning regarding the working memory when using
+ LIAC-ARFF parser."""
+ pytest.importorskip("modin.pandas")
+
+ data_id = 1119
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+
+ msg = "Could not adhere to working_memory config."
+ with pytest.warns(UserWarning, match=msg):
+ with config_context(working_memory=1e-6):
+ fetch_openml(
+ data_id=data_id,
+ as_frame=True,
+ cache=False,
+ parser="liac-arff",
+ )
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_fetch_openml_iris_warn_multiple_version(monkeypatch, gzip_response):
+ """Check that a warning is raised when multiple versions exist and no version is
+ requested."""
+ data_id = 61
+ data_name = "iris"
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+
+ msg = (
+ "Multiple active versions of the dataset matching the name"
+ " iris exist. Versions may be fundamentally different, "
+ "returning version 1."
+ )
+ with pytest.warns(UserWarning, match=msg):
+ fetch_openml(
+ name=data_name,
+ as_frame=False,
+ cache=False,
+ parser="liac-arff",
+ )
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_fetch_openml_no_target(monkeypatch, gzip_response):
+ """Check that we can get a dataset without target."""
+ data_id = 61
+ target_column = None
+ expected_observations = 150
+ expected_features = 5
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ data = fetch_openml(
+ data_id=data_id,
+ target_column=target_column,
+ cache=False,
+ as_frame=False,
+ parser="liac-arff",
+ )
+ assert data.data.shape == (expected_observations, expected_features)
+ assert data.target is None
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+def test_missing_values_pandas(monkeypatch, gzip_response, parser):
+ """check that missing values in categories are compatible with pandas
+ categorical"""
+ pytest.importorskip("pandas")
+
+ data_id = 42585
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=gzip_response)
+ penguins = fetch_openml(
+ data_id=data_id,
+ cache=False,
+ as_frame=True,
+ parser=parser,
+ )
+
+ cat_dtype = penguins.data.dtypes["sex"]
+ # there are nans in the categorical
+ assert penguins.data["sex"].isna().any()
+ assert_array_equal(cat_dtype.categories, ["FEMALE", "MALE", "_"])
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+@pytest.mark.parametrize(
+ "dataset_params",
+ [
+ {"data_id": 40675},
+ {"data_id": None, "name": "glass2", "version": 1},
+ ],
+)
+def test_fetch_openml_inactive(monkeypatch, gzip_response, dataset_params):
+ """Check that we raise a warning when the dataset is inactive."""
+ data_id = 40675
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ msg = "Version 1 of dataset glass2 is inactive,"
+ with pytest.warns(UserWarning, match=msg):
+ glass2 = fetch_openml(
+ cache=False, as_frame=False, parser="liac-arff", **dataset_params
+ )
+ assert glass2.data.shape == (163, 9)
+ assert glass2.details["id"] == "40675"
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+@pytest.mark.parametrize(
+ "data_id, params, err_type, err_msg",
+ [
+ (40675, {"name": "glass2"}, ValueError, "No active dataset glass2 found"),
+ (
+ 61,
+ {"data_id": 61, "target_column": ["sepalwidth", "class"]},
+ ValueError,
+ "Can only handle homogeneous multi-target datasets",
+ ),
+ (
+ 40945,
+ {"data_id": 40945, "as_frame": False},
+ ValueError,
+ "STRING attributes are not supported for array representation. Try"
+ " as_frame=True",
+ ),
+ (
+ 2,
+ {"data_id": 2, "target_column": "family", "as_frame": True},
+ ValueError,
+ "Target column 'family'",
+ ),
+ (
+ 2,
+ {"data_id": 2, "target_column": "family", "as_frame": False},
+ ValueError,
+ "Target column 'family'",
+ ),
+ (
+ 61,
+ {"data_id": 61, "target_column": "undefined"},
+ KeyError,
+ "Could not find target_column='undefined'",
+ ),
+ (
+ 61,
+ {"data_id": 61, "target_column": ["undefined", "class"]},
+ KeyError,
+ "Could not find target_column='undefined'",
+ ),
+ ],
+)
+@pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
+def test_fetch_openml_error(
+ monkeypatch, gzip_response, data_id, params, err_type, err_msg, parser
+):
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ if params.get("as_frame", True) or parser == "pandas":
+ pytest.importorskip("pandas")
+ with pytest.raises(err_type, match=err_msg):
+ fetch_openml(cache=False, parser=parser, **params)
+
+
+@pytest.mark.parametrize(
+ "params, err_type, err_msg",
+ [
+ (
+ {"data_id": -1, "name": None, "version": "version"},
+ ValueError,
+ "Dataset data_id=-1 and version=version passed, but you can only",
+ ),
+ (
+ {"data_id": -1, "name": "nAmE"},
+ ValueError,
+ "Dataset data_id=-1 and name=name passed, but you can only",
+ ),
+ (
+ {"data_id": -1, "name": "nAmE", "version": "version"},
+ ValueError,
+ "Dataset data_id=-1 and name=name passed, but you can only",
+ ),
+ (
+ {},
+ ValueError,
+ "Neither name nor data_id are provided. Please provide name or data_id.",
+ ),
+ ],
+)
+def test_fetch_openml_raises_illegal_argument(params, err_type, err_msg):
+ with pytest.raises(err_type, match=err_msg):
+ fetch_openml(**params)
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_warn_ignore_attribute(monkeypatch, gzip_response):
+ data_id = 40966
+ expected_row_id_msg = "target_column='{}' has flag is_row_identifier."
+ expected_ignore_msg = "target_column='{}' has flag is_ignore."
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ # single column test
+ target_col = "MouseID"
+ msg = expected_row_id_msg.format(target_col)
+ with pytest.warns(UserWarning, match=msg):
+ fetch_openml(
+ data_id=data_id,
+ target_column=target_col,
+ cache=False,
+ as_frame=False,
+ parser="liac-arff",
+ )
+ target_col = "Genotype"
+ msg = expected_ignore_msg.format(target_col)
+ with pytest.warns(UserWarning, match=msg):
+ fetch_openml(
+ data_id=data_id,
+ target_column=target_col,
+ cache=False,
+ as_frame=False,
+ parser="liac-arff",
+ )
+ # multi column test
+ target_col = "MouseID"
+ msg = expected_row_id_msg.format(target_col)
+ with pytest.warns(UserWarning, match=msg):
+ fetch_openml(
+ data_id=data_id,
+ target_column=[target_col, "class"],
+ cache=False,
+ as_frame=False,
+ parser="liac-arff",
+ )
+ target_col = "Genotype"
+ msg = expected_ignore_msg.format(target_col)
+ with pytest.warns(UserWarning, match=msg):
+ fetch_openml(
+ data_id=data_id,
+ target_column=[target_col, "class"],
+ cache=False,
+ as_frame=False,
+ parser="liac-arff",
+ )
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_dataset_with_openml_error(monkeypatch, gzip_response):
+ data_id = 1
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ msg = "OpenML registered a problem with the dataset. It might be unusable. Error:"
+ with pytest.warns(UserWarning, match=msg):
+ fetch_openml(data_id=data_id, cache=False, as_frame=False, parser="liac-arff")
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_dataset_with_openml_warning(monkeypatch, gzip_response):
+ data_id = 3
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ msg = "OpenML raised a warning on the dataset. It might be unusable. Warning:"
+ with pytest.warns(UserWarning, match=msg):
+ fetch_openml(data_id=data_id, cache=False, as_frame=False, parser="liac-arff")
+
+
+###############################################################################
+# Test cache, retry mechanisms, checksum, etc.
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir):
+ data_id = 61
+
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
+ cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
+ # first fill the cache
+ response1 = _open_openml_url(openml_path, cache_directory)
+ # assert file exists
+ location = _get_local_path(openml_path, cache_directory)
+ assert os.path.isfile(location)
+ # redownload, to utilize cache
+ response2 = _open_openml_url(openml_path, cache_directory)
+ assert response1.read() == response2.read()
+
+
+@pytest.mark.parametrize("write_to_disk", [True, False])
+def test_open_openml_url_unlinks_local_path(monkeypatch, tmpdir, write_to_disk):
+ data_id = 61
+ openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
+ cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
+ location = _get_local_path(openml_path, cache_directory)
+
+ def _mock_urlopen(request, *args, **kwargs):
+ if write_to_disk:
+ with open(location, "w") as f:
+ f.write("")
+ raise ValueError("Invalid request")
+
+ monkeypatch.setattr(sklearn.datasets._openml, "urlopen", _mock_urlopen)
+
+ with pytest.raises(ValueError, match="Invalid request"):
+ _open_openml_url(openml_path, cache_directory)
+
+ assert not os.path.exists(location)
+
+
+def test_retry_with_clean_cache(tmpdir):
+ data_id = 61
+ openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
+ cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
+ location = _get_local_path(openml_path, cache_directory)
+ os.makedirs(os.path.dirname(location))
+
+ with open(location, "w") as f:
+ f.write("")
+
+ @_retry_with_clean_cache(openml_path, cache_directory)
+ def _load_data():
+ # The first call will raise an error since location exists
+ if os.path.exists(location):
+ raise Exception("File exist!")
+ return 1
+
+ warn_msg = "Invalid cache, redownloading file"
+ with pytest.warns(RuntimeWarning, match=warn_msg):
+ result = _load_data()
+ assert result == 1
+
+
+def test_retry_with_clean_cache_http_error(tmpdir):
+ data_id = 61
+ openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
+ cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
+
+ @_retry_with_clean_cache(openml_path, cache_directory)
+ def _load_data():
+ raise HTTPError(
+ url=None, code=412, msg="Simulated mock error", hdrs=None, fp=None
+ )
+
+ error_msg = "Simulated mock error"
+ with pytest.raises(HTTPError, match=error_msg):
+ _load_data()
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+def test_fetch_openml_cache(monkeypatch, gzip_response, tmpdir):
+ def _mock_urlopen_raise(request, *args, **kwargs):
+ raise ValueError(
+ "This mechanism intends to test correct cache"
+ "handling. As such, urlopen should never be "
+ "accessed. URL: %s" % request.get_full_url()
+ )
+
+ data_id = 61
+ cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+ X_fetched, y_fetched = fetch_openml(
+ data_id=data_id,
+ cache=True,
+ data_home=cache_directory,
+ return_X_y=True,
+ as_frame=False,
+ parser="liac-arff",
+ )
+
+ monkeypatch.setattr(sklearn.datasets._openml, "urlopen", _mock_urlopen_raise)
+
+ X_cached, y_cached = fetch_openml(
+ data_id=data_id,
+ cache=True,
+ data_home=cache_directory,
+ return_X_y=True,
+ as_frame=False,
+ parser="liac-arff",
+ )
+ np.testing.assert_array_equal(X_fetched, X_cached)
+ np.testing.assert_array_equal(y_fetched, y_cached)
+
+
+# Known failure of PyPy for OpenML. See the following issue:
+# https://github.com/scikit-learn/scikit-learn/issues/18906
+@fails_if_pypy
+@pytest.mark.parametrize(
+ "as_frame, parser",
+ [
+ (True, "liac-arff"),
+ (False, "liac-arff"),
+ (True, "pandas"),
+ (False, "pandas"),
+ ],
+)
+def test_fetch_openml_verify_checksum(monkeypatch, as_frame, cache, tmpdir, parser):
+ """Check that the checksum is working as expected."""
+ if as_frame or parser == "pandas":
+ pytest.importorskip("pandas")
+
+ data_id = 2
+ _monkey_patch_webbased_functions(monkeypatch, data_id, True)
+
+ # create a temporary modified arff file
+ original_data_module = OPENML_TEST_DATA_MODULE + "." + f"id_{data_id}"
+ original_data_file_name = "data-v1-dl-1666876.arff.gz"
+ corrupt_copy_path = tmpdir / "test_invalid_checksum.arff"
+ with _open_binary(original_data_module, original_data_file_name) as orig_file:
+ orig_gzip = gzip.open(orig_file, "rb")
+ data = bytearray(orig_gzip.read())
+ data[len(data) - 1] = 37
+
+ with gzip.GzipFile(corrupt_copy_path, "wb") as modified_gzip:
+ modified_gzip.write(data)
+
+ # Requests are already mocked by monkey_patch_webbased_functions.
+ # We want to re-use that mock for all requests except file download,
+ # hence creating a thin mock over the original mock
+ mocked_openml_url = sklearn.datasets._openml.urlopen
+
+ def swap_file_mock(request, *args, **kwargs):
+ url = request.get_full_url()
+ if url.endswith("data/v1/download/1666876"):
+ with open(corrupt_copy_path, "rb") as f:
+ corrupted_data = f.read()
+ return _MockHTTPResponse(BytesIO(corrupted_data), is_gzip=True)
+ else:
+ return mocked_openml_url(request)
+
+ monkeypatch.setattr(sklearn.datasets._openml, "urlopen", swap_file_mock)
+
+ # validate failed checksum
+ with pytest.raises(ValueError) as exc:
+ sklearn.datasets.fetch_openml(
+ data_id=data_id, cache=False, as_frame=as_frame, parser=parser
+ )
+ # exception message should have file-path
+ assert exc.match("1666876")
+
+
+def test_open_openml_url_retry_on_network_error(monkeypatch):
+ def _mock_urlopen_network_error(request, *args, **kwargs):
+ raise HTTPError("", 404, "Simulated network error", None, None)
+
+ monkeypatch.setattr(
+ sklearn.datasets._openml, "urlopen", _mock_urlopen_network_error
+ )
+
+ invalid_openml_url = "invalid-url"
+
+ with pytest.warns(
+ UserWarning,
+ match=re.escape(
+ "A network error occurred while downloading"
+ f" {_OPENML_PREFIX + invalid_openml_url}. Retrying..."
+ ),
+ ) as record:
+ with pytest.raises(HTTPError, match="Simulated network error"):
+ _open_openml_url(invalid_openml_url, None, delay=0)
+ assert len(record) == 3
+
+
+###############################################################################
+# Non-regressiont tests
+
+
+@pytest.mark.parametrize("gzip_response", [True, False])
+@pytest.mark.parametrize("parser", ("liac-arff", "pandas"))
+def test_fetch_openml_with_ignored_feature(monkeypatch, gzip_response, parser):
+ """Check that we can load the "zoo" dataset.
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/14340
+ """
+ if parser == "pandas":
+ pytest.importorskip("pandas")
+ data_id = 62
+ _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
+
+ dataset = sklearn.datasets.fetch_openml(
+ data_id=data_id, cache=False, as_frame=False, parser=parser
+ )
+ assert dataset is not None
+ # The dataset has 17 features, including 1 ignored (animal),
+ # so we assert that we don't have the ignored feature in the final Bunch
+ assert dataset["data"].shape == (101, 16)
+ assert "animal" not in dataset["feature_names"]
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_fetch_openml_strip_quotes(monkeypatch):
+ """Check that we strip the single quotes when used as a string delimiter.
+
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/23381
+ """
+ pd = pytest.importorskip("modin.pandas")
+ data_id = 40966
+ _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
+
+ common_params = {"as_frame": True, "cache": False, "data_id": data_id}
+ mice_pandas = fetch_openml(parser="pandas", **common_params)
+ mice_liac_arff = fetch_openml(parser="liac-arff", **common_params)
+ pd.testing.assert_series_equal(mice_pandas.target, mice_liac_arff.target)
+ assert not mice_pandas.target.str.startswith("'").any()
+ assert not mice_pandas.target.str.endswith("'").any()
+
+ # similar behaviour should be observed when the column is not the target
+ mice_pandas = fetch_openml(parser="pandas", target_column="NUMB_N", **common_params)
+ mice_liac_arff = fetch_openml(
+ parser="liac-arff", target_column="NUMB_N", **common_params
+ )
+ pd.testing.assert_series_equal(
+ mice_pandas.frame["class"], mice_liac_arff.frame["class"]
+ )
+ assert not mice_pandas.frame["class"].str.startswith("'").any()
+ assert not mice_pandas.frame["class"].str.endswith("'").any()
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_fetch_openml_leading_whitespace(monkeypatch):
+ """Check that we can strip leading whitespace in pandas parser.
+
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/25311
+ """
+ pd = pytest.importorskip("modin.pandas")
+ data_id = 1590
+ _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
+
+ common_params = {"as_frame": True, "cache": False, "data_id": data_id}
+ adult_pandas = fetch_openml(parser="pandas", **common_params)
+ adult_liac_arff = fetch_openml(parser="liac-arff", **common_params)
+ pd.testing.assert_series_equal(
+ adult_pandas.frame["class"], adult_liac_arff.frame["class"]
+ )
+
+
+###############################################################################
+# Deprecation-changed parameters
+
+
+# TODO(1.4): remove this test
+def test_fetch_openml_deprecation_parser(monkeypatch):
+ """Check that we raise a deprecation warning for parser parameter."""
+ pytest.importorskip("modin.pandas")
+ data_id = 61
+ _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
+
+ with pytest.warns(FutureWarning, match="The default value of `parser` will change"):
+ sklearn.datasets.fetch_openml(data_id=data_id)
diff --git a/modin/pandas/test/interoperability/sklearn/ensemble/_hist_gradient_boosting/test_gradient_boosting.py b/modin/pandas/test/interoperability/sklearn/ensemble/_hist_gradient_boosting/test_gradient_boosting.py
new file mode 100644
index 00000000000..a778fbfb360
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/ensemble/_hist_gradient_boosting/test_gradient_boosting.py
@@ -0,0 +1,1396 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import warnings
+import re
+import numpy as np
+import pytest
+from numpy.testing import assert_allclose, assert_array_equal
+from sklearn._loss.loss import (
+ AbsoluteError,
+ HalfBinomialLoss,
+ HalfSquaredError,
+ PinballLoss,
+)
+from sklearn.datasets import make_classification, make_regression
+from sklearn.datasets import make_low_rank_matrix
+from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler, OneHotEncoder
+from sklearn.model_selection import train_test_split, cross_val_score
+from sklearn.base import clone, BaseEstimator, TransformerMixin
+from sklearn.base import is_regressor
+from sklearn.pipeline import make_pipeline
+from sklearn.metrics import mean_poisson_deviance
+from sklearn.dummy import DummyRegressor
+from sklearn.exceptions import NotFittedError
+from sklearn.compose import make_column_transformer
+from sklearn.ensemble import HistGradientBoostingRegressor
+from sklearn.ensemble import HistGradientBoostingClassifier
+from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower
+from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
+from sklearn.ensemble._hist_gradient_boosting.common import G_H_DTYPE
+from sklearn.utils import shuffle
+from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
+
+
+n_threads = _openmp_effective_n_threads()
+
+X_classification, y_classification = make_classification(random_state=0)
+X_regression, y_regression = make_regression(random_state=0)
+X_multi_classification, y_multi_classification = make_classification(
+ n_classes=3, n_informative=3, random_state=0
+)
+
+
+def _make_dumb_dataset(n_samples):
+ """Make a dumb dataset to test early stopping."""
+ rng = np.random.RandomState(42)
+ X_dumb = rng.randn(n_samples, 1)
+ y_dumb = (X_dumb[:, 0] > 0).astype("int64")
+ return X_dumb, y_dumb
+
+
+@pytest.mark.parametrize(
+ "GradientBoosting, X, y",
+ [
+ (HistGradientBoostingClassifier, X_classification, y_classification),
+ (HistGradientBoostingRegressor, X_regression, y_regression),
+ ],
+)
+@pytest.mark.parametrize(
+ "params, err_msg",
+ [
+ (
+ {"interaction_cst": [0, 1]},
+ "Interaction constraints must be a sequence of tuples or lists",
+ ),
+ (
+ {"interaction_cst": [{0, 9999}]},
+ r"Interaction constraints must consist of integer indices in \[0,"
+ r" n_features - 1\] = \[.*\], specifying the position of features,",
+ ),
+ (
+ {"interaction_cst": [{-1, 0}]},
+ r"Interaction constraints must consist of integer indices in \[0,"
+ r" n_features - 1\] = \[.*\], specifying the position of features,",
+ ),
+ (
+ {"interaction_cst": [{0.5}]},
+ r"Interaction constraints must consist of integer indices in \[0,"
+ r" n_features - 1\] = \[.*\], specifying the position of features,",
+ ),
+ ],
+)
+def test_init_parameters_validation(GradientBoosting, X, y, params, err_msg):
+ with pytest.raises(ValueError, match=err_msg):
+ GradientBoosting(**params).fit(X, y)
+
+
+# TODO(1.3): remove
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+def test_invalid_classification_loss():
+ binary_clf = HistGradientBoostingClassifier(loss="binary_crossentropy")
+ err_msg = (
+ "loss='binary_crossentropy' is not defined for multiclass "
+ "classification with n_classes=3, use "
+ "loss='log_loss' instead"
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ binary_clf.fit(np.zeros(shape=(3, 2)), np.arange(3))
+
+
+@pytest.mark.parametrize(
+ "scoring, validation_fraction, early_stopping, n_iter_no_change, tol",
+ [
+ ("neg_mean_squared_error", 0.1, True, 5, 1e-7), # use scorer
+ ("neg_mean_squared_error", None, True, 5, 1e-1), # use scorer on train
+ (None, 0.1, True, 5, 1e-7), # same with default scorer
+ (None, None, True, 5, 1e-1),
+ ("loss", 0.1, True, 5, 1e-7), # use loss
+ ("loss", None, True, 5, 1e-1), # use loss on training data
+ (None, None, False, 5, 0.0), # no early stopping
+ ],
+)
+def test_early_stopping_regression(
+ scoring, validation_fraction, early_stopping, n_iter_no_change, tol
+):
+ max_iter = 200
+
+ X, y = make_regression(n_samples=50, random_state=0)
+
+ gb = HistGradientBoostingRegressor(
+ verbose=1, # just for coverage
+ min_samples_leaf=5, # easier to overfit fast
+ scoring=scoring,
+ tol=tol,
+ early_stopping=early_stopping,
+ validation_fraction=validation_fraction,
+ max_iter=max_iter,
+ n_iter_no_change=n_iter_no_change,
+ random_state=0,
+ )
+ gb.fit(X, y)
+
+ if early_stopping:
+ assert n_iter_no_change <= gb.n_iter_ < max_iter
+ else:
+ assert gb.n_iter_ == max_iter
+
+
+@pytest.mark.parametrize(
+ "data",
+ (
+ make_classification(n_samples=30, random_state=0),
+ make_classification(
+ n_samples=30, n_classes=3, n_clusters_per_class=1, random_state=0
+ ),
+ ),
+)
+@pytest.mark.parametrize(
+ "scoring, validation_fraction, early_stopping, n_iter_no_change, tol",
+ [
+ ("accuracy", 0.1, True, 5, 1e-7), # use scorer
+ ("accuracy", None, True, 5, 1e-1), # use scorer on training data
+ (None, 0.1, True, 5, 1e-7), # same with default scorer
+ (None, None, True, 5, 1e-1),
+ ("loss", 0.1, True, 5, 1e-7), # use loss
+ ("loss", None, True, 5, 1e-1), # use loss on training data
+ (None, None, False, 5, 0.0), # no early stopping
+ ],
+)
+def test_early_stopping_classification(
+ data, scoring, validation_fraction, early_stopping, n_iter_no_change, tol
+):
+ max_iter = 50
+
+ X, y = data
+
+ gb = HistGradientBoostingClassifier(
+ verbose=1, # just for coverage
+ min_samples_leaf=5, # easier to overfit fast
+ scoring=scoring,
+ tol=tol,
+ early_stopping=early_stopping,
+ validation_fraction=validation_fraction,
+ max_iter=max_iter,
+ n_iter_no_change=n_iter_no_change,
+ random_state=0,
+ )
+ gb.fit(X, y)
+
+ if early_stopping is True:
+ assert n_iter_no_change <= gb.n_iter_ < max_iter
+ else:
+ assert gb.n_iter_ == max_iter
+
+
+@pytest.mark.parametrize(
+ "GradientBoosting, X, y",
+ [
+ (HistGradientBoostingClassifier, *_make_dumb_dataset(10000)),
+ (HistGradientBoostingClassifier, *_make_dumb_dataset(10001)),
+ (HistGradientBoostingRegressor, *_make_dumb_dataset(10000)),
+ (HistGradientBoostingRegressor, *_make_dumb_dataset(10001)),
+ ],
+)
+def test_early_stopping_default(GradientBoosting, X, y):
+ # Test that early stopping is enabled by default if and only if there
+ # are more than 10000 samples
+ gb = GradientBoosting(max_iter=10, n_iter_no_change=2, tol=1e-1)
+ gb.fit(X, y)
+ if X.shape[0] > 10000:
+ assert gb.n_iter_ < gb.max_iter
+ else:
+ assert gb.n_iter_ == gb.max_iter
+
+
+@pytest.mark.parametrize(
+ "scores, n_iter_no_change, tol, stopping",
+ [
+ ([], 1, 0.001, False), # not enough iterations
+ ([1, 1, 1], 5, 0.001, False), # not enough iterations
+ ([1, 1, 1, 1, 1], 5, 0.001, False), # not enough iterations
+ ([1, 2, 3, 4, 5, 6], 5, 0.001, False), # significant improvement
+ ([1, 2, 3, 4, 5, 6], 5, 0.0, False), # significant improvement
+ ([1, 2, 3, 4, 5, 6], 5, 0.999, False), # significant improvement
+ ([1, 2, 3, 4, 5, 6], 5, 5 - 1e-5, False), # significant improvement
+ ([1] * 6, 5, 0.0, True), # no significant improvement
+ ([1] * 6, 5, 0.001, True), # no significant improvement
+ ([1] * 6, 5, 5, True), # no significant improvement
+ ],
+)
+def test_should_stop(scores, n_iter_no_change, tol, stopping):
+ gbdt = HistGradientBoostingClassifier(n_iter_no_change=n_iter_no_change, tol=tol)
+ assert gbdt._should_stop(scores) == stopping
+
+
+def test_absolute_error():
+ # For coverage only.
+ X, y = make_regression(n_samples=500, random_state=0)
+ gbdt = HistGradientBoostingRegressor(loss="absolute_error", random_state=0)
+ gbdt.fit(X, y)
+ assert gbdt.score(X, y) > 0.9
+
+
+def test_absolute_error_sample_weight():
+ # non regression test for issue #19400
+ # make sure no error is thrown during fit of
+ # HistGradientBoostingRegressor with absolute_error loss function
+ # and passing sample_weight
+ rng = np.random.RandomState(0)
+ n_samples = 100
+ X = rng.uniform(-1, 1, size=(n_samples, 2))
+ y = rng.uniform(-1, 1, size=n_samples)
+ sample_weight = rng.uniform(0, 1, size=n_samples)
+ gbdt = HistGradientBoostingRegressor(loss="absolute_error")
+ gbdt.fit(X, y, sample_weight=sample_weight)
+
+
+@pytest.mark.parametrize("quantile", [0.2, 0.5, 0.8])
+def test_asymmetric_error(quantile):
+ """Test quantile regression for asymmetric distributed targets."""
+ n_samples = 10_000
+ rng = np.random.RandomState(42)
+ # take care that X @ coef + intercept > 0
+ X = np.concatenate(
+ (
+ np.abs(rng.randn(n_samples)[:, None]),
+ -rng.randint(2, size=(n_samples, 1)),
+ ),
+ axis=1,
+ )
+ intercept = 1.23
+ coef = np.array([0.5, -2])
+ # For an exponential distribution with rate lambda, e.g. exp(-lambda * x),
+ # the quantile at level q is:
+ # quantile(q) = - log(1 - q) / lambda
+ # scale = 1/lambda = -quantile(q) / log(1-q)
+ y = rng.exponential(
+ scale=-(X @ coef + intercept) / np.log(1 - quantile), size=n_samples
+ )
+ model = HistGradientBoostingRegressor(
+ loss="quantile",
+ quantile=quantile,
+ max_iter=25,
+ random_state=0,
+ max_leaf_nodes=10,
+ ).fit(X, y)
+ assert_allclose(np.mean(model.predict(X) > y), quantile, rtol=1e-2)
+
+ pinball_loss = PinballLoss(quantile=quantile)
+ loss_true_quantile = pinball_loss(y, X @ coef + intercept)
+ loss_pred_quantile = pinball_loss(y, model.predict(X))
+ # we are overfitting
+ assert loss_pred_quantile <= loss_true_quantile
+
+
+@pytest.mark.parametrize("y", [([1.0, -2.0, 0.0]), ([0.0, 0.0, 0.0])])
+def test_poisson_y_positive(y):
+ # Test that ValueError is raised if either one y_i < 0 or sum(y_i) <= 0.
+ err_msg = r"loss='poisson' requires non-negative y and sum\(y\) > 0."
+ gbdt = HistGradientBoostingRegressor(loss="poisson", random_state=0)
+ with pytest.raises(ValueError, match=err_msg):
+ gbdt.fit(np.zeros(shape=(len(y), 1)), y)
+
+
+def test_poisson():
+ # For Poisson distributed target, Poisson loss should give better results
+ # than least squares measured in Poisson deviance as metric.
+ rng = np.random.RandomState(42)
+ n_train, n_test, n_features = 500, 100, 100
+ X = make_low_rank_matrix(
+ n_samples=n_train + n_test, n_features=n_features, random_state=rng
+ )
+ # We create a log-linear Poisson model and downscale coef as it will get
+ # exponentiated.
+ coef = rng.uniform(low=-2, high=2, size=n_features) / np.max(X, axis=0)
+ y = rng.poisson(lam=np.exp(X @ coef))
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, test_size=n_test, random_state=rng
+ )
+ gbdt_pois = HistGradientBoostingRegressor(loss="poisson", random_state=rng)
+ gbdt_ls = HistGradientBoostingRegressor(loss="squared_error", random_state=rng)
+ gbdt_pois.fit(X_train, y_train)
+ gbdt_ls.fit(X_train, y_train)
+ dummy = DummyRegressor(strategy="mean").fit(X_train, y_train)
+
+ for X, y in [(X_train, y_train), (X_test, y_test)]:
+ metric_pois = mean_poisson_deviance(y, gbdt_pois.predict(X))
+ # squared_error might produce non-positive predictions => clip
+ metric_ls = mean_poisson_deviance(y, np.clip(gbdt_ls.predict(X), 1e-15, None))
+ metric_dummy = mean_poisson_deviance(y, dummy.predict(X))
+ assert metric_pois < metric_ls
+ assert metric_pois < metric_dummy
+
+
+def test_binning_train_validation_are_separated():
+ # Make sure training and validation data are binned separately.
+ # See issue 13926
+
+ rng = np.random.RandomState(0)
+ validation_fraction = 0.2
+ gb = HistGradientBoostingClassifier(
+ early_stopping=True, validation_fraction=validation_fraction, random_state=rng
+ )
+ gb.fit(X_classification, y_classification)
+ mapper_training_data = gb._bin_mapper
+
+ # Note that since the data is small there is no subsampling and the
+ # random_state doesn't matter
+ mapper_whole_data = _BinMapper(random_state=0)
+ mapper_whole_data.fit(X_classification)
+
+ n_samples = X_classification.shape[0]
+ assert np.all(
+ mapper_training_data.n_bins_non_missing_
+ == int((1 - validation_fraction) * n_samples)
+ )
+ assert np.all(
+ mapper_training_data.n_bins_non_missing_
+ != mapper_whole_data.n_bins_non_missing_
+ )
+
+
+def test_missing_values_trivial():
+ # sanity check for missing values support. With only one feature and
+ # y == isnan(X), the gbdt is supposed to reach perfect accuracy on the
+ # training set.
+
+ n_samples = 100
+ n_features = 1
+ rng = np.random.RandomState(0)
+
+ X = rng.normal(size=(n_samples, n_features))
+ mask = rng.binomial(1, 0.5, size=X.shape).astype(bool)
+ X[mask] = np.nan
+ y = mask.ravel()
+ gb = HistGradientBoostingClassifier()
+ gb.fit(X, y)
+
+ assert gb.score(X, y) == pytest.approx(1)
+
+
+@pytest.mark.parametrize("problem", ("classification", "regression"))
+@pytest.mark.parametrize(
+ "missing_proportion, expected_min_score_classification, "
+ "expected_min_score_regression",
+ [(0.1, 0.97, 0.89), (0.2, 0.93, 0.81), (0.5, 0.79, 0.52)],
+)
+def test_missing_values_resilience(
+ problem,
+ missing_proportion,
+ expected_min_score_classification,
+ expected_min_score_regression,
+):
+ # Make sure the estimators can deal with missing values and still yield
+ # decent predictions
+
+ rng = np.random.RandomState(0)
+ n_samples = 1000
+ n_features = 2
+ if problem == "regression":
+ X, y = make_regression(
+ n_samples=n_samples,
+ n_features=n_features,
+ n_informative=n_features,
+ random_state=rng,
+ )
+ gb = HistGradientBoostingRegressor()
+ expected_min_score = expected_min_score_regression
+ else:
+ X, y = make_classification(
+ n_samples=n_samples,
+ n_features=n_features,
+ n_informative=n_features,
+ n_redundant=0,
+ n_repeated=0,
+ random_state=rng,
+ )
+ gb = HistGradientBoostingClassifier()
+ expected_min_score = expected_min_score_classification
+
+ mask = rng.binomial(1, missing_proportion, size=X.shape).astype(bool)
+ X[mask] = np.nan
+
+ gb.fit(X, y)
+
+ assert gb.score(X, y) > expected_min_score
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ make_classification(random_state=0, n_classes=2),
+ make_classification(random_state=0, n_classes=3, n_informative=3),
+ ],
+ ids=["binary_log_loss", "multiclass_log_loss"],
+)
+def test_zero_division_hessians(data):
+ # non regression test for issue #14018
+ # make sure we avoid zero division errors when computing the leaves values.
+
+ # If the learning rate is too high, the raw predictions are bad and will
+ # saturate the softmax (or sigmoid in binary classif). This leads to
+ # probabilities being exactly 0 or 1, gradients being constant, and
+ # hessians being zero.
+ X, y = data
+ gb = HistGradientBoostingClassifier(learning_rate=100, max_iter=10)
+ gb.fit(X, y)
+
+
+def test_small_trainset():
+ # Make sure that the small trainset is stratified and has the expected
+ # length (10k samples)
+ n_samples = 20000
+ original_distrib = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}
+ rng = np.random.RandomState(42)
+ X = rng.randn(n_samples).reshape(n_samples, 1)
+ y = [
+ [class_] * int(prop * n_samples) for (class_, prop) in original_distrib.items()
+ ]
+ y = shuffle(np.concatenate(y))
+ gb = HistGradientBoostingClassifier()
+
+ # Compute the small training set
+ X_small, y_small, _ = gb._get_small_trainset(
+ X, y, seed=42, sample_weight_train=None
+ )
+
+ # Compute the class distribution in the small training set
+ unique, counts = np.unique(y_small, return_counts=True)
+ small_distrib = {class_: count / 10000 for (class_, count) in zip(unique, counts)}
+
+ # Test that the small training set has the expected length
+ assert X_small.shape[0] == 10000
+ assert y_small.shape[0] == 10000
+
+ # Test that the class distributions in the whole dataset and in the small
+ # training set are identical
+ assert small_distrib == pytest.approx(original_distrib)
+
+
+def test_missing_values_minmax_imputation():
+ # Compare the buit-in missing value handling of Histogram GBC with an
+ # a-priori missing value imputation strategy that should yield the same
+ # results in terms of decision function.
+ #
+ # Each feature (containing NaNs) is replaced by 2 features:
+ # - one where the nans are replaced by min(feature) - 1
+ # - one where the nans are replaced by max(feature) + 1
+ # A split where nans go to the left has an equivalent split in the
+ # first (min) feature, and a split where nans go to the right has an
+ # equivalent split in the second (max) feature.
+ #
+ # Assuming the data is such that there is never a tie to select the best
+ # feature to split on during training, the learned decision trees should be
+ # strictly equivalent (learn a sequence of splits that encode the same
+ # decision function).
+ #
+ # The MinMaxImputer transformer is meant to be a toy implementation of the
+ # "Missing In Attributes" (MIA) missing value handling for decision trees
+ # https://www.sciencedirect.com/science/article/abs/pii/S0167865508000305
+ # The implementation of MIA as an imputation transformer was suggested by
+ # "Remark 3" in :arxiv:'<1902.06931>`
+
+ class MinMaxImputer(TransformerMixin, BaseEstimator):
+ def fit(self, X, y=None):
+ mm = MinMaxScaler().fit(X)
+ self.data_min_ = mm.data_min_
+ self.data_max_ = mm.data_max_
+ return self
+
+ def transform(self, X):
+ X_min, X_max = X.copy(), X.copy()
+
+ for feature_idx in range(X.shape[1]):
+ nan_mask = np.isnan(X[:, feature_idx])
+ X_min[nan_mask, feature_idx] = self.data_min_[feature_idx] - 1
+ X_max[nan_mask, feature_idx] = self.data_max_[feature_idx] + 1
+
+ return np.concatenate([X_min, X_max], axis=1)
+
+ def make_missing_value_data(n_samples=int(1e4), seed=0):
+ rng = np.random.RandomState(seed)
+ X, y = make_regression(n_samples=n_samples, n_features=4, random_state=rng)
+
+ # Pre-bin the data to ensure a deterministic handling by the 2
+ # strategies and also make it easier to insert np.nan in a structured
+ # way:
+ X = KBinsDiscretizer(n_bins=42, encode="ordinal").fit_transform(X)
+
+ # First feature has missing values completely at random:
+ rnd_mask = rng.rand(X.shape[0]) > 0.9
+ X[rnd_mask, 0] = np.nan
+
+ # Second and third features have missing values for extreme values
+ # (censoring missingness):
+ low_mask = X[:, 1] == 0
+ X[low_mask, 1] = np.nan
+
+ high_mask = X[:, 2] == X[:, 2].max()
+ X[high_mask, 2] = np.nan
+
+ # Make the last feature nan pattern very informative:
+ y_max = np.percentile(y, 70)
+ y_max_mask = y >= y_max
+ y[y_max_mask] = y_max
+ X[y_max_mask, 3] = np.nan
+
+ # Check that there is at least one missing value in each feature:
+ for feature_idx in range(X.shape[1]):
+ assert any(np.isnan(X[:, feature_idx]))
+
+ # Let's use a test set to check that the learned decision function is
+ # the same as evaluated on unseen data. Otherwise it could just be the
+ # case that we find two independent ways to overfit the training set.
+ return train_test_split(X, y, random_state=rng)
+
+ # n_samples need to be large enough to minimize the likelihood of having
+ # several candidate splits with the same gain value in a given tree.
+ X_train, X_test, y_train, y_test = make_missing_value_data(
+ n_samples=int(1e4), seed=0
+ )
+
+ # Use a small number of leaf nodes and iterations so as to keep
+ # under-fitting models to minimize the likelihood of ties when training the
+ # model.
+ gbm1 = HistGradientBoostingRegressor(max_iter=100, max_leaf_nodes=5, random_state=0)
+ gbm1.fit(X_train, y_train)
+
+ gbm2 = make_pipeline(MinMaxImputer(), clone(gbm1))
+ gbm2.fit(X_train, y_train)
+
+ # Check that the model reach the same score:
+ assert gbm1.score(X_train, y_train) == pytest.approx(gbm2.score(X_train, y_train))
+
+ assert gbm1.score(X_test, y_test) == pytest.approx(gbm2.score(X_test, y_test))
+
+ # Check the individual prediction match as a finer grained
+ # decision function check.
+ assert_allclose(gbm1.predict(X_train), gbm2.predict(X_train))
+ assert_allclose(gbm1.predict(X_test), gbm2.predict(X_test))
+
+
+def test_infinite_values():
+ # Basic test for infinite values
+
+ X = np.array([-np.inf, 0, 1, np.inf]).reshape(-1, 1)
+ y = np.array([0, 0, 1, 1])
+
+ gbdt = HistGradientBoostingRegressor(min_samples_leaf=1)
+ gbdt.fit(X, y)
+ np.testing.assert_allclose(gbdt.predict(X), y, atol=1e-4)
+
+
+def test_consistent_lengths():
+ X = np.array([-np.inf, 0, 1, np.inf]).reshape(-1, 1)
+ y = np.array([0, 0, 1, 1])
+ sample_weight = np.array([0.1, 0.3, 0.1])
+ gbdt = HistGradientBoostingRegressor()
+ with pytest.raises(ValueError, match=r"sample_weight.shape == \(3,\), expected"):
+ gbdt.fit(X, y, sample_weight)
+
+ with pytest.raises(
+ ValueError, match="Found input variables with inconsistent number"
+ ):
+ gbdt.fit(X, y[1:])
+
+
+def test_infinite_values_missing_values():
+ # High level test making sure that inf and nan values are properly handled
+ # when both are present. This is similar to
+ # test_split_on_nan_with_infinite_values() in test_grower.py, though we
+ # cannot check the predictions for binned values here.
+
+ X = np.asarray([-np.inf, 0, 1, np.inf, np.nan]).reshape(-1, 1)
+ y_isnan = np.isnan(X.ravel())
+ y_isinf = X.ravel() == np.inf
+
+ stump_clf = HistGradientBoostingClassifier(
+ min_samples_leaf=1, max_iter=1, learning_rate=1, max_depth=2
+ )
+
+ assert stump_clf.fit(X, y_isinf).score(X, y_isinf) == 1
+ assert stump_clf.fit(X, y_isnan).score(X, y_isnan) == 1
+
+
+# TODO(1.3): remove
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+def test_crossentropy_binary_problem():
+ # categorical_crossentropy should only be used if there are more than two
+ # classes present. PR #14869
+ X = [[1], [0]]
+ y = [0, 1]
+ gbrt = HistGradientBoostingClassifier(loss="categorical_crossentropy")
+ with pytest.raises(
+ ValueError, match="loss='categorical_crossentropy' is not suitable for"
+ ):
+ gbrt.fit(X, y)
+
+
+@pytest.mark.parametrize("scoring", [None, "loss"])
+def test_string_target_early_stopping(scoring):
+ # Regression tests for #14709 where the targets need to be encoded before
+ # to compute the score
+ rng = np.random.RandomState(42)
+ X = rng.randn(100, 10)
+ y = np.array(["x"] * 50 + ["y"] * 50, dtype=object)
+ gbrt = HistGradientBoostingClassifier(n_iter_no_change=10, scoring=scoring)
+ gbrt.fit(X, y)
+
+
+def test_zero_sample_weights_regression():
+ # Make sure setting a SW to zero amounts to ignoring the corresponding
+ # sample
+
+ X = [[1, 0], [1, 0], [1, 0], [0, 1]]
+ y = [0, 0, 1, 0]
+ # ignore the first 2 training samples by setting their weight to 0
+ sample_weight = [0, 0, 1, 1]
+ gb = HistGradientBoostingRegressor(min_samples_leaf=1)
+ gb.fit(X, y, sample_weight=sample_weight)
+ assert gb.predict([[1, 0]])[0] > 0.5
+
+
+def test_zero_sample_weights_classification():
+ # Make sure setting a SW to zero amounts to ignoring the corresponding
+ # sample
+
+ X = [[1, 0], [1, 0], [1, 0], [0, 1]]
+ y = [0, 0, 1, 0]
+ # ignore the first 2 training samples by setting their weight to 0
+ sample_weight = [0, 0, 1, 1]
+ gb = HistGradientBoostingClassifier(loss="log_loss", min_samples_leaf=1)
+ gb.fit(X, y, sample_weight=sample_weight)
+ assert_array_equal(gb.predict([[1, 0]]), [1])
+
+ X = [[1, 0], [1, 0], [1, 0], [0, 1], [1, 1]]
+ y = [0, 0, 1, 0, 2]
+ # ignore the first 2 training samples by setting their weight to 0
+ sample_weight = [0, 0, 1, 1, 1]
+ gb = HistGradientBoostingClassifier(loss="log_loss", min_samples_leaf=1)
+ gb.fit(X, y, sample_weight=sample_weight)
+ assert_array_equal(gb.predict([[1, 0]]), [1])
+
+
+@pytest.mark.parametrize(
+ "problem", ("regression", "binary_classification", "multiclass_classification")
+)
+@pytest.mark.parametrize("duplication", ("half", "all"))
+def test_sample_weight_effect(problem, duplication):
+ # High level test to make sure that duplicating a sample is equivalent to
+ # giving it weight of 2.
+
+ # fails for n_samples > 255 because binning does not take sample weights
+ # into account. Keeping n_samples <= 255 makes
+ # sure only unique values are used so SW have no effect on binning.
+ n_samples = 255
+ n_features = 2
+ if problem == "regression":
+ X, y = make_regression(
+ n_samples=n_samples,
+ n_features=n_features,
+ n_informative=n_features,
+ random_state=0,
+ )
+ Klass = HistGradientBoostingRegressor
+ else:
+ n_classes = 2 if problem == "binary_classification" else 3
+ X, y = make_classification(
+ n_samples=n_samples,
+ n_features=n_features,
+ n_informative=n_features,
+ n_redundant=0,
+ n_clusters_per_class=1,
+ n_classes=n_classes,
+ random_state=0,
+ )
+ Klass = HistGradientBoostingClassifier
+
+ # This test can't pass if min_samples_leaf > 1 because that would force 2
+ # samples to be in the same node in est_sw, while these samples would be
+ # free to be separate in est_dup: est_dup would just group together the
+ # duplicated samples.
+ est = Klass(min_samples_leaf=1)
+
+ # Create dataset with duplicate and corresponding sample weights
+ if duplication == "half":
+ lim = n_samples // 2
+ else:
+ lim = n_samples
+ X_dup = np.r_[X, X[:lim]]
+ y_dup = np.r_[y, y[:lim]]
+ sample_weight = np.ones(shape=(n_samples))
+ sample_weight[:lim] = 2
+
+ est_sw = clone(est).fit(X, y, sample_weight=sample_weight)
+ est_dup = clone(est).fit(X_dup, y_dup)
+
+ # checking raw_predict is stricter than just predict for classification
+ assert np.allclose(est_sw._raw_predict(X_dup), est_dup._raw_predict(X_dup))
+
+
+@pytest.mark.parametrize("Loss", (HalfSquaredError, AbsoluteError))
+def test_sum_hessians_are_sample_weight(Loss):
+ # For losses with constant hessians, the sum_hessians field of the
+ # histograms must be equal to the sum of the sample weight of samples at
+ # the corresponding bin.
+
+ rng = np.random.RandomState(0)
+ n_samples = 1000
+ n_features = 2
+ X, y = make_regression(n_samples=n_samples, n_features=n_features, random_state=rng)
+ bin_mapper = _BinMapper()
+ X_binned = bin_mapper.fit_transform(X)
+
+ # While sample weights are supposed to be positive, this still works.
+ sample_weight = rng.normal(size=n_samples)
+
+ loss = Loss(sample_weight=sample_weight)
+ gradients, hessians = loss.init_gradient_and_hessian(
+ n_samples=n_samples, dtype=G_H_DTYPE
+ )
+ gradients, hessians = gradients.reshape((-1, 1)), hessians.reshape((-1, 1))
+ raw_predictions = rng.normal(size=(n_samples, 1))
+ loss.gradient_hessian(
+ y_true=y,
+ raw_prediction=raw_predictions,
+ sample_weight=sample_weight,
+ gradient_out=gradients,
+ hessian_out=hessians,
+ n_threads=n_threads,
+ )
+
+ # build sum_sample_weight which contains the sum of the sample weights at
+ # each bin (for each feature). This must be equal to the sum_hessians
+ # field of the corresponding histogram
+ sum_sw = np.zeros(shape=(n_features, bin_mapper.n_bins))
+ for feature_idx in range(n_features):
+ for sample_idx in range(n_samples):
+ sum_sw[feature_idx, X_binned[sample_idx, feature_idx]] += sample_weight[
+ sample_idx
+ ]
+
+ # Build histogram
+ grower = TreeGrower(
+ X_binned, gradients[:, 0], hessians[:, 0], n_bins=bin_mapper.n_bins
+ )
+ histograms = grower.histogram_builder.compute_histograms_brute(
+ grower.root.sample_indices
+ )
+
+ for feature_idx in range(n_features):
+ for bin_idx in range(bin_mapper.n_bins):
+ assert histograms[feature_idx, bin_idx]["sum_hessians"] == (
+ pytest.approx(sum_sw[feature_idx, bin_idx], rel=1e-5)
+ )
+
+
+def test_max_depth_max_leaf_nodes():
+ # Non regression test for
+ # https://github.com/scikit-learn/scikit-learn/issues/16179
+ # there was a bug when the max_depth and the max_leaf_nodes criteria were
+ # met at the same time, which would lead to max_leaf_nodes not being
+ # respected.
+ X, y = make_classification(random_state=0)
+ est = HistGradientBoostingClassifier(max_depth=2, max_leaf_nodes=3, max_iter=1).fit(
+ X, y
+ )
+ tree = est._predictors[0][0]
+ assert tree.get_max_depth() == 2
+ assert tree.get_n_leaf_nodes() == 3 # would be 4 prior to bug fix
+
+
+def test_early_stopping_on_test_set_with_warm_start():
+ # Non regression test for #16661 where second fit fails with
+ # warm_start=True, early_stopping is on, and no validation set
+ X, y = make_classification(random_state=0)
+ gb = HistGradientBoostingClassifier(
+ max_iter=1,
+ scoring="loss",
+ warm_start=True,
+ early_stopping=True,
+ n_iter_no_change=1,
+ validation_fraction=None,
+ )
+
+ gb.fit(X, y)
+ # does not raise on second call
+ gb.set_params(max_iter=2)
+ gb.fit(X, y)
+
+
+@pytest.mark.parametrize(
+ "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
+)
+def test_single_node_trees(Est):
+ # Make sure it's still possible to build single-node trees. In that case
+ # the value of the root is set to 0. That's a correct value: if the tree is
+ # single-node that's because min_gain_to_split is not respected right from
+ # the root, so we don't want the tree to have any impact on the
+ # predictions.
+
+ X, y = make_classification(random_state=0)
+ y[:] = 1 # constant target will lead to a single root node
+
+ est = Est(max_iter=20)
+ est.fit(X, y)
+
+ assert all(len(predictor[0].nodes) == 1 for predictor in est._predictors)
+ assert all(predictor[0].nodes[0]["value"] == 0 for predictor in est._predictors)
+ # Still gives correct predictions thanks to the baseline prediction
+ assert_allclose(est.predict(X), y)
+
+
+@pytest.mark.parametrize(
+ "Est, loss, X, y",
+ [
+ (
+ HistGradientBoostingClassifier,
+ HalfBinomialLoss(sample_weight=None),
+ X_classification,
+ y_classification,
+ ),
+ (
+ HistGradientBoostingRegressor,
+ HalfSquaredError(sample_weight=None),
+ X_regression,
+ y_regression,
+ ),
+ ],
+)
+def test_custom_loss(Est, loss, X, y):
+ est = Est(loss=loss, max_iter=20)
+ est.fit(X, y)
+
+
+@pytest.mark.parametrize(
+ "HistGradientBoosting, X, y",
+ [
+ (HistGradientBoostingClassifier, X_classification, y_classification),
+ (HistGradientBoostingRegressor, X_regression, y_regression),
+ (
+ HistGradientBoostingClassifier,
+ X_multi_classification,
+ y_multi_classification,
+ ),
+ ],
+)
+def test_staged_predict(HistGradientBoosting, X, y):
+ # Test whether staged predictor eventually gives
+ # the same prediction.
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, test_size=0.5, random_state=0
+ )
+ gb = HistGradientBoosting(max_iter=10)
+
+ # test raise NotFittedError if not fitted
+ with pytest.raises(NotFittedError):
+ next(gb.staged_predict(X_test))
+
+ gb.fit(X_train, y_train)
+
+ # test if the staged predictions of each iteration
+ # are equal to the corresponding predictions of the same estimator
+ # trained from scratch.
+ # this also test limit case when max_iter = 1
+ method_names = (
+ ["predict"]
+ if is_regressor(gb)
+ else ["predict", "predict_proba", "decision_function"]
+ )
+ for method_name in method_names:
+ staged_method = getattr(gb, "staged_" + method_name)
+ staged_predictions = list(staged_method(X_test))
+ assert len(staged_predictions) == gb.n_iter_
+ for n_iter, staged_predictions in enumerate(staged_method(X_test), 1):
+ aux = HistGradientBoosting(max_iter=n_iter)
+ aux.fit(X_train, y_train)
+ pred_aux = getattr(aux, method_name)(X_test)
+
+ assert_allclose(staged_predictions, pred_aux)
+ assert staged_predictions.shape == pred_aux.shape
+
+
+@pytest.mark.parametrize("insert_missing", [False, True])
+@pytest.mark.parametrize(
+ "Est", (HistGradientBoostingRegressor, HistGradientBoostingClassifier)
+)
+@pytest.mark.parametrize("bool_categorical_parameter", [True, False])
+def test_unknown_categories_nan(insert_missing, Est, bool_categorical_parameter):
+ # Make sure no error is raised at predict if a category wasn't seen during
+ # fit. We also make sure they're treated as nans.
+
+ rng = np.random.RandomState(0)
+ n_samples = 1000
+ f1 = rng.rand(n_samples)
+ f2 = rng.randint(4, size=n_samples)
+ X = np.c_[f1, f2]
+ y = np.zeros(shape=n_samples)
+ y[X[:, 1] % 2 == 0] = 1
+
+ if bool_categorical_parameter:
+ categorical_features = [False, True]
+ else:
+ categorical_features = [1]
+
+ if insert_missing:
+ mask = rng.binomial(1, 0.01, size=X.shape).astype(bool)
+ assert mask.sum() > 0
+ X[mask] = np.nan
+
+ est = Est(max_iter=20, categorical_features=categorical_features).fit(X, y)
+ assert_array_equal(est.is_categorical_, [False, True])
+
+ # Make sure no error is raised on unknown categories and nans
+ # unknown categories will be treated as nans
+ X_test = np.zeros((10, X.shape[1]), dtype=float)
+ X_test[:5, 1] = 30
+ X_test[5:, 1] = np.nan
+ assert len(np.unique(est.predict(X_test))) == 1
+
+
+def test_categorical_encoding_strategies():
+ # Check native categorical handling vs different encoding strategies. We
+ # make sure that native encoding needs only 1 split to achieve a perfect
+ # prediction on a simple dataset. In contrast, OneHotEncoded data needs
+ # more depth / splits, and treating categories as ordered (just using
+ # OrdinalEncoder) requires even more depth.
+
+ # dataset with one random continuous feature, and one categorical feature
+ # with values in [0, 5], e.g. from an OrdinalEncoder.
+ # class == 1 iff categorical value in {0, 2, 4}
+ rng = np.random.RandomState(0)
+ n_samples = 10_000
+ f1 = rng.rand(n_samples)
+ f2 = rng.randint(6, size=n_samples)
+ X = np.c_[f1, f2]
+ y = np.zeros(shape=n_samples)
+ y[X[:, 1] % 2 == 0] = 1
+
+ # make sure dataset is balanced so that the baseline_prediction doesn't
+ # influence predictions too much with max_iter = 1
+ assert 0.49 < y.mean() < 0.51
+
+ native_cat_specs = [
+ [False, True],
+ [1],
+ ]
+ try:
+ import modin.pandas as pd
+
+ X = pd.DataFrame(X, columns=["f_0", "f_1"])
+ native_cat_specs.append(["f_1"])
+ except ImportError:
+ pass
+
+ for native_cat_spec in native_cat_specs:
+ clf_cat = HistGradientBoostingClassifier(
+ max_iter=1, max_depth=1, categorical_features=native_cat_spec
+ )
+
+ # Using native categorical encoding, we get perfect predictions with just
+ # one split
+ assert cross_val_score(clf_cat, X, y).mean() == 1
+
+ # quick sanity check for the bitset: 0, 2, 4 = 2**0 + 2**2 + 2**4 = 21
+ expected_left_bitset = [21, 0, 0, 0, 0, 0, 0, 0]
+ left_bitset = clf_cat.fit(X, y)._predictors[0][0].raw_left_cat_bitsets[0]
+ assert_array_equal(left_bitset, expected_left_bitset)
+
+ # Treating categories as ordered, we need more depth / more splits to get
+ # the same predictions
+ clf_no_cat = HistGradientBoostingClassifier(
+ max_iter=1, max_depth=4, categorical_features=None
+ )
+ assert cross_val_score(clf_no_cat, X, y).mean() < 0.9
+
+ clf_no_cat.set_params(max_depth=5)
+ assert cross_val_score(clf_no_cat, X, y).mean() == 1
+
+ # Using OHEd data, we need less splits than with pure OEd data, but we
+ # still need more splits than with the native categorical splits
+ ct = make_column_transformer(
+ (OneHotEncoder(sparse_output=False), [1]), remainder="passthrough"
+ )
+ X_ohe = ct.fit_transform(X)
+ clf_no_cat.set_params(max_depth=2)
+ assert cross_val_score(clf_no_cat, X_ohe, y).mean() < 0.9
+
+ clf_no_cat.set_params(max_depth=3)
+ assert cross_val_score(clf_no_cat, X_ohe, y).mean() == 1
+
+
+@pytest.mark.parametrize(
+ "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
+)
+@pytest.mark.parametrize(
+ "categorical_features, monotonic_cst, expected_msg",
+ [
+ (
+ [b"hello", b"world"],
+ None,
+ re.escape(
+ "categorical_features must be an array-like of bool, int or str, "
+ "got: bytes40."
+ ),
+ ),
+ (
+ np.array([b"hello", 1.3], dtype=object),
+ None,
+ re.escape(
+ "categorical_features must be an array-like of bool, int or str, "
+ "got: bytes, float."
+ ),
+ ),
+ (
+ [0, -1],
+ None,
+ re.escape(
+ "categorical_features set as integer indices must be in "
+ "[0, n_features - 1]"
+ ),
+ ),
+ (
+ [True, True, False, False, True],
+ None,
+ re.escape(
+ "categorical_features set as a boolean mask must have shape "
+ "(n_features,)"
+ ),
+ ),
+ (
+ [True, True, False, False],
+ [0, -1, 0, 1],
+ "Categorical features cannot have monotonic constraints",
+ ),
+ ],
+)
+def test_categorical_spec_errors(
+ Est, categorical_features, monotonic_cst, expected_msg
+):
+ # Test errors when categories are specified incorrectly
+ n_samples = 100
+ X, y = make_classification(random_state=0, n_features=4, n_samples=n_samples)
+ rng = np.random.RandomState(0)
+ X[:, 0] = rng.randint(0, 10, size=n_samples)
+ X[:, 1] = rng.randint(0, 10, size=n_samples)
+ est = Est(categorical_features=categorical_features, monotonic_cst=monotonic_cst)
+
+ with pytest.raises(ValueError, match=expected_msg):
+ est.fit(X, y)
+
+
+@pytest.mark.parametrize(
+ "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
+)
+def test_categorical_spec_errors_with_feature_names(Est):
+ pd = pytest.importorskip("modin.pandas")
+ n_samples = 10
+ X = pd.DataFrame(
+ {
+ "f0": range(n_samples),
+ "f1": range(n_samples),
+ "f2": [1.0] * n_samples,
+ }
+ )
+ y = [0, 1] * (n_samples // 2)
+
+ est = Est(categorical_features=["f0", "f1", "f3"])
+ expected_msg = re.escape(
+ "categorical_features has a item value 'f3' which is not a valid "
+ "feature name of the training data."
+ )
+ with pytest.raises(ValueError, match=expected_msg):
+ est.fit(X, y)
+
+ est = Est(categorical_features=["f0", "f1"])
+ expected_msg = re.escape(
+ "categorical_features should be passed as an array of integers or "
+ "as a boolean mask when the model is fitted on data without feature "
+ "names."
+ )
+ with pytest.raises(ValueError, match=expected_msg):
+ est.fit(X.to_numpy(), y)
+
+
+@pytest.mark.parametrize(
+ "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
+)
+@pytest.mark.parametrize("categorical_features", ([False, False], []))
+@pytest.mark.parametrize("as_array", (True, False))
+def test_categorical_spec_no_categories(Est, categorical_features, as_array):
+ # Make sure we can properly detect that no categorical features are present
+ # even if the categorical_features parameter is not None
+ X = np.arange(10).reshape(5, 2)
+ y = np.arange(5)
+ if as_array:
+ categorical_features = np.asarray(categorical_features)
+ est = Est(categorical_features=categorical_features).fit(X, y)
+ assert est.is_categorical_ is None
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize(
+ "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
+)
+@pytest.mark.parametrize(
+ "use_pandas, feature_name", [(False, "at index 0"), (True, "'f0'")]
+)
+def test_categorical_bad_encoding_errors(Est, use_pandas, feature_name):
+ # Test errors when categories are encoded incorrectly
+
+ gb = Est(categorical_features=[True], max_bins=2)
+
+ if use_pandas:
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame({"f0": [0, 1, 2]})
+ else:
+ X = np.array([[0, 1, 2]]).T
+ y = np.arange(3)
+ msg = (
+ f"Categorical feature {feature_name} is expected to have a "
+ "cardinality <= 2 but actually has a cardinality of 3."
+ )
+ with pytest.raises(ValueError, match=msg):
+ gb.fit(X, y)
+
+ if use_pandas:
+ X = pd.DataFrame({"f0": [0, 2]})
+ else:
+ X = np.array([[0, 2]]).T
+ y = np.arange(2)
+ msg = (
+ f"Categorical feature {feature_name} is expected to be encoded "
+ "with values < 2 but the largest value for the encoded categories "
+ "is 2.0."
+ )
+ with pytest.raises(ValueError, match=msg):
+ gb.fit(X, y)
+
+ # nans are ignored in the counts
+ X = np.array([[0, 1, np.nan]]).T
+ y = np.arange(3)
+ gb.fit(X, y)
+
+
+@pytest.mark.parametrize(
+ "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
+)
+def test_uint8_predict(Est):
+ # Non regression test for
+ # https://github.com/scikit-learn/scikit-learn/issues/18408
+ # Make sure X can be of dtype uint8 (i.e. X_BINNED_DTYPE) in predict. It
+ # will be converted to X_DTYPE.
+
+ rng = np.random.RandomState(0)
+
+ X = rng.randint(0, 100, size=(10, 2)).astype(np.uint8)
+ y = rng.randint(0, 2, size=10).astype(np.uint8)
+ est = Est()
+ est.fit(X, y)
+ est.predict(X)
+
+
+@pytest.mark.parametrize(
+ "interaction_cst, n_features, result",
+ [
+ (None, 931, None),
+ ([{0, 1}], 2, [{0, 1}]),
+ ("pairwise", 2, [{0, 1}]),
+ ("pairwise", 4, [{0, 1}, {0, 2}, {0, 3}, {1, 2}, {1, 3}, {2, 3}]),
+ ("no_interactions", 2, [{0}, {1}]),
+ ("no_interactions", 4, [{0}, {1}, {2}, {3}]),
+ ([(1, 0), [5, 1]], 6, [{0, 1}, {1, 5}, {2, 3, 4}]),
+ ],
+)
+def test_check_interaction_cst(interaction_cst, n_features, result):
+ """Check that _check_interaction_cst returns the expected list of sets"""
+ est = HistGradientBoostingRegressor()
+ est.set_params(interaction_cst=interaction_cst)
+ assert est._check_interaction_cst(n_features) == result
+
+
+def test_interaction_cst_numerically():
+ """Check that interaction constraints have no forbidden interactions."""
+ rng = np.random.RandomState(42)
+ n_samples = 1000
+ X = rng.uniform(size=(n_samples, 2))
+ # Construct y with a strong interaction term
+ # y = x0 + x1 + 5 * x0 * x1
+ y = np.hstack((X, 5 * X[:, [0]] * X[:, [1]])).sum(axis=1)
+
+ est = HistGradientBoostingRegressor(random_state=42)
+ est.fit(X, y)
+ est_no_interactions = HistGradientBoostingRegressor(
+ interaction_cst=[{0}, {1}], random_state=42
+ )
+ est_no_interactions.fit(X, y)
+
+ delta = 0.25
+ # Make sure we do not extrapolate out of the training set as tree-based estimators
+ # are very bad in doing so.
+ X_test = X[(X[:, 0] < 1 - delta) & (X[:, 1] < 1 - delta)]
+ X_delta_d_0 = X_test + [delta, 0]
+ X_delta_0_d = X_test + [0, delta]
+ X_delta_d_d = X_test + [delta, delta]
+
+ # Note: For the y from above as a function of x0 and x1, we have
+ # y(x0+d, x1+d) = y(x0, x1) + 5 * d * (2/5 + x0 + x1) + 5 * d**2
+ # y(x0+d, x1) = y(x0, x1) + 5 * d * (1/5 + x1)
+ # y(x0, x1+d) = y(x0, x1) + 5 * d * (1/5 + x0)
+ # Without interaction constraints, we would expect a result of 5 * d**2 for the
+ # following expression, but zero with constraints in place.
+ assert_allclose(
+ est_no_interactions.predict(X_delta_d_d)
+ + est_no_interactions.predict(X_test)
+ - est_no_interactions.predict(X_delta_d_0)
+ - est_no_interactions.predict(X_delta_0_d),
+ 0,
+ atol=1e-12,
+ )
+
+ # Correct result of the expressions is 5 * delta**2. But this is hard to achieve by
+ # a fitted tree-based model. However, with 100 iterations the expression should
+ # at least be positive!
+ assert np.all(
+ est.predict(X_delta_d_d)
+ + est.predict(X_test)
+ - est.predict(X_delta_d_0)
+ - est.predict(X_delta_0_d)
+ > 0.01
+ )
+
+
+# TODO(1.3): Remove
+@pytest.mark.parametrize(
+ "old_loss, new_loss, Estimator",
+ [
+ ("auto", "log_loss", HistGradientBoostingClassifier),
+ ("binary_crossentropy", "log_loss", HistGradientBoostingClassifier),
+ ("categorical_crossentropy", "log_loss", HistGradientBoostingClassifier),
+ ],
+)
+def test_loss_deprecated(old_loss, new_loss, Estimator):
+ if old_loss == "categorical_crossentropy":
+ X, y = X_multi_classification[:10], y_multi_classification[:10]
+ assert len(np.unique(y)) > 2
+ else:
+ X, y = X_classification[:10], y_classification[:10]
+
+ est1 = Estimator(loss=old_loss, random_state=0)
+
+ with pytest.warns(FutureWarning, match=f"The loss '{old_loss}' was deprecated"):
+ est1.fit(X, y)
+
+ est2 = Estimator(loss=new_loss, random_state=0)
+ est2.fit(X, y)
+ assert_allclose(est1.predict(X), est2.predict(X))
+
+
+def test_no_user_warning_with_scoring():
+ """Check that no UserWarning is raised when scoring is set.
+
+ Non-regression test for #22907.
+ """
+ pd = pytest.importorskip("modin.pandas")
+ X, y = make_regression(n_samples=50, random_state=0)
+ X_df = pd.DataFrame(X, columns=[f"col{i}" for i in range(X.shape[1])])
+
+ est = HistGradientBoostingRegressor(
+ random_state=0, scoring="neg_mean_absolute_error", early_stopping=True
+ )
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", UserWarning)
+ est.fit(X_df, y)
+
+
+def test_class_weights():
+ """High level test to check class_weights."""
+ n_samples = 255
+ n_features = 2
+
+ X, y = make_classification(
+ n_samples=n_samples,
+ n_features=n_features,
+ n_informative=n_features,
+ n_redundant=0,
+ n_clusters_per_class=1,
+ n_classes=2,
+ random_state=0,
+ )
+ y_is_1 = y == 1
+
+ # class_weight is the same as sample weights with the corresponding class
+ clf = HistGradientBoostingClassifier(
+ min_samples_leaf=2, random_state=0, max_depth=2
+ )
+ sample_weight = np.ones(shape=(n_samples))
+ sample_weight[y_is_1] = 3.0
+ clf.fit(X, y, sample_weight=sample_weight)
+
+ class_weight = {0: 1.0, 1: 3.0}
+ clf_class_weighted = clone(clf).set_params(class_weight=class_weight)
+ clf_class_weighted.fit(X, y)
+
+ assert_allclose(clf.decision_function(X), clf_class_weighted.decision_function(X))
+
+ # Check that sample_weight and class_weight are multiplicative
+ clf.fit(X, y, sample_weight=sample_weight**2)
+ clf_class_weighted.fit(X, y, sample_weight=sample_weight)
+ assert_allclose(clf.decision_function(X), clf_class_weighted.decision_function(X))
+
+ # Make imbalanced dataset
+ X_imb = np.concatenate((X[~y_is_1], X[y_is_1][:10]))
+ y_imb = np.concatenate((y[~y_is_1], y[y_is_1][:10]))
+
+ # class_weight="balanced" is the same as sample_weights to be
+ # inversely proportional to n_samples / (n_classes * np.bincount(y))
+ clf_balanced = clone(clf).set_params(class_weight="balanced")
+ clf_balanced.fit(X_imb, y_imb)
+
+ class_weight = y_imb.shape[0] / (2 * np.bincount(y_imb))
+ sample_weight = class_weight[y_imb]
+ clf_sample_weight = clone(clf).set_params(class_weight=None)
+ clf_sample_weight.fit(X_imb, y_imb, sample_weight=sample_weight)
+
+ assert_allclose(
+ clf_balanced.decision_function(X_imb),
+ clf_sample_weight.decision_function(X_imb),
+ )
+
+
+def test_unknown_category_that_are_negative():
+ """Check that unknown categories that are negative does not error.
+
+ Non-regression test for #24274.
+ """
+ rng = np.random.RandomState(42)
+ n_samples = 1000
+ X = np.c_[rng.rand(n_samples), rng.randint(4, size=n_samples)]
+ y = np.zeros(shape=n_samples)
+ y[X[:, 1] % 2 == 0] = 1
+
+ hist = HistGradientBoostingRegressor(
+ random_state=0,
+ categorical_features=[False, True],
+ max_iter=10,
+ ).fit(X, y)
+
+ # Check that negative values from the second column are treated like a
+ # missing category
+ X_test_neg = np.asarray([[1, -2], [3, -4]])
+ X_test_nan = np.asarray([[1, np.nan], [3, np.nan]])
+
+ assert_allclose(hist.predict(X_test_neg), hist.predict(X_test_nan))
diff --git a/modin/pandas/test/interoperability/sklearn/ensemble/_hist_gradient_boosting/test_monotonic_constraints.py b/modin/pandas/test/interoperability/sklearn/ensemble/_hist_gradient_boosting/test_monotonic_constraints.py
new file mode 100644
index 00000000000..521f7a75c6c
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/ensemble/_hist_gradient_boosting/test_monotonic_constraints.py
@@ -0,0 +1,439 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import re
+import numpy as np
+import pytest
+from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower
+from sklearn.ensemble._hist_gradient_boosting.common import G_H_DTYPE
+from sklearn.ensemble._hist_gradient_boosting.common import X_BINNED_DTYPE
+from sklearn.ensemble._hist_gradient_boosting.common import MonotonicConstraint
+from sklearn.ensemble._hist_gradient_boosting.splitting import (
+ Splitter,
+ compute_node_value,
+)
+from sklearn.ensemble._hist_gradient_boosting.histogram import HistogramBuilder
+from sklearn.ensemble import HistGradientBoostingRegressor
+from sklearn.ensemble import HistGradientBoostingClassifier
+from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
+
+n_threads = _openmp_effective_n_threads()
+
+
+def is_increasing(a):
+ return (np.diff(a) >= 0.0).all()
+
+
+def is_decreasing(a):
+ return (np.diff(a) <= 0.0).all()
+
+
+def assert_leaves_values_monotonic(predictor, monotonic_cst):
+ # make sure leaves values (from left to right) are either all increasing
+ # or all decreasing (or neither) depending on the monotonic constraint.
+ nodes = predictor.nodes
+
+ def get_leaves_values():
+ """get leaves values from left to right"""
+ values = []
+
+ def depth_first_collect_leaf_values(node_idx):
+ node = nodes[node_idx]
+ if node["is_leaf"]:
+ values.append(node["value"])
+ return
+ depth_first_collect_leaf_values(node["left"])
+ depth_first_collect_leaf_values(node["right"])
+
+ depth_first_collect_leaf_values(0) # start at root (0)
+ return values
+
+ values = get_leaves_values()
+
+ if monotonic_cst == MonotonicConstraint.NO_CST:
+ # some increasing, some decreasing
+ assert not is_increasing(values) and not is_decreasing(values)
+ elif monotonic_cst == MonotonicConstraint.POS:
+ # all increasing
+ assert is_increasing(values)
+ else: # NEG
+ # all decreasing
+ assert is_decreasing(values)
+
+
+def assert_children_values_monotonic(predictor, monotonic_cst):
+ # Make sure siblings values respect the monotonic constraints. Left should
+ # be lower (resp greater) than right child if constraint is POS (resp.
+ # NEG).
+ # Note that this property alone isn't enough to ensure full monotonicity,
+ # since we also need to guanrantee that all the descendents of the left
+ # child won't be greater (resp. lower) than the right child, or its
+ # descendents. That's why we need to bound the predicted values (this is
+ # tested in assert_children_values_bounded)
+ nodes = predictor.nodes
+ left_lower = []
+ left_greater = []
+ for node in nodes:
+ if node["is_leaf"]:
+ continue
+
+ left_idx = node["left"]
+ right_idx = node["right"]
+
+ if nodes[left_idx]["value"] < nodes[right_idx]["value"]:
+ left_lower.append(node)
+ elif nodes[left_idx]["value"] > nodes[right_idx]["value"]:
+ left_greater.append(node)
+
+ if monotonic_cst == MonotonicConstraint.NO_CST:
+ assert left_lower and left_greater
+ elif monotonic_cst == MonotonicConstraint.POS:
+ assert left_lower and not left_greater
+ else: # NEG
+ assert not left_lower and left_greater
+
+
+def assert_children_values_bounded(grower, monotonic_cst):
+ # Make sure that the values of the children of a node are bounded by the
+ # middle value between that node and its sibling (if there is a monotonic
+ # constraint).
+ # As a bonus, we also check that the siblings values are properly ordered
+ # which is slightly redundant with assert_children_values_monotonic (but
+ # this check is done on the grower nodes whereas
+ # assert_children_values_monotonic is done on the predictor nodes)
+
+ if monotonic_cst == MonotonicConstraint.NO_CST:
+ return
+
+ def recursively_check_children_node_values(node, right_sibling=None):
+ if node.is_leaf:
+ return
+ if right_sibling is not None:
+ middle = (node.value + right_sibling.value) / 2
+ if monotonic_cst == MonotonicConstraint.POS:
+ assert node.left_child.value <= node.right_child.value <= middle
+ if not right_sibling.is_leaf:
+ assert (
+ middle
+ <= right_sibling.left_child.value
+ <= right_sibling.right_child.value
+ )
+ else: # NEG
+ assert node.left_child.value >= node.right_child.value >= middle
+ if not right_sibling.is_leaf:
+ assert (
+ middle
+ >= right_sibling.left_child.value
+ >= right_sibling.right_child.value
+ )
+
+ recursively_check_children_node_values(
+ node.left_child, right_sibling=node.right_child
+ )
+ recursively_check_children_node_values(node.right_child)
+
+ recursively_check_children_node_values(grower.root)
+
+
+@pytest.mark.parametrize("seed", range(3))
+@pytest.mark.parametrize(
+ "monotonic_cst",
+ (
+ MonotonicConstraint.NO_CST,
+ MonotonicConstraint.POS,
+ MonotonicConstraint.NEG,
+ ),
+)
+def test_nodes_values(monotonic_cst, seed):
+ # Build a single tree with only one feature, and make sure the nodes
+ # values respect the monotonic constraints.
+
+ # Considering the following tree with a monotonic POS constraint, we
+ # should have:
+ #
+ # root
+ # / \
+ # 5 10 # middle = 7.5
+ # / \ / \
+ # a b c d
+ #
+ # a <= b and c <= d (assert_children_values_monotonic)
+ # a, b <= middle <= c, d (assert_children_values_bounded)
+ # a <= b <= c <= d (assert_leaves_values_monotonic)
+ #
+ # The last one is a consequence of the others, but can't hurt to check
+
+ rng = np.random.RandomState(seed)
+ n_samples = 1000
+ n_features = 1
+ X_binned = rng.randint(0, 255, size=(n_samples, n_features), dtype=np.uint8)
+ X_binned = np.asfortranarray(X_binned)
+
+ gradients = rng.normal(size=n_samples).astype(G_H_DTYPE)
+ hessians = np.ones(shape=1, dtype=G_H_DTYPE)
+
+ grower = TreeGrower(
+ X_binned, gradients, hessians, monotonic_cst=[monotonic_cst], shrinkage=0.1
+ )
+ grower.grow()
+
+ # grow() will shrink the leaves values at the very end. For our comparison
+ # tests, we need to revert the shrinkage of the leaves, else we would
+ # compare the value of a leaf (shrunk) with a node (not shrunk) and the
+ # test would not be correct.
+ for leave in grower.finalized_leaves:
+ leave.value /= grower.shrinkage
+
+ # We pass undefined binning_thresholds because we won't use predict anyway
+ predictor = grower.make_predictor(
+ binning_thresholds=np.zeros((X_binned.shape[1], X_binned.max() + 1))
+ )
+
+ # The consistency of the bounds can only be checked on the tree grower
+ # as the node bounds are not copied into the predictor tree. The
+ # consistency checks on the values of node children and leaves can be
+ # done either on the grower tree or on the predictor tree. We only
+ # do those checks on the predictor tree as the latter is derived from
+ # the former.
+ assert_children_values_monotonic(predictor, monotonic_cst)
+ assert_children_values_bounded(grower, monotonic_cst)
+ assert_leaves_values_monotonic(predictor, monotonic_cst)
+
+
+@pytest.mark.parametrize("use_feature_names", (True, False))
+def test_predictions(global_random_seed, use_feature_names):
+ # Train a model with a POS constraint on the first feature and a NEG
+ # constraint on the second feature, and make sure the constraints are
+ # respected by checking the predictions.
+ # test adapted from lightgbm's test_monotone_constraint(), itself inspired
+ # by https://xgboost.readthedocs.io/en/latest/tutorials/monotonic.html
+
+ rng = np.random.RandomState(global_random_seed)
+
+ n_samples = 1000
+ f_0 = rng.rand(n_samples) # positive correlation with y
+ f_1 = rng.rand(n_samples) # negative correslation with y
+ X = np.c_[f_0, f_1]
+ if use_feature_names:
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame(X, columns=["f_0", "f_1"])
+
+ noise = rng.normal(loc=0.0, scale=0.01, size=n_samples)
+ y = 5 * f_0 + np.sin(10 * np.pi * f_0) - 5 * f_1 - np.cos(10 * np.pi * f_1) + noise
+
+ if use_feature_names:
+ monotonic_cst = {"f_0": +1, "f_1": -1}
+ else:
+ monotonic_cst = [+1, -1]
+
+ gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst)
+ gbdt.fit(X, y)
+
+ linspace = np.linspace(0, 1, 100)
+ sin = np.sin(linspace)
+ constant = np.full_like(linspace, fill_value=0.5)
+
+ # We now assert the predictions properly respect the constraints, on each
+ # feature. When testing for a feature we need to set the other one to a
+ # constant, because the monotonic constraints are only a "all else being
+ # equal" type of constraints:
+ # a constraint on the first feature only means that
+ # x0 < x0' => f(x0, x1) < f(x0', x1)
+ # while x1 stays constant.
+ # The constraint does not guanrantee that
+ # x0 < x0' => f(x0, x1) < f(x0', x1')
+
+ # First feature (POS)
+ # assert pred is all increasing when f_0 is all increasing
+ X = np.c_[linspace, constant]
+ pred = gbdt.predict(X)
+ assert is_increasing(pred)
+ # assert pred actually follows the variations of f_0
+ X = np.c_[sin, constant]
+ pred = gbdt.predict(X)
+ assert np.all((np.diff(pred) >= 0) == (np.diff(sin) >= 0))
+
+ # Second feature (NEG)
+ # assert pred is all decreasing when f_1 is all increasing
+ X = np.c_[constant, linspace]
+ pred = gbdt.predict(X)
+ assert is_decreasing(pred)
+ # assert pred actually follows the inverse variations of f_1
+ X = np.c_[constant, sin]
+ pred = gbdt.predict(X)
+ assert ((np.diff(pred) <= 0) == (np.diff(sin) >= 0)).all()
+
+
+def test_input_error():
+ X = [[1, 2], [2, 3], [3, 4]]
+ y = [0, 1, 2]
+
+ gbdt = HistGradientBoostingRegressor(monotonic_cst=[1, 0, -1])
+ with pytest.raises(
+ ValueError, match=re.escape("monotonic_cst has shape (3,) but the input data")
+ ):
+ gbdt.fit(X, y)
+
+ for monotonic_cst in ([1, 3], [1, -3], [0.3, -0.7]):
+ gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst)
+ expected_msg = re.escape(
+ "must be an array-like of -1, 0 or 1. Observed values:"
+ )
+ with pytest.raises(ValueError, match=expected_msg):
+ gbdt.fit(X, y)
+
+ gbdt = HistGradientBoostingClassifier(monotonic_cst=[0, 1])
+ with pytest.raises(
+ ValueError,
+ match="monotonic constraints are not supported for multiclass classification",
+ ):
+ gbdt.fit(X, y)
+
+
+def test_input_error_related_to_feature_names():
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame({"a": [0, 1, 2], "b": [0, 1, 2]})
+ y = np.array([0, 1, 0])
+
+ monotonic_cst = {"d": 1, "a": 1, "c": -1}
+ gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst)
+ expected_msg = re.escape(
+ "monotonic_cst contains 2 unexpected feature names: ['c', 'd']."
+ )
+ with pytest.raises(ValueError, match=expected_msg):
+ gbdt.fit(X, y)
+
+ monotonic_cst = {k: 1 for k in "abcdefghijklmnopqrstuvwxyz"}
+ gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst)
+ expected_msg = re.escape(
+ "monotonic_cst contains 24 unexpected feature names: "
+ "['c', 'd', 'e', 'f', 'g', '...']."
+ )
+ with pytest.raises(ValueError, match=expected_msg):
+ gbdt.fit(X, y)
+
+ monotonic_cst = {"a": 1}
+ gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst)
+ expected_msg = re.escape(
+ "HistGradientBoostingRegressor was not fitted on data with feature "
+ "names. Pass monotonic_cst as an integer array instead."
+ )
+ with pytest.raises(ValueError, match=expected_msg):
+ gbdt.fit(X.values, y)
+
+ monotonic_cst = {"b": -1, "a": "+"}
+ gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst)
+ expected_msg = re.escape("monotonic_cst['a'] must be either -1, 0 or 1. Got '+'.")
+ with pytest.raises(ValueError, match=expected_msg):
+ gbdt.fit(X, y)
+
+
+def test_bounded_value_min_gain_to_split():
+ # The purpose of this test is to show that when computing the gain at a
+ # given split, the value of the current node should be properly bounded to
+ # respect the monotonic constraints, because it strongly interacts with
+ # min_gain_to_split. We build a simple example where gradients are [1, 1,
+ # 100, 1, 1] (hessians are all ones). The best split happens on the 3rd
+ # bin, and depending on whether the value of the node is bounded or not,
+ # the min_gain_to_split constraint is or isn't satisfied.
+ l2_regularization = 0
+ min_hessian_to_split = 0
+ min_samples_leaf = 1
+ n_bins = n_samples = 5
+ X_binned = np.arange(n_samples).reshape(-1, 1).astype(X_BINNED_DTYPE)
+ sample_indices = np.arange(n_samples, dtype=np.uint32)
+ all_hessians = np.ones(n_samples, dtype=G_H_DTYPE)
+ all_gradients = np.array([1, 1, 100, 1, 1], dtype=G_H_DTYPE)
+ sum_gradients = all_gradients.sum()
+ sum_hessians = all_hessians.sum()
+ hessians_are_constant = False
+
+ builder = HistogramBuilder(
+ X_binned, n_bins, all_gradients, all_hessians, hessians_are_constant, n_threads
+ )
+ n_bins_non_missing = np.array([n_bins - 1] * X_binned.shape[1], dtype=np.uint32)
+ has_missing_values = np.array([False] * X_binned.shape[1], dtype=np.uint8)
+ monotonic_cst = np.array(
+ [MonotonicConstraint.NO_CST] * X_binned.shape[1], dtype=np.int8
+ )
+ is_categorical = np.zeros_like(monotonic_cst, dtype=np.uint8)
+ missing_values_bin_idx = n_bins - 1
+ children_lower_bound, children_upper_bound = -np.inf, np.inf
+
+ min_gain_to_split = 2000
+ splitter = Splitter(
+ X_binned,
+ n_bins_non_missing,
+ missing_values_bin_idx,
+ has_missing_values,
+ is_categorical,
+ monotonic_cst,
+ l2_regularization,
+ min_hessian_to_split,
+ min_samples_leaf,
+ min_gain_to_split,
+ hessians_are_constant,
+ )
+
+ histograms = builder.compute_histograms_brute(sample_indices)
+
+ # Since the gradient array is [1, 1, 100, 1, 1]
+ # the max possible gain happens on the 3rd bin (or equivalently in the 2nd)
+ # and is equal to about 1307, which less than min_gain_to_split = 2000, so
+ # the node is considered unsplittable (gain = -1)
+ current_lower_bound, current_upper_bound = -np.inf, np.inf
+ value = compute_node_value(
+ sum_gradients,
+ sum_hessians,
+ current_lower_bound,
+ current_upper_bound,
+ l2_regularization,
+ )
+ # the unbounded value is equal to -sum_gradients / sum_hessians
+ assert value == pytest.approx(-104 / 5)
+ split_info = splitter.find_node_split(
+ n_samples,
+ histograms,
+ sum_gradients,
+ sum_hessians,
+ value,
+ lower_bound=children_lower_bound,
+ upper_bound=children_upper_bound,
+ )
+ assert split_info.gain == -1 # min_gain_to_split not respected
+
+ # here again the max possible gain is on the 3rd bin but we now cap the
+ # value of the node into [-10, inf].
+ # This means the gain is now about 2430 which is more than the
+ # min_gain_to_split constraint.
+ current_lower_bound, current_upper_bound = -10, np.inf
+ value = compute_node_value(
+ sum_gradients,
+ sum_hessians,
+ current_lower_bound,
+ current_upper_bound,
+ l2_regularization,
+ )
+ assert value == -10
+ split_info = splitter.find_node_split(
+ n_samples,
+ histograms,
+ sum_gradients,
+ sum_hessians,
+ value,
+ lower_bound=children_lower_bound,
+ upper_bound=children_upper_bound,
+ )
+ assert split_info.gain > min_gain_to_split
diff --git a/modin/pandas/test/interoperability/sklearn/feature_selection/test_from_model.py b/modin/pandas/test/interoperability/sklearn/feature_selection/test_from_model.py
new file mode 100644
index 00000000000..e57b680892d
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/feature_selection/test_from_model.py
@@ -0,0 +1,673 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import re
+import pytest
+import numpy as np
+import warnings
+from unittest.mock import Mock
+from sklearn.utils._testing import assert_array_almost_equal
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_allclose
+from sklearn.utils._testing import skip_if_32bit
+from sklearn.utils._testing import MinimalClassifier
+from sklearn import datasets
+from sklearn.cross_decomposition import CCA, PLSCanonical, PLSRegression
+from sklearn.datasets import make_friedman1
+from sklearn.exceptions import NotFittedError
+from sklearn.linear_model import (
+ LogisticRegression,
+ SGDClassifier,
+ Lasso,
+ LassoCV,
+ ElasticNet,
+ ElasticNetCV,
+)
+from sklearn.svm import LinearSVC
+from sklearn.feature_selection import SelectFromModel
+from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
+from sklearn.linear_model import PassiveAggressiveClassifier
+from sklearn.base import BaseEstimator
+from sklearn.pipeline import make_pipeline
+from sklearn.decomposition import PCA
+
+
+class NaNTag(BaseEstimator):
+ def _more_tags(self):
+ return {"allow_nan": True}
+
+
+class NoNaNTag(BaseEstimator):
+ def _more_tags(self):
+ return {"allow_nan": False}
+
+
+class NaNTagRandomForest(RandomForestClassifier):
+ def _more_tags(self):
+ return {"allow_nan": True}
+
+
+iris = datasets.load_iris()
+data, y = iris.data, iris.target
+rng = np.random.RandomState(0)
+
+
+def test_invalid_input():
+ clf = SGDClassifier(
+ alpha=0.1, max_iter=10, shuffle=True, random_state=None, tol=None
+ )
+ for threshold in ["gobbledigook", ".5 * gobbledigook"]:
+ model = SelectFromModel(clf, threshold=threshold)
+ model.fit(data, y)
+ with pytest.raises(ValueError):
+ model.transform(data)
+
+
+def test_input_estimator_unchanged():
+ # Test that SelectFromModel fits on a clone of the estimator.
+ est = RandomForestClassifier()
+ transformer = SelectFromModel(estimator=est)
+ transformer.fit(data, y)
+ assert transformer.estimator is est
+
+
+@pytest.mark.parametrize(
+ "max_features, err_type, err_msg",
+ [
+ (
+ data.shape[1] + 1,
+ ValueError,
+ "max_features ==",
+ ),
+ (
+ lambda X: 1.5,
+ TypeError,
+ "max_features must be an instance of int, not float.",
+ ),
+ (
+ lambda X: data.shape[1] + 1,
+ ValueError,
+ "max_features ==",
+ ),
+ (
+ lambda X: -1,
+ ValueError,
+ "max_features ==",
+ ),
+ ],
+)
+def test_max_features_error(max_features, err_type, err_msg):
+ err_msg = re.escape(err_msg)
+ clf = RandomForestClassifier(n_estimators=5, random_state=0)
+
+ transformer = SelectFromModel(
+ estimator=clf, max_features=max_features, threshold=-np.inf
+ )
+ with pytest.raises(err_type, match=err_msg):
+ transformer.fit(data, y)
+
+
+@pytest.mark.parametrize("max_features", [0, 2, data.shape[1], None])
+def test_inferred_max_features_integer(max_features):
+ """Check max_features_ and output shape for integer max_features."""
+ clf = RandomForestClassifier(n_estimators=5, random_state=0)
+ transformer = SelectFromModel(
+ estimator=clf, max_features=max_features, threshold=-np.inf
+ )
+ X_trans = transformer.fit_transform(data, y)
+ if max_features is not None:
+ assert transformer.max_features_ == max_features
+ assert X_trans.shape[1] == transformer.max_features_
+ else:
+ assert not hasattr(transformer, "max_features_")
+ assert X_trans.shape[1] == data.shape[1]
+
+
+@pytest.mark.parametrize(
+ "max_features",
+ [lambda X: 1, lambda X: X.shape[1], lambda X: min(X.shape[1], 10000)],
+)
+def test_inferred_max_features_callable(max_features):
+ """Check max_features_ and output shape for callable max_features."""
+ clf = RandomForestClassifier(n_estimators=5, random_state=0)
+ transformer = SelectFromModel(
+ estimator=clf, max_features=max_features, threshold=-np.inf
+ )
+ X_trans = transformer.fit_transform(data, y)
+ assert transformer.max_features_ == max_features(data)
+ assert X_trans.shape[1] == transformer.max_features_
+
+
+@pytest.mark.parametrize("max_features", [lambda X: round(len(X[0]) / 2), 2])
+def test_max_features_array_like(max_features):
+ X = [
+ [0.87, -1.34, 0.31],
+ [-2.79, -0.02, -0.85],
+ [-1.34, -0.48, -2.55],
+ [1.92, 1.48, 0.65],
+ ]
+ y = [0, 1, 0, 1]
+
+ clf = RandomForestClassifier(n_estimators=5, random_state=0)
+ transformer = SelectFromModel(
+ estimator=clf, max_features=max_features, threshold=-np.inf
+ )
+ X_trans = transformer.fit_transform(X, y)
+ assert X_trans.shape[1] == transformer.max_features_
+
+
+@pytest.mark.parametrize(
+ "max_features",
+ [lambda X: min(X.shape[1], 10000), lambda X: X.shape[1], lambda X: 1],
+)
+def test_max_features_callable_data(max_features):
+ """Tests that the callable passed to `fit` is called on X."""
+ clf = RandomForestClassifier(n_estimators=50, random_state=0)
+ m = Mock(side_effect=max_features)
+ transformer = SelectFromModel(estimator=clf, max_features=m, threshold=-np.inf)
+ transformer.fit_transform(data, y)
+ m.assert_called_with(data)
+
+
+class FixedImportanceEstimator(BaseEstimator):
+ def __init__(self, importances):
+ self.importances = importances
+
+ def fit(self, X, y=None):
+ self.feature_importances_ = np.array(self.importances)
+
+
+def test_max_features():
+ # Test max_features parameter using various values
+ X, y = datasets.make_classification(
+ n_samples=1000,
+ n_features=10,
+ n_informative=3,
+ n_redundant=0,
+ n_repeated=0,
+ shuffle=False,
+ random_state=0,
+ )
+ max_features = X.shape[1]
+ est = RandomForestClassifier(n_estimators=50, random_state=0)
+
+ transformer1 = SelectFromModel(estimator=est, threshold=-np.inf)
+ transformer2 = SelectFromModel(
+ estimator=est, max_features=max_features, threshold=-np.inf
+ )
+ X_new1 = transformer1.fit_transform(X, y)
+ X_new2 = transformer2.fit_transform(X, y)
+ assert_allclose(X_new1, X_new2)
+
+ # Test max_features against actual model.
+ transformer1 = SelectFromModel(estimator=Lasso(alpha=0.025, random_state=42))
+ X_new1 = transformer1.fit_transform(X, y)
+ scores1 = np.abs(transformer1.estimator_.coef_)
+ candidate_indices1 = np.argsort(-scores1, kind="mergesort")
+
+ for n_features in range(1, X_new1.shape[1] + 1):
+ transformer2 = SelectFromModel(
+ estimator=Lasso(alpha=0.025, random_state=42),
+ max_features=n_features,
+ threshold=-np.inf,
+ )
+ X_new2 = transformer2.fit_transform(X, y)
+ scores2 = np.abs(transformer2.estimator_.coef_)
+ candidate_indices2 = np.argsort(-scores2, kind="mergesort")
+ assert_allclose(
+ X[:, candidate_indices1[:n_features]], X[:, candidate_indices2[:n_features]]
+ )
+ assert_allclose(transformer1.estimator_.coef_, transformer2.estimator_.coef_)
+
+
+def test_max_features_tiebreak():
+ # Test if max_features can break tie among feature importance
+ X, y = datasets.make_classification(
+ n_samples=1000,
+ n_features=10,
+ n_informative=3,
+ n_redundant=0,
+ n_repeated=0,
+ shuffle=False,
+ random_state=0,
+ )
+ max_features = X.shape[1]
+
+ feature_importances = np.array([4, 4, 4, 4, 3, 3, 3, 2, 2, 1])
+ for n_features in range(1, max_features + 1):
+ transformer = SelectFromModel(
+ FixedImportanceEstimator(feature_importances),
+ max_features=n_features,
+ threshold=-np.inf,
+ )
+ X_new = transformer.fit_transform(X, y)
+ selected_feature_indices = np.where(transformer._get_support_mask())[0]
+ assert_array_equal(selected_feature_indices, np.arange(n_features))
+ assert X_new.shape[1] == n_features
+
+
+def test_threshold_and_max_features():
+ X, y = datasets.make_classification(
+ n_samples=1000,
+ n_features=10,
+ n_informative=3,
+ n_redundant=0,
+ n_repeated=0,
+ shuffle=False,
+ random_state=0,
+ )
+ est = RandomForestClassifier(n_estimators=50, random_state=0)
+
+ transformer1 = SelectFromModel(estimator=est, max_features=3, threshold=-np.inf)
+ X_new1 = transformer1.fit_transform(X, y)
+
+ transformer2 = SelectFromModel(estimator=est, threshold=0.04)
+ X_new2 = transformer2.fit_transform(X, y)
+
+ transformer3 = SelectFromModel(estimator=est, max_features=3, threshold=0.04)
+ X_new3 = transformer3.fit_transform(X, y)
+ assert X_new3.shape[1] == min(X_new1.shape[1], X_new2.shape[1])
+ selected_indices = transformer3.transform(np.arange(X.shape[1])[np.newaxis, :])
+ assert_allclose(X_new3, X[:, selected_indices[0]])
+
+
+@skip_if_32bit
+def test_feature_importances():
+ X, y = datasets.make_classification(
+ n_samples=1000,
+ n_features=10,
+ n_informative=3,
+ n_redundant=0,
+ n_repeated=0,
+ shuffle=False,
+ random_state=0,
+ )
+
+ est = RandomForestClassifier(n_estimators=50, random_state=0)
+ for threshold, func in zip(["mean", "median"], [np.mean, np.median]):
+ transformer = SelectFromModel(estimator=est, threshold=threshold)
+ transformer.fit(X, y)
+ assert hasattr(transformer.estimator_, "feature_importances_")
+
+ X_new = transformer.transform(X)
+ assert X_new.shape[1] < X.shape[1]
+ importances = transformer.estimator_.feature_importances_
+
+ feature_mask = np.abs(importances) > func(importances)
+ assert_array_almost_equal(X_new, X[:, feature_mask])
+
+
+def test_sample_weight():
+ # Ensure sample weights are passed to underlying estimator
+ X, y = datasets.make_classification(
+ n_samples=100,
+ n_features=10,
+ n_informative=3,
+ n_redundant=0,
+ n_repeated=0,
+ shuffle=False,
+ random_state=0,
+ )
+
+ # Check with sample weights
+ sample_weight = np.ones(y.shape)
+ sample_weight[y == 1] *= 100
+
+ est = LogisticRegression(random_state=0, fit_intercept=False)
+ transformer = SelectFromModel(estimator=est)
+ transformer.fit(X, y, sample_weight=None)
+ mask = transformer._get_support_mask()
+ transformer.fit(X, y, sample_weight=sample_weight)
+ weighted_mask = transformer._get_support_mask()
+ assert not np.all(weighted_mask == mask)
+ transformer.fit(X, y, sample_weight=3 * sample_weight)
+ reweighted_mask = transformer._get_support_mask()
+ assert np.all(weighted_mask == reweighted_mask)
+
+
+@pytest.mark.parametrize(
+ "estimator",
+ [
+ Lasso(alpha=0.1, random_state=42),
+ LassoCV(random_state=42),
+ ElasticNet(l1_ratio=1, random_state=42),
+ ElasticNetCV(l1_ratio=[1], random_state=42),
+ ],
+)
+def test_coef_default_threshold(estimator):
+ X, y = datasets.make_classification(
+ n_samples=100,
+ n_features=10,
+ n_informative=3,
+ n_redundant=0,
+ n_repeated=0,
+ shuffle=False,
+ random_state=0,
+ )
+
+ # For the Lasso and related models, the threshold defaults to 1e-5
+ transformer = SelectFromModel(estimator=estimator)
+ transformer.fit(X, y)
+ X_new = transformer.transform(X)
+ mask = np.abs(transformer.estimator_.coef_) > 1e-5
+ assert_array_almost_equal(X_new, X[:, mask])
+
+
+@skip_if_32bit
+def test_2d_coef():
+ X, y = datasets.make_classification(
+ n_samples=1000,
+ n_features=10,
+ n_informative=3,
+ n_redundant=0,
+ n_repeated=0,
+ shuffle=False,
+ random_state=0,
+ n_classes=4,
+ )
+
+ est = LogisticRegression()
+ for threshold, func in zip(["mean", "median"], [np.mean, np.median]):
+ for order in [1, 2, np.inf]:
+ # Fit SelectFromModel a multi-class problem
+ transformer = SelectFromModel(
+ estimator=LogisticRegression(), threshold=threshold, norm_order=order
+ )
+ transformer.fit(X, y)
+ assert hasattr(transformer.estimator_, "coef_")
+ X_new = transformer.transform(X)
+ assert X_new.shape[1] < X.shape[1]
+
+ # Manually check that the norm is correctly performed
+ est.fit(X, y)
+ importances = np.linalg.norm(est.coef_, axis=0, ord=order)
+ feature_mask = importances > func(importances)
+ assert_array_almost_equal(X_new, X[:, feature_mask])
+
+
+def test_partial_fit():
+ est = PassiveAggressiveClassifier(
+ random_state=0, shuffle=False, max_iter=5, tol=None
+ )
+ transformer = SelectFromModel(estimator=est)
+ transformer.partial_fit(data, y, classes=np.unique(y))
+ old_model = transformer.estimator_
+ transformer.partial_fit(data, y, classes=np.unique(y))
+ new_model = transformer.estimator_
+ assert old_model is new_model
+
+ X_transform = transformer.transform(data)
+ transformer.fit(np.vstack((data, data)), np.concatenate((y, y)))
+ assert_array_almost_equal(X_transform, transformer.transform(data))
+
+ # check that if est doesn't have partial_fit, neither does SelectFromModel
+ transformer = SelectFromModel(estimator=RandomForestClassifier())
+ assert not hasattr(transformer, "partial_fit")
+
+
+def test_calling_fit_reinitializes():
+ est = LinearSVC(random_state=0)
+ transformer = SelectFromModel(estimator=est)
+ transformer.fit(data, y)
+ transformer.set_params(estimator__C=100)
+ transformer.fit(data, y)
+ assert transformer.estimator_.C == 100
+
+
+def test_prefit():
+ # Test all possible combinations of the prefit parameter.
+
+ # Passing a prefit parameter with the selected model
+ # and fitting a unfit model with prefit=False should give same results.
+ clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, random_state=0, tol=None)
+ model = SelectFromModel(clf)
+ model.fit(data, y)
+ X_transform = model.transform(data)
+ clf.fit(data, y)
+ model = SelectFromModel(clf, prefit=True)
+ assert_array_almost_equal(model.transform(data), X_transform)
+ model.fit(data, y)
+ assert model.estimator_ is not clf
+
+ # Check that the model is rewritten if prefit=False and a fitted model is
+ # passed
+ model = SelectFromModel(clf, prefit=False)
+ model.fit(data, y)
+ assert_array_almost_equal(model.transform(data), X_transform)
+
+ # Check that passing an unfitted estimator with `prefit=True` raises a
+ # `ValueError`
+ clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, random_state=0, tol=None)
+ model = SelectFromModel(clf, prefit=True)
+ err_msg = "When `prefit=True`, `estimator` is expected to be a fitted estimator."
+ with pytest.raises(NotFittedError, match=err_msg):
+ model.fit(data, y)
+ with pytest.raises(NotFittedError, match=err_msg):
+ model.partial_fit(data, y)
+ with pytest.raises(NotFittedError, match=err_msg):
+ model.transform(data)
+
+ # Check that the internal parameters of prefitted model are not changed
+ # when calling `fit` or `partial_fit` with `prefit=True`
+ clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, tol=None).fit(data, y)
+ model = SelectFromModel(clf, prefit=True)
+ model.fit(data, y)
+ assert_allclose(model.estimator_.coef_, clf.coef_)
+ model.partial_fit(data, y)
+ assert_allclose(model.estimator_.coef_, clf.coef_)
+
+
+def test_prefit_max_features():
+ """Check the interaction between `prefit` and `max_features`."""
+ # case 1: an error should be raised at `transform` if `fit` was not called to
+ # validate the attributes
+ estimator = RandomForestClassifier(n_estimators=5, random_state=0)
+ estimator.fit(data, y)
+ model = SelectFromModel(estimator, prefit=True, max_features=lambda X: X.shape[1])
+
+ err_msg = (
+ "When `prefit=True` and `max_features` is a callable, call `fit` "
+ "before calling `transform`."
+ )
+ with pytest.raises(NotFittedError, match=err_msg):
+ model.transform(data)
+
+ # case 2: `max_features` is not validated and different from an integer
+ # FIXME: we cannot validate the upper bound of the attribute at transform
+ # and we should force calling `fit` if we intend to force the attribute
+ # to have such an upper bound.
+ max_features = 2.5
+ model.set_params(max_features=max_features)
+ with pytest.raises(ValueError, match="`max_features` must be an integer"):
+ model.transform(data)
+
+
+def test_prefit_get_feature_names_out():
+ """Check the interaction between prefit and the feature names."""
+ clf = RandomForestClassifier(n_estimators=2, random_state=0)
+ clf.fit(data, y)
+ model = SelectFromModel(clf, prefit=True, max_features=1)
+
+ # FIXME: the error message should be improved. Raising a `NotFittedError`
+ # would be better since it would force to validate all class attribute and
+ # create all the necessary fitted attribute
+ err_msg = "Unable to generate feature names without n_features_in_"
+ with pytest.raises(ValueError, match=err_msg):
+ model.get_feature_names_out()
+
+ model.fit(data, y)
+ feature_names = model.get_feature_names_out()
+ assert feature_names == ["x3"]
+
+
+def test_threshold_string():
+ est = RandomForestClassifier(n_estimators=50, random_state=0)
+ model = SelectFromModel(est, threshold="0.5*mean")
+ model.fit(data, y)
+ X_transform = model.transform(data)
+
+ # Calculate the threshold from the estimator directly.
+ est.fit(data, y)
+ threshold = 0.5 * np.mean(est.feature_importances_)
+ mask = est.feature_importances_ > threshold
+ assert_array_almost_equal(X_transform, data[:, mask])
+
+
+def test_threshold_without_refitting():
+ # Test that the threshold can be set without refitting the model.
+ clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, random_state=0, tol=None)
+ model = SelectFromModel(clf, threshold="0.1 * mean")
+ model.fit(data, y)
+ X_transform = model.transform(data)
+
+ # Set a higher threshold to filter out more features.
+ model.threshold = "1.0 * mean"
+ assert X_transform.shape[1] > model.transform(data).shape[1]
+
+
+def test_fit_accepts_nan_inf():
+ # Test that fit doesn't check for np.inf and np.nan values.
+ clf = HistGradientBoostingClassifier(random_state=0)
+
+ model = SelectFromModel(estimator=clf)
+
+ nan_data = data.copy()
+ nan_data[0] = np.NaN
+ nan_data[1] = np.Inf
+
+ model.fit(data, y)
+
+
+def test_transform_accepts_nan_inf():
+ # Test that transform doesn't check for np.inf and np.nan values.
+ clf = NaNTagRandomForest(n_estimators=100, random_state=0)
+ nan_data = data.copy()
+
+ model = SelectFromModel(estimator=clf)
+ model.fit(nan_data, y)
+
+ nan_data[0] = np.NaN
+ nan_data[1] = np.Inf
+
+ model.transform(nan_data)
+
+
+def test_allow_nan_tag_comes_from_estimator():
+ allow_nan_est = NaNTag()
+ model = SelectFromModel(estimator=allow_nan_est)
+ assert model._get_tags()["allow_nan"] is True
+
+ no_nan_est = NoNaNTag()
+ model = SelectFromModel(estimator=no_nan_est)
+ assert model._get_tags()["allow_nan"] is False
+
+
+def _pca_importances(pca_estimator):
+ return np.abs(pca_estimator.explained_variance_)
+
+
+@pytest.mark.parametrize(
+ "estimator, importance_getter",
+ [
+ (
+ make_pipeline(PCA(random_state=0), LogisticRegression()),
+ "named_steps.logisticregression.coef_",
+ ),
+ (PCA(random_state=0), _pca_importances),
+ ],
+)
+def test_importance_getter(estimator, importance_getter):
+ selector = SelectFromModel(
+ estimator, threshold="mean", importance_getter=importance_getter
+ )
+ selector.fit(data, y)
+ assert selector.transform(data).shape[1] == 1
+
+
+@pytest.mark.parametrize("PLSEstimator", [CCA, PLSCanonical, PLSRegression])
+def test_select_from_model_pls(PLSEstimator):
+ """Check the behaviour of SelectFromModel with PLS estimators.
+
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/12410
+ """
+ X, y = make_friedman1(n_samples=50, n_features=10, random_state=0)
+ estimator = PLSEstimator(n_components=1)
+ model = make_pipeline(SelectFromModel(estimator), estimator).fit(X, y)
+ assert model.score(X, y) > 0.5
+
+
+def test_estimator_does_not_support_feature_names():
+ """SelectFromModel works with estimators that do not support feature_names_in_.
+
+ Non-regression test for #21949.
+ """
+ pytest.importorskip("modin.pandas")
+ X, y = datasets.load_iris(as_frame=True, return_X_y=True)
+ all_feature_names = set(X.columns)
+
+ def importance_getter(estimator):
+ return np.arange(X.shape[1])
+
+ selector = SelectFromModel(
+ MinimalClassifier(), importance_getter=importance_getter
+ ).fit(X, y)
+
+ # selector learns the feature names itself
+ assert_array_equal(selector.feature_names_in_, X.columns)
+
+ feature_names_out = set(selector.get_feature_names_out())
+ assert feature_names_out < all_feature_names
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", UserWarning)
+
+ selector.transform(X.iloc[1:3])
+
+
+@pytest.mark.parametrize(
+ "error, err_msg, max_features",
+ (
+ [ValueError, "max_features == 10, must be <= 4", 10],
+ [ValueError, "max_features == 5, must be <= 4", lambda x: x.shape[1] + 1],
+ ),
+)
+def test_partial_fit_validate_max_features(error, err_msg, max_features):
+ """Test that partial_fit from SelectFromModel validates `max_features`."""
+ X, y = datasets.make_classification(
+ n_samples=100,
+ n_features=4,
+ random_state=0,
+ )
+
+ with pytest.raises(error, match=err_msg):
+ SelectFromModel(
+ estimator=SGDClassifier(), max_features=max_features
+ ).partial_fit(X, y, classes=[0, 1])
+
+
+@pytest.mark.parametrize("as_frame", [True, False])
+def test_partial_fit_validate_feature_names(as_frame):
+ """Test that partial_fit from SelectFromModel validates `feature_names_in_`."""
+ pytest.importorskip("modin.pandas")
+ X, y = datasets.load_iris(as_frame=as_frame, return_X_y=True)
+
+ selector = SelectFromModel(estimator=SGDClassifier(), max_features=4).partial_fit(
+ X, y, classes=[0, 1, 2]
+ )
+ if as_frame:
+ assert_array_equal(selector.feature_names_in_, X.columns)
+ else:
+ assert not hasattr(selector, "feature_names_in_")
diff --git a/modin/pandas/test/interoperability/sklearn/impute/test_common.py b/modin/pandas/test/interoperability/sklearn/impute/test_common.py
new file mode 100644
index 00000000000..5a2648957e4
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/impute/test_common.py
@@ -0,0 +1,195 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import pytest
+import numpy as np
+from scipy import sparse
+from sklearn.utils._testing import assert_allclose
+from sklearn.utils._testing import assert_allclose_dense_sparse
+from sklearn.utils._testing import assert_array_equal
+from sklearn.experimental import enable_iterative_imputer # noqa
+from sklearn.impute import IterativeImputer
+from sklearn.impute import KNNImputer
+from sklearn.impute import SimpleImputer
+
+
+def imputers():
+ return [IterativeImputer(tol=0.1), KNNImputer(), SimpleImputer()]
+
+
+def sparse_imputers():
+ return [SimpleImputer()]
+
+
+# ConvergenceWarning will be raised by the IterativeImputer
+@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
+@pytest.mark.parametrize("imputer", imputers(), ids=lambda x: x.__class__.__name__)
+def test_imputation_missing_value_in_test_array(imputer):
+ # [Non Regression Test for issue #13968] Missing value in test set should
+ # not throw an error and return a finite dataset
+ train = [[1], [2]]
+ test = [[3], [np.nan]]
+ imputer.set_params(add_indicator=True)
+ imputer.fit(train).transform(test)
+
+
+# ConvergenceWarning will be raised by the IterativeImputer
+@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
+@pytest.mark.parametrize("marker", [np.nan, -1, 0])
+@pytest.mark.parametrize("imputer", imputers(), ids=lambda x: x.__class__.__name__)
+def test_imputers_add_indicator(marker, imputer):
+ X = np.array(
+ [
+ [marker, 1, 5, marker, 1],
+ [2, marker, 1, marker, 2],
+ [6, 3, marker, marker, 3],
+ [1, 2, 9, marker, 4],
+ ]
+ )
+ X_true_indicator = np.array(
+ [
+ [1.0, 0.0, 0.0, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ]
+ )
+ imputer.set_params(missing_values=marker, add_indicator=True)
+
+ X_trans = imputer.fit_transform(X)
+ assert_allclose(X_trans[:, -4:], X_true_indicator)
+ assert_array_equal(imputer.indicator_.features_, np.array([0, 1, 2, 3]))
+
+ imputer.set_params(add_indicator=False)
+ X_trans_no_indicator = imputer.fit_transform(X)
+ assert_allclose(X_trans[:, :-4], X_trans_no_indicator)
+
+
+# ConvergenceWarning will be raised by the IterativeImputer
+@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
+@pytest.mark.parametrize("marker", [np.nan, -1])
+@pytest.mark.parametrize(
+ "imputer", sparse_imputers(), ids=lambda x: x.__class__.__name__
+)
+def test_imputers_add_indicator_sparse(imputer, marker):
+ X = sparse.csr_matrix(
+ [
+ [marker, 1, 5, marker, 1],
+ [2, marker, 1, marker, 2],
+ [6, 3, marker, marker, 3],
+ [1, 2, 9, marker, 4],
+ ]
+ )
+ X_true_indicator = sparse.csr_matrix(
+ [
+ [1.0, 0.0, 0.0, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ]
+ )
+ imputer.set_params(missing_values=marker, add_indicator=True)
+
+ X_trans = imputer.fit_transform(X)
+ assert_allclose_dense_sparse(X_trans[:, -4:], X_true_indicator)
+ assert_array_equal(imputer.indicator_.features_, np.array([0, 1, 2, 3]))
+
+ imputer.set_params(add_indicator=False)
+ X_trans_no_indicator = imputer.fit_transform(X)
+ assert_allclose_dense_sparse(X_trans[:, :-4], X_trans_no_indicator)
+
+
+# ConvergenceWarning will be raised by the IterativeImputer
+@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
+@pytest.mark.parametrize("imputer", imputers(), ids=lambda x: x.__class__.__name__)
+@pytest.mark.parametrize("add_indicator", [True, False])
+def test_imputers_pandas_na_integer_array_support(imputer, add_indicator):
+ # Test pandas IntegerArray with pd.NA
+ pd = pytest.importorskip("modin.pandas")
+ marker = np.nan
+ imputer = imputer.set_params(add_indicator=add_indicator, missing_values=marker)
+
+ X = np.array(
+ [
+ [marker, 1, 5, marker, 1],
+ [2, marker, 1, marker, 2],
+ [6, 3, marker, marker, 3],
+ [1, 2, 9, marker, 4],
+ ]
+ )
+ # fit on numpy array
+ X_trans_expected = imputer.fit_transform(X)
+
+ # Creates dataframe with IntegerArrays with pd.NA
+ X_df = pd.DataFrame(X, dtype="Int16", columns=["a", "b", "c", "d", "e"])
+
+ # fit on pandas dataframe with IntegerArrays
+ X_trans = imputer.fit_transform(X_df)
+
+ assert_allclose(X_trans_expected, X_trans)
+
+
+@pytest.mark.parametrize("imputer", imputers(), ids=lambda x: x.__class__.__name__)
+@pytest.mark.parametrize("add_indicator", [True, False])
+def test_imputers_feature_names_out_pandas(imputer, add_indicator):
+ """Check feature names out for imputers."""
+ pd = pytest.importorskip("modin.pandas")
+ marker = np.nan
+ imputer = imputer.set_params(add_indicator=add_indicator, missing_values=marker)
+
+ X = np.array(
+ [
+ [marker, 1, 5, 3, marker, 1],
+ [2, marker, 1, 4, marker, 2],
+ [6, 3, 7, marker, marker, 3],
+ [1, 2, 9, 8, marker, 4],
+ ]
+ )
+ X_df = pd.DataFrame(X, columns=["a", "b", "c", "d", "e", "f"])
+ imputer.fit(X_df)
+
+ names = imputer.get_feature_names_out()
+
+ if add_indicator:
+ expected_names = [
+ "a",
+ "b",
+ "c",
+ "d",
+ "f",
+ "missingindicator_a",
+ "missingindicator_b",
+ "missingindicator_d",
+ "missingindicator_e",
+ ]
+ assert_array_equal(expected_names, names)
+ else:
+ expected_names = ["a", "b", "c", "d", "f"]
+ assert_array_equal(expected_names, names)
+
+
+@pytest.mark.parametrize("keep_empty_features", [True, False])
+@pytest.mark.parametrize("imputer", imputers(), ids=lambda x: x.__class__.__name__)
+def test_keep_empty_features(imputer, keep_empty_features):
+ """Check that the imputer keeps features with only missing values."""
+ X = np.array([[np.nan, 1], [np.nan, 2], [np.nan, 3]])
+ imputer = imputer.set_params(
+ add_indicator=False, keep_empty_features=keep_empty_features
+ )
+
+ for method in ["fit_transform", "transform"]:
+ X_imputed = getattr(imputer, method)(X)
+ if keep_empty_features:
+ assert X_imputed.shape == X.shape
+ else:
+ assert X_imputed.shape == (X.shape[0], X.shape[1] - 1)
diff --git a/modin/pandas/test/interoperability/sklearn/impute/test_impute.py b/modin/pandas/test/interoperability/sklearn/impute/test_impute.py
new file mode 100644
index 00000000000..ddeb7a40739
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/impute/test_impute.py
@@ -0,0 +1,1723 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+import pytest
+import warnings
+import numpy as np
+from scipy import sparse
+from scipy.stats import kstest
+import io
+from sklearn.utils._testing import _convert_container
+from sklearn.utils._testing import assert_allclose
+from sklearn.utils._testing import assert_allclose_dense_sparse
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_array_almost_equal
+
+# make IterativeImputer available
+from sklearn.experimental import enable_iterative_imputer # noqa
+from sklearn.datasets import load_diabetes
+from sklearn.impute import MissingIndicator
+from sklearn.impute import SimpleImputer, IterativeImputer, KNNImputer
+from sklearn.dummy import DummyRegressor
+from sklearn.linear_model import BayesianRidge, ARDRegression, RidgeCV
+from sklearn.pipeline import Pipeline
+from sklearn.pipeline import make_union
+from sklearn.model_selection import GridSearchCV
+from sklearn import tree
+from sklearn.random_projection import _sparse_random_matrix
+from sklearn.exceptions import ConvergenceWarning
+from sklearn.impute._base import _most_frequent
+
+
+def _assert_array_equal_and_same_dtype(x, y):
+ assert_array_equal(x, y)
+ assert x.dtype == y.dtype
+
+
+def _assert_allclose_and_same_dtype(x, y):
+ assert_allclose(x, y)
+ assert x.dtype == y.dtype
+
+
+def _check_statistics(X, X_true, strategy, statistics, missing_values):
+ """Utility function for testing imputation for a given strategy.
+
+ Test with dense and sparse arrays
+
+ Check that:
+ - the statistics (mean, median, mode) are correct
+ - the missing values are imputed correctly"""
+
+ err_msg = "Parameters: strategy = %s, missing_values = %s, sparse = {0}" % (
+ strategy,
+ missing_values,
+ )
+
+ assert_ae = assert_array_equal
+
+ if X.dtype.kind == "f" or X_true.dtype.kind == "f":
+ assert_ae = assert_array_almost_equal
+
+ # Normal matrix
+ imputer = SimpleImputer(missing_values=missing_values, strategy=strategy)
+ X_trans = imputer.fit(X).transform(X.copy())
+ assert_ae(imputer.statistics_, statistics, err_msg=err_msg.format(False))
+ assert_ae(X_trans, X_true, err_msg=err_msg.format(False))
+
+ # Sparse matrix
+ imputer = SimpleImputer(missing_values=missing_values, strategy=strategy)
+ imputer.fit(sparse.csc_matrix(X))
+ X_trans = imputer.transform(sparse.csc_matrix(X.copy()))
+
+ if sparse.issparse(X_trans):
+ X_trans = X_trans.toarray()
+
+ assert_ae(imputer.statistics_, statistics, err_msg=err_msg.format(True))
+ assert_ae(X_trans, X_true, err_msg=err_msg.format(True))
+
+
+@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent", "constant"])
+def test_imputation_shape(strategy):
+ # Verify the shapes of the imputed matrix for different strategies.
+ X = np.random.randn(10, 2)
+ X[::2] = np.nan
+
+ imputer = SimpleImputer(strategy=strategy)
+ X_imputed = imputer.fit_transform(sparse.csr_matrix(X))
+ assert X_imputed.shape == (10, 2)
+ X_imputed = imputer.fit_transform(X)
+ assert X_imputed.shape == (10, 2)
+
+ iterative_imputer = IterativeImputer(initial_strategy=strategy)
+ X_imputed = iterative_imputer.fit_transform(X)
+ assert X_imputed.shape == (10, 2)
+
+
+@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
+def test_imputation_deletion_warning(strategy):
+ X = np.ones((3, 5))
+ X[:, 0] = np.nan
+ imputer = SimpleImputer(strategy=strategy, verbose=1)
+
+ # TODO: Remove in 1.3
+ with pytest.warns(FutureWarning, match="The 'verbose' parameter"):
+ imputer.fit(X)
+
+ with pytest.warns(UserWarning, match="Skipping"):
+ imputer.transform(X)
+
+
+@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
+def test_imputation_deletion_warning_feature_names(strategy):
+ pd = pytest.importorskip("modin.pandas")
+
+ missing_values = np.nan
+ feature_names = np.array(["a", "b", "c", "d"], dtype=object)
+ X = pd.DataFrame(
+ [
+ [missing_values, missing_values, 1, missing_values],
+ [4, missing_values, 2, 10],
+ ],
+ columns=feature_names,
+ )
+
+ imputer = SimpleImputer(strategy=strategy, verbose=1)
+
+ # TODO: Remove in 1.3
+ with pytest.warns(FutureWarning, match="The 'verbose' parameter"):
+ imputer.fit(X)
+
+ # check SimpleImputer returning feature name attribute correctly
+ assert_array_equal(imputer.feature_names_in_, feature_names)
+
+ # ensure that skipped feature warning includes feature name
+ with pytest.warns(
+ UserWarning, match=r"Skipping features without any observed values: \['b'\]"
+ ):
+ imputer.transform(X)
+
+
+@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent", "constant"])
+def test_imputation_error_sparse_0(strategy):
+ # check that error are raised when missing_values = 0 and input is sparse
+ X = np.ones((3, 5))
+ X[0] = 0
+ X = sparse.csc_matrix(X)
+
+ imputer = SimpleImputer(strategy=strategy, missing_values=0)
+ with pytest.raises(ValueError, match="Provide a dense array"):
+ imputer.fit(X)
+
+ imputer.fit(X.toarray())
+ with pytest.raises(ValueError, match="Provide a dense array"):
+ imputer.transform(X)
+
+
+def safe_median(arr, *args, **kwargs):
+ # np.median([]) raises a TypeError for numpy >= 1.10.1
+ length = arr.size if hasattr(arr, "size") else len(arr)
+ return np.nan if length == 0 else np.median(arr, *args, **kwargs)
+
+
+def safe_mean(arr, *args, **kwargs):
+ # np.mean([]) raises a RuntimeWarning for numpy >= 1.10.1
+ length = arr.size if hasattr(arr, "size") else len(arr)
+ return np.nan if length == 0 else np.mean(arr, *args, **kwargs)
+
+
+def test_imputation_mean_median():
+ # Test imputation using the mean and median strategies, when
+ # missing_values != 0.
+ rng = np.random.RandomState(0)
+
+ dim = 10
+ dec = 10
+ shape = (dim * dim, dim + dec)
+
+ zeros = np.zeros(shape[0])
+ values = np.arange(1, shape[0] + 1)
+ values[4::2] = -values[4::2]
+
+ tests = [
+ ("mean", np.nan, lambda z, v, p: safe_mean(np.hstack((z, v)))),
+ ("median", np.nan, lambda z, v, p: safe_median(np.hstack((z, v)))),
+ ]
+
+ for strategy, test_missing_values, true_value_fun in tests:
+ X = np.empty(shape)
+ X_true = np.empty(shape)
+ true_statistics = np.empty(shape[1])
+
+ # Create a matrix X with columns
+ # - with only zeros,
+ # - with only missing values
+ # - with zeros, missing values and values
+ # And a matrix X_true containing all true values
+ for j in range(shape[1]):
+ nb_zeros = (j - dec + 1 > 0) * (j - dec + 1) * (j - dec + 1)
+ nb_missing_values = max(shape[0] + dec * dec - (j + dec) * (j + dec), 0)
+ nb_values = shape[0] - nb_zeros - nb_missing_values
+
+ z = zeros[:nb_zeros]
+ p = np.repeat(test_missing_values, nb_missing_values)
+ v = values[rng.permutation(len(values))[:nb_values]]
+
+ true_statistics[j] = true_value_fun(z, v, p)
+
+ # Create the columns
+ X[:, j] = np.hstack((v, z, p))
+
+ if 0 == test_missing_values:
+ # XXX unreached code as of v0.22
+ X_true[:, j] = np.hstack(
+ (v, np.repeat(true_statistics[j], nb_missing_values + nb_zeros))
+ )
+ else:
+ X_true[:, j] = np.hstack(
+ (v, z, np.repeat(true_statistics[j], nb_missing_values))
+ )
+
+ # Shuffle them the same way
+ np.random.RandomState(j).shuffle(X[:, j])
+ np.random.RandomState(j).shuffle(X_true[:, j])
+
+ # Mean doesn't support columns containing NaNs, median does
+ if strategy == "median":
+ cols_to_keep = ~np.isnan(X_true).any(axis=0)
+ else:
+ cols_to_keep = ~np.isnan(X_true).all(axis=0)
+
+ X_true = X_true[:, cols_to_keep]
+
+ _check_statistics(X, X_true, strategy, true_statistics, test_missing_values)
+
+
+def test_imputation_median_special_cases():
+ # Test median imputation with sparse boundary cases
+ X = np.array(
+ [
+ [0, np.nan, np.nan], # odd: implicit zero
+ [5, np.nan, np.nan], # odd: explicit nonzero
+ [0, 0, np.nan], # even: average two zeros
+ [-5, 0, np.nan], # even: avg zero and neg
+ [0, 5, np.nan], # even: avg zero and pos
+ [4, 5, np.nan], # even: avg nonzeros
+ [-4, -5, np.nan], # even: avg negatives
+ [-1, 2, np.nan], # even: crossing neg and pos
+ ]
+ ).transpose()
+
+ X_imputed_median = np.array(
+ [
+ [0, 0, 0],
+ [5, 5, 5],
+ [0, 0, 0],
+ [-5, 0, -2.5],
+ [0, 5, 2.5],
+ [4, 5, 4.5],
+ [-4, -5, -4.5],
+ [-1, 2, 0.5],
+ ]
+ ).transpose()
+ statistics_median = [0, 5, 0, -2.5, 2.5, 4.5, -4.5, 0.5]
+
+ _check_statistics(X, X_imputed_median, "median", statistics_median, np.nan)
+
+
+@pytest.mark.parametrize("strategy", ["mean", "median"])
+@pytest.mark.parametrize("dtype", [None, object, str])
+def test_imputation_mean_median_error_invalid_type(strategy, dtype):
+ X = np.array([["a", "b", 3], [4, "e", 6], ["g", "h", 9]], dtype=dtype)
+ msg = "non-numeric data:\ncould not convert string to float: '"
+ with pytest.raises(ValueError, match=msg):
+ imputer = SimpleImputer(strategy=strategy)
+ imputer.fit_transform(X)
+
+
+@pytest.mark.parametrize("strategy", ["mean", "median"])
+@pytest.mark.parametrize("type", ["list", "dataframe"])
+def test_imputation_mean_median_error_invalid_type_list_pandas(strategy, type):
+ X = [["a", "b", 3], [4, "e", 6], ["g", "h", 9]]
+ if type == "dataframe":
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame(X)
+ msg = "non-numeric data:\ncould not convert string to float: '"
+ with pytest.raises(ValueError, match=msg):
+ imputer = SimpleImputer(strategy=strategy)
+ imputer.fit_transform(X)
+
+
+@pytest.mark.parametrize("strategy", ["constant", "most_frequent"])
+@pytest.mark.parametrize("dtype", [str, np.dtype("U"), np.dtype("S")])
+def test_imputation_const_mostf_error_invalid_types(strategy, dtype):
+ # Test imputation on non-numeric data using "most_frequent" and "constant"
+ # strategy
+ X = np.array(
+ [
+ [np.nan, np.nan, "a", "f"],
+ [np.nan, "c", np.nan, "d"],
+ [np.nan, "b", "d", np.nan],
+ [np.nan, "c", "d", "h"],
+ ],
+ dtype=dtype,
+ )
+
+ err_msg = "SimpleImputer does not support data"
+ with pytest.raises(ValueError, match=err_msg):
+ imputer = SimpleImputer(strategy=strategy)
+ imputer.fit(X).transform(X)
+
+
+def test_imputation_most_frequent():
+ # Test imputation using the most-frequent strategy.
+ X = np.array(
+ [
+ [-1, -1, 0, 5],
+ [-1, 2, -1, 3],
+ [-1, 1, 3, -1],
+ [-1, 2, 3, 7],
+ ]
+ )
+
+ X_true = np.array(
+ [
+ [2, 0, 5],
+ [2, 3, 3],
+ [1, 3, 3],
+ [2, 3, 7],
+ ]
+ )
+
+ # scipy.stats.mode, used in SimpleImputer, doesn't return the first most
+ # frequent as promised in the doc but the lowest most frequent. When this
+ # test will fail after an update of scipy, SimpleImputer will need to be
+ # updated to be consistent with the new (correct) behaviour
+ _check_statistics(X, X_true, "most_frequent", [np.nan, 2, 3, 3], -1)
+
+
+@pytest.mark.parametrize("marker", [None, np.nan, "NAN", "", 0])
+def test_imputation_most_frequent_objects(marker):
+ # Test imputation using the most-frequent strategy.
+ X = np.array(
+ [
+ [marker, marker, "a", "f"],
+ [marker, "c", marker, "d"],
+ [marker, "b", "d", marker],
+ [marker, "c", "d", "h"],
+ ],
+ dtype=object,
+ )
+
+ X_true = np.array(
+ [
+ ["c", "a", "f"],
+ ["c", "d", "d"],
+ ["b", "d", "d"],
+ ["c", "d", "h"],
+ ],
+ dtype=object,
+ )
+
+ imputer = SimpleImputer(missing_values=marker, strategy="most_frequent")
+ X_trans = imputer.fit(X).transform(X)
+
+ assert_array_equal(X_trans, X_true)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("dtype", [object, "category"])
+def test_imputation_most_frequent_pandas(dtype):
+ # Test imputation using the most frequent strategy on pandas df
+ pd = pytest.importorskip("modin.pandas")
+
+ f = io.StringIO("Cat1,Cat2,Cat3,Cat4\n,i,x,\na,,y,\na,j,,\nb,j,x,")
+
+ df = pd.read_csv(f, dtype=dtype)
+
+ X_true = np.array(
+ [["a", "i", "x"], ["a", "j", "y"], ["a", "j", "x"], ["b", "j", "x"]],
+ dtype=object,
+ )
+
+ imputer = SimpleImputer(strategy="most_frequent")
+ X_trans = imputer.fit_transform(df)
+
+ assert_array_equal(X_trans, X_true)
+
+
+@pytest.mark.parametrize("X_data, missing_value", [(1, 0), (1.0, np.nan)])
+def test_imputation_constant_error_invalid_type(X_data, missing_value):
+ # Verify that exceptions are raised on invalid fill_value type
+ X = np.full((3, 5), X_data, dtype=float)
+ X[0, 0] = missing_value
+
+ with pytest.raises(ValueError, match="imputing numerical"):
+ imputer = SimpleImputer(
+ missing_values=missing_value, strategy="constant", fill_value="x"
+ )
+ imputer.fit_transform(X)
+
+
+def test_imputation_constant_integer():
+ # Test imputation using the constant strategy on integers
+ X = np.array([[-1, 2, 3, -1], [4, -1, 5, -1], [6, 7, -1, -1], [8, 9, 0, -1]])
+
+ X_true = np.array([[0, 2, 3, 0], [4, 0, 5, 0], [6, 7, 0, 0], [8, 9, 0, 0]])
+
+ imputer = SimpleImputer(missing_values=-1, strategy="constant", fill_value=0)
+ X_trans = imputer.fit_transform(X)
+
+ assert_array_equal(X_trans, X_true)
+
+
+@pytest.mark.parametrize("array_constructor", [sparse.csr_matrix, np.asarray])
+def test_imputation_constant_float(array_constructor):
+ # Test imputation using the constant strategy on floats
+ X = np.array(
+ [
+ [np.nan, 1.1, 0, np.nan],
+ [1.2, np.nan, 1.3, np.nan],
+ [0, 0, np.nan, np.nan],
+ [1.4, 1.5, 0, np.nan],
+ ]
+ )
+
+ X_true = np.array(
+ [[-1, 1.1, 0, -1], [1.2, -1, 1.3, -1], [0, 0, -1, -1], [1.4, 1.5, 0, -1]]
+ )
+
+ X = array_constructor(X)
+
+ X_true = array_constructor(X_true)
+
+ imputer = SimpleImputer(strategy="constant", fill_value=-1)
+ X_trans = imputer.fit_transform(X)
+
+ assert_allclose_dense_sparse(X_trans, X_true)
+
+
+@pytest.mark.parametrize("marker", [None, np.nan, "NAN", "", 0])
+def test_imputation_constant_object(marker):
+ # Test imputation using the constant strategy on objects
+ X = np.array(
+ [
+ [marker, "a", "b", marker],
+ ["c", marker, "d", marker],
+ ["e", "f", marker, marker],
+ ["g", "h", "i", marker],
+ ],
+ dtype=object,
+ )
+
+ X_true = np.array(
+ [
+ ["missing", "a", "b", "missing"],
+ ["c", "missing", "d", "missing"],
+ ["e", "f", "missing", "missing"],
+ ["g", "h", "i", "missing"],
+ ],
+ dtype=object,
+ )
+
+ imputer = SimpleImputer(
+ missing_values=marker, strategy="constant", fill_value="missing"
+ )
+ X_trans = imputer.fit_transform(X)
+
+ assert_array_equal(X_trans, X_true)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("dtype", [object, "category"])
+def test_imputation_constant_pandas(dtype):
+ # Test imputation using the constant strategy on pandas df
+ pd = pytest.importorskip("modin.pandas")
+
+ f = io.StringIO("Cat1,Cat2,Cat3,Cat4\n,i,x,\na,,y,\na,j,,\nb,j,x,")
+
+ df = pd.read_csv(f, dtype=dtype)
+
+ X_true = np.array(
+ [
+ ["missing_value", "i", "x", "missing_value"],
+ ["a", "missing_value", "y", "missing_value"],
+ ["a", "j", "missing_value", "missing_value"],
+ ["b", "j", "x", "missing_value"],
+ ],
+ dtype=object,
+ )
+
+ imputer = SimpleImputer(strategy="constant")
+ X_trans = imputer.fit_transform(df)
+
+ assert_array_equal(X_trans, X_true)
+
+
+@pytest.mark.parametrize("X", [[[1], [2]], [[1], [np.nan]]])
+def test_iterative_imputer_one_feature(X):
+ # check we exit early when there is a single feature
+ imputer = IterativeImputer().fit(X)
+ assert imputer.n_iter_ == 0
+ imputer = IterativeImputer()
+ imputer.fit([[1], [2]])
+ assert imputer.n_iter_ == 0
+ imputer.fit([[1], [np.nan]])
+ assert imputer.n_iter_ == 0
+
+
+def test_imputation_pipeline_grid_search():
+ # Test imputation within a pipeline + gridsearch.
+ X = _sparse_random_matrix(100, 100, density=0.10)
+ missing_values = X.data[0]
+
+ pipeline = Pipeline(
+ [
+ ("imputer", SimpleImputer(missing_values=missing_values)),
+ ("tree", tree.DecisionTreeRegressor(random_state=0)),
+ ]
+ )
+
+ parameters = {"imputer__strategy": ["mean", "median", "most_frequent"]}
+
+ Y = _sparse_random_matrix(100, 1, density=0.10).toarray()
+ gs = GridSearchCV(pipeline, parameters)
+ gs.fit(X, Y)
+
+
+def test_imputation_copy():
+ # Test imputation with copy
+ X_orig = _sparse_random_matrix(5, 5, density=0.75, random_state=0)
+
+ # copy=True, dense => copy
+ X = X_orig.copy().toarray()
+ imputer = SimpleImputer(missing_values=0, strategy="mean", copy=True)
+ Xt = imputer.fit(X).transform(X)
+ Xt[0, 0] = -1
+ assert not np.all(X == Xt)
+
+ # copy=True, sparse csr => copy
+ X = X_orig.copy()
+ imputer = SimpleImputer(missing_values=X.data[0], strategy="mean", copy=True)
+ Xt = imputer.fit(X).transform(X)
+ Xt.data[0] = -1
+ assert not np.all(X.data == Xt.data)
+
+ # copy=False, dense => no copy
+ X = X_orig.copy().toarray()
+ imputer = SimpleImputer(missing_values=0, strategy="mean", copy=False)
+ Xt = imputer.fit(X).transform(X)
+ Xt[0, 0] = -1
+ assert_array_almost_equal(X, Xt)
+
+ # copy=False, sparse csc => no copy
+ X = X_orig.copy().tocsc()
+ imputer = SimpleImputer(missing_values=X.data[0], strategy="mean", copy=False)
+ Xt = imputer.fit(X).transform(X)
+ Xt.data[0] = -1
+ assert_array_almost_equal(X.data, Xt.data)
+
+ # copy=False, sparse csr => copy
+ X = X_orig.copy()
+ imputer = SimpleImputer(missing_values=X.data[0], strategy="mean", copy=False)
+ Xt = imputer.fit(X).transform(X)
+ Xt.data[0] = -1
+ assert not np.all(X.data == Xt.data)
+
+ # Note: If X is sparse and if missing_values=0, then a (dense) copy of X is
+ # made, even if copy=False.
+
+
+def test_iterative_imputer_zero_iters():
+ rng = np.random.RandomState(0)
+
+ n = 100
+ d = 10
+ X = _sparse_random_matrix(n, d, density=0.10, random_state=rng).toarray()
+ missing_flag = X == 0
+ X[missing_flag] = np.nan
+
+ imputer = IterativeImputer(max_iter=0)
+ X_imputed = imputer.fit_transform(X)
+ # with max_iter=0, only initial imputation is performed
+ assert_allclose(X_imputed, imputer.initial_imputer_.transform(X))
+
+ # repeat but force n_iter_ to 0
+ imputer = IterativeImputer(max_iter=5).fit(X)
+ # transformed should not be equal to initial imputation
+ assert not np.all(imputer.transform(X) == imputer.initial_imputer_.transform(X))
+
+ imputer.n_iter_ = 0
+ # now they should be equal as only initial imputation is done
+ assert_allclose(imputer.transform(X), imputer.initial_imputer_.transform(X))
+
+
+def test_iterative_imputer_verbose():
+ rng = np.random.RandomState(0)
+
+ n = 100
+ d = 3
+ X = _sparse_random_matrix(n, d, density=0.10, random_state=rng).toarray()
+ imputer = IterativeImputer(missing_values=0, max_iter=1, verbose=1)
+ imputer.fit(X)
+ imputer.transform(X)
+ imputer = IterativeImputer(missing_values=0, max_iter=1, verbose=2)
+ imputer.fit(X)
+ imputer.transform(X)
+
+
+def test_iterative_imputer_all_missing():
+ n = 100
+ d = 3
+ X = np.zeros((n, d))
+ imputer = IterativeImputer(missing_values=0, max_iter=1)
+ X_imputed = imputer.fit_transform(X)
+ assert_allclose(X_imputed, imputer.initial_imputer_.transform(X))
+
+
+@pytest.mark.parametrize(
+ "imputation_order", ["random", "roman", "ascending", "descending", "arabic"]
+)
+def test_iterative_imputer_imputation_order(imputation_order):
+ rng = np.random.RandomState(0)
+ n = 100
+ d = 10
+ max_iter = 2
+ X = _sparse_random_matrix(n, d, density=0.10, random_state=rng).toarray()
+ X[:, 0] = 1 # this column should not be discarded by IterativeImputer
+
+ imputer = IterativeImputer(
+ missing_values=0,
+ max_iter=max_iter,
+ n_nearest_features=5,
+ sample_posterior=False,
+ skip_complete=True,
+ min_value=0,
+ max_value=1,
+ verbose=1,
+ imputation_order=imputation_order,
+ random_state=rng,
+ )
+ imputer.fit_transform(X)
+ ordered_idx = [i.feat_idx for i in imputer.imputation_sequence_]
+
+ assert len(ordered_idx) // imputer.n_iter_ == imputer.n_features_with_missing_
+
+ if imputation_order == "roman":
+ assert np.all(ordered_idx[: d - 1] == np.arange(1, d))
+ elif imputation_order == "arabic":
+ assert np.all(ordered_idx[: d - 1] == np.arange(d - 1, 0, -1))
+ elif imputation_order == "random":
+ ordered_idx_round_1 = ordered_idx[: d - 1]
+ ordered_idx_round_2 = ordered_idx[d - 1 :]
+ assert ordered_idx_round_1 != ordered_idx_round_2
+ elif "ending" in imputation_order:
+ assert len(ordered_idx) == max_iter * (d - 1)
+
+
+@pytest.mark.parametrize(
+ "estimator", [None, DummyRegressor(), BayesianRidge(), ARDRegression(), RidgeCV()]
+)
+def test_iterative_imputer_estimators(estimator):
+ rng = np.random.RandomState(0)
+
+ n = 100
+ d = 10
+ X = _sparse_random_matrix(n, d, density=0.10, random_state=rng).toarray()
+
+ imputer = IterativeImputer(
+ missing_values=0, max_iter=1, estimator=estimator, random_state=rng
+ )
+ imputer.fit_transform(X)
+
+ # check that types are correct for estimators
+ hashes = []
+ for triplet in imputer.imputation_sequence_:
+ expected_type = (
+ type(estimator) if estimator is not None else type(BayesianRidge())
+ )
+ assert isinstance(triplet.estimator, expected_type)
+ hashes.append(id(triplet.estimator))
+
+ # check that each estimator is unique
+ assert len(set(hashes)) == len(hashes)
+
+
+def test_iterative_imputer_clip():
+ rng = np.random.RandomState(0)
+ n = 100
+ d = 10
+ X = _sparse_random_matrix(n, d, density=0.10, random_state=rng).toarray()
+
+ imputer = IterativeImputer(
+ missing_values=0, max_iter=1, min_value=0.1, max_value=0.2, random_state=rng
+ )
+
+ Xt = imputer.fit_transform(X)
+ assert_allclose(np.min(Xt[X == 0]), 0.1)
+ assert_allclose(np.max(Xt[X == 0]), 0.2)
+ assert_allclose(Xt[X != 0], X[X != 0])
+
+
+def test_iterative_imputer_clip_truncnorm():
+ rng = np.random.RandomState(0)
+ n = 100
+ d = 10
+ X = _sparse_random_matrix(n, d, density=0.10, random_state=rng).toarray()
+ X[:, 0] = 1
+
+ imputer = IterativeImputer(
+ missing_values=0,
+ max_iter=2,
+ n_nearest_features=5,
+ sample_posterior=True,
+ min_value=0.1,
+ max_value=0.2,
+ verbose=1,
+ imputation_order="random",
+ random_state=rng,
+ )
+ Xt = imputer.fit_transform(X)
+ assert_allclose(np.min(Xt[X == 0]), 0.1)
+ assert_allclose(np.max(Xt[X == 0]), 0.2)
+ assert_allclose(Xt[X != 0], X[X != 0])
+
+
+def test_iterative_imputer_truncated_normal_posterior():
+ # test that the values that are imputed using `sample_posterior=True`
+ # with boundaries (`min_value` and `max_value` are not None) are drawn
+ # from a distribution that looks gaussian via the Kolmogorov Smirnov test.
+ # note that starting from the wrong random seed will make this test fail
+ # because random sampling doesn't occur at all when the imputation
+ # is outside of the (min_value, max_value) range
+ rng = np.random.RandomState(42)
+
+ X = rng.normal(size=(5, 5))
+ X[0][0] = np.nan
+
+ imputer = IterativeImputer(
+ min_value=0, max_value=0.5, sample_posterior=True, random_state=rng
+ )
+
+ imputer.fit_transform(X)
+ # generate multiple imputations for the single missing value
+ imputations = np.array([imputer.transform(X)[0][0] for _ in range(100)])
+
+ assert all(imputations >= 0)
+ assert all(imputations <= 0.5)
+
+ mu, sigma = imputations.mean(), imputations.std()
+ ks_statistic, p_value = kstest((imputations - mu) / sigma, "norm")
+ if sigma == 0:
+ sigma += 1e-12
+ ks_statistic, p_value = kstest((imputations - mu) / sigma, "norm")
+ # we want to fail to reject null hypothesis
+ # null hypothesis: distributions are the same
+ assert ks_statistic < 0.2 or p_value > 0.1, "The posterior does appear to be normal"
+
+
+@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
+def test_iterative_imputer_missing_at_transform(strategy):
+ rng = np.random.RandomState(0)
+ n = 100
+ d = 10
+ X_train = rng.randint(low=0, high=3, size=(n, d))
+ X_test = rng.randint(low=0, high=3, size=(n, d))
+
+ X_train[:, 0] = 1 # definitely no missing values in 0th column
+ X_test[0, 0] = 0 # definitely missing value in 0th column
+
+ imputer = IterativeImputer(
+ missing_values=0, max_iter=1, initial_strategy=strategy, random_state=rng
+ ).fit(X_train)
+ initial_imputer = SimpleImputer(missing_values=0, strategy=strategy).fit(X_train)
+
+ # if there were no missing values at time of fit, then imputer will
+ # only use the initial imputer for that feature at transform
+ assert_allclose(
+ imputer.transform(X_test)[:, 0], initial_imputer.transform(X_test)[:, 0]
+ )
+
+
+def test_iterative_imputer_transform_stochasticity():
+ rng1 = np.random.RandomState(0)
+ rng2 = np.random.RandomState(1)
+ n = 100
+ d = 10
+ X = _sparse_random_matrix(n, d, density=0.10, random_state=rng1).toarray()
+
+ # when sample_posterior=True, two transforms shouldn't be equal
+ imputer = IterativeImputer(
+ missing_values=0, max_iter=1, sample_posterior=True, random_state=rng1
+ )
+ imputer.fit(X)
+
+ X_fitted_1 = imputer.transform(X)
+ X_fitted_2 = imputer.transform(X)
+
+ # sufficient to assert that the means are not the same
+ assert np.mean(X_fitted_1) != pytest.approx(np.mean(X_fitted_2))
+
+ # when sample_posterior=False, and n_nearest_features=None
+ # and imputation_order is not random
+ # the two transforms should be identical even if rng are different
+ imputer1 = IterativeImputer(
+ missing_values=0,
+ max_iter=1,
+ sample_posterior=False,
+ n_nearest_features=None,
+ imputation_order="ascending",
+ random_state=rng1,
+ )
+
+ imputer2 = IterativeImputer(
+ missing_values=0,
+ max_iter=1,
+ sample_posterior=False,
+ n_nearest_features=None,
+ imputation_order="ascending",
+ random_state=rng2,
+ )
+ imputer1.fit(X)
+ imputer2.fit(X)
+
+ X_fitted_1a = imputer1.transform(X)
+ X_fitted_1b = imputer1.transform(X)
+ X_fitted_2 = imputer2.transform(X)
+
+ assert_allclose(X_fitted_1a, X_fitted_1b)
+ assert_allclose(X_fitted_1a, X_fitted_2)
+
+
+def test_iterative_imputer_no_missing():
+ rng = np.random.RandomState(0)
+ X = rng.rand(100, 100)
+ X[:, 0] = np.nan
+ m1 = IterativeImputer(max_iter=10, random_state=rng)
+ m2 = IterativeImputer(max_iter=10, random_state=rng)
+ pred1 = m1.fit(X).transform(X)
+ pred2 = m2.fit_transform(X)
+ # should exclude the first column entirely
+ assert_allclose(X[:, 1:], pred1)
+ # fit and fit_transform should both be identical
+ assert_allclose(pred1, pred2)
+
+
+def test_iterative_imputer_rank_one():
+ rng = np.random.RandomState(0)
+ d = 50
+ A = rng.rand(d, 1)
+ B = rng.rand(1, d)
+ X = np.dot(A, B)
+ nan_mask = rng.rand(d, d) < 0.5
+ X_missing = X.copy()
+ X_missing[nan_mask] = np.nan
+
+ imputer = IterativeImputer(max_iter=5, verbose=1, random_state=rng)
+ X_filled = imputer.fit_transform(X_missing)
+ assert_allclose(X_filled, X, atol=0.02)
+
+
+@pytest.mark.parametrize("rank", [3, 5])
+def test_iterative_imputer_transform_recovery(rank):
+ rng = np.random.RandomState(0)
+ n = 70
+ d = 70
+ A = rng.rand(n, rank)
+ B = rng.rand(rank, d)
+ X_filled = np.dot(A, B)
+ nan_mask = rng.rand(n, d) < 0.5
+ X_missing = X_filled.copy()
+ X_missing[nan_mask] = np.nan
+
+ # split up data in half
+ n = n // 2
+ X_train = X_missing[:n]
+ X_test_filled = X_filled[n:]
+ X_test = X_missing[n:]
+
+ imputer = IterativeImputer(
+ max_iter=5, imputation_order="descending", verbose=1, random_state=rng
+ ).fit(X_train)
+ X_test_est = imputer.transform(X_test)
+ assert_allclose(X_test_filled, X_test_est, atol=0.1)
+
+
+def test_iterative_imputer_additive_matrix():
+ rng = np.random.RandomState(0)
+ n = 100
+ d = 10
+ A = rng.randn(n, d)
+ B = rng.randn(n, d)
+ X_filled = np.zeros(A.shape)
+ for i in range(d):
+ for j in range(d):
+ X_filled[:, (i + j) % d] += (A[:, i] + B[:, j]) / 2
+ # a quarter is randomly missing
+ nan_mask = rng.rand(n, d) < 0.25
+ X_missing = X_filled.copy()
+ X_missing[nan_mask] = np.nan
+
+ # split up data
+ n = n // 2
+ X_train = X_missing[:n]
+ X_test_filled = X_filled[n:]
+ X_test = X_missing[n:]
+
+ imputer = IterativeImputer(max_iter=10, verbose=1, random_state=rng).fit(X_train)
+ X_test_est = imputer.transform(X_test)
+ assert_allclose(X_test_filled, X_test_est, rtol=1e-3, atol=0.01)
+
+
+def test_iterative_imputer_early_stopping():
+ rng = np.random.RandomState(0)
+ n = 50
+ d = 5
+ A = rng.rand(n, 1)
+ B = rng.rand(1, d)
+ X = np.dot(A, B)
+ nan_mask = rng.rand(n, d) < 0.5
+ X_missing = X.copy()
+ X_missing[nan_mask] = np.nan
+
+ imputer = IterativeImputer(
+ max_iter=100, tol=1e-2, sample_posterior=False, verbose=1, random_state=rng
+ )
+ X_filled_100 = imputer.fit_transform(X_missing)
+ assert len(imputer.imputation_sequence_) == d * imputer.n_iter_
+
+ imputer = IterativeImputer(
+ max_iter=imputer.n_iter_, sample_posterior=False, verbose=1, random_state=rng
+ )
+ X_filled_early = imputer.fit_transform(X_missing)
+ assert_allclose(X_filled_100, X_filled_early, atol=1e-7)
+
+ imputer = IterativeImputer(
+ max_iter=100, tol=0, sample_posterior=False, verbose=1, random_state=rng
+ )
+ imputer.fit(X_missing)
+ assert imputer.n_iter_ == imputer.max_iter
+
+
+def test_iterative_imputer_catch_warning():
+ # check that we catch a RuntimeWarning due to a division by zero when a
+ # feature is constant in the dataset
+ X, y = load_diabetes(return_X_y=True)
+ n_samples, n_features = X.shape
+
+ # simulate that a feature only contain one category during fit
+ X[:, 3] = 1
+
+ # add some missing values
+ rng = np.random.RandomState(0)
+ missing_rate = 0.15
+ for feat in range(n_features):
+ sample_idx = rng.choice(
+ np.arange(n_samples), size=int(n_samples * missing_rate), replace=False
+ )
+ X[sample_idx, feat] = np.nan
+
+ imputer = IterativeImputer(n_nearest_features=5, sample_posterior=True)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", RuntimeWarning)
+ X_fill = imputer.fit_transform(X, y)
+ assert not np.any(np.isnan(X_fill))
+
+
+@pytest.mark.parametrize(
+ "min_value, max_value, correct_output",
+ [
+ (0, 100, np.array([[0] * 3, [100] * 3])),
+ (None, None, np.array([[-np.inf] * 3, [np.inf] * 3])),
+ (-np.inf, np.inf, np.array([[-np.inf] * 3, [np.inf] * 3])),
+ ([-5, 5, 10], [100, 200, 300], np.array([[-5, 5, 10], [100, 200, 300]])),
+ (
+ [-5, -np.inf, 10],
+ [100, 200, np.inf],
+ np.array([[-5, -np.inf, 10], [100, 200, np.inf]]),
+ ),
+ ],
+ ids=["scalars", "None-default", "inf", "lists", "lists-with-inf"],
+)
+def test_iterative_imputer_min_max_array_like(min_value, max_value, correct_output):
+ # check that passing scalar or array-like
+ # for min_value and max_value in IterativeImputer works
+ X = np.random.RandomState(0).randn(10, 3)
+ imputer = IterativeImputer(min_value=min_value, max_value=max_value)
+ imputer.fit(X)
+
+ assert isinstance(imputer._min_value, np.ndarray) and isinstance(
+ imputer._max_value, np.ndarray
+ )
+ assert (imputer._min_value.shape[0] == X.shape[1]) and (
+ imputer._max_value.shape[0] == X.shape[1]
+ )
+
+ assert_allclose(correct_output[0, :], imputer._min_value)
+ assert_allclose(correct_output[1, :], imputer._max_value)
+
+
+@pytest.mark.parametrize(
+ "min_value, max_value, err_msg",
+ [
+ (100, 0, "min_value >= max_value."),
+ (np.inf, -np.inf, "min_value >= max_value."),
+ ([-5, 5], [100, 200, 0], "_value' should be of shape"),
+ ],
+)
+def test_iterative_imputer_catch_min_max_error(min_value, max_value, err_msg):
+ # check that passing scalar or array-like
+ # for min_value and max_value in IterativeImputer works
+ X = np.random.random((10, 3))
+ imputer = IterativeImputer(min_value=min_value, max_value=max_value)
+ with pytest.raises(ValueError, match=err_msg):
+ imputer.fit(X)
+
+
+@pytest.mark.parametrize(
+ "min_max_1, min_max_2",
+ [([None, None], [-np.inf, np.inf]), ([-10, 10], [[-10] * 4, [10] * 4])],
+ ids=["None-vs-inf", "Scalar-vs-vector"],
+)
+def test_iterative_imputer_min_max_array_like_imputation(min_max_1, min_max_2):
+ # Test that None/inf and scalar/vector give the same imputation
+ X_train = np.array(
+ [
+ [np.nan, 2, 2, 1],
+ [10, np.nan, np.nan, 7],
+ [3, 1, np.nan, 1],
+ [np.nan, 4, 2, np.nan],
+ ]
+ )
+ X_test = np.array(
+ [[np.nan, 2, np.nan, 5], [2, 4, np.nan, np.nan], [np.nan, 1, 10, 1]]
+ )
+ imputer1 = IterativeImputer(
+ min_value=min_max_1[0], max_value=min_max_1[1], random_state=0
+ )
+ imputer2 = IterativeImputer(
+ min_value=min_max_2[0], max_value=min_max_2[1], random_state=0
+ )
+ X_test_imputed1 = imputer1.fit(X_train).transform(X_test)
+ X_test_imputed2 = imputer2.fit(X_train).transform(X_test)
+ assert_allclose(X_test_imputed1[:, 0], X_test_imputed2[:, 0])
+
+
+@pytest.mark.parametrize("skip_complete", [True, False])
+def test_iterative_imputer_skip_non_missing(skip_complete):
+ # check the imputing strategy when missing data are present in the
+ # testing set only.
+ # taken from: https://github.com/scikit-learn/scikit-learn/issues/14383
+ rng = np.random.RandomState(0)
+ X_train = np.array([[5, 2, 2, 1], [10, 1, 2, 7], [3, 1, 1, 1], [8, 4, 2, 2]])
+ X_test = np.array([[np.nan, 2, 4, 5], [np.nan, 4, 1, 2], [np.nan, 1, 10, 1]])
+ imputer = IterativeImputer(
+ initial_strategy="mean", skip_complete=skip_complete, random_state=rng
+ )
+ X_test_est = imputer.fit(X_train).transform(X_test)
+ if skip_complete:
+ # impute with the initial strategy: 'mean'
+ assert_allclose(X_test_est[:, 0], np.mean(X_train[:, 0]))
+ else:
+ assert_allclose(X_test_est[:, 0], [11, 7, 12], rtol=1e-4)
+
+
+@pytest.mark.parametrize("rs_imputer", [None, 1, np.random.RandomState(seed=1)])
+@pytest.mark.parametrize("rs_estimator", [None, 1, np.random.RandomState(seed=1)])
+def test_iterative_imputer_dont_set_random_state(rs_imputer, rs_estimator):
+ class ZeroEstimator:
+ def __init__(self, random_state):
+ self.random_state = random_state
+
+ def fit(self, *args, **kgards):
+ return self
+
+ def predict(self, X):
+ return np.zeros(X.shape[0])
+
+ estimator = ZeroEstimator(random_state=rs_estimator)
+ imputer = IterativeImputer(random_state=rs_imputer)
+ X_train = np.zeros((10, 3))
+ imputer.fit(X_train)
+ assert estimator.random_state == rs_estimator
+
+
+@pytest.mark.parametrize(
+ "X_fit, X_trans, params, msg_err",
+ [
+ (
+ np.array([[-1, 1], [1, 2]]),
+ np.array([[-1, 1], [1, -1]]),
+ {"features": "missing-only", "sparse": "auto"},
+ "have missing values in transform but have no missing values in fit",
+ ),
+ (
+ np.array([["a", "b"], ["c", "a"]], dtype=str),
+ np.array([["a", "b"], ["c", "a"]], dtype=str),
+ {},
+ "MissingIndicator does not support data with dtype",
+ ),
+ ],
+)
+def test_missing_indicator_error(X_fit, X_trans, params, msg_err):
+ indicator = MissingIndicator(missing_values=-1)
+ indicator.set_params(**params)
+ with pytest.raises(ValueError, match=msg_err):
+ indicator.fit(X_fit).transform(X_trans)
+
+
+@pytest.mark.parametrize(
+ "missing_values, dtype, arr_type",
+ [
+ (np.nan, np.float64, np.array),
+ (0, np.int32, np.array),
+ (-1, np.int32, np.array),
+ (np.nan, np.float64, sparse.csc_matrix),
+ (-1, np.int32, sparse.csc_matrix),
+ (np.nan, np.float64, sparse.csr_matrix),
+ (-1, np.int32, sparse.csr_matrix),
+ (np.nan, np.float64, sparse.coo_matrix),
+ (-1, np.int32, sparse.coo_matrix),
+ (np.nan, np.float64, sparse.lil_matrix),
+ (-1, np.int32, sparse.lil_matrix),
+ (np.nan, np.float64, sparse.bsr_matrix),
+ (-1, np.int32, sparse.bsr_matrix),
+ ],
+)
+@pytest.mark.parametrize(
+ "param_features, n_features, features_indices",
+ [("missing-only", 3, np.array([0, 1, 2])), ("all", 3, np.array([0, 1, 2]))],
+)
+def test_missing_indicator_new(
+ missing_values, arr_type, dtype, param_features, n_features, features_indices
+):
+ X_fit = np.array([[missing_values, missing_values, 1], [4, 2, missing_values]])
+ X_trans = np.array([[missing_values, missing_values, 1], [4, 12, 10]])
+ X_fit_expected = np.array([[1, 1, 0], [0, 0, 1]])
+ X_trans_expected = np.array([[1, 1, 0], [0, 0, 0]])
+
+ # convert the input to the right array format and right dtype
+ X_fit = arr_type(X_fit).astype(dtype)
+ X_trans = arr_type(X_trans).astype(dtype)
+ X_fit_expected = X_fit_expected.astype(dtype)
+ X_trans_expected = X_trans_expected.astype(dtype)
+
+ indicator = MissingIndicator(
+ missing_values=missing_values, features=param_features, sparse=False
+ )
+ X_fit_mask = indicator.fit_transform(X_fit)
+ X_trans_mask = indicator.transform(X_trans)
+
+ assert X_fit_mask.shape[1] == n_features
+ assert X_trans_mask.shape[1] == n_features
+
+ assert_array_equal(indicator.features_, features_indices)
+ assert_allclose(X_fit_mask, X_fit_expected[:, features_indices])
+ assert_allclose(X_trans_mask, X_trans_expected[:, features_indices])
+
+ assert X_fit_mask.dtype == bool
+ assert X_trans_mask.dtype == bool
+ assert isinstance(X_fit_mask, np.ndarray)
+ assert isinstance(X_trans_mask, np.ndarray)
+
+ indicator.set_params(sparse=True)
+ X_fit_mask_sparse = indicator.fit_transform(X_fit)
+ X_trans_mask_sparse = indicator.transform(X_trans)
+
+ assert X_fit_mask_sparse.dtype == bool
+ assert X_trans_mask_sparse.dtype == bool
+ assert X_fit_mask_sparse.format == "csc"
+ assert X_trans_mask_sparse.format == "csc"
+ assert_allclose(X_fit_mask_sparse.toarray(), X_fit_mask)
+ assert_allclose(X_trans_mask_sparse.toarray(), X_trans_mask)
+
+
+@pytest.mark.parametrize(
+ "arr_type",
+ [
+ sparse.csc_matrix,
+ sparse.csr_matrix,
+ sparse.coo_matrix,
+ sparse.lil_matrix,
+ sparse.bsr_matrix,
+ ],
+)
+def test_missing_indicator_raise_on_sparse_with_missing_0(arr_type):
+ # test for sparse input and missing_value == 0
+
+ missing_values = 0
+ X_fit = np.array([[missing_values, missing_values, 1], [4, missing_values, 2]])
+ X_trans = np.array([[missing_values, missing_values, 1], [4, 12, 10]])
+
+ # convert the input to the right array format
+ X_fit_sparse = arr_type(X_fit)
+ X_trans_sparse = arr_type(X_trans)
+
+ indicator = MissingIndicator(missing_values=missing_values)
+
+ with pytest.raises(ValueError, match="Sparse input with missing_values=0"):
+ indicator.fit_transform(X_fit_sparse)
+
+ indicator.fit_transform(X_fit)
+ with pytest.raises(ValueError, match="Sparse input with missing_values=0"):
+ indicator.transform(X_trans_sparse)
+
+
+@pytest.mark.parametrize("param_sparse", [True, False, "auto"])
+@pytest.mark.parametrize(
+ "missing_values, arr_type",
+ [
+ (np.nan, np.array),
+ (0, np.array),
+ (np.nan, sparse.csc_matrix),
+ (np.nan, sparse.csr_matrix),
+ (np.nan, sparse.coo_matrix),
+ (np.nan, sparse.lil_matrix),
+ ],
+)
+def test_missing_indicator_sparse_param(arr_type, missing_values, param_sparse):
+ # check the format of the output with different sparse parameter
+ X_fit = np.array([[missing_values, missing_values, 1], [4, missing_values, 2]])
+ X_trans = np.array([[missing_values, missing_values, 1], [4, 12, 10]])
+ X_fit = arr_type(X_fit).astype(np.float64)
+ X_trans = arr_type(X_trans).astype(np.float64)
+
+ indicator = MissingIndicator(missing_values=missing_values, sparse=param_sparse)
+ X_fit_mask = indicator.fit_transform(X_fit)
+ X_trans_mask = indicator.transform(X_trans)
+
+ if param_sparse is True:
+ assert X_fit_mask.format == "csc"
+ assert X_trans_mask.format == "csc"
+ elif param_sparse == "auto" and missing_values == 0:
+ assert isinstance(X_fit_mask, np.ndarray)
+ assert isinstance(X_trans_mask, np.ndarray)
+ elif param_sparse is False:
+ assert isinstance(X_fit_mask, np.ndarray)
+ assert isinstance(X_trans_mask, np.ndarray)
+ else:
+ if sparse.issparse(X_fit):
+ assert X_fit_mask.format == "csc"
+ assert X_trans_mask.format == "csc"
+ else:
+ assert isinstance(X_fit_mask, np.ndarray)
+ assert isinstance(X_trans_mask, np.ndarray)
+
+
+def test_missing_indicator_string():
+ X = np.array([["a", "b", "c"], ["b", "c", "a"]], dtype=object)
+ indicator = MissingIndicator(missing_values="a", features="all")
+ X_trans = indicator.fit_transform(X)
+ assert_array_equal(X_trans, np.array([[True, False, False], [False, False, True]]))
+
+
+@pytest.mark.parametrize(
+ "X, missing_values, X_trans_exp",
+ [
+ (
+ np.array([["a", "b"], ["b", "a"]], dtype=object),
+ "a",
+ np.array([["b", "b", True, False], ["b", "b", False, True]], dtype=object),
+ ),
+ (
+ np.array([[np.nan, 1.0], [1.0, np.nan]]),
+ np.nan,
+ np.array([[1.0, 1.0, True, False], [1.0, 1.0, False, True]]),
+ ),
+ (
+ np.array([[np.nan, "b"], ["b", np.nan]], dtype=object),
+ np.nan,
+ np.array([["b", "b", True, False], ["b", "b", False, True]], dtype=object),
+ ),
+ (
+ np.array([[None, "b"], ["b", None]], dtype=object),
+ None,
+ np.array([["b", "b", True, False], ["b", "b", False, True]], dtype=object),
+ ),
+ ],
+)
+def test_missing_indicator_with_imputer(X, missing_values, X_trans_exp):
+ trans = make_union(
+ SimpleImputer(missing_values=missing_values, strategy="most_frequent"),
+ MissingIndicator(missing_values=missing_values),
+ )
+ X_trans = trans.fit_transform(X)
+ assert_array_equal(X_trans, X_trans_exp)
+
+
+@pytest.mark.parametrize("imputer_constructor", [SimpleImputer, IterativeImputer])
+@pytest.mark.parametrize(
+ "imputer_missing_values, missing_value, err_msg",
+ [
+ ("NaN", np.nan, "Input X contains NaN"),
+ ("-1", -1, "types are expected to be both numerical."),
+ ],
+)
+def test_inconsistent_dtype_X_missing_values(
+ imputer_constructor, imputer_missing_values, missing_value, err_msg
+):
+ # regression test for issue #11390. Comparison between incoherent dtype
+ # for X and missing_values was not raising a proper error.
+ rng = np.random.RandomState(42)
+ X = rng.randn(10, 10)
+ X[0, 0] = missing_value
+
+ imputer = imputer_constructor(missing_values=imputer_missing_values)
+
+ with pytest.raises(ValueError, match=err_msg):
+ imputer.fit_transform(X)
+
+
+def test_missing_indicator_no_missing():
+ # check that all features are dropped if there are no missing values when
+ # features='missing-only' (#13491)
+ X = np.array([[1, 1], [1, 1]])
+
+ mi = MissingIndicator(features="missing-only", missing_values=-1)
+ Xt = mi.fit_transform(X)
+
+ assert Xt.shape[1] == 0
+
+
+def test_missing_indicator_sparse_no_explicit_zeros():
+ # Check that non missing values don't become explicit zeros in the mask
+ # generated by missing indicator when X is sparse. (#13491)
+ X = sparse.csr_matrix([[0, 1, 2], [1, 2, 0], [2, 0, 1]])
+
+ mi = MissingIndicator(features="all", missing_values=1)
+ Xt = mi.fit_transform(X)
+
+ assert Xt.getnnz() == Xt.sum()
+
+
+@pytest.mark.parametrize("imputer_constructor", [SimpleImputer, IterativeImputer])
+def test_imputer_without_indicator(imputer_constructor):
+ X = np.array([[1, 1], [1, 1]])
+ imputer = imputer_constructor()
+ imputer.fit(X)
+
+ assert imputer.indicator_ is None
+
+
+@pytest.mark.parametrize(
+ "arr_type",
+ [
+ sparse.csc_matrix,
+ sparse.csr_matrix,
+ sparse.coo_matrix,
+ sparse.lil_matrix,
+ sparse.bsr_matrix,
+ ],
+)
+def test_simple_imputation_add_indicator_sparse_matrix(arr_type):
+ X_sparse = arr_type([[np.nan, 1, 5], [2, np.nan, 1], [6, 3, np.nan], [1, 2, 9]])
+ X_true = np.array(
+ [
+ [3.0, 1.0, 5.0, 1.0, 0.0, 0.0],
+ [2.0, 2.0, 1.0, 0.0, 1.0, 0.0],
+ [6.0, 3.0, 5.0, 0.0, 0.0, 1.0],
+ [1.0, 2.0, 9.0, 0.0, 0.0, 0.0],
+ ]
+ )
+
+ imputer = SimpleImputer(missing_values=np.nan, add_indicator=True)
+ X_trans = imputer.fit_transform(X_sparse)
+
+ assert sparse.issparse(X_trans)
+ assert X_trans.shape == X_true.shape
+ assert_allclose(X_trans.toarray(), X_true)
+
+
+@pytest.mark.parametrize(
+ "strategy, expected", [("most_frequent", "b"), ("constant", "missing_value")]
+)
+def test_simple_imputation_string_list(strategy, expected):
+ X = [["a", "b"], ["c", np.nan]]
+
+ X_true = np.array([["a", "b"], ["c", expected]], dtype=object)
+
+ imputer = SimpleImputer(strategy=strategy)
+ X_trans = imputer.fit_transform(X)
+
+ assert_array_equal(X_trans, X_true)
+
+
+@pytest.mark.parametrize(
+ "order, idx_order",
+ [("ascending", [3, 4, 2, 0, 1]), ("descending", [1, 0, 2, 4, 3])],
+)
+def test_imputation_order(order, idx_order):
+ # regression test for #15393
+ rng = np.random.RandomState(42)
+ X = rng.rand(100, 5)
+ X[:50, 1] = np.nan
+ X[:30, 0] = np.nan
+ X[:20, 2] = np.nan
+ X[:10, 4] = np.nan
+
+ with pytest.warns(ConvergenceWarning):
+ trs = IterativeImputer(max_iter=1, imputation_order=order, random_state=0).fit(
+ X
+ )
+ idx = [x.feat_idx for x in trs.imputation_sequence_]
+ assert idx == idx_order
+
+
+@pytest.mark.parametrize("missing_value", [-1, np.nan])
+def test_simple_imputation_inverse_transform(missing_value):
+ # Test inverse_transform feature for np.nan
+ X_1 = np.array(
+ [
+ [9, missing_value, 3, -1],
+ [4, -1, 5, 4],
+ [6, 7, missing_value, -1],
+ [8, 9, 0, missing_value],
+ ]
+ )
+
+ X_2 = np.array(
+ [
+ [5, 4, 2, 1],
+ [2, 1, missing_value, 3],
+ [9, missing_value, 7, 1],
+ [6, 4, 2, missing_value],
+ ]
+ )
+
+ X_3 = np.array(
+ [
+ [1, missing_value, 5, 9],
+ [missing_value, 4, missing_value, missing_value],
+ [2, missing_value, 7, missing_value],
+ [missing_value, 3, missing_value, 8],
+ ]
+ )
+
+ X_4 = np.array(
+ [
+ [1, 1, 1, 3],
+ [missing_value, 2, missing_value, 1],
+ [2, 3, 3, 4],
+ [missing_value, 4, missing_value, 2],
+ ]
+ )
+
+ imputer = SimpleImputer(
+ missing_values=missing_value, strategy="mean", add_indicator=True
+ )
+
+ X_1_trans = imputer.fit_transform(X_1)
+ X_1_inv_trans = imputer.inverse_transform(X_1_trans)
+
+ X_2_trans = imputer.transform(X_2) # test on new data
+ X_2_inv_trans = imputer.inverse_transform(X_2_trans)
+
+ assert_array_equal(X_1_inv_trans, X_1)
+ assert_array_equal(X_2_inv_trans, X_2)
+
+ for X in [X_3, X_4]:
+ X_trans = imputer.fit_transform(X)
+ X_inv_trans = imputer.inverse_transform(X_trans)
+ assert_array_equal(X_inv_trans, X)
+
+
+@pytest.mark.parametrize("missing_value", [-1, np.nan])
+def test_simple_imputation_inverse_transform_exceptions(missing_value):
+ X_1 = np.array(
+ [
+ [9, missing_value, 3, -1],
+ [4, -1, 5, 4],
+ [6, 7, missing_value, -1],
+ [8, 9, 0, missing_value],
+ ]
+ )
+
+ imputer = SimpleImputer(missing_values=missing_value, strategy="mean")
+ X_1_trans = imputer.fit_transform(X_1)
+ with pytest.raises(
+ ValueError, match=f"Got 'add_indicator={imputer.add_indicator}'"
+ ):
+ imputer.inverse_transform(X_1_trans)
+
+
+@pytest.mark.parametrize(
+ "expected,array,dtype,extra_value,n_repeat",
+ [
+ # array of object dtype
+ ("extra_value", ["a", "b", "c"], object, "extra_value", 2),
+ (
+ "most_frequent_value",
+ ["most_frequent_value", "most_frequent_value", "value"],
+ object,
+ "extra_value",
+ 1,
+ ),
+ ("a", ["min_value", "min_valuevalue"], object, "a", 2),
+ ("min_value", ["min_value", "min_value", "value"], object, "z", 2),
+ # array of numeric dtype
+ (10, [1, 2, 3], int, 10, 2),
+ (1, [1, 1, 2], int, 10, 1),
+ (10, [20, 20, 1], int, 10, 2),
+ (1, [1, 1, 20], int, 10, 2),
+ ],
+)
+def test_most_frequent(expected, array, dtype, extra_value, n_repeat):
+ assert expected == _most_frequent(
+ np.array(array, dtype=dtype), extra_value, n_repeat
+ )
+
+
+@pytest.mark.parametrize(
+ "initial_strategy", ["mean", "median", "most_frequent", "constant"]
+)
+def test_iterative_imputer_keep_empty_features(initial_strategy):
+ """Check the behaviour of the iterative imputer with different initial strategy
+ and keeping empty features (i.e. features containing only missing values).
+ """
+ X = np.array([[1, np.nan, 2], [3, np.nan, np.nan]])
+
+ imputer = IterativeImputer(
+ initial_strategy=initial_strategy, keep_empty_features=True
+ )
+ X_imputed = imputer.fit_transform(X)
+ assert_allclose(X_imputed[:, 1], 0)
+ X_imputed = imputer.transform(X)
+ assert_allclose(X_imputed[:, 1], 0)
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_iterative_imputer_constant_fill_value():
+ """Check that we propagate properly the parameter `fill_value`."""
+ X = np.array([[-1, 2, 3, -1], [4, -1, 5, -1], [6, 7, -1, -1], [8, 9, 0, -1]])
+
+ fill_value = 100
+ imputer = IterativeImputer(
+ missing_values=-1,
+ initial_strategy="constant",
+ fill_value=fill_value,
+ max_iter=0,
+ )
+ imputer.fit_transform(X)
+ assert_array_equal(imputer.initial_imputer_.statistics_, fill_value)
+
+
+@pytest.mark.parametrize("keep_empty_features", [True, False])
+def test_knn_imputer_keep_empty_features(keep_empty_features):
+ """Check the behaviour of `keep_empty_features` for `KNNImputer`."""
+ X = np.array([[1, np.nan, 2], [3, np.nan, np.nan]])
+
+ imputer = KNNImputer(keep_empty_features=keep_empty_features)
+
+ for method in ["fit_transform", "transform"]:
+ X_imputed = getattr(imputer, method)(X)
+ if keep_empty_features:
+ assert X_imputed.shape == X.shape
+ assert_array_equal(X_imputed[:, 1], 0)
+ else:
+ assert X_imputed.shape == (X.shape[0], X.shape[1] - 1)
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_simple_impute_pd_na():
+ pd = pytest.importorskip("modin.pandas")
+
+ # Impute pandas array of string types.
+ df = pd.DataFrame({"feature": pd.Series(["abc", None, "de"], dtype="string")})
+ imputer = SimpleImputer(missing_values=pd.NA, strategy="constant", fill_value="na")
+ _assert_array_equal_and_same_dtype(
+ imputer.fit_transform(df), np.array([["abc"], ["na"], ["de"]], dtype=object)
+ )
+
+ # Impute pandas array of string types without any missing values.
+ df = pd.DataFrame({"feature": pd.Series(["abc", "de", "fgh"], dtype="string")})
+ imputer = SimpleImputer(fill_value="ok", strategy="constant")
+ _assert_array_equal_and_same_dtype(
+ imputer.fit_transform(df), np.array([["abc"], ["de"], ["fgh"]], dtype=object)
+ )
+
+ # Impute pandas array of integer types.
+ df = pd.DataFrame({"feature": pd.Series([1.0, None, 3.0], dtype="Int64")})
+
+ imputer = SimpleImputer(missing_values=pd.NA, strategy="constant", fill_value=-1)
+ # df = np.array(df, dtype="float64")
+ _assert_allclose_and_same_dtype(
+ np.array(imputer.fit_transform(df), dtype="float64"), # fixed
+ np.array([[1], [-1], [3]], dtype="float64"),
+ )
+
+ # Use `np.nan` also works.
+ imputer = SimpleImputer(missing_values=np.nan, strategy="constant", fill_value=-1)
+ _assert_allclose_and_same_dtype(
+ imputer.fit_transform(df), np.array([[1], [-1], [3]], dtype="float64")
+ )
+
+ # Impute pandas array of integer types with 'median' strategy.
+ df = pd.DataFrame({"feature": pd.Series([1, None, 2, 3], dtype="Int64")})
+ imputer = SimpleImputer(missing_values=pd.NA, strategy="median")
+ _assert_allclose_and_same_dtype(
+ imputer.fit_transform(df), np.array([[1], [2], [2], [3]], dtype="float64")
+ )
+
+ # Impute pandas array of integer types with 'mean' strategy.
+ df = pd.DataFrame({"feature": pd.Series([1, None, 2], dtype="Int64")})
+ imputer = SimpleImputer(missing_values=pd.NA, strategy="mean")
+ _assert_allclose_and_same_dtype(
+ imputer.fit_transform(df), np.array([[1], [1.5], [2]], dtype="float64")
+ )
+
+ # Impute pandas array of float types.
+ df = pd.DataFrame({"feature": pd.Series([1.0, None, 3.0], dtype="float64")})
+ imputer = SimpleImputer(missing_values=pd.NA, strategy="constant", fill_value=-2.0)
+ _assert_allclose_and_same_dtype(
+ imputer.fit_transform(df), np.array([[1.0], [-2.0], [3.0]], dtype="float64")
+ )
+
+ # Impute pandas array of float types with 'median' strategy.
+ df = pd.DataFrame({"feature": pd.Series([1.0, None, 2.0, 3.0], dtype="float64")})
+ imputer = SimpleImputer(missing_values=pd.NA, strategy="median")
+ _assert_allclose_and_same_dtype(
+ imputer.fit_transform(df),
+ np.array([[1.0], [2.0], [2.0], [3.0]], dtype="float64"),
+ )
+
+
+def test_missing_indicator_feature_names_out():
+ """Check that missing indicator return the feature names with a prefix."""
+ pd = pytest.importorskip("modin.pandas")
+
+ missing_values = np.nan
+ X = pd.DataFrame(
+ [
+ [missing_values, missing_values, 1, missing_values],
+ [4, missing_values, 2, 10],
+ ],
+ columns=["a", "b", "c", "d"],
+ )
+
+ indicator = MissingIndicator(missing_values=missing_values).fit(X)
+ feature_names = indicator.get_feature_names_out()
+ expected_names = ["missingindicator_a", "missingindicator_b", "missingindicator_d"]
+ assert_array_equal(expected_names, feature_names)
+
+
+def test_imputer_lists_fit_transform():
+ """Check transform uses object dtype when fitted on an object dtype.
+
+ Non-regression test for #19572.
+ """
+
+ X = [["a", "b"], ["c", "b"], ["a", "a"]]
+ imp_frequent = SimpleImputer(strategy="most_frequent").fit(X)
+ X_trans = imp_frequent.transform([[np.nan, np.nan]])
+ assert X_trans.dtype == object
+ assert_array_equal(X_trans, [["a", "b"]])
+
+
+@pytest.mark.parametrize("dtype_test", [np.float32, np.float64])
+def test_imputer_transform_preserves_numeric_dtype(dtype_test):
+ """Check transform preserves numeric dtype independent of fit dtype."""
+ X = np.asarray(
+ [[1.2, 3.4, np.nan], [np.nan, 1.2, 1.3], [4.2, 2, 1]], dtype=np.float64
+ )
+ imp = SimpleImputer().fit(X)
+
+ X_test = np.asarray([[np.nan, np.nan, np.nan]], dtype=dtype_test)
+ X_trans = imp.transform(X_test)
+ assert X_trans.dtype == dtype_test
+
+
+@pytest.mark.parametrize("array_type", ["array", "sparse"])
+@pytest.mark.parametrize("keep_empty_features", [True, False])
+def test_simple_imputer_constant_keep_empty_features(array_type, keep_empty_features):
+ """Check the behaviour of `keep_empty_features` with `strategy='constant'.
+ For backward compatibility, a column full of missing values will always be
+ fill and never dropped.
+ """
+ X = np.array([[np.nan, 2], [np.nan, 3], [np.nan, 6]])
+ X = _convert_container(X, array_type)
+ fill_value = 10
+ imputer = SimpleImputer(
+ strategy="constant",
+ fill_value=fill_value,
+ keep_empty_features=keep_empty_features,
+ )
+
+ for method in ["fit_transform", "transform"]:
+ X_imputed = getattr(imputer, method)(X)
+ assert X_imputed.shape == X.shape
+ constant_feature = (
+ X_imputed[:, 0].A if array_type == "sparse" else X_imputed[:, 0]
+ )
+ assert_array_equal(constant_feature, fill_value)
+
+
+@pytest.mark.parametrize("array_type", ["array", "sparse"])
+@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
+@pytest.mark.parametrize("keep_empty_features", [True, False])
+def test_simple_imputer_keep_empty_features(strategy, array_type, keep_empty_features):
+ """Check the behaviour of `keep_empty_features` with all strategies but
+ 'constant'.
+ """
+ X = np.array([[np.nan, 2], [np.nan, 3], [np.nan, 6]])
+ X = _convert_container(X, array_type)
+ imputer = SimpleImputer(strategy=strategy, keep_empty_features=keep_empty_features)
+
+ for method in ["fit_transform", "transform"]:
+ X_imputed = getattr(imputer, method)(X)
+ if keep_empty_features:
+ assert X_imputed.shape == X.shape
+ constant_feature = (
+ X_imputed[:, 0].A if array_type == "sparse" else X_imputed[:, 0]
+ )
+ assert_array_equal(constant_feature, 0)
+ else:
+ assert X_imputed.shape == (X.shape[0], X.shape[1] - 1)
diff --git a/modin/pandas/test/interoperability/sklearn/inspection/_plot/test_boundary_decision_display.py b/modin/pandas/test/interoperability/sklearn/inspection/_plot/test_boundary_decision_display.py
new file mode 100644
index 00000000000..420f10ce015
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/inspection/_plot/test_boundary_decision_display.py
@@ -0,0 +1,370 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import warnings
+import pytest
+import numpy as np
+from numpy.testing import assert_allclose
+from sklearn.base import BaseEstimator
+from sklearn.base import ClassifierMixin
+from sklearn.datasets import make_classification
+from sklearn.linear_model import LogisticRegression
+from sklearn.datasets import load_iris
+from sklearn.datasets import make_multilabel_classification
+from sklearn.tree import DecisionTreeRegressor
+from sklearn.tree import DecisionTreeClassifier
+
+from sklearn.inspection import DecisionBoundaryDisplay
+from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method
+
+
+# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
+pytestmark = pytest.mark.filterwarnings(
+ "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
+ "matplotlib.*"
+)
+
+
+X, y = make_classification(
+ n_informative=1,
+ n_redundant=1,
+ n_clusters_per_class=1,
+ n_features=2,
+ random_state=42,
+)
+
+
+@pytest.fixture(scope="module")
+def fitted_clf():
+ return LogisticRegression().fit(X, y)
+
+
+def test_input_data_dimension(pyplot):
+ """Check that we raise an error when `X` does not have exactly 2 features."""
+ X, y = make_classification(n_samples=10, n_features=4, random_state=0)
+
+ clf = LogisticRegression().fit(X, y)
+ msg = "n_features must be equal to 2. Got 4 instead."
+ with pytest.raises(ValueError, match=msg):
+ DecisionBoundaryDisplay.from_estimator(estimator=clf, X=X)
+
+
+def test_check_boundary_response_method_auto():
+ """Check _check_boundary_response_method behavior with 'auto'."""
+
+ class A:
+ def decision_function(self):
+ pass
+
+ a_inst = A()
+ method = _check_boundary_response_method(a_inst, "auto")
+ assert method == a_inst.decision_function
+
+ class B:
+ def predict_proba(self):
+ pass
+
+ b_inst = B()
+ method = _check_boundary_response_method(b_inst, "auto")
+ assert method == b_inst.predict_proba
+
+ class C:
+ def predict_proba(self):
+ pass
+
+ def decision_function(self):
+ pass
+
+ c_inst = C()
+ method = _check_boundary_response_method(c_inst, "auto")
+ assert method == c_inst.decision_function
+
+ class D:
+ def predict(self):
+ pass
+
+ d_inst = D()
+ method = _check_boundary_response_method(d_inst, "auto")
+ assert method == d_inst.predict
+
+
+@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
+def test_multiclass_error(pyplot, response_method):
+ """Check multiclass errors."""
+ X, y = make_classification(n_classes=3, n_informative=3, random_state=0)
+ X = X[:, [0, 1]]
+ lr = LogisticRegression().fit(X, y)
+
+ msg = (
+ "Multiclass classifiers are only supported when response_method is 'predict' or"
+ " 'auto'"
+ )
+ with pytest.raises(ValueError, match=msg):
+ DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method)
+
+
+@pytest.mark.parametrize("response_method", ["auto", "predict"])
+def test_multiclass(pyplot, response_method):
+ """Check multiclass gives expected results."""
+ grid_resolution = 10
+ eps = 1.0
+ X, y = make_classification(n_classes=3, n_informative=3, random_state=0)
+ X = X[:, [0, 1]]
+ lr = LogisticRegression(random_state=0).fit(X, y)
+
+ disp = DecisionBoundaryDisplay.from_estimator(
+ lr, X, response_method=response_method, grid_resolution=grid_resolution, eps=1.0
+ )
+
+ x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps
+ x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps
+ xx0, xx1 = np.meshgrid(
+ np.linspace(x0_min, x0_max, grid_resolution),
+ np.linspace(x1_min, x1_max, grid_resolution),
+ )
+ response = lr.predict(np.c_[xx0.ravel(), xx1.ravel()])
+ assert_allclose(disp.response, response.reshape(xx0.shape))
+ assert_allclose(disp.xx0, xx0)
+ assert_allclose(disp.xx1, xx1)
+
+
+@pytest.mark.parametrize(
+ "kwargs, error_msg",
+ [
+ (
+ {"plot_method": "hello_world"},
+ r"plot_method must be one of contourf, contour, pcolormesh. Got hello_world"
+ r" instead.",
+ ),
+ (
+ {"grid_resolution": 1},
+ r"grid_resolution must be greater than 1. Got 1 instead",
+ ),
+ (
+ {"grid_resolution": -1},
+ r"grid_resolution must be greater than 1. Got -1 instead",
+ ),
+ ({"eps": -1.1}, r"eps must be greater than or equal to 0. Got -1.1 instead"),
+ ],
+)
+def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf):
+ """Check input validation from_estimator."""
+ with pytest.raises(ValueError, match=error_msg):
+ DecisionBoundaryDisplay.from_estimator(fitted_clf, X, **kwargs)
+
+
+def test_display_plot_input_error(pyplot, fitted_clf):
+ """Check input validation for `plot`."""
+ disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, X, grid_resolution=5)
+
+ with pytest.raises(ValueError, match="plot_method must be 'contourf'"):
+ disp.plot(plot_method="hello_world")
+
+
+@pytest.mark.parametrize(
+ "response_method", ["auto", "predict", "predict_proba", "decision_function"]
+)
+@pytest.mark.parametrize("plot_method", ["contourf", "contour"])
+def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_method):
+ """Check that decision boundary is correct."""
+ fig, ax = pyplot.subplots()
+ eps = 2.0
+ disp = DecisionBoundaryDisplay.from_estimator(
+ fitted_clf,
+ X,
+ grid_resolution=5,
+ response_method=response_method,
+ plot_method=plot_method,
+ eps=eps,
+ ax=ax,
+ )
+ assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet)
+ assert disp.ax_ == ax
+ assert disp.figure_ == fig
+
+ x0, x1 = X[:, 0], X[:, 1]
+
+ x0_min, x0_max = x0.min() - eps, x0.max() + eps
+ x1_min, x1_max = x1.min() - eps, x1.max() + eps
+
+ assert disp.xx0.min() == pytest.approx(x0_min)
+ assert disp.xx0.max() == pytest.approx(x0_max)
+ assert disp.xx1.min() == pytest.approx(x1_min)
+ assert disp.xx1.max() == pytest.approx(x1_max)
+
+ fig2, ax2 = pyplot.subplots()
+ # change plotting method for second plot
+ disp.plot(plot_method="pcolormesh", ax=ax2, shading="auto")
+ assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh)
+ assert disp.ax_ == ax2
+ assert disp.figure_ == fig2
+
+
+@pytest.mark.parametrize(
+ "response_method, msg",
+ [
+ (
+ "predict_proba",
+ "MyClassifier has none of the following attributes: predict_proba",
+ ),
+ (
+ "decision_function",
+ "MyClassifier has none of the following attributes: decision_function",
+ ),
+ (
+ "auto",
+ "MyClassifier has none of the following attributes: decision_function, "
+ "predict_proba, predict",
+ ),
+ (
+ "bad_method",
+ "MyClassifier has none of the following attributes: bad_method",
+ ),
+ ],
+)
+def test_error_bad_response(pyplot, response_method, msg):
+ """Check errors for bad response."""
+
+ class MyClassifier(BaseEstimator, ClassifierMixin):
+ def fit(self, X, y):
+ self.fitted_ = True
+ self.classes_ = [0, 1]
+ return self
+
+ clf = MyClassifier().fit(X, y)
+
+ with pytest.raises(ValueError, match=msg):
+ DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method)
+
+
+@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"])
+def test_multilabel_classifier_error(pyplot, response_method):
+ """Check that multilabel classifier raises correct error."""
+ X, y = make_multilabel_classification(random_state=0)
+ X = X[:, :2]
+ tree = DecisionTreeClassifier().fit(X, y)
+
+ msg = "Multi-label and multi-output multi-class classifiers are not supported"
+ with pytest.raises(ValueError, match=msg):
+ DecisionBoundaryDisplay.from_estimator(
+ tree,
+ X,
+ response_method=response_method,
+ )
+
+
+@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"])
+def test_multi_output_multi_class_classifier_error(pyplot, response_method):
+ """Check that multi-output multi-class classifier raises correct error."""
+ X = np.asarray([[0, 1], [1, 2]])
+ y = np.asarray([["tree", "cat"], ["cat", "tree"]])
+ tree = DecisionTreeClassifier().fit(X, y)
+
+ msg = "Multi-label and multi-output multi-class classifiers are not supported"
+ with pytest.raises(ValueError, match=msg):
+ DecisionBoundaryDisplay.from_estimator(
+ tree,
+ X,
+ response_method=response_method,
+ )
+
+
+def test_multioutput_regressor_error(pyplot):
+ """Check that multioutput regressor raises correct error."""
+ X = np.asarray([[0, 1], [1, 2]])
+ y = np.asarray([[0, 1], [4, 1]])
+ tree = DecisionTreeRegressor().fit(X, y)
+ with pytest.raises(ValueError, match="Multi-output regressors are not supported"):
+ DecisionBoundaryDisplay.from_estimator(tree, X)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.filterwarnings(
+ # We expect to raise the following warning because the classifier is fit on a
+ # NumPy array
+ "ignore:X has feature names, but LogisticRegression was fitted without"
+)
+def test_dataframe_labels_used(pyplot, fitted_clf):
+ """Check that column names are used for pandas."""
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame(X, columns=["col_x", "col_y"])
+
+ # pandas column names are used by default
+ _, ax = pyplot.subplots()
+ disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, df, ax=ax)
+ assert ax.get_xlabel() == "col_x"
+ assert ax.get_ylabel() == "col_y"
+
+ # second call to plot will have the names
+ fig, ax = pyplot.subplots()
+ disp.plot(ax=ax)
+ assert ax.get_xlabel() == "col_x"
+ assert ax.get_ylabel() == "col_y"
+
+ # axes with a label will not get overridden
+ fig, ax = pyplot.subplots()
+ ax.set(xlabel="hello", ylabel="world")
+ disp.plot(ax=ax)
+ assert ax.get_xlabel() == "hello"
+ assert ax.get_ylabel() == "world"
+
+ # labels get overriden only if provided to the `plot` method
+ disp.plot(ax=ax, xlabel="overwritten_x", ylabel="overwritten_y")
+ assert ax.get_xlabel() == "overwritten_x"
+ assert ax.get_ylabel() == "overwritten_y"
+
+ # labels do not get inferred if provided to `from_estimator`
+ _, ax = pyplot.subplots()
+ disp = DecisionBoundaryDisplay.from_estimator(
+ fitted_clf, df, ax=ax, xlabel="overwritten_x", ylabel="overwritten_y"
+ )
+ assert ax.get_xlabel() == "overwritten_x"
+ assert ax.get_ylabel() == "overwritten_y"
+
+
+def test_string_target(pyplot):
+ """Check that decision boundary works with classifiers trained on string labels."""
+ iris = load_iris()
+ X = iris.data[:, [0, 1]]
+
+ # Use strings as target
+ y = iris.target_names[iris.target]
+ log_reg = LogisticRegression().fit(X, y)
+
+ # Does not raise
+ DecisionBoundaryDisplay.from_estimator(
+ log_reg,
+ X,
+ grid_resolution=5,
+ response_method="predict",
+ )
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_dataframe_support(pyplot):
+ """Check that passing a dataframe at fit and to the Display does not
+ raise warnings.
+
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/23311
+ """
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame(X, columns=["col_x", "col_y"])
+ estimator = LogisticRegression().fit(df, y)
+
+ with warnings.catch_warnings():
+ # no warnings linked to feature names validation should be raised
+ warnings.simplefilter("error", UserWarning)
+ DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict")
diff --git a/modin/pandas/test/interoperability/sklearn/inspection/_plot/test_plot_partial_dependence.py b/modin/pandas/test/interoperability/sklearn/inspection/_plot/test_plot_partial_dependence.py
new file mode 100644
index 00000000000..11b29effcc9
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/inspection/_plot/test_plot_partial_dependence.py
@@ -0,0 +1,1139 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import numpy as np
+from scipy.stats.mstats import mquantiles
+import pytest
+from numpy.testing import assert_allclose
+import warnings
+from sklearn.datasets import load_diabetes
+from sklearn.datasets import load_iris
+from sklearn.datasets import make_classification, make_regression
+from sklearn.ensemble import GradientBoostingRegressor
+from sklearn.ensemble import GradientBoostingClassifier
+from sklearn.linear_model import LinearRegression
+from sklearn.utils._testing import _convert_container
+from sklearn.compose import make_column_transformer
+from sklearn.preprocessing import OneHotEncoder
+from sklearn.pipeline import make_pipeline
+from sklearn.inspection import PartialDependenceDisplay
+
+
+# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
+pytestmark = pytest.mark.filterwarnings(
+ "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
+ "matplotlib.*",
+)
+
+
+@pytest.fixture(scope="module")
+def diabetes():
+ # diabetes dataset, subsampled for speed
+ data = load_diabetes()
+ data.data = data.data[:50]
+ data.target = data.target[:50]
+ return data
+
+
+@pytest.fixture(scope="module")
+def clf_diabetes(diabetes):
+ clf = GradientBoostingRegressor(n_estimators=10, random_state=1)
+ clf.fit(diabetes.data, diabetes.target)
+ return clf
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize("grid_resolution", [10, 20])
+def test_plot_partial_dependence(grid_resolution, pyplot, clf_diabetes, diabetes):
+ # Test partial dependence plot function.
+ # Use columns 0 & 2 as 1 is not quantitative (sex)
+ feature_names = diabetes.feature_names
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ [0, 2, (0, 2)],
+ grid_resolution=grid_resolution,
+ feature_names=feature_names,
+ contour_kw={"cmap": "jet"},
+ )
+ fig = pyplot.gcf()
+ axs = fig.get_axes()
+ assert disp.figure_ is fig
+ assert len(axs) == 4
+
+ assert disp.bounding_ax_ is not None
+ assert disp.axes_.shape == (1, 3)
+ assert disp.lines_.shape == (1, 3)
+ assert disp.contours_.shape == (1, 3)
+ assert disp.deciles_vlines_.shape == (1, 3)
+ assert disp.deciles_hlines_.shape == (1, 3)
+
+ assert disp.lines_[0, 2] is None
+ assert disp.contours_[0, 0] is None
+ assert disp.contours_[0, 1] is None
+
+ # deciles lines: always show on xaxis, only show on yaxis if 2-way PDP
+ for i in range(3):
+ assert disp.deciles_vlines_[0, i] is not None
+ assert disp.deciles_hlines_[0, 0] is None
+ assert disp.deciles_hlines_[0, 1] is None
+ assert disp.deciles_hlines_[0, 2] is not None
+
+ assert disp.features == [(0,), (2,), (0, 2)]
+ assert np.all(disp.feature_names == feature_names)
+ assert len(disp.deciles) == 2
+ for i in [0, 2]:
+ assert_allclose(
+ disp.deciles[i],
+ mquantiles(diabetes.data[:, i], prob=np.arange(0.1, 1.0, 0.1)),
+ )
+
+ single_feature_positions = [(0, (0, 0)), (2, (0, 1))]
+ expected_ylabels = ["Partial dependence", ""]
+
+ for i, (feat_col, pos) in enumerate(single_feature_positions):
+ ax = disp.axes_[pos]
+ assert ax.get_ylabel() == expected_ylabels[i]
+ assert ax.get_xlabel() == diabetes.feature_names[feat_col]
+
+ line = disp.lines_[pos]
+
+ avg_preds = disp.pd_results[i]
+ assert avg_preds.average.shape == (1, grid_resolution)
+ target_idx = disp.target_idx
+
+ line_data = line.get_data()
+ assert_allclose(line_data[0], avg_preds["values"][0])
+ assert_allclose(line_data[1], avg_preds.average[target_idx].ravel())
+
+ # two feature position
+ ax = disp.axes_[0, 2]
+ coutour = disp.contours_[0, 2]
+ assert coutour.get_cmap().name == "jet"
+ assert ax.get_xlabel() == diabetes.feature_names[0]
+ assert ax.get_ylabel() == diabetes.feature_names[2]
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize(
+ "kind, centered, subsample, shape",
+ [
+ ("average", False, None, (1, 3)),
+ ("individual", False, None, (1, 3, 50)),
+ ("both", False, None, (1, 3, 51)),
+ ("individual", False, 20, (1, 3, 20)),
+ ("both", False, 20, (1, 3, 21)),
+ ("individual", False, 0.5, (1, 3, 25)),
+ ("both", False, 0.5, (1, 3, 26)),
+ ("average", True, None, (1, 3)),
+ ("individual", True, None, (1, 3, 50)),
+ ("both", True, None, (1, 3, 51)),
+ ("individual", True, 20, (1, 3, 20)),
+ ("both", True, 20, (1, 3, 21)),
+ ],
+)
+def test_plot_partial_dependence_kind(
+ pyplot,
+ kind,
+ centered,
+ subsample,
+ shape,
+ clf_diabetes,
+ diabetes,
+):
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ [0, 1, 2],
+ kind=kind,
+ centered=centered,
+ subsample=subsample,
+ )
+
+ assert disp.axes_.shape == (1, 3)
+ assert disp.lines_.shape == shape
+ assert disp.contours_.shape == (1, 3)
+
+ assert disp.contours_[0, 0] is None
+ assert disp.contours_[0, 1] is None
+ assert disp.contours_[0, 2] is None
+
+ if centered:
+ assert all([ln._y[0] == 0.0 for ln in disp.lines_.ravel() if ln is not None])
+ else:
+ assert all([ln._y[0] != 0.0 for ln in disp.lines_.ravel() if ln is not None])
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize(
+ "input_type, feature_names_type",
+ [
+ ("dataframe", None),
+ ("dataframe", "list"),
+ ("list", "list"),
+ ("array", "list"),
+ ("dataframe", "array"),
+ ("list", "array"),
+ ("array", "array"),
+ ("dataframe", "series"),
+ ("list", "series"),
+ ("array", "series"),
+ ("dataframe", "index"),
+ ("list", "index"),
+ ("array", "index"),
+ ],
+)
+def test_plot_partial_dependence_str_features(
+ pyplot,
+ clf_diabetes,
+ diabetes,
+ input_type,
+ feature_names_type,
+):
+ if input_type == "dataframe":
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
+ elif input_type == "list":
+ X = diabetes.data.tolist()
+ else:
+ X = diabetes.data
+
+ if feature_names_type is None:
+ feature_names = None
+ else:
+ feature_names = _convert_container(diabetes.feature_names, feature_names_type)
+
+ grid_resolution = 25
+ # check with str features and array feature names and single column
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ X,
+ [("age", "bmi"), "bmi"],
+ grid_resolution=grid_resolution,
+ feature_names=feature_names,
+ n_cols=1,
+ line_kw={"alpha": 0.8},
+ )
+ fig = pyplot.gcf()
+ axs = fig.get_axes()
+ assert len(axs) == 3
+
+ assert disp.figure_ is fig
+ assert disp.axes_.shape == (2, 1)
+ assert disp.lines_.shape == (2, 1)
+ assert disp.contours_.shape == (2, 1)
+ assert disp.deciles_vlines_.shape == (2, 1)
+ assert disp.deciles_hlines_.shape == (2, 1)
+
+ assert disp.lines_[0, 0] is None
+ assert disp.deciles_vlines_[0, 0] is not None
+ assert disp.deciles_hlines_[0, 0] is not None
+ assert disp.contours_[1, 0] is None
+ assert disp.deciles_hlines_[1, 0] is None
+ assert disp.deciles_vlines_[1, 0] is not None
+
+ # line
+ ax = disp.axes_[1, 0]
+ assert ax.get_xlabel() == "bmi"
+ assert ax.get_ylabel() == "Partial dependence"
+
+ line = disp.lines_[1, 0]
+ avg_preds = disp.pd_results[1]
+ target_idx = disp.target_idx
+ assert line.get_alpha() == 0.8
+
+ line_data = line.get_data()
+ assert_allclose(line_data[0], avg_preds["values"][0])
+ assert_allclose(line_data[1], avg_preds.average[target_idx].ravel())
+
+ # contour
+ ax = disp.axes_[0, 0]
+ assert ax.get_xlabel() == "age"
+ assert ax.get_ylabel() == "bmi"
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+def test_plot_partial_dependence_custom_axes(pyplot, clf_diabetes, diabetes):
+ grid_resolution = 25
+ fig, (ax1, ax2) = pyplot.subplots(1, 2)
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ ["age", ("age", "bmi")],
+ grid_resolution=grid_resolution,
+ feature_names=diabetes.feature_names,
+ ax=[ax1, ax2],
+ )
+ assert fig is disp.figure_
+ assert disp.bounding_ax_ is None
+ assert disp.axes_.shape == (2,)
+ assert disp.axes_[0] is ax1
+ assert disp.axes_[1] is ax2
+
+ ax = disp.axes_[0]
+ assert ax.get_xlabel() == "age"
+ assert ax.get_ylabel() == "Partial dependence"
+
+ line = disp.lines_[0]
+ avg_preds = disp.pd_results[0]
+ target_idx = disp.target_idx
+
+ line_data = line.get_data()
+ assert_allclose(line_data[0], avg_preds["values"][0])
+ assert_allclose(line_data[1], avg_preds.average[target_idx].ravel())
+
+ # contour
+ ax = disp.axes_[1]
+ assert ax.get_xlabel() == "age"
+ assert ax.get_ylabel() == "bmi"
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize(
+ "kind, lines", [("average", 1), ("individual", 50), ("both", 51)]
+)
+def test_plot_partial_dependence_passing_numpy_axes(
+ pyplot, clf_diabetes, diabetes, kind, lines
+):
+ grid_resolution = 25
+ feature_names = diabetes.feature_names
+ disp1 = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ ["age", "bmi"],
+ kind=kind,
+ grid_resolution=grid_resolution,
+ feature_names=feature_names,
+ )
+ assert disp1.axes_.shape == (1, 2)
+ assert disp1.axes_[0, 0].get_ylabel() == "Partial dependence"
+ assert disp1.axes_[0, 1].get_ylabel() == ""
+ assert len(disp1.axes_[0, 0].get_lines()) == lines
+ assert len(disp1.axes_[0, 1].get_lines()) == lines
+
+ lr = LinearRegression()
+ lr.fit(diabetes.data, diabetes.target)
+
+ disp2 = PartialDependenceDisplay.from_estimator(
+ lr,
+ diabetes.data,
+ ["age", "bmi"],
+ kind=kind,
+ grid_resolution=grid_resolution,
+ feature_names=feature_names,
+ ax=disp1.axes_,
+ )
+
+ assert np.all(disp1.axes_ == disp2.axes_)
+ assert len(disp2.axes_[0, 0].get_lines()) == 2 * lines
+ assert len(disp2.axes_[0, 1].get_lines()) == 2 * lines
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize("nrows, ncols", [(2, 2), (3, 1)])
+def test_plot_partial_dependence_incorrent_num_axes(
+ pyplot, clf_diabetes, diabetes, nrows, ncols
+):
+ grid_resolution = 5
+ fig, axes = pyplot.subplots(nrows, ncols)
+ axes_formats = [list(axes.ravel()), tuple(axes.ravel()), axes]
+
+ msg = "Expected ax to have 2 axes, got {}".format(nrows * ncols)
+
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ ["age", "bmi"],
+ grid_resolution=grid_resolution,
+ feature_names=diabetes.feature_names,
+ )
+
+ for ax_format in axes_formats:
+ with pytest.raises(ValueError, match=msg):
+ PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ ["age", "bmi"],
+ grid_resolution=grid_resolution,
+ feature_names=diabetes.feature_names,
+ ax=ax_format,
+ )
+
+ # with axes object
+ with pytest.raises(ValueError, match=msg):
+ disp.plot(ax=ax_format)
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+def test_plot_partial_dependence_with_same_axes(pyplot, clf_diabetes, diabetes):
+ # The first call to plot_partial_dependence will create two new axes to
+ # place in the space of the passed in axes, which results in a total of
+ # three axes in the figure.
+ # Currently the API does not allow for the second call to
+ # plot_partial_dependence to use the same axes again, because it will
+ # create two new axes in the space resulting in five axes. To get the
+ # expected behavior one needs to pass the generated axes into the second
+ # call:
+ # disp1 = plot_partial_dependence(...)
+ # disp2 = plot_partial_dependence(..., ax=disp1.axes_)
+
+ grid_resolution = 25
+ fig, ax = pyplot.subplots()
+ PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ ["age", "bmi"],
+ grid_resolution=grid_resolution,
+ feature_names=diabetes.feature_names,
+ ax=ax,
+ )
+
+ msg = (
+ "The ax was already used in another plot function, please set "
+ "ax=display.axes_ instead"
+ )
+
+ with pytest.raises(ValueError, match=msg):
+ PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ ["age", "bmi"],
+ grid_resolution=grid_resolution,
+ feature_names=diabetes.feature_names,
+ ax=ax,
+ )
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+def test_plot_partial_dependence_feature_name_reuse(pyplot, clf_diabetes, diabetes):
+ # second call to plot does not change the feature names from the first
+ # call
+
+ feature_names = diabetes.feature_names
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ [0, 1],
+ grid_resolution=10,
+ feature_names=feature_names,
+ )
+
+ PartialDependenceDisplay.from_estimator(
+ clf_diabetes, diabetes.data, [0, 1], grid_resolution=10, ax=disp.axes_
+ )
+
+ for i, ax in enumerate(disp.axes_.ravel()):
+ assert ax.get_xlabel() == feature_names[i]
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+def test_plot_partial_dependence_multiclass(pyplot):
+ grid_resolution = 25
+ clf_int = GradientBoostingClassifier(n_estimators=10, random_state=1)
+ iris = load_iris()
+
+ # Test partial dependence plot function on multi-class input.
+ clf_int.fit(iris.data, iris.target)
+ disp_target_0 = PartialDependenceDisplay.from_estimator(
+ clf_int, iris.data, [0, 3], target=0, grid_resolution=grid_resolution
+ )
+ assert disp_target_0.figure_ is pyplot.gcf()
+ assert disp_target_0.axes_.shape == (1, 2)
+ assert disp_target_0.lines_.shape == (1, 2)
+ assert disp_target_0.contours_.shape == (1, 2)
+ assert disp_target_0.deciles_vlines_.shape == (1, 2)
+ assert disp_target_0.deciles_hlines_.shape == (1, 2)
+ assert all(c is None for c in disp_target_0.contours_.flat)
+ assert disp_target_0.target_idx == 0
+
+ # now with symbol labels
+ target = iris.target_names[iris.target]
+ clf_symbol = GradientBoostingClassifier(n_estimators=10, random_state=1)
+ clf_symbol.fit(iris.data, target)
+ disp_symbol = PartialDependenceDisplay.from_estimator(
+ clf_symbol, iris.data, [0, 3], target="setosa", grid_resolution=grid_resolution
+ )
+ assert disp_symbol.figure_ is pyplot.gcf()
+ assert disp_symbol.axes_.shape == (1, 2)
+ assert disp_symbol.lines_.shape == (1, 2)
+ assert disp_symbol.contours_.shape == (1, 2)
+ assert disp_symbol.deciles_vlines_.shape == (1, 2)
+ assert disp_symbol.deciles_hlines_.shape == (1, 2)
+ assert all(c is None for c in disp_symbol.contours_.flat)
+ assert disp_symbol.target_idx == 0
+
+ for int_result, symbol_result in zip(
+ disp_target_0.pd_results, disp_symbol.pd_results
+ ):
+ assert_allclose(int_result.average, symbol_result.average)
+ assert_allclose(int_result["values"], symbol_result["values"])
+
+ # check that the pd plots are different for another target
+ disp_target_1 = PartialDependenceDisplay.from_estimator(
+ clf_int, iris.data, [0, 3], target=1, grid_resolution=grid_resolution
+ )
+ target_0_data_y = disp_target_0.lines_[0, 0].get_data()[1]
+ target_1_data_y = disp_target_1.lines_[0, 0].get_data()[1]
+ assert any(target_0_data_y != target_1_data_y)
+
+
+multioutput_regression_data = make_regression(n_samples=50, n_targets=2, random_state=0)
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize("target", [0, 1])
+def test_plot_partial_dependence_multioutput(pyplot, target):
+ # Test partial dependence plot function on multi-output input.
+ X, y = multioutput_regression_data
+ clf = LinearRegression().fit(X, y)
+
+ grid_resolution = 25
+ disp = PartialDependenceDisplay.from_estimator(
+ clf, X, [0, 1], target=target, grid_resolution=grid_resolution
+ )
+ fig = pyplot.gcf()
+ axs = fig.get_axes()
+ assert len(axs) == 3
+ assert disp.target_idx == target
+ assert disp.bounding_ax_ is not None
+
+ positions = [(0, 0), (0, 1)]
+ expected_label = ["Partial dependence", ""]
+
+ for i, pos in enumerate(positions):
+ ax = disp.axes_[pos]
+ assert ax.get_ylabel() == expected_label[i]
+ assert ax.get_xlabel() == f"x{i}"
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+def test_plot_partial_dependence_dataframe(pyplot, clf_diabetes, diabetes):
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
+
+ grid_resolution = 25
+
+ PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ df,
+ ["bp", "s1"],
+ grid_resolution=grid_resolution,
+ feature_names=df.columns.tolist(),
+ )
+
+
+dummy_classification_data = make_classification(random_state=0)
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize(
+ "data, params, err_msg",
+ [
+ (
+ multioutput_regression_data,
+ {"target": None, "features": [0]},
+ "target must be specified for multi-output",
+ ),
+ (
+ multioutput_regression_data,
+ {"target": -1, "features": [0]},
+ r"target must be in \[0, n_tasks\]",
+ ),
+ (
+ multioutput_regression_data,
+ {"target": 100, "features": [0]},
+ r"target must be in \[0, n_tasks\]",
+ ),
+ (
+ dummy_classification_data,
+ {"features": ["foobar"], "feature_names": None},
+ "Feature 'foobar' not in feature_names",
+ ),
+ (
+ dummy_classification_data,
+ {"features": ["foobar"], "feature_names": ["abcd", "def"]},
+ "Feature 'foobar' not in feature_names",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [(1, 2, 3)]},
+ "Each entry in features must be either an int, ",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [1, {}]},
+ "Each entry in features must be either an int, ",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [tuple()]},
+ "Each entry in features must be either an int, ",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [123], "feature_names": ["blahblah"]},
+ "All entries of features must be less than ",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [0, 1, 2], "feature_names": ["a", "b", "a"]},
+ "feature_names should not contain duplicates",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [1, 2], "kind": ["both"]},
+ "When `kind` is provided as a list of strings, it should contain",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [1], "subsample": -1},
+ "When an integer, subsample=-1 should be positive.",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [1], "subsample": 1.2},
+ r"When a floating-point, subsample=1.2 should be in the \(0, 1\) range",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [1, 2], "categorical_features": [1.0, 2.0]},
+ "Expected `categorical_features` to be an array-like of boolean,",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [(1, 2)], "categorical_features": [2]},
+ "Two-way partial dependence plots are not supported for pairs",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [1], "categorical_features": [1], "kind": "individual"},
+ "It is not possible to display individual effects",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [1], "kind": "foo"},
+ "Values provided to `kind` must be one of",
+ ),
+ (
+ dummy_classification_data,
+ {"features": [0, 1], "kind": ["foo", "individual"]},
+ "Values provided to `kind` must be one of",
+ ),
+ ],
+)
+def test_plot_partial_dependence_error(pyplot, data, params, err_msg):
+ X, y = data
+ estimator = LinearRegression().fit(X, y)
+
+ with pytest.raises(ValueError, match=err_msg):
+ PartialDependenceDisplay.from_estimator(estimator, X, **params)
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize(
+ "params, err_msg",
+ [
+ ({"target": 4, "features": [0]}, "target not in est.classes_, got 4"),
+ ({"target": None, "features": [0]}, "target must be specified for multi-class"),
+ (
+ {"target": 1, "features": [4.5]},
+ "Each entry in features must be either an int,",
+ ),
+ ],
+)
+def test_plot_partial_dependence_multiclass_error(pyplot, params, err_msg):
+ iris = load_iris()
+ clf = GradientBoostingClassifier(n_estimators=10, random_state=1)
+ clf.fit(iris.data, iris.target)
+
+ with pytest.raises(ValueError, match=err_msg):
+ PartialDependenceDisplay.from_estimator(clf, iris.data, **params)
+
+
+def test_plot_partial_dependence_does_not_override_ylabel(
+ pyplot, clf_diabetes, diabetes
+):
+ # Non-regression test to be sure to not override the ylabel if it has been
+ # See https://github.com/scikit-learn/scikit-learn/issues/15772
+ _, axes = pyplot.subplots(1, 2)
+ axes[0].set_ylabel("Hello world")
+ PartialDependenceDisplay.from_estimator(
+ clf_diabetes, diabetes.data, [0, 1], ax=axes
+ )
+
+ assert axes[0].get_ylabel() == "Hello world"
+ assert axes[1].get_ylabel() == "Partial dependence"
+
+
+@pytest.mark.parametrize(
+ "categorical_features, array_type",
+ [
+ (["col_A", "col_C"], "dataframe"),
+ ([0, 2], "array"),
+ ([True, False, True], "array"),
+ ],
+)
+def test_plot_partial_dependence_with_categorical(
+ pyplot, categorical_features, array_type
+):
+ X = [[1, 1, "A"], [2, 0, "C"], [3, 2, "B"]]
+ column_name = ["col_A", "col_B", "col_C"]
+ X = _convert_container(X, array_type, columns_name=column_name)
+ y = np.array([1.2, 0.5, 0.45]).T
+
+ preprocessor = make_column_transformer((OneHotEncoder(), categorical_features))
+ model = make_pipeline(preprocessor, LinearRegression())
+ model.fit(X, y)
+
+ # single feature
+ disp = PartialDependenceDisplay.from_estimator(
+ model,
+ X,
+ features=["col_C"],
+ feature_names=column_name,
+ categorical_features=categorical_features,
+ )
+
+ assert disp.figure_ is pyplot.gcf()
+ assert disp.bars_.shape == (1, 1)
+ assert disp.bars_[0][0] is not None
+ assert disp.lines_.shape == (1, 1)
+ assert disp.lines_[0][0] is None
+ assert disp.contours_.shape == (1, 1)
+ assert disp.contours_[0][0] is None
+ assert disp.deciles_vlines_.shape == (1, 1)
+ assert disp.deciles_vlines_[0][0] is None
+ assert disp.deciles_hlines_.shape == (1, 1)
+ assert disp.deciles_hlines_[0][0] is None
+ assert disp.axes_[0, 0].get_legend() is None
+
+ # interaction between two features
+ disp = PartialDependenceDisplay.from_estimator(
+ model,
+ X,
+ features=[("col_A", "col_C")],
+ feature_names=column_name,
+ categorical_features=categorical_features,
+ )
+
+ assert disp.figure_ is pyplot.gcf()
+ assert disp.bars_.shape == (1, 1)
+ assert disp.bars_[0][0] is None
+ assert disp.lines_.shape == (1, 1)
+ assert disp.lines_[0][0] is None
+ assert disp.contours_.shape == (1, 1)
+ assert disp.contours_[0][0] is None
+ assert disp.deciles_vlines_.shape == (1, 1)
+ assert disp.deciles_vlines_[0][0] is None
+ assert disp.deciles_hlines_.shape == (1, 1)
+ assert disp.deciles_hlines_[0][0] is None
+ assert disp.axes_[0, 0].get_legend() is None
+
+
+def test_plot_partial_dependence_legend(pyplot):
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame(
+ {
+ "col_A": ["A", "B", "C"],
+ "col_B": [1, 0, 2],
+ "col_C": ["C", "B", "A"],
+ }
+ )
+ y = np.array([1.2, 0.5, 0.45]).T
+
+ categorical_features = ["col_A", "col_C"]
+ preprocessor = make_column_transformer((OneHotEncoder(), categorical_features))
+ model = make_pipeline(preprocessor, LinearRegression())
+ model.fit(X, y)
+
+ disp = PartialDependenceDisplay.from_estimator(
+ model,
+ X,
+ features=["col_B", "col_C"],
+ categorical_features=categorical_features,
+ kind=["both", "average"],
+ )
+
+ legend_text = disp.axes_[0, 0].get_legend().get_texts()
+ assert len(legend_text) == 1
+ assert legend_text[0].get_text() == "average"
+ assert disp.axes_[0, 1].get_legend() is None
+
+
+@pytest.mark.parametrize(
+ "kind, expected_shape",
+ [("average", (1, 2)), ("individual", (1, 2, 20)), ("both", (1, 2, 21))],
+)
+def test_plot_partial_dependence_subsampling(
+ pyplot, clf_diabetes, diabetes, kind, expected_shape
+):
+ # check that the subsampling is properly working
+ # non-regression test for:
+ # https://github.com/scikit-learn/scikit-learn/pull/18359
+ matplotlib = pytest.importorskip("matplotlib")
+ grid_resolution = 25
+ feature_names = diabetes.feature_names
+
+ disp1 = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ ["age", "bmi"],
+ kind=kind,
+ grid_resolution=grid_resolution,
+ feature_names=feature_names,
+ subsample=20,
+ random_state=0,
+ )
+
+ assert disp1.lines_.shape == expected_shape
+ assert all(
+ [isinstance(line, matplotlib.lines.Line2D) for line in disp1.lines_.ravel()]
+ )
+
+
+@pytest.mark.parametrize(
+ "kind, line_kw, label",
+ [
+ ("individual", {}, None),
+ ("individual", {"label": "xxx"}, None),
+ ("average", {}, None),
+ ("average", {"label": "xxx"}, "xxx"),
+ ("both", {}, "average"),
+ ("both", {"label": "xxx"}, "xxx"),
+ ],
+)
+def test_partial_dependence_overwrite_labels(
+ pyplot,
+ clf_diabetes,
+ diabetes,
+ kind,
+ line_kw,
+ label,
+):
+ """Test that make sure that we can overwrite the label of the PDP plot"""
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ [0, 2],
+ grid_resolution=25,
+ feature_names=diabetes.feature_names,
+ kind=kind,
+ line_kw=line_kw,
+ )
+
+ for ax in disp.axes_.ravel():
+ if label is None:
+ assert ax.get_legend() is None
+ else:
+ legend_text = ax.get_legend().get_texts()
+ assert len(legend_text) == 1
+ assert legend_text[0].get_text() == label
+
+
+@pytest.mark.parametrize(
+ "categorical_features, array_type",
+ [
+ (["col_A", "col_C"], "dataframe"),
+ ([0, 2], "array"),
+ ([True, False, True], "array"),
+ ],
+)
+def test_grid_resolution_with_categorical(pyplot, categorical_features, array_type):
+ """Check that we raise a ValueError when the grid_resolution is too small
+ respect to the number of categories in the categorical features targeted.
+ """
+ X = [["A", 1, "A"], ["B", 0, "C"], ["C", 2, "B"]]
+ column_name = ["col_A", "col_B", "col_C"]
+ X = _convert_container(X, array_type, columns_name=column_name)
+ y = np.array([1.2, 0.5, 0.45]).T
+
+ preprocessor = make_column_transformer((OneHotEncoder(), categorical_features))
+ model = make_pipeline(preprocessor, LinearRegression())
+ model.fit(X, y)
+
+ err_msg = (
+ "resolution of the computed grid is less than the minimum number of categories"
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ PartialDependenceDisplay.from_estimator(
+ model,
+ X,
+ features=["col_C"],
+ feature_names=column_name,
+ categorical_features=categorical_features,
+ grid_resolution=2,
+ )
+
+
+# TODO(1.3): remove
+def test_partial_dependence_display_deprecation(pyplot, clf_diabetes, diabetes):
+ """Check that we raise the proper warning in the display."""
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ [0, 2],
+ grid_resolution=25,
+ feature_names=diabetes.feature_names,
+ )
+
+ deprecation_msg = "The `pdp_lim` parameter is deprecated"
+ overwritting_msg = (
+ "`pdp_lim` has been passed in both the constructor and the `plot` method"
+ )
+
+ disp.pdp_lim = None
+ # case when constructor and method parameters are the same
+ with pytest.warns(FutureWarning, match=deprecation_msg):
+ disp.plot(pdp_lim=None)
+ # case when constructor and method parameters are different
+ with warnings.catch_warnings(record=True) as record:
+ warnings.simplefilter("always", FutureWarning)
+ disp.plot(pdp_lim=(0, 1))
+ assert len(record) == 2
+ for warning in record:
+ assert warning.message.args[0].startswith((deprecation_msg, overwritting_msg))
+
+
+@pytest.mark.parametrize("kind", ["individual", "average", "both"])
+@pytest.mark.parametrize("centered", [True, False])
+def test_partial_dependence_plot_limits_one_way(
+ pyplot, clf_diabetes, diabetes, kind, centered
+):
+ """Check that the PD limit on the plots are properly set on one-way plots."""
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ features=(0, 1),
+ kind=kind,
+ grid_resolution=25,
+ feature_names=diabetes.feature_names,
+ )
+
+ range_pd = np.array([-1, 1], dtype=np.float64)
+ for pd in disp.pd_results:
+ if "average" in pd:
+ pd["average"][...] = range_pd[1]
+ pd["average"][0, 0] = range_pd[0]
+ if "individual" in pd:
+ pd["individual"][...] = range_pd[1]
+ pd["individual"][0, 0, 0] = range_pd[0]
+
+ disp.plot(centered=centered)
+ # check that we anchor to zero x-axis when centering
+ y_lim = range_pd - range_pd[0] if centered else range_pd
+ padding = 0.05 * (y_lim[1] - y_lim[0])
+ y_lim[0] -= padding
+ y_lim[1] += padding
+ for ax in disp.axes_.ravel():
+ assert_allclose(ax.get_ylim(), y_lim)
+
+
+@pytest.mark.parametrize("centered", [True, False])
+def test_partial_dependence_plot_limits_two_way(
+ pyplot, clf_diabetes, diabetes, centered
+):
+ """Check that the PD limit on the plots are properly set on two-way plots."""
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ features=[(0, 1)],
+ kind="average",
+ grid_resolution=25,
+ feature_names=diabetes.feature_names,
+ )
+
+ range_pd = np.array([-1, 1], dtype=np.float64)
+ for pd in disp.pd_results:
+ pd["average"][...] = range_pd[1]
+ pd["average"][0, 0] = range_pd[0]
+
+ disp.plot(centered=centered)
+ contours = disp.contours_[0, 0]
+ levels = range_pd - range_pd[0] if centered else range_pd
+
+ padding = 0.05 * (levels[1] - levels[0])
+ levels[0] -= padding
+ levels[1] += padding
+ expect_levels = np.linspace(*levels, num=8)
+ assert_allclose(contours.levels, expect_levels)
+
+
+def test_partial_dependence_kind_list(
+ pyplot,
+ clf_diabetes,
+ diabetes,
+):
+ """Check that we can provide a list of strings to kind parameter."""
+ matplotlib = pytest.importorskip("matplotlib")
+
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ features=[0, 2, (1, 2)],
+ grid_resolution=20,
+ kind=["both", "both", "average"],
+ )
+
+ for idx in [0, 1]:
+ assert all(
+ [
+ isinstance(line, matplotlib.lines.Line2D)
+ for line in disp.lines_[0, idx].ravel()
+ ]
+ )
+ assert disp.contours_[0, idx] is None
+
+ assert disp.contours_[0, 2] is not None
+ assert all([line is None for line in disp.lines_[0, 2].ravel()])
+
+
+@pytest.mark.parametrize(
+ "features, kind",
+ [
+ ([0, 2, (1, 2)], "individual"),
+ ([0, 2, (1, 2)], "both"),
+ ([(0, 1), (0, 2), (1, 2)], "individual"),
+ ([(0, 1), (0, 2), (1, 2)], "both"),
+ ([0, 2, (1, 2)], ["individual", "individual", "individual"]),
+ ([0, 2, (1, 2)], ["both", "both", "both"]),
+ ],
+)
+def test_partial_dependence_kind_error(
+ pyplot,
+ clf_diabetes,
+ diabetes,
+ features,
+ kind,
+):
+ """Check that we raise an informative error when 2-way PD is requested
+ together with 1-way PD/ICE"""
+ warn_msg = (
+ "ICE plot cannot be rendered for 2-way feature interactions. 2-way "
+ "feature interactions mandates PD plots using the 'average' kind"
+ )
+ with pytest.raises(ValueError, match=warn_msg):
+ PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ features=features,
+ grid_resolution=20,
+ kind=kind,
+ )
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize(
+ "line_kw, pd_line_kw, ice_lines_kw, expected_colors",
+ [
+ ({"color": "r"}, {"color": "g"}, {"color": "b"}, ("g", "b")),
+ (None, {"color": "g"}, {"color": "b"}, ("g", "b")),
+ ({"color": "r"}, None, {"color": "b"}, ("r", "b")),
+ ({"color": "r"}, {"color": "g"}, None, ("g", "r")),
+ ({"color": "r"}, None, None, ("r", "r")),
+ ({"color": "r"}, {"linestyle": "--"}, {"linestyle": "-."}, ("r", "r")),
+ ],
+)
+def test_plot_partial_dependence_lines_kw(
+ pyplot,
+ clf_diabetes,
+ diabetes,
+ line_kw,
+ pd_line_kw,
+ ice_lines_kw,
+ expected_colors,
+):
+ """Check that passing `pd_line_kw` and `ice_lines_kw` will act on the
+ specific lines in the plot.
+ """
+
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ [0, 2],
+ grid_resolution=20,
+ feature_names=diabetes.feature_names,
+ n_cols=2,
+ kind="both",
+ line_kw=line_kw,
+ pd_line_kw=pd_line_kw,
+ ice_lines_kw=ice_lines_kw,
+ )
+
+ line = disp.lines_[0, 0, -1]
+ assert line.get_color() == expected_colors[0]
+ if pd_line_kw is not None and "linestyle" in pd_line_kw:
+ assert line.get_linestyle() == pd_line_kw["linestyle"]
+ else:
+ assert line.get_linestyle() == "--"
+
+ line = disp.lines_[0, 0, 0]
+ assert line.get_color() == expected_colors[1]
+ if ice_lines_kw is not None and "linestyle" in ice_lines_kw:
+ assert line.get_linestyle() == ice_lines_kw["linestyle"]
+ else:
+ assert line.get_linestyle() == "-"
+
+
+def test_partial_dependence_display_wrong_len_kind(
+ pyplot,
+ clf_diabetes,
+ diabetes,
+):
+ """Check that we raise an error when `kind` is a list with a wrong length.
+
+ This case can only be triggered using the `PartialDependenceDisplay.from_estimator`
+ method.
+ """
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ features=[0, 2],
+ grid_resolution=20,
+ kind="average", # len(kind) != len(features)
+ )
+
+ # alter `kind` to be a list with a length different from length of `features`
+ disp.kind = ["average"]
+ err_msg = (
+ r"When `kind` is provided as a list of strings, it should contain as many"
+ r" elements as `features`. `kind` contains 1 element\(s\) and `features`"
+ r" contains 2 element\(s\)."
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ disp.plot()
+
+
+@pytest.mark.parametrize(
+ "kind",
+ ["individual", "both", "average", ["average", "both"], ["individual", "both"]],
+)
+def test_partial_dependence_display_kind_centered_interaction(
+ pyplot,
+ kind,
+ clf_diabetes,
+ diabetes,
+):
+ """Check that we properly center ICE and PD when passing kind as a string and as a
+ list."""
+ disp = PartialDependenceDisplay.from_estimator(
+ clf_diabetes,
+ diabetes.data,
+ [0, 1],
+ kind=kind,
+ centered=True,
+ subsample=5,
+ )
+
+ assert all([ln._y[0] == 0.0 for ln in disp.lines_.ravel() if ln is not None])
diff --git a/modin/pandas/test/interoperability/sklearn/inspection/tests/test_partial_dependence.py b/modin/pandas/test/interoperability/sklearn/inspection/tests/test_partial_dependence.py
new file mode 100644
index 00000000000..6e5a84f18b0
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/inspection/tests/test_partial_dependence.py
@@ -0,0 +1,852 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+"""
+Testing for the partial dependence module.
+"""
+
+import numpy as np
+import pytest
+import sklearn
+from sklearn.inspection import partial_dependence
+from sklearn.inspection._partial_dependence import (
+ _grid_from_X,
+ _partial_dependence_brute,
+ _partial_dependence_recursion,
+)
+from sklearn.ensemble import GradientBoostingClassifier
+from sklearn.ensemble import GradientBoostingRegressor
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.ensemble import HistGradientBoostingClassifier
+from sklearn.ensemble import HistGradientBoostingRegressor
+from sklearn.linear_model import LinearRegression
+from sklearn.linear_model import LogisticRegression
+from sklearn.linear_model import MultiTaskLasso
+from sklearn.tree import DecisionTreeRegressor
+from sklearn.datasets import load_iris
+from sklearn.datasets import make_classification, make_regression
+from sklearn.cluster import KMeans
+from sklearn.compose import make_column_transformer
+from sklearn.metrics import r2_score
+from sklearn.preprocessing import PolynomialFeatures
+from sklearn.preprocessing import StandardScaler
+from sklearn.preprocessing import RobustScaler
+from sklearn.preprocessing import scale
+from sklearn.pipeline import make_pipeline
+from sklearn.dummy import DummyClassifier
+from sklearn.base import BaseEstimator, ClassifierMixin, clone
+from sklearn.exceptions import NotFittedError
+from sklearn.utils._testing import assert_allclose
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils import _IS_32BIT
+from sklearn.utils.validation import check_random_state
+from sklearn.tree.tests.test_tree import assert_is_subtree
+
+
+# toy sample
+X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
+y = [-1, -1, -1, 1, 1, 1]
+
+
+# (X, y), n_targets <-- as expected in the output of partial_dep()
+binary_classification_data = (make_classification(n_samples=50, random_state=0), 1)
+multiclass_classification_data = (
+ make_classification(
+ n_samples=50, n_classes=3, n_clusters_per_class=1, random_state=0
+ ),
+ 3,
+)
+regression_data = (make_regression(n_samples=50, random_state=0), 1)
+multioutput_regression_data = (
+ make_regression(n_samples=50, n_targets=2, random_state=0),
+ 2,
+)
+
+# iris
+iris = load_iris()
+
+
+@pytest.mark.parametrize(
+ "Estimator, method, data",
+ [
+ (GradientBoostingClassifier, "auto", binary_classification_data),
+ (GradientBoostingClassifier, "auto", multiclass_classification_data),
+ (GradientBoostingClassifier, "brute", binary_classification_data),
+ (GradientBoostingClassifier, "brute", multiclass_classification_data),
+ (GradientBoostingRegressor, "auto", regression_data),
+ (GradientBoostingRegressor, "brute", regression_data),
+ (DecisionTreeRegressor, "brute", regression_data),
+ (LinearRegression, "brute", regression_data),
+ (LinearRegression, "brute", multioutput_regression_data),
+ (LogisticRegression, "brute", binary_classification_data),
+ (LogisticRegression, "brute", multiclass_classification_data),
+ (MultiTaskLasso, "brute", multioutput_regression_data),
+ ],
+)
+@pytest.mark.parametrize("grid_resolution", (5, 10))
+@pytest.mark.parametrize("features", ([1], [1, 2]))
+@pytest.mark.parametrize("kind", ("average", "individual", "both"))
+def test_output_shape(Estimator, method, data, grid_resolution, features, kind):
+ # Check that partial_dependence has consistent output shape for different
+ # kinds of estimators:
+ # - classifiers with binary and multiclass settings
+ # - regressors
+ # - multi-task regressors
+
+ est = Estimator()
+
+ # n_target corresponds to the number of classes (1 for binary classif) or
+ # the number of tasks / outputs in multi task settings. It's equal to 1 for
+ # classical regression_data.
+ (X, y), n_targets = data
+ n_instances = X.shape[0]
+
+ est.fit(X, y)
+ result = partial_dependence(
+ est,
+ X=X,
+ features=features,
+ method=method,
+ kind=kind,
+ grid_resolution=grid_resolution,
+ )
+ pdp, axes = result, result["values"]
+
+ expected_pdp_shape = (n_targets, *[grid_resolution for _ in range(len(features))])
+ expected_ice_shape = (
+ n_targets,
+ n_instances,
+ *[grid_resolution for _ in range(len(features))],
+ )
+ if kind == "average":
+ assert pdp.average.shape == expected_pdp_shape
+ elif kind == "individual":
+ assert pdp.individual.shape == expected_ice_shape
+ else: # 'both'
+ assert pdp.average.shape == expected_pdp_shape
+ assert pdp.individual.shape == expected_ice_shape
+
+ expected_axes_shape = (len(features), grid_resolution)
+ assert axes is not None
+ assert np.asarray(axes).shape == expected_axes_shape
+
+
+def test_grid_from_X():
+ # tests for _grid_from_X: sanity check for output, and for shapes.
+
+ # Make sure that the grid is a cartesian product of the input (it will use
+ # the unique values instead of the percentiles)
+ percentiles = (0.05, 0.95)
+ grid_resolution = 100
+ is_categorical = [False, False]
+ X = np.asarray([[1, 2], [3, 4]])
+ grid, axes = _grid_from_X(X, percentiles, is_categorical, grid_resolution)
+ assert_array_equal(grid, [[1, 2], [1, 4], [3, 2], [3, 4]])
+ assert_array_equal(axes, X.T)
+
+ # test shapes of returned objects depending on the number of unique values
+ # for a feature.
+ rng = np.random.RandomState(0)
+ grid_resolution = 15
+
+ # n_unique_values > grid_resolution
+ X = rng.normal(size=(20, 2))
+ grid, axes = _grid_from_X(
+ X, percentiles, is_categorical, grid_resolution=grid_resolution
+ )
+ assert grid.shape == (grid_resolution * grid_resolution, X.shape[1])
+ assert np.asarray(axes).shape == (2, grid_resolution)
+
+ # n_unique_values < grid_resolution, will use actual values
+ n_unique_values = 12
+ X[n_unique_values - 1 :, 0] = 12345
+ rng.shuffle(X) # just to make sure the order is irrelevant
+ grid, axes = _grid_from_X(
+ X, percentiles, is_categorical, grid_resolution=grid_resolution
+ )
+ assert grid.shape == (n_unique_values * grid_resolution, X.shape[1])
+ # axes is a list of arrays of different shapes
+ assert axes[0].shape == (n_unique_values,)
+ assert axes[1].shape == (grid_resolution,)
+
+
+@pytest.mark.parametrize(
+ "grid_resolution",
+ [
+ 2, # since n_categories > 2, we should not use quantiles resampling
+ 100,
+ ],
+)
+def test_grid_from_X_with_categorical(grid_resolution):
+ """Check that `_grid_from_X` always sample from categories and does not
+ depend from the percentiles.
+ """
+ pd = pytest.importorskip("modin.pandas")
+ percentiles = (0.05, 0.95)
+ is_categorical = [True]
+ X = pd.DataFrame({"cat_feature": ["A", "B", "C", "A", "B", "D", "E"]})
+ grid, axes = _grid_from_X(
+ X, percentiles, is_categorical, grid_resolution=grid_resolution
+ )
+ assert grid.shape == (5, X.shape[1])
+ assert axes[0].shape == (5,)
+
+
+@pytest.mark.parametrize("grid_resolution", [3, 100])
+def test_grid_from_X_heterogeneous_type(grid_resolution):
+ """Check that `_grid_from_X` always sample from categories and does not
+ depend from the percentiles.
+ """
+ pd = pytest.importorskip("modin.pandas")
+ percentiles = (0.05, 0.95)
+ is_categorical = [True, False]
+ X = pd.DataFrame(
+ {
+ "cat": ["A", "B", "C", "A", "B", "D", "E", "A", "B", "D"],
+ "num": [1, 1, 1, 2, 5, 6, 6, 6, 6, 8],
+ }
+ )
+ nunique = X.nunique()
+
+ grid, axes = _grid_from_X(
+ X, percentiles, is_categorical, grid_resolution=grid_resolution
+ )
+ if grid_resolution == 3:
+ assert grid.shape == (15, 2)
+ assert axes[0].shape[0] == nunique["num"]
+ assert axes[1].shape[0] == grid_resolution
+ else:
+ assert grid.shape == (25, 2)
+ assert axes[0].shape[0] == nunique["cat"]
+ assert axes[1].shape[0] == nunique["cat"]
+
+
+@pytest.mark.parametrize(
+ "grid_resolution, percentiles, err_msg",
+ [
+ (2, (0, 0.0001), "percentiles are too close"),
+ (100, (1, 2, 3, 4), "'percentiles' must be a sequence of 2 elements"),
+ (100, 12345, "'percentiles' must be a sequence of 2 elements"),
+ (100, (-1, 0.95), r"'percentiles' values must be in \[0, 1\]"),
+ (100, (0.05, 2), r"'percentiles' values must be in \[0, 1\]"),
+ (100, (0.9, 0.1), r"percentiles\[0\] must be strictly less than"),
+ (1, (0.05, 0.95), "'grid_resolution' must be strictly greater than 1"),
+ ],
+)
+def test_grid_from_X_error(grid_resolution, percentiles, err_msg):
+ X = np.asarray([[1, 2], [3, 4]])
+ is_categorical = [False]
+ with pytest.raises(ValueError, match=err_msg):
+ _grid_from_X(X, percentiles, is_categorical, grid_resolution)
+
+
+@pytest.mark.parametrize("target_feature", range(5))
+@pytest.mark.parametrize(
+ "est, method",
+ [
+ (LinearRegression(), "brute"),
+ (GradientBoostingRegressor(random_state=0), "brute"),
+ (GradientBoostingRegressor(random_state=0), "recursion"),
+ (HistGradientBoostingRegressor(random_state=0), "brute"),
+ (HistGradientBoostingRegressor(random_state=0), "recursion"),
+ ],
+)
+def test_partial_dependence_helpers(est, method, target_feature):
+ # Check that what is returned by _partial_dependence_brute or
+ # _partial_dependence_recursion is equivalent to manually setting a target
+ # feature to a given value, and computing the average prediction over all
+ # samples.
+ # This also checks that the brute and recursion methods give the same
+ # output.
+ # Note that even on the trainset, the brute and the recursion methods
+ # aren't always strictly equivalent, in particular when the slow method
+ # generates unrealistic samples that have low mass in the joint
+ # distribution of the input features, and when some of the features are
+ # dependent. Hence the high tolerance on the checks.
+
+ X, y = make_regression(random_state=0, n_features=5, n_informative=5)
+ # The 'init' estimator for GBDT (here the average prediction) isn't taken
+ # into account with the recursion method, for technical reasons. We set
+ # the mean to 0 to that this 'bug' doesn't have any effect.
+ y = y - y.mean()
+ est.fit(X, y)
+
+ # target feature will be set to .5 and then to 123
+ features = np.array([target_feature], dtype=np.int32)
+ grid = np.array([[0.5], [123]])
+
+ if method == "brute":
+ pdp, predictions = _partial_dependence_brute(
+ est, grid, features, X, response_method="auto"
+ )
+ else:
+ pdp = _partial_dependence_recursion(est, grid, features)
+
+ mean_predictions = []
+ for val in (0.5, 123):
+ X_ = X.copy()
+ X_[:, target_feature] = val
+ mean_predictions.append(est.predict(X_).mean())
+
+ pdp = pdp[0] # (shape is (1, 2) so make it (2,))
+
+ # allow for greater margin for error with recursion method
+ rtol = 1e-1 if method == "recursion" else 1e-3
+ assert np.allclose(pdp, mean_predictions, rtol=rtol)
+
+
+@pytest.mark.parametrize("seed", range(1))
+def test_recursion_decision_tree_vs_forest_and_gbdt(seed):
+ # Make sure that the recursion method gives the same results on a
+ # DecisionTreeRegressor and a GradientBoostingRegressor or a
+ # RandomForestRegressor with 1 tree and equivalent parameters.
+
+ rng = np.random.RandomState(seed)
+
+ # Purely random dataset to avoid correlated features
+ n_samples = 1000
+ n_features = 5
+ X = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples) * 10
+
+ # The 'init' estimator for GBDT (here the average prediction) isn't taken
+ # into account with the recursion method, for technical reasons. We set
+ # the mean to 0 to that this 'bug' doesn't have any effect.
+ y = y - y.mean()
+
+ # set max_depth not too high to avoid splits with same gain but different
+ # features
+ max_depth = 5
+
+ tree_seed = 0
+ forest = RandomForestRegressor(
+ n_estimators=1,
+ max_features=None,
+ bootstrap=False,
+ max_depth=max_depth,
+ random_state=tree_seed,
+ )
+ # The forest will use ensemble.base._set_random_states to set the
+ # random_state of the tree sub-estimator. We simulate this here to have
+ # equivalent estimators.
+ equiv_random_state = check_random_state(tree_seed).randint(np.iinfo(np.int32).max)
+ gbdt = GradientBoostingRegressor(
+ n_estimators=1,
+ learning_rate=1,
+ criterion="squared_error",
+ max_depth=max_depth,
+ random_state=equiv_random_state,
+ )
+ tree = DecisionTreeRegressor(max_depth=max_depth, random_state=equiv_random_state)
+
+ forest.fit(X, y)
+ gbdt.fit(X, y)
+ tree.fit(X, y)
+
+ # sanity check: if the trees aren't the same, the PD values won't be equal
+ try:
+ assert_is_subtree(tree.tree_, gbdt[0, 0].tree_)
+ assert_is_subtree(tree.tree_, forest[0].tree_)
+ except AssertionError:
+ # For some reason the trees aren't exactly equal on 32bits, so the PDs
+ # cannot be equal either. See
+ # https://github.com/scikit-learn/scikit-learn/issues/8853
+ assert _IS_32BIT, "this should only fail on 32 bit platforms"
+ return
+
+ grid = rng.randn(50).reshape(-1, 1)
+ for f in range(n_features):
+ features = np.array([f], dtype=np.int32)
+
+ pdp_forest = _partial_dependence_recursion(forest, grid, features)
+ pdp_gbdt = _partial_dependence_recursion(gbdt, grid, features)
+ pdp_tree = _partial_dependence_recursion(tree, grid, features)
+
+ np.testing.assert_allclose(pdp_gbdt, pdp_tree)
+ np.testing.assert_allclose(pdp_forest, pdp_tree)
+
+
+@pytest.mark.parametrize(
+ "est",
+ (
+ GradientBoostingClassifier(random_state=0),
+ HistGradientBoostingClassifier(random_state=0),
+ ),
+)
+@pytest.mark.parametrize("target_feature", (0, 1, 2, 3, 4, 5))
+def test_recursion_decision_function(est, target_feature):
+ # Make sure the recursion method (implicitly uses decision_function) has
+ # the same result as using brute method with
+ # response_method=decision_function
+
+ X, y = make_classification(n_classes=2, n_clusters_per_class=1, random_state=1)
+ assert np.mean(y) == 0.5 # make sure the init estimator predicts 0 anyway
+
+ est.fit(X, y)
+
+ preds_1 = partial_dependence(
+ est,
+ X,
+ [target_feature],
+ response_method="decision_function",
+ method="recursion",
+ kind="average",
+ )
+ preds_2 = partial_dependence(
+ est,
+ X,
+ [target_feature],
+ response_method="decision_function",
+ method="brute",
+ kind="average",
+ )
+
+ assert_allclose(preds_1["average"], preds_2["average"], atol=1e-7)
+
+
+@pytest.mark.parametrize(
+ "est",
+ (
+ LinearRegression(),
+ GradientBoostingRegressor(random_state=0),
+ HistGradientBoostingRegressor(
+ random_state=0, min_samples_leaf=1, max_leaf_nodes=None, max_iter=1
+ ),
+ DecisionTreeRegressor(random_state=0),
+ ),
+)
+@pytest.mark.parametrize("power", (1, 2))
+def test_partial_dependence_easy_target(est, power):
+ # If the target y only depends on one feature in an obvious way (linear or
+ # quadratic) then the partial dependence for that feature should reflect
+ # it.
+ # We here fit a linear regression_data model (with polynomial features if
+ # needed) and compute r_squared to check that the partial dependence
+ # correctly reflects the target.
+
+ rng = np.random.RandomState(0)
+ n_samples = 200
+ target_variable = 2
+ X = rng.normal(size=(n_samples, 5))
+ y = X[:, target_variable] ** power
+
+ est.fit(X, y)
+
+ pdp = partial_dependence(
+ est, features=[target_variable], X=X, grid_resolution=1000, kind="average"
+ )
+
+ new_X = pdp["values"][0].reshape(-1, 1)
+ new_y = pdp["average"][0]
+ # add polynomial features if needed
+ new_X = PolynomialFeatures(degree=power).fit_transform(new_X)
+
+ lr = LinearRegression().fit(new_X, new_y)
+ r2 = r2_score(new_y, lr.predict(new_X))
+
+ assert r2 > 0.99
+
+
+@pytest.mark.parametrize(
+ "Estimator",
+ (
+ sklearn.tree.DecisionTreeClassifier,
+ sklearn.tree.ExtraTreeClassifier,
+ sklearn.ensemble.ExtraTreesClassifier,
+ sklearn.neighbors.KNeighborsClassifier,
+ sklearn.neighbors.RadiusNeighborsClassifier,
+ sklearn.ensemble.RandomForestClassifier,
+ ),
+)
+def test_multiclass_multioutput(Estimator):
+ # Make sure error is raised for multiclass-multioutput classifiers
+
+ # make multiclass-multioutput dataset
+ X, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=0)
+ y = np.array([y, y]).T
+
+ est = Estimator()
+ est.fit(X, y)
+
+ with pytest.raises(
+ ValueError, match="Multiclass-multioutput estimators are not supported"
+ ):
+ partial_dependence(est, X, [0])
+
+
+class NoPredictProbaNoDecisionFunction(ClassifierMixin, BaseEstimator):
+ def fit(self, X, y):
+ # simulate that we have some classes
+ self.classes_ = [0, 1]
+ return self
+
+
+@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
+@pytest.mark.parametrize(
+ "estimator, params, err_msg",
+ [
+ (
+ KMeans(random_state=0, n_init="auto"),
+ {"features": [0]},
+ "'estimator' must be a fitted regressor or classifier",
+ ),
+ (
+ LinearRegression(),
+ {"features": [0], "response_method": "predict_proba"},
+ "The response_method parameter is ignored for regressors",
+ ),
+ (
+ GradientBoostingClassifier(random_state=0),
+ {
+ "features": [0],
+ "response_method": "predict_proba",
+ "method": "recursion",
+ },
+ "'recursion' method, the response_method must be 'decision_function'",
+ ),
+ (
+ GradientBoostingClassifier(random_state=0),
+ {"features": [0], "response_method": "predict_proba", "method": "auto"},
+ "'recursion' method, the response_method must be 'decision_function'",
+ ),
+ (
+ GradientBoostingClassifier(random_state=0),
+ {"features": [0], "response_method": "blahblah"},
+ "response_method blahblah is invalid. Accepted response_method",
+ ),
+ (
+ NoPredictProbaNoDecisionFunction(),
+ {"features": [0], "response_method": "auto"},
+ "The estimator has no predict_proba and no decision_function method",
+ ),
+ (
+ NoPredictProbaNoDecisionFunction(),
+ {"features": [0], "response_method": "predict_proba"},
+ "The estimator has no predict_proba method.",
+ ),
+ (
+ NoPredictProbaNoDecisionFunction(),
+ {"features": [0], "response_method": "decision_function"},
+ "The estimator has no decision_function method.",
+ ),
+ (
+ LinearRegression(),
+ {"features": [0], "method": "blahblah"},
+ "blahblah is invalid. Accepted method names are brute, recursion, auto",
+ ),
+ (
+ LinearRegression(),
+ {"features": [0], "method": "recursion", "kind": "individual"},
+ "The 'recursion' method only applies when 'kind' is set to 'average'",
+ ),
+ (
+ LinearRegression(),
+ {"features": [0], "method": "recursion", "kind": "both"},
+ "The 'recursion' method only applies when 'kind' is set to 'average'",
+ ),
+ (
+ LinearRegression(),
+ {"features": [0], "method": "recursion"},
+ "Only the following estimators support the 'recursion' method:",
+ ),
+ ],
+)
+def test_partial_dependence_error(estimator, params, err_msg):
+ X, y = make_classification(random_state=0)
+ estimator.fit(X, y)
+
+ with pytest.raises(ValueError, match=err_msg):
+ partial_dependence(estimator, X, **params)
+
+
+@pytest.mark.parametrize(
+ "with_dataframe, err_msg",
+ [
+ (True, "Only array-like or scalar are supported"),
+ (False, "Only array-like or scalar are supported"),
+ ],
+)
+def test_partial_dependence_slice_error(with_dataframe, err_msg):
+ X, y = make_classification(random_state=0)
+ if with_dataframe:
+ pd = pytest.importorskip("modin.pandas")
+ X = pd.DataFrame(X)
+ estimator = LogisticRegression().fit(X, y)
+
+ with pytest.raises(TypeError, match=err_msg):
+ partial_dependence(estimator, X, features=slice(0, 2, 1))
+
+
+@pytest.mark.parametrize(
+ "estimator", [LinearRegression(), GradientBoostingClassifier(random_state=0)]
+)
+@pytest.mark.parametrize("features", [-1, 10000])
+def test_partial_dependence_unknown_feature_indices(estimator, features):
+ X, y = make_classification(random_state=0)
+ estimator.fit(X, y)
+
+ err_msg = "all features must be in"
+ with pytest.raises(ValueError, match=err_msg):
+ partial_dependence(estimator, X, [features])
+
+
+@pytest.mark.parametrize(
+ "estimator", [LinearRegression(), GradientBoostingClassifier(random_state=0)]
+)
+def test_partial_dependence_unknown_feature_string(estimator):
+ pd = pytest.importorskip("modin.pandas")
+ X, y = make_classification(random_state=0)
+ df = pd.DataFrame(X)
+ estimator.fit(df, y)
+
+ features = ["random"]
+ err_msg = "A given column is not a column of the dataframe"
+ with pytest.raises(ValueError, match=err_msg):
+ partial_dependence(estimator, df, features)
+
+
+@pytest.mark.parametrize(
+ "estimator", [LinearRegression(), GradientBoostingClassifier(random_state=0)]
+)
+def test_partial_dependence_X_list(estimator):
+ # check that array-like objects are accepted
+ X, y = make_classification(random_state=0)
+ estimator.fit(X, y)
+ partial_dependence(estimator, list(X), [0], kind="average")
+
+
+def test_warning_recursion_non_constant_init():
+ # make sure that passing a non-constant init parameter to a GBDT and using
+ # recursion method yields a warning.
+
+ gbc = GradientBoostingClassifier(init=DummyClassifier(), random_state=0)
+ gbc.fit(X, y)
+
+ with pytest.warns(
+ UserWarning, match="Using recursion method with a non-constant init predictor"
+ ):
+ partial_dependence(gbc, X, [0], method="recursion", kind="average")
+
+ with pytest.warns(
+ UserWarning, match="Using recursion method with a non-constant init predictor"
+ ):
+ partial_dependence(gbc, X, [0], method="recursion", kind="average")
+
+
+def test_partial_dependence_sample_weight():
+ # Test near perfect correlation between partial dependence and diagonal
+ # when sample weights emphasize y = x predictions
+ # non-regression test for #13193
+ # TODO: extend to HistGradientBoosting once sample_weight is supported
+ N = 1000
+ rng = np.random.RandomState(123456)
+ mask = rng.randint(2, size=N, dtype=bool)
+
+ x = rng.rand(N)
+ # set y = x on mask and y = -x outside
+ y = x.copy()
+ y[~mask] = -y[~mask]
+ X = np.c_[mask, x]
+ # sample weights to emphasize data points where y = x
+ sample_weight = np.ones(N)
+ sample_weight[mask] = 1000.0
+
+ clf = GradientBoostingRegressor(n_estimators=10, random_state=1)
+ clf.fit(X, y, sample_weight=sample_weight)
+
+ pdp = partial_dependence(clf, X, features=[1], kind="average")
+
+ assert np.corrcoef(pdp["average"], pdp["values"])[0, 1] > 0.99
+
+
+def test_hist_gbdt_sw_not_supported():
+ # TODO: remove/fix when PDP supports HGBT with sample weights
+ clf = HistGradientBoostingRegressor(random_state=1)
+ clf.fit(X, y, sample_weight=np.ones(len(X)))
+
+ with pytest.raises(
+ NotImplementedError, match="does not support partial dependence"
+ ):
+ partial_dependence(clf, X, features=[1])
+
+
+def test_partial_dependence_pipeline():
+ # check that the partial dependence support pipeline
+ iris = load_iris()
+
+ scaler = StandardScaler()
+ clf = DummyClassifier(random_state=42)
+ pipe = make_pipeline(scaler, clf)
+
+ clf.fit(scaler.fit_transform(iris.data), iris.target)
+ pipe.fit(iris.data, iris.target)
+
+ features = 0
+ pdp_pipe = partial_dependence(
+ pipe, iris.data, features=[features], grid_resolution=10, kind="average"
+ )
+ pdp_clf = partial_dependence(
+ clf,
+ scaler.transform(iris.data),
+ features=[features],
+ grid_resolution=10,
+ kind="average",
+ )
+ assert_allclose(pdp_pipe["average"], pdp_clf["average"])
+ assert_allclose(
+ pdp_pipe["values"][0],
+ pdp_clf["values"][0] * scaler.scale_[features] + scaler.mean_[features],
+ )
+
+
+@pytest.mark.parametrize(
+ "estimator",
+ [
+ LogisticRegression(max_iter=1000, random_state=0),
+ GradientBoostingClassifier(random_state=0, n_estimators=5),
+ ],
+ ids=["estimator-brute", "estimator-recursion"],
+)
+@pytest.mark.parametrize(
+ "preprocessor",
+ [
+ None,
+ make_column_transformer(
+ (StandardScaler(), [iris.feature_names[i] for i in (0, 2)]),
+ (RobustScaler(), [iris.feature_names[i] for i in (1, 3)]),
+ ),
+ make_column_transformer(
+ (StandardScaler(), [iris.feature_names[i] for i in (0, 2)]),
+ remainder="passthrough",
+ ),
+ ],
+ ids=["None", "column-transformer", "column-transformer-passthrough"],
+)
+@pytest.mark.parametrize(
+ "features",
+ [[0, 2], [iris.feature_names[i] for i in (0, 2)]],
+ ids=["features-integer", "features-string"],
+)
+def test_partial_dependence_dataframe(estimator, preprocessor, features):
+ # check that the partial dependence support dataframe and pipeline
+ # including a column transformer
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame(scale(iris.data), columns=iris.feature_names)
+
+ pipe = make_pipeline(preprocessor, estimator)
+ pipe.fit(df, iris.target)
+ pdp_pipe = partial_dependence(
+ pipe, df, features=features, grid_resolution=10, kind="average"
+ )
+
+ # the column transformer will reorder the column when transforming
+ # we mixed the index to be sure that we are computing the partial
+ # dependence of the right columns
+ if preprocessor is not None:
+ X_proc = clone(preprocessor).fit_transform(df)
+ features_clf = [0, 1]
+ else:
+ X_proc = df
+ features_clf = [0, 2]
+
+ clf = clone(estimator).fit(X_proc, iris.target)
+ pdp_clf = partial_dependence(
+ clf,
+ X_proc,
+ features=features_clf,
+ method="brute",
+ grid_resolution=10,
+ kind="average",
+ )
+
+ assert_allclose(pdp_pipe["average"], pdp_clf["average"])
+ if preprocessor is not None:
+ scaler = preprocessor.named_transformers_["standardscaler"]
+ assert_allclose(
+ pdp_pipe["values"][1],
+ pdp_clf["values"][1] * scaler.scale_[1] + scaler.mean_[1],
+ )
+ else:
+ assert_allclose(pdp_pipe["values"][1], pdp_clf["values"][1])
+
+
+@pytest.mark.parametrize(
+ "features, expected_pd_shape",
+ [
+ (0, (3, 10)),
+ (iris.feature_names[0], (3, 10)),
+ ([0, 2], (3, 10, 10)),
+ ([iris.feature_names[i] for i in (0, 2)], (3, 10, 10)),
+ ([True, False, True, False], (3, 10, 10)),
+ ],
+ ids=["scalar-int", "scalar-str", "list-int", "list-str", "mask"],
+)
+def test_partial_dependence_feature_type(features, expected_pd_shape):
+ # check all possible features type supported in PDP
+ pd = pytest.importorskip("modin.pandas")
+ df = pd.DataFrame(iris.data, columns=iris.feature_names)
+
+ preprocessor = make_column_transformer(
+ (StandardScaler(), [iris.feature_names[i] for i in (0, 2)]),
+ (RobustScaler(), [iris.feature_names[i] for i in (1, 3)]),
+ )
+ pipe = make_pipeline(
+ preprocessor, LogisticRegression(max_iter=1000, random_state=0)
+ )
+ pipe.fit(df, iris.target)
+ pdp_pipe = partial_dependence(
+ pipe, df, features=features, grid_resolution=10, kind="average"
+ )
+ assert pdp_pipe["average"].shape == expected_pd_shape
+ assert len(pdp_pipe["values"]) == len(pdp_pipe["average"].shape) - 1
+
+
+@pytest.mark.parametrize(
+ "estimator",
+ [
+ LinearRegression(),
+ LogisticRegression(),
+ GradientBoostingRegressor(),
+ GradientBoostingClassifier(),
+ ],
+)
+def test_partial_dependence_unfitted(estimator):
+ X = iris.data
+ preprocessor = make_column_transformer(
+ (StandardScaler(), [0, 2]), (RobustScaler(), [1, 3])
+ )
+ pipe = make_pipeline(preprocessor, estimator)
+ with pytest.raises(NotFittedError, match="is not fitted yet"):
+ partial_dependence(pipe, X, features=[0, 2], grid_resolution=10)
+ with pytest.raises(NotFittedError, match="is not fitted yet"):
+ partial_dependence(estimator, X, features=[0, 2], grid_resolution=10)
+
+
+@pytest.mark.parametrize(
+ "Estimator, data",
+ [
+ (LinearRegression, multioutput_regression_data),
+ (LogisticRegression, binary_classification_data),
+ ],
+)
+def test_kind_average_and_average_of_individual(Estimator, data):
+ est = Estimator()
+ (X, y), n_targets = data
+ est.fit(X, y)
+
+ pdp_avg = partial_dependence(est, X=X, features=[1, 2], kind="average")
+ pdp_ind = partial_dependence(est, X=X, features=[1, 2], kind="individual")
+ avg_ind = np.mean(pdp_ind["individual"], axis=1)
+ assert_allclose(avg_ind, pdp_avg["average"])
diff --git a/modin/pandas/test/interoperability/sklearn/inspection/tests/test_permutation_importance.py b/modin/pandas/test/interoperability/sklearn/inspection/tests/test_permutation_importance.py
new file mode 100644
index 00000000000..fb327ad20ce
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/inspection/tests/test_permutation_importance.py
@@ -0,0 +1,558 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import pytest
+import numpy as np
+from numpy.testing import assert_allclose
+from sklearn.compose import ColumnTransformer
+from sklearn.datasets import load_diabetes
+from sklearn.datasets import load_iris
+from sklearn.datasets import make_classification
+from sklearn.datasets import make_regression
+from sklearn.dummy import DummyClassifier
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.linear_model import LinearRegression
+from sklearn.linear_model import LogisticRegression
+from sklearn.impute import SimpleImputer
+from sklearn.inspection import permutation_importance
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import (
+ get_scorer,
+ mean_squared_error,
+ r2_score,
+)
+from sklearn.pipeline import make_pipeline
+from sklearn.preprocessing import KBinsDiscretizer
+from sklearn.preprocessing import OneHotEncoder
+from sklearn.preprocessing import StandardScaler
+from sklearn.preprocessing import scale
+from sklearn.utils import parallel_backend
+from sklearn.utils._testing import _convert_container
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("n_jobs", [1, 2])
+@pytest.mark.parametrize("max_samples", [0.5, 1.0])
+def test_permutation_importance_correlated_feature_regression(n_jobs, max_samples):
+ # Make sure that feature highly correlated to the target have a higher
+ # importance
+ rng = np.random.RandomState(42)
+ n_repeats = 5
+
+ X, y = load_diabetes(return_X_y=True)
+ y_with_little_noise = (y + rng.normal(scale=0.001, size=y.shape[0])).reshape(-1, 1)
+
+ X = np.hstack([X, y_with_little_noise])
+
+ clf = RandomForestRegressor(n_estimators=10, random_state=42)
+ clf.fit(X, y)
+
+ result = permutation_importance(
+ clf,
+ X,
+ y,
+ n_repeats=n_repeats,
+ random_state=rng,
+ n_jobs=n_jobs,
+ max_samples=max_samples,
+ )
+
+ assert result.importances.shape == (X.shape[1], n_repeats)
+
+ # the correlated feature with y was added as the last column and should
+ # have the highest importance
+ assert np.all(result.importances_mean[-1] > result.importances_mean[:-1])
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("n_jobs", [1, 2])
+@pytest.mark.parametrize("max_samples", [0.5, 1.0])
+def test_permutation_importance_correlated_feature_regression_pandas(
+ n_jobs, max_samples
+):
+ pd = pytest.importorskip("modin.pandas")
+
+ # Make sure that feature highly correlated to the target have a higher
+ # importance
+ rng = np.random.RandomState(42)
+ n_repeats = 5
+
+ dataset = load_iris()
+ X, y = dataset.data, dataset.target
+ y_with_little_noise = (y + rng.normal(scale=0.001, size=y.shape[0])).reshape(-1, 1)
+
+ # Adds feature correlated with y as the last column
+ X = pd.DataFrame(X, columns=dataset.feature_names)
+ X["correlated_feature"] = y_with_little_noise
+
+ clf = RandomForestClassifier(n_estimators=10, random_state=42)
+ clf.fit(X, y)
+
+ result = permutation_importance(
+ clf,
+ X,
+ y,
+ n_repeats=n_repeats,
+ random_state=rng,
+ n_jobs=n_jobs,
+ max_samples=max_samples,
+ )
+
+ assert result.importances.shape == (X.shape[1], n_repeats)
+
+ # the correlated feature with y was added as the last column and should
+ # have the highest importance
+ assert np.all(result.importances_mean[-1] > result.importances_mean[:-1])
+
+
+@pytest.mark.parametrize("n_jobs", [1, 2])
+@pytest.mark.parametrize("max_samples", [0.5, 1.0])
+def test_robustness_to_high_cardinality_noisy_feature(n_jobs, max_samples, seed=42):
+ # Permutation variable importance should not be affected by the high
+ # cardinality bias of traditional feature importances, especially when
+ # computed on a held-out test set:
+ rng = np.random.RandomState(seed)
+ n_repeats = 5
+ n_samples = 1000
+ n_classes = 5
+ n_informative_features = 2
+ n_noise_features = 1
+ n_features = n_informative_features + n_noise_features
+
+ # Generate a multiclass classification dataset and a set of informative
+ # binary features that can be used to predict some classes of y exactly
+ # while leaving some classes unexplained to make the problem harder.
+ classes = np.arange(n_classes)
+ y = rng.choice(classes, size=n_samples)
+ X = np.hstack([(y == c).reshape(-1, 1) for c in classes[:n_informative_features]])
+ X = X.astype(np.float32)
+
+ # Not all target classes are explained by the binary class indicator
+ # features:
+ assert n_informative_features < n_classes
+
+ # Add 10 other noisy features with high cardinality (numerical) values
+ # that can be used to overfit the training data.
+ X = np.concatenate([X, rng.randn(n_samples, n_noise_features)], axis=1)
+ assert X.shape == (n_samples, n_features)
+
+ # Split the dataset to be able to evaluate on a held-out test set. The
+ # Test size should be large enough for importance measurements to be
+ # stable:
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, test_size=0.5, random_state=rng
+ )
+ clf = RandomForestClassifier(n_estimators=5, random_state=rng)
+ clf.fit(X_train, y_train)
+
+ # Variable importances computed by impurity decrease on the tree node
+ # splits often use the noisy features in splits. This can give misleading
+ # impression that high cardinality noisy variables are the most important:
+ tree_importances = clf.feature_importances_
+ informative_tree_importances = tree_importances[:n_informative_features]
+ noisy_tree_importances = tree_importances[n_informative_features:]
+ assert informative_tree_importances.max() < noisy_tree_importances.min()
+
+ # Let's check that permutation-based feature importances do not have this
+ # problem.
+ r = permutation_importance(
+ clf,
+ X_test,
+ y_test,
+ n_repeats=n_repeats,
+ random_state=rng,
+ n_jobs=n_jobs,
+ max_samples=max_samples,
+ )
+
+ assert r.importances.shape == (X.shape[1], n_repeats)
+
+ # Split the importances between informative and noisy features
+ informative_importances = r.importances_mean[:n_informative_features]
+ noisy_importances = r.importances_mean[n_informative_features:]
+
+ # Because we do not have a binary variable explaining each target classes,
+ # the RF model will have to use the random variable to make some
+ # (overfitting) splits (as max_depth is not set). Therefore the noisy
+ # variables will be non-zero but with small values oscillating around
+ # zero:
+ assert max(np.abs(noisy_importances)) > 1e-7
+ assert noisy_importances.max() < 0.05
+
+ # The binary features correlated with y should have a higher importance
+ # than the high cardinality noisy features.
+ # The maximum test accuracy is 2 / 5 == 0.4, each informative feature
+ # contributing approximately a bit more than 0.2 of accuracy.
+ assert informative_importances.min() > 0.15
+
+
+def test_permutation_importance_mixed_types():
+ rng = np.random.RandomState(42)
+ n_repeats = 4
+
+ # Last column is correlated with y
+ X = np.array([[1.0, 2.0, 3.0, np.nan], [2, 1, 2, 1]]).T
+ y = np.array([0, 1, 0, 1])
+
+ clf = make_pipeline(SimpleImputer(), LogisticRegression(solver="lbfgs"))
+ clf.fit(X, y)
+ result = permutation_importance(clf, X, y, n_repeats=n_repeats, random_state=rng)
+
+ assert result.importances.shape == (X.shape[1], n_repeats)
+
+ # the correlated feature with y is the last column and should
+ # have the highest importance
+ assert np.all(result.importances_mean[-1] > result.importances_mean[:-1])
+
+ # use another random state
+ rng = np.random.RandomState(0)
+ result2 = permutation_importance(clf, X, y, n_repeats=n_repeats, random_state=rng)
+ assert result2.importances.shape == (X.shape[1], n_repeats)
+
+ assert not np.allclose(result.importances, result2.importances)
+
+ # the correlated feature with y is the last column and should
+ # have the highest importance
+ assert np.all(result2.importances_mean[-1] > result2.importances_mean[:-1])
+
+
+def test_permutation_importance_mixed_types_pandas():
+ pd = pytest.importorskip("modin.pandas")
+ rng = np.random.RandomState(42)
+ n_repeats = 5
+
+ # Last column is correlated with y
+ X = pd.DataFrame({"col1": [1.0, 2.0, 3.0, np.nan], "col2": ["a", "b", "a", "b"]})
+ y = np.array([0, 1, 0, 1])
+
+ num_preprocess = make_pipeline(SimpleImputer(), StandardScaler())
+ preprocess = ColumnTransformer(
+ [("num", num_preprocess, ["col1"]), ("cat", OneHotEncoder(), ["col2"])]
+ )
+ clf = make_pipeline(preprocess, LogisticRegression(solver="lbfgs"))
+ clf.fit(X, y)
+
+ result = permutation_importance(clf, X, y, n_repeats=n_repeats, random_state=rng)
+
+ assert result.importances.shape == (X.shape[1], n_repeats)
+ # the correlated feature with y is the last column and should
+ # have the highest importance
+ assert np.all(result.importances_mean[-1] > result.importances_mean[:-1])
+
+
+def test_permutation_importance_linear_regresssion():
+ X, y = make_regression(n_samples=500, n_features=10, random_state=0)
+
+ X = scale(X)
+ y = scale(y)
+
+ lr = LinearRegression().fit(X, y)
+
+ # this relationship can be computed in closed form
+ expected_importances = 2 * lr.coef_**2
+ results = permutation_importance(
+ lr, X, y, n_repeats=50, scoring="neg_mean_squared_error"
+ )
+ assert_allclose(
+ expected_importances, results.importances_mean, rtol=1e-1, atol=1e-6
+ )
+
+
+@pytest.mark.parametrize("max_samples", [500, 1.0])
+def test_permutation_importance_equivalence_sequential_parallel(max_samples):
+ # regression test to make sure that sequential and parallel calls will
+ # output the same results.
+ # Also tests that max_samples equal to number of samples is equivalent to 1.0
+ X, y = make_regression(n_samples=500, n_features=10, random_state=0)
+ lr = LinearRegression().fit(X, y)
+
+ importance_sequential = permutation_importance(
+ lr, X, y, n_repeats=5, random_state=0, n_jobs=1, max_samples=max_samples
+ )
+
+ # First check that the problem is structured enough and that the model is
+ # complex enough to not yield trivial, constant importances:
+ imp_min = importance_sequential["importances"].min()
+ imp_max = importance_sequential["importances"].max()
+ assert imp_max - imp_min > 0.3
+
+ # The actually check that parallelism does not impact the results
+ # either with shared memory (threading) or without isolated memory
+ # via process-based parallelism using the default backend
+ # ('loky' or 'multiprocessing') depending on the joblib version:
+
+ # process-based parallelism (by default):
+ importance_processes = permutation_importance(
+ lr, X, y, n_repeats=5, random_state=0, n_jobs=2
+ )
+ assert_allclose(
+ importance_processes["importances"], importance_sequential["importances"]
+ )
+
+ # thread-based parallelism:
+ with parallel_backend("threading"):
+ importance_threading = permutation_importance(
+ lr, X, y, n_repeats=5, random_state=0, n_jobs=2
+ )
+ assert_allclose(
+ importance_threading["importances"], importance_sequential["importances"]
+ )
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("n_jobs", [None, 1, 2])
+@pytest.mark.parametrize("max_samples", [0.5, 1.0])
+def test_permutation_importance_equivalence_array_dataframe(n_jobs, max_samples):
+ # This test checks that the column shuffling logic has the same behavior
+ # both a dataframe and a simple numpy array.
+ pd = pytest.importorskip("modin.pandas")
+
+ # regression test to make sure that sequential and parallel calls will
+ # output the same results.
+ X, y = make_regression(n_samples=100, n_features=5, random_state=0)
+ X_df = pd.DataFrame(X)
+
+ # Add a categorical feature that is statistically linked to y:
+ binner = KBinsDiscretizer(n_bins=3, encode="ordinal")
+ cat_column = binner.fit_transform(y.reshape(-1, 1))
+
+ # Concatenate the extra column to the numpy array: integers will be
+ # cast to float values
+ X = np.hstack([X, cat_column])
+ assert X.dtype.kind == "f"
+
+ # Insert extra column as a non-numpy-native dtype (while keeping backward
+ # compat for old pandas versions):
+ if hasattr(pd, "Categorical"):
+ cat_column = pd.Categorical(cat_column.ravel())
+ else:
+ cat_column = cat_column.ravel()
+ new_col_idx = len(X_df.columns)
+ X_df[new_col_idx] = cat_column
+ assert X_df[new_col_idx].dtype == cat_column.dtype
+
+ # Stich an arbitrary index to the dataframe:
+ X_df.index = np.arange(len(X_df)).astype(str)
+
+ rf = RandomForestRegressor(n_estimators=5, max_depth=3, random_state=0)
+ rf.fit(X, y)
+
+ n_repeats = 3
+ importance_array = permutation_importance(
+ rf,
+ X,
+ y,
+ n_repeats=n_repeats,
+ random_state=0,
+ n_jobs=n_jobs,
+ max_samples=max_samples,
+ )
+
+ # First check that the problem is structured enough and that the model is
+ # complex enough to not yield trivial, constant importances:
+ imp_min = importance_array["importances"].min()
+ imp_max = importance_array["importances"].max()
+ assert imp_max - imp_min > 0.3
+
+ # Now check that importances computed on dataframe matche the values
+ # of those computed on the array with the same data.
+ importance_dataframe = permutation_importance(
+ rf,
+ X_df,
+ y,
+ n_repeats=n_repeats,
+ random_state=0,
+ n_jobs=n_jobs,
+ max_samples=max_samples,
+ )
+ assert_allclose(
+ importance_array["importances"], importance_dataframe["importances"]
+ )
+
+
+@pytest.mark.parametrize("input_type", ["array", "dataframe"])
+def test_permutation_importance_large_memmaped_data(input_type):
+ # Smoke, non-regression test for:
+ # https://github.com/scikit-learn/scikit-learn/issues/15810
+ n_samples, n_features = int(5e4), 4
+ X, y = make_classification(
+ n_samples=n_samples, n_features=n_features, random_state=0
+ )
+ assert X.nbytes > 1e6 # trigger joblib memmaping
+
+ X = _convert_container(X, input_type)
+ clf = DummyClassifier(strategy="prior").fit(X, y)
+
+ # Actual smoke test: should not raise any error:
+ n_repeats = 5
+ r = permutation_importance(clf, X, y, n_repeats=n_repeats, n_jobs=2)
+
+ # Auxiliary check: DummyClassifier is feature independent:
+ # permutating feature should not change the predictions
+ expected_importances = np.zeros((n_features, n_repeats))
+ assert_allclose(expected_importances, r.importances)
+
+
+def test_permutation_importance_sample_weight():
+ # Creating data with 2 features and 1000 samples, where the target
+ # variable is a linear combination of the two features, such that
+ # in half of the samples the impact of feature 1 is twice the impact of
+ # feature 2, and vice versa on the other half of the samples.
+ rng = np.random.RandomState(1)
+ n_samples = 1000
+ n_features = 2
+ n_half_samples = n_samples // 2
+ x = rng.normal(0.0, 0.001, (n_samples, n_features))
+ y = np.zeros(n_samples)
+ y[:n_half_samples] = 2 * x[:n_half_samples, 0] + x[:n_half_samples, 1]
+ y[n_half_samples:] = x[n_half_samples:, 0] + 2 * x[n_half_samples:, 1]
+
+ # Fitting linear regression with perfect prediction
+ lr = LinearRegression(fit_intercept=False)
+ lr.fit(x, y)
+
+ # When all samples are weighted with the same weights, the ratio of
+ # the two features importance should equal to 1 on expectation (when using
+ # mean absolutes error as the loss function).
+ pi = permutation_importance(
+ lr, x, y, random_state=1, scoring="neg_mean_absolute_error", n_repeats=200
+ )
+ x1_x2_imp_ratio_w_none = pi.importances_mean[0] / pi.importances_mean[1]
+ assert x1_x2_imp_ratio_w_none == pytest.approx(1, 0.01)
+
+ # When passing a vector of ones as the sample_weight, results should be
+ # the same as in the case that sample_weight=None.
+ w = np.ones(n_samples)
+ pi = permutation_importance(
+ lr,
+ x,
+ y,
+ random_state=1,
+ scoring="neg_mean_absolute_error",
+ n_repeats=200,
+ sample_weight=w,
+ )
+ x1_x2_imp_ratio_w_ones = pi.importances_mean[0] / pi.importances_mean[1]
+ assert x1_x2_imp_ratio_w_ones == pytest.approx(x1_x2_imp_ratio_w_none, 0.01)
+
+ # When the ratio between the weights of the first half of the samples and
+ # the second half of the samples approaches to infinity, the ratio of
+ # the two features importance should equal to 2 on expectation (when using
+ # mean absolutes error as the loss function).
+ w = np.hstack(
+ [np.repeat(10.0**10, n_half_samples), np.repeat(1.0, n_half_samples)]
+ )
+ lr.fit(x, y, w)
+ pi = permutation_importance(
+ lr,
+ x,
+ y,
+ random_state=1,
+ scoring="neg_mean_absolute_error",
+ n_repeats=200,
+ sample_weight=w,
+ )
+ x1_x2_imp_ratio_w = pi.importances_mean[0] / pi.importances_mean[1]
+ assert x1_x2_imp_ratio_w / x1_x2_imp_ratio_w_none == pytest.approx(2, 0.01)
+
+
+def test_permutation_importance_no_weights_scoring_function():
+ # Creating a scorer function that does not takes sample_weight
+ def my_scorer(estimator, X, y):
+ return 1
+
+ # Creating some data and estimator for the permutation test
+ x = np.array([[1, 2], [3, 4]])
+ y = np.array([1, 2])
+ w = np.array([1, 1])
+ lr = LinearRegression()
+ lr.fit(x, y)
+
+ # test that permutation_importance does not return error when
+ # sample_weight is None
+ try:
+ permutation_importance(lr, x, y, random_state=1, scoring=my_scorer, n_repeats=1)
+ except TypeError:
+ pytest.fail(
+ "permutation_test raised an error when using a scorer "
+ "function that does not accept sample_weight even though "
+ "sample_weight was None"
+ )
+
+ # test that permutation_importance raise exception when sample_weight is
+ # not None
+ with pytest.raises(TypeError):
+ permutation_importance(
+ lr, x, y, random_state=1, scoring=my_scorer, n_repeats=1, sample_weight=w
+ )
+
+
+@pytest.mark.parametrize(
+ "list_single_scorer, multi_scorer",
+ [
+ (["r2", "neg_mean_squared_error"], ["r2", "neg_mean_squared_error"]),
+ (
+ ["r2", "neg_mean_squared_error"],
+ {
+ "r2": get_scorer("r2"),
+ "neg_mean_squared_error": get_scorer("neg_mean_squared_error"),
+ },
+ ),
+ (
+ ["r2", "neg_mean_squared_error"],
+ lambda estimator, X, y: {
+ "r2": r2_score(y, estimator.predict(X)),
+ "neg_mean_squared_error": -mean_squared_error(y, estimator.predict(X)),
+ },
+ ),
+ ],
+)
+def test_permutation_importance_multi_metric(list_single_scorer, multi_scorer):
+ # Test permutation importance when scoring contains multiple scorers
+
+ # Creating some data and estimator for the permutation test
+ x, y = make_regression(n_samples=500, n_features=10, random_state=0)
+ lr = LinearRegression().fit(x, y)
+
+ multi_importance = permutation_importance(
+ lr, x, y, random_state=1, scoring=multi_scorer, n_repeats=2
+ )
+ assert set(multi_importance.keys()) == set(list_single_scorer)
+
+ for scorer in list_single_scorer:
+ multi_result = multi_importance[scorer]
+ single_result = permutation_importance(
+ lr, x, y, random_state=1, scoring=scorer, n_repeats=2
+ )
+
+ assert_allclose(multi_result.importances, single_result.importances)
+
+
+@pytest.mark.parametrize("max_samples", [-1, 5])
+def test_permutation_importance_max_samples_error(max_samples):
+ """Check that a proper error message is raised when `max_samples` is not
+ set to a valid input value.
+ """
+ X = np.array([(1.0, 2.0, 3.0, 4.0)]).T
+ y = np.array([0, 1, 0, 1])
+
+ clf = LogisticRegression()
+ clf.fit(X, y)
+
+ err_msg = r"max_samples must be in \(0, n_samples\]"
+
+ with pytest.raises(ValueError, match=err_msg):
+ permutation_importance(clf, X, y, max_samples=max_samples)
diff --git a/modin/pandas/test/interoperability/sklearn/linear_model/test_base_lm.py b/modin/pandas/test/interoperability/sklearn/linear_model/test_base_lm.py
new file mode 100644
index 00000000000..007127a8bc8
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/linear_model/test_base_lm.py
@@ -0,0 +1,741 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# Author: Alexandre Gramfort
+# Fabian Pedregosa
+# Maria Telenczuk
+#
+# License: BSD 3 clause
+
+import pytest
+import warnings
+
+import numpy as np
+from scipy import sparse
+from scipy import linalg
+
+from sklearn.utils._testing import assert_array_almost_equal
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_allclose
+
+from sklearn.linear_model import LinearRegression
+from sklearn.linear_model._base import _deprecate_normalize
+from sklearn.linear_model._base import _preprocess_data
+from sklearn.linear_model._base import _rescale_data
+from sklearn.linear_model._base import make_dataset
+from sklearn.datasets import make_sparse_uncorrelated
+from sklearn.datasets import make_regression
+from sklearn.datasets import load_iris
+from sklearn.preprocessing import StandardScaler
+from sklearn.preprocessing import add_dummy_feature
+
+rtol = 1e-6
+
+
+def test_linear_regression():
+ # Test LinearRegression on a simple dataset.
+ # a simple dataset
+ X = [[1], [2]]
+ Y = [1, 2]
+
+ reg = LinearRegression()
+ reg.fit(X, Y)
+
+ assert_array_almost_equal(reg.coef_, [1])
+ assert_array_almost_equal(reg.intercept_, [0])
+ assert_array_almost_equal(reg.predict(X), [1, 2])
+
+ # test it also for degenerate input
+ X = [[1]]
+ Y = [0]
+
+ reg = LinearRegression()
+ reg.fit(X, Y)
+ assert_array_almost_equal(reg.coef_, [0])
+ assert_array_almost_equal(reg.intercept_, [0])
+ assert_array_almost_equal(reg.predict(X), [0])
+
+
+@pytest.mark.parametrize("array_constr", [np.array, sparse.csr_matrix])
+@pytest.mark.parametrize("fit_intercept", [True, False])
+def test_linear_regression_sample_weights(
+ array_constr, fit_intercept, global_random_seed
+):
+ rng = np.random.RandomState(global_random_seed)
+
+ # It would not work with under-determined systems
+ n_samples, n_features = 6, 5
+
+ X = array_constr(rng.normal(size=(n_samples, n_features)))
+ y = rng.normal(size=n_samples)
+
+ sample_weight = 1.0 + rng.uniform(size=n_samples)
+
+ # LinearRegression with explicit sample_weight
+ reg = LinearRegression(fit_intercept=fit_intercept)
+ reg.fit(X, y, sample_weight=sample_weight)
+ coefs1 = reg.coef_
+ inter1 = reg.intercept_
+
+ assert reg.coef_.shape == (X.shape[1],) # sanity checks
+
+ # Closed form of the weighted least square
+ # theta = (X^T W X)^(-1) @ X^T W y
+ W = np.diag(sample_weight)
+ X_aug = X if not fit_intercept else add_dummy_feature(X)
+
+ Xw = X_aug.T @ W @ X_aug
+ yw = X_aug.T @ W @ y
+ coefs2 = linalg.solve(Xw, yw)
+
+ if not fit_intercept:
+ assert_allclose(coefs1, coefs2)
+ else:
+ assert_allclose(coefs1, coefs2[1:])
+ assert_allclose(inter1, coefs2[0])
+
+
+def test_raises_value_error_if_positive_and_sparse():
+ error_msg = "A sparse matrix was passed, but dense data is required."
+ # X must not be sparse if positive == True
+ X = sparse.eye(10)
+ y = np.ones(10)
+
+ reg = LinearRegression(positive=True)
+
+ with pytest.raises(TypeError, match=error_msg):
+ reg.fit(X, y)
+
+
+@pytest.mark.parametrize("n_samples, n_features", [(2, 3), (3, 2)])
+def test_raises_value_error_if_sample_weights_greater_than_1d(n_samples, n_features):
+ # Sample weights must be either scalar or 1D
+ rng = np.random.RandomState(0)
+ X = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples)
+ sample_weights_OK = rng.randn(n_samples) ** 2 + 1
+ sample_weights_OK_1 = 1.0
+ sample_weights_OK_2 = 2.0
+
+ reg = LinearRegression()
+
+ # make sure the "OK" sample weights actually work
+ reg.fit(X, y, sample_weights_OK)
+ reg.fit(X, y, sample_weights_OK_1)
+ reg.fit(X, y, sample_weights_OK_2)
+
+
+def test_fit_intercept():
+ # Test assertions on betas shape.
+ X2 = np.array([[0.38349978, 0.61650022], [0.58853682, 0.41146318]])
+ X3 = np.array(
+ [[0.27677969, 0.70693172, 0.01628859], [0.08385139, 0.20692515, 0.70922346]]
+ )
+ y = np.array([1, 1])
+
+ lr2_without_intercept = LinearRegression(fit_intercept=False).fit(X2, y)
+ lr2_with_intercept = LinearRegression().fit(X2, y)
+
+ lr3_without_intercept = LinearRegression(fit_intercept=False).fit(X3, y)
+ lr3_with_intercept = LinearRegression().fit(X3, y)
+
+ assert lr2_with_intercept.coef_.shape == lr2_without_intercept.coef_.shape
+ assert lr3_with_intercept.coef_.shape == lr3_without_intercept.coef_.shape
+ assert lr2_without_intercept.coef_.ndim == lr3_without_intercept.coef_.ndim
+
+
+def test_error_on_wrong_normalize():
+ normalize = "wrong"
+ error_msg = "Leave 'normalize' to its default"
+ with pytest.raises(ValueError, match=error_msg):
+ _deprecate_normalize(normalize, "estimator")
+
+
+# TODO(1.4): remove
+@pytest.mark.parametrize("normalize", [True, False, "deprecated"])
+def test_deprecate_normalize(normalize):
+ # test all possible case of the normalize parameter deprecation
+ if normalize == "deprecated":
+ # no warning
+ output = False
+ expected = None
+ warning_msg = []
+ else:
+ output = normalize
+ expected = FutureWarning
+ warning_msg = ["1.4"]
+ if not normalize:
+ warning_msg.append("default value")
+ else:
+ warning_msg.append("StandardScaler(")
+
+ if expected is None:
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", FutureWarning)
+ _normalize = _deprecate_normalize(normalize, "estimator")
+ else:
+ with pytest.warns(expected) as record:
+ _normalize = _deprecate_normalize(normalize, "estimator")
+ assert all([warning in str(record[0].message) for warning in warning_msg])
+ assert _normalize == output
+
+
+def test_linear_regression_sparse(global_random_seed):
+ # Test that linear regression also works with sparse data
+ rng = np.random.RandomState(global_random_seed)
+ n = 100
+ X = sparse.eye(n, n)
+ beta = rng.rand(n)
+ y = X @ beta
+
+ ols = LinearRegression()
+ ols.fit(X, y.ravel())
+ assert_array_almost_equal(beta, ols.coef_ + ols.intercept_)
+
+ assert_array_almost_equal(ols.predict(X) - y.ravel(), 0)
+
+
+@pytest.mark.parametrize("fit_intercept", [True, False])
+def test_linear_regression_sparse_equal_dense(fit_intercept):
+ # Test that linear regression agrees between sparse and dense
+ rng = np.random.RandomState(0)
+ n_samples = 200
+ n_features = 2
+ X = rng.randn(n_samples, n_features)
+ X[X < 0.1] = 0.0
+ Xcsr = sparse.csr_matrix(X)
+ y = rng.rand(n_samples)
+ params = dict(fit_intercept=fit_intercept)
+ clf_dense = LinearRegression(**params)
+ clf_sparse = LinearRegression(**params)
+ clf_dense.fit(X, y)
+ clf_sparse.fit(Xcsr, y)
+ assert clf_dense.intercept_ == pytest.approx(clf_sparse.intercept_)
+ assert_allclose(clf_dense.coef_, clf_sparse.coef_)
+
+
+def test_linear_regression_multiple_outcome():
+ # Test multiple-outcome linear regressions
+ rng = np.random.RandomState(0)
+ X, y = make_regression(random_state=rng)
+
+ Y = np.vstack((y, y)).T
+ n_features = X.shape[1]
+
+ reg = LinearRegression()
+ reg.fit((X), Y)
+ assert reg.coef_.shape == (2, n_features)
+ Y_pred = reg.predict(X)
+ reg.fit(X, y)
+ y_pred = reg.predict(X)
+ assert_array_almost_equal(np.vstack((y_pred, y_pred)).T, Y_pred, decimal=3)
+
+
+def test_linear_regression_sparse_multiple_outcome(global_random_seed):
+ # Test multiple-outcome linear regressions with sparse data
+ rng = np.random.RandomState(global_random_seed)
+ X, y = make_sparse_uncorrelated(random_state=rng)
+ X = sparse.coo_matrix(X)
+ Y = np.vstack((y, y)).T
+ n_features = X.shape[1]
+
+ ols = LinearRegression()
+ ols.fit(X, Y)
+ assert ols.coef_.shape == (2, n_features)
+ Y_pred = ols.predict(X)
+ ols.fit(X, y.ravel())
+ y_pred = ols.predict(X)
+ assert_array_almost_equal(np.vstack((y_pred, y_pred)).T, Y_pred, decimal=3)
+
+
+def test_linear_regression_positive():
+ # Test nonnegative LinearRegression on a simple dataset.
+ X = [[1], [2]]
+ y = [1, 2]
+
+ reg = LinearRegression(positive=True)
+ reg.fit(X, y)
+
+ assert_array_almost_equal(reg.coef_, [1])
+ assert_array_almost_equal(reg.intercept_, [0])
+ assert_array_almost_equal(reg.predict(X), [1, 2])
+
+ # test it also for degenerate input
+ X = [[1]]
+ y = [0]
+
+ reg = LinearRegression(positive=True)
+ reg.fit(X, y)
+ assert_allclose(reg.coef_, [0])
+ assert_allclose(reg.intercept_, [0])
+ assert_allclose(reg.predict(X), [0])
+
+
+def test_linear_regression_positive_multiple_outcome(global_random_seed):
+ # Test multiple-outcome nonnegative linear regressions
+ rng = np.random.RandomState(global_random_seed)
+ X, y = make_sparse_uncorrelated(random_state=rng)
+ Y = np.vstack((y, y)).T
+ n_features = X.shape[1]
+
+ ols = LinearRegression(positive=True)
+ ols.fit(X, Y)
+ assert ols.coef_.shape == (2, n_features)
+ assert np.all(ols.coef_ >= 0.0)
+ Y_pred = ols.predict(X)
+ ols.fit(X, y.ravel())
+ y_pred = ols.predict(X)
+ assert_allclose(np.vstack((y_pred, y_pred)).T, Y_pred)
+
+
+def test_linear_regression_positive_vs_nonpositive(global_random_seed):
+ # Test differences with LinearRegression when positive=False.
+ rng = np.random.RandomState(global_random_seed)
+ X, y = make_sparse_uncorrelated(random_state=rng)
+
+ reg = LinearRegression(positive=True)
+ reg.fit(X, y)
+ regn = LinearRegression(positive=False)
+ regn.fit(X, y)
+
+ assert np.mean((reg.coef_ - regn.coef_) ** 2) > 1e-3
+
+
+def test_linear_regression_positive_vs_nonpositive_when_positive(global_random_seed):
+ # Test LinearRegression fitted coefficients
+ # when the problem is positive.
+ rng = np.random.RandomState(global_random_seed)
+ n_samples = 200
+ n_features = 4
+ X = rng.rand(n_samples, n_features)
+ y = X[:, 0] + 2 * X[:, 1] + 3 * X[:, 2] + 1.5 * X[:, 3]
+
+ reg = LinearRegression(positive=True)
+ reg.fit(X, y)
+ regn = LinearRegression(positive=False)
+ regn.fit(X, y)
+
+ assert np.mean((reg.coef_ - regn.coef_) ** 2) < 1e-6
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_linear_regression_pd_sparse_dataframe_warning():
+ pd = pytest.importorskip("modin.pandas")
+
+ # Warning is raised only when some of the columns is sparse
+ df = pd.DataFrame({"0": np.random.randn(10)})
+ for col in range(1, 4):
+ arr = np.random.randn(10)
+ arr[:8] = 0
+ # all columns but the first column is sparse
+ if col != 0:
+ arr = pd.arrays.SparseArray(arr, fill_value=0)
+ df[str(col)] = arr
+
+ msg = "pandas.DataFrame with sparse columns found."
+
+ reg = LinearRegression()
+ with pytest.warns(UserWarning, match=msg):
+ reg.fit(df.iloc[:, 0:2], df.iloc[:, 3])
+
+ # does not warn when the whole dataframe is sparse
+ df["0"] = pd.arrays.SparseArray(df["0"], fill_value=0)
+ assert hasattr(df, "sparse")
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", UserWarning)
+ reg.fit(df.iloc[:, 0:2], df.iloc[:, 3])
+
+
+def test_preprocess_data(global_random_seed):
+ rng = np.random.RandomState(global_random_seed)
+ n_samples = 200
+ n_features = 2
+ X = rng.rand(n_samples, n_features)
+ y = rng.rand(n_samples)
+ expected_X_mean = np.mean(X, axis=0)
+ expected_X_scale = np.std(X, axis=0) * np.sqrt(X.shape[0])
+ expected_y_mean = np.mean(y, axis=0)
+
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X, y, fit_intercept=False, normalize=False
+ )
+ assert_array_almost_equal(X_mean, np.zeros(n_features))
+ assert_array_almost_equal(y_mean, 0)
+ assert_array_almost_equal(X_scale, np.ones(n_features))
+ assert_array_almost_equal(Xt, X)
+ assert_array_almost_equal(yt, y)
+
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X, y, fit_intercept=True, normalize=False
+ )
+ assert_array_almost_equal(X_mean, expected_X_mean)
+ assert_array_almost_equal(y_mean, expected_y_mean)
+ assert_array_almost_equal(X_scale, np.ones(n_features))
+ assert_array_almost_equal(Xt, X - expected_X_mean)
+ assert_array_almost_equal(yt, y - expected_y_mean)
+
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X, y, fit_intercept=True, normalize=True
+ )
+ assert_array_almost_equal(X_mean, expected_X_mean)
+ assert_array_almost_equal(y_mean, expected_y_mean)
+ assert_array_almost_equal(X_scale, expected_X_scale)
+ assert_array_almost_equal(Xt, (X - expected_X_mean) / expected_X_scale)
+ assert_array_almost_equal(yt, y - expected_y_mean)
+
+
+def test_preprocess_data_multioutput(global_random_seed):
+ rng = np.random.RandomState(global_random_seed)
+ n_samples = 200
+ n_features = 3
+ n_outputs = 2
+ X = rng.rand(n_samples, n_features)
+ y = rng.rand(n_samples, n_outputs)
+ expected_y_mean = np.mean(y, axis=0)
+
+ args = [X, sparse.csc_matrix(X)]
+ for X in args:
+ _, yt, _, y_mean, _ = _preprocess_data(
+ X, y, fit_intercept=False, normalize=False
+ )
+ assert_array_almost_equal(y_mean, np.zeros(n_outputs))
+ assert_array_almost_equal(yt, y)
+
+ _, yt, _, y_mean, _ = _preprocess_data(
+ X, y, fit_intercept=True, normalize=False
+ )
+ assert_array_almost_equal(y_mean, expected_y_mean)
+ assert_array_almost_equal(yt, y - y_mean)
+
+ _, yt, _, y_mean, _ = _preprocess_data(X, y, fit_intercept=True, normalize=True)
+ assert_array_almost_equal(y_mean, expected_y_mean)
+ assert_array_almost_equal(yt, y - y_mean)
+
+
+@pytest.mark.parametrize("is_sparse", [False, True])
+def test_preprocess_data_weighted(is_sparse, global_random_seed):
+ rng = np.random.RandomState(global_random_seed)
+ n_samples = 200
+ n_features = 4
+ # Generate random data with 50% of zero values to make sure
+ # that the sparse variant of this test is actually sparse. This also
+ # shifts the mean value for each columns in X further away from
+ # zero.
+ X = rng.rand(n_samples, n_features)
+ X[X < 0.5] = 0.0
+
+ # Scale the first feature of X to be 10 larger than the other to
+ # better check the impact of feature scaling.
+ X[:, 0] *= 10
+
+ # Constant non-zero feature.
+ X[:, 2] = 1.0
+
+ # Constant zero feature (non-materialized in the sparse case)
+ X[:, 3] = 0.0
+ y = rng.rand(n_samples)
+
+ sample_weight = rng.rand(n_samples)
+ expected_X_mean = np.average(X, axis=0, weights=sample_weight)
+ expected_y_mean = np.average(y, axis=0, weights=sample_weight)
+
+ X_sample_weight_avg = np.average(X, weights=sample_weight, axis=0)
+ X_sample_weight_var = np.average(
+ (X - X_sample_weight_avg) ** 2, weights=sample_weight, axis=0
+ )
+ constant_mask = X_sample_weight_var < 10 * np.finfo(X.dtype).eps
+ assert_array_equal(constant_mask, [0, 0, 1, 1])
+ expected_X_scale = np.sqrt(X_sample_weight_var) * np.sqrt(sample_weight.sum())
+
+ # near constant features should not be scaled
+ expected_X_scale[constant_mask] = 1
+
+ if is_sparse:
+ X = sparse.csr_matrix(X)
+
+ # normalize is False
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X,
+ y,
+ fit_intercept=True,
+ normalize=False,
+ sample_weight=sample_weight,
+ )
+ assert_array_almost_equal(X_mean, expected_X_mean)
+ assert_array_almost_equal(y_mean, expected_y_mean)
+ assert_array_almost_equal(X_scale, np.ones(n_features))
+ if is_sparse:
+ assert_array_almost_equal(Xt.toarray(), X.toarray())
+ else:
+ assert_array_almost_equal(Xt, X - expected_X_mean)
+ assert_array_almost_equal(yt, y - expected_y_mean)
+
+ # normalize is True
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X,
+ y,
+ fit_intercept=True,
+ normalize=True,
+ sample_weight=sample_weight,
+ )
+
+ assert_array_almost_equal(X_mean, expected_X_mean)
+ assert_array_almost_equal(y_mean, expected_y_mean)
+ assert_array_almost_equal(X_scale, expected_X_scale)
+
+ if is_sparse:
+ # X is not centered
+ assert_array_almost_equal(Xt.toarray(), X.toarray() / expected_X_scale)
+ else:
+ assert_array_almost_equal(Xt, (X - expected_X_mean) / expected_X_scale)
+
+ # _preprocess_data with normalize=True scales the data by the feature-wise
+ # euclidean norms while StandardScaler scales the data by the feature-wise
+ # standard deviations.
+ # The two are equivalent up to a ratio of np.sqrt(n_samples) if unweighted
+ # or np.sqrt(sample_weight.sum()) if weighted.
+ if is_sparse:
+ scaler = StandardScaler(with_mean=False).fit(X, sample_weight=sample_weight)
+
+ # Non-constant features are scaled similarly with np.sqrt(n_samples)
+ assert_array_almost_equal(
+ scaler.transform(X).toarray()[:, :2] / np.sqrt(sample_weight.sum()),
+ Xt.toarray()[:, :2],
+ )
+
+ # Constant features go through un-scaled.
+ assert_array_almost_equal(
+ scaler.transform(X).toarray()[:, 2:], Xt.toarray()[:, 2:]
+ )
+ else:
+ scaler = StandardScaler(with_mean=True).fit(X, sample_weight=sample_weight)
+ assert_array_almost_equal(scaler.mean_, X_mean)
+ assert_array_almost_equal(
+ scaler.transform(X) / np.sqrt(sample_weight.sum()),
+ Xt,
+ )
+ assert_array_almost_equal(yt, y - expected_y_mean)
+
+
+def test_sparse_preprocess_data_offsets(global_random_seed):
+ rng = np.random.RandomState(global_random_seed)
+ n_samples = 200
+ n_features = 2
+ X = sparse.rand(n_samples, n_features, density=0.5, random_state=rng)
+ X = X.tolil()
+ y = rng.rand(n_samples)
+ XA = X.toarray()
+ expected_X_scale = np.std(XA, axis=0) * np.sqrt(X.shape[0])
+
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X, y, fit_intercept=False, normalize=False
+ )
+ assert_array_almost_equal(X_mean, np.zeros(n_features))
+ assert_array_almost_equal(y_mean, 0)
+ assert_array_almost_equal(X_scale, np.ones(n_features))
+ assert_array_almost_equal(Xt.A, XA)
+ assert_array_almost_equal(yt, y)
+
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X, y, fit_intercept=True, normalize=False
+ )
+ assert_array_almost_equal(X_mean, np.mean(XA, axis=0))
+ assert_array_almost_equal(y_mean, np.mean(y, axis=0))
+ assert_array_almost_equal(X_scale, np.ones(n_features))
+ assert_array_almost_equal(Xt.A, XA)
+ assert_array_almost_equal(yt, y - np.mean(y, axis=0))
+
+ Xt, yt, X_mean, y_mean, X_scale = _preprocess_data(
+ X, y, fit_intercept=True, normalize=True
+ )
+ assert_array_almost_equal(X_mean, np.mean(XA, axis=0))
+ assert_array_almost_equal(y_mean, np.mean(y, axis=0))
+ assert_array_almost_equal(X_scale, expected_X_scale)
+ assert_array_almost_equal(Xt.A, XA / expected_X_scale)
+ assert_array_almost_equal(yt, y - np.mean(y, axis=0))
+
+
+def test_csr_preprocess_data():
+ # Test output format of _preprocess_data, when input is csr
+ X, y = make_regression()
+ X[X < 2.5] = 0.0
+ csr = sparse.csr_matrix(X)
+ csr_, y, _, _, _ = _preprocess_data(csr, y, True)
+ assert csr_.getformat() == "csr"
+
+
+@pytest.mark.parametrize("is_sparse", (True, False))
+@pytest.mark.parametrize("to_copy", (True, False))
+def test_preprocess_copy_data_no_checks(is_sparse, to_copy):
+ X, y = make_regression()
+ X[X < 2.5] = 0.0
+
+ if is_sparse:
+ X = sparse.csr_matrix(X)
+
+ X_, y_, _, _, _ = _preprocess_data(X, y, True, copy=to_copy, check_input=False)
+
+ if to_copy and is_sparse:
+ assert not np.may_share_memory(X_.data, X.data)
+ elif to_copy:
+ assert not np.may_share_memory(X_, X)
+ elif is_sparse:
+ assert np.may_share_memory(X_.data, X.data)
+ else:
+ assert np.may_share_memory(X_, X)
+
+
+def test_dtype_preprocess_data(global_random_seed):
+ rng = np.random.RandomState(global_random_seed)
+ n_samples = 200
+ n_features = 2
+ X = rng.rand(n_samples, n_features)
+ y = rng.rand(n_samples)
+
+ X_32 = np.asarray(X, dtype=np.float32)
+ y_32 = np.asarray(y, dtype=np.float32)
+ X_64 = np.asarray(X, dtype=np.float64)
+ y_64 = np.asarray(y, dtype=np.float64)
+
+ for fit_intercept in [True, False]:
+ for normalize in [True, False]:
+ Xt_32, yt_32, X_mean_32, y_mean_32, X_scale_32 = _preprocess_data(
+ X_32,
+ y_32,
+ fit_intercept=fit_intercept,
+ normalize=normalize,
+ )
+
+ Xt_64, yt_64, X_mean_64, y_mean_64, X_scale_64 = _preprocess_data(
+ X_64,
+ y_64,
+ fit_intercept=fit_intercept,
+ normalize=normalize,
+ )
+
+ Xt_3264, yt_3264, X_mean_3264, y_mean_3264, X_scale_3264 = _preprocess_data(
+ X_32,
+ y_64,
+ fit_intercept=fit_intercept,
+ normalize=normalize,
+ )
+
+ Xt_6432, yt_6432, X_mean_6432, y_mean_6432, X_scale_6432 = _preprocess_data(
+ X_64,
+ y_32,
+ fit_intercept=fit_intercept,
+ normalize=normalize,
+ )
+
+ assert Xt_32.dtype == np.float32
+ assert yt_32.dtype == np.float32
+ assert X_mean_32.dtype == np.float32
+ assert y_mean_32.dtype == np.float32
+ assert X_scale_32.dtype == np.float32
+
+ assert Xt_64.dtype == np.float64
+ assert yt_64.dtype == np.float64
+ assert X_mean_64.dtype == np.float64
+ assert y_mean_64.dtype == np.float64
+ assert X_scale_64.dtype == np.float64
+
+ assert Xt_3264.dtype == np.float32
+ assert yt_3264.dtype == np.float32
+ assert X_mean_3264.dtype == np.float32
+ assert y_mean_3264.dtype == np.float32
+ assert X_scale_3264.dtype == np.float32
+
+ assert Xt_6432.dtype == np.float64
+ assert yt_6432.dtype == np.float64
+ assert X_mean_6432.dtype == np.float64
+ assert y_mean_6432.dtype == np.float64
+ assert X_scale_6432.dtype == np.float64
+
+ assert X_32.dtype == np.float32
+ assert y_32.dtype == np.float32
+ assert X_64.dtype == np.float64
+ assert y_64.dtype == np.float64
+
+ assert_array_almost_equal(Xt_32, Xt_64)
+ assert_array_almost_equal(yt_32, yt_64)
+ assert_array_almost_equal(X_mean_32, X_mean_64)
+ assert_array_almost_equal(y_mean_32, y_mean_64)
+ assert_array_almost_equal(X_scale_32, X_scale_64)
+
+
+@pytest.mark.parametrize("n_targets", [None, 2])
+def test_rescale_data_dense(n_targets, global_random_seed):
+ rng = np.random.RandomState(global_random_seed)
+ n_samples = 200
+ n_features = 2
+
+ sample_weight = 1.0 + rng.rand(n_samples)
+ X = rng.rand(n_samples, n_features)
+ if n_targets is None:
+ y = rng.rand(n_samples)
+ else:
+ y = rng.rand(n_samples, n_targets)
+ rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight)
+ rescaled_X2 = X * sqrt_sw[:, np.newaxis]
+ if n_targets is None:
+ rescaled_y2 = y * sqrt_sw
+ else:
+ rescaled_y2 = y * sqrt_sw[:, np.newaxis]
+ assert_array_almost_equal(rescaled_X, rescaled_X2)
+ assert_array_almost_equal(rescaled_y, rescaled_y2)
+
+
+def test_fused_types_make_dataset():
+ iris = load_iris()
+
+ X_32 = iris.data.astype(np.float32)
+ y_32 = iris.target.astype(np.float32)
+ X_csr_32 = sparse.csr_matrix(X_32)
+ sample_weight_32 = np.arange(y_32.size, dtype=np.float32)
+
+ X_64 = iris.data.astype(np.float64)
+ y_64 = iris.target.astype(np.float64)
+ X_csr_64 = sparse.csr_matrix(X_64)
+ sample_weight_64 = np.arange(y_64.size, dtype=np.float64)
+
+ # array
+ dataset_32, _ = make_dataset(X_32, y_32, sample_weight_32)
+ dataset_64, _ = make_dataset(X_64, y_64, sample_weight_64)
+ xi_32, yi_32, _, _ = dataset_32._next_py()
+ xi_64, yi_64, _, _ = dataset_64._next_py()
+ xi_data_32, _, _ = xi_32
+ xi_data_64, _, _ = xi_64
+
+ assert xi_data_32.dtype == np.float32
+ assert xi_data_64.dtype == np.float64
+ assert_allclose(yi_64, yi_32, rtol=rtol)
+
+ # csr
+ datasetcsr_32, _ = make_dataset(X_csr_32, y_32, sample_weight_32)
+ datasetcsr_64, _ = make_dataset(X_csr_64, y_64, sample_weight_64)
+ xicsr_32, yicsr_32, _, _ = datasetcsr_32._next_py()
+ xicsr_64, yicsr_64, _, _ = datasetcsr_64._next_py()
+ xicsr_data_32, _, _ = xicsr_32
+ xicsr_data_64, _, _ = xicsr_64
+
+ assert xicsr_data_32.dtype == np.float32
+ assert xicsr_data_64.dtype == np.float64
+
+ assert_allclose(xicsr_data_64, xicsr_data_32, rtol=rtol)
+ assert_allclose(yicsr_64, yicsr_32, rtol=rtol)
+
+ assert_array_equal(xi_data_32, xicsr_data_32)
+ assert_array_equal(xi_data_64, xicsr_data_64)
+ assert_array_equal(yi_32, yicsr_32)
+ assert_array_equal(yi_64, yicsr_64)
diff --git a/modin/pandas/test/interoperability/sklearn/manifold/test_t_sne.py b/modin/pandas/test/interoperability/sklearn/manifold/test_t_sne.py
new file mode 100644
index 00000000000..0050afe7848
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/manifold/test_t_sne.py
@@ -0,0 +1,1220 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+import sys
+from io import StringIO
+import numpy as np
+from numpy.testing import assert_allclose
+import scipy.sparse as sp
+import pytest
+
+from sklearn import config_context
+from sklearn.neighbors import NearestNeighbors
+from sklearn.neighbors import kneighbors_graph
+from sklearn.exceptions import EfficiencyWarning
+from sklearn.utils._testing import ignore_warnings
+from sklearn.utils._testing import assert_almost_equal
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_array_almost_equal
+from sklearn.utils._testing import skip_if_32bit
+from sklearn.utils import check_random_state
+from sklearn.manifold._t_sne import _joint_probabilities
+from sklearn.manifold._t_sne import _joint_probabilities_nn
+from sklearn.manifold._t_sne import _kl_divergence
+from sklearn.manifold._t_sne import _kl_divergence_bh
+from sklearn.manifold._t_sne import _gradient_descent
+from sklearn.manifold._t_sne import trustworthiness
+from sklearn.manifold import TSNE
+
+# mypy error: Module 'sklearn.manifold' has no attribute '_barnes_hut_tsne'
+from sklearn.manifold import _barnes_hut_tsne # type: ignore
+from sklearn.manifold._utils import _binary_search_perplexity
+from sklearn.datasets import make_blobs
+from scipy.optimize import check_grad
+from scipy.spatial.distance import pdist
+from scipy.spatial.distance import squareform
+from sklearn.metrics.pairwise import pairwise_distances
+from sklearn.metrics.pairwise import manhattan_distances
+from sklearn.metrics.pairwise import cosine_distances
+
+
+x = np.linspace(0, 1, 10)
+xx, yy = np.meshgrid(x, x)
+X_2d_grid = np.hstack(
+ [
+ xx.ravel().reshape(-1, 1),
+ yy.ravel().reshape(-1, 1),
+ ]
+)
+
+
+def test_gradient_descent_stops():
+ # Test stopping conditions of gradient descent.
+ class ObjectiveSmallGradient:
+ def __init__(self):
+ self.it = -1
+
+ def __call__(self, _, compute_error=True):
+ self.it += 1
+ return (10 - self.it) / 10.0, np.array([1e-5])
+
+ def flat_function(_, compute_error=True):
+ return 0.0, np.ones(1)
+
+ # Gradient norm
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ _, error, it = _gradient_descent(
+ ObjectiveSmallGradient(),
+ np.zeros(1),
+ 0,
+ n_iter=100,
+ n_iter_without_progress=100,
+ momentum=0.0,
+ learning_rate=0.0,
+ min_gain=0.0,
+ min_grad_norm=1e-5,
+ verbose=2,
+ )
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+ assert error == 1.0
+ assert it == 0
+ assert "gradient norm" in out
+
+ # Maximum number of iterations without improvement
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ _, error, it = _gradient_descent(
+ flat_function,
+ np.zeros(1),
+ 0,
+ n_iter=100,
+ n_iter_without_progress=10,
+ momentum=0.0,
+ learning_rate=0.0,
+ min_gain=0.0,
+ min_grad_norm=0.0,
+ verbose=2,
+ )
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+ assert error == 0.0
+ assert it == 11
+ assert "did not make any progress" in out
+
+ # Maximum number of iterations
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ _, error, it = _gradient_descent(
+ ObjectiveSmallGradient(),
+ np.zeros(1),
+ 0,
+ n_iter=11,
+ n_iter_without_progress=100,
+ momentum=0.0,
+ learning_rate=0.0,
+ min_gain=0.0,
+ min_grad_norm=0.0,
+ verbose=2,
+ )
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+ assert error == 0.0
+ assert it == 10
+ assert "Iteration 10" in out
+
+
+def test_binary_search():
+ # Test if the binary search finds Gaussians with desired perplexity.
+ random_state = check_random_state(0)
+ data = random_state.randn(50, 5)
+ distances = pairwise_distances(data).astype(np.float32)
+ desired_perplexity = 25.0
+ P = _binary_search_perplexity(distances, desired_perplexity, verbose=0)
+ P = np.maximum(P, np.finfo(np.double).eps)
+ mean_perplexity = np.mean(
+ [np.exp(-np.sum(P[i] * np.log(P[i]))) for i in range(P.shape[0])]
+ )
+ assert_almost_equal(mean_perplexity, desired_perplexity, decimal=3)
+
+
+def test_binary_search_underflow():
+ # Test if the binary search finds Gaussians with desired perplexity.
+ # A more challenging case than the one above, producing numeric
+ # underflow in float precision (see issue #19471 and PR #19472).
+ random_state = check_random_state(42)
+ data = random_state.randn(1, 90).astype(np.float32) + 100
+ desired_perplexity = 30.0
+ P = _binary_search_perplexity(data, desired_perplexity, verbose=0)
+ perplexity = 2 ** -np.nansum(P[0, 1:] * np.log2(P[0, 1:]))
+ assert_almost_equal(perplexity, desired_perplexity, decimal=3)
+
+
+def test_binary_search_neighbors():
+ # Binary perplexity search approximation.
+ # Should be approximately equal to the slow method when we use
+ # all points as neighbors.
+ n_samples = 200
+ desired_perplexity = 25.0
+ random_state = check_random_state(0)
+ data = random_state.randn(n_samples, 2).astype(np.float32, copy=False)
+ distances = pairwise_distances(data)
+ P1 = _binary_search_perplexity(distances, desired_perplexity, verbose=0)
+
+ # Test that when we use all the neighbors the results are identical
+ n_neighbors = n_samples - 1
+ nn = NearestNeighbors().fit(data)
+ distance_graph = nn.kneighbors_graph(n_neighbors=n_neighbors, mode="distance")
+ distances_nn = distance_graph.data.astype(np.float32, copy=False)
+ distances_nn = distances_nn.reshape(n_samples, n_neighbors)
+ P2 = _binary_search_perplexity(distances_nn, desired_perplexity, verbose=0)
+
+ indptr = distance_graph.indptr
+ P1_nn = np.array(
+ [
+ P1[k, distance_graph.indices[indptr[k] : indptr[k + 1]]]
+ for k in range(n_samples)
+ ]
+ )
+ assert_array_almost_equal(P1_nn, P2, decimal=4)
+
+ # Test that the highest P_ij are the same when fewer neighbors are used
+ for k in np.linspace(150, n_samples - 1, 5):
+ k = int(k)
+ topn = k * 10 # check the top 10 * k entries out of k * k entries
+ distance_graph = nn.kneighbors_graph(n_neighbors=k, mode="distance")
+ distances_nn = distance_graph.data.astype(np.float32, copy=False)
+ distances_nn = distances_nn.reshape(n_samples, k)
+ P2k = _binary_search_perplexity(distances_nn, desired_perplexity, verbose=0)
+ assert_array_almost_equal(P1_nn, P2, decimal=2)
+ idx = np.argsort(P1.ravel())[::-1]
+ P1top = P1.ravel()[idx][:topn]
+ idx = np.argsort(P2k.ravel())[::-1]
+ P2top = P2k.ravel()[idx][:topn]
+ assert_array_almost_equal(P1top, P2top, decimal=2)
+
+
+def test_binary_perplexity_stability():
+ # Binary perplexity search should be stable.
+ # The binary_search_perplexity had a bug wherein the P array
+ # was uninitialized, leading to sporadically failing tests.
+ n_neighbors = 10
+ n_samples = 100
+ random_state = check_random_state(0)
+ data = random_state.randn(n_samples, 5)
+ nn = NearestNeighbors().fit(data)
+ distance_graph = nn.kneighbors_graph(n_neighbors=n_neighbors, mode="distance")
+ distances = distance_graph.data.astype(np.float32, copy=False)
+ distances = distances.reshape(n_samples, n_neighbors)
+ last_P = None
+ desired_perplexity = 3
+ for _ in range(100):
+ P = _binary_search_perplexity(distances.copy(), desired_perplexity, verbose=0)
+ P1 = _joint_probabilities_nn(distance_graph, desired_perplexity, verbose=0)
+ # Convert the sparse matrix to a dense one for testing
+ P1 = P1.toarray()
+ if last_P is None:
+ last_P = P
+ last_P1 = P1
+ else:
+ assert_array_almost_equal(P, last_P, decimal=4)
+ assert_array_almost_equal(P1, last_P1, decimal=4)
+
+
+def test_gradient():
+ # Test gradient of Kullback-Leibler divergence.
+ random_state = check_random_state(0)
+
+ n_samples = 50
+ n_features = 2
+ n_components = 2
+ alpha = 1.0
+
+ distances = random_state.randn(n_samples, n_features).astype(np.float32)
+ distances = np.abs(distances.dot(distances.T))
+ np.fill_diagonal(distances, 0.0)
+ X_embedded = random_state.randn(n_samples, n_components).astype(np.float32)
+
+ P = _joint_probabilities(distances, desired_perplexity=25.0, verbose=0)
+
+ def fun(params):
+ return _kl_divergence(params, P, alpha, n_samples, n_components)[0]
+
+ def grad(params):
+ return _kl_divergence(params, P, alpha, n_samples, n_components)[1]
+
+ assert_almost_equal(check_grad(fun, grad, X_embedded.ravel()), 0.0, decimal=5)
+
+
+def test_trustworthiness():
+ # Test trustworthiness score.
+ random_state = check_random_state(0)
+
+ # Affine transformation
+ X = random_state.randn(100, 2)
+ assert trustworthiness(X, 5.0 + X / 10.0) == 1.0
+
+ # Randomly shuffled
+ X = np.arange(100).reshape(-1, 1)
+ X_embedded = X.copy()
+ random_state.shuffle(X_embedded)
+ assert trustworthiness(X, X_embedded) < 0.6
+
+ # Completely different
+ X = np.arange(5).reshape(-1, 1)
+ X_embedded = np.array([[0], [2], [4], [1], [3]])
+ assert_almost_equal(trustworthiness(X, X_embedded, n_neighbors=1), 0.2)
+
+
+def test_trustworthiness_n_neighbors_error():
+ """Raise an error when n_neighbors >= n_samples / 2.
+
+ Non-regression test for #18567.
+ """
+ regex = "n_neighbors .+ should be less than .+"
+ rng = np.random.RandomState(42)
+ X = rng.rand(7, 4)
+ X_embedded = rng.rand(7, 2)
+ with pytest.raises(ValueError, match=regex):
+ trustworthiness(X, X_embedded, n_neighbors=5)
+
+ trust = trustworthiness(X, X_embedded, n_neighbors=3)
+ assert 0 <= trust <= 1
+
+
+@pytest.mark.parametrize("method", ["exact", "barnes_hut"])
+@pytest.mark.parametrize("init", ("random", "pca"))
+def test_preserve_trustworthiness_approximately(method, init):
+ # Nearest neighbors should be preserved approximately.
+ random_state = check_random_state(0)
+ n_components = 2
+ X = random_state.randn(50, n_components).astype(np.float32)
+ tsne = TSNE(
+ n_components=n_components,
+ init=init,
+ random_state=0,
+ method=method,
+ n_iter=700,
+ learning_rate="auto",
+ )
+ X_embedded = tsne.fit_transform(X)
+ t = trustworthiness(X, X_embedded, n_neighbors=1)
+ assert t > 0.85
+
+
+def test_optimization_minimizes_kl_divergence():
+ """t-SNE should give a lower KL divergence with more iterations."""
+ random_state = check_random_state(0)
+ X, _ = make_blobs(n_features=3, random_state=random_state)
+ kl_divergences = []
+ for n_iter in [250, 300, 350]:
+ tsne = TSNE(
+ n_components=2,
+ init="random",
+ perplexity=10,
+ learning_rate=100.0,
+ n_iter=n_iter,
+ random_state=0,
+ )
+ tsne.fit_transform(X)
+ kl_divergences.append(tsne.kl_divergence_)
+ assert kl_divergences[1] <= kl_divergences[0]
+ assert kl_divergences[2] <= kl_divergences[1]
+
+
+@pytest.mark.parametrize("method", ["exact", "barnes_hut"])
+def test_fit_transform_csr_matrix(method):
+ # TODO: compare results on dense and sparse data as proposed in:
+ # https://github.com/scikit-learn/scikit-learn/pull/23585#discussion_r968388186
+ # X can be a sparse matrix.
+ rng = check_random_state(0)
+ X = rng.randn(50, 2)
+ X[(rng.randint(0, 50, 25), rng.randint(0, 2, 25))] = 0.0
+ X_csr = sp.csr_matrix(X)
+ tsne = TSNE(
+ n_components=2,
+ init="random",
+ perplexity=10,
+ learning_rate=100.0,
+ random_state=0,
+ method=method,
+ n_iter=750,
+ )
+ X_embedded = tsne.fit_transform(X_csr)
+ assert_allclose(trustworthiness(X_csr, X_embedded, n_neighbors=1), 1.0, rtol=1.1e-1)
+
+
+def test_preserve_trustworthiness_approximately_with_precomputed_distances():
+ # Nearest neighbors should be preserved approximately.
+ random_state = check_random_state(0)
+ for i in range(3):
+ X = random_state.randn(80, 2)
+ D = squareform(pdist(X), "sqeuclidean")
+ tsne = TSNE(
+ n_components=2,
+ perplexity=2,
+ learning_rate=100.0,
+ early_exaggeration=2.0,
+ metric="precomputed",
+ random_state=i,
+ verbose=0,
+ n_iter=500,
+ init="random",
+ )
+ X_embedded = tsne.fit_transform(D)
+ t = trustworthiness(D, X_embedded, n_neighbors=1, metric="precomputed")
+ assert t > 0.95
+
+
+def test_trustworthiness_not_euclidean_metric():
+ # Test trustworthiness with a metric different from 'euclidean' and
+ # 'precomputed'
+ random_state = check_random_state(0)
+ X = random_state.randn(100, 2)
+ assert trustworthiness(X, X, metric="cosine") == trustworthiness(
+ pairwise_distances(X, metric="cosine"), X, metric="precomputed"
+ )
+
+
+@pytest.mark.parametrize(
+ "method, retype",
+ [
+ ("exact", np.asarray),
+ ("barnes_hut", np.asarray),
+ ("barnes_hut", sp.csr_matrix),
+ ],
+)
+@pytest.mark.parametrize(
+ "D, message_regex",
+ [
+ ([[0.0], [1.0]], ".* square distance matrix"),
+ ([[0.0, -1.0], [1.0, 0.0]], ".* positive.*"),
+ ],
+)
+def test_bad_precomputed_distances(method, D, retype, message_regex):
+ tsne = TSNE(
+ metric="precomputed",
+ method=method,
+ init="random",
+ random_state=42,
+ perplexity=1,
+ )
+ with pytest.raises(ValueError, match=message_regex):
+ tsne.fit_transform(retype(D))
+
+
+def test_exact_no_precomputed_sparse():
+ tsne = TSNE(
+ metric="precomputed",
+ method="exact",
+ init="random",
+ random_state=42,
+ perplexity=1,
+ )
+ with pytest.raises(TypeError, match="sparse"):
+ tsne.fit_transform(sp.csr_matrix([[0, 5], [5, 0]]))
+
+
+def test_high_perplexity_precomputed_sparse_distances():
+ # Perplexity should be less than 50
+ dist = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
+ bad_dist = sp.csr_matrix(dist)
+ tsne = TSNE(metric="precomputed", init="random", random_state=42, perplexity=1)
+ msg = "3 neighbors per samples are required, but some samples have only 1"
+ with pytest.raises(ValueError, match=msg):
+ tsne.fit_transform(bad_dist)
+
+
+@ignore_warnings(category=EfficiencyWarning)
+def test_sparse_precomputed_distance():
+ """Make sure that TSNE works identically for sparse and dense matrix"""
+ random_state = check_random_state(0)
+ X = random_state.randn(100, 2)
+
+ D_sparse = kneighbors_graph(X, n_neighbors=100, mode="distance", include_self=True)
+ D = pairwise_distances(X)
+ assert sp.issparse(D_sparse)
+ assert_almost_equal(D_sparse.A, D)
+
+ tsne = TSNE(
+ metric="precomputed", random_state=0, init="random", learning_rate="auto"
+ )
+ Xt_dense = tsne.fit_transform(D)
+
+ for fmt in ["csr", "lil"]:
+ Xt_sparse = tsne.fit_transform(D_sparse.asformat(fmt))
+ assert_almost_equal(Xt_dense, Xt_sparse)
+
+
+def test_non_positive_computed_distances():
+ # Computed distance matrices must be positive.
+ def metric(x, y):
+ return -1
+
+ # Negative computed distances should be caught even if result is squared
+ tsne = TSNE(metric=metric, method="exact", perplexity=1)
+ X = np.array([[0.0, 0.0], [1.0, 1.0]])
+ with pytest.raises(ValueError, match="All distances .*metric given.*"):
+ tsne.fit_transform(X)
+
+
+def test_init_ndarray():
+ # Initialize TSNE with ndarray and test fit
+ tsne = TSNE(init=np.zeros((100, 2)), learning_rate="auto")
+ X_embedded = tsne.fit_transform(np.ones((100, 5)))
+ assert_array_equal(np.zeros((100, 2)), X_embedded)
+
+
+def test_init_ndarray_precomputed():
+ # Initialize TSNE with ndarray and metric 'precomputed'
+ # Make sure no FutureWarning is thrown from _fit
+ tsne = TSNE(
+ init=np.zeros((100, 2)),
+ metric="precomputed",
+ learning_rate=50.0,
+ )
+ tsne.fit(np.zeros((100, 100)))
+
+
+def test_pca_initialization_not_compatible_with_precomputed_kernel():
+ # Precomputed distance matrices cannot use PCA initialization.
+ tsne = TSNE(metric="precomputed", init="pca", perplexity=1)
+ with pytest.raises(
+ ValueError,
+ match='The parameter init="pca" cannot be used with metric="precomputed".',
+ ):
+ tsne.fit_transform(np.array([[0.0], [1.0]]))
+
+
+def test_pca_initialization_not_compatible_with_sparse_input():
+ # Sparse input matrices cannot use PCA initialization.
+ tsne = TSNE(init="pca", learning_rate=100.0, perplexity=1)
+ with pytest.raises(TypeError, match="PCA initialization.*"):
+ tsne.fit_transform(sp.csr_matrix([[0, 5], [5, 0]]))
+
+
+def test_n_components_range():
+ # barnes_hut method should only be used with n_components <= 3
+ tsne = TSNE(n_components=4, method="barnes_hut", perplexity=1)
+ with pytest.raises(ValueError, match="'n_components' should be .*"):
+ tsne.fit_transform(np.array([[0.0], [1.0]]))
+
+
+def test_early_exaggeration_used():
+ # check that the ``early_exaggeration`` parameter has an effect
+ random_state = check_random_state(0)
+ n_components = 2
+ methods = ["exact", "barnes_hut"]
+ X = random_state.randn(25, n_components).astype(np.float32)
+ for method in methods:
+ tsne = TSNE(
+ n_components=n_components,
+ perplexity=1,
+ learning_rate=100.0,
+ init="pca",
+ random_state=0,
+ method=method,
+ early_exaggeration=1.0,
+ n_iter=250,
+ )
+ X_embedded1 = tsne.fit_transform(X)
+ tsne = TSNE(
+ n_components=n_components,
+ perplexity=1,
+ learning_rate=100.0,
+ init="pca",
+ random_state=0,
+ method=method,
+ early_exaggeration=10.0,
+ n_iter=250,
+ )
+ X_embedded2 = tsne.fit_transform(X)
+
+ assert not np.allclose(X_embedded1, X_embedded2)
+
+
+def test_n_iter_used():
+ # check that the ``n_iter`` parameter has an effect
+ random_state = check_random_state(0)
+ n_components = 2
+ methods = ["exact", "barnes_hut"]
+ X = random_state.randn(25, n_components).astype(np.float32)
+ for method in methods:
+ for n_iter in [251, 500]:
+ tsne = TSNE(
+ n_components=n_components,
+ perplexity=1,
+ learning_rate=0.5,
+ init="random",
+ random_state=0,
+ method=method,
+ early_exaggeration=1.0,
+ n_iter=n_iter,
+ )
+ tsne.fit_transform(X)
+
+ assert tsne.n_iter_ == n_iter - 1
+
+
+def test_answer_gradient_two_points():
+ # Test the tree with only a single set of children.
+ #
+ # These tests & answers have been checked against the reference
+ # implementation by LvdM.
+ pos_input = np.array([[1.0, 0.0], [0.0, 1.0]])
+ pos_output = np.array(
+ [[-4.961291e-05, -1.072243e-04], [9.259460e-05, 2.702024e-04]]
+ )
+ neighbors = np.array([[1], [0]])
+ grad_output = np.array(
+ [[-2.37012478e-05, -6.29044398e-05], [2.37012478e-05, 6.29044398e-05]]
+ )
+ _run_answer_test(pos_input, pos_output, neighbors, grad_output)
+
+
+def test_answer_gradient_four_points():
+ # Four points tests the tree with multiple levels of children.
+ #
+ # These tests & answers have been checked against the reference
+ # implementation by LvdM.
+ pos_input = np.array([[1.0, 0.0], [0.0, 1.0], [5.0, 2.0], [7.3, 2.2]])
+ pos_output = np.array(
+ [
+ [6.080564e-05, -7.120823e-05],
+ [-1.718945e-04, -4.000536e-05],
+ [-2.271720e-04, 8.663310e-05],
+ [-1.032577e-04, -3.582033e-05],
+ ]
+ )
+ neighbors = np.array([[1, 2, 3], [0, 2, 3], [1, 0, 3], [1, 2, 0]])
+ grad_output = np.array(
+ [
+ [5.81128448e-05, -7.78033454e-06],
+ [-5.81526851e-05, 7.80976444e-06],
+ [4.24275173e-08, -3.69569698e-08],
+ [-2.58720939e-09, 7.52706374e-09],
+ ]
+ )
+ _run_answer_test(pos_input, pos_output, neighbors, grad_output)
+
+
+def test_skip_num_points_gradient():
+ # Test the kwargs option skip_num_points.
+ #
+ # Skip num points should make it such that the Barnes_hut gradient
+ # is not calculated for indices below skip_num_point.
+ # Aside from skip_num_points=2 and the first two gradient rows
+ # being set to zero, these data points are the same as in
+ # test_answer_gradient_four_points()
+ pos_input = np.array([[1.0, 0.0], [0.0, 1.0], [5.0, 2.0], [7.3, 2.2]])
+ pos_output = np.array(
+ [
+ [6.080564e-05, -7.120823e-05],
+ [-1.718945e-04, -4.000536e-05],
+ [-2.271720e-04, 8.663310e-05],
+ [-1.032577e-04, -3.582033e-05],
+ ]
+ )
+ neighbors = np.array([[1, 2, 3], [0, 2, 3], [1, 0, 3], [1, 2, 0]])
+ grad_output = np.array(
+ [
+ [0.0, 0.0],
+ [0.0, 0.0],
+ [4.24275173e-08, -3.69569698e-08],
+ [-2.58720939e-09, 7.52706374e-09],
+ ]
+ )
+ _run_answer_test(pos_input, pos_output, neighbors, grad_output, False, 0.1, 2)
+
+
+def _run_answer_test(
+ pos_input,
+ pos_output,
+ neighbors,
+ grad_output,
+ verbose=False,
+ perplexity=0.1,
+ skip_num_points=0,
+):
+ distances = pairwise_distances(pos_input).astype(np.float32)
+ args = distances, perplexity, verbose
+ pos_output = pos_output.astype(np.float32)
+ neighbors = neighbors.astype(np.int64, copy=False)
+ pij_input = _joint_probabilities(*args)
+ pij_input = squareform(pij_input).astype(np.float32)
+ grad_bh = np.zeros(pos_output.shape, dtype=np.float32)
+
+ from scipy.sparse import csr_matrix
+
+ P = csr_matrix(pij_input)
+
+ neighbors = P.indices.astype(np.int64)
+ indptr = P.indptr.astype(np.int64)
+
+ _barnes_hut_tsne.gradient(
+ P.data, pos_output, neighbors, indptr, grad_bh, 0.5, 2, 1, skip_num_points=0
+ )
+ assert_array_almost_equal(grad_bh, grad_output, decimal=4)
+
+
+def test_verbose():
+ # Verbose options write to stdout.
+ random_state = check_random_state(0)
+ tsne = TSNE(verbose=2, perplexity=4)
+ X = random_state.randn(5, 2)
+
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ tsne.fit_transform(X)
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+
+ assert "[t-SNE]" in out
+ assert "nearest neighbors..." in out
+ assert "Computed conditional probabilities" in out
+ assert "Mean sigma" in out
+ assert "early exaggeration" in out
+
+
+def test_chebyshev_metric():
+ # t-SNE should allow metrics that cannot be squared (issue #3526).
+ random_state = check_random_state(0)
+ tsne = TSNE(metric="chebyshev", perplexity=4)
+ X = random_state.randn(5, 2)
+ tsne.fit_transform(X)
+
+
+def test_reduction_to_one_component():
+ # t-SNE should allow reduction to one component (issue #4154).
+ random_state = check_random_state(0)
+ tsne = TSNE(n_components=1, perplexity=4)
+ X = random_state.randn(5, 2)
+ X_embedded = tsne.fit(X).embedding_
+ assert np.all(np.isfinite(X_embedded))
+
+
+@pytest.mark.parametrize("method", ["barnes_hut", "exact"])
+@pytest.mark.parametrize("dt", [np.float32, np.float64])
+def test_64bit(method, dt):
+ # Ensure 64bit arrays are handled correctly.
+ random_state = check_random_state(0)
+
+ X = random_state.randn(10, 2).astype(dt, copy=False)
+ tsne = TSNE(
+ n_components=2,
+ perplexity=2,
+ learning_rate=100.0,
+ random_state=0,
+ method=method,
+ verbose=0,
+ n_iter=300,
+ init="random",
+ )
+ X_embedded = tsne.fit_transform(X)
+ effective_type = X_embedded.dtype
+
+ # tsne cython code is only single precision, so the output will
+ # always be single precision, irrespectively of the input dtype
+ assert effective_type == np.float32
+
+
+@pytest.mark.parametrize("method", ["barnes_hut", "exact"])
+def test_kl_divergence_not_nan(method):
+ # Ensure kl_divergence_ is computed at last iteration
+ # even though n_iter % n_iter_check != 0, i.e. 1003 % 50 != 0
+ random_state = check_random_state(0)
+
+ X = random_state.randn(50, 2)
+ tsne = TSNE(
+ n_components=2,
+ perplexity=2,
+ learning_rate=100.0,
+ random_state=0,
+ method=method,
+ verbose=0,
+ n_iter=503,
+ init="random",
+ )
+ tsne.fit_transform(X)
+
+ assert not np.isnan(tsne.kl_divergence_)
+
+
+def test_barnes_hut_angle():
+ # When Barnes-Hut's angle=0 this corresponds to the exact method.
+ angle = 0.0
+ perplexity = 10
+ n_samples = 100
+ for n_components in [2, 3]:
+ n_features = 5
+ degrees_of_freedom = float(n_components - 1.0)
+
+ random_state = check_random_state(0)
+ data = random_state.randn(n_samples, n_features)
+ distances = pairwise_distances(data)
+ params = random_state.randn(n_samples, n_components)
+ P = _joint_probabilities(distances, perplexity, verbose=0)
+ kl_exact, grad_exact = _kl_divergence(
+ params, P, degrees_of_freedom, n_samples, n_components
+ )
+
+ n_neighbors = n_samples - 1
+ distances_csr = (
+ NearestNeighbors()
+ .fit(data)
+ .kneighbors_graph(n_neighbors=n_neighbors, mode="distance")
+ )
+ P_bh = _joint_probabilities_nn(distances_csr, perplexity, verbose=0)
+ kl_bh, grad_bh = _kl_divergence_bh(
+ params,
+ P_bh,
+ degrees_of_freedom,
+ n_samples,
+ n_components,
+ angle=angle,
+ skip_num_points=0,
+ verbose=0,
+ )
+
+ P = squareform(P)
+ P_bh = P_bh.toarray()
+ assert_array_almost_equal(P_bh, P, decimal=5)
+ assert_almost_equal(kl_exact, kl_bh, decimal=3)
+
+
+@skip_if_32bit
+def test_n_iter_without_progress():
+ # Use a dummy negative n_iter_without_progress and check output on stdout
+ random_state = check_random_state(0)
+ X = random_state.randn(100, 10)
+ for method in ["barnes_hut", "exact"]:
+ tsne = TSNE(
+ n_iter_without_progress=-1,
+ verbose=2,
+ learning_rate=1e8,
+ random_state=0,
+ method=method,
+ n_iter=351,
+ init="random",
+ )
+ tsne._N_ITER_CHECK = 1
+ tsne._EXPLORATION_N_ITER = 0
+
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ tsne.fit_transform(X)
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+
+ # The output needs to contain the value of n_iter_without_progress
+ assert "did not make any progress during the last -1 episodes. Finished." in out
+
+
+def test_min_grad_norm():
+ # Make sure that the parameter min_grad_norm is used correctly
+ random_state = check_random_state(0)
+ X = random_state.randn(100, 2)
+ min_grad_norm = 0.002
+ tsne = TSNE(min_grad_norm=min_grad_norm, verbose=2, random_state=0, method="exact")
+
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ tsne.fit_transform(X)
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+
+ lines_out = out.split("\n")
+
+ # extract the gradient norm from the verbose output
+ gradient_norm_values = []
+ for line in lines_out:
+ # When the computation is Finished just an old gradient norm value
+ # is repeated that we do not need to store
+ if "Finished" in line:
+ break
+
+ start_grad_norm = line.find("gradient norm")
+ if start_grad_norm >= 0:
+ line = line[start_grad_norm:]
+ line = line.replace("gradient norm = ", "").split(" ")[0]
+ gradient_norm_values.append(float(line))
+
+ # Compute how often the gradient norm is smaller than min_grad_norm
+ gradient_norm_values = np.array(gradient_norm_values)
+ n_smaller_gradient_norms = len(
+ gradient_norm_values[gradient_norm_values <= min_grad_norm]
+ )
+
+ # The gradient norm can be smaller than min_grad_norm at most once,
+ # because in the moment it becomes smaller the optimization stops
+ assert n_smaller_gradient_norms <= 1
+
+
+def test_accessible_kl_divergence():
+ # Ensures that the accessible kl_divergence matches the computed value
+ random_state = check_random_state(0)
+ X = random_state.randn(50, 2)
+ tsne = TSNE(
+ n_iter_without_progress=2, verbose=2, random_state=0, method="exact", n_iter=500
+ )
+
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ tsne.fit_transform(X)
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+
+ # The output needs to contain the accessible kl_divergence as the error at
+ # the last iteration
+ for line in out.split("\n")[::-1]:
+ if "Iteration" in line:
+ _, _, error = line.partition("error = ")
+ if error:
+ error, _, _ = error.partition(",")
+ break
+ assert_almost_equal(tsne.kl_divergence_, float(error), decimal=5)
+
+
+@pytest.mark.parametrize("method", ["barnes_hut", "exact"])
+def test_uniform_grid(method):
+ """Make sure that TSNE can approximately recover a uniform 2D grid
+
+ Due to ties in distances between point in X_2d_grid, this test is platform
+ dependent for ``method='barnes_hut'`` due to numerical imprecision.
+
+ Also, t-SNE is not assured to converge to the right solution because bad
+ initialization can lead to convergence to bad local minimum (the
+ optimization problem is non-convex). To avoid breaking the test too often,
+ we re-run t-SNE from the final point when the convergence is not good
+ enough.
+ """
+ seeds = range(3)
+ n_iter = 500
+ for seed in seeds:
+ tsne = TSNE(
+ n_components=2,
+ init="random",
+ random_state=seed,
+ perplexity=50,
+ n_iter=n_iter,
+ method=method,
+ learning_rate="auto",
+ )
+ Y = tsne.fit_transform(X_2d_grid)
+
+ try_name = "{}_{}".format(method, seed)
+ try:
+ assert_uniform_grid(Y, try_name)
+ except AssertionError:
+ # If the test fails a first time, re-run with init=Y to see if
+ # this was caused by a bad initialization. Note that this will
+ # also run an early_exaggeration step.
+ try_name += ":rerun"
+ tsne.init = Y
+ Y = tsne.fit_transform(X_2d_grid)
+ assert_uniform_grid(Y, try_name)
+
+
+def assert_uniform_grid(Y, try_name=None):
+ # Ensure that the resulting embedding leads to approximately
+ # uniformly spaced points: the distance to the closest neighbors
+ # should be non-zero and approximately constant.
+ nn = NearestNeighbors(n_neighbors=1).fit(Y)
+ dist_to_nn = nn.kneighbors(return_distance=True)[0].ravel()
+ assert dist_to_nn.min() > 0.1
+
+ smallest_to_mean = dist_to_nn.min() / np.mean(dist_to_nn)
+ largest_to_mean = dist_to_nn.max() / np.mean(dist_to_nn)
+
+ assert smallest_to_mean > 0.5, try_name
+ assert largest_to_mean < 2, try_name
+
+
+def test_bh_match_exact():
+ # check that the ``barnes_hut`` method match the exact one when
+ # ``angle = 0`` and ``perplexity > n_samples / 3``
+ random_state = check_random_state(0)
+ n_features = 10
+ X = random_state.randn(30, n_features).astype(np.float32)
+ X_embeddeds = {}
+ n_iter = {}
+ for method in ["exact", "barnes_hut"]:
+ tsne = TSNE(
+ n_components=2,
+ method=method,
+ learning_rate=1.0,
+ init="random",
+ random_state=0,
+ n_iter=251,
+ perplexity=29.5,
+ angle=0,
+ )
+ # Kill the early_exaggeration
+ tsne._EXPLORATION_N_ITER = 0
+ X_embeddeds[method] = tsne.fit_transform(X)
+ n_iter[method] = tsne.n_iter_
+
+ assert n_iter["exact"] == n_iter["barnes_hut"]
+ assert_allclose(X_embeddeds["exact"], X_embeddeds["barnes_hut"], rtol=1e-4)
+
+
+def test_gradient_bh_multithread_match_sequential():
+ # check that the bh gradient with different num_threads gives the same
+ # results
+
+ n_features = 10
+ n_samples = 30
+ n_components = 2
+ degrees_of_freedom = 1
+
+ angle = 3
+ perplexity = 5
+
+ random_state = check_random_state(0)
+ data = random_state.randn(n_samples, n_features).astype(np.float32)
+ params = random_state.randn(n_samples, n_components)
+
+ n_neighbors = n_samples - 1
+ distances_csr = (
+ NearestNeighbors()
+ .fit(data)
+ .kneighbors_graph(n_neighbors=n_neighbors, mode="distance")
+ )
+ P_bh = _joint_probabilities_nn(distances_csr, perplexity, verbose=0)
+ kl_sequential, grad_sequential = _kl_divergence_bh(
+ params,
+ P_bh,
+ degrees_of_freedom,
+ n_samples,
+ n_components,
+ angle=angle,
+ skip_num_points=0,
+ verbose=0,
+ num_threads=1,
+ )
+ for num_threads in [2, 4]:
+ kl_multithread, grad_multithread = _kl_divergence_bh(
+ params,
+ P_bh,
+ degrees_of_freedom,
+ n_samples,
+ n_components,
+ angle=angle,
+ skip_num_points=0,
+ verbose=0,
+ num_threads=num_threads,
+ )
+
+ assert_allclose(kl_multithread, kl_sequential, rtol=1e-6)
+ assert_allclose(grad_multithread, grad_multithread)
+
+
+@pytest.mark.parametrize(
+ "metric, dist_func",
+ [("manhattan", manhattan_distances), ("cosine", cosine_distances)],
+)
+@pytest.mark.parametrize("method", ["barnes_hut", "exact"])
+def test_tsne_with_different_distance_metrics(metric, dist_func, method):
+ """Make sure that TSNE works for different distance metrics"""
+
+ if method == "barnes_hut" and metric == "manhattan":
+ # The distances computed by `manhattan_distances` differ slightly from those
+ # computed internally by NearestNeighbors via the PairwiseDistancesReduction
+ # Cython code-based. This in turns causes T-SNE to converge to a different
+ # solution but this should not impact the qualitative results as both
+ # methods.
+ # NOTE: it's probably not valid from a mathematical point of view to use the
+ # Manhattan distance for T-SNE...
+ # TODO: re-enable this test if/when `manhattan_distances` is refactored to
+ # reuse the same underlying Cython code NearestNeighbors.
+ # For reference, see:
+ # https://github.com/scikit-learn/scikit-learn/pull/23865/files#r925721573
+ pytest.xfail(
+ "Distance computations are different for method == 'barnes_hut' and metric"
+ " == 'manhattan', but this is expected."
+ )
+
+ random_state = check_random_state(0)
+ n_components_original = 3
+ n_components_embedding = 2
+ X = random_state.randn(50, n_components_original).astype(np.float32)
+ X_transformed_tsne = TSNE(
+ metric=metric,
+ method=method,
+ n_components=n_components_embedding,
+ random_state=0,
+ n_iter=300,
+ init="random",
+ learning_rate="auto",
+ ).fit_transform(X)
+ X_transformed_tsne_precomputed = TSNE(
+ metric="precomputed",
+ method=method,
+ n_components=n_components_embedding,
+ random_state=0,
+ n_iter=300,
+ init="random",
+ learning_rate="auto",
+ ).fit_transform(dist_func(X))
+ assert_array_equal(X_transformed_tsne, X_transformed_tsne_precomputed)
+
+
+@pytest.mark.parametrize("method", ["exact", "barnes_hut"])
+def test_tsne_n_jobs(method):
+ """Make sure that the n_jobs parameter doesn't impact the output"""
+ random_state = check_random_state(0)
+ n_features = 10
+ X = random_state.randn(30, n_features)
+ X_tr_ref = TSNE(
+ n_components=2,
+ method=method,
+ perplexity=25.0,
+ angle=0,
+ n_jobs=1,
+ random_state=0,
+ init="random",
+ learning_rate="auto",
+ ).fit_transform(X)
+ X_tr = TSNE(
+ n_components=2,
+ method=method,
+ perplexity=25.0,
+ angle=0,
+ n_jobs=2,
+ random_state=0,
+ init="random",
+ learning_rate="auto",
+ ).fit_transform(X)
+
+ assert_allclose(X_tr_ref, X_tr)
+
+
+def test_tsne_with_mahalanobis_distance():
+ """Make sure that method_parameters works with mahalanobis distance."""
+ random_state = check_random_state(0)
+ n_samples, n_features = 300, 10
+ X = random_state.randn(n_samples, n_features)
+ default_params = {
+ "perplexity": 40,
+ "n_iter": 250,
+ "learning_rate": "auto",
+ "init": "random",
+ "n_components": 3,
+ "random_state": 0,
+ }
+
+ tsne = TSNE(metric="mahalanobis", **default_params)
+ msg = "Must provide either V or VI for Mahalanobis distance"
+ with pytest.raises(ValueError, match=msg):
+ tsne.fit_transform(X)
+
+ precomputed_X = squareform(pdist(X, metric="mahalanobis"), checks=True)
+ X_trans_expected = TSNE(metric="precomputed", **default_params).fit_transform(
+ precomputed_X
+ )
+
+ X_trans = TSNE(
+ metric="mahalanobis", metric_params={"V": np.cov(X.T)}, **default_params
+ ).fit_transform(X)
+ assert_allclose(X_trans, X_trans_expected)
+
+
+# FIXME: remove in 1.3 after deprecation of `square_distances`
+def test_tsne_deprecation_square_distances():
+ """Check that we raise a warning regarding the removal of
+ `square_distances`.
+
+ Also check the parameters do not have any effect.
+ """
+ random_state = check_random_state(0)
+ X = random_state.randn(30, 10)
+ tsne = TSNE(
+ n_components=2,
+ init="pca",
+ learning_rate="auto",
+ perplexity=25.0,
+ angle=0,
+ n_jobs=1,
+ random_state=0,
+ square_distances=True,
+ )
+ warn_msg = (
+ "The parameter `square_distances` has not effect and will be removed in"
+ " version 1.3"
+ )
+ with pytest.warns(FutureWarning, match=warn_msg):
+ X_trans_1 = tsne.fit_transform(X)
+
+ tsne = TSNE(
+ n_components=2,
+ init="pca",
+ learning_rate="auto",
+ perplexity=25.0,
+ angle=0,
+ n_jobs=1,
+ random_state=0,
+ )
+ X_trans_2 = tsne.fit_transform(X)
+ assert_allclose(X_trans_1, X_trans_2)
+
+
+@pytest.mark.parametrize("perplexity", (20, 30))
+def test_tsne_perplexity_validation(perplexity):
+ """Make sure that perplexity > n_samples results in a ValueError"""
+
+ random_state = check_random_state(0)
+ X = random_state.randn(20, 2)
+ est = TSNE(
+ learning_rate="auto",
+ init="pca",
+ perplexity=perplexity,
+ random_state=random_state,
+ )
+ msg = "perplexity must be less than n_samples"
+ with pytest.raises(ValueError, match=msg):
+ est.fit_transform(X)
+
+
+def test_tsne_works_with_pandas_output():
+ """Make sure that TSNE works when the output is set to "pandas".
+
+ Non-regression test for gh-25365.
+ """
+ pytest.importorskip("modin.pandas")
+ with config_context(transform_output="pandas"):
+ arr = np.arange(35 * 4).reshape(35, 4)
+ TSNE(n_components=2).fit_transform(arr)
diff --git a/modin/pandas/test/interoperability/sklearn/metrics/test_classification.py b/modin/pandas/test/interoperability/sklearn/metrics/test_classification.py
new file mode 100644
index 00000000000..93d1d99a0c3
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/metrics/test_classification.py
@@ -0,0 +1,2644 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+from functools import partial
+from itertools import product
+from itertools import chain
+from itertools import permutations
+import warnings
+import re
+
+import numpy as np
+from scipy import linalg
+from scipy.stats import bernoulli
+import pytest
+
+from sklearn import datasets
+from sklearn import svm
+
+from sklearn.datasets import make_multilabel_classification
+from sklearn.preprocessing import label_binarize, LabelBinarizer
+from sklearn.utils.validation import check_random_state
+from sklearn.utils._testing import assert_almost_equal
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_array_almost_equal
+from sklearn.utils._testing import assert_allclose
+from sklearn.utils._testing import assert_no_warnings
+from sklearn.utils._testing import ignore_warnings
+from sklearn.utils._mocking import MockDataFrame
+
+from sklearn.metrics import accuracy_score
+from sklearn.metrics import average_precision_score
+from sklearn.metrics import balanced_accuracy_score
+from sklearn.metrics import class_likelihood_ratios
+from sklearn.metrics import classification_report
+from sklearn.metrics import cohen_kappa_score
+from sklearn.metrics import confusion_matrix
+from sklearn.metrics import f1_score
+from sklearn.metrics import fbeta_score
+from sklearn.metrics import hamming_loss
+from sklearn.metrics import hinge_loss
+from sklearn.metrics import jaccard_score
+from sklearn.metrics import log_loss
+from sklearn.metrics import matthews_corrcoef
+from sklearn.metrics import precision_recall_fscore_support
+from sklearn.metrics import precision_score
+from sklearn.metrics import recall_score
+from sklearn.metrics import zero_one_loss
+from sklearn.metrics import brier_score_loss
+from sklearn.metrics import multilabel_confusion_matrix
+
+from sklearn.metrics._classification import _check_targets
+from sklearn.exceptions import UndefinedMetricWarning
+
+from scipy.spatial.distance import hamming as sp_hamming
+
+###############################################################################
+# Utilities for testing
+
+
+def make_prediction(dataset=None, binary=False):
+ """Make some classification predictions on a toy dataset using a SVC
+
+ If binary is True restrict to a binary classification problem instead of a
+ multiclass classification problem
+ """
+
+ if dataset is None:
+ # import some data to play with
+ dataset = datasets.load_iris()
+
+ X = dataset.data
+ y = dataset.target
+
+ if binary:
+ # restrict to a binary classification task
+ X, y = X[y < 2], y[y < 2]
+
+ n_samples, n_features = X.shape
+ p = np.arange(n_samples)
+
+ rng = check_random_state(37)
+ rng.shuffle(p)
+ X, y = X[p], y[p]
+ half = int(n_samples / 2)
+
+ # add noisy features to make the problem harder and avoid perfect results
+ rng = np.random.RandomState(0)
+ X = np.c_[X, rng.randn(n_samples, 200 * n_features)]
+
+ # run classifier, get class probabilities and label predictions
+ clf = svm.SVC(kernel="linear", probability=True, random_state=0)
+ probas_pred = clf.fit(X[:half], y[:half]).predict_proba(X[half:])
+
+ if binary:
+ # only interested in probabilities of the positive case
+ # XXX: do we really want a special API for the binary case?
+ probas_pred = probas_pred[:, 1]
+
+ y_pred = clf.predict(X[half:])
+ y_true = y[half:]
+ return y_true, y_pred, probas_pred
+
+
+###############################################################################
+# Tests
+
+
+def test_classification_report_dictionary_output():
+ # Test performance report with dictionary output
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+
+ # print classification report with class names
+ expected_report = {
+ "setosa": {
+ "precision": 0.82608695652173914,
+ "recall": 0.79166666666666663,
+ "f1-score": 0.8085106382978724,
+ "support": 24,
+ },
+ "versicolor": {
+ "precision": 0.33333333333333331,
+ "recall": 0.096774193548387094,
+ "f1-score": 0.15000000000000002,
+ "support": 31,
+ },
+ "virginica": {
+ "precision": 0.41860465116279072,
+ "recall": 0.90000000000000002,
+ "f1-score": 0.57142857142857151,
+ "support": 20,
+ },
+ "macro avg": {
+ "f1-score": 0.5099797365754813,
+ "precision": 0.5260083136726211,
+ "recall": 0.596146953405018,
+ "support": 75,
+ },
+ "accuracy": 0.5333333333333333,
+ "weighted avg": {
+ "f1-score": 0.47310435663627154,
+ "precision": 0.5137535108414785,
+ "recall": 0.5333333333333333,
+ "support": 75,
+ },
+ }
+
+ report = classification_report(
+ y_true,
+ y_pred,
+ labels=np.arange(len(iris.target_names)),
+ target_names=iris.target_names,
+ output_dict=True,
+ )
+
+ # assert the 2 dicts are equal.
+ assert report.keys() == expected_report.keys()
+ for key in expected_report:
+ if key == "accuracy":
+ assert isinstance(report[key], float)
+ assert report[key] == expected_report[key]
+ else:
+ assert report[key].keys() == expected_report[key].keys()
+ for metric in expected_report[key]:
+ assert_almost_equal(expected_report[key][metric], report[key][metric])
+
+ assert type(expected_report["setosa"]["precision"]) == float
+ assert type(expected_report["macro avg"]["precision"]) == float
+ assert type(expected_report["setosa"]["support"]) == int
+ assert type(expected_report["macro avg"]["support"]) == int
+
+
+def test_classification_report_output_dict_empty_input():
+ report = classification_report(y_true=[], y_pred=[], output_dict=True)
+ expected_report = {
+ "accuracy": 0.0,
+ "macro avg": {
+ "f1-score": np.nan,
+ "precision": np.nan,
+ "recall": np.nan,
+ "support": 0,
+ },
+ "weighted avg": {
+ "f1-score": 0.0,
+ "precision": 0.0,
+ "recall": 0.0,
+ "support": 0,
+ },
+ }
+ assert isinstance(report, dict)
+ # assert the 2 dicts are equal.
+ assert report.keys() == expected_report.keys()
+ for key in expected_report:
+ if key == "accuracy":
+ assert isinstance(report[key], float)
+ assert report[key] == expected_report[key]
+ else:
+ assert report[key].keys() == expected_report[key].keys()
+ for metric in expected_report[key]:
+ assert_almost_equal(expected_report[key][metric], report[key][metric])
+
+
+@pytest.mark.parametrize("zero_division", ["warn", 0, 1])
+def test_classification_report_zero_division_warning(zero_division):
+ y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"]
+ with warnings.catch_warnings(record=True) as record:
+ classification_report(
+ y_true, y_pred, zero_division=zero_division, output_dict=True
+ )
+ if zero_division == "warn":
+ assert len(record) > 1
+ for item in record:
+ msg = "Use `zero_division` parameter to control this behavior."
+ assert msg in str(item.message)
+ else:
+ assert not record
+
+
+def test_multilabel_accuracy_score_subset_accuracy():
+ # Dense label indicator matrix format
+ y1 = np.array([[0, 1, 1], [1, 0, 1]])
+ y2 = np.array([[0, 0, 1], [1, 0, 1]])
+
+ assert accuracy_score(y1, y2) == 0.5
+ assert accuracy_score(y1, y1) == 1
+ assert accuracy_score(y2, y2) == 1
+ assert accuracy_score(y2, np.logical_not(y2)) == 0
+ assert accuracy_score(y1, np.logical_not(y1)) == 0
+ assert accuracy_score(y1, np.zeros(y1.shape)) == 0
+ assert accuracy_score(y2, np.zeros(y1.shape)) == 0
+
+
+def test_precision_recall_f1_score_binary():
+ # Test Precision Recall and F1 Score for binary classification task
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ # detailed measures for each class
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
+ assert_array_almost_equal(p, [0.73, 0.85], 2)
+ assert_array_almost_equal(r, [0.88, 0.68], 2)
+ assert_array_almost_equal(f, [0.80, 0.76], 2)
+ assert_array_equal(s, [25, 25])
+
+ # individual scoring function that can be used for grid search: in the
+ # binary class case the score is the value of the measure for the positive
+ # class (e.g. label == 1). This is deprecated for average != 'binary'.
+ for kwargs, my_assert in [
+ ({}, assert_no_warnings),
+ ({"average": "binary"}, assert_no_warnings),
+ ]:
+ ps = my_assert(precision_score, y_true, y_pred, **kwargs)
+ assert_array_almost_equal(ps, 0.85, 2)
+
+ rs = my_assert(recall_score, y_true, y_pred, **kwargs)
+ assert_array_almost_equal(rs, 0.68, 2)
+
+ fs = my_assert(f1_score, y_true, y_pred, **kwargs)
+ assert_array_almost_equal(fs, 0.76, 2)
+
+ assert_almost_equal(
+ my_assert(fbeta_score, y_true, y_pred, beta=2, **kwargs),
+ (1 + 2**2) * ps * rs / (2**2 * ps + rs),
+ 2,
+ )
+
+
+@ignore_warnings
+def test_precision_recall_f_binary_single_class():
+ # Test precision, recall and F-scores behave with a single positive or
+ # negative class
+ # Such a case may occur with non-stratified cross-validation
+ assert 1.0 == precision_score([1, 1], [1, 1])
+ assert 1.0 == recall_score([1, 1], [1, 1])
+ assert 1.0 == f1_score([1, 1], [1, 1])
+ assert 1.0 == fbeta_score([1, 1], [1, 1], beta=0)
+
+ assert 0.0 == precision_score([-1, -1], [-1, -1])
+ assert 0.0 == recall_score([-1, -1], [-1, -1])
+ assert 0.0 == f1_score([-1, -1], [-1, -1])
+ assert 0.0 == fbeta_score([-1, -1], [-1, -1], beta=float("inf"))
+ assert fbeta_score([-1, -1], [-1, -1], beta=float("inf")) == pytest.approx(
+ fbeta_score([-1, -1], [-1, -1], beta=1e5)
+ )
+
+
+@ignore_warnings
+def test_precision_recall_f_extra_labels():
+ # Test handling of explicit additional (not in input) labels to PRF
+ y_true = [1, 3, 3, 2]
+ y_pred = [1, 1, 3, 2]
+ y_true_bin = label_binarize(y_true, classes=np.arange(5))
+ y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
+ data = [(y_true, y_pred), (y_true_bin, y_pred_bin)]
+
+ for i, (y_true, y_pred) in enumerate(data):
+ # No average: zeros in array
+ actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average=None)
+ assert_array_almost_equal([0.0, 1.0, 1.0, 0.5, 0.0], actual)
+
+ # Macro average is changed
+ actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average="macro")
+ assert_array_almost_equal(np.mean([0.0, 1.0, 1.0, 0.5, 0.0]), actual)
+
+ # No effect otherwise
+ for average in ["micro", "weighted", "samples"]:
+ if average == "samples" and i == 0:
+ continue
+ assert_almost_equal(
+ recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average=average),
+ recall_score(y_true, y_pred, labels=None, average=average),
+ )
+
+ # Error when introducing invalid label in multilabel case
+ # (although it would only affect performance if average='macro'/None)
+ for average in [None, "macro", "micro", "samples"]:
+ with pytest.raises(ValueError):
+ recall_score(y_true_bin, y_pred_bin, labels=np.arange(6), average=average)
+ with pytest.raises(ValueError):
+ recall_score(
+ y_true_bin, y_pred_bin, labels=np.arange(-1, 4), average=average
+ )
+
+ # tests non-regression on issue #10307
+ y_true = np.array([[0, 1, 1], [1, 0, 0]])
+ y_pred = np.array([[1, 1, 1], [1, 0, 1]])
+ p, r, f, _ = precision_recall_fscore_support(
+ y_true, y_pred, average="samples", labels=[0, 1]
+ )
+ assert_almost_equal(np.array([p, r, f]), np.array([3 / 4, 1, 5 / 6]))
+
+
+@ignore_warnings
+def test_precision_recall_f_ignored_labels():
+ # Test a subset of labels may be requested for PRF
+ y_true = [1, 1, 2, 3]
+ y_pred = [1, 3, 3, 3]
+ y_true_bin = label_binarize(y_true, classes=np.arange(5))
+ y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
+ data = [(y_true, y_pred), (y_true_bin, y_pred_bin)]
+
+ for i, (y_true, y_pred) in enumerate(data):
+ recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3])
+ recall_all = partial(recall_score, y_true, y_pred, labels=None)
+
+ assert_array_almost_equal([0.5, 1.0], recall_13(average=None))
+ assert_almost_equal((0.5 + 1.0) / 2, recall_13(average="macro"))
+ assert_almost_equal((0.5 * 2 + 1.0 * 1) / 3, recall_13(average="weighted"))
+ assert_almost_equal(2.0 / 3, recall_13(average="micro"))
+
+ # ensure the above were meaningful tests:
+ for average in ["macro", "weighted", "micro"]:
+ assert recall_13(average=average) != recall_all(average=average)
+
+
+def test_average_precision_score_score_non_binary_class():
+ # Test that average_precision_score function returns an error when trying
+ # to compute average_precision_score for multiclass task.
+ rng = check_random_state(404)
+ y_pred = rng.rand(10)
+
+ # y_true contains three different class values
+ y_true = rng.randint(0, 3, size=10)
+ err_msg = "multiclass format is not supported"
+ with pytest.raises(ValueError, match=err_msg):
+ average_precision_score(y_true, y_pred)
+
+
+def test_average_precision_score_duplicate_values():
+ # Duplicate values with precision-recall require a different
+ # processing than when computing the AUC of a ROC, because the
+ # precision-recall curve is a decreasing curve
+ # The following situation corresponds to a perfect
+ # test statistic, the average_precision_score should be 1
+ y_true = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
+ y_score = [0, 0.1, 0.1, 0.4, 0.5, 0.6, 0.6, 0.9, 0.9, 1, 1]
+ assert average_precision_score(y_true, y_score) == 1
+
+
+def test_average_precision_score_tied_values():
+ # Here if we go from left to right in y_true, the 0 values are
+ # are separated from the 1 values, so it appears that we've
+ # Correctly sorted our classifications. But in fact the first two
+ # values have the same score (0.5) and so the first two values
+ # could be swapped around, creating an imperfect sorting. This
+ # imperfection should come through in the end score, making it less
+ # than one.
+ y_true = [0, 1, 1]
+ y_score = [0.5, 0.5, 0.6]
+ assert average_precision_score(y_true, y_score) != 1.0
+
+
+@ignore_warnings
+def test_precision_recall_fscore_support_errors():
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ # Bad beta
+ with pytest.raises(ValueError):
+ precision_recall_fscore_support(y_true, y_pred, beta=-0.1)
+
+ # Bad pos_label
+ with pytest.raises(ValueError):
+ precision_recall_fscore_support(y_true, y_pred, pos_label=2, average="binary")
+
+ # Bad average option
+ with pytest.raises(ValueError):
+ precision_recall_fscore_support([0, 1, 2], [1, 2, 0], average="mega")
+
+
+def test_precision_recall_f_unused_pos_label():
+ # Check warning that pos_label unused when set to non-default value
+ # but average != 'binary'; even if data is binary.
+
+ msg = (
+ r"Note that pos_label \(set to 2\) is "
+ r"ignored when average != 'binary' \(got 'macro'\). You "
+ r"may use labels=\[pos_label\] to specify a single "
+ "positive class."
+ )
+ with pytest.warns(UserWarning, match=msg):
+ precision_recall_fscore_support(
+ [1, 2, 1], [1, 2, 2], pos_label=2, average="macro"
+ )
+
+
+def test_confusion_matrix_binary():
+ # Test confusion matrix - binary classification case
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ def test(y_true, y_pred):
+ cm = confusion_matrix(y_true, y_pred)
+ assert_array_equal(cm, [[22, 3], [8, 17]])
+
+ tp, fp, fn, tn = cm.flatten()
+ num = tp * tn - fp * fn
+ den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
+
+ true_mcc = 0 if den == 0 else num / den
+ mcc = matthews_corrcoef(y_true, y_pred)
+ assert_array_almost_equal(mcc, true_mcc, decimal=2)
+ assert_array_almost_equal(mcc, 0.57, decimal=2)
+
+ test(y_true, y_pred)
+ test([str(y) for y in y_true], [str(y) for y in y_pred])
+
+
+def test_multilabel_confusion_matrix_binary():
+ # Test multilabel confusion matrix - binary classification case
+ y_true, y_pred, _ = make_prediction(binary=True)
+
+ def test(y_true, y_pred):
+ cm = multilabel_confusion_matrix(y_true, y_pred)
+ assert_array_equal(cm, [[[17, 8], [3, 22]], [[22, 3], [8, 17]]])
+
+ test(y_true, y_pred)
+ test([str(y) for y in y_true], [str(y) for y in y_pred])
+
+
+def test_multilabel_confusion_matrix_multiclass():
+ # Test multilabel confusion matrix - multi-class case
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ def test(y_true, y_pred, string_type=False):
+ # compute confusion matrix with default labels introspection
+ cm = multilabel_confusion_matrix(y_true, y_pred)
+ assert_array_equal(
+ cm, [[[47, 4], [5, 19]], [[38, 6], [28, 3]], [[30, 25], [2, 18]]]
+ )
+
+ # compute confusion matrix with explicit label ordering
+ labels = ["0", "2", "1"] if string_type else [0, 2, 1]
+ cm = multilabel_confusion_matrix(y_true, y_pred, labels=labels)
+ assert_array_equal(
+ cm, [[[47, 4], [5, 19]], [[30, 25], [2, 18]], [[38, 6], [28, 3]]]
+ )
+
+ # compute confusion matrix with super set of present labels
+ labels = ["0", "2", "1", "3"] if string_type else [0, 2, 1, 3]
+ cm = multilabel_confusion_matrix(y_true, y_pred, labels=labels)
+ assert_array_equal(
+ cm,
+ [
+ [[47, 4], [5, 19]],
+ [[30, 25], [2, 18]],
+ [[38, 6], [28, 3]],
+ [[75, 0], [0, 0]],
+ ],
+ )
+
+ test(y_true, y_pred)
+ test([str(y) for y in y_true], [str(y) for y in y_pred], string_type=True)
+
+
+def test_multilabel_confusion_matrix_multilabel():
+ # Test multilabel confusion matrix - multilabel-indicator case
+ from scipy.sparse import csc_matrix, csr_matrix
+
+ y_true = np.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
+ y_pred = np.array([[1, 0, 0], [0, 1, 1], [0, 0, 1]])
+ y_true_csr = csr_matrix(y_true)
+ y_pred_csr = csr_matrix(y_pred)
+ y_true_csc = csc_matrix(y_true)
+ y_pred_csc = csc_matrix(y_pred)
+
+ # cross test different types
+ sample_weight = np.array([2, 1, 3])
+ real_cm = [[[1, 0], [1, 1]], [[1, 0], [1, 1]], [[0, 2], [1, 0]]]
+ trues = [y_true, y_true_csr, y_true_csc]
+ preds = [y_pred, y_pred_csr, y_pred_csc]
+
+ for y_true_tmp in trues:
+ for y_pred_tmp in preds:
+ cm = multilabel_confusion_matrix(y_true_tmp, y_pred_tmp)
+ assert_array_equal(cm, real_cm)
+
+ # test support for samplewise
+ cm = multilabel_confusion_matrix(y_true, y_pred, samplewise=True)
+ assert_array_equal(cm, [[[1, 0], [1, 1]], [[1, 1], [0, 1]], [[0, 1], [2, 0]]])
+
+ # test support for labels
+ cm = multilabel_confusion_matrix(y_true, y_pred, labels=[2, 0])
+ assert_array_equal(cm, [[[0, 2], [1, 0]], [[1, 0], [1, 1]]])
+
+ # test support for labels with samplewise
+ cm = multilabel_confusion_matrix(y_true, y_pred, labels=[2, 0], samplewise=True)
+ assert_array_equal(cm, [[[0, 0], [1, 1]], [[1, 1], [0, 0]], [[0, 1], [1, 0]]])
+
+ # test support for sample_weight with sample_wise
+ cm = multilabel_confusion_matrix(
+ y_true, y_pred, sample_weight=sample_weight, samplewise=True
+ )
+ assert_array_equal(cm, [[[2, 0], [2, 2]], [[1, 1], [0, 1]], [[0, 3], [6, 0]]])
+
+
+def test_multilabel_confusion_matrix_errors():
+ y_true = np.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
+ y_pred = np.array([[1, 0, 0], [0, 1, 1], [0, 0, 1]])
+
+ # Bad sample_weight
+ with pytest.raises(ValueError, match="inconsistent numbers of samples"):
+ multilabel_confusion_matrix(y_true, y_pred, sample_weight=[1, 2])
+ with pytest.raises(ValueError, match="should be a 1d array"):
+ multilabel_confusion_matrix(
+ y_true, y_pred, sample_weight=[[1, 2, 3], [2, 3, 4], [3, 4, 5]]
+ )
+
+ # Bad labels
+ err_msg = r"All labels must be in \[0, n labels\)"
+ with pytest.raises(ValueError, match=err_msg):
+ multilabel_confusion_matrix(y_true, y_pred, labels=[-1])
+ err_msg = r"All labels must be in \[0, n labels\)"
+ with pytest.raises(ValueError, match=err_msg):
+ multilabel_confusion_matrix(y_true, y_pred, labels=[3])
+
+ # Using samplewise outside multilabel
+ with pytest.raises(ValueError, match="Samplewise metrics"):
+ multilabel_confusion_matrix([0, 1, 2], [1, 2, 0], samplewise=True)
+
+ # Bad y_type
+ err_msg = "multiclass-multioutput is not supported"
+ with pytest.raises(ValueError, match=err_msg):
+ multilabel_confusion_matrix([[0, 1, 2], [2, 1, 0]], [[1, 2, 0], [1, 0, 2]])
+
+
+@pytest.mark.parametrize(
+ "normalize, cm_dtype, expected_results",
+ [
+ ("true", "f", 0.333333333),
+ ("pred", "f", 0.333333333),
+ ("all", "f", 0.1111111111),
+ (None, "i", 2),
+ ],
+)
+def test_confusion_matrix_normalize(normalize, cm_dtype, expected_results):
+ y_test = [0, 1, 2] * 6
+ y_pred = list(chain(*permutations([0, 1, 2])))
+ cm = confusion_matrix(y_test, y_pred, normalize=normalize)
+ assert_allclose(cm, expected_results)
+ assert cm.dtype.kind == cm_dtype
+
+
+def test_confusion_matrix_normalize_single_class():
+ y_test = [0, 0, 0, 0, 1, 1, 1, 1]
+ y_pred = [0, 0, 0, 0, 0, 0, 0, 0]
+
+ cm_true = confusion_matrix(y_test, y_pred, normalize="true")
+ assert cm_true.sum() == pytest.approx(2.0)
+
+ # additionally check that no warnings are raised due to a division by zero
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", RuntimeWarning)
+ cm_pred = confusion_matrix(y_test, y_pred, normalize="pred")
+
+ assert cm_pred.sum() == pytest.approx(1.0)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", RuntimeWarning)
+ confusion_matrix(y_pred, y_test, normalize="true")
+
+
+@pytest.mark.parametrize(
+ "params, warn_msg",
+ [
+ # When y_test contains one class only and y_test==y_pred, LR+ is undefined
+ (
+ {
+ "y_true": np.array([0, 0, 0, 0, 0, 0]),
+ "y_pred": np.array([0, 0, 0, 0, 0, 0]),
+ },
+ "samples of only one class were seen during testing",
+ ),
+ # When `fp == 0` and `tp != 0`, LR+ is undefined
+ (
+ {
+ "y_true": np.array([1, 1, 1, 0, 0, 0]),
+ "y_pred": np.array([1, 1, 1, 0, 0, 0]),
+ },
+ "positive_likelihood_ratio ill-defined and being set to nan",
+ ),
+ # When `fp == 0` and `tp == 0`, LR+ is undefined
+ (
+ {
+ "y_true": np.array([1, 1, 1, 0, 0, 0]),
+ "y_pred": np.array([0, 0, 0, 0, 0, 0]),
+ },
+ "no samples predicted for the positive class",
+ ),
+ # When `tn == 0`, LR- is undefined
+ (
+ {
+ "y_true": np.array([1, 1, 1, 0, 0, 0]),
+ "y_pred": np.array([0, 0, 0, 1, 1, 1]),
+ },
+ "negative_likelihood_ratio ill-defined and being set to nan",
+ ),
+ # When `tp + fn == 0` both ratios are undefined
+ (
+ {
+ "y_true": np.array([0, 0, 0, 0, 0, 0]),
+ "y_pred": np.array([1, 1, 1, 0, 0, 0]),
+ },
+ "no samples of the positive class were present in the testing set",
+ ),
+ ],
+)
+def test_likelihood_ratios_warnings(params, warn_msg):
+ # likelihood_ratios must raise warnings when at
+ # least one of the ratios is ill-defined.
+
+ with pytest.warns(UserWarning, match=warn_msg):
+ class_likelihood_ratios(**params)
+
+
+@pytest.mark.parametrize(
+ "params, err_msg",
+ [
+ (
+ {
+ "y_true": np.array([0, 1, 0, 1, 0]),
+ "y_pred": np.array([1, 1, 0, 0, 2]),
+ },
+ "class_likelihood_ratios only supports binary classification "
+ "problems, got targets of type: multiclass",
+ ),
+ ],
+)
+def test_likelihood_ratios_errors(params, err_msg):
+ # likelihood_ratios must raise error when attempting
+ # non-binary classes to avoid Simpson's paradox
+ with pytest.raises(ValueError, match=err_msg):
+ class_likelihood_ratios(**params)
+
+
+def test_likelihood_ratios():
+ # Build confusion matrix with tn=9, fp=8, fn=1, tp=2,
+ # sensitivity=2/3, specificity=9/17, prevalence=3/20,
+ # LR+=34/24, LR-=17/27
+ y_true = np.array([1] * 3 + [0] * 17)
+ y_pred = np.array([1] * 2 + [0] * 10 + [1] * 8)
+
+ pos, neg = class_likelihood_ratios(y_true, y_pred)
+ assert_allclose(pos, 34 / 24)
+ assert_allclose(neg, 17 / 27)
+
+ # Build limit case with y_pred = y_true
+ pos, neg = class_likelihood_ratios(y_true, y_true)
+ assert_array_equal(pos, np.nan * 2)
+ assert_allclose(neg, np.zeros(2), rtol=1e-12)
+
+ # Ignore last 5 samples to get tn=9, fp=3, fn=1, tp=2,
+ # sensitivity=2/3, specificity=9/12, prevalence=3/20,
+ # LR+=24/9, LR-=12/27
+ sample_weight = np.array([1.0] * 15 + [0.0] * 5)
+ pos, neg = class_likelihood_ratios(y_true, y_pred, sample_weight=sample_weight)
+ assert_allclose(pos, 24 / 9)
+ assert_allclose(neg, 12 / 27)
+
+
+def test_cohen_kappa():
+ # These label vectors reproduce the contingency matrix from Artstein and
+ # Poesio (2008), Table 1: np.array([[20, 20], [10, 50]]).
+ y1 = np.array([0] * 40 + [1] * 60)
+ y2 = np.array([0] * 20 + [1] * 20 + [0] * 10 + [1] * 50)
+ kappa = cohen_kappa_score(y1, y2)
+ assert_almost_equal(kappa, 0.348, decimal=3)
+ assert kappa == cohen_kappa_score(y2, y1)
+
+ # Add spurious labels and ignore them.
+ y1 = np.append(y1, [2] * 4)
+ y2 = np.append(y2, [2] * 4)
+ assert cohen_kappa_score(y1, y2, labels=[0, 1]) == kappa
+
+ assert_almost_equal(cohen_kappa_score(y1, y1), 1.0)
+
+ # Multiclass example: Artstein and Poesio, Table 4.
+ y1 = np.array([0] * 46 + [1] * 44 + [2] * 10)
+ y2 = np.array([0] * 52 + [1] * 32 + [2] * 16)
+ assert_almost_equal(cohen_kappa_score(y1, y2), 0.8013, decimal=4)
+
+ # Weighting example: none, linear, quadratic.
+ y1 = np.array([0] * 46 + [1] * 44 + [2] * 10)
+ y2 = np.array([0] * 50 + [1] * 40 + [2] * 10)
+ assert_almost_equal(cohen_kappa_score(y1, y2), 0.9315, decimal=4)
+ assert_almost_equal(cohen_kappa_score(y1, y2, weights="linear"), 0.9412, decimal=4)
+ assert_almost_equal(
+ cohen_kappa_score(y1, y2, weights="quadratic"), 0.9541, decimal=4
+ )
+
+
+def test_matthews_corrcoef_nan():
+ assert matthews_corrcoef([0], [1]) == 0.0
+ assert matthews_corrcoef([0, 0], [0, 1]) == 0.0
+
+
+def test_matthews_corrcoef_against_numpy_corrcoef():
+ rng = np.random.RandomState(0)
+ y_true = rng.randint(0, 2, size=20)
+ y_pred = rng.randint(0, 2, size=20)
+
+ assert_almost_equal(
+ matthews_corrcoef(y_true, y_pred), np.corrcoef(y_true, y_pred)[0, 1], 10
+ )
+
+
+def test_matthews_corrcoef_against_jurman():
+ # Check that the multiclass matthews_corrcoef agrees with the definition
+ # presented in Jurman, Riccadonna, Furlanello, (2012). A Comparison of MCC
+ # and CEN Error Measures in MultiClass Prediction
+ rng = np.random.RandomState(0)
+ y_true = rng.randint(0, 2, size=20)
+ y_pred = rng.randint(0, 2, size=20)
+ sample_weight = rng.rand(20)
+
+ C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
+ N = len(C)
+ cov_ytyp = sum(
+ [
+ C[k, k] * C[m, l] - C[l, k] * C[k, m]
+ for k in range(N)
+ for m in range(N)
+ for l in range(N)
+ ]
+ )
+ cov_ytyt = sum(
+ [
+ C[:, k].sum()
+ * np.sum([C[g, f] for f in range(N) for g in range(N) if f != k])
+ for k in range(N)
+ ]
+ )
+ cov_ypyp = np.sum(
+ [
+ C[k, :].sum()
+ * np.sum([C[f, g] for f in range(N) for g in range(N) if f != k])
+ for k in range(N)
+ ]
+ )
+ mcc_jurman = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
+ mcc_ours = matthews_corrcoef(y_true, y_pred, sample_weight=sample_weight)
+
+ assert_almost_equal(mcc_ours, mcc_jurman, 10)
+
+
+def test_matthews_corrcoef():
+ rng = np.random.RandomState(0)
+ y_true = ["a" if i == 0 else "b" for i in rng.randint(0, 2, size=20)]
+
+ # corrcoef of same vectors must be 1
+ assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)
+
+ # corrcoef, when the two vectors are opposites of each other, should be -1
+ y_true_inv = ["b" if i == "a" else "a" for i in y_true]
+ assert_almost_equal(matthews_corrcoef(y_true, y_true_inv), -1)
+
+ y_true_inv2 = label_binarize(y_true, classes=["a", "b"])
+ y_true_inv2 = np.where(y_true_inv2, "a", "b")
+ assert_almost_equal(matthews_corrcoef(y_true, y_true_inv2), -1)
+
+ # For the zero vector case, the corrcoef cannot be calculated and should
+ # output 0
+ assert_almost_equal(matthews_corrcoef([0, 0, 0, 0], [0, 0, 0, 0]), 0.0)
+
+ # And also for any other vector with 0 variance
+ assert_almost_equal(matthews_corrcoef(y_true, ["a"] * len(y_true)), 0.0)
+
+ # These two vectors have 0 correlation and hence mcc should be 0
+ y_1 = [1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]
+ y_2 = [1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]
+ assert_almost_equal(matthews_corrcoef(y_1, y_2), 0.0)
+
+ # Check that sample weight is able to selectively exclude
+ mask = [1] * 10 + [0] * 10
+ # Now the first half of the vector elements are alone given a weight of 1
+ # and hence the mcc will not be a perfect 0 as in the previous case
+ with pytest.raises(AssertionError):
+ assert_almost_equal(matthews_corrcoef(y_1, y_2, sample_weight=mask), 0.0)
+
+
+def test_matthews_corrcoef_multiclass():
+ rng = np.random.RandomState(0)
+ ord_a = ord("a")
+ n_classes = 4
+ y_true = [chr(ord_a + i) for i in rng.randint(0, n_classes, size=20)]
+
+ # corrcoef of same vectors must be 1
+ assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)
+
+ # with multiclass > 2 it is not possible to achieve -1
+ y_true = [0, 0, 1, 1, 2, 2]
+ y_pred_bad = [2, 2, 0, 0, 1, 1]
+ assert_almost_equal(matthews_corrcoef(y_true, y_pred_bad), -0.5)
+
+ # Maximizing false positives and negatives minimizes the MCC
+ # The minimum will be different for depending on the input
+ y_true = [0, 0, 1, 1, 2, 2]
+ y_pred_min = [1, 1, 0, 0, 0, 0]
+ assert_almost_equal(matthews_corrcoef(y_true, y_pred_min), -12 / np.sqrt(24 * 16))
+
+ # Zero variance will result in an mcc of zero
+ y_true = [0, 1, 2]
+ y_pred = [3, 3, 3]
+ assert_almost_equal(matthews_corrcoef(y_true, y_pred), 0.0)
+
+ # Also for ground truth with zero variance
+ y_true = [3, 3, 3]
+ y_pred = [0, 1, 2]
+ assert_almost_equal(matthews_corrcoef(y_true, y_pred), 0.0)
+
+ # These two vectors have 0 correlation and hence mcc should be 0
+ y_1 = [0, 1, 2, 0, 1, 2, 0, 1, 2]
+ y_2 = [1, 1, 1, 2, 2, 2, 0, 0, 0]
+ assert_almost_equal(matthews_corrcoef(y_1, y_2), 0.0)
+
+ # We can test that binary assumptions hold using the multiclass computation
+ # by masking the weight of samples not in the first two classes
+
+ # Masking the last label should let us get an MCC of -1
+ y_true = [0, 0, 1, 1, 2]
+ y_pred = [1, 1, 0, 0, 2]
+ sample_weight = [1, 1, 1, 1, 0]
+ assert_almost_equal(
+ matthews_corrcoef(y_true, y_pred, sample_weight=sample_weight), -1
+ )
+
+ # For the zero vector case, the corrcoef cannot be calculated and should
+ # output 0
+ y_true = [0, 0, 1, 2]
+ y_pred = [0, 0, 1, 2]
+ sample_weight = [1, 1, 0, 0]
+ assert_almost_equal(
+ matthews_corrcoef(y_true, y_pred, sample_weight=sample_weight), 0.0
+ )
+
+
+@pytest.mark.parametrize("n_points", [100, 10000])
+def test_matthews_corrcoef_overflow(n_points):
+ # https://github.com/scikit-learn/scikit-learn/issues/9622
+ rng = np.random.RandomState(20170906)
+
+ def mcc_safe(y_true, y_pred):
+ conf_matrix = confusion_matrix(y_true, y_pred)
+ true_pos = conf_matrix[1, 1]
+ false_pos = conf_matrix[1, 0]
+ false_neg = conf_matrix[0, 1]
+ n_points = len(y_true)
+ pos_rate = (true_pos + false_neg) / n_points
+ activity = (true_pos + false_pos) / n_points
+ mcc_numerator = true_pos / n_points - pos_rate * activity
+ mcc_denominator = activity * pos_rate * (1 - activity) * (1 - pos_rate)
+ return mcc_numerator / np.sqrt(mcc_denominator)
+
+ def random_ys(n_points): # binary
+ x_true = rng.random_sample(n_points)
+ x_pred = x_true + 0.2 * (rng.random_sample(n_points) - 0.5)
+ y_true = x_true > 0.5
+ y_pred = x_pred > 0.5
+ return y_true, y_pred
+
+ arr = np.repeat([0.0, 1.0], n_points) # binary
+ assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
+ arr = np.repeat([0.0, 1.0, 2.0], n_points) # multiclass
+ assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
+
+ y_true, y_pred = random_ys(n_points)
+ assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)
+ assert_almost_equal(matthews_corrcoef(y_true, y_pred), mcc_safe(y_true, y_pred))
+
+
+def test_precision_recall_f1_score_multiclass():
+ # Test Precision Recall and F1 Score for multiclass classification task
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ # compute scores with default labels introspection
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
+ assert_array_almost_equal(p, [0.83, 0.33, 0.42], 2)
+ assert_array_almost_equal(r, [0.79, 0.09, 0.90], 2)
+ assert_array_almost_equal(f, [0.81, 0.15, 0.57], 2)
+ assert_array_equal(s, [24, 31, 20])
+
+ # averaging tests
+ ps = precision_score(y_true, y_pred, pos_label=1, average="micro")
+ assert_array_almost_equal(ps, 0.53, 2)
+
+ rs = recall_score(y_true, y_pred, average="micro")
+ assert_array_almost_equal(rs, 0.53, 2)
+
+ fs = f1_score(y_true, y_pred, average="micro")
+ assert_array_almost_equal(fs, 0.53, 2)
+
+ ps = precision_score(y_true, y_pred, average="macro")
+ assert_array_almost_equal(ps, 0.53, 2)
+
+ rs = recall_score(y_true, y_pred, average="macro")
+ assert_array_almost_equal(rs, 0.60, 2)
+
+ fs = f1_score(y_true, y_pred, average="macro")
+ assert_array_almost_equal(fs, 0.51, 2)
+
+ ps = precision_score(y_true, y_pred, average="weighted")
+ assert_array_almost_equal(ps, 0.51, 2)
+
+ rs = recall_score(y_true, y_pred, average="weighted")
+ assert_array_almost_equal(rs, 0.53, 2)
+
+ fs = f1_score(y_true, y_pred, average="weighted")
+ assert_array_almost_equal(fs, 0.47, 2)
+
+ with pytest.raises(ValueError):
+ precision_score(y_true, y_pred, average="samples")
+ with pytest.raises(ValueError):
+ recall_score(y_true, y_pred, average="samples")
+ with pytest.raises(ValueError):
+ f1_score(y_true, y_pred, average="samples")
+ with pytest.raises(ValueError):
+ fbeta_score(y_true, y_pred, average="samples", beta=0.5)
+
+ # same prediction but with and explicit label ordering
+ p, r, f, s = precision_recall_fscore_support(
+ y_true, y_pred, labels=[0, 2, 1], average=None
+ )
+ assert_array_almost_equal(p, [0.83, 0.41, 0.33], 2)
+ assert_array_almost_equal(r, [0.79, 0.90, 0.10], 2)
+ assert_array_almost_equal(f, [0.81, 0.57, 0.15], 2)
+ assert_array_equal(s, [24, 20, 31])
+
+
+@pytest.mark.parametrize("average", ["samples", "micro", "macro", "weighted", None])
+def test_precision_refcall_f1_score_multilabel_unordered_labels(average):
+ # test that labels need not be sorted in the multilabel case
+ y_true = np.array([[1, 1, 0, 0]])
+ y_pred = np.array([[0, 0, 1, 1]])
+ p, r, f, s = precision_recall_fscore_support(
+ y_true, y_pred, labels=[3, 0, 1, 2], warn_for=[], average=average
+ )
+ assert_array_equal(p, 0)
+ assert_array_equal(r, 0)
+ assert_array_equal(f, 0)
+ if average is None:
+ assert_array_equal(s, [0, 1, 1, 0])
+
+
+def test_precision_recall_f1_score_binary_averaged():
+ y_true = np.array([0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1])
+ y_pred = np.array([1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1])
+
+ # compute scores with default labels introspection
+ ps, rs, fs, _ = precision_recall_fscore_support(y_true, y_pred, average=None)
+ p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average="macro")
+ assert p == np.mean(ps)
+ assert r == np.mean(rs)
+ assert f == np.mean(fs)
+ p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted")
+ support = np.bincount(y_true)
+ assert p == np.average(ps, weights=support)
+ assert r == np.average(rs, weights=support)
+ assert f == np.average(fs, weights=support)
+
+
+def test_zero_precision_recall():
+ # Check that pathological cases do not bring NaNs
+
+ old_error_settings = np.seterr(all="raise")
+
+ try:
+ y_true = np.array([0, 1, 2, 0, 1, 2])
+ y_pred = np.array([2, 0, 1, 1, 2, 0])
+
+ assert_almost_equal(precision_score(y_true, y_pred, average="macro"), 0.0, 2)
+ assert_almost_equal(recall_score(y_true, y_pred, average="macro"), 0.0, 2)
+ assert_almost_equal(f1_score(y_true, y_pred, average="macro"), 0.0, 2)
+
+ finally:
+ np.seterr(**old_error_settings)
+
+
+def test_confusion_matrix_multiclass_subset_labels():
+ # Test confusion matrix - multi-class case with subset of labels
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ # compute confusion matrix with only first two labels considered
+ cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
+ assert_array_equal(cm, [[19, 4], [4, 3]])
+
+ # compute confusion matrix with explicit label ordering for only subset
+ # of labels
+ cm = confusion_matrix(y_true, y_pred, labels=[2, 1])
+ assert_array_equal(cm, [[18, 2], [24, 3]])
+
+ # a label not in y_true should result in zeros for that row/column
+ extra_label = np.max(y_true) + 1
+ cm = confusion_matrix(y_true, y_pred, labels=[2, extra_label])
+ assert_array_equal(cm, [[18, 0], [0, 0]])
+
+
+@pytest.mark.parametrize(
+ "labels, err_msg",
+ [
+ ([], "'labels' should contains at least one label."),
+ ([3, 4], "At least one label specified must be in y_true"),
+ ],
+ ids=["empty list", "unknown labels"],
+)
+def test_confusion_matrix_error(labels, err_msg):
+ y_true, y_pred, _ = make_prediction(binary=False)
+ with pytest.raises(ValueError, match=err_msg):
+ confusion_matrix(y_true, y_pred, labels=labels)
+
+
+@pytest.mark.parametrize(
+ "labels", (None, [0, 1], [0, 1, 2]), ids=["None", "binary", "multiclass"]
+)
+def test_confusion_matrix_on_zero_length_input(labels):
+ expected_n_classes = len(labels) if labels else 0
+ expected = np.zeros((expected_n_classes, expected_n_classes), dtype=int)
+ cm = confusion_matrix([], [], labels=labels)
+ assert_array_equal(cm, expected)
+
+
+def test_confusion_matrix_dtype():
+ y = [0, 1, 1]
+ weight = np.ones(len(y))
+ # confusion_matrix returns int64 by default
+ cm = confusion_matrix(y, y)
+ assert cm.dtype == np.int64
+ # The dtype of confusion_matrix is always 64 bit
+ for dtype in [np.bool_, np.int32, np.uint64]:
+ cm = confusion_matrix(y, y, sample_weight=weight.astype(dtype, copy=False))
+ assert cm.dtype == np.int64
+ for dtype in [np.float32, np.float64, None, object]:
+ cm = confusion_matrix(y, y, sample_weight=weight.astype(dtype, copy=False))
+ assert cm.dtype == np.float64
+
+ # np.iinfo(np.uint32).max should be accumulated correctly
+ weight = np.full(len(y), 4294967295, dtype=np.uint32)
+ cm = confusion_matrix(y, y, sample_weight=weight)
+ assert cm[0, 0] == 4294967295
+ assert cm[1, 1] == 8589934590
+
+ # np.iinfo(np.int64).max should cause an overflow
+ weight = np.full(len(y), 9223372036854775807, dtype=np.int64)
+ cm = confusion_matrix(y, y, sample_weight=weight)
+ assert cm[0, 0] == 9223372036854775807
+ assert cm[1, 1] == -2
+
+
+def test_classification_report_multiclass():
+ # Test performance report
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+
+ # print classification report with class names
+ expected_report = """\
+ precision recall f1-score support
+
+ setosa 0.83 0.79 0.81 24
+ versicolor 0.33 0.10 0.15 31
+ virginica 0.42 0.90 0.57 20
+
+ accuracy 0.53 75
+ macro avg 0.53 0.60 0.51 75
+weighted avg 0.51 0.53 0.47 75
+"""
+ report = classification_report(
+ y_true,
+ y_pred,
+ labels=np.arange(len(iris.target_names)),
+ target_names=iris.target_names,
+ )
+ assert report == expected_report
+
+
+def test_classification_report_multiclass_balanced():
+ y_true, y_pred = [0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]
+
+ expected_report = """\
+ precision recall f1-score support
+
+ 0 0.33 0.33 0.33 3
+ 1 0.33 0.33 0.33 3
+ 2 0.33 0.33 0.33 3
+
+ accuracy 0.33 9
+ macro avg 0.33 0.33 0.33 9
+weighted avg 0.33 0.33 0.33 9
+"""
+ report = classification_report(y_true, y_pred)
+ assert report == expected_report
+
+
+def test_classification_report_multiclass_with_label_detection():
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+
+ # print classification report with label detection
+ expected_report = """\
+ precision recall f1-score support
+
+ 0 0.83 0.79 0.81 24
+ 1 0.33 0.10 0.15 31
+ 2 0.42 0.90 0.57 20
+
+ accuracy 0.53 75
+ macro avg 0.53 0.60 0.51 75
+weighted avg 0.51 0.53 0.47 75
+"""
+ report = classification_report(y_true, y_pred)
+ assert report == expected_report
+
+
+def test_classification_report_multiclass_with_digits():
+ # Test performance report with added digits in floating point values
+ iris = datasets.load_iris()
+ y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
+
+ # print classification report with class names
+ expected_report = """\
+ precision recall f1-score support
+
+ setosa 0.82609 0.79167 0.80851 24
+ versicolor 0.33333 0.09677 0.15000 31
+ virginica 0.41860 0.90000 0.57143 20
+
+ accuracy 0.53333 75
+ macro avg 0.52601 0.59615 0.50998 75
+weighted avg 0.51375 0.53333 0.47310 75
+"""
+ report = classification_report(
+ y_true,
+ y_pred,
+ labels=np.arange(len(iris.target_names)),
+ target_names=iris.target_names,
+ digits=5,
+ )
+ assert report == expected_report
+
+
+def test_classification_report_multiclass_with_string_label():
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ y_true = np.array(["blue", "green", "red"])[y_true]
+ y_pred = np.array(["blue", "green", "red"])[y_pred]
+
+ expected_report = """\
+ precision recall f1-score support
+
+ blue 0.83 0.79 0.81 24
+ green 0.33 0.10 0.15 31
+ red 0.42 0.90 0.57 20
+
+ accuracy 0.53 75
+ macro avg 0.53 0.60 0.51 75
+weighted avg 0.51 0.53 0.47 75
+"""
+ report = classification_report(y_true, y_pred)
+ assert report == expected_report
+
+ expected_report = """\
+ precision recall f1-score support
+
+ a 0.83 0.79 0.81 24
+ b 0.33 0.10 0.15 31
+ c 0.42 0.90 0.57 20
+
+ accuracy 0.53 75
+ macro avg 0.53 0.60 0.51 75
+weighted avg 0.51 0.53 0.47 75
+"""
+ report = classification_report(y_true, y_pred, target_names=["a", "b", "c"])
+ assert report == expected_report
+
+
+def test_classification_report_multiclass_with_unicode_label():
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ labels = np.array(["blue\xa2", "green\xa2", "red\xa2"])
+ y_true = labels[y_true]
+ y_pred = labels[y_pred]
+
+ expected_report = """\
+ precision recall f1-score support
+
+ blue\xa2 0.83 0.79 0.81 24
+ green\xa2 0.33 0.10 0.15 31
+ red\xa2 0.42 0.90 0.57 20
+
+ accuracy 0.53 75
+ macro avg 0.53 0.60 0.51 75
+weighted avg 0.51 0.53 0.47 75
+"""
+ report = classification_report(y_true, y_pred)
+ assert report == expected_report
+
+
+def test_classification_report_multiclass_with_long_string_label():
+ y_true, y_pred, _ = make_prediction(binary=False)
+
+ labels = np.array(["blue", "green" * 5, "red"])
+ y_true = labels[y_true]
+ y_pred = labels[y_pred]
+
+ expected_report = """\
+ precision recall f1-score support
+
+ blue 0.83 0.79 0.81 24
+greengreengreengreengreen 0.33 0.10 0.15 31
+ red 0.42 0.90 0.57 20
+
+ accuracy 0.53 75
+ macro avg 0.53 0.60 0.51 75
+ weighted avg 0.51 0.53 0.47 75
+"""
+
+ report = classification_report(y_true, y_pred)
+ assert report == expected_report
+
+
+def test_classification_report_labels_target_names_unequal_length():
+ y_true = [0, 0, 2, 0, 0]
+ y_pred = [0, 2, 2, 0, 0]
+ target_names = ["class 0", "class 1", "class 2"]
+
+ msg = "labels size, 2, does not match size of target_names, 3"
+ with pytest.warns(UserWarning, match=msg):
+ classification_report(y_true, y_pred, labels=[0, 2], target_names=target_names)
+
+
+def test_classification_report_no_labels_target_names_unequal_length():
+ y_true = [0, 0, 2, 0, 0]
+ y_pred = [0, 2, 2, 0, 0]
+ target_names = ["class 0", "class 1", "class 2"]
+
+ err_msg = (
+ "Number of classes, 2, does not "
+ "match size of target_names, 3. "
+ "Try specifying the labels parameter"
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ classification_report(y_true, y_pred, target_names=target_names)
+
+
+@ignore_warnings
+def test_multilabel_classification_report():
+ n_classes = 4
+ n_samples = 50
+
+ _, y_true = make_multilabel_classification(
+ n_features=1, n_samples=n_samples, n_classes=n_classes, random_state=0
+ )
+
+ _, y_pred = make_multilabel_classification(
+ n_features=1, n_samples=n_samples, n_classes=n_classes, random_state=1
+ )
+
+ expected_report = """\
+ precision recall f1-score support
+
+ 0 0.50 0.67 0.57 24
+ 1 0.51 0.74 0.61 27
+ 2 0.29 0.08 0.12 26
+ 3 0.52 0.56 0.54 27
+
+ micro avg 0.50 0.51 0.50 104
+ macro avg 0.45 0.51 0.46 104
+weighted avg 0.45 0.51 0.46 104
+ samples avg 0.46 0.42 0.40 104
+"""
+
+ report = classification_report(y_true, y_pred)
+ assert report == expected_report
+
+
+def test_multilabel_zero_one_loss_subset():
+ # Dense label indicator matrix format
+ y1 = np.array([[0, 1, 1], [1, 0, 1]])
+ y2 = np.array([[0, 0, 1], [1, 0, 1]])
+
+ assert zero_one_loss(y1, y2) == 0.5
+ assert zero_one_loss(y1, y1) == 0
+ assert zero_one_loss(y2, y2) == 0
+ assert zero_one_loss(y2, np.logical_not(y2)) == 1
+ assert zero_one_loss(y1, np.logical_not(y1)) == 1
+ assert zero_one_loss(y1, np.zeros(y1.shape)) == 1
+ assert zero_one_loss(y2, np.zeros(y1.shape)) == 1
+
+
+def test_multilabel_hamming_loss():
+ # Dense label indicator matrix format
+ y1 = np.array([[0, 1, 1], [1, 0, 1]])
+ y2 = np.array([[0, 0, 1], [1, 0, 1]])
+ w = np.array([1, 3])
+
+ assert hamming_loss(y1, y2) == 1 / 6
+ assert hamming_loss(y1, y1) == 0
+ assert hamming_loss(y2, y2) == 0
+ assert hamming_loss(y2, 1 - y2) == 1
+ assert hamming_loss(y1, 1 - y1) == 1
+ assert hamming_loss(y1, np.zeros(y1.shape)) == 4 / 6
+ assert hamming_loss(y2, np.zeros(y1.shape)) == 0.5
+ assert hamming_loss(y1, y2, sample_weight=w) == 1.0 / 12
+ assert hamming_loss(y1, 1 - y2, sample_weight=w) == 11.0 / 12
+ assert hamming_loss(y1, np.zeros_like(y1), sample_weight=w) == 2.0 / 3
+ # sp_hamming only works with 1-D arrays
+ assert hamming_loss(y1[0], y2[0]) == sp_hamming(y1[0], y2[0])
+
+
+def test_jaccard_score_validation():
+ y_true = np.array([0, 1, 0, 1, 1])
+ y_pred = np.array([0, 1, 0, 1, 1])
+ err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]"
+ with pytest.raises(ValueError, match=err_msg):
+ jaccard_score(y_true, y_pred, average="binary", pos_label=2)
+
+ y_true = np.array([[0, 1, 1], [1, 0, 0]])
+ y_pred = np.array([[1, 1, 1], [1, 0, 1]])
+ msg1 = (
+ r"Target is multilabel-indicator but average='binary'. "
+ r"Please choose another average setting, one of \[None, "
+ r"'micro', 'macro', 'weighted', 'samples'\]."
+ )
+ with pytest.raises(ValueError, match=msg1):
+ jaccard_score(y_true, y_pred, average="binary", pos_label=-1)
+
+ y_true = np.array([0, 1, 1, 0, 2])
+ y_pred = np.array([1, 1, 1, 1, 0])
+ msg2 = (
+ r"Target is multiclass but average='binary'. Please choose "
+ r"another average setting, one of \[None, 'micro', 'macro', "
+ r"'weighted'\]."
+ )
+ with pytest.raises(ValueError, match=msg2):
+ jaccard_score(y_true, y_pred, average="binary")
+ msg3 = "Samplewise metrics are not available outside of multilabel classification."
+ with pytest.raises(ValueError, match=msg3):
+ jaccard_score(y_true, y_pred, average="samples")
+
+ msg = (
+ r"Note that pos_label \(set to 3\) is ignored when "
+ r"average != 'binary' \(got 'micro'\). You may use "
+ r"labels=\[pos_label\] to specify a single positive "
+ "class."
+ )
+ with pytest.warns(UserWarning, match=msg):
+ jaccard_score(y_true, y_pred, average="micro", pos_label=3)
+
+
+def test_multilabel_jaccard_score(recwarn):
+ # Dense label indicator matrix format
+ y1 = np.array([[0, 1, 1], [1, 0, 1]])
+ y2 = np.array([[0, 0, 1], [1, 0, 1]])
+
+ # size(y1 \inter y2) = [1, 2]
+ # size(y1 \union y2) = [2, 2]
+
+ assert jaccard_score(y1, y2, average="samples") == 0.75
+ assert jaccard_score(y1, y1, average="samples") == 1
+ assert jaccard_score(y2, y2, average="samples") == 1
+ assert jaccard_score(y2, np.logical_not(y2), average="samples") == 0
+ assert jaccard_score(y1, np.logical_not(y1), average="samples") == 0
+ assert jaccard_score(y1, np.zeros(y1.shape), average="samples") == 0
+ assert jaccard_score(y2, np.zeros(y1.shape), average="samples") == 0
+
+ y_true = np.array([[0, 1, 1], [1, 0, 0]])
+ y_pred = np.array([[1, 1, 1], [1, 0, 1]])
+ # average='macro'
+ assert_almost_equal(jaccard_score(y_true, y_pred, average="macro"), 2.0 / 3)
+ # average='micro'
+ assert_almost_equal(jaccard_score(y_true, y_pred, average="micro"), 3.0 / 5)
+ # average='samples'
+ assert_almost_equal(jaccard_score(y_true, y_pred, average="samples"), 7.0 / 12)
+ assert_almost_equal(
+ jaccard_score(y_true, y_pred, average="samples", labels=[0, 2]), 1.0 / 2
+ )
+ assert_almost_equal(
+ jaccard_score(y_true, y_pred, average="samples", labels=[1, 2]), 1.0 / 2
+ )
+ # average=None
+ assert_array_equal(
+ jaccard_score(y_true, y_pred, average=None), np.array([1.0 / 2, 1.0, 1.0 / 2])
+ )
+
+ y_true = np.array([[0, 1, 1], [1, 0, 1]])
+ y_pred = np.array([[1, 1, 1], [1, 0, 1]])
+ assert_almost_equal(jaccard_score(y_true, y_pred, average="macro"), 5.0 / 6)
+ # average='weighted'
+ assert_almost_equal(jaccard_score(y_true, y_pred, average="weighted"), 7.0 / 8)
+
+ msg2 = "Got 4 > 2"
+ with pytest.raises(ValueError, match=msg2):
+ jaccard_score(y_true, y_pred, labels=[4], average="macro")
+ msg3 = "Got -1 < 0"
+ with pytest.raises(ValueError, match=msg3):
+ jaccard_score(y_true, y_pred, labels=[-1], average="macro")
+
+ msg = (
+ "Jaccard is ill-defined and being set to 0.0 in labels "
+ "with no true or predicted samples."
+ )
+
+ with pytest.warns(UndefinedMetricWarning, match=msg):
+ assert (
+ jaccard_score(np.array([[0, 1]]), np.array([[0, 1]]), average="macro")
+ == 0.5
+ )
+
+ msg = (
+ "Jaccard is ill-defined and being set to 0.0 in samples "
+ "with no true or predicted labels."
+ )
+
+ with pytest.warns(UndefinedMetricWarning, match=msg):
+ assert (
+ jaccard_score(
+ np.array([[0, 0], [1, 1]]),
+ np.array([[0, 0], [1, 1]]),
+ average="samples",
+ )
+ == 0.5
+ )
+
+ assert not list(recwarn)
+
+
+def test_multiclass_jaccard_score(recwarn):
+ y_true = ["ant", "ant", "cat", "cat", "ant", "cat", "bird", "bird"]
+ y_pred = ["cat", "ant", "cat", "cat", "ant", "bird", "bird", "cat"]
+ labels = ["ant", "bird", "cat"]
+ lb = LabelBinarizer()
+ lb.fit(labels)
+ y_true_bin = lb.transform(y_true)
+ y_pred_bin = lb.transform(y_pred)
+ multi_jaccard_score = partial(jaccard_score, y_true, y_pred)
+ bin_jaccard_score = partial(jaccard_score, y_true_bin, y_pred_bin)
+ multi_labels_list = [
+ ["ant", "bird"],
+ ["ant", "cat"],
+ ["cat", "bird"],
+ ["ant"],
+ ["bird"],
+ ["cat"],
+ None,
+ ]
+ bin_labels_list = [[0, 1], [0, 2], [2, 1], [0], [1], [2], None]
+
+ # other than average='samples'/'none-samples', test everything else here
+ for average in ("macro", "weighted", "micro", None):
+ for m_label, b_label in zip(multi_labels_list, bin_labels_list):
+ assert_almost_equal(
+ multi_jaccard_score(average=average, labels=m_label),
+ bin_jaccard_score(average=average, labels=b_label),
+ )
+
+ y_true = np.array([[0, 0], [0, 0], [0, 0]])
+ y_pred = np.array([[0, 0], [0, 0], [0, 0]])
+ with ignore_warnings():
+ assert jaccard_score(y_true, y_pred, average="weighted") == 0
+
+ assert not list(recwarn)
+
+
+def test_average_binary_jaccard_score(recwarn):
+ # tp=0, fp=0, fn=1, tn=0
+ assert jaccard_score([1], [0], average="binary") == 0.0
+ # tp=0, fp=0, fn=0, tn=1
+ msg = (
+ "Jaccard is ill-defined and being set to 0.0 due to "
+ "no true or predicted samples"
+ )
+ with pytest.warns(UndefinedMetricWarning, match=msg):
+ assert jaccard_score([0, 0], [0, 0], average="binary") == 0.0
+
+ # tp=1, fp=0, fn=0, tn=0 (pos_label=0)
+ assert jaccard_score([0], [0], pos_label=0, average="binary") == 1.0
+ y_true = np.array([1, 0, 1, 1, 0])
+ y_pred = np.array([1, 0, 1, 1, 1])
+ assert_almost_equal(jaccard_score(y_true, y_pred, average="binary"), 3.0 / 4)
+ assert_almost_equal(
+ jaccard_score(y_true, y_pred, average="binary", pos_label=0), 1.0 / 2
+ )
+
+ assert not list(recwarn)
+
+
+def test_jaccard_score_zero_division_warning():
+ # check that we raised a warning with default behavior if a zero division
+ # happens
+ y_true = np.array([[1, 0, 1], [0, 0, 0]])
+ y_pred = np.array([[0, 0, 0], [0, 0, 0]])
+ msg = (
+ "Jaccard is ill-defined and being set to 0.0 in "
+ "samples with no true or predicted labels."
+ " Use `zero_division` parameter to control this behavior."
+ )
+ with pytest.warns(UndefinedMetricWarning, match=msg):
+ score = jaccard_score(y_true, y_pred, average="samples", zero_division="warn")
+ assert score == pytest.approx(0.0)
+
+
+@pytest.mark.parametrize("zero_division, expected_score", [(0, 0), (1, 0.5)])
+def test_jaccard_score_zero_division_set_value(zero_division, expected_score):
+ # check that we don't issue warning by passing the zero_division parameter
+ y_true = np.array([[1, 0, 1], [0, 0, 0]])
+ y_pred = np.array([[0, 0, 0], [0, 0, 0]])
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", UndefinedMetricWarning)
+ score = jaccard_score(
+ y_true, y_pred, average="samples", zero_division=zero_division
+ )
+ assert score == pytest.approx(expected_score)
+
+
+@ignore_warnings
+def test_precision_recall_f1_score_multilabel_1():
+ # Test precision_recall_f1_score on a crafted multilabel example
+ # First crafted example
+
+ y_true = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1]])
+ y_pred = np.array([[0, 1, 0, 0], [0, 1, 0, 0], [1, 0, 1, 0]])
+
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
+
+ # tp = [0, 1, 1, 0]
+ # fn = [1, 0, 0, 1]
+ # fp = [1, 1, 0, 0]
+ # Check per class
+
+ assert_array_almost_equal(p, [0.0, 0.5, 1.0, 0.0], 2)
+ assert_array_almost_equal(r, [0.0, 1.0, 1.0, 0.0], 2)
+ assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2)
+ assert_array_almost_equal(s, [1, 1, 1, 1], 2)
+
+ f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
+ support = s
+ assert_array_almost_equal(f2, [0, 0.83, 1, 0], 2)
+
+ # Check macro
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="macro")
+ assert_almost_equal(p, 1.5 / 4)
+ assert_almost_equal(r, 0.5)
+ assert_almost_equal(f, 2.5 / 1.5 * 0.25)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="macro"), np.mean(f2)
+ )
+
+ # Check micro
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="micro")
+ assert_almost_equal(p, 0.5)
+ assert_almost_equal(r, 0.5)
+ assert_almost_equal(f, 0.5)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="micro"),
+ (1 + 4) * p * r / (4 * p + r),
+ )
+
+ # Check weighted
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="weighted")
+ assert_almost_equal(p, 1.5 / 4)
+ assert_almost_equal(r, 0.5)
+ assert_almost_equal(f, 2.5 / 1.5 * 0.25)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="weighted"),
+ np.average(f2, weights=support),
+ )
+ # Check samples
+ # |h(x_i) inter y_i | = [0, 1, 1]
+ # |y_i| = [1, 1, 2]
+ # |h(x_i)| = [1, 1, 2]
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="samples")
+ assert_almost_equal(p, 0.5)
+ assert_almost_equal(r, 0.5)
+ assert_almost_equal(f, 0.5)
+ assert s is None
+ assert_almost_equal(fbeta_score(y_true, y_pred, beta=2, average="samples"), 0.5)
+
+
+@ignore_warnings
+def test_precision_recall_f1_score_multilabel_2():
+ # Test precision_recall_f1_score on a crafted multilabel example 2
+ # Second crafted example
+ y_true = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 1, 1, 0]])
+ y_pred = np.array([[0, 0, 0, 1], [0, 0, 0, 1], [1, 1, 0, 0]])
+
+ # tp = [ 0. 1. 0. 0.]
+ # fp = [ 1. 0. 0. 2.]
+ # fn = [ 1. 1. 1. 0.]
+
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
+ assert_array_almost_equal(p, [0.0, 1.0, 0.0, 0.0], 2)
+ assert_array_almost_equal(r, [0.0, 0.5, 0.0, 0.0], 2)
+ assert_array_almost_equal(f, [0.0, 0.66, 0.0, 0.0], 2)
+ assert_array_almost_equal(s, [1, 2, 1, 0], 2)
+
+ f2 = fbeta_score(y_true, y_pred, beta=2, average=None)
+ support = s
+ assert_array_almost_equal(f2, [0, 0.55, 0, 0], 2)
+
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="micro")
+ assert_almost_equal(p, 0.25)
+ assert_almost_equal(r, 0.25)
+ assert_almost_equal(f, 2 * 0.25 * 0.25 / 0.5)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="micro"),
+ (1 + 4) * p * r / (4 * p + r),
+ )
+
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="macro")
+ assert_almost_equal(p, 0.25)
+ assert_almost_equal(r, 0.125)
+ assert_almost_equal(f, 2 / 12)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="macro"), np.mean(f2)
+ )
+
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="weighted")
+ assert_almost_equal(p, 2 / 4)
+ assert_almost_equal(r, 1 / 4)
+ assert_almost_equal(f, 2 / 3 * 2 / 4)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="weighted"),
+ np.average(f2, weights=support),
+ )
+
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="samples")
+ # Check samples
+ # |h(x_i) inter y_i | = [0, 0, 1]
+ # |y_i| = [1, 1, 2]
+ # |h(x_i)| = [1, 1, 2]
+
+ assert_almost_equal(p, 1 / 6)
+ assert_almost_equal(r, 1 / 6)
+ assert_almost_equal(f, 2 / 4 * 1 / 3)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="samples"), 0.1666, 2
+ )
+
+
+@ignore_warnings
+@pytest.mark.parametrize("zero_division", ["warn", 0, 1])
+def test_precision_recall_f1_score_with_an_empty_prediction(zero_division):
+ y_true = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 1, 1, 0]])
+ y_pred = np.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 0]])
+
+ # true_pos = [ 0. 1. 1. 0.]
+ # false_pos = [ 0. 0. 0. 1.]
+ # false_neg = [ 1. 1. 0. 0.]
+ zero_division = 1.0 if zero_division == 1.0 else 0.0
+ p, r, f, s = precision_recall_fscore_support(
+ y_true, y_pred, average=None, zero_division=zero_division
+ )
+ assert_array_almost_equal(p, [zero_division, 1.0, 1.0, 0.0], 2)
+ assert_array_almost_equal(r, [0.0, 0.5, 1.0, zero_division], 2)
+ assert_array_almost_equal(f, [0.0, 1 / 1.5, 1, 0.0], 2)
+ assert_array_almost_equal(s, [1, 2, 1, 0], 2)
+
+ f2 = fbeta_score(y_true, y_pred, beta=2, average=None, zero_division=zero_division)
+ support = s
+ assert_array_almost_equal(f2, [0, 0.55, 1, 0], 2)
+
+ p, r, f, s = precision_recall_fscore_support(
+ y_true, y_pred, average="macro", zero_division=zero_division
+ )
+ assert_almost_equal(p, (2 + zero_division) / 4)
+ assert_almost_equal(r, (1.5 + zero_division) / 4)
+ assert_almost_equal(f, 2.5 / (4 * 1.5))
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(y_true, y_pred, beta=2, average="macro"), np.mean(f2)
+ )
+
+ p, r, f, s = precision_recall_fscore_support(
+ y_true, y_pred, average="micro", zero_division=zero_division
+ )
+ assert_almost_equal(p, 2 / 3)
+ assert_almost_equal(r, 0.5)
+ assert_almost_equal(f, 2 / 3 / (2 / 3 + 0.5))
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(
+ y_true, y_pred, beta=2, average="micro", zero_division=zero_division
+ ),
+ (1 + 4) * p * r / (4 * p + r),
+ )
+
+ p, r, f, s = precision_recall_fscore_support(
+ y_true, y_pred, average="weighted", zero_division=zero_division
+ )
+ assert_almost_equal(p, 3 / 4 if zero_division == 0 else 1.0)
+ assert_almost_equal(r, 0.5)
+ assert_almost_equal(f, (2 / 1.5 + 1) / 4)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(
+ y_true, y_pred, beta=2, average="weighted", zero_division=zero_division
+ ),
+ np.average(f2, weights=support),
+ )
+
+ p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average="samples")
+ # |h(x_i) inter y_i | = [0, 0, 2]
+ # |y_i| = [1, 1, 2]
+ # |h(x_i)| = [0, 1, 2]
+ assert_almost_equal(p, 1 / 3)
+ assert_almost_equal(r, 1 / 3)
+ assert_almost_equal(f, 1 / 3)
+ assert s is None
+ assert_almost_equal(
+ fbeta_score(
+ y_true, y_pred, beta=2, average="samples", zero_division=zero_division
+ ),
+ 0.333,
+ 2,
+ )
+
+
+@pytest.mark.parametrize("beta", [1])
+@pytest.mark.parametrize("average", ["macro", "micro", "weighted", "samples"])
+@pytest.mark.parametrize("zero_division", [0, 1])
+def test_precision_recall_f1_no_labels(beta, average, zero_division):
+ y_true = np.zeros((20, 3))
+ y_pred = np.zeros_like(y_true)
+
+ p, r, f, s = assert_no_warnings(
+ precision_recall_fscore_support,
+ y_true,
+ y_pred,
+ average=average,
+ beta=beta,
+ zero_division=zero_division,
+ )
+ fbeta = assert_no_warnings(
+ fbeta_score,
+ y_true,
+ y_pred,
+ beta=beta,
+ average=average,
+ zero_division=zero_division,
+ )
+
+ zero_division = float(zero_division)
+ assert_almost_equal(p, zero_division)
+ assert_almost_equal(r, zero_division)
+ assert_almost_equal(f, zero_division)
+ assert s is None
+
+ assert_almost_equal(fbeta, float(zero_division))
+
+
+@pytest.mark.parametrize("average", ["macro", "micro", "weighted", "samples"])
+def test_precision_recall_f1_no_labels_check_warnings(average):
+ y_true = np.zeros((20, 3))
+ y_pred = np.zeros_like(y_true)
+
+ func = precision_recall_fscore_support
+ with pytest.warns(UndefinedMetricWarning):
+ p, r, f, s = func(y_true, y_pred, average=average, beta=1.0)
+
+ assert_almost_equal(p, 0)
+ assert_almost_equal(r, 0)
+ assert_almost_equal(f, 0)
+ assert s is None
+
+ with pytest.warns(UndefinedMetricWarning):
+ fbeta = fbeta_score(y_true, y_pred, average=average, beta=1.0)
+
+ assert_almost_equal(fbeta, 0)
+
+
+@pytest.mark.parametrize("zero_division", [0, 1])
+def test_precision_recall_f1_no_labels_average_none(zero_division):
+ y_true = np.zeros((20, 3))
+ y_pred = np.zeros_like(y_true)
+
+ # tp = [0, 0, 0]
+ # fn = [0, 0, 0]
+ # fp = [0, 0, 0]
+ # support = [0, 0, 0]
+ # |y_hat_i inter y_i | = [0, 0, 0]
+ # |y_i| = [0, 0, 0]
+ # |y_hat_i| = [0, 0, 0]
+
+ p, r, f, s = assert_no_warnings(
+ precision_recall_fscore_support,
+ y_true,
+ y_pred,
+ average=None,
+ beta=1.0,
+ zero_division=zero_division,
+ )
+ fbeta = assert_no_warnings(
+ fbeta_score, y_true, y_pred, beta=1.0, average=None, zero_division=zero_division
+ )
+
+ zero_division = float(zero_division)
+ assert_array_almost_equal(p, [zero_division, zero_division, zero_division], 2)
+ assert_array_almost_equal(r, [zero_division, zero_division, zero_division], 2)
+ assert_array_almost_equal(f, [zero_division, zero_division, zero_division], 2)
+ assert_array_almost_equal(s, [0, 0, 0], 2)
+
+ assert_array_almost_equal(fbeta, [zero_division, zero_division, zero_division], 2)
+
+
+def test_precision_recall_f1_no_labels_average_none_warn():
+ y_true = np.zeros((20, 3))
+ y_pred = np.zeros_like(y_true)
+
+ # tp = [0, 0, 0]
+ # fn = [0, 0, 0]
+ # fp = [0, 0, 0]
+ # support = [0, 0, 0]
+ # |y_hat_i inter y_i | = [0, 0, 0]
+ # |y_i| = [0, 0, 0]
+ # |y_hat_i| = [0, 0, 0]
+
+ with pytest.warns(UndefinedMetricWarning):
+ p, r, f, s = precision_recall_fscore_support(
+ y_true, y_pred, average=None, beta=1
+ )
+
+ assert_array_almost_equal(p, [0, 0, 0], 2)
+ assert_array_almost_equal(r, [0, 0, 0], 2)
+ assert_array_almost_equal(f, [0, 0, 0], 2)
+ assert_array_almost_equal(s, [0, 0, 0], 2)
+
+ with pytest.warns(UndefinedMetricWarning):
+ fbeta = fbeta_score(y_true, y_pred, beta=1, average=None)
+
+ assert_array_almost_equal(fbeta, [0, 0, 0], 2)
+
+
+def test_prf_warnings():
+ # average of per-label scores
+ f, w = precision_recall_fscore_support, UndefinedMetricWarning
+ for average in [None, "weighted", "macro"]:
+ msg = (
+ "Precision and F-score are ill-defined and "
+ "being set to 0.0 in labels with no predicted samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f([0, 1, 2], [1, 1, 2], average=average)
+
+ msg = (
+ "Recall and F-score are ill-defined and "
+ "being set to 0.0 in labels with no true samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f([1, 1, 2], [0, 1, 2], average=average)
+
+ # average of per-sample scores
+ msg = (
+ "Precision and F-score are ill-defined and "
+ "being set to 0.0 in samples with no predicted labels."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f(np.array([[1, 0], [1, 0]]), np.array([[1, 0], [0, 0]]), average="samples")
+
+ msg = (
+ "Recall and F-score are ill-defined and "
+ "being set to 0.0 in samples with no true labels."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f(np.array([[1, 0], [0, 0]]), np.array([[1, 0], [1, 0]]), average="samples")
+
+ # single score: micro-average
+ msg = (
+ "Precision and F-score are ill-defined and "
+ "being set to 0.0 due to no predicted samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f(np.array([[1, 1], [1, 1]]), np.array([[0, 0], [0, 0]]), average="micro")
+
+ msg = (
+ "Recall and F-score are ill-defined and "
+ "being set to 0.0 due to no true samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f(np.array([[0, 0], [0, 0]]), np.array([[1, 1], [1, 1]]), average="micro")
+
+ # single positive label
+ msg = (
+ "Precision and F-score are ill-defined and "
+ "being set to 0.0 due to no predicted samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f([1, 1], [-1, -1], average="binary")
+
+ msg = (
+ "Recall and F-score are ill-defined and "
+ "being set to 0.0 due to no true samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ with pytest.warns(w, match=msg):
+ f([-1, -1], [1, 1], average="binary")
+
+ with warnings.catch_warnings(record=True) as record:
+ warnings.simplefilter("always")
+ precision_recall_fscore_support([0, 0], [0, 0], average="binary")
+ msg = (
+ "Recall and F-score are ill-defined and "
+ "being set to 0.0 due to no true samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ assert str(record.pop().message) == msg
+ msg = (
+ "Precision and F-score are ill-defined and "
+ "being set to 0.0 due to no predicted samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ assert str(record.pop().message) == msg
+
+
+@pytest.mark.parametrize("zero_division", [0, 1])
+def test_prf_no_warnings_if_zero_division_set(zero_division):
+ # average of per-label scores
+ f = precision_recall_fscore_support
+ for average in [None, "weighted", "macro"]:
+ assert_no_warnings(
+ f, [0, 1, 2], [1, 1, 2], average=average, zero_division=zero_division
+ )
+
+ assert_no_warnings(
+ f, [1, 1, 2], [0, 1, 2], average=average, zero_division=zero_division
+ )
+
+ # average of per-sample scores
+ assert_no_warnings(
+ f,
+ np.array([[1, 0], [1, 0]]),
+ np.array([[1, 0], [0, 0]]),
+ average="samples",
+ zero_division=zero_division,
+ )
+
+ assert_no_warnings(
+ f,
+ np.array([[1, 0], [0, 0]]),
+ np.array([[1, 0], [1, 0]]),
+ average="samples",
+ zero_division=zero_division,
+ )
+
+ # single score: micro-average
+ assert_no_warnings(
+ f,
+ np.array([[1, 1], [1, 1]]),
+ np.array([[0, 0], [0, 0]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+
+ assert_no_warnings(
+ f,
+ np.array([[0, 0], [0, 0]]),
+ np.array([[1, 1], [1, 1]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+
+ # single positive label
+ assert_no_warnings(
+ f, [1, 1], [-1, -1], average="binary", zero_division=zero_division
+ )
+
+ assert_no_warnings(
+ f, [-1, -1], [1, 1], average="binary", zero_division=zero_division
+ )
+
+ with warnings.catch_warnings(record=True) as record:
+ warnings.simplefilter("always")
+ precision_recall_fscore_support(
+ [0, 0], [0, 0], average="binary", zero_division=zero_division
+ )
+ assert len(record) == 0
+
+
+@pytest.mark.parametrize("zero_division", ["warn", 0, 1])
+def test_recall_warnings(zero_division):
+ assert_no_warnings(
+ recall_score,
+ np.array([[1, 1], [1, 1]]),
+ np.array([[0, 0], [0, 0]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+ with warnings.catch_warnings(record=True) as record:
+ warnings.simplefilter("always")
+ recall_score(
+ np.array([[0, 0], [0, 0]]),
+ np.array([[1, 1], [1, 1]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+ if zero_division == "warn":
+ assert (
+ str(record.pop().message) == "Recall is ill-defined and "
+ "being set to 0.0 due to no true samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ else:
+ assert len(record) == 0
+
+ recall_score([0, 0], [0, 0])
+ if zero_division == "warn":
+ assert (
+ str(record.pop().message) == "Recall is ill-defined and "
+ "being set to 0.0 due to no true samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+
+
+@pytest.mark.parametrize("zero_division", ["warn", 0, 1])
+def test_precision_warnings(zero_division):
+ with warnings.catch_warnings(record=True) as record:
+ warnings.simplefilter("always")
+ precision_score(
+ np.array([[1, 1], [1, 1]]),
+ np.array([[0, 0], [0, 0]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+ if zero_division == "warn":
+ assert (
+ str(record.pop().message) == "Precision is ill-defined and "
+ "being set to 0.0 due to no predicted samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+ else:
+ assert len(record) == 0
+
+ precision_score([0, 0], [0, 0])
+ if zero_division == "warn":
+ assert (
+ str(record.pop().message) == "Precision is ill-defined and "
+ "being set to 0.0 due to no predicted samples."
+ " Use `zero_division` parameter to control"
+ " this behavior."
+ )
+
+ assert_no_warnings(
+ precision_score,
+ np.array([[0, 0], [0, 0]]),
+ np.array([[1, 1], [1, 1]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+
+
+@pytest.mark.parametrize("zero_division", ["warn", 0, 1])
+def test_fscore_warnings(zero_division):
+ with warnings.catch_warnings(record=True) as record:
+ warnings.simplefilter("always")
+
+ for score in [f1_score, partial(fbeta_score, beta=2)]:
+ score(
+ np.array([[1, 1], [1, 1]]),
+ np.array([[0, 0], [0, 0]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+ assert len(record) == 0
+
+ score(
+ np.array([[0, 0], [0, 0]]),
+ np.array([[1, 1], [1, 1]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+ assert len(record) == 0
+
+ score(
+ np.array([[0, 0], [0, 0]]),
+ np.array([[0, 0], [0, 0]]),
+ average="micro",
+ zero_division=zero_division,
+ )
+ if zero_division == "warn":
+ assert (
+ str(record.pop().message) == "F-score is ill-defined and "
+ "being set to 0.0 due to no true nor predicted "
+ "samples. Use `zero_division` parameter to "
+ "control this behavior."
+ )
+ else:
+ assert len(record) == 0
+
+
+def test_prf_average_binary_data_non_binary():
+ # Error if user does not explicitly set non-binary average mode
+ y_true_mc = [1, 2, 3, 3]
+ y_pred_mc = [1, 2, 3, 1]
+ msg_mc = (
+ r"Target is multiclass but average='binary'. Please "
+ r"choose another average setting, one of \["
+ r"None, 'micro', 'macro', 'weighted'\]."
+ )
+ y_true_ind = np.array([[0, 1, 1], [1, 0, 0], [0, 0, 1]])
+ y_pred_ind = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
+ msg_ind = (
+ r"Target is multilabel-indicator but average='binary'. Please "
+ r"choose another average setting, one of \["
+ r"None, 'micro', 'macro', 'weighted', 'samples'\]."
+ )
+
+ for y_true, y_pred, msg in [
+ (y_true_mc, y_pred_mc, msg_mc),
+ (y_true_ind, y_pred_ind, msg_ind),
+ ]:
+ for metric in [
+ precision_score,
+ recall_score,
+ f1_score,
+ partial(fbeta_score, beta=2),
+ ]:
+ with pytest.raises(ValueError, match=msg):
+ metric(y_true, y_pred)
+
+
+def test__check_targets():
+ # Check that _check_targets correctly merges target types, squeezes
+ # output and fails if input lengths differ.
+ IND = "multilabel-indicator"
+ MC = "multiclass"
+ BIN = "binary"
+ CNT = "continuous"
+ MMC = "multiclass-multioutput"
+ MCN = "continuous-multioutput"
+ # all of length 3
+ EXAMPLES = [
+ (IND, np.array([[0, 1, 1], [1, 0, 0], [0, 0, 1]])),
+ # must not be considered binary
+ (IND, np.array([[0, 1], [1, 0], [1, 1]])),
+ (MC, [2, 3, 1]),
+ (BIN, [0, 1, 1]),
+ (CNT, [0.0, 1.5, 1.0]),
+ (MC, np.array([[2], [3], [1]])),
+ (BIN, np.array([[0], [1], [1]])),
+ (CNT, np.array([[0.0], [1.5], [1.0]])),
+ (MMC, np.array([[0, 2], [1, 3], [2, 3]])),
+ (MCN, np.array([[0.5, 2.0], [1.1, 3.0], [2.0, 3.0]])),
+ ]
+ # expected type given input types, or None for error
+ # (types will be tried in either order)
+ EXPECTED = {
+ (IND, IND): IND,
+ (MC, MC): MC,
+ (BIN, BIN): BIN,
+ (MC, IND): None,
+ (BIN, IND): None,
+ (BIN, MC): MC,
+ # Disallowed types
+ (CNT, CNT): None,
+ (MMC, MMC): None,
+ (MCN, MCN): None,
+ (IND, CNT): None,
+ (MC, CNT): None,
+ (BIN, CNT): None,
+ (MMC, CNT): None,
+ (MCN, CNT): None,
+ (IND, MMC): None,
+ (MC, MMC): None,
+ (BIN, MMC): None,
+ (MCN, MMC): None,
+ (IND, MCN): None,
+ (MC, MCN): None,
+ (BIN, MCN): None,
+ }
+
+ for (type1, y1), (type2, y2) in product(EXAMPLES, repeat=2):
+ try:
+ expected = EXPECTED[type1, type2]
+ except KeyError:
+ expected = EXPECTED[type2, type1]
+ if expected is None:
+ with pytest.raises(ValueError):
+ _check_targets(y1, y2)
+
+ if type1 != type2:
+ err_msg = (
+ "Classification metrics can't handle a mix "
+ "of {0} and {1} targets".format(type1, type2)
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ _check_targets(y1, y2)
+
+ else:
+ if type1 not in (BIN, MC, IND):
+ err_msg = "{0} is not supported".format(type1)
+ with pytest.raises(ValueError, match=err_msg):
+ _check_targets(y1, y2)
+
+ else:
+ merged_type, y1out, y2out = _check_targets(y1, y2)
+ assert merged_type == expected
+ if merged_type.startswith("multilabel"):
+ assert y1out.format == "csr"
+ assert y2out.format == "csr"
+ else:
+ assert_array_equal(y1out, np.squeeze(y1))
+ assert_array_equal(y2out, np.squeeze(y2))
+ with pytest.raises(ValueError):
+ _check_targets(y1[:-1], y2)
+
+ # Make sure seq of seq is not supported
+ y1 = [(1, 2), (0, 2, 3)]
+ y2 = [(2,), (0, 2)]
+ msg = (
+ "You appear to be using a legacy multi-label data representation. "
+ "Sequence of sequences are no longer supported; use a binary array"
+ " or sparse matrix instead - the MultiLabelBinarizer"
+ " transformer can convert to this format."
+ )
+ with pytest.raises(ValueError, match=msg):
+ _check_targets(y1, y2)
+
+
+def test__check_targets_multiclass_with_both_y_true_and_y_pred_binary():
+ # https://github.com/scikit-learn/scikit-learn/issues/8098
+ y_true = [0, 1]
+ y_pred = [0, -1]
+ assert _check_targets(y_true, y_pred)[0] == "multiclass"
+
+
+def test_hinge_loss_binary():
+ y_true = np.array([-1, 1, 1, -1])
+ pred_decision = np.array([-8.5, 0.5, 1.5, -0.3])
+ assert hinge_loss(y_true, pred_decision) == 1.2 / 4
+
+ y_true = np.array([0, 2, 2, 0])
+ pred_decision = np.array([-8.5, 0.5, 1.5, -0.3])
+ assert hinge_loss(y_true, pred_decision) == 1.2 / 4
+
+
+def test_hinge_loss_multiclass():
+ pred_decision = np.array(
+ [
+ [+0.36, -0.17, -0.58, -0.99],
+ [-0.54, -0.37, -0.48, -0.58],
+ [-1.45, -0.58, -0.38, -0.17],
+ [-0.54, -0.38, -0.48, -0.58],
+ [-2.36, -0.79, -0.27, +0.24],
+ [-1.45, -0.58, -0.38, -0.17],
+ ]
+ )
+ y_true = np.array([0, 1, 2, 1, 3, 2])
+ dummy_losses = np.array(
+ [
+ 1 - pred_decision[0][0] + pred_decision[0][1],
+ 1 - pred_decision[1][1] + pred_decision[1][2],
+ 1 - pred_decision[2][2] + pred_decision[2][3],
+ 1 - pred_decision[3][1] + pred_decision[3][2],
+ 1 - pred_decision[4][3] + pred_decision[4][2],
+ 1 - pred_decision[5][2] + pred_decision[5][3],
+ ]
+ )
+ np.clip(dummy_losses, 0, None, out=dummy_losses)
+ dummy_hinge_loss = np.mean(dummy_losses)
+ assert hinge_loss(y_true, pred_decision) == dummy_hinge_loss
+
+
+def test_hinge_loss_multiclass_missing_labels_with_labels_none():
+ y_true = np.array([0, 1, 2, 2])
+ pred_decision = np.array(
+ [
+ [+1.27, 0.034, -0.68, -1.40],
+ [-1.45, -0.58, -0.38, -0.17],
+ [-2.36, -0.79, -0.27, +0.24],
+ [-2.36, -0.79, -0.27, +0.24],
+ ]
+ )
+ error_message = (
+ "Please include all labels in y_true or pass labels as third argument"
+ )
+ with pytest.raises(ValueError, match=error_message):
+ hinge_loss(y_true, pred_decision)
+
+
+def test_hinge_loss_multiclass_no_consistent_pred_decision_shape():
+ # test for inconsistency between multiclass problem and pred_decision
+ # argument
+ y_true = np.array([2, 1, 0, 1, 0, 1, 1])
+ pred_decision = np.array([0, 1, 2, 1, 0, 2, 1])
+ error_message = (
+ "The shape of pred_decision cannot be 1d array"
+ "with a multiclass target. pred_decision shape "
+ "must be (n_samples, n_classes), that is "
+ "(7, 3). Got: (7,)"
+ )
+ with pytest.raises(ValueError, match=re.escape(error_message)):
+ hinge_loss(y_true=y_true, pred_decision=pred_decision)
+
+ # test for inconsistency between pred_decision shape and labels number
+ pred_decision = np.array([[0, 1], [0, 1], [0, 1], [0, 1], [2, 0], [0, 1], [1, 0]])
+ labels = [0, 1, 2]
+ error_message = (
+ "The shape of pred_decision is not "
+ "consistent with the number of classes. "
+ "With a multiclass target, pred_decision "
+ "shape must be (n_samples, n_classes), that is "
+ "(7, 3). Got: (7, 2)"
+ )
+ with pytest.raises(ValueError, match=re.escape(error_message)):
+ hinge_loss(y_true=y_true, pred_decision=pred_decision, labels=labels)
+
+
+def test_hinge_loss_multiclass_with_missing_labels():
+ pred_decision = np.array(
+ [
+ [+0.36, -0.17, -0.58, -0.99],
+ [-0.55, -0.38, -0.48, -0.58],
+ [-1.45, -0.58, -0.38, -0.17],
+ [-0.55, -0.38, -0.48, -0.58],
+ [-1.45, -0.58, -0.38, -0.17],
+ ]
+ )
+ y_true = np.array([0, 1, 2, 1, 2])
+ labels = np.array([0, 1, 2, 3])
+ dummy_losses = np.array(
+ [
+ 1 - pred_decision[0][0] + pred_decision[0][1],
+ 1 - pred_decision[1][1] + pred_decision[1][2],
+ 1 - pred_decision[2][2] + pred_decision[2][3],
+ 1 - pred_decision[3][1] + pred_decision[3][2],
+ 1 - pred_decision[4][2] + pred_decision[4][3],
+ ]
+ )
+ np.clip(dummy_losses, 0, None, out=dummy_losses)
+ dummy_hinge_loss = np.mean(dummy_losses)
+ assert hinge_loss(y_true, pred_decision, labels=labels) == dummy_hinge_loss
+
+
+def test_hinge_loss_multiclass_missing_labels_only_two_unq_in_y_true():
+ # non-regression test for:
+ # https://github.com/scikit-learn/scikit-learn/issues/17630
+ # check that we can compute the hinge loss when providing an array
+ # with labels allowing to not have all labels in y_true
+ pred_decision = np.array(
+ [
+ [+0.36, -0.17, -0.58],
+ [-0.15, -0.58, -0.48],
+ [-1.45, -0.58, -0.38],
+ [-0.55, -0.78, -0.42],
+ [-1.45, -0.58, -0.38],
+ ]
+ )
+ y_true = np.array([0, 2, 2, 0, 2])
+ labels = np.array([0, 1, 2])
+ dummy_losses = np.array(
+ [
+ 1 - pred_decision[0][0] + pred_decision[0][1],
+ 1 - pred_decision[1][2] + pred_decision[1][0],
+ 1 - pred_decision[2][2] + pred_decision[2][1],
+ 1 - pred_decision[3][0] + pred_decision[3][2],
+ 1 - pred_decision[4][2] + pred_decision[4][1],
+ ]
+ )
+ np.clip(dummy_losses, 0, None, out=dummy_losses)
+ dummy_hinge_loss = np.mean(dummy_losses)
+ assert_almost_equal(
+ hinge_loss(y_true, pred_decision, labels=labels), dummy_hinge_loss
+ )
+
+
+def test_hinge_loss_multiclass_invariance_lists():
+ # Currently, invariance of string and integer labels cannot be tested
+ # in common invariance tests because invariance tests for multiclass
+ # decision functions is not implemented yet.
+ y_true = ["blue", "green", "red", "green", "white", "red"]
+ pred_decision = [
+ [+0.36, -0.17, -0.58, -0.99],
+ [-0.55, -0.38, -0.48, -0.58],
+ [-1.45, -0.58, -0.38, -0.17],
+ [-0.55, -0.38, -0.48, -0.58],
+ [-2.36, -0.79, -0.27, +0.24],
+ [-1.45, -0.58, -0.38, -0.17],
+ ]
+ dummy_losses = np.array(
+ [
+ 1 - pred_decision[0][0] + pred_decision[0][1],
+ 1 - pred_decision[1][1] + pred_decision[1][2],
+ 1 - pred_decision[2][2] + pred_decision[2][3],
+ 1 - pred_decision[3][1] + pred_decision[3][2],
+ 1 - pred_decision[4][3] + pred_decision[4][2],
+ 1 - pred_decision[5][2] + pred_decision[5][3],
+ ]
+ )
+ np.clip(dummy_losses, 0, None, out=dummy_losses)
+ dummy_hinge_loss = np.mean(dummy_losses)
+ assert hinge_loss(y_true, pred_decision) == dummy_hinge_loss
+
+
+def test_log_loss():
+ # binary case with symbolic labels ("no" < "yes")
+ y_true = ["no", "no", "no", "yes", "yes", "yes"]
+ y_pred = np.array(
+ [[0.5, 0.5], [0.1, 0.9], [0.01, 0.99], [0.9, 0.1], [0.75, 0.25], [0.001, 0.999]]
+ )
+ loss = log_loss(y_true, y_pred)
+ loss_true = -np.mean(bernoulli.logpmf(np.array(y_true) == "yes", y_pred[:, 1]))
+ assert_almost_equal(loss, loss_true)
+
+ # multiclass case; adapted from http://bit.ly/RJJHWA
+ y_true = [1, 0, 2]
+ y_pred = [[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]
+ loss = log_loss(y_true, y_pred, normalize=True)
+ assert_almost_equal(loss, 0.6904911)
+
+ # check that we got all the shapes and axes right
+ # by doubling the length of y_true and y_pred
+ y_true *= 2
+ y_pred *= 2
+ loss = log_loss(y_true, y_pred, normalize=False)
+ assert_almost_equal(loss, 0.6904911 * 6, decimal=6)
+
+ # check eps and handling of absolute zero and one probabilities
+ y_pred = np.asarray(y_pred) > 0.5
+ loss = log_loss(y_true, y_pred, normalize=True, eps=0.1)
+ assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, 0.1, 0.9)))
+
+ # binary case: check correct boundary values for eps = 0
+ assert log_loss([0, 1], [0, 1], eps=0) == 0
+ assert log_loss([0, 1], [0, 0], eps=0) == np.inf
+ assert log_loss([0, 1], [1, 1], eps=0) == np.inf
+
+ # multiclass case: check correct boundary values for eps = 0
+ assert log_loss([0, 1, 2], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], eps=0) == 0
+ assert log_loss([0, 1, 2], [[0, 0.5, 0.5], [0, 1, 0], [0, 0, 1]], eps=0) == np.inf
+
+ # raise error if number of classes are not equal.
+ y_true = [1, 0, 2]
+ y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]
+ with pytest.raises(ValueError):
+ log_loss(y_true, y_pred)
+
+ # case when y_true is a string array object
+ y_true = ["ham", "spam", "spam", "ham"]
+ y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]]
+ loss = log_loss(y_true, y_pred)
+ assert_almost_equal(loss, 1.0383217, decimal=6)
+
+ # test labels option
+
+ y_true = [2, 2]
+ y_pred = [[0.2, 0.7], [0.6, 0.5]]
+ y_score = np.array([[0.1, 0.9], [0.1, 0.9]])
+ error_str = (
+ r"y_true contains only one label \(2\). Please provide "
+ r"the true labels explicitly through the labels argument."
+ )
+ with pytest.raises(ValueError, match=error_str):
+ log_loss(y_true, y_pred)
+
+ y_pred = [[0.2, 0.7], [0.6, 0.5], [0.2, 0.3]]
+ error_str = "Found input variables with inconsistent numbers of samples: [3, 2]"
+ (ValueError, error_str, log_loss, y_true, y_pred)
+
+ # works when the labels argument is used
+
+ true_log_loss = -np.mean(np.log(y_score[:, 1]))
+ calculated_log_loss = log_loss(y_true, y_score, labels=[1, 2])
+ assert_almost_equal(calculated_log_loss, true_log_loss)
+
+ # ensure labels work when len(np.unique(y_true)) != y_pred.shape[1]
+ y_true = [1, 2, 2]
+ y_score2 = [[0.2, 0.7, 0.3], [0.6, 0.5, 0.3], [0.3, 0.9, 0.1]]
+ loss = log_loss(y_true, y_score2, labels=[1, 2, 3])
+ assert_almost_equal(loss, 1.0630345, decimal=6)
+
+
+def test_log_loss_eps_auto(global_dtype):
+ """Check the behaviour of `eps="auto"` that changes depending on the input
+ array dtype.
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/24315
+ """
+ y_true = np.array([0, 1], dtype=global_dtype)
+ y_pred = y_true.copy()
+
+ loss = log_loss(y_true, y_pred, eps="auto")
+ assert np.isfinite(loss)
+
+
+def test_log_loss_eps_auto_float16():
+ """Check the behaviour of `eps="auto"` for np.float16"""
+ y_true = np.array([0, 1], dtype=np.float16)
+ y_pred = y_true.copy()
+
+ loss = log_loss(y_true, y_pred, eps="auto")
+ assert np.isfinite(loss)
+
+
+def test_log_loss_pandas_input():
+ # case when input is a pandas series and dataframe gh-5715
+ y_tr = np.array(["ham", "spam", "spam", "ham"])
+ y_pr = np.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]])
+ types = [(MockDataFrame, MockDataFrame)]
+ try:
+ from modin.pandas import Series, DataFrame
+
+ types.append((Series, DataFrame))
+ except ImportError:
+ pass
+ for TrueInputType, PredInputType in types:
+ # y_pred dataframe, y_true series
+ y_true, y_pred = TrueInputType(y_tr), PredInputType(y_pr)
+ loss = log_loss(y_true, y_pred)
+ assert_almost_equal(loss, 1.0383217, decimal=6)
+
+
+def test_brier_score_loss():
+ # Check brier_score_loss function
+ y_true = np.array([0, 1, 1, 0, 1, 1])
+ y_pred = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95])
+ true_score = linalg.norm(y_true - y_pred) ** 2 / len(y_true)
+
+ assert_almost_equal(brier_score_loss(y_true, y_true), 0.0)
+ assert_almost_equal(brier_score_loss(y_true, y_pred), true_score)
+ assert_almost_equal(brier_score_loss(1.0 + y_true, y_pred), true_score)
+ assert_almost_equal(brier_score_loss(2 * y_true - 1, y_pred), true_score)
+ with pytest.raises(ValueError):
+ brier_score_loss(y_true, y_pred[1:])
+ with pytest.raises(ValueError):
+ brier_score_loss(y_true, y_pred + 1.0)
+ with pytest.raises(ValueError):
+ brier_score_loss(y_true, y_pred - 1.0)
+
+ # ensure to raise an error for multiclass y_true
+ y_true = np.array([0, 1, 2, 0])
+ y_pred = np.array([0.8, 0.6, 0.4, 0.2])
+ error_message = (
+ "Only binary classification is supported. The type of the target is multiclass"
+ )
+
+ with pytest.raises(ValueError, match=error_message):
+ brier_score_loss(y_true, y_pred)
+
+ # calculate correctly when there's only one class in y_true
+ assert_almost_equal(brier_score_loss([-1], [0.4]), 0.16)
+ assert_almost_equal(brier_score_loss([0], [0.4]), 0.16)
+ assert_almost_equal(brier_score_loss([1], [0.4]), 0.36)
+ assert_almost_equal(brier_score_loss(["foo"], [0.4], pos_label="bar"), 0.16)
+ assert_almost_equal(brier_score_loss(["foo"], [0.4], pos_label="foo"), 0.36)
+
+
+def test_balanced_accuracy_score_unseen():
+ msg = "y_pred contains classes not in y_true"
+ with pytest.warns(UserWarning, match=msg):
+ balanced_accuracy_score([0, 0, 0], [0, 0, 1])
+
+
+@pytest.mark.parametrize(
+ "y_true,y_pred",
+ [
+ (["a", "b", "a", "b"], ["a", "a", "a", "b"]),
+ (["a", "b", "c", "b"], ["a", "a", "a", "b"]),
+ (["a", "a", "a", "b"], ["a", "b", "c", "b"]),
+ ],
+)
+def test_balanced_accuracy_score(y_true, y_pred):
+ macro_recall = recall_score(
+ y_true, y_pred, average="macro", labels=np.unique(y_true)
+ )
+ with ignore_warnings():
+ # Warnings are tested in test_balanced_accuracy_score_unseen
+ balanced = balanced_accuracy_score(y_true, y_pred)
+ assert balanced == pytest.approx(macro_recall)
+ adjusted = balanced_accuracy_score(y_true, y_pred, adjusted=True)
+ chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[0]))
+ assert adjusted == (balanced - chance) / (1 - chance)
diff --git a/modin/pandas/test/interoperability/sklearn/model_selection/test_search.py b/modin/pandas/test/interoperability/sklearn/model_selection/test_search.py
new file mode 100644
index 00000000000..ef67616d1ed
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/model_selection/test_search.py
@@ -0,0 +1,2430 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+"""Test the search module"""
+
+from collections.abc import Iterable, Sized
+from io import StringIO
+from itertools import chain, product
+from functools import partial
+import pickle
+import sys
+from types import GeneratorType
+import re
+
+import numpy as np
+import scipy.sparse as sp
+import pytest
+
+from sklearn.utils._testing import (
+ assert_array_equal,
+ assert_array_almost_equal,
+ assert_allclose,
+ assert_almost_equal,
+ ignore_warnings,
+ MinimalClassifier,
+ MinimalRegressor,
+ MinimalTransformer,
+)
+from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
+
+from scipy.stats import bernoulli, expon, uniform
+
+from sklearn.base import BaseEstimator, ClassifierMixin
+from sklearn.base import is_classifier
+from sklearn.datasets import make_classification
+from sklearn.datasets import make_blobs
+from sklearn.datasets import make_multilabel_classification
+
+from sklearn.model_selection import train_test_split
+from sklearn.model_selection import KFold
+from sklearn.model_selection import StratifiedKFold
+from sklearn.model_selection import StratifiedShuffleSplit
+from sklearn.model_selection import LeaveOneGroupOut
+from sklearn.model_selection import LeavePGroupsOut
+from sklearn.model_selection import GroupKFold
+from sklearn.model_selection import GroupShuffleSplit
+from sklearn.model_selection import GridSearchCV
+from sklearn.model_selection import RandomizedSearchCV
+from sklearn.model_selection import ParameterGrid
+from sklearn.model_selection import ParameterSampler
+from sklearn.model_selection._search import BaseSearchCV
+
+from sklearn.model_selection._validation import FitFailedWarning
+
+from sklearn.svm import LinearSVC, SVC
+from sklearn.tree import DecisionTreeRegressor
+from sklearn.tree import DecisionTreeClassifier
+from sklearn.cluster import KMeans
+from sklearn.neighbors import KernelDensity
+from sklearn.neighbors import LocalOutlierFactor
+from sklearn.neighbors import KNeighborsClassifier
+from sklearn.metrics import f1_score
+from sklearn.metrics import recall_score
+from sklearn.metrics import accuracy_score
+from sklearn.metrics import make_scorer
+from sklearn.metrics import roc_auc_score
+from sklearn.metrics import confusion_matrix
+from sklearn.metrics import r2_score
+from sklearn.metrics.pairwise import euclidean_distances
+from sklearn.impute import SimpleImputer
+from sklearn.pipeline import Pipeline
+from sklearn.linear_model import Ridge, SGDClassifier, LinearRegression
+from sklearn.ensemble import HistGradientBoostingClassifier
+
+from sklearn.model_selection.tests.common import OneTimeSplitter
+
+
+# Neither of the following two estimators inherit from BaseEstimator,
+# to test hyperparameter search on user-defined classifiers.
+class MockClassifier:
+ """Dummy classifier to test the parameter search algorithms"""
+
+ def __init__(self, foo_param=0):
+ self.foo_param = foo_param
+
+ def fit(self, X, Y):
+ assert len(X) == len(Y)
+ self.classes_ = np.unique(Y)
+ return self
+
+ def predict(self, T):
+ return T.shape[0]
+
+ def transform(self, X):
+ return X + self.foo_param
+
+ def inverse_transform(self, X):
+ return X - self.foo_param
+
+ predict_proba = predict
+ predict_log_proba = predict
+ decision_function = predict
+
+ def score(self, X=None, Y=None):
+ if self.foo_param > 1:
+ score = 1.0
+ else:
+ score = 0.0
+ return score
+
+ def get_params(self, deep=False):
+ return {"foo_param": self.foo_param}
+
+ def set_params(self, **params):
+ self.foo_param = params["foo_param"]
+ return self
+
+
+class LinearSVCNoScore(LinearSVC):
+ """A LinearSVC classifier that has no score method."""
+
+ @property
+ def score(self):
+ raise AttributeError
+
+
+X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
+y = np.array([1, 1, 2, 2])
+
+
+def assert_grid_iter_equals_getitem(grid):
+ assert list(grid) == [grid[i] for i in range(len(grid))]
+
+
+@pytest.mark.parametrize("klass", [ParameterGrid, partial(ParameterSampler, n_iter=10)])
+@pytest.mark.parametrize(
+ "input, error_type, error_message",
+ [
+ (0, TypeError, r"Parameter .* a dict or a list, got: 0 of type int"),
+ ([{"foo": [0]}, 0], TypeError, r"Parameter .* is not a dict \(0\)"),
+ (
+ {"foo": 0},
+ TypeError,
+ r"Parameter (grid|distribution) for parameter 'foo' (is not|needs to be) "
+ r"(a list or a numpy array|iterable or a distribution).*",
+ ),
+ ],
+)
+def test_validate_parameter_input(klass, input, error_type, error_message):
+ with pytest.raises(error_type, match=error_message):
+ klass(input)
+
+
+def test_parameter_grid():
+ # Test basic properties of ParameterGrid.
+ params1 = {"foo": [1, 2, 3]}
+ grid1 = ParameterGrid(params1)
+ assert isinstance(grid1, Iterable)
+ assert isinstance(grid1, Sized)
+ assert len(grid1) == 3
+ assert_grid_iter_equals_getitem(grid1)
+
+ params2 = {"foo": [4, 2], "bar": ["ham", "spam", "eggs"]}
+ grid2 = ParameterGrid(params2)
+ assert len(grid2) == 6
+
+ # loop to assert we can iterate over the grid multiple times
+ for i in range(2):
+ # tuple + chain transforms {"a": 1, "b": 2} to ("a", 1, "b", 2)
+ points = set(tuple(chain(*(sorted(p.items())))) for p in grid2)
+ assert points == set(
+ ("bar", x, "foo", y) for x, y in product(params2["bar"], params2["foo"])
+ )
+ assert_grid_iter_equals_getitem(grid2)
+
+ # Special case: empty grid (useful to get default estimator settings)
+ empty = ParameterGrid({})
+ assert len(empty) == 1
+ assert list(empty) == [{}]
+ assert_grid_iter_equals_getitem(empty)
+ with pytest.raises(IndexError):
+ empty[1]
+
+ has_empty = ParameterGrid([{"C": [1, 10]}, {}, {"C": [0.5]}])
+ assert len(has_empty) == 4
+ assert list(has_empty) == [{"C": 1}, {"C": 10}, {}, {"C": 0.5}]
+ assert_grid_iter_equals_getitem(has_empty)
+
+
+def test_grid_search():
+ # Test that the best estimator contains the right value for foo_param
+ clf = MockClassifier()
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3)
+ # make sure it selects the smallest parameter in case of ties
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ grid_search.fit(X, y)
+ sys.stdout = old_stdout
+ assert grid_search.best_estimator_.foo_param == 2
+
+ assert_array_equal(grid_search.cv_results_["param_foo_param"].data, [1, 2, 3])
+
+ # Smoke test the score etc:
+ grid_search.score(X, y)
+ grid_search.predict_proba(X)
+ grid_search.decision_function(X)
+ grid_search.transform(X)
+
+ # Test exception handling on scoring
+ grid_search.scoring = "sklearn"
+ with pytest.raises(ValueError):
+ grid_search.fit(X, y)
+
+
+def test_grid_search_pipeline_steps():
+ # check that parameters that are estimators are cloned before fitting
+ pipe = Pipeline([("regressor", LinearRegression())])
+ param_grid = {"regressor": [LinearRegression(), Ridge()]}
+ grid_search = GridSearchCV(pipe, param_grid, cv=2)
+ grid_search.fit(X, y)
+ regressor_results = grid_search.cv_results_["param_regressor"]
+ assert isinstance(regressor_results[0], LinearRegression)
+ assert isinstance(regressor_results[1], Ridge)
+ assert not hasattr(regressor_results[0], "coef_")
+ assert not hasattr(regressor_results[1], "coef_")
+ assert regressor_results[0] is not grid_search.best_estimator_
+ assert regressor_results[1] is not grid_search.best_estimator_
+ # check that we didn't modify the parameter grid that was passed
+ assert not hasattr(param_grid["regressor"][0], "coef_")
+ assert not hasattr(param_grid["regressor"][1], "coef_")
+
+
+@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
+def test_SearchCV_with_fit_params(SearchCV):
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+ clf = CheckingClassifier(expected_fit_params=["spam", "eggs"])
+ searcher = SearchCV(clf, {"foo_param": [1, 2, 3]}, cv=2, error_score="raise")
+
+ # The CheckingClassifier generates an assertion error if
+ # a parameter is missing or has length != len(X).
+ err_msg = r"Expected fit parameter\(s\) \['eggs'\] not seen."
+ with pytest.raises(AssertionError, match=err_msg):
+ searcher.fit(X, y, spam=np.ones(10))
+
+ err_msg = "Fit parameter spam has length 1; expected"
+ with pytest.raises(AssertionError, match=err_msg):
+ searcher.fit(X, y, spam=np.ones(1), eggs=np.zeros(10))
+ searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))
+
+
+@ignore_warnings
+def test_grid_search_no_score():
+ # Test grid-search on classifier that has no score function.
+ clf = LinearSVC(random_state=0)
+ X, y = make_blobs(random_state=0, centers=2)
+ Cs = [0.1, 1, 10]
+ clf_no_score = LinearSVCNoScore(random_state=0)
+ grid_search = GridSearchCV(clf, {"C": Cs}, scoring="accuracy")
+ grid_search.fit(X, y)
+
+ grid_search_no_score = GridSearchCV(clf_no_score, {"C": Cs}, scoring="accuracy")
+ # smoketest grid search
+ grid_search_no_score.fit(X, y)
+
+ # check that best params are equal
+ assert grid_search_no_score.best_params_ == grid_search.best_params_
+ # check that we can call score and that it gives the correct result
+ assert grid_search.score(X, y) == grid_search_no_score.score(X, y)
+
+ # giving no scoring function raises an error
+ grid_search_no_score = GridSearchCV(clf_no_score, {"C": Cs})
+ with pytest.raises(TypeError, match="no scoring"):
+ grid_search_no_score.fit([[1]])
+
+
+def test_grid_search_score_method():
+ X, y = make_classification(n_samples=100, n_classes=2, flip_y=0.2, random_state=0)
+ clf = LinearSVC(random_state=0)
+ grid = {"C": [0.1]}
+
+ search_no_scoring = GridSearchCV(clf, grid, scoring=None).fit(X, y)
+ search_accuracy = GridSearchCV(clf, grid, scoring="accuracy").fit(X, y)
+ search_no_score_method_auc = GridSearchCV(
+ LinearSVCNoScore(), grid, scoring="roc_auc"
+ ).fit(X, y)
+ search_auc = GridSearchCV(clf, grid, scoring="roc_auc").fit(X, y)
+
+ # Check warning only occurs in situation where behavior changed:
+ # estimator requires score method to compete with scoring parameter
+ score_no_scoring = search_no_scoring.score(X, y)
+ score_accuracy = search_accuracy.score(X, y)
+ score_no_score_auc = search_no_score_method_auc.score(X, y)
+ score_auc = search_auc.score(X, y)
+
+ # ensure the test is sane
+ assert score_auc < 1.0
+ assert score_accuracy < 1.0
+ assert score_auc != score_accuracy
+
+ assert_almost_equal(score_accuracy, score_no_scoring)
+ assert_almost_equal(score_auc, score_no_score_auc)
+
+
+def test_grid_search_groups():
+ # Check if ValueError (when groups is None) propagates to GridSearchCV
+ # And also check if groups is correctly passed to the cv object
+ rng = np.random.RandomState(0)
+
+ X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
+ groups = rng.randint(0, 3, 15)
+
+ clf = LinearSVC(random_state=0)
+ grid = {"C": [1]}
+
+ group_cvs = [
+ LeaveOneGroupOut(),
+ LeavePGroupsOut(2),
+ GroupKFold(n_splits=3),
+ GroupShuffleSplit(),
+ ]
+ error_msg = "The 'groups' parameter should not be None."
+ for cv in group_cvs:
+ gs = GridSearchCV(clf, grid, cv=cv)
+ with pytest.raises(ValueError, match=error_msg):
+ gs.fit(X, y)
+ gs.fit(X, y, groups=groups)
+
+ non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
+ for cv in non_group_cvs:
+ gs = GridSearchCV(clf, grid, cv=cv)
+ # Should not raise an error
+ gs.fit(X, y)
+
+
+def test_classes__property():
+ # Test that classes_ property matches best_estimator_.classes_
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+ Cs = [0.1, 1, 10]
+
+ grid_search = GridSearchCV(LinearSVC(random_state=0), {"C": Cs})
+ grid_search.fit(X, y)
+ assert_array_equal(grid_search.best_estimator_.classes_, grid_search.classes_)
+
+ # Test that regressors do not have a classes_ attribute
+ grid_search = GridSearchCV(Ridge(), {"alpha": [1.0, 2.0]})
+ grid_search.fit(X, y)
+ assert not hasattr(grid_search, "classes_")
+
+ # Test that the grid searcher has no classes_ attribute before it's fit
+ grid_search = GridSearchCV(LinearSVC(random_state=0), {"C": Cs})
+ assert not hasattr(grid_search, "classes_")
+
+ # Test that the grid searcher has no classes_ attribute without a refit
+ grid_search = GridSearchCV(LinearSVC(random_state=0), {"C": Cs}, refit=False)
+ grid_search.fit(X, y)
+ assert not hasattr(grid_search, "classes_")
+
+
+def test_trivial_cv_results_attr():
+ # Test search over a "grid" with only one point.
+ clf = MockClassifier()
+ grid_search = GridSearchCV(clf, {"foo_param": [1]}, cv=3)
+ grid_search.fit(X, y)
+ assert hasattr(grid_search, "cv_results_")
+
+ random_search = RandomizedSearchCV(clf, {"foo_param": [0]}, n_iter=1, cv=3)
+ random_search.fit(X, y)
+ assert hasattr(grid_search, "cv_results_")
+
+
+def test_no_refit():
+ # Test that GSCV can be used for model selection alone without refitting
+ clf = MockClassifier()
+ for scoring in [None, ["accuracy", "precision"]]:
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, refit=False, cv=3)
+ grid_search.fit(X, y)
+ assert (
+ not hasattr(grid_search, "best_estimator_")
+ and hasattr(grid_search, "best_index_")
+ and hasattr(grid_search, "best_params_")
+ )
+
+ # Make sure the functions predict/transform etc raise meaningful
+ # error messages
+ for fn_name in (
+ "predict",
+ "predict_proba",
+ "predict_log_proba",
+ "transform",
+ "inverse_transform",
+ ):
+ error_msg = (
+ f"`refit=False`. {fn_name} is available only after "
+ "refitting on the best parameters"
+ )
+ with pytest.raises(AttributeError, match=error_msg):
+ getattr(grid_search, fn_name)(X)
+
+ # Test that an invalid refit param raises appropriate error messages
+ error_msg = (
+ "For multi-metric scoring, the parameter refit must be set to a scorer key"
+ )
+ for refit in [True, "recall", "accuracy"]:
+ with pytest.raises(ValueError, match=error_msg):
+ GridSearchCV(
+ clf, {}, refit=refit, scoring={"acc": "accuracy", "prec": "precision"}
+ ).fit(X, y)
+
+
+def test_grid_search_error():
+ # Test that grid search will capture errors on data with different length
+ X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
+
+ clf = LinearSVC()
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
+ with pytest.raises(ValueError):
+ cv.fit(X_[:180], y_)
+
+
+def test_grid_search_one_grid_point():
+ X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
+ param_dict = {"C": [1.0], "kernel": ["rbf"], "gamma": [0.1]}
+
+ clf = SVC(gamma="auto")
+ cv = GridSearchCV(clf, param_dict)
+ cv.fit(X_, y_)
+
+ clf = SVC(C=1.0, kernel="rbf", gamma=0.1)
+ clf.fit(X_, y_)
+
+ assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_)
+
+
+def test_grid_search_when_param_grid_includes_range():
+ # Test that the best estimator contains the right value for foo_param
+ clf = MockClassifier()
+ grid_search = None
+ grid_search = GridSearchCV(clf, {"foo_param": range(1, 4)}, cv=3)
+ grid_search.fit(X, y)
+ assert grid_search.best_estimator_.foo_param == 2
+
+
+def test_grid_search_bad_param_grid():
+ X, y = make_classification(n_samples=10, n_features=5, random_state=0)
+ param_dict = {"C": 1}
+ clf = SVC(gamma="auto")
+ error_msg = re.escape(
+ "Parameter grid for parameter 'C' needs to be a list or "
+ "a numpy array, but got 1 (of type int) instead. Single "
+ "values need to be wrapped in a list with one element."
+ )
+ search = GridSearchCV(clf, param_dict)
+ with pytest.raises(TypeError, match=error_msg):
+ search.fit(X, y)
+
+ param_dict = {"C": []}
+ clf = SVC()
+ error_msg = re.escape(
+ "Parameter grid for parameter 'C' need to be a non-empty sequence, got: []"
+ )
+ search = GridSearchCV(clf, param_dict)
+ with pytest.raises(ValueError, match=error_msg):
+ search.fit(X, y)
+
+ param_dict = {"C": "1,2,3"}
+ clf = SVC(gamma="auto")
+ error_msg = re.escape(
+ "Parameter grid for parameter 'C' needs to be a list or a numpy array, "
+ "but got '1,2,3' (of type str) instead. Single values need to be "
+ "wrapped in a list with one element."
+ )
+ search = GridSearchCV(clf, param_dict)
+ with pytest.raises(TypeError, match=error_msg):
+ search.fit(X, y)
+
+ param_dict = {"C": np.ones((3, 2))}
+ clf = SVC()
+ search = GridSearchCV(clf, param_dict)
+ with pytest.raises(ValueError):
+ search.fit(X, y)
+
+
+def test_grid_search_sparse():
+ # Test that grid search works with both dense and sparse matrices
+ X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
+
+ clf = LinearSVC()
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
+ cv.fit(X_[:180], y_[:180])
+ y_pred = cv.predict(X_[180:])
+ C = cv.best_estimator_.C
+
+ X_ = sp.csr_matrix(X_)
+ clf = LinearSVC()
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
+ cv.fit(X_[:180].tocoo(), y_[:180])
+ y_pred2 = cv.predict(X_[180:])
+ C2 = cv.best_estimator_.C
+
+ assert np.mean(y_pred == y_pred2) >= 0.9
+ assert C == C2
+
+
+def test_grid_search_sparse_scoring():
+ X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
+
+ clf = LinearSVC()
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]}, scoring="f1")
+ cv.fit(X_[:180], y_[:180])
+ y_pred = cv.predict(X_[180:])
+ C = cv.best_estimator_.C
+
+ X_ = sp.csr_matrix(X_)
+ clf = LinearSVC()
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]}, scoring="f1")
+ cv.fit(X_[:180], y_[:180])
+ y_pred2 = cv.predict(X_[180:])
+ C2 = cv.best_estimator_.C
+
+ assert_array_equal(y_pred, y_pred2)
+ assert C == C2
+ # Smoke test the score
+ # np.testing.assert_allclose(f1_score(cv.predict(X_[:180]), y[:180]),
+ # cv.score(X_[:180], y[:180]))
+
+ # test loss where greater is worse
+ def f1_loss(y_true_, y_pred_):
+ return -f1_score(y_true_, y_pred_)
+
+ F1Loss = make_scorer(f1_loss, greater_is_better=False)
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]}, scoring=F1Loss)
+ cv.fit(X_[:180], y_[:180])
+ y_pred3 = cv.predict(X_[180:])
+ C3 = cv.best_estimator_.C
+
+ assert C == C3
+ assert_array_equal(y_pred, y_pred3)
+
+
+def test_grid_search_precomputed_kernel():
+ # Test that grid search works when the input features are given in the
+ # form of a precomputed kernel matrix
+ X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
+
+ # compute the training kernel matrix corresponding to the linear kernel
+ K_train = np.dot(X_[:180], X_[:180].T)
+ y_train = y_[:180]
+
+ clf = SVC(kernel="precomputed")
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
+ cv.fit(K_train, y_train)
+
+ assert cv.best_score_ >= 0
+
+ # compute the test kernel matrix
+ K_test = np.dot(X_[180:], X_[:180].T)
+ y_test = y_[180:]
+
+ y_pred = cv.predict(K_test)
+
+ assert np.mean(y_pred == y_test) >= 0
+
+ # test error is raised when the precomputed kernel is not array-like
+ # or sparse
+ with pytest.raises(ValueError):
+ cv.fit(K_train.tolist(), y_train)
+
+
+def test_grid_search_precomputed_kernel_error_nonsquare():
+ # Test that grid search returns an error with a non-square precomputed
+ # training kernel matrix
+ K_train = np.zeros((10, 20))
+ y_train = np.ones((10,))
+ clf = SVC(kernel="precomputed")
+ cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
+ with pytest.raises(ValueError):
+ cv.fit(K_train, y_train)
+
+
+class BrokenClassifier(BaseEstimator):
+ """Broken classifier that cannot be fit twice"""
+
+ def __init__(self, parameter=None):
+ self.parameter = parameter
+
+ def fit(self, X, y):
+ assert not hasattr(self, "has_been_fit_")
+ self.has_been_fit_ = True
+
+ def predict(self, X):
+ return np.zeros(X.shape[0])
+
+
+@ignore_warnings
+def test_refit():
+ # Regression test for bug in refitting
+ # Simulates re-fitting a broken estimator; this used to break with
+ # sparse SVMs.
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+
+ clf = GridSearchCV(
+ BrokenClassifier(), [{"parameter": [0, 1]}], scoring="precision", refit=True
+ )
+ clf.fit(X, y)
+
+
+def test_refit_callable():
+ """
+ Test refit=callable, which adds flexibility in identifying the
+ "best" estimator.
+ """
+
+ def refit_callable(cv_results):
+ """
+ A dummy function tests `refit=callable` interface.
+ Return the index of a model that has the least
+ `mean_test_score`.
+ """
+ # Fit a dummy clf with `refit=True` to get a list of keys in
+ # clf.cv_results_.
+ X, y = make_classification(n_samples=100, n_features=4, random_state=42)
+ clf = GridSearchCV(
+ LinearSVC(random_state=42),
+ {"C": [0.01, 0.1, 1]},
+ scoring="precision",
+ refit=True,
+ )
+ clf.fit(X, y)
+ # Ensure that `best_index_ != 0` for this dummy clf
+ assert clf.best_index_ != 0
+
+ # Assert every key matches those in `cv_results`
+ for key in clf.cv_results_.keys():
+ assert key in cv_results
+
+ return cv_results["mean_test_score"].argmin()
+
+ X, y = make_classification(n_samples=100, n_features=4, random_state=42)
+ clf = GridSearchCV(
+ LinearSVC(random_state=42),
+ {"C": [0.01, 0.1, 1]},
+ scoring="precision",
+ refit=refit_callable,
+ )
+ clf.fit(X, y)
+
+ assert clf.best_index_ == 0
+ # Ensure `best_score_` is disabled when using `refit=callable`
+ assert not hasattr(clf, "best_score_")
+
+
+def test_refit_callable_invalid_type():
+ """
+ Test implementation catches the errors when 'best_index_' returns an
+ invalid result.
+ """
+
+ def refit_callable_invalid_type(cv_results):
+ """
+ A dummy function tests when returned 'best_index_' is not integer.
+ """
+ return None
+
+ X, y = make_classification(n_samples=100, n_features=4, random_state=42)
+
+ clf = GridSearchCV(
+ LinearSVC(random_state=42),
+ {"C": [0.1, 1]},
+ scoring="precision",
+ refit=refit_callable_invalid_type,
+ )
+ with pytest.raises(TypeError, match="best_index_ returned is not an integer"):
+ clf.fit(X, y)
+
+
+@pytest.mark.parametrize("out_bound_value", [-1, 2])
+@pytest.mark.parametrize("search_cv", [RandomizedSearchCV, GridSearchCV])
+def test_refit_callable_out_bound(out_bound_value, search_cv):
+ """
+ Test implementation catches the errors when 'best_index_' returns an
+ out of bound result.
+ """
+
+ def refit_callable_out_bound(cv_results):
+ """
+ A dummy function tests when returned 'best_index_' is out of bounds.
+ """
+ return out_bound_value
+
+ X, y = make_classification(n_samples=100, n_features=4, random_state=42)
+
+ clf = search_cv(
+ LinearSVC(random_state=42),
+ {"C": [0.1, 1]},
+ scoring="precision",
+ refit=refit_callable_out_bound,
+ )
+ with pytest.raises(IndexError, match="best_index_ index out of range"):
+ clf.fit(X, y)
+
+
+def test_refit_callable_multi_metric():
+ """
+ Test refit=callable in multiple metric evaluation setting
+ """
+
+ def refit_callable(cv_results):
+ """
+ A dummy function tests `refit=callable` interface.
+ Return the index of a model that has the least
+ `mean_test_prec`.
+ """
+ assert "mean_test_prec" in cv_results
+ return cv_results["mean_test_prec"].argmin()
+
+ X, y = make_classification(n_samples=100, n_features=4, random_state=42)
+ scoring = {"Accuracy": make_scorer(accuracy_score), "prec": "precision"}
+ clf = GridSearchCV(
+ LinearSVC(random_state=42),
+ {"C": [0.01, 0.1, 1]},
+ scoring=scoring,
+ refit=refit_callable,
+ )
+ clf.fit(X, y)
+
+ assert clf.best_index_ == 0
+ # Ensure `best_score_` is disabled when using `refit=callable`
+ assert not hasattr(clf, "best_score_")
+
+
+def test_gridsearch_nd():
+ # Pass X as list in GridSearchCV
+ X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2)
+ y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11)
+
+ def check_X(x):
+ return x.shape[1:] == (5, 3, 2)
+
+ def check_y(x):
+ return x.shape[1:] == (7, 11)
+
+ clf = CheckingClassifier(
+ check_X=check_X,
+ check_y=check_y,
+ methods_to_check=["fit"],
+ )
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]})
+ grid_search.fit(X_4d, y_3d).score(X, y)
+ assert hasattr(grid_search, "cv_results_")
+
+
+def test_X_as_list():
+ # Pass X as list in GridSearchCV
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+
+ clf = CheckingClassifier(
+ check_X=lambda x: isinstance(x, list),
+ methods_to_check=["fit"],
+ )
+ cv = KFold(n_splits=3)
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=cv)
+ grid_search.fit(X.tolist(), y).score(X, y)
+ assert hasattr(grid_search, "cv_results_")
+
+
+def test_y_as_list():
+ # Pass y as list in GridSearchCV
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+
+ clf = CheckingClassifier(
+ check_y=lambda x: isinstance(x, list),
+ methods_to_check=["fit"],
+ )
+ cv = KFold(n_splits=3)
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=cv)
+ grid_search.fit(X, y.tolist()).score(X, y)
+ assert hasattr(grid_search, "cv_results_")
+
+
+@ignore_warnings
+def test_pandas_input():
+ # check cross_val_score doesn't destroy pandas dataframe
+ types = [(MockDataFrame, MockDataFrame)]
+ try:
+ from modin.pandas import Series, DataFrame
+
+ types.append((DataFrame, Series))
+ except ImportError:
+ pass
+
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+
+ for InputFeatureType, TargetType in types:
+ # X dataframe, y series
+ X_df, y_ser = InputFeatureType(X), TargetType(y)
+
+ def check_df(x):
+ return isinstance(x, InputFeatureType)
+
+ def check_series(x):
+ return isinstance(x, TargetType)
+
+ clf = CheckingClassifier(check_X=check_df, check_y=check_series)
+
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]})
+ grid_search.fit(X_df, y_ser).score(X_df, y_ser)
+ grid_search.predict(X_df)
+ assert hasattr(grid_search, "cv_results_")
+
+
+def test_unsupervised_grid_search():
+ # test grid-search with unsupervised estimator
+ X, y = make_blobs(n_samples=50, random_state=0)
+ km = KMeans(random_state=0, init="random", n_init=1)
+
+ # Multi-metric evaluation unsupervised
+ scoring = ["adjusted_rand_score", "fowlkes_mallows_score"]
+ for refit in ["adjusted_rand_score", "fowlkes_mallows_score"]:
+ grid_search = GridSearchCV(
+ km, param_grid=dict(n_clusters=[2, 3, 4]), scoring=scoring, refit=refit
+ )
+ grid_search.fit(X, y)
+ # Both ARI and FMS can find the right number :)
+ assert grid_search.best_params_["n_clusters"] == 3
+
+ # Single metric evaluation unsupervised
+ grid_search = GridSearchCV(
+ km, param_grid=dict(n_clusters=[2, 3, 4]), scoring="fowlkes_mallows_score"
+ )
+ grid_search.fit(X, y)
+ assert grid_search.best_params_["n_clusters"] == 3
+
+ # Now without a score, and without y
+ grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]))
+ grid_search.fit(X)
+ assert grid_search.best_params_["n_clusters"] == 4
+
+
+def test_gridsearch_no_predict():
+ # test grid-search with an estimator without predict.
+ # slight duplication of a test from KDE
+ def custom_scoring(estimator, X):
+ return 42 if estimator.bandwidth == 0.1 else 0
+
+ X, _ = make_blobs(cluster_std=0.1, random_state=1, centers=[[0, 1], [1, 0], [0, 0]])
+ search = GridSearchCV(
+ KernelDensity(),
+ param_grid=dict(bandwidth=[0.01, 0.1, 1]),
+ scoring=custom_scoring,
+ )
+ search.fit(X)
+ assert search.best_params_["bandwidth"] == 0.1
+ assert search.best_score_ == 42
+
+
+def test_param_sampler():
+ # test basic properties of param sampler
+ param_distributions = {"kernel": ["rbf", "linear"], "C": uniform(0, 1)}
+ sampler = ParameterSampler(
+ param_distributions=param_distributions, n_iter=10, random_state=0
+ )
+ samples = [x for x in sampler]
+ assert len(samples) == 10
+ for sample in samples:
+ assert sample["kernel"] in ["rbf", "linear"]
+ assert 0 <= sample["C"] <= 1
+
+ # test that repeated calls yield identical parameters
+ param_distributions = {"C": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
+ sampler = ParameterSampler(
+ param_distributions=param_distributions, n_iter=3, random_state=0
+ )
+ assert [x for x in sampler] == [x for x in sampler]
+
+ param_distributions = {"C": uniform(0, 1)}
+ sampler = ParameterSampler(
+ param_distributions=param_distributions, n_iter=10, random_state=0
+ )
+ assert [x for x in sampler] == [x for x in sampler]
+
+
+def check_cv_results_array_types(search, param_keys, score_keys):
+ # Check if the search `cv_results`'s array are of correct types
+ cv_results = search.cv_results_
+ assert all(isinstance(cv_results[param], np.ma.MaskedArray) for param in param_keys)
+ assert all(cv_results[key].dtype == object for key in param_keys)
+ assert not any(isinstance(cv_results[key], np.ma.MaskedArray) for key in score_keys)
+ assert all(
+ cv_results[key].dtype == np.float64
+ for key in score_keys
+ if not key.startswith("rank")
+ )
+
+ scorer_keys = search.scorer_.keys() if search.multimetric_ else ["score"]
+
+ for key in scorer_keys:
+ assert cv_results["rank_test_%s" % key].dtype == np.int32
+
+
+def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand):
+ # Test the search.cv_results_ contains all the required results
+ assert_array_equal(
+ sorted(cv_results.keys()), sorted(param_keys + score_keys + ("params",))
+ )
+ assert all(cv_results[key].shape == (n_cand,) for key in param_keys + score_keys)
+
+
+def test_grid_search_cv_results():
+ X, y = make_classification(n_samples=50, n_features=4, random_state=42)
+
+ n_splits = 3
+ n_grid_points = 6
+ params = [
+ dict(
+ kernel=[
+ "rbf",
+ ],
+ C=[1, 10],
+ gamma=[0.1, 1],
+ ),
+ dict(
+ kernel=[
+ "poly",
+ ],
+ degree=[1, 2],
+ ),
+ ]
+
+ param_keys = ("param_C", "param_degree", "param_gamma", "param_kernel")
+ score_keys = (
+ "mean_test_score",
+ "mean_train_score",
+ "rank_test_score",
+ "split0_test_score",
+ "split1_test_score",
+ "split2_test_score",
+ "split0_train_score",
+ "split1_train_score",
+ "split2_train_score",
+ "std_test_score",
+ "std_train_score",
+ "mean_fit_time",
+ "std_fit_time",
+ "mean_score_time",
+ "std_score_time",
+ )
+ n_candidates = n_grid_points
+
+ search = GridSearchCV(
+ SVC(), cv=n_splits, param_grid=params, return_train_score=True
+ )
+ search.fit(X, y)
+ cv_results = search.cv_results_
+ # Check if score and timing are reasonable
+ assert all(cv_results["rank_test_score"] >= 1)
+ assert (all(cv_results[k] >= 0) for k in score_keys if k != "rank_test_score")
+ assert (
+ all(cv_results[k] <= 1)
+ for k in score_keys
+ if "time" not in k and k != "rank_test_score"
+ )
+ # Check cv_results structure
+ check_cv_results_array_types(search, param_keys, score_keys)
+ check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
+ # Check masking
+ cv_results = search.cv_results_
+ n_candidates = len(search.cv_results_["params"])
+ assert all(
+ (
+ cv_results["param_C"].mask[i]
+ and cv_results["param_gamma"].mask[i]
+ and not cv_results["param_degree"].mask[i]
+ )
+ for i in range(n_candidates)
+ if cv_results["param_kernel"][i] == "linear"
+ )
+ assert all(
+ (
+ not cv_results["param_C"].mask[i]
+ and not cv_results["param_gamma"].mask[i]
+ and cv_results["param_degree"].mask[i]
+ )
+ for i in range(n_candidates)
+ if cv_results["param_kernel"][i] == "rbf"
+ )
+
+
+def test_random_search_cv_results():
+ X, y = make_classification(n_samples=50, n_features=4, random_state=42)
+
+ n_splits = 3
+ n_search_iter = 30
+
+ params = [
+ {"kernel": ["rbf"], "C": expon(scale=10), "gamma": expon(scale=0.1)},
+ {"kernel": ["poly"], "degree": [2, 3]},
+ ]
+ param_keys = ("param_C", "param_degree", "param_gamma", "param_kernel")
+ score_keys = (
+ "mean_test_score",
+ "mean_train_score",
+ "rank_test_score",
+ "split0_test_score",
+ "split1_test_score",
+ "split2_test_score",
+ "split0_train_score",
+ "split1_train_score",
+ "split2_train_score",
+ "std_test_score",
+ "std_train_score",
+ "mean_fit_time",
+ "std_fit_time",
+ "mean_score_time",
+ "std_score_time",
+ )
+ n_cand = n_search_iter
+
+ search = RandomizedSearchCV(
+ SVC(),
+ n_iter=n_search_iter,
+ cv=n_splits,
+ param_distributions=params,
+ return_train_score=True,
+ )
+ search.fit(X, y)
+ cv_results = search.cv_results_
+ # Check results structure
+ check_cv_results_array_types(search, param_keys, score_keys)
+ check_cv_results_keys(cv_results, param_keys, score_keys, n_cand)
+ n_candidates = len(search.cv_results_["params"])
+ assert all(
+ (
+ cv_results["param_C"].mask[i]
+ and cv_results["param_gamma"].mask[i]
+ and not cv_results["param_degree"].mask[i]
+ )
+ for i in range(n_candidates)
+ if cv_results["param_kernel"][i] == "linear"
+ )
+ assert all(
+ (
+ not cv_results["param_C"].mask[i]
+ and not cv_results["param_gamma"].mask[i]
+ and cv_results["param_degree"].mask[i]
+ )
+ for i in range(n_candidates)
+ if cv_results["param_kernel"][i] == "rbf"
+ )
+
+
+@pytest.mark.parametrize(
+ "SearchCV, specialized_params",
+ [
+ (GridSearchCV, {"param_grid": {"C": [1, 10]}}),
+ (RandomizedSearchCV, {"param_distributions": {"C": [1, 10]}, "n_iter": 2}),
+ ],
+)
+def test_search_default_iid(SearchCV, specialized_params):
+ # Test the IID parameter TODO: Clearly this test does something else???
+ # noise-free simple 2d-data
+ X, y = make_blobs(
+ centers=[[0, 0], [1, 0], [0, 1], [1, 1]],
+ random_state=0,
+ cluster_std=0.1,
+ shuffle=False,
+ n_samples=80,
+ )
+ # split dataset into two folds that are not iid
+ # first one contains data of all 4 blobs, second only from two.
+ mask = np.ones(X.shape[0], dtype=bool)
+ mask[np.where(y == 1)[0][::2]] = 0
+ mask[np.where(y == 2)[0][::2]] = 0
+ # this leads to perfect classification on one fold and a score of 1/3 on
+ # the other
+ # create "cv" for splits
+ cv = [[mask, ~mask], [~mask, mask]]
+
+ common_params = {"estimator": SVC(), "cv": cv, "return_train_score": True}
+ search = SearchCV(**common_params, **specialized_params)
+ search.fit(X, y)
+
+ test_cv_scores = np.array(
+ [
+ search.cv_results_["split%d_test_score" % s][0]
+ for s in range(search.n_splits_)
+ ]
+ )
+ test_mean = search.cv_results_["mean_test_score"][0]
+ test_std = search.cv_results_["std_test_score"][0]
+
+ train_cv_scores = np.array(
+ [
+ search.cv_results_["split%d_train_score" % s][0]
+ for s in range(search.n_splits_)
+ ]
+ )
+ train_mean = search.cv_results_["mean_train_score"][0]
+ train_std = search.cv_results_["std_train_score"][0]
+
+ assert search.cv_results_["param_C"][0] == 1
+ # scores are the same as above
+ assert_allclose(test_cv_scores, [1, 1.0 / 3.0])
+ assert_allclose(train_cv_scores, [1, 1])
+ # Unweighted mean/std is used
+ assert test_mean == pytest.approx(np.mean(test_cv_scores))
+ assert test_std == pytest.approx(np.std(test_cv_scores))
+
+ # For the train scores, we do not take a weighted mean irrespective of
+ # i.i.d. or not
+ assert train_mean == pytest.approx(1)
+ assert train_std == pytest.approx(0)
+
+
+def test_grid_search_cv_results_multimetric():
+ X, y = make_classification(n_samples=50, n_features=4, random_state=42)
+
+ n_splits = 3
+ params = [
+ dict(
+ kernel=[
+ "rbf",
+ ],
+ C=[1, 10],
+ gamma=[0.1, 1],
+ ),
+ dict(
+ kernel=[
+ "poly",
+ ],
+ degree=[1, 2],
+ ),
+ ]
+
+ grid_searches = []
+ for scoring in (
+ {"accuracy": make_scorer(accuracy_score), "recall": make_scorer(recall_score)},
+ "accuracy",
+ "recall",
+ ):
+ grid_search = GridSearchCV(
+ SVC(), cv=n_splits, param_grid=params, scoring=scoring, refit=False
+ )
+ grid_search.fit(X, y)
+ grid_searches.append(grid_search)
+
+ compare_cv_results_multimetric_with_single(*grid_searches)
+
+
+def test_random_search_cv_results_multimetric():
+ X, y = make_classification(n_samples=50, n_features=4, random_state=42)
+
+ n_splits = 3
+ n_search_iter = 30
+
+ # Scipy 0.12's stats dists do not accept seed, hence we use param grid
+ params = dict(C=np.logspace(-4, 1, 3), gamma=np.logspace(-5, 0, 3, base=0.1))
+ for refit in (True, False):
+ random_searches = []
+ for scoring in (("accuracy", "recall"), "accuracy", "recall"):
+ # If True, for multi-metric pass refit='accuracy'
+ if refit:
+ probability = True
+ refit = "accuracy" if isinstance(scoring, tuple) else refit
+ else:
+ probability = False
+ clf = SVC(probability=probability, random_state=42)
+ random_search = RandomizedSearchCV(
+ clf,
+ n_iter=n_search_iter,
+ cv=n_splits,
+ param_distributions=params,
+ scoring=scoring,
+ refit=refit,
+ random_state=0,
+ )
+ random_search.fit(X, y)
+ random_searches.append(random_search)
+
+ compare_cv_results_multimetric_with_single(*random_searches)
+ compare_refit_methods_when_refit_with_acc(
+ random_searches[0], random_searches[1], refit
+ )
+
+
+def compare_cv_results_multimetric_with_single(search_multi, search_acc, search_rec):
+ """Compare multi-metric cv_results with the ensemble of multiple
+ single metric cv_results from single metric grid/random search"""
+
+ assert search_multi.multimetric_
+ assert_array_equal(sorted(search_multi.scorer_), ("accuracy", "recall"))
+
+ cv_results_multi = search_multi.cv_results_
+ cv_results_acc_rec = {
+ re.sub("_score$", "_accuracy", k): v for k, v in search_acc.cv_results_.items()
+ }
+ cv_results_acc_rec.update(
+ {re.sub("_score$", "_recall", k): v for k, v in search_rec.cv_results_.items()}
+ )
+
+ # Check if score and timing are reasonable, also checks if the keys
+ # are present
+ assert all(
+ (
+ np.all(cv_results_multi[k] <= 1)
+ for k in (
+ "mean_score_time",
+ "std_score_time",
+ "mean_fit_time",
+ "std_fit_time",
+ )
+ )
+ )
+
+ # Compare the keys, other than time keys, among multi-metric and
+ # single metric grid search results. np.testing.assert_equal performs a
+ # deep nested comparison of the two cv_results dicts
+ np.testing.assert_equal(
+ {k: v for k, v in cv_results_multi.items() if not k.endswith("_time")},
+ {k: v for k, v in cv_results_acc_rec.items() if not k.endswith("_time")},
+ )
+
+
+def compare_refit_methods_when_refit_with_acc(search_multi, search_acc, refit):
+ """Compare refit multi-metric search methods with single metric methods"""
+ assert search_acc.refit == refit
+ if refit:
+ assert search_multi.refit == "accuracy"
+ else:
+ assert not search_multi.refit
+ return # search cannot predict/score without refit
+
+ X, y = make_blobs(n_samples=100, n_features=4, random_state=42)
+ for method in ("predict", "predict_proba", "predict_log_proba"):
+ assert_almost_equal(
+ getattr(search_multi, method)(X), getattr(search_acc, method)(X)
+ )
+ assert_almost_equal(search_multi.score(X, y), search_acc.score(X, y))
+ for key in ("best_index_", "best_score_", "best_params_"):
+ assert getattr(search_multi, key) == getattr(search_acc, key)
+
+
+@pytest.mark.parametrize(
+ "search_cv",
+ [
+ RandomizedSearchCV(
+ estimator=DecisionTreeClassifier(),
+ param_distributions={"max_depth": [5, 10]},
+ ),
+ GridSearchCV(
+ estimator=DecisionTreeClassifier(), param_grid={"max_depth": [5, 10]}
+ ),
+ ],
+)
+def test_search_cv_score_samples_error(search_cv):
+ X, y = make_blobs(n_samples=100, n_features=4, random_state=42)
+ search_cv.fit(X, y)
+
+ # Make sure to error out when underlying estimator does not implement
+ # the method `score_samples`
+ err_msg = "'DecisionTreeClassifier' object has no attribute 'score_samples'"
+
+ with pytest.raises(AttributeError, match=err_msg):
+ search_cv.score_samples(X)
+
+
+@pytest.mark.parametrize(
+ "search_cv",
+ [
+ RandomizedSearchCV(
+ estimator=LocalOutlierFactor(novelty=True),
+ param_distributions={"n_neighbors": [5, 10]},
+ scoring="precision",
+ ),
+ GridSearchCV(
+ estimator=LocalOutlierFactor(novelty=True),
+ param_grid={"n_neighbors": [5, 10]},
+ scoring="precision",
+ ),
+ ],
+)
+def test_search_cv_score_samples_method(search_cv):
+ # Set parameters
+ rng = np.random.RandomState(42)
+ n_samples = 300
+ outliers_fraction = 0.15
+ n_outliers = int(outliers_fraction * n_samples)
+ n_inliers = n_samples - n_outliers
+
+ # Create dataset
+ X = make_blobs(
+ n_samples=n_inliers,
+ n_features=2,
+ centers=[[0, 0], [0, 0]],
+ cluster_std=0.5,
+ random_state=0,
+ )[0]
+ # Add some noisy points
+ X = np.concatenate([X, rng.uniform(low=-6, high=6, size=(n_outliers, 2))], axis=0)
+
+ # Define labels to be able to score the estimator with `search_cv`
+ y_true = np.array([1] * n_samples)
+ y_true[-n_outliers:] = -1
+
+ # Fit on data
+ search_cv.fit(X, y_true)
+
+ # Verify that the stand alone estimator yields the same results
+ # as the ones obtained with *SearchCV
+ assert_allclose(
+ search_cv.score_samples(X), search_cv.best_estimator_.score_samples(X)
+ )
+
+
+def test_search_cv_results_rank_tie_breaking():
+ X, y = make_blobs(n_samples=50, random_state=42)
+
+ # The two C values are close enough to give similar models
+ # which would result in a tie of their mean cv-scores
+ param_grid = {"C": [1, 1.001, 0.001]}
+
+ grid_search = GridSearchCV(SVC(), param_grid=param_grid, return_train_score=True)
+ random_search = RandomizedSearchCV(
+ SVC(), n_iter=3, param_distributions=param_grid, return_train_score=True
+ )
+
+ for search in (grid_search, random_search):
+ search.fit(X, y)
+ cv_results = search.cv_results_
+ # Check tie breaking strategy -
+ # Check that there is a tie in the mean scores between
+ # candidates 1 and 2 alone
+ assert_almost_equal(
+ cv_results["mean_test_score"][0], cv_results["mean_test_score"][1]
+ )
+ assert_almost_equal(
+ cv_results["mean_train_score"][0], cv_results["mean_train_score"][1]
+ )
+ assert not np.allclose(
+ cv_results["mean_test_score"][1], cv_results["mean_test_score"][2]
+ )
+ assert not np.allclose(
+ cv_results["mean_train_score"][1], cv_results["mean_train_score"][2]
+ )
+ # 'min' rank should be assigned to the tied candidates
+ assert_almost_equal(search.cv_results_["rank_test_score"], [1, 1, 3])
+
+
+def test_search_cv_results_none_param():
+ X, y = [[1], [2], [3], [4], [5]], [0, 0, 0, 0, 1]
+ estimators = (DecisionTreeRegressor(), DecisionTreeClassifier())
+ est_parameters = {"random_state": [0, None]}
+ cv = KFold()
+
+ for est in estimators:
+ grid_search = GridSearchCV(
+ est,
+ est_parameters,
+ cv=cv,
+ ).fit(X, y)
+ assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])
+
+
+@ignore_warnings()
+def test_search_cv_timing():
+ svc = LinearSVC(random_state=0)
+
+ X = [
+ [
+ 1,
+ ],
+ [
+ 2,
+ ],
+ [
+ 3,
+ ],
+ [
+ 4,
+ ],
+ ]
+ y = [0, 1, 1, 0]
+
+ gs = GridSearchCV(svc, {"C": [0, 1]}, cv=2, error_score=0)
+ rs = RandomizedSearchCV(svc, {"C": [0, 1]}, cv=2, error_score=0, n_iter=2)
+
+ for search in (gs, rs):
+ search.fit(X, y)
+ for key in ["mean_fit_time", "std_fit_time"]:
+ # NOTE The precision of time.time in windows is not high
+ # enough for the fit/score times to be non-zero for trivial X and y
+ assert np.all(search.cv_results_[key] >= 0)
+ assert np.all(search.cv_results_[key] < 1)
+
+ for key in ["mean_score_time", "std_score_time"]:
+ assert search.cv_results_[key][1] >= 0
+ assert search.cv_results_[key][0] == 0.0
+ assert np.all(search.cv_results_[key] < 1)
+
+ assert hasattr(search, "refit_time_")
+ assert isinstance(search.refit_time_, float)
+ assert search.refit_time_ >= 0
+
+
+def test_grid_search_correct_score_results():
+ # test that correct scores are used
+ n_splits = 3
+ clf = LinearSVC(random_state=0)
+ X, y = make_blobs(random_state=0, centers=2)
+ Cs = [0.1, 1, 10]
+ for score in ["f1", "roc_auc"]:
+ grid_search = GridSearchCV(clf, {"C": Cs}, scoring=score, cv=n_splits)
+ cv_results = grid_search.fit(X, y).cv_results_
+
+ # Test scorer names
+ result_keys = list(cv_results.keys())
+ expected_keys = ("mean_test_score", "rank_test_score") + tuple(
+ "split%d_test_score" % cv_i for cv_i in range(n_splits)
+ )
+ assert all(np.in1d(expected_keys, result_keys))
+
+ cv = StratifiedKFold(n_splits=n_splits)
+ n_splits = grid_search.n_splits_
+ for candidate_i, C in enumerate(Cs):
+ clf.set_params(C=C)
+ cv_scores = np.array(
+ [
+ grid_search.cv_results_["split%d_test_score" % s][candidate_i]
+ for s in range(n_splits)
+ ]
+ )
+ for i, (train, test) in enumerate(cv.split(X, y)):
+ clf.fit(X[train], y[train])
+ if score == "f1":
+ correct_score = f1_score(y[test], clf.predict(X[test]))
+ elif score == "roc_auc":
+ dec = clf.decision_function(X[test])
+ correct_score = roc_auc_score(y[test], dec)
+ assert_almost_equal(correct_score, cv_scores[i])
+
+
+def test_pickle():
+ # Test that a fit search can be pickled
+ clf = MockClassifier()
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, refit=True, cv=3)
+ grid_search.fit(X, y)
+ grid_search_pickled = pickle.loads(pickle.dumps(grid_search))
+ assert_array_almost_equal(grid_search.predict(X), grid_search_pickled.predict(X))
+
+ random_search = RandomizedSearchCV(
+ clf, {"foo_param": [1, 2, 3]}, refit=True, n_iter=3, cv=3
+ )
+ random_search.fit(X, y)
+ random_search_pickled = pickle.loads(pickle.dumps(random_search))
+ assert_array_almost_equal(
+ random_search.predict(X), random_search_pickled.predict(X)
+ )
+
+
+def test_grid_search_with_multioutput_data():
+ # Test search with multi-output estimator
+
+ X, y = make_multilabel_classification(return_indicator=True, random_state=0)
+
+ est_parameters = {"max_depth": [1, 2, 3, 4]}
+ cv = KFold()
+
+ estimators = [
+ DecisionTreeRegressor(random_state=0),
+ DecisionTreeClassifier(random_state=0),
+ ]
+
+ # Test with grid search cv
+ for est in estimators:
+ grid_search = GridSearchCV(est, est_parameters, cv=cv)
+ grid_search.fit(X, y)
+ res_params = grid_search.cv_results_["params"]
+ for cand_i in range(len(res_params)):
+ est.set_params(**res_params[cand_i])
+
+ for i, (train, test) in enumerate(cv.split(X, y)):
+ est.fit(X[train], y[train])
+ correct_score = est.score(X[test], y[test])
+ assert_almost_equal(
+ correct_score,
+ grid_search.cv_results_["split%d_test_score" % i][cand_i],
+ )
+
+ # Test with a randomized search
+ for est in estimators:
+ random_search = RandomizedSearchCV(est, est_parameters, cv=cv, n_iter=3)
+ random_search.fit(X, y)
+ res_params = random_search.cv_results_["params"]
+ for cand_i in range(len(res_params)):
+ est.set_params(**res_params[cand_i])
+
+ for i, (train, test) in enumerate(cv.split(X, y)):
+ est.fit(X[train], y[train])
+ correct_score = est.score(X[test], y[test])
+ assert_almost_equal(
+ correct_score,
+ random_search.cv_results_["split%d_test_score" % i][cand_i],
+ )
+
+
+def test_predict_proba_disabled():
+ # Test predict_proba when disabled on estimator.
+ X = np.arange(20).reshape(5, -1)
+ y = [0, 0, 1, 1, 1]
+ clf = SVC(probability=False)
+ gs = GridSearchCV(clf, {}, cv=2).fit(X, y)
+ assert not hasattr(gs, "predict_proba")
+
+
+def test_grid_search_allows_nans():
+ # Test GridSearchCV with SimpleImputer
+ X = np.arange(20, dtype=np.float64).reshape(5, -1)
+ X[2, :] = np.nan
+ y = [0, 0, 1, 1, 1]
+ p = Pipeline(
+ [
+ ("imputer", SimpleImputer(strategy="mean", missing_values=np.nan)),
+ ("classifier", MockClassifier()),
+ ]
+ )
+ GridSearchCV(p, {"classifier__foo_param": [1, 2, 3]}, cv=2).fit(X, y)
+
+
+class FailingClassifier(BaseEstimator):
+ """Classifier that raises a ValueError on fit()"""
+
+ FAILING_PARAMETER = 2
+
+ def __init__(self, parameter=None):
+ self.parameter = parameter
+
+ def fit(self, X, y=None):
+ if self.parameter == FailingClassifier.FAILING_PARAMETER:
+ raise ValueError("Failing classifier failed as required")
+
+ def predict(self, X):
+ return np.zeros(X.shape[0])
+
+ def score(self, X=None, Y=None):
+ return 0.0
+
+
+def test_grid_search_failing_classifier():
+ # GridSearchCV with on_error != 'raise'
+ # Ensures that a warning is raised and score reset where appropriate.
+
+ X, y = make_classification(n_samples=20, n_features=10, random_state=0)
+
+ clf = FailingClassifier()
+
+ # refit=False because we only want to check that errors caused by fits
+ # to individual folds will be caught and warnings raised instead. If
+ # refit was done, then an exception would be raised on refit and not
+ # caught by grid_search (expected behavior), and this would cause an
+ # error in this test.
+ gs = GridSearchCV(
+ clf,
+ [{"parameter": [0, 1, 2]}],
+ scoring="accuracy",
+ refit=False,
+ error_score=0.0,
+ )
+
+ warning_message = re.compile(
+ "5 fits failed.+total of 15.+The score on these"
+ r" train-test partitions for these parameters will be set to 0\.0.+"
+ "5 fits failed with the following error.+ValueError.+Failing classifier failed"
+ " as required",
+ flags=re.DOTALL,
+ )
+ with pytest.warns(FitFailedWarning, match=warning_message):
+ gs.fit(X, y)
+ n_candidates = len(gs.cv_results_["params"])
+
+ # Ensure that grid scores were set to zero as required for those fits
+ # that are expected to fail.
+ def get_cand_scores(i):
+ return np.array(
+ [gs.cv_results_["split%d_test_score" % s][i] for s in range(gs.n_splits_)]
+ )
+
+ assert all(
+ (
+ np.all(get_cand_scores(cand_i) == 0.0)
+ for cand_i in range(n_candidates)
+ if gs.cv_results_["param_parameter"][cand_i]
+ == FailingClassifier.FAILING_PARAMETER
+ )
+ )
+
+ gs = GridSearchCV(
+ clf,
+ [{"parameter": [0, 1, 2]}],
+ scoring="accuracy",
+ refit=False,
+ error_score=float("nan"),
+ )
+ warning_message = re.compile(
+ "5 fits failed.+total of 15.+The score on these"
+ r" train-test partitions for these parameters will be set to nan.+"
+ "5 fits failed with the following error.+ValueError.+Failing classifier failed"
+ " as required",
+ flags=re.DOTALL,
+ )
+ with pytest.warns(FitFailedWarning, match=warning_message):
+ gs.fit(X, y)
+ n_candidates = len(gs.cv_results_["params"])
+ assert all(
+ np.all(np.isnan(get_cand_scores(cand_i)))
+ for cand_i in range(n_candidates)
+ if gs.cv_results_["param_parameter"][cand_i]
+ == FailingClassifier.FAILING_PARAMETER
+ )
+
+ ranks = gs.cv_results_["rank_test_score"]
+
+ # Check that succeeded estimators have lower ranks
+ assert ranks[0] <= 2 and ranks[1] <= 2
+ # Check that failed estimator has the highest rank
+ assert ranks[clf.FAILING_PARAMETER] == 3
+ assert gs.best_index_ != clf.FAILING_PARAMETER
+
+
+def test_grid_search_classifier_all_fits_fail():
+ X, y = make_classification(n_samples=20, n_features=10, random_state=0)
+
+ clf = FailingClassifier()
+
+ gs = GridSearchCV(
+ clf,
+ [{"parameter": [FailingClassifier.FAILING_PARAMETER] * 3}],
+ error_score=0.0,
+ )
+
+ warning_message = re.compile(
+ "All the 15 fits failed.+"
+ "15 fits failed with the following error.+ValueError.+Failing classifier failed"
+ " as required",
+ flags=re.DOTALL,
+ )
+ with pytest.raises(ValueError, match=warning_message):
+ gs.fit(X, y)
+
+
+def test_grid_search_failing_classifier_raise():
+ # GridSearchCV with on_error == 'raise' raises the error
+
+ X, y = make_classification(n_samples=20, n_features=10, random_state=0)
+
+ clf = FailingClassifier()
+
+ # refit=False because we want to test the behaviour of the grid search part
+ gs = GridSearchCV(
+ clf,
+ [{"parameter": [0, 1, 2]}],
+ scoring="accuracy",
+ refit=False,
+ error_score="raise",
+ )
+
+ # FailingClassifier issues a ValueError so this is what we look for.
+ with pytest.raises(ValueError):
+ gs.fit(X, y)
+
+
+def test_parameters_sampler_replacement():
+ # raise warning if n_iter is bigger than total parameter space
+ params = [
+ {"first": [0, 1], "second": ["a", "b", "c"]},
+ {"third": ["two", "values"]},
+ ]
+ sampler = ParameterSampler(params, n_iter=9)
+ n_iter = 9
+ grid_size = 8
+ expected_warning = (
+ "The total space of parameters %d is smaller "
+ "than n_iter=%d. Running %d iterations. For "
+ "exhaustive searches, use GridSearchCV." % (grid_size, n_iter, grid_size)
+ )
+ with pytest.warns(UserWarning, match=expected_warning):
+ list(sampler)
+
+ # degenerates to GridSearchCV if n_iter the same as grid_size
+ sampler = ParameterSampler(params, n_iter=8)
+ samples = list(sampler)
+ assert len(samples) == 8
+ for values in ParameterGrid(params):
+ assert values in samples
+ assert len(ParameterSampler(params, n_iter=1000)) == 8
+
+ # test sampling without replacement in a large grid
+ params = {"a": range(10), "b": range(10), "c": range(10)}
+ sampler = ParameterSampler(params, n_iter=99, random_state=42)
+ samples = list(sampler)
+ assert len(samples) == 99
+ hashable_samples = ["a%db%dc%d" % (p["a"], p["b"], p["c"]) for p in samples]
+ assert len(set(hashable_samples)) == 99
+
+ # doesn't go into infinite loops
+ params_distribution = {"first": bernoulli(0.5), "second": ["a", "b", "c"]}
+ sampler = ParameterSampler(params_distribution, n_iter=7)
+ samples = list(sampler)
+ assert len(samples) == 7
+
+
+def test_stochastic_gradient_loss_param():
+ # Make sure the predict_proba works when loss is specified
+ # as one of the parameters in the param_grid.
+ param_grid = {
+ "loss": ["log_loss"],
+ }
+ X = np.arange(24).reshape(6, -1)
+ y = [0, 0, 0, 1, 1, 1]
+ clf = GridSearchCV(
+ estimator=SGDClassifier(loss="hinge"), param_grid=param_grid, cv=3
+ )
+
+ # When the estimator is not fitted, `predict_proba` is not available as the
+ # loss is 'hinge'.
+ assert not hasattr(clf, "predict_proba")
+ clf.fit(X, y)
+ clf.predict_proba(X)
+ clf.predict_log_proba(X)
+
+ # Make sure `predict_proba` is not available when setting loss=['hinge']
+ # in param_grid
+ param_grid = {
+ "loss": ["hinge"],
+ }
+ clf = GridSearchCV(
+ estimator=SGDClassifier(loss="hinge"), param_grid=param_grid, cv=3
+ )
+ assert not hasattr(clf, "predict_proba")
+ clf.fit(X, y)
+ assert not hasattr(clf, "predict_proba")
+
+
+def test_search_train_scores_set_to_false():
+ X = np.arange(6).reshape(6, -1)
+ y = [0, 0, 0, 1, 1, 1]
+ clf = LinearSVC(random_state=0)
+
+ gs = GridSearchCV(clf, param_grid={"C": [0.1, 0.2]}, cv=3)
+ gs.fit(X, y)
+
+
+def test_grid_search_cv_splits_consistency():
+ # Check if a one time iterable is accepted as a cv parameter.
+ n_samples = 100
+ n_splits = 5
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+
+ gs = GridSearchCV(
+ LinearSVC(random_state=0),
+ param_grid={"C": [0.1, 0.2, 0.3]},
+ cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples),
+ return_train_score=True,
+ )
+ gs.fit(X, y)
+
+ gs2 = GridSearchCV(
+ LinearSVC(random_state=0),
+ param_grid={"C": [0.1, 0.2, 0.3]},
+ cv=KFold(n_splits=n_splits),
+ return_train_score=True,
+ )
+ gs2.fit(X, y)
+
+ # Give generator as a cv parameter
+ assert isinstance(
+ KFold(n_splits=n_splits, shuffle=True, random_state=0).split(X, y),
+ GeneratorType,
+ )
+ gs3 = GridSearchCV(
+ LinearSVC(random_state=0),
+ param_grid={"C": [0.1, 0.2, 0.3]},
+ cv=KFold(n_splits=n_splits, shuffle=True, random_state=0).split(X, y),
+ return_train_score=True,
+ )
+ gs3.fit(X, y)
+
+ gs4 = GridSearchCV(
+ LinearSVC(random_state=0),
+ param_grid={"C": [0.1, 0.2, 0.3]},
+ cv=KFold(n_splits=n_splits, shuffle=True, random_state=0),
+ return_train_score=True,
+ )
+ gs4.fit(X, y)
+
+ def _pop_time_keys(cv_results):
+ for key in (
+ "mean_fit_time",
+ "std_fit_time",
+ "mean_score_time",
+ "std_score_time",
+ ):
+ cv_results.pop(key)
+ return cv_results
+
+ # Check if generators are supported as cv and
+ # that the splits are consistent
+ np.testing.assert_equal(
+ _pop_time_keys(gs3.cv_results_), _pop_time_keys(gs4.cv_results_)
+ )
+
+ # OneTimeSplitter is a non-re-entrant cv where split can be called only
+ # once if ``cv.split`` is called once per param setting in GridSearchCV.fit
+ # the 2nd and 3rd parameter will not be evaluated as no train/test indices
+ # will be generated for the 2nd and subsequent cv.split calls.
+ # This is a check to make sure cv.split is not called once per param
+ # setting.
+ np.testing.assert_equal(
+ {k: v for k, v in gs.cv_results_.items() if not k.endswith("_time")},
+ {k: v for k, v in gs2.cv_results_.items() if not k.endswith("_time")},
+ )
+
+ # Check consistency of folds across the parameters
+ gs = GridSearchCV(
+ LinearSVC(random_state=0),
+ param_grid={"C": [0.1, 0.1, 0.2, 0.2]},
+ cv=KFold(n_splits=n_splits, shuffle=True),
+ return_train_score=True,
+ )
+ gs.fit(X, y)
+
+ # As the first two param settings (C=0.1) and the next two param
+ # settings (C=0.2) are same, the test and train scores must also be
+ # same as long as the same train/test indices are generated for all
+ # the cv splits, for both param setting
+ for score_type in ("train", "test"):
+ per_param_scores = {}
+ for param_i in range(4):
+ per_param_scores[param_i] = [
+ gs.cv_results_["split%d_%s_score" % (s, score_type)][param_i]
+ for s in range(5)
+ ]
+
+ assert_array_almost_equal(per_param_scores[0], per_param_scores[1])
+ assert_array_almost_equal(per_param_scores[2], per_param_scores[3])
+
+
+def test_transform_inverse_transform_round_trip():
+ clf = MockClassifier()
+ grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3)
+
+ grid_search.fit(X, y)
+ X_round_trip = grid_search.inverse_transform(grid_search.transform(X))
+ assert_array_equal(X, X_round_trip)
+
+
+def test_custom_run_search():
+ def check_results(results, gscv):
+ exp_results = gscv.cv_results_
+ assert sorted(results.keys()) == sorted(exp_results)
+ for k in results:
+ if not k.endswith("_time"):
+ # XXX: results['params'] is a list :|
+ results[k] = np.asanyarray(results[k])
+ if results[k].dtype.kind == "O":
+ assert_array_equal(
+ exp_results[k], results[k], err_msg="Checking " + k
+ )
+ else:
+ assert_allclose(exp_results[k], results[k], err_msg="Checking " + k)
+
+ def fit_grid(param_grid):
+ return GridSearchCV(clf, param_grid, return_train_score=True).fit(X, y)
+
+ class CustomSearchCV(BaseSearchCV):
+ def __init__(self, estimator, **kwargs):
+ super().__init__(estimator, **kwargs)
+
+ def _run_search(self, evaluate):
+ results = evaluate([{"max_depth": 1}, {"max_depth": 2}])
+ check_results(results, fit_grid({"max_depth": [1, 2]}))
+ results = evaluate([{"min_samples_split": 5}, {"min_samples_split": 10}])
+ check_results(
+ results,
+ fit_grid([{"max_depth": [1, 2]}, {"min_samples_split": [5, 10]}]),
+ )
+
+ # Using regressor to make sure each score differs
+ clf = DecisionTreeRegressor(random_state=0)
+ X, y = make_classification(n_samples=100, n_informative=4, random_state=0)
+ mycv = CustomSearchCV(clf, return_train_score=True).fit(X, y)
+ gscv = fit_grid([{"max_depth": [1, 2]}, {"min_samples_split": [5, 10]}])
+
+ results = mycv.cv_results_
+ check_results(results, gscv)
+ for attr in dir(gscv):
+ if (
+ attr[0].islower()
+ and attr[-1:] == "_"
+ and attr
+ not in {"cv_results_", "best_estimator_", "refit_time_", "classes_"}
+ ):
+ assert getattr(gscv, attr) == getattr(mycv, attr), (
+ "Attribute %s not equal" % attr
+ )
+
+
+def test__custom_fit_no_run_search():
+ class NoRunSearchSearchCV(BaseSearchCV):
+ def __init__(self, estimator, **kwargs):
+ super().__init__(estimator, **kwargs)
+
+ def fit(self, X, y=None, groups=None, **fit_params):
+ return self
+
+ # this should not raise any exceptions
+ NoRunSearchSearchCV(SVC()).fit(X, y)
+
+ class BadSearchCV(BaseSearchCV):
+ def __init__(self, estimator, **kwargs):
+ super().__init__(estimator, **kwargs)
+
+ with pytest.raises(NotImplementedError, match="_run_search not implemented."):
+ # this should raise a NotImplementedError
+ BadSearchCV(SVC()).fit(X, y)
+
+
+def test_empty_cv_iterator_error():
+ # Use global X, y
+
+ # create cv
+ cv = KFold(n_splits=3).split(X)
+
+ # pop all of it, this should cause the expected ValueError
+ [u for u in cv]
+ # cv is empty now
+
+ train_size = 100
+ ridge = RandomizedSearchCV(Ridge(), {"alpha": [1e-3, 1e-2, 1e-1]}, cv=cv, n_jobs=4)
+
+ # assert that this raises an error
+ with pytest.raises(
+ ValueError,
+ match=(
+ "No fits were performed. "
+ "Was the CV iterator empty\\? "
+ "Were there no candidates\\?"
+ ),
+ ):
+ ridge.fit(X[:train_size], y[:train_size])
+
+
+def test_random_search_bad_cv():
+ # Use global X, y
+
+ class BrokenKFold(KFold):
+ def get_n_splits(self, *args, **kw):
+ return 1
+
+ # create bad cv
+ cv = BrokenKFold(n_splits=3)
+
+ train_size = 100
+ ridge = RandomizedSearchCV(Ridge(), {"alpha": [1e-3, 1e-2, 1e-1]}, cv=cv, n_jobs=4)
+
+ # assert that this raises an error
+ with pytest.raises(
+ ValueError,
+ match=(
+ "cv.split and cv.get_n_splits returned "
+ "inconsistent results. Expected \\d+ "
+ "splits, got \\d+"
+ ),
+ ):
+ ridge.fit(X[:train_size], y[:train_size])
+
+
+@pytest.mark.parametrize("return_train_score", [False, True])
+@pytest.mark.parametrize(
+ "SearchCV, specialized_params",
+ [
+ (GridSearchCV, {"param_grid": {"max_depth": [2, 3, 5, 8]}}),
+ (
+ RandomizedSearchCV,
+ {"param_distributions": {"max_depth": [2, 3, 5, 8]}, "n_iter": 4},
+ ),
+ ],
+)
+def test_searchcv_raise_warning_with_non_finite_score(
+ SearchCV, specialized_params, return_train_score
+):
+ # Non-regression test for:
+ # https://github.com/scikit-learn/scikit-learn/issues/10529
+ # Check that we raise a UserWarning when a non-finite score is
+ # computed in the SearchCV
+ X, y = make_classification(n_classes=2, random_state=0)
+
+ class FailingScorer:
+ """Scorer that will fail for some split but not all."""
+
+ def __init__(self):
+ self.n_counts = 0
+
+ def __call__(self, estimator, X, y):
+ self.n_counts += 1
+ if self.n_counts % 5 == 0:
+ return np.nan
+ return 1
+
+ grid = SearchCV(
+ DecisionTreeClassifier(),
+ scoring=FailingScorer(),
+ cv=3,
+ return_train_score=return_train_score,
+ **specialized_params,
+ )
+
+ with pytest.warns(UserWarning) as warn_msg:
+ grid.fit(X, y)
+
+ set_with_warning = ["test", "train"] if return_train_score else ["test"]
+ assert len(warn_msg) == len(set_with_warning)
+ for msg, dataset in zip(warn_msg, set_with_warning):
+ assert f"One or more of the {dataset} scores are non-finite" in str(msg.message)
+
+ # all non-finite scores should be equally ranked last
+ last_rank = grid.cv_results_["rank_test_score"].max()
+ non_finite_mask = np.isnan(grid.cv_results_["mean_test_score"])
+ assert_array_equal(grid.cv_results_["rank_test_score"][non_finite_mask], last_rank)
+ # all finite scores should be better ranked than the non-finite scores
+ assert np.all(grid.cv_results_["rank_test_score"][~non_finite_mask] < last_rank)
+
+
+def test_callable_multimetric_confusion_matrix():
+ # Test callable with many metrics inserts the correct names and metrics
+ # into the search cv object
+ def custom_scorer(clf, X, y):
+ y_pred = clf.predict(X)
+ cm = confusion_matrix(y, y_pred)
+ return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]}
+
+ X, y = make_classification(n_samples=40, n_features=4, random_state=42)
+ est = LinearSVC(random_state=42)
+ search = GridSearchCV(est, {"C": [0.1, 1]}, scoring=custom_scorer, refit="fp")
+
+ search.fit(X, y)
+
+ score_names = ["tn", "fp", "fn", "tp"]
+ for name in score_names:
+ assert "mean_test_{}".format(name) in search.cv_results_
+
+ y_pred = search.predict(X)
+ cm = confusion_matrix(y, y_pred)
+ assert search.score(X, y) == pytest.approx(cm[0, 1])
+
+
+def test_callable_multimetric_same_as_list_of_strings():
+ # Test callable multimetric is the same as a list of strings
+ def custom_scorer(est, X, y):
+ y_pred = est.predict(X)
+ return {
+ "recall": recall_score(y, y_pred),
+ "accuracy": accuracy_score(y, y_pred),
+ }
+
+ X, y = make_classification(n_samples=40, n_features=4, random_state=42)
+ est = LinearSVC(random_state=42)
+ search_callable = GridSearchCV(
+ est, {"C": [0.1, 1]}, scoring=custom_scorer, refit="recall"
+ )
+ search_str = GridSearchCV(
+ est, {"C": [0.1, 1]}, scoring=["recall", "accuracy"], refit="recall"
+ )
+
+ search_callable.fit(X, y)
+ search_str.fit(X, y)
+
+ assert search_callable.best_score_ == pytest.approx(search_str.best_score_)
+ assert search_callable.best_index_ == search_str.best_index_
+ assert search_callable.score(X, y) == pytest.approx(search_str.score(X, y))
+
+
+def test_callable_single_metric_same_as_single_string():
+ # Tests callable scorer is the same as scoring with a single string
+ def custom_scorer(est, X, y):
+ y_pred = est.predict(X)
+ return recall_score(y, y_pred)
+
+ X, y = make_classification(n_samples=40, n_features=4, random_state=42)
+ est = LinearSVC(random_state=42)
+ search_callable = GridSearchCV(
+ est, {"C": [0.1, 1]}, scoring=custom_scorer, refit=True
+ )
+ search_str = GridSearchCV(est, {"C": [0.1, 1]}, scoring="recall", refit="recall")
+ search_list_str = GridSearchCV(
+ est, {"C": [0.1, 1]}, scoring=["recall"], refit="recall"
+ )
+ search_callable.fit(X, y)
+ search_str.fit(X, y)
+ search_list_str.fit(X, y)
+
+ assert search_callable.best_score_ == pytest.approx(search_str.best_score_)
+ assert search_callable.best_index_ == search_str.best_index_
+ assert search_callable.score(X, y) == pytest.approx(search_str.score(X, y))
+
+ assert search_list_str.best_score_ == pytest.approx(search_str.best_score_)
+ assert search_list_str.best_index_ == search_str.best_index_
+ assert search_list_str.score(X, y) == pytest.approx(search_str.score(X, y))
+
+
+def test_callable_multimetric_error_on_invalid_key():
+ # Raises when the callable scorer does not return a dict with `refit` key.
+ def bad_scorer(est, X, y):
+ return {"bad_name": 1}
+
+ X, y = make_classification(n_samples=40, n_features=4, random_state=42)
+ clf = GridSearchCV(
+ LinearSVC(random_state=42),
+ {"C": [0.1, 1]},
+ scoring=bad_scorer,
+ refit="good_name",
+ )
+
+ msg = (
+ "For multi-metric scoring, the parameter refit must be set to a "
+ "scorer key or a callable to refit"
+ )
+ with pytest.raises(ValueError, match=msg):
+ clf.fit(X, y)
+
+
+def test_callable_multimetric_error_failing_clf():
+ # Warns when there is an estimator the fails to fit with a float
+ # error_score
+ def custom_scorer(est, X, y):
+ return {"acc": 1}
+
+ X, y = make_classification(n_samples=20, n_features=10, random_state=0)
+
+ clf = FailingClassifier()
+ gs = GridSearchCV(
+ clf,
+ [{"parameter": [0, 1, 2]}],
+ scoring=custom_scorer,
+ refit=False,
+ error_score=0.1,
+ )
+
+ warning_message = re.compile(
+ "5 fits failed.+total of 15.+The score on these"
+ r" train-test partitions for these parameters will be set to 0\.1",
+ flags=re.DOTALL,
+ )
+ with pytest.warns(FitFailedWarning, match=warning_message):
+ gs.fit(X, y)
+
+ assert_allclose(gs.cv_results_["mean_test_acc"], [1, 1, 0.1])
+
+
+def test_callable_multimetric_clf_all_fits_fail():
+ # Warns and raises when all estimator fails to fit.
+ def custom_scorer(est, X, y):
+ return {"acc": 1}
+
+ X, y = make_classification(n_samples=20, n_features=10, random_state=0)
+
+ clf = FailingClassifier()
+
+ gs = GridSearchCV(
+ clf,
+ [{"parameter": [FailingClassifier.FAILING_PARAMETER] * 3}],
+ scoring=custom_scorer,
+ refit=False,
+ error_score=0.1,
+ )
+
+ individual_fit_error_message = "ValueError: Failing classifier failed as required"
+ error_message = re.compile(
+ "All the 15 fits failed.+your model is misconfigured.+"
+ f"{individual_fit_error_message}",
+ flags=re.DOTALL,
+ )
+
+ with pytest.raises(ValueError, match=error_message):
+ gs.fit(X, y)
+
+
+def test_n_features_in():
+ # make sure grid search and random search delegate n_features_in to the
+ # best estimator
+ n_features = 4
+ X, y = make_classification(n_features=n_features)
+ gbdt = HistGradientBoostingClassifier()
+ param_grid = {"max_iter": [3, 4]}
+ gs = GridSearchCV(gbdt, param_grid)
+ rs = RandomizedSearchCV(gbdt, param_grid, n_iter=1)
+ assert not hasattr(gs, "n_features_in_")
+ assert not hasattr(rs, "n_features_in_")
+ gs.fit(X, y)
+ rs.fit(X, y)
+ assert gs.n_features_in_ == n_features
+ assert rs.n_features_in_ == n_features
+
+
+@pytest.mark.parametrize("pairwise", [True, False])
+def test_search_cv_pairwise_property_delegated_to_base_estimator(pairwise):
+ """
+ Test implementation of BaseSearchCV has the pairwise tag
+ which matches the pairwise tag of its estimator.
+ This test make sure pairwise tag is delegated to the base estimator.
+
+ Non-regression test for issue #13920.
+ """
+
+ class TestEstimator(BaseEstimator):
+ def _more_tags(self):
+ return {"pairwise": pairwise}
+
+ est = TestEstimator()
+ attr_message = "BaseSearchCV pairwise tag must match estimator"
+ cv = GridSearchCV(est, {"n_neighbors": [10]})
+ assert pairwise == cv._get_tags()["pairwise"], attr_message
+
+
+def test_search_cv__pairwise_property_delegated_to_base_estimator():
+ """
+ Test implementation of BaseSearchCV has the pairwise property
+ which matches the pairwise tag of its estimator.
+ This test make sure pairwise tag is delegated to the base estimator.
+
+ Non-regression test for issue #13920.
+ """
+
+ class EstimatorPairwise(BaseEstimator):
+ def __init__(self, pairwise=True):
+ self.pairwise = pairwise
+
+ def _more_tags(self):
+ return {"pairwise": self.pairwise}
+
+ est = EstimatorPairwise()
+ attr_message = "BaseSearchCV _pairwise property must match estimator"
+
+ for _pairwise_setting in [True, False]:
+ est.set_params(pairwise=_pairwise_setting)
+ cv = GridSearchCV(est, {"n_neighbors": [10]})
+ assert _pairwise_setting == cv._get_tags()["pairwise"], attr_message
+
+
+def test_search_cv_pairwise_property_equivalence_of_precomputed():
+ """
+ Test implementation of BaseSearchCV has the pairwise tag
+ which matches the pairwise tag of its estimator.
+ This test ensures the equivalence of 'precomputed'.
+
+ Non-regression test for issue #13920.
+ """
+ n_samples = 50
+ n_splits = 2
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ grid_params = {"n_neighbors": [10]}
+
+ # defaults to euclidean metric (minkowski p = 2)
+ clf = KNeighborsClassifier()
+ cv = GridSearchCV(clf, grid_params, cv=n_splits)
+ cv.fit(X, y)
+ preds_original = cv.predict(X)
+
+ # precompute euclidean metric to validate pairwise is working
+ X_precomputed = euclidean_distances(X)
+ clf = KNeighborsClassifier(metric="precomputed")
+ cv = GridSearchCV(clf, grid_params, cv=n_splits)
+ cv.fit(X_precomputed, y)
+ preds_precomputed = cv.predict(X_precomputed)
+
+ attr_message = "GridSearchCV not identical with precomputed metric"
+ assert (preds_original == preds_precomputed).all(), attr_message
+
+
+@pytest.mark.parametrize(
+ "SearchCV, param_search",
+ [(GridSearchCV, {"a": [0.1, 0.01]}), (RandomizedSearchCV, {"a": uniform(1, 3)})],
+)
+def test_scalar_fit_param(SearchCV, param_search):
+ # unofficially sanctioned tolerance for scalar values in fit_params
+ # non-regression test for:
+ # https://github.com/scikit-learn/scikit-learn/issues/15805
+ class TestEstimator(ClassifierMixin, BaseEstimator):
+ def __init__(self, a=None):
+ self.a = a
+
+ def fit(self, X, y, r=None):
+ self.r_ = r
+
+ def predict(self, X):
+ return np.zeros(shape=(len(X)))
+
+ model = SearchCV(TestEstimator(), param_search)
+ X, y = make_classification(random_state=42)
+ model.fit(X, y, r=42)
+ assert model.best_estimator_.r_ == 42
+
+
+@pytest.mark.parametrize(
+ "SearchCV, param_search",
+ [
+ (GridSearchCV, {"alpha": [0.1, 0.01]}),
+ (RandomizedSearchCV, {"alpha": uniform(0.01, 0.1)}),
+ ],
+)
+def test_scalar_fit_param_compat(SearchCV, param_search):
+ # check support for scalar values in fit_params, for instance in LightGBM
+ # that do not exactly respect the scikit-learn API contract but that we do
+ # not want to break without an explicit deprecation cycle and API
+ # recommendations for implementing early stopping with a user provided
+ # validation set. non-regression test for:
+ # https://github.com/scikit-learn/scikit-learn/issues/15805
+ X_train, X_valid, y_train, y_valid = train_test_split(
+ *make_classification(random_state=42), random_state=42
+ )
+
+ class _FitParamClassifier(SGDClassifier):
+ def fit(
+ self,
+ X,
+ y,
+ sample_weight=None,
+ tuple_of_arrays=None,
+ scalar_param=None,
+ callable_param=None,
+ ):
+ super().fit(X, y, sample_weight=sample_weight)
+ assert scalar_param > 0
+ assert callable(callable_param)
+
+ # The tuple of arrays should be preserved as tuple.
+ assert isinstance(tuple_of_arrays, tuple)
+ assert tuple_of_arrays[0].ndim == 2
+ assert tuple_of_arrays[1].ndim == 1
+ return self
+
+ def _fit_param_callable():
+ pass
+
+ model = SearchCV(_FitParamClassifier(), param_search)
+
+ # NOTE: `fit_params` should be data dependent (e.g. `sample_weight`) which
+ # is not the case for the following parameters. But this abuse is common in
+ # popular third-party libraries and we should tolerate this behavior for
+ # now and be careful not to break support for those without following
+ # proper deprecation cycle.
+ fit_params = {
+ "tuple_of_arrays": (X_valid, y_valid),
+ "callable_param": _fit_param_callable,
+ "scalar_param": 42,
+ }
+ model.fit(X_train, y_train, **fit_params)
+
+
+# FIXME: Replace this test with a full `check_estimator` once we have API only
+# checks.
+@pytest.mark.filterwarnings("ignore:The total space of parameters 4 is")
+@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
+@pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier])
+def test_search_cv_using_minimal_compatible_estimator(SearchCV, Predictor):
+ # Check that third-party library can run tests without inheriting from
+ # BaseEstimator.
+ rng = np.random.RandomState(0)
+ X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20)
+
+ model = Pipeline(
+ [("transformer", MinimalTransformer()), ("predictor", Predictor())]
+ )
+
+ params = {
+ "transformer__param": [1, 10],
+ "predictor__parama": [1, 10],
+ }
+ search = SearchCV(model, params, error_score="raise")
+ search.fit(X, y)
+
+ assert search.best_params_.keys() == params.keys()
+
+ y_pred = search.predict(X)
+ if is_classifier(search):
+ assert_array_equal(y_pred, 1)
+ assert search.score(X, y) == pytest.approx(accuracy_score(y, y_pred))
+ else:
+ assert_allclose(y_pred, y.mean())
+ assert search.score(X, y) == pytest.approx(r2_score(y, y_pred))
+
+
+@pytest.mark.parametrize("return_train_score", [True, False])
+def test_search_cv_verbose_3(capsys, return_train_score):
+ """Check that search cv with verbose>2 shows the score for single
+ metrics. non-regression test for #19658."""
+ X, y = make_classification(n_samples=100, n_classes=2, flip_y=0.2, random_state=0)
+ clf = LinearSVC(random_state=0)
+ grid = {"C": [0.1]}
+
+ GridSearchCV(
+ clf,
+ grid,
+ scoring="accuracy",
+ verbose=3,
+ cv=3,
+ return_train_score=return_train_score,
+ ).fit(X, y)
+ captured = capsys.readouterr().out
+ if return_train_score:
+ match = re.findall(r"score=\(train=[\d\.]+, test=[\d.]+\)", captured)
+ else:
+ match = re.findall(r"score=[\d\.]+", captured)
+ assert len(match) == 3
diff --git a/modin/pandas/test/interoperability/sklearn/model_selection/test_split.py b/modin/pandas/test/interoperability/sklearn/model_selection/test_split.py
new file mode 100644
index 00000000000..123ca127916
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/model_selection/test_split.py
@@ -0,0 +1,1911 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+"""Test the split module"""
+import warnings
+import pytest
+import re
+import numpy as np
+from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
+from scipy import stats
+from scipy.special import comb
+from itertools import combinations
+from itertools import combinations_with_replacement
+from itertools import permutations
+
+from sklearn.utils._testing import assert_allclose
+from sklearn.utils._testing import assert_array_almost_equal
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import ignore_warnings
+from sklearn.utils.validation import _num_samples
+from sklearn.utils._mocking import MockDataFrame
+
+from sklearn.model_selection import cross_val_score
+from sklearn.model_selection import KFold
+from sklearn.model_selection import StratifiedKFold
+from sklearn.model_selection import GroupKFold
+from sklearn.model_selection import TimeSeriesSplit
+from sklearn.model_selection import LeaveOneOut
+from sklearn.model_selection import LeaveOneGroupOut
+from sklearn.model_selection import LeavePOut
+from sklearn.model_selection import LeavePGroupsOut
+from sklearn.model_selection import ShuffleSplit
+from sklearn.model_selection import GroupShuffleSplit
+from sklearn.model_selection import StratifiedShuffleSplit
+from sklearn.model_selection import PredefinedSplit
+from sklearn.model_selection import check_cv
+from sklearn.model_selection import train_test_split
+from sklearn.model_selection import GridSearchCV
+from sklearn.model_selection import RepeatedKFold
+from sklearn.model_selection import RepeatedStratifiedKFold
+from sklearn.model_selection import StratifiedGroupKFold
+
+from sklearn.dummy import DummyClassifier
+
+from sklearn.model_selection._split import _validate_shuffle_split
+from sklearn.model_selection._split import _build_repr
+from sklearn.model_selection._split import _yields_constant_splits
+
+from sklearn.datasets import load_digits
+from sklearn.datasets import make_classification
+
+from sklearn.svm import SVC
+
+X = np.ones(10)
+y = np.arange(10) // 2
+P_sparse = coo_matrix(np.eye(5))
+test_groups = (
+ np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
+ np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
+ np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
+ [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
+ ["1", "1", "1", "1", "2", "2", "2", "3", "3", "3", "3", "3"],
+)
+digits = load_digits()
+
+
+@ignore_warnings
+def test_cross_validator_with_default_params():
+ n_samples = 4
+ n_unique_groups = 4
+ n_splits = 2
+ p = 2
+ n_shuffle_splits = 10 # (the default value)
+
+ X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
+ X_1d = np.array([1, 2, 3, 4])
+ y = np.array([1, 1, 2, 2])
+ groups = np.array([1, 2, 3, 4])
+ loo = LeaveOneOut()
+ lpo = LeavePOut(p)
+ kf = KFold(n_splits)
+ skf = StratifiedKFold(n_splits)
+ lolo = LeaveOneGroupOut()
+ lopo = LeavePGroupsOut(p)
+ ss = ShuffleSplit(random_state=0)
+ ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2
+ sgkf = StratifiedGroupKFold(n_splits)
+
+ loo_repr = "LeaveOneOut()"
+ lpo_repr = "LeavePOut(p=2)"
+ kf_repr = "KFold(n_splits=2, random_state=None, shuffle=False)"
+ skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
+ lolo_repr = "LeaveOneGroupOut()"
+ lopo_repr = "LeavePGroupsOut(n_groups=2)"
+ ss_repr = (
+ "ShuffleSplit(n_splits=10, random_state=0, test_size=None, train_size=None)"
+ )
+ ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"
+ sgkf_repr = "StratifiedGroupKFold(n_splits=2, random_state=None, shuffle=False)"
+
+ n_splits_expected = [
+ n_samples,
+ comb(n_samples, p),
+ n_splits,
+ n_splits,
+ n_unique_groups,
+ comb(n_unique_groups, p),
+ n_shuffle_splits,
+ 2,
+ n_splits,
+ ]
+
+ for i, (cv, cv_repr) in enumerate(
+ zip(
+ [loo, lpo, kf, skf, lolo, lopo, ss, ps, sgkf],
+ [
+ loo_repr,
+ lpo_repr,
+ kf_repr,
+ skf_repr,
+ lolo_repr,
+ lopo_repr,
+ ss_repr,
+ ps_repr,
+ sgkf_repr,
+ ],
+ )
+ ):
+ # Test if get_n_splits works correctly
+ assert n_splits_expected[i] == cv.get_n_splits(X, y, groups)
+
+ # Test if the cross-validator works as expected even if
+ # the data is 1d
+ np.testing.assert_equal(
+ list(cv.split(X, y, groups)), list(cv.split(X_1d, y, groups))
+ )
+ # Test that train, test indices returned are integers
+ for train, test in cv.split(X, y, groups):
+ assert np.asarray(train).dtype.kind == "i"
+ assert np.asarray(test).dtype.kind == "i"
+
+ # Test if the repr works without any errors
+ assert cv_repr == repr(cv)
+
+ # ValueError for get_n_splits methods
+ msg = "The 'X' parameter should not be None."
+ with pytest.raises(ValueError, match=msg):
+ loo.get_n_splits(None, y, groups)
+ with pytest.raises(ValueError, match=msg):
+ lpo.get_n_splits(None, y, groups)
+
+
+def test_2d_y():
+ # smoke test for 2d y and multi-label
+ n_samples = 30
+ rng = np.random.RandomState(1)
+ X = rng.randint(0, 3, size=(n_samples, 2))
+ y = rng.randint(0, 3, size=(n_samples,))
+ y_2d = y.reshape(-1, 1)
+ y_multilabel = rng.randint(0, 2, size=(n_samples, 3))
+ groups = rng.randint(0, 3, size=(n_samples,))
+ splitters = [
+ LeaveOneOut(),
+ LeavePOut(p=2),
+ KFold(),
+ StratifiedKFold(),
+ RepeatedKFold(),
+ RepeatedStratifiedKFold(),
+ StratifiedGroupKFold(),
+ ShuffleSplit(),
+ StratifiedShuffleSplit(test_size=0.5),
+ GroupShuffleSplit(),
+ LeaveOneGroupOut(),
+ LeavePGroupsOut(n_groups=2),
+ GroupKFold(n_splits=3),
+ TimeSeriesSplit(),
+ PredefinedSplit(test_fold=groups),
+ ]
+ for splitter in splitters:
+ list(splitter.split(X, y, groups))
+ list(splitter.split(X, y_2d, groups))
+ try:
+ list(splitter.split(X, y_multilabel, groups))
+ except ValueError as e:
+ allowed_target_types = ("binary", "multiclass")
+ msg = "Supported target types are: {}. Got 'multilabel".format(
+ allowed_target_types
+ )
+ assert msg in str(e)
+
+
+def check_valid_split(train, test, n_samples=None):
+ # Use python sets to get more informative assertion failure messages
+ train, test = set(train), set(test)
+
+ # Train and test split should not overlap
+ assert train.intersection(test) == set()
+
+ if n_samples is not None:
+ # Check that the union of train an test split cover all the indices
+ assert train.union(test) == set(range(n_samples))
+
+
+def check_cv_coverage(cv, X, y, groups, expected_n_splits):
+ n_samples = _num_samples(X)
+ # Check that a all the samples appear at least once in a test fold
+ assert cv.get_n_splits(X, y, groups) == expected_n_splits
+
+ collected_test_samples = set()
+ iterations = 0
+ for train, test in cv.split(X, y, groups):
+ check_valid_split(train, test, n_samples=n_samples)
+ iterations += 1
+ collected_test_samples.update(test)
+
+ # Check that the accumulated test samples cover the whole dataset
+ assert iterations == expected_n_splits
+ if n_samples is not None:
+ assert collected_test_samples == set(range(n_samples))
+
+
+def test_kfold_valueerrors():
+ X1 = np.array([[1, 2], [3, 4], [5, 6]])
+ X2 = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
+ # Check that errors are raised if there is not enough samples
+ (ValueError, next, KFold(4).split(X1))
+
+ # Check that a warning is raised if the least populated class has too few
+ # members.
+ y = np.array([3, 3, -1, -1, 3])
+
+ skf_3 = StratifiedKFold(3)
+ with pytest.warns(Warning, match="The least populated class"):
+ next(skf_3.split(X2, y))
+
+ sgkf_3 = StratifiedGroupKFold(3)
+ naive_groups = np.arange(len(y))
+ with pytest.warns(Warning, match="The least populated class"):
+ next(sgkf_3.split(X2, y, naive_groups))
+
+ # Check that despite the warning the folds are still computed even
+ # though all the classes are not necessarily represented at on each
+ # side of the split at each split
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ check_cv_coverage(skf_3, X2, y, groups=None, expected_n_splits=3)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ check_cv_coverage(sgkf_3, X2, y, groups=naive_groups, expected_n_splits=3)
+
+ # Check that errors are raised if all n_groups for individual
+ # classes are less than n_splits.
+ y = np.array([3, 3, -1, -1, 2])
+
+ with pytest.raises(ValueError):
+ next(skf_3.split(X2, y))
+ with pytest.raises(ValueError):
+ next(sgkf_3.split(X2, y))
+
+ # Error when number of folds is <= 1
+ with pytest.raises(ValueError):
+ KFold(0)
+ with pytest.raises(ValueError):
+ KFold(1)
+ error_string = "k-fold cross-validation requires at least one train/test split"
+ with pytest.raises(ValueError, match=error_string):
+ StratifiedKFold(0)
+ with pytest.raises(ValueError, match=error_string):
+ StratifiedKFold(1)
+ with pytest.raises(ValueError, match=error_string):
+ StratifiedGroupKFold(0)
+ with pytest.raises(ValueError, match=error_string):
+ StratifiedGroupKFold(1)
+
+ # When n_splits is not integer:
+ with pytest.raises(ValueError):
+ KFold(1.5)
+ with pytest.raises(ValueError):
+ KFold(2.0)
+ with pytest.raises(ValueError):
+ StratifiedKFold(1.5)
+ with pytest.raises(ValueError):
+ StratifiedKFold(2.0)
+ with pytest.raises(ValueError):
+ StratifiedGroupKFold(1.5)
+ with pytest.raises(ValueError):
+ StratifiedGroupKFold(2.0)
+
+ # When shuffle is not a bool:
+ with pytest.raises(TypeError):
+ KFold(n_splits=4, shuffle=None)
+
+
+def test_kfold_indices():
+ # Check all indices are returned in the test folds
+ X1 = np.ones(18)
+ kf = KFold(3)
+ check_cv_coverage(kf, X1, y=None, groups=None, expected_n_splits=3)
+
+ # Check all indices are returned in the test folds even when equal-sized
+ # folds are not possible
+ X2 = np.ones(17)
+ kf = KFold(3)
+ check_cv_coverage(kf, X2, y=None, groups=None, expected_n_splits=3)
+
+ # Check if get_n_splits returns the number of folds
+ assert 5 == KFold(5).get_n_splits(X2)
+
+
+def test_kfold_no_shuffle():
+ # Manually check that KFold preserves the data ordering on toy datasets
+ X2 = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+
+ splits = KFold(2).split(X2[:-1])
+ train, test = next(splits)
+ assert_array_equal(test, [0, 1])
+ assert_array_equal(train, [2, 3])
+
+ train, test = next(splits)
+ assert_array_equal(test, [2, 3])
+ assert_array_equal(train, [0, 1])
+
+ splits = KFold(2).split(X2)
+ train, test = next(splits)
+ assert_array_equal(test, [0, 1, 2])
+ assert_array_equal(train, [3, 4])
+
+ train, test = next(splits)
+ assert_array_equal(test, [3, 4])
+ assert_array_equal(train, [0, 1, 2])
+
+
+def test_stratified_kfold_no_shuffle():
+ # Manually check that StratifiedKFold preserves the data ordering as much
+ # as possible on toy datasets in order to avoid hiding sample dependencies
+ # when possible
+ X, y = np.ones(4), [1, 1, 0, 0]
+ splits = StratifiedKFold(2).split(X, y)
+ train, test = next(splits)
+ assert_array_equal(test, [0, 2])
+ assert_array_equal(train, [1, 3])
+
+ train, test = next(splits)
+ assert_array_equal(test, [1, 3])
+ assert_array_equal(train, [0, 2])
+
+ X, y = np.ones(7), [1, 1, 1, 0, 0, 0, 0]
+ splits = StratifiedKFold(2).split(X, y)
+ train, test = next(splits)
+ assert_array_equal(test, [0, 1, 3, 4])
+ assert_array_equal(train, [2, 5, 6])
+
+ train, test = next(splits)
+ assert_array_equal(test, [2, 5, 6])
+ assert_array_equal(train, [0, 1, 3, 4])
+
+ # Check if get_n_splits returns the number of folds
+ assert 5 == StratifiedKFold(5).get_n_splits(X, y)
+
+ # Make sure string labels are also supported
+ X = np.ones(7)
+ y1 = ["1", "1", "1", "0", "0", "0", "0"]
+ y2 = [1, 1, 1, 0, 0, 0, 0]
+ np.testing.assert_equal(
+ list(StratifiedKFold(2).split(X, y1)), list(StratifiedKFold(2).split(X, y2))
+ )
+
+ # Check equivalence to KFold
+ y = [0, 1, 0, 1, 0, 1, 0, 1]
+ X = np.ones_like(y)
+ np.testing.assert_equal(
+ list(StratifiedKFold(3).split(X, y)), list(KFold(3).split(X, y))
+ )
+
+
+@pytest.mark.parametrize("shuffle", [False, True])
+@pytest.mark.parametrize("k", [4, 5, 6, 7, 8, 9, 10])
+@pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
+def test_stratified_kfold_ratios(k, shuffle, kfold):
+ # Check that stratified kfold preserves class ratios in individual splits
+ # Repeat with shuffling turned off and on
+ n_samples = 1000
+ X = np.ones(n_samples)
+ y = np.array(
+ [4] * int(0.10 * n_samples)
+ + [0] * int(0.89 * n_samples)
+ + [1] * int(0.01 * n_samples)
+ )
+ # ensure perfect stratification with StratifiedGroupKFold
+ groups = np.arange(len(y))
+ distr = np.bincount(y) / len(y)
+
+ test_sizes = []
+ random_state = None if not shuffle else 0
+ skf = kfold(k, random_state=random_state, shuffle=shuffle)
+ for train, test in skf.split(X, y, groups=groups):
+ assert_allclose(np.bincount(y[train]) / len(train), distr, atol=0.02)
+ assert_allclose(np.bincount(y[test]) / len(test), distr, atol=0.02)
+ test_sizes.append(len(test))
+ assert np.ptp(test_sizes) <= 1
+
+
+@pytest.mark.parametrize("shuffle", [False, True])
+@pytest.mark.parametrize("k", [4, 6, 7])
+@pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
+def test_stratified_kfold_label_invariance(k, shuffle, kfold):
+ # Check that stratified kfold gives the same indices regardless of labels
+ n_samples = 100
+ y = np.array(
+ [2] * int(0.10 * n_samples)
+ + [0] * int(0.89 * n_samples)
+ + [1] * int(0.01 * n_samples)
+ )
+ X = np.ones(len(y))
+ # ensure perfect stratification with StratifiedGroupKFold
+ groups = np.arange(len(y))
+
+ def get_splits(y):
+ random_state = None if not shuffle else 0
+ return [
+ (list(train), list(test))
+ for train, test in kfold(
+ k, random_state=random_state, shuffle=shuffle
+ ).split(X, y, groups=groups)
+ ]
+
+ splits_base = get_splits(y)
+ for perm in permutations([0, 1, 2]):
+ y_perm = np.take(perm, y)
+ splits_perm = get_splits(y_perm)
+ assert splits_perm == splits_base
+
+
+def test_kfold_balance():
+ # Check that KFold returns folds with balanced sizes
+ for i in range(11, 17):
+ kf = KFold(5).split(X=np.ones(i))
+ sizes = [len(test) for _, test in kf]
+
+ assert (np.max(sizes) - np.min(sizes)) <= 1
+ assert np.sum(sizes) == i
+
+
+@pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
+def test_stratifiedkfold_balance(kfold):
+ # Check that KFold returns folds with balanced sizes (only when
+ # stratification is possible)
+ # Repeat with shuffling turned off and on
+ X = np.ones(17)
+ y = [0] * 3 + [1] * 14
+ # ensure perfect stratification with StratifiedGroupKFold
+ groups = np.arange(len(y))
+
+ for shuffle in (True, False):
+ cv = kfold(3, shuffle=shuffle)
+ for i in range(11, 17):
+ skf = cv.split(X[:i], y[:i], groups[:i])
+ sizes = [len(test) for _, test in skf]
+
+ assert (np.max(sizes) - np.min(sizes)) <= 1
+ assert np.sum(sizes) == i
+
+
+def test_shuffle_kfold():
+ # Check the indices are shuffled properly
+ kf = KFold(3)
+ kf2 = KFold(3, shuffle=True, random_state=0)
+ kf3 = KFold(3, shuffle=True, random_state=1)
+
+ X = np.ones(300)
+
+ all_folds = np.zeros(300)
+ for (tr1, te1), (tr2, te2), (tr3, te3) in zip(
+ kf.split(X), kf2.split(X), kf3.split(X)
+ ):
+ for tr_a, tr_b in combinations((tr1, tr2, tr3), 2):
+ # Assert that there is no complete overlap
+ assert len(np.intersect1d(tr_a, tr_b)) != len(tr1)
+
+ # Set all test indices in successive iterations of kf2 to 1
+ all_folds[te2] = 1
+
+ # Check that all indices are returned in the different test folds
+ assert sum(all_folds) == 300
+
+
+@pytest.mark.parametrize("kfold", [KFold, StratifiedKFold, StratifiedGroupKFold])
+def test_shuffle_kfold_stratifiedkfold_reproducibility(kfold):
+ X = np.ones(15) # Divisible by 3
+ y = [0] * 7 + [1] * 8
+ groups_1 = np.arange(len(y))
+ X2 = np.ones(16) # Not divisible by 3
+ y2 = [0] * 8 + [1] * 8
+ groups_2 = np.arange(len(y2))
+
+ # Check that when the shuffle is True, multiple split calls produce the
+ # same split when random_state is int
+ kf = kfold(3, shuffle=True, random_state=0)
+
+ np.testing.assert_equal(
+ list(kf.split(X, y, groups_1)), list(kf.split(X, y, groups_1))
+ )
+
+ # Check that when the shuffle is True, multiple split calls often
+ # (not always) produce different splits when random_state is
+ # RandomState instance or None
+ kf = kfold(3, shuffle=True, random_state=np.random.RandomState(0))
+ for data in zip((X, X2), (y, y2), (groups_1, groups_2)):
+ # Test if the two splits are different cv
+ for (_, test_a), (_, test_b) in zip(kf.split(*data), kf.split(*data)):
+ # cv.split(...) returns an array of tuples, each tuple
+ # consisting of an array with train indices and test indices
+ # Ensure that the splits for data are not same
+ # when random state is not set
+ with pytest.raises(AssertionError):
+ np.testing.assert_array_equal(test_a, test_b)
+
+
+def test_shuffle_stratifiedkfold():
+ # Check that shuffling is happening when requested, and for proper
+ # sample coverage
+ X_40 = np.ones(40)
+ y = [0] * 20 + [1] * 20
+ kf0 = StratifiedKFold(5, shuffle=True, random_state=0)
+ kf1 = StratifiedKFold(5, shuffle=True, random_state=1)
+ for (_, test0), (_, test1) in zip(kf0.split(X_40, y), kf1.split(X_40, y)):
+ assert set(test0) != set(test1)
+ check_cv_coverage(kf0, X_40, y, groups=None, expected_n_splits=5)
+
+ # Ensure that we shuffle each class's samples with different
+ # random_state in StratifiedKFold
+ # See https://github.com/scikit-learn/scikit-learn/pull/13124
+ X = np.arange(10)
+ y = [0] * 5 + [1] * 5
+ kf1 = StratifiedKFold(5, shuffle=True, random_state=0)
+ kf2 = StratifiedKFold(5, shuffle=True, random_state=1)
+ test_set1 = sorted([tuple(s[1]) for s in kf1.split(X, y)])
+ test_set2 = sorted([tuple(s[1]) for s in kf2.split(X, y)])
+ assert test_set1 != test_set2
+
+
+def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372
+ # The digits samples are dependent: they are apparently grouped by authors
+ # although we don't have any information on the groups segment locations
+ # for this data. We can highlight this fact by computing k-fold cross-
+ # validation with and without shuffling: we observe that the shuffling case
+ # wrongly makes the IID assumption and is therefore too optimistic: it
+ # estimates a much higher accuracy (around 0.93) than that the non
+ # shuffling variant (around 0.81).
+
+ X, y = digits.data[:600], digits.target[:600]
+ model = SVC(C=10, gamma=0.005)
+
+ n_splits = 3
+
+ cv = KFold(n_splits=n_splits, shuffle=False)
+ mean_score = cross_val_score(model, X, y, cv=cv).mean()
+ assert 0.92 > mean_score
+ assert mean_score > 0.80
+
+ # Shuffling the data artificially breaks the dependency and hides the
+ # overfitting of the model with regards to the writing style of the authors
+ # by yielding a seriously overestimated score:
+
+ cv = KFold(n_splits, shuffle=True, random_state=0)
+ mean_score = cross_val_score(model, X, y, cv=cv).mean()
+ assert mean_score > 0.92
+
+ cv = KFold(n_splits, shuffle=True, random_state=1)
+ mean_score = cross_val_score(model, X, y, cv=cv).mean()
+ assert mean_score > 0.92
+
+ # Similarly, StratifiedKFold should try to shuffle the data as little
+ # as possible (while respecting the balanced class constraints)
+ # and thus be able to detect the dependency by not overestimating
+ # the CV score either. As the digits dataset is approximately balanced
+ # the estimated mean score is close to the score measured with
+ # non-shuffled KFold
+
+ cv = StratifiedKFold(n_splits)
+ mean_score = cross_val_score(model, X, y, cv=cv).mean()
+ assert 0.94 > mean_score
+ assert mean_score > 0.80
+
+
+def test_stratified_group_kfold_trivial():
+ sgkf = StratifiedGroupKFold(n_splits=3)
+ # Trivial example - groups with the same distribution
+ y = np.array([1] * 6 + [0] * 12)
+ X = np.ones_like(y).reshape(-1, 1)
+ groups = np.asarray((1, 2, 3, 4, 5, 6, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6))
+ distr = np.bincount(y) / len(y)
+ test_sizes = []
+ for train, test in sgkf.split(X, y, groups):
+ # check group constraint
+ assert np.intersect1d(groups[train], groups[test]).size == 0
+ # check y distribution
+ assert_allclose(np.bincount(y[train]) / len(train), distr, atol=0.02)
+ assert_allclose(np.bincount(y[test]) / len(test), distr, atol=0.02)
+ test_sizes.append(len(test))
+ assert np.ptp(test_sizes) <= 1
+
+
+def test_stratified_group_kfold_approximate():
+ # Not perfect stratification (even though it is possible) because of
+ # iteration over groups
+ sgkf = StratifiedGroupKFold(n_splits=3)
+ y = np.array([1] * 6 + [0] * 12)
+ X = np.ones_like(y).reshape(-1, 1)
+ groups = np.array([1, 2, 3, 3, 4, 4, 1, 1, 2, 2, 3, 4, 5, 5, 5, 6, 6, 6])
+ expected = np.asarray([[0.833, 0.166], [0.666, 0.333], [0.5, 0.5]])
+ test_sizes = []
+ for (train, test), expect_dist in zip(sgkf.split(X, y, groups), expected):
+ # check group constraint
+ assert np.intersect1d(groups[train], groups[test]).size == 0
+ split_dist = np.bincount(y[test]) / len(test)
+ assert_allclose(split_dist, expect_dist, atol=0.001)
+ test_sizes.append(len(test))
+ assert np.ptp(test_sizes) <= 1
+
+
+@pytest.mark.parametrize(
+ "y, groups, expected",
+ [
+ (
+ np.array([0] * 6 + [1] * 6),
+ np.array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
+ np.asarray([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]),
+ ),
+ (
+ np.array([0] * 9 + [1] * 3),
+ np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6]),
+ np.asarray([[0.75, 0.25], [0.75, 0.25], [0.75, 0.25]]),
+ ),
+ ],
+)
+def test_stratified_group_kfold_homogeneous_groups(y, groups, expected):
+ sgkf = StratifiedGroupKFold(n_splits=3)
+ X = np.ones_like(y).reshape(-1, 1)
+ for (train, test), expect_dist in zip(sgkf.split(X, y, groups), expected):
+ # check group constraint
+ assert np.intersect1d(groups[train], groups[test]).size == 0
+ split_dist = np.bincount(y[test]) / len(test)
+ assert_allclose(split_dist, expect_dist, atol=0.001)
+
+
+@pytest.mark.parametrize("cls_distr", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), (0.8, 0.2)])
+@pytest.mark.parametrize("n_groups", [5, 30, 70])
+def test_stratified_group_kfold_against_group_kfold(cls_distr, n_groups):
+ # Check that given sufficient amount of samples StratifiedGroupKFold
+ # produces better stratified folds than regular GroupKFold
+ n_splits = 5
+ sgkf = StratifiedGroupKFold(n_splits=n_splits)
+ gkf = GroupKFold(n_splits=n_splits)
+ rng = np.random.RandomState(0)
+ n_points = 1000
+ y = rng.choice(2, size=n_points, p=cls_distr)
+ X = np.ones_like(y).reshape(-1, 1)
+ g = rng.choice(n_groups, n_points)
+ sgkf_folds = sgkf.split(X, y, groups=g)
+ gkf_folds = gkf.split(X, y, groups=g)
+ sgkf_entr = 0
+ gkf_entr = 0
+ for (sgkf_train, sgkf_test), (_, gkf_test) in zip(sgkf_folds, gkf_folds):
+ # check group constraint
+ assert np.intersect1d(g[sgkf_train], g[sgkf_test]).size == 0
+ sgkf_distr = np.bincount(y[sgkf_test]) / len(sgkf_test)
+ gkf_distr = np.bincount(y[gkf_test]) / len(gkf_test)
+ sgkf_entr += stats.entropy(sgkf_distr, qk=cls_distr)
+ gkf_entr += stats.entropy(gkf_distr, qk=cls_distr)
+ sgkf_entr /= n_splits
+ gkf_entr /= n_splits
+ assert sgkf_entr <= gkf_entr
+
+
+def test_shuffle_split():
+ ss1 = ShuffleSplit(test_size=0.2, random_state=0).split(X)
+ ss2 = ShuffleSplit(test_size=2, random_state=0).split(X)
+ ss3 = ShuffleSplit(test_size=np.int32(2), random_state=0).split(X)
+ ss4 = ShuffleSplit(test_size=int(2), random_state=0).split(X)
+ for t1, t2, t3, t4 in zip(ss1, ss2, ss3, ss4):
+ assert_array_equal(t1[0], t2[0])
+ assert_array_equal(t2[0], t3[0])
+ assert_array_equal(t3[0], t4[0])
+ assert_array_equal(t1[1], t2[1])
+ assert_array_equal(t2[1], t3[1])
+ assert_array_equal(t3[1], t4[1])
+
+
+@pytest.mark.parametrize("split_class", [ShuffleSplit, StratifiedShuffleSplit])
+@pytest.mark.parametrize(
+ "train_size, exp_train, exp_test", [(None, 9, 1), (8, 8, 2), (0.8, 8, 2)]
+)
+def test_shuffle_split_default_test_size(split_class, train_size, exp_train, exp_test):
+ # Check that the default value has the expected behavior, i.e. 0.1 if both
+ # unspecified or complement train_size unless both are specified.
+ X = np.ones(10)
+ y = np.ones(10)
+
+ X_train, X_test = next(split_class(train_size=train_size).split(X, y))
+
+ assert len(X_train) == exp_train
+ assert len(X_test) == exp_test
+
+
+@pytest.mark.parametrize(
+ "train_size, exp_train, exp_test", [(None, 8, 2), (7, 7, 3), (0.7, 7, 3)]
+)
+def test_group_shuffle_split_default_test_size(train_size, exp_train, exp_test):
+ # Check that the default value has the expected behavior, i.e. 0.2 if both
+ # unspecified or complement train_size unless both are specified.
+ X = np.ones(10)
+ y = np.ones(10)
+ groups = range(10)
+
+ X_train, X_test = next(GroupShuffleSplit(train_size=train_size).split(X, y, groups))
+
+ assert len(X_train) == exp_train
+ assert len(X_test) == exp_test
+
+
+@ignore_warnings
+def test_stratified_shuffle_split_init():
+ X = np.arange(7)
+ y = np.asarray([0, 1, 1, 1, 2, 2, 2])
+ # Check that error is raised if there is a class with only one sample
+ with pytest.raises(ValueError):
+ next(StratifiedShuffleSplit(3, test_size=0.2).split(X, y))
+
+ # Check that error is raised if the test set size is smaller than n_classes
+ with pytest.raises(ValueError):
+ next(StratifiedShuffleSplit(3, test_size=2).split(X, y))
+ # Check that error is raised if the train set size is smaller than
+ # n_classes
+ with pytest.raises(ValueError):
+ next(StratifiedShuffleSplit(3, test_size=3, train_size=2).split(X, y))
+
+ X = np.arange(9)
+ y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2])
+
+ # Train size or test size too small
+ with pytest.raises(ValueError):
+ next(StratifiedShuffleSplit(train_size=2).split(X, y))
+ with pytest.raises(ValueError):
+ next(StratifiedShuffleSplit(test_size=2).split(X, y))
+
+
+def test_stratified_shuffle_split_respects_test_size():
+ y = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2])
+ test_size = 5
+ train_size = 10
+ sss = StratifiedShuffleSplit(
+ 6, test_size=test_size, train_size=train_size, random_state=0
+ ).split(np.ones(len(y)), y)
+ for train, test in sss:
+ assert len(train) == train_size
+ assert len(test) == test_size
+
+
+def test_stratified_shuffle_split_iter():
+ ys = [
+ np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
+ np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
+ np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
+ np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
+ np.array([-1] * 800 + [1] * 50),
+ np.concatenate([[i] * (100 + i) for i in range(11)]),
+ [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
+ ["1", "1", "1", "1", "2", "2", "2", "3", "3", "3", "3", "3"],
+ ]
+
+ for y in ys:
+ sss = StratifiedShuffleSplit(6, test_size=0.33, random_state=0).split(
+ np.ones(len(y)), y
+ )
+ y = np.asanyarray(y) # To make it indexable for y[train]
+ # this is how test-size is computed internally
+ # in _validate_shuffle_split
+ test_size = np.ceil(0.33 * len(y))
+ train_size = len(y) - test_size
+ for train, test in sss:
+ assert_array_equal(np.unique(y[train]), np.unique(y[test]))
+ # Checks if folds keep classes proportions
+ p_train = np.bincount(np.unique(y[train], return_inverse=True)[1]) / float(
+ len(y[train])
+ )
+ p_test = np.bincount(np.unique(y[test], return_inverse=True)[1]) / float(
+ len(y[test])
+ )
+ assert_array_almost_equal(p_train, p_test, 1)
+ assert len(train) + len(test) == y.size
+ assert len(train) == train_size
+ assert len(test) == test_size
+ assert_array_equal(np.lib.arraysetops.intersect1d(train, test), [])
+
+
+def test_stratified_shuffle_split_even():
+ # Test the StratifiedShuffleSplit, indices are drawn with a
+ # equal chance
+ n_folds = 5
+ n_splits = 1000
+
+ def assert_counts_are_ok(idx_counts, p):
+ # Here we test that the distribution of the counts
+ # per index is close enough to a binomial
+ threshold = 0.05 / n_splits
+ bf = stats.binom(n_splits, p)
+ for count in idx_counts:
+ prob = bf.pmf(count)
+ assert (
+ prob > threshold
+ ), "An index is not drawn with chance corresponding to even draws"
+
+ for n_samples in (6, 22):
+ groups = np.array((n_samples // 2) * [0, 1])
+ splits = StratifiedShuffleSplit(
+ n_splits=n_splits, test_size=1.0 / n_folds, random_state=0
+ )
+
+ train_counts = [0] * n_samples
+ test_counts = [0] * n_samples
+ n_splits_actual = 0
+ for train, test in splits.split(X=np.ones(n_samples), y=groups):
+ n_splits_actual += 1
+ for counter, ids in [(train_counts, train), (test_counts, test)]:
+ for id in ids:
+ counter[id] += 1
+ assert n_splits_actual == n_splits
+
+ n_train, n_test = _validate_shuffle_split(
+ n_samples, test_size=1.0 / n_folds, train_size=1.0 - (1.0 / n_folds)
+ )
+
+ assert len(train) == n_train
+ assert len(test) == n_test
+ assert len(set(train).intersection(test)) == 0
+
+ group_counts = np.unique(groups)
+ assert splits.test_size == 1.0 / n_folds
+ assert n_train + n_test == len(groups)
+ assert len(group_counts) == 2
+ ex_test_p = float(n_test) / n_samples
+ ex_train_p = float(n_train) / n_samples
+
+ assert_counts_are_ok(train_counts, ex_train_p)
+ assert_counts_are_ok(test_counts, ex_test_p)
+
+
+def test_stratified_shuffle_split_overlap_train_test_bug():
+ # See https://github.com/scikit-learn/scikit-learn/issues/6121 for
+ # the original bug report
+ y = [0, 1, 2, 3] * 3 + [4, 5] * 5
+ X = np.ones_like(y)
+
+ sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
+
+ train, test = next(sss.split(X=X, y=y))
+
+ # no overlap
+ assert_array_equal(np.intersect1d(train, test), [])
+
+ # complete partition
+ assert_array_equal(np.union1d(train, test), np.arange(len(y)))
+
+
+def test_stratified_shuffle_split_multilabel():
+ # fix for issue 9037
+ for y in [
+ np.array([[0, 1], [1, 0], [1, 0], [0, 1]]),
+ np.array([[0, 1], [1, 1], [1, 1], [0, 1]]),
+ ]:
+ X = np.ones_like(y)
+ sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
+ train, test = next(sss.split(X=X, y=y))
+ y_train = y[train]
+ y_test = y[test]
+
+ # no overlap
+ assert_array_equal(np.intersect1d(train, test), [])
+
+ # complete partition
+ assert_array_equal(np.union1d(train, test), np.arange(len(y)))
+
+ # correct stratification of entire rows
+ # (by design, here y[:, 0] uniquely determines the entire row of y)
+ expected_ratio = np.mean(y[:, 0])
+ assert expected_ratio == np.mean(y_train[:, 0])
+ assert expected_ratio == np.mean(y_test[:, 0])
+
+
+def test_stratified_shuffle_split_multilabel_many_labels():
+ # fix in PR #9922: for multilabel data with > 1000 labels, str(row)
+ # truncates with an ellipsis for elements in positions 4 through
+ # len(row) - 4, so labels were not being correctly split using the powerset
+ # method for transforming a multilabel problem to a multiclass one; this
+ # test checks that this problem is fixed.
+ row_with_many_zeros = [1, 0, 1] + [0] * 1000 + [1, 0, 1]
+ row_with_many_ones = [1, 0, 1] + [1] * 1000 + [1, 0, 1]
+ y = np.array([row_with_many_zeros] * 10 + [row_with_many_ones] * 100)
+ X = np.ones_like(y)
+
+ sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
+ train, test = next(sss.split(X=X, y=y))
+ y_train = y[train]
+ y_test = y[test]
+
+ # correct stratification of entire rows
+ # (by design, here y[:, 4] uniquely determines the entire row of y)
+ expected_ratio = np.mean(y[:, 4])
+ assert expected_ratio == np.mean(y_train[:, 4])
+ assert expected_ratio == np.mean(y_test[:, 4])
+
+
+def test_predefinedsplit_with_kfold_split():
+ # Check that PredefinedSplit can reproduce a split generated by Kfold.
+ folds = np.full(10, -1.0)
+ kf_train = []
+ kf_test = []
+ for i, (train_ind, test_ind) in enumerate(KFold(5, shuffle=True).split(X)):
+ kf_train.append(train_ind)
+ kf_test.append(test_ind)
+ folds[test_ind] = i
+ ps = PredefinedSplit(folds)
+ # n_splits is simply the no of unique folds
+ assert len(np.unique(folds)) == ps.get_n_splits()
+ ps_train, ps_test = zip(*ps.split())
+ assert_array_equal(ps_train, kf_train)
+ assert_array_equal(ps_test, kf_test)
+
+
+def test_group_shuffle_split():
+ for groups_i in test_groups:
+ X = y = np.ones(len(groups_i))
+ n_splits = 6
+ test_size = 1.0 / 3
+ slo = GroupShuffleSplit(n_splits, test_size=test_size, random_state=0)
+
+ # Make sure the repr works
+ repr(slo)
+
+ # Test that the length is correct
+ assert slo.get_n_splits(X, y, groups=groups_i) == n_splits
+
+ l_unique = np.unique(groups_i)
+ l = np.asarray(groups_i)
+
+ for train, test in slo.split(X, y, groups=groups_i):
+ # First test: no train group is in the test set and vice versa
+ l_train_unique = np.unique(l[train])
+ l_test_unique = np.unique(l[test])
+ assert not np.any(np.in1d(l[train], l_test_unique))
+ assert not np.any(np.in1d(l[test], l_train_unique))
+
+ # Second test: train and test add up to all the data
+ assert l[train].size + l[test].size == l.size
+
+ # Third test: train and test are disjoint
+ assert_array_equal(np.intersect1d(train, test), [])
+
+ # Fourth test:
+ # unique train and test groups are correct, +- 1 for rounding error
+ assert abs(len(l_test_unique) - round(test_size * len(l_unique))) <= 1
+ assert (
+ abs(len(l_train_unique) - round((1.0 - test_size) * len(l_unique))) <= 1
+ )
+
+
+def test_leave_one_p_group_out():
+ logo = LeaveOneGroupOut()
+ lpgo_1 = LeavePGroupsOut(n_groups=1)
+ lpgo_2 = LeavePGroupsOut(n_groups=2)
+
+ # Make sure the repr works
+ assert repr(logo) == "LeaveOneGroupOut()"
+ assert repr(lpgo_1) == "LeavePGroupsOut(n_groups=1)"
+ assert repr(lpgo_2) == "LeavePGroupsOut(n_groups=2)"
+ assert repr(LeavePGroupsOut(n_groups=3)) == "LeavePGroupsOut(n_groups=3)"
+
+ for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1), (lpgo_2, 2))):
+ for i, groups_i in enumerate(test_groups):
+ n_groups = len(np.unique(groups_i))
+ n_splits = n_groups if p_groups_out == 1 else n_groups * (n_groups - 1) / 2
+ X = y = np.ones(len(groups_i))
+
+ # Test that the length is correct
+ assert cv.get_n_splits(X, y, groups=groups_i) == n_splits
+
+ groups_arr = np.asarray(groups_i)
+
+ # Split using the original list / array / list of string groups_i
+ for train, test in cv.split(X, y, groups=groups_i):
+ # First test: no train group is in the test set and vice versa
+ assert_array_equal(
+ np.intersect1d(groups_arr[train], groups_arr[test]).tolist(), []
+ )
+
+ # Second test: train and test add up to all the data
+ assert len(train) + len(test) == len(groups_i)
+
+ # Third test:
+ # The number of groups in test must be equal to p_groups_out
+ assert np.unique(groups_arr[test]).shape[0], p_groups_out
+
+ # check get_n_splits() with dummy parameters
+ assert logo.get_n_splits(None, None, ["a", "b", "c", "b", "c"]) == 3
+ assert logo.get_n_splits(groups=[1.0, 1.1, 1.0, 1.2]) == 3
+ assert lpgo_2.get_n_splits(None, None, np.arange(4)) == 6
+ assert lpgo_1.get_n_splits(groups=np.arange(4)) == 4
+
+ # raise ValueError if a `groups` parameter is illegal
+ with pytest.raises(ValueError):
+ logo.get_n_splits(None, None, [0.0, np.nan, 0.0])
+ with pytest.raises(ValueError):
+ lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0])
+
+ msg = "The 'groups' parameter should not be None."
+ with pytest.raises(ValueError, match=msg):
+ logo.get_n_splits(None, None, None)
+ with pytest.raises(ValueError, match=msg):
+ lpgo_1.get_n_splits(None, None, None)
+
+
+def test_leave_group_out_changing_groups():
+ # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if
+ # the groups variable is changed before calling split
+ groups = np.array([0, 1, 2, 1, 1, 2, 0, 0])
+ X = np.ones(len(groups))
+ groups_changing = np.array(groups, copy=True)
+ lolo = LeaveOneGroupOut().split(X, groups=groups)
+ lolo_changing = LeaveOneGroupOut().split(X, groups=groups)
+ lplo = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
+ lplo_changing = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
+ groups_changing[:] = 0
+ for llo, llo_changing in [(lolo, lolo_changing), (lplo, lplo_changing)]:
+ for (train, test), (train_chan, test_chan) in zip(llo, llo_changing):
+ assert_array_equal(train, train_chan)
+ assert_array_equal(test, test_chan)
+
+ # n_splits = no of 2 (p) group combinations of the unique groups = 3C2 = 3
+ assert 3 == LeavePGroupsOut(n_groups=2).get_n_splits(X, y=X, groups=groups)
+ # n_splits = no of unique groups (C(uniq_lbls, 1) = n_unique_groups)
+ assert 3 == LeaveOneGroupOut().get_n_splits(X, y=X, groups=groups)
+
+
+def test_leave_group_out_order_dependence():
+ # Check that LeaveOneGroupOut orders the splits according to the index
+ # of the group left out.
+ groups = np.array([2, 2, 0, 0, 1, 1])
+ X = np.ones(len(groups))
+
+ splits = iter(LeaveOneGroupOut().split(X, groups=groups))
+
+ expected_indices = [
+ ([0, 1, 4, 5], [2, 3]),
+ ([0, 1, 2, 3], [4, 5]),
+ ([2, 3, 4, 5], [0, 1]),
+ ]
+
+ for expected_train, expected_test in expected_indices:
+ train, test = next(splits)
+ assert_array_equal(train, expected_train)
+ assert_array_equal(test, expected_test)
+
+
+def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
+ X = y = groups = np.ones(0)
+ msg = re.escape("Found array with 0 sample(s)")
+ with pytest.raises(ValueError, match=msg):
+ next(LeaveOneGroupOut().split(X, y, groups))
+
+ X = y = groups = np.ones(1)
+ msg = re.escape(
+ f"The groups parameter contains fewer than 2 unique groups ({groups})."
+ " LeaveOneGroupOut expects at least 2."
+ )
+ with pytest.raises(ValueError, match=msg):
+ next(LeaveOneGroupOut().split(X, y, groups))
+
+ X = y = groups = np.ones(1)
+ msg = re.escape(
+ "The groups parameter contains fewer than (or equal to) n_groups "
+ f"(3) numbers of unique groups ({groups}). LeavePGroupsOut expects "
+ "that at least n_groups + 1 (4) unique groups "
+ "be present"
+ )
+ with pytest.raises(ValueError, match=msg):
+ next(LeavePGroupsOut(n_groups=3).split(X, y, groups))
+
+ X = y = groups = np.arange(3)
+ msg = re.escape(
+ "The groups parameter contains fewer than (or equal to) n_groups "
+ f"(3) numbers of unique groups ({groups}). LeavePGroupsOut expects "
+ "that at least n_groups + 1 (4) unique groups "
+ "be present"
+ )
+ with pytest.raises(ValueError, match=msg):
+ next(LeavePGroupsOut(n_groups=3).split(X, y, groups))
+
+
+@ignore_warnings
+def test_repeated_cv_value_errors():
+ # n_repeats is not integer or <= 0
+ for cv in (RepeatedKFold, RepeatedStratifiedKFold):
+ with pytest.raises(ValueError):
+ cv(n_repeats=0)
+ with pytest.raises(ValueError):
+ cv(n_repeats=1.5)
+
+
+@pytest.mark.parametrize("RepeatedCV", [RepeatedKFold, RepeatedStratifiedKFold])
+def test_repeated_cv_repr(RepeatedCV):
+ n_splits, n_repeats = 2, 6
+ repeated_cv = RepeatedCV(n_splits=n_splits, n_repeats=n_repeats)
+ repeated_cv_repr = "{}(n_repeats=6, n_splits=2, random_state=None)".format(
+ repeated_cv.__class__.__name__
+ )
+ assert repeated_cv_repr == repr(repeated_cv)
+
+
+def test_repeated_kfold_determinstic_split():
+ X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+ random_state = 258173307
+ rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=random_state)
+
+ # split should produce same and deterministic splits on
+ # each call
+ for _ in range(3):
+ splits = rkf.split(X)
+ train, test = next(splits)
+ assert_array_equal(train, [2, 4])
+ assert_array_equal(test, [0, 1, 3])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 3])
+ assert_array_equal(test, [2, 4])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1])
+ assert_array_equal(test, [2, 3, 4])
+
+ train, test = next(splits)
+ assert_array_equal(train, [2, 3, 4])
+ assert_array_equal(test, [0, 1])
+
+ with pytest.raises(StopIteration):
+ next(splits)
+
+
+def test_get_n_splits_for_repeated_kfold():
+ n_splits = 3
+ n_repeats = 4
+ rkf = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats)
+ expected_n_splits = n_splits * n_repeats
+ assert expected_n_splits == rkf.get_n_splits()
+
+
+def test_get_n_splits_for_repeated_stratified_kfold():
+ n_splits = 3
+ n_repeats = 4
+ rskf = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats)
+ expected_n_splits = n_splits * n_repeats
+ assert expected_n_splits == rskf.get_n_splits()
+
+
+def test_repeated_stratified_kfold_determinstic_split():
+ X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+ y = [1, 1, 1, 0, 0]
+ random_state = 1944695409
+ rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, random_state=random_state)
+
+ # split should produce same and deterministic splits on
+ # each call
+ for _ in range(3):
+ splits = rskf.split(X, y)
+ train, test = next(splits)
+ assert_array_equal(train, [1, 4])
+ assert_array_equal(test, [0, 2, 3])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 2, 3])
+ assert_array_equal(test, [1, 4])
+
+ train, test = next(splits)
+ assert_array_equal(train, [2, 3])
+ assert_array_equal(test, [0, 1, 4])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 4])
+ assert_array_equal(test, [2, 3])
+
+ with pytest.raises(StopIteration):
+ next(splits)
+
+
+def test_train_test_split_errors():
+ pytest.raises(ValueError, train_test_split)
+
+ pytest.raises(ValueError, train_test_split, range(3), train_size=1.1)
+
+ pytest.raises(ValueError, train_test_split, range(3), test_size=0.6, train_size=0.6)
+ pytest.raises(
+ ValueError,
+ train_test_split,
+ range(3),
+ test_size=np.float32(0.6),
+ train_size=np.float32(0.6),
+ )
+ pytest.raises(ValueError, train_test_split, range(3), test_size="wrong_type")
+ pytest.raises(ValueError, train_test_split, range(3), test_size=2, train_size=4)
+ pytest.raises(TypeError, train_test_split, range(3), some_argument=1.1)
+ pytest.raises(ValueError, train_test_split, range(3), range(42))
+ pytest.raises(ValueError, train_test_split, range(10), shuffle=False, stratify=True)
+
+ with pytest.raises(
+ ValueError,
+ match=r"train_size=11 should be either positive and "
+ r"smaller than the number of samples 10 or a "
+ r"float in the \(0, 1\) range",
+ ):
+ train_test_split(range(10), train_size=11, test_size=1)
+
+
+@pytest.mark.parametrize(
+ "train_size, exp_train, exp_test", [(None, 7, 3), (8, 8, 2), (0.8, 8, 2)]
+)
+def test_train_test_split_default_test_size(train_size, exp_train, exp_test):
+ # Check that the default value has the expected behavior, i.e. complement
+ # train_size unless both are specified.
+ X_train, X_test = train_test_split(X, train_size=train_size)
+
+ assert len(X_train) == exp_train
+ assert len(X_test) == exp_test
+
+
+def test_train_test_split():
+ X = np.arange(100).reshape((10, 10))
+ X_s = coo_matrix(X)
+ y = np.arange(10)
+
+ # simple test
+ split = train_test_split(X, y, test_size=None, train_size=0.5)
+ X_train, X_test, y_train, y_test = split
+ assert len(y_test) == len(y_train)
+ # test correspondence of X and y
+ assert_array_equal(X_train[:, 0], y_train * 10)
+ assert_array_equal(X_test[:, 0], y_test * 10)
+
+ # don't convert lists to anything else by default
+ split = train_test_split(X, X_s, y.tolist())
+ X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
+ assert isinstance(y_train, list)
+ assert isinstance(y_test, list)
+
+ # allow nd-arrays
+ X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2)
+ y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11)
+ split = train_test_split(X_4d, y_3d)
+ assert split[0].shape == (7, 5, 3, 2)
+ assert split[1].shape == (3, 5, 3, 2)
+ assert split[2].shape == (7, 7, 11)
+ assert split[3].shape == (3, 7, 11)
+
+ # test stratification option
+ y = np.array([1, 1, 1, 1, 2, 2, 2, 2])
+ for test_size, exp_test_size in zip([2, 4, 0.25, 0.5, 0.75], [2, 4, 2, 4, 6]):
+ train, test = train_test_split(
+ y, test_size=test_size, stratify=y, random_state=0
+ )
+ assert len(test) == exp_test_size
+ assert len(test) + len(train) == len(y)
+ # check the 1:1 ratio of ones and twos in the data is preserved
+ assert np.sum(train == 1) == np.sum(train == 2)
+
+ # test unshuffled split
+ y = np.arange(10)
+ for test_size in [2, 0.2]:
+ train, test = train_test_split(y, shuffle=False, test_size=test_size)
+ assert_array_equal(test, [8, 9])
+ assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6, 7])
+
+
+def test_train_test_split_32bit_overflow():
+ """Check for integer overflow on 32-bit platforms.
+
+ Non-regression test for:
+ https://github.com/scikit-learn/scikit-learn/issues/20774
+ """
+
+ # A number 'n' big enough for expression 'n * n * train_size' to cause
+ # an overflow for signed 32-bit integer
+ big_number = 100000
+
+ # Definition of 'y' is a part of reproduction - population for at least
+ # one class should be in the same order of magnitude as size of X
+ X = np.arange(big_number)
+ y = X > (0.99 * big_number)
+
+ split = train_test_split(X, y, stratify=y, train_size=0.25)
+ X_train, X_test, y_train, y_test = split
+
+ assert X_train.size + X_test.size == big_number
+ assert y_train.size + y_test.size == big_number
+
+
+@ignore_warnings
+def test_train_test_split_pandas():
+ # check train_test_split doesn't destroy pandas dataframe
+ types = [MockDataFrame]
+ try:
+ from modin.pandas import DataFrame
+
+ types.append(DataFrame)
+ except ImportError:
+ pass
+ for InputFeatureType in types:
+ # X dataframe
+ X_df = InputFeatureType(X)
+ X_train, X_test = train_test_split(X_df)
+ assert isinstance(X_train, InputFeatureType)
+ assert isinstance(X_test, InputFeatureType)
+
+
+def test_train_test_split_sparse():
+ # check that train_test_split converts scipy sparse matrices
+ # to csr, as stated in the documentation
+ X = np.arange(100).reshape((10, 10))
+ sparse_types = [csr_matrix, csc_matrix, coo_matrix]
+ for InputFeatureType in sparse_types:
+ X_s = InputFeatureType(X)
+ X_train, X_test = train_test_split(X_s)
+ assert isinstance(X_train, csr_matrix)
+ assert isinstance(X_test, csr_matrix)
+
+
+def test_train_test_split_mock_pandas():
+ # X mock dataframe
+ X_df = MockDataFrame(X)
+ X_train, X_test = train_test_split(X_df)
+ assert isinstance(X_train, MockDataFrame)
+ assert isinstance(X_test, MockDataFrame)
+ X_train_arr, X_test_arr = train_test_split(X_df)
+
+
+def test_train_test_split_list_input():
+ # Check that when y is a list / list of string labels, it works.
+ X = np.ones(7)
+ y1 = ["1"] * 4 + ["0"] * 3
+ y2 = np.hstack((np.ones(4), np.zeros(3)))
+ y3 = y2.tolist()
+
+ for stratify in (True, False):
+ X_train1, X_test1, y_train1, y_test1 = train_test_split(
+ X, y1, stratify=y1 if stratify else None, random_state=0
+ )
+ X_train2, X_test2, y_train2, y_test2 = train_test_split(
+ X, y2, stratify=y2 if stratify else None, random_state=0
+ )
+ X_train3, X_test3, y_train3, y_test3 = train_test_split(
+ X, y3, stratify=y3 if stratify else None, random_state=0
+ )
+
+ np.testing.assert_equal(X_train1, X_train2)
+ np.testing.assert_equal(y_train2, y_train3)
+ np.testing.assert_equal(X_test1, X_test3)
+ np.testing.assert_equal(y_test3, y_test2)
+
+
+@pytest.mark.parametrize(
+ "test_size, train_size",
+ [(2.0, None), (1.0, None), (0.1, 0.95), (None, 1j), (11, None), (10, None), (8, 3)],
+)
+def test_shufflesplit_errors(test_size, train_size):
+ with pytest.raises(ValueError):
+ next(ShuffleSplit(test_size=test_size, train_size=train_size).split(X))
+
+
+def test_shufflesplit_reproducible():
+ # Check that iterating twice on the ShuffleSplit gives the same
+ # sequence of train-test when the random_state is given
+ ss = ShuffleSplit(random_state=21)
+ assert_array_equal([a for a, b in ss.split(X)], [a for a, b in ss.split(X)])
+
+
+def test_stratifiedshufflesplit_list_input():
+ # Check that when y is a list / list of string labels, it works.
+ sss = StratifiedShuffleSplit(test_size=2, random_state=42)
+ X = np.ones(7)
+ y1 = ["1"] * 4 + ["0"] * 3
+ y2 = np.hstack((np.ones(4), np.zeros(3)))
+ y3 = y2.tolist()
+
+ np.testing.assert_equal(list(sss.split(X, y1)), list(sss.split(X, y2)))
+ np.testing.assert_equal(list(sss.split(X, y3)), list(sss.split(X, y2)))
+
+
+def test_train_test_split_allow_nans():
+ # Check that train_test_split allows input data with NaNs
+ X = np.arange(200, dtype=np.float64).reshape(10, -1)
+ X[2, :] = np.nan
+ y = np.repeat([0, 1], X.shape[0] / 2)
+ train_test_split(X, y, test_size=0.2, random_state=42)
+
+
+def test_check_cv():
+ X = np.ones(9)
+ cv = check_cv(3, classifier=False)
+ # Use numpy.testing.assert_equal which recursively compares
+ # lists of lists
+ np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
+
+ y_binary = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1])
+ cv = check_cv(3, y_binary, classifier=True)
+ np.testing.assert_equal(
+ list(StratifiedKFold(3).split(X, y_binary)), list(cv.split(X, y_binary))
+ )
+
+ y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
+ cv = check_cv(3, y_multiclass, classifier=True)
+ np.testing.assert_equal(
+ list(StratifiedKFold(3).split(X, y_multiclass)), list(cv.split(X, y_multiclass))
+ )
+ # also works with 2d multiclass
+ y_multiclass_2d = y_multiclass.reshape(-1, 1)
+ cv = check_cv(3, y_multiclass_2d, classifier=True)
+ np.testing.assert_equal(
+ list(StratifiedKFold(3).split(X, y_multiclass_2d)),
+ list(cv.split(X, y_multiclass_2d)),
+ )
+
+ assert not np.all(
+ next(StratifiedKFold(3).split(X, y_multiclass_2d))[0]
+ == next(KFold(3).split(X, y_multiclass_2d))[0]
+ )
+
+ X = np.ones(5)
+ y_multilabel = np.array(
+ [[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1], [1, 1, 0, 1], [0, 0, 1, 0]]
+ )
+ cv = check_cv(3, y_multilabel, classifier=True)
+ np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
+
+ y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]])
+ cv = check_cv(3, y_multioutput, classifier=True)
+ np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
+
+ with pytest.raises(ValueError):
+ check_cv(cv="lolo")
+
+
+def test_cv_iterable_wrapper():
+ kf_iter = KFold().split(X, y)
+ kf_iter_wrapped = check_cv(kf_iter)
+ # Since the wrapped iterable is enlisted and stored,
+ # split can be called any number of times to produce
+ # consistent results.
+ np.testing.assert_equal(
+ list(kf_iter_wrapped.split(X, y)), list(kf_iter_wrapped.split(X, y))
+ )
+ # If the splits are randomized, successive calls to split yields different
+ # results
+ kf_randomized_iter = KFold(shuffle=True, random_state=0).split(X, y)
+ kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
+ # numpy's assert_array_equal properly compares nested lists
+ np.testing.assert_equal(
+ list(kf_randomized_iter_wrapped.split(X, y)),
+ list(kf_randomized_iter_wrapped.split(X, y)),
+ )
+
+ try:
+ splits_are_equal = True
+ np.testing.assert_equal(
+ list(kf_iter_wrapped.split(X, y)),
+ list(kf_randomized_iter_wrapped.split(X, y)),
+ )
+ except AssertionError:
+ splits_are_equal = False
+ assert not splits_are_equal, (
+ "If the splits are randomized, "
+ "successive calls to split should yield different results"
+ )
+
+
+@pytest.mark.parametrize("kfold", [GroupKFold, StratifiedGroupKFold])
+def test_group_kfold(kfold):
+ rng = np.random.RandomState(0)
+
+ # Parameters of the test
+ n_groups = 15
+ n_samples = 1000
+ n_splits = 5
+
+ X = y = np.ones(n_samples)
+
+ # Construct the test data
+ tolerance = 0.05 * n_samples # 5 percent error allowed
+ groups = rng.randint(0, n_groups, n_samples)
+
+ ideal_n_groups_per_fold = n_samples // n_splits
+
+ len(np.unique(groups))
+ # Get the test fold indices from the test set indices of each fold
+ folds = np.zeros(n_samples)
+ lkf = kfold(n_splits=n_splits)
+ for i, (_, test) in enumerate(lkf.split(X, y, groups)):
+ folds[test] = i
+
+ # Check that folds have approximately the same size
+ assert len(folds) == len(groups)
+ for i in np.unique(folds):
+ assert tolerance >= abs(sum(folds == i) - ideal_n_groups_per_fold)
+
+ # Check that each group appears only in 1 fold
+ for group in np.unique(groups):
+ assert len(np.unique(folds[groups == group])) == 1
+
+ # Check that no group is on both sides of the split
+ groups = np.asarray(groups, dtype=object)
+ for train, test in lkf.split(X, y, groups):
+ assert len(np.intersect1d(groups[train], groups[test])) == 0
+
+ # Construct the test data
+ groups = np.array(
+ [
+ "Albert",
+ "Jean",
+ "Bertrand",
+ "Michel",
+ "Jean",
+ "Francis",
+ "Robert",
+ "Michel",
+ "Rachel",
+ "Lois",
+ "Michelle",
+ "Bernard",
+ "Marion",
+ "Laura",
+ "Jean",
+ "Rachel",
+ "Franck",
+ "John",
+ "Gael",
+ "Anna",
+ "Alix",
+ "Robert",
+ "Marion",
+ "David",
+ "Tony",
+ "Abel",
+ "Becky",
+ "Madmood",
+ "Cary",
+ "Mary",
+ "Alexandre",
+ "David",
+ "Francis",
+ "Barack",
+ "Abdoul",
+ "Rasha",
+ "Xi",
+ "Silvia",
+ ]
+ )
+
+ n_groups = len(np.unique(groups))
+ n_samples = len(groups)
+ n_splits = 5
+ tolerance = 0.05 * n_samples # 5 percent error allowed
+ ideal_n_groups_per_fold = n_samples // n_splits
+
+ X = y = np.ones(n_samples)
+
+ # Get the test fold indices from the test set indices of each fold
+ folds = np.zeros(n_samples)
+ for i, (_, test) in enumerate(lkf.split(X, y, groups)):
+ folds[test] = i
+
+ # Check that folds have approximately the same size
+ assert len(folds) == len(groups)
+ for i in np.unique(folds):
+ assert tolerance >= abs(sum(folds == i) - ideal_n_groups_per_fold)
+
+ # Check that each group appears only in 1 fold
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", FutureWarning)
+ for group in np.unique(groups):
+ assert len(np.unique(folds[groups == group])) == 1
+
+ # Check that no group is on both sides of the split
+ groups = np.asarray(groups, dtype=object)
+ for train, test in lkf.split(X, y, groups):
+ assert len(np.intersect1d(groups[train], groups[test])) == 0
+
+ # groups can also be a list
+ cv_iter = list(lkf.split(X, y, groups.tolist()))
+ for (train1, test1), (train2, test2) in zip(lkf.split(X, y, groups), cv_iter):
+ assert_array_equal(train1, train2)
+ assert_array_equal(test1, test2)
+
+ # Should fail if there are more folds than groups
+ groups = np.array([1, 1, 1, 2, 2])
+ X = y = np.ones(len(groups))
+ with pytest.raises(ValueError, match="Cannot have number of splits.*greater"):
+ next(GroupKFold(n_splits=3).split(X, y, groups))
+
+
+def test_time_series_cv():
+ X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
+
+ # Should fail if there are more folds than samples
+ with pytest.raises(ValueError, match="Cannot have number of folds.*greater"):
+ next(TimeSeriesSplit(n_splits=7).split(X))
+
+ tscv = TimeSeriesSplit(2)
+
+ # Manually check that Time Series CV preserves the data
+ # ordering on toy datasets
+ splits = tscv.split(X[:-1])
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1])
+ assert_array_equal(test, [2, 3])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2, 3])
+ assert_array_equal(test, [4, 5])
+
+ splits = TimeSeriesSplit(2).split(X)
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2])
+ assert_array_equal(test, [3, 4])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2, 3, 4])
+ assert_array_equal(test, [5, 6])
+
+ # Check get_n_splits returns the correct number of splits
+ splits = TimeSeriesSplit(2).split(X)
+ n_splits_actual = len(list(splits))
+ assert n_splits_actual == tscv.get_n_splits()
+ assert n_splits_actual == 2
+
+
+def _check_time_series_max_train_size(splits, check_splits, max_train_size):
+ for (train, test), (check_train, check_test) in zip(splits, check_splits):
+ assert_array_equal(test, check_test)
+ assert len(check_train) <= max_train_size
+ suffix_start = max(len(train) - max_train_size, 0)
+ assert_array_equal(check_train, train[suffix_start:])
+
+
+def test_time_series_max_train_size():
+ X = np.zeros((6, 1))
+ splits = TimeSeriesSplit(n_splits=3).split(X)
+ check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3).split(X)
+ _check_time_series_max_train_size(splits, check_splits, max_train_size=3)
+
+ # Test for the case where the size of a fold is greater than max_train_size
+ check_splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X)
+ _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
+
+ # Test for the case where the size of each fold is less than max_train_size
+ check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
+ _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
+
+
+def test_time_series_test_size():
+ X = np.zeros((10, 1))
+
+ # Test alone
+ splits = TimeSeriesSplit(n_splits=3, test_size=3).split(X)
+
+ train, test = next(splits)
+ assert_array_equal(train, [0])
+ assert_array_equal(test, [1, 2, 3])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2, 3])
+ assert_array_equal(test, [4, 5, 6])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6])
+ assert_array_equal(test, [7, 8, 9])
+
+ # Test with max_train_size
+ splits = TimeSeriesSplit(n_splits=2, test_size=2, max_train_size=4).split(X)
+
+ train, test = next(splits)
+ assert_array_equal(train, [2, 3, 4, 5])
+ assert_array_equal(test, [6, 7])
+
+ train, test = next(splits)
+ assert_array_equal(train, [4, 5, 6, 7])
+ assert_array_equal(test, [8, 9])
+
+ # Should fail with not enough data points for configuration
+ with pytest.raises(ValueError, match="Too many splits.*with test_size"):
+ splits = TimeSeriesSplit(n_splits=5, test_size=2).split(X)
+ next(splits)
+
+
+def test_time_series_gap():
+ X = np.zeros((10, 1))
+
+ # Test alone
+ splits = TimeSeriesSplit(n_splits=2, gap=2).split(X)
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1])
+ assert_array_equal(test, [4, 5, 6])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2, 3, 4])
+ assert_array_equal(test, [7, 8, 9])
+
+ # Test with max_train_size
+ splits = TimeSeriesSplit(n_splits=3, gap=2, max_train_size=2).split(X)
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1])
+ assert_array_equal(test, [4, 5])
+
+ train, test = next(splits)
+ assert_array_equal(train, [2, 3])
+ assert_array_equal(test, [6, 7])
+
+ train, test = next(splits)
+ assert_array_equal(train, [4, 5])
+ assert_array_equal(test, [8, 9])
+
+ # Test with test_size
+ splits = TimeSeriesSplit(n_splits=2, gap=2, max_train_size=4, test_size=2).split(X)
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2, 3])
+ assert_array_equal(test, [6, 7])
+
+ train, test = next(splits)
+ assert_array_equal(train, [2, 3, 4, 5])
+ assert_array_equal(test, [8, 9])
+
+ # Test with additional test_size
+ splits = TimeSeriesSplit(n_splits=2, gap=2, test_size=3).split(X)
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1])
+ assert_array_equal(test, [4, 5, 6])
+
+ train, test = next(splits)
+ assert_array_equal(train, [0, 1, 2, 3, 4])
+ assert_array_equal(test, [7, 8, 9])
+
+ # Verify proper error is thrown
+ with pytest.raises(ValueError, match="Too many splits.*and gap"):
+ splits = TimeSeriesSplit(n_splits=4, gap=2).split(X)
+ next(splits)
+
+
+def test_nested_cv():
+ # Test if nested cross validation works with different combinations of cv
+ rng = np.random.RandomState(0)
+
+ X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
+ groups = rng.randint(0, 5, 15)
+
+ cvs = [
+ LeaveOneGroupOut(),
+ StratifiedKFold(n_splits=2),
+ GroupKFold(n_splits=3),
+ ]
+
+ for inner_cv, outer_cv in combinations_with_replacement(cvs, 2):
+ gs = GridSearchCV(
+ DummyClassifier(),
+ param_grid={"strategy": ["stratified", "most_frequent"]},
+ cv=inner_cv,
+ error_score="raise",
+ )
+ cross_val_score(
+ gs, X=X, y=y, groups=groups, cv=outer_cv, fit_params={"groups": groups}
+ )
+
+
+def test_build_repr():
+ class MockSplitter:
+ def __init__(self, a, b=0, c=None):
+ self.a = a
+ self.b = b
+ self.c = c
+
+ def __repr__(self):
+ return _build_repr(self)
+
+ assert repr(MockSplitter(5, 6)) == "MockSplitter(a=5, b=6, c=None)"
+
+
+@pytest.mark.parametrize(
+ "CVSplitter", (ShuffleSplit, GroupShuffleSplit, StratifiedShuffleSplit)
+)
+def test_shuffle_split_empty_trainset(CVSplitter):
+ cv = CVSplitter(test_size=0.99)
+ X, y = [[1]], [0] # 1 sample
+ with pytest.raises(
+ ValueError,
+ match=(
+ "With n_samples=1, test_size=0.99 and train_size=None, "
+ "the resulting train set will be empty"
+ ),
+ ):
+ next(cv.split(X, y, groups=[1]))
+
+
+def test_train_test_split_empty_trainset():
+ (X,) = [[1]] # 1 sample
+ with pytest.raises(
+ ValueError,
+ match=(
+ "With n_samples=1, test_size=0.99 and train_size=None, "
+ "the resulting train set will be empty"
+ ),
+ ):
+ train_test_split(X, test_size=0.99)
+
+ X = [[1], [1], [1]] # 3 samples, ask for more than 2 thirds
+ with pytest.raises(
+ ValueError,
+ match=(
+ "With n_samples=3, test_size=0.67 and train_size=None, "
+ "the resulting train set will be empty"
+ ),
+ ):
+ train_test_split(X, test_size=0.67)
+
+
+def test_leave_one_out_empty_trainset():
+ # LeaveOneGroup out expect at least 2 groups so no need to check
+ cv = LeaveOneOut()
+ X, y = [[1]], [0] # 1 sample
+ with pytest.raises(ValueError, match="Cannot perform LeaveOneOut with n_samples=1"):
+ next(cv.split(X, y))
+
+
+def test_leave_p_out_empty_trainset():
+ # No need to check LeavePGroupsOut
+ cv = LeavePOut(p=2)
+ X, y = [[1], [2]], [0, 3] # 2 samples
+ with pytest.raises(
+ ValueError, match="p=2 must be strictly less than the number of samples=2"
+ ):
+ next(cv.split(X, y, groups=[1, 2]))
+
+
+@pytest.mark.parametrize("Klass", (KFold, StratifiedKFold, StratifiedGroupKFold))
+def test_random_state_shuffle_false(Klass):
+ # passing a non-default random_state when shuffle=False makes no sense
+ with pytest.raises(ValueError, match="has no effect since shuffle is False"):
+ Klass(3, shuffle=False, random_state=0)
+
+
+@pytest.mark.parametrize(
+ "cv, expected",
+ [
+ (KFold(), True),
+ (KFold(shuffle=True, random_state=123), True),
+ (StratifiedKFold(), True),
+ (StratifiedKFold(shuffle=True, random_state=123), True),
+ (StratifiedGroupKFold(shuffle=True, random_state=123), True),
+ (StratifiedGroupKFold(), True),
+ (RepeatedKFold(random_state=123), True),
+ (RepeatedStratifiedKFold(random_state=123), True),
+ (ShuffleSplit(random_state=123), True),
+ (GroupShuffleSplit(random_state=123), True),
+ (StratifiedShuffleSplit(random_state=123), True),
+ (GroupKFold(), True),
+ (TimeSeriesSplit(), True),
+ (LeaveOneOut(), True),
+ (LeaveOneGroupOut(), True),
+ (LeavePGroupsOut(n_groups=2), True),
+ (LeavePOut(p=2), True),
+ (KFold(shuffle=True, random_state=None), False),
+ (KFold(shuffle=True, random_state=None), False),
+ (StratifiedKFold(shuffle=True, random_state=np.random.RandomState(0)), False),
+ (StratifiedKFold(shuffle=True, random_state=np.random.RandomState(0)), False),
+ (RepeatedKFold(random_state=None), False),
+ (RepeatedKFold(random_state=np.random.RandomState(0)), False),
+ (RepeatedStratifiedKFold(random_state=None), False),
+ (RepeatedStratifiedKFold(random_state=np.random.RandomState(0)), False),
+ (ShuffleSplit(random_state=None), False),
+ (ShuffleSplit(random_state=np.random.RandomState(0)), False),
+ (GroupShuffleSplit(random_state=None), False),
+ (GroupShuffleSplit(random_state=np.random.RandomState(0)), False),
+ (StratifiedShuffleSplit(random_state=None), False),
+ (StratifiedShuffleSplit(random_state=np.random.RandomState(0)), False),
+ ],
+)
+def test_yields_constant_splits(cv, expected):
+ assert _yields_constant_splits(cv) == expected
diff --git a/modin/pandas/test/interoperability/sklearn/model_selection/test_successive_halving.py b/modin/pandas/test/interoperability/sklearn/model_selection/test_successive_halving.py
new file mode 100644
index 00000000000..7416100a749
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/model_selection/test_successive_halving.py
@@ -0,0 +1,788 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+from math import ceil
+
+import pytest
+from scipy.stats import norm, randint
+import numpy as np
+
+from sklearn.datasets import make_classification
+from sklearn.dummy import DummyClassifier
+from sklearn.experimental import enable_halving_search_cv # noqa
+from sklearn.model_selection import StratifiedKFold
+from sklearn.model_selection import StratifiedShuffleSplit
+from sklearn.model_selection import LeaveOneGroupOut
+from sklearn.model_selection import LeavePGroupsOut
+from sklearn.model_selection import GroupKFold
+from sklearn.model_selection import GroupShuffleSplit
+from sklearn.model_selection import HalvingGridSearchCV
+from sklearn.model_selection import HalvingRandomSearchCV
+from sklearn.model_selection import KFold, ShuffleSplit
+from sklearn.svm import LinearSVC
+from sklearn.model_selection._search_successive_halving import (
+ _SubsampleMetaSplitter,
+ _top_k,
+)
+
+
+class FastClassifier(DummyClassifier):
+ """Dummy classifier that accepts parameters a, b, ... z.
+
+ These parameter don't affect the predictions and are useful for fast
+ grid searching."""
+
+ # update the constraints such that we accept all parameters from a to z
+ _parameter_constraints: dict = {
+ **DummyClassifier._parameter_constraints,
+ **{
+ chr(key): "no_validation" # type: ignore
+ for key in range(ord("a"), ord("z") + 1)
+ },
+ }
+
+ def __init__(
+ self, strategy="stratified", random_state=None, constant=None, **kwargs
+ ):
+ super().__init__(
+ strategy=strategy, random_state=random_state, constant=constant
+ )
+
+ def get_params(self, deep=False):
+ params = super().get_params(deep=deep)
+ for char in range(ord("a"), ord("z") + 1):
+ params[chr(char)] = "whatever"
+ return params
+
+
+class SometimesFailClassifier(DummyClassifier):
+ def __init__(
+ self,
+ strategy="stratified",
+ random_state=None,
+ constant=None,
+ n_estimators=10,
+ fail_fit=False,
+ fail_predict=False,
+ a=0,
+ ):
+ self.fail_fit = fail_fit
+ self.fail_predict = fail_predict
+ self.n_estimators = n_estimators
+ self.a = a
+
+ super().__init__(
+ strategy=strategy, random_state=random_state, constant=constant
+ )
+
+ def fit(self, X, y):
+ if self.fail_fit:
+ raise Exception("fitting failed")
+ return super().fit(X, y)
+
+ def predict(self, X):
+ if self.fail_predict:
+ raise Exception("predict failed")
+ return super().predict(X)
+
+
+@pytest.mark.filterwarnings("ignore::sklearn.exceptions.FitFailedWarning")
+@pytest.mark.filterwarnings("ignore:Scoring failed:UserWarning")
+@pytest.mark.filterwarnings("ignore:One or more of the:UserWarning")
+@pytest.mark.parametrize("HalvingSearch", (HalvingGridSearchCV, HalvingRandomSearchCV))
+@pytest.mark.parametrize("fail_at", ("fit", "predict"))
+def test_nan_handling(HalvingSearch, fail_at):
+ """Check the selection of the best scores in presence of failure represented by
+ NaN values."""
+ n_samples = 1_000
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+
+ search = HalvingSearch(
+ SometimesFailClassifier(),
+ {f"fail_{fail_at}": [False, True], "a": range(3)},
+ resource="n_estimators",
+ max_resources=6,
+ min_resources=1,
+ factor=2,
+ )
+
+ search.fit(X, y)
+
+ # estimators that failed during fit/predict should always rank lower
+ # than ones where the fit/predict succeeded
+ assert not search.best_params_[f"fail_{fail_at}"]
+ scores = search.cv_results_["mean_test_score"]
+ ranks = search.cv_results_["rank_test_score"]
+
+ # some scores should be NaN
+ assert np.isnan(scores).any()
+
+ unique_nan_ranks = np.unique(ranks[np.isnan(scores)])
+ # all NaN scores should have the same rank
+ assert unique_nan_ranks.shape[0] == 1
+ # NaNs should have the lowest rank
+ assert (unique_nan_ranks[0] >= ranks).all()
+
+
+@pytest.mark.parametrize("Est", (HalvingGridSearchCV, HalvingRandomSearchCV))
+@pytest.mark.parametrize(
+ "aggressive_elimination,"
+ "max_resources,"
+ "expected_n_iterations,"
+ "expected_n_required_iterations,"
+ "expected_n_possible_iterations,"
+ "expected_n_remaining_candidates,"
+ "expected_n_candidates,"
+ "expected_n_resources,",
+ [
+ # notice how it loops at the beginning
+ # also, the number of candidates evaluated at the last iteration is
+ # <= factor
+ (True, "limited", 4, 4, 3, 1, [60, 20, 7, 3], [20, 20, 60, 180]),
+ # no aggressive elimination: we end up with less iterations, and
+ # the number of candidates at the last iter is > factor, which isn't
+ # ideal
+ (False, "limited", 3, 4, 3, 3, [60, 20, 7], [20, 60, 180]),
+ # # When the amount of resource isn't limited, aggressive_elimination
+ # # has no effect. Here the default min_resources='exhaust' will take
+ # # over.
+ (True, "unlimited", 4, 4, 4, 1, [60, 20, 7, 3], [37, 111, 333, 999]),
+ (False, "unlimited", 4, 4, 4, 1, [60, 20, 7, 3], [37, 111, 333, 999]),
+ ],
+)
+def test_aggressive_elimination(
+ Est,
+ aggressive_elimination,
+ max_resources,
+ expected_n_iterations,
+ expected_n_required_iterations,
+ expected_n_possible_iterations,
+ expected_n_remaining_candidates,
+ expected_n_candidates,
+ expected_n_resources,
+):
+ # Test the aggressive_elimination parameter.
+
+ n_samples = 1000
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ param_grid = {"a": ("l1", "l2"), "b": list(range(30))}
+ base_estimator = FastClassifier()
+
+ if max_resources == "limited":
+ max_resources = 180
+ else:
+ max_resources = n_samples
+
+ sh = Est(
+ base_estimator,
+ param_grid,
+ aggressive_elimination=aggressive_elimination,
+ max_resources=max_resources,
+ factor=3,
+ )
+ sh.set_params(verbose=True) # just for test coverage
+
+ if Est is HalvingRandomSearchCV:
+ # same number of candidates as with the grid
+ sh.set_params(n_candidates=2 * 30, min_resources="exhaust")
+
+ sh.fit(X, y)
+
+ assert sh.n_iterations_ == expected_n_iterations
+ assert sh.n_required_iterations_ == expected_n_required_iterations
+ assert sh.n_possible_iterations_ == expected_n_possible_iterations
+ assert sh.n_resources_ == expected_n_resources
+ assert sh.n_candidates_ == expected_n_candidates
+ assert sh.n_remaining_candidates_ == expected_n_remaining_candidates
+ assert ceil(sh.n_candidates_[-1] / sh.factor) == sh.n_remaining_candidates_
+
+
+@pytest.mark.parametrize("Est", (HalvingGridSearchCV, HalvingRandomSearchCV))
+@pytest.mark.parametrize(
+ "min_resources,"
+ "max_resources,"
+ "expected_n_iterations,"
+ "expected_n_possible_iterations,"
+ "expected_n_resources,",
+ [
+ # with enough resources
+ ("smallest", "auto", 2, 4, [20, 60]),
+ # with enough resources but min_resources set manually
+ (50, "auto", 2, 3, [50, 150]),
+ # without enough resources, only one iteration can be done
+ ("smallest", 30, 1, 1, [20]),
+ # with exhaust: use as much resources as possible at the last iter
+ ("exhaust", "auto", 2, 2, [333, 999]),
+ ("exhaust", 1000, 2, 2, [333, 999]),
+ ("exhaust", 999, 2, 2, [333, 999]),
+ ("exhaust", 600, 2, 2, [200, 600]),
+ ("exhaust", 599, 2, 2, [199, 597]),
+ ("exhaust", 300, 2, 2, [100, 300]),
+ ("exhaust", 60, 2, 2, [20, 60]),
+ ("exhaust", 50, 1, 1, [20]),
+ ("exhaust", 20, 1, 1, [20]),
+ ],
+)
+def test_min_max_resources(
+ Est,
+ min_resources,
+ max_resources,
+ expected_n_iterations,
+ expected_n_possible_iterations,
+ expected_n_resources,
+):
+ # Test the min_resources and max_resources parameters, and how they affect
+ # the number of resources used at each iteration
+ n_samples = 1000
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ param_grid = {"a": [1, 2], "b": [1, 2, 3]}
+ base_estimator = FastClassifier()
+
+ sh = Est(
+ base_estimator,
+ param_grid,
+ factor=3,
+ min_resources=min_resources,
+ max_resources=max_resources,
+ )
+ if Est is HalvingRandomSearchCV:
+ sh.set_params(n_candidates=6) # same number as with the grid
+
+ sh.fit(X, y)
+
+ expected_n_required_iterations = 2 # given 6 combinations and factor = 3
+ assert sh.n_iterations_ == expected_n_iterations
+ assert sh.n_required_iterations_ == expected_n_required_iterations
+ assert sh.n_possible_iterations_ == expected_n_possible_iterations
+ assert sh.n_resources_ == expected_n_resources
+ if min_resources == "exhaust":
+ assert sh.n_possible_iterations_ == sh.n_iterations_ == len(sh.n_resources_)
+
+
+@pytest.mark.parametrize("Est", (HalvingRandomSearchCV, HalvingGridSearchCV))
+@pytest.mark.parametrize(
+ "max_resources, n_iterations, n_possible_iterations",
+ [
+ ("auto", 5, 9), # all resources are used
+ (1024, 5, 9),
+ (700, 5, 8),
+ (512, 5, 8),
+ (511, 5, 7),
+ (32, 4, 4),
+ (31, 3, 3),
+ (16, 3, 3),
+ (4, 1, 1), # max_resources == min_resources, only one iteration is
+ # possible
+ ],
+)
+def test_n_iterations(Est, max_resources, n_iterations, n_possible_iterations):
+ # test the number of actual iterations that were run depending on
+ # max_resources
+
+ n_samples = 1024
+ X, y = make_classification(n_samples=n_samples, random_state=1)
+ param_grid = {"a": [1, 2], "b": list(range(10))}
+ base_estimator = FastClassifier()
+ factor = 2
+
+ sh = Est(
+ base_estimator,
+ param_grid,
+ cv=2,
+ factor=factor,
+ max_resources=max_resources,
+ min_resources=4,
+ )
+ if Est is HalvingRandomSearchCV:
+ sh.set_params(n_candidates=20) # same as for HalvingGridSearchCV
+ sh.fit(X, y)
+ assert sh.n_required_iterations_ == 5
+ assert sh.n_iterations_ == n_iterations
+ assert sh.n_possible_iterations_ == n_possible_iterations
+
+
+@pytest.mark.parametrize("Est", (HalvingRandomSearchCV, HalvingGridSearchCV))
+def test_resource_parameter(Est):
+ # Test the resource parameter
+
+ n_samples = 1000
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ param_grid = {"a": [1, 2], "b": list(range(10))}
+ base_estimator = FastClassifier()
+ sh = Est(base_estimator, param_grid, cv=2, resource="c", max_resources=10, factor=3)
+ sh.fit(X, y)
+ assert set(sh.n_resources_) == set([1, 3, 9])
+ for r_i, params, param_c in zip(
+ sh.cv_results_["n_resources"],
+ sh.cv_results_["params"],
+ sh.cv_results_["param_c"],
+ ):
+ assert r_i == params["c"] == param_c
+
+ with pytest.raises(
+ ValueError, match="Cannot use resource=1234 which is not supported "
+ ):
+ sh = HalvingGridSearchCV(
+ base_estimator, param_grid, cv=2, resource="1234", max_resources=10
+ )
+ sh.fit(X, y)
+
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Cannot use parameter c as the resource since it is part "
+ "of the searched parameters."
+ ),
+ ):
+ param_grid = {"a": [1, 2], "b": [1, 2], "c": [1, 3]}
+ sh = HalvingGridSearchCV(
+ base_estimator, param_grid, cv=2, resource="c", max_resources=10
+ )
+ sh.fit(X, y)
+
+
+@pytest.mark.parametrize(
+ "max_resources, n_candidates, expected_n_candidates",
+ [
+ (512, "exhaust", 128), # generate exactly as much as needed
+ (32, "exhaust", 8),
+ (32, 8, 8),
+ (32, 7, 7), # ask for less than what we could
+ (32, 9, 9), # ask for more than 'reasonable'
+ ],
+)
+def test_random_search(max_resources, n_candidates, expected_n_candidates):
+ # Test random search and make sure the number of generated candidates is
+ # as expected
+
+ n_samples = 1024
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ param_grid = {"a": norm, "b": norm}
+ base_estimator = FastClassifier()
+ sh = HalvingRandomSearchCV(
+ base_estimator,
+ param_grid,
+ n_candidates=n_candidates,
+ cv=2,
+ max_resources=max_resources,
+ factor=2,
+ min_resources=4,
+ )
+ sh.fit(X, y)
+ assert sh.n_candidates_[0] == expected_n_candidates
+ if n_candidates == "exhaust":
+ # Make sure 'exhaust' makes the last iteration use as much resources as
+ # we can
+ assert sh.n_resources_[-1] == max_resources
+
+
+@pytest.mark.parametrize(
+ "param_distributions, expected_n_candidates",
+ [
+ ({"a": [1, 2]}, 2), # all lists, sample less than n_candidates
+ ({"a": randint(1, 3)}, 10), # not all list, respect n_candidates
+ ],
+)
+def test_random_search_discrete_distributions(
+ param_distributions, expected_n_candidates
+):
+ # Make sure random search samples the appropriate number of candidates when
+ # we ask for more than what's possible. How many parameters are sampled
+ # depends whether the distributions are 'all lists' or not (see
+ # ParameterSampler for details). This is somewhat redundant with the checks
+ # in ParameterSampler but interaction bugs were discovered during
+ # development of SH
+
+ n_samples = 1024
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ base_estimator = FastClassifier()
+ sh = HalvingRandomSearchCV(base_estimator, param_distributions, n_candidates=10)
+ sh.fit(X, y)
+ assert sh.n_candidates_[0] == expected_n_candidates
+
+
+@pytest.mark.parametrize("Est", (HalvingGridSearchCV, HalvingRandomSearchCV))
+@pytest.mark.parametrize(
+ "params, expected_error_message",
+ [
+ (
+ {"resource": "not_a_parameter"},
+ "Cannot use resource=not_a_parameter which is not supported",
+ ),
+ (
+ {"resource": "a", "max_resources": 100},
+ "Cannot use parameter a as the resource since it is part of",
+ ),
+ (
+ {"max_resources": "auto", "resource": "b"},
+ "resource can only be 'n_samples' when max_resources='auto'",
+ ),
+ (
+ {"min_resources": 15, "max_resources": 14},
+ "min_resources_=15 is greater than max_resources_=14",
+ ),
+ ({"cv": KFold(shuffle=True)}, "must yield consistent folds"),
+ ({"cv": ShuffleSplit()}, "must yield consistent folds"),
+ ],
+)
+def test_input_errors(Est, params, expected_error_message):
+ base_estimator = FastClassifier()
+ param_grid = {"a": [1]}
+ X, y = make_classification(100)
+
+ sh = Est(base_estimator, param_grid, **params)
+
+ with pytest.raises(ValueError, match=expected_error_message):
+ sh.fit(X, y)
+
+
+@pytest.mark.parametrize(
+ "params, expected_error_message",
+ [
+ (
+ {"n_candidates": "exhaust", "min_resources": "exhaust"},
+ "cannot be both set to 'exhaust'",
+ ),
+ ],
+)
+def test_input_errors_randomized(params, expected_error_message):
+ # tests specific to HalvingRandomSearchCV
+
+ base_estimator = FastClassifier()
+ param_grid = {"a": [1]}
+ X, y = make_classification(100)
+
+ sh = HalvingRandomSearchCV(base_estimator, param_grid, **params)
+
+ with pytest.raises(ValueError, match=expected_error_message):
+ sh.fit(X, y)
+
+
+@pytest.mark.parametrize(
+ "fraction, subsample_test, expected_train_size, expected_test_size",
+ [
+ (0.5, True, 40, 10),
+ (0.5, False, 40, 20),
+ (0.2, True, 16, 4),
+ (0.2, False, 16, 20),
+ ],
+)
+def test_subsample_splitter_shapes(
+ fraction, subsample_test, expected_train_size, expected_test_size
+):
+ # Make sure splits returned by SubsampleMetaSplitter are of appropriate
+ # size
+
+ n_samples = 100
+ X, y = make_classification(n_samples)
+ cv = _SubsampleMetaSplitter(
+ base_cv=KFold(5),
+ fraction=fraction,
+ subsample_test=subsample_test,
+ random_state=None,
+ )
+
+ for train, test in cv.split(X, y):
+ assert train.shape[0] == expected_train_size
+ assert test.shape[0] == expected_test_size
+ if subsample_test:
+ assert train.shape[0] + test.shape[0] == int(n_samples * fraction)
+ else:
+ assert test.shape[0] == n_samples // cv.base_cv.get_n_splits()
+
+
+@pytest.mark.parametrize("subsample_test", (True, False))
+def test_subsample_splitter_determinism(subsample_test):
+ # Make sure _SubsampleMetaSplitter is consistent across calls to split():
+ # - we're OK having training sets differ (they're always sampled with a
+ # different fraction anyway)
+ # - when we don't subsample the test set, we want it to be always the same.
+ # This check is the most important. This is ensured by the determinism
+ # of the base_cv.
+
+ # Note: we could force both train and test splits to be always the same if
+ # we drew an int seed in _SubsampleMetaSplitter.__init__
+
+ n_samples = 100
+ X, y = make_classification(n_samples)
+ cv = _SubsampleMetaSplitter(
+ base_cv=KFold(5), fraction=0.5, subsample_test=subsample_test, random_state=None
+ )
+
+ folds_a = list(cv.split(X, y, groups=None))
+ folds_b = list(cv.split(X, y, groups=None))
+
+ for (train_a, test_a), (train_b, test_b) in zip(folds_a, folds_b):
+ assert not np.all(train_a == train_b)
+
+ if subsample_test:
+ assert not np.all(test_a == test_b)
+ else:
+ assert np.all(test_a == test_b)
+ assert np.all(X[test_a] == X[test_b])
+
+
+@pytest.mark.parametrize(
+ "k, itr, expected",
+ [
+ (1, 0, ["c"]),
+ (2, 0, ["a", "c"]),
+ (4, 0, ["d", "b", "a", "c"]),
+ (10, 0, ["d", "b", "a", "c"]),
+ (1, 1, ["e"]),
+ (2, 1, ["f", "e"]),
+ (10, 1, ["f", "e"]),
+ (1, 2, ["i"]),
+ (10, 2, ["g", "h", "i"]),
+ ],
+)
+def test_top_k(k, itr, expected):
+ results = { # this isn't a 'real world' result dict
+ "iter": [0, 0, 0, 0, 1, 1, 2, 2, 2],
+ "mean_test_score": [4, 3, 5, 1, 11, 10, 5, 6, 9],
+ "params": ["a", "b", "c", "d", "e", "f", "g", "h", "i"],
+ }
+ got = _top_k(results, k=k, itr=itr)
+ assert np.all(got == expected)
+
+
+@pytest.mark.skip(reason="Failing test")
+@pytest.mark.parametrize("Est", (HalvingRandomSearchCV, HalvingGridSearchCV))
+def test_cv_results(Est):
+ # test that the cv_results_ matches correctly the logic of the
+ # tournament: in particular that the candidates continued in each
+ # successive iteration are those that were best in the previous iteration
+ pd = pytest.importorskip("modin.pandas")
+
+ rng = np.random.RandomState(0)
+
+ n_samples = 1000
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ param_grid = {"a": ("l1", "l2"), "b": list(range(30))}
+ base_estimator = FastClassifier()
+
+ # generate random scores: we want to avoid ties, which would otherwise
+ # mess with the ordering and make testing harder
+ def scorer(est, X, y):
+ return rng.rand()
+
+ sh = Est(base_estimator, param_grid, factor=2, scoring=scorer)
+ if Est is HalvingRandomSearchCV:
+ # same number of candidates as with the grid
+ sh.set_params(n_candidates=2 * 30, min_resources="exhaust")
+
+ sh.fit(X, y)
+
+ # non-regression check for
+ # https://github.com/scikit-learn/scikit-learn/issues/19203
+ assert isinstance(sh.cv_results_["iter"], np.ndarray)
+ assert isinstance(sh.cv_results_["n_resources"], np.ndarray)
+
+ cv_results_df = pd.DataFrame(sh.cv_results_)
+
+ # just make sure we don't have ties
+ assert len(cv_results_df["mean_test_score"].unique()) == len(cv_results_df)
+
+ cv_results_df["params_str"] = cv_results_df["params"].apply(str)
+ table = cv_results_df.pivot(
+ index="params_str", columns="iter", values="mean_test_score"
+ )
+
+ # table looks like something like this:
+ # iter 0 1 2 3 4 5
+ # params_str
+ # {'a': 'l2', 'b': 23} 0.75 NaN NaN NaN NaN NaN
+ # {'a': 'l1', 'b': 30} 0.90 0.875 NaN NaN NaN NaN
+ # {'a': 'l1', 'b': 0} 0.75 NaN NaN NaN NaN NaN
+ # {'a': 'l2', 'b': 3} 0.85 0.925 0.9125 0.90625 NaN NaN
+ # {'a': 'l1', 'b': 5} 0.80 NaN NaN NaN NaN NaN
+ # ...
+
+ # where a NaN indicates that the candidate wasn't evaluated at a given
+ # iteration, because it wasn't part of the top-K at some previous
+ # iteration. We here make sure that candidates that aren't in the top-k at
+ # any given iteration are indeed not evaluated at the subsequent
+ # iterations.
+ nan_mask = pd.isna(table)
+ n_iter = sh.n_iterations_
+ for it in range(n_iter - 1):
+ already_discarded_mask = nan_mask[it]
+
+ # make sure that if a candidate is already discarded, we don't evaluate
+ # it later
+ assert (
+ already_discarded_mask & nan_mask[it + 1] == already_discarded_mask
+ ).all()
+
+ # make sure that the number of discarded candidate is correct
+ discarded_now_mask = ~already_discarded_mask & nan_mask[it + 1]
+ kept_mask = ~already_discarded_mask & ~discarded_now_mask
+ assert kept_mask.sum() == sh.n_candidates_[it + 1]
+
+ # make sure that all discarded candidates have a lower score than the
+ # kept candidates
+ discarded_max_score = table[it].where(discarded_now_mask).max()
+ kept_min_score = table[it].where(kept_mask).min()
+ assert discarded_max_score < kept_min_score
+
+ # We now make sure that the best candidate is chosen only from the last
+ # iteration.
+ # We also make sure this is true even if there were higher scores in
+ # earlier rounds (this isn't generally the case, but worth ensuring it's
+ # possible).
+
+ last_iter = cv_results_df["iter"].max()
+ idx_best_last_iter = cv_results_df[cv_results_df["iter"] == last_iter][
+ "mean_test_score"
+ ].idxmax()
+ idx_best_all_iters = cv_results_df["mean_test_score"].idxmax()
+
+ assert sh.best_params_ == cv_results_df.iloc[idx_best_last_iter]["params"]
+ assert (
+ cv_results_df.iloc[idx_best_last_iter]["mean_test_score"]
+ < cv_results_df.iloc[idx_best_all_iters]["mean_test_score"]
+ )
+ assert (
+ cv_results_df.iloc[idx_best_last_iter]["params"]
+ != cv_results_df.iloc[idx_best_all_iters]["params"]
+ )
+
+
+@pytest.mark.parametrize("Est", (HalvingGridSearchCV, HalvingRandomSearchCV))
+def test_base_estimator_inputs(Est):
+ # make sure that the base estimators are passed the correct parameters and
+ # number of samples at each iteration.
+ pd = pytest.importorskip("modin.pandas")
+
+ passed_n_samples_fit = []
+ passed_n_samples_predict = []
+ passed_params = []
+
+ class FastClassifierBookKeeping(FastClassifier):
+ def fit(self, X, y):
+ passed_n_samples_fit.append(X.shape[0])
+ return super().fit(X, y)
+
+ def predict(self, X):
+ passed_n_samples_predict.append(X.shape[0])
+ return super().predict(X)
+
+ def set_params(self, **params):
+ passed_params.append(params)
+ return super().set_params(**params)
+
+ n_samples = 1024
+ n_splits = 2
+ X, y = make_classification(n_samples=n_samples, random_state=0)
+ param_grid = {"a": ("l1", "l2"), "b": list(range(30))}
+ base_estimator = FastClassifierBookKeeping()
+
+ sh = Est(
+ base_estimator,
+ param_grid,
+ factor=2,
+ cv=n_splits,
+ return_train_score=False,
+ refit=False,
+ )
+ if Est is HalvingRandomSearchCV:
+ # same number of candidates as with the grid
+ sh.set_params(n_candidates=2 * 30, min_resources="exhaust")
+
+ sh.fit(X, y)
+
+ assert len(passed_n_samples_fit) == len(passed_n_samples_predict)
+ passed_n_samples = [
+ x + y for (x, y) in zip(passed_n_samples_fit, passed_n_samples_predict)
+ ]
+
+ # Lists are of length n_splits * n_iter * n_candidates_at_i.
+ # Each chunk of size n_splits corresponds to the n_splits folds for the
+ # same candidate at the same iteration, so they contain equal values. We
+ # subsample such that the lists are of length n_iter * n_candidates_at_it
+ passed_n_samples = passed_n_samples[::n_splits]
+ passed_params = passed_params[::n_splits]
+
+ cv_results_df = pd.DataFrame(sh.cv_results_)
+
+ assert len(passed_params) == len(passed_n_samples) == len(cv_results_df)
+
+ uniques, counts = np.unique(passed_n_samples, return_counts=True)
+ assert (sh.n_resources_ == uniques).all()
+ assert (sh.n_candidates_ == counts).all()
+
+ assert (cv_results_df["params"] == passed_params).all()
+ assert (cv_results_df["n_resources"] == passed_n_samples).all()
+
+
+@pytest.mark.parametrize("Est", (HalvingGridSearchCV, HalvingRandomSearchCV))
+def test_groups_support(Est):
+ # Check if ValueError (when groups is None) propagates to
+ # HalvingGridSearchCV and HalvingRandomSearchCV
+ # And also check if groups is correctly passed to the cv object
+ rng = np.random.RandomState(0)
+
+ X, y = make_classification(n_samples=50, n_classes=2, random_state=0)
+ groups = rng.randint(0, 3, 50)
+
+ clf = LinearSVC(random_state=0)
+ grid = {"C": [1]}
+
+ group_cvs = [
+ LeaveOneGroupOut(),
+ LeavePGroupsOut(2),
+ GroupKFold(n_splits=3),
+ GroupShuffleSplit(random_state=0),
+ ]
+ error_msg = "The 'groups' parameter should not be None."
+ for cv in group_cvs:
+ gs = Est(clf, grid, cv=cv, random_state=0)
+ with pytest.raises(ValueError, match=error_msg):
+ gs.fit(X, y)
+ gs.fit(X, y, groups=groups)
+
+ non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit(random_state=0)]
+ for cv in non_group_cvs:
+ gs = Est(clf, grid, cv=cv)
+ # Should not raise an error
+ gs.fit(X, y)
+
+
+@pytest.mark.parametrize("SearchCV", [HalvingRandomSearchCV, HalvingGridSearchCV])
+def test_min_resources_null(SearchCV):
+ """Check that we raise an error if the minimum resources is set to 0."""
+ base_estimator = FastClassifier()
+ param_grid = {"a": [1]}
+ X = np.empty(0).reshape(0, 3)
+
+ search = SearchCV(base_estimator, param_grid, min_resources="smallest")
+
+ err_msg = "min_resources_=0: you might have passed an empty dataset X."
+ with pytest.raises(ValueError, match=err_msg):
+ search.fit(X, [])
+
+
+@pytest.mark.parametrize("SearchCV", [HalvingGridSearchCV, HalvingRandomSearchCV])
+def test_select_best_index(SearchCV):
+ """Check the selection strategy of the halving search."""
+ results = { # this isn't a 'real world' result dict
+ "iter": np.array([0, 0, 0, 0, 1, 1, 2, 2, 2]),
+ "mean_test_score": np.array([4, 3, 5, 1, 11, 10, 5, 6, 9]),
+ "params": np.array(["a", "b", "c", "d", "e", "f", "g", "h", "i"]),
+ }
+
+ # we expect the index of 'i'
+ best_index = SearchCV._select_best_index(None, None, results)
+ assert best_index == 8
diff --git a/modin/pandas/test/interoperability/sklearn/model_selection/test_validation_ms.py b/modin/pandas/test/interoperability/sklearn/model_selection/test_validation_ms.py
new file mode 100644
index 00000000000..a5fc93bdb1a
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/model_selection/test_validation_ms.py
@@ -0,0 +1,2400 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+"""Test the validation module"""
+import os
+import re
+import sys
+import tempfile
+import warnings
+from functools import partial
+from time import sleep
+
+import pytest
+import numpy as np
+from scipy.sparse import coo_matrix, csr_matrix
+from sklearn.exceptions import FitFailedWarning
+
+from sklearn.model_selection.tests.test_search import FailingClassifier
+
+from sklearn.utils._testing import assert_almost_equal
+from sklearn.utils._testing import assert_array_almost_equal
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_allclose
+from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
+
+from sklearn.utils.validation import _num_samples
+
+from sklearn.model_selection import cross_val_score, ShuffleSplit
+from sklearn.model_selection import cross_val_predict
+from sklearn.model_selection import cross_validate
+from sklearn.model_selection import permutation_test_score
+from sklearn.model_selection import KFold
+from sklearn.model_selection import StratifiedKFold
+from sklearn.model_selection import LeaveOneOut
+from sklearn.model_selection import LeaveOneGroupOut
+from sklearn.model_selection import LeavePGroupsOut
+from sklearn.model_selection import GroupKFold
+from sklearn.model_selection import GroupShuffleSplit
+from sklearn.model_selection import learning_curve
+from sklearn.model_selection import validation_curve
+from sklearn.model_selection._validation import _check_is_permutation
+from sklearn.model_selection._validation import _fit_and_score
+from sklearn.model_selection._validation import _score
+
+from sklearn.datasets import make_regression
+from sklearn.datasets import load_diabetes
+from sklearn.datasets import load_iris
+from sklearn.datasets import load_digits
+from sklearn.metrics import explained_variance_score
+from sklearn.metrics import make_scorer
+from sklearn.metrics import accuracy_score
+from sklearn.metrics import confusion_matrix
+from sklearn.metrics import precision_recall_fscore_support
+from sklearn.metrics import precision_score
+from sklearn.metrics import r2_score
+from sklearn.metrics import mean_squared_error
+from sklearn.metrics import check_scoring
+
+from sklearn.linear_model import Ridge, LogisticRegression, SGDClassifier
+from sklearn.linear_model import PassiveAggressiveClassifier, RidgeClassifier
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.neighbors import KNeighborsClassifier
+from sklearn.svm import SVC, LinearSVC
+from sklearn.cluster import KMeans
+from sklearn.neural_network import MLPRegressor
+
+from sklearn.impute import SimpleImputer
+
+from sklearn.preprocessing import LabelEncoder
+from sklearn.pipeline import Pipeline
+
+from io import StringIO
+from sklearn.base import BaseEstimator
+from sklearn.base import clone
+from sklearn.multiclass import OneVsRestClassifier
+from sklearn.utils import shuffle
+from sklearn.datasets import make_classification
+from sklearn.datasets import make_multilabel_classification
+
+from sklearn.model_selection.tests.common import OneTimeSplitter
+from sklearn.model_selection import GridSearchCV
+
+
+try:
+ WindowsError # type: ignore
+except NameError:
+ WindowsError = None
+
+
+class MockImprovingEstimator(BaseEstimator):
+ """Dummy classifier to test the learning curve"""
+
+ def __init__(self, n_max_train_sizes):
+ self.n_max_train_sizes = n_max_train_sizes
+ self.train_sizes = 0
+ self.X_subset = None
+
+ def fit(self, X_subset, y_subset=None):
+ self.X_subset = X_subset
+ self.train_sizes = X_subset.shape[0]
+ return self
+
+ def predict(self, X):
+ raise NotImplementedError
+
+ def score(self, X=None, Y=None):
+ # training score becomes worse (2 -> 1), test error better (0 -> 1)
+ if self._is_training_data(X):
+ return 2.0 - float(self.train_sizes) / self.n_max_train_sizes
+ else:
+ return float(self.train_sizes) / self.n_max_train_sizes
+
+ def _is_training_data(self, X):
+ return X is self.X_subset
+
+
+class MockIncrementalImprovingEstimator(MockImprovingEstimator):
+ """Dummy classifier that provides partial_fit"""
+
+ def __init__(self, n_max_train_sizes, expected_fit_params=None):
+ super().__init__(n_max_train_sizes)
+ self.x = None
+ self.expected_fit_params = expected_fit_params
+
+ def _is_training_data(self, X):
+ return self.x in X
+
+ def partial_fit(self, X, y=None, **params):
+ self.train_sizes += X.shape[0]
+ self.x = X[0]
+ if self.expected_fit_params:
+ missing = set(self.expected_fit_params) - set(params)
+ if missing:
+ raise AssertionError(
+ f"Expected fit parameter(s) {list(missing)} not seen."
+ )
+ for key, value in params.items():
+ if key in self.expected_fit_params and _num_samples(
+ value
+ ) != _num_samples(X):
+ raise AssertionError(
+ f"Fit parameter {key} has length {_num_samples(value)}"
+ f"; expected {_num_samples(X)}."
+ )
+
+
+class MockEstimatorWithParameter(BaseEstimator):
+ """Dummy classifier to test the validation curve"""
+
+ def __init__(self, param=0.5):
+ self.X_subset = None
+ self.param = param
+
+ def fit(self, X_subset, y_subset):
+ self.X_subset = X_subset
+ self.train_sizes = X_subset.shape[0]
+ return self
+
+ def predict(self, X):
+ raise NotImplementedError
+
+ def score(self, X=None, y=None):
+ return self.param if self._is_training_data(X) else 1 - self.param
+
+ def _is_training_data(self, X):
+ return X is self.X_subset
+
+
+class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
+ """Dummy classifier that disallows repeated calls of fit method"""
+
+ def fit(self, X_subset, y_subset):
+ assert not hasattr(self, "fit_called_"), "fit is called the second time"
+ self.fit_called_ = True
+ return super().fit(X_subset, y_subset)
+
+ def predict(self, X):
+ raise NotImplementedError
+
+
+class MockClassifier:
+ """Dummy classifier to test the cross-validation"""
+
+ def __init__(self, a=0, allow_nd=False):
+ self.a = a
+ self.allow_nd = allow_nd
+
+ def fit(
+ self,
+ X,
+ Y=None,
+ sample_weight=None,
+ class_prior=None,
+ sparse_sample_weight=None,
+ sparse_param=None,
+ dummy_int=None,
+ dummy_str=None,
+ dummy_obj=None,
+ callback=None,
+ ):
+ """The dummy arguments are to test that this fit function can
+ accept non-array arguments through cross-validation, such as:
+ - int
+ - str (this is actually array-like)
+ - object
+ - function
+ """
+ self.dummy_int = dummy_int
+ self.dummy_str = dummy_str
+ self.dummy_obj = dummy_obj
+ if callback is not None:
+ callback(self)
+
+ if self.allow_nd:
+ X = X.reshape(len(X), -1)
+ if X.ndim >= 3 and not self.allow_nd:
+ raise ValueError("X cannot be d")
+ if sample_weight is not None:
+ assert sample_weight.shape[0] == X.shape[0], (
+ "MockClassifier extra fit_param "
+ "sample_weight.shape[0] is {0}, should be {1}".format(
+ sample_weight.shape[0], X.shape[0]
+ )
+ )
+ if class_prior is not None:
+ assert class_prior.shape[0] == len(np.unique(y)), (
+ "MockClassifier extra fit_param class_prior.shape[0]"
+ " is {0}, should be {1}".format(class_prior.shape[0], len(np.unique(y)))
+ )
+ if sparse_sample_weight is not None:
+ fmt = (
+ "MockClassifier extra fit_param sparse_sample_weight"
+ ".shape[0] is {0}, should be {1}"
+ )
+ assert sparse_sample_weight.shape[0] == X.shape[0], fmt.format(
+ sparse_sample_weight.shape[0], X.shape[0]
+ )
+ if sparse_param is not None:
+ fmt = (
+ "MockClassifier extra fit_param sparse_param.shape "
+ "is ({0}, {1}), should be ({2}, {3})"
+ )
+ assert sparse_param.shape == P_sparse.shape, fmt.format(
+ sparse_param.shape[0],
+ sparse_param.shape[1],
+ P_sparse.shape[0],
+ P_sparse.shape[1],
+ )
+ return self
+
+ def predict(self, T):
+ if self.allow_nd:
+ T = T.reshape(len(T), -1)
+ return T[:, 0]
+
+ def predict_proba(self, T):
+ return T
+
+ def score(self, X=None, Y=None):
+ return 1.0 / (1 + np.abs(self.a))
+
+ def get_params(self, deep=False):
+ return {"a": self.a, "allow_nd": self.allow_nd}
+
+
+# XXX: use 2D array, since 1D X is being detected as a single sample in
+# check_consistent_length
+X = np.ones((10, 2))
+X_sparse = coo_matrix(X)
+y = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
+# The number of samples per class needs to be > n_splits,
+# for StratifiedKFold(n_splits=3)
+y2 = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 3])
+P_sparse = coo_matrix(np.eye(5))
+
+
+def test_cross_val_score():
+ clf = MockClassifier()
+
+ for a in range(-10, 10):
+ clf.a = a
+ # Smoke test
+ scores = cross_val_score(clf, X, y2)
+ assert_array_equal(scores, clf.score(X, y2))
+
+ # test with multioutput y
+ multioutput_y = np.column_stack([y2, y2[::-1]])
+ scores = cross_val_score(clf, X_sparse, multioutput_y)
+ assert_array_equal(scores, clf.score(X_sparse, multioutput_y))
+
+ scores = cross_val_score(clf, X_sparse, y2)
+ assert_array_equal(scores, clf.score(X_sparse, y2))
+
+ # test with multioutput y
+ scores = cross_val_score(clf, X_sparse, multioutput_y)
+ assert_array_equal(scores, clf.score(X_sparse, multioutput_y))
+
+ # test with X and y as list
+ list_check = lambda x: isinstance(x, list)
+ clf = CheckingClassifier(check_X=list_check)
+ scores = cross_val_score(clf, X.tolist(), y2.tolist(), cv=3)
+
+ clf = CheckingClassifier(check_y=list_check)
+ scores = cross_val_score(clf, X, y2.tolist(), cv=3)
+
+ with pytest.raises(ValueError):
+ cross_val_score(clf, X, y2, scoring="sklearn")
+
+ # test with 3d X and
+ X_3d = X[:, :, np.newaxis]
+ clf = MockClassifier(allow_nd=True)
+ scores = cross_val_score(clf, X_3d, y2)
+
+ clf = MockClassifier(allow_nd=False)
+ with pytest.raises(ValueError):
+ cross_val_score(clf, X_3d, y2, error_score="raise")
+
+
+def test_cross_validate_many_jobs():
+ # regression test for #12154: cv='warn' with n_jobs>1 trigger a copy of
+ # the parameters leading to a failure in check_cv due to cv is 'warn'
+ # instead of cv == 'warn'.
+ X, y = load_iris(return_X_y=True)
+ clf = SVC(gamma="auto")
+ grid = GridSearchCV(clf, param_grid={"C": [1, 10]})
+ cross_validate(grid, X, y, n_jobs=2)
+
+
+def test_cross_validate_invalid_scoring_param():
+ X, y = make_classification(random_state=0)
+ estimator = MockClassifier()
+
+ # Test the errors
+ error_message_regexp = ".*must be unique strings.*"
+
+ # List/tuple of callables should raise a message advising users to use
+ # dict of names to callables mapping
+ with pytest.raises(ValueError, match=error_message_regexp):
+ cross_validate(
+ estimator,
+ X,
+ y,
+ scoring=(make_scorer(precision_score), make_scorer(accuracy_score)),
+ )
+ with pytest.raises(ValueError, match=error_message_regexp):
+ cross_validate(estimator, X, y, scoring=(make_scorer(precision_score),))
+
+ # So should empty lists/tuples
+ with pytest.raises(ValueError, match=error_message_regexp + "Empty list.*"):
+ cross_validate(estimator, X, y, scoring=())
+
+ # So should duplicated entries
+ with pytest.raises(ValueError, match=error_message_regexp + "Duplicate.*"):
+ cross_validate(estimator, X, y, scoring=("f1_micro", "f1_micro"))
+
+ # Nested Lists should raise a generic error message
+ with pytest.raises(ValueError, match=error_message_regexp):
+ cross_validate(estimator, X, y, scoring=[[make_scorer(precision_score)]])
+
+ error_message_regexp = (
+ ".*scoring is invalid.*Refer to the scoring glossary for details:.*"
+ )
+
+ # Empty dict should raise invalid scoring error
+ with pytest.raises(ValueError, match="An empty dict"):
+ cross_validate(estimator, X, y, scoring=(dict()))
+
+ # And so should any other invalid entry
+ with pytest.raises(ValueError, match=error_message_regexp):
+ cross_validate(estimator, X, y, scoring=5)
+
+ multiclass_scorer = make_scorer(precision_recall_fscore_support)
+
+ # Multiclass Scorers that return multiple values are not supported yet
+ # the warning message we're expecting to see
+ warning_message = (
+ "Scoring failed. The score on this train-test "
+ f"partition for these parameters will be set to {np.nan}. "
+ "Details: \n"
+ )
+
+ with pytest.warns(UserWarning, match=warning_message):
+ cross_validate(estimator, X, y, scoring=multiclass_scorer)
+
+ with pytest.warns(UserWarning, match=warning_message):
+ cross_validate(estimator, X, y, scoring={"foo": multiclass_scorer})
+
+ with pytest.raises(ValueError, match="'mse' is not a valid scoring value."):
+ cross_validate(SVC(), X, y, scoring="mse")
+
+
+def test_cross_validate_nested_estimator():
+ # Non-regression test to ensure that nested
+ # estimators are properly returned in a list
+ # https://github.com/scikit-learn/scikit-learn/pull/17745
+ (X, y) = load_iris(return_X_y=True)
+ pipeline = Pipeline(
+ [
+ ("imputer", SimpleImputer()),
+ ("classifier", MockClassifier()),
+ ]
+ )
+
+ results = cross_validate(pipeline, X, y, return_estimator=True)
+ estimators = results["estimator"]
+
+ assert isinstance(estimators, list)
+ assert all(isinstance(estimator, Pipeline) for estimator in estimators)
+
+
+def test_cross_validate():
+ # Compute train and test mse/r2 scores
+ cv = KFold()
+
+ # Regression
+ X_reg, y_reg = make_regression(n_samples=30, random_state=0)
+ reg = Ridge(random_state=0)
+
+ # Classification
+ X_clf, y_clf = make_classification(n_samples=30, random_state=0)
+ clf = SVC(kernel="linear", random_state=0)
+
+ for X, y, est in ((X_reg, y_reg, reg), (X_clf, y_clf, clf)):
+ # It's okay to evaluate regression metrics on classification too
+ mse_scorer = check_scoring(est, scoring="neg_mean_squared_error")
+ r2_scorer = check_scoring(est, scoring="r2")
+ train_mse_scores = []
+ test_mse_scores = []
+ train_r2_scores = []
+ test_r2_scores = []
+ fitted_estimators = []
+ for train, test in cv.split(X, y):
+ est = clone(reg).fit(X[train], y[train])
+ train_mse_scores.append(mse_scorer(est, X[train], y[train]))
+ train_r2_scores.append(r2_scorer(est, X[train], y[train]))
+ test_mse_scores.append(mse_scorer(est, X[test], y[test]))
+ test_r2_scores.append(r2_scorer(est, X[test], y[test]))
+ fitted_estimators.append(est)
+
+ train_mse_scores = np.array(train_mse_scores)
+ test_mse_scores = np.array(test_mse_scores)
+ train_r2_scores = np.array(train_r2_scores)
+ test_r2_scores = np.array(test_r2_scores)
+ fitted_estimators = np.array(fitted_estimators)
+
+ scores = (
+ train_mse_scores,
+ test_mse_scores,
+ train_r2_scores,
+ test_r2_scores,
+ fitted_estimators,
+ )
+
+ check_cross_validate_single_metric(est, X, y, scores)
+ check_cross_validate_multi_metric(est, X, y, scores)
+
+
+def check_cross_validate_single_metric(clf, X, y, scores):
+ (
+ train_mse_scores,
+ test_mse_scores,
+ train_r2_scores,
+ test_r2_scores,
+ fitted_estimators,
+ ) = scores
+ # Test single metric evaluation when scoring is string or singleton list
+ for return_train_score, dict_len in ((True, 4), (False, 3)):
+ # Single metric passed as a string
+ if return_train_score:
+ mse_scores_dict = cross_validate(
+ clf, X, y, scoring="neg_mean_squared_error", return_train_score=True
+ )
+ assert_array_almost_equal(mse_scores_dict["train_score"], train_mse_scores)
+ else:
+ mse_scores_dict = cross_validate(
+ clf, X, y, scoring="neg_mean_squared_error", return_train_score=False
+ )
+ assert isinstance(mse_scores_dict, dict)
+ assert len(mse_scores_dict) == dict_len
+ assert_array_almost_equal(mse_scores_dict["test_score"], test_mse_scores)
+
+ # Single metric passed as a list
+ if return_train_score:
+ # It must be True by default - deprecated
+ r2_scores_dict = cross_validate(
+ clf, X, y, scoring=["r2"], return_train_score=True
+ )
+ assert_array_almost_equal(r2_scores_dict["train_r2"], train_r2_scores, True)
+ else:
+ r2_scores_dict = cross_validate(
+ clf, X, y, scoring=["r2"], return_train_score=False
+ )
+ assert isinstance(r2_scores_dict, dict)
+ assert len(r2_scores_dict) == dict_len
+ assert_array_almost_equal(r2_scores_dict["test_r2"], test_r2_scores)
+
+ # Test return_estimator option
+ mse_scores_dict = cross_validate(
+ clf, X, y, scoring="neg_mean_squared_error", return_estimator=True
+ )
+ for k, est in enumerate(mse_scores_dict["estimator"]):
+ assert_almost_equal(est.coef_, fitted_estimators[k].coef_)
+ assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_)
+
+
+def check_cross_validate_multi_metric(clf, X, y, scores):
+ # Test multimetric evaluation when scoring is a list / dict
+ (
+ train_mse_scores,
+ test_mse_scores,
+ train_r2_scores,
+ test_r2_scores,
+ fitted_estimators,
+ ) = scores
+
+ def custom_scorer(clf, X, y):
+ y_pred = clf.predict(X)
+ return {
+ "r2": r2_score(y, y_pred),
+ "neg_mean_squared_error": -mean_squared_error(y, y_pred),
+ }
+
+ all_scoring = (
+ ("r2", "neg_mean_squared_error"),
+ {
+ "r2": make_scorer(r2_score),
+ "neg_mean_squared_error": "neg_mean_squared_error",
+ },
+ custom_scorer,
+ )
+
+ keys_sans_train = {
+ "test_r2",
+ "test_neg_mean_squared_error",
+ "fit_time",
+ "score_time",
+ }
+ keys_with_train = keys_sans_train.union(
+ {"train_r2", "train_neg_mean_squared_error"}
+ )
+
+ for return_train_score in (True, False):
+ for scoring in all_scoring:
+ if return_train_score:
+ # return_train_score must be True by default - deprecated
+ cv_results = cross_validate(
+ clf, X, y, scoring=scoring, return_train_score=True
+ )
+ assert_array_almost_equal(cv_results["train_r2"], train_r2_scores)
+ assert_array_almost_equal(
+ cv_results["train_neg_mean_squared_error"], train_mse_scores
+ )
+ else:
+ cv_results = cross_validate(
+ clf, X, y, scoring=scoring, return_train_score=False
+ )
+ assert isinstance(cv_results, dict)
+ assert set(cv_results.keys()) == (
+ keys_with_train if return_train_score else keys_sans_train
+ )
+ assert_array_almost_equal(cv_results["test_r2"], test_r2_scores)
+ assert_array_almost_equal(
+ cv_results["test_neg_mean_squared_error"], test_mse_scores
+ )
+
+ # Make sure all the arrays are of np.ndarray type
+ assert type(cv_results["test_r2"]) == np.ndarray
+ assert type(cv_results["test_neg_mean_squared_error"]) == np.ndarray
+ assert type(cv_results["fit_time"]) == np.ndarray
+ assert type(cv_results["score_time"]) == np.ndarray
+
+ # Ensure all the times are within sane limits
+ assert np.all(cv_results["fit_time"] >= 0)
+ assert np.all(cv_results["fit_time"] < 10)
+ assert np.all(cv_results["score_time"] >= 0)
+ assert np.all(cv_results["score_time"] < 10)
+
+
+def test_cross_val_score_predict_groups():
+ # Check if ValueError (when groups is None) propagates to cross_val_score
+ # and cross_val_predict
+ # And also check if groups is correctly passed to the cv object
+ X, y = make_classification(n_samples=20, n_classes=2, random_state=0)
+
+ clf = SVC(kernel="linear")
+
+ group_cvs = [
+ LeaveOneGroupOut(),
+ LeavePGroupsOut(2),
+ GroupKFold(),
+ GroupShuffleSplit(),
+ ]
+ error_message = "The 'groups' parameter should not be None."
+ for cv in group_cvs:
+ with pytest.raises(ValueError, match=error_message):
+ cross_val_score(estimator=clf, X=X, y=y, cv=cv)
+ with pytest.raises(ValueError, match=error_message):
+ cross_val_predict(estimator=clf, X=X, y=y, cv=cv)
+
+
+@pytest.mark.filterwarnings("ignore: Using or importing the ABCs from")
+def test_cross_val_score_pandas():
+ # check cross_val_score doesn't destroy pandas dataframe
+ types = [(MockDataFrame, MockDataFrame)]
+ try:
+ from modin.pandas import Series, DataFrame
+
+ types.append((Series, DataFrame))
+ except ImportError:
+ pass
+ for TargetType, InputFeatureType in types:
+ # X dataframe, y series
+ # 3 fold cross val is used so we need at least 3 samples per class
+ X_df, y_ser = InputFeatureType(X), TargetType(y2)
+ check_df = lambda x: isinstance(x, InputFeatureType)
+ check_series = lambda x: isinstance(x, TargetType)
+ clf = CheckingClassifier(check_X=check_df, check_y=check_series)
+ cross_val_score(clf, X_df, y_ser, cv=3)
+
+
+def test_cross_val_score_mask():
+ # test that cross_val_score works with boolean masks
+ svm = SVC(kernel="linear")
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ kfold = KFold(5)
+ scores_indices = cross_val_score(svm, X, y, cv=kfold)
+ kfold = KFold(5)
+ cv_masks = []
+ for train, test in kfold.split(X, y):
+ mask_train = np.zeros(len(y), dtype=bool)
+ mask_test = np.zeros(len(y), dtype=bool)
+ mask_train[train] = 1
+ mask_test[test] = 1
+ cv_masks.append((train, test))
+ scores_masks = cross_val_score(svm, X, y, cv=cv_masks)
+ assert_array_equal(scores_indices, scores_masks)
+
+
+def test_cross_val_score_precomputed():
+ # test for svm with precomputed kernel
+ svm = SVC(kernel="precomputed")
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ linear_kernel = np.dot(X, X.T)
+ score_precomputed = cross_val_score(svm, linear_kernel, y)
+ svm = SVC(kernel="linear")
+ score_linear = cross_val_score(svm, X, y)
+ assert_array_almost_equal(score_precomputed, score_linear)
+
+ # test with callable
+ svm = SVC(kernel=lambda x, y: np.dot(x, y.T))
+ score_callable = cross_val_score(svm, X, y)
+ assert_array_almost_equal(score_precomputed, score_callable)
+
+ # Error raised for non-square X
+ svm = SVC(kernel="precomputed")
+ with pytest.raises(ValueError):
+ cross_val_score(svm, X, y)
+
+ # test error is raised when the precomputed kernel is not array-like
+ # or sparse
+ with pytest.raises(ValueError):
+ cross_val_score(svm, linear_kernel.tolist(), y)
+
+
+def test_cross_val_score_fit_params():
+ clf = MockClassifier()
+ n_samples = X.shape[0]
+ n_classes = len(np.unique(y))
+
+ W_sparse = coo_matrix(
+ (np.array([1]), (np.array([1]), np.array([0]))), shape=(10, 1)
+ )
+ P_sparse = coo_matrix(np.eye(5))
+
+ DUMMY_INT = 42
+ DUMMY_STR = "42"
+ DUMMY_OBJ = object()
+
+ def assert_fit_params(clf):
+ # Function to test that the values are passed correctly to the
+ # classifier arguments for non-array type
+
+ assert clf.dummy_int == DUMMY_INT
+ assert clf.dummy_str == DUMMY_STR
+ assert clf.dummy_obj == DUMMY_OBJ
+
+ fit_params = {
+ "sample_weight": np.ones(n_samples),
+ "class_prior": np.full(n_classes, 1.0 / n_classes),
+ "sparse_sample_weight": W_sparse,
+ "sparse_param": P_sparse,
+ "dummy_int": DUMMY_INT,
+ "dummy_str": DUMMY_STR,
+ "dummy_obj": DUMMY_OBJ,
+ "callback": assert_fit_params,
+ }
+ cross_val_score(clf, X, y, fit_params=fit_params)
+
+
+def test_cross_val_score_score_func():
+ clf = MockClassifier()
+ _score_func_args = []
+
+ def score_func(y_test, y_predict):
+ _score_func_args.append((y_test, y_predict))
+ return 1.0
+
+ with warnings.catch_warnings(record=True):
+ scoring = make_scorer(score_func)
+ score = cross_val_score(clf, X, y, scoring=scoring, cv=3)
+ assert_array_equal(score, [1.0, 1.0, 1.0])
+ # Test that score function is called only 3 times (for cv=3)
+ assert len(_score_func_args) == 3
+
+
+def test_cross_val_score_errors():
+ class BrokenEstimator:
+ pass
+
+ with pytest.raises(TypeError):
+ cross_val_score(BrokenEstimator(), X)
+
+
+def test_cross_val_score_with_score_func_classification():
+ iris = load_iris()
+ clf = SVC(kernel="linear")
+
+ # Default score (should be the accuracy score)
+ scores = cross_val_score(clf, iris.data, iris.target)
+ assert_array_almost_equal(scores, [0.97, 1.0, 0.97, 0.97, 1.0], 2)
+
+ # Correct classification score (aka. zero / one score) - should be the
+ # same as the default estimator score
+ zo_scores = cross_val_score(clf, iris.data, iris.target, scoring="accuracy")
+ assert_array_almost_equal(zo_scores, [0.97, 1.0, 0.97, 0.97, 1.0], 2)
+
+ # F1 score (class are balanced so f1_score should be equal to zero/one
+ # score
+ f1_scores = cross_val_score(clf, iris.data, iris.target, scoring="f1_weighted")
+ assert_array_almost_equal(f1_scores, [0.97, 1.0, 0.97, 0.97, 1.0], 2)
+
+
+def test_cross_val_score_with_score_func_regression():
+ X, y = make_regression(n_samples=30, n_features=20, n_informative=5, random_state=0)
+ reg = Ridge()
+
+ # Default score of the Ridge regression estimator
+ scores = cross_val_score(reg, X, y)
+ assert_array_almost_equal(scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2)
+
+ # R2 score (aka. determination coefficient) - should be the
+ # same as the default estimator score
+ r2_scores = cross_val_score(reg, X, y, scoring="r2")
+ assert_array_almost_equal(r2_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2)
+
+ # Mean squared error; this is a loss function, so "scores" are negative
+ neg_mse_scores = cross_val_score(reg, X, y, scoring="neg_mean_squared_error")
+ expected_neg_mse = np.array([-763.07, -553.16, -274.38, -273.26, -1681.99])
+ assert_array_almost_equal(neg_mse_scores, expected_neg_mse, 2)
+
+ # Explained variance
+ scoring = make_scorer(explained_variance_score)
+ ev_scores = cross_val_score(reg, X, y, scoring=scoring)
+ assert_array_almost_equal(ev_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2)
+
+
+def test_permutation_score():
+ iris = load_iris()
+ X = iris.data
+ X_sparse = coo_matrix(X)
+ y = iris.target
+ svm = SVC(kernel="linear")
+ cv = StratifiedKFold(2)
+
+ score, scores, pvalue = permutation_test_score(
+ svm, X, y, n_permutations=30, cv=cv, scoring="accuracy"
+ )
+ assert score > 0.9
+ assert_almost_equal(pvalue, 0.0, 1)
+
+ score_group, _, pvalue_group = permutation_test_score(
+ svm,
+ X,
+ y,
+ n_permutations=30,
+ cv=cv,
+ scoring="accuracy",
+ groups=np.ones(y.size),
+ random_state=0,
+ )
+ assert score_group == score
+ assert pvalue_group == pvalue
+
+ # check that we obtain the same results with a sparse representation
+ svm_sparse = SVC(kernel="linear")
+ cv_sparse = StratifiedKFold(2)
+ score_group, _, pvalue_group = permutation_test_score(
+ svm_sparse,
+ X_sparse,
+ y,
+ n_permutations=30,
+ cv=cv_sparse,
+ scoring="accuracy",
+ groups=np.ones(y.size),
+ random_state=0,
+ )
+
+ assert score_group == score
+ assert pvalue_group == pvalue
+
+ # test with custom scoring object
+ def custom_score(y_true, y_pred):
+ return ((y_true == y_pred).sum() - (y_true != y_pred).sum()) / y_true.shape[0]
+
+ scorer = make_scorer(custom_score)
+ score, _, pvalue = permutation_test_score(
+ svm, X, y, n_permutations=100, scoring=scorer, cv=cv, random_state=0
+ )
+ assert_almost_equal(score, 0.93, 2)
+ assert_almost_equal(pvalue, 0.01, 3)
+
+ # set random y
+ y = np.mod(np.arange(len(y)), 3)
+
+ score, scores, pvalue = permutation_test_score(
+ svm, X, y, n_permutations=30, cv=cv, scoring="accuracy"
+ )
+
+ assert score < 0.5
+ assert pvalue > 0.2
+
+
+def test_permutation_test_score_allow_nans():
+ # Check that permutation_test_score allows input data with NaNs
+ X = np.arange(200, dtype=np.float64).reshape(10, -1)
+ X[2, :] = np.nan
+ y = np.repeat([0, 1], X.shape[0] / 2)
+ p = Pipeline(
+ [
+ ("imputer", SimpleImputer(strategy="mean", missing_values=np.nan)),
+ ("classifier", MockClassifier()),
+ ]
+ )
+ permutation_test_score(p, X, y)
+
+
+def test_permutation_test_score_fit_params():
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+ clf = CheckingClassifier(expected_sample_weight=True)
+
+ err_msg = r"Expected sample_weight to be passed"
+ with pytest.raises(AssertionError, match=err_msg):
+ permutation_test_score(clf, X, y)
+
+ err_msg = r"sample_weight.shape == \(1,\), expected \(8,\)!"
+ with pytest.raises(ValueError, match=err_msg):
+ permutation_test_score(clf, X, y, fit_params={"sample_weight": np.ones(1)})
+ permutation_test_score(clf, X, y, fit_params={"sample_weight": np.ones(10)})
+
+
+def test_cross_val_score_allow_nans():
+ # Check that cross_val_score allows input data with NaNs
+ X = np.arange(200, dtype=np.float64).reshape(10, -1)
+ X[2, :] = np.nan
+ y = np.repeat([0, 1], X.shape[0] / 2)
+ p = Pipeline(
+ [
+ ("imputer", SimpleImputer(strategy="mean", missing_values=np.nan)),
+ ("classifier", MockClassifier()),
+ ]
+ )
+ cross_val_score(p, X, y)
+
+
+def test_cross_val_score_multilabel():
+ X = np.array(
+ [
+ [-3, 4],
+ [2, 4],
+ [3, 3],
+ [0, 2],
+ [-3, 1],
+ [-2, 1],
+ [0, 0],
+ [-2, -1],
+ [-1, -2],
+ [1, -2],
+ ]
+ )
+ y = np.array(
+ [[1, 1], [0, 1], [0, 1], [0, 1], [1, 1], [0, 1], [1, 0], [1, 1], [1, 0], [0, 0]]
+ )
+ clf = KNeighborsClassifier(n_neighbors=1)
+ scoring_micro = make_scorer(precision_score, average="micro")
+ scoring_macro = make_scorer(precision_score, average="macro")
+ scoring_samples = make_scorer(precision_score, average="samples")
+ score_micro = cross_val_score(clf, X, y, scoring=scoring_micro)
+ score_macro = cross_val_score(clf, X, y, scoring=scoring_macro)
+ score_samples = cross_val_score(clf, X, y, scoring=scoring_samples)
+ assert_almost_equal(score_micro, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 3])
+ assert_almost_equal(score_macro, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 4])
+ assert_almost_equal(score_samples, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 4])
+
+
+def test_cross_val_predict():
+ X, y = load_diabetes(return_X_y=True)
+ cv = KFold()
+
+ est = Ridge()
+
+ # Naive loop (should be same as cross_val_predict):
+ preds2 = np.zeros_like(y)
+ for train, test in cv.split(X, y):
+ est.fit(X[train], y[train])
+ preds2[test] = est.predict(X[test])
+
+ preds = cross_val_predict(est, X, y, cv=cv)
+ assert_array_almost_equal(preds, preds2)
+
+ preds = cross_val_predict(est, X, y)
+ assert len(preds) == len(y)
+
+ cv = LeaveOneOut()
+ preds = cross_val_predict(est, X, y, cv=cv)
+ assert len(preds) == len(y)
+
+ Xsp = X.copy()
+ Xsp *= Xsp > np.median(Xsp)
+ Xsp = coo_matrix(Xsp)
+ preds = cross_val_predict(est, Xsp, y)
+ assert_array_almost_equal(len(preds), len(y))
+
+ preds = cross_val_predict(KMeans(n_init="auto"), X)
+ assert len(preds) == len(y)
+
+ class BadCV:
+ def split(self, X, y=None, groups=None):
+ for i in range(4):
+ yield np.array([0, 1, 2, 3]), np.array([4, 5, 6, 7, 8])
+
+ with pytest.raises(ValueError):
+ cross_val_predict(est, X, y, cv=BadCV())
+
+ X, y = load_iris(return_X_y=True)
+
+ warning_message = (
+ r"Number of classes in training fold \(2\) does "
+ r"not match total number of classes \(3\). "
+ "Results may not be appropriate for your use case."
+ )
+ with pytest.warns(RuntimeWarning, match=warning_message):
+ cross_val_predict(
+ LogisticRegression(solver="liblinear"),
+ X,
+ y,
+ method="predict_proba",
+ cv=KFold(2),
+ )
+
+
+def test_cross_val_predict_decision_function_shape():
+ X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
+
+ preds = cross_val_predict(
+ LogisticRegression(solver="liblinear"), X, y, method="decision_function"
+ )
+ assert preds.shape == (50,)
+
+ X, y = load_iris(return_X_y=True)
+
+ preds = cross_val_predict(
+ LogisticRegression(solver="liblinear"), X, y, method="decision_function"
+ )
+ assert preds.shape == (150, 3)
+
+ # This specifically tests imbalanced splits for binary
+ # classification with decision_function. This is only
+ # applicable to classifiers that can be fit on a single
+ # class.
+ X = X[:100]
+ y = y[:100]
+ error_message = (
+ "Only 1 class/es in training fold,"
+ " but 2 in overall dataset. This"
+ " is not supported for decision_function"
+ " with imbalanced folds. To fix "
+ "this, use a cross-validation technique "
+ "resulting in properly stratified folds"
+ )
+ with pytest.raises(ValueError, match=error_message):
+ cross_val_predict(
+ RidgeClassifier(), X, y, method="decision_function", cv=KFold(2)
+ )
+
+ X, y = load_digits(return_X_y=True)
+ est = SVC(kernel="linear", decision_function_shape="ovo")
+
+ preds = cross_val_predict(est, X, y, method="decision_function")
+ assert preds.shape == (1797, 45)
+
+ ind = np.argsort(y)
+ X, y = X[ind], y[ind]
+ error_message_regexp = (
+ r"Output shape \(599L?, 21L?\) of "
+ "decision_function does not match number of "
+ r"classes \(7\) in fold. Irregular "
+ "decision_function .*"
+ )
+ with pytest.raises(ValueError, match=error_message_regexp):
+ cross_val_predict(est, X, y, cv=KFold(n_splits=3), method="decision_function")
+
+
+def test_cross_val_predict_predict_proba_shape():
+ X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
+
+ preds = cross_val_predict(
+ LogisticRegression(solver="liblinear"), X, y, method="predict_proba"
+ )
+ assert preds.shape == (50, 2)
+
+ X, y = load_iris(return_X_y=True)
+
+ preds = cross_val_predict(
+ LogisticRegression(solver="liblinear"), X, y, method="predict_proba"
+ )
+ assert preds.shape == (150, 3)
+
+
+def test_cross_val_predict_predict_log_proba_shape():
+ X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
+
+ preds = cross_val_predict(
+ LogisticRegression(solver="liblinear"), X, y, method="predict_log_proba"
+ )
+ assert preds.shape == (50, 2)
+
+ X, y = load_iris(return_X_y=True)
+
+ preds = cross_val_predict(
+ LogisticRegression(solver="liblinear"), X, y, method="predict_log_proba"
+ )
+ assert preds.shape == (150, 3)
+
+
+def test_cross_val_predict_input_types():
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ X_sparse = coo_matrix(X)
+ multioutput_y = np.column_stack([y, y[::-1]])
+
+ clf = Ridge(fit_intercept=False, random_state=0)
+ # 3 fold cv is used --> at least 3 samples per class
+ # Smoke test
+ predictions = cross_val_predict(clf, X, y)
+ assert predictions.shape == (150,)
+
+ # test with multioutput y
+ predictions = cross_val_predict(clf, X_sparse, multioutput_y)
+ assert predictions.shape == (150, 2)
+
+ predictions = cross_val_predict(clf, X_sparse, y)
+ assert_array_equal(predictions.shape, (150,))
+
+ # test with multioutput y
+ predictions = cross_val_predict(clf, X_sparse, multioutput_y)
+ assert_array_equal(predictions.shape, (150, 2))
+
+ # test with X and y as list
+ list_check = lambda x: isinstance(x, list)
+ clf = CheckingClassifier(check_X=list_check)
+ predictions = cross_val_predict(clf, X.tolist(), y.tolist())
+
+ clf = CheckingClassifier(check_y=list_check)
+ predictions = cross_val_predict(clf, X, y.tolist())
+
+ # test with X and y as list and non empty method
+ predictions = cross_val_predict(
+ LogisticRegression(solver="liblinear"),
+ X.tolist(),
+ y.tolist(),
+ method="decision_function",
+ )
+ predictions = cross_val_predict(
+ LogisticRegression(solver="liblinear"),
+ X,
+ y.tolist(),
+ method="decision_function",
+ )
+
+ # test with 3d X and
+ X_3d = X[:, :, np.newaxis]
+ check_3d = lambda x: x.ndim == 3
+ clf = CheckingClassifier(check_X=check_3d)
+ predictions = cross_val_predict(clf, X_3d, y)
+ assert_array_equal(predictions.shape, (150,))
+
+
+@pytest.mark.filterwarnings("ignore: Using or importing the ABCs from")
+# python3.7 deprecation warnings in pandas via matplotlib :-/
+def test_cross_val_predict_pandas():
+ # check cross_val_score doesn't destroy pandas dataframe
+ types = [(MockDataFrame, MockDataFrame)]
+ try:
+ from modin.pandas import Series, DataFrame
+
+ types.append((Series, DataFrame))
+ except ImportError:
+ pass
+ for TargetType, InputFeatureType in types:
+ # X dataframe, y series
+ X_df, y_ser = InputFeatureType(X), TargetType(y2)
+ check_df = lambda x: isinstance(x, InputFeatureType)
+ check_series = lambda x: isinstance(x, TargetType)
+ clf = CheckingClassifier(check_X=check_df, check_y=check_series)
+ cross_val_predict(clf, X_df, y_ser, cv=3)
+
+
+def test_cross_val_predict_unbalanced():
+ X, y = make_classification(
+ n_samples=100,
+ n_features=2,
+ n_redundant=0,
+ n_informative=2,
+ n_clusters_per_class=1,
+ random_state=1,
+ )
+ # Change the first sample to a new class
+ y[0] = 2
+ clf = LogisticRegression(random_state=1, solver="liblinear")
+ cv = StratifiedKFold(n_splits=2)
+ train, test = list(cv.split(X, y))
+ yhat_proba = cross_val_predict(clf, X, y, cv=cv, method="predict_proba")
+ assert y[test[0]][0] == 2 # sanity check for further assertions
+ assert np.all(yhat_proba[test[0]][:, 2] == 0)
+ assert np.all(yhat_proba[test[0]][:, 0:1] > 0)
+ assert np.all(yhat_proba[test[1]] > 0)
+ assert_array_almost_equal(yhat_proba.sum(axis=1), np.ones(y.shape), decimal=12)
+
+
+def test_cross_val_predict_y_none():
+ # ensure that cross_val_predict works when y is None
+ mock_classifier = MockClassifier()
+ rng = np.random.RandomState(42)
+ X = rng.rand(100, 10)
+ y_hat = cross_val_predict(mock_classifier, X, y=None, cv=5, method="predict")
+ assert_allclose(X[:, 0], y_hat)
+ y_hat_proba = cross_val_predict(
+ mock_classifier, X, y=None, cv=5, method="predict_proba"
+ )
+ assert_allclose(X, y_hat_proba)
+
+
+def test_cross_val_score_sparse_fit_params():
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ clf = MockClassifier()
+ fit_params = {"sparse_sample_weight": coo_matrix(np.eye(X.shape[0]))}
+ a = cross_val_score(clf, X, y, fit_params=fit_params, cv=3)
+ assert_array_equal(a, np.ones(3))
+
+
+def test_learning_curve():
+ n_samples = 30
+ n_splits = 3
+ X, y = make_classification(
+ n_samples=n_samples,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockImprovingEstimator(n_samples * ((n_splits - 1) / n_splits))
+ for shuffle_train in [False, True]:
+ with warnings.catch_warnings(record=True) as w:
+ (
+ train_sizes,
+ train_scores,
+ test_scores,
+ fit_times,
+ score_times,
+ ) = learning_curve(
+ estimator,
+ X,
+ y,
+ cv=KFold(n_splits=n_splits),
+ train_sizes=np.linspace(0.1, 1.0, 10),
+ shuffle=shuffle_train,
+ return_times=True,
+ )
+ if len(w) > 0:
+ raise RuntimeError("Unexpected warning: %r" % w[0].message)
+ assert train_scores.shape == (10, 3)
+ assert test_scores.shape == (10, 3)
+ assert fit_times.shape == (10, 3)
+ assert score_times.shape == (10, 3)
+ assert_array_equal(train_sizes, np.linspace(2, 20, 10))
+ assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
+ assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
+
+ # Cannot use assert_array_almost_equal for fit and score times because
+ # the values are hardware-dependant
+ assert fit_times.dtype == "float64"
+ assert score_times.dtype == "float64"
+
+ # Test a custom cv splitter that can iterate only once
+ with warnings.catch_warnings(record=True) as w:
+ train_sizes2, train_scores2, test_scores2 = learning_curve(
+ estimator,
+ X,
+ y,
+ cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples),
+ train_sizes=np.linspace(0.1, 1.0, 10),
+ shuffle=shuffle_train,
+ )
+ if len(w) > 0:
+ raise RuntimeError("Unexpected warning: %r" % w[0].message)
+ assert_array_almost_equal(train_scores2, train_scores)
+ assert_array_almost_equal(test_scores2, test_scores)
+
+
+def test_learning_curve_unsupervised():
+ X, _ = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockImprovingEstimator(20)
+ train_sizes, train_scores, test_scores = learning_curve(
+ estimator, X, y=None, cv=3, train_sizes=np.linspace(0.1, 1.0, 10)
+ )
+ assert_array_equal(train_sizes, np.linspace(2, 20, 10))
+ assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
+ assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
+
+
+def test_learning_curve_verbose():
+ X, y = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockImprovingEstimator(20)
+
+ old_stdout = sys.stdout
+ sys.stdout = StringIO()
+ try:
+ train_sizes, train_scores, test_scores = learning_curve(
+ estimator, X, y, cv=3, verbose=1
+ )
+ finally:
+ out = sys.stdout.getvalue()
+ sys.stdout.close()
+ sys.stdout = old_stdout
+
+ assert "[learning_curve]" in out
+
+
+def test_learning_curve_incremental_learning_not_possible():
+ X, y = make_classification(
+ n_samples=2,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ # The mockup does not have partial_fit()
+ estimator = MockImprovingEstimator(1)
+ with pytest.raises(ValueError):
+ learning_curve(estimator, X, y, exploit_incremental_learning=True)
+
+
+def test_learning_curve_incremental_learning():
+ X, y = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockIncrementalImprovingEstimator(20)
+ for shuffle_train in [False, True]:
+ train_sizes, train_scores, test_scores = learning_curve(
+ estimator,
+ X,
+ y,
+ cv=3,
+ exploit_incremental_learning=True,
+ train_sizes=np.linspace(0.1, 1.0, 10),
+ shuffle=shuffle_train,
+ )
+ assert_array_equal(train_sizes, np.linspace(2, 20, 10))
+ assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
+ assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
+
+
+def test_learning_curve_incremental_learning_unsupervised():
+ X, _ = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockIncrementalImprovingEstimator(20)
+ train_sizes, train_scores, test_scores = learning_curve(
+ estimator,
+ X,
+ y=None,
+ cv=3,
+ exploit_incremental_learning=True,
+ train_sizes=np.linspace(0.1, 1.0, 10),
+ )
+ assert_array_equal(train_sizes, np.linspace(2, 20, 10))
+ assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
+ assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
+
+
+def test_learning_curve_batch_and_incremental_learning_are_equal():
+ X, y = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ train_sizes = np.linspace(0.2, 1.0, 5)
+ estimator = PassiveAggressiveClassifier(max_iter=1, tol=None, shuffle=False)
+
+ train_sizes_inc, train_scores_inc, test_scores_inc = learning_curve(
+ estimator,
+ X,
+ y,
+ train_sizes=train_sizes,
+ cv=3,
+ exploit_incremental_learning=True,
+ )
+ train_sizes_batch, train_scores_batch, test_scores_batch = learning_curve(
+ estimator,
+ X,
+ y,
+ cv=3,
+ train_sizes=train_sizes,
+ exploit_incremental_learning=False,
+ )
+
+ assert_array_equal(train_sizes_inc, train_sizes_batch)
+ assert_array_almost_equal(
+ train_scores_inc.mean(axis=1), train_scores_batch.mean(axis=1)
+ )
+ assert_array_almost_equal(
+ test_scores_inc.mean(axis=1), test_scores_batch.mean(axis=1)
+ )
+
+
+def test_learning_curve_n_sample_range_out_of_bounds():
+ X, y = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockImprovingEstimator(20)
+ with pytest.raises(ValueError):
+ learning_curve(estimator, X, y, cv=3, train_sizes=[0, 1])
+ with pytest.raises(ValueError):
+ learning_curve(estimator, X, y, cv=3, train_sizes=[0.0, 1.0])
+ with pytest.raises(ValueError):
+ learning_curve(estimator, X, y, cv=3, train_sizes=[0.1, 1.1])
+ with pytest.raises(ValueError):
+ learning_curve(estimator, X, y, cv=3, train_sizes=[0, 20])
+ with pytest.raises(ValueError):
+ learning_curve(estimator, X, y, cv=3, train_sizes=[1, 21])
+
+
+def test_learning_curve_remove_duplicate_sample_sizes():
+ X, y = make_classification(
+ n_samples=3,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockImprovingEstimator(2)
+ warning_message = (
+ "Removed duplicate entries from 'train_sizes'. Number of ticks "
+ "will be less than the size of 'train_sizes': 2 instead of 3."
+ )
+ with pytest.warns(RuntimeWarning, match=warning_message):
+ train_sizes, _, _ = learning_curve(
+ estimator, X, y, cv=3, train_sizes=np.linspace(0.33, 1.0, 3)
+ )
+ assert_array_equal(train_sizes, [1, 2])
+
+
+def test_learning_curve_with_boolean_indices():
+ X, y = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockImprovingEstimator(20)
+ cv = KFold(n_splits=3)
+ train_sizes, train_scores, test_scores = learning_curve(
+ estimator, X, y, cv=cv, train_sizes=np.linspace(0.1, 1.0, 10)
+ )
+ assert_array_equal(train_sizes, np.linspace(2, 20, 10))
+ assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
+ assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
+
+
+def test_learning_curve_with_shuffle():
+ # Following test case was designed this way to verify the code
+ # changes made in pull request: #7506.
+ X = np.array(
+ [
+ [1, 2],
+ [3, 4],
+ [5, 6],
+ [7, 8],
+ [11, 12],
+ [13, 14],
+ [15, 16],
+ [17, 18],
+ [19, 20],
+ [7, 8],
+ [9, 10],
+ [11, 12],
+ [13, 14],
+ [15, 16],
+ [17, 18],
+ ]
+ )
+ y = np.array([1, 1, 1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 2, 3, 4])
+ groups = np.array([1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 4, 4, 4, 4])
+ # Splits on these groups fail without shuffle as the first iteration
+ # of the learning curve doesn't contain label 4 in the training set.
+ estimator = PassiveAggressiveClassifier(max_iter=5, tol=None, shuffle=False)
+
+ cv = GroupKFold(n_splits=2)
+ train_sizes_batch, train_scores_batch, test_scores_batch = learning_curve(
+ estimator,
+ X,
+ y,
+ cv=cv,
+ n_jobs=1,
+ train_sizes=np.linspace(0.3, 1.0, 3),
+ groups=groups,
+ shuffle=True,
+ random_state=2,
+ )
+ assert_array_almost_equal(
+ train_scores_batch.mean(axis=1), np.array([0.75, 0.3, 0.36111111])
+ )
+ assert_array_almost_equal(
+ test_scores_batch.mean(axis=1), np.array([0.36111111, 0.25, 0.25])
+ )
+ with pytest.raises(ValueError):
+ learning_curve(
+ estimator,
+ X,
+ y,
+ cv=cv,
+ n_jobs=1,
+ train_sizes=np.linspace(0.3, 1.0, 3),
+ groups=groups,
+ error_score="raise",
+ )
+
+ train_sizes_inc, train_scores_inc, test_scores_inc = learning_curve(
+ estimator,
+ X,
+ y,
+ cv=cv,
+ n_jobs=1,
+ train_sizes=np.linspace(0.3, 1.0, 3),
+ groups=groups,
+ shuffle=True,
+ random_state=2,
+ exploit_incremental_learning=True,
+ )
+ assert_array_almost_equal(
+ train_scores_inc.mean(axis=1), train_scores_batch.mean(axis=1)
+ )
+ assert_array_almost_equal(
+ test_scores_inc.mean(axis=1), test_scores_batch.mean(axis=1)
+ )
+
+
+def test_learning_curve_fit_params():
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+ clf = CheckingClassifier(expected_sample_weight=True)
+
+ err_msg = r"Expected sample_weight to be passed"
+ with pytest.raises(AssertionError, match=err_msg):
+ learning_curve(clf, X, y, error_score="raise")
+
+ err_msg = r"sample_weight.shape == \(1,\), expected \(2,\)!"
+ with pytest.raises(ValueError, match=err_msg):
+ learning_curve(
+ clf, X, y, error_score="raise", fit_params={"sample_weight": np.ones(1)}
+ )
+ learning_curve(
+ clf, X, y, error_score="raise", fit_params={"sample_weight": np.ones(10)}
+ )
+
+
+def test_learning_curve_incremental_learning_fit_params():
+ X, y = make_classification(
+ n_samples=30,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ estimator = MockIncrementalImprovingEstimator(20, ["sample_weight"])
+ err_msg = r"Expected fit parameter\(s\) \['sample_weight'\] not seen."
+ with pytest.raises(AssertionError, match=err_msg):
+ learning_curve(
+ estimator,
+ X,
+ y,
+ cv=3,
+ exploit_incremental_learning=True,
+ train_sizes=np.linspace(0.1, 1.0, 10),
+ error_score="raise",
+ )
+
+ err_msg = "Fit parameter sample_weight has length 3; expected"
+ with pytest.raises(AssertionError, match=err_msg):
+ learning_curve(
+ estimator,
+ X,
+ y,
+ cv=3,
+ exploit_incremental_learning=True,
+ train_sizes=np.linspace(0.1, 1.0, 10),
+ error_score="raise",
+ fit_params={"sample_weight": np.ones(3)},
+ )
+
+ learning_curve(
+ estimator,
+ X,
+ y,
+ cv=3,
+ exploit_incremental_learning=True,
+ train_sizes=np.linspace(0.1, 1.0, 10),
+ error_score="raise",
+ fit_params={"sample_weight": np.ones(2)},
+ )
+
+
+def test_validation_curve():
+ X, y = make_classification(
+ n_samples=2,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+ param_range = np.linspace(0, 1, 10)
+ with warnings.catch_warnings(record=True) as w:
+ train_scores, test_scores = validation_curve(
+ MockEstimatorWithParameter(),
+ X,
+ y,
+ param_name="param",
+ param_range=param_range,
+ cv=2,
+ )
+ if len(w) > 0:
+ raise RuntimeError("Unexpected warning: %r" % w[0].message)
+
+ assert_array_almost_equal(train_scores.mean(axis=1), param_range)
+ assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)
+
+
+def test_validation_curve_clone_estimator():
+ X, y = make_classification(
+ n_samples=2,
+ n_features=1,
+ n_informative=1,
+ n_redundant=0,
+ n_classes=2,
+ n_clusters_per_class=1,
+ random_state=0,
+ )
+
+ param_range = np.linspace(1, 0, 10)
+ _, _ = validation_curve(
+ MockEstimatorWithSingleFitCallAllowed(),
+ X,
+ y,
+ param_name="param",
+ param_range=param_range,
+ cv=2,
+ )
+
+
+def test_validation_curve_cv_splits_consistency():
+ n_samples = 100
+ n_splits = 5
+ X, y = make_classification(n_samples=100, random_state=0)
+
+ scores1 = validation_curve(
+ SVC(kernel="linear", random_state=0),
+ X,
+ y,
+ param_name="C",
+ param_range=[0.1, 0.1, 0.2, 0.2],
+ cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples),
+ )
+ # The OneTimeSplitter is a non-re-entrant cv splitter. Unless, the
+ # `split` is called for each parameter, the following should produce
+ # identical results for param setting 1 and param setting 2 as both have
+ # the same C value.
+ assert_array_almost_equal(*np.vsplit(np.hstack(scores1)[(0, 2, 1, 3), :], 2))
+
+ scores2 = validation_curve(
+ SVC(kernel="linear", random_state=0),
+ X,
+ y,
+ param_name="C",
+ param_range=[0.1, 0.1, 0.2, 0.2],
+ cv=KFold(n_splits=n_splits, shuffle=True),
+ )
+
+ # For scores2, compare the 1st and 2nd parameter's scores
+ # (Since the C value for 1st two param setting is 0.1, they must be
+ # consistent unless the train test folds differ between the param settings)
+ assert_array_almost_equal(*np.vsplit(np.hstack(scores2)[(0, 2, 1, 3), :], 2))
+
+ scores3 = validation_curve(
+ SVC(kernel="linear", random_state=0),
+ X,
+ y,
+ param_name="C",
+ param_range=[0.1, 0.1, 0.2, 0.2],
+ cv=KFold(n_splits=n_splits),
+ )
+
+ # OneTimeSplitter is basically unshuffled KFold(n_splits=5). Sanity check.
+ assert_array_almost_equal(np.array(scores3), np.array(scores1))
+
+
+def test_validation_curve_fit_params():
+ X = np.arange(100).reshape(10, 10)
+ y = np.array([0] * 5 + [1] * 5)
+ clf = CheckingClassifier(expected_sample_weight=True)
+
+ err_msg = r"Expected sample_weight to be passed"
+ with pytest.raises(AssertionError, match=err_msg):
+ validation_curve(
+ clf,
+ X,
+ y,
+ param_name="foo_param",
+ param_range=[1, 2, 3],
+ error_score="raise",
+ )
+
+ err_msg = r"sample_weight.shape == \(1,\), expected \(8,\)!"
+ with pytest.raises(ValueError, match=err_msg):
+ validation_curve(
+ clf,
+ X,
+ y,
+ param_name="foo_param",
+ param_range=[1, 2, 3],
+ error_score="raise",
+ fit_params={"sample_weight": np.ones(1)},
+ )
+ validation_curve(
+ clf,
+ X,
+ y,
+ param_name="foo_param",
+ param_range=[1, 2, 3],
+ error_score="raise",
+ fit_params={"sample_weight": np.ones(10)},
+ )
+
+
+def test_check_is_permutation():
+ rng = np.random.RandomState(0)
+ p = np.arange(100)
+ rng.shuffle(p)
+ assert _check_is_permutation(p, 100)
+ assert not _check_is_permutation(np.delete(p, 23), 100)
+
+ p[0] = 23
+ assert not _check_is_permutation(p, 100)
+
+ # Check if the additional duplicate indices are caught
+ assert not _check_is_permutation(np.hstack((p, 0)), 100)
+
+
+def test_cross_val_predict_sparse_prediction():
+ # check that cross_val_predict gives same result for sparse and dense input
+ X, y = make_multilabel_classification(
+ n_classes=2,
+ n_labels=1,
+ allow_unlabeled=False,
+ return_indicator=True,
+ random_state=1,
+ )
+ X_sparse = csr_matrix(X)
+ y_sparse = csr_matrix(y)
+ classif = OneVsRestClassifier(SVC(kernel="linear"))
+ preds = cross_val_predict(classif, X, y, cv=10)
+ preds_sparse = cross_val_predict(classif, X_sparse, y_sparse, cv=10)
+ preds_sparse = preds_sparse.toarray()
+ assert_array_almost_equal(preds_sparse, preds)
+
+
+def check_cross_val_predict_binary(est, X, y, method):
+ """Helper for tests of cross_val_predict with binary classification"""
+ cv = KFold(n_splits=3, shuffle=False)
+
+ # Generate expected outputs
+ if y.ndim == 1:
+ exp_shape = (len(X),) if method == "decision_function" else (len(X), 2)
+ else:
+ exp_shape = y.shape
+ expected_predictions = np.zeros(exp_shape)
+ for train, test in cv.split(X, y):
+ est = clone(est).fit(X[train], y[train])
+ expected_predictions[test] = getattr(est, method)(X[test])
+
+ # Check actual outputs for several representations of y
+ for tg in [y, y + 1, y - 2, y.astype("str")]:
+ assert_allclose(
+ cross_val_predict(est, X, tg, method=method, cv=cv), expected_predictions
+ )
+
+
+def check_cross_val_predict_multiclass(est, X, y, method):
+ """Helper for tests of cross_val_predict with multiclass classification"""
+ cv = KFold(n_splits=3, shuffle=False)
+
+ # Generate expected outputs
+ float_min = np.finfo(np.float64).min
+ default_values = {
+ "decision_function": float_min,
+ "predict_log_proba": float_min,
+ "predict_proba": 0,
+ }
+ expected_predictions = np.full(
+ (len(X), len(set(y))), default_values[method], dtype=np.float64
+ )
+ _, y_enc = np.unique(y, return_inverse=True)
+ for train, test in cv.split(X, y_enc):
+ est = clone(est).fit(X[train], y_enc[train])
+ fold_preds = getattr(est, method)(X[test])
+ i_cols_fit = np.unique(y_enc[train])
+ expected_predictions[np.ix_(test, i_cols_fit)] = fold_preds
+
+ # Check actual outputs for several representations of y
+ for tg in [y, y + 1, y - 2, y.astype("str")]:
+ assert_allclose(
+ cross_val_predict(est, X, tg, method=method, cv=cv), expected_predictions
+ )
+
+
+def check_cross_val_predict_multilabel(est, X, y, method):
+ """Check the output of cross_val_predict for 2D targets using
+ Estimators which provide a predictions as a list with one
+ element per class.
+ """
+ cv = KFold(n_splits=3, shuffle=False)
+
+ # Create empty arrays of the correct size to hold outputs
+ float_min = np.finfo(np.float64).min
+ default_values = {
+ "decision_function": float_min,
+ "predict_log_proba": float_min,
+ "predict_proba": 0,
+ }
+ n_targets = y.shape[1]
+ expected_preds = []
+ for i_col in range(n_targets):
+ n_classes_in_label = len(set(y[:, i_col]))
+ if n_classes_in_label == 2 and method == "decision_function":
+ exp_shape = (len(X),)
+ else:
+ exp_shape = (len(X), n_classes_in_label)
+ expected_preds.append(
+ np.full(exp_shape, default_values[method], dtype=np.float64)
+ )
+
+ # Generate expected outputs
+ y_enc_cols = [
+ np.unique(y[:, i], return_inverse=True)[1][:, np.newaxis]
+ for i in range(y.shape[1])
+ ]
+ y_enc = np.concatenate(y_enc_cols, axis=1)
+ for train, test in cv.split(X, y_enc):
+ est = clone(est).fit(X[train], y_enc[train])
+ fold_preds = getattr(est, method)(X[test])
+ for i_col in range(n_targets):
+ fold_cols = np.unique(y_enc[train][:, i_col])
+ if expected_preds[i_col].ndim == 1:
+ # Decision function with <=2 classes
+ expected_preds[i_col][test] = fold_preds[i_col]
+ else:
+ idx = np.ix_(test, fold_cols)
+ expected_preds[i_col][idx] = fold_preds[i_col]
+
+ # Check actual outputs for several representations of y
+ for tg in [y, y + 1, y - 2, y.astype("str")]:
+ cv_predict_output = cross_val_predict(est, X, tg, method=method, cv=cv)
+ assert len(cv_predict_output) == len(expected_preds)
+ for i in range(len(cv_predict_output)):
+ assert_allclose(cv_predict_output[i], expected_preds[i])
+
+
+def check_cross_val_predict_with_method_binary(est):
+ # This test includes the decision_function with two classes.
+ # This is a special case: it has only one column of output.
+ X, y = make_classification(n_classes=2, random_state=0)
+ for method in ["decision_function", "predict_proba", "predict_log_proba"]:
+ check_cross_val_predict_binary(est, X, y, method)
+
+
+def check_cross_val_predict_with_method_multiclass(est):
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ X, y = shuffle(X, y, random_state=0)
+ for method in ["decision_function", "predict_proba", "predict_log_proba"]:
+ check_cross_val_predict_multiclass(est, X, y, method)
+
+
+def test_cross_val_predict_with_method():
+ check_cross_val_predict_with_method_binary(LogisticRegression(solver="liblinear"))
+ check_cross_val_predict_with_method_multiclass(
+ LogisticRegression(solver="liblinear")
+ )
+
+
+def test_cross_val_predict_method_checking():
+ # Regression test for issue #9639. Tests that cross_val_predict does not
+ # check estimator methods (e.g. predict_proba) before fitting
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ X, y = shuffle(X, y, random_state=0)
+ for method in ["decision_function", "predict_proba", "predict_log_proba"]:
+ est = SGDClassifier(loss="log_loss", random_state=2)
+ check_cross_val_predict_multiclass(est, X, y, method)
+
+
+def test_gridsearchcv_cross_val_predict_with_method():
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ X, y = shuffle(X, y, random_state=0)
+ est = GridSearchCV(
+ LogisticRegression(random_state=42, solver="liblinear"), {"C": [0.1, 1]}, cv=2
+ )
+ for method in ["decision_function", "predict_proba", "predict_log_proba"]:
+ check_cross_val_predict_multiclass(est, X, y, method)
+
+
+def test_cross_val_predict_with_method_multilabel_ovr():
+ # OVR does multilabel predictions, but only arrays of
+ # binary indicator columns. The output of predict_proba
+ # is a 2D array with shape (n_samples, n_classes).
+ n_samp = 100
+ n_classes = 4
+ X, y = make_multilabel_classification(
+ n_samples=n_samp, n_labels=3, n_classes=n_classes, n_features=5, random_state=42
+ )
+ est = OneVsRestClassifier(LogisticRegression(solver="liblinear", random_state=0))
+ for method in ["predict_proba", "decision_function"]:
+ check_cross_val_predict_binary(est, X, y, method=method)
+
+
+class RFWithDecisionFunction(RandomForestClassifier):
+ # None of the current multioutput-multiclass estimators have
+ # decision function methods. Create a mock decision function
+ # to test the cross_val_predict function's handling of this case.
+ def decision_function(self, X):
+ probs = self.predict_proba(X)
+ msg = "This helper should only be used on multioutput-multiclass tasks"
+ assert isinstance(probs, list), msg
+ probs = [p[:, -1] if p.shape[1] == 2 else p for p in probs]
+ return probs
+
+
+def test_cross_val_predict_with_method_multilabel_rf():
+ # The RandomForest allows multiple classes in each label.
+ # Output of predict_proba is a list of outputs of predict_proba
+ # for each individual label.
+ n_classes = 4
+ X, y = make_multilabel_classification(
+ n_samples=100, n_labels=3, n_classes=n_classes, n_features=5, random_state=42
+ )
+ y[:, 0] += y[:, 1] # Put three classes in the first column
+ for method in ["predict_proba", "predict_log_proba", "decision_function"]:
+ est = RFWithDecisionFunction(n_estimators=5, random_state=0)
+ with warnings.catch_warnings():
+ # Suppress "RuntimeWarning: divide by zero encountered in log"
+ warnings.simplefilter("ignore")
+ check_cross_val_predict_multilabel(est, X, y, method=method)
+
+
+def test_cross_val_predict_with_method_rare_class():
+ # Test a multiclass problem where one class will be missing from
+ # one of the CV training sets.
+ rng = np.random.RandomState(0)
+ X = rng.normal(0, 1, size=(14, 10))
+ y = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 3])
+ est = LogisticRegression(solver="liblinear")
+ for method in ["predict_proba", "predict_log_proba", "decision_function"]:
+ with warnings.catch_warnings():
+ # Suppress warning about too few examples of a class
+ warnings.simplefilter("ignore")
+ check_cross_val_predict_multiclass(est, X, y, method)
+
+
+def test_cross_val_predict_with_method_multilabel_rf_rare_class():
+ # The RandomForest allows anything for the contents of the labels.
+ # Output of predict_proba is a list of outputs of predict_proba
+ # for each individual label.
+ # In this test, the first label has a class with a single example.
+ # We'll have one CV fold where the training data don't include it.
+ rng = np.random.RandomState(0)
+ X = rng.normal(0, 1, size=(5, 10))
+ y = np.array([[0, 0], [1, 1], [2, 1], [0, 1], [1, 0]])
+ for method in ["predict_proba", "predict_log_proba"]:
+ est = RFWithDecisionFunction(n_estimators=5, random_state=0)
+ with warnings.catch_warnings():
+ # Suppress "RuntimeWarning: divide by zero encountered in log"
+ warnings.simplefilter("ignore")
+ check_cross_val_predict_multilabel(est, X, y, method=method)
+
+
+def get_expected_predictions(X, y, cv, classes, est, method):
+ expected_predictions = np.zeros([len(y), classes])
+ func = getattr(est, method)
+
+ for train, test in cv.split(X, y):
+ est.fit(X[train], y[train])
+ expected_predictions_ = func(X[test])
+ # To avoid 2 dimensional indexing
+ if method == "predict_proba":
+ exp_pred_test = np.zeros((len(test), classes))
+ else:
+ exp_pred_test = np.full(
+ (len(test), classes), np.finfo(expected_predictions.dtype).min
+ )
+ exp_pred_test[:, est.classes_] = expected_predictions_
+ expected_predictions[test] = exp_pred_test
+
+ return expected_predictions
+
+
+def test_cross_val_predict_class_subset():
+ X = np.arange(200).reshape(100, 2)
+ y = np.array([x // 10 for x in range(100)])
+ classes = 10
+
+ kfold3 = KFold(n_splits=3)
+ kfold4 = KFold(n_splits=4)
+
+ le = LabelEncoder()
+
+ methods = ["decision_function", "predict_proba", "predict_log_proba"]
+ for method in methods:
+ est = LogisticRegression(solver="liblinear")
+
+ # Test with n_splits=3
+ predictions = cross_val_predict(est, X, y, method=method, cv=kfold3)
+
+ # Runs a naive loop (should be same as cross_val_predict):
+ expected_predictions = get_expected_predictions(
+ X, y, kfold3, classes, est, method
+ )
+ assert_array_almost_equal(expected_predictions, predictions)
+
+ # Test with n_splits=4
+ predictions = cross_val_predict(est, X, y, method=method, cv=kfold4)
+ expected_predictions = get_expected_predictions(
+ X, y, kfold4, classes, est, method
+ )
+ assert_array_almost_equal(expected_predictions, predictions)
+
+ # Testing unordered labels
+ y = shuffle(np.repeat(range(10), 10), random_state=0)
+ predictions = cross_val_predict(est, X, y, method=method, cv=kfold3)
+ y = le.fit_transform(y)
+ expected_predictions = get_expected_predictions(
+ X, y, kfold3, classes, est, method
+ )
+ assert_array_almost_equal(expected_predictions, predictions)
+
+
+def test_score_memmap():
+ # Ensure a scalar score of memmap type is accepted
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ clf = MockClassifier()
+ tf = tempfile.NamedTemporaryFile(mode="wb", delete=False)
+ tf.write(b"Hello world!!!!!")
+ tf.close()
+ scores = np.memmap(tf.name, dtype=np.float64)
+ score = np.memmap(tf.name, shape=(), mode="r", dtype=np.float64)
+ try:
+ cross_val_score(clf, X, y, scoring=lambda est, X, y: score)
+ with pytest.raises(ValueError):
+ cross_val_score(clf, X, y, scoring=lambda est, X, y: scores)
+ finally:
+ # Best effort to release the mmap file handles before deleting the
+ # backing file under Windows
+ scores, score = None, None
+ for _ in range(3):
+ try:
+ os.unlink(tf.name)
+ break
+ except WindowsError:
+ sleep(1.0)
+
+
+@pytest.mark.filterwarnings("ignore: Using or importing the ABCs from")
+def test_permutation_test_score_pandas():
+ # check permutation_test_score doesn't destroy pandas dataframe
+ types = [(MockDataFrame, MockDataFrame)]
+ try:
+ from modin.pandas import Series, DataFrame
+
+ types.append((Series, DataFrame))
+ except ImportError:
+ pass
+ for TargetType, InputFeatureType in types:
+ # X dataframe, y series
+ iris = load_iris()
+ X, y = iris.data, iris.target
+ X_df, y_ser = InputFeatureType(X), TargetType(y)
+ check_df = lambda x: isinstance(x, InputFeatureType)
+ check_series = lambda x: isinstance(x, TargetType)
+ clf = CheckingClassifier(check_X=check_df, check_y=check_series)
+ permutation_test_score(clf, X_df, y_ser)
+
+
+def test_fit_and_score_failing():
+ # Create a failing classifier to deliberately fail
+ failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
+ # dummy X data
+ X = np.arange(1, 10)
+ y = np.ones(9)
+ fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
+ # passing error score to trigger the warning message
+ fit_and_score_kwargs = {"error_score": "raise"}
+ # check if exception was raised, with default error_score='raise'
+ with pytest.raises(ValueError, match="Failing classifier failed as required"):
+ _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
+
+ # check that functions upstream pass error_score param to _fit_and_score
+ error_message = re.escape(
+ "error_score must be the string 'raise' or a numeric value. (Hint: if "
+ "using 'raise', please make sure that it has been spelled correctly.)"
+ )
+ with pytest.raises(ValueError, match=error_message):
+ cross_validate(failing_clf, X, cv=3, error_score="unvalid-string")
+
+ with pytest.raises(ValueError, match=error_message):
+ cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string")
+
+ with pytest.raises(ValueError, match=error_message):
+ learning_curve(failing_clf, X, y, cv=3, error_score="unvalid-string")
+
+ with pytest.raises(ValueError, match=error_message):
+ validation_curve(
+ failing_clf,
+ X,
+ y,
+ param_name="parameter",
+ param_range=[FailingClassifier.FAILING_PARAMETER],
+ cv=3,
+ error_score="unvalid-string",
+ )
+
+ assert failing_clf.score() == 0.0 # FailingClassifier coverage
+
+
+def test_fit_and_score_working():
+ X, y = make_classification(n_samples=30, random_state=0)
+ clf = SVC(kernel="linear", random_state=0)
+ train, test = next(ShuffleSplit().split(X))
+ # Test return_parameters option
+ fit_and_score_args = [clf, X, y, dict(), train, test, 0]
+ fit_and_score_kwargs = {
+ "parameters": {"max_iter": 100, "tol": 0.1},
+ "fit_params": None,
+ "return_parameters": True,
+ }
+ result = _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
+ assert result["parameters"] == fit_and_score_kwargs["parameters"]
+
+
+class DataDependentFailingClassifier(BaseEstimator):
+ def __init__(self, max_x_value=None):
+ self.max_x_value = max_x_value
+
+ def fit(self, X, y=None):
+ num_values_too_high = (X > self.max_x_value).sum()
+ if num_values_too_high:
+ raise ValueError(
+ f"Classifier fit failed with {num_values_too_high} values too high"
+ )
+
+ def score(self, X=None, Y=None):
+ return 0.0
+
+
+@pytest.mark.parametrize("error_score", [np.nan, 0])
+def test_cross_validate_some_failing_fits_warning(error_score):
+ # Create a failing classifier to deliberately fail
+ failing_clf = DataDependentFailingClassifier(max_x_value=8)
+ # dummy X data
+ X = np.arange(1, 10)
+ y = np.ones(9)
+ # passing error score to trigger the warning message
+ cross_validate_args = [failing_clf, X, y]
+ cross_validate_kwargs = {"cv": 3, "error_score": error_score}
+ # check if the warning message type is as expected
+
+ individual_fit_error_message = (
+ "ValueError: Classifier fit failed with 1 values too high"
+ )
+ warning_message = re.compile(
+ "2 fits failed.+total of 3.+The score on these"
+ " train-test partitions for these parameters will be set to"
+ f" {cross_validate_kwargs['error_score']}.+{individual_fit_error_message}",
+ flags=re.DOTALL,
+ )
+
+ with pytest.warns(FitFailedWarning, match=warning_message):
+ cross_validate(*cross_validate_args, **cross_validate_kwargs)
+
+
+@pytest.mark.parametrize("error_score", [np.nan, 0])
+def test_cross_validate_all_failing_fits_error(error_score):
+ # Create a failing classifier to deliberately fail
+ failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
+ # dummy X data
+ X = np.arange(1, 10)
+ y = np.ones(9)
+
+ cross_validate_args = [failing_clf, X, y]
+ cross_validate_kwargs = {"cv": 7, "error_score": error_score}
+
+ individual_fit_error_message = "ValueError: Failing classifier failed as required"
+ error_message = re.compile(
+ "All the 7 fits failed.+your model is misconfigured.+"
+ f"{individual_fit_error_message}",
+ flags=re.DOTALL,
+ )
+
+ with pytest.raises(ValueError, match=error_message):
+ cross_validate(*cross_validate_args, **cross_validate_kwargs)
+
+
+def _failing_scorer(estimator, X, y, error_msg):
+ raise ValueError(error_msg)
+
+
+@pytest.mark.filterwarnings("ignore:lbfgs failed to converge")
+@pytest.mark.parametrize("error_score", [np.nan, 0, "raise"])
+def test_cross_val_score_failing_scorer(error_score):
+ # check that an estimator can fail during scoring in `cross_val_score` and
+ # that we can optionally replaced it with `error_score`
+ X, y = load_iris(return_X_y=True)
+ clf = LogisticRegression(max_iter=5).fit(X, y)
+
+ error_msg = "This scorer is supposed to fail!!!"
+ failing_scorer = partial(_failing_scorer, error_msg=error_msg)
+
+ if error_score == "raise":
+ with pytest.raises(ValueError, match=error_msg):
+ cross_val_score(
+ clf, X, y, cv=3, scoring=failing_scorer, error_score=error_score
+ )
+ else:
+ warning_msg = (
+ "Scoring failed. The score on this train-test partition for "
+ f"these parameters will be set to {error_score}"
+ )
+ with pytest.warns(UserWarning, match=warning_msg):
+ scores = cross_val_score(
+ clf, X, y, cv=3, scoring=failing_scorer, error_score=error_score
+ )
+ assert_allclose(scores, error_score)
+
+
+@pytest.mark.filterwarnings("ignore:lbfgs failed to converge")
+@pytest.mark.parametrize("error_score", [np.nan, 0, "raise"])
+@pytest.mark.parametrize("return_train_score", [True, False])
+@pytest.mark.parametrize("with_multimetric", [False, True])
+def test_cross_validate_failing_scorer(
+ error_score, return_train_score, with_multimetric
+):
+ # Check that an estimator can fail during scoring in `cross_validate` and
+ # that we can optionally replace it with `error_score`. In the multimetric
+ # case also check the result of a non-failing scorer where the other scorers
+ # are failing.
+ X, y = load_iris(return_X_y=True)
+ clf = LogisticRegression(max_iter=5).fit(X, y)
+
+ error_msg = "This scorer is supposed to fail!!!"
+ failing_scorer = partial(_failing_scorer, error_msg=error_msg)
+ if with_multimetric:
+ non_failing_scorer = make_scorer(mean_squared_error)
+ scoring = {
+ "score_1": failing_scorer,
+ "score_2": non_failing_scorer,
+ "score_3": failing_scorer,
+ }
+ else:
+ scoring = failing_scorer
+
+ if error_score == "raise":
+ with pytest.raises(ValueError, match=error_msg):
+ cross_validate(
+ clf,
+ X,
+ y,
+ cv=3,
+ scoring=scoring,
+ return_train_score=return_train_score,
+ error_score=error_score,
+ )
+ else:
+ warning_msg = (
+ "Scoring failed. The score on this train-test partition for "
+ f"these parameters will be set to {error_score}"
+ )
+ with pytest.warns(UserWarning, match=warning_msg):
+ results = cross_validate(
+ clf,
+ X,
+ y,
+ cv=3,
+ scoring=scoring,
+ return_train_score=return_train_score,
+ error_score=error_score,
+ )
+ for key in results:
+ if "_score" in key:
+ if "_score_2" in key:
+ # check the test (and optionally train) score for the
+ # scorer that should be non-failing
+ for i in results[key]:
+ assert isinstance(i, float)
+ else:
+ # check the test (and optionally train) score for all
+ # scorers that should be assigned to `error_score`.
+ assert_allclose(results[key], error_score)
+
+
+def three_params_scorer(i, j, k):
+ return 3.4213
+
+
+@pytest.mark.parametrize(
+ "train_score, scorer, verbose, split_prg, cdt_prg, expected",
+ [
+ (
+ False,
+ three_params_scorer,
+ 2,
+ (1, 3),
+ (0, 1),
+ r"\[CV\] END ...................................................."
+ r" total time= 0.\ds",
+ ),
+ (
+ True,
+ {"sc1": three_params_scorer, "sc2": three_params_scorer},
+ 3,
+ (1, 3),
+ (0, 1),
+ r"\[CV 2/3\] END sc1: \(train=3.421, test=3.421\) sc2: "
+ r"\(train=3.421, test=3.421\) total time= 0.\ds",
+ ),
+ (
+ False,
+ {"sc1": three_params_scorer, "sc2": three_params_scorer},
+ 10,
+ (1, 3),
+ (0, 1),
+ r"\[CV 2/3; 1/1\] END ....... sc1: \(test=3.421\) sc2: \(test=3.421\)"
+ r" total time= 0.\ds",
+ ),
+ ],
+)
+def test_fit_and_score_verbosity(
+ capsys, train_score, scorer, verbose, split_prg, cdt_prg, expected
+):
+ X, y = make_classification(n_samples=30, random_state=0)
+ clf = SVC(kernel="linear", random_state=0)
+ train, test = next(ShuffleSplit().split(X))
+
+ # test print without train score
+ fit_and_score_args = [clf, X, y, scorer, train, test, verbose, None, None]
+ fit_and_score_kwargs = {
+ "return_train_score": train_score,
+ "split_progress": split_prg,
+ "candidate_progress": cdt_prg,
+ }
+ _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
+ out, _ = capsys.readouterr()
+ outlines = out.split("\n")
+ if len(outlines) > 2:
+ assert re.match(expected, outlines[1])
+ else:
+ assert re.match(expected, outlines[0])
+
+
+def test_score():
+ error_message = "scoring must return a number, got None"
+
+ def two_params_scorer(estimator, X_test):
+ return None
+
+ fit_and_score_args = [None, None, None, two_params_scorer]
+ with pytest.raises(ValueError, match=error_message):
+ _score(*fit_and_score_args, error_score=np.nan)
+
+
+def test_callable_multimetric_confusion_matrix_cross_validate():
+ def custom_scorer(clf, X, y):
+ y_pred = clf.predict(X)
+ cm = confusion_matrix(y, y_pred)
+ return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]}
+
+ X, y = make_classification(n_samples=40, n_features=4, random_state=42)
+ est = LinearSVC(random_state=42)
+ est.fit(X, y)
+ cv_results = cross_validate(est, X, y, cv=5, scoring=custom_scorer)
+
+ score_names = ["tn", "fp", "fn", "tp"]
+ for name in score_names:
+ assert "test_{}".format(name) in cv_results
+
+
+def test_learning_curve_partial_fit_regressors():
+ """Check that regressors with partial_fit is supported.
+
+ Non-regression test for #22981.
+ """
+ X, y = make_regression(random_state=42)
+
+ # Does not error
+ learning_curve(MLPRegressor(), X, y, exploit_incremental_learning=True, cv=2)
diff --git a/modin/pandas/test/interoperability/sklearn/neural_network/test_mlp.py b/modin/pandas/test/interoperability/sklearn/neural_network/test_mlp.py
new file mode 100644
index 00000000000..d4644c3b770
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/neural_network/test_mlp.py
@@ -0,0 +1,942 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# flake8: noqa
+
+"""
+Testing for Multi-layer Perceptron module (sklearn.neural_network)
+"""
+
+# Author: Issam H. Laradji
+# License: BSD 3 clause
+
+import pytest
+import sys
+import warnings
+import re
+
+import numpy as np
+import joblib
+
+from numpy.testing import (
+ assert_almost_equal,
+ assert_array_equal,
+ assert_allclose,
+)
+
+from sklearn.datasets import load_digits, load_iris
+from sklearn.datasets import make_regression, make_multilabel_classification
+from sklearn.exceptions import ConvergenceWarning
+from io import StringIO
+from sklearn.metrics import roc_auc_score
+from sklearn.neural_network import MLPClassifier
+from sklearn.neural_network import MLPRegressor
+from sklearn.preprocessing import LabelBinarizer
+from sklearn.preprocessing import MinMaxScaler, scale
+from scipy.sparse import csr_matrix
+from sklearn.utils._testing import ignore_warnings
+
+
+ACTIVATION_TYPES = ["identity", "logistic", "tanh", "relu"]
+
+X_digits, y_digits = load_digits(n_class=3, return_X_y=True)
+
+X_digits_multi = MinMaxScaler().fit_transform(X_digits[:200])
+y_digits_multi = y_digits[:200]
+
+X_digits, y_digits = load_digits(n_class=2, return_X_y=True)
+
+X_digits_binary = MinMaxScaler().fit_transform(X_digits[:200])
+y_digits_binary = y_digits[:200]
+
+classification_datasets = [
+ (X_digits_multi, y_digits_multi),
+ (X_digits_binary, y_digits_binary),
+]
+
+X_reg, y_reg = make_regression(
+ n_samples=200, n_features=10, bias=20.0, noise=100.0, random_state=7
+)
+y_reg = scale(y_reg)
+regression_datasets = [(X_reg, y_reg)]
+
+iris = load_iris()
+
+X_iris = iris.data
+y_iris = iris.target
+
+
+def test_alpha():
+ # Test that larger alpha yields weights closer to zero
+ X = X_digits_binary[:100]
+ y = y_digits_binary[:100]
+
+ alpha_vectors = []
+ alpha_values = np.arange(2)
+ absolute_sum = lambda x: np.sum(np.abs(x))
+
+ for alpha in alpha_values:
+ mlp = MLPClassifier(hidden_layer_sizes=10, alpha=alpha, random_state=1)
+ with ignore_warnings(category=ConvergenceWarning):
+ mlp.fit(X, y)
+ alpha_vectors.append(
+ np.array([absolute_sum(mlp.coefs_[0]), absolute_sum(mlp.coefs_[1])])
+ )
+
+ for i in range(len(alpha_values) - 1):
+ assert (alpha_vectors[i] > alpha_vectors[i + 1]).all()
+
+
+def test_fit():
+ # Test that the algorithm solution is equal to a worked out example.
+ X = np.array([[0.6, 0.8, 0.7]])
+ y = np.array([0])
+ mlp = MLPClassifier(
+ solver="sgd",
+ learning_rate_init=0.1,
+ alpha=0.1,
+ activation="logistic",
+ random_state=1,
+ max_iter=1,
+ hidden_layer_sizes=2,
+ momentum=0,
+ )
+ # set weights
+ mlp.coefs_ = [0] * 2
+ mlp.intercepts_ = [0] * 2
+ mlp.n_outputs_ = 1
+ mlp.coefs_[0] = np.array([[0.1, 0.2], [0.3, 0.1], [0.5, 0]])
+ mlp.coefs_[1] = np.array([[0.1], [0.2]])
+ mlp.intercepts_[0] = np.array([0.1, 0.1])
+ mlp.intercepts_[1] = np.array([1.0])
+ mlp._coef_grads = [] * 2
+ mlp._intercept_grads = [] * 2
+ mlp.n_features_in_ = 3
+
+ # Initialize parameters
+ mlp.n_iter_ = 0
+ mlp.learning_rate_ = 0.1
+
+ # Compute the number of layers
+ mlp.n_layers_ = 3
+
+ # Pre-allocate gradient matrices
+ mlp._coef_grads = [0] * (mlp.n_layers_ - 1)
+ mlp._intercept_grads = [0] * (mlp.n_layers_ - 1)
+
+ mlp.out_activation_ = "logistic"
+ mlp.t_ = 0
+ mlp.best_loss_ = np.inf
+ mlp.loss_curve_ = []
+ mlp._no_improvement_count = 0
+ mlp._intercept_velocity = [
+ np.zeros_like(intercepts) for intercepts in mlp.intercepts_
+ ]
+ mlp._coef_velocity = [np.zeros_like(coefs) for coefs in mlp.coefs_]
+
+ mlp.partial_fit(X, y, classes=[0, 1])
+ # Manually worked out example
+ # h1 = g(X1 * W_i1 + b11) = g(0.6 * 0.1 + 0.8 * 0.3 + 0.7 * 0.5 + 0.1)
+ # = 0.679178699175393
+ # h2 = g(X2 * W_i2 + b12) = g(0.6 * 0.2 + 0.8 * 0.1 + 0.7 * 0 + 0.1)
+ # = 0.574442516811659
+ # o1 = g(h * W2 + b21) = g(0.679 * 0.1 + 0.574 * 0.2 + 1)
+ # = 0.7654329236196236
+ # d21 = -(0 - 0.765) = 0.765
+ # d11 = (1 - 0.679) * 0.679 * 0.765 * 0.1 = 0.01667
+ # d12 = (1 - 0.574) * 0.574 * 0.765 * 0.2 = 0.0374
+ # W1grad11 = X1 * d11 + alpha * W11 = 0.6 * 0.01667 + 0.1 * 0.1 = 0.0200
+ # W1grad11 = X1 * d12 + alpha * W12 = 0.6 * 0.0374 + 0.1 * 0.2 = 0.04244
+ # W1grad21 = X2 * d11 + alpha * W13 = 0.8 * 0.01667 + 0.1 * 0.3 = 0.043336
+ # W1grad22 = X2 * d12 + alpha * W14 = 0.8 * 0.0374 + 0.1 * 0.1 = 0.03992
+ # W1grad31 = X3 * d11 + alpha * W15 = 0.6 * 0.01667 + 0.1 * 0.5 = 0.060002
+ # W1grad32 = X3 * d12 + alpha * W16 = 0.6 * 0.0374 + 0.1 * 0 = 0.02244
+ # W2grad1 = h1 * d21 + alpha * W21 = 0.679 * 0.765 + 0.1 * 0.1 = 0.5294
+ # W2grad2 = h2 * d21 + alpha * W22 = 0.574 * 0.765 + 0.1 * 0.2 = 0.45911
+ # b1grad1 = d11 = 0.01667
+ # b1grad2 = d12 = 0.0374
+ # b2grad = d21 = 0.765
+ # W1 = W1 - eta * [W1grad11, .., W1grad32] = [[0.1, 0.2], [0.3, 0.1],
+ # [0.5, 0]] - 0.1 * [[0.0200, 0.04244], [0.043336, 0.03992],
+ # [0.060002, 0.02244]] = [[0.098, 0.195756], [0.2956664,
+ # 0.096008], [0.4939998, -0.002244]]
+ # W2 = W2 - eta * [W2grad1, W2grad2] = [[0.1], [0.2]] - 0.1 *
+ # [[0.5294], [0.45911]] = [[0.04706], [0.154089]]
+ # b1 = b1 - eta * [b1grad1, b1grad2] = 0.1 - 0.1 * [0.01667, 0.0374]
+ # = [0.098333, 0.09626]
+ # b2 = b2 - eta * b2grad = 1.0 - 0.1 * 0.765 = 0.9235
+ assert_almost_equal(
+ mlp.coefs_[0],
+ np.array([[0.098, 0.195756], [0.2956664, 0.096008], [0.4939998, -0.002244]]),
+ decimal=3,
+ )
+ assert_almost_equal(mlp.coefs_[1], np.array([[0.04706], [0.154089]]), decimal=3)
+ assert_almost_equal(mlp.intercepts_[0], np.array([0.098333, 0.09626]), decimal=3)
+ assert_almost_equal(mlp.intercepts_[1], np.array(0.9235), decimal=3)
+ # Testing output
+ # h1 = g(X1 * W_i1 + b11) = g(0.6 * 0.098 + 0.8 * 0.2956664 +
+ # 0.7 * 0.4939998 + 0.098333) = 0.677
+ # h2 = g(X2 * W_i2 + b12) = g(0.6 * 0.195756 + 0.8 * 0.096008 +
+ # 0.7 * -0.002244 + 0.09626) = 0.572
+ # o1 = h * W2 + b21 = 0.677 * 0.04706 +
+ # 0.572 * 0.154089 + 0.9235 = 1.043
+ # prob = sigmoid(o1) = 0.739
+ assert_almost_equal(mlp.predict_proba(X)[0, 1], 0.739, decimal=3)
+
+
+def test_gradient():
+ # Test gradient.
+
+ # This makes sure that the activation functions and their derivatives
+ # are correct. The numerical and analytical computation of the gradient
+ # should be close.
+ for n_labels in [2, 3]:
+ n_samples = 5
+ n_features = 10
+ random_state = np.random.RandomState(seed=42)
+ X = random_state.rand(n_samples, n_features)
+ y = 1 + np.mod(np.arange(n_samples) + 1, n_labels)
+ Y = LabelBinarizer().fit_transform(y)
+
+ for activation in ACTIVATION_TYPES:
+ mlp = MLPClassifier(
+ activation=activation,
+ hidden_layer_sizes=10,
+ solver="lbfgs",
+ alpha=1e-5,
+ learning_rate_init=0.2,
+ max_iter=1,
+ random_state=1,
+ )
+ mlp.fit(X, y)
+
+ theta = np.hstack([l.ravel() for l in mlp.coefs_ + mlp.intercepts_])
+
+ layer_units = [X.shape[1]] + [mlp.hidden_layer_sizes] + [mlp.n_outputs_]
+
+ activations = []
+ deltas = []
+ coef_grads = []
+ intercept_grads = []
+
+ activations.append(X)
+ for i in range(mlp.n_layers_ - 1):
+ activations.append(np.empty((X.shape[0], layer_units[i + 1])))
+ deltas.append(np.empty((X.shape[0], layer_units[i + 1])))
+
+ fan_in = layer_units[i]
+ fan_out = layer_units[i + 1]
+ coef_grads.append(np.empty((fan_in, fan_out)))
+ intercept_grads.append(np.empty(fan_out))
+
+ # analytically compute the gradients
+ def loss_grad_fun(t):
+ return mlp._loss_grad_lbfgs(
+ t, X, Y, activations, deltas, coef_grads, intercept_grads
+ )
+
+ [value, grad] = loss_grad_fun(theta)
+ numgrad = np.zeros(np.size(theta))
+ n = np.size(theta, 0)
+ E = np.eye(n)
+ epsilon = 1e-5
+ # numerically compute the gradients
+ for i in range(n):
+ dtheta = E[:, i] * epsilon
+ numgrad[i] = (
+ loss_grad_fun(theta + dtheta)[0] - loss_grad_fun(theta - dtheta)[0]
+ ) / (epsilon * 2.0)
+ assert_almost_equal(numgrad, grad)
+
+
+@pytest.mark.parametrize("X,y", classification_datasets)
+def test_lbfgs_classification(X, y):
+ # Test lbfgs on classification.
+ # It should achieve a score higher than 0.95 for the binary and multi-class
+ # versions of the digits dataset.
+ X_train = X[:150]
+ y_train = y[:150]
+ X_test = X[150:]
+ expected_shape_dtype = (X_test.shape[0], y_train.dtype.kind)
+
+ for activation in ACTIVATION_TYPES:
+ mlp = MLPClassifier(
+ solver="lbfgs",
+ hidden_layer_sizes=50,
+ max_iter=150,
+ shuffle=True,
+ random_state=1,
+ activation=activation,
+ )
+ mlp.fit(X_train, y_train)
+ y_predict = mlp.predict(X_test)
+ assert mlp.score(X_train, y_train) > 0.95
+ assert (y_predict.shape[0], y_predict.dtype.kind) == expected_shape_dtype
+
+
+@pytest.mark.parametrize("X,y", regression_datasets)
+def test_lbfgs_regression(X, y):
+ # Test lbfgs on the regression dataset.
+ for activation in ACTIVATION_TYPES:
+ mlp = MLPRegressor(
+ solver="lbfgs",
+ hidden_layer_sizes=50,
+ max_iter=150,
+ shuffle=True,
+ random_state=1,
+ activation=activation,
+ )
+ mlp.fit(X, y)
+ if activation == "identity":
+ assert mlp.score(X, y) > 0.80
+ else:
+ # Non linear models perform much better than linear bottleneck:
+ assert mlp.score(X, y) > 0.98
+
+
+@pytest.mark.parametrize("X,y", classification_datasets)
+def test_lbfgs_classification_maxfun(X, y):
+ # Test lbfgs parameter max_fun.
+ # It should independently limit the number of iterations for lbfgs.
+ max_fun = 10
+ # classification tests
+ for activation in ACTIVATION_TYPES:
+ mlp = MLPClassifier(
+ solver="lbfgs",
+ hidden_layer_sizes=50,
+ max_iter=150,
+ max_fun=max_fun,
+ shuffle=True,
+ random_state=1,
+ activation=activation,
+ )
+ with pytest.warns(ConvergenceWarning):
+ mlp.fit(X, y)
+ assert max_fun >= mlp.n_iter_
+
+
+@pytest.mark.parametrize("X,y", regression_datasets)
+def test_lbfgs_regression_maxfun(X, y):
+ # Test lbfgs parameter max_fun.
+ # It should independently limit the number of iterations for lbfgs.
+ max_fun = 10
+ # regression tests
+ for activation in ACTIVATION_TYPES:
+ mlp = MLPRegressor(
+ solver="lbfgs",
+ hidden_layer_sizes=50,
+ tol=0.0,
+ max_iter=150,
+ max_fun=max_fun,
+ shuffle=True,
+ random_state=1,
+ activation=activation,
+ )
+ with pytest.warns(ConvergenceWarning):
+ mlp.fit(X, y)
+ assert max_fun >= mlp.n_iter_
+
+
+def test_learning_rate_warmstart():
+ # Tests that warm_start reuse past solutions.
+ X = [[3, 2], [1, 6], [5, 6], [-2, -4]]
+ y = [1, 1, 1, 0]
+ for learning_rate in ["invscaling", "constant"]:
+ mlp = MLPClassifier(
+ solver="sgd",
+ hidden_layer_sizes=4,
+ learning_rate=learning_rate,
+ max_iter=1,
+ power_t=0.25,
+ warm_start=True,
+ )
+ with ignore_warnings(category=ConvergenceWarning):
+ mlp.fit(X, y)
+ prev_eta = mlp._optimizer.learning_rate
+ mlp.fit(X, y)
+ post_eta = mlp._optimizer.learning_rate
+
+ if learning_rate == "constant":
+ assert prev_eta == post_eta
+ elif learning_rate == "invscaling":
+ assert mlp.learning_rate_init / pow(8 + 1, mlp.power_t) == post_eta
+
+
+def test_multilabel_classification():
+ # Test that multi-label classification works as expected.
+ # test fit method
+ X, y = make_multilabel_classification(
+ n_samples=50, random_state=0, return_indicator=True
+ )
+ mlp = MLPClassifier(
+ solver="lbfgs",
+ hidden_layer_sizes=50,
+ alpha=1e-5,
+ max_iter=150,
+ random_state=0,
+ activation="logistic",
+ learning_rate_init=0.2,
+ )
+ mlp.fit(X, y)
+ assert mlp.score(X, y) > 0.97
+
+ # test partial fit method
+ mlp = MLPClassifier(
+ solver="sgd",
+ hidden_layer_sizes=50,
+ max_iter=150,
+ random_state=0,
+ activation="logistic",
+ alpha=1e-5,
+ learning_rate_init=0.2,
+ )
+ for i in range(100):
+ mlp.partial_fit(X, y, classes=[0, 1, 2, 3, 4])
+ assert mlp.score(X, y) > 0.9
+
+ # Make sure early stopping still work now that splitting is stratified by
+ # default (it is disabled for multilabel classification)
+ mlp = MLPClassifier(early_stopping=True)
+ mlp.fit(X, y).predict(X)
+
+
+def test_multioutput_regression():
+ # Test that multi-output regression works as expected
+ X, y = make_regression(n_samples=200, n_targets=5)
+ mlp = MLPRegressor(
+ solver="lbfgs", hidden_layer_sizes=50, max_iter=200, random_state=1
+ )
+ mlp.fit(X, y)
+ assert mlp.score(X, y) > 0.9
+
+
+def test_partial_fit_classes_error():
+ # Tests that passing different classes to partial_fit raises an error
+ X = [[3, 2]]
+ y = [0]
+ clf = MLPClassifier(solver="sgd")
+ clf.partial_fit(X, y, classes=[0, 1])
+ with pytest.raises(ValueError):
+ clf.partial_fit(X, y, classes=[1, 2])
+
+
+def test_partial_fit_classification():
+ # Test partial_fit on classification.
+ # `partial_fit` should yield the same results as 'fit' for binary and
+ # multi-class classification.
+ for X, y in classification_datasets:
+ mlp = MLPClassifier(
+ solver="sgd",
+ max_iter=100,
+ random_state=1,
+ tol=0,
+ alpha=1e-5,
+ learning_rate_init=0.2,
+ )
+
+ with ignore_warnings(category=ConvergenceWarning):
+ mlp.fit(X, y)
+ pred1 = mlp.predict(X)
+ mlp = MLPClassifier(
+ solver="sgd", random_state=1, alpha=1e-5, learning_rate_init=0.2
+ )
+ for i in range(100):
+ mlp.partial_fit(X, y, classes=np.unique(y))
+ pred2 = mlp.predict(X)
+ assert_array_equal(pred1, pred2)
+ assert mlp.score(X, y) > 0.95
+
+
+def test_partial_fit_unseen_classes():
+ # Non regression test for bug 6994
+ # Tests for labeling errors in partial fit
+
+ clf = MLPClassifier(random_state=0)
+ clf.partial_fit([[1], [2], [3]], ["a", "b", "c"], classes=["a", "b", "c", "d"])
+ clf.partial_fit([[4]], ["d"])
+ assert clf.score([[1], [2], [3], [4]], ["a", "b", "c", "d"]) > 0
+
+
+def test_partial_fit_regression():
+ # Test partial_fit on regression.
+ # `partial_fit` should yield the same results as 'fit' for regression.
+ X = X_reg
+ y = y_reg
+
+ for momentum in [0, 0.9]:
+ mlp = MLPRegressor(
+ solver="sgd",
+ max_iter=100,
+ activation="relu",
+ random_state=1,
+ learning_rate_init=0.01,
+ batch_size=X.shape[0],
+ momentum=momentum,
+ )
+ with warnings.catch_warnings(record=True):
+ # catch convergence warning
+ mlp.fit(X, y)
+ pred1 = mlp.predict(X)
+ mlp = MLPRegressor(
+ solver="sgd",
+ activation="relu",
+ learning_rate_init=0.01,
+ random_state=1,
+ batch_size=X.shape[0],
+ momentum=momentum,
+ )
+ for i in range(100):
+ mlp.partial_fit(X, y)
+
+ pred2 = mlp.predict(X)
+ assert_allclose(pred1, pred2)
+ score = mlp.score(X, y)
+ assert score > 0.65
+
+
+def test_partial_fit_errors():
+ # Test partial_fit error handling.
+ X = [[3, 2], [1, 6]]
+ y = [1, 0]
+
+ # no classes passed
+ with pytest.raises(ValueError):
+ MLPClassifier(solver="sgd").partial_fit(X, y, classes=[2])
+
+ # lbfgs doesn't support partial_fit
+ assert not hasattr(MLPClassifier(solver="lbfgs"), "partial_fit")
+
+
+def test_nonfinite_params():
+ # Check that MLPRegressor throws ValueError when dealing with non-finite
+ # parameter values
+ rng = np.random.RandomState(0)
+ n_samples = 10
+ fmax = np.finfo(np.float64).max
+ X = fmax * rng.uniform(size=(n_samples, 2))
+ y = rng.standard_normal(size=n_samples)
+
+ clf = MLPRegressor()
+ msg = (
+ "Solver produced non-finite parameter weights. The input data may contain large"
+ " values and need to be preprocessed."
+ )
+ with pytest.raises(ValueError, match=msg):
+ clf.fit(X, y)
+
+
+def test_predict_proba_binary():
+ # Test that predict_proba works as expected for binary class.
+ X = X_digits_binary[:50]
+ y = y_digits_binary[:50]
+
+ clf = MLPClassifier(hidden_layer_sizes=5, activation="logistic", random_state=1)
+ with ignore_warnings(category=ConvergenceWarning):
+ clf.fit(X, y)
+ y_proba = clf.predict_proba(X)
+ y_log_proba = clf.predict_log_proba(X)
+
+ (n_samples, n_classes) = y.shape[0], 2
+
+ proba_max = y_proba.argmax(axis=1)
+ proba_log_max = y_log_proba.argmax(axis=1)
+
+ assert y_proba.shape == (n_samples, n_classes)
+ assert_array_equal(proba_max, proba_log_max)
+ assert_allclose(y_log_proba, np.log(y_proba))
+
+ assert roc_auc_score(y, y_proba[:, 1]) == 1.0
+
+
+def test_predict_proba_multiclass():
+ # Test that predict_proba works as expected for multi class.
+ X = X_digits_multi[:10]
+ y = y_digits_multi[:10]
+
+ clf = MLPClassifier(hidden_layer_sizes=5)
+ with ignore_warnings(category=ConvergenceWarning):
+ clf.fit(X, y)
+ y_proba = clf.predict_proba(X)
+ y_log_proba = clf.predict_log_proba(X)
+
+ (n_samples, n_classes) = y.shape[0], np.unique(y).size
+
+ proba_max = y_proba.argmax(axis=1)
+ proba_log_max = y_log_proba.argmax(axis=1)
+
+ assert y_proba.shape == (n_samples, n_classes)
+ assert_array_equal(proba_max, proba_log_max)
+ assert_allclose(y_log_proba, np.log(y_proba))
+
+
+def test_predict_proba_multilabel():
+ # Test that predict_proba works as expected for multilabel.
+ # Multilabel should not use softmax which makes probabilities sum to 1
+ X, Y = make_multilabel_classification(
+ n_samples=50, random_state=0, return_indicator=True
+ )
+ n_samples, n_classes = Y.shape
+
+ clf = MLPClassifier(solver="lbfgs", hidden_layer_sizes=30, random_state=0)
+ clf.fit(X, Y)
+ y_proba = clf.predict_proba(X)
+
+ assert y_proba.shape == (n_samples, n_classes)
+ assert_array_equal(y_proba > 0.5, Y)
+
+ y_log_proba = clf.predict_log_proba(X)
+ proba_max = y_proba.argmax(axis=1)
+ proba_log_max = y_log_proba.argmax(axis=1)
+
+ assert (y_proba.sum(1) - 1).dot(y_proba.sum(1) - 1) > 1e-10
+ assert_array_equal(proba_max, proba_log_max)
+ assert_allclose(y_log_proba, np.log(y_proba))
+
+
+def test_shuffle():
+ # Test that the shuffle parameter affects the training process (it should)
+ X, y = make_regression(n_samples=50, n_features=5, n_targets=1, random_state=0)
+
+ # The coefficients will be identical if both do or do not shuffle
+ for shuffle in [True, False]:
+ mlp1 = MLPRegressor(
+ hidden_layer_sizes=1,
+ max_iter=1,
+ batch_size=1,
+ random_state=0,
+ shuffle=shuffle,
+ )
+ mlp2 = MLPRegressor(
+ hidden_layer_sizes=1,
+ max_iter=1,
+ batch_size=1,
+ random_state=0,
+ shuffle=shuffle,
+ )
+ mlp1.fit(X, y)
+ mlp2.fit(X, y)
+
+ assert np.array_equal(mlp1.coefs_[0], mlp2.coefs_[0])
+
+ # The coefficients will be slightly different if shuffle=True
+ mlp1 = MLPRegressor(
+ hidden_layer_sizes=1, max_iter=1, batch_size=1, random_state=0, shuffle=True
+ )
+ mlp2 = MLPRegressor(
+ hidden_layer_sizes=1, max_iter=1, batch_size=1, random_state=0, shuffle=False
+ )
+ mlp1.fit(X, y)
+ mlp2.fit(X, y)
+
+ assert not np.array_equal(mlp1.coefs_[0], mlp2.coefs_[0])
+
+
+def test_sparse_matrices():
+ # Test that sparse and dense input matrices output the same results.
+ X = X_digits_binary[:50]
+ y = y_digits_binary[:50]
+ X_sparse = csr_matrix(X)
+ mlp = MLPClassifier(solver="lbfgs", hidden_layer_sizes=15, random_state=1)
+ mlp.fit(X, y)
+ pred1 = mlp.predict(X)
+ mlp.fit(X_sparse, y)
+ pred2 = mlp.predict(X_sparse)
+ assert_almost_equal(pred1, pred2)
+ pred1 = mlp.predict(X)
+ pred2 = mlp.predict(X_sparse)
+ assert_array_equal(pred1, pred2)
+
+
+def test_tolerance():
+ # Test tolerance.
+ # It should force the solver to exit the loop when it converges.
+ X = [[3, 2], [1, 6]]
+ y = [1, 0]
+ clf = MLPClassifier(tol=0.5, max_iter=3000, solver="sgd")
+ clf.fit(X, y)
+ assert clf.max_iter > clf.n_iter_
+
+
+def test_verbose_sgd():
+ # Test verbose.
+ X = [[3, 2], [1, 6]]
+ y = [1, 0]
+ clf = MLPClassifier(solver="sgd", max_iter=2, verbose=10, hidden_layer_sizes=2)
+ old_stdout = sys.stdout
+ sys.stdout = output = StringIO()
+
+ with ignore_warnings(category=ConvergenceWarning):
+ clf.fit(X, y)
+ clf.partial_fit(X, y)
+
+ sys.stdout = old_stdout
+ assert "Iteration" in output.getvalue()
+
+
+@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
+def test_early_stopping(MLPEstimator):
+ X = X_digits_binary[:100]
+ y = y_digits_binary[:100]
+ tol = 0.2
+ mlp_estimator = MLPEstimator(
+ tol=tol, max_iter=3000, solver="sgd", early_stopping=True
+ )
+ mlp_estimator.fit(X, y)
+ assert mlp_estimator.max_iter > mlp_estimator.n_iter_
+
+ assert mlp_estimator.best_loss_ is None
+ assert isinstance(mlp_estimator.validation_scores_, list)
+
+ valid_scores = mlp_estimator.validation_scores_
+ best_valid_score = mlp_estimator.best_validation_score_
+ assert max(valid_scores) == best_valid_score
+ assert best_valid_score + tol > valid_scores[-2]
+ assert best_valid_score + tol > valid_scores[-1]
+
+ # check that the attributes `validation_scores_` and `best_validation_score_`
+ # are set to None when `early_stopping=False`
+ mlp_estimator = MLPEstimator(
+ tol=tol, max_iter=3000, solver="sgd", early_stopping=False
+ )
+ mlp_estimator.fit(X, y)
+ assert mlp_estimator.validation_scores_ is None
+ assert mlp_estimator.best_validation_score_ is None
+ assert mlp_estimator.best_loss_ is not None
+
+
+def test_adaptive_learning_rate():
+ X = [[3, 2], [1, 6]]
+ y = [1, 0]
+ clf = MLPClassifier(tol=0.5, max_iter=3000, solver="sgd", learning_rate="adaptive")
+ clf.fit(X, y)
+ assert clf.max_iter > clf.n_iter_
+ assert 1e-6 > clf._optimizer.learning_rate
+
+
+@ignore_warnings(category=RuntimeWarning)
+def test_warm_start():
+ X = X_iris
+ y = y_iris
+
+ y_2classes = np.array([0] * 75 + [1] * 75)
+ y_3classes = np.array([0] * 40 + [1] * 40 + [2] * 70)
+ y_3classes_alt = np.array([0] * 50 + [1] * 50 + [3] * 50)
+ y_4classes = np.array([0] * 37 + [1] * 37 + [2] * 38 + [3] * 38)
+ y_5classes = np.array([0] * 30 + [1] * 30 + [2] * 30 + [3] * 30 + [4] * 30)
+
+ # No error raised
+ clf = MLPClassifier(hidden_layer_sizes=2, solver="lbfgs", warm_start=True).fit(X, y)
+ clf.fit(X, y)
+ clf.fit(X, y_3classes)
+
+ for y_i in (y_2classes, y_3classes_alt, y_4classes, y_5classes):
+ clf = MLPClassifier(hidden_layer_sizes=2, solver="lbfgs", warm_start=True).fit(
+ X, y
+ )
+ message = (
+ "warm_start can only be used where `y` has the same "
+ "classes as in the previous call to fit."
+ " Previously got [0 1 2], `y` has %s" % np.unique(y_i)
+ )
+ with pytest.raises(ValueError, match=re.escape(message)):
+ clf.fit(X, y_i)
+
+
+@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
+def test_warm_start_full_iteration(MLPEstimator):
+ # Non-regression test for:
+ # https://github.com/scikit-learn/scikit-learn/issues/16812
+ # Check that the MLP estimator accomplish `max_iter` with a
+ # warm started estimator.
+ X, y = X_iris, y_iris
+ max_iter = 3
+ clf = MLPEstimator(
+ hidden_layer_sizes=2, solver="sgd", warm_start=True, max_iter=max_iter
+ )
+ clf.fit(X, y)
+ assert max_iter == clf.n_iter_
+ clf.fit(X, y)
+ assert 2 * max_iter == clf.n_iter_
+
+
+def test_n_iter_no_change():
+ # test n_iter_no_change using binary data set
+ # the classifying fitting process is not prone to loss curve fluctuations
+ X = X_digits_binary[:100]
+ y = y_digits_binary[:100]
+ tol = 0.01
+ max_iter = 3000
+
+ # test multiple n_iter_no_change
+ for n_iter_no_change in [2, 5, 10, 50, 100]:
+ clf = MLPClassifier(
+ tol=tol, max_iter=max_iter, solver="sgd", n_iter_no_change=n_iter_no_change
+ )
+ clf.fit(X, y)
+
+ # validate n_iter_no_change
+ assert clf._no_improvement_count == n_iter_no_change + 1
+ assert max_iter > clf.n_iter_
+
+
+@ignore_warnings(category=ConvergenceWarning)
+def test_n_iter_no_change_inf():
+ # test n_iter_no_change using binary data set
+ # the fitting process should go to max_iter iterations
+ X = X_digits_binary[:100]
+ y = y_digits_binary[:100]
+
+ # set a ridiculous tolerance
+ # this should always trigger _update_no_improvement_count()
+ tol = 1e9
+
+ # fit
+ n_iter_no_change = np.inf
+ max_iter = 3000
+ clf = MLPClassifier(
+ tol=tol, max_iter=max_iter, solver="sgd", n_iter_no_change=n_iter_no_change
+ )
+ clf.fit(X, y)
+
+ # validate n_iter_no_change doesn't cause early stopping
+ assert clf.n_iter_ == max_iter
+
+ # validate _update_no_improvement_count() was always triggered
+ assert clf._no_improvement_count == clf.n_iter_ - 1
+
+
+def test_early_stopping_stratified():
+ # Make sure data splitting for early stopping is stratified
+ X = [[1, 2], [2, 3], [3, 4], [4, 5]]
+ y = [0, 0, 0, 1]
+
+ mlp = MLPClassifier(early_stopping=True)
+ with pytest.raises(
+ ValueError, match="The least populated class in y has only 1 member"
+ ):
+ mlp.fit(X, y)
+
+
+def test_mlp_classifier_dtypes_casting():
+ # Compare predictions for different dtypes
+ mlp_64 = MLPClassifier(
+ alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=50
+ )
+ mlp_64.fit(X_digits[:300], y_digits[:300])
+ pred_64 = mlp_64.predict(X_digits[300:])
+ proba_64 = mlp_64.predict_proba(X_digits[300:])
+
+ mlp_32 = MLPClassifier(
+ alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=50
+ )
+ mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300])
+ pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32))
+ proba_32 = mlp_32.predict_proba(X_digits[300:].astype(np.float32))
+
+ assert_array_equal(pred_64, pred_32)
+ assert_allclose(proba_64, proba_32, rtol=1e-02)
+
+
+def test_mlp_regressor_dtypes_casting():
+ mlp_64 = MLPRegressor(
+ alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=50
+ )
+ mlp_64.fit(X_digits[:300], y_digits[:300])
+ pred_64 = mlp_64.predict(X_digits[300:])
+
+ mlp_32 = MLPRegressor(
+ alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=50
+ )
+ mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300])
+ pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32))
+
+ assert_allclose(pred_64, pred_32, rtol=1e-04)
+
+
+@pytest.mark.parametrize("dtype", [np.float32, np.float64])
+@pytest.mark.parametrize("Estimator", [MLPClassifier, MLPRegressor])
+def test_mlp_param_dtypes(dtype, Estimator):
+ # Checks if input dtype is used for network parameters
+ # and predictions
+ X, y = X_digits.astype(dtype), y_digits
+ mlp = Estimator(alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=50)
+ mlp.fit(X[:300], y[:300])
+ pred = mlp.predict(X[300:])
+
+ assert all([intercept.dtype == dtype for intercept in mlp.intercepts_])
+
+ assert all([coef.dtype == dtype for coef in mlp.coefs_])
+
+ if Estimator == MLPRegressor:
+ assert pred.dtype == dtype
+
+
+def test_mlp_loading_from_joblib_partial_fit(tmp_path):
+ """Loading from MLP and partial fitting updates weights. Non-regression
+ test for #19626."""
+ pre_trained_estimator = MLPRegressor(
+ hidden_layer_sizes=(42,), random_state=42, learning_rate_init=0.01, max_iter=200
+ )
+ features, target = [[2]], [4]
+
+ # Fit on x=2, y=4
+ pre_trained_estimator.fit(features, target)
+
+ # dump and load model
+ pickled_file = tmp_path / "mlp.pkl"
+ joblib.dump(pre_trained_estimator, pickled_file)
+ load_estimator = joblib.load(pickled_file)
+
+ # Train for a more epochs on point x=2, y=1
+ fine_tune_features, fine_tune_target = [[2]], [1]
+
+ for _ in range(200):
+ load_estimator.partial_fit(fine_tune_features, fine_tune_target)
+
+ # finetuned model learned the new target
+ predicted_value = load_estimator.predict(fine_tune_features)
+ assert_allclose(predicted_value, fine_tune_target, rtol=1e-4)
+
+
+@pytest.mark.parametrize("Estimator", [MLPClassifier, MLPRegressor])
+def test_preserve_feature_names(Estimator):
+ """Check that feature names are preserved when early stopping is enabled.
+
+ Feature names are required for consistency checks during scoring.
+
+ Non-regression test for gh-24846
+ """
+ pd = pytest.importorskip("modin.pandas")
+ rng = np.random.RandomState(0)
+
+ X = pd.DataFrame(data=rng.randn(10, 2), columns=["colname_a", "colname_b"])
+ y = pd.Series(data=np.full(10, 1), name="colname_y")
+
+ model = Estimator(early_stopping=True, validation_fraction=0.2)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", UserWarning)
+ model.fit(X, y)
+
+
+@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
+def test_mlp_warm_start_with_early_stopping(MLPEstimator):
+ """Check that early stopping works with warm start."""
+ mlp = MLPEstimator(
+ max_iter=10, random_state=0, warm_start=True, early_stopping=True
+ )
+ mlp.fit(X_iris, y_iris)
+ n_validation_scores = len(mlp.validation_scores_)
+ mlp.set_params(max_iter=20)
+ mlp.fit(X_iris, y_iris)
+ assert len(mlp.validation_scores_) > n_validation_scores
diff --git a/modin/pandas/test/interoperability/sklearn/tests/test_base.py b/modin/pandas/test/interoperability/sklearn/tests/test_base.py
new file mode 100644
index 00000000000..09e1d232079
--- /dev/null
+++ b/modin/pandas/test/interoperability/sklearn/tests/test_base.py
@@ -0,0 +1,773 @@
+# Licensed to Modin Development Team under one or more contributor license agreements.
+# See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The Modin Development Team licenses this file to you under the
+# Apache License, Version 2.0 (the "License"); you may not use this file except in
+# compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under
+# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific language
+# governing permissions and limitations under the License.
+
+# Author: Gael Varoquaux
+# License: BSD 3 clause
+
+# flake8: noqa
+
+import re
+import numpy as np
+import scipy.sparse as sp
+import pytest
+import warnings
+from numpy.testing import assert_allclose
+
+import sklearn
+from sklearn.utils._testing import assert_array_equal
+from sklearn.utils._testing import assert_no_warnings
+from sklearn.utils._testing import ignore_warnings
+
+from sklearn.base import BaseEstimator, clone, is_classifier
+from sklearn.svm import SVC
+from sklearn.preprocessing import StandardScaler
+from sklearn.utils._set_output import _get_output_config
+from sklearn.pipeline import Pipeline
+from sklearn.decomposition import PCA
+from sklearn.model_selection import GridSearchCV
+
+from sklearn.tree import DecisionTreeClassifier
+from sklearn.tree import DecisionTreeRegressor
+from sklearn import datasets
+
+from sklearn.base import TransformerMixin
+from sklearn.utils._mocking import MockDataFrame
+from sklearn import config_context
+import pickle
+
+
+#############################################################################
+# A few test classes
+class MyEstimator(BaseEstimator):
+ def __init__(self, l1=0, empty=None):
+ self.l1 = l1
+ self.empty = empty
+
+
+class K(BaseEstimator):
+ def __init__(self, c=None, d=None):
+ self.c = c
+ self.d = d
+
+
+class T(BaseEstimator):
+ def __init__(self, a=None, b=None):
+ self.a = a
+ self.b = b
+
+
+class NaNTag(BaseEstimator):
+ def _more_tags(self):
+ return {"allow_nan": True}
+
+
+class NoNaNTag(BaseEstimator):
+ def _more_tags(self):
+ return {"allow_nan": False}
+
+
+class OverrideTag(NaNTag):
+ def _more_tags(self):
+ return {"allow_nan": False}
+
+
+class DiamondOverwriteTag(NaNTag, NoNaNTag):
+ def _more_tags(self):
+ return dict()
+
+
+class InheritDiamondOverwriteTag(DiamondOverwriteTag):
+ pass
+
+
+class ModifyInitParams(BaseEstimator):
+ """Deprecated behavior.
+ Equal parameters but with a type cast.
+ Doesn't fulfill a is a
+ """
+
+ def __init__(self, a=np.array([0])):
+ self.a = a.copy()
+
+
+class Buggy(BaseEstimator):
+ "A buggy estimator that does not set its parameters right."
+
+ def __init__(self, a=None):
+ self.a = 1
+
+
+class NoEstimator:
+ def __init__(self):
+ pass
+
+ def fit(self, X=None, y=None):
+ return self
+
+ def predict(self, X=None):
+ return None
+
+
+class VargEstimator(BaseEstimator):
+ """scikit-learn estimators shouldn't have vargs."""
+
+ def __init__(self, *vargs):
+ pass
+
+
+#############################################################################
+# The tests
+
+
+def test_clone():
+ # Tests that clone creates a correct deep copy.
+ # We create an estimator, make a copy of its original state
+ # (which, in this case, is the current state of the estimator),
+ # and check that the obtained copy is a correct deep copy.
+
+ from sklearn.feature_selection import SelectFpr, f_classif
+
+ selector = SelectFpr(f_classif, alpha=0.1)
+ new_selector = clone(selector)
+ assert selector is not new_selector
+ assert selector.get_params() == new_selector.get_params()
+
+ selector = SelectFpr(f_classif, alpha=np.zeros((10, 2)))
+ new_selector = clone(selector)
+ assert selector is not new_selector
+
+
+def test_clone_2():
+ # Tests that clone doesn't copy everything.
+ # We first create an estimator, give it an own attribute, and
+ # make a copy of its original state. Then we check that the copy doesn't
+ # have the specific attribute we manually added to the initial estimator.
+
+ from sklearn.feature_selection import SelectFpr, f_classif
+
+ selector = SelectFpr(f_classif, alpha=0.1)
+ selector.own_attribute = "test"
+ new_selector = clone(selector)
+ assert not hasattr(new_selector, "own_attribute")
+
+
+def test_clone_buggy():
+ # Check that clone raises an error on buggy estimators.
+ buggy = Buggy()
+ buggy.a = 2
+ with pytest.raises(RuntimeError):
+ clone(buggy)
+
+ no_estimator = NoEstimator()
+ with pytest.raises(TypeError):
+ clone(no_estimator)
+
+ varg_est = VargEstimator()
+ with pytest.raises(RuntimeError):
+ clone(varg_est)
+
+ est = ModifyInitParams()
+ with pytest.raises(RuntimeError):
+ clone(est)
+
+
+def test_clone_empty_array():
+ # Regression test for cloning estimators with empty arrays
+ clf = MyEstimator(empty=np.array([]))
+ clf2 = clone(clf)
+ assert_array_equal(clf.empty, clf2.empty)
+
+ clf = MyEstimator(empty=sp.csr_matrix(np.array([[0]])))
+ clf2 = clone(clf)
+ assert_array_equal(clf.empty.data, clf2.empty.data)
+
+
+def test_clone_nan():
+ # Regression test for cloning estimators with default parameter as np.nan
+ clf = MyEstimator(empty=np.nan)
+ clf2 = clone(clf)
+
+ assert clf.empty is clf2.empty
+
+
+def test_clone_sparse_matrices():
+ sparse_matrix_classes = [
+ getattr(sp, name) for name in dir(sp) if name.endswith("_matrix")
+ ]
+
+ for cls in sparse_matrix_classes:
+ sparse_matrix = cls(np.eye(5))
+ clf = MyEstimator(empty=sparse_matrix)
+ clf_cloned = clone(clf)
+ assert clf.empty.__class__ is clf_cloned.empty.__class__
+ assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray())
+
+
+def test_clone_estimator_types():
+ # Check that clone works for parameters that are types rather than
+ # instances
+ clf = MyEstimator(empty=MyEstimator)
+ clf2 = clone(clf)
+
+ assert clf.empty is clf2.empty
+
+
+def test_clone_class_rather_than_instance():
+ # Check that clone raises expected error message when
+ # cloning class rather than instance
+ msg = "You should provide an instance of scikit-learn estimator"
+ with pytest.raises(TypeError, match=msg):
+ clone(MyEstimator)
+
+
+def test_repr():
+ # Smoke test the repr of the base estimator.
+ my_estimator = MyEstimator()
+ repr(my_estimator)
+ test = T(K(), K())
+ assert repr(test) == "T(a=K(), b=K())"
+
+ some_est = T(a=["long_params"] * 1000)
+ assert len(repr(some_est)) == 485
+
+
+def test_str():
+ # Smoke test the str of the base estimator
+ my_estimator = MyEstimator()
+ str(my_estimator)
+
+
+def test_get_params():
+ test = T(K(), K)
+
+ assert "a__d" in test.get_params(deep=True)
+ assert "a__d" not in test.get_params(deep=False)
+
+ test.set_params(a__d=2)
+ assert test.a.d == 2
+
+ with pytest.raises(ValueError):
+ test.set_params(a__a=2)
+
+
+def test_is_classifier():
+ svc = SVC()
+ assert is_classifier(svc)
+ assert is_classifier(GridSearchCV(svc, {"C": [0.1, 1]}))
+ assert is_classifier(Pipeline([("svc", svc)]))
+ assert is_classifier(Pipeline([("svc_cv", GridSearchCV(svc, {"C": [0.1, 1]}))]))
+
+
+def test_set_params():
+ # test nested estimator parameter setting
+ clf = Pipeline([("svc", SVC())])
+
+ # non-existing parameter in svc
+ with pytest.raises(ValueError):
+ clf.set_params(svc__stupid_param=True)
+
+ # non-existing parameter of pipeline
+ with pytest.raises(ValueError):
+ clf.set_params(svm__stupid_param=True)
+
+ # we don't currently catch if the things in pipeline are estimators
+ # bad_pipeline = Pipeline([("bad", NoEstimator())])
+ # assert_raises(AttributeError, bad_pipeline.set_params,
+ # bad__stupid_param=True)
+
+
+def test_set_params_passes_all_parameters():
+ # Make sure all parameters are passed together to set_params
+ # of nested estimator. Regression test for #9944
+
+ class TestDecisionTree(DecisionTreeClassifier):
+ def set_params(self, **kwargs):
+ super().set_params(**kwargs)
+ # expected_kwargs is in test scope
+ assert kwargs == expected_kwargs
+ return self
+
+ expected_kwargs = {"max_depth": 5, "min_samples_leaf": 2}
+ for est in [
+ Pipeline([("estimator", TestDecisionTree())]),
+ GridSearchCV(TestDecisionTree(), {}),
+ ]:
+ est.set_params(estimator__max_depth=5, estimator__min_samples_leaf=2)
+
+
+def test_set_params_updates_valid_params():
+ # Check that set_params tries to set SVC().C, not
+ # DecisionTreeClassifier().C
+ gscv = GridSearchCV(DecisionTreeClassifier(), {})
+ gscv.set_params(estimator=SVC(), estimator__C=42.0)
+ assert gscv.estimator.C == 42.0
+
+
+@pytest.mark.parametrize(
+ "tree,dataset",
+ [
+ (
+ DecisionTreeClassifier(max_depth=2, random_state=0),
+ datasets.make_classification(random_state=0),
+ ),
+ (
+ DecisionTreeRegressor(max_depth=2, random_state=0),
+ datasets.make_regression(random_state=0),
+ ),
+ ],
+)
+def test_score_sample_weight(tree, dataset):
+ rng = np.random.RandomState(0)
+ # check that the score with and without sample weights are different
+ X, y = dataset
+
+ tree.fit(X, y)
+ # generate random sample weights
+ sample_weight = rng.randint(1, 10, size=len(y))
+ score_unweighted = tree.score(X, y)
+ score_weighted = tree.score(X, y, sample_weight=sample_weight)
+ msg = "Unweighted and weighted scores are unexpectedly equal"
+ assert score_unweighted != score_weighted, msg
+
+
+def test_clone_pandas_dataframe():
+ class DummyEstimator(TransformerMixin, BaseEstimator):
+ """This is a dummy class for generating numerical features
+
+ This feature extractor extracts numerical features from pandas data
+ frame.
+
+ Parameters
+ ----------
+
+ df: pandas data frame
+ The pandas data frame parameter.
+
+ Notes
+ -----
+ """
+
+ def __init__(self, df=None, scalar_param=1):
+ self.df = df
+ self.scalar_param = scalar_param
+
+ def fit(self, X, y=None):
+ pass
+
+ def transform(self, X):
+ pass
+
+ # build and clone estimator
+ d = np.arange(10)
+ df = MockDataFrame(d)
+ e = DummyEstimator(df, scalar_param=1)
+ cloned_e = clone(e)
+
+ # the test
+ assert (e.df == cloned_e.df).values.all()
+ assert e.scalar_param == cloned_e.scalar_param
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_clone_protocol():
+ """Checks that clone works with `__sklearn_clone__` protocol."""
+
+ class FrozenEstimator(BaseEstimator):
+ def __init__(self, fitted_estimator):
+ self.fitted_estimator = fitted_estimator
+
+ def __getattr__(self, name):
+ return getattr(self.fitted_estimator, name)
+
+ def __sklearn_clone__(self):
+ return self
+
+ def fit(self, *args, **kwargs):
+ return self
+
+ def fit_transform(self, *args, **kwargs):
+ return self.fitted_estimator.transform(*args, **kwargs)
+
+ X = np.array([[-1, -1], [-2, -1], [-3, -2]])
+ pca = PCA().fit(X)
+ components = pca.components_
+
+ frozen_pca = FrozenEstimator(pca)
+ assert_allclose(frozen_pca.components_, components)
+
+ # Calling PCA methods such as `get_feature_names_out` still works
+ assert_array_equal(frozen_pca.get_feature_names_out(), pca.get_feature_names_out())
+
+ # Fitting on a new data does not alter `components_`
+ X_new = np.asarray([[-1, 2], [3, 4], [1, 2]])
+ frozen_pca.fit(X_new)
+ assert_allclose(frozen_pca.components_, components)
+
+ # `fit_transform` does not alter state
+ frozen_pca.fit_transform(X_new)
+ assert_allclose(frozen_pca.components_, components)
+
+ # Cloning estimator is a no-op
+ clone_frozen_pca = clone(frozen_pca)
+ assert clone_frozen_pca is frozen_pca
+ assert_allclose(clone_frozen_pca.components_, components)
+
+
+def test_pickle_version_warning_is_not_raised_with_matching_version():
+ iris = datasets.load_iris()
+ tree = DecisionTreeClassifier().fit(iris.data, iris.target)
+ tree_pickle = pickle.dumps(tree)
+ assert b"version" in tree_pickle
+ tree_restored = assert_no_warnings(pickle.loads, tree_pickle)
+
+ # test that we can predict with the restored decision tree classifier
+ score_of_original = tree.score(iris.data, iris.target)
+ score_of_restored = tree_restored.score(iris.data, iris.target)
+ assert score_of_original == score_of_restored
+
+
+class TreeBadVersion(DecisionTreeClassifier):
+ def __getstate__(self):
+ return dict(self.__dict__.items(), _sklearn_version="something")
+
+
+pickle_error_message = (
+ "Trying to unpickle estimator {estimator} from "
+ "version {old_version} when using version "
+ "{current_version}. This might "
+ "lead to breaking code or invalid results. "
+ "Use at your own risk."
+)
+
+
+class TreeNoVersion(DecisionTreeClassifier):
+ def __getstate__(self):
+ return self.__dict__
+
+
+@pytest.mark.skip(reason="Failing test")
+def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
+ iris = datasets.load_iris()
+ # TreeNoVersion has no getstate, like pre-0.18
+ tree = TreeNoVersion().fit(iris.data, iris.target)
+
+ tree_pickle_noversion = pickle.dumps(tree)
+ assert b"version" not in tree_pickle_noversion
+ message = pickle_error_message.format(
+ estimator="TreeNoVersion",
+ old_version="pre-0.18",
+ current_version=sklearn.__version__,
+ )
+ # check we got the warning about using pre-0.18 pickle
+ with pytest.warns(UserWarning, match=message):
+ pickle.loads(tree_pickle_noversion)
+
+
+def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
+ iris = datasets.load_iris()
+ tree = TreeNoVersion().fit(iris.data, iris.target)
+ tree_pickle_noversion = pickle.dumps(tree)
+ try:
+ module_backup = TreeNoVersion.__module__
+ TreeNoVersion.__module__ = "notsklearn"
+ assert_no_warnings(pickle.loads, tree_pickle_noversion)
+ finally:
+ TreeNoVersion.__module__ = module_backup
+
+
+class DontPickleAttributeMixin:
+ def __getstate__(self):
+ data = self.__dict__.copy()
+ data["_attribute_not_pickled"] = None
+ return data
+
+ def __setstate__(self, state):
+ state["_restored"] = True
+ self.__dict__.update(state)
+
+
+class MultiInheritanceEstimator(DontPickleAttributeMixin, BaseEstimator):
+ def __init__(self, attribute_pickled=5):
+ self.attribute_pickled = attribute_pickled
+ self._attribute_not_pickled = None
+
+
+def test_pickling_when_getstate_is_overwritten_by_mixin():
+ estimator = MultiInheritanceEstimator()
+ estimator._attribute_not_pickled = "this attribute should not be pickled"
+
+ serialized = pickle.dumps(estimator)
+ estimator_restored = pickle.loads(serialized)
+ assert estimator_restored.attribute_pickled == 5
+ assert estimator_restored._attribute_not_pickled is None
+ assert estimator_restored._restored
+
+
+def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn():
+ try:
+ estimator = MultiInheritanceEstimator()
+ text = "this attribute should not be pickled"
+ estimator._attribute_not_pickled = text
+ old_mod = type(estimator).__module__
+ type(estimator).__module__ = "notsklearn"
+
+ serialized = estimator.__getstate__()
+ assert serialized == {"_attribute_not_pickled": None, "attribute_pickled": 5}
+
+ serialized["attribute_pickled"] = 4
+ estimator.__setstate__(serialized)
+ assert estimator.attribute_pickled == 4
+ assert estimator._restored
+ finally:
+ type(estimator).__module__ = old_mod
+
+
+class SingleInheritanceEstimator(BaseEstimator):
+ def __init__(self, attribute_pickled=5):
+ self.attribute_pickled = attribute_pickled
+ self._attribute_not_pickled = None
+
+ def __getstate__(self):
+ data = self.__dict__.copy()
+ data["_attribute_not_pickled"] = None
+ return data
+
+
+@ignore_warnings(category=(UserWarning))
+def test_pickling_works_when_getstate_is_overwritten_in_the_child_class():
+ estimator = SingleInheritanceEstimator()
+ estimator._attribute_not_pickled = "this attribute should not be pickled"
+
+ serialized = pickle.dumps(estimator)
+ estimator_restored = pickle.loads(serialized)
+ assert estimator_restored.attribute_pickled == 5
+ assert estimator_restored._attribute_not_pickled is None
+
+
+def test_tag_inheritance():
+ # test that changing tags by inheritance is not allowed
+
+ nan_tag_est = NaNTag()
+ no_nan_tag_est = NoNaNTag()
+ assert nan_tag_est._get_tags()["allow_nan"]
+ assert not no_nan_tag_est._get_tags()["allow_nan"]
+
+ redefine_tags_est = OverrideTag()
+ assert not redefine_tags_est._get_tags()["allow_nan"]
+
+ diamond_tag_est = DiamondOverwriteTag()
+ assert diamond_tag_est._get_tags()["allow_nan"]
+
+ inherit_diamond_tag_est = InheritDiamondOverwriteTag()
+ assert inherit_diamond_tag_est._get_tags()["allow_nan"]
+
+
+def test_raises_on_get_params_non_attribute():
+ class MyEstimator(BaseEstimator):
+ def __init__(self, param=5):
+ pass
+
+ def fit(self, X, y=None):
+ return self
+
+ est = MyEstimator()
+ msg = "'MyEstimator' object has no attribute 'param'"
+
+ with pytest.raises(AttributeError, match=msg):
+ est.get_params()
+
+
+def test_repr_mimebundle_():
+ # Checks the display configuration flag controls the json output
+ tree = DecisionTreeClassifier()
+ output = tree._repr_mimebundle_()
+ assert "text/plain" in output
+ assert "text/html" in output
+
+ with config_context(display="text"):
+ output = tree._repr_mimebundle_()
+ assert "text/plain" in output
+ assert "text/html" not in output
+
+
+def test_repr_html_wraps():
+ # Checks the display configuration flag controls the html output
+ tree = DecisionTreeClassifier()
+
+ output = tree._repr_html_()
+ assert "