Skip to content

Commit

Permalink
draft a demo code for memory.
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 committed Nov 4, 2024
1 parent c1c5798 commit d204e5e
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 5 deletions.
11 changes: 10 additions & 1 deletion comps/agent/langchain/src/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from comps.cores.proto.agents import AgentConfig
from .utils import load_python_prompt


Expand All @@ -10,6 +11,14 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
else:
custom_prompt = None

agent_config = AgentConfig(
model = args.llm_engine,
with_memory = with_memory,
custom_prompt = custom_prompt,
tools = args.tools
enable_session_persistence=False,
)

if strategy == "react_langchain":
from .strategy.react import ReActAgentwithLangchain

Expand All @@ -22,7 +31,7 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
print("Initializing ReAct Agent with LLAMA")
from .strategy.react import ReActAgentLlama

return ReActAgentLlama(args, with_memory, custom_prompt=custom_prompt)
return ReActAgentLlama(args, agent_config)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

Expand Down
20 changes: 19 additions & 1 deletion comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class BaseAgent:
def __init__(self, args, local_vars=None, **kwargs) -> None:
def __init__(self, args, local_vars=None, agent_config=None, **kwargs) -> None:
self.llm = setup_chat_model(args)
self.tools_descriptions = get_tools_descriptions(args.tools)
self.app = None
Expand All @@ -18,6 +18,21 @@ def __init__(self, args, local_vars=None, **kwargs) -> None:
adapt_custom_prompt(local_vars, kwargs.get("custom_prompt"))
print(self.tools_descriptions)

self.storage = None
if agent_config.enable_session_persistence:
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore import KVStoreConfig
# need async
# self.persistence_store = await kvstore_impl(self.config.persistence_store)
self.persistence_store = await kvstore_impl(KVStoreConfig())

await self.persistence_store.set(
key=f"agent:{self.id}",
value=agent_config.json(),
)

self.storage = AgentPersistence(self.id, self.persistence_store)

@property
def is_vllm(self):
return self.args.llm_engine == "vllm"
Expand All @@ -38,3 +53,6 @@ def execute(self, state: dict):

def non_streaming_run(self, query, config):
raise NotImplementedError

async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
47 changes: 44 additions & 3 deletions comps/agent/langchain/src/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def __call__(self, state):


class ReActAgentLlama(BaseAgent):
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, agent_config=None, **kwargs):
super().__init__(args, local_vars=globals(), agent_config=agent_config, **kwargs)
agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args)
tool_node = ToolNode(self.tools_descriptions)

Expand Down Expand Up @@ -265,7 +265,26 @@ def should_continue(self, state: AgentState):
return "continue"

def prepare_initial_state(self, query):
return {"messages": [HumanMessage(content=query)]}

session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")

turns = await self.storage.get_session_turns(request.session_id)

messages = []
if len(turns) == 0 and self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))

for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))

messages.extend(request.messages)

self.turn_id = str(uuid.uuid4())

# return {"messages": [HumanMessage(content=query)]}
return {"messages": messages}

async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
Expand All @@ -277,6 +296,17 @@ async def stream_generator(self, query, config):
if v is not None:
yield f"{k}: {v}\n"

turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)

yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
Expand All @@ -292,6 +322,17 @@ async def non_streaming_run(self, query, config):
else:
message.pretty_print()

turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)

last_message = s["messages"][-1]
print("******Response: ", last_message.content)
return last_message.content
Expand Down
4 changes: 4 additions & 0 deletions comps/cores/proto/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .agents import AgentConfig
10 changes: 10 additions & 0 deletions comps/cores/proto/agents/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel, ConfigDict, Field


class AgentConfig(BaseModel):
model: str = None
instructions: str = None
enable_session_persistence: bool = False
with_memory: bool = False
tools: str = None
custom_prompt: str = None

0 comments on commit d204e5e

Please sign in to comment.