Skip to content

Commit

Permalink
Get all tests passing.
Browse files Browse the repository at this point in the history
  • Loading branch information
tcdent committed Jan 15, 2025
1 parent fb6146b commit d87c8b3
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 42 deletions.
13 changes: 4 additions & 9 deletions agentstack/frameworks/crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class CrewFile(asttools.File):
Parses and manipulates the CrewAI entrypoint file.
All AST interactions should happen within the methods of this class.
"""

_base_class: Optional[ast.ClassDef] = None

def write(self):
"""
Early versions of the crew entrypoint file used tabs instead of spaces.
Expand All @@ -35,12 +32,10 @@ def write(self):

def get_base_class(self) -> ast.ClassDef:
"""A base class is a class decorated with `@CrewBase`."""
if self._base_class is None: # Gets cached to save repeat iteration
try:
self._base_class = asttools.find_class_with_decorator(self.tree, 'CrewBase')[0]
except IndexError:
raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}")
return self._base_class
try:
return asttools.find_class_with_decorator(self.tree, 'CrewBase')[0]
except IndexError:
raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}")

def get_crew_method(self) -> ast.FunctionDef:
"""A `crew` method is a method decorated with `@crew`."""
Expand Down
22 changes: 9 additions & 13 deletions agentstack/frameworks/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@ class LangGraphFile(asttools.File):
"""
Parses and manipulates the LangGraph entrypoint file.
"""
_base_class: Optional[ast.ClassDef] = None

def get_base_class(self) -> ast.ClassDef:
"""
A base class is the first class inside of the file that follows the
naming convention: `<FooBar>Graph`
"""
if self._base_class is None: # gets cached to save repeat iteration
try:
self._base_class = asttools.find_class_with_regex(self.tree, r'\w+Graph$')[0]
except IndexError:
raise ValidationError(f"`<FooBar>Graph` class not found in {ENTRYPOINT}")
return self._base_class
try:
return asttools.find_class_with_regex(self.tree, r'\w+Graph$')[0]
except IndexError:
raise ValidationError(f"`<FooBar>Graph` class not found in {ENTRYPOINT}")

def get_run_method(self) -> ast.FunctionDef:
"""A method named `run`."""
Expand Down Expand Up @@ -211,8 +207,8 @@ def add_agent_tools(self, agent_name: str, tool: ToolConfig):
existing_global_elts.append(asttools.create_tool_node(tool.name))

new_global_node = ast.List(elts=existing_global_elts, ctx=ast.Load())
start, end = self.get_node_range(existing_global_node)
self.edit_node_range(start, end, new_global_node)
global_start, global_end = self.get_node_range(existing_global_node)
self.edit_node_range(global_start, global_end, new_global_node)

