Skip to content

Commit

Permalink
feat: add versioning for type details
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Feb 13, 2024
1 parent df5a0d9 commit 761d366
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 19 deletions.
54 changes: 39 additions & 15 deletions src/entities/listeners.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,53 @@
from typing import List
from sqlalchemy import Connection, Table, event, inspect
from sqlalchemy.orm import Mapper

from .models.dao import Base, FinancialInstitutionDao
from .models.dao import Base, FinancialInstitutionDao, SblTypeMappingDao
from entities.engine.engine import engine


def inspect_fi(fi: FinancialInstitutionDao):
changes = {}
new_version = fi.version + 1 if fi.version else 1
state = inspect(fi)
for attr in state.attrs:
if attr.key == "event_time":
continue
if attr.key == "sbl_institution_types":
field_changes = inspect_type_fields(attr.value)
if attr.history.has_changes() or field_changes:
old_types = {"old": [o.as_db_dict() for o in attr.history.deleted]} if attr.history.deleted else {}
new_types = (
{"new": [{**n.as_db_dict(), "version": new_version} for n in attr.history.added]}
if attr.history.added
else {}
)
changes[attr.key] = {**old_types, **new_types, "field_changes": field_changes}
elif attr.history.has_changes():
changes[attr.key] = {"old": attr.history.deleted, "new": attr.history.added}
return changes


def inspect_type_fields(types: List[SblTypeMappingDao], fields: List[str] = ["details"]):
changes = []
for t in types:
state = inspect(t)
attr_changes = {
attr.key: {"old": attr.history.deleted, "new": attr.history.added}
for attr in state.attrs
if attr.key in fields and attr.history.has_changes()
}
if attr_changes:
changes.append({**t.as_db_dict(), **attr_changes})
return changes


def _setup_fi_history(fi_history: Table, mapping_history: Table):
def _insert_history(
mapper: Mapper[FinancialInstitutionDao], connection: Connection, target: FinancialInstitutionDao
):
new_version = target.version + 1 if target.version else 1
changes = {}
state = inspect(target)
for attr in state.attrs:
if attr.key == "event_time":
continue
attr_hist = attr.load_history()
if not attr_hist.has_changes():
continue
if attr.key == "sbl_institution_types":
old_types = [o.as_db_dict() for o in attr_hist.deleted]
new_types = [{**n.as_db_dict(), "version": new_version} for n in attr_hist.added]
changes[attr.key] = {"old": old_types, "new": new_types}
else:
changes[attr.key] = {"old": attr_hist.deleted, "new": attr_hist.added}
changes = inspect_fi(target)
if changes:
target.version = new_version
for t in target.sbl_institution_types:
Expand Down
63 changes: 59 additions & 4 deletions tests/entities/test_listeners.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from unittest.mock import Mock
import pytest
from unittest.mock import Mock, call
from pytest_mock import MockerFixture

from sqlalchemy import Connection, Table
from sqlalchemy import Connection, Insert, Table
from sqlalchemy.orm import Mapper, InstanceState, AttributeState
from sqlalchemy.orm.attributes import History

from entities.models.dao import FinancialInstitutionDao, SBLInstitutionTypeDao, SblTypeMappingDao

Expand Down Expand Up @@ -37,6 +39,14 @@ class TestListeners:
modified_by="test_user_id",
)

@pytest.fixture(autouse=True)
def setup(self):
self.fi_history.reset_mock()
self.fi_history.columns = {"name": "test"}
self.mapping_history.reset_mock()
self.mapper.reset_mock()
self.connection.reset_mock()

def test_fi_history_listener(self, mocker: MockerFixture):
inspect_mock = mocker.patch("entities.listeners.inspect")
attr_mock1: AttributeState = Mock(AttributeState)
Expand All @@ -45,10 +55,55 @@ def test_fi_history_listener(self, mocker: MockerFixture):
attr_mock2.key = "event_time"
state_mock: InstanceState = Mock(InstanceState)
state_mock.attrs = [attr_mock1, attr_mock2]
self.fi_history.columns = {"name": "test"}
inspect_mock.return_value = state_mock
fi_listener = _setup_fi_history(self.fi_history, self.mapping_history)
fi_listener(self.mapper, self.connection, self.target)
inspect_mock.assert_called_once_with(self.target)
attr_mock1.load_history.assert_called_once()
self.fi_history.insert.assert_called_once()
self.mapping_history.insert.assert_called_once()

