From 9a9ae5dbc470d6db004e20b053d537bac4802d7d Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Mon, 21 Oct 2024 12:42:43 +0200 Subject: [PATCH] fix graph creation --- .../backend/backend/data/graph.py | 56 ++++++++++++++++--- .../backend/backend/server/model.py | 2 +- .../backend/backend/server/rest_api.py | 2 +- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 68608cd6497e..b43320aebb16 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -49,19 +49,21 @@ def __hash__(self): return hash((self.source_id, self.sink_id, self.source_name, self.sink_name)) -class Node(BaseDbModel): +class CreatableNode(BaseDbModel): block_id: str input_default: BlockInput = {} # dict[input_name, default_value] metadata: dict[str, Any] = {} input_links: list[Link] = [] output_links: list[Link] = [] - graph_id: str - graph_version: int - webhook_id: Optional[str] = None webhook: Optional[Webhook] = None + +class Node(CreatableNode): + graph_id: str + graph_version: int + @staticmethod def from_db(node: AgentNode): if not node.AgentBlock: @@ -124,8 +126,7 @@ def from_agent_graph_execution(execution: AgentGraphExecution): ) -class GraphMeta(BaseDbModel): - user_id: str +class CreatableGraphMeta(BaseDbModel): version: int = 1 is_active: bool = True is_template: bool = False @@ -133,6 +134,10 @@ class GraphMeta(BaseDbModel): description: str executions: list[ExecutionMeta] | None = None + +class GraphMeta(CreatableGraphMeta): + user_id: str + @staticmethod def from_db(graph: AgentGraph): if graph.AgentGraphExecution: @@ -155,11 +160,15 @@ def from_db(graph: AgentGraph): ) -class Graph(GraphMeta): - nodes: list[Node] +class CreatableGraph(CreatableGraphMeta): + nodes: list[CreatableNode] links: list[Link] subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id] + +class Graph(CreatableGraph, GraphMeta): + nodes: list[Node] + @property def starting_nodes(self) -> list[Node]: outbound_nodes = {link.sink_id for link in self.links} @@ -409,7 +418,7 @@ def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]: } -# --------------------- Model functions --------------------- # +# --------------------- CRUD functions --------------------- # async def get_node(node_id: str) -> Node: @@ -640,3 +649,32 @@ async def __create_graph(tx, graph: Graph, user_id: str): for link in graph.links ] ) + + +# ------------------------ UTILITIES ------------------------ # + + +def graph_from_creatable(creatable_graph: CreatableGraph, user_id: str) -> Graph: + """ + Convert a CreatableGraph to a Graph, setting graph_id and graph_version on all nodes. + + Args: + creatable_graph (CreatableGraph): The creatable graph to convert. + user_id (str): The ID of the user creating the graph. + + Returns: + Graph: The converted Graph object. + """ + # Create a new Graph object, inheriting properties from CreatableGraph + return Graph( + **creatable_graph.model_dump(), + user_id=user_id, + nodes=[ + Node( + **creatable_node.model_dump(), + graph_id=creatable_graph.id, + graph_version=creatable_graph.version, + ) + for creatable_node in creatable_graph.nodes + ], + ) diff --git a/autogpt_platform/backend/backend/server/model.py b/autogpt_platform/backend/backend/server/model.py index 90795589a743..5610c67453b7 100644 --- a/autogpt_platform/backend/backend/server/model.py +++ b/autogpt_platform/backend/backend/server/model.py @@ -34,7 +34,7 @@ class SubscriptionDetails(pydantic.BaseModel): class CreateGraph(pydantic.BaseModel): template_id: str | None = None template_version: int | None = None - graph: backend.data.graph.Graph | None = None + graph: backend.data.graph.CreatableGraph | None = None class SetGraphActiveVersion(pydantic.BaseModel): diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 0efb347275c1..32c1816b5fdb 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -434,7 +434,7 @@ async def create_graph( user_id: str, ) -> graph_db.Graph: if create_graph.graph: - graph = create_graph.graph + graph = graph_db.graph_from_creatable(create_graph.graph, user_id) elif create_graph.template_id: # Create a new graph from a template graph = await graph_db.get_graph(