Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Store chat history in Cosmos DB #2063

Merged
merged 15 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ However, you can try the [Azure pricing calculator](https://azure.com/e/a87a169b
- Azure AI Document Intelligence: SO (Standard) tier using pre-built layout. Pricing per document page, sample documents have 261 pages total. [Pricing](https://azure.microsoft.com/pricing/details/form-recognizer/)
- Azure AI Search: Basic tier, 1 replica, free level of semantic search. Pricing per hour. [Pricing](https://azure.microsoft.com/pricing/details/search/)
- Azure Blob Storage: Standard tier with ZRS (Zone-redundant storage). Pricing per storage and read operations. [Pricing](https://azure.microsoft.com/pricing/details/storage/blobs/)
- Azure Cosmos DB: Serverless tier. Pricing per request unit and storage. [Pricing](https://azure.microsoft.com/pricing/details/cosmos-db/)
- Azure Monitor: Pay-as-you-go tier. Costs based on data ingested. [Pricing](https://azure.microsoft.com/pricing/details/monitor/)

To reduce costs, you can switch to free SKUs for various services, but those SKUs have limitations.
Expand Down
20 changes: 17 additions & 3 deletions app/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@
from approaches.chatreadretrievereadvision import ChatReadRetrieveReadVisionApproach
from approaches.retrievethenread import RetrieveThenReadApproach
from approaches.retrievethenreadvision import RetrieveThenReadVisionApproach
from chat_history.cosmosdb import chat_history_cosmosdb_bp
from config import (
CONFIG_ASK_APPROACH,
CONFIG_ASK_VISION_APPROACH,
CONFIG_AUTH_CLIENT,
CONFIG_BLOB_CONTAINER_CLIENT,
CONFIG_CHAT_APPROACH,
CONFIG_CHAT_HISTORY_BROWSER_ENABLED,
CONFIG_CHAT_HISTORY_COSMOS_ENABLED,
CONFIG_CHAT_VISION_APPROACH,
CONFIG_CREDENTIAL,
CONFIG_GPT4V_DEPLOYED,
Expand Down Expand Up @@ -224,7 +226,10 @@ async def chat(auth_claims: Dict[str, Any]):
# else creates a new session_id depending on the chat history options enabled.
session_state = request_json.get("session_state")
if session_state is None:
session_state = create_session_id(current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED])
session_state = create_session_id(
current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED],
current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED],
)
result = await approach.run(
request_json["messages"],
context=context,
Expand Down Expand Up @@ -255,7 +260,10 @@ async def chat_stream(auth_claims: Dict[str, Any]):
# else creates a new session_id depending on the chat history options enabled.
session_state = request_json.get("session_state")
if session_state is None:
session_state = create_session_id(current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED])
session_state = create_session_id(
current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED],
current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED],
)
result = await approach.run_stream(
request_json["messages"],
context=context,
Expand Down Expand Up @@ -289,6 +297,7 @@ def config():
"showSpeechOutputBrowser": current_app.config[CONFIG_SPEECH_OUTPUT_BROWSER_ENABLED],
"showSpeechOutputAzure": current_app.config[CONFIG_SPEECH_OUTPUT_AZURE_ENABLED],
"showChatHistoryBrowser": current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED],
"showChatHistoryCosmos": current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED],
}
)

Expand Down Expand Up @@ -455,6 +464,7 @@ async def setup_clients():
USE_SPEECH_OUTPUT_BROWSER = os.getenv("USE_SPEECH_OUTPUT_BROWSER", "").lower() == "true"
USE_SPEECH_OUTPUT_AZURE = os.getenv("USE_SPEECH_OUTPUT_AZURE", "").lower() == "true"
USE_CHAT_HISTORY_BROWSER = os.getenv("USE_CHAT_HISTORY_BROWSER", "").lower() == "true"
USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true"

# WEBSITE_HOSTNAME is always set by App Service, RUNNING_IN_PRODUCTION is set in main.bicep
RUNNING_ON_AZURE = os.getenv("WEBSITE_HOSTNAME") is not None or os.getenv("RUNNING_IN_PRODUCTION") is not None
Expand Down Expand Up @@ -484,6 +494,9 @@ async def setup_clients():
current_app.logger.info("Setting up Azure credential using AzureDeveloperCliCredential for home tenant")
azure_credential = AzureDeveloperCliCredential(process_timeout=60)

# Set the Azure credential in the app config for use in other parts of the app
current_app.config[CONFIG_CREDENTIAL] = azure_credential

