From 761d36652eb04b93d132d0b9b9a96946256cd8cf Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:42:52 -0500 Subject: [PATCH] feat: add versioning for type details --- src/entities/listeners.py | 54 ++++++++++++++++++------- tests/entities/test_listeners.py | 63 +++++++++++++++++++++++++++-- tests/migrations/test_migrations.py | 8 ++++ tests/migrations/test_schema.py | 9 +++++ 4 files changed, 115 insertions(+), 19 deletions(-) diff --git a/src/entities/listeners.py b/src/entities/listeners.py index bfb6be0..38b64a3 100644 --- a/src/entities/listeners.py +++ b/src/entities/listeners.py @@ -1,29 +1,53 @@ +from typing import List from sqlalchemy import Connection, Table, event, inspect from sqlalchemy.orm import Mapper -from .models.dao import Base, FinancialInstitutionDao +from .models.dao import Base, FinancialInstitutionDao, SblTypeMappingDao from entities.engine.engine import engine +def inspect_fi(fi: FinancialInstitutionDao): + changes = {} + new_version = fi.version + 1 if fi.version else 1 + state = inspect(fi) + for attr in state.attrs: + if attr.key == "event_time": + continue + if attr.key == "sbl_institution_types": + field_changes = inspect_type_fields(attr.value) + if attr.history.has_changes() or field_changes: + old_types = {"old": [o.as_db_dict() for o in attr.history.deleted]} if attr.history.deleted else {} + new_types = ( + {"new": [{**n.as_db_dict(), "version": new_version} for n in attr.history.added]} + if attr.history.added + else {} + ) + changes[attr.key] = {**old_types, **new_types, "field_changes": field_changes} + elif attr.history.has_changes(): + changes[attr.key] = {"old": attr.history.deleted, "new": attr.history.added} + return changes + + +def inspect_type_fields(types: List[SblTypeMappingDao], fields: List[str] = ["details"]): + changes = [] + for t in types: + state = inspect(t) + attr_changes = { + attr.key: {"old": attr.history.deleted, "new": attr.history.added} + for attr in state.attrs + if attr.key in fields and attr.history.has_changes() + } + if attr_changes: + changes.append({**t.as_db_dict(), **attr_changes}) + return changes + + def _setup_fi_history(fi_history: Table, mapping_history: Table): def _insert_history( mapper: Mapper[FinancialInstitutionDao], connection: Connection, target: FinancialInstitutionDao ): new_version = target.version + 1 if target.version else 1 - changes = {} - state = inspect(target) - for attr in state.attrs: - if attr.key == "event_time": - continue - attr_hist = attr.load_history() - if not attr_hist.has_changes(): - continue - if attr.key == "sbl_institution_types": - old_types = [o.as_db_dict() for o in attr_hist.deleted] - new_types = [{**n.as_db_dict(), "version": new_version} for n in attr_hist.added] - changes[attr.key] = {"old": old_types, "new": new_types} - else: - changes[attr.key] = {"old": attr_hist.deleted, "new": attr_hist.added} + changes = inspect_fi(target) if changes: target.version = new_version for t in target.sbl_institution_types: diff --git a/tests/entities/test_listeners.py b/tests/entities/test_listeners.py index 88dc853..9c8f908 100644 --- a/tests/entities/test_listeners.py +++ b/tests/entities/test_listeners.py @@ -1,8 +1,10 @@ -from unittest.mock import Mock +import pytest +from unittest.mock import Mock, call from pytest_mock import MockerFixture -from sqlalchemy import Connection, Table +from sqlalchemy import Connection, Insert, Table from sqlalchemy.orm import Mapper, InstanceState, AttributeState +from sqlalchemy.orm.attributes import History from entities.models.dao import FinancialInstitutionDao, SBLInstitutionTypeDao, SblTypeMappingDao @@ -37,6 +39,14 @@ class TestListeners: modified_by="test_user_id", ) + @pytest.fixture(autouse=True) + def setup(self): + self.fi_history.reset_mock() + self.fi_history.columns = {"name": "test"} + self.mapping_history.reset_mock() + self.mapper.reset_mock() + self.connection.reset_mock() + def test_fi_history_listener(self, mocker: MockerFixture): inspect_mock = mocker.patch("entities.listeners.inspect") attr_mock1: AttributeState = Mock(AttributeState) @@ -45,10 +55,55 @@ def test_fi_history_listener(self, mocker: MockerFixture): attr_mock2.key = "event_time" state_mock: InstanceState = Mock(InstanceState) state_mock.attrs = [attr_mock1, attr_mock2] - self.fi_history.columns = {"name": "test"} inspect_mock.return_value = state_mock fi_listener = _setup_fi_history(self.fi_history, self.mapping_history) fi_listener(self.mapper, self.connection, self.target) inspect_mock.assert_called_once_with(self.target) - attr_mock1.load_history.assert_called_once() self.fi_history.insert.assert_called_once() + self.mapping_history.insert.assert_called_once() + + def _get_fi_inspect_mock(self): + fi_attr_mock: AttributeState = Mock(AttributeState) + fi_attr_mock.key = "sbl_institution_types" + fi_attr_mock.value = self.target.sbl_institution_types + fi_attr_mock.history = History(added=[], deleted=[], unchanged=[]) + fi_state_mock: InstanceState = Mock(InstanceState) + fi_state_mock.attrs = [fi_attr_mock] + return fi_state_mock + + def _get_mapping_inspect_mock(self): + mapping_attr_mock: AttributeState = Mock(AttributeState) + mapping_attr_mock.key = "details" + mapping_attr_mock.history = History(added=["new type"], deleted=["old type"], unchanged=[]) + mapping_state_mock: InstanceState = Mock(InstanceState) + mapping_state_mock.attrs = [mapping_attr_mock] + return mapping_state_mock + + def test_fi_mapping_changed(self, mocker: MockerFixture): + inspect_mock = mocker.patch("entities.listeners.inspect") + fi_state_mock = self._get_fi_inspect_mock() + mapping_state_mock = self._get_mapping_inspect_mock() + + def inspect_side_effect(inspect_target): + if inspect_target == self.target: + return fi_state_mock + elif inspect_target == self.target.sbl_institution_types[0]: + return mapping_state_mock + + inspect_mock.side_effect = inspect_side_effect + fi_insert_mock = Mock(Insert) + self.fi_history.insert.return_value = fi_insert_mock + mapping_insert_mock = Mock(Insert) + self.mapping_history.insert.return_value = mapping_insert_mock + fi_listener = _setup_fi_history(self.fi_history, self.mapping_history) + fi_listener(self.mapper, self.connection, self.target) + inspect_mock.assert_has_calls([call(self.target), call(self.target.sbl_institution_types[0])]) + self.fi_history.insert.assert_called_once() + self.mapping_history.insert.assert_called_once() + fi_insert_mock.values.assert_called_once() + args, _ = fi_insert_mock.values.call_args + insert_data = args[0] + assert insert_data["changeset"]["sbl_institution_types"]["field_changes"][0]["details"] == { + "old": ["old type"], + "new": ["new type"], + } diff --git a/tests/migrations/test_migrations.py b/tests/migrations/test_migrations.py index 5059002..1c1b66d 100644 --- a/tests/migrations/test_migrations.py +++ b/tests/migrations/test_migrations.py @@ -39,3 +39,11 @@ def test_tables_not_exist_migrate_down_to_base(alembic_runner: MigrationContext, assert "denied_domains" not in tables assert "financial_institutions" not in tables assert "financial_institution_domains" not in tables + + +def test_fi_history_tables_8106d83ff594(alembic_runner: MigrationContext, alembic_engine: Engine): + alembic_runner.migrate_up_to("8106d83ff594") + inspector = sqlalchemy.inspect(alembic_engine) + tables = inspector.get_table_names() + assert "financial_institutions_history" in tables + assert "fi_to_type_mapping_history" in tables diff --git a/tests/migrations/test_schema.py b/tests/migrations/test_schema.py index 88273ba..fcf9092 100644 --- a/tests/migrations/test_schema.py +++ b/tests/migrations/test_schema.py @@ -64,3 +64,12 @@ def test_fi_types_table_6826f05140cd(alembic_runner: MigrationContext, alembic_e columns_names = [column.get("name") for column in columns] assert columns_names == expected_columns + + +def test_fi_versioning_tables_3f893e52d05c(alembic_runner: MigrationContext, alembic_engine: Engine): + alembic_runner.migrate_up_to("3f893e52d05c") + inspector = sqlalchemy.inspect(alembic_engine) + fi_columns = inspector.get_columns("financial_institutions") + assert "version" in [column.get("name") for column in fi_columns] + mapping_columns = inspector.get_columns("fi_to_type_mapping") + assert "version" in [column.get("name") for column in mapping_columns]