Skip to content

Commit

Permalink
Respect order, number to skip and number to take of chat messages at …
Browse files Browse the repository at this point in the history
…underlying DB level rather than at higher service level
  • Loading branch information
glahaye committed Mar 27, 2024
1 parent b8e7b47 commit 4d8baba
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 30 deletions.
3 changes: 1 addition & 2 deletions webapi/Controllers/ChatArchiveController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ private async Task<List<Citation>> GetMemoryRecordsAndAppendToEmbeddingsAsync(
/// <returns>The list of chat messages in descending order of the timestamp</returns>
private async Task<List<CopilotChatMessage>> GetAllChatMessagesAsync(string chatId)
{
return (await this._chatMessageRepository.FindByChatIdAsync(chatId))
.OrderByDescending(m => m.Timestamp).ToList();
return (await this._chatMessageRepository.FindByChatIdAsync(chatId)).ToList();
}
}
16 changes: 6 additions & 10 deletions webapi/Controllers/ChatHistoryController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ public async Task<IActionResult> GetAllChatSessionsAsync()
}

/// <summary>
/// Get all chat messages for a chat session.
/// The list will be ordered with the first entry being the most recent message.
/// Get chat messages for a chat session.
/// Messages are returned ordered from most recent to oldest.
/// </summary>
/// <param name="chatId">The chat id.</param>
/// <param name="startIdx">The start index at which the first message will be returned.</param>
/// <param name="count">The number of messages to return. -1 will return all messages starting from startIdx.</param>
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
[HttpGet]
[Route("chats/{chatId:guid}/messages")]
[ProducesResponseType(StatusCodes.Status200OK)]
Expand All @@ -183,19 +183,15 @@ public async Task<IActionResult> GetAllChatSessionsAsync()
[Authorize(Policy = AuthPolicyName.RequireChatParticipant)]
public async Task<IActionResult> GetChatMessagesAsync(
[FromRoute] Guid chatId,
[FromQuery] int startIdx = 0,
[FromQuery] int skip = 0,
[FromQuery] int count = -1)
{
// TODO: [Issue #48] the code mixes strings and Guid without being explicit about the serialization format
var chatMessages = await this._messageRepository.FindByChatIdAsync(chatId.ToString());
var chatMessages = await this._messageRepository.FindByChatIdAsync(chatId.ToString(), skip, count);
if (!chatMessages.Any())
{
return this.NotFound($"No messages found for chat id '{chatId}'.");
}

chatMessages = chatMessages.OrderByDescending(m => m.Timestamp).Skip(startIdx);
if (count >= 0) { chatMessages = chatMessages.Take(count); }

return this.Ok(chatMessages);
}

