-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
105 lines (85 loc) · 3.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# import required dependencies
# https://docs.chainlit.io/integrations/langchain
import os
from langchain import hub
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import chainlit as cl
from langchain.chains import RetrievalQA
ABS_PATH: str = os.path.dirname(os.path.abspath(__file__))
DB_DIR: str = os.path.join(ABS_PATH, "db")
# Set up RetrievelQA model
rag_prompt_mistral = hub.pull("rlm/rag-prompt-mistral")
def load_model():
llm = Ollama(
model="mistral",
verbose=True,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
)
return llm
def retrieval_qa_chain(llm, vectorstore):
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectorstore.as_retriever(),
chain_type_kwargs={"prompt": rag_prompt_mistral},
return_source_documents=True,
)
return qa_chain
def qa_bot():
llm = load_model()
DB_PATH = DB_DIR
vectorstore = Chroma(
persist_directory=DB_PATH, embedding_function=OllamaEmbeddings(model="mistral")
)
qa = retrieval_qa_chain(llm, vectorstore)
return qa
@cl.on_chat_start
async def start():
"""
Initializes the bot when a new chat starts.
This asynchronous function creates a new instance of the retrieval QA bot,
sends a welcome message, and stores the bot instance in the user's session.
"""
chain = qa_bot()
welcome_message = cl.Message(content="Starting the bot...")
await welcome_message.send()
welcome_message.content = (
"Hi, Welcome to Chat With Documents using Ollama (mistral model) and LangChain."
)
await welcome_message.update()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message):
"""
Processes incoming chat messages.
This asynchronous function retrieves the QA bot instance from the user's session,
sets up a callback handler for the bot's response, and executes the bot's
call method with the given message and callback. The bot's answer and source
documents are then extracted from the response.
"""
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler()
cb.answer_reached = True
# res=await chain.acall(message, callbacks=[cb])
res = await chain.acall(message.content, callbacks=[cb])
#print(f"response: {res}")
answer = res["result"]
#answer = answer.replace(".", ".\n")
source_documents = res["source_documents"]
text_elements = [] # type: List[cl.Text]
if source_documents:
for source_idx, source_doc in enumerate(source_documents):
source_name = f"source_{source_idx}"
# Create the text element referenced in the message
text_elements.append(
cl.Text(content=source_doc.page_content, name=source_name)
)
source_names = [text_el.name for text_el in text_elements]
if source_names:
answer += f"\nSources: {', '.join(source_names)}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=text_elements).send()