diff --git a/README.md b/README.md index 847653232c..2d4f80454e 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ However, you can try the [Azure pricing calculator](https://azure.com/e/a87a169b - Azure AI Document Intelligence: SO (Standard) tier using pre-built layout. Pricing per document page, sample documents have 261 pages total. [Pricing](https://azure.microsoft.com/pricing/details/form-recognizer/) - Azure AI Search: Basic tier, 1 replica, free level of semantic search. Pricing per hour. [Pricing](https://azure.microsoft.com/pricing/details/search/) - Azure Blob Storage: Standard tier with ZRS (Zone-redundant storage). Pricing per storage and read operations. [Pricing](https://azure.microsoft.com/pricing/details/storage/blobs/) +- Azure Cosmos DB: Serverless tier. Pricing per request unit and storage. [Pricing](https://azure.microsoft.com/pricing/details/cosmos-db/) - Azure Monitor: Pay-as-you-go tier. Costs based on data ingested. [Pricing](https://azure.microsoft.com/pricing/details/monitor/) To reduce costs, you can switch to free SKUs for various services, but those SKUs have limitations. diff --git a/app/backend/app.py b/app/backend/app.py index 4098a12e1d..b83efefe62 100644 --- a/app/backend/app.py +++ b/app/backend/app.py @@ -53,6 +53,7 @@ from approaches.chatreadretrievereadvision import ChatReadRetrieveReadVisionApproach from approaches.retrievethenread import RetrieveThenReadApproach from approaches.retrievethenreadvision import RetrieveThenReadVisionApproach +from chat_history.cosmosdb import chat_history_cosmosdb_bp from config import ( CONFIG_ASK_APPROACH, CONFIG_ASK_VISION_APPROACH, @@ -60,6 +61,7 @@ CONFIG_BLOB_CONTAINER_CLIENT, CONFIG_CHAT_APPROACH, CONFIG_CHAT_HISTORY_BROWSER_ENABLED, + CONFIG_CHAT_HISTORY_COSMOS_ENABLED, CONFIG_CHAT_VISION_APPROACH, CONFIG_CREDENTIAL, CONFIG_GPT4V_DEPLOYED, @@ -224,7 +226,10 @@ async def chat(auth_claims: Dict[str, Any]): # else creates a new session_id depending on the chat history options enabled. session_state = request_json.get("session_state") if session_state is None: - session_state = create_session_id(current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED]) + session_state = create_session_id( + current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED], + current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED], + ) result = await approach.run( request_json["messages"], context=context, @@ -255,7 +260,10 @@ async def chat_stream(auth_claims: Dict[str, Any]): # else creates a new session_id depending on the chat history options enabled. session_state = request_json.get("session_state") if session_state is None: - session_state = create_session_id(current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED]) + session_state = create_session_id( + current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED], + current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED], + ) result = await approach.run_stream( request_json["messages"], context=context, @@ -289,6 +297,7 @@ def config(): "showSpeechOutputBrowser": current_app.config[CONFIG_SPEECH_OUTPUT_BROWSER_ENABLED], "showSpeechOutputAzure": current_app.config[CONFIG_SPEECH_OUTPUT_AZURE_ENABLED], "showChatHistoryBrowser": current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED], + "showChatHistoryCosmos": current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED], } ) @@ -455,6 +464,7 @@ async def setup_clients(): USE_SPEECH_OUTPUT_BROWSER = os.getenv("USE_SPEECH_OUTPUT_BROWSER", "").lower() == "true" USE_SPEECH_OUTPUT_AZURE = os.getenv("USE_SPEECH_OUTPUT_AZURE", "").lower() == "true" USE_CHAT_HISTORY_BROWSER = os.getenv("USE_CHAT_HISTORY_BROWSER", "").lower() == "true" + USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true" # WEBSITE_HOSTNAME is always set by App Service, RUNNING_IN_PRODUCTION is set in main.bicep RUNNING_ON_AZURE = os.getenv("WEBSITE_HOSTNAME") is not None or os.getenv("RUNNING_IN_PRODUCTION") is not None @@ -484,6 +494,9 @@ async def setup_clients(): current_app.logger.info("Setting up Azure credential using AzureDeveloperCliCredential for home tenant") azure_credential = AzureDeveloperCliCredential(process_timeout=60) + # Set the Azure credential in the app config for use in other parts of the app + current_app.config[CONFIG_CREDENTIAL] = azure_credential + # Set up clients for AI Search and Storage search_client = SearchClient( endpoint=f"https://{AZURE_SEARCH_SERVICE}.search.windows.net", @@ -573,7 +586,6 @@ async def setup_clients(): current_app.config[CONFIG_SPEECH_SERVICE_VOICE] = AZURE_SPEECH_SERVICE_VOICE # Wait until token is needed to fetch for the first time current_app.config[CONFIG_SPEECH_SERVICE_TOKEN] = None - current_app.config[CONFIG_CREDENTIAL] = azure_credential if OPENAI_HOST.startswith("azure"): if OPENAI_HOST == "azure_custom": @@ -628,6 +640,7 @@ async def setup_clients(): current_app.config[CONFIG_SPEECH_OUTPUT_BROWSER_ENABLED] = USE_SPEECH_OUTPUT_BROWSER current_app.config[CONFIG_SPEECH_OUTPUT_AZURE_ENABLED] = USE_SPEECH_OUTPUT_AZURE current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED] = USE_CHAT_HISTORY_BROWSER + current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED] = USE_CHAT_HISTORY_COSMOS # Various approaches to integrate GPT and external knowledge, most applications will use a single one of these patterns # or some derivative, here we include several for exploration purposes @@ -717,6 +730,7 @@ async def close_clients(): def create_app(): app = Quart(__name__) app.register_blueprint(bp) + app.register_blueprint(chat_history_cosmosdb_bp) if os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING"): app.logger.info("APPLICATIONINSIGHTS_CONNECTION_STRING is set, enabling Azure Monitor") diff --git a/app/backend/chat_history/__init__.py b/app/backend/chat_history/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/app/backend/chat_history/cosmosdb.py b/app/backend/chat_history/cosmosdb.py new file mode 100644 index 0000000000..49760970f7 --- /dev/null +++ b/app/backend/chat_history/cosmosdb.py @@ -0,0 +1,192 @@ +import os +import time +from typing import Any, Dict, Union + +from azure.cosmos.aio import ContainerProxy, CosmosClient +from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential +from quart import Blueprint, current_app, jsonify, request + +from config import ( + CONFIG_CHAT_HISTORY_COSMOS_ENABLED, + CONFIG_COSMOS_HISTORY_CLIENT, + CONFIG_COSMOS_HISTORY_CONTAINER, + CONFIG_CREDENTIAL, +) +from decorators import authenticated +from error import error_response + +chat_history_cosmosdb_bp = Blueprint("chat_history_cosmos", __name__, static_folder="static") + + +@chat_history_cosmosdb_bp.post("/chat_history") +@authenticated +async def post_chat_history(auth_claims: Dict[str, Any]): + if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: + return jsonify({"error": "Chat history not enabled"}), 400 + + container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] + if not container: + return jsonify({"error": "Chat history not enabled"}), 400 + + entra_oid = auth_claims.get("oid") + if not entra_oid: + return jsonify({"error": "User OID not found"}), 401 + + try: + request_json = await request.get_json() + id = request_json.get("id") + answers = request_json.get("answers") + title = answers[0][0][:50] + "..." if len(answers[0][0]) > 50 else answers[0][0] + timestamp = int(time.time() * 1000) + + await container.upsert_item( + {"id": id, "entra_oid": entra_oid, "title": title, "answers": answers, "timestamp": timestamp} + ) + + return jsonify({}), 201 + except Exception as error: + return error_response(error, "/chat_history") + + +@chat_history_cosmosdb_bp.post("/chat_history/items") +@authenticated +async def get_chat_history(auth_claims: Dict[str, Any]): + if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: + return jsonify({"error": "Chat history not enabled"}), 400 + + container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] + if not container: + return jsonify({"error": "Chat history not enabled"}), 400 + + entra_oid = auth_claims.get("oid") + if not entra_oid: + return jsonify({"error": "User OID not found"}), 401 + + try: + request_json = await request.get_json() + count = request_json.get("count", 20) + continuation_token = request_json.get("continuation_token") + + res = container.query_items( + query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC", + parameters=[dict(name="@entra_oid", value=entra_oid)], + max_item_count=count, + ) + + # set the continuation token for the next page + pager = res.by_page(continuation_token) + + # Get the first page, and the continuation token + try: + page = await pager.__anext__() + continuation_token = pager.continuation_token # type: ignore + + items = [] + async for item in page: + items.append( + { + "id": item.get("id"), + "entra_oid": item.get("entra_oid"), + "title": item.get("title", "untitled"), + "timestamp": item.get("timestamp"), + } + ) + + # If there are no more pages, StopAsyncIteration is raised + except StopAsyncIteration: + items = [] + continuation_token = None + + return jsonify({"items": items, "continuation_token": continuation_token}), 200 + + except Exception as error: + return error_response(error, "/chat_history/items") + + +@chat_history_cosmosdb_bp.get("/chat_history/items/") +@authenticated +async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str): + if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: + return jsonify({"error": "Chat history not enabled"}), 400 + + container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] + if not container: + return jsonify({"error": "Chat history not enabled"}), 400 + + entra_oid = auth_claims.get("oid") + if not entra_oid: + return jsonify({"error": "User OID not found"}), 401 + + try: + res = await container.read_item(item=item_id, partition_key=entra_oid) + return ( + jsonify( + { + "id": res.get("id"), + "entra_oid": res.get("entra_oid"), + "title": res.get("title", "untitled"), + "timestamp": res.get("timestamp"), + "answers": res.get("answers", []), + } + ), + 200, + ) + except Exception as error: + return error_response(error, f"/chat_history/items/{item_id}") + + +@chat_history_cosmosdb_bp.delete("/chat_history/items/") +@authenticated +async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str): + if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: + return jsonify({"error": "Chat history not enabled"}), 400 + + container: ContainerProxy = current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] + if not container: + return jsonify({"error": "Chat history not enabled"}), 400 + + entra_oid = auth_claims.get("oid") + if not entra_oid: + return jsonify({"error": "User OID not found"}), 401 + + try: + await container.delete_item(item=item_id, partition_key=entra_oid) + return jsonify({}), 204 + except Exception as error: + return error_response(error, f"/chat_history/items/{item_id}") + + +@chat_history_cosmosdb_bp.before_app_serving +async def setup_clients(): + USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true" + AZURE_COSMOSDB_ACCOUNT = os.getenv("AZURE_COSMOSDB_ACCOUNT") + AZURE_CHAT_HISTORY_DATABASE = os.getenv("AZURE_CHAT_HISTORY_DATABASE") + AZURE_CHAT_HISTORY_CONTAINER = os.getenv("AZURE_CHAT_HISTORY_CONTAINER") + + azure_credential: Union[AzureDeveloperCliCredential, ManagedIdentityCredential] = current_app.config[ + CONFIG_CREDENTIAL + ] + + if USE_CHAT_HISTORY_COSMOS: + current_app.logger.info("USE_CHAT_HISTORY_COSMOS is true, setting up CosmosDB client") + if not AZURE_COSMOSDB_ACCOUNT: + raise ValueError("AZURE_COSMOSDB_ACCOUNT must be set when USE_CHAT_HISTORY_COSMOS is true") + if not AZURE_CHAT_HISTORY_DATABASE: + raise ValueError("AZURE_CHAT_HISTORY_DATABASE must be set when USE_CHAT_HISTORY_COSMOS is true") + if not AZURE_CHAT_HISTORY_CONTAINER: + raise ValueError("AZURE_CHAT_HISTORY_CONTAINER must be set when USE_CHAT_HISTORY_COSMOS is true") + cosmos_client = CosmosClient( + url=f"https://{AZURE_COSMOSDB_ACCOUNT}.documents.azure.com:443/", credential=azure_credential + ) + cosmos_db = cosmos_client.get_database_client(AZURE_CHAT_HISTORY_DATABASE) + cosmos_container = cosmos_db.get_container_client(AZURE_CHAT_HISTORY_CONTAINER) + + current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] = cosmos_client + current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] = cosmos_container + + +@chat_history_cosmosdb_bp.after_app_serving +async def close_clients(): + if current_app.config.get(CONFIG_COSMOS_HISTORY_CLIENT): + cosmos_client: CosmosClient = current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] + await cosmos_client.close() diff --git a/app/backend/config.py b/app/backend/config.py index e5f274fa38..eaba154116 100644 --- a/app/backend/config.py +++ b/app/backend/config.py @@ -23,3 +23,6 @@ CONFIG_SPEECH_SERVICE_TOKEN = "speech_service_token" CONFIG_SPEECH_SERVICE_VOICE = "speech_service_voice" CONFIG_CHAT_HISTORY_BROWSER_ENABLED = "chat_history_browser_enabled" +CONFIG_CHAT_HISTORY_COSMOS_ENABLED = "chat_history_cosmos_enabled" +CONFIG_COSMOS_HISTORY_CLIENT = "cosmos_history_client" +CONFIG_COSMOS_HISTORY_CONTAINER = "cosmos_history_container" diff --git a/app/backend/core/sessionhelper.py b/app/backend/core/sessionhelper.py index 28dd0e811b..ddda8e03b7 100644 --- a/app/backend/core/sessionhelper.py +++ b/app/backend/core/sessionhelper.py @@ -2,7 +2,11 @@ from typing import Union -def create_session_id(config_chat_history_browser_enabled: bool) -> Union[str, None]: +def create_session_id( + config_chat_history_cosmos_enabled: bool, config_chat_history_browser_enabled: bool +) -> Union[str, None]: + if config_chat_history_cosmos_enabled: + return str(uuid.uuid4()) if config_chat_history_browser_enabled: return str(uuid.uuid4()) return None diff --git a/app/backend/decorators.py b/app/backend/decorators.py index 32f6b9a2b5..f4becc4b70 100644 --- a/app/backend/decorators.py +++ b/app/backend/decorators.py @@ -1,6 +1,6 @@ import logging from functools import wraps -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, TypeVar, cast from quart import abort, current_app, request @@ -37,19 +37,22 @@ async def auth_handler(path=""): return auth_handler -def authenticated(route_fn: Callable[[Dict[str, Any]], Any]): +_C = TypeVar("_C", bound=Callable[..., Any]) + + +def authenticated(route_fn: _C) -> _C: """ Decorator for routes that might require access control. Unpacks Authorization header information into an auth_claims dictionary """ @wraps(route_fn) - async def auth_handler(): + async def auth_handler(*args, **kwargs): auth_helper = current_app.config[CONFIG_AUTH_CLIENT] try: auth_claims = await auth_helper.get_auth_claims_if_enabled(request.headers) except AuthError: abort(403) - return await route_fn(auth_claims) + return await route_fn(auth_claims, *args, **kwargs) - return auth_handler + return cast(_C, auth_handler) diff --git a/app/backend/requirements.in b/app/backend/requirements.in index 99cb44e678..765a72c486 100644 --- a/app/backend/requirements.in +++ b/app/backend/requirements.in @@ -7,6 +7,7 @@ tiktoken tenacity azure-ai-documentintelligence azure-cognitiveservices-speech +azure-cosmos azure-search-documents==11.6.0b6 azure-storage-blob azure-storage-file-datalake diff --git a/app/backend/requirements.txt b/app/backend/requirements.txt index fe339f08c1..0dcf9f86a4 100644 --- a/app/backend/requirements.txt +++ b/app/backend/requirements.txt @@ -34,6 +34,7 @@ azure-core==1.30.2 # via # azure-ai-documentintelligence # azure-core-tracing-opentelemetry + # azure-cosmos # azure-identity # azure-monitor-opentelemetry # azure-monitor-opentelemetry-exporter @@ -44,6 +45,8 @@ azure-core==1.30.2 # msrest azure-core-tracing-opentelemetry==1.0.0b11 # via azure-monitor-opentelemetry +azure-cosmos==4.7.0 + # via -r requirements.in azure-identity==1.17.1 # via # -r requirements.in @@ -402,6 +405,7 @@ typing-extensions==4.12.2 # via # azure-ai-documentintelligence # azure-core + # azure-cosmos # azure-identity # azure-search-documents # azure-storage-blob diff --git a/app/frontend/src/api/api.ts b/app/frontend/src/api/api.ts index f2cd36507e..76636d4d05 100644 --- a/app/frontend/src/api/api.ts +++ b/app/frontend/src/api/api.ts @@ -1,6 +1,6 @@ const BACKEND_URI = ""; -import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse } from "./models"; +import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse, HistoryListApiResponse, HistroyApiResponse } from "./models"; import { useLogin, getToken, isUsingAppServicesLogin } from "../authConfig"; export async function getHeaders(idToken: string | undefined): Promise> { @@ -126,3 +126,65 @@ export async function listUploadedFilesApi(idToken: string): Promise { const dataResponse: string[] = await response.json(); return dataResponse; } + +export async function postChatHistoryApi(item: any, idToken: string): Promise { + const headers = await getHeaders(idToken); + const response = await fetch("/chat_history", { + method: "POST", + headers: { ...headers, "Content-Type": "application/json" }, + body: JSON.stringify(item) + }); + + if (!response.ok) { + throw new Error(`Posting chat history failed: ${response.statusText}`); + } + + const dataResponse: any = await response.json(); + return dataResponse; +} + +export async function getChatHistoryListApi(count: number, continuationToken: string | undefined, idToken: string): Promise { + const headers = await getHeaders(idToken); + const response = await fetch("/chat_history/items", { + method: "POST", + headers: { ...headers, "Content-Type": "application/json" }, + body: JSON.stringify({ count: count, continuation_token: continuationToken }) + }); + + if (!response.ok) { + throw new Error(`Getting chat histories failed: ${response.statusText}`); + } + + const dataResponse: HistoryListApiResponse = await response.json(); + return dataResponse; +} + +export async function getChatHistoryApi(id: string, idToken: string): Promise { + const headers = await getHeaders(idToken); + const response = await fetch(`/chat_history/items/${id}`, { + method: "GET", + headers: { ...headers, "Content-Type": "application/json" } + }); + + if (!response.ok) { + throw new Error(`Getting chat history failed: ${response.statusText}`); + } + + const dataResponse: HistroyApiResponse = await response.json(); + return dataResponse; +} + +export async function deleteChatHistoryApi(id: string, idToken: string): Promise { + const headers = await getHeaders(idToken); + const response = await fetch(`/chat_history/items/${id}`, { + method: "DELETE", + headers: { ...headers, "Content-Type": "application/json" } + }); + + if (!response.ok) { + throw new Error(`Deleting chat history failed: ${response.statusText}`); + } + + const dataResponse: any = await response.json(); + return dataResponse; +} diff --git a/app/frontend/src/api/models.ts b/app/frontend/src/api/models.ts index 633af8bd3f..ef1fa154b0 100644 --- a/app/frontend/src/api/models.ts +++ b/app/frontend/src/api/models.ts @@ -91,6 +91,7 @@ export type Config = { showSpeechOutputBrowser: boolean; showSpeechOutputAzure: boolean; showChatHistoryBrowser: boolean; + showChatHistoryCosmos: boolean; }; export type SimpleAPIResponse = { @@ -104,3 +105,21 @@ export interface SpeechConfig { isPlaying: boolean; setIsPlaying: (isPlaying: boolean) => void; } + +export type HistoryListApiResponse = { + items: { + id: string; + entra_oid: string; + title: string; + timestamp: number; + }[]; + continuation_token?: string; +}; + +export type HistroyApiResponse = { + id: string; + entra_oid: string; + title: string; + answers: any; + timestamp: number; +}; diff --git a/app/frontend/src/components/HistoryPanel/HistoryPanel.tsx b/app/frontend/src/components/HistoryPanel/HistoryPanel.tsx index b70a723305..acaf3b7870 100644 --- a/app/frontend/src/components/HistoryPanel/HistoryPanel.tsx +++ b/app/frontend/src/components/HistoryPanel/HistoryPanel.tsx @@ -1,4 +1,6 @@ -import { Panel, PanelType } from "@fluentui/react"; +import { useMsal } from "@azure/msal-react"; +import { getToken, useLogin } from "../../authConfig"; +import { Panel, PanelType, Spinner } from "@fluentui/react"; import { useEffect, useMemo, useRef, useState } from "react"; import { HistoryData, HistoryItem } from "../HistoryItem"; import { Answers, HistoryProviderOptions } from "../HistoryProviders/IProvider"; @@ -26,6 +28,8 @@ export const HistoryPanel = ({ const [isLoading, setIsLoading] = useState(false); const [hasMoreHistory, setHasMoreHistory] = useState(false); + const client = useLogin ? useMsal().instance : undefined; + useEffect(() => { if (!isOpen) return; if (notify) { @@ -37,7 +41,8 @@ export const HistoryPanel = ({ const loadMoreHistory = async () => { setIsLoading(() => true); - const items = await historyManager.getNextItems(HISTORY_COUNT_PER_LOAD); + const token = client ? await getToken(client) : undefined; + const items = await historyManager.getNextItems(HISTORY_COUNT_PER_LOAD, token); if (items.length === 0) { setHasMoreHistory(false); } @@ -46,14 +51,16 @@ export const HistoryPanel = ({ }; const handleSelect = async (id: string) => { - const item = await historyManager.getItem(id); + const token = client ? await getToken(client) : undefined; + const item = await historyManager.getItem(id, token); if (item) { onChatSelected(item); } }; const handleDelete = async (id: string) => { - await historyManager.deleteItem(id); + const token = client ? await getToken(client) : undefined; + await historyManager.deleteItem(id, token); setHistory(prevHistory => prevHistory.filter(item => item.id !== id)); }; @@ -85,7 +92,8 @@ export const HistoryPanel = ({ ))} ))} - {history.length === 0 &&

{t("history.noHistory")}

} + {isLoading && } + {history.length === 0 && !isLoading &&

{t("history.noHistory")}

} {hasMoreHistory && !isLoading && } diff --git a/app/frontend/src/components/HistoryProviders/CosmosDB.ts b/app/frontend/src/components/HistoryProviders/CosmosDB.ts new file mode 100644 index 0000000000..4d613b28a8 --- /dev/null +++ b/app/frontend/src/components/HistoryProviders/CosmosDB.ts @@ -0,0 +1,51 @@ +import { IHistoryProvider, Answers, HistoryProviderOptions, HistoryMetaData } from "./IProvider"; +import { deleteChatHistoryApi, getChatHistoryApi, getChatHistoryListApi, postChatHistoryApi } from "../../api"; + +export class CosmosDBProvider implements IHistoryProvider { + getProviderName = () => HistoryProviderOptions.CosmosDB; + + private continuationToken: string | undefined; + private isItemEnd: boolean = false; + + resetContinuationToken() { + this.continuationToken = undefined; + this.isItemEnd = false; + } + + async getNextItems(count: number, idToken?: string): Promise { + if (this.isItemEnd) { + return []; + } + + try { + const response = await getChatHistoryListApi(count, this.continuationToken, idToken || ""); + this.continuationToken = response.continuation_token; + if (!this.continuationToken) { + this.isItemEnd = true; + } + return response.items.map(item => ({ + id: item.id, + title: item.title, + timestamp: item.timestamp + })); + } catch (e) { + console.error(e); + return []; + } + } + + async addItem(id: string, answers: Answers, idToken?: string): Promise { + await postChatHistoryApi({ id, answers }, idToken || ""); + return; + } + + async getItem(id: string, idToken?: string): Promise { + const response = await getChatHistoryApi(id, idToken || ""); + return response.answers || null; + } + + async deleteItem(id: string, idToken?: string): Promise { + await deleteChatHistoryApi(id, idToken || ""); + return; + } +} diff --git a/app/frontend/src/components/HistoryProviders/HistoryManager.ts b/app/frontend/src/components/HistoryProviders/HistoryManager.ts index fa31c6314d..e796d04933 100644 --- a/app/frontend/src/components/HistoryProviders/HistoryManager.ts +++ b/app/frontend/src/components/HistoryProviders/HistoryManager.ts @@ -2,12 +2,15 @@ import { useMemo } from "react"; import { IHistoryProvider, HistoryProviderOptions } from "../HistoryProviders/IProvider"; import { NoneProvider } from "../HistoryProviders/None"; import { IndexedDBProvider } from "../HistoryProviders/IndexedDB"; +import { CosmosDBProvider } from "../HistoryProviders/CosmosDB"; export const useHistoryManager = (provider: HistoryProviderOptions): IHistoryProvider => { const providerInstance = useMemo(() => { switch (provider) { case HistoryProviderOptions.IndexedDB: return new IndexedDBProvider("chat-database", "chat-history"); + case HistoryProviderOptions.CosmosDB: + return new CosmosDBProvider(); case HistoryProviderOptions.None: default: return new NoneProvider(); diff --git a/app/frontend/src/components/HistoryProviders/IProvider.ts b/app/frontend/src/components/HistoryProviders/IProvider.ts index 330437a8ec..026443d681 100644 --- a/app/frontend/src/components/HistoryProviders/IProvider.ts +++ b/app/frontend/src/components/HistoryProviders/IProvider.ts @@ -5,14 +5,15 @@ export type Answers = [user: string, response: ChatAppResponse][]; export const enum HistoryProviderOptions { None = "none", - IndexedDB = "indexedDB" + IndexedDB = "indexedDB", + CosmosDB = "cosmosDB" } export interface IHistoryProvider { getProviderName(): HistoryProviderOptions; resetContinuationToken(): void; - getNextItems(count: number): Promise; - addItem(id: string, answers: Answers): Promise; - getItem(id: string): Promise; - deleteItem(id: string): Promise; + getNextItems(count: number, idToken?: string): Promise; + addItem(id: string, answers: Answers, idToken?: string): Promise; + getItem(id: string, idToken?: string): Promise; + deleteItem(id: string, idToken?: string): Promise; } diff --git a/app/frontend/src/pages/chat/Chat.tsx b/app/frontend/src/pages/chat/Chat.tsx index 1fa09ec511..7dabff8c7f 100644 --- a/app/frontend/src/pages/chat/Chat.tsx +++ b/app/frontend/src/pages/chat/Chat.tsx @@ -83,6 +83,7 @@ const Chat = () => { const [showSpeechOutputBrowser, setShowSpeechOutputBrowser] = useState(false); const [showSpeechOutputAzure, setShowSpeechOutputAzure] = useState(false); const [showChatHistoryBrowser, setShowChatHistoryBrowser] = useState(false); + const [showChatHistoryCosmos, setShowChatHistoryCosmos] = useState(false); const audio = useRef(new Audio()).current; const [isPlaying, setIsPlaying] = useState(false); @@ -109,6 +110,7 @@ const Chat = () => { setShowSpeechOutputBrowser(config.showSpeechOutputBrowser); setShowSpeechOutputAzure(config.showSpeechOutputAzure); setShowChatHistoryBrowser(config.showChatHistoryBrowser); + setShowChatHistoryCosmos(config.showChatHistoryCosmos); }); }; @@ -158,7 +160,11 @@ const Chat = () => { const client = useLogin ? useMsal().instance : undefined; const { loggedIn } = useContext(LoginContext); - const historyProvider: HistoryProviderOptions = showChatHistoryBrowser ? HistoryProviderOptions.IndexedDB : HistoryProviderOptions.None; + const historyProvider: HistoryProviderOptions = (() => { + if (useLogin && showChatHistoryCosmos) return HistoryProviderOptions.CosmosDB; + if (showChatHistoryBrowser) return HistoryProviderOptions.IndexedDB; + return HistoryProviderOptions.None; + })(); const historyManager = useHistoryManager(historyProvider); const makeApiRequest = async (question: string) => { @@ -216,7 +222,8 @@ const Chat = () => { const parsedResponse: ChatAppResponse = await handleAsyncRequest(question, answers, response.body); setAnswers([...answers, [question, parsedResponse]]); if (typeof parsedResponse.session_state === "string" && parsedResponse.session_state !== "") { - historyManager.addItem(parsedResponse.session_state, [...answers, [question, parsedResponse]]); + const token = client ? await getToken(client) : undefined; + historyManager.addItem(parsedResponse.session_state, [...answers, [question, parsedResponse]], token); } } else { const parsedResponse: ChatAppResponseOrError = await response.json(); @@ -225,7 +232,8 @@ const Chat = () => { } setAnswers([...answers, [question, parsedResponse as ChatAppResponse]]); if (typeof parsedResponse.session_state === "string" && parsedResponse.session_state !== "") { - historyManager.addItem(parsedResponse.session_state, [...answers, [question, parsedResponse as ChatAppResponse]]); + const token = client ? await getToken(client) : undefined; + historyManager.addItem(parsedResponse.session_state, [...answers, [question, parsedResponse as ChatAppResponse]], token); } } setSpeechUrls([...speechUrls, null]); @@ -348,7 +356,9 @@ const Chat = () => {
- {showChatHistoryBrowser && setIsHistoryPanelOpen(!isHistoryPanelOpen)} />} + {((useLogin && showChatHistoryCosmos) || showChatHistoryBrowser) && ( + setIsHistoryPanelOpen(!isHistoryPanelOpen)} /> + )}
@@ -457,7 +467,7 @@ const Chat = () => { /> )} - {showChatHistoryBrowser && ( + {((useLogin && showChatHistoryCosmos) || showChatHistoryBrowser) && ( \n" + } + + +@pytest.mark.asyncio +async def test_chathistory_query(auth_public_documents_client, monkeypatch, snapshot): + + def mock_query_items(container_proxy, query, **kwargs): + return MockCosmosDBResultsIterator() + + monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) + + response = await auth_public_documents_client.post( + "/chat_history/items", + headers={"Authorization": "Bearer MockToken"}, + json={"count": 20}, + ) + assert response.status_code == 200 + result = await response.get_json() + snapshot.assert_match(json.dumps(result, indent=4), "result.json") + + +@pytest.mark.asyncio +async def test_chathistory_query_continuation(auth_public_documents_client, monkeypatch, snapshot): + + def mock_query_items(container_proxy, query, **kwargs): + return MockCosmosDBResultsIterator(empty=True) + + monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) + + response = await auth_public_documents_client.post( + "/chat_history/items", + headers={"Authorization": "Bearer MockToken"}, + json={"count": 20}, + ) + assert response.status_code == 200 + result = await response.get_json() + snapshot.assert_match(json.dumps(result, indent=4), "result.json") + + +@pytest.mark.asyncio +async def test_chathistory_query_error_disabled(client, monkeypatch): + + response = await client.post( + "/chat_history/items", + headers={"Authorization": "Bearer MockToken"}, + json={ + "id": "123", + "answers": [["This is a test message"]], + }, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_chathistory_query_error_container(auth_public_documents_client, monkeypatch): + auth_public_documents_client.app.config["cosmos_history_container"] = None + response = await auth_public_documents_client.post( + "/chat_history/items", + headers={"Authorization": "Bearer MockToken"}, + json={ + "id": "123", + "answers": [["This is a test message"]], + }, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_chathistory_query_error_entra(auth_public_documents_client, monkeypatch): + response = await auth_public_documents_client.post( + "/chat_history/items", + json={ + "id": "123", + "answers": [["This is a test message"]], + }, + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_chathistory_query_error_runtime(auth_public_documents_client, monkeypatch): + + def mock_query_items(container_proxy, query, **kwargs): + raise Exception("Test Exception") + + monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) + + response = await auth_public_documents_client.post( + "/chat_history/items", + headers={"Authorization": "Bearer MockToken"}, + json={"count": 20}, + ) + assert response.status_code == 500 + assert (await response.get_json()) == { + "error": "The app encountered an error processing your request.\nIf you are an administrator of the app, view the full error in the logs. See aka.ms/appservice-logs for more information.\nError type: \n" + } + + +# Tests for getting an individual chat history item +@pytest.mark.asyncio +async def test_chathistory_getitem(auth_public_documents_client, monkeypatch, snapshot): + + async def mock_read_item(container_proxy, item, partition_key, **kwargs): + return { + "id": "123", + "entra_oid": "OID_X", + "title": "This is a test message", + "timestamp": 123456789, + "answers": [["This is a test message"]], + } + + monkeypatch.setattr(ContainerProxy, "read_item", mock_read_item) + + response = await auth_public_documents_client.get( + "/chat_history/items/123", + headers={"Authorization": "Bearer MockToken"}, + ) + assert response.status_code == 200 + result = await response.get_json() + snapshot.assert_match(json.dumps(result, indent=4), "result.json") + + +# Error handling tests for getting an individual chat history item +@pytest.mark.asyncio +async def test_chathistory_getitem_error_disabled(client, monkeypatch): + + response = await client.get( + "/chat_history/items/123", + headers={"Authorization": "BearerMockToken"}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_chathistory_getitem_error_container(auth_public_documents_client, monkeypatch): + auth_public_documents_client.app.config["cosmos_history_container"] = None + response = await auth_public_documents_client.get( + "/chat_history/items/123", + headers={"Authorization": "BearerMockToken"}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_chathistory_getitem_error_entra(auth_public_documents_client, monkeypatch): + response = await auth_public_documents_client.get( + "/chat_history/items/123", + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_chathistory_getitem_error_runtime(auth_public_documents_client, monkeypatch): + + async def mock_read_item(container_proxy, item, partition_key, **kwargs): + raise Exception("Test Exception") + + monkeypatch.setattr(ContainerProxy, "read_item", mock_read_item) + + response = await auth_public_documents_client.get( + "/chat_history/items/123", + headers={"Authorization": "Bearer MockToken"}, + ) + assert response.status_code == 500 + + +# Tests for deleting an individual chat history item +@pytest.mark.asyncio +async def test_chathistory_deleteitem(auth_public_documents_client, monkeypatch): + + async def mock_delete_item(container_proxy, item, partition_key, **kwargs): + assert item == "123" + assert partition_key == "OID_X" + + monkeypatch.setattr(ContainerProxy, "delete_item", mock_delete_item) + + response = await auth_public_documents_client.delete( + "/chat_history/items/123", + headers={"Authorization": "Bearer MockToken"}, + ) + assert response.status_code == 204 + + +@pytest.mark.asyncio +async def test_chathistory_deleteitem_error_disabled(client, monkeypatch): + + response = await client.delete( + "/chat_history/items/123", + headers={"Authorization": "Bearer MockToken"}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_chathistory_deleteitem_error_container(auth_public_documents_client, monkeypatch): + auth_public_documents_client.app.config["cosmos_history_container"] = None + response = await auth_public_documents_client.delete( + "/chat_history/items/123", + headers={"Authorization": "Bearer MockToken"}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_chathistory_deleteitem_error_entra(auth_public_documents_client, monkeypatch): + response = await auth_public_documents_client.delete( + "/chat_history/items/123", + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_chathistory_deleteitem_error_runtime(auth_public_documents_client, monkeypatch): + + async def mock_delete_item(container_proxy, item, partition_key, **kwargs): + raise Exception("Test Exception") + + monkeypatch.setattr(ContainerProxy, "delete_item", mock_delete_item) + + response = await auth_public_documents_client.delete( + "/chat_history/items/123", + headers={"Authorization": "Bearer MockToken"}, + ) + assert response.status_code == 500