Skip to content

Commit

Permalink
Token Usage (#39)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the copilot-chat repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

This PR adds the token usage feature, in which token usage is calculated
and shown per prompt and per session. Each token usage calculation will
be split into two values:
1. total tokens used in chat completion of the bot response prompt
2. total tokens used in dependencies used to generate prompt

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

webapi
- Token usage per prompt persists as part of ChatMessage object
- Add initial bot message and token usage tracking to
ChatHistoryController and tracking of token usage for bot response in
ChatMessage and ChatSession models.
- Update ChatSkill to save token usage to context variables and return
with bot response.
- Added token usage calculation in ChatSkill by calculating total token
usage for dependency functions and chat completion and sending updated
response to client. Copy token usage into original chat context.
- Calculate memory extraction token usage by taking into account
cumulative semantic memory token usage.
- Update Utilities to include GetTokenUsage method.

Webapp
- AppState: Added a new TokenUsage field to track total usage across all
chats by app session, and appSlice has been updated to cumulate session
token usage.
- Update SignalRMiddleware to handle token usage when receiving message
updates from server. Update message property to tokenUsage if tokenUsage
is defined, otherwise update content.
- Fix ChatHistoryTextContent to include TypingIndicator when bot
response is generating.
- Changed PromptDetails -> PromptDialog component to show prompt details
and token usage graph.
- Removed TypingIndicatorRenderer

Token usage shown for ChatMessages of type Message and Plan

![image](https://github.com/microsoft/chat-copilot/assets/125500434/9ae5a262-67ed-400c-8e26-b486f0e307c8)

Hardcoded bot responses default to 0

![image](https://github.com/microsoft/chat-copilot/assets/125500434/0240695a-14ea-4a53-90f7-e2c3f64df0fe)
Loading state

![image](https://github.com/microsoft/chat-copilot/assets/125500434/cb3b0ab7-d76e-4404-8a09-7fa8078fbbf1)

Info

![image](https://github.com/microsoft/chat-copilot/assets/125500434/9747b239-daa8-4553-ab57-9210c5553211)

This is what it will look like in settings dialog once changes go in

![image](https://github.com/microsoft/chat-copilot/assets/125500434/aca7d038-96f9-4d3e-95a7-339bede3ebc7)


### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [Contribution
Guidelines](https://github.com/microsoft/copilot-chat/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/copilot-chat/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
~~[ ] All unit tests pass, and I have added new tests where possible~~
- [x] I didn't break anyone 😄

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
teresaqhoang and actions-user authored Jul 28, 2023
1 parent d76ab3e commit 6aefee8
Show file tree
Hide file tree
Showing 38 changed files with 776 additions and 211 deletions.
14 changes: 9 additions & 5 deletions webapi/CopilotChat/Controllers/ChatHistoryController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using SemanticKernel.Service.CopilotChat.Hubs;
using SemanticKernel.Service.CopilotChat.Models;
using SemanticKernel.Service.CopilotChat.Options;
using SemanticKernel.Service.CopilotChat.Skills;
using SemanticKernel.Service.CopilotChat.Storage;

namespace SemanticKernel.Service.CopilotChat.Controllers;
Expand Down Expand Up @@ -80,19 +81,22 @@ public async Task<IActionResult> CreateChatSessionAsync([FromBody] CreateChatPar
var newChat = new ChatSession(chatParameter.Title, this._promptOptions.SystemDescription);
await this._sessionRepository.CreateAsync(newChat);

var initialBotMessage = this._promptOptions.InitialBotMessage;
// The initial bot message doesn't need a prompt.
// Create initial bot message
var chatMessage = ChatMessage.CreateBotResponseMessage(
newChat.Id,
initialBotMessage,
string.Empty);
this._promptOptions.InitialBotMessage,
string.Empty, // The initial bot message doesn't need a prompt.
TokenUtilities.EmptyTokenUsages());
await this._messageRepository.CreateAsync(chatMessage);

// Add the user to the chat session
await this._participantRepository.CreateAsync(new ChatParticipant(chatParameter.UserId, newChat.Id));

this._logger.LogDebug("Created chat session with id {0}.", newChat.Id);
return this.CreatedAtAction(nameof(this.GetChatSessionByIdAsync), new { chatId = newChat.Id }, newChat);
return this.CreatedAtAction(
nameof(this.GetChatSessionByIdAsync),
new { chatId = newChat.Id },
new CreateChatResponse(newChat, chatMessage));
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion webapi/CopilotChat/Controllers/ChatMemoryController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ private bool ValidateMemoryName(string memoryName)
}

# endregion
}
}
2 changes: 1 addition & 1 deletion webapi/CopilotChat/Controllers/DocumentImportController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ await kernel.Memory.SaveInformationAsync(
id: key,
description: $"Document: {documentName}");
importResult.AddKey(key);
importResult.Tokens += Utilities.TokenCount(paragraph);
importResult.Tokens += TokenUtilities.TokenCount(paragraph);
}

this._logger.LogInformation(
Expand Down
17 changes: 14 additions & 3 deletions webapi/CopilotChat/Models/ChatMessage.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Text.Json;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -110,6 +111,12 @@ public enum ChatMessageType
[JsonPropertyName("type")]
public ChatMessageType Type { get; set; }

/// <summary>
/// Counts of total token usage used to generate bot response.
/// </summary>
[JsonPropertyName("tokenUsage")]
public Dictionary<string, int>? TokenUsage { get; set; }

/// <summary>
/// Create a new chat message. Timestamp is automatically generated.
/// </summary>
Expand All @@ -120,14 +127,16 @@ public enum ChatMessageType
/// <param name="prompt">The prompt used to generate the message</param>
/// <param name="authorRole">Role of the author</param>
/// <param name="type">Type of the message</param>
/// <param name="tokenUsage">Total token usages used to generate bot response</param>
public ChatMessage(
string userId,
string userName,
string chatId,
string content,
string prompt = "",
AuthorRoles authorRole = AuthorRoles.User,
ChatMessageType type = ChatMessageType.Message)
ChatMessageType type = ChatMessageType.Message,
Dictionary<string, int>? tokenUsage = null)
{
this.Timestamp = DateTimeOffset.Now;
this.UserId = userId;
Expand All @@ -138,6 +147,7 @@ public ChatMessage(
this.Prompt = prompt;
this.AuthorRole = authorRole;
this.Type = type;
this.TokenUsage = tokenUsage;
}

/// <summary>
Expand All @@ -146,9 +156,10 @@ public ChatMessage(
/// <param name="chatId">The chat ID that this message belongs to</param>
/// <param name="content">The message</param>
/// <param name="prompt">The prompt used to generate the message</param>
public static ChatMessage CreateBotResponseMessage(string chatId, string content, string prompt)
/// <param name="tokenUsage">Total token usage of response completion</param>
public static ChatMessage CreateBotResponseMessage(string chatId, string content, string prompt, Dictionary<string, int>? tokenUsage = null)
{
return new ChatMessage("bot", "bot", chatId, content, prompt, AuthorRoles.Bot, IsPlan(content) ? ChatMessageType.Plan : ChatMessageType.Message);
return new ChatMessage("bot", "bot", chatId, content, prompt, AuthorRoles.Bot, IsPlan(content) ? ChatMessageType.Plan : ChatMessageType.Message, tokenUsage);
}

/// <summary>
Expand Down
36 changes: 36 additions & 0 deletions webapi/CopilotChat/Models/CreateChatResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace SemanticKernel.Service.CopilotChat.Models;

/// <summary>
/// Response to chatSession/create request.
/// </summary>
public class CreateChatResponse
{
/// <summary>
/// ID that is persistent and unique to new chat session.
/// </summary>
[JsonPropertyName("id")]
public string Id { get; set; }

/// <summary>
/// Title of the chat.
/// </summary>
[JsonPropertyName("title")]
public string Title { get; set; }

/// <summary>
/// Initial bot message.
/// </summary>
[JsonPropertyName("initialBotMessage")]
public ChatMessage? InitialBotMessage { get; set; }

public CreateChatResponse(ChatSession chatSession, ChatMessage initialBotMessage)
{
this.Id = chatSession.Id;
this.Title = chatSession.Title;
this.InitialBotMessage = initialBotMessage;
}
}
97 changes: 79 additions & 18 deletions webapi/CopilotChat/Skills/ChatSkills/ChatSkill.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public async Task<string> ExtractUserIntentAsync(SKContext context)
var historyTokenBudget =
tokenLimit -
this._promptOptions.ResponseTokenLimit -
Utilities.TokenCount(string.Join("\n", new string[]
TokenUtilities.TokenCount(string.Join("\n", new string[]
{
this._promptOptions.SystemDescription,
this._promptOptions.SystemIntent,
Expand All @@ -139,6 +139,9 @@ public async Task<string> ExtractUserIntentAsync(SKContext context)
settings: this.CreateIntentCompletionSettings()
);

// Get token usage from ChatCompletion result and add to context
TokenUtilities.GetFunctionTokenUsage(result, context, "SystemIntentExtraction");

if (result.ErrorOccurred)
{
context.Log.LogError("{0}: {1}", result.LastErrorDescription, result.LastException);
Expand All @@ -161,7 +164,7 @@ public async Task<string> ExtractAudienceAsync(SKContext context)
var historyTokenBudget =
tokenLimit -
this._promptOptions.ResponseTokenLimit -
Utilities.TokenCount(string.Join("\n", new string[]
TokenUtilities.TokenCount(string.Join("\n", new string[]
{
this._promptOptions.SystemAudience,
this._promptOptions.SystemAudienceContinuation,
Expand All @@ -182,6 +185,9 @@ public async Task<string> ExtractAudienceAsync(SKContext context)
settings: this.CreateIntentCompletionSettings()
);

// Get token usage from ChatCompletion result and add to context
TokenUtilities.GetFunctionTokenUsage(result, context, "SystemAudienceExtraction");

if (result.ErrorOccurred)
{
context.Log.LogError("{0}: {1}", result.LastErrorDescription, result.LastException);
Expand Down Expand Up @@ -229,7 +235,7 @@ public async Task<string> ExtractChatHistoryAsync(
}
}

var tokenCount = Utilities.TokenCount(formattedMessage);
var tokenCount = TokenUtilities.TokenCount(formattedMessage);

if (remainingToken - tokenCount >= 0)
{
Expand Down Expand Up @@ -262,7 +268,7 @@ public async Task<SKContext> ChatAsync(
SKContext context)
{
// Set the system description in the prompt options
await SetSystemDescriptionAsync(chatId);
await this.SetSystemDescriptionAsync(chatId);

// Save this new message to memory such that subsequent chat responses can use it
await this.UpdateBotResponseStatusOnClient(chatId, "Saving user message to chat history");
Expand All @@ -284,7 +290,7 @@ public async Task<SKContext> ChatAsync(
// Save hardcoded response if user cancelled plan
if (chatContext.Variables.ContainsKey("userCancelledPlan"))
{
await this.SaveNewResponseAsync("I am sorry the plan did not meet your goals.", string.Empty, chatId, userId);
await this.SaveNewResponseAsync("I am sorry the plan did not meet your goals.", string.Empty, chatId, userId, TokenUtilities.EmptyTokenUsages());
return context;
}

Expand All @@ -296,6 +302,7 @@ public async Task<SKContext> ChatAsync(
return context;
}

context.Variables.Set("tokenUsage", JsonSerializer.Serialize(chatMessage.TokenUsage));
return context;
}

Expand Down Expand Up @@ -350,7 +357,11 @@ public async Task<SKContext> ChatAsync(
chatContext.Variables.Set("prompt", prompt);

// Save a new response to the chat history with the proposed plan content
return await this.SaveNewResponseAsync(JsonSerializer.Serialize<ProposedPlan>(proposedPlan), prompt, chatId, userId);
return await this.SaveNewResponseAsync(
JsonSerializer.Serialize<ProposedPlan>(proposedPlan), prompt, chatId, userId,
// TODO: [Issue #2106] Accommodate plan token usage differently
this.GetTokenUsagesAsync(chatContext)
);
}

// Query relevant semantic and document memories
Expand All @@ -376,7 +387,7 @@ public async Task<SKContext> ChatAsync(
// Fill in the chat history if there is any token budget left
var chatContextComponents = new List<string>() { chatMemories, documentMemories, planResult };
var chatContextText = string.Join("\n\n", chatContextComponents.Where(c => !string.IsNullOrEmpty(c)));
var chatHistoryTokenLimit = remainingToken - Utilities.TokenCount(chatContextText);
var chatHistoryTokenLimit = remainingToken - TokenUtilities.TokenCount(chatContextText);
if (chatHistoryTokenLimit > 0)
{
await this.UpdateBotResponseStatusOnClient(chatId, "Extracting chat history");
Expand All @@ -399,6 +410,7 @@ public async Task<SKContext> ChatAsync(
this._promptOptions.SystemChatPrompt,
chatContext);
chatContext.Variables.Set("prompt", renderedPrompt);
chatContext.Variables.Set(TokenUtilities.GetFunctionKey(chatContext.Log, "SystemMetaPrompt")!, TokenUtilities.TokenCount(renderedPrompt).ToString(CultureInfo.InvariantCulture));

if (chatContext.ErrorOccurred)
{
Expand All @@ -417,9 +429,15 @@ await SemanticChatMemoryExtractor.ExtractSemanticChatMemoryAsync(
chatContext,
this._promptOptions);

// Save the message
// Calculate total token usage for dependency functions and prompt template and send to client
await this.UpdateBotResponseStatusOnClient(chatId, "Calculating token usage");
chatMessage.TokenUsage = this.GetTokenUsagesAsync(chatContext, chatMessage.Content);
await this.UpdateMessageOnClient(chatMessage);

// Save the message with final completion token usage
await this.UpdateBotResponseStatusOnClient(chatId, "Saving message to chat history");
await this._chatMessageRepository.UpsertAsync(chatMessage);

return chatMessage;
}

Expand All @@ -442,6 +460,13 @@ private async Task<string> GetAudienceAsync(SKContext context)

var audience = await this.ExtractAudienceAsync(audienceContext);

// Copy token usage into original chat context
var functionKey = TokenUtilities.GetFunctionKey(context.Log, "SystemAudienceExtraction")!;
if (audienceContext.Variables.TryGetValue(functionKey, out string? tokenUsage))
{
context.Variables.Set(functionKey, tokenUsage);
}

// Propagate the error
if (audienceContext.ErrorOccurred)
{
Expand Down Expand Up @@ -473,6 +498,14 @@ private async Task<string> GetUserIntentAsync(SKContext context)
);

userIntent = await this.ExtractUserIntentAsync(intentContext);

// Copy token usage into original chat context
var functionKey = TokenUtilities.GetFunctionKey(context.Log, "SystemIntentExtraction")!;
if (intentContext.Variables.TryGetValue(functionKey!, out string? tokenUsage))
{
context.Variables.Set(functionKey!, tokenUsage);
}

// Propagate the error
if (intentContext.ErrorOccurred)
{
Expand Down Expand Up @@ -579,8 +612,9 @@ private async Task<ChatMessage> SaveNewMessageAsync(string message, string userI
/// <param name="prompt">Prompt used to generate the response.</param>
/// <param name="chatId">The chat ID</param>
/// <param name="userId">The user ID</param>
/// <param name="tokenUsage">Total token usage of response completion</param>
/// <returns>The created chat message.</returns>
private async Task<ChatMessage> SaveNewResponseAsync(string response, string prompt, string chatId, string userId)
private async Task<ChatMessage> SaveNewResponseAsync(string response, string prompt, string chatId, string userId, Dictionary<string, int>? tokenUsage)
{
// Make sure the chat exists.
if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, v => _ = v))
Expand Down Expand Up @@ -651,10 +685,10 @@ private int GetChatContextTokenLimit(string audience, string userIntent)
var tokenLimit = this._promptOptions.CompletionTokenLimit;
var remainingToken =
tokenLimit -
Utilities.TokenCount(audience) -
Utilities.TokenCount(userIntent) -
TokenUtilities.TokenCount(audience) -
TokenUtilities.TokenCount(userIntent) -
this._promptOptions.ResponseTokenLimit -
Utilities.TokenCount(string.Join("\n", new string[]
TokenUtilities.TokenCount(string.Join("\n", new string[]
{
this._promptOptions.SystemDescription,
this._promptOptions.SystemResponse,
Expand All @@ -665,6 +699,33 @@ private int GetChatContextTokenLimit(string audience, string userIntent)
return remainingToken;
}

/// <summary>
/// Gets token usage totals for each semantic function if not undefined.
/// </summary>
/// <param name="chatContext">Context maintained during response generation.</param>
/// <param name="content">String representing bot response. If null, response is still being generated or was hardcoded.</param>
/// <returns>Dictionary containing function to token usage mapping for each total that's defined.</returns>
private Dictionary<string, int> GetTokenUsagesAsync(SKContext chatContext, string? content = null)
{
var tokenUsageDict = new Dictionary<string, int>(StringComparer.OrdinalIgnoreCase);

// Total token usage of each semantic function
foreach (string function in TokenUtilities.semanticFunctions.Values)
{
if (chatContext.Variables.TryGetValue($"{function}TokenUsage", out string? tokenUsage))
{
tokenUsageDict.Add(function, int.Parse(tokenUsage, CultureInfo.InvariantCulture));
}
}

if (content != null)
{
tokenUsageDict.Add(TokenUtilities.semanticFunctions["SystemCompletion"]!, TokenUtilities.TokenCount(content));
}

return tokenUsageDict;
}

/// <summary>
/// Stream the response to the client.
/// </summary>
Expand All @@ -685,7 +746,7 @@ private async Task<ChatMessage> StreamResponseToClient(string chatId, string use
await foreach (string contentPiece in stream)
{
chatMessage.Content += contentPiece;
await this.UpdateMessageContentOnClient(chatId, chatMessage);
await this.UpdateMessageOnClient(chatMessage);
}

return chatMessage;
Expand All @@ -698,22 +759,22 @@ private async Task<ChatMessage> StreamResponseToClient(string chatId, string use
/// <param name="userId">The user ID</param>
/// <param name="prompt">Prompt used to generate the message</param>
/// <param name="content">Content of the message</param>
/// <param name="tokenUsage">Total token usage of response completion</param>
/// <returns>The created chat message</returns>
private async Task<ChatMessage> CreateBotMessageOnClient(string chatId, string userId, string prompt, string content)
private async Task<ChatMessage> CreateBotMessageOnClient(string chatId, string userId, string prompt, string content, Dictionary<string, int>? tokenUsage = null)
{
var chatMessage = ChatMessage.CreateBotResponseMessage(chatId, content, prompt);
var chatMessage = ChatMessage.CreateBotResponseMessage(chatId, content, prompt, tokenUsage);
await this._messageRelayHubContext.Clients.Group(chatId).SendAsync("ReceiveMessage", chatId, userId, chatMessage);
return chatMessage;
}

/// <summary>
/// Update the response on the client.
/// </summary>
/// <param name="chatId">The chat ID</param>
/// <param name="message">The message</param>
private async Task UpdateMessageContentOnClient(string chatId, ChatMessage message)
private async Task UpdateMessageOnClient(ChatMessage message)
{
await this._messageRelayHubContext.Clients.Group(chatId).SendAsync("ReceiveMessageStream", chatId, message.Id, message.Content);
await this._messageRelayHubContext.Clients.Group(message.ChatId).SendAsync("ReceiveMessageUpdate", message);
}

/// <summary>
Expand Down
Loading

0 comments on commit 6aefee8

Please sign in to comment.