Skip to content

Commit

Permalink
Merge pull request #839 from djukicn/document-embedding-add-sbert
Browse files Browse the repository at this point in the history
[ENH] Document Embedding: add SBERT
  • Loading branch information
djukicn authored Jun 30, 2022
2 parents 7f3baed + 170cd4c commit be87655
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 143 deletions.
67 changes: 21 additions & 46 deletions orangecontrib/text/tests/test_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
from collections.abc import Iterator
import asyncio

from orangecontrib.text.vectorization.sbert import (
SBERT,
MIN_CHUNKS,
MAX_PACKAGE_SIZE,
EMB_DIM
)
from orangecontrib.text.vectorization.sbert import SBERT, EMB_DIM
from orangecontrib.text import Corpus

PATCH_METHOD = 'httpx.AsyncClient.post'
Expand Down Expand Up @@ -37,47 +32,17 @@ async def dummy_post(url, headers, data):


class TestSBERT(unittest.TestCase):

def setUp(self):
self.sbert = SBERT()
self.sbert.clear_cache()
self.corpus = Corpus.from_file('deerwester')

def tearDown(self):
self.sbert.clear_cache()

def test_make_chunks_small(self):
chunks = self.sbert._make_chunks(
self.corpus.documents, [100] * len(self.corpus.documents)
)
self.assertEqual(len(chunks), min(len(self.corpus.documents), MIN_CHUNKS))