# Set up clients for AI Search and Storage
search_client = SearchClient(
endpoint=f"https://{AZURE_SEARCH_SERVICE}.search.windows.net",
Expand Down Expand Up @@ -573,7 +586,6 @@ async def setup_clients():
current_app.config[CONFIG_SPEECH_SERVICE_VOICE] = AZURE_SPEECH_SERVICE_VOICE
# Wait until token is needed to fetch for the first time
current_app.config[CONFIG_SPEECH_SERVICE_TOKEN] = None
current_app.config[CONFIG_CREDENTIAL] = azure_credential

if OPENAI_HOST.startswith("azure"):
if OPENAI_HOST == "azure_custom":
Expand Down Expand Up @@ -628,6 +640,7 @@ async def setup_clients():
current_app.config[CONFIG_SPEECH_OUTPUT_BROWSER_ENABLED] = USE_SPEECH_OUTPUT_BROWSER
current_app.config[CONFIG_SPEECH_OUTPUT_AZURE_ENABLED] = USE_SPEECH_OUTPUT_AZURE
current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED] = USE_CHAT_HISTORY_BROWSER
current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED] = USE_CHAT_HISTORY_COSMOS

# Various approaches to integrate GPT and external knowledge, most applications will use a single one of these patterns
# or some derivative, here we include several for exploration purposes
Expand Down Expand Up @@ -717,6 +730,7 @@ async def close_clients():
def create_app():
app = Quart(__name__)
app.register_blueprint(bp)
app.register_blueprint(chat_history_cosmosdb_bp)

if os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING"):
app.logger.info("APPLICATIONINSIGHTS_CONNECTION_STRING is set, enabling Azure Monitor")
Expand Down
Empty file.
192 changes: 192 additions & 0 deletions app/backend/chat_history/cosmosdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
import time
from typing import Any, Dict, Union

from azure.cosmos.aio import ContainerProxy, CosmosClient
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential
from quart import Blueprint, current_app, jsonify, request

from config import (
CONFIG_CHAT_HISTORY_COSMOS_ENABLED,
CONFIG_COSMOS_HISTORY_CLIENT,
CONFIG_COSMOS_HISTORY_CONTAINER,
CONFIG_CREDENTIAL,
)
from decorators import authenticated
from error import error_response

chat_history_cosmosdb_bp = Blueprint("chat_history_cosmos", __name__, static_folder="static")


@chat_history_cosmosdb_bp.post("/chat_history")
@authenticated
async def post_chat_history(auth_claims: Dict[str, Any]):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called in every function and I assume it never changes. Could this be cached or defined globally?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_app.config is basically our global dict, it's how we access objects that were setup at the beginning of the app start. I don't think there's a performance hit, since it should be O(1) retrieval.

if not container:
return jsonify({"error": "Chat history not enabled"}), 400

entra_oid = auth_claims.get("oid")
if not entra_oid:
return jsonify({"error": "User OID not found"}), 401

try:
request_json = await request.get_json()
id = request_json.get("id")
answers = request_json.get("answers")
title = answers[0][0][:50] + "..." if len(answers[0][0]) > 50 else answers[0][0]
timestamp = int(time.time() * 1000)

await container.upsert_item(
{"id": id, "entra_oid": entra_oid, "title": title, "answers": answers, "timestamp": timestamp}
)

return jsonify({}), 201

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder whether the status code here should be based on the response from container.upsert_item - if the item already already exists, then it will be 200 instead of 201

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice idea! I can't figure out from the SDK return types what in the response would indicate that however, as all the function signatures just say that they return a dict: https://learn.microsoft.com/en-us/python/api/azure-cosmos/azure.cosmos.container.containerproxy?view=azure-python#azure-cosmos-container-containerproxy-upsert-item
And this example doesn't do anything with the dict: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/cosmos/azure-cosmos/samples/document_management.py#L149-L156
So is there some listing of what key in the dict would indicate it already existing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've taken a look at the dicts for a first-time ID versus a second-time ID, and I can't find any key that'd indicate whether the item already existed:
https://www.diffchecker.com/Fs0ECaUa/
So I don't think we should change the status code currently, as I'm not sure what condition to use, but happy to change in the future if you have suggestions.

except Exception as error:
return error_response(error, "/chat_history")


@chat_history_cosmosdb_bp.post("/chat_history/items")
@authenticated
async def get_chat_history(auth_claims: Dict[str, Any]):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER]
if not container:
return jsonify({"error": "Chat history not enabled"}), 400

