diff --git a/orangecontrib/text/widgets/owdocumentembedding.py b/orangecontrib/text/widgets/owdocumentembedding.py index 1a8524d6e..e83b56c4f 100644 --- a/orangecontrib/text/widgets/owdocumentembedding.py +++ b/orangecontrib/text/widgets/owdocumentembedding.py @@ -1,7 +1,7 @@ from typing import Dict, Optional, Any from AnyQt.QtCore import Qt -from AnyQt.QtWidgets import QGridLayout, QLabel, QPushButton, QStyle +from AnyQt.QtWidgets import QVBoxLayout, QPushButton, QStyle from Orange.misc.utils.embedder_utils import EmbeddingConnectionError from Orange.widgets import gui from Orange.widgets.settings import Setting @@ -13,7 +13,7 @@ LANGS_TO_ISO, DocumentEmbedder, ) -from orangecontrib.text.widgets.utils import widgets +from orangecontrib.text.vectorization.sbert import SBERT from orangecontrib.text.widgets.utils.owbasevectorizer import ( OWBaseVectorizer, Vectorizer, @@ -30,6 +30,7 @@ def _transform(self, callback): self.new_corpus = embeddings self.skipped_documents = skipped + class OWDocumentEmbedding(OWBaseVectorizer): name = "Document Embedding" description = "Document embedding using pretrained models." @@ -40,7 +41,7 @@ class OWDocumentEmbedding(OWBaseVectorizer): buttons_area_orientation = Qt.Vertical settings_version = 2 - Method = DocumentEmbedder + Methods = [DocumentEmbedder, SBERT] class Outputs(OWBaseVectorizer.Outputs): skipped = Output("Skipped documents", Corpus) @@ -55,9 +56,9 @@ class Error(OWWidget.Error): class Warning(OWWidget.Warning): unsuccessful_embeddings = Msg("Some embeddings were unsuccessful.") - method = Setting(default=0) - language = Setting(default="English") - aggregator = Setting(default="Mean") + method: int = Setting(default=0) + language: str = Setting(default="English") + aggregator: str = Setting(default="Mean") def __init__(self): super().__init__() @@ -69,32 +70,43 @@ def __init__(self): self.cancel_button.setDisabled(True) def create_configuration_layout(self): - layout = QGridLayout() - layout.setSpacing(10) - - combo = widgets.ComboBox( + layout = QVBoxLayout() + rbtns = gui.radioButtons(None, self, "method", callback=self.on_change) + layout.addWidget(rbtns) + + gui.appendRadioButton(rbtns, "fastText:") + ibox = gui.indentedBox(rbtns) + gui.comboBox( + ibox, self, "language", items=LANGUAGES, + label="Language:", + sendSelectedValue=True, # value is actual string not index + orientation=Qt.Horizontal, + callback=self.on_change, + ) + gui.comboBox( + ibox, + self, + "aggregator", + items=AGGREGATORS, + label="Aggregator:", + sendSelectedValue=True, # value is actual string not index + orientation=Qt.Horizontal, + callback=self.on_change, ) - combo.currentIndexChanged.connect(self.on_change) - layout.addWidget(QLabel("Language:")) - layout.addWidget(combo, 0, 1) - - combo = widgets.ComboBox(self, "aggregator", items=AGGREGATORS) - combo.currentIndexChanged.connect(self.on_change) - layout.addWidget(QLabel("Aggregator:")) - layout.addWidget(combo, 1, 1) + gui.appendRadioButton(rbtns, "Multilingual SBERT:") return layout def update_method(self): self.vectorizer = EmbeddingVectorizer(self.init_method(), self.corpus) def init_method(self): - return self.Method( - language=LANGS_TO_ISO[self.language], aggregator=self.aggregator - ) + params = dict(language=LANGS_TO_ISO[self.language], aggregator=self.aggregator) + kwargs = (params, {})[self.method] + return self.Methods[self.method](**kwargs) @gui.deferred def commit(self): @@ -133,7 +145,6 @@ def migrate_settings(cls, settings: Dict[str, Any], version: Optional[int]): settings["aggregator"] = AGGREGATORS[settings["aggregator"]] - if __name__ == "__main__": from orangewidget.utils.widgetpreview import WidgetPreview diff --git a/orangecontrib/text/widgets/utils/owbasevectorizer.py b/orangecontrib/text/widgets/utils/owbasevectorizer.py index 41b029775..0434250bd 100644 --- a/orangecontrib/text/widgets/utils/owbasevectorizer.py +++ b/orangecontrib/text/widgets/utils/owbasevectorizer.py @@ -117,7 +117,7 @@ def on_change(self): self.commit.deferred() def send_report(self): - self.report_items(self.method.report()) + self.report_items(self.vectorizer.method.report()) def create_configuration_layout(self): raise NotImplementedError