diff --git a/classy_classification/classifiers/classy_skeleton.py b/classy_classification/classifiers/classy_skeleton.py index c3e3f20..bfe2bdc 100644 --- a/classy_classification/classifiers/classy_skeleton.py +++ b/classy_classification/classifiers/classy_skeleton.py @@ -1,3 +1,4 @@ +import collections import importlib.util from typing import List, Union @@ -48,9 +49,11 @@ def __init__( "max_cross_validation_folds": 5 }. """ + self.multi_label = multi_label - self.data = data + self.data = collections.OrderedDict(sorted(data.items())) + self.name = name self.nlp = nlp self.verbose = verbose @@ -125,7 +128,7 @@ def set_config(self, config: Union[dict, None] = None): if len(self.label_list) > 1: config = { "C": [1, 2, 5, 10, 20, 50, 100], - "kernel": ["linear"], + "kernel": ["linear", "rbf", "poly", "sigmoid"], "max_cross_validation_folds": 5, "seed": None, } diff --git a/classy_classification/classifiers/classy_standalone.py b/classy_classification/classifiers/classy_standalone.py index 173561c..eae0ea4 100644 --- a/classy_classification/classifiers/classy_standalone.py +++ b/classy_classification/classifiers/classy_standalone.py @@ -1,3 +1,4 @@ +import collections from typing import List, Union from .classy_spacy import ClassyExternal, ClassySkeletonFewShot @@ -59,7 +60,7 @@ def __init__( }. """ self.multi_label = multi_label - self.data = data + self.data = collections.OrderedDict(sorted(data.items())) self.model = model self.device = device self.verbose = verbose diff --git a/poetry.lock b/poetry.lock index c2f3b02..0af8244 100644 --- a/poetry.lock +++ b/poetry.lock @@ -355,7 +355,7 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] [[package]] name = "fast-sentence-transformers" -version = "0.4.0" +version = "0.4.1" description = "This repository contains code to run faster sentence-transformers using tools like quantization, ONNX and pruning." category = "main" optional = true @@ -553,6 +553,14 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "InstructorEmbedding" +version = "1.0.0" +description = "Text embedding tool" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "ipykernel" version = "6.20.1" @@ -2473,7 +2481,7 @@ onnx = ["fast-sentence-transformers"] [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.12" -content-hash = "071e1d8953fe722e7884f8ac4a003d1f55101263694f7b22340b3f281b02f470" +content-hash = "c1b34a37910e4f6ef4436feb1890e176edc243416c2148261cb8528709c227c5" [metadata.files] anyio = [ @@ -2760,8 +2768,8 @@ executing = [ {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"}, ] fast-sentence-transformers = [ - {file = "fast_sentence_transformers-0.4.0-py3-none-any.whl", hash = "sha256:e41c37078d37f230323e74828e27c3d7bab501c349295e13b25d48a4a97336a5"}, - {file = "fast_sentence_transformers-0.4.0.tar.gz", hash = "sha256:1627cb744a50f0167a705563b2085c3f31eff50c63c6a778215866f0e0ecad95"}, + {file = "fast_sentence_transformers-0.4.1-py3-none-any.whl", hash = "sha256:83d47d623c2ad5b5d2bf7d5fffcfeb32a994df82be71414fd60c0e849026a94b"}, + {file = "fast_sentence_transformers-0.4.1.tar.gz", hash = "sha256:d4d963b5495cb070701e67517f89c7d9a0f11179281e1d9c4f4dcafd230ada0f"}, ] fastjsonschema = [ {file = "fastjsonschema-2.16.2-py3-none-any.whl", hash = "sha256:21f918e8d9a1a4ba9c22e09574ba72267a6762d47822db9add95f6454e51cc1c"}, @@ -2823,6 +2831,10 @@ iniconfig = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +InstructorEmbedding = [ + {file = "InstructorEmbedding-1.0.0-py2.py3-none-any.whl", hash = "sha256:f73e0f1fa3649329a70f29fc5abe7d03bac96429c4e2704df8c77d6970b4bc32"}, + {file = "InstructorEmbedding-1.0.0.tar.gz", hash = "sha256:d68f51bcfcc98afd556ba40249790eccaf928440f8a002a178fb6bf31858eb02"}, +] ipykernel = [ {file = "ipykernel-6.20.1-py3-none-any.whl", hash = "sha256:a314e6782a4f9e277783382976b3a93608a3787cd70a235b558b47f875134be1"}, {file = "ipykernel-6.20.1.tar.gz", hash = "sha256:f6016ecbf581d0ea6e29ba16cee6cc1a9bbde3835900c46c6571a791692f4139"}, diff --git a/pyproject.toml b/pyproject.toml index 187d13b..c62e71d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ sentence-transformers = "^2.0" scikit-learn = "^1.0" pandas = "^1.4" fast-sentence-transformers = { version = "^0.4.1", optional = true } +InstructorEmbedding = "^1.0.0" [tool.poetry.extras] onnx = ["fast-sentence-transformers"]