Skip to content

Commit

Permalink
Refactor handleStream to LLM Classes (#685)
Browse files Browse the repository at this point in the history
  • Loading branch information
timothycarambat authored Feb 7, 2024
1 parent e2a6a2d commit aca5940
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 307 deletions.
38 changes: 36 additions & 2 deletions server/utils/AiProviders/azureOpenAi/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi");
const { chatPrompt } = require("../../chats");
const { writeResponseChunk } = require("../../chats/stream");

class AzureOpenAiLLM {
constructor(embedder = null, _modelPreference = null) {
Expand Down Expand Up @@ -135,7 +136,7 @@ class AzureOpenAiLLM {
n: 1,
}
);
return { type: "azureStream", stream };
return stream;
}

async getChatCompletion(messages = [], { temperature = 0.7 }) {
Expand Down Expand Up @@ -165,7 +166,40 @@ class AzureOpenAiLLM {
n: 1,
}
);
return { type: "azureStream", stream };
return stream;
}

handleStream(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;

return new Promise(async (resolve) => {
let fullText = "";
for await (const event of stream) {
for (const choice of event.choices) {
const delta = choice.delta?.content;
if (!delta) continue;
fullText += delta;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: delta,
close: false,
error: false,
});
}
}

writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
});
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
Expand Down
34 changes: 32 additions & 2 deletions server/utils/AiProviders/gemini/index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const { chatPrompt } = require("../../chats");
const { writeResponseChunk } = require("../../chats/stream");

class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
Expand Down Expand Up @@ -164,7 +165,7 @@ class GeminiLLM {
if (!responseStream.stream)
throw new Error("Could not stream response stream from Gemini.");

return { type: "geminiStream", ...responseStream };
return responseStream.stream;
}

async streamGetChatCompletion(messages = [], _opts = {}) {
Expand All @@ -183,7 +184,7 @@ class GeminiLLM {
if (!responseStream.stream)
throw new Error("Could not stream response stream from Gemini.");

return { type: "geminiStream", ...responseStream };
return responseStream.stream;
}

async compressMessages(promptArgs = {}, rawHistory = []) {
Expand All @@ -192,6 +193,35 @@ class GeminiLLM {
return await messageArrayCompressor(this, messageArray, rawHistory);
}

handleStream(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;

return new Promise(async (resolve) => {
let fullText = "";
for await (const chunk of stream) {
fullText += chunk.text();
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: chunk.text(),
close: false,
error: false,
});
}

writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
});
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
Expand Down
113 changes: 111 additions & 2 deletions server/utils/AiProviders/huggingface/index.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { OpenAiEmbedder } = require("../../EmbeddingEngines/openAi");
const { chatPrompt } = require("../../chats");
const { writeResponseChunk } = require("../../chats/stream");

class HuggingFaceLLM {
constructor(embedder = null, _modelPreference = null) {
Expand Down Expand Up @@ -138,7 +139,7 @@ class HuggingFaceLLM {
},
{ responseType: "stream" }
);
return { type: "huggingFaceStream", stream: streamRequest };
return streamRequest;
}

async getChatCompletion(messages = null, { temperature = 0.7 }) {
Expand All @@ -162,7 +163,115 @@ class HuggingFaceLLM {
},
{ responseType: "stream" }
);
return { type: "huggingFaceStream", stream: streamRequest };
return streamRequest;
}

handleStream(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;

return new Promise((resolve) => {
let fullText = "";
let chunk = "";
stream.data.on("data", (data) => {
const lines = data
?.toString()
?.split("\n")
.filter((line) => line.trim() !== "");

for (const line of lines) {
let validJSON = false;
const message = chunk + line.replace(/^data:/, "");
if (message !== "[DONE]") {
// JSON chunk is incomplete and has not ended yet
// so we need to stitch it together. You would think JSON
// chunks would only come complete - but they don't!
try {
JSON.parse(message);
validJSON = true;
} catch {
console.log("Failed to parse message", message);
}

if (!validJSON) {
// It can be possible that the chunk decoding is running away
// and the message chunk fails to append due to string length.
// In this case abort the chunk and reset so we can continue.
// ref: https://github.com/Mintplex-Labs/anything-llm/issues/416
try {
chunk += message;
} catch (e) {
console.error(`Chunk appending error`, e);
chunk = "";
}
continue;
} else {
chunk = "";
}
}

if (message == "[DONE]") {
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
} else {
let error = null;
let finishReason = null;
let token = "";
try {
const json = JSON.parse(message);
error = json?.error || null;
token = json?.choices?.[0]?.delta?.content;
finishReason = json?.choices?.[0]?.finish_reason || null;
} catch {
continue;
}

if (!!error) {
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: null,
close: true,
error,
});
resolve("");
return;
}

if (token) {
fullText += token;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: token,
close: false,
error: false,
});
}

if (finishReason !== null) {
writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
}
}
}
});
});
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
Expand Down
5 changes: 5 additions & 0 deletions server/utils/AiProviders/lmStudio/index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const { chatPrompt } = require("../../chats");
const { handleDefaultStreamResponse } = require("../../chats/stream");

// hybrid of openAi LLM chat completion for LMStudio
class LMStudioLLM {
Expand Down Expand Up @@ -174,6 +175,10 @@ class LMStudioLLM {
return streamRequest;
}

handleStream(response, stream, responseProps) {
return handleDefaultStreamResponse(response, stream, responseProps);
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
Expand Down
5 changes: 5 additions & 0 deletions server/utils/AiProviders/localAi/index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const { chatPrompt } = require("../../chats");
const { handleDefaultStreamResponse } = require("../../chats/stream");

class LocalAiLLM {
constructor(embedder = null, modelPreference = null) {
Expand Down Expand Up @@ -174,6 +175,10 @@ class LocalAiLLM {
return streamRequest;
}

handleStream(response, stream, responseProps) {
return handleDefaultStreamResponse(response, stream, responseProps);
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
Expand Down
5 changes: 5 additions & 0 deletions server/utils/AiProviders/mistral/index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const { chatPrompt } = require("../../chats");
const { handleDefaultStreamResponse } = require("../../chats/stream");

class MistralLLM {
constructor(embedder = null, modelPreference = null) {
Expand Down Expand Up @@ -164,6 +165,10 @@ class MistralLLM {
return streamRequest;
}

handleStream(response, stream, responseProps) {
return handleDefaultStreamResponse(response, stream, responseProps);
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
Expand Down
36 changes: 36 additions & 0 deletions server/utils/AiProviders/native/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const fs = require("fs");
const path = require("path");
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { chatPrompt } = require("../../chats");
const { writeResponseChunk } = require("../../chats/stream");

// Docs: https://api.js.langchain.com/classes/chat_models_llama_cpp.ChatLlamaCpp.html
const ChatLlamaCpp = (...args) =>
Expand Down Expand Up @@ -170,6 +171,41 @@ class NativeLLM {
return responseStream;
}

handleStream(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;

return new Promise(async (resolve) => {
let fullText = "";
for await (const chunk of stream) {
if (chunk === undefined)
throw new Error(
"Stream returned undefined chunk. Aborting reply - check model provider logs."
);

const content = chunk.hasOwnProperty("content") ? chunk.content : chunk;
fullText += content;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: content,
close: false,
error: false,
});
}

writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
});
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
async embedTextInput(textInput) {
return await this.embedder.embedTextInput(textInput);
Expand Down
Loading

0 comments on commit aca5940

Please sign in to comment.