Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Oct 11, 2024
2 parents 1f2b45e + c5074f0 commit 36d5609
Show file tree
Hide file tree
Showing 9 changed files with 2,079 additions and 146 deletions.
194 changes: 82 additions & 112 deletions app/client/platforms/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@ import {
LLMUsage,
SpeechOptions,
} from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import {
useAccessStore,
useAppConfig,
useChatStore,
usePluginStore,
ChatMessageTool,
} from "@/app/store";
import { stream } from "@/app/utils/chat";
import { getClientConfig } from "@/app/config/client";
import { GEMINI_BASE_URL } from "@/app/constant";
import Locale from "../../locales";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";

import {
getMessageTextContent,
getMessageImages,
isVisionModel,
} from "@/app/utils";
import { preProcessImageContent } from "@/app/utils/chat";
import { nanoid } from "nanoid";
import { RequestPayload } from "./openai";
import { fetch } from "@/app/utils/stream";

export class GeminiProApi implements LLMApi {
Expand Down Expand Up @@ -178,115 +182,81 @@ export class GeminiProApi implements LLMApi {
);

if (shouldStream) {
let responseText = "";
let remainText = "";
let finished = false;

const finish = () => {
if (!finished) {
finished = true;
options.onFinish(responseText + remainText);
}
};

// animate response to make it looks smooth
function animateResponseText() {
if (finished || controller.signal.aborted) {
responseText += remainText;
finish();
return;
}

if (remainText.length > 0) {
const fetchCount = Math.max(1, Math.round(remainText.length / 60));
const fetchText = remainText.slice(0, fetchCount);
responseText += fetchText;
remainText = remainText.slice(fetchCount);
options.onUpdate?.(responseText, fetchText);
}

requestAnimationFrame(animateResponseText);
}

// start animaion
animateResponseText();

controller.signal.onabort = finish;

fetchEventSource(chatPath, {
fetch: fetch as any,
...chatPayload,
async onopen(res) {
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
console.log(
"[Gemini] request response content type: ",
contentType,
);

if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
}

if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}

if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}

if (extraInfo) {
responseTexts.push(extraInfo);
}

responseText = responseTexts.join("\n\n");

return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const delta = apiClient.extractMessage(json);

if (delta) {
remainText += delta;
}
const [tools, funcs] = usePluginStore
.getState()
.getAsTools(
useChatStore.getState().currentSession().mask?.plugin || [],
);
return stream(
chatPath,
requestPayload,
getHeaders(),
// @ts-ignore
[{ functionDeclarations: tools.map((tool) => tool.function) }],
funcs,
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
const chunkJson = JSON.parse(text);

const blockReason = json?.promptFeedback?.blockReason;
if (blockReason) {
// being blocked
console.log(`[Google] [Safety Ratings] result:`, blockReason);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
const functionCall = chunkJson?.candidates
?.at(0)
?.content.parts.at(0)?.functionCall;
if (functionCall) {
const { name, args } = functionCall;
runTools.push({
id: nanoid(),
type: "function",
function: {
name,
arguments: JSON.stringify(args), // utils.chat call function, using JSON.parse
},
});
}
return chunkJson?.candidates?.at(0)?.content.parts.at(0)?.text;
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
// processToolMessage, include tool_calls message and tool call results
(
requestPayload: RequestPayload,
toolCallMessage: any,
toolCallResult: any[],
) => {
// @ts-ignore
requestPayload?.contents?.splice(
// @ts-ignore
requestPayload?.contents?.length,
0,
{
role: "model",
parts: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
functionCall: {
name: tool?.function?.name,
args: JSON.parse(tool?.function?.arguments as string),
},
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "function",
parts: [
{
functionResponse: {
name: result.name,
response: {
name: result.name,
content: result.content, // TODO just text content...
},
},
},
],
})),
);
},
openWhenHidden: true,
});
options,
);
} else {
const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId);
Expand Down
2 changes: 1 addition & 1 deletion app/store/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ export const usePromptStore = createPersistStore(
if (typeof window === "undefined") {
return;
}

const PROMPT_URL = "./prompts.json";

type PromptList = Array<[string, string]>;
Expand Down
3 changes: 3 additions & 0 deletions app/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ export function showPlugins(provider: ServiceProvider, model: string) {
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
return true;
}
if (provider == ServiceProvider.Google && !model.includes("vision")) {
return true;
}
return false;
}

