Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add patch for sbl institution types #110

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"SblTypeMappingDao",
"SblTypeAssociationDto",
"SblTypeAssociationDetailsDto",
"SblTypeAssociationPatchDto",
"VersionedData",
]

from .dao import (
Expand All @@ -46,4 +48,6 @@
AddressStateDto,
SblTypeAssociationDto,
SblTypeAssociationDetailsDto,
SblTypeAssociationPatchDto,
VersionedData,
)
6 changes: 6 additions & 0 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class SblTypeMappingDao(Base):
details: Mapped[str] = mapped_column(nullable=True)
modified_by: Mapped[str] = mapped_column()

def __eq__(self, other: "SblTypeMappingDao") -> bool:
return self.lei == other.lei and self.type_id == other.type_id and self.details == other.details

def __hash__(self) -> int:
return hash((self.lei, self.type_id, self.details))

def as_db_dict(self):
data = {}
for attr, column in inspect(self.__class__).c.items():
Expand Down
14 changes: 13 additions & 1 deletion src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from typing import List, Set
from typing import Generic, List, Set, Sequence
from pydantic import BaseModel, model_validator
from typing import TypeVar

T = TypeVar("T")


class VersionedData(BaseModel, Generic[T]):
version: int
data: T


class FinancialInsitutionDomainBase(BaseModel):
Expand Down Expand Up @@ -45,6 +53,10 @@ class Config:
from_attributes = True


class SblTypeAssociationPatchDto(BaseModel):
sbl_institution_types: Sequence[SblTypeAssociationDto | str]


class FinancialInstitutionDto(FinancialInstitutionBase):
tax_id: str | None = None
rssd_id: int | None = None
Expand Down
51 changes: 34 additions & 17 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List
from typing import List, Sequence, Set

from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession

from regtech_api_commons.models import AuthenticatedUser

from .repo_utils import query_type
from .repo_utils import get_associated_sbl_types, query_type

from entities.models import (
FinancialInstitutionDao,
Expand All @@ -18,17 +18,17 @@
DeniedDomainDao,
AddressStateDao,
FederalRegulatorDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)


async def get_institutions(
session: AsyncSession,
leis: List[str] = None,
leis: List[str] | None = None,
domain: str = "",
page: int = 0,
count: int = 100,
) -> List[FinancialInstitutionDao]:
) -> Sequence[FinancialInstitutionDao]:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
Expand All @@ -44,7 +44,7 @@ async def get_institutions(
return res.unique().all()


async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao:
async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao | None:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
Expand All @@ -54,19 +54,19 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
return await session.scalar(stmt)


async def get_sbl_types(session: AsyncSession) -> List[SBLInstitutionTypeDao]:
async def get_sbl_types(session: AsyncSession) -> Sequence[SBLInstitutionTypeDao]:
return await query_type(session, SBLInstitutionTypeDao)


async def get_hmda_types(session: AsyncSession) -> List[HMDAInstitutionTypeDao]:
async def get_hmda_types(session: AsyncSession) -> Sequence[HMDAInstitutionTypeDao]:
return await query_type(session, HMDAInstitutionTypeDao)


async def get_address_states(session: AsyncSession) -> List[AddressStateDao]:
async def get_address_states(session: AsyncSession) -> Sequence[AddressStateDao]:
return await query_type(session, AddressStateDao)


async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulatorDao]:
async def get_federal_regulators(session: AsyncSession) -> Sequence[FederalRegulatorDao]:
return await query_type(session, FederalRegulatorDao)


Expand All @@ -79,12 +79,7 @@ async def upsert_institution(
fi_data.pop("version", None)

if "sbl_institution_types" in fi_data:
types_association = [
SblTypeMappingDao(type_id=t, lei=fi.lei, modified_by=user.id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=fi.lei, modified_by=user.id)
for t in fi.sbl_institution_types
]
types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types)
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
Expand All @@ -93,9 +88,31 @@ async def upsert_institution(
return db_fi


async def update_sbl_types(
session: AsyncSession, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str]
) -> FinancialInstitutionDao | None:
if fi := await get_institution(session, lei):
new_types = set(get_associated_sbl_types(lei, user.id, sbl_types))
old_types = set(fi.sbl_institution_types)
add_types = new_types.difference(old_types)
remove_types = old_types.difference(new_types)

fi.sbl_institution_types = [type for type in fi.sbl_institution_types if type not in remove_types]
fi.sbl_institution_types.extend(add_types)
for type in fi.sbl_institution_types:
type.version = fi.version
await session.commit()
"""
load the async relational attributes so dto can be properly serialized
"""
for type in fi.sbl_institution_types:
await type.awaitable_attrs.sbl_type
return fi


