Skip to content

Commit

Permalink
Merge pull request #9 from mlx-chat/MLC-20
Browse files Browse the repository at this point in the history
[MLC-20] server: use a specialized embedding model for document indexing
  • Loading branch information
stockeh authored Feb 28, 2024
2 parents 47a0c2d + 22a70f6 commit 5fc726a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 36 deletions.
80 changes: 80 additions & 0 deletions server/retriever/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import mlx.core as mx
import mlx.nn as nn

import torch
import torch.nn.functional as F
from torch import Tensor

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer
from abc import ABC, abstractmethod
from typing import Any, List


class Embeddings(ABC):
"""Interface for embedding models."""

@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""

@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""


class E5Embeddings(Embeddings):

model: Any = None
tokenizer: PreTrainedTokenizer = None

def __init__(self, model_name: str = 'intfloat/multilingual-e5-small'):
self.model = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def embed_documents(self, texts: List[str], batch_size: int = 1) -> List[List[float]]:
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
batch_embeddings = self.embed_query(batch_texts, batch=True)
embeddings.extend(batch_embeddings)
return embeddings

@torch.no_grad()
def embed_query(self, texts: Any, batch: bool = False) -> List[Any]:
batch_dict = self.tokenizer(texts, max_length=512, padding=True,
truncation=True, return_tensors='pt', return_attention_mask=True)
outputs = self.model(**batch_dict)
embeddings = self._average_pool(
outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)

if batch:
return embeddings.tolist() # -> List[List[float]]

return embeddings[0].tolist() # -> List[float]


class ChatEmbeddings(Embeddings):

model: nn.Module = None
tokenizer: PreTrainedTokenizer = None

def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer):
self.model = model
self.tokenizer = tokenizer

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
h = self.model.embed_tokens(mx.array(
self.tokenizer.encode(text, add_special_tokens=False)))
# normalized to have unit length
h = mx.mean(h, axis=0)
h = h / mx.linalg.norm(h)
return h.tolist()
23 changes: 2 additions & 21 deletions server/retriever/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Callable,
Iterable,
Optional,
Literal,
Tuple,
Type,
)
Expand All @@ -21,6 +20,8 @@
import chromadb.config
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument

from .embeddings import Embeddings

Chroma = TypeVar('Chroma', bound='Chroma')


Expand Down Expand Up @@ -118,26 +119,6 @@ def maximal_marginal_relevance(
return idxs


class Embeddings():

type: Literal["Embeddings"] = "Embeddings"

def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
h = self.model.embed_tokens(mx.array(
self.tokenizer.encode(text, add_special_tokens=False)))
# normalized to have unit length
h = mx.mean(h, axis=0)
h = h / mx.linalg.norm(h)
return h.tolist()


class Chroma():
"""
similarity_search
Expand Down
36 changes: 21 additions & 15 deletions server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from .retriever.loader import directory_loader
from .retriever.splitter import RecursiveCharacterTextSplitter
from .retriever.vectorstore import Chroma, Embeddings
from .retriever.vectorstore import Chroma
from .retriever.embeddings import ChatEmbeddings, E5Embeddings

_model: Optional[nn.Module] = None
_tokenizer: Optional[PreTrainedTokenizer] = None
Expand All @@ -27,16 +28,20 @@ def load_model(model_path: str, adapter_file: Optional[str] = None):
_model, _tokenizer = load(model_path, adapter_file=adapter_file)


def load_database(directory: str):
def load_database(directory: str, use_embedding: bool = True):
global _database
# TODO: handle error from directory_loader on invalid
raw_docs = directory_loader(directory)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=4000, chunk_overlap=200, add_start_index=True
chunk_size=512, chunk_overlap=32, add_start_index=True
)
embedding = E5Embeddings() if use_embedding else ChatEmbeddings(
model=_model.model, tokenizer=_tokenizer)
splits = text_splitter.split_documents(raw_docs)
_database = Chroma.from_documents(
documents=splits, embedding=Embeddings(_model.model, _tokenizer))
documents=splits,
embedding=embedding
)


def create_response(chat_id, prompt, tokens, text):
Expand Down Expand Up @@ -66,24 +71,23 @@ def create_response(chat_id, prompt, tokens, text):
return response


def format_messages(messages, condition):
def format_messages(messages, context):
failedString = "ERROR"
if condition:
if context:
messages[-1]['content'] = f"""
Only using the documents in the index, answer the following, Respond with just the answer, no "The answer is" or "Answer: " or anything like that.
Question:
Only using the documents in the index, answer the following, respond with just the answer without "The answer is:" or "Answer:" or anything like that.
<BEGIN_QUESTION>
{messages[-1]['content']}
</END_QUESTION>
Index:
{condition}
<BEGIN_INDEX>
{context}
</END_INDEX>
Remember, if you do not know the answer, just say "{failedString}",
Try to give as much detail as possible, but only from what is provided within the index.
If steps are given, you MUST ALWAYS use bullet points to list each of them them and you MUST use markdown when applicable.
You MUST markdown when applicable.
Only use information you can find in the index, do not make up knowledge.
Remember, use bullet points or numbered steps to better organize your answer if applicable.
NEVER try to make up the answer, always return "{failedString}" if you do not know the answer or it's not provided in the index.
Expand Down Expand Up @@ -122,9 +126,11 @@ def handle_post_request(self, post_data):
chat_id = f'chatcmpl-{uuid.uuid4()}'

load_database(body.get('directory', None))
# emperically better than similarity_search
# emperically better than `similarity_search`
docs = _database.max_marginal_relevance_search(
body['messages'][-1]['content'])
body['messages'][-1]['content'],
k=4 # number of documents to return
)
context = '\n'.join([doc.page_content for doc in docs])
print(body, flush=True)
print(('\n'+'--'*10+'\n').join([
Expand Down

0 comments on commit 5fc726a

Please sign in to comment.