Skip to content

Commit

Permalink
feat: add patch for sbl institution types
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Mar 6, 2024
1 parent ff20bdb commit 8c30d9b
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 21 deletions.
4 changes: 4 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"SblTypeMappingDao",
"SblTypeAssociationDto",
"SblTypeAssociationDetailsDto",
"SblTypeAssociationPatchDto",
"VersionedData",
]

from .dao import (
Expand All @@ -46,4 +48,6 @@
AddressStateDto,
SblTypeAssociationDto,
SblTypeAssociationDetailsDto,
SblTypeAssociationPatchDto,
VersionedData,
)
6 changes: 6 additions & 0 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class SblTypeMappingDao(Base):
details: Mapped[str] = mapped_column(nullable=True)
modified_by: Mapped[str] = mapped_column()

def __eq__(self, other: "SblTypeMappingDao") -> bool:
return self.lei == other.lei and self.type_id == other.type_id and self.details == other.details

def __hash__(self) -> int:
return hash((self.lei, self.type_id, self.details))

def as_db_dict(self):
data = {}
for attr, column in inspect(self.__class__).c.items():
Expand Down
14 changes: 13 additions & 1 deletion src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from typing import List, Set
from typing import Generic, List, Set, Sequence
from pydantic import BaseModel, model_validator
from typing import TypeVar

T = TypeVar("T")


class VersionedData(BaseModel, Generic[T]):
version: int
data: T


class FinancialInsitutionDomainBase(BaseModel):
Expand Down Expand Up @@ -45,6 +53,10 @@ class Config:
from_attributes = True


class SblTypeAssociationPatchDto(BaseModel):
sbl_institution_types: Sequence[SblTypeAssociationDto | str]


class FinancialInstitutionDto(FinancialInstitutionBase):
tax_id: str | None = None
rssd_id: int | None = None
Expand Down
52 changes: 35 additions & 17 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List
from typing import List, Sequence, Set

from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession

from regtech_api_commons.models import AuthenticatedUser

from .repo_utils import query_type
from .repo_utils import get_associated_sbl_types, query_type

from entities.models import (
FinancialInstitutionDao,
Expand All @@ -18,17 +18,17 @@
DeniedDomainDao,
AddressStateDao,
FederalRegulatorDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)


async def get_institutions(
session: AsyncSession,
leis: List[str] = None,
leis: List[str] | None = None,
domain: str = "",
page: int = 0,
count: int = 100,
) -> List[FinancialInstitutionDao]:
) -> Sequence[FinancialInstitutionDao]:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
Expand All @@ -44,7 +44,7 @@ async def get_institutions(
return res.unique().all()


async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao:
async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao | None:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
Expand All @@ -54,19 +54,19 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
return await session.scalar(stmt)


async def get_sbl_types(session: AsyncSession) -> List[SBLInstitutionTypeDao]:
async def get_sbl_types(session: AsyncSession) -> Sequence[SBLInstitutionTypeDao]:
return await query_type(session, SBLInstitutionTypeDao)


async def get_hmda_types(session: AsyncSession) -> List[HMDAInstitutionTypeDao]:
async def get_hmda_types(session: AsyncSession) -> Sequence[HMDAInstitutionTypeDao]:
return await query_type(session, HMDAInstitutionTypeDao)


async def get_address_states(session: AsyncSession) -> List[AddressStateDao]:
async def get_address_states(session: AsyncSession) -> Sequence[AddressStateDao]:
return await query_type(session, AddressStateDao)


async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulatorDao]:
async def get_federal_regulators(session: AsyncSession) -> Sequence[FederalRegulatorDao]:
return await query_type(session, FederalRegulatorDao)


Expand All @@ -79,12 +79,7 @@ async def upsert_institution(
fi_data.pop("version", None)

if "sbl_institution_types" in fi_data:
types_association = [
SblTypeMappingDao(type_id=t, lei=fi.lei, modified_by=user.id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=fi.lei, modified_by=user.id)
for t in fi.sbl_institution_types
]
types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types)
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
Expand All @@ -93,9 +88,32 @@ async def upsert_institution(
return db_fi


async def update_sbl_types(
session: AsyncSession, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str]
) -> FinancialInstitutionDao | None:
fi = await get_institution(session, lei)
if fi:
new_types = set(get_associated_sbl_types(lei, user.id, sbl_types))
old_types = set(fi.sbl_institution_types)
add_types = new_types.difference(old_types)
remove_types = old_types.difference(new_types)

