Skip to content

Commit

Permalink
Merge pull request #410 from aidangomez/patch-1
Browse files Browse the repository at this point in the history
Add CohereVectorizer for Embed 3
  • Loading branch information
okhat authored Feb 21, 2024
2 parents eb74b1d + 74b66cd commit 44941a2
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion dsp/modules/sentence_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,50 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
return embeddings


class CohereVectorizer(BaseSentenceVectorizer):
'''
This vectorizer uses the Cohere API to convert texts to embeddings.
More about the available models: https://docs.cohere.com/reference/embed
`api_key` should be passed as an argument and can be retrieved
from https://dashboard.cohere.com/api-keys
'''
def __init__(
self,
api_key: str,
model: str = 'embed-english-v3.0',
embed_batch_size: int = 96,
embedding_type: str = 'search_document' # for details check Cohere embed docs
):
self.model = model
self.embed_batch_size = embed_batch_size
self.embedding_type = embedding_type

import cohere
self.client = cohere.Client(api_key)

def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
text_to_vectorize = self._extract_text_from_examples(inp_examples)

embeddings_list = []

n_batches = (len(text_to_vectorize) - 1) // self.embed_batch_size + 1
for cur_batch_idx in range(n_batches):
start_idx = cur_batch_idx * self.embed_batch_size
end_idx = (cur_batch_idx + 1) * self.embed_batch_size
cur_batch = text_to_vectorize[start_idx: end_idx]

response = self.client.embed(
texts=cur_batch,
model=self.model,
input_type=self.embedding_type
)

embeddings_list.extend(response.embeddings)

embeddings = np.array(embeddings_list, dtype=np.float32)
return embeddings


try:
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
except Exception:
Expand Down Expand Up @@ -158,4 +202,4 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
embeddings_list.extend(cur_batch_embeddings)

embeddings = np.array(embeddings_list, dtype=np.float32)
return embeddings
return embeddings

0 comments on commit 44941a2

Please sign in to comment.