def test_make_chunks_medium(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
chunks = self.sbert._make_chunks(
documents, [MAX_PACKAGE_SIZE / MIN_CHUNKS - 1] * len(documents)
)
self.assertEqual(len(chunks), MIN_CHUNKS)

def test_make_chunks_large(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS * 100
mps = MAX_PACKAGE_SIZE
chunks = self.sbert._make_chunks(
documents,
[mps / 100] * (len(documents) - 2) + [0.3 * mps, 0.9 * mps, mps]
)
self.assertGreater(len(chunks), MIN_CHUNKS)

@patch(PATCH_METHOD)
def test_empty_corpus(self, mock):
self.assertEqual(
len(self.sbert(self.corpus.documents[:0])), 0
)
self.assertEqual(len(self.sbert(self.corpus.documents[:0])), 0)
mock.request.assert_not_called()
mock.get_response.assert_not_called()
self.assertEqual(
Expand All @@ -95,14 +60,24 @@ def test_none_result(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])

@patch(PATCH_METHOD, make_dummy_post(RESPONSE[0]))
def test_success_chunks(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
result = self.sbert(documents)
self.assertEqual(len(result), MIN_CHUNKS)
@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
def test_transform(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertIsNone(skipped)
self.assertEqual(len(self.corpus), len(res))
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
self.assertEqual(384, len(res.domain.attributes))

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
def test_transform_skipped(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertEqual(len(self.corpus) - 1, len(res))
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
self.assertEqual(384, len(res.domain.attributes))

self.assertEqual(1, len(skipped))
self.assertTupleEqual(self.corpus.domain.metas, skipped.domain.metas)
self.assertEqual(0, len(skipped.domain.attributes))


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions orangecontrib/text/vectorization/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,19 @@ def _transform(

return new_corpus, skipped_corpus

def report(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
def report(self) -> Tuple[Tuple[str, str], ...]:
"""Reports on current parameters of DocumentEmbedder.
Returns
-------
tuple
Tuple of parameters.
"""
return (('Language', self.language),
('Aggregator', self.aggregator))
return (
("Embedder", "fastText"),
("Language", self.language),
("Aggregator", self.aggregator),
)

def clear_cache(self):
"""Clears embedder cache"""
Expand Down
164 changes: 93 additions & 71 deletions orangecontrib/text/vectorization/sbert.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import json
import base64
import warnings
import zlib
import sys
from typing import Any, List, Optional, Callable
from typing import Any, List, Optional, Callable, Tuple

import numpy as np

from Orange.misc.server_embedder import ServerEmbedderCommunicator
from Orange.util import dummy_callback

from orangecontrib.text import Corpus
from orangecontrib.text.vectorization.base import BaseVectorizer

# maximum document size that we still send to the server
MAX_PACKAGE_SIZE = 3000000
# maximum size of a chunk - when one document is longer send is as a chunk with
# a single document
MAX_CHUNK_SIZE = 50000
MIN_CHUNKS = 20
EMB_DIM = 384


class SBERT:
class SBERT(BaseVectorizer):
def __init__(self) -> None:
self._server_communicator = _ServerCommunicator(
model_name='sbert',
model_name="sbert",
max_parallel_requests=100,
server_url='https://api.garaza.io',
embedder_type='text',
server_url="https://api.garaza.io",
embedder_type="text",
)

def __call__(
Expand All @@ -41,78 +40,101 @@ def __call__(
-------
An array of embeddings.
"""

if len(texts) == 0:
return []
# sort text by their lengths that longer texts start to embed first. It
# prevents that long text with long embedding times start embedding
# at the end and thus add extra time to the complete embedding time
sorted_texts = sorted(
enumerate(texts),
key=lambda x: len(x[1][0]) if x[1] is not None else 0,
reverse=True,
)
indices, sorted_texts = zip(*sorted_texts)
# embedd - send to server
results = self._server_communicator.embedd_data(sorted_texts, callback=callback)
# unsort and unpack
return [x[0] if x else None for _, x in sorted(zip(indices, results))]

def _transform(
self, corpus: Corpus, _, callback=dummy_callback
) -> Tuple[Corpus, Optional[Corpus]]:
"""
Computes embeddings for given corpus and append results to the corpus
skipped = list()

encoded_texts = list()
sizes = list()
chunks = list()
for i, text in enumerate(texts):
encoded = base64.b64encode(zlib.compress(
text.encode('utf-8', 'replace'), level=-1)
).decode('utf-8', 'replace')
size = sys.getsizeof(encoded)
if size > MAX_PACKAGE_SIZE:
skipped.append(i)
continue
encoded_texts.append(encoded)
sizes.append(size)

chunks = self._make_chunks(encoded_texts, sizes)

result_ = self._server_communicator.embedd_data(chunks, callback=callback)
if result_ is None:
return [None] * len(texts)

result = list()
assert len(result_) == len(chunks)
for res_chunk, orig_chunk in zip(result_, chunks):
if res_chunk is None:
# when embedder fails (Timeout or other error) result will be None
result.extend([None] * len(orig_chunk))
else:
result.extend(res_chunk)

results = list()
idx = 0
for i in range(len(texts)):
if i in skipped:
results.append(None)
else:
results.append(result[idx])
idx += 1

return results

def _make_chunks(self, encoded_texts, sizes, depth=0):
chunks = np.array_split(encoded_texts, MIN_CHUNKS if depth == 0 else 2)
chunk_sizes = np.array_split(sizes, MIN_CHUNKS if depth == 0 else 2)
result = list()
for i in range(len(chunks)):
# checking that more than one text in chunk prevent recursion to infinity
# when one text is bigger than MAX_CHUNK_SIZE
if len(chunks[i]) > 1 and np.sum(chunk_sizes[i]) > MAX_CHUNK_SIZE:
result.extend(self._make_chunks(chunks[i], chunk_sizes[i], depth + 1))
else:
result.append(chunks[i])
return [list(r) for r in result if len(r) > 0]
Parameters
----------
corpus
Corpus on which transform is performed.
Returns
-------
Embeddings
Corpus with new features added.
Skipped documents
Corpus of documents that were not embedded
"""
embs = self(corpus.documents, callback)

# Check if some documents in corpus in weren't embedded
# for some reason. This is a very rare case.
skipped_documents = [emb is None for emb in embs]
embedded_documents = np.logical_not(skipped_documents)

new_corpus = None
if np.any(embedded_documents):
# if at least one embedding is not None, extend attributes
new_corpus = corpus[embedded_documents]
new_corpus = new_corpus.extend_attributes(
np.array(
[e for e in embs if e],
dtype=float,
),
["Dim{}".format(i + 1) for i in range(EMB_DIM)],
var_attrs={
"embedding-feature": True,
"hidden": True,
},
)

skipped_corpus = None
if np.any(skipped_documents):
skipped_corpus = corpus[skipped_documents].copy()
skipped_corpus.name = "Skipped documents"
warnings.warn(
"Some documents were not embedded for unknown reason. Those "
"documents are skipped.",
RuntimeWarning,
)

return new_corpus, skipped_corpus

def report(self) -> Tuple[Tuple[str, str], ...]:
"""Reports on current parameters of DocumentEmbedder.
Returns
-------
tuple
Tuple of parameters.
"""
return (("Embedder", "Multilingual SBERT"),)

def clear_cache(self):
if self._server_communicator:
self._server_communicator.clear_cache()

def __enter__(self):
return self


class _ServerCommunicator(ServerEmbedderCommunicator):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.content_type = 'application/json'
self.content_type = "application/json"

async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
return json.dumps(data_instance).encode('utf-8', 'replace')
data = base64.b64encode(
zlib.compress(data_instance.encode("utf-8", "replace"), level=-1)
).decode("utf-8", "replace")
if sys.getsizeof(data) > 500000:
# Document in corpus is too large. Size limit is 500 KB
# (after compression). - document skipped
return None
return json.dumps([data]).encode("utf-8", "replace")
Loading

0 comments on commit be87655

Please sign in to comment.