diff --git a/.changeset/spicy-singers-flow.md b/.changeset/spicy-singers-flow.md new file mode 100644 index 00000000..6c09c5cf --- /dev/null +++ b/.changeset/spicy-singers-flow.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": minor +--- + +exposed llmClient in stagehand constructor diff --git a/README.md b/README.md index 141bde86..b75fe2dc 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,7 @@ This constructor is used to create an instance of Stagehand. - `1`: SDK-level logging - `2`: LLM-client level logging (most granular) - `debugDom`: a `boolean` that draws bounding boxes around elements presented to the LLM during automation. + - `llmClient`: (optional) a custom `LLMClient` implementation. - **Returns:** diff --git a/examples/external_client.ts b/examples/external_client.ts new file mode 100644 index 00000000..26c40661 --- /dev/null +++ b/examples/external_client.ts @@ -0,0 +1,46 @@ +import { type ConstructorParams, type LogLine, Stagehand } from "../lib"; +import { z } from "zod"; +import { OllamaClient } from "./external_clients/ollama"; + +const StagehandConfig: ConstructorParams = { + env: "BROWSERBASE", + apiKey: process.env.BROWSERBASE_API_KEY, + projectId: process.env.BROWSERBASE_PROJECT_ID, + verbose: 1, + llmClient: new OllamaClient( + (message: LogLine) => + console.log(`[stagehand::${message.category}] ${message.message}`), + false, + undefined, + "llama3.2", + ), + debugDom: true, +}; + +async function example() { + const stagehand = new Stagehand(StagehandConfig); + + await stagehand.init(); + await stagehand.page.goto("https://news.ycombinator.com"); + + const headlines = await stagehand.page.extract({ + instruction: "Extract only 3 stories from the Hacker News homepage.", + schema: z.object({ + stories: z.array( + z.object({ + title: z.string(), + url: z.string(), + points: z.number(), + }), + ).length(3), + }), + }); + + console.log(headlines); + + await stagehand.close(); +} + +(async () => { + await example(); +})(); diff --git a/examples/external_clients/ollama.ts b/examples/external_clients/ollama.ts new file mode 100644 index 00000000..39415227 --- /dev/null +++ b/examples/external_clients/ollama.ts @@ -0,0 +1,313 @@ +import OpenAI, { type ClientOptions } from "openai"; +import { zodResponseFormat } from "openai/helpers/zod"; +import type { LLMCache } from "../../lib/cache/LLMCache"; +import { validateZodSchema } from "../../lib/utils"; +import { + type ChatCompletionOptions, + type ChatMessage, + LLMClient, +} from "../../lib/llm/LLMClient"; +import type { LogLine } from "../../types/log"; +import type { AvailableModel } from "../../types/model"; +import type { + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImage, + ChatCompletionContentPartText, + ChatCompletionCreateParamsNonStreaming, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +} from "openai/resources/chat"; + +export class OllamaClient extends LLMClient { + public type = "ollama" as const; + private client: OpenAI; + private cache: LLMCache | undefined; + public logger: (message: LogLine) => void; + private enableCaching: boolean; + public clientOptions: ClientOptions; + + constructor( + logger: (message: LogLine) => void, + enableCaching = false, + cache: LLMCache | undefined, + modelName: "llama3.2", + clientOptions?: ClientOptions, + ) { + super(modelName as AvailableModel); + this.client = new OpenAI({ + ...clientOptions, + baseURL: clientOptions?.baseURL || "http://localhost:11434/v1", + apiKey: "ollama", + }); + this.logger = logger; + this.cache = cache; + this.enableCaching = enableCaching; + this.modelName = modelName as AvailableModel; + } + + async createChatCompletion( + options: ChatCompletionOptions, + retries = 3, + ): Promise { + const { image, requestId, ...optionsWithoutImageAndRequestId } = options; + + // TODO: Implement vision support + if (image) { + throw new Error( + "Image provided. Vision is not currently supported for Ollama", + ); + } + + this.logger({ + category: "ollama", + message: "creating chat completion", + level: 1, + auxiliary: { + options: { + value: JSON.stringify({ + ...optionsWithoutImageAndRequestId, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + top_p: options.top_p, + frequency_penalty: options.frequency_penalty, + presence_penalty: options.presence_penalty, + image: image, + response_model: options.response_model, + }; + + if (options.image) { + const screenshotMessage: ChatMessage = { + role: "user", + content: [ + { + type: "image_url", + image_url: { + url: `data:image/jpeg;base64,${options.image.buffer.toString("base64")}`, + }, + }, + ...(options.image.description + ? [{ type: "text", text: options.image.description }] + : []), + ], + }; + + options.messages.push(screenshotMessage); + } + + if (this.enableCaching && this.cache) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + + if (cachedResponse) { + this.logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + }, + }); + return cachedResponse; + } + + this.logger({ + category: "llm_cache", + message: "LLM cache miss - no cached response found", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + } + + let responseFormat = undefined; + if (options.response_model) { + responseFormat = zodResponseFormat( + options.response_model.schema, + options.response_model.name, + ); + } + + /* eslint-disable */ + // Remove unsupported options + const { response_model, ...ollamaOptions } = { + ...optionsWithoutImageAndRequestId, + model: this.modelName, + }; + + this.logger({ + category: "ollama", + message: "creating chat completion", + level: 1, + auxiliary: { + ollamaOptions: { + value: JSON.stringify(ollamaOptions), + type: "object", + }, + }, + }); + + const formattedMessages: ChatCompletionMessageParam[] = + options.messages.map((message) => { + if (Array.isArray(message.content)) { + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ChatCompletionContentPartImage = { + image_url: { + url: content.image_url.url, + }, + type: "image_url", + }; + return imageContent; + } else { + const textContent: ChatCompletionContentPartText = { + text: content.text, + type: "text", + }; + return textContent; + } + }); + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + ...message, + role: "system", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } else if (message.role === "user") { + const formattedMessage: ChatCompletionUserMessageParam = { + ...message, + role: "user", + content: contentParts, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionAssistantMessageParam = { + ...message, + role: "assistant", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } + } + + const formattedMessage: ChatCompletionUserMessageParam = { + role: "user", + content: message.content, + }; + + return formattedMessage; + }); + + const body: ChatCompletionCreateParamsNonStreaming = { + ...ollamaOptions, + model: this.modelName, + messages: formattedMessages, + response_format: responseFormat, + stream: false, + tools: options.tools?.filter((tool) => "function" in tool), // ensure only OpenAI compatibletools are used + }; + + const response = await this.client.chat.completions.create(body); + + this.logger({ + category: "ollama", + message: "response", + level: 1, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + if (options.response_model) { + const extractedData = response.choices[0].message.content; + const parsedData = JSON.parse(extractedData); + + if (!validateZodSchema(options.response_model.schema, parsedData)) { + if (retries > 0) { + return this.createChatCompletion(options, retries - 1); + } + + throw new Error("Invalid response schema"); + } + + if (this.enableCaching) { + this.cache.set( + cacheOptions, + { + ...parsedData, + }, + options.requestId, + ); + } + + return parsedData; + } + + if (this.enableCaching) { + this.logger({ + category: "llm_cache", + message: "caching response", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + response: { + value: JSON.stringify(response), + type: "object", + }, + }, + }); + this.cache.set(cacheOptions, response, options.requestId); + } + + return response as T; + } +} diff --git a/lib/index.ts b/lib/index.ts index 9d8527c5..5bbc80e1 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -341,6 +341,7 @@ export class Stagehand { verbose, debugDom, llmProvider, + llmClient, headless, logger, browserbaseSessionCreateParams, @@ -365,7 +366,7 @@ export class Stagehand { this.projectId = projectId ?? process.env.BROWSERBASE_PROJECT_ID; this.verbose = verbose ?? 0; this.debugDom = debugDom ?? false; - this.llmClient = this.llmProvider.getClient( + this.llmClient = llmClient || this.llmProvider.getClient( modelName ?? DEFAULT_MODEL_NAME, modelClientOptions, ); diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index 254454ca..c93fab49 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -65,7 +65,7 @@ export interface ChatCompletionOptions { export type LLMResponse = AnthropicTransformedResponse | ChatCompletion; export abstract class LLMClient { - public type: "openai" | "anthropic"; + public type: "openai" | "anthropic" | string; public modelName: AvailableModel; public hasVision: boolean; public clientOptions: ClientOptions; diff --git a/package.json b/package.json index 690f7bd9..fde65cc9 100644 --- a/package.json +++ b/package.json @@ -9,6 +9,7 @@ "2048": "npm run build-dom-scripts && tsx examples/2048.ts", "example": "npm run build-dom-scripts && tsx examples/example.ts", "debug-url": "npm run build-dom-scripts && tsx examples/debugUrl.ts", + "external-client": "npm run build-dom-scripts && tsx examples/external_client.ts", "format": "prettier --write .", "prettier": "prettier --check .", "prettier:fix": "prettier --write .", diff --git a/types/stagehand.ts b/types/stagehand.ts index 612fc4df..7eeecb56 100644 --- a/types/stagehand.ts +++ b/types/stagehand.ts @@ -4,6 +4,7 @@ import { z } from "zod"; import { LLMProvider } from "../lib/llm/LLMProvider"; import { LogLine } from "./log"; import { AvailableModel, ClientOptions } from "./model"; +import { LLMClient } from "../lib/llm/LLMClient"; export interface ConstructorParams { env: "LOCAL" | "BROWSERBASE"; @@ -19,6 +20,7 @@ export interface ConstructorParams { enableCaching?: boolean; browserbaseSessionID?: string; modelName?: AvailableModel; + llmClient?: LLMClient; modelClientOptions?: ClientOptions; }