-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.ts
146 lines (124 loc) · 4.85 KB
/
agent.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import { OpenAIEmbeddings } from "@langchain/openai";
import { ChatOpenAI } from "@langchain/openai";
import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages";
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { StateGraph } from "@langchain/langgraph";
import { Annotation } from "@langchain/langgraph";
import { tool } from "@langchain/core/tools";
import { ToolNode } from "@langchain/langgraph/prebuilt";
import { MongoDBSaver } from "@langchain/langgraph-checkpoint-mongodb";
import { MongoDBAtlasVectorSearch } from "@langchain/mongodb";
import { MongoClient } from "mongodb";
import { z } from "zod";
import "dotenv/config";
export async function callAgent(client: MongoClient, query: string, thread_id: string) {
// Define the MongoDB database and collection
const dbName = "hr_database";
const db = client.db(dbName);
const collection = db.collection("employees");
// Define the graph state
const GraphState = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: (x, y) => x.concat(y),
}),
});
// Configure the OpenRouter model
const model = new ChatOpenAI({
configuration: {
baseURL: "https://openrouter.ai/api/v1",
apiKey: process.env.OPENROUTER_API_KEY,
defaultHeaders: {
"HTTP-Referer": "http://localhost:3000",
"X-Title": "HR Chatbot Agent",
},
},
modelName: "mistralai/mistral-7b-instruct",
temperature: 0,
maxTokens: 1000,
maxRetries: 3,
});
// Define the tools for the agent to use
const employeeLookupTool = tool(
async ({ query, n = 10 }) => {
console.log("Employee lookup tool called");
const dbConfig = {
collection: collection,
indexName: "vector_index",
textKey: "embedding_text",
embeddingKey: "embedding",
};
// Initialize vector store with OpenRouter configuration
const vectorStore = new MongoDBAtlasVectorSearch(
new OpenAIEmbeddings({
openAIApiKey: process.env.OPENAI_API_KEY, // Note: Still need OpenAI for embeddings
}),
dbConfig
);
const result = await vectorStore.similaritySearchWithScore(query, n);
return JSON.stringify(result);
},
{
name: "employee_lookup",
description: "Gathers employee details from the HR database",
schema: z.object({
query: z.string().describe("The search query"),
n: z
.number()
.optional()
.default(10)
.describe("Number of results to return"),
}),
}
);
const tools = [employeeLookupTool];
const toolNode = new ToolNode<typeof GraphState.State>(tools);
// Bind tools to the model
const modelWithTools = model.bindTools(tools);
// Define the function that determines whether to continue or not
function shouldContinue(state: typeof GraphState.State) {
const messages = state.messages;
const lastMessage = messages[messages.length - 1] as AIMessage;
if (lastMessage.tool_calls?.length) {
return "tools";
}
return "__end__";
}
// Define the function that calls the model
async function callModel(state: typeof GraphState.State) {
const prompt = ChatPromptTemplate.fromMessages([
[
"system",
`You are a helpful AI assistant, collaborating with other assistants. Use the provided tools to progress towards answering the question. If you are unable to fully answer, that's OK, another assistant with different tools will help where you left off. Execute what you can to make progress. If you or any of the other assistants have the final answer or deliverable, prefix your response with FINAL ANSWER so the team knows to stop. You have access to the following tools: {tool_names}.\n{system_message}\nCurrent time: {time}.`,
],
new MessagesPlaceholder("messages"),
]);
const formattedPrompt = await prompt.formatMessages({
system_message: "You are helpful HR Chatbot Agent.",
time: new Date().toISOString(),
tool_names: tools.map((tool) => tool.name).join(", "),
messages: state.messages,
});
const result = await modelWithTools.invoke(formattedPrompt);
return { messages: [result] };
}
// Define the graph
const workflow = new StateGraph(GraphState)
.addNode("agent", callModel)
.addNode("tools", toolNode)
.addEdge("__start__", "agent")
.addConditionalEdges("agent", shouldContinue)
.addEdge("tools", "agent");
const checkpointer = new MongoDBSaver({ client, dbName });
const app = workflow.compile({ checkpointer });
const finalState = await app.invoke(
{
messages: [new HumanMessage(query)],
},
{ recursionLimit: 15, configurable: { thread_id: thread_id } }
);
console.log(finalState.messages[finalState.messages.length - 1].content);
return finalState.messages[finalState.messages.length - 1].content;
}