From 13c11a11a89b81be9eefd2a5d4511d993fa8e3c5 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Tue, 23 Jul 2024 14:53:49 -0700 Subject: [PATCH] feat: adds 'regenerateLast' feature --- src/answerSession.ts | 52 +++++++++++++++++++++++++++++++------------- src/client.ts | 4 ++-- tests/e2e.test.ts | 47 ++++++++++++++++++++++++++++++++++----- 3 files changed, 80 insertions(+), 23 deletions(-) diff --git a/src/answerSession.ts b/src/answerSession.ts index ccade3a..a4c6895 100644 --- a/src/answerSession.ts +++ b/src/answerSession.ts @@ -25,20 +25,20 @@ export type AnswerParams = { onAnswerAborted?: (aborted: true) => void onSourceChange?: (sources: Results) => void onQueryTranslated?: (query: SearchParams) => void - onRelatedQueries?: (relatedQueries: string[]) => void, - onNewInteractionStarted?: (interactionId: string) => void, + onRelatedQueries?: (relatedQueries: string[]) => void + onNewInteractionStarted?: (interactionId: string) => void onStateChange?: (state: Interaction[]) => void } } export type Interaction = { - interactionId: string, - query: string, - response: string, - relatedQueries: Nullable, - sources: Nullable>, - translatedQuery: Nullable>, - aborted: boolean, + interactionId: string + query: string + response: string + relatedQueries: Nullable + sources: Nullable> + translatedQuery: Nullable> + aborted: boolean loading: boolean } @@ -60,6 +60,7 @@ export class AnswerSession { private userContext?: AnswerParams['userContext'] private conversationID: string private userID: string + private lastInteractionParams?: AskParams public state: Interaction[] = [] @@ -105,21 +106,42 @@ export class AnswerSession { this.messages = [] } - private addNewEmptyAssistantMessage(): void { - this.messages.push({ role: 'assistant', content: '' }) - } - public abortAnswer() { if (this.abortController) { this.abortController.abort() this.abortController = undefined - this.messages.pop() + this.state[this.state.length - 1].aborted = true + } + } + + public async regenerateLast({ stream = true } = {}): Promise> { + if (this.state.length === 0 || this.messages.length === 0) { + throw new Error('No messages to regenerate') + } + + const isLastMessageAssistant = this.messages.at(-1)?.role === 'assistant' + + if (!isLastMessageAssistant) { + throw new Error('Last message is not an assistant message') } + + this.messages.pop() + this.state.pop() + + if (stream) { + return this.askStream(this.lastInteractionParams as AskParams) + } + + return this.ask(this.lastInteractionParams as AskParams) + } + + private addNewEmptyAssistantMessage(): void { + this.messages.push({ role: 'assistant', content: '' }) } private async *fetchAnswer(params: AskParams): AsyncGenerator { this.abortController = new AbortController() - + this.lastInteractionParams = params const interactionId = createId() this.state.push({ diff --git a/src/client.ts b/src/client.ts index 18f71cc..3c42e50 100644 --- a/src/client.ts +++ b/src/client.ts @@ -52,8 +52,8 @@ export type AnswerSessionParams = { onAnswerAborted?: (aborted: true) => void onSourceChange?: (sources: Results) => void onQueryTranslated?: (query: SearchParams) => void - onRelatedQueries?: (relatedQueries: string[]) => void, - onNewInteractionStarted?: (interactionId: string) => void, + onRelatedQueries?: (relatedQueries: string[]) => void + onNewInteractionStarted?: (interactionId: string) => void onStateChange?: (state: Interaction[]) => void } } diff --git a/tests/e2e.test.ts b/tests/e2e.test.ts index f896e22..3f74d91 100644 --- a/tests/e2e.test.ts +++ b/tests/e2e.test.ts @@ -9,6 +9,12 @@ import { CloudManager } from '../src/manager/index.js' import 'dotenv/config.js' import { Interaction } from '../src/answerSession.js' +function createProxy() { + return new OramaProxy({ + api_key: process.env.ORAMA_SECURE_PROXY_API_KEY_TEST || '' + }) +} + await t.test('secure proxy', async t => { await t.test('summaryStream should abort previous requests', async t => { @@ -174,12 +180,6 @@ await t.test('secure proxy', async t => { }) }) -function createProxy() { - return new OramaProxy({ - api_key: process.env.ORAMA_SECURE_PROXY_API_KEY_TEST || '' - }) -} - await t.test('answer session', async t => { if (!process.env.ORAMA_E2E_ENDPOINT || !process.env.ORAMA_E2E_API_KEY) { @@ -339,4 +339,39 @@ await t.test('state management via answer session APIs', async t => { assert.equal(state[1].query, 'labrador') assert.equal(state[0].aborted, false) }) +}) + +await t.test('regenerate last answer', async t => { + if (!process.env.ORAMA_E2E_ENDPOINT || !process.env.ORAMA_E2E_API_KEY) { + t.skip('ORAMA_E2E_ENDPOINT and ORAMA_E2E_API_KEY are not set. E2e tests will be skipped.') + return + } + + const client = new OramaClient({ + endpoint: process.env.ORAMA_E2E_ENDPOINT!, + api_key: process.env.ORAMA_E2E_API_KEY! + }) + + let state: Interaction[] = [] + + const answerSession = client.createAnswerSession({ + events: { + onStateChange: (newState) => { + state = newState + } + } + }) + + await answerSession.ask({ + term: 'german' + }) + + await answerSession.ask({ + term: 'labrador' + }) + + await answerSession.regenerateLast({ stream: false }) + + assert.equal(state.length, 2) + assert.equal(state[state.length - 1].query, 'labrador') }) \ No newline at end of file