Skip to content

Commit

Permalink
Add support for bert-sklearn #minor (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 authored Jan 17, 2023
1 parent cc6e992 commit c5023a4
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 9 deletions.
37 changes: 37 additions & 0 deletions docs/examples/plot_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
"""
=====================
BERT sklearn
=====================
In order to use `bert-sklearn <https://github.com/charles9n/bert-sklearn>`_ with HiClass, some of scikit-learns checks need to be disabled.
The reason is that BERT expects text as input for the features, but scikit-learn expects numerical features.
Hence, the checks will fail.
To disable scikit-learn's checks, we can simply use the parameter `bert=True` in the constructor of the local hierarchical classifier.
"""
from bert_sklearn import BertClassifier
from hiclass import LocalClassifierPerParentNode

# Define data
X_train = X_test = [
"Batman",
"Rorschach",
]
Y_train = [
["Action", "The Dark Night"],
["Action", "Watchmen"],
]

# Use BERT for every node
bert = BertClassifier()
classifier = LocalClassifierPerParentNode(
local_classifier=bert,
bert=True,
)

# Train local classifier per node
classifier.fit(X_train, Y_train)

# Predict
predictions = classifier.predict(X_test)
print(predictions)
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ matplotlib==3.5.2
pandas==1.4.2
ray==1.13.0
numpy<1.24
git+https://github.com/charles9n/bert-sklearn.git@master
14 changes: 11 additions & 3 deletions hiclass/HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
edge_list: str = None,
replace_classifiers: bool = True,
n_jobs: int = 1,
bert: bool = False,
classifier_abbreviation: str = "",
):
"""
Expand All @@ -87,6 +88,8 @@ def __init__(
n_jobs : int, default=1
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
bert : bool, default=False
If True, skip scikit-learn's checks and sample_weight passing for BERT.
classifier_abbreviation : str, default=""
The abbreviation of the local hierarchical classifier to be displayed during logging.
"""
Expand All @@ -95,6 +98,7 @@ def __init__(
self.edge_list = edge_list
self.replace_classifiers = replace_classifiers
self.n_jobs = n_jobs
self.bert = bert
self.classifier_abbreviation = classifier_abbreviation

def fit(self, X, y, sample_weight=None):
Expand Down Expand Up @@ -130,9 +134,13 @@ def _pre_fit(self, X, y, sample_weight):
# Check that X and y have correct shape
# and convert them to np.ndarray if need be

self.X_, self.y_ = self._validate_data(
X, y, multi_output=True, accept_sparse="csr"
)
if not self.bert:
self.X_, self.y_ = self._validate_data(
X, y, multi_output=True, accept_sparse="csr"
)
else:
self.X_ = np.array(X)
self.y_ = np.array(y)

if sample_weight is not None:
self.sample_weight_ = _check_sample_weight(sample_weight, X)
Expand Down
14 changes: 12 additions & 2 deletions hiclass/LocalClassifierPerLevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
edge_list: str = None,
replace_classifiers: bool = True,
n_jobs: int = 1,
bert: bool = False,
):
"""
Initialize a local classifier per level.
Expand All @@ -68,6 +69,8 @@ def __init__(
n_jobs : int, default=1
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
bert : bool, default=False
If True, skip scikit-learn's checks and sample_weight passing for BERT.
"""
super().__init__(
local_classifier=local_classifier,
Expand All @@ -76,6 +79,7 @@ def __init__(
replace_classifiers=replace_classifiers,
n_jobs=n_jobs,
classifier_abbreviation="LCPL",
bert=bert,
)

def fit(self, X, y, sample_weight=None):
Expand Down Expand Up @@ -135,7 +139,10 @@ def predict(self, X):
check_is_fitted(self)

# Input validation
X = check_array(X, accept_sparse="csr")
if not self.bert:
X = check_array(X, accept_sparse="csr")
else:
X = np.array(X)

# Initialize array that holds predictions
y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_)
Expand Down Expand Up @@ -242,7 +249,10 @@ def _fit_classifier(self, level, separator):
unique_y = np.unique(y)
if len(unique_y) == 1 and self.replace_classifiers:
classifier = ConstantClassifier()
classifier.fit(X, y, sample_weight)
if not self.bert:
classifier.fit(X, y, sample_weight)
else:
classifier.fit(X, y)
return classifier

@staticmethod
Expand Down
14 changes: 12 additions & 2 deletions hiclass/LocalClassifierPerNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
edge_list: str = None,
replace_classifiers: bool = True,
n_jobs: int = 1,
bert: bool = False,
):
"""
Initialize a local classifier per node.
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(
n_jobs : int, default=1
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
bert : bool, default=False
If True, skip scikit-learn's checks and sample_weight passing for BERT.
"""
super().__init__(
local_classifier=local_classifier,
Expand All @@ -82,6 +85,7 @@ def __init__(
replace_classifiers=replace_classifiers,
n_jobs=n_jobs,
classifier_abbreviation="LCPN",
bert=bert,
)
self.binary_policy = binary_policy

Expand Down Expand Up @@ -145,7 +149,10 @@ def predict(self, X):
check_is_fitted(self)

# Input validation
X = check_array(X, accept_sparse="csr")
if not self.bert:
X = check_array(X, accept_sparse="csr")
else:
X = np.array(X)

# Initialize array that holds predictions
y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_)
Expand Down Expand Up @@ -233,7 +240,10 @@ def _fit_classifier(self, node):
unique_y = np.unique(y)
if len(unique_y) == 1 and self.replace_classifiers:
classifier = ConstantClassifier()
classifier.fit(X, y, sample_weight)
if not self.bert:
classifier.fit(X, y, sample_weight)
else:
classifier.fit(X, y)
return classifier