entra_oid = auth_claims.get("oid")
if not entra_oid:
return jsonify({"error": "User OID not found"}), 401

try:
request_json = await request.get_json()
count = request_json.get("count", 20)
continuation_token = request_json.get("continuation_token")

res = container.query_items(
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC",
parameters=[dict(name="@entra_oid", value=entra_oid)],
max_item_count=count,
)

# set the continuation token for the next page
pager = res.by_page(continuation_token)

# Get the first page, and the continuation token
try:
page = await pager.__anext__()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider iterating over the by_page result directly and avoiding explicit calls to await pager.anext(), i.e. process each page as soon as it’s available without awaiting when there are no more pages.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean to just do:
async for page in pager:
async for item in page:

As we use similar code for other Azure Python SDKs elsewhere in this repo. That wouldn't give us the continuation token, right, as that would exhaust all the pages?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but you should still be able to get the continuation token. Something like (untested):

        items = []
        async for page in container.query_items(
            query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC",
            parameters=[{"name": "@entra_oid", "value": entra_oid}],
            max_item_count=count,
        ).by_page(continuation_token):
            async for item in page:
                items.append(
                    {
                        "id": item.get("id"),
                        "entra_oid": item.get("entra_oid"),
                        "title": item.get("title", "untitled"),
                        "timestamp": item.get("timestamp"),
                    }
                )
            # Update continuation token after processing the page
            continuation_token = page.continuation_token if hasattr(page, "continuation_token") else None

            # Break if no continuation token (i.e., last page)
            if not continuation_token:
                break

Just a suggestion. I think its fine as it is :-)

continuation_token = pager.continuation_token # type: ignore

items = []
async for item in page:
items.append(
{
"id": item.get("id"),
"entra_oid": item.get("entra_oid"),
"title": item.get("title", "untitled"),
"timestamp": item.get("timestamp"),
}
)

# If there are no more pages, StopAsyncIteration is raised
except StopAsyncIteration:
items = []
continuation_token = None

return jsonify({"items": items, "continuation_token": continuation_token}), 200

except Exception as error:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion! I've looked into it and I think that we do want the general Exception catching here, to make sure that we always pass JSON down to the frontend if the server errors, so the user sees that there's been an error. But we might add special handling for the CosmosDB exceptions in future if it makes our logs easier to grok.

return error_response(error, "/chat_history/items")


@chat_history_cosmosdb_bp.get("/chat_history/items/<item_id>")
@authenticated
async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER]
if not container:
return jsonify({"error": "Chat history not enabled"}), 400

entra_oid = auth_claims.get("oid")
if not entra_oid:
return jsonify({"error": "User OID not found"}), 401

try:
res = await container.read_item(item=item_id, partition_key=entra_oid)
return (
jsonify(
{
"id": res.get("id"),
"entra_oid": res.get("entra_oid"),
"title": res.get("title", "untitled"),
"timestamp": res.get("timestamp"),
"answers": res.get("answers", []),
}
),
200,
)
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")


@chat_history_cosmosdb_bp.delete("/chat_history/items/<item_id>")
@authenticated
async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER]
if not container:
return jsonify({"error": "Chat history not enabled"}), 400

entra_oid = auth_claims.get("oid")
if not entra_oid:
return jsonify({"error": "User OID not found"}), 401

try:
await container.delete_item(item=item_id, partition_key=entra_oid)
return jsonify({}), 204
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")


@chat_history_cosmosdb_bp.before_app_serving
async def setup_clients():
USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true"
AZURE_COSMOSDB_ACCOUNT = os.getenv("AZURE_COSMOSDB_ACCOUNT")
AZURE_CHAT_HISTORY_DATABASE = os.getenv("AZURE_CHAT_HISTORY_DATABASE")
AZURE_CHAT_HISTORY_CONTAINER = os.getenv("AZURE_CHAT_HISTORY_CONTAINER")

azure_credential: Union[AzureDeveloperCliCredential, ManagedIdentityCredential] = current_app.config[
CONFIG_CREDENTIAL
]

