Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stop generating button #424

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/App.vue
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
<SettingsModal v-model:open="isSettingsOpen" />
<ConfirmModal ref="confirmModal" />
<UpdateNotification></UpdateNotification>
<StopGenerating v-if="isGenerating"/>
<ShortcutGuide
ref="shortcutGuideRef"
v-model:open="isShortcutGuideOpen"
Expand Down Expand Up @@ -125,6 +126,7 @@ import FooterBar from "@/components/Footer/FooterBar.vue";
import UpdateNotification from "@/components/Notification/UpdateNotificationModal.vue";
import FindModal from "@/components/FindModal.vue";
import ShortcutGuide from "@/components/ShortcutGuide/ShortcutGuide.vue";
import StopGenerating from "@/components/StopGenerating/StopGenerating.vue";

// Styles
import "@mdi/font/css/materialdesignicons.css";
Expand All @@ -150,6 +152,10 @@ const isSettingsOpen = ref(false);
const isChatDrawerOpen = ref(store.state.isChatDrawerOpen);
const chatDrawerRef = ref();

const isGenerating = computed(() => {
const messages = store.getters.currentChat.messages || [];
return messages.filter(v => v.type === 'response').some(v => !v.done);
});
const columns = computed(() => store.state.columns);

const changeColumns = (columns) => store.commit("changeColumns", columns);
Expand Down
8 changes: 7 additions & 1 deletion src/bots/Bot.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ export default class Bot {
async _sendPrompt(prompt, onUpdateResponse, callbackParam) {
throw new Error(i18n.global.t("bot.notImplemented"));
}
/* eslint-enable no-unused-vars */

async _stopGenerating() {
}

stopGenerating() {
this._stopGenerating();
}

async sendPrompt(prompt, onUpdateResponse, callbackParam) {
// If not logged in, handle the error
Expand Down
190 changes: 150 additions & 40 deletions src/bots/LangChainBot.js
Original file line number Diff line number Diff line change
@@ -1,56 +1,166 @@
import Bot from "@/bots/Bot";
import { HumanMessage, AIMessage, SystemMessage } from "langchain/schema";
import store from "@/store";
import { SSE } from 'sse.js';

export default class LangChainBot extends Bot {
static _brandId = "langChainBot";
static _chatModel = undefined; // ChatModel instance

constructor() {
super();
super();
this.source = null;
}
async _sendPrompt(prompt, onUpdateResponse, callbackParam) {
let messages = await this.getChatContext();
// Remove old messages if exceeding the pastRounds limit
while (messages.length > store.state.openaiApi.pastRounds * 2) {
messages.shift();
}

async _sendPrompt(prompt, onUpdateResponse, callbackParam) {
let messages = await this.getChatContext();
// Remove old messages if exceeding the pastRounds limit
while (messages.length > this.getPastRounds() * 2) {
messages.shift();
// Send the prompt to the OpenAI API
try {
const headers = {
'Content-Type': 'application/json',
Authorization: `Bearer ${store.state.openaiApi.apiKey}`,
};

messages.push({ role: 'user', content: `‘${prompt}’` });
const payload = JSON.stringify({
model: this.constructor._model,
messages: messages,
temperature: store.state.openaiApi.temperature,
stream: true,
});

const requestConfig = {
headers,
method: 'POST',
payload,
};

let res = '';
return new Promise((resolve, reject) => {
// call OpenAI API
const apiUrl
= store.state.openaiApi.alterUrl
|| 'https://api.openai.com/v1/chat/completions';
this.source = new SSE(apiUrl, requestConfig);
this.source.addEventListener('message', event => {
const regex = /^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}$/;
if (event.data === '[DONE]') {
onUpdateResponse(callbackParam, { done: true });
messages.push({ role: 'assistant', content: res });
this.setChatContext(messages);
this.source.close();
resolve();
}
else if (regex.test(event.data)) {
// Ignore the timestamp
return;
}
else {
if (event?.source?.chunk?.startsWith('{')) {
const { code, msg, error } = JSON.parse(event?.source?.chunk || '{}');
if (error && error?.message) {
this.source.close();
reject(error?.message);
return;
}
if (code >= 400) {
this.source.close();
reject(`${code}: ${msg}`);
return;
}
}
try {
const data = JSON.parse(event.data);
const partialText = data.choices?.[0]?.delta?.content;
if (partialText) {
res += partialText;
onUpdateResponse(callbackParam, { content: res, done: false });
}
}
catch (e) {
this.source.close();
reject(e);
}
}
});
this.source.addEventListener('error', error => {
try {
const data = (() => {
if (error?.data) {
return error?.data?.startsWith('{')
? JSON.parse(error.data) : typeof error?.data === 'object'
? JSON.stringify(error.data) : error.data;
}
return error;
})();
this.source.close();
reject(data?.error?.message || data?.error?.msg || data?.data || data || '');
}
catch (e) {
this.source.close();
console.error(e);
reject(e);
}

});
this.source.stream();
});
}
catch (error) {
console.error('Error sending prompt to OpenAIAPI:', error);
throw error;
}
}

// Deserialize the messages and convert them to the correct format
messages = messages.map((item) => {
let storedMessage = JSON.parse(item); // Deserialize
if (storedMessage.type === "human") {
return new HumanMessage(storedMessage.data);
} else if (storedMessage.type === "ai") {
return new AIMessage(storedMessage.data);
} else if (storedMessage.type === "system") {
return new SystemMessage(storedMessage.data);
}
});
_stopGenerating() {
this.source && this.source.close();
}

// Add the prompt to the messages
messages.push(new HumanMessage(prompt));
// async _sendPrompt(prompt, onUpdateResponse, callbackParam) {
// let messages = await this.getChatContext();
// // Remove old messages if exceeding the pastRounds limit
// while (messages.length > this.getPastRounds() * 2) {
// messages.shift();
// }

let res = "";
const model = this.constructor._chatModel;
const callbacks = [
{
handleLLMNewToken(token) {
res += token;
onUpdateResponse(callbackParam, { content: res, done: false });
},
handleLLMEnd() {
onUpdateResponse(callbackParam, { done: true });
},
},
];
model.callbacks = callbacks;
await model.call(messages);
messages.push(new AIMessage(res));
// Serialize the messages before storing
messages = messages.map((item) => JSON.stringify(item.toDict()));
this.setChatContext(messages);
}
// // Deserialize the messages and convert them to the correct format
// messages = messages.map((item) => {
// let storedMessage = JSON.parse(item); // Deserialize
// if (storedMessage.type === "human") {
// return new HumanMessage(storedMessage.data);
// } else if (storedMessage.type === "ai") {
// return new AIMessage(storedMessage.data);
// } else if (storedMessage.type === "system") {
// return new SystemMessage(storedMessage.data);
// }
// });

// // Add the prompt to the messages
// messages.push(new HumanMessage(prompt));

// let res = "";
// const model = this.constructor._chatModel;
// const callbacks = [
// {
// handleLLMNewToken(token) {
// res += token;
// onUpdateResponse(callbackParam, { content: res, done: false });
// },
// handleLLMEnd() {
// onUpdateResponse(callbackParam, { done: true });
// },
// },
// ];
// model.callbacks = callbacks;
// await model.call(messages);
// messages.push(new AIMessage(res));
// // Serialize the messages before storing
// messages = messages.map((item) => JSON.stringify(item.toDict()));
// this.setChatContext(messages);
// }

async createChatContext() {
return [];
Expand Down
15 changes: 10 additions & 5 deletions src/components/Footer/FooterBar.vue
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
<v-btn
class="send-prompt-btn"
elevation="2"
:disabled="
prompt.trim() === '' ||
favBots.filter((favBot) => activeBots[favBot.classname]).length === 0
"
:disabled="isSendDisabled"
@click="sendPromptToBots"
>
{{ $t("footer.sendPrompt") }}
Expand Down Expand Up @@ -121,6 +118,12 @@ const favBots = computed(() => {
});

const prompt = ref("");
const isSendDisabled = computed(() => {
const messages = store.getters.currentChat.messages || [];
return prompt.value.trim() === ''
|| favBots.value.filter(favBot => activeBots[favBot.classname]).length === 0
|| messages.filter(v => v.type === 'response').some(v => !v.done);
});
const clickedBot = ref(null);
const isMakeAvailableOpen = ref(false);

Expand Down Expand Up @@ -184,7 +187,9 @@ function filterEnterKey(event) {
!event.metaKey
) {
event.preventDefault();
sendPromptToBots();
if (!isSendDisabled.value) {
sendPromptToBots();
}
}
}

Expand Down
55 changes: 55 additions & 0 deletions src/components/StopGenerating/StopGenerating.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<template>
<v-btn
class="stop-generating-btn"
:class="{'drawer-open': store.state.isChatDrawerOpen}"
@click="stopGenerating"
>
{{ $t("footer.stopGenerating") }}
</v-btn>
</template>

<script setup>
import store from '@/store';
import _bots from '@/bots';
import {computed} from 'vue';

const favBots = computed(() => {
const _favBots = [];
store.getters.currentChat.favBots.forEach(favBot => {
_favBots.push({
...favBot,
instance: _bots.getBotByClassName(favBot.classname),
});
});
return _favBots;
});

const stopGenerating = () => {
const toBots = favBots.value
.map(favBot => favBot.instance);

if (toBots.length === 0) {
return;
}

store.dispatch('stopGenerating', {
bots: toBots,
});
};
</script>

<style>
.stop-generating-btn {
width: 100px;
position: fixed !important;
bottom: 60px;
left: 0;
right: 0;
z-index: 9999;
color: rgb(var(--v-theme-font));
margin: 0 auto;
}
.stop-generating-btn.drawer-open {
right: -255px;
}
</style>
1 change: 1 addition & 0 deletions src/i18n/locales/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
},
"footer": {
"chooseFavorite": "选择你喜欢的 AI",
"stopGenerating": "停止生成",
"sendPrompt": "发送到:",
"promptPlaceholder": "输入消息。(Shift+Enter 换行)"
},
Expand Down
11 changes: 11 additions & 0 deletions src/store/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,17 @@ export default createStore({
},
},
actions: {
stopGenerating({ state, }, { bots }) {
bots?.map(v => v?.stopGenerating());
const currentChat = state.chats[state.currentChatIndex];
if (currentChat.messages) {
currentChat.messages.forEach(
item => {
item.done = true;
}
);
}
},
sendPrompt({ commit, state, dispatch }, { prompt, bots, promptIndex }) {
isThrottle = false;
const currentChat = state.chats[state.currentChatIndex];
Expand Down