def remove_agent_tools(self, agent_name: str, tool: ToolConfig):
"""
Expand All @@ -232,11 +228,11 @@ def remove_agent_tools(self, agent_name: str, tool: ToolConfig):

# remove the tool from the global tools list
existing_global_node: ast.List = self.get_global_tools()
start, end = self.get_node_range(existing_global_node)
global_start, global_end = self.get_node_range(existing_global_node)
for node in self.get_global_tool_nodes():
if tool.name == node.value.slice.value: # type: ignore[attr-defined]
existing_global_node.elts.remove(node)
self.edit_node_range(start, end, existing_global_node)
self.edit_node_range(global_start, global_end, existing_global_node)


def validate_project() -> None:
Expand All @@ -259,7 +255,7 @@ def validate_project() -> None:
# as a keyword argument.
try:
node = graph_file.get_run_method()
assert 'inputs' in (arg.arg for arg in node.args.kwonlyargs), \
assert 'inputs' in (arg.arg for arg in node.args.args), \
f"Method `run` of `{class_node.name}` must accept `inputs` as a keyword argument."
except (AssertionError, ValidationError) as e:
raise e
Expand Down
16 changes: 9 additions & 7 deletions agentstack/proj_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from agentstack.utils import get_package_path


CURRENT_VERSION: int = 4

class TemplateConfig_v1(pydantic.BaseModel):
name: str
description: str
Expand All @@ -23,7 +25,7 @@ def to_v4(self) -> 'TemplateConfig':
return TemplateConfig(
name=self.name,
description=self.description,
template_version=4,
template_version=CURRENT_VERSION,
framework=self.framework,
method=self.method,
manager_agent=None,
Expand Down Expand Up @@ -67,7 +69,7 @@ def to_v4(self) -> 'TemplateConfig':
return TemplateConfig(
name=self.name,
description=self.description,
template_version=4,
template_version=CURRENT_VERSION,
framework=self.framework,
method=self.method,
manager_agent=None,
Expand Down Expand Up @@ -113,13 +115,13 @@ def to_v4(self) -> 'TemplateConfig':
return TemplateConfig(
name=self.name,
description=self.description,
template_version=4,
template_version=CURRENT_VERSION,
framework=self.framework,
method=self.method,
manager_agent=self.manager_agent,
agents=[TemplateConfig.Agent(**agent.dict()) for agent in self.agents],
tasks=[TemplateConfig.Task(**task.dict()) for task in self.tasks],
tools=[TemplateConfig.Tool(**tool.dict()) for tool in self.tools],
agents=[TemplateConfig.Agent(**agent.model_dump()) for agent in self.agents],
tasks=[TemplateConfig.Task(**task.model_dump()) for task in self.tasks],
tools=[TemplateConfig.Tool(**tool.model_dump()) for tool in self.tools],
graph=[],
inputs=self.inputs,
)
Expand Down Expand Up @@ -181,7 +183,7 @@ class Node(pydantic.BaseModel):

name: str
description: str
template_version: Literal[4]
template_version: Literal[CURRENT_VERSION]
framework: str
method: str
manager_agent: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ crewai = [
]
langgraph = [
"langgraph>=0.2.61",
"langchain_anthropic>=0.31"
"langchain_anthropic>=0.3.1"
]
all = [
"agentstack[dev,test,crewai,langgraph]",
Expand Down
62 changes: 62 additions & 0 deletions tests/fixtures/frameworks/langgraph/entrypoint_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Annotated
from typing_extensions import TypedDict

from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

import agentstack


class State(TypedDict):
inputs: dict[str, str]
messages: Annotated[list, add_messages]


class TestGraph:
@agentstack.agent
def test_agent(self, state: State):
agent_config = agentstack.get_agent('test_agent')
messages = ChatPromptTemplate.from_messages([
("user", agent_config.prompt),
])
messages = messages.format_messages(**state['inputs'])
agent = ChatOpenAI(model=agent_config.model)
agent = agent.bind_tools([])
response = agent.invoke(
messages + state['messages'],
)
return {'messages': [response, ]}

@agentstack.task
def test_task(self, state: State):
task_config = agentstack.get_task('test_task')
messages = ChatPromptTemplate.from_messages([
("user", task_config.prompt),
])
messages = messages.format_messages(**state['inputs'])
return {'messages': messages + state['messages']}

def run(self, inputs: list[str]):
self.graph = StateGraph(State)
tools = ToolNode([])
self.graph.add_node("tools", tools)

self.graph.add_node("test_agent", self.test_agent)
self.graph.add_edge("test_agent", "tools")
self.graph.add_conditional_edges("test_agent", tools_condition)

self.graph.add_edge(START, "test_task")
self.graph.add_edge("test_task", "test_agent")
self.graph.add_edge("test_agent", END)

app = self.graph.compile()
result = app.invoke({
'inputs': inputs,
'messages': [],
})
print(result['messages'][-1].content)

30 changes: 30 additions & 0 deletions tests/fixtures/frameworks/langgraph/entrypoint_min.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Annotated
from typing_extensions import TypedDict

from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

import agentstack


class State(TypedDict):
inputs: dict[str, str]
messages: Annotated[list, add_messages]


class TestGraph:
def run(self, inputs: list[str]):
self.graph = StateGraph(State)
tools = ToolNode([])
self.graph.add_node("tools", tools)

app = self.graph.compile()
result = app.invoke({
'inputs': inputs,
'messages': [],
})

11 changes: 5 additions & 6 deletions tests/test_frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def test_add_tool(self):
frameworks.add_tool(self._get_test_tool(), 'test_agent')

entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read()
# TODO these asserts are not framework agnostic
assert "tools=[*agentstack.tools['test_tool']" in entrypoint_src
assert "*agentstack.tools['test_tool']" in entrypoint_src

def test_add_tool_invalid(self):
self._populate_min_entrypoint()
Expand All @@ -85,7 +84,7 @@ def test_remove_tool(self):
frameworks.remove_tool(self._get_test_tool(), 'test_agent')

entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read()
assert "tools=[*agentstack.tools['test_tool']" not in entrypoint_src
assert "*agentstack.tools['test_tool']" not in entrypoint_src

def test_add_multiple_tools(self):
self._populate_max_entrypoint()
Expand All @@ -94,8 +93,8 @@ def test_add_multiple_tools(self):

entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read()
assert ( # ordering is not guaranteed
"tools=[*agentstack.tools['test_tool'], *agentstack.tools['test_tool_alt']" in entrypoint_src
or "tools=[*agentstack.tools['test_tool_alt'], *agentstack.tools['test_tool']" in entrypoint_src
"*agentstack.tools['test_tool'], *agentstack.tools['test_tool_alt']" in entrypoint_src
or "*agentstack.tools['test_tool_alt'], *agentstack.tools['test_tool']" in entrypoint_src
)

def test_remove_one_tool_of_multiple(self):
Expand All @@ -106,4 +105,4 @@ def test_remove_one_tool_of_multiple(self):

entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read()
assert "*agentstack.tools['test_tool']" not in entrypoint_src
assert "tools=[*agentstack.tools['test_tool_alt']" in entrypoint_src
assert "*agentstack.tools['test_tool_alt']" in entrypoint_src
4 changes: 2 additions & 2 deletions tests/test_generation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_add_agent(self):
role='role',
goal='goal',
backstory='backstory',
llm='llm',
llm='openai/gpt-4o',
)

entrypoint_path = frameworks.get_entrypoint_path(self.framework)
Expand All @@ -60,5 +60,5 @@ def test_add_agent_exists(self):
role='role',
goal='goal',
backstory='backstory',
llm='llm',
llm='openai/gpt-4o',
)
9 changes: 5 additions & 4 deletions tests/test_templates_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from parameterized import parameterized
from agentstack.exceptions import ValidationError
from agentstack.proj_templates import (
CURRENT_VERSION,
TemplateConfig,
get_all_template_names,
get_all_template_paths,
Expand Down Expand Up @@ -88,12 +89,12 @@ def test_write_to_file_without_suffix(self):
def test_from_user_input_url(self):
config = TemplateConfig.from_user_input(VALID_TEMPLATE_URL)
self.assertEqual(config.name, "content_creator")
self.assertEqual(config.template_version, 3)
self.assertEqual(config.template_version, CURRENT_VERSION)

def test_from_user_input_name(self):
config = TemplateConfig.from_user_input('content_creator')
self.assertEqual(config.name, "content_creator")
self.assertEqual(config.template_version, 3)
self.assertEqual(config.template_version, CURRENT_VERSION)

def test_from_user_input_local_file(self):
test_file = self.project_dir / 'test_local_template.json'
Expand All @@ -116,7 +117,7 @@ def test_from_user_input_local_file(self):

config = TemplateConfig.from_user_input(str(test_file))
self.assertEqual(config.name, "test_local")
self.assertEqual(config.template_version, 3)
self.assertEqual(config.template_version, CURRENT_VERSION)

def test_from_file_missing_file(self):
non_existent_path = Path("/path/to/non_existent_file.json")
Expand Down Expand Up @@ -158,7 +159,7 @@ def test_from_json_pydantic_validation_error(self):
invalid_template = {
"name": "invalid_template",
"description": "A template with invalid data",
"template_version": 3,
"template_version": CURRENT_VERSION,
"framework": "test",
"method": "test",
"manager_agent": None,
Expand Down

0 comments on commit d87c8b3

Please sign in to comment.