-
Notifications
You must be signed in to change notification settings - Fork 1
/
cli.py
32 lines (25 loc) · 1.11 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from configs.model_config import *
from chains.local_doc_qa import LocalDocQA
import nltk
import models.shared as shared
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
def main():
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(embedding_model=EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
top_k=VECTOR_SEARCH_TOP_K)
filepath = "D:\\qa"
vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
llm_model_ins = shared.loaderLLM(LLM_MODEL)
llm_model_ins.history_len = LLM_HISTORY_LEN
local_doc_qa.llm = llm_model_ins
history = []
while True:
query = input("Input your question 请输入问题:")
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=history,
streaming=STREAMING):
print(resp["result"])
if __name__ == "__main__":
main()