Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that the entire chat request is cancellable #433

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions chat-client/src/client/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import {
import {
CHAT_REQUEST_METHOD,
ChatParams,
END_CHAT_REQUEST_METHOD,
EndChatParams,
FEEDBACK_NOTIFICATION_METHOD,
FOLLOW_UP_CLICK_NOTIFICATION_METHOD,
FeedbackParams,
Expand Down Expand Up @@ -106,6 +108,9 @@ export const createChat = (
}

const chatApi: OutboundChatApi = {
endChat: (params: EndChatParams) => {
sendMessageToClient({ command: END_CHAT_REQUEST_METHOD, params })
},
sendChatPrompt: (params: ChatParams) => {
sendMessageToClient({ command: CHAT_REQUEST_METHOD, params })
},
Expand Down
6 changes: 6 additions & 0 deletions chat-client/src/client/messager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
} from '@aws/chat-client-ui-types'
import {
ChatParams,
EndChatParams,
FeedbackParams,
FollowUpClickParams,
InfoLinkClickParams,
Expand Down Expand Up @@ -40,6 +41,7 @@ import {
} from '../contracts/telemetry'

export interface OutboundChatApi {
endChat(params: EndChatParams): void
sendChatPrompt(params: ChatParams): void
sendQuickActionCommand(params: QuickActionParams): void
tabAdded(params: TabAddParams): void
Expand Down Expand Up @@ -81,6 +83,10 @@ export class Messager {
this.chatApi.telemetry({ ...params, tabId, name: SEND_TO_PROMPT_TELEMETRY_EVENT })
}

onStopChatResponse = (params: EndChatParams): void => {
this.chatApi.endChat(params)
}

onChatPrompt = (params: ChatParams, triggerType?: string): void => {
// Let the server know about the latest trigger interaction on the tabId
this.chatApi.telemetry({
Expand Down
1 change: 1 addition & 0 deletions chat-client/src/client/mynahUi.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ describe('MynahUI', () => {

beforeEach(() => {
outboundChatApi = {
endChat: sinon.stub(),
sendChatPrompt: sinon.stub(),
sendQuickActionCommand: sinon.stub(),
tabAdded: sinon.stub(),
Expand Down
7 changes: 7 additions & 0 deletions chat-client/src/client/mynahUi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ export const createMynahUi = (messager: Messager, tabFactory: TabFactory): [Myna
}
messager.onSourceLinkClick(payload)
},
onStopChatResponse(tabId, _eventId) {
messager.onStopChatResponse({ tabId })
mynahUi.updateStore(tabId, {
loadingChat: false,
promptInputDisabledState: false,
})
},
onInfoLinkClick: (tabId, link, mouseEvent, eventId) => {
mouseEvent?.preventDefault()
mouseEvent?.stopPropagation()
Expand Down
2 changes: 2 additions & 0 deletions chat-client/src/contracts/serverContracts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ import {
SOURCE_LINK_CLICK_NOTIFICATION_METHOD,
INFO_LINK_CLICK_NOTIFICATION_METHOD,
QUICK_ACTION_REQUEST_METHOD,
END_CHAT_REQUEST_METHOD,
} from '@aws/language-server-runtimes-types'

export const TELEMETRY = 'telemetry/event'

export type ServerMessageCommand =
| typeof END_CHAT_REQUEST_METHOD
| typeof CHAT_REQUEST_METHOD
| typeof TAB_ADD_NOTIFICATION_METHOD
| typeof TAB_REMOVE_NOTIFICATION_METHOD
Expand Down
58 changes: 49 additions & 9 deletions client/vscode/src/chatActivation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
quickActionRequestType,
QuickActionResult,
QuickActionParams,
END_CHAT_REQUEST_METHOD,
} from '@aws/language-server-runtimes/protocol'
import { v4 as uuidv4 } from 'uuid'
import { Uri, ViewColumn, Webview, WebviewPanel, commands, window } from 'vscode'
Expand Down Expand Up @@ -54,26 +55,64 @@ export function registerChat(languageClient: LanguageClient, extensionUri: Uri,
languageClient.info(`vscode client: Received ${JSON.stringify(message)} from chat`)

switch (message.command) {
case END_CHAT_REQUEST_METHOD:
Copy link
Contributor

@ege0zcan ege0zcan Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would expect this case to be already taken care of by the 'default' case, I think we can remove this specific case I see the default case doesn't send requests right now and only notifications, this comment can be ignored

languageClient.sendRequest(END_CHAT_REQUEST_METHOD, message.params)
break
case INSERT_TO_CURSOR_POSITION:
insertTextAtCursorPosition(message.params.code)
break
case AUTH_FOLLOW_UP_CLICKED:
languageClient.info('AuthFollowUp clicked')
break
case chatRequestType.method:
case chatRequestType.method: {
let chatResult: ChatResult | string = ''
const textDocument = window.visibleTextEditors.find(editor => editor.document.uri.scheme === 'file')
const documentUri = textDocument?.document.uri.toString()

const partialResultToken = uuidv4()
const chatDisposable = languageClient.onProgress(chatRequestType, partialResultToken, partialResult =>
const chatDisposable = languageClient.onProgress(chatRequestType, partialResultToken, partialResult => {
chatResult = partialResult
handlePartialResult<ChatResult>(partialResult, encryptionKey, panel, message.params.tabId)
})

const chatRequest = await encryptRequest<ChatParams>(
{
...message.params,
textDocument: { uri: documentUri },
},
encryptionKey
)

const chatRequest = await encryptRequest<ChatParams>(message.params, encryptionKey)
const chatResult = await languageClient.sendRequest(chatRequestType, {
...chatRequest,
partialResultToken,
})
handleCompleteResult<ChatResult>(chatResult, encryptionKey, panel, message.params.tabId, chatDisposable)
try {
chatResult = await languageClient.sendRequest(chatRequestType, {
...chatRequest,
partialResultToken,
})
} catch (e) {
if (e instanceof Error) {
languageClient.info(`Client caught error during chat request: ${e.message}`)

if (chatResult === '') {
if (e.message === 'Request cancelled') {
languageClient.info('Request cancelled before receiving any partial result')
chatResult = { body: 'Request Cancelled' }
} else {
chatResult = { body: `Error in chat: ${e.message}` }
}
}
}
} finally {
handleCompleteResult<ChatResult>(
chatResult,
encryptionKey,
panel,
message.params.tabId,
chatDisposable
)
}
break
case quickActionRequestType.method:
}
case quickActionRequestType.method: {
const quickActionPartialResultToken = uuidv4()
const quickActionDisposable = languageClient.onProgress(
quickActionRequestType,
Expand All @@ -100,6 +139,7 @@ export function registerChat(languageClient: LanguageClient, extensionUri: Uri,
quickActionDisposable
)
break
}
case followUpClickNotificationType.method:
if (!isValidAuthFollowUpType(message.params.followUp.type))
languageClient.sendNotification(followUpClickNotificationType, message.params)
Expand Down
39 changes: 21 additions & 18 deletions core/aws-lsp-fqn/src/browser/fqnWorkerPool.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool } from '../common/types'
import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, Cancellable } from '../common/types'

// TODO: implement logic for browser/webworker environment
export class FqnWorkerPool implements IFqnWorkerPool {
public async exec(_input: FqnExtractorInput): Promise<ExtractorResult> {
return Promise.resolve({
success: true,
data: {
fullyQualified: {
declaredSymbols: [],
usedSymbols: [],
public exec(_input: FqnExtractorInput): Cancellable<Promise<ExtractorResult>> {
return [
Promise.resolve({
success: true,
data: {
fullyQualified: {
declaredSymbols: [],
usedSymbols: [],
},
simple: {
declaredSymbols: [],
usedSymbols: [],
},
externalSimple: {
declaredSymbols: [],
usedSymbols: [],
},
},
simple: {
declaredSymbols: [],
usedSymbols: [],
},
externalSimple: {
declaredSymbols: [],
usedSymbols: [],
},
},
})
}),
() => {},
]
}

public dispose() {}
Expand Down
43 changes: 25 additions & 18 deletions core/aws-lsp-fqn/src/common/commonFqnWorkerPool.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { pool, Pool } from 'workerpool'
import { DEFAULT_MAX_QUEUE_SIZE, DEFAULT_MAX_WORKERS, DEFAULT_TIMEOUT, FQN_WORKER_ID } from './defaults'
import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, Logger, WorkerPoolConfig } from './types'
import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, Logger, WorkerPoolConfig, Cancellable } from './types'

export class CommonFqnWorkerPool implements IFqnWorkerPool {
#workerPool: Pool
Expand All @@ -17,25 +17,32 @@ export class CommonFqnWorkerPool implements IFqnWorkerPool {
})
}

public async exec(input: FqnExtractorInput): Promise<ExtractorResult> {
public exec(input: FqnExtractorInput): Cancellable<Promise<ExtractorResult>> {
this.#logger?.log(`Extracting fully qualified names for ${input.languageId}`)

return this.#workerPool
.exec(FQN_WORKER_ID, [input])
.timeout(this.#timeout)
.then(data => data as ExtractorResult)
.catch(error => {
const errorMessage = `Encountered error while extracting fully qualified names: ${
error instanceof Error ? error.message : 'Unknown'
}`

this.#logger?.error(errorMessage)

return {
success: false,
error: errorMessage,
}
})
const execPromise = this.#workerPool.exec(FQN_WORKER_ID, [input]).timeout(this.#timeout)

return [
// have to wrap this in promise since exec promise is not a true promise
new Promise<ExtractorResult>(resolve => {
execPromise
.then(data => resolve(data as ExtractorResult))
.catch(error => {
const errorMessage = `Encountered error while extracting fully qualified names: ${
error instanceof Error ? error.message : 'Unknown'
}`

this.#logger?.error(errorMessage)

// using result pattern, so we will resolve with success: false
return resolve({
success: false,
error,
})
})
}),
() => execPromise.cancel(),
]
}

public dispose() {
Expand Down
7 changes: 5 additions & 2 deletions core/aws-lsp-fqn/src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export type Result<TData, TError> =
error: TError
}

export type ExtractorResult = Result<FqnExtractorOutput, string>
export type ExtractorResult = Result<FqnExtractorOutput, Error>

export interface FullyQualifiedName {
source: string[]
Expand Down Expand Up @@ -72,6 +72,9 @@ export interface FqnExtractorInput {
}

export interface IFqnWorkerPool {
exec(input: FqnExtractorInput): Promise<ExtractorResult>
exec(input: FqnExtractorInput): Cancellable<Promise<ExtractorResult>>
dispose(): void
}

export type CancelFn = () => void
export type Cancellable<T> = [T, CancelFn]
5 changes: 3 additions & 2 deletions core/aws-lsp-fqn/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, WorkerPoolConfig } from './common/types'
import { Cancellable, ExtractorResult, FqnExtractorInput, IFqnWorkerPool, WorkerPoolConfig } from './common/types'

export * from './common/types'

export declare class FqnWorkerPool implements IFqnWorkerPool {
constructor(workerPoolConfig?: WorkerPoolConfig)
exec(input: FqnExtractorInput): Promise<ExtractorResult>
exec(input: FqnExtractorInput): Cancellable<Promise<ExtractorResult>>
dispose(): void
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { ChatController } from './chatController'
import { ChatSessionManagementService } from './chatSessionManagementService'
import { ChatSessionService } from './chatSessionService'
import { ChatTelemetryController } from './telemetry/chatTelemetryController'
import { DocumentContextExtractor } from './contexts/documentContext'
import { TriggerContextExtractor } from './contexts/triggerContextExtractor'
import * as utils from './utils'
import { DEFAULT_HELP_FOLLOW_UP_PROMPT, HELP_MESSAGE } from './constants'

Expand Down Expand Up @@ -61,6 +61,8 @@ describe('ChatController', () => {
}
let removeConversationSpy: sinon.SinonSpy
let emitConversationMetricStub: sinon.SinonStub
let abortRequestStub: sinon.SinonStub
let triggerContextCancelSpy: sinon.SinonSpy

let testFeatures: TestFeatures
let chatSessionManagementService: ChatSessionManagementService
Expand Down Expand Up @@ -88,9 +90,10 @@ describe('ChatController', () => {
activeTabSpy = sinon.spy(ChatTelemetryController.prototype, 'activeTabId', ['get', 'set'])
removeConversationSpy = sinon.spy(ChatTelemetryController.prototype, 'removeConversation')
emitConversationMetricStub = sinon.stub(ChatTelemetryController.prototype, 'emitConversationMetric')
triggerContextCancelSpy = sinon.spy(TriggerContextExtractor.prototype, 'cancel')

disposeStub = sinon.stub(ChatSessionService.prototype, 'dispose')

abortRequestStub = sinon.stub(ChatSessionService.prototype, 'abortRequest')
chatSessionManagementService = ChatSessionManagementService.getInstance().withCredentialsProvider(
testFeatures.credentialsProvider
)
Expand Down Expand Up @@ -132,11 +135,8 @@ describe('ChatController', () => {

chatController.onEndChat({ tabId: mockTabId }, mockCancellationToken)

sinon.assert.calledOnce(disposeStub)

const hasSession = chatSessionManagementService.hasSession(mockTabId)

assert.ok(!hasSession)
sinon.assert.calledOnce(abortRequestStub)
sinon.assert.calledOnce(triggerContextCancelSpy)
})

it('onTabAdd sets active tab id in telemetryController', () => {
Expand Down Expand Up @@ -340,7 +340,7 @@ describe('ChatController', () => {
}

beforeEach(() => {
extractDocumentContextStub = sinon.stub(DocumentContextExtractor.prototype, 'extractDocumentContext')
extractDocumentContextStub = sinon.stub(TriggerContextExtractor.prototype, 'extractDocumentContext')
testFeatures.openDocument(typescriptDocument)
})

Expand Down
Loading
Loading