Expand Down
6 changes: 6 additions & 0 deletions app/utils/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ export function stream(
)
.then((res) => {
let content = res.data || res?.statusText;
// hotfix #5614
content =
typeof content === "string"
? content
: JSON.stringify(content);
if (res.status >= 300) {
return Promise.reject(content);
}
Expand All @@ -245,6 +250,7 @@ export function stream(
return e.toString();
})
.then((content) => ({
name: tool.function.name,
role: "tool",
content,
tool_call_id: tool.id,
Expand Down
21 changes: 21 additions & 0 deletions jest.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import type { Config } from "jest";
import nextJest from "next/jest.js";

const createJestConfig = nextJest({
// Provide the path to your Next.js app to load next.config.js and .env files in your test environment
dir: "./",
});

// Add any custom config to be passed to Jest
const config: Config = {
coverageProvider: "v8",
testEnvironment: "jsdom",
testMatch: ["**/*.test.js", "**/*.test.ts", "**/*.test.jsx", "**/*.test.tsx"],
setupFilesAfterEnv: ["<rootDir>/jest.setup.ts"],
moduleNameMapper: {
"^@/(.*)$": "<rootDir>/$1",
},
};

// createJestConfig is exported this way to ensure that next/jest can load the Next.js config which is async
export default createJestConfig(config);
2 changes: 2 additions & 0 deletions jest.setup.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Learn more: https://github.com/testing-library/jest-dom
import "@testing-library/jest-dom";
16 changes: 12 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
"mask": "npx tsx app/masks/build.ts",
"mask:watch": "npx watch \"yarn mask\" app/masks",
"dev": "concurrently -r \"yarn run mask:watch\" \"next dev\"",
"build": "yarn mask && cross-env BUILD_MODE=standalone next build",
"build": "yarn test:ci && yarn mask && cross-env BUILD_MODE=standalone next build",
"start": "next start",
"lint": "next lint",
"export": "yarn mask && cross-env BUILD_MODE=export BUILD_APP=1 next build",
"export": "yarn test:ci && yarn mask && cross-env BUILD_MODE=export BUILD_APP=1 next build",
"export:dev": "concurrently -r \"yarn mask:watch\" \"cross-env BUILD_MODE=export BUILD_APP=1 next dev\"",
"app:dev": "concurrently -r \"yarn mask:watch\" \"yarn tauri dev\"",
"app:build": "yarn mask && yarn tauri build",
"app:build": "yarn test:ci && yarn mask && yarn tauri build",
"prompts": "node ./scripts/fetch-prompts.mjs",
"prepare": "husky install",
"proxy-dev": "sh ./scripts/init-proxy.sh && proxychains -f ./scripts/proxychains.conf yarn dev"
"proxy-dev": "sh ./scripts/init-proxy.sh && proxychains -f ./scripts/proxychains.conf yarn dev",
"test": "jest --watch",
"test:ci": "jest --ci"
},
"dependencies": {
"@fortaine/fetch-event-source": "^3.0.6",
Expand Down Expand Up @@ -54,6 +56,9 @@
"devDependencies": {
"@tauri-apps/api": "^1.6.0",
"@tauri-apps/cli": "1.5.11",
"@testing-library/jest-dom": "^6.4.8",
"@testing-library/react": "^16.0.0",
"@types/jest": "^29.5.12",
"@types/js-yaml": "4.0.9",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.11.30",
Expand All @@ -69,8 +74,11 @@
"eslint-plugin-prettier": "^5.1.3",
"eslint-plugin-unused-imports": "^3.2.0",
"husky": "^8.0.0",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0",
"lint-staged": "^13.2.2",
"prettier": "^3.0.2",
"ts-node": "^10.9.2",
"tsx": "^4.16.0",
"typescript": "5.2.2",
"watch": "^1.0.2",
Expand Down
9 changes: 9 additions & 0 deletions test/sum-module.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function sum(a: number, b: number) {
return a + b;
}

describe("sum module", () => {
test("adds 1 + 2 to equal 3", () => {
expect(sum(1, 2)).toBe(3);
});
});
Loading

0 comments on commit 36d5609

Please sign in to comment.