Skip to content

Commit

Permalink
fix graph creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Pwuts committed Oct 21, 2024
1 parent f117d3f commit 9a9ae5d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
56 changes: 47 additions & 9 deletions autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,21 @@ def __hash__(self):
return hash((self.source_id, self.sink_id, self.source_name, self.sink_name))


class Node(BaseDbModel):
class CreatableNode(BaseDbModel):
block_id: str
input_default: BlockInput = {} # dict[input_name, default_value]
metadata: dict[str, Any] = {}
input_links: list[Link] = []
output_links: list[Link] = []

graph_id: str
graph_version: int

webhook_id: Optional[str] = None
webhook: Optional[Webhook] = None


class Node(CreatableNode):
graph_id: str
graph_version: int

@staticmethod
def from_db(node: AgentNode):
if not node.AgentBlock:
Expand Down Expand Up @@ -124,15 +126,18 @@ def from_agent_graph_execution(execution: AgentGraphExecution):
)


class GraphMeta(BaseDbModel):
user_id: str
class CreatableGraphMeta(BaseDbModel):
version: int = 1
is_active: bool = True
is_template: bool = False
name: str
description: str
executions: list[ExecutionMeta] | None = None


class GraphMeta(CreatableGraphMeta):
user_id: str

@staticmethod
def from_db(graph: AgentGraph):
if graph.AgentGraphExecution:
Expand All @@ -155,11 +160,15 @@ def from_db(graph: AgentGraph):
)


class Graph(GraphMeta):
nodes: list[Node]
class CreatableGraph(CreatableGraphMeta):
nodes: list[CreatableNode]
links: list[Link]
subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id]


class Graph(CreatableGraph, GraphMeta):
nodes: list[Node]

@property
def starting_nodes(self) -> list[Node]:
outbound_nodes = {link.sink_id for link in self.links}
Expand Down Expand Up @@ -409,7 +418,7 @@ def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]:
}


# --------------------- Model functions --------------------- #
# --------------------- CRUD functions --------------------- #


async def get_node(node_id: str) -> Node:
Expand Down Expand Up @@ -640,3 +649,32 @@ async def __create_graph(tx, graph: Graph, user_id: str):
for link in graph.links
]
)


# ------------------------ UTILITIES ------------------------ #


def graph_from_creatable(creatable_graph: CreatableGraph, user_id: str) -> Graph:
"""
Convert a CreatableGraph to a Graph, setting graph_id and graph_version on all nodes.
Args:
creatable_graph (CreatableGraph): The creatable graph to convert.
user_id (str): The ID of the user creating the graph.
Returns:
Graph: The converted Graph object.
"""
# Create a new Graph object, inheriting properties from CreatableGraph
return Graph(
**creatable_graph.model_dump(),
user_id=user_id,
nodes=[
Node(
**creatable_node.model_dump(),
graph_id=creatable_graph.id,
graph_version=creatable_graph.version,
)
for creatable_node in creatable_graph.nodes
],
)
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SubscriptionDetails(pydantic.BaseModel):
class CreateGraph(pydantic.BaseModel):
template_id: str | None = None
template_version: int | None = None
graph: backend.data.graph.Graph | None = None
graph: backend.data.graph.CreatableGraph | None = None


class SetGraphActiveVersion(pydantic.BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ async def create_graph(
user_id: str,
) -> graph_db.Graph:
if create_graph.graph:
graph = create_graph.graph
graph = graph_db.graph_from_creatable(create_graph.graph, user_id)
elif create_graph.template_id:
# Create a new graph from a template
graph = await graph_db.get_graph(
Expand Down

0 comments on commit 9a9ae5d

Please sign in to comment.