Expand Down
8 changes: 4 additions & 4 deletions webapi/Extensions/ServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ internal static IServiceCollection AddCorsPolicy(this IServiceCollection service
public static IServiceCollection AddPersistentChatStore(this IServiceCollection services)
{
IStorageContext<ChatSession> chatSessionStorageContext;
IStorageContext<CopilotChatMessage> chatMessageStorageContext;
ICopilotChatMessageStorageContext chatMessageStorageContext;
IStorageContext<MemorySource> chatMemorySourceStorageContext;
IStorageContext<ChatParticipant> chatParticipantStorageContext;

Expand All @@ -175,7 +175,7 @@ public static IServiceCollection AddPersistentChatStore(this IServiceCollection
case ChatStoreOptions.ChatStoreType.Volatile:
{
chatSessionStorageContext = new VolatileContext<ChatSession>();
chatMessageStorageContext = new VolatileContext<CopilotChatMessage>();
chatMessageStorageContext = new VolatileCopilotChatMessageContext();
chatMemorySourceStorageContext = new VolatileContext<MemorySource>();
chatParticipantStorageContext = new VolatileContext<ChatParticipant>();
break;
Expand All @@ -192,7 +192,7 @@ public static IServiceCollection AddPersistentChatStore(this IServiceCollection
string directory = Path.GetDirectoryName(fullPath) ?? string.Empty;
chatSessionStorageContext = new FileSystemContext<ChatSession>(
new FileInfo(Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(fullPath)}_sessions{Path.GetExtension(fullPath)}")));
chatMessageStorageContext = new FileSystemContext<CopilotChatMessage>(
chatMessageStorageContext = new FileSystemCopilotChatMessageContext(
new FileInfo(Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(fullPath)}_messages{Path.GetExtension(fullPath)}")));
chatMemorySourceStorageContext = new FileSystemContext<MemorySource>(
new FileInfo(Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(fullPath)}_memorysources{Path.GetExtension(fullPath)}")));
Expand All @@ -210,7 +210,7 @@ public static IServiceCollection AddPersistentChatStore(this IServiceCollection
#pragma warning disable CA2000 // Dispose objects before losing scope - objects are singletons for the duration of the process and disposed when the process exits.
chatSessionStorageContext = new CosmosDbContext<ChatSession>(
chatStoreConfig.Cosmos.ConnectionString, chatStoreConfig.Cosmos.Database, chatStoreConfig.Cosmos.ChatSessionsContainer);
chatMessageStorageContext = new CosmosDbContext<CopilotChatMessage>(
chatMessageStorageContext = new CosmosDbCopilotChatMessageContext(
chatStoreConfig.Cosmos.ConnectionString, chatStoreConfig.Cosmos.Database, chatStoreConfig.Cosmos.ChatMessagesContainer);
chatMemorySourceStorageContext = new CosmosDbContext<MemorySource>(
chatStoreConfig.Cosmos.ConnectionString, chatStoreConfig.Cosmos.Database, chatStoreConfig.Cosmos.ChatMemorySourcesContainer);
Expand Down
3 changes: 1 addition & 2 deletions webapi/Plugins/Chat/ChatPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ private async Task<string> GetAllowedChatHistoryAsync(
ChatHistory? chatHistory = null,
CancellationToken cancellationToken = default)
{
var messages = await this._chatMessageRepository.FindByChatIdAsync(chatId);
var sortedMessages = messages.OrderByDescending(m => m.Timestamp);
var sortedMessages = await this._chatMessageRepository.FindByChatIdAsync(chatId, 0, 100);

ChatHistory allottedChatHistory = new();
var remainingToken = tokenLimit;
Expand Down
14 changes: 8 additions & 6 deletions webapi/Storage/ChatMessageRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace CopilotChat.WebApi.Storage;
/// <summary>
/// A repository for chat messages.
/// </summary>
public class ChatMessageRepository : Repository<CopilotChatMessage>
public class ChatMessageRepository : CopilotChatMessageRepository
{
/// <summary>
/// Initializes a new instance of the ChatMessageRepository class.
/// </summary>
/// <param name="storageContext">The storage context.</param>
public ChatMessageRepository(IStorageContext<CopilotChatMessage> storageContext)
public ChatMessageRepository(ICopilotChatMessageStorageContext storageContext)
: base(storageContext)
{
}
Expand All @@ -25,10 +25,12 @@ public ChatMessageRepository(IStorageContext<CopilotChatMessage> storageContext)
/// Finds chat messages by chat id.
/// </summary>
/// <param name="chatId">The chat id.</param>
/// <returns>A list of ChatMessages matching the given chatId.</returns>
public Task<IEnumerable<CopilotChatMessage>> FindByChatIdAsync(string chatId)
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
/// <returns>A list of ChatMessages matching the given chatId sorted from most recent to oldest.</returns>
public Task<IEnumerable<CopilotChatMessage>> FindByChatIdAsync(string chatId, int skip = 0, int count = -1)
{
return base.StorageContext.QueryEntitiesAsync(e => e.ChatId == chatId);
return base.QueryEntitiesAsync(e => e.ChatId == chatId, skip, count);
}

/// <summary>
Expand All @@ -38,7 +40,7 @@ public Task<IEnumerable<CopilotChatMessage>> FindByChatIdAsync(string chatId)
/// <returns>The most recent ChatMessage matching the given chatId.</returns>
public async Task<CopilotChatMessage> FindLastByChatIdAsync(string chatId)
{
var chatMessages = await this.FindByChatIdAsync(chatId);
var chatMessages = await this.FindByChatIdAsync(chatId, 0, 1);
var first = chatMessages.MaxBy(e => e.Timestamp);
return first ?? throw new KeyNotFoundException($"No messages found for chat '{chatId}'.");
}
Expand Down
30 changes: 29 additions & 1 deletion webapi/Storage/CosmosDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using System.Net;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;
using Microsoft.Azure.Cosmos;

namespace CopilotChat.WebApi.Storage;
Expand All @@ -22,7 +23,9 @@ public class CosmosDbContext<T> : IStorageContext<T>, IDisposable where T : ISto
/// <summary>
/// CosmosDB container.
/// </summary>
private readonly Container _container;
#pragma warning disable CA1051 // Do not declare visible instance fields
protected readonly Container _container;
#pragma warning restore CA1051 // Do not declare visible instance fields

/// <summary>
/// Initializes a new instance of the CosmosDbContext class.
Expand Down Expand Up @@ -117,3 +120,28 @@ protected virtual void Dispose(bool disposing)
}
}
}

/// <summary>
/// Specialization of CosmosDbContext<T> for CopilotChatMessage.
/// </summary>
public class CosmosDbCopilotChatMessageContext : CosmosDbContext<CopilotChatMessage>, ICopilotChatMessageStorageContext
{
/// <summary>
/// Initializes a new instance of the CosmosDbCopilotChatMessageContext class.
/// </summary>
/// <param name="connectionString">The CosmosDB connection string.</param>
/// <param name="database">The CosmosDB database name.</param>
/// <param name="container">The CosmosDB container name.</param>
public CosmosDbCopilotChatMessageContext(string connectionString, string database, string container) :
base(connectionString, database, container)
{
}

/// <inheritdoc/>
public Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip, int count)
{
return Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._container.GetItemLinqQueryable<CopilotChatMessage>(true)
.Where(predicate).OrderByDescending(m => m.Timestamp).Skip(skip).Take(count).AsEnumerable());
}
}
32 changes: 30 additions & 2 deletions webapi/Storage/FileSystemContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Text.Json;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand Down Expand Up @@ -99,14 +100,16 @@ public Task UpsertAsync(T entity)
/// <summary>
/// A concurrent dictionary to store entities in memory.
/// </summary>
private sealed class EntityDictionary : ConcurrentDictionary<string, T>
protected sealed class EntityDictionary : ConcurrentDictionary<string, T>
{
}

/// <summary>
/// Using a concurrent dictionary to store entities in memory.
/// </summary>
private readonly EntityDictionary _entities;
#pragma warning disable CA1051 // Do not declare visible instance fields
protected readonly EntityDictionary _entities;
#pragma warning restore CA1051 // Do not declare visible instance fields

/// <summary>
/// The file path to store entities on disk.
Expand Down Expand Up @@ -164,3 +167,28 @@ private EntityDictionary Load(FileInfo fileInfo)
}
}
}

/// <summary>
/// Specialization of FileSystemContext<T> for CopilotChatMessage.
/// </summary>
public class FileSystemCopilotChatMessageContext : FileSystemContext<CopilotChatMessage>, ICopilotChatMessageStorageContext
{
/// <summary>
/// Initializes a new instance of the CosmosDbContext class.
/// </summary>
/// <param name="connectionString">The CosmosDB connection string.</param>
/// <param name="database">The CosmosDB database name.</param>
/// <param name="container">The CosmosDB container name.</param>
public FileSystemCopilotChatMessageContext(FileInfo filePath) :
base(filePath)
{
}

/// <inheritdoc/>
public Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip, int count)
{
return Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._entities.Values
.Where(predicate).OrderByDescending(m => m.Timestamp).Skip(skip).Take(count));
}
}
17 changes: 17 additions & 0 deletions webapi/Storage/IStorageContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand All @@ -13,6 +14,7 @@ public interface IStorageContext<T> where T : IStorageEntity
{
/// <summary>
/// Query entities in the storage context.
/// <param name="predicate">Predicate that needs to evaluate to true for a particular entryto be returned.</param>
/// </summary>
Task<IEnumerable<T>> QueryEntitiesAsync(Func<T, bool> predicate);

Expand Down Expand Up @@ -42,3 +44,18 @@ public interface IStorageContext<T> where T : IStorageEntity
/// <param name="entity">The entity to be deleted from the context.</param>
Task DeleteAsync(T entity);
}

/// <summary>
/// Specialization of IStorageContext<T> for CopilotChatMessage.
/// </summary>
public interface ICopilotChatMessageStorageContext : IStorageContext<CopilotChatMessage>
{
/// <summary>
/// Query entities in the storage context.
/// </summary>
/// <param name="predicate">Predicate that needs to evaluate to true for a particular entryto be returned.</param>
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
/// <returns>A list of ChatMessages matching the given chatId sorted from most recent to oldest.</returns>
Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip = 0, int count = -1);
}
28 changes: 28 additions & 0 deletions webapi/Storage/Repository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand Down Expand Up @@ -70,3 +71,30 @@ public Task UpsertAsync(T entity)
return this.StorageContext.UpsertAsync(entity);
}
}

