Skip to content

Commit

Permalink
feat: add associated institutions endpoint, refactored domain parsing (
Browse files Browse the repository at this point in the history
…#39)

closes #34
  • Loading branch information
lchen-2101 authored Oct 3, 2023
1 parent 6891a4f commit e19d50e
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 27 deletions.
10 changes: 8 additions & 2 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
if await email_domain_denied(session, get_email_domain(request.user.email)):
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not await repo.is_email_domain_allowed(session, email)
return not await repo.is_domain_allowed(session, email)


def parse_leis(leis: List[str] = Query(None)) -> Optional[List]:
Expand All @@ -35,3 +35,9 @@ def parse_leis(leis: List[str] = Query(None)) -> Optional[List]:
return list(chain.from_iterable([x.split(",") for x in leis]))
else:
return None


def get_email_domain(email: str) -> str:
if email:
return email.split("@")[-1]
return None
2 changes: 2 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"FinancialInstitutionWithDomainsDto",
"FinancialInsitutionDomainDto",
"FinancialInsitutionDomainCreate",
"FinanicialInstitutionAssociationDto",
"DeniedDomainDao",
"DeniedDomainDto",
"AuthenticatedUser",
Expand All @@ -22,6 +23,7 @@
FinancialInstitutionWithDomainsDto,
FinancialInsitutionDomainDto,
FinancialInsitutionDomainCreate,
FinanicialInstitutionAssociationDto,
DeniedDomainDto,
AuthenticatedUser,
)
7 changes: 5 additions & 2 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import List
from sqlalchemy import ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase
Expand All @@ -17,14 +18,16 @@ class FinancialInstitutionDao(AuditMixin, Base):
__tablename__ = "financial_institutions"
lei: Mapped[str] = mapped_column(unique=True, index=True, primary_key=True)
name: Mapped[str] = mapped_column(index=True)
domains = relationship("FinancialInstitutionDomainDao", back_populates="fi")
domains: Mapped[List["FinancialInstitutionDomainDao"]] = relationship(
"FinancialInstitutionDomainDao", back_populates="fi"
)


class FinancialInstitutionDomainDao(AuditMixin, Base):
__tablename__ = "financial_institution_domains"
domain: Mapped[str] = mapped_column(index=True, primary_key=True)
lei: Mapped[str] = mapped_column(ForeignKey("financial_institutions.lei"), index=True, primary_key=True)
fi = relationship("FinancialInstitutionDao", back_populates="domains")
fi: Mapped["FinancialInstitutionDao"] = relationship("FinancialInstitutionDao", back_populates="domains")


class DeniedDomainDao(AuditMixin, Base):
Expand Down
4 changes: 4 additions & 0 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class Config:
orm_mode = True


class FinanicialInstitutionAssociationDto(FinancialInstitutionDto):
approved: bool


