Skip to content

Commit

Permalink
Fix/1349 (#1350)
Browse files Browse the repository at this point in the history
* Add packaging dependency

* Change use of distutils to packaging

* Update missed usage of distutils to packaging

* Inline comparison to clear up confusion
  • Loading branch information
PGijsbers authored Sep 16, 2024
1 parent de983ac commit fa7e9db
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 133 deletions.
25 changes: 12 additions & 13 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import traceback
import warnings
from collections import OrderedDict
from distutils.version import LooseVersion
from json.decoder import JSONDecodeError
from re import IGNORECASE
from typing import Any, Callable, List, Sized, cast
Expand All @@ -25,6 +24,7 @@
import sklearn.base
import sklearn.model_selection
import sklearn.pipeline
from packaging.version import Version

import openml
from openml.exceptions import PyOpenMLError
Expand All @@ -48,7 +48,7 @@
r"(?P<version>(\d+\.)?(\d+\.)?(\d+)?(dev)?[0-9]*))?$",
)

sctypes = np.sctypes if LooseVersion(np.__version__) < "2.0" else np.core.sctypes
sctypes = np.sctypes if Version(np.__version__) < Version("2.0") else np.core.sctypes
SIMPLE_NUMPY_TYPES = [
nptype
for type_cat, nptypes in sctypes.items()
Expand Down Expand Up @@ -237,14 +237,13 @@ def _min_dependency_str(cls, sklearn_version: str) -> str:
-------
str
"""
openml_major_version = int(LooseVersion(openml.__version__).version[1])
# This explicit check is necessary to support existing entities on the OpenML servers
# that used the fixed dependency string (in the else block)
if openml_major_version > 11:
if Version(openml.__version__) > Version("0.11"):
# OpenML v0.11 onwards supports sklearn>=0.24
# assumption: 0.24 onwards sklearn should contain a _min_dependencies.py file with
# variables declared for extracting minimum dependency for that version
if LooseVersion(sklearn_version) >= "0.24":
if Version(sklearn_version) >= Version("0.24"):
from sklearn import _min_dependencies as _mindep

dependency_list = {
Expand All @@ -253,18 +252,18 @@ def _min_dependency_str(cls, sklearn_version: str) -> str:
"joblib": f"{_mindep.JOBLIB_MIN_VERSION}",
"threadpoolctl": f"{_mindep.THREADPOOLCTL_MIN_VERSION}",
}
elif LooseVersion(sklearn_version) >= "0.23":
elif Version(sklearn_version) >= Version("0.23"):
dependency_list = {
"numpy": "1.13.3",
"scipy": "0.19.1",
"joblib": "0.11",
"threadpoolctl": "2.0.0",
}
if LooseVersion(sklearn_version).version[2] == 0:
if Version(sklearn_version).micro == 0:
dependency_list.pop("threadpoolctl")
elif LooseVersion(sklearn_version) >= "0.21":
elif Version(sklearn_version) >= Version("0.21"):
dependency_list = {"numpy": "1.11.0", "scipy": "0.17.0", "joblib": "0.11"}
elif LooseVersion(sklearn_version) >= "0.19":
elif Version(sklearn_version) >= Version("0.19"):
dependency_list = {"numpy": "1.8.2", "scipy": "0.13.3"}
else:
dependency_list = {"numpy": "1.6.1", "scipy": "0.9"}
Expand Down Expand Up @@ -1226,8 +1225,8 @@ def _check_dependencies(
version = match.group("version")

module = importlib.import_module(dependency_name)
required_version = LooseVersion(version)
installed_version = LooseVersion(module.__version__) # type: ignore
required_version = Version(version)
installed_version = Version(module.__version__) # type: ignore

if operation == "==":
check = required_version == installed_version
Expand Down Expand Up @@ -1258,7 +1257,7 @@ def _serialize_type(self, o: Any) -> OrderedDict[str, str]:
np.int32: "np.int32",
np.int64: "np.int64",
}
if LooseVersion(np.__version__) < "1.24":
if Version(np.__version__) < Version("1.24"):
mapping[float] = "np.float"
mapping[int] = "np.int"

Expand All @@ -1278,7 +1277,7 @@ def _deserialize_type(self, o: str) -> Any:
}

# TODO(eddiebergman): Might be able to remove this
if LooseVersion(np.__version__) < "1.24":
if Version(np.__version__) < Version("1.24"):
mapping["np.float"] = np.float # type: ignore # noqa: NPY001
mapping["np.int"] = np.int # type: ignore # noqa: NPY001

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"numpy>=1.6.2",
"minio",
"pyarrow",
"packaging",
]
requires-python = ">=3.8"
authors = [
Expand Down
Loading

0 comments on commit fa7e9db

Please sign in to comment.