From d87c8b36daeee6b3f55dcc985dce25e6829cc876 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Wed, 15 Jan 2025 13:43:45 -0800 Subject: [PATCH] Get all tests passing. --- agentstack/frameworks/crewai.py | 13 ++-- agentstack/frameworks/langgraph.py | 22 +++---- agentstack/proj_templates.py | 16 ++--- pyproject.toml | 2 +- .../frameworks/langgraph/entrypoint_max.py | 62 +++++++++++++++++++ .../frameworks/langgraph/entrypoint_min.py | 30 +++++++++ tests/test_frameworks.py | 11 ++-- tests/test_generation_agent.py | 4 +- tests/test_templates_config.py | 9 +-- 9 files changed, 127 insertions(+), 42 deletions(-) create mode 100644 tests/fixtures/frameworks/langgraph/entrypoint_max.py create mode 100644 tests/fixtures/frameworks/langgraph/entrypoint_min.py diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 0ab477f..f3e7366 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -21,9 +21,6 @@ class CrewFile(asttools.File): Parses and manipulates the CrewAI entrypoint file. All AST interactions should happen within the methods of this class. """ - - _base_class: Optional[ast.ClassDef] = None - def write(self): """ Early versions of the crew entrypoint file used tabs instead of spaces. @@ -35,12 +32,10 @@ def write(self): def get_base_class(self) -> ast.ClassDef: """A base class is a class decorated with `@CrewBase`.""" - if self._base_class is None: # Gets cached to save repeat iteration - try: - self._base_class = asttools.find_class_with_decorator(self.tree, 'CrewBase')[0] - except IndexError: - raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}") - return self._base_class + try: + return asttools.find_class_with_decorator(self.tree, 'CrewBase')[0] + except IndexError: + raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}") def get_crew_method(self) -> ast.FunctionDef: """A `crew` method is a method decorated with `@crew`.""" diff --git a/agentstack/frameworks/langgraph.py b/agentstack/frameworks/langgraph.py index 745b120..d91f8f4 100644 --- a/agentstack/frameworks/langgraph.py +++ b/agentstack/frameworks/langgraph.py @@ -15,19 +15,15 @@ class LangGraphFile(asttools.File): """ Parses and manipulates the LangGraph entrypoint file. """ - _base_class: Optional[ast.ClassDef] = None - def get_base_class(self) -> ast.ClassDef: """ A base class is the first class inside of the file that follows the naming convention: `Graph` """ - if self._base_class is None: # gets cached to save repeat iteration - try: - self._base_class = asttools.find_class_with_regex(self.tree, r'\w+Graph$')[0] - except IndexError: - raise ValidationError(f"`Graph` class not found in {ENTRYPOINT}") - return self._base_class + try: + return asttools.find_class_with_regex(self.tree, r'\w+Graph$')[0] + except IndexError: + raise ValidationError(f"`Graph` class not found in {ENTRYPOINT}") def get_run_method(self) -> ast.FunctionDef: """A method named `run`.""" @@ -211,8 +207,8 @@ def add_agent_tools(self, agent_name: str, tool: ToolConfig): existing_global_elts.append(asttools.create_tool_node(tool.name)) new_global_node = ast.List(elts=existing_global_elts, ctx=ast.Load()) - start, end = self.get_node_range(existing_global_node) - self.edit_node_range(start, end, new_global_node) + global_start, global_end = self.get_node_range(existing_global_node) + self.edit_node_range(global_start, global_end, new_global_node) def remove_agent_tools(self, agent_name: str, tool: ToolConfig): """ @@ -232,11 +228,11 @@ def remove_agent_tools(self, agent_name: str, tool: ToolConfig): # remove the tool from the global tools list existing_global_node: ast.List = self.get_global_tools() - start, end = self.get_node_range(existing_global_node) + global_start, global_end = self.get_node_range(existing_global_node) for node in self.get_global_tool_nodes(): if tool.name == node.value.slice.value: # type: ignore[attr-defined] existing_global_node.elts.remove(node) - self.edit_node_range(start, end, existing_global_node) + self.edit_node_range(global_start, global_end, existing_global_node) def validate_project() -> None: @@ -259,7 +255,7 @@ def validate_project() -> None: # as a keyword argument. try: node = graph_file.get_run_method() - assert 'inputs' in (arg.arg for arg in node.args.kwonlyargs), \ + assert 'inputs' in (arg.arg for arg in node.args.args), \ f"Method `run` of `{class_node.name}` must accept `inputs` as a keyword argument." except (AssertionError, ValidationError) as e: raise e diff --git a/agentstack/proj_templates.py b/agentstack/proj_templates.py index 49bdda8..a38a397 100644 --- a/agentstack/proj_templates.py +++ b/agentstack/proj_templates.py @@ -8,6 +8,8 @@ from agentstack.utils import get_package_path +CURRENT_VERSION: int = 4 + class TemplateConfig_v1(pydantic.BaseModel): name: str description: str @@ -23,7 +25,7 @@ def to_v4(self) -> 'TemplateConfig': return TemplateConfig( name=self.name, description=self.description, - template_version=4, + template_version=CURRENT_VERSION, framework=self.framework, method=self.method, manager_agent=None, @@ -67,7 +69,7 @@ def to_v4(self) -> 'TemplateConfig': return TemplateConfig( name=self.name, description=self.description, - template_version=4, + template_version=CURRENT_VERSION, framework=self.framework, method=self.method, manager_agent=None, @@ -113,13 +115,13 @@ def to_v4(self) -> 'TemplateConfig': return TemplateConfig( name=self.name, description=self.description, - template_version=4, + template_version=CURRENT_VERSION, framework=self.framework, method=self.method, manager_agent=self.manager_agent, - agents=[TemplateConfig.Agent(**agent.dict()) for agent in self.agents], - tasks=[TemplateConfig.Task(**task.dict()) for task in self.tasks], - tools=[TemplateConfig.Tool(**tool.dict()) for tool in self.tools], + agents=[TemplateConfig.Agent(**agent.model_dump()) for agent in self.agents], + tasks=[TemplateConfig.Task(**task.model_dump()) for task in self.tasks], + tools=[TemplateConfig.Tool(**tool.model_dump()) for tool in self.tools], graph=[], inputs=self.inputs, ) @@ -181,7 +183,7 @@ class Node(pydantic.BaseModel): name: str description: str - template_version: Literal[4] + template_version: Literal[CURRENT_VERSION] framework: str method: str manager_agent: Optional[str] = None diff --git a/pyproject.toml b/pyproject.toml index 4d83d55..26e31a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ crewai = [ ] langgraph = [ "langgraph>=0.2.61", - "langchain_anthropic>=0.31" + "langchain_anthropic>=0.3.1" ] all = [ "agentstack[dev,test,crewai,langgraph]", diff --git a/tests/fixtures/frameworks/langgraph/entrypoint_max.py b/tests/fixtures/frameworks/langgraph/entrypoint_max.py new file mode 100644 index 0000000..a349e35 --- /dev/null +++ b/tests/fixtures/frameworks/langgraph/entrypoint_max.py @@ -0,0 +1,62 @@ +from typing import Annotated +from typing_extensions import TypedDict + +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain.prompts import ChatPromptTemplate +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages +from langgraph.prebuilt import ToolNode, tools_condition + +import agentstack + + +class State(TypedDict): + inputs: dict[str, str] + messages: Annotated[list, add_messages] + + +class TestGraph: + @agentstack.agent + def test_agent(self, state: State): + agent_config = agentstack.get_agent('test_agent') + messages = ChatPromptTemplate.from_messages([ + ("user", agent_config.prompt), + ]) + messages = messages.format_messages(**state['inputs']) + agent = ChatOpenAI(model=agent_config.model) + agent = agent.bind_tools([]) + response = agent.invoke( + messages + state['messages'], + ) + return {'messages': [response, ]} + + @agentstack.task + def test_task(self, state: State): + task_config = agentstack.get_task('test_task') + messages = ChatPromptTemplate.from_messages([ + ("user", task_config.prompt), + ]) + messages = messages.format_messages(**state['inputs']) + return {'messages': messages + state['messages']} + + def run(self, inputs: list[str]): + self.graph = StateGraph(State) + tools = ToolNode([]) + self.graph.add_node("tools", tools) + + self.graph.add_node("test_agent", self.test_agent) + self.graph.add_edge("test_agent", "tools") + self.graph.add_conditional_edges("test_agent", tools_condition) + + self.graph.add_edge(START, "test_task") + self.graph.add_edge("test_task", "test_agent") + self.graph.add_edge("test_agent", END) + + app = self.graph.compile() + result = app.invoke({ + 'inputs': inputs, + 'messages': [], + }) + print(result['messages'][-1].content) + diff --git a/tests/fixtures/frameworks/langgraph/entrypoint_min.py b/tests/fixtures/frameworks/langgraph/entrypoint_min.py new file mode 100644 index 0000000..4e3e0ab --- /dev/null +++ b/tests/fixtures/frameworks/langgraph/entrypoint_min.py @@ -0,0 +1,30 @@ +from typing import Annotated +from typing_extensions import TypedDict + +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain.prompts import ChatPromptTemplate +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages +from langgraph.prebuilt import ToolNode, tools_condition + +import agentstack + + +class State(TypedDict): + inputs: dict[str, str] + messages: Annotated[list, add_messages] + + +class TestGraph: + def run(self, inputs: list[str]): + self.graph = StateGraph(State) + tools = ToolNode([]) + self.graph.add_node("tools", tools) + + app = self.graph.compile() + result = app.invoke({ + 'inputs': inputs, + 'messages': [], + }) + diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 0a87a07..8e0a61e 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -71,8 +71,7 @@ def test_add_tool(self): frameworks.add_tool(self._get_test_tool(), 'test_agent') entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() - # TODO these asserts are not framework agnostic - assert "tools=[*agentstack.tools['test_tool']" in entrypoint_src + assert "*agentstack.tools['test_tool']" in entrypoint_src def test_add_tool_invalid(self): self._populate_min_entrypoint() @@ -85,7 +84,7 @@ def test_remove_tool(self): frameworks.remove_tool(self._get_test_tool(), 'test_agent') entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() - assert "tools=[*agentstack.tools['test_tool']" not in entrypoint_src + assert "*agentstack.tools['test_tool']" not in entrypoint_src def test_add_multiple_tools(self): self._populate_max_entrypoint() @@ -94,8 +93,8 @@ def test_add_multiple_tools(self): entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() assert ( # ordering is not guaranteed - "tools=[*agentstack.tools['test_tool'], *agentstack.tools['test_tool_alt']" in entrypoint_src - or "tools=[*agentstack.tools['test_tool_alt'], *agentstack.tools['test_tool']" in entrypoint_src + "*agentstack.tools['test_tool'], *agentstack.tools['test_tool_alt']" in entrypoint_src + or "*agentstack.tools['test_tool_alt'], *agentstack.tools['test_tool']" in entrypoint_src ) def test_remove_one_tool_of_multiple(self): @@ -106,4 +105,4 @@ def test_remove_one_tool_of_multiple(self): entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() assert "*agentstack.tools['test_tool']" not in entrypoint_src - assert "tools=[*agentstack.tools['test_tool_alt']" in entrypoint_src + assert "*agentstack.tools['test_tool_alt']" in entrypoint_src diff --git a/tests/test_generation_agent.py b/tests/test_generation_agent.py index 3e3e90c..7bf0189 100644 --- a/tests/test_generation_agent.py +++ b/tests/test_generation_agent.py @@ -42,7 +42,7 @@ def test_add_agent(self): role='role', goal='goal', backstory='backstory', - llm='llm', + llm='openai/gpt-4o', ) entrypoint_path = frameworks.get_entrypoint_path(self.framework) @@ -60,5 +60,5 @@ def test_add_agent_exists(self): role='role', goal='goal', backstory='backstory', - llm='llm', + llm='openai/gpt-4o', ) diff --git a/tests/test_templates_config.py b/tests/test_templates_config.py index 63779d5..045bade 100644 --- a/tests/test_templates_config.py +++ b/tests/test_templates_config.py @@ -7,6 +7,7 @@ from parameterized import parameterized from agentstack.exceptions import ValidationError from agentstack.proj_templates import ( + CURRENT_VERSION, TemplateConfig, get_all_template_names, get_all_template_paths, @@ -88,12 +89,12 @@ def test_write_to_file_without_suffix(self): def test_from_user_input_url(self): config = TemplateConfig.from_user_input(VALID_TEMPLATE_URL) self.assertEqual(config.name, "content_creator") - self.assertEqual(config.template_version, 3) + self.assertEqual(config.template_version, CURRENT_VERSION) def test_from_user_input_name(self): config = TemplateConfig.from_user_input('content_creator') self.assertEqual(config.name, "content_creator") - self.assertEqual(config.template_version, 3) + self.assertEqual(config.template_version, CURRENT_VERSION) def test_from_user_input_local_file(self): test_file = self.project_dir / 'test_local_template.json' @@ -116,7 +117,7 @@ def test_from_user_input_local_file(self): config = TemplateConfig.from_user_input(str(test_file)) self.assertEqual(config.name, "test_local") - self.assertEqual(config.template_version, 3) + self.assertEqual(config.template_version, CURRENT_VERSION) def test_from_file_missing_file(self): non_existent_path = Path("/path/to/non_existent_file.json") @@ -158,7 +159,7 @@ def test_from_json_pydantic_validation_error(self): invalid_template = { "name": "invalid_template", "description": "A template with invalid data", - "template_version": 3, + "template_version": CURRENT_VERSION, "framework": "test", "method": "test", "manager_agent": None,