Skip to content

Commit

Permalink
Merge pull request #926 from siiddhantt/feature
Browse files Browse the repository at this point in the history
Feature: Logging token usage info to MongoDB
  • Loading branch information
dartpain authored Apr 22, 2024
2 parents 130c83e + ab43c20 commit 8873428
Show file tree
Hide file tree
Showing 14 changed files with 561 additions and 296 deletions.
222 changes: 150 additions & 72 deletions application/api/answer/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
from bson.objectid import ObjectId



from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request



logger = logging.getLogger(__name__)

mongo = MongoClient(settings.MONGO_URI)
Expand All @@ -26,20 +24,22 @@
vectors_collection = db["vectors"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
answer = Blueprint('answer', __name__)
answer = Blueprint("answer", __name__)

gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_NAME == "openai":
gpt_model = 'gpt-3.5-turbo'
gpt_model = "gpt-3.5-turbo"
elif settings.LLM_NAME == "anthropic":
gpt_model = 'claude-2'
gpt_model = "claude-2"

if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME

# load the prompts
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()

Expand All @@ -50,7 +50,7 @@
chat_combine_creative = f.read()

with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read()
chat_combine_strict = f.read()

api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
Expand All @@ -61,8 +61,6 @@ async def async_generate(chain, question, chat_history):
return result




def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
Expand All @@ -74,17 +72,18 @@ def run_async_chain(chain, question, chat_history):
result["answer"] = answer
return result


def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key})
if data is None:
return bad_request(401, "Invalid API key")
return data


def get_vectorstore(data):
if "active_docs" in data:
if data["active_docs"].split("/")[0] == "default":
vectorstore = ""
vectorstore = ""
elif data["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + data["active_docs"]
else:
Expand All @@ -98,52 +97,82 @@ def get_vectorstore(data):


def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)


def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"sources": source_log_docs,
}
}
},
)

else:
# create new conversation
# generate summary
messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " + question + "\n\n" +
"AI: " +
response},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system"}]

completion = llm.gen(model=gpt_model,
messages=messages_summary, max_tokens=30)
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system",
},
]

completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
{
"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"sources": source_log_docs,
}
],
}
).inserted_id
return conversation_id


def get_prompt(prompt_id):
if prompt_id == 'default':
if prompt_id == "default":
prompt = chat_combine_template
elif prompt_id == 'creative':
elif prompt_id == "creative":
prompt = chat_combine_creative
elif prompt_id == 'strict':
elif prompt_id == "strict":
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt


def complete_stream(question, retriever, conversation_id):


def complete_stream(question, retriever, conversation_id, user_api_key):

response_full = ""
source_log_docs = []
answer = retriever.gen()
Expand All @@ -155,9 +184,12 @@ def complete_stream(question, retriever, conversation_id):
elif "source" in line:
source_log_docs.append(line["source"])


llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)

# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
Expand All @@ -180,41 +212,60 @@ def stream():
conversation_id = None
else:
conversation_id = data["conversation_id"]
if 'prompt_id' in data:
if "prompt_id" in data:
prompt_id = data["prompt_id"]
else:
prompt_id = 'default'
if 'selectedDocs' in data and data['selectedDocs'] is None:
prompt_id = "default"
if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0
elif 'chunks' in data:
elif "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2

prompt = get_prompt(prompt_id)

# check if active_docs is set

if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}
user_api_key = None

if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic"
else:
retriever_name = source['active_docs']

retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
retriever_name = source["active_docs"]

retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
user_api_key=user_api_key,
)

return Response(
complete_stream(question=question, retriever=retriever,
conversation_id=conversation_id), mimetype="text/event-stream")
complete_stream(
question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
),
mimetype="text/event-stream",
)


@answer.route("/api/answer", methods=["POST"])
Expand All @@ -230,15 +281,15 @@ def api_answer():
else:
conversation_id = data["conversation_id"]
print("-" * 5)
if 'prompt_id' in data:
if "prompt_id" in data:
prompt_id = data["prompt_id"]
else:
prompt_id = 'default'
if 'chunks' in data:
prompt_id = "default"
if "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2

prompt = get_prompt(prompt_id)

# use try and except to check for exception
Expand All @@ -247,30 +298,45 @@ def api_answer():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
else:
source = {data}
user_api_key = None

if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic"
else:
retriever_name = source['active_docs']

retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
retriever_name = source["active_docs"]

retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]

llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)


llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)

result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
result["conversation_id"] = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)

return result
except Exception as e:
Expand All @@ -289,23 +355,35 @@ def api_search():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}
if 'chunks' in data:
user_api_key = None
if "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2

if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic"
else:
retriever_name = source['active_docs']

retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
)
retriever_name = source["active_docs"]

retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
docs = retriever.search()
return docs

Loading

0 comments on commit 8873428

Please sign in to comment.