Skip to content

Commit

Permalink
Merge pull request #28 from askorama/feat/adds-regenerate
Browse files Browse the repository at this point in the history
feat: adds 'answerSession.regenerateLast' feature
  • Loading branch information
micheleriva authored Jul 23, 2024
2 parents eda873f + 13c11a1 commit c07e019
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 23 deletions.
52 changes: 37 additions & 15 deletions src/answerSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ export type AnswerParams<UserContext = unknown> = {
onAnswerAborted?: (aborted: true) => void
onSourceChange?: <T = AnyDocument>(sources: Results<T>) => void
onQueryTranslated?: (query: SearchParams<AnyOrama>) => void
onRelatedQueries?: (relatedQueries: string[]) => void,
onNewInteractionStarted?: (interactionId: string) => void,
onRelatedQueries?: (relatedQueries: string[]) => void
onNewInteractionStarted?: (interactionId: string) => void
onStateChange?: (state: Interaction[]) => void
}
}

export type Interaction<T = AnyDocument> = {
interactionId: string,
query: string,
response: string,
relatedQueries: Nullable<string[]>,
sources: Nullable<Results<T>>,
translatedQuery: Nullable<SearchParams<AnyOrama>>,
aborted: boolean,
interactionId: string
query: string
response: string
relatedQueries: Nullable<string[]>
sources: Nullable<Results<T>>
translatedQuery: Nullable<SearchParams<AnyOrama>>
aborted: boolean
loading: boolean
}

Expand All @@ -60,6 +60,7 @@ export class AnswerSession {
private userContext?: AnswerParams['userContext']
private conversationID: string
private userID: string
private lastInteractionParams?: AskParams

public state: Interaction[] = []

Expand Down Expand Up @@ -105,21 +106,42 @@ export class AnswerSession {
this.messages = []
}

private addNewEmptyAssistantMessage(): void {
this.messages.push({ role: 'assistant', content: '' })
}

public abortAnswer() {
if (this.abortController) {
this.abortController.abort()
this.abortController = undefined
this.messages.pop()
this.state[this.state.length - 1].aborted = true
}
}

public async regenerateLast({ stream = true } = {}): Promise<string | AsyncGenerator<string>> {
if (this.state.length === 0 || this.messages.length === 0) {
throw new Error('No messages to regenerate')
}

const isLastMessageAssistant = this.messages.at(-1)?.role === 'assistant'

if (!isLastMessageAssistant) {
throw new Error('Last message is not an assistant message')
}

this.messages.pop()
this.state.pop()

if (stream) {
return this.askStream(this.lastInteractionParams as AskParams)
}

return this.ask(this.lastInteractionParams as AskParams)
}

private addNewEmptyAssistantMessage(): void {
this.messages.push({ role: 'assistant', content: '' })
}

private async *fetchAnswer(params: AskParams): AsyncGenerator<string> {
this.abortController = new AbortController()

this.lastInteractionParams = params
const interactionId = createId()

this.state.push({
Expand Down
4 changes: 2 additions & 2 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ export type AnswerSessionParams = {
onAnswerAborted?: (aborted: true) => void
onSourceChange?: <T = AnyDocument>(sources: Results<T>) => void
onQueryTranslated?: (query: SearchParams<AnyOrama>) => void
onRelatedQueries?: (relatedQueries: string[]) => void,
onNewInteractionStarted?: (interactionId: string) => void,
onRelatedQueries?: (relatedQueries: string[]) => void
onNewInteractionStarted?: (interactionId: string) => void
onStateChange?: (state: Interaction[]) => void
}
}
Expand Down
47 changes: 41 additions & 6 deletions tests/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import { CloudManager } from '../src/manager/index.js'
import 'dotenv/config.js'
import { Interaction } from '../src/answerSession.js'

function createProxy() {
return new OramaProxy({
api_key: process.env.ORAMA_SECURE_PROXY_API_KEY_TEST || ''
})
}

await t.test('secure proxy', async t => {

await t.test('summaryStream should abort previous requests', async t => {
Expand Down Expand Up @@ -174,12 +180,6 @@ await t.test('secure proxy', async t => {
})
})

function createProxy() {
return new OramaProxy({
api_key: process.env.ORAMA_SECURE_PROXY_API_KEY_TEST || ''
})
}

await t.test('answer session', async t => {

if (!process.env.ORAMA_E2E_ENDPOINT || !process.env.ORAMA_E2E_API_KEY) {
Expand Down Expand Up @@ -339,4 +339,39 @@ await t.test('state management via answer session APIs', async t => {
assert.equal(state[1].query, 'labrador')
assert.equal(state[0].aborted, false)
})
})

await t.test('regenerate last answer', async t => {
if (!process.env.ORAMA_E2E_ENDPOINT || !process.env.ORAMA_E2E_API_KEY) {
t.skip('ORAMA_E2E_ENDPOINT and ORAMA_E2E_API_KEY are not set. E2e tests will be skipped.')
return
}

const client = new OramaClient({
endpoint: process.env.ORAMA_E2E_ENDPOINT!,
api_key: process.env.ORAMA_E2E_API_KEY!
})

let state: Interaction[] = []

const answerSession = client.createAnswerSession({
events: {
onStateChange: (newState) => {
state = newState
}
}
})

await answerSession.ask({
term: 'german'
})

await answerSession.ask({
term: 'labrador'
})

await answerSession.regenerateLast({ stream: false })

assert.equal(state.length, 2)
assert.equal(state[state.length - 1].query, 'labrador')
})

0 comments on commit c07e019

Please sign in to comment.