-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for bert-sklearn #minor (#74)
- Loading branch information
Showing
10 changed files
with
140 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
===================== | ||
BERT sklearn | ||
===================== | ||
In order to use `bert-sklearn <https://github.com/charles9n/bert-sklearn>`_ with HiClass, some of scikit-learns checks need to be disabled. | ||
The reason is that BERT expects text as input for the features, but scikit-learn expects numerical features. | ||
Hence, the checks will fail. | ||
To disable scikit-learn's checks, we can simply use the parameter `bert=True` in the constructor of the local hierarchical classifier. | ||
""" | ||
from bert_sklearn import BertClassifier | ||
from hiclass import LocalClassifierPerParentNode | ||
|
||
# Define data | ||
X_train = X_test = [ | ||
"Batman", | ||
"Rorschach", | ||
] | ||
Y_train = [ | ||
["Action", "The Dark Night"], | ||
["Action", "Watchmen"], | ||
] | ||
|
||
# Use BERT for every node | ||
bert = BertClassifier() | ||
classifier = LocalClassifierPerParentNode( | ||
local_classifier=bert, | ||
bert=True, | ||
) | ||
|
||
# Train local classifier per node | ||
classifier.fit(X_train, Y_train) | ||
|
||
# Predict | ||
predictions = classifier.predict(X_test) | ||
print(predictions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ matplotlib==3.5.2 | |
pandas==1.4.2 | ||
ray==1.13.0 | ||
numpy<1.24 | ||
git+https://github.com/charles9n/bert-sklearn.git@master |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters