From e19d50ef50bf9519a9485aca4937f44d81682b9e Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Tue, 3 Oct 2023 07:14:52 -0700 Subject: [PATCH] feat: add associated institutions endpoint, refactored domain parsing (#39) closes #34 --- src/dependencies.py | 10 +++- src/entities/models/__init__.py | 2 + src/entities/models/dao.py | 7 ++- src/entities/models/dto.py | 4 ++ src/entities/repos/institutions_repo.py | 9 +--- src/routers/institutions.py | 22 +++++++- tests/api/conftest.py | 15 +++++- tests/api/routers/test_institutions_api.py | 50 +++++++++++++++---- tests/app/conftest.py | 12 +++++ tests/app/test_dependencies.py | 31 ++++++++++++ .../entities/repos/test_institutions_repo.py | 4 +- 11 files changed, 139 insertions(+), 27 deletions(-) create mode 100644 tests/app/conftest.py create mode 100644 tests/app/test_dependencies.py diff --git a/src/dependencies.py b/src/dependencies.py index 0d36f97..90a3c94 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -14,12 +14,12 @@ async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: if not request.user.is_authenticated: raise HTTPException(status_code=HTTPStatus.FORBIDDEN) - if await email_domain_denied(session, request.user.email): + if await email_domain_denied(session, get_email_domain(request.user.email)): raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied") async def email_domain_denied(session: AsyncSession, email: str) -> bool: - return not await repo.is_email_domain_allowed(session, email) + return not await repo.is_domain_allowed(session, email) def parse_leis(leis: List[str] = Query(None)) -> Optional[List]: @@ -35,3 +35,9 @@ def parse_leis(leis: List[str] = Query(None)) -> Optional[List]: return list(chain.from_iterable([x.split(",") for x in leis])) else: return None + + +def get_email_domain(email: str) -> str: + if email: + return email.split("@")[-1] + return None diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py index c772155..3c1779c 100644 --- a/src/entities/models/__init__.py +++ b/src/entities/models/__init__.py @@ -6,6 +6,7 @@ "FinancialInstitutionWithDomainsDto", "FinancialInsitutionDomainDto", "FinancialInsitutionDomainCreate", + "FinanicialInstitutionAssociationDto", "DeniedDomainDao", "DeniedDomainDto", "AuthenticatedUser", @@ -22,6 +23,7 @@ FinancialInstitutionWithDomainsDto, FinancialInsitutionDomainDto, FinancialInsitutionDomainCreate, + FinanicialInstitutionAssociationDto, DeniedDomainDto, AuthenticatedUser, ) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 3213524..3599188 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import List from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase @@ -17,14 +18,16 @@ class FinancialInstitutionDao(AuditMixin, Base): __tablename__ = "financial_institutions" lei: Mapped[str] = mapped_column(unique=True, index=True, primary_key=True) name: Mapped[str] = mapped_column(index=True) - domains = relationship("FinancialInstitutionDomainDao", back_populates="fi") + domains: Mapped[List["FinancialInstitutionDomainDao"]] = relationship( + "FinancialInstitutionDomainDao", back_populates="fi" + ) class FinancialInstitutionDomainDao(AuditMixin, Base): __tablename__ = "financial_institution_domains" domain: Mapped[str] = mapped_column(index=True, primary_key=True) lei: Mapped[str] = mapped_column(ForeignKey("financial_institutions.lei"), index=True, primary_key=True) - fi = relationship("FinancialInstitutionDao", back_populates="domains") + fi: Mapped["FinancialInstitutionDao"] = relationship("FinancialInstitutionDao", back_populates="domains") class DeniedDomainDao(AuditMixin, Base): diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index ac389f8..5214f1c 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -41,6 +41,10 @@ class Config: orm_mode = True +class FinanicialInstitutionAssociationDto(FinancialInstitutionDto): + approved: bool + + class AuthenticatedUser(BaseUser, BaseModel): claims: Dict[str, Any] name: str diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index ba33cf1..b4820e5 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -78,17 +78,10 @@ async def add_domains( return daos -async def is_email_domain_allowed(session: AsyncSession, email: str) -> bool: - domain = get_email_domain(email) +async def is_domain_allowed(session: AsyncSession, domain: str) -> bool: if domain: async with session: stmt = select(func.count()).filter(DeniedDomainDao.domain == domain) res = await session.scalar(stmt) return res == 0 return False - - -def get_email_domain(email: str) -> str: - if email: - return email.split("@")[-1] - return None diff --git a/src/routers/institutions.py b/src/routers/institutions.py index 012179c..78b5d5b 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -2,7 +2,7 @@ from http import HTTPStatus from oauth2 import oauth2_admin from util import Router -from dependencies import check_domain, parse_leis +from dependencies import check_domain, parse_leis, get_email_domain from typing import Annotated, List, Tuple from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -11,6 +11,8 @@ FinancialInstitutionWithDomainsDto, FinancialInsitutionDomainDto, FinancialInsitutionDomainCreate, + FinanicialInstitutionAssociationDto, + AuthenticatedUser, ) from sqlalchemy.ext.asyncio import AsyncSession from starlette.authentication import requires @@ -46,6 +48,22 @@ async def create_institution( return kc_id, db_fi +@router.get("/associated", response_model=List[FinanicialInstitutionAssociationDto]) +@requires("authenticated") +async def get_associated_institutions(request: Request): + user: AuthenticatedUser = request.user + email_domain = get_email_domain(user.email) + associated_institutions = await repo.get_institutions(request.state.db_session, user.institutions) + return [ + FinanicialInstitutionAssociationDto( + name=institution.name, + lei=institution.lei, + approved=email_domain in [inst_domain.domain for inst_domain in institution.domains], + ) + for institution in associated_institutions + ] + + @router.get("/{lei}", response_model=FinancialInstitutionWithDomainsDto) @requires("authenticated") async def get_institution( @@ -70,4 +88,4 @@ async def add_domains( @router.get("/domains/allowed", response_model=bool) async def is_domain_allowed(request: Request, domain: str): - return await repo.is_email_domain_allowed(request.state.db_session, domain) + return await repo.is_domain_allowed(request.state.db_session, domain) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 201c5fa..451ba6b 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -5,7 +5,7 @@ from pytest_mock import MockerFixture from starlette.authentication import AuthCredentials, UnauthenticatedUser -from entities.models import AuthenticatedUser +from entities.models import AuthenticatedUser, FinancialInstitutionDao, FinancialInstitutionDomainDao @pytest.fixture @@ -45,3 +45,16 @@ def authed_user_mock(auth_mock: Mock) -> Mock: def unauthed_user_mock(auth_mock: Mock) -> Mock: auth_mock.return_value = (AuthCredentials("unauthenticated"), UnauthenticatedUser()) return auth_mock + + +@pytest.fixture +def get_institutions_mock(mocker: MockerFixture) -> Mock: + mock = mocker.patch("entities.repos.institutions_repo.get_institutions") + mock.return_value = [ + FinancialInstitutionDao( + name="Test Bank 123", + lei="TESTBANK123", + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], + ) + ] + return mock diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index 2465dec..3af621f 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -14,15 +14,9 @@ def test_get_institutions_unauthed(self, app_fixture: FastAPI, unauthed_user_moc res = client.get("/v1/institutions/") assert res.status_code == 403 - def test_get_institutions_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): - get_institutions_mock = mocker.patch("entities.repos.institutions_repo.get_institutions") - get_institutions_mock.return_value = [ - FinancialInstitutionDao( - name="Test Bank 123", - lei="TESTBANK123", - domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], - ) - ] + def test_get_institutions_authed( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock, get_institutions_mock: Mock + ): client = TestClient(app_fixture) res = client.get("/v1/institutions/") assert res.status_code == 200 @@ -135,10 +129,46 @@ def test_add_domains_authed_with_denied_email_domain( assert "domain denied" in res.json()["detail"] def test_check_domain_allowed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): - domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_email_domain_allowed") + domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_domain_allowed") domain_allowed_mock.return_value = True domain_to_check = "local.host" client = TestClient(app_fixture) res = client.get(f"/v1/institutions/domains/allowed?domain={domain_to_check}") domain_allowed_mock.assert_called_once_with(ANY, domain_to_check) assert res.json() is True + + def test_get_associated_institutions( + self, mocker: MockerFixture, app_fixture: FastAPI, auth_mock: Mock, get_institutions_mock: Mock + ): + get_institutions_mock.return_value = [ + FinancialInstitutionDao( + name="Test Bank 123", + lei="TESTBANK123", + domains=[FinancialInstitutionDomainDao(domain="test123.bank", lei="TESTBANK123")], + ), + FinancialInstitutionDao( + name="Test Bank 234", + lei="TESTBANK234", + domains=[FinancialInstitutionDomainDao(domain="test234.bank", lei="TESTBANK234")], + ), + ] + claims = { + "name": "test", + "preferred_username": "test_user", + "email": "test@test234.bank", + "sub": "testuser123", + "institutions": ["/TESTBANK123", "/TESTBANK234"], + } + auth_mock.return_value = ( + AuthCredentials(["authenticated"]), + AuthenticatedUser.from_claim(claims), + ) + client = TestClient(app_fixture) + res = client.get("/v1/institutions/associated") + assert res.status_code == 200 + get_institutions_mock.assert_called_once_with(ANY, ["TESTBANK123", "TESTBANK234"]) + data = res.json() + inst1 = next(filter(lambda inst: inst["lei"] == "TESTBANK123", data)) + inst2 = next(filter(lambda inst: inst["lei"] == "TESTBANK234", data)) + assert inst1["approved"] is False + assert inst2["approved"] is True diff --git a/tests/app/conftest.py b/tests/app/conftest.py new file mode 100644 index 0000000..18cc99d --- /dev/null +++ b/tests/app/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from pytest_mock import MockerFixture + + +@pytest.fixture(autouse=True) +def setup(mocker: MockerFixture): + mocked_engine = mocker.patch("sqlalchemy.ext.asyncio.create_async_engine") + MockedEngine = mocker.patch("sqlalchemy.ext.asyncio.AsyncEngine") + mocked_engine.return_value = MockedEngine.return_value + mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer") + mocker.patch("entities.engine.get_session") diff --git a/tests/app/test_dependencies.py b/tests/app/test_dependencies.py new file mode 100644 index 0000000..f44692a --- /dev/null +++ b/tests/app/test_dependencies.py @@ -0,0 +1,31 @@ +from pytest_mock import MockerFixture +from sqlalchemy.ext.asyncio import AsyncSession + +import pytest + + +@pytest.fixture +def mock_session(mocker: MockerFixture) -> AsyncSession: + return mocker.patch("sqlalchemy.ext.asyncio.AsyncSession").return_value + + +async def test_domain_denied(mocker: MockerFixture, mock_session: AsyncSession): + domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_domain_allowed") + domain_allowed_mock.return_value = False + from dependencies import email_domain_denied + + denied_domain = "denied.domain" + + assert await email_domain_denied(mock_session, denied_domain) is True + domain_allowed_mock.assert_called_once_with(mock_session, denied_domain) + + +async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession): + domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_domain_allowed") + domain_allowed_mock.return_value = True + from dependencies import email_domain_denied + + allowed_domain = "allowed.domain" + + assert await email_domain_denied(mock_session, allowed_domain) is False + domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain) diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 5acdfd6..3b7279a 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -79,5 +79,5 @@ async def test_domain_allowed(self, transaction_session: AsyncSession): denied_domain = DeniedDomainDao(domain="yahoo.com") transaction_session.add(denied_domain) await transaction_session.commit() - assert await repo.is_email_domain_allowed(transaction_session, "test@yahoo.com") is False - assert await repo.is_email_domain_allowed(transaction_session, "test@gmail.com") is True + assert await repo.is_domain_allowed(transaction_session, "yahoo.com") is False + assert await repo.is_domain_allowed(transaction_session, "gmail.com") is True