From 98635a02317c40674875a4ac83b5034990593cd2 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Thu, 14 Sep 2023 09:10:14 -0700 Subject: [PATCH] feat: add denied domain check (#21) closes #12 --------- Co-authored-by: Hans Keeler --- README.md | 1 + src/.env.template | 16 +++ src/dependencies.py | 37 ++++++ src/entities/models/__init__.py | 10 +- src/entities/models/dao.py | 5 + src/entities/models/dto.py | 7 + src/entities/repos/institutions_repo.py | 19 ++- src/main.py | 9 +- src/routers/institutions.py | 30 +++-- tests/api/conftest.py | 10 +- tests/api/routers/test_institutions_api.py | 122 ++++++++++++------ .../entities/repos/test_institutions_repo.py | 8 ++ 12 files changed, 218 insertions(+), 56 deletions(-) create mode 100644 src/.env.template create mode 100644 src/dependencies.py diff --git a/README.md b/README.md index 8dcf1b7..2420ffd 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ poetry run uvicorn main:app --reload --port 8888 ``` ### Local development notes - [.env.template](.env.template) is added to allow VS Code to search the correct path for imports when writing tests, just copy the [.env.template](.env.template) file into `.env` file locally +- [src/.env.template](./src/.env.template) is added as the template for the app's environment variables, appropriate values are already provided in [.env.local](./src/.env.local) for local development. If `ENV` variable with default of `LOCAL` is changed, copy this template into `src/.env`, and provide appropriate values, and set all listed empty variables in the environment. --- ## Retrieve credentials diff --git a/src/.env.template b/src/.env.template new file mode 100644 index 0000000..7ccb51a --- /dev/null +++ b/src/.env.template @@ -0,0 +1,16 @@ +KC_URL= +KC_REALM= +KC_ADMIN_CLIENT_ID= +KC_ADMIN_CLIENT_SECRET= +KC_REALM_URL=${KC_URL}/realms/${KC_REALM} +KC_REALM_ADMIN_URL=${KC_URL}/admin/realms/${KC_REALM} +AUTH_URL=${KC_REALM_URL}/protocol/openid-connect/auth +TOKEN_URL=${KC_REALM_URL}/protocol/openid-connect/token +CERTS_URL=${KC_REALM_URL}/protocol/openid-connect/certs +AUTH_CLIENT= +INST_DB_NAME= +INST_DB_USER= +INST_DB_PWD= +INST_DB_HOST= +INST_DB_SCHEMA= +INST_CONN=postgresql+asyncpg://${INST_DB_USER}:${INST_DB_PWD}@${INST_DB_HOST}/${INST_DB_NAME} \ No newline at end of file diff --git a/src/dependencies.py b/src/dependencies.py new file mode 100644 index 0000000..384c7c9 --- /dev/null +++ b/src/dependencies.py @@ -0,0 +1,37 @@ +from http import HTTPStatus +from typing import Annotated +from fastapi import Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from entities.engine import get_session +from entities.repos import institutions_repo as repo + +OPEN_DOMAIN_REQUESTS = { + "/v1/admin/me": {"GET"}, + "/v1/institutions": {"GET"}, + "/v1/institutions/domains/allowed": {"GET"}, +} + + +async def check_domain( + request: Request, session: Annotated[AsyncSession, Depends(get_session)] +) -> None: + if request_needs_domain_check(request): + if not request.user.is_authenticated: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN) + if await email_domain_denied(session, request.user.email): + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="email domain denied" + ) + + +def request_needs_domain_check(request: Request) -> bool: + path = request.scope["path"].rstrip("/") + return not ( + path in OPEN_DOMAIN_REQUESTS + and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path] + ) + + +async def email_domain_denied(session: AsyncSession, email: str) -> bool: + return not await repo.is_email_domain_allowed(session, email) diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py index c8ac48a..6dfd58b 100644 --- a/src/entities/models/__init__.py +++ b/src/entities/models/__init__.py @@ -6,12 +6,20 @@ "FinancialInstitutionWithDomainsDto", "FinancialInsitutionDomainDto", "FinancialInsitutionDomainCreate", + "DeniedDomainDao", + "DeniedDomainDto", ] -from .dao import Base, FinancialInstitutionDao, FinancialInstitutionDomainDao +from .dao import ( + Base, + FinancialInstitutionDao, + FinancialInstitutionDomainDao, + DeniedDomainDao, +) from .dto import ( FinancialInstitutionDto, FinancialInstitutionWithDomainsDto, FinancialInsitutionDomainDto, FinancialInsitutionDomainCreate, + DeniedDomainDto, ) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 3581d0a..3059bd1 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -27,3 +27,8 @@ class FinancialInstitutionDomainDao(AuditMixin, Base): ForeignKey("financial_institutions.lei"), index=True, primary_key=True ) fi = relationship("FinancialInstitutionDao", back_populates="domains") + + +class DeniedDomainDao(AuditMixin, Base): + __tablename__ = "denied_domains" + domain: Mapped[str] = mapped_column(index=True, primary_key=True) diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index 5303ff5..bc84375 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -30,3 +30,10 @@ class Config: class FinancialInstitutionWithDomainsDto(FinancialInstitutionDto): domains: List[FinancialInsitutionDomainDto] = [] + + +class DeniedDomainDto(BaseModel): + domain: str + + class Config: + orm_mode = True diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 960f9a1..887809b 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -1,6 +1,6 @@ from typing import List -from sqlalchemy import select +from sqlalchemy import select, func from sqlalchemy.orm import joinedload from sqlalchemy.ext.asyncio import AsyncSession @@ -9,6 +9,7 @@ FinancialInstitutionDomainDao, FinancialInstitutionDto, FinancialInsitutionDomainCreate, + DeniedDomainDao, ) @@ -73,3 +74,19 @@ async def add_domains( session.add_all(daos) await session.commit() return daos + + +async def is_email_domain_allowed(session: AsyncSession, email: str) -> bool: + domain = get_email_domain(email) + if domain: + async with session: + stmt = select(func.count()).filter(DeniedDomainDao.domain == domain) + res = await session.scalar(stmt) + return res == 0 + return False + + +def get_email_domain(email: str) -> str: + if email: + return email.split("@")[-1] + return None diff --git a/src/main.py b/src/main.py index b4e0a36..3fd265a 100644 --- a/src/main.py +++ b/src/main.py @@ -2,11 +2,12 @@ import logging import env # noqa: F401 from http import HTTPStatus -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.responses import JSONResponse from fastapi.security import OAuth2AuthorizationCodeBearer from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.authentication import AuthenticationMiddleware +from dependencies import check_domain from routers import admin_router, institutions_router @@ -14,7 +15,7 @@ log = logging.getLogger() -app = FastAPI() +app = FastAPI(dependencies=[Depends(check_domain)]) @app.exception_handler(HTTPException) @@ -23,7 +24,7 @@ async def http_exception_handler( ) -> JSONResponse: log.error(exception, exc_info=True, stack_info=True) return JSONResponse( - status_code=exception.status_code, content={"message": exception.detail} + status_code=exception.status_code, content={"detail": exception.detail} ) @@ -34,7 +35,7 @@ async def general_exception_handler( log.error(exception, exc_info=True, stack_info=True) return JSONResponse( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - content={"message": "server error"}, + content={"detail": "server error"}, ) diff --git a/src/routers/institutions.py b/src/routers/institutions.py index 72a54f1..e54ed32 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -2,7 +2,7 @@ from http import HTTPStatus from oauth2 import oauth2_admin from util import Router -from typing import List, Tuple +from typing import Annotated, List, Tuple from entities.engine import get_session from entities.repos import institutions_repo as repo from entities.models import ( @@ -14,7 +14,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from starlette.authentication import requires -router = Router() + +async def set_db( + request: Request, session: Annotated[AsyncSession, Depends(get_session)] +): + request.state.db_session = session + + +router = Router(dependencies=[Depends(set_db)]) @router.get("/", response_model=List[FinancialInstitutionWithDomainsDto]) @@ -24,9 +31,8 @@ async def get_institutions( domain: str = "", page: int = 0, count: int = 100, - session: AsyncSession = Depends(get_session), ): - return await repo.get_institutions(session, domain, page, count) + return await repo.get_institutions(request.state.db_session, domain, page, count) @router.post("/", response_model=Tuple[str, FinancialInstitutionDto]) @@ -34,9 +40,8 @@ async def get_institutions( async def create_institution( request: Request, fi: FinancialInstitutionDto, - session: AsyncSession = Depends(get_session), ): - db_fi = await repo.upsert_institution(session, fi) + db_fi = await repo.upsert_institution(request.state.db_session, fi) kc_id = oauth2_admin.upsert_group(fi.lei, fi.name) return kc_id, db_fi @@ -44,9 +49,10 @@ async def create_institution( @router.get("/{lei}", response_model=FinancialInstitutionWithDomainsDto) @requires("authenticated") async def get_institution( - request: Request, lei: str, session: AsyncSession = Depends(get_session) + request: Request, + lei: str, ): - res = await repo.get_institution(session, lei) + res = await repo.get_institution(request.state.db_session, lei) if not res: raise HTTPException(HTTPStatus.NOT_FOUND, f"{lei} not found.") return res @@ -58,6 +64,10 @@ async def add_domains( request: Request, lei: str, domains: List[FinancialInsitutionDomainCreate], - session: AsyncSession = Depends(get_session), ): - return await repo.add_domains(session, lei, domains) + return await repo.add_domains(request.state.db_session, lei, domains) + + +@router.get("/domains/allowed", response_model=bool) +async def is_domain_allowed(request: Request, domain: str): + return await repo.is_email_domain_allowed(request.state.db_session, domain) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 1174977..148a35c 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -2,18 +2,20 @@ import pytest from fastapi import FastAPI -from pytest_mock import MockerFixture, MockFixture +from pytest_mock import MockerFixture from starlette.authentication import AuthCredentials, UnauthenticatedUser from oauth2.oauth2_backend import AuthenticatedUser @pytest.fixture -def app_fixture(mocker: MockFixture) -> FastAPI: +def app_fixture(mocker: MockerFixture) -> FastAPI: mocked_engine = mocker.patch("sqlalchemy.ext.asyncio.create_async_engine") MockedEngine = mocker.patch("sqlalchemy.ext.asyncio.AsyncEngine") mocked_engine.return_value = MockedEngine.return_value mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer") + domain_denied_mock = mocker.patch("dependencies.email_domain_denied") + domain_denied_mock.return_value = False from main import app return app @@ -33,7 +35,9 @@ def authed_user_mock(auth_mock: Mock) -> Mock: "sub": "testuser123", } auth_mock.return_value = ( - AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]), + AuthCredentials( + ["manage-account", "query-groups", "manage-users", "authenticated"] + ), AuthenticatedUser.from_claim(claims), ) return auth_mock diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index 7278695..e4d0706 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -5,14 +5,13 @@ from pytest_mock import MockerFixture from starlette.authentication import AuthCredentials from oauth2.oauth2_backend import AuthenticatedUser -from entities.models import ( - FinancialInstitutionDao, - FinancialInstitutionDomainDao -) +from entities.models import FinancialInstitutionDao, FinancialInstitutionDomainDao class TestInstitutionsApi: - def test_get_institutions_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): + def test_get_institutions_unauthed( + self, app_fixture: FastAPI, unauthed_user_mock: Mock + ): client = TestClient(app_fixture) res = client.get("/v1/institutions/") assert res.status_code == 403 @@ -20,26 +19,38 @@ def test_get_institutions_unauthed(self, app_fixture: FastAPI, unauthed_user_moc def test_get_institutions_authed( self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock ): - get_institutions_mock = mocker.patch("entities.repos.institutions_repo.get_institutions") - get_institutions_mock.return_value = [FinancialInstitutionDao( - name="Test Bank 123", - lei="TESTBANK123", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], - )] + get_institutions_mock = mocker.patch( + "entities.repos.institutions_repo.get_institutions" + ) + get_institutions_mock.return_value = [ + FinancialInstitutionDao( + name="Test Bank 123", + lei="TESTBANK123", + domains=[ + FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") + ], + ) + ] client = TestClient(app_fixture) res = client.get("/v1/institutions/") assert res.status_code == 200 assert res.json()[0].get("name") == "Test Bank 123" - def test_create_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): + def test_create_institution_unauthed( + self, app_fixture: FastAPI, unauthed_user_mock: Mock + ): client = TestClient(app_fixture) - res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) + res = client.post( + "/v1/institutions/", json={"name": "testName", "lei": "testLei"} + ) assert res.status_code == 403 - def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): - upsert_institution_mock = mocker.patch("entities.repos.institutions_repo.upsert_institution") + def test_create_institution_authed( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + ): + upsert_institution_mock = mocker.patch( + "entities.repos.institutions_repo.upsert_institution" + ) upsert_institution_mock.return_value = FinancialInstitutionDao( name="testName", lei="testLei", @@ -50,11 +61,15 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas upsert_group_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.upsert_group") upsert_group_mock.return_value = "leiGroup" client = TestClient(app_fixture) - res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) + res = client.post( + "/v1/institutions/", json={"name": "testName", "lei": "testLei"} + ) assert res.status_code == 200 assert res.json()[1].get("name") == "testName" - def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): + def test_create_institution_authed_no_permission( + self, app_fixture: FastAPI, auth_mock: Mock + ): claims = { "name": "test", "preferred_username": "test_user", @@ -66,17 +81,25 @@ def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, aut AuthenticatedUser.from_claim(claims), ) client = TestClient(app_fixture) - res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) + res = client.post( + "/v1/institutions/", json={"name": "testName", "lei": "testLei"} + ) assert res.status_code == 403 - def test_get_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): + def test_get_institution_unauthed( + self, app_fixture: FastAPI, unauthed_user_mock: Mock + ): client = TestClient(app_fixture) - leiPath = "testLeiPath" - res = client.get(f"/v1/institutions/{leiPath}") + lei_path = "testLeiPath" + res = client.get(f"/v1/institutions/{lei_path}") assert res.status_code == 403 - def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): - get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution") + def test_get_institution_authed( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + ): + get_institution_mock = mocker.patch( + "entities.repos.institutions_repo.get_institution" + ) get_institution_mock.return_value = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", @@ -85,27 +108,37 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP ], ) client = TestClient(app_fixture) - leiPath = "testLeiPath" - res = client.get(f"/v1/institutions/{leiPath}") + lei_path = "testLeiPath" + res = client.get(f"/v1/institutions/{lei_path}") assert res.status_code == 200 assert res.json().get("name") == "Test Bank 123" def test_add_domains_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) - leiPath = "testLeiPath" - res = client.post(f"/v1/institutions/{leiPath}/domains/", json=[{"domain": "testDomain"}]) + lei_path = "testLeiPath" + res = client.post( + f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] + ) assert res.status_code == 403 - def test_add_domains_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + def test_add_domains_authed( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + ): add_domains_mock = mocker.patch("entities.repos.institutions_repo.add_domains") - add_domains_mock.return_value = [FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")] + add_domains_mock.return_value = [ + FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") + ] client = TestClient(app_fixture) - leiPath = "testLeiPath" - res = client.post(f"/v1/institutions/{leiPath}/domains/", json=[{"domain": "testDomain"}]) + lei_path = "testLeiPath" + res = client.post( + f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] + ) assert res.status_code == 200 assert res.json()[0].get("domain") == "test.bank" - def test_add_domains_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): + def test_add_domains_authed_no_permission( + self, app_fixture: FastAPI, auth_mock: Mock + ): claims = { "name": "test", "preferred_username": "test_user", @@ -117,6 +150,21 @@ def test_add_domains_authed_no_permission(self, app_fixture: FastAPI, auth_mock: AuthenticatedUser.from_claim(claims), ) client = TestClient(app_fixture) - leiPath = "testLeiPath" - res = client.post(f"/v1/institutions/{leiPath}/domains/", json=[{"domain": "testDomain"}]) - assert res.status_code == 403 \ No newline at end of file + lei_path = "testLeiPath" + res = client.post( + f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] + ) + assert res.status_code == 403 + + def test_add_domains_authed_with_denied_email_domain( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + ): + domain_denied_mock = mocker.patch("dependencies.email_domain_denied") + domain_denied_mock.return_value = True + client = TestClient(app_fixture) + lei_path = "testLeiPath" + res = client.post( + f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] + ) + assert res.status_code == 403 + assert "domain denied" in res.json()["detail"] diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 3870836..1ac7f50 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -6,6 +6,7 @@ FinancialInstitutionDomainDao, FinancialInsitutionDomainCreate, ) +from entities.models import DeniedDomainDao from entities.repos import institutions_repo as repo @@ -60,3 +61,10 @@ async def test_add_domains(self, session: AsyncSession): ) fi = await repo.get_institution(session, "TESTBANK123") assert len(fi.domains) == 2 + + async def test_domain_allowed(self, session: AsyncSession): + denied_domain = DeniedDomainDao(domain="yahoo.com") + session.add(denied_domain) + await session.commit() + assert await repo.is_email_domain_allowed(session, "test@yahoo.com") is False + assert await repo.is_email_domain_allowed(session, "test@gmail.com") is True