Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow llm model to be configurable #42

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
"dependencies": {
"@arcadeai/arcadejs": "^0.1.2",
"@googleapis/youtube": "^20.0.0",
"@langchain/anthropic": "^0.3.9",
"@langchain/anthropic": "^0.3.11",
"@langchain/community": "^0.3.22",
"@langchain/core": "^0.3.22",
"@langchain/core": "^0.3.27",
"@langchain/google-vertexai-web": "^0.1.2",
"@langchain/langgraph": "^0.2.31",
"@langchain/langgraph-sdk": "^0.0.31",
Expand All @@ -54,6 +54,7 @@
"express-session": "^1.18.1",
"file-type": "^19.6.0",
"google-auth-library": "^9.15.0",
"langchain": "^0.3.10",
"langsmith": "0.2.15-rc.2",
"moment": "^2.30.1",
"passport": "^0.7.0",
Expand Down
11 changes: 0 additions & 11 deletions scripts/generate-demo-post.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import "dotenv/config";
import { Client } from "@langchain/langgraph-sdk";
// import {
// LINKEDIN_USER_ID,
// TWITTER_USER_ID,
// } from "../src/agents/generate-post/constants.js";

/**
* Generate a post based on the Open Canvas project.
Expand All @@ -22,13 +18,6 @@ async function invokeGraph() {
input: {
links: [link],
},
config: {
configurable: {
// By default, the graph will read these values from the environment
// [TWITTER_USER_ID]: process.env.TWITTER_USER_ID,
// [LINKEDIN_USER_ID]: process.env.LINKEDIN_USER_ID,
},
},
});
}

Expand Down
2 changes: 2 additions & 0 deletions src/agents/generate-post/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,5 @@ export const TWITTER_USER_ID = "twitterUserId";
export const TWITTER_TOKEN = "twitterToken";
export const TWITTER_TOKEN_SECRET = "twitterTokenSecret";
export const INGEST_TWITTER_USERNAME = "ingestTwitterUsername";

export const LLM_MODEL_NAME = "llmModel";
12 changes: 11 additions & 1 deletion src/agents/generate-post/generate-post-state.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Annotation, END } from "@langchain/langgraph";
import { IngestDataAnnotation } from "../ingest-data/ingest-data-state.js";
import { POST_TO_LINKEDIN_ORGANIZATION } from "./constants.js";
import { LLM_MODEL_NAME, POST_TO_LINKEDIN_ORGANIZATION } from "./constants.js";
import { DateType } from "../types.js";

export type LangChainProduct = "langchain" | "langgraph" | "langsmith";
Expand Down Expand Up @@ -102,4 +102,14 @@ export const GeneratePostConfigurableAnnotation = Annotation.Root({
* If true, [LINKEDIN_ORGANIZATION_ID] is required.
*/
[POST_TO_LINKEDIN_ORGANIZATION]: Annotation<boolean | undefined>,
/**
* The name of the LLM to use for generations
* @default "gemini-2.0-flash-exp"
*/
[LLM_MODEL_NAME]: Annotation<
"gemini-2.0-flash-exp" | "claude-3-5-sonnet-latest" | undefined
>({
reducer: (_state, update) => update,
default: () => "gemini-2.0-flash-exp",
}),
});
8 changes: 4 additions & 4 deletions src/agents/generate-post/nodes/condense-post.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { ChatAnthropic } from "@langchain/anthropic";
import { GeneratePostAnnotation } from "../generate-post-state.js";
import { STRUCTURE_INSTRUCTIONS, RULES } from "./geterate-post/prompts.js";
import { parseGeneration } from "./geterate-post/utils.js";
import { removeUrls } from "../../utils.js";
import { getModelFromConfig, removeUrls } from "../../utils.js";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

