Skip to content

Commit

Permalink
Add Model2Vec & light-weight installation (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Feb 5, 2025
1 parent f0f96a6 commit b5f25d5
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 12 deletions.
28 changes: 28 additions & 0 deletions docs/guides/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@ sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
kw_model = KeyBERT(model=sentence_model)
```

### **Model2Vec**

For blazingly fast embedding models, [Model2Vec](https://github.com/MinishLab/model2vec) is an incredible framework. To use it KeyBERT, you only need to pass their `StaticModel`:

```python
from keybert import KeyBERT
from model2vec import StaticModel

embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M")
kw_model = KeyBERT(embedding_model)
```

If you want to distill a sentence-transformers model with the vocabulary of the documents,
run the following:

```python
from keybert.backend import Model2VecBackend

embedding_model = Model2VecBackend("sentence-transformers/all-MiniLM-L6-v2", distill=True)
```

Note that this is especially helpful if you have a very large dataset (I wouldn't recommend it with small datasets).

!!! Tip
If you also want to have a light-weight installation without (sentence-)transformers, you can install KeyBERT as follows:
`pip install keybert --no-deps scikit-learn model2vec`
This will make the installation much smaller and the import much quicker.

### 🤗 **Hugging Face Transformers**
To use a Hugging Face transformers model, load in a pipeline and point
to any model found on their model hub (https://huggingface.co/models):
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/keyllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ If you want the full performance and easiest method, you can skip the use cases
!!! Tip
If you want to use KeyLLM without any of the HuggingFace packages, you can install it as follows:
`pip install keybert --no-deps`
`pip install scikit-learn numpy rich tqdm`
`pip install scikit-learn rich tqdm`
This will make the installation much smaller and the import much quicker.

## 1. **Create** Keywords with `KeyLLM`
Expand Down
25 changes: 18 additions & 7 deletions keybert/_highlight.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Tuple, List
from rich.console import Console
from rich.highlighter import RegexHighlighter
from sklearn.feature_extraction.text import CountVectorizer

try:
from rich.console import Console

class NullHighlighter(RegexHighlighter):
"""Basic highlighter."""
HAS_RICH = True
except ModuleNotFoundError:
HAS_RICH = False

base_style = ""
highlights = [r""]
from sklearn.feature_extraction.text import CountVectorizer


def highlight_document(doc: str, keywords: List[Tuple[str, float]], vectorizer: CountVectorizer):
Expand All @@ -24,6 +23,10 @@ def highlight_document(doc: str, keywords: List[Tuple[str, float]], vectorizer:
highlighted_text: The document with additional tags to highlight keywords
according to the rich package.
"""
if not HAS_RICH:
raise ModuleNotFoundError(
"The `rich` package is required for highlighting which you can install with `pip install rich`."
)
keywords_only = [keyword for keyword, _ in keywords]
max_len = vectorizer.ngram_range[1]

Expand All @@ -32,6 +35,14 @@ def highlight_document(doc: str, keywords: List[Tuple[str, float]], vectorizer:
else:
highlighted_text = _highlight_n_gram(doc, keywords_only, vectorizer)

from rich.highlighter import RegexHighlighter

class NullHighlighter(RegexHighlighter):
"""Basic highlighter."""

base_style = ""
highlights = [r""]

console = Console(highlighter=NullHighlighter())
console.print(highlighted_text)

Expand Down
2 changes: 1 addition & 1 deletion keybert/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, tool, dep, custom_msg=None):
if custom_msg is not None:
msg += custom_msg
else:
msg += f"pip install bertopic[{self.dep}]\n\n"
msg += f"pip install keybert[{self.dep}]\n\n"
self.msg = msg

def __getattr__(self, *args, **kwargs):
Expand Down
21 changes: 18 additions & 3 deletions keybert/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
from ._base import BaseEmbedder
from ._sentencetransformers import SentenceTransformerBackend
from keybert.backend._base import BaseEmbedder
from keybert._utils import NotInstalled

__all__ = ["BaseEmbedder", "SentenceTransformerBackend"]
# Sentence Transformers
try:
from ._sentencetransformers import SentenceTransformerBackend
except ModuleNotFoundError:
msg = "`pip install sentence-transformers`"
SentenceTransformerBackend = NotInstalled("Sentence-Transformers", "sentence-transformers", custom_msg=msg)

# Model2Vec
try:
from ._model2vec import Model2VecBackend
except ModuleNotFoundError:
msg = "`pip install model2vec`"
Model2VecBackend = NotInstalled("Model2Vec", "Model2Vec", custom_msg=msg)


__all__ = ["BaseEmbedder", "SentenceTransformerBackend", "Model2VecBackend"]
123 changes: 123 additions & 0 deletions keybert/backend/_model2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
from typing import List, Union
from model2vec import StaticModel
from sklearn.feature_extraction.text import CountVectorizer

from keybert.backend import BaseEmbedder


class Model2VecBackend(BaseEmbedder):
"""Model2Vec embedding model.
Arguments:
embedding_model: Either a model2vec model or a
string pointing to a model2vec model
distill: Indicates whether to distill a sentence-transformers compatible model.
The distillation will happen during fitting of the topic model.
NOTE: Only works if `embedding_model` is a string.
distill_kwargs: Keyword arguments to pass to the distillation process
of `model2vec.distill.distill`
distill_vectorizer: A CountVectorizer used for creating a custom vocabulary
based on the same documents used for topic modeling.
NOTE: If "vocabulary" is in `distill_kwargs`, this will be ignored.
Examples:
To use Model2Vec, simply pass it to the KeyBERT model:
```python
from keybert import KeyBERT
from model2vec import StaticModel
embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M")
# Extract keywords
kw_model = KeyBERT(embedding_model)
keywords = kw_model.extract_keywords(my_docs)
```
If you want to distill a sentence-transformers model with the vocabulary of the documents,
run the following:
```python
from keybert.backend import Model2VecBackend
embedding_model = Model2VecBackend("sentence-transformers/all-MiniLM-L6-v2", distill=True)
```
"""

def __init__(
self,
embedding_model: Union[str, StaticModel],
distill: bool = False,
distill_kwargs: dict = {},
distill_vectorizer: str = None,
):
super().__init__()

self.distill = distill
self.distill_kwargs = distill_kwargs
self.distill_vectorizer = distill_vectorizer
self._has_distilled = False

# When we distill, we need a string pointing to a sentence-transformer model
if self.distill:
self._check_model2vec_installation()
if not self.distill_vectorizer:
self.distill_vectorizer = CountVectorizer()
if isinstance(embedding_model, str):
self.embedding_model = embedding_model
else:
raise ValueError("Please pass a string pointing to a sentence-transformer model when distilling.")

# If we don't distill, we can pass a model2vec model directly or load from a string
elif isinstance(embedding_model, StaticModel):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
self.embedding_model = StaticModel.from_pretrained(embedding_model)
else:
raise ValueError(
"Please select a correct Model2Vec model: \n"
"`from model2vec import StaticModel` \n"
"`model = StaticModel.from_pretrained('minishlab/potion-base-8M')`"
)

def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words into an n-dimensional
matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
# Distill the model
if self.distill and not self._has_distilled:
from model2vec.distill import distill

# Distill with the vocabulary of the documents
if not self.distill_kwargs.get("vocabulary"):
X = self.distill_vectorizer.fit_transform(documents)
word_counts = np.array(X.sum(axis=0)).flatten()
words = self.distill_vectorizer.get_feature_names_out()
vocabulary = [word for word, _ in sorted(zip(words, word_counts), key=lambda x: x[1], reverse=True)]
self.distill_kwargs["vocabulary"] = vocabulary

# Distill the model
self.embedding_model = distill(self.embedding_model, **self.distill_kwargs)

# Distillation should happen only once and not for every embed call
# The distillation should only happen the first time on the entire vocabulary
self._has_distilled = True

# Embed the documents
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
return embeddings

def _check_model2vec_installation(self):
try:
from model2vec.distill import distill # noqa: F401
except ImportError:
raise ImportError("To distill a model using model2vec, you need to run `pip install model2vec[distill]`")
6 changes: 6 additions & 0 deletions keybert/backend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def select_backend(embedding_model) -> BaseEmbedder:
if isinstance(embedding_model, BaseEmbedder):
return embedding_model

# Model2Vec embeddings
if "model2vec" in str(type(embedding_model)):
from keybert.backend._model2vec import Model2VecBackend

return Model2VecBackend(embedding_model)

# Flair word embeddings
if "flair" in str(type(embedding_model)):
from keybert.backend._flair import FlairBackend
Expand Down

0 comments on commit b5f25d5

Please sign in to comment.