diff --git a/tests/integration/metadata/test_multi_table.py b/tests/integration/metadata/test_multi_table.py index 23e16de42..c36fee1ca 100644 --- a/tests/integration/metadata/test_multi_table.py +++ b/tests/integration/metadata/test_multi_table.py @@ -4,7 +4,7 @@ from unittest.mock import patch from sdv.datasets.demo import download_demo -from sdv.metadata import MultiTableMetadata, SingleTableMetadata +from sdv.metadata import Metadata, MultiTableMetadata from tests.utils import get_multi_table_metadata @@ -311,18 +311,25 @@ def test_get_table_metadata(): table_metadata = metadata.get_table_metadata('nesreca') # Assert - assert isinstance(table_metadata, SingleTableMetadata) + assert isinstance(table_metadata, Metadata) expected_metadata = { - 'primary_key': 'id_nesreca', - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', - 'columns': { - 'upravna_enota': {'sdtype': 'id'}, - 'id_nesreca': {'sdtype': 'id'}, - 'nesreca_val': {'sdtype': 'numerical'}, - 'latitude': {'sdtype': 'latitude', 'pii': True}, - 'longitude': {'sdtype': 'longitude', 'pii': True}, + 'METADATA_SPEC_VERSION': 'V1', + 'relationships': [], + 'tables': { + 'nesreca': { + 'column_relationships': [ + {'column_names': ['latitude', 'longitude'], 'type': 'gps'} + ], + 'columns': { + 'id_nesreca': {'sdtype': 'id'}, + 'latitude': {'pii': True, 'sdtype': 'latitude'}, + 'longitude': {'pii': True, 'sdtype': 'longitude'}, + 'nesreca_val': {'sdtype': 'numerical'}, + 'upravna_enota': {'sdtype': 'id'}, + }, + 'primary_key': 'id_nesreca', + } }, - 'column_relationships': [{'type': 'gps', 'column_names': ['latitude', 'longitude']}], } assert table_metadata.to_dict() == expected_metadata diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index ba1a79ba2..e405d3473 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -518,8 +518,7 @@ def test_validate_no_relationships(self, metadata_instance): - Instance of ``Metadata`` with all valid tables and no relationships. """ # Setup - metadata = metadata_instance - metadata_no_relationships = metadata.to_dict() + metadata_no_relationships = metadata_instance.to_dict() del metadata_no_relationships['relationships'] test_metadata = Metadata.load_from_dict(metadata_no_relationships)