def _clean_up(self):
Expand Down
14 changes: 12 additions & 2 deletions hiclass/LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
edge_list: str = None,
replace_classifiers: bool = True,
n_jobs: int = 1,
bert: bool = False,
):
"""
Initialize a local classifier per parent node.
Expand All @@ -61,6 +62,8 @@ def __init__(
n_jobs : int, default=1
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
bert : bool, default=False
If True, skip scikit-learn's checks and sample_weight passing for BERT.
"""
super().__init__(
local_classifier=local_classifier,
Expand All @@ -69,6 +72,7 @@ def __init__(
replace_classifiers=replace_classifiers,
n_jobs=n_jobs,
classifier_abbreviation="LCPPN",
bert=bert,
)

def fit(self, X, y, sample_weight=None):
Expand Down Expand Up @@ -128,7 +132,10 @@ def predict(self, X):
check_is_fitted(self)

# Input validation
X = check_array(X, accept_sparse="csr")
if not self.bert:
X = check_array(X, accept_sparse="csr")
else:
X = np.array(X)

# Initialize array that holds predictions
y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_)
Expand Down Expand Up @@ -203,7 +210,10 @@ def _fit_classifier(self, node):
unique_y = np.unique(y)
if len(unique_y) == 1 and self.replace_classifiers:
classifier = ConstantClassifier()
classifier.fit(X, y, sample_weight)
if not self.bert:
classifier.fit(X, y, sample_weight)
else:
classifier.fit(X, y)
return classifier

def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,13 @@ def test_make_leveled_non_iterable_y(noniterable_y):
def test_fit_classifier():
with pytest.raises(NotImplementedError):
HierarchicalClassifier._fit_classifier(None, None)


def test_pre_fit_bert():
classifier = HierarchicalClassifier()
classifier.logger_ = logging.getLogger("HC")
classifier.bert = True
x = [[0, 1], [2, 3]]
y = [["a", "b"], ["c", "d"]]
sample_weight = None
classifier._pre_fit(x, y, sample_weight)
15 changes: 15 additions & 0 deletions tests/test_LocalClassifierPerLevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.utils.validation import check_is_fitted
from hiclass import LocalClassifierPerLevel
from hiclass.ConstantClassifier import ConstantClassifier


@parametrize_with_checks([LocalClassifierPerLevel()])
Expand Down Expand Up @@ -180,3 +181,17 @@ def test_empty_levels(empty_levels):
lcppn.root_,
]
assert_array_equal(ground_truth, predictions)


def test_fit_bert():
bert = ConstantClassifier()
lcpn = LocalClassifierPerLevel(
local_classifier=bert,
bert=True,
)
X = ["Text 1", "Text 2"]
y = ["a", "a"]
lcpn.fit(X, y)
check_is_fitted(lcpn)
predictions = lcpn.predict(X)
assert_array_equal(y, predictions)
15 changes: 15 additions & 0 deletions tests/test_LocalClassifierPerNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from hiclass import LocalClassifierPerNode
from hiclass.BinaryPolicy import ExclusivePolicy
from hiclass.ConstantClassifier import ConstantClassifier


@parametrize_with_checks([LocalClassifierPerNode()])
Expand Down Expand Up @@ -242,3 +243,17 @@ def test_empty_levels(empty_levels):
lcppn.root_,
]
assert_array_equal(ground_truth, predictions)


def test_fit_bert():
bert = ConstantClassifier()
lcpn = LocalClassifierPerNode(
local_classifier=bert,
bert=True,
)
X = ["Text 1", "Text 2"]
y = ["a", "a"]
lcpn.fit(X, y)
check_is_fitted(lcpn)
predictions = lcpn.predict(X)
assert_array_equal(y, predictions)
15 changes: 15 additions & 0 deletions tests/test_LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.utils.validation import check_is_fitted

from hiclass import LocalClassifierPerParentNode
from hiclass.ConstantClassifier import ConstantClassifier


@parametrize_with_checks([LocalClassifierPerParentNode()])
Expand Down Expand Up @@ -226,3 +227,17 @@ def test_empty_levels(empty_levels):
lcppn.root_,
]
assert_array_equal(ground_truth, predictions)


def test_bert():
bert = ConstantClassifier()
lcpn = LocalClassifierPerParentNode(
local_classifier=bert,
bert=True,
)
X = ["Text 1", "Text 2"]
y = ["a", "a"]
lcpn.fit(X, y)
check_is_fitted(lcpn)
predictions = lcpn.predict(X)
assert_array_equal(y, predictions)

0 comments on commit c5023a4

Please sign in to comment.