async def add_domains(
session: AsyncSession, lei: str, domains: List[FinancialInsitutionDomainCreate]
) -> List[FinancialInstitutionDomainDao]:
) -> Set[FinancialInstitutionDomainDao]:
async with session.begin():
daos = set(
map(
Expand Down
18 changes: 15 additions & 3 deletions src/entities/repos/repo_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, TypeVar
from typing import Sequence, TypeVar, Type
from entities.models import Base, SblTypeMappingDao, SblTypeAssociationDto

T = TypeVar("T")
T = TypeVar("T", bound=Base)


async def query_type(session: AsyncSession, type: T) -> List[T]:
async def query_type(session: AsyncSession, type: Type[T]) -> Sequence[T]:
async with session.begin():
stmt = select(type)
res = await session.scalars(stmt)
return res.all()


def get_associated_sbl_types(
lei: str, user_id: str, types: Sequence[SblTypeAssociationDto | str]
) -> Sequence[SblTypeMappingDao]:
return [
SblTypeMappingDao(type_id=t, lei=lei, modified_by=user_id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=lei, modified_by=user_id)
for t in types
]
35 changes: 34 additions & 1 deletion src/routers/institutions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import Depends, Request, HTTPException
from fastapi import Depends, Request, HTTPException, Response
from http import HTTPStatus
from regtech_api_commons.oauth2.oauth2_admin import OAuth2Admin
from config import kc_settings
Expand All @@ -16,6 +16,9 @@
InstitutionTypeDto,
AddressStateDto,
FederalRegulatorDto,
SblTypeAssociationDetailsDto,
SblTypeAssociationPatchDto,
VersionedData,
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import requires
Expand Down Expand Up @@ -105,6 +108,36 @@ async def get_institution(
return res


@router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
async def get_types(request: Request, response: Response, lei: str, type: InstitutionType):
match type:
case "sbl":
if fi := await repo.get_institution(request.state.db_session, lei):
return VersionedData(version=fi.version, data=fi.sbl_institution_types)
else:
response.status_code = HTTPStatus.NO_CONTENT
case "hmda":
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED, detail="HMDA type not yet supported")


@router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
async def update_types(
request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto
):
match type:
case "sbl":
if fi := await repo.update_sbl_types(
request.state.db_session, request.user, lei, types_patch.sbl_institution_types
):
return VersionedData(version=fi.version, data=fi.sbl_institution_types) if fi else None
else:
response.status_code = HTTPStatus.NO_CONTENT
case "hmda":
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED, detail="HMDA type not yet supported")


@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the question I put in Mattermost when you lost power:

Wondering if we should be returning 204 NO_CONTENT instead of None? I've done that with a few of the endpoints in filing, and added it to the wiki for others to bring up if we want to decide on a consistent approach

Thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap, the sbl case I've updated to use no_content instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmda one is a question mark at the moment; if we do allow more open modifications later down the line, I'm guessing the hmda one will be more like the other normal fields, rather than the special case types dealing

@requires(["query-groups", "manage-users"])
async def add_domains(
Expand Down
104 changes: 104 additions & 0 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from http import HTTPStatus
from unittest.mock import Mock, ANY

from fastapi import FastAPI
Expand All @@ -13,6 +14,7 @@
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)


Expand Down Expand Up @@ -425,3 +427,105 @@ def test_get_federal_regulators(self, mocker: MockerFixture, app_fixture: FastAP
client = TestClient(app_fixture)
res = client.get("/v1/institutions/regulators")
assert res.status_code == 200

def test_get_sbl_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
inst_version = 2
get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution")
get_institution_mock.return_value = FinancialInstitutionDao(
version=inst_version,
name="Test Bank 123",
lei="TESTBANK123",
is_active=True,
domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")],
tax_id="123456789",
rssd_id=1234,
primary_federal_regulator_id="FRI1",
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
hq_address_state_code="GA",
hq_address_state=AddressStateDao(code="GA", name="Georgia"),
hq_address_zip="00000",
parent_lei="PARENTTESTBANK123",
parent_legal_name="PARENT TEST BANK 123",
parent_rssd_id=12345,
top_holder_lei="TOPHOLDERLEI123",
top_holder_legal_name="TOP HOLDER LEI 123",
top_holder_rssd_id=123456,
)
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.get(f"/v1/institutions/{test_lei}/types/sbl")
assert res.status_code == HTTPStatus.OK
result = res.json()
assert len(result["data"]) == 1
assert result["version"] == inst_version
assert result["data"][0] == {"sbl_type": {"id": "SIT1", "name": "SIT1"}, "details": None}

def test_get_sbl_types_no_institution(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution")
get_institution_mock.return_value = None
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.get(f"/v1/institutions/{test_lei}/types/sbl")
assert res.status_code == HTTPStatus.NO_CONTENT

def test_get_hmda_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.get(f"/v1/institutions/{test_lei}/types/hmda")
assert res.status_code == HTTPStatus.NOT_IMPLEMENTED

def test_update_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.put(
f"/v1/institutions/{test_lei}/types/sbl",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.OK
mock.assert_called_once_with(
ANY, ANY, test_lei, ["1", SblTypeAssociationDto(id="2"), SblTypeAssociationDto(id="13", details="test")]
)

def test_update_non_existing_institution_types(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution")
get_institution_mock.return_value = None
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.put(
f"/v1/institutions/{test_lei}/types/sbl",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.NO_CONTENT

def test_update_unsupported_institution_types(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.put(
f"/v1/institutions/{test_lei}/types/hmda",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.NOT_IMPLEMENTED
mock.assert_not_called()

def test_update_wrong_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.put(
f"/v1/institutions/{test_lei}/types/test",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
mock.assert_not_called()
Loading
Loading