Skip to content

Commit

Permalink
Document Embedding - add SBERT method to widget
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Jun 29, 2022
1 parent 555fe72 commit b67016c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
55 changes: 33 additions & 22 deletions orangecontrib/text/widgets/owdocumentembedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional, Any

from AnyQt.QtCore import Qt
from AnyQt.QtWidgets import QGridLayout, QLabel, QPushButton, QStyle
from AnyQt.QtWidgets import QVBoxLayout, QPushButton, QStyle
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from Orange.widgets import gui
from Orange.widgets.settings import Setting
Expand All @@ -13,7 +13,7 @@
LANGS_TO_ISO,
DocumentEmbedder,
)
from orangecontrib.text.widgets.utils import widgets
from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.utils.owbasevectorizer import (
OWBaseVectorizer,
Vectorizer,
Expand All @@ -30,6 +30,7 @@ def _transform(self, callback):
self.new_corpus = embeddings
self.skipped_documents = skipped


class OWDocumentEmbedding(OWBaseVectorizer):
name = "Document Embedding"
description = "Document embedding using pretrained models."
Expand All @@ -40,7 +41,7 @@ class OWDocumentEmbedding(OWBaseVectorizer):
buttons_area_orientation = Qt.Vertical
settings_version = 2

Method = DocumentEmbedder
Methods = [DocumentEmbedder, SBERT]

class Outputs(OWBaseVectorizer.Outputs):
skipped = Output("Skipped documents", Corpus)
Expand All @@ -55,9 +56,9 @@ class Error(OWWidget.Error):
class Warning(OWWidget.Warning):
unsuccessful_embeddings = Msg("Some embeddings were unsuccessful.")

method = Setting(default=0)
language = Setting(default="English")
aggregator = Setting(default="Mean")
method: int = Setting(default=0)
language: str = Setting(default="English")
aggregator: str = Setting(default="Mean")

def __init__(self):
super().__init__()
Expand All @@ -69,32 +70,43 @@ def __init__(self):
self.cancel_button.setDisabled(True)

def create_configuration_layout(self):
layout = QGridLayout()
layout.setSpacing(10)

combo = widgets.ComboBox(
layout = QVBoxLayout()
rbtns = gui.radioButtons(None, self, "method", callback=self.on_change)
layout.addWidget(rbtns)

gui.appendRadioButton(rbtns, "fastText:")
ibox = gui.indentedBox(rbtns)
gui.comboBox(
ibox,
self,
"language",
items=LANGUAGES,
label="Language:",
sendSelectedValue=True, # value is actual string not index
orientation=Qt.Horizontal,
callback=self.on_change,
)
gui.comboBox(
ibox,
self,
"aggregator",
items=AGGREGATORS,
label="Aggregator:",
sendSelectedValue=True, # value is actual string not index
orientation=Qt.Horizontal,
callback=self.on_change,
)
combo.currentIndexChanged.connect(self.on_change)
layout.addWidget(QLabel("Language:"))
layout.addWidget(combo, 0, 1)

combo = widgets.ComboBox(self, "aggregator", items=AGGREGATORS)
combo.currentIndexChanged.connect(self.on_change)
layout.addWidget(QLabel("Aggregator:"))
layout.addWidget(combo, 1, 1)

gui.appendRadioButton(rbtns, "Multilingual SBERT:")
return layout

def update_method(self):
self.vectorizer = EmbeddingVectorizer(self.init_method(), self.corpus)

def init_method(self):
return self.Method(
language=LANGS_TO_ISO[self.language], aggregator=self.aggregator
)
params = dict(language=LANGS_TO_ISO[self.language], aggregator=self.aggregator)
kwargs = (params, {})[self.method]
return self.Methods[self.method](**kwargs)

@gui.deferred
def commit(self):
Expand Down Expand Up @@ -133,7 +145,6 @@ def migrate_settings(cls, settings: Dict[str, Any], version: Optional[int]):
settings["aggregator"] = AGGREGATORS[settings["aggregator"]]



if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview

Expand Down
2 changes: 1 addition & 1 deletion orangecontrib/text/widgets/utils/owbasevectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def on_change(self):
self.commit.deferred()

def send_report(self):
self.report_items(self.method.report())
self.report_items(self.vectorizer.method.report())

def create_configuration_layout(self):
raise NotImplementedError
Expand Down

0 comments on commit b67016c

Please sign in to comment.