/// <summary>
/// Specialization of Repository<T> for CopilotChatMessage.
/// </summary>
public class CopilotChatMessageRepository : Repository<CopilotChatMessage>
{
private readonly ICopilotChatMessageStorageContext _messageStorageContext;

public CopilotChatMessageRepository(ICopilotChatMessageStorageContext storageContext)
: base(storageContext)
{
this._messageStorageContext = storageContext;
}

/// <summary>
/// Finds chat messages matching a predicate.
/// </summary>
/// <param name="predicate">Predicate that needs to evaluate to true for a particular entryto be returned.</param>
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
/// <returns>A list of ChatMessages matching the given chatId sorted from most recent to oldest.</returns>
public async Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip = 0, int count = -1)
{
return await Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._messageStorageContext.QueryEntitiesAsync(predicate, skip, count));
}
}
19 changes: 18 additions & 1 deletion webapi/Storage/VolatileContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand All @@ -18,7 +19,9 @@ public class VolatileContext<T> : IStorageContext<T> where T : IStorageEntity
/// <summary>
/// Using a concurrent dictionary to store entities in memory.
/// </summary>
private readonly ConcurrentDictionary<string, T> _entities;
#pragma warning disable CA1051 // Do not declare visible instance fields
protected readonly ConcurrentDictionary<string, T> _entities;
#pragma warning restore CA1051 // Do not declare visible instance fields

