diff --git a/chat-client/src/client/chat.ts b/chat-client/src/client/chat.ts index 6e44a921..2e4b050d 100644 --- a/chat-client/src/client/chat.ts +++ b/chat-client/src/client/chat.ts @@ -17,6 +17,8 @@ import { import { CHAT_REQUEST_METHOD, ChatParams, + END_CHAT_REQUEST_METHOD, + EndChatParams, FEEDBACK_NOTIFICATION_METHOD, FOLLOW_UP_CLICK_NOTIFICATION_METHOD, FeedbackParams, @@ -106,6 +108,9 @@ export const createChat = ( } const chatApi: OutboundChatApi = { + endChat: (params: EndChatParams) => { + sendMessageToClient({ command: END_CHAT_REQUEST_METHOD, params }) + }, sendChatPrompt: (params: ChatParams) => { sendMessageToClient({ command: CHAT_REQUEST_METHOD, params }) }, diff --git a/chat-client/src/client/messager.ts b/chat-client/src/client/messager.ts index feb6f6f3..3caae818 100644 --- a/chat-client/src/client/messager.ts +++ b/chat-client/src/client/messager.ts @@ -12,6 +12,7 @@ import { } from '@aws/chat-client-ui-types' import { ChatParams, + EndChatParams, FeedbackParams, FollowUpClickParams, InfoLinkClickParams, @@ -40,6 +41,7 @@ import { } from '../contracts/telemetry' export interface OutboundChatApi { + endChat(params: EndChatParams): void sendChatPrompt(params: ChatParams): void sendQuickActionCommand(params: QuickActionParams): void tabAdded(params: TabAddParams): void @@ -81,6 +83,10 @@ export class Messager { this.chatApi.telemetry({ ...params, tabId, name: SEND_TO_PROMPT_TELEMETRY_EVENT }) } + onStopChatResponse = (params: EndChatParams): void => { + this.chatApi.endChat(params) + } + onChatPrompt = (params: ChatParams, triggerType?: string): void => { // Let the server know about the latest trigger interaction on the tabId this.chatApi.telemetry({ diff --git a/chat-client/src/client/mynahUi.test.ts b/chat-client/src/client/mynahUi.test.ts index 73f807cf..061b0513 100644 --- a/chat-client/src/client/mynahUi.test.ts +++ b/chat-client/src/client/mynahUi.test.ts @@ -22,6 +22,7 @@ describe('MynahUI', () => { beforeEach(() => { outboundChatApi = { + endChat: sinon.stub(), sendChatPrompt: sinon.stub(), sendQuickActionCommand: sinon.stub(), tabAdded: sinon.stub(), diff --git a/chat-client/src/client/mynahUi.ts b/chat-client/src/client/mynahUi.ts index 9fcd5f04..2ad90821 100644 --- a/chat-client/src/client/mynahUi.ts +++ b/chat-client/src/client/mynahUi.ts @@ -221,6 +221,13 @@ export const createMynahUi = (messager: Messager, tabFactory: TabFactory): [Myna } messager.onSourceLinkClick(payload) }, + onStopChatResponse(tabId, _eventId) { + messager.onStopChatResponse({ tabId }) + mynahUi.updateStore(tabId, { + loadingChat: false, + promptInputDisabledState: false, + }) + }, onInfoLinkClick: (tabId, link, mouseEvent, eventId) => { mouseEvent?.preventDefault() mouseEvent?.stopPropagation() diff --git a/chat-client/src/contracts/serverContracts.ts b/chat-client/src/contracts/serverContracts.ts index 4e064f63..0c1ca8c8 100644 --- a/chat-client/src/contracts/serverContracts.ts +++ b/chat-client/src/contracts/serverContracts.ts @@ -19,11 +19,13 @@ import { SOURCE_LINK_CLICK_NOTIFICATION_METHOD, INFO_LINK_CLICK_NOTIFICATION_METHOD, QUICK_ACTION_REQUEST_METHOD, + END_CHAT_REQUEST_METHOD, } from '@aws/language-server-runtimes-types' export const TELEMETRY = 'telemetry/event' export type ServerMessageCommand = + | typeof END_CHAT_REQUEST_METHOD | typeof CHAT_REQUEST_METHOD | typeof TAB_ADD_NOTIFICATION_METHOD | typeof TAB_REMOVE_NOTIFICATION_METHOD diff --git a/client/vscode/src/chatActivation.ts b/client/vscode/src/chatActivation.ts index 4afa0a5b..5aa2f24d 100644 --- a/client/vscode/src/chatActivation.ts +++ b/client/vscode/src/chatActivation.ts @@ -12,6 +12,7 @@ import { quickActionRequestType, QuickActionResult, QuickActionParams, + END_CHAT_REQUEST_METHOD, } from '@aws/language-server-runtimes/protocol' import { v4 as uuidv4 } from 'uuid' import { Uri, ViewColumn, Webview, WebviewPanel, commands, window } from 'vscode' @@ -54,26 +55,64 @@ export function registerChat(languageClient: LanguageClient, extensionUri: Uri, languageClient.info(`vscode client: Received ${JSON.stringify(message)} from chat`) switch (message.command) { + case END_CHAT_REQUEST_METHOD: + languageClient.sendRequest(END_CHAT_REQUEST_METHOD, message.params) + break case INSERT_TO_CURSOR_POSITION: insertTextAtCursorPosition(message.params.code) break case AUTH_FOLLOW_UP_CLICKED: languageClient.info('AuthFollowUp clicked') break - case chatRequestType.method: + case chatRequestType.method: { + let chatResult: ChatResult | string = '' + const textDocument = window.visibleTextEditors.find(editor => editor.document.uri.scheme === 'file') + const documentUri = textDocument?.document.uri.toString() + const partialResultToken = uuidv4() - const chatDisposable = languageClient.onProgress(chatRequestType, partialResultToken, partialResult => + const chatDisposable = languageClient.onProgress(chatRequestType, partialResultToken, partialResult => { + chatResult = partialResult handlePartialResult(partialResult, encryptionKey, panel, message.params.tabId) + }) + + const chatRequest = await encryptRequest( + { + ...message.params, + textDocument: { uri: documentUri }, + }, + encryptionKey ) - const chatRequest = await encryptRequest(message.params, encryptionKey) - const chatResult = await languageClient.sendRequest(chatRequestType, { - ...chatRequest, - partialResultToken, - }) - handleCompleteResult(chatResult, encryptionKey, panel, message.params.tabId, chatDisposable) + try { + chatResult = await languageClient.sendRequest(chatRequestType, { + ...chatRequest, + partialResultToken, + }) + } catch (e) { + if (e instanceof Error) { + languageClient.info(`Client caught error during chat request: ${e.message}`) + + if (chatResult === '') { + if (e.message === 'Request cancelled') { + languageClient.info('Request cancelled before receiving any partial result') + chatResult = { body: 'Request Cancelled' } + } else { + chatResult = { body: `Error in chat: ${e.message}` } + } + } + } + } finally { + handleCompleteResult( + chatResult, + encryptionKey, + panel, + message.params.tabId, + chatDisposable + ) + } break - case quickActionRequestType.method: + } + case quickActionRequestType.method: { const quickActionPartialResultToken = uuidv4() const quickActionDisposable = languageClient.onProgress( quickActionRequestType, @@ -100,6 +139,7 @@ export function registerChat(languageClient: LanguageClient, extensionUri: Uri, quickActionDisposable ) break + } case followUpClickNotificationType.method: if (!isValidAuthFollowUpType(message.params.followUp.type)) languageClient.sendNotification(followUpClickNotificationType, message.params) diff --git a/core/aws-lsp-fqn/src/browser/fqnWorkerPool.ts b/core/aws-lsp-fqn/src/browser/fqnWorkerPool.ts index 0e6a8ff1..18354fe8 100644 --- a/core/aws-lsp-fqn/src/browser/fqnWorkerPool.ts +++ b/core/aws-lsp-fqn/src/browser/fqnWorkerPool.ts @@ -1,25 +1,28 @@ -import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool } from '../common/types' +import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, Cancellable } from '../common/types' // TODO: implement logic for browser/webworker environment export class FqnWorkerPool implements IFqnWorkerPool { - public async exec(_input: FqnExtractorInput): Promise { - return Promise.resolve({ - success: true, - data: { - fullyQualified: { - declaredSymbols: [], - usedSymbols: [], + public exec(_input: FqnExtractorInput): Cancellable> { + return [ + Promise.resolve({ + success: true, + data: { + fullyQualified: { + declaredSymbols: [], + usedSymbols: [], + }, + simple: { + declaredSymbols: [], + usedSymbols: [], + }, + externalSimple: { + declaredSymbols: [], + usedSymbols: [], + }, }, - simple: { - declaredSymbols: [], - usedSymbols: [], - }, - externalSimple: { - declaredSymbols: [], - usedSymbols: [], - }, - }, - }) + }), + () => {}, + ] } public dispose() {} diff --git a/core/aws-lsp-fqn/src/common/commonFqnWorkerPool.ts b/core/aws-lsp-fqn/src/common/commonFqnWorkerPool.ts index 38ebfcd9..e04e1b27 100644 --- a/core/aws-lsp-fqn/src/common/commonFqnWorkerPool.ts +++ b/core/aws-lsp-fqn/src/common/commonFqnWorkerPool.ts @@ -1,6 +1,6 @@ import { pool, Pool } from 'workerpool' import { DEFAULT_MAX_QUEUE_SIZE, DEFAULT_MAX_WORKERS, DEFAULT_TIMEOUT, FQN_WORKER_ID } from './defaults' -import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, Logger, WorkerPoolConfig } from './types' +import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, Logger, WorkerPoolConfig, Cancellable } from './types' export class CommonFqnWorkerPool implements IFqnWorkerPool { #workerPool: Pool @@ -17,25 +17,32 @@ export class CommonFqnWorkerPool implements IFqnWorkerPool { }) } - public async exec(input: FqnExtractorInput): Promise { + public exec(input: FqnExtractorInput): Cancellable> { this.#logger?.log(`Extracting fully qualified names for ${input.languageId}`) - return this.#workerPool - .exec(FQN_WORKER_ID, [input]) - .timeout(this.#timeout) - .then(data => data as ExtractorResult) - .catch(error => { - const errorMessage = `Encountered error while extracting fully qualified names: ${ - error instanceof Error ? error.message : 'Unknown' - }` - - this.#logger?.error(errorMessage) - - return { - success: false, - error: errorMessage, - } - }) + const execPromise = this.#workerPool.exec(FQN_WORKER_ID, [input]).timeout(this.#timeout) + + return [ + // have to wrap this in promise since exec promise is not a true promise + new Promise(resolve => { + execPromise + .then(data => resolve(data as ExtractorResult)) + .catch(error => { + const errorMessage = `Encountered error while extracting fully qualified names: ${ + error instanceof Error ? error.message : 'Unknown' + }` + + this.#logger?.error(errorMessage) + + // using result pattern, so we will resolve with success: false + return resolve({ + success: false, + error, + }) + }) + }), + () => execPromise.cancel(), + ] } public dispose() { diff --git a/core/aws-lsp-fqn/src/common/types.ts b/core/aws-lsp-fqn/src/common/types.ts index 18993b2d..7567fbbc 100644 --- a/core/aws-lsp-fqn/src/common/types.ts +++ b/core/aws-lsp-fqn/src/common/types.ts @@ -20,7 +20,7 @@ export type Result = error: TError } -export type ExtractorResult = Result +export type ExtractorResult = Result export interface FullyQualifiedName { source: string[] @@ -72,6 +72,9 @@ export interface FqnExtractorInput { } export interface IFqnWorkerPool { - exec(input: FqnExtractorInput): Promise + exec(input: FqnExtractorInput): Cancellable> dispose(): void } + +export type CancelFn = () => void +export type Cancellable = [T, CancelFn] diff --git a/core/aws-lsp-fqn/src/index.ts b/core/aws-lsp-fqn/src/index.ts index f05264cc..7c7396b5 100644 --- a/core/aws-lsp-fqn/src/index.ts +++ b/core/aws-lsp-fqn/src/index.ts @@ -1,8 +1,9 @@ -import { ExtractorResult, FqnExtractorInput, IFqnWorkerPool, WorkerPoolConfig } from './common/types' +import { Cancellable, ExtractorResult, FqnExtractorInput, IFqnWorkerPool, WorkerPoolConfig } from './common/types' export * from './common/types' + export declare class FqnWorkerPool implements IFqnWorkerPool { constructor(workerPoolConfig?: WorkerPoolConfig) - exec(input: FqnExtractorInput): Promise + exec(input: FqnExtractorInput): Cancellable> dispose(): void } diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.test.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.test.ts index 8b878efd..0502f328 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.test.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.test.ts @@ -12,7 +12,7 @@ import { ChatController } from './chatController' import { ChatSessionManagementService } from './chatSessionManagementService' import { ChatSessionService } from './chatSessionService' import { ChatTelemetryController } from './telemetry/chatTelemetryController' -import { DocumentContextExtractor } from './contexts/documentContext' +import { TriggerContextExtractor } from './contexts/triggerContextExtractor' import * as utils from './utils' import { DEFAULT_HELP_FOLLOW_UP_PROMPT, HELP_MESSAGE } from './constants' @@ -61,6 +61,8 @@ describe('ChatController', () => { } let removeConversationSpy: sinon.SinonSpy let emitConversationMetricStub: sinon.SinonStub + let abortRequestStub: sinon.SinonStub + let triggerContextCancelSpy: sinon.SinonSpy let testFeatures: TestFeatures let chatSessionManagementService: ChatSessionManagementService @@ -88,9 +90,10 @@ describe('ChatController', () => { activeTabSpy = sinon.spy(ChatTelemetryController.prototype, 'activeTabId', ['get', 'set']) removeConversationSpy = sinon.spy(ChatTelemetryController.prototype, 'removeConversation') emitConversationMetricStub = sinon.stub(ChatTelemetryController.prototype, 'emitConversationMetric') + triggerContextCancelSpy = sinon.spy(TriggerContextExtractor.prototype, 'cancel') disposeStub = sinon.stub(ChatSessionService.prototype, 'dispose') - + abortRequestStub = sinon.stub(ChatSessionService.prototype, 'abortRequest') chatSessionManagementService = ChatSessionManagementService.getInstance().withCredentialsProvider( testFeatures.credentialsProvider ) @@ -132,11 +135,8 @@ describe('ChatController', () => { chatController.onEndChat({ tabId: mockTabId }, mockCancellationToken) - sinon.assert.calledOnce(disposeStub) - - const hasSession = chatSessionManagementService.hasSession(mockTabId) - - assert.ok(!hasSession) + sinon.assert.calledOnce(abortRequestStub) + sinon.assert.calledOnce(triggerContextCancelSpy) }) it('onTabAdd sets active tab id in telemetryController', () => { @@ -340,7 +340,7 @@ describe('ChatController', () => { } beforeEach(() => { - extractDocumentContextStub = sinon.stub(DocumentContextExtractor.prototype, 'extractDocumentContext') + extractDocumentContextStub = sinon.stub(TriggerContextExtractor.prototype, 'extractDocumentContext') testFeatures.openDocument(typescriptDocument) }) diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.ts index 30b9df4b..49f6882f 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/chatController.ts @@ -22,13 +22,13 @@ import { } from '../telemetry/types' import { Features, LspHandlers, Result } from '../types' import { ChatEventParser } from './chatEventParser' -import { createAuthFollowUpResult, getAuthFollowUpType, getDefaultChatResponse } from './utils' +import { CancellationError, createAuthFollowUpResult, getAuthFollowUpType, getDefaultChatResponse } from './utils' import { ChatSessionManagementService } from './chatSessionManagementService' import { ChatTelemetryController } from './telemetry/chatTelemetryController' import { QuickAction } from './quickActions' -import { getErrorMessage, isAwsError, isNullish, isObject } from '../utils' +import { getErrorMessage, hasCode, isAwsError, isNullish, isObject } from '../utils' import { Metric } from '../telemetry/metric' -import { QChatTriggerContext, TriggerContext } from './contexts/triggerContext' +import { TriggerContext, TriggerContextExtractor } from './contexts/triggerContextExtractor' import { HELP_MESSAGE } from './constants' type ChatHandlers = LspHandlers @@ -37,12 +37,12 @@ export class ChatController implements ChatHandlers { #features: Features #chatSessionManagementService: ChatSessionManagementService #telemetryController: ChatTelemetryController - #triggerContext: QChatTriggerContext + #triggerContext: TriggerContextExtractor constructor(chatSessionManagementService: ChatSessionManagementService, features: Features) { this.#features = features this.#chatSessionManagementService = chatSessionManagementService - this.#triggerContext = new QChatTriggerContext(features.workspace, features.logging) + this.#triggerContext = new TriggerContextExtractor(features.workspace, { logger: features.logging }) this.#telemetryController = new ChatTelemetryController(features) } @@ -67,108 +67,123 @@ export class ChatController implements ChatHandlers { return new ResponseError(ErrorCodes.InternalError, sessionResult.error) } - const metric = new Metric({ - cwsprChatConversationType: 'Chat', - }) + return this.#withLspCancellation(params.tabId, token, async checkIsCancelled => { + const metric = new Metric({ + cwsprChatConversationType: 'Chat', + }) - const triggerContext = await this.#getTriggerContext(params, metric) - const isNewConversation = !session.sessionId + const triggerContext = await this.#getTriggerContext(params, metric) - token.onCancellationRequested(() => { - this.#log('cancellation requested') - session.abortRequest() - }) + // return empty result since the other promise will have been resolved + if (checkIsCancelled()) { + return {} + } + + const isNewConversation = !session.sessionId + + let response: GenerateAssistantResponseCommandOutput + + const conversationIdentifier = session?.sessionId ?? 'New session' + try { + this.#log('Request for conversation id:', conversationIdentifier) + const requestInput = TriggerContextExtractor.getChatParamsFromTrigger(params, triggerContext) + + metric.recordStart() + response = await session.generateAssistantResponse(requestInput) + this.#log('Response for conversationId:', conversationIdentifier, JSON.stringify(response.$metadata)) + } catch (err) { + if (isObject(err) && 'name' in err && err.name === 'AbortError') { + throw new CancellationError('Q api request aborted') + } else if ( + isAwsError(err) || + (isObject(err) && 'statusCode' in err && typeof err.statusCode === 'number') + ) { + metric.setDimension('cwsprChatRepsonseCode', err.statusCode ?? 400) + this.#telemetryController.emitMessageResponseError(params.tabId, metric.metric) + } + + const authFollowType = getAuthFollowUpType(err) - let response: GenerateAssistantResponseCommandOutput + if (authFollowType) { + this.#log(`Q auth error: ${getErrorMessage(err)}`) - const conversationIdentifier = session?.sessionId ?? 'New session' - try { - this.#log('Request for conversation id:', conversationIdentifier) - const requestInput = this.#triggerContext.getChatParamsFromTrigger(params, triggerContext) + return createAuthFollowUpResult(authFollowType) + } - metric.recordStart() - response = await session.generateAssistantResponse(requestInput) - this.#log('Response for conversationId:', conversationIdentifier, JSON.stringify(response.$metadata)) - } catch (err) { - if (isAwsError(err) || (isObject(err) && 'statusCode' in err && typeof err.statusCode === 'number')) { - metric.setDimension('cwsprChatRepsonseCode', err.statusCode ?? 400) - this.#telemetryController.emitMessageResponseError(params.tabId, metric.metric) + this.#log(`Q api request error ${err instanceof Error ? err.message : 'unknown'}`) + return new ResponseError( + LSPErrorCodes.RequestFailed, + err instanceof Error ? err.message : 'Unknown request error' + ) } - const authFollowType = getAuthFollowUpType(err) + if (response.conversationId) { + this.#telemetryController.setConversationId(params.tabId, response.conversationId) - if (authFollowType) { - this.#log(`Q auth error: ${getErrorMessage(err)}`) + if (isNewConversation) { + this.#telemetryController.emitStartConversationMetric(params.tabId, metric.metric) + } + } - return createAuthFollowUpResult(authFollowType) + // return empty result since the other promise will have been resolved + if (checkIsCancelled()) { + return {} } - this.#log(`Q api request error ${err instanceof Error ? err.message : 'unknown'}`) - return new ResponseError( - LSPErrorCodes.RequestFailed, - err instanceof Error ? err.message : 'Unknown request error' - ) - } + try { + const result = await this.#processAssistantResponse( + response, + metric.mergeWith({ + cwsprChatResponseCode: response.$metadata.httpStatusCode, + cwsprChatMessageId: response.$metadata.requestId, + }), + params.partialResultToken + ) - if (response.conversationId) { - this.#telemetryController.setConversationId(params.tabId, response.conversationId) + this.#telemetryController.emitAddMessageMetric(params.tabId, metric.metric) - if (isNewConversation) { this.#telemetryController.updateTriggerInfo(params.tabId, { - startTrigger: { - hasUserSnippet: metric.metric.cwsprChatHasCodeSnippet ?? false, - triggerType: triggerContext.triggerType, + lastMessageTrigger: { + ...triggerContext, + messageId: response.$metadata.requestId, + followUpActions: new Set( + result.data?.followUp?.options + ?.map(option => option.prompt ?? '') + .filter(prompt => prompt.length > 0) + ), }, }) - this.#telemetryController.emitStartConversationMetric(params.tabId, metric.metric) - } - } - - try { - const result = await this.#processAssistantResponse( - response, - metric.mergeWith({ - cwsprChatResponseCode: response.$metadata.httpStatusCode, - cwsprChatMessageId: response.$metadata.requestId, - }), - params.partialResultToken - ) - - this.#telemetryController.emitAddMessageMetric(params.tabId, metric.metric) - - this.#telemetryController.updateTriggerInfo(params.tabId, { - lastMessageTrigger: { - ...triggerContext, - messageId: response.$metadata.requestId, - followUpActions: new Set( - result.data?.followUp?.options - ?.map(option => option.prompt ?? '') - .filter(prompt => prompt.length > 0) - ), - }, - }) + return result.success + ? result.data + : new ResponseError(LSPErrorCodes.RequestFailed, result.error) + } catch (err) { + if (hasCode(err) && err.code === 'ECONNRESET') { + this.#log('Response streaming aborted') + throw new CancellationError('Response streaming aborted') + } - return result.success - ? result.data - : new ResponseError(LSPErrorCodes.RequestFailed, result.error) - } catch (err) { - this.#log('Error encountered during response streaming:', err instanceof Error ? err.message : 'unknown') + this.#log( + 'Error encountered during response streaming:', + err instanceof Error ? err.message : 'unknown' + ) - return new ResponseError( - LSPErrorCodes.RequestFailed, - err instanceof Error ? err.message : 'Unknown error occured during response stream' - ) - } + return new ResponseError( + LSPErrorCodes.RequestFailed, + err instanceof Error ? err.message : 'Unknown error occured during response stream' + ) + } + }) } onCodeInsertToCursorPosition() {} onCopyCodeToClipboard() {} onEndChat(params: EndChatParams, _token: CancellationToken): boolean { - const { success } = this.#chatSessionManagementService.deleteSession(params.tabId) + this.#log('end chat') + this.#cancelRequest(params.tabId) - return success + return true } onFollowUpClicked() {} @@ -281,7 +296,7 @@ export class ChatController implements ChatHandlers { triggerContext = lastMessageTrigger } else { - triggerContext = await this.#triggerContext.getNewTriggerContext(params) + triggerContext = await this.#triggerContext.getTriggerContext(params) triggerContext.triggerType = this.#telemetryController.getCurrentTrigger(params.tabId) ?? 'click' } @@ -333,4 +348,42 @@ export class ChatController implements ChatHandlers { #log(...messages: string[]) { this.#features.logging.log(messages.join(' ')) } + + #cancelRequest(tabId: string) { + this.#triggerContext.cancel(tabId) + this.#chatSessionManagementService.getSession(tabId).data?.abortRequest() + } + + #withLspCancellation( + tabId: string, + token: CancellationToken, + action: (checkIsCancelled: () => boolean) => Promise> + ): Promise> { + let isCancelled = false + + return Promise.race([ + new Promise>(resolve => { + token.onCancellationRequested(() => { + this.#log('cancellation requested') + + this.#cancelRequest(tabId) + isCancelled = true + + resolve(new ResponseError(LSPErrorCodes.RequestCancelled, 'Request cancelled')) + }) + }), + // .race doesn't stop the "losing" promise from executing so we need to provide this boolean for early termination + action(() => isCancelled).catch(error => { + if (error instanceof CancellationError) { + this.#log('Request cancelled: ', error.message) + return new ResponseError(LSPErrorCodes.RequestCancelled, 'Request cancelled') + } + + return new ResponseError( + ErrorCodes.InternalError, + error instanceof Error ? error.message : 'Unknown error' + ) + }), + ]) + } } diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/chatSessionService.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/chatSessionService.ts index 7390f24f..7822ee23 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/chatSessionService.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/chatSessionService.ts @@ -9,6 +9,7 @@ import { CredentialsProvider } from '@aws/language-server-runtimes/server-interf import { getBearerTokenFromProvider } from '../utils' export type ChatSessionServiceConfig = CodeWhispererStreamingClientConfig +export type Dispose = () => void export class ChatSessionService { public shareCodeWhispererContentWithAWS = false readonly #codeWhispererRegion = 'us-east-1' diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentContext.test.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentContext.test.ts deleted file mode 100644 index 6618a51c..00000000 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentContext.test.ts +++ /dev/null @@ -1,135 +0,0 @@ -import { EditorState } from '@amzn/codewhisperer-streaming' -import * as assert from 'assert' -import sinon from 'ts-sinon' -import { TextDocument } from 'vscode-languageserver-textdocument' -import { DocumentContext, DocumentContextExtractor } from './documentContext' -import { DocumentFqnExtractor } from './documentFqnExtractor' - -describe('DocumentContext', () => { - const mockTypescriptCodeBlock = `function test() { - console.log('test') -}` - const mockTSDocument = TextDocument.create('file://test.ts', 'typescript', 1, mockTypescriptCodeBlock) - - beforeEach(() => { - sinon.stub(DocumentFqnExtractor.prototype, 'extractDocumentSymbols').resolves([]) - }) - - afterEach(() => { - sinon.restore() - }) - - describe('documentContextExtractor.extractEditorState', () => { - it('extracts editor state for range selection', async () => { - const documentContextExtractor = new DocumentContextExtractor({ characterLimits: 19 }) - const expected: DocumentContext = { - programmingLanguage: { languageName: 'typescript' }, - relativeFilePath: 'file://test.ts', - documentSymbols: [], - text: "console.log('test')", - hasCodeSnippet: true, - totalEditorCharacters: mockTypescriptCodeBlock.length, - cursorState: { - range: { - start: { - line: 0, - character: 8, - }, - end: { - line: 0, - character: 11, - }, - }, - }, - } - - const result = await documentContextExtractor.extractDocumentContext(mockTSDocument, { - // highlighting "log" - range: { - start: { - line: 1, - character: 12, - }, - end: { - line: 1, - character: 15, - }, - }, - }) - - assert.deepStrictEqual(result, expected) - }) - - it('extracts editor state for collapsed position', async () => { - const documentContextExtractor = new DocumentContextExtractor({ characterLimits: 19 }) - const expected: DocumentContext = { - programmingLanguage: { languageName: 'typescript' }, - relativeFilePath: 'file://test.ts', - documentSymbols: [], - text: "console.log('test')", - hasCodeSnippet: true, - totalEditorCharacters: mockTypescriptCodeBlock.length, - cursorState: { - range: { - start: { - line: 0, - character: 9, - }, - end: { - line: 0, - character: 10, - }, - }, - }, - } - - const result = await documentContextExtractor.extractDocumentContext(mockTSDocument, { - // highlighting "o" in "log" - range: { - start: { - line: 1, - character: 13, - }, - end: { - line: 1, - character: 14, - }, - }, - }) - - assert.deepStrictEqual(result, expected) - }) - }) - - it('handles other languages correctly', async () => { - const documentContextExtractor = new DocumentContextExtractor({ characterLimits: 19 }) - - const mockGoCodeBLock = `func main() { - fmt.Println("test") -}` - const mockDocument = TextDocument.create('file://test.go', 'go', 1, mockGoCodeBLock) - - const expectedResult: DocumentContext = { - programmingLanguage: { languageName: 'go' }, - relativeFilePath: 'file://test.go', - documentSymbols: [], - text: 'fmt.Println("test")', - totalEditorCharacters: mockGoCodeBLock.length, - hasCodeSnippet: true, - cursorState: { - range: { - start: { line: 0, character: 0 }, - end: { line: 0, character: 19 }, - }, - }, - } - const result = await documentContextExtractor.extractDocumentContext(mockDocument, { - range: { - start: { line: 1, character: 4 }, - end: { line: 1, character: 23 }, - }, - }) - - assert.deepStrictEqual(result, expectedResult) - }) -}) diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentContext.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentContext.ts deleted file mode 100644 index 14d8ecb9..00000000 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentContext.ts +++ /dev/null @@ -1,84 +0,0 @@ -import { EditorState, TextDocument as CwsprTextDocument, DocumentSymbol } from '@amzn/codewhisperer-streaming' -import { CursorState } from '@aws/language-server-runtimes/server-interface' -import { Range, TextDocument } from 'vscode-languageserver-textdocument' -import { getLanguageId } from '../../languageDetection' -import { Features } from '../../types' -import { DocumentFqnExtractor, DocumentFqnExtractorConfig } from './documentFqnExtractor' -import { getExtendedCodeBlockRange, getSelectionWithinExtendedRange } from './utils' - -export type DocumentContext = CwsprTextDocument & { - cursorState?: EditorState['cursorState'] - hasCodeSnippet: boolean - totalEditorCharacters: number -} - -export interface DocumentContextExtractorConfig extends DocumentFqnExtractorConfig { - config?: DocumentFqnExtractorConfig - logger?: Features['logging'] - characterLimits?: number -} - -export class DocumentContextExtractor { - private static readonly DEFAULT_CHARACTER_LIMIT = 9000 - - #characterLimits: number - #logger?: Features['logging'] - #documentSymbolExtractor: DocumentFqnExtractor - - constructor(config?: DocumentContextExtractorConfig) { - const { characterLimits, ...fqnConfig } = config ?? {} - - this.#logger = config?.logger - this.#characterLimits = characterLimits ?? DocumentContextExtractor.DEFAULT_CHARACTER_LIMIT - this.#documentSymbolExtractor = new DocumentFqnExtractor(fqnConfig) - } - - public dispose() { - this.#documentSymbolExtractor.dispose() - } - - /** - * From the given the cursor state, we want to give Q context up to the characters limit - * on both sides of the cursor. - */ - public async extractDocumentContext(document: TextDocument, cursorState: CursorState): Promise { - const targetRange: Range = - 'position' in cursorState - ? { - start: cursorState.position, - end: cursorState.position, - } - : cursorState.range - - const codeBlockRange = getExtendedCodeBlockRange(document, targetRange, this.#characterLimits) - - const rangeWithinCodeBlock = getSelectionWithinExtendedRange(targetRange, codeBlockRange) - - const languageId = getLanguageId(document) - - let documentSymbols: DocumentSymbol[] = [] - - try { - // best effort to extract symbols - documentSymbols = await this.#documentSymbolExtractor.extractDocumentSymbols( - document, - codeBlockRange, - languageId - ) - } catch (e) { - this.#logger?.log( - `Error extracting document symbols but continuing on. ${e instanceof Error ? e.message : 'Unknown error'}` - ) - } - - return { - cursorState: rangeWithinCodeBlock ? { range: rangeWithinCodeBlock } : undefined, - text: document.getText(codeBlockRange), - programmingLanguage: languageId ? { languageName: languageId } : undefined, - relativeFilePath: document.uri, - documentSymbols, - hasCodeSnippet: Boolean(rangeWithinCodeBlock), - totalEditorCharacters: document.getText().length, - } - } -} diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.test.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.test.ts index bb00802c..8f655ea0 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.test.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.test.ts @@ -25,10 +25,10 @@ describe('DocumentFQNExtractor', () => { it('returns symbols in the right shape', async () => { const documentFqnExtractor = new DocumentFqnExtractor() - extractorStub.returns(Promise.resolve({ success: true, data: mockExtractedSymbols })) - const documentSymbols = await documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) + extractorStub.returns([Promise.resolve({ success: true, data: mockExtractedSymbols }), () => {}]) + const [documentSymbolsPromise] = documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) - assert.deepStrictEqual(documentSymbols, expectedExtractedNames) + assert.deepStrictEqual(await documentSymbolsPromise, expectedExtractedNames) sinon.assert.calledOnceWithExactly(extractorStub, { fileText: typescriptDocument.getText(), selection: mockRange, @@ -39,41 +39,39 @@ describe('DocumentFQNExtractor', () => { it('returns empty array if language id is not supported', async () => { const documentFqnExtractor = new DocumentFqnExtractor() - extractorStub.returns(Promise.resolve({ success: true, data: mockExtractedSymbols })) - const documentSymbols = await documentFqnExtractor.extractDocumentSymbols( + extractorStub.returns([Promise.resolve({ success: true, data: mockExtractedSymbols }), () => {}]) + const [documentSymbolsPromise] = documentFqnExtractor.extractDocumentSymbols( typescriptDocument, mockRange, 'lolcode' ) - assert.deepStrictEqual(documentSymbols, []) + assert.deepStrictEqual(await documentSymbolsPromise, []) }) - it('resolves to empty array if not successful', async () => { - extractorStub.resolves({ success: false, data: mockExtractedSymbols }) + it('throws error if error is present', () => { + const mockError = new Error('mock error') - let documentSymbols = await documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) + extractorStub.returns([ + Promise.resolve({ success: false, data: mockExtractedSymbols, error: mockError }), + () => {}, + ]) - assert.deepStrictEqual(documentSymbols, []) - - extractorStub.resolves({ success: false, data: undefined }) - - documentSymbols = await documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) - - assert.deepStrictEqual(documentSymbols, []) + const [documentSymbolsPromise] = documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) + assert.rejects(documentSymbolsPromise) }) it('uses language id if passed', async () => { const documentFqnExtractor = new DocumentFqnExtractor() - extractorStub.resolves({ success: true, data: mockExtractedSymbols }) - const documentSymbols = await documentFqnExtractor.extractDocumentSymbols( + extractorStub.returns([Promise.resolve({ success: true, data: mockExtractedSymbols }), () => {}]) + const [documentSymbolsPromise] = documentFqnExtractor.extractDocumentSymbols( typescriptDocument, mockRange, 'python' ) - assert.deepStrictEqual(documentSymbols, expectedExtractedNames) + assert.deepStrictEqual(await documentSymbolsPromise, expectedExtractedNames) sinon.assert.calledOnceWithExactly(extractorStub, { fileText: typescriptDocument.getText(), selection: mockRange, @@ -84,46 +82,55 @@ describe('DocumentFQNExtractor', () => { it('dedups symbols', async () => { const documentFqnExtractor = new DocumentFqnExtractor() - extractorStub.resolves({ - success: true, - data: { - fullyQualified: { - ...mockExtractedSymbols.fullyQualified, - usedSymbols: [ - ...mockExtractedSymbols.fullyQualified.usedSymbols.slice(0, 4), - ...mockExtractedSymbols.fullyQualified.usedSymbols.slice(3, 6), - ...mockExtractedSymbols.fullyQualified.usedSymbols.slice(5), - ], + extractorStub.returns([ + Promise.resolve({ + success: true, + data: { + fullyQualified: { + ...mockExtractedSymbols.fullyQualified, + usedSymbols: [ + ...mockExtractedSymbols.fullyQualified.usedSymbols.slice(0, 4), + ...mockExtractedSymbols.fullyQualified.usedSymbols.slice(3, 6), + ...mockExtractedSymbols.fullyQualified.usedSymbols.slice(5), + ], + }, }, - }, - }) + }), + () => {}, + ]) - const documentSymbols = await documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) + const [documentSymbolsPromise] = documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) - assert.deepStrictEqual(documentSymbols, expectedExtractedNames) + assert.deepStrictEqual(await documentSymbolsPromise, expectedExtractedNames) }) it('returns no more than the limit of symbols specify', async () => { const documentFqnExtractor = new DocumentFqnExtractor({ maxSymbols: 3 }) - extractorStub.resolves({ - success: true, - data: mockExtractedSymbols, - }) - const documentSymbols = await documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) + extractorStub.returns([ + Promise.resolve({ + success: true, + data: mockExtractedSymbols, + }), + () => {}, + ]) + const [documentSymbolsPromise] = documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) - assert.deepStrictEqual(documentSymbols, expectedExtractedNames.slice(0, 3)) + assert.deepStrictEqual(await documentSymbolsPromise, expectedExtractedNames.slice(0, 3)) }) it('filters out symbols if either name or source does not conform to the length limit', async () => { const documentFqnExtractor = new DocumentFqnExtractor({ nameMinLength: 5, nameMaxLength: 8 }) - extractorStub.resolves({ - success: true, - data: mockExtractedSymbols, - }) - const documentSymbols = await documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) + extractorStub.returns([ + Promise.resolve({ + success: true, + data: mockExtractedSymbols, + }), + () => {}, + ]) + const [documentSymbolsPromise] = documentFqnExtractor.extractDocumentSymbols(typescriptDocument, mockRange) - assert.deepStrictEqual(documentSymbols, [{ name: 'mkdir', type: 'USAGE', source: 'node:fs' }]) + assert.deepStrictEqual(await documentSymbolsPromise, [{ name: 'mkdir', type: 'USAGE', source: 'node:fs' }]) }) }) diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.ts index ba5d160c..57ce662c 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/documentFqnExtractor.ts @@ -1,7 +1,7 @@ import { DocumentSymbol, SymbolType } from '@amzn/codewhisperer-streaming' import { ExtractorResult, FqnSupportedLanguages, FqnWorkerPool, FullyQualifiedName, IFqnWorkerPool } from '@aws/lsp-fqn' import { Range, TextDocument } from 'vscode-languageserver-textdocument' -import { Features } from '../../types' +import { Cancellable, Features } from '../../types' export interface DocumentFqnExtractorConfig { nameMinLength?: number @@ -54,70 +54,92 @@ export class DocumentFqnExtractor { this.#workerPool.dispose() } - public async extractDocumentSymbols( + public extractDocumentSymbols( document: TextDocument, range: Range, languageId = document.languageId - ): Promise { + ): Cancellable> { return DocumentFqnExtractor.FQN_SUPPORTED_LANGUAGE_SET.has(languageId) ? this.#extractSymbols(document, range, languageId as FqnSupportedLanguages) - : [] + : [Promise.resolve([]), () => {}] } - async #extractSymbols(document: TextDocument, range: Range, languageId: FqnSupportedLanguages) { - const names = await this.#extractNames(document, range, languageId) - - const documentSymbols: DocumentSymbol[] = [] - - for (const name of names) { - if (documentSymbols.length >= this.#maxSymbols) { - break - } - - const sourceSymbolString = name.source.join('.') - const symbolFqn = { - name: name.symbol.join('.') ?? '', - type: SymbolType.USAGE, - source: sourceSymbolString ? sourceSymbolString : undefined, - } - - if ( - symbolFqn.name.length >= this.#nameMinLength && - symbolFqn.name.length < this.#nameMaxLength && - (symbolFqn.source === undefined || - (symbolFqn.source.length >= this.#nameMinLength && symbolFqn.source.length < this.#nameMaxLength)) - ) { - documentSymbols.push(symbolFqn) - } - } - - return documentSymbols + #extractSymbols( + document: TextDocument, + range: Range, + languageId: FqnSupportedLanguages + ): Cancellable> { + const [extractPromise, cancel] = this.#extractNames(document, range, languageId) + + return [ + extractPromise.then(names => { + const documentSymbols: DocumentSymbol[] = [] + + for (const name of names) { + if (documentSymbols.length >= this.#maxSymbols) { + break + } + + const sourceSymbolString = name.source.join('.') + const symbolFqn = { + name: name.symbol.join('.') ?? '', + type: SymbolType.USAGE, + source: sourceSymbolString ? sourceSymbolString : undefined, + } + + if ( + symbolFqn.name.length >= this.#nameMinLength && + symbolFqn.name.length < this.#nameMaxLength && + (symbolFqn.source === undefined || + (symbolFqn.source.length >= this.#nameMinLength && + symbolFqn.source.length < this.#nameMaxLength)) + ) { + documentSymbols.push(symbolFqn) + } + } + + return documentSymbols + }), + cancel, + ] } - async #extractNames( + #extractNames( document: TextDocument, range: Range, languageId: FqnSupportedLanguages - ): Promise { - const result = await this.#findNamesInRange(document.getText(), range, languageId) - - if (!result.success || !result.data.fullyQualified) { - return [] - } - - const dedupedUsedFullyQualifiedNames: { [key: string]: FullyQualifiedName } = Object.fromEntries( - result.data.fullyQualified.usedSymbols.map((name: FullyQualifiedName) => [ - JSON.stringify([name.source, name.symbol]), - { source: name.source, symbol: name.symbol }, - ]) - ) - - return Object.values(dedupedUsedFullyQualifiedNames).sort( - (name, other) => name.source.length + name.symbol.length - (other.source.length + other.symbol.length) - ) + ): Cancellable> { + const [extractPromise, cancel] = this.#findNamesInRange(document.getText(), range, languageId) + + return [ + extractPromise.then(result => { + if (!result.success) { + throw result.error + } else if (!result.data.fullyQualified) { + return [] + } + + const dedupedUsedFullyQualifiedNames: { [key: string]: FullyQualifiedName } = Object.fromEntries( + result.data.fullyQualified.usedSymbols.map((name: FullyQualifiedName) => [ + JSON.stringify([name.source, name.symbol]), + { source: name.source, symbol: name.symbol }, + ]) + ) + + return Object.values(dedupedUsedFullyQualifiedNames).sort( + (name, other) => + name.source.length + name.symbol.length - (other.source.length + other.symbol.length) + ) + }), + cancel, + ] } - #findNamesInRange(fileText: string, selection: Range, languageId: FqnSupportedLanguages): Promise { + #findNamesInRange( + fileText: string, + selection: Range, + languageId: FqnSupportedLanguages + ): Cancellable> { return this.#workerPool.exec({ /** * [\ue000-\uf8ff]: Private Use Area in Unicode diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContext.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContext.ts deleted file mode 100644 index 10fea446..00000000 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContext.ts +++ /dev/null @@ -1,105 +0,0 @@ -import { TriggerType } from '@aws/chat-client-ui-types' -import { ChatTriggerType, GenerateAssistantResponseCommandInput, UserIntent } from '@amzn/codewhisperer-streaming' -import { ChatParams, CursorState } from '@aws/language-server-runtimes/server-interface' -import { Features } from '../../types' -import { DocumentContext, DocumentContextExtractor } from './documentContext' - -export interface TriggerContext extends Partial { - userIntent?: string - triggerType?: TriggerType -} - -export class QChatTriggerContext { - private static readonly DEFAULT_CURSOR_STATE: CursorState = { position: { line: 0, character: 0 } } - - #workspace: Features['workspace'] - #documentContextExtractor: DocumentContextExtractor - - constructor(workspace: Features['workspace'], logger: Features['logging']) { - this.#workspace = workspace - this.#documentContextExtractor = new DocumentContextExtractor({ logger }) - } - - async getNewTriggerContext(params: ChatParams): Promise { - const documentContext: DocumentContext | undefined = await this.extractDocumentContext(params) - - return { - ...documentContext, - userIntent: this.#guessIntentFromPrompt(params.prompt.prompt), - } - } - - getChatParamsFromTrigger( - params: ChatParams, - triggerContext: TriggerContext - ): GenerateAssistantResponseCommandInput { - const { prompt } = params - - const data: GenerateAssistantResponseCommandInput = { - conversationState: { - chatTriggerType: ChatTriggerType.MANUAL, - currentMessage: { - userInputMessage: { - content: prompt.escapedPrompt ?? prompt.prompt, - userInputMessageContext: - triggerContext.cursorState && triggerContext.relativeFilePath - ? { - editorState: { - cursorState: triggerContext.cursorState, - document: { - text: triggerContext.text, - programmingLanguage: triggerContext.programmingLanguage, - relativeFilePath: triggerContext.relativeFilePath, - documentSymbols: triggerContext.documentSymbols, - }, - }, - } - : undefined, - userIntent: triggerContext.userIntent, - }, - }, - }, - } - - return data - } - - public dispose() { - this.#documentContextExtractor.dispose() - } - - // public for testing - async extractDocumentContext( - input: Pick - ): Promise { - const { textDocument: textDocumentIdentifier, cursorState } = input - - const textDocument = - textDocumentIdentifier?.uri && (await this.#workspace.getTextDocument(textDocumentIdentifier.uri)) - - return textDocument - ? this.#documentContextExtractor.extractDocumentContext( - textDocument, - // we want to include a default position if a text document is found so users can still ask questions about the opened file - // the range will be expanded up to the max characters downstream - cursorState?.[0] ?? QChatTriggerContext.DEFAULT_CURSOR_STATE - ) - : undefined - } - - #guessIntentFromPrompt(prompt?: string): UserIntent | undefined { - if (prompt === undefined) { - return undefined - } else if (/^explain/i.test(prompt)) { - return UserIntent.EXPLAIN_CODE_SELECTION - } else if (/^refactor/i.test(prompt)) { - return UserIntent.SUGGEST_ALTERNATE_IMPLEMENTATION - } else if (/^fix/i.test(prompt)) { - return UserIntent.APPLY_COMMON_BEST_PRACTICES - } else if (/^optimize/i.test(prompt)) { - return UserIntent.IMPROVE_CODE - } - - return undefined - } -} diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContextExtractor.test.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContextExtractor.test.ts new file mode 100644 index 00000000..526b9b3b --- /dev/null +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContextExtractor.test.ts @@ -0,0 +1,263 @@ +import * as assert from 'assert' +import sinon from 'ts-sinon' +import { TextDocument } from 'vscode-languageserver-textdocument' +import { DocumentContext, TriggerContextExtractor } from './triggerContextExtractor' +import { DocumentFqnExtractor } from './documentFqnExtractor' +import { TestFeatures } from '@aws/language-server-runtimes/testing' +import { UserIntent } from '@amzn/codewhisperer-streaming' + +describe('TriggerContextExtractor', () => { + let features: TestFeatures + const mockTypescriptCodeBlock = `function test() { + console.log('test') +}` + const mockTSDocument = TextDocument.create('file://test.ts', 'typescript', 1, mockTypescriptCodeBlock) + const mockTabId = 'tab-1' + + beforeEach(() => { + features = new TestFeatures() + sinon.stub(DocumentFqnExtractor.prototype, 'extractDocumentSymbols').returns([Promise.resolve([]), () => {}]) + }) + + afterEach(() => { + sinon.restore() + }) + + describe('TriggerContextExtractor.extractEditorState', () => { + it('extracts editor state for range selection', async () => { + const documentContextExtractor = new TriggerContextExtractor(features.workspace, { characterLimits: 19 }) + const expected: DocumentContext = { + programmingLanguage: { languageName: 'typescript' }, + relativeFilePath: 'file://test.ts', + documentSymbols: [], + text: "console.log('test')", + hasCodeSnippet: true, + totalEditorCharacters: mockTypescriptCodeBlock.length, + cursorState: { + range: { + start: { + line: 0, + character: 8, + }, + end: { + line: 0, + character: 11, + }, + }, + }, + } + + const result = await documentContextExtractor.extractDocumentContext(mockTabId, mockTSDocument, { + // highlighting "log" + range: { + start: { + line: 1, + character: 12, + }, + end: { + line: 1, + character: 15, + }, + }, + }) + + assert.deepStrictEqual(result, expected) + }) + + it('extracts editor state for collapsed position', async () => { + const documentContextExtractor = new TriggerContextExtractor(features.workspace, { characterLimits: 19 }) + const expected: DocumentContext = { + programmingLanguage: { languageName: 'typescript' }, + relativeFilePath: 'file://test.ts', + documentSymbols: [], + text: "console.log('test')", + hasCodeSnippet: true, + totalEditorCharacters: mockTypescriptCodeBlock.length, + cursorState: { + range: { + start: { + line: 0, + character: 9, + }, + end: { + line: 0, + character: 10, + }, + }, + }, + } + + const result = await documentContextExtractor.extractDocumentContext(mockTabId, mockTSDocument, { + // highlighting "o" in "log" + range: { + start: { + line: 1, + character: 13, + }, + end: { + line: 1, + character: 14, + }, + }, + }) + + assert.deepStrictEqual(result, expected) + }) + }) + + it('handles other languages correctly', async () => { + const documentContextExtractor = new TriggerContextExtractor(features.workspace, { characterLimits: 19 }) + + const mockGoCodeBLock = `func main() { + fmt.Println("test") +}` + const mockDocument = TextDocument.create('file://test.go', 'go', 1, mockGoCodeBLock) + + const expectedResult: DocumentContext = { + programmingLanguage: { languageName: 'go' }, + relativeFilePath: 'file://test.go', + documentSymbols: [], + text: 'fmt.Println("test")', + totalEditorCharacters: mockGoCodeBLock.length, + hasCodeSnippet: true, + cursorState: { + range: { + start: { line: 0, character: 0 }, + end: { line: 0, character: 19 }, + }, + }, + } + const result = await documentContextExtractor.extractDocumentContext(mockTabId, mockDocument, { + range: { + start: { line: 1, character: 4 }, + end: { line: 1, character: 23 }, + }, + }) + + assert.deepStrictEqual(result, expectedResult) + }) + + describe('TriggerContextExtractor.getTriggerContext', () => { + beforeEach(() => { + features.openDocument(mockTSDocument) + }) + + it('returns only userIntent if textDocument uri is not passed', async () => { + const documentContextExtractor = new TriggerContextExtractor(features.workspace, { characterLimits: 10 }) + + const result = await documentContextExtractor.getTriggerContext({ + prompt: { prompt: 'Explain this code' }, + tabId: 'tab1', + cursorState: [ + { + range: { + start: { line: 1, character: 4 }, + end: { line: 1, character: 23 }, + }, + }, + ], + }) + + assert.deepStrictEqual(result, { userIntent: UserIntent.EXPLAIN_CODE_SELECTION }) + }) + + it('returns only userIntent if textDocument is not found', async () => { + const documentContextExtractor = new TriggerContextExtractor(features.workspace, { characterLimits: 10 }) + + const result = await documentContextExtractor.getTriggerContext({ + prompt: { prompt: 'Fix this code' }, + tabId: 'tab1', + textDocument: { + uri: 'file://non-existent.ts', + }, + cursorState: [ + { + range: { + start: { line: 1, character: 4 }, + end: { line: 1, character: 23 }, + }, + }, + ], + }) + + assert.deepStrictEqual(result, { userIntent: UserIntent.APPLY_COMMON_BEST_PRACTICES }) + }) + + it('uses a default cursor state if cursor state is not defined', async () => { + const documentContextExtractor = new TriggerContextExtractor(features.workspace, { characterLimits: 10 }) + + const result = await documentContextExtractor.getTriggerContext({ + prompt: { prompt: 'Fix this code' }, + tabId: 'tab1', + textDocument: { + uri: mockTSDocument.uri, + }, + }) + + assert.deepStrictEqual(result, { + programmingLanguage: { languageName: 'typescript' }, + relativeFilePath: 'file://test.ts', + documentSymbols: [], + text: 'function t', + hasCodeSnippet: true, + totalEditorCharacters: mockTypescriptCodeBlock.length, + userIntent: UserIntent.APPLY_COMMON_BEST_PRACTICES, + cursorState: { + range: { + start: { + line: 0, + character: 0, + }, + end: { + line: 0, + + character: 0, + }, + }, + }, + }) + }) + + it('returns all context and userIntent', async () => { + const documentContextExtractor = new TriggerContextExtractor(features.workspace, { characterLimits: 5 }) + + const result = await documentContextExtractor.getTriggerContext({ + prompt: { prompt: 'Fix this code' }, + tabId: 'tab1', + textDocument: { + uri: mockTSDocument.uri, + }, + cursorState: [ + { + position: { + line: 1, + character: 7, + }, + }, + ], + }) + + assert.deepStrictEqual(result, { + programmingLanguage: { languageName: 'typescript' }, + relativeFilePath: 'file://test.ts', + documentSymbols: [], + text: 'conso', + hasCodeSnippet: true, + totalEditorCharacters: mockTypescriptCodeBlock.length, + userIntent: UserIntent.APPLY_COMMON_BEST_PRACTICES, + cursorState: { + range: { + start: { + line: 0, + character: 3, + }, + end: { + line: 0, + character: 3, + }, + }, + }, + }) + }) + }) +}) diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContextExtractor.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContextExtractor.ts new file mode 100644 index 00000000..f50dd391 --- /dev/null +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContextExtractor.ts @@ -0,0 +1,191 @@ +import { + EditorState, + TextDocument as CwsprTextDocument, + DocumentSymbol, + UserIntent, + GenerateAssistantResponseCommandInput, + ChatTriggerType, +} from '@amzn/codewhisperer-streaming' +import { TriggerType } from '@aws/chat-client-ui-types' +import { ChatParams, CursorState } from '@aws/language-server-runtimes/server-interface' +import { Range, TextDocument } from 'vscode-languageserver-textdocument' +import { getLanguageId } from '../../languageDetection' +import { Cancel, Features } from '../../types' +import { DocumentFqnExtractor, DocumentFqnExtractorConfig } from './documentFqnExtractor' +import { getExtendedCodeBlockRange, getSelectionWithinExtendedRange } from './utils' +import { CancellationError } from '../utils' + +export type TriggerContext = Partial & { + userIntent?: string + triggerType?: TriggerType +} + +export type DocumentContext = CwsprTextDocument & { + cursorState?: EditorState['cursorState'] + hasCodeSnippet: boolean + totalEditorCharacters: number +} + +export interface TriggerContextExtractorConfig extends DocumentFqnExtractorConfig { + config?: DocumentFqnExtractorConfig + logger?: Features['logging'] + characterLimits?: number +} + +export class TriggerContextExtractor { + private static readonly DEFAULT_CHARACTER_LIMIT = 9000 + private static readonly DEFAULT_CURSOR_STATE: CursorState = { position: { line: 0, character: 0 } } + + #characterLimits: number + #logger?: Features['logging'] + #documentSymbolExtractor: DocumentFqnExtractor + #workspace: Features['workspace'] + #cancellableByTabId: { [tabId: string]: Cancel } + + public static getChatParamsFromTrigger( + params: ChatParams, + triggerContext: TriggerContext + ): GenerateAssistantResponseCommandInput { + const { prompt } = params + + const data: GenerateAssistantResponseCommandInput = { + conversationState: { + chatTriggerType: ChatTriggerType.MANUAL, + currentMessage: { + userInputMessage: { + content: prompt.escapedPrompt ?? prompt.prompt, + userInputMessageContext: + triggerContext.cursorState && triggerContext.relativeFilePath + ? { + editorState: { + cursorState: triggerContext.cursorState, + document: { + text: triggerContext.text, + programmingLanguage: triggerContext.programmingLanguage, + relativeFilePath: triggerContext.relativeFilePath, + documentSymbols: triggerContext.documentSymbols, + }, + }, + } + : undefined, + userIntent: triggerContext.userIntent, + }, + }, + }, + } + + return data + } + + constructor(workspace: Features['workspace'], config?: TriggerContextExtractorConfig) { + const { characterLimits, ...fqnConfig } = config ?? {} + + this.#logger = config?.logger + this.#characterLimits = characterLimits ?? TriggerContextExtractor.DEFAULT_CHARACTER_LIMIT + this.#documentSymbolExtractor = new DocumentFqnExtractor(fqnConfig) + this.#workspace = workspace + this.#cancellableByTabId = {} + } + + public async getTriggerContext(params: ChatParams): Promise { + const { textDocument: textDocumentIdentifier, cursorState } = params + + const textDocument = + textDocumentIdentifier?.uri && (await this.#workspace.getTextDocument(textDocumentIdentifier.uri)) + + const documentContext = + textDocument && + (await this.extractDocumentContext( + params.tabId, + textDocument, + cursorState?.[0] ?? TriggerContextExtractor.DEFAULT_CURSOR_STATE + )) + + return { + ...documentContext, + userIntent: this.#guessIntentFromPrompt(params.prompt.prompt), + } + } + + /** + * From the given the cursor state, we want to give Q context up to the characters limit + * on both sides of the cursor. + */ + public async extractDocumentContext( + tabId: string, + document: TextDocument, + cursorState: CursorState + ): Promise { + const targetRange: Range = + 'position' in cursorState + ? { + start: cursorState.position, + end: cursorState.position, + } + : cursorState.range + + const codeBlockRange = getExtendedCodeBlockRange(document, targetRange, this.#characterLimits) + + const rangeWithinCodeBlock = getSelectionWithinExtendedRange(targetRange, codeBlockRange) + + const languageId = getLanguageId(document) + + let documentSymbols: DocumentSymbol[] | undefined + const [extractPromise, cancel] = this.#documentSymbolExtractor.extractDocumentSymbols( + document, + codeBlockRange, + languageId + ) + + this.#cancellableByTabId[tabId] = cancel + + try { + documentSymbols = await extractPromise + } catch (error: unknown) { + if (error instanceof Error && error.name === CancellationError.name) { + throw new CancellationError(error.message, error.stack) + } + + this.#logger?.log( + `Error extracting document symbols but continuing on. ${error instanceof Error ? error.message : 'Unknown error'}` + ) + } finally { + delete this.#cancellableByTabId[tabId] + } + + return { + cursorState: rangeWithinCodeBlock ? { range: rangeWithinCodeBlock } : undefined, + documentSymbols, + text: document.getText(codeBlockRange), + programmingLanguage: languageId ? { languageName: languageId } : undefined, + relativeFilePath: document.uri, + hasCodeSnippet: Boolean(rangeWithinCodeBlock), + totalEditorCharacters: document.getText().length, + } + } + + public cancel(tabId: string) { + this.#cancellableByTabId[tabId]?.() + } + + public dispose() { + Object.values(this.#cancellableByTabId).forEach(cancellable => cancellable?.()) + this.#documentSymbolExtractor.dispose() + } + + #guessIntentFromPrompt(prompt?: string): UserIntent | undefined { + if (prompt === undefined) { + return undefined + } else if (/^explain/i.test(prompt)) { + return UserIntent.EXPLAIN_CODE_SELECTION + } else if (/^refactor/i.test(prompt)) { + return UserIntent.SUGGEST_ALTERNATE_IMPLEMENTATION + } else if (/^fix/i.test(prompt)) { + return UserIntent.APPLY_COMMON_BEST_PRACTICES + } else if (/^optimize/i.test(prompt)) { + return UserIntent.IMPROVE_CODE + } + + return undefined + } +} diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContexts.test.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContexts.test.ts deleted file mode 100644 index f003cde0..00000000 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/contexts/triggerContexts.test.ts +++ /dev/null @@ -1,95 +0,0 @@ -import { TestFeatures } from '@aws/language-server-runtimes/testing' -import { QChatTriggerContext } from './triggerContext' -import assert = require('assert') -import { TextDocument } from 'vscode-languageserver-textdocument' -import { DocumentContext, DocumentContextExtractor } from './documentContext' -import sinon = require('sinon') - -describe('QChatTriggerContext', () => { - let testFeatures: TestFeatures - - const filePath = 'file://test.ts' - const mockTSDocument = TextDocument.create(filePath, 'typescript', 1, '') - const mockDocumentContext: DocumentContext = { - text: '', - programmingLanguage: { languageName: 'typescript' }, - relativeFilePath: 'file://test.ts', - documentSymbols: [], - hasCodeSnippet: false, - totalEditorCharacters: 0, - } - - beforeEach(() => { - testFeatures = new TestFeatures() - sinon.stub(DocumentContextExtractor.prototype, 'extractDocumentContext').resolves(mockDocumentContext) - }) - - afterEach(() => { - sinon.restore() - }) - - it('returns null if text document is not defined in params', async () => { - const triggerContext = new QChatTriggerContext(testFeatures.workspace, testFeatures.logging) - - const documentContext = await triggerContext.extractDocumentContext({ - cursorState: [ - { - position: { - line: 5, - character: 0, - }, - }, - ], - textDocument: undefined, - }) - - assert.deepStrictEqual(documentContext, undefined) - }) - - it('returns null if text document is not found', async () => { - const triggerContext = new QChatTriggerContext(testFeatures.workspace, testFeatures.logging) - - const documentContext = await triggerContext.extractDocumentContext({ - cursorState: [ - { - position: { - line: 5, - character: 0, - }, - }, - ], - textDocument: { - uri: filePath, - }, - }) - - assert.deepStrictEqual(documentContext, undefined) - }) - - it('passes default cursor state if no cursor is found', async () => { - const triggerContext = new QChatTriggerContext(testFeatures.workspace, testFeatures.logging) - - const documentContext = await triggerContext.extractDocumentContext({ - cursorState: [], - textDocument: { - uri: filePath, - }, - }) - - assert.deepStrictEqual(documentContext, undefined) - }) - - it('includes cursor state from the parameters and text document if found', async () => { - const triggerContext = new QChatTriggerContext(testFeatures.workspace, testFeatures.logging) - - testFeatures.openDocument(mockTSDocument) - const documentContext = await triggerContext.extractDocumentContext({ - cursorState: [], - textDocument: { - uri: filePath, - }, - }) - - assert.deepStrictEqual(documentContext, mockDocumentContext) - }) -}) diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/telemetry/chatTelemetryController.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/telemetry/chatTelemetryController.ts index 745a7e8c..202005fd 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/telemetry/chatTelemetryController.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/telemetry/chatTelemetryController.ts @@ -15,8 +15,8 @@ import { isClientTelemetryEvent, } from './clientTelemetry' import { UserIntent } from '@amzn/codewhisperer-streaming' -import { TriggerContext } from '../contexts/triggerContext' import { AcceptedSuggestionEntry, CodeDiffTracker } from '../../telemetry/codeDiffTracker' +import { TriggerContext } from '../contexts/triggerContextExtractor' export const CONVERSATION_ID_METRIC_KEY = 'cwsprChatConversationId' @@ -35,14 +35,9 @@ interface MessageTrigger extends TriggerContext { followUpActions?: Set } -interface StartTrigger { - triggerType?: TriggerType - hasUserSnippet?: boolean -} - interface ConversationTriggerInfo { conversationId: string - startTrigger?: StartTrigger + startTriggerType?: TriggerType lastMessageTrigger?: MessageTrigger } @@ -255,9 +250,7 @@ export class ChatTelemetryController { case ChatUIEventName.TabAdd: this.#tabTelemetryInfoByTabId[params.tabId] = { ...this.#tabTelemetryInfoByTabId[params.tabId], - startTrigger: { - triggerType: params.triggerType, - }, + startTriggerType: params.triggerType, } break case ChatUIEventName.EnterFocusChat: diff --git a/server/aws-lsp-codewhisperer/src/language-server/chat/utils.ts b/server/aws-lsp-codewhisperer/src/language-server/chat/utils.ts index 2ea74488..347f42ad 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/chat/utils.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/chat/utils.ts @@ -62,3 +62,11 @@ export function getDefaultChatResponse(prompt?: string): ChatResult | undefined return undefined } + +export class CancellationError extends Error { + constructor(message?: string, stack?: string) { + super(message || 'Promise cancelled') + this.name = 'CancellationError' + this.stack = stack + } +} diff --git a/server/aws-lsp-codewhisperer/src/language-server/types.ts b/server/aws-lsp-codewhisperer/src/language-server/types.ts index 7cdd9820..3f7c5ab6 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/types.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/types.ts @@ -20,3 +20,6 @@ export type LspHandlers = { export type KeysMatching = { [TKey in keyof TMap]: TMap[TKey] extends TCriteria ? TKey : never }[keyof TMap] + +export type Cancel = () => void +export type Cancellable = [T, Cancel] diff --git a/server/aws-lsp-codewhisperer/src/language-server/utils.ts b/server/aws-lsp-codewhisperer/src/language-server/utils.ts index f503050a..e89962ca 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/utils.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/utils.ts @@ -12,7 +12,7 @@ export function isAwsError(error: unknown): error is AWSError { return error instanceof Error && hasCode(error) && hasTime(error) } -function hasCode(error: T): error is T & { code: string } { +export function hasCode(error: T): error is T & { code: string } { return typeof (error as { code?: unknown }).code === 'string' }