if USE_CHAT_HISTORY_COSMOS:
current_app.logger.info("USE_CHAT_HISTORY_COSMOS is true, setting up CosmosDB client")
if not AZURE_COSMOSDB_ACCOUNT:
raise ValueError("AZURE_COSMOSDB_ACCOUNT must be set when USE_CHAT_HISTORY_COSMOS is true")
if not AZURE_CHAT_HISTORY_DATABASE:
raise ValueError("AZURE_CHAT_HISTORY_DATABASE must be set when USE_CHAT_HISTORY_COSMOS is true")
if not AZURE_CHAT_HISTORY_CONTAINER:
raise ValueError("AZURE_CHAT_HISTORY_CONTAINER must be set when USE_CHAT_HISTORY_COSMOS is true")
cosmos_client = CosmosClient(
url=f"https://{AZURE_COSMOSDB_ACCOUNT}.documents.azure.com:443/", credential=azure_credential
)
cosmos_db = cosmos_client.get_database_client(AZURE_CHAT_HISTORY_DATABASE)
cosmos_container = cosmos_db.get_container_client(AZURE_CHAT_HISTORY_CONTAINER)

current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] = cosmos_client
current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] = cosmos_container


@chat_history_cosmosdb_bp.after_app_serving
async def close_clients():
if current_app.config.get(CONFIG_COSMOS_HISTORY_CLIENT):
cosmos_client: CosmosClient = current_app.config[CONFIG_COSMOS_HISTORY_CLIENT]
await cosmos_client.close()
3 changes: 3 additions & 0 deletions app/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@
CONFIG_SPEECH_SERVICE_TOKEN = "speech_service_token"
CONFIG_SPEECH_SERVICE_VOICE = "speech_service_voice"
CONFIG_CHAT_HISTORY_BROWSER_ENABLED = "chat_history_browser_enabled"
CONFIG_CHAT_HISTORY_COSMOS_ENABLED = "chat_history_cosmos_enabled"
CONFIG_COSMOS_HISTORY_CLIENT = "cosmos_history_client"
CONFIG_COSMOS_HISTORY_CONTAINER = "cosmos_history_container"
6 changes: 5 additions & 1 deletion app/backend/core/sessionhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from typing import Union


def create_session_id(config_chat_history_browser_enabled: bool) -> Union[str, None]:
def create_session_id(
config_chat_history_cosmos_enabled: bool, config_chat_history_browser_enabled: bool
) -> Union[str, None]:
if config_chat_history_cosmos_enabled:
return str(uuid.uuid4())
if config_chat_history_browser_enabled:
return str(uuid.uuid4())
return None
13 changes: 8 additions & 5 deletions app/backend/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import wraps
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, TypeVar, cast

from quart import abort, current_app, request

Expand Down Expand Up @@ -37,19 +37,22 @@ async def auth_handler(path=""):
return auth_handler


def authenticated(route_fn: Callable[[Dict[str, Any]], Any]):
_C = TypeVar("_C", bound=Callable[..., Any])


def authenticated(route_fn: _C) -> _C:
"""
Decorator for routes that might require access control. Unpacks Authorization header information into an auth_claims dictionary
"""

@wraps(route_fn)
async def auth_handler():
async def auth_handler(*args, **kwargs):
auth_helper = current_app.config[CONFIG_AUTH_CLIENT]
try:
auth_claims = await auth_helper.get_auth_claims_if_enabled(request.headers)
except AuthError:
abort(403)

return await route_fn(auth_claims)
return await route_fn(auth_claims, *args, **kwargs)

return auth_handler
return cast(_C, auth_handler)
1 change: 1 addition & 0 deletions app/backend/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ tiktoken
tenacity
azure-ai-documentintelligence
azure-cognitiveservices-speech
azure-cosmos
azure-search-documents==11.6.0b6
azure-storage-blob
azure-storage-file-datalake
Expand Down
4 changes: 4 additions & 0 deletions app/backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ azure-core==1.30.2
# via
# azure-ai-documentintelligence
# azure-core-tracing-opentelemetry
# azure-cosmos
# azure-identity
# azure-monitor-opentelemetry
# azure-monitor-opentelemetry-exporter
Expand All @@ -44,6 +45,8 @@ azure-core==1.30.2
# msrest
azure-core-tracing-opentelemetry==1.0.0b11
# via azure-monitor-opentelemetry
azure-cosmos==4.7.0
# via -r requirements.in
azure-identity==1.17.1
# via
# -r requirements.in
Expand Down Expand Up @@ -402,6 +405,7 @@ typing-extensions==4.12.2
# via
# azure-ai-documentintelligence
# azure-core
# azure-cosmos
# azure-identity
# azure-search-documents
# azure-storage-blob
Expand Down
Loading
Loading