/// <summary>
/// Initializes a new instance of the InMemoryContext class.
Expand Down Expand Up @@ -94,3 +97,17 @@ private string GetDebuggerDisplay()
return this.ToString() ?? string.Empty;
}
}

/// <summary>
/// Specialization of VolatileContext<T> for CopilotChatMessage.
/// </summary>
public class VolatileCopilotChatMessageContext : VolatileContext<CopilotChatMessage>, ICopilotChatMessageStorageContext
{
/// <inheritdoc/>
public Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip, int count)
{
return Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._entities.Values
.Where(predicate).OrderByDescending(m => m.Timestamp).Skip(skip).Take(count));
}
}
4 changes: 2 additions & 2 deletions webapp/src/libs/services/ChatService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ export class ChatService extends BaseService {

public getChatMessagesAsync = async (
chatId: string,
startIdx: number,
skip: number,
count: number,
accessToken: string,
): Promise<IChatMessage[]> => {
const result = await this.getResponseAsync<IChatMessage[]>(
{
commandPath: `chats/${chatId}/messages?startIdx=${startIdx}&count=${count}`,
commandPath: `chats/${chatId}/messages?skip=${skip}&count=${count}`,
method: 'GET',
},
accessToken,
Expand Down

0 comments on commit 4d8baba

Please sign in to comment.