From 35bddfc56e4175aaab7c9943d939c92694018942 Mon Sep 17 00:00:00 2001 From: Waleed Alzarooni Date: Thu, 23 Jan 2025 10:09:13 -0800 Subject: [PATCH] DKG ammendments to include dynamic time_label integration --- camel/storages/graph_storages/nebula_graph.py | 45 +++++++++++++++++-- .../graph_storages/test_nebula_graph.py | 36 +++++++++++++++ 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/camel/storages/graph_storages/nebula_graph.py b/camel/storages/graph_storages/nebula_graph.py index 14e8a48caa..60a321b851 100644 --- a/camel/storages/graph_storages/nebula_graph.py +++ b/camel/storages/graph_storages/nebula_graph.py @@ -292,10 +292,15 @@ def add_node( node_id = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', node_id) tag_name = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '', tag_name) - self.ensure_tag_exists(tag_name, time_label) + # Ensure the tag exists + self.ensure_tag_exists(tag_name) - # Insert node with or without time_label property - if time_label is not None: + # Add `time_label` to schema if it's provided and not in the schema + if time_label: + self.ensure_field_in_schema(tag_name, "time_label", "string") + + # Construct the insert query based on the presence of `time_label` + if time_label: time_label = self._validate_time_label(time_label) insert_stmt = ( f'INSERT VERTEX IF NOT EXISTS {tag_name}(time_label) VALUES ' @@ -307,6 +312,7 @@ def add_node( f'"{node_id}":()' ) + # Execute the insert query with retries for attempt in range(MAX_RETRIES): res = self.query(insert_stmt) if res.is_succeeded(): @@ -321,6 +327,39 @@ def add_node( f" {MAX_RETRIES} attempts: {res.error_msg()}" ) + def ensure_field_in_schema( + self, tag_name: str, field_name: str, field_type: str + ) -> None: + r"""Ensure a field exists in the tag's schema. + + Args: + tag_name (str): The tag name to check. + field_name (str): The field name to ensure exists. + field_type (str): The type of the field (e.g., 'string', 'int'). + """ + # Describe the tag to check its schema + schema_stmt = f"DESCRIBE TAG {tag_name}" + res = self.query(schema_stmt) + + if not res.is_succeeded(): + raise Exception( + f"Failed to describe tag `{tag_name}`: {res.error_msg()}" + ) + + # Parse the schema fields + schema_fields = [field.strip().split()[0] for field in res.rows()] + if field_name not in schema_fields: + # Add the field to the schema + alter_stmt = ( + f"ALTER TAG {tag_name} ADD ({field_name} {field_type})" + ) + alter_res = self.query(alter_stmt) + if not alter_res.is_succeeded(): + raise Exception( + f"Failed to add field `{field_name}` to tag `{tag_name}`: " + f"{alter_res.error_msg()}" + ) + def _extract_nodes(self, graph_elements: List[Any]) -> List[Dict]: r"""Extracts unique nodes from graph elements. diff --git a/test/storages/graph_storages/test_nebula_graph.py b/test/storages/graph_storages/test_nebula_graph.py index 32c03cb24e..638414eedc 100644 --- a/test/storages/graph_storages/test_nebula_graph.py +++ b/test/storages/graph_storages/test_nebula_graph.py @@ -112,6 +112,42 @@ def test_add_node(self): ) self.graph.query.assert_called_with(insert_stmt) + def test_add_node_with_time_label_not_in_schema(self): + node_id = 'node1' + tag_name = 'Tag1' + time_label = '2025-01-21T12:00:00' + + # Mock dependencies + self.graph.ensure_tag_exists = Mock() + self.graph.ensure_field_in_schema = Mock() + self.graph.query = Mock() + self.graph._validate_time_label = Mock(return_value=time_label) + + # Mock query success + self.graph.query.return_value.is_succeeded = Mock(return_value=True) + + # Call the method + self.graph.add_node(node_id, tag_name, time_label) + + # Ensure the tag existence check was performed + self.graph.ensure_tag_exists.assert_called_once_with(tag_name) + + # Ensure the time_label field was added to the schema + self.graph.ensure_field_in_schema.assert_called_once_with( + tag_name, "time_label", "string" + ) + + # Validate the time_label + self.graph._validate_time_label.assert_called_once_with(time_label) + + # Ensure the correct query was executed + insert_stmt = ( + f'INSERT VERTEX IF NOT EXISTS {tag_name}(time_label) VALUES ' + f'"{node_id}":("{time_label}")' + ) + + self.graph.query.assert_called_once_with(insert_stmt) + def test_ensure_tag_exists_success(self): tag_name = 'Tag1' # Mock query to return a successful result