-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
86 lines (63 loc) · 2.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Description: Entrypoint for the application.
"""
from dotenv import load_dotenv
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import List
from utils.io import Output, ChatRequest, FileUploadRequest, generate_chat_history, generate_reference_output, generate_formatted_docs, clean_text
from ingest import ingest_docs
from chain import get_chain
from upload import upload_to_gcs
load_dotenv()
app = FastAPI(
title="LangChain Server",
description="A simple API server using LangChain's Runnable interfaces",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
@app.post("/agent/")
async def agent(request: ChatRequest) -> dict:
"""Handle a request."""
db = ingest_docs(request.params.langchain_params, file_names = request.file_name)
chain = get_chain(request.model, request.params.model_params.dict())
chat_history = generate_chat_history(request.chat_history)
docs = db.similarity_search(request.input, k=request.regen_count + 1)
formatted_docs = clean_text(generate_formatted_docs(docs))
#reference_output = generate_reference_output(docs)
model_output = chain.run({
"instruction": request.instruction,
"input": request.input,
"chat_history": chat_history,
"retrieved_document": formatted_docs,
})
response_data = {
"model": request.model,
"model_params": request.params.model_params,
"input": request.input,
"answer": model_output,
"reference1": formatted_docs,
#"page": reference_output,
"langchain_params": request.params.langchain_params,
}
return response_data
@app.post("/feedback/")
async def feedback(request):
"""Handle feedbacks"""
print(request)
return {"output": "OK"}
@app.post("/upload_file/")
async def upload_file(request: FileUploadRequest):
"""Handle files upload to GCS"""
upload_to_gcs(project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"), url=request.url, file_name=request.file_name)
return {"message": "file uploaded successfully"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)