diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py index d61563a..89ec44b 100644 --- a/src/entities/models/__init__.py +++ b/src/entities/models/__init__.py @@ -20,6 +20,8 @@ "SblTypeMappingDao", "SblTypeAssociationDto", "SblTypeAssociationDetailsDto", + "SblTypeAssociationPatchDto", + "VersionedData", ] from .dao import ( @@ -46,4 +48,6 @@ AddressStateDto, SblTypeAssociationDto, SblTypeAssociationDetailsDto, + SblTypeAssociationPatchDto, + VersionedData, ) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 0eb3315..57502c6 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -23,6 +23,12 @@ class SblTypeMappingDao(Base): details: Mapped[str] = mapped_column(nullable=True) modified_by: Mapped[str] = mapped_column() + def __eq__(self, other: "SblTypeMappingDao") -> bool: + return self.lei == other.lei and self.type_id == other.type_id and self.details == other.details + + def __hash__(self) -> int: + return hash((self.lei, self.type_id, self.details)) + def as_db_dict(self): data = {} for attr, column in inspect(self.__class__).c.items(): diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index c7f312f..94b5dc1 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -1,5 +1,13 @@ -from typing import List, Set +from typing import Generic, List, Set, Sequence from pydantic import BaseModel, model_validator +from typing import TypeVar + +T = TypeVar("T") + + +class VersionedData(BaseModel, Generic[T]): + version: int + data: T class FinancialInsitutionDomainBase(BaseModel): @@ -45,6 +53,10 @@ class Config: from_attributes = True +class SblTypeAssociationPatchDto(BaseModel): + sbl_institution_types: Sequence[SblTypeAssociationDto | str] + + class FinancialInstitutionDto(FinancialInstitutionBase): tax_id: str | None = None rssd_id: int | None = None diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index b171a71..4d8f28c 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Sequence, Set from sqlalchemy import select, func from sqlalchemy.orm import joinedload @@ -6,7 +6,7 @@ from regtech_api_commons.models import AuthenticatedUser -from .repo_utils import query_type +from .repo_utils import get_associated_sbl_types, query_type from entities.models import ( FinancialInstitutionDao, @@ -18,17 +18,17 @@ DeniedDomainDao, AddressStateDao, FederalRegulatorDao, - SblTypeMappingDao, + SblTypeAssociationDto, ) async def get_institutions( session: AsyncSession, - leis: List[str] = None, + leis: List[str] | None = None, domain: str = "", page: int = 0, count: int = 100, -) -> List[FinancialInstitutionDao]: +) -> Sequence[FinancialInstitutionDao]: async with session.begin(): stmt = ( select(FinancialInstitutionDao) @@ -44,7 +44,7 @@ async def get_institutions( return res.unique().all() -async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao: +async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao | None: async with session.begin(): stmt = ( select(FinancialInstitutionDao) @@ -54,19 +54,19 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti return await session.scalar(stmt) -async def get_sbl_types(session: AsyncSession) -> List[SBLInstitutionTypeDao]: +async def get_sbl_types(session: AsyncSession) -> Sequence[SBLInstitutionTypeDao]: return await query_type(session, SBLInstitutionTypeDao) -async def get_hmda_types(session: AsyncSession) -> List[HMDAInstitutionTypeDao]: +async def get_hmda_types(session: AsyncSession) -> Sequence[HMDAInstitutionTypeDao]: return await query_type(session, HMDAInstitutionTypeDao) -async def get_address_states(session: AsyncSession) -> List[AddressStateDao]: +async def get_address_states(session: AsyncSession) -> Sequence[AddressStateDao]: return await query_type(session, AddressStateDao) -async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulatorDao]: +async def get_federal_regulators(session: AsyncSession) -> Sequence[FederalRegulatorDao]: return await query_type(session, FederalRegulatorDao) @@ -79,12 +79,7 @@ async def upsert_institution( fi_data.pop("version", None) if "sbl_institution_types" in fi_data: - types_association = [ - SblTypeMappingDao(type_id=t, lei=fi.lei, modified_by=user.id) - if isinstance(t, str) - else SblTypeMappingDao(type_id=t.id, details=t.details, lei=fi.lei, modified_by=user.id) - for t in fi.sbl_institution_types - ] + types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types) fi_data["sbl_institution_types"] = types_association db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id)) @@ -93,9 +88,31 @@ async def upsert_institution( return db_fi +async def update_sbl_types( + session: AsyncSession, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str] +) -> FinancialInstitutionDao | None: + if fi := await get_institution(session, lei): + new_types = set(get_associated_sbl_types(lei, user.id, sbl_types)) + old_types = set(fi.sbl_institution_types) + add_types = new_types.difference(old_types) + remove_types = old_types.difference(new_types) + + fi.sbl_institution_types = [type for type in fi.sbl_institution_types if type not in remove_types] + fi.sbl_institution_types.extend(add_types) + for type in fi.sbl_institution_types: + type.version = fi.version + await session.commit() + """ + load the async relational attributes so dto can be properly serialized + """ + for type in fi.sbl_institution_types: + await type.awaitable_attrs.sbl_type + return fi + + async def add_domains( session: AsyncSession, lei: str, domains: List[FinancialInsitutionDomainCreate] -) -> List[FinancialInstitutionDomainDao]: +) -> Set[FinancialInstitutionDomainDao]: async with session.begin(): daos = set( map( diff --git a/src/entities/repos/repo_utils.py b/src/entities/repos/repo_utils.py index faa48be..15c4b11 100644 --- a/src/entities/repos/repo_utils.py +++ b/src/entities/repos/repo_utils.py @@ -1,12 +1,24 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from typing import List, TypeVar +from typing import Sequence, TypeVar, Type +from entities.models import Base, SblTypeMappingDao, SblTypeAssociationDto -T = TypeVar("T") +T = TypeVar("T", bound=Base) -async def query_type(session: AsyncSession, type: T) -> List[T]: +async def query_type(session: AsyncSession, type: Type[T]) -> Sequence[T]: async with session.begin(): stmt = select(type) res = await session.scalars(stmt) return res.all() + + +def get_associated_sbl_types( + lei: str, user_id: str, types: Sequence[SblTypeAssociationDto | str] +) -> Sequence[SblTypeMappingDao]: + return [ + SblTypeMappingDao(type_id=t, lei=lei, modified_by=user_id) + if isinstance(t, str) + else SblTypeMappingDao(type_id=t.id, details=t.details, lei=lei, modified_by=user_id) + for t in types + ] diff --git a/src/routers/institutions.py b/src/routers/institutions.py index 49556bf..84141b3 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -1,4 +1,4 @@ -from fastapi import Depends, Request, HTTPException +from fastapi import Depends, Request, HTTPException, Response from http import HTTPStatus from regtech_api_commons.oauth2.oauth2_admin import OAuth2Admin from config import kc_settings @@ -16,6 +16,9 @@ InstitutionTypeDto, AddressStateDto, FederalRegulatorDto, + SblTypeAssociationDetailsDto, + SblTypeAssociationPatchDto, + VersionedData, ) from sqlalchemy.ext.asyncio import AsyncSession from starlette.authentication import requires @@ -105,6 +108,36 @@ async def get_institution( return res +@router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None) +@requires("authenticated") +async def get_types(request: Request, response: Response, lei: str, type: InstitutionType): + match type: + case "sbl": + if fi := await repo.get_institution(request.state.db_session, lei): + return VersionedData(version=fi.version, data=fi.sbl_institution_types) + else: + response.status_code = HTTPStatus.NO_CONTENT + case "hmda": + raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED, detail="HMDA type not yet supported") + + +@router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None) +@requires("authenticated") +async def update_types( + request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto +): + match type: + case "sbl": + if fi := await repo.update_sbl_types( + request.state.db_session, request.user, lei, types_patch.sbl_institution_types + ): + return VersionedData(version=fi.version, data=fi.sbl_institution_types) if fi else None + else: + response.status_code = HTTPStatus.NO_CONTENT + case "hmda": + raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED, detail="HMDA type not yet supported") + + @router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)]) @requires(["query-groups", "manage-users"]) async def add_domains( diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index 59e58b7..0fd11d6 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -1,3 +1,4 @@ +from http import HTTPStatus from unittest.mock import Mock, ANY from fastapi import FastAPI @@ -13,6 +14,7 @@ HMDAInstitutionTypeDao, SBLInstitutionTypeDao, SblTypeMappingDao, + SblTypeAssociationDto, ) @@ -425,3 +427,105 @@ def test_get_federal_regulators(self, mocker: MockerFixture, app_fixture: FastAP client = TestClient(app_fixture) res = client.get("/v1/institutions/regulators") assert res.status_code == 200 + + def test_get_sbl_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + inst_version = 2 + get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution") + get_institution_mock.return_value = FinancialInstitutionDao( + version=inst_version, + name="Test Bank 123", + lei="TESTBANK123", + is_active=True, + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], + tax_id="123456789", + rssd_id=1234, + primary_federal_regulator_id="FRI1", + primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), + hmda_institution_type_id="HIT1", + hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), + sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))], + hq_address_street_1="Test Address Street 1", + hq_address_street_2="", + hq_address_city="Test City 1", + hq_address_state_code="GA", + hq_address_state=AddressStateDao(code="GA", name="Georgia"), + hq_address_zip="00000", + parent_lei="PARENTTESTBANK123", + parent_legal_name="PARENT TEST BANK 123", + parent_rssd_id=12345, + top_holder_lei="TOPHOLDERLEI123", + top_holder_legal_name="TOP HOLDER LEI 123", + top_holder_rssd_id=123456, + ) + client = TestClient(app_fixture) + test_lei = "TESTBANK123" + res = client.get(f"/v1/institutions/{test_lei}/types/sbl") + assert res.status_code == HTTPStatus.OK + result = res.json() + assert len(result["data"]) == 1 + assert result["version"] == inst_version + assert result["data"][0] == {"sbl_type": {"id": "SIT1", "name": "SIT1"}, "details": None} + + def test_get_sbl_types_no_institution(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution") + get_institution_mock.return_value = None + client = TestClient(app_fixture) + test_lei = "TESTBANK123" + res = client.get(f"/v1/institutions/{test_lei}/types/sbl") + assert res.status_code == HTTPStatus.NO_CONTENT + + def test_get_hmda_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + client = TestClient(app_fixture) + test_lei = "TESTBANK123" + res = client.get(f"/v1/institutions/{test_lei}/types/hmda") + assert res.status_code == HTTPStatus.NOT_IMPLEMENTED + + def test_update_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types") + client = TestClient(app_fixture) + test_lei = "TESTBANK123" + res = client.put( + f"/v1/institutions/{test_lei}/types/sbl", + json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]}, + ) + assert res.status_code == HTTPStatus.OK + mock.assert_called_once_with( + ANY, ANY, test_lei, ["1", SblTypeAssociationDto(id="2"), SblTypeAssociationDto(id="13", details="test")] + ) + + def test_update_non_existing_institution_types( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + ): + get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution") + get_institution_mock.return_value = None + client = TestClient(app_fixture) + test_lei = "TESTBANK123" + res = client.put( + f"/v1/institutions/{test_lei}/types/sbl", + json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]}, + ) + assert res.status_code == HTTPStatus.NO_CONTENT + + def test_update_unsupported_institution_types( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + ): + mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types") + client = TestClient(app_fixture) + test_lei = "TESTBANK123" + res = client.put( + f"/v1/institutions/{test_lei}/types/hmda", + json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]}, + ) + assert res.status_code == HTTPStatus.NOT_IMPLEMENTED + mock.assert_not_called() + + def test_update_wrong_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types") + client = TestClient(app_fixture) + test_lei = "TESTBANK123" + res = client.put( + f"/v1/institutions/{test_lei}/types/test", + json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]}, + ) + assert res.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + mock.assert_not_called() diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index f764b45..152b383 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -1,4 +1,5 @@ import pytest +from pytest_mock import MockerFixture from sqlalchemy.ext.asyncio import AsyncSession from entities.models import ( @@ -333,3 +334,35 @@ async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSess async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, leis=["TESTBANK456"]) assert res[0].sbl_institution_types[0].sbl_type.name != "Test SBL Instituion ID 1" + + async def test_update_sbl_institution_types( + self, mocker: MockerFixture, query_session: AsyncSession, transaction_session: AsyncSession + ): + test_lei = "TESTBANK123" + existing_inst = await repo.get_institution(query_session, test_lei) + sbl_types = [ + SblTypeAssociationDto(id="1"), + SblTypeAssociationDto(id="2"), + SblTypeAssociationDto(id="13", details="test"), + ] + commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit) + updated_inst = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types) + commit_spy.assert_called_once() + assert len(existing_inst.sbl_institution_types) == 1 + assert len(updated_inst.sbl_institution_types) == 3 + diffs = set(updated_inst.sbl_institution_types).difference(set(existing_inst.sbl_institution_types)) + assert len(diffs) == 2 + + async def test_update_sbl_institution_types_inst_non_exist( + self, mocker: MockerFixture, transaction_session: AsyncSession + ): + test_lei = "NONEXISTINGBANK" + sbl_types = [ + SblTypeAssociationDto(id="1"), + SblTypeAssociationDto(id="2"), + SblTypeAssociationDto(id="13", details="test"), + ] + commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit) + res = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types) + commit_spy.assert_not_called() + assert res is None