def _get_fi_inspect_mock(self):
fi_attr_mock: AttributeState = Mock(AttributeState)
fi_attr_mock.key = "sbl_institution_types"
fi_attr_mock.value = self.target.sbl_institution_types
fi_attr_mock.history = History(added=[], deleted=[], unchanged=[])
fi_state_mock: InstanceState = Mock(InstanceState)
fi_state_mock.attrs = [fi_attr_mock]
return fi_state_mock

def _get_mapping_inspect_mock(self):
mapping_attr_mock: AttributeState = Mock(AttributeState)
mapping_attr_mock.key = "details"
mapping_attr_mock.history = History(added=["new type"], deleted=["old type"], unchanged=[])
mapping_state_mock: InstanceState = Mock(InstanceState)
mapping_state_mock.attrs = [mapping_attr_mock]
return mapping_state_mock

def test_fi_mapping_changed(self, mocker: MockerFixture):
inspect_mock = mocker.patch("entities.listeners.inspect")
fi_state_mock = self._get_fi_inspect_mock()
mapping_state_mock = self._get_mapping_inspect_mock()

def inspect_side_effect(inspect_target):
if inspect_target == self.target:
return fi_state_mock
elif inspect_target == self.target.sbl_institution_types[0]:
return mapping_state_mock

inspect_mock.side_effect = inspect_side_effect
fi_insert_mock = Mock(Insert)
self.fi_history.insert.return_value = fi_insert_mock
mapping_insert_mock = Mock(Insert)
self.mapping_history.insert.return_value = mapping_insert_mock
fi_listener = _setup_fi_history(self.fi_history, self.mapping_history)
fi_listener(self.mapper, self.connection, self.target)
inspect_mock.assert_has_calls([call(self.target), call(self.target.sbl_institution_types[0])])
self.fi_history.insert.assert_called_once()
self.mapping_history.insert.assert_called_once()
fi_insert_mock.values.assert_called_once()
args, _ = fi_insert_mock.values.call_args
insert_data = args[0]
assert insert_data["changeset"]["sbl_institution_types"]["field_changes"][0]["details"] == {
"old": ["old type"],
"new": ["new type"],
}
8 changes: 8 additions & 0 deletions tests/migrations/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ def test_tables_not_exist_migrate_down_to_base(alembic_runner: MigrationContext,
assert "denied_domains" not in tables
assert "financial_institutions" not in tables
assert "financial_institution_domains" not in tables


def test_fi_history_tables_8106d83ff594(alembic_runner: MigrationContext, alembic_engine: Engine):
alembic_runner.migrate_up_to("8106d83ff594")
inspector = sqlalchemy.inspect(alembic_engine)
tables = inspector.get_table_names()
assert "financial_institutions_history" in tables
assert "fi_to_type_mapping_history" in tables
9 changes: 9 additions & 0 deletions tests/migrations/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,12 @@ def test_fi_types_table_6826f05140cd(alembic_runner: MigrationContext, alembic_e
columns_names = [column.get("name") for column in columns]

assert columns_names == expected_columns


def test_fi_versioning_tables_3f893e52d05c(alembic_runner: MigrationContext, alembic_engine: Engine):
alembic_runner.migrate_up_to("3f893e52d05c")
inspector = sqlalchemy.inspect(alembic_engine)
fi_columns = inspector.get_columns("financial_institutions")
assert "version" in [column.get("name") for column in fi_columns]
mapping_columns = inspector.get_columns("fi_to_type_mapping")
assert "version" in [column.get("name") for column in mapping_columns]

0 comments on commit 761d366

Please sign in to comment.