Skip to content

Commit

Permalink
feat: set domain check to be endpoint specific, and remove unused end…
Browse files Browse the repository at this point in the history
…point
  • Loading branch information
lchen-2101 committed Sep 18, 2023
1 parent 342244c commit b0c184f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 31 deletions.
20 changes: 4 additions & 16 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,12 @@
from entities.engine import get_session
from entities.repos import institutions_repo as repo

OPEN_DOMAIN_REQUESTS = {
"/v1/admin/me": {"GET"},
"/v1/institutions": {"GET"},
"/v1/institutions/domains/allowed": {"GET"},
}


async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
if request_needs_domain_check(request):
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


def request_needs_domain_check(request: Request) -> bool:
path = request.scope["path"].rstrip("/")
return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path])
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

log = logging.getLogger()

app = FastAPI(dependencies=[Depends(check_domain)])
app = FastAPI()


@app.exception_handler(HTTPException)
Expand Down
14 changes: 4 additions & 10 deletions src/routers/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from http import HTTPStatus
from typing import Dict, Any, Set
from fastapi import Request
from fastapi import Depends, Request
from starlette.authentication import requires
from dependencies import check_domain
from util import Router

from oauth2 import AuthenticatedUser, oauth2_admin
Expand All @@ -15,20 +16,13 @@ async def get_me(request: Request):
return request.user


@router.put("/me/", status_code=HTTPStatus.ACCEPTED)
@router.put("/me/", status_code=HTTPStatus.ACCEPTED, dependencies=[Depends(check_domain)])
@requires("manage-account")
async def update_me(request: Request, user: Dict[str, Any]):
oauth2_admin.update_user(request.user.id, user)


@router.put("/me/groups/", status_code=HTTPStatus.ACCEPTED)
@requires("manage-account")
async def associate_group(request: Request, groups: Set[str]):
for group in groups:
oauth2_admin.associate_to_group(request.user.id, group)


@router.put("/me/institutions/", status_code=HTTPStatus.ACCEPTED)
@router.put("/me/institutions/", status_code=HTTPStatus.ACCEPTED, dependencies=[Depends(check_domain)])
@requires("manage-account")
async def associate_lei(request: Request, leis: Set[str]):
for lei in leis:
Expand Down
6 changes: 3 additions & 3 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from http import HTTPStatus
from oauth2 import oauth2_admin
from util import Router
from dependencies import parse_leis
from dependencies import check_domain, parse_leis
from typing import Annotated, List, Tuple
from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand Down Expand Up @@ -35,7 +35,7 @@ async def get_institutions(
return await repo.get_institutions(request.state.db_session, leis, domain, page, count)


@router.post("/", response_model=Tuple[str, FinancialInstitutionDto])
@router.post("/", response_model=Tuple[str, FinancialInstitutionDto], dependencies=[Depends(check_domain)])
@requires(["query-groups", "manage-users"])
async def create_institution(
request: Request,
Expand All @@ -58,7 +58,7 @@ async def get_institution(
return res


@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto])
@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)])
@requires(["query-groups", "manage-users"])
async def add_domains(
request: Request,
Expand Down
20 changes: 19 additions & 1 deletion tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import Mock, ANY

from fastapi import FastAPI
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -81,6 +81,15 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP
assert res.status_code == 200
assert res.json().get("name") == "Test Bank 123"

def test_get_institution_not_exists(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution")
get_institution_mock.return_value = None
client = TestClient(app_fixture)
lei_path = "testLeiPath"
res = client.get(f"/v1/institutions/{lei_path}")
get_institution_mock.assert_called_once_with(ANY, lei_path)
assert res.status_code == 404

def test_add_domains_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)

Expand Down Expand Up @@ -124,3 +133,12 @@ def test_add_domains_authed_with_denied_email_domain(
res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}])
assert res.status_code == 403
assert "domain denied" in res.json()["detail"]

def test_check_domain_allowed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_email_domain_allowed")
domain_allowed_mock.return_value = True
domain_to_check = "local.host"
client = TestClient(app_fixture)
res = client.get(f"/v1/institutions/domains/allowed?domain={domain_to_check}")
domain_allowed_mock.assert_called_once_with(ANY, domain_to_check)
assert res.json() == True

Check failure on line 144 in tests/api/routers/test_institutions_api.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E712)

tests/api/routers/test_institutions_api.py:144:30: E712 Comparison to `True` should be `cond is True` or `if cond:`

0 comments on commit b0c184f

Please sign in to comment.