-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: Store chat history in Cosmos DB #2063
Changes from 13 commits
bbc74de
6846c4d
07fe2ae
0189dfc
44bfc40
11c4d9f
9863e67
b0aacb2
e0e85f1
ec12bca
867d3d3
50c6ea9
4dc3b48
c2a810f
e48535d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import os | ||
import time | ||
from typing import Any, Dict, Union | ||
|
||
from azure.cosmos.aio import ContainerProxy, CosmosClient | ||
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential | ||
from quart import Blueprint, current_app, jsonify, request | ||
|
||
from config import ( | ||
CONFIG_CHAT_HISTORY_COSMOS_ENABLED, | ||
CONFIG_COSMOS_HISTORY_CLIENT, | ||
CONFIG_COSMOS_HISTORY_CONTAINER, | ||
CONFIG_CREDENTIAL, | ||
) | ||
from decorators import authenticated | ||
from error import error_response | ||
|
||
chat_history_cosmosdb_bp = Blueprint("chat_history_cosmos", __name__, static_folder="static") | ||
|
||
|
||
@chat_history_cosmosdb_bp.post("/chat_history") | ||
@authenticated | ||
async def post_chat_history(auth_claims: Dict[str, Any]): | ||
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] | ||
if not container: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
entra_oid = auth_claims.get("oid") | ||
if not entra_oid: | ||
return jsonify({"error": "User OID not found"}), 401 | ||
|
||
try: | ||
request_json = await request.get_json() | ||
id = request_json.get("id") | ||
answers = request_json.get("answers") | ||
title = answers[0][0][:50] + "..." if len(answers[0][0]) > 50 else answers[0][0] | ||
timestamp = int(time.time() * 1000) | ||
|
||
await container.upsert_item( | ||
{"id": id, "entra_oid": entra_oid, "title": title, "answers": answers, "timestamp": timestamp} | ||
) | ||
|
||
return jsonify({}), 201 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wonder whether the status code here should be based on the response from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a nice idea! I can't figure out from the SDK return types what in the response would indicate that however, as all the function signatures just say that they return a dict: https://learn.microsoft.com/en-us/python/api/azure-cosmos/azure.cosmos.container.containerproxy?view=azure-python#azure-cosmos-container-containerproxy-upsert-item There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've taken a look at the dicts for a first-time ID versus a second-time ID, and I can't find any key that'd indicate whether the item already existed: |
||
except Exception as error: | ||
return error_response(error, "/chat_history") | ||
|
||
|
||
@chat_history_cosmosdb_bp.post("/chat_history/items") | ||
@authenticated | ||
async def get_chat_history(auth_claims: Dict[str, Any]): | ||
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] | ||
if not container: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
entra_oid = auth_claims.get("oid") | ||
if not entra_oid: | ||
return jsonify({"error": "User OID not found"}), 401 | ||
|
||
try: | ||
request_json = await request.get_json() | ||
count = request_json.get("count", 20) | ||
continuation_token = request_json.get("continuation_token") | ||
|
||
res = container.query_items( | ||
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC", | ||
parameters=[dict(name="@entra_oid", value=entra_oid)], | ||
max_item_count=count, | ||
) | ||
|
||
# set the continuation token for the next page | ||
pager = res.by_page(continuation_token) | ||
|
||
# Get the first page, and the continuation token | ||
try: | ||
page = await pager.__anext__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider iterating over the by_page result directly and avoiding explicit calls to await pager.anext(), i.e. process each page as soon as it’s available without awaiting when there are no more pages. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you mean to just do: As we use similar code for other Azure Python SDKs elsewhere in this repo. That wouldn't give us the continuation token, right, as that would exhaust all the pages? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but you should still be able to get the continuation token. Something like (untested): items = []
async for page in container.query_items(
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC",
parameters=[{"name": "@entra_oid", "value": entra_oid}],
max_item_count=count,
).by_page(continuation_token):
async for item in page:
items.append(
{
"id": item.get("id"),
"entra_oid": item.get("entra_oid"),
"title": item.get("title", "untitled"),
"timestamp": item.get("timestamp"),
}
)
# Update continuation token after processing the page
continuation_token = page.continuation_token if hasattr(page, "continuation_token") else None
# Break if no continuation token (i.e., last page)
if not continuation_token:
break Just a suggestion. I think its fine as it is :-) |
||
continuation_token = pager.continuation_token # type: ignore | ||
|
||
items = [] | ||
async for item in page: | ||
items.append( | ||
{ | ||
"id": item.get("id"), | ||
"entra_oid": item.get("entra_oid"), | ||
"title": item.get("title", "untitled"), | ||
"timestamp": item.get("timestamp"), | ||
} | ||
) | ||
|
||
# If there are no more pages, StopAsyncIteration is raised | ||
except StopAsyncIteration: | ||
items = [] | ||
continuation_token = None | ||
|
||
return jsonify({"items": items, "continuation_token": continuation_token}), 200 | ||
|
||
except Exception as error: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using more specific cosmos exceptions, e.g. https://learn.microsoft.com/en-us/python/api/azure-cosmos/azure.cosmos.exceptions.cosmoshttpresponseerror?view=azure-python. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice suggestion! I've looked into it and I think that we do want the general Exception catching here, to make sure that we always pass JSON down to the frontend if the server errors, so the user sees that there's been an error. But we might add special handling for the CosmosDB exceptions in future if it makes our logs easier to grok. |
||
return error_response(error, "/chat_history/items") | ||
|
||
|
||
@chat_history_cosmosdb_bp.get("/chat_history/items/<item_id>") | ||
@authenticated | ||
async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str): | ||
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] | ||
if not container: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
entra_oid = auth_claims.get("oid") | ||
if not entra_oid: | ||
return jsonify({"error": "User OID not found"}), 401 | ||
|
||
try: | ||
res = await container.read_item(item=item_id, partition_key=entra_oid) | ||
return ( | ||
jsonify( | ||
{ | ||
"id": res.get("id"), | ||
"entra_oid": res.get("entra_oid"), | ||
"title": res.get("title", "untitled"), | ||
"timestamp": res.get("timestamp"), | ||
"answers": res.get("answers", []), | ||
} | ||
), | ||
200, | ||
) | ||
except Exception as error: | ||
return error_response(error, f"/chat_history/items/{item_id}") | ||
|
||
|
||
@chat_history_cosmosdb_bp.delete("/chat_history/items/<item_id>") | ||
@authenticated | ||
async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str): | ||
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] | ||
if not container: | ||
return jsonify({"error": "Chat history not enabled"}), 400 | ||
|
||
entra_oid = auth_claims.get("oid") | ||
if not entra_oid: | ||
return jsonify({"error": "User OID not found"}), 401 | ||
|
||
try: | ||
await container.delete_item(item=item_id, partition_key=entra_oid) | ||
return jsonify({}), 204 | ||
except Exception as error: | ||
return error_response(error, f"/chat_history/items/{item_id}") | ||
|
||
|
||
@chat_history_cosmosdb_bp.before_app_serving | ||
async def setup_clients(): | ||
USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true" | ||
AZURE_COSMOSDB_ACCOUNT = os.getenv("AZURE_COSMOSDB_ACCOUNT") | ||
AZURE_CHAT_HISTORY_DATABASE = os.getenv("AZURE_CHAT_HISTORY_DATABASE") | ||
AZURE_CHAT_HISTORY_CONTAINER = os.getenv("AZURE_CHAT_HISTORY_CONTAINER") | ||
|
||
azure_credential: Union[AzureDeveloperCliCredential, ManagedIdentityCredential] = current_app.config[ | ||
CONFIG_CREDENTIAL | ||
] | ||
|
||
if USE_CHAT_HISTORY_COSMOS: | ||
current_app.logger.info("USE_CHAT_HISTORY_COSMOS is true, setting up CosmosDB client") | ||
if not AZURE_COSMOSDB_ACCOUNT: | ||
raise ValueError("AZURE_COSMOSDB_ACCOUNT must be set when USE_CHAT_HISTORY_COSMOS is true") | ||
if not AZURE_CHAT_HISTORY_DATABASE: | ||
raise ValueError("AZURE_CHAT_HISTORY_DATABASE must be set when USE_CHAT_HISTORY_COSMOS is true") | ||
if not AZURE_CHAT_HISTORY_CONTAINER: | ||
raise ValueError("AZURE_CHAT_HISTORY_CONTAINER must be set when USE_CHAT_HISTORY_COSMOS is true") | ||
cosmos_client = CosmosClient( | ||
url=f"https://{AZURE_COSMOSDB_ACCOUNT}.documents.azure.com:443/", credential=azure_credential | ||
) | ||
cosmos_db = cosmos_client.get_database_client(AZURE_CHAT_HISTORY_DATABASE) | ||
cosmos_container = cosmos_db.get_container_client(AZURE_CHAT_HISTORY_CONTAINER) | ||
|
||
current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] = cosmos_client | ||
current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] = cosmos_container | ||
|
||
|
||
@chat_history_cosmosdb_bp.after_app_serving | ||
async def close_clients(): | ||
if current_app.config.get(CONFIG_COSMOS_HISTORY_CLIENT): | ||
cosmos_client: CosmosClient = current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] | ||
await cosmos_client.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is called in every function and I assume it never changes. Could this be cached or defined globally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current_app.config
is basically our global dict, it's how we access objects that were setup at the beginning of the app start. I don't think there's a performance hit, since it should be O(1) retrieval.