-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag.py
45 lines (35 loc) · 1.98 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores.chroma import Chroma
from tools.prompt_templates import rag_prompt_template
class RetrievalAugmentedGeneration:
def __init__(self, api_base, text_file_path):
self.api_base = api_base
self.text_file_path = text_file_path
self.vector_store = self.create_vector_store()
self.qa_chain = self.setup_qa_chain()
def create_vector_store(self):
texts = self.load_texts(self.text_file_path)
embeddings = OpenAIEmbeddings() #TODO REPLACE HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5")
return Chroma.from_documents(texts, embeddings).as_retriever()
def setup_qa_chain(self):
llm = ChatOpenAI(model_name="gpt-4")
return RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
retriever=self.vector_store,
chain_type_kwargs={"prompt": rag_prompt_template},
return_source_documents = True)
def add_documents(self, texts):
# Split the input texts into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len)
split_texts = text_splitter.split_documents(texts)
# Add these chunks to the vector store
self.vector_store.add_documents(split_texts)
def load_texts(self, text_file_path):
loader = TextLoader(text_file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len)
return text_splitter.split_documents(documents)