fi.sbl_institution_types = [type for type in fi.sbl_institution_types if type not in remove_types]
fi.sbl_institution_types.extend(add_types)
for type in fi.sbl_institution_types:
type.version = fi.version
await session.commit()
"""
load the async relational attributes so dto can be properly serialized
"""
for type in fi.sbl_institution_types:
await type.awaitable_attrs.sbl_type
return fi


async def add_domains(
session: AsyncSession, lei: str, domains: List[FinancialInsitutionDomainCreate]
) -> List[FinancialInstitutionDomainDao]:
) -> Set[FinancialInstitutionDomainDao]:
async with session.begin():
daos = set(
map(
Expand Down
18 changes: 15 additions & 3 deletions src/entities/repos/repo_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, TypeVar
from typing import Sequence, TypeVar, Type
from entities.models import Base, SblTypeMappingDao, SblTypeAssociationDto

T = TypeVar("T")
T = TypeVar("T", bound=Base)


async def query_type(session: AsyncSession, type: T) -> List[T]:
async def query_type(session: AsyncSession, type: Type[T]) -> Sequence[T]:
async with session.begin():
stmt = select(type)
res = await session.scalars(stmt)
return res.all()


def get_associated_sbl_types(
lei: str, user_id: str, types: Sequence[SblTypeAssociationDto | str]
) -> Sequence[SblTypeMappingDao]:
return [
SblTypeMappingDao(type_id=t, lei=lei, modified_by=user_id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=lei, modified_by=user_id)
for t in types
]
16 changes: 16 additions & 0 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
InstitutionTypeDto,
AddressStateDto,
FederalRegulatorDto,
SblTypeAssociationDetailsDto,
SblTypeAssociationPatchDto,
VersionedData,
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import requires
Expand Down Expand Up @@ -105,6 +108,19 @@ async def get_institution(
return res


@router.patch("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
async def update_sbl_types(request: Request, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto):
match type:
case "sbl":
fi = await repo.update_sbl_types(
request.state.db_session, request.user, lei, types_patch.sbl_institution_types
)
return VersionedData(version=fi.version, data=fi.sbl_institution_types) if fi else None
case "hmda":
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED, detail="HMDA type not yet supported")


@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)])
@requires(["query-groups", "manage-users"])
async def add_domains(
Expand Down
39 changes: 39 additions & 0 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from http import HTTPStatus
from unittest.mock import Mock, ANY

from fastapi import FastAPI
Expand All @@ -13,6 +14,7 @@
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)


Expand Down Expand Up @@ -425,3 +427,40 @@ def test_get_federal_regulators(self, mocker: MockerFixture, app_fixture: FastAP
client = TestClient(app_fixture)
res = client.get("/v1/institutions/regulators")
assert res.status_code == 200

def test_update_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.patch(
f"/v1/institutions/{test_lei}/types/sbl",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.OK
mock.assert_called_once_with(
ANY, ANY, test_lei, ["1", SblTypeAssociationDto(id="2"), SblTypeAssociationDto(id="13", details="test")]
)

def test_update_unsupported_institution_types(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.patch(
f"/v1/institutions/{test_lei}/types/hmda",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.NOT_IMPLEMENTED
mock.assert_not_called()

def test_update_wrong_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.patch(
f"/v1/institutions/{test_lei}/types/test",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
mock.assert_not_called()
33 changes: 33 additions & 0 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncSession

from entities.models import (
Expand Down Expand Up @@ -333,3 +334,35 @@ async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSess
async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK456"])
assert res[0].sbl_institution_types[0].sbl_type.name != "Test SBL Instituion ID 1"

async def test_update_sbl_institution_types(
self, mocker: MockerFixture, query_session: AsyncSession, transaction_session: AsyncSession
):
test_lei = "TESTBANK123"
existing_inst = await repo.get_institution(query_session, test_lei)
sbl_types = [
SblTypeAssociationDto(id="1"),
SblTypeAssociationDto(id="2"),
SblTypeAssociationDto(id="13", details="test"),
]
commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit)
updated_inst = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types)
commit_spy.assert_called_once()
assert len(existing_inst.sbl_institution_types) == 1
assert len(updated_inst.sbl_institution_types) == 3
diffs = set(updated_inst.sbl_institution_types).difference(set(existing_inst.sbl_institution_types))
assert len(diffs) == 2

async def test_update_sbl_institution_types_inst_non_exist(
self, mocker: MockerFixture, transaction_session: AsyncSession
):
test_lei = "NONEXISTINGBANK"
sbl_types = [
SblTypeAssociationDto(id="1"),
SblTypeAssociationDto(id="2"),
SblTypeAssociationDto(id="13", details="test"),
]
commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit)
res = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types)
commit_spy.assert_not_called()
assert res is None

0 comments on commit 8c30d9b

Please sign in to comment.