const CONDENSE_POST_PROMPT = `You're a highly skilled marketer at LangChain, working on crafting thoughtful and engaging content for LangChain's LinkedIn and Twitter pages.
You wrote a post for the LangChain LinkedIn and Twitter pages, however it's a bit too long for Twitter, and thus needs to be condensed.
Expand Down Expand Up @@ -52,6 +52,7 @@ Follow all rules and instructions outlined above. The user message below will pr
*/
export async function condensePost(
state: typeof GeneratePostAnnotation.State,
config: LangGraphRunnableConfig,
): Promise<Partial<typeof GeneratePostAnnotation.State>> {
if (!state.post) {
throw new Error("No post found");
Expand All @@ -72,8 +73,7 @@ export async function condensePost(
.replace("{link}", state.relevantLinks[0])
.replace("{originalPostLength}", originalPostLength);

const condensePostModel = new ChatAnthropic({
model: "claude-3-5-sonnet-20241022",
const condensePostModel = await getModelFromConfig(config, {
temperature: 0.5,
});

Expand Down
7 changes: 3 additions & 4 deletions src/agents/generate-post/nodes/generate-content-report.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { GeneratePostAnnotation } from "../generate-post-state.js";
import { LANGCHAIN_PRODUCTS_CONTEXT } from "../prompts.js";
import { ChatAnthropic } from "@langchain/anthropic";
import { getModelFromConfig } from "../../utils.js";

const GENERATE_REPORT_PROMPT = `You are a highly regarded marketing employee at LangChain.
You have been tasked with writing a marketing report on content submitted to you from a third party which uses LangChain's products.
Expand Down Expand Up @@ -85,10 +85,9 @@ ${pageContents.map((content, index) => `<Content index={${index + 1}}>\n${conten

export async function generateContentReport(
state: typeof GeneratePostAnnotation.State,
_config: LangGraphRunnableConfig,
config: LangGraphRunnableConfig,
): Promise<Partial<typeof GeneratePostAnnotation.State>> {
const reportModel = new ChatAnthropic({
model: "claude-3-5-sonnet-20241022",
const reportModel = await getModelFromConfig(config, {
temperature: 0,
});

Expand Down
5 changes: 2 additions & 3 deletions src/agents/generate-post/nodes/geterate-post/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { GeneratePostAnnotation } from "../../generate-post-state.js";
import { ChatAnthropic } from "@langchain/anthropic";
import { GENERATE_POST_PROMPT, REFLECTIONS_PROMPT } from "./prompts.js";
import { formatPrompt, parseGeneration } from "./utils.js";
import { ALLOWED_TIMES } from "../../constants.js";
import { getReflections, RULESET_KEY } from "../../../../utils/reflections.js";
import { getNextSaturdayDate } from "../../../../utils/date.js";
import { getModelFromConfig } from "../../../utils.js";

export async function generatePost(
state: typeof GeneratePostAnnotation.State,
Expand All @@ -17,8 +17,7 @@ export async function generatePost(
if (state.relevantLinks.length === 0) {
throw new Error("No relevant links found");
}
const postModel = new ChatAnthropic({
model: "claude-3-5-sonnet-20241022",
const postModel = await getModelFromConfig(config, {
temperature: 0.5,
});

Expand Down
3 changes: 2 additions & 1 deletion src/agents/generate-post/nodes/human-node/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Here is the report that was generated for the posts:\n${report}

export async function humanNode(
state: typeof GeneratePostAnnotation.State,
_config: LangGraphRunnableConfig,
config: LangGraphRunnableConfig,
): Promise<Partial<typeof GeneratePostAnnotation.State>> {
if (!state.post) {
throw new Error("No post found");
Expand Down Expand Up @@ -153,6 +153,7 @@ export async function humanNode(
post: state.post,
dateOrPriority: defaultDateString,
userResponse: response.args,
config,
});

if (route === "rewrite_post") {
Expand Down
8 changes: 5 additions & 3 deletions src/agents/generate-post/nodes/human-node/route-response.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { ChatAnthropic } from "@langchain/anthropic";
import { z } from "zod";
import { getModelFromConfig } from "../../../utils.js";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

const ROUTE_RESPONSE_PROMPT = `You are an AI assistant tasked with routing a user's response to one of two possible routes based on their intention. The two possible routes are:

Expand Down Expand Up @@ -55,15 +56,16 @@ interface RouteResponseArgs {
post: string;
dateOrPriority: string;
userResponse: string;
config: LangGraphRunnableConfig;
}

export async function routeResponse({
post,
dateOrPriority,
userResponse,
config,
}: RouteResponseArgs) {
const model = new ChatAnthropic({
model: "claude-3-5-sonnet-20241022",
const model = await getModelFromConfig(config, {
temperature: 0,
});

Expand Down
7 changes: 3 additions & 4 deletions src/agents/generate-post/nodes/rewrite-post.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Client } from "@langchain/langgraph-sdk";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { GeneratePostAnnotation } from "../generate-post-state.js";
import { ChatAnthropic } from "@langchain/anthropic";
import { getModelFromConfig } from "../../utils.js";

const REWRITE_POST_PROMPT = `You're a highly regarded marketing employee at LangChain, working on crafting thoughtful and engaging content for LangChain's LinkedIn and Twitter pages.
You wrote a post for the LangChain LinkedIn and Twitter pages, however your boss has asked for some changes to be made before it can be published.
Expand Down Expand Up @@ -44,7 +44,7 @@ async function runReflections({

export async function rewritePost(
state: typeof GeneratePostAnnotation.State,
_config: LangGraphRunnableConfig,
config: LangGraphRunnableConfig,
): Promise<Partial<typeof GeneratePostAnnotation.State>> {
if (!state.post) {
throw new Error("No post found");
Expand All @@ -53,8 +53,7 @@ export async function rewritePost(
throw new Error("No user response found");
}

const rewritePostModel = new ChatAnthropic({
model: "claude-3-5-sonnet-20241022",
const rewritePostModel = await getModelFromConfig(config, {
temperature: 0.5,
});

Expand Down
7 changes: 6 additions & 1 deletion src/agents/generate-post/nodes/schedule-post/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { GeneratePostAnnotation } from "../../generate-post-state.js";
import { Client } from "@langchain/langgraph-sdk";
import { POST_TO_LINKEDIN_ORGANIZATION } from "../../constants.js";
import {
LLM_MODEL_NAME,
POST_TO_LINKEDIN_ORGANIZATION,
} from "../../constants.js";
import { getScheduledDateSeconds } from "./find-date.js";
import { SlackClient } from "../../../../clients/slack.js";
import { getFutureDate } from "./get-future-date.js";
Expand Down Expand Up @@ -40,6 +43,8 @@ export async function schedulePost(
[POST_TO_LINKEDIN_ORGANIZATION]:
config.configurable?.[POST_TO_LINKEDIN_ORGANIZATION] ||
process.env.POST_TO_LINKEDIN_ORGANIZATION,
[LLM_MODEL_NAME]:
config.configurable?.[LLM_MODEL_NAME] || "gemini-2.0-flash-exp",
},
},
afterSeconds,
Expand Down
13 changes: 8 additions & 5 deletions src/agents/generate-post/nodes/update-scheduled-date.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { z } from "zod";
import { GeneratePostAnnotation } from "../generate-post-state.js";
import { ChatAnthropic } from "@langchain/anthropic";
import { toZonedTime } from "date-fns-tz";
import { DateType } from "../../types.js";
import { timezoneToUtc } from "../../../utils/date.js";
import { getModelFromConfig } from "../../utils.js";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

const SCHEDULE_POST_DATE_PROMPT = `You're an intelligent AI assistant tasked with extracting the date to schedule a social media post from the user's message.

Expand Down Expand Up @@ -40,14 +41,16 @@ const scheduleDateSchema = z.object({

export async function updateScheduledDate(
state: typeof GeneratePostAnnotation.State,
config: LangGraphRunnableConfig,
): Promise<Partial<typeof GeneratePostAnnotation.State>> {
if (!state.userResponse) {
throw new Error("No user response found");
}
const model = new ChatAnthropic({
model: "claude-3-5-sonnet-20241022",
temperature: 0.5,
}).withStructuredOutput(scheduleDateSchema, {
const model = (
await getModelFromConfig(config, {
temperature: 0.5,
})
).withStructuredOutput(scheduleDateSchema, {
name: "scheduleDate",
});
const pstDate = toZonedTime(new Date(), "America/Los_Angeles");
Expand Down
7 changes: 6 additions & 1 deletion src/agents/ingest-data/ingest-data-graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import {
} from "./ingest-data-state.js";
import { ingestSlackData } from "./nodes/ingest-slack.js";
import { Client } from "@langchain/langgraph-sdk";
import { POST_TO_LINKEDIN_ORGANIZATION } from "../generate-post/constants.js";
import {
LLM_MODEL_NAME,
POST_TO_LINKEDIN_ORGANIZATION,
} from "../generate-post/constants.js";
import { getUrlType } from "../utils.js";

/**
Expand Down Expand Up @@ -69,6 +72,8 @@ async function generatePostFromMessages(
config: {
configurable: {
[POST_TO_LINKEDIN_ORGANIZATION]: postToLinkedInOrg,
[LLM_MODEL_NAME]:
config.configurable?.[LLM_MODEL_NAME] || "gemini-2.0-flash-exp",
},
},
afterSeconds,
Expand Down
15 changes: 14 additions & 1 deletion src/agents/ingest-data/ingest-data-state.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import { Annotation } from "@langchain/langgraph";
import { SimpleSlackMessage } from "../../clients/slack.js";
import { POST_TO_LINKEDIN_ORGANIZATION } from "../generate-post/constants.js";
import {
LLM_MODEL_NAME,
POST_TO_LINKEDIN_ORGANIZATION,
} from "../generate-post/constants.js";

export type LangChainProduct = "langchain" | "langgraph" | "langsmith";
export type SimpleSlackMessageWithLinks = SimpleSlackMessage & {
Expand Down Expand Up @@ -50,4 +53,14 @@ export const IngestDataConfigurableAnnotation = Annotation.Root({
* If true, [LINKEDIN_ORGANIZATION_ID] is required.
*/
[POST_TO_LINKEDIN_ORGANIZATION]: Annotation<boolean | undefined>,
/**
* The name of the LLM to use for generations
* @default "gemini-2.0-flash-exp"
*/
[LLM_MODEL_NAME]: Annotation<
"gemini-2.0-flash-exp" | "claude-3-5-sonnet-latest" | undefined
>({
reducer: (_state, update) => update,
default: () => "gemini-2.0-flash-exp",
}),
});
5 changes: 2 additions & 3 deletions src/agents/reflection/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import {
StateGraph,
} from "@langchain/langgraph";
import { z } from "zod";
import { ChatAnthropic } from "@langchain/anthropic";
import {
getReflections,
putReflections,
RULESET_KEY,
} from "../../utils/reflections.js";
import { REFLECTION_PROMPT, UPDATE_RULES_PROMPT } from "./prompts.js";
import { getModelFromConfig } from "../utils.js";

const newRuleSchema = z.object({
newRule: z.string().describe("The new rule to create."),
Expand Down Expand Up @@ -47,8 +47,7 @@ async function reflection(
if (!config.store) {
throw new Error("No store provided");
}
const model = new ChatAnthropic({
model: "claude-3-5-sonnet-latest",
const model = await getModelFromConfig(config, {
temperature: 0,
});

Expand Down
20 changes: 12 additions & 8 deletions src/agents/shared/nodes/verify-general.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { GeneratePostAnnotation } from "../../generate-post/generate-post-state.js";
import { z } from "zod";
import { ChatAnthropic } from "@langchain/anthropic";
import { FireCrawlLoader } from "@langchain/community/document_loaders/web/firecrawl";
import { LANGCHAIN_PRODUCTS_CONTEXT } from "../../generate-post/prompts.js";
import { VerifyContentAnnotation } from "../shared-state.js";
import { RunnableLambda } from "@langchain/core/runnables";
import { getPageText } from "../../utils.js";
import { getModelFromConfig, getPageText } from "../../utils.js";

type VerifyGeneralContentReturn = {
relevantLinks: (typeof GeneratePostAnnotation.State)["relevantLinks"];
Expand Down Expand Up @@ -87,11 +86,13 @@ export async function getUrlContents(url: string): Promise<UrlContents> {

export async function verifyGeneralContentIsRelevant(
content: string,
config: LangGraphRunnableConfig,
): Promise<boolean> {
const relevancyModel = new ChatAnthropic({
model: "claude-3-5-sonnet-20241022",
temperature: 0,
}).withStructuredOutput(RELEVANCY_SCHEMA, {
const relevancyModel = (
await getModelFromConfig(config, {
temperature: 0,
})
).withStructuredOutput(RELEVANCY_SCHEMA, {
name: "relevancy",
});

Expand Down Expand Up @@ -125,14 +126,17 @@ export async function verifyGeneralContentIsRelevant(
*/
export async function verifyGeneralContent(
state: typeof VerifyContentAnnotation.State,
_config: LangGraphRunnableConfig,
config: LangGraphRunnableConfig,
): Promise<VerifyGeneralContentReturn> {
const urlContents = await new RunnableLambda<string, UrlContents>({
func: getUrlContents,
})
.withConfig({ runName: "get-url-contents" })
.invoke(state.link);
const relevant = await verifyGeneralContentIsRelevant(urlContents.content);
const relevant = await verifyGeneralContentIsRelevant(
urlContents.content,
config,
);

if (relevant) {
return {
Expand Down
Loading
Loading