Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLC-20] server: use a specialized embedding model for document indexing #9

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions server/retriever/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import mlx.core as mx
import mlx.nn as nn

import torch
import torch.nn.functional as F
from torch import Tensor

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer
from abc import ABC, abstractmethod
from typing import Any, List


class Embeddings(ABC):
"""Interface for embedding models."""

@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""

@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""


class E5Embeddings(Embeddings):

model: Any = None
tokenizer: PreTrainedTokenizer = None

def __init__(self, model_name: str = 'intfloat/multilingual-e5-small'):
self.model = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def embed_documents(self, texts: List[str], batch_size: int = 1) -> List[List[float]]:
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
batch_embeddings = self.embed_query(batch_texts, batch=True)
embeddings.extend(batch_embeddings)
return embeddings

@torch.no_grad()
def embed_query(self, texts: Any, batch: bool = False) -> List[Any]:
batch_dict = self.tokenizer(texts, max_length=512, padding=True,
truncation=True, return_tensors='pt', return_attention_mask=True)
outputs = self.model(**batch_dict)
embeddings = self._average_pool(
outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)

if batch:
return embeddings.tolist() # -> List[List[float]]

return embeddings[0].tolist() # -> List[float]


class ChatEmbeddings(Embeddings):

model: nn.Module = None
tokenizer: PreTrainedTokenizer = None

def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer):
self.model = model
self.tokenizer = tokenizer

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
h = self.model.embed_tokens(mx.array(
self.tokenizer.encode(text, add_special_tokens=False)))
# normalized to have unit length
h = mx.mean(h, axis=0)
h = h / mx.linalg.norm(h)
return h.tolist()
23 changes: 2 additions & 21 deletions server/retriever/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Callable,
Iterable,
Optional,
Literal,
Tuple,
Type,
)
Expand All @@ -21,6 +20,8 @@
import chromadb.config
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument

from .embeddings import Embeddings

Chroma = TypeVar('Chroma', bound='Chroma')


Expand Down Expand Up @@ -118,26 +119,6 @@ def maximal_marginal_relevance(
return idxs


class Embeddings():

type: Literal["Embeddings"] = "Embeddings"

def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
h = self.model.embed_tokens(mx.array(
self.tokenizer.encode(text, add_special_tokens=False)))
# normalized to have unit length
h = mx.mean(h, axis=0)
h = h / mx.linalg.norm(h)
return h.tolist()


class Chroma():
"""
similarity_search
Expand Down
36 changes: 21 additions & 15 deletions server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from .retriever.loader import directory_loader
from .retriever.splitter import RecursiveCharacterTextSplitter
from .retriever.vectorstore import Chroma, Embeddings
from .retriever.vectorstore import Chroma
from .retriever.embeddings import ChatEmbeddings, E5Embeddings

_model: Optional[nn.Module] = None
_tokenizer: Optional[PreTrainedTokenizer] = None
Expand All @@ -27,16 +28,20 @@ def load_model(model_path: str, adapter_file: Optional[str] = None):
_model, _tokenizer = load(model_path, adapter_file=adapter_file)


def load_database(directory: str):
def load_database(directory: str, use_embedding: bool = True):
global _database
# TODO: handle error from directory_loader on invalid
raw_docs = directory_loader(directory)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=4000, chunk_overlap=200, add_start_index=True
chunk_size=512, chunk_overlap=32, add_start_index=True
)
embedding = E5Embeddings() if use_embedding else ChatEmbeddings(
model=_model.model, tokenizer=_tokenizer)
splits = text_splitter.split_documents(raw_docs)
_database = Chroma.from_documents(
documents=splits, embedding=Embeddings(_model.model, _tokenizer))
documents=splits,
embedding=embedding
)


def create_response(chat_id, prompt, tokens, text):
Expand Down Expand Up @@ -66,24 +71,23 @@ def create_response(chat_id, prompt, tokens, text):
return response


def format_messages(messages, condition):
def format_messages(messages, context):
failedString = "ERROR"
if condition:
if context:
messages[-1]['content'] = f"""
Only using the documents in the index, answer the following, Respond with just the answer, no "The answer is" or "Answer: " or anything like that.

Question:
Only using the documents in the index, answer the following, respond with just the answer without "The answer is:" or "Answer:" or anything like that.

<BEGIN_QUESTION>
{messages[-1]['content']}
</END_QUESTION>

Index:

{condition}
<BEGIN_INDEX>
{context}
</END_INDEX>

Remember, if you do not know the answer, just say "{failedString}",
Try to give as much detail as possible, but only from what is provided within the index.
If steps are given, you MUST ALWAYS use bullet points to list each of them them and you MUST use markdown when applicable.
You MUST markdown when applicable.
Only use information you can find in the index, do not make up knowledge.
Remember, use bullet points or numbered steps to better organize your answer if applicable.
NEVER try to make up the answer, always return "{failedString}" if you do not know the answer or it's not provided in the index.
Expand Down Expand Up @@ -122,9 +126,11 @@ def handle_post_request(self, post_data):
chat_id = f'chatcmpl-{uuid.uuid4()}'

load_database(body.get('directory', None))
# emperically better than similarity_search
# emperically better than `similarity_search`
docs = _database.max_marginal_relevance_search(
body['messages'][-1]['content'])
body['messages'][-1]['content'],
k=4 # number of documents to return
)
context = '\n'.join([doc.page_content for doc in docs])
print(body, flush=True)
print(('\n'+'--'*10+'\n').join([
Expand Down
Loading