class AuthenticatedUser(BaseUser, BaseModel):
claims: Dict[str, Any]
name: str
Expand Down
9 changes: 1 addition & 8 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,10 @@ async def add_domains(
return daos


async def is_email_domain_allowed(session: AsyncSession, email: str) -> bool:
domain = get_email_domain(email)
async def is_domain_allowed(session: AsyncSession, domain: str) -> bool:
if domain:
async with session:
stmt = select(func.count()).filter(DeniedDomainDao.domain == domain)
res = await session.scalar(stmt)
return res == 0
return False


def get_email_domain(email: str) -> str:
if email:
return email.split("@")[-1]
return None
22 changes: 20 additions & 2 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from http import HTTPStatus
from oauth2 import oauth2_admin
from util import Router
from dependencies import check_domain, parse_leis
from dependencies import check_domain, parse_leis, get_email_domain
from typing import Annotated, List, Tuple
from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand All @@ -11,6 +11,8 @@
FinancialInstitutionWithDomainsDto,
FinancialInsitutionDomainDto,
FinancialInsitutionDomainCreate,
FinanicialInstitutionAssociationDto,
AuthenticatedUser,
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import requires
Expand Down Expand Up @@ -46,6 +48,22 @@ async def create_institution(
return kc_id, db_fi


@router.get("/associated", response_model=List[FinanicialInstitutionAssociationDto])
@requires("authenticated")
async def get_associated_institutions(request: Request):
user: AuthenticatedUser = request.user
email_domain = get_email_domain(user.email)
associated_institutions = await repo.get_institutions(request.state.db_session, user.institutions)
return [
FinanicialInstitutionAssociationDto(
name=institution.name,
lei=institution.lei,
approved=email_domain in [inst_domain.domain for inst_domain in institution.domains],
)
for institution in associated_institutions
]


@router.get("/{lei}", response_model=FinancialInstitutionWithDomainsDto)
@requires("authenticated")
async def get_institution(
Expand All @@ -70,4 +88,4 @@ async def add_domains(

@router.get("/domains/allowed", response_model=bool)
async def is_domain_allowed(request: Request, domain: str):
return await repo.is_email_domain_allowed(request.state.db_session, domain)
return await repo.is_domain_allowed(request.state.db_session, domain)
15 changes: 14 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials, UnauthenticatedUser

from entities.models import AuthenticatedUser
from entities.models import AuthenticatedUser, FinancialInstitutionDao, FinancialInstitutionDomainDao


@pytest.fixture
Expand Down Expand Up @@ -45,3 +45,16 @@ def authed_user_mock(auth_mock: Mock) -> Mock:
def unauthed_user_mock(auth_mock: Mock) -> Mock:
auth_mock.return_value = (AuthCredentials("unauthenticated"), UnauthenticatedUser())
return auth_mock


@pytest.fixture
def get_institutions_mock(mocker: MockerFixture) -> Mock:
mock = mocker.patch("entities.repos.institutions_repo.get_institutions")
mock.return_value = [
FinancialInstitutionDao(
name="Test Bank 123",
lei="TESTBANK123",
domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")],
)
]
return mock
50 changes: 40 additions & 10 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@ def test_get_institutions_unauthed(self, app_fixture: FastAPI, unauthed_user_moc
res = client.get("/v1/institutions/")
assert res.status_code == 403

def test_get_institutions_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
get_institutions_mock = mocker.patch("entities.repos.institutions_repo.get_institutions")
get_institutions_mock.return_value = [
FinancialInstitutionDao(
name="Test Bank 123",
lei="TESTBANK123",
domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")],
)
]
def test_get_institutions_authed(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock, get_institutions_mock: Mock
):
client = TestClient(app_fixture)
res = client.get("/v1/institutions/")
assert res.status_code == 200
Expand Down Expand Up @@ -135,10 +129,46 @@ def test_add_domains_authed_with_denied_email_domain(
assert "domain denied" in res.json()["detail"]

def test_check_domain_allowed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_email_domain_allowed")
domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_domain_allowed")
domain_allowed_mock.return_value = True
domain_to_check = "local.host"
client = TestClient(app_fixture)
res = client.get(f"/v1/institutions/domains/allowed?domain={domain_to_check}")
domain_allowed_mock.assert_called_once_with(ANY, domain_to_check)
assert res.json() is True

def test_get_associated_institutions(
self, mocker: MockerFixture, app_fixture: FastAPI, auth_mock: Mock, get_institutions_mock: Mock
):
get_institutions_mock.return_value = [
FinancialInstitutionDao(
name="Test Bank 123",
lei="TESTBANK123",
domains=[FinancialInstitutionDomainDao(domain="test123.bank", lei="TESTBANK123")],
),
FinancialInstitutionDao(
name="Test Bank 234",
lei="TESTBANK234",
domains=[FinancialInstitutionDomainDao(domain="test234.bank", lei="TESTBANK234")],
),
]
claims = {
"name": "test",
"preferred_username": "test_user",
"email": "[email protected]",
"sub": "testuser123",
"institutions": ["/TESTBANK123", "/TESTBANK234"],
}
auth_mock.return_value = (
AuthCredentials(["authenticated"]),
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.get("/v1/institutions/associated")
assert res.status_code == 200
get_institutions_mock.assert_called_once_with(ANY, ["TESTBANK123", "TESTBANK234"])
data = res.json()
inst1 = next(filter(lambda inst: inst["lei"] == "TESTBANK123", data))
inst2 = next(filter(lambda inst: inst["lei"] == "TESTBANK234", data))
assert inst1["approved"] is False
assert inst2["approved"] is True
12 changes: 12 additions & 0 deletions tests/app/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from pytest_mock import MockerFixture


@pytest.fixture(autouse=True)
def setup(mocker: MockerFixture):
mocked_engine = mocker.patch("sqlalchemy.ext.asyncio.create_async_engine")
MockedEngine = mocker.patch("sqlalchemy.ext.asyncio.AsyncEngine")
mocked_engine.return_value = MockedEngine.return_value
mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer")
mocker.patch("entities.engine.get_session")
31 changes: 31 additions & 0 deletions tests/app/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncSession

import pytest


@pytest.fixture
def mock_session(mocker: MockerFixture) -> AsyncSession:
return mocker.patch("sqlalchemy.ext.asyncio.AsyncSession").return_value


async def test_domain_denied(mocker: MockerFixture, mock_session: AsyncSession):
domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_domain_allowed")
domain_allowed_mock.return_value = False
from dependencies import email_domain_denied

denied_domain = "denied.domain"

assert await email_domain_denied(mock_session, denied_domain) is True
domain_allowed_mock.assert_called_once_with(mock_session, denied_domain)


async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession):
domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_domain_allowed")
domain_allowed_mock.return_value = True
from dependencies import email_domain_denied

allowed_domain = "allowed.domain"

assert await email_domain_denied(mock_session, allowed_domain) is False
domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain)
4 changes: 2 additions & 2 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ async def test_domain_allowed(self, transaction_session: AsyncSession):
denied_domain = DeniedDomainDao(domain="yahoo.com")
transaction_session.add(denied_domain)
await transaction_session.commit()
assert await repo.is_email_domain_allowed(transaction_session, "test@yahoo.com") is False
assert await repo.is_email_domain_allowed(transaction_session, "test@gmail.com") is True
assert await repo.is_domain_allowed(transaction_session, "yahoo.com") is False
assert await repo.is_domain_allowed(transaction_session, "gmail.com") is True

0 comments on commit e19d50e

Please sign in to comment.