Skip to content

Commit

Permalink
MAINT use public import for metadata routing (#1113)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Dec 20, 2024
1 parent bc94b25 commit 2d65471
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
7 changes: 7 additions & 0 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
import numpy as np
from sklearn.base import BaseEstimator, OneToOneFeatureMixin
from sklearn.preprocessing import label_binarize
from sklearn.utils._metadata_requests import METHODS
from sklearn.utils.multiclass import check_classification_targets

from .utils import check_sampling_strategy, check_target_type
from .utils._sklearn_compat import _fit_context, get_tags, validate_data
from .utils._validation import ArraysTransformer

if "fit_predict" not in METHODS:
METHODS.append("fit_predict")
if "fit_transform" not in METHODS:
METHODS.append("fit_transform")
METHODS.append("fit_resample")


class SamplerMixin(metaclass=ABCMeta):
"""Mixin class for samplers with abstract method.
Expand Down
14 changes: 4 additions & 10 deletions imblearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@
from sklearn.base import clone
from sklearn.exceptions import NotFittedError
from sklearn.utils import Bunch
from sklearn.utils._metadata_requests import (
METHODS,
from sklearn.utils._param_validation import HasMethods
from sklearn.utils.fixes import parse_version
from sklearn.utils.metadata_routing import (
MetadataRouter,
MethodMapping,
_routing_enabled,
get_routing_for_object,
)
from sklearn.utils._param_validation import HasMethods
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_is_fitted, check_memory

from .base import METHODS
from .utils._sklearn_compat import (
_fit_context,
_print_elapsed_time,
Expand All @@ -43,12 +43,6 @@
validate_params,
)

if "fit_predict" not in METHODS:
METHODS.append("fit_predict")
if "fit_transform" not in METHODS:
METHODS.append("fit_transform")
METHODS.append("fit_resample")

__all__ = ["Pipeline", "make_pipeline"]


Expand Down

0 comments on commit 2d65471

Please sign in to comment.