From b0c184f65f27bd020432de159145dd37f16d2fc2 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:48:26 -0400 Subject: [PATCH] feat: set domain check to be endpoint specific, and remove unused endpoint --- src/dependencies.py | 20 ++++---------------- src/main.py | 2 +- src/routers/admin.py | 14 ++++---------- src/routers/institutions.py | 6 +++--- tests/api/routers/test_institutions_api.py | 20 +++++++++++++++++++- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/dependencies.py b/src/dependencies.py index 02ddcb0..0d36f97 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -10,24 +10,12 @@ from entities.engine import get_session from entities.repos import institutions_repo as repo -OPEN_DOMAIN_REQUESTS = { - "/v1/admin/me": {"GET"}, - "/v1/institutions": {"GET"}, - "/v1/institutions/domains/allowed": {"GET"}, -} - async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: - if request_needs_domain_check(request): - if not request.user.is_authenticated: - raise HTTPException(status_code=HTTPStatus.FORBIDDEN) - if await email_domain_denied(session, request.user.email): - raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied") - - -def request_needs_domain_check(request: Request) -> bool: - path = request.scope["path"].rstrip("/") - return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path]) + if not request.user.is_authenticated: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN) + if await email_domain_denied(session, request.user.email): + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied") async def email_domain_denied(session: AsyncSession, email: str) -> bool: diff --git a/src/main.py b/src/main.py index bc54b7a..3ef9690 100644 --- a/src/main.py +++ b/src/main.py @@ -15,7 +15,7 @@ log = logging.getLogger() -app = FastAPI(dependencies=[Depends(check_domain)]) +app = FastAPI() @app.exception_handler(HTTPException) diff --git a/src/routers/admin.py b/src/routers/admin.py index 676dc1e..06ec490 100644 --- a/src/routers/admin.py +++ b/src/routers/admin.py @@ -1,7 +1,8 @@ from http import HTTPStatus from typing import Dict, Any, Set -from fastapi import Request +from fastapi import Depends, Request from starlette.authentication import requires +from dependencies import check_domain from util import Router from oauth2 import AuthenticatedUser, oauth2_admin @@ -15,20 +16,13 @@ async def get_me(request: Request): return request.user -@router.put("/me/", status_code=HTTPStatus.ACCEPTED) +@router.put("/me/", status_code=HTTPStatus.ACCEPTED, dependencies=[Depends(check_domain)]) @requires("manage-account") async def update_me(request: Request, user: Dict[str, Any]): oauth2_admin.update_user(request.user.id, user) -@router.put("/me/groups/", status_code=HTTPStatus.ACCEPTED) -@requires("manage-account") -async def associate_group(request: Request, groups: Set[str]): - for group in groups: - oauth2_admin.associate_to_group(request.user.id, group) - - -@router.put("/me/institutions/", status_code=HTTPStatus.ACCEPTED) +@router.put("/me/institutions/", status_code=HTTPStatus.ACCEPTED, dependencies=[Depends(check_domain)]) @requires("manage-account") async def associate_lei(request: Request, leis: Set[str]): for lei in leis: diff --git a/src/routers/institutions.py b/src/routers/institutions.py index 1de7b94..012179c 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -2,7 +2,7 @@ from http import HTTPStatus from oauth2 import oauth2_admin from util import Router -from dependencies import parse_leis +from dependencies import check_domain, parse_leis from typing import Annotated, List, Tuple from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -35,7 +35,7 @@ async def get_institutions( return await repo.get_institutions(request.state.db_session, leis, domain, page, count) -@router.post("/", response_model=Tuple[str, FinancialInstitutionDto]) +@router.post("/", response_model=Tuple[str, FinancialInstitutionDto], dependencies=[Depends(check_domain)]) @requires(["query-groups", "manage-users"]) async def create_institution( request: Request, @@ -58,7 +58,7 @@ async def get_institution( return res -@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto]) +@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)]) @requires(["query-groups", "manage-users"]) async def add_domains( request: Request, diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index 7130437..a66e44c 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import Mock, ANY from fastapi import FastAPI from fastapi.testclient import TestClient @@ -81,6 +81,15 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP assert res.status_code == 200 assert res.json().get("name") == "Test Bank 123" + def test_get_institution_not_exists(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution") + get_institution_mock.return_value = None + client = TestClient(app_fixture) + lei_path = "testLeiPath" + res = client.get(f"/v1/institutions/{lei_path}") + get_institution_mock.assert_called_once_with(ANY, lei_path) + assert res.status_code == 404 + def test_add_domains_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) @@ -124,3 +133,12 @@ def test_add_domains_authed_with_denied_email_domain( res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 403 assert "domain denied" in res.json()["detail"] + + def test_check_domain_allowed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_email_domain_allowed") + domain_allowed_mock.return_value = True + domain_to_check = "local.host" + client = TestClient(app_fixture) + res = client.get(f"/v1/institutions/domains/allowed?domain={domain_to_check}") + domain_allowed_mock.assert_called_once_with(ANY, domain_to_check) + assert res.json() == True