diff --git a/databuilder/models/table_metadata.py b/databuilder/models/table_metadata.py index 17ff52f7..f30c5b59 100644 --- a/databuilder/models/table_metadata.py +++ b/databuilder/models/table_metadata.py @@ -117,7 +117,7 @@ def _create_record_iterator(self) -> Iterator[RDSModel]: # TODO: this should inherit from ProgrammaticDescription in amundsen-common -class DescriptionMetadata: +class DescriptionMetadata(GraphSerializable): DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL PROGRAMMATIC_DESCRIPTION_NODE_LABEL = 'Programmatic_Description' DESCRIPTION_KEY_FORMAT = '{description}' @@ -132,7 +132,10 @@ class DescriptionMetadata: def __init__(self, text: Optional[str], - source: str = DEFAULT_SOURCE + source: str = DEFAULT_SOURCE, + description_key: Optional[str] = None, + start_label: Optional[str] = None, # Table, Column, Schema + start_key: Optional[str] = None, ): """ :param source: The unique source of what is populating this description. @@ -146,17 +149,28 @@ def __init__(self, else: self.label = self.PROGRAMMATIC_DESCRIPTION_NODE_LABEL + self.start_label = start_label + self.start_key = start_key + self.description_key = description_key or self.get_description_default_key(start_key) + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + @staticmethod def create_description_metadata(text: Union[None, str], - source: Optional[str] = DEFAULT_SOURCE + source: Optional[str] = DEFAULT_SOURCE, + description_key: Optional[str] = None, + start_label: Optional[str] = None, # Table, Column, Schema + start_key: Optional[str] = None, ) -> Optional['DescriptionMetadata']: # We do not want to create a node if there is no description text! if text is None: return None - if not source: - description_node = DescriptionMetadata(text=text, source=DescriptionMetadata.DEFAULT_SOURCE) - else: - description_node = DescriptionMetadata(text=text, source=source) + description_node = DescriptionMetadata(text=text, + source=source or DescriptionMetadata.DEFAULT_SOURCE, + description_key=description_key, + start_label=start_label, + start_key=start_key) return description_node def get_description_id(self) -> str: @@ -165,8 +179,8 @@ def get_description_id(self) -> str: else: return "_" + self.source + "_description" - def __repr__(self) -> str: - return f'DescriptionMetadata({self.source!r}, {self.text!r})' + def get_description_default_key(self, start_key: Optional[str]) -> Optional[str]: + return f'{start_key}/{self.get_description_id()}' if start_key else None def get_node(self, node_key: str) -> GraphNode: node = GraphNode( @@ -179,7 +193,11 @@ def get_node(self, node_key: str) -> GraphNode: ) return node - def get_relation(self, start_node: str, start_key: Any, end_key: Any) -> GraphRelationship: + def get_relation(self, + start_node: str, + start_key: str, + end_key: str, + ) -> GraphRelationship: relationship = GraphRelationship( start_label=start_node, start_key=start_key, @@ -191,6 +209,40 @@ def get_relation(self, start_node: str, start_key: Any, end_key: Any) -> GraphRe ) return relationship + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + if not self.description_key: + raise Exception('Required description node key cannot be None') + yield self.get_node(self.description_key) + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + if not self.start_label: + raise Exception('Required relation start node label cannot be None') + if not self.start_key: + raise Exception('Required relation start key cannot be None') + if not self.description_key: + raise Exception('Required relation end key cannot be None') + yield self.get_relation( + start_node=self.start_label, + start_key=self.start_key, + end_key=self.description_key + ) + + def __repr__(self) -> str: + return f'DescriptionMetadata({self.source!r}, {self.text!r})' + class ColumnMetadata: COLUMN_NODE_LABEL = 'Column' diff --git a/setup.py b/setup.py index 3ec65e8d..b44a4c31 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ from setuptools import find_packages, setup -__version__ = '4.2.0' +__version__ = '4.2.1' requirements = [ diff --git a/tests/unit/models/test_description_metadata.py b/tests/unit/models/test_description_metadata.py new file mode 100644 index 00000000..eba86c81 --- /dev/null +++ b/tests/unit/models/test_description_metadata.py @@ -0,0 +1,124 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.table_metadata import DescriptionMetadata +from databuilder.serializers import neo4_serializer + + +class TestDescriptionMetadata(unittest.TestCase): + def test_raise_exception_when_missing_data(self) -> None: + # assert raise when missing description node key + self.assertRaises( + Exception, + DescriptionMetadata(text='test_text').next_node + ) + DescriptionMetadata(text='test_text', description_key='test_key').next_node() + DescriptionMetadata(text='test_text', start_key='start_key').next_node() + + # assert raise when missing relation start label + self.assertRaises( + Exception, + DescriptionMetadata(text='test_text', start_key='start_key').next_relation + ) + DescriptionMetadata(text='test_text', start_key='test_key', start_label='Table').next_relation() + + # assert raise when missing relation start key + self.assertRaises( + Exception, + DescriptionMetadata(text='test_text', description_key='test_key', start_label='Table').next_relation + ) + + def test_serialize_table_description_metadata(self) -> None: + description_metadata = DescriptionMetadata( + text='test text 1', + start_label='Table', + start_key='test_start_key' + ) + node_row = description_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = description_metadata.next_node() + expected = [ + {'description': 'test text 1', 'KEY': 'test_start_key/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + ] + self.assertEqual(actual, expected) + + relation_row = description_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = description_metadata.next_relation() + expected = [ + {'START_KEY': 'test_start_key', 'START_LABEL': 'Table', 'END_KEY': 'test_start_key/_description', + 'END_LABEL': 'Description', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} + ] + self.assertEqual(actual, expected) + + def test_serialize_column_description_metadata(self) -> None: + description_metadata = DescriptionMetadata( + text='test text 2', + start_label='Column', + start_key='test_start_key', + description_key='customized_key' + ) + node_row = description_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = description_metadata.next_node() + expected = [ + {'description': 'test text 2', 'KEY': 'customized_key', + 'LABEL': 'Description', 'description_source': 'description'}, + ] + self.assertEqual(actual, expected) + + relation_row = description_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = description_metadata.next_relation() + expected = [ + {'START_KEY': 'test_start_key', 'START_LABEL': 'Column', 'END_KEY': 'customized_key', + 'END_LABEL': 'Description', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} + ] + self.assertEqual(actual, expected) + + def test_serialize_column_with_source_description_metadata(self) -> None: + description_metadata = DescriptionMetadata( + text='test text 3', + start_label='Column', + start_key='test_start_key', + description_key='customized_key', + source='external', + ) + node_row = description_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = description_metadata.next_node() + expected = [ + {'description': 'test text 3', 'KEY': 'customized_key', + 'LABEL': 'Programmatic_Description', 'description_source': 'external'}, + ] + self.assertEqual(actual, expected) + + relation_row = description_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = description_metadata.next_relation() + expected = [ + {'START_KEY': 'test_start_key', 'START_LABEL': 'Column', 'END_KEY': 'customized_key', + 'END_LABEL': 'Programmatic_Description', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} + ] + self